Skip to content

Commit

Permalink
parameterize tests
Browse files Browse the repository at this point in the history
  • Loading branch information
epistoteles committed Jun 21, 2024
1 parent 0165add commit b175432
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 92 deletions.
70 changes: 4 additions & 66 deletions tensorhue/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,68 +51,6 @@ def _viz(self, colorscheme: ColorScheme = None, legend: bool = True):
c.print("\n".join(result_lines))


# def _viz_2d(array_2d: np.ndarray, colorscheme: ColorScheme) -> list[str]:
# """
# Constructs a list of rich-compatible strings out of a 2D numpy array.

# Args:
# array_2d (np.ndarray): The 2-dimensional numpy array
# colorscheme (ColorScheme): The color scheme to use
# """
# result_lines = [""]
# terminal_width = get_terminal_width()
# shape = array_2d.shape

# if shape[1] <= terminal_width:
# colors = colorscheme(array_2d)[..., :3]

# for y in range(0, shape[0] - 1, 2):
# for x in range(shape[-1]):
# result_lines[
# -1
# ] += f"[rgb({colors[y, x, 0]},{colors[y, x, 1]},{colors[y, x, 2]}) on rgb({colors[y+1, x, 0]},{colors[y+1, x, 1]},{colors[y+1, x, 2]})]▀[/]"
# result_lines.append("")

# if shape[0] % 2 == 1:
# for x in range(shape[1]):
# result_lines[-1] += f"[rgb({colors[-1, x, 0]},{colors[-1, x, 1]},{colors[-1, x, 2]})]▀[/]"
# else:
# result_lines = result_lines[:-1]

# else:
# slice_left = (terminal_width - 5) // 2
# slice_right = slice_left + 5
# colors_left = colorscheme(array_2d[:, :slice_left])[..., :3]
# colors_right = colorscheme(array_2d[:, slice_right:])[..., :3]

# for y in range(0, shape[0] - 1, 2):
# for x in range(slice_left):
# result_lines[
# -1
# ] += f"[rgb({colors_left[y, x, 0]},{colors_left[y, x, 1]},{colors_left[y, x, 2]}) on rgb({colors_left[y+1, x, 0]},{colors_left[y+1, x, 1]},{colors_left[y+1, x, 2]})]▀[/]"
# result_lines[-1] += " ··· "
# for x in range(terminal_width - slice_right):
# result_lines[
# -1
# ] += f"[rgb({colors_right[y, x, 0]},{colors_right[y, x, 1]},{colors_right[y, x, 2]}) on rgb({colors_right[y+1, x, 0]},{colors_right[y+1, x, 1]},{colors_right[y+1, x, 2]})]▀[/]"
# result_lines.append("")

# if shape[0] % 2 == 1:
# for x in range(slice_left):
# result_lines[
# -1
# ] += f"[rgb({colors_left[-1, x, 0]},{colors_left[-1, x, 1]},{colors_left[-1, x, 2]})]▀[/]"
# result_lines[-1] += " ··· "
# for x in range(terminal_width - slice_right):
# result_lines[
# -1
# ] += f"[rgb({colors_right[-1, x, 0]},{colors_right[-1, x, 1]},{colors_right[-1, x, 2]})]▀[/]"
# else:
# result_lines = result_lines[:-1]

# return result_lines


