Skip to content

Commit

Permalink
add 1d tensor support, legends, and dynamic terminal width adjustment…
Browse files Browse the repository at this point in the history
… when printing
  • Loading branch information
epistoteles committed Jun 21, 2024
1 parent ffe95b5 commit 0165add
Show file tree
Hide file tree
Showing 4 changed files with 204 additions and 15 deletions.
17 changes: 13 additions & 4 deletions tensorhue/eastereggs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,19 @@
from matplotlib.colors import LinearSegmentedColormap
from rich.color_triplet import ColorTriplet
from tensorhue.colors import ColorScheme
from tensorhue.viz import viz
from tensorhue.viz import viz, get_terminal_width


def pride():
def pride(width: int = None):
"""
Prints a pride flag in the terminal
Args:
width (int, optional): The width of the pride flag. If none is specified,
the full width of the terminal is used.
"""
if width is None:
width = get_terminal_width(default_width=10)
pride_colors = [
ColorTriplet(228, 3, 3),
ColorTriplet(255, 140, 0),
Expand All @@ -16,5 +25,5 @@ def pride():
]
pride_cm = LinearSegmentedColormap.from_list(colors=[c.normalized for c in pride_colors], name="pride")
pride_cs = ColorScheme(colormap=pride_cm)
arr = np.repeat(np.linspace(0, 1, 6).reshape(-1, 1), 10, axis=1)
viz(arr, colorscheme=pride_cs)
arr = np.repeat(np.linspace(0, 1, 6).reshape(-1, 1), width, axis=1)
viz(arr, colorscheme=pride_cs, legend=False)
142 changes: 133 additions & 9 deletions tensorhue/viz.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os
import sys
from rich.console import Console
import numpy as np
from tensorhue.colors import ColorScheme
Expand All @@ -18,38 +20,160 @@ def viz(tensor, *args, **kwargs):
) from e


def _viz(self, colorscheme: ColorScheme = None):
def _viz(self, colorscheme: ColorScheme = None, legend: bool = True):
"""
Prints a tensor using colored Unicode art representation.
Args:
colorscheme (ColorScheme, optional): The color scheme to use.
Defaults to None, which means the global default color scheme is used.
legend (bool, optional): Whether or not to include legend information (like the shape)
"""
if colorscheme is None:
colorscheme = PRINT_OPTS.colorscheme

self = self._tensorhue_to_numpy()
shape = self.shape

if len(shape) > 2:
if len(shape) == 1:
self = self[np.newaxis, :]
elif len(shape) > 2:
raise NotImplementedError(
"Visualization for tensors with more than 2 dimensions is under development. Please slice them for now."
)

colors = colorscheme(self)[..., :3]
result_lines = _viz_2d(self, colorscheme)

if legend:
result_lines.append(f"[italic]shape = {shape}[/]")

c = Console(log_path=False, record=False)
c.print("\n".join(result_lines))


# def _viz_2d(array_2d: np.ndarray, colorscheme: ColorScheme) -> list[str]:
# """
# Constructs a list of rich-compatible strings out of a 2D numpy array.

# Args:
# array_2d (np.ndarray): The 2-dimensional numpy array
# colorscheme (ColorScheme): The color scheme to use
# """
# result_lines = [""]
# terminal_width = get_terminal_width()
# shape = array_2d.shape

# if shape[1] <= terminal_width:
# colors = colorscheme(array_2d)[..., :3]

# for y in range(0, shape[0] - 1, 2):
# for x in range(shape[-1]):
# result_lines[
# -1
# ] += f"[rgb({colors[y, x, 0]},{colors[y, x, 1]},{colors[y, x, 2]}) on rgb({colors[y+1, x, 0]},{colors[y+1, x, 1]},{colors[y+1, x, 2]})]▀[/]"
# result_lines.append("")

# if shape[0] % 2 == 1:
# for x in range(shape[1]):
# result_lines[-1] += f"[rgb({colors[-1, x, 0]},{colors[-1, x, 1]},{colors[-1, x, 2]})]▀[/]"
# else:
# result_lines = result_lines[:-1]

# else:
# slice_left = (terminal_width - 5) // 2
# slice_right = slice_left + 5
# colors_left = colorscheme(array_2d[:, :slice_left])[..., :3]
# colors_right = colorscheme(array_2d[:, slice_right:])[..., :3]

# for y in range(0, shape[0] - 1, 2):
# for x in range(slice_left):
# result_lines[
# -1
# ] += f"[rgb({colors_left[y, x, 0]},{colors_left[y, x, 1]},{colors_left[y, x, 2]}) on rgb({colors_left[y+1, x, 0]},{colors_left[y+1, x, 1]},{colors_left[y+1, x, 2]})]▀[/]"
# result_lines[-1] += " ··· "
# for x in range(terminal_width - slice_right):
# result_lines[
# -1
# ] += f"[rgb({colors_right[y, x, 0]},{colors_right[y, x, 1]},{colors_right[y, x, 2]}) on rgb({colors_right[y+1, x, 0]},{colors_right[y+1, x, 1]},{colors_right[y+1, x, 2]})]▀[/]"
# result_lines.append("")

# if shape[0] % 2 == 1:
# for x in range(slice_left):
# result_lines[
# -1
# ] += f"[rgb({colors_left[-1, x, 0]},{colors_left[-1, x, 1]},{colors_left[-1, x, 2]})]▀[/]"
# result_lines[-1] += " ··· "
# for x in range(terminal_width - slice_right):
# result_lines[
# -1
# ] += f"[rgb({colors_right[-1, x, 0]},{colors_right[-1, x, 1]},{colors_right[-1, x, 2]})]▀[/]"
# else:
# result_lines = result_lines[:-1]

# return result_lines


def _viz_2d(array_2d: np.ndarray, colorscheme: ColorScheme) -> list[str]:
"""
Constructs a list of rich-compatible strings out of a 2D numpy array.
Args:
array_2d (np.ndarray): The 2-dimensional numpy array
colorscheme (ColorScheme): The color scheme to use
"""
result_lines = [""]
terminal_width = get_terminal_width()
shape = array_2d.shape

if shape[1] > terminal_width:
slice_left = (terminal_width - 5) // 2
slice_right = slice_left + 5
colors_right = colorscheme(array_2d[:, slice_right:])[..., :3]
else:
slice_left = shape[1]
slice_right = colors_right = False

colors_left = colorscheme(array_2d[:, :slice_left])[..., :3]

for y in range(0, shape[0] - 1, 2):
for x in range(shape[-1]):
for x in range(slice_left):
result_lines[
-1
] += f"[rgb({colors[y, x, 0]},{colors[y, x, 1]},{colors[y, x, 2]}) on rgb({colors[y+1, x, 0]},{colors[y+1, x, 1]},{colors[y+1, x, 2]})]▀[/]"
] += f"[rgb({colors_left[y, x, 0]},{colors_left[y, x, 1]},{colors_left[y, x, 2]}) on rgb({colors_left[y+1, x, 0]},{colors_left[y+1, x, 1]},{colors_left[y+1, x, 2]})]▀[/]"
if slice_right:
result_lines[-1] += " ··· "
for x in range(terminal_width - slice_right):
result_lines[
-1
] += f"[rgb({colors_right[y, x, 0]},{colors_right[y, x, 1]},{colors_right[y, x, 2]}) on rgb({colors_right[y+1, x, 0]},{colors_right[y+1, x, 1]},{colors_right[y+1, x, 2]})]▀[/]"
result_lines.append("")

if shape[0] % 2 == 1:
for x in range(shape[1]):
result_lines[-1] += f"[rgb({colors[-1, x, 0]},{colors[-1, x, 1]},{colors[-1, x, 2]})]▀[/]"
for x in range(slice_left):
result_lines[-1] += f"[rgb({colors_left[-1, x, 0]},{colors_left[-1, x, 1]},{colors_left[-1, x, 2]})]▀[/]"
if slice_right:
result_lines[-1] += " ··· "
for x in range(terminal_width - slice_right):
result_lines[
-1
] += f"[rgb({colors_right[-1, x, 0]},{colors_right[-1, x, 1]},{colors_right[-1, x, 2]})]▀[/]"
else:
result_lines = result_lines[:-1]

c = Console(log_path=False, record=False)
c.print("\n".join(result_lines))
return result_lines


def get_terminal_width(default_width: int = 100) -> int:
"""
Returns the terminal width if the standard output is connected to a terminal. Otherwise, returns default_width.
Args:
default_width (int, optional): The default width to use if there is no terminal.
"""
if sys.stdout.isatty():
try:
return os.get_terminal_size().columns
except OSError:
return default_width
else:
return default_width
5 changes: 3 additions & 2 deletions tests/test_eastereggs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
def test_pride_output(capsys):
pride()
captured = capsys.readouterr()
assert len(captured.out.split("\n")) == 5
assert captured.out.count("▀") == 30
out = captured.out.rstrip("\n")
assert len(out.split("\n")) == 3
assert out.count("▀") == 30
55 changes: 55 additions & 0 deletions tests/test_viz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import torch
import numpy as np
from tensorhue.viz import viz
from tensorhue._torch import _tensorhue_to_numpy_torch


def test_1d_tensor_numpy(capsys):
n = np.ones(10)
viz(n)
captured = capsys.readouterr()
out = captured.out.rstrip("\n")
assert len(out.split("\n")) == 2
assert out.count("▀") == 10
assert out.split("\n")[-1] == f"shape = {n.shape}"


def test_2d_tensor_numpy(capsys):
n = np.ones((10, 10))
viz(n)
captured = capsys.readouterr()
out = captured.out.rstrip("\n")
assert len(out.split("\n")) == 6
assert out.count("▀") == 50
assert out.split("\n")[-1] == f"shape = {n.shape}"


def test_1d_tensor_torch(capsys):
t = torch.ones(10)
n = _tensorhue_to_numpy_torch(t)
viz(n)
captured = capsys.readouterr()
out = captured.out.rstrip("\n")
assert len(out.split("\n")) == 2
assert out.count("▀") == 10
assert out.split("\n")[-1] == f"shape = {n.shape}"


def test_2d_tensor_torch(capsys):
t = torch.ones(10, 10)
n = _tensorhue_to_numpy_torch(t)
viz(n)
captured = capsys.readouterr()
out = captured.out.rstrip("\n")
assert len(out.split("\n")) == 6
assert out.count("▀") == 50
assert out.split("\n")[-1] == f"shape = {n.shape}"


def test_no_legend(capsys):
n = np.ones(10)
viz(n, legend=False)
captured = capsys.readouterr()
out = captured.out.rstrip("\n")
assert len(out.split("\n")) == 1
assert out.count("▀") == 10

0 comments on commit 0165add

Please sign in to comment.