Skip to content

Commit

Permalink
update __init__.py
Browse files Browse the repository at this point in the history
  • Loading branch information
epistoteles committed May 31, 2024
1 parent 6d12225 commit 231d7f0
Showing 1 changed file with 22 additions and 17 deletions.
39 changes: 22 additions & 17 deletions tensorhue/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,39 +3,44 @@
import numpy as np
from tensorhue.version import VERSION
from tensorhue.colors import COLORS

__version__ = VERSION

__all__ = [
"set_printoptions",
"setup"
]
__all__ = ["set_printoptions", "setup"]


def viz(self) -> None:
if isinstance(self, torch.FloatTensor):
_viz_Tensor(self, COLORS['default_dark'], COLORS['default_bright'])
_viz_Tensor(self, COLORS["default_dark"], COLORS["default_bright"])
elif isinstance(self, torch.BoolTensor):
_viz_Tensor(self, COLORS['false'], COLORS['true'])
_viz_Tensor(self, COLORS["false"], COLORS["true"])


def _viz_Tensor(self, colors: tuple[tuple[int], tuple[int]] = None) -> None:
data = self.data.numpy()
shape = data.shape
color_a = np.array(color_a)
color_b = np.array(color_b)
color = ((1 - data[::2, :, None]) * color_a + data[::2, :, None] * color_b).astype(int)
bgcolor = ((1 - data[1::2, :, None]) * color_a + data[1::2, :, None] * color_b).astype(int)

color_a = np.array(colors[0])
color_b = np.array(colors[1])
color = ((1 - data[::2, :, None]) * color_a + data[::2, :, None] * color_b).astype(
int
)
bgcolor = (
(1 - data[1::2, :, None]) * color_a + data[1::2, :, None] * color_b
).astype(int)

result_parts = []
for y in range(shape[0] // 2):
for x in range(shape[1]):
result_parts.append(f"[rgb({color[y, x, 0]},{color[y, x, 1]},{color[y, x, 2]}) on rgb({bgcolor[y, x, 0]},{bgcolor[y, x, 1]},{bgcolor[y, x, 2]})]▀[/]")
result_parts.append(
f"[rgb({color[y, x, 0]},{color[y, x, 1]},{color[y, x, 2]}) on rgb({bgcolor[y, x, 0]},{bgcolor[y, x, 1]},{bgcolor[y, x, 2]})]▀[/]"
)
result_parts.append("\n")

c = Console(log_path=False, record=False)
c.print(''.join(result_parts))
c.print("".join(result_parts))


# automagically set up tensorhue
if 'torch' in sys.modules:
torch = sys.modules['torch']
# automagically set up TensorHue
if "torch" in sys.modules:
torch = sys.modules["torch"]
setattr(torch.Tensor, "viz", viz)

0 comments on commit 231d7f0

Please sign in to comment.