def _viz_2d(array_2d: np.ndarray, colorscheme: ColorScheme) -> list[str]:
"""
Constructs a list of rich-compatible strings out of a 2D numpy array.
Expand All @@ -127,8 +65,8 @@ def _viz_2d(array_2d: np.ndarray, colorscheme: ColorScheme) -> list[str]:

if shape[1] > terminal_width:
slice_left = (terminal_width - 5) // 2
slice_right = slice_left + 5
colors_right = colorscheme(array_2d[:, slice_right:])[..., :3]
slice_right = slice_left + (terminal_width - 5) % 2
colors_right = colorscheme(array_2d[:, -slice_right:])[..., :3]
else:
slice_left = shape[1]
slice_right = colors_right = False
Expand All @@ -142,7 +80,7 @@ def _viz_2d(array_2d: np.ndarray, colorscheme: ColorScheme) -> list[str]:
] += f"[rgb({colors_left[y, x, 0]},{colors_left[y, x, 1]},{colors_left[y, x, 2]}) on rgb({colors_left[y+1, x, 0]},{colors_left[y+1, x, 1]},{colors_left[y+1, x, 2]})]▀[/]"
if slice_right:
result_lines[-1] += " ··· "
for x in range(terminal_width - slice_right):
for x in range(slice_right):
result_lines[
-1
] += f"[rgb({colors_right[y, x, 0]},{colors_right[y, x, 1]},{colors_right[y, x, 2]}) on rgb({colors_right[y+1, x, 0]},{colors_right[y+1, x, 1]},{colors_right[y+1, x, 2]})]▀[/]"
Expand All @@ -153,7 +91,7 @@ def _viz_2d(array_2d: np.ndarray, colorscheme: ColorScheme) -> list[str]:
result_lines[-1] += f"[rgb({colors_left[-1, x, 0]},{colors_left[-1, x, 1]},{colors_left[-1, x, 2]})]▀[/]"
if slice_right:
result_lines[-1] += " ··· "
for x in range(terminal_width - slice_right):
for x in range(slice_right):
result_lines[
-1
] += f"[rgb({colors_right[-1, x, 0]},{colors_right[-1, x, 1]},{colors_right[-1, x, 2]})]▀[/]"
Expand Down
51 changes: 25 additions & 26 deletions tests/test_viz.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,53 @@
import pytest
import torch
import numpy as np
from tensorhue.viz import viz
from tensorhue._torch import _tensorhue_to_numpy_torch


def test_1d_tensor_numpy(capsys):
n = np.ones(10)
viz(n)
@pytest.mark.parametrize("input", [np.ones(10), _tensorhue_to_numpy_torch(torch.ones(10))])
def test_1d_tensor(input, capsys):
viz(input)
captured = capsys.readouterr()
out = captured.out.rstrip("\n")
assert len(out.split("\n")) == 2
assert out.count("▀") == 10
assert out.split("\n")[-1] == f"shape = {n.shape}"
assert out.split("\n")[-1] == f"shape = {input.shape}"


def test_2d_tensor_numpy(capsys):
n = np.ones((10, 10))
viz(n)
@pytest.mark.parametrize("input", [np.ones((10, 10)), _tensorhue_to_numpy_torch(torch.ones(10, 10))])
def test_2d_tensor(input, capsys):
viz(input)
captured = capsys.readouterr()
out = captured.out.rstrip("\n")
assert len(out.split("\n")) == 6
assert out.count("▀") == 50
assert out.split("\n")[-1] == f"shape = {n.shape}"
assert out.count("▀") == 100 / 2
assert out.split("\n")[-1] == f"shape = {input.shape}"


def test_1d_tensor_torch(capsys):
t = torch.ones(10)
n = _tensorhue_to_numpy_torch(t)
viz(n)
@pytest.mark.parametrize("input", [np.ones(200), _tensorhue_to_numpy_torch(torch.ones(200))])
def test_1d_tensor_too_wide(input, capsys):
viz(input)
captured = capsys.readouterr()
out = captured.out.rstrip("\n")
assert len(out.split("\n")) == 2
assert out.count("▀") == 10
assert out.split("\n")[-1] == f"shape = {n.shape}"
assert out.count(" ··· ") == 1
assert out.count("▀") == 95
assert out.split("\n")[-1] == f"shape = {input.shape}"


def test_2d_tensor_torch(capsys):
t = torch.ones(10, 10)
n = _tensorhue_to_numpy_torch(t)
viz(n)
@pytest.mark.parametrize("input", [np.ones((10, 200)), _tensorhue_to_numpy_torch(torch.ones(10, 200))])
def test_2d_tensor_too_wide(input, capsys):
viz(input)
captured = capsys.readouterr()
out = captured.out.rstrip("\n")
assert len(out.split("\n")) == 6
assert out.count("▀") == 50
assert out.split("\n")[-1] == f"shape = {n.shape}"
assert out.count(" ··· ") == 5
assert out.count("▀") == 950 / 2
assert out.split("\n")[-1] == f"shape = {input.shape}"


def test_no_legend(capsys):
n = np.ones(10)
viz(n, legend=False)
@pytest.mark.parametrize("input", [np.ones(10), _tensorhue_to_numpy_torch(torch.ones(10))])
def test_no_legend(input, capsys):
viz(input, legend=False)
captured = capsys.readouterr()
out = captured.out.rstrip("\n")
assert len(out.split("\n")) == 1
Expand Down

0 comments on commit b175432

Please sign in to comment.