Skip to content

Commit

Permalink
update ColorScheme and supported libraries and lint everything
Browse files Browse the repository at this point in the history
  • Loading branch information
epistoteles committed Jun 16, 2024
1 parent 2f9a6dd commit 16db8d3
Show file tree
Hide file tree
Showing 7 changed files with 157 additions and 86 deletions.
62 changes: 9 additions & 53 deletions tensorhue/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,63 +3,19 @@
import numpy as np
from tensorhue.colors import COLORS, ColorScheme
from tensorhue._print_opts import PRINT_OPTS, set_printoptions
from tensorhue.numpy import NumpyArrayWrapper
from tensorhue.torch import _tensorhue_to_numpy_torch
from tensorhue.eastereggs import pride
from tensorhue.viz import viz, _viz


__version__ = "0.0.2" # single source of version truth

__all__ = ["set_printoptions"]


# def viz(self) -> None:
# """
# Prints the tensor using a colored Unicode art representation.
# This method checks the type of the tensor and calls the `_viz_tensor` function with the appropriate colors.
# """
# if isinstance(self, (torch.FloatTensor, torch.IntTensor, torch.LongTensor)): # pylint: disable=possibly-used-before-assignment
# _viz_tensor(self)
# elif isinstance(self, torch.BoolTensor):
# _viz_tensor(self, (COLORS["false"], COLORS["true"]))


def viz(self, colorscheme: ColorScheme = None) -> None:
"""
Prints a tensor using colored Unicode art representation.
This function takes a tensor and a tuple of two tuples of integers representing the colors.
It converts the tensor data to a numpy array, calculates the colors for each element based on the input colors,
and generates a string representation of the tensor using the calculated colors.
The resulting string is then printed using the Console class from the rich library.
Parameters:
colors (tuple[tuple[int], tuple[int]]): A tuple of two RGB tuples representing the colors.
The first tuple represents the RGB color for the smallest value in the tensor, the second tuple represents
the RGB color for the biggest value of the tensor. (Default = None; uses default colors)
"""
if colorscheme is None:
colorscheme = PRINT_OPTS.colorscheme

data = self.data.numpy()
shape = data.shape

colors = colorscheme(data)[..., :3]

result_lines = [""]
for y in range(0, shape[0] - 1, 2):
for x in range(shape[-1]):
result_lines[
-1
] += f"[rgb({colors[y, x, 0]},{colors[y, x, 1]},{colors[y, x, 2]}) on rgb({colors[y+1, x, 0]},{colors[y+1, x, 1]},{colors[y+1, x, 2]})]▀[/]"
result_lines.append("")

if shape[0] % 2 == 1:
for x in range(shape[1]):
result_lines[-1] += f"[rgb({colors[-1, x, 0]},{colors[-1, x, 1]},{colors[-1, x, 2]})]▀[/]"

c = Console(log_path=False, record=False)
c.print("\n".join(result_lines))
__version__ = "0.0.3" # single source of version truth

__all__ = ["set_printoptions", "viz", "pride"]

# automagically set up TensorHue
setattr(NumpyArrayWrapper, "viz", _viz)
if "torch" in sys.modules:
torch = sys.modules["torch"]
setattr(torch.Tensor, "viz", viz)
setattr(torch.Tensor, "viz", _viz)
setattr(torch.Tensor, "_tensorhue_to_numpy", _tensorhue_to_numpy_torch)
24 changes: 8 additions & 16 deletions tensorhue/_print_opts.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import dataclasses
from dataclasses import dataclass
from tensorhue.colors import ColorScheme


@dataclasses.dataclass
@dataclass
class __PrinterOptions:
threshold: float = 1000
edgeitems: int = 3
Expand All @@ -19,6 +19,7 @@ def set_printoptions(
edgeitems: int = None,
linewidth: int = None,
colorscheme: ColorScheme = None,
accessible: bool = False,
):
"""Set options for printing. Items shamelessly taken from NumPy
Expand All @@ -31,13 +32,8 @@ def set_printoptions(
inserting line breaks (default = 200). Thresholded matrices will
ignore this parameter.
colorscheme: The color scheme to use.
Example::
>>> TODO
accessible: Whether to use accessible mode or not (default = False).
"""

if threshold is not None:
assert isinstance(threshold, int)
PRINT_OPTS.threshold = threshold
Expand All @@ -53,11 +49,7 @@ def set_printoptions(
if colorscheme is not None:
assert isinstance(colorscheme, ColorScheme)
PRINT_OPTS.colorscheme = colorscheme


def _get_printoptions() -> dict[str, any]:
"""
Gets the current options for printing, as a dictionary that can be passed
as ``**kwargs`` to set_printoptions().
"""
return dataclasses.asdict(PRINT_OPTS)
if accessible:
raise NotImplementedError(
"Accessible mode is not implemented yet. If you have a vision impairment and can provide feedback or beta-test color schemes please let us know."
)
55 changes: 38 additions & 17 deletions tensorhue/colors.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,47 @@
from __future__ import annotations

from dataclasses import dataclass, field
from rich.color_triplet import ColorTriplet
import numpy as np
from matplotlib import colormaps
from matplotlib.colors import Colormap, Normalize


COLORS = {
"masked": ColorTriplet(140, 140, 140), # medium grey
"masked": ColorTriplet(127, 127, 127), # medium grey
"default_dark": ColorTriplet(64, 17, 159), # dark purple
"default_medium": ColorTriplet(255, 55, 140), # pink
"default_bright": ColorTriplet(255, 210, 240), # light rose
"true": ColorTriplet(125, 215, 82), # green
"false": ColorTriplet(255, 80, 80), # red
"accessible_true": ColorTriplet(255, 80, 80), # TODO
"accessible_false": ColorTriplet(125, 215, 82), # TODO
"accessible_true": ColorTriplet(255, 255, 255), # TODO
"accessible_false": ColorTriplet(0, 0, 0), # TODO
"black": ColorTriplet(0, 0, 0), # black
"white": ColorTriplet(255, 255, 255), # white
}


@dataclass
class ColorScheme:
_colormap: Colormap = field(default_factory=lambda: colormaps["magma"])
normalize: Normalize = field(default_factory=Normalize)
_masked_color: ColorTriplet = field(default_factory=lambda: COLORS["masked"])
true_color: ColorTriplet = field(default_factory=lambda: COLORS["true"])
false_color: ColorTriplet = field(default_factory=lambda: COLORS["false"])
_inf_color: ColorTriplet = field(default_factory=lambda: COLORS["white"])
_ninf_color: ColorTriplet = field(default_factory=lambda: COLORS["black"])
def __init__(
self,
colormap: Colormap = colormaps["magma"],
normalize: Normalize = Normalize(),
masked_color: ColorTriplet = COLORS["masked"],
true_color: ColorTriplet = COLORS["true"],
false_color: ColorTriplet = COLORS["false"],
inf_color: ColorTriplet = COLORS["white"],
ninf_color: ColorTriplet = COLORS["black"],
):
self._colormap = colormap
self.normalize = normalize
self._masked_color = masked_color
self.true_color = true_color
self.false_color = false_color
self._inf_color = inf_color
self._ninf_color = ninf_color

self.colormap.set_extremes(
bad=self.masked_color.normalized, under=self.ninf_color.normalized, over=self.inf_color.normalized
)

@property
def colormap(self):
Expand All @@ -38,7 +50,9 @@ def colormap(self):
@colormap.setter
def colormap(self, value):
self._colormap = value
self._colormap.set_extreme(bad=self._masked_color, under=self._ninf_color, over=self._inf_color)
self._colormap.set_extremes(
bad=self._masked_color.normalized, under=self._ninf_color.normalized, over=self._inf_color.normalized
)

@property
def masked_color(self):
Expand All @@ -47,7 +61,7 @@ def masked_color(self):
@masked_color.setter
def masked_color(self, value):
self._masked_color = value
self._colormap.set_bad(value)
self._colormap.set_bad(value.normalized)

@property
def inf_color(self):
Expand All @@ -56,7 +70,7 @@ def inf_color(self):
@inf_color.setter
def inf_color(self, value):
self._inf_color = value
self._colormap.set_over(value)
self._colormap.set_over(value.normalized)

@property
def ninf_color(self):
Expand All @@ -65,15 +79,22 @@ def ninf_color(self):
@ninf_color.setter
def ninf_color(self, value):
self._ninf_color = value
self._colormap.set_under(value)
self._colormap.set_under(value.normalized)

def __call__(self, data: np.ndarray) -> np.ndarray:
if data.dtype == "bool":
true_values = np.array(self.true_color, dtype=np.uint8)
false_values = np.array(self.false_color, dtype=np.uint8)
return np.where(data[..., np.newaxis], true_values, false_values)
data_noinf = np.where(np.isinf(data), np.nan, data)
self.normalize.vmin = np.nanmin(data_noinf)
self.normalize.vmax = np.nanmax(data_noinf)
return self.colormap(self.normalize(data), bytes=True)

def __repr__(self):
return (
f"ColorScheme(\n"
f" colormap={self.colormap},\n"
f" colormap={self._colormap},\n"
f" normalize={self.normalize},\n"
f" masked_color={self._masked_color},\n"
f" true_color={self.true_color},\n"
Expand Down
20 changes: 20 additions & 0 deletions tensorhue/eastereggs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import numpy as np
from matplotlib.colors import LinearSegmentedColormap
from rich.color_triplet import ColorTriplet
from tensorhue.colors import ColorScheme
from tensorhue.viz import viz


def pride():
pride_colors = [
ColorTriplet(228, 3, 3),
ColorTriplet(255, 140, 0),
ColorTriplet(255, 237, 0),
ColorTriplet(0, 128, 38),
ColorTriplet(0, 76, 255),
ColorTriplet(115, 41, 130),
]
pride_cm = LinearSegmentedColormap.from_list(colors=[c.normalized for c in pride_colors], name="pride")
pride_cs = ColorScheme(colormap=pride_cm)
arr = np.repeat(np.linspace(0, 1, 6).reshape(-1, 1), 10, axis=1)
viz(arr, colorscheme=pride_cs)
10 changes: 10 additions & 0 deletions tensorhue/numpy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import numpy as np


class NumpyArrayWrapper(np.ndarray):
def __new__(cls, input_array):
obj = np.asarray(input_array).view(cls)
return obj

def _tensorhue_to_numpy(self):
return self
17 changes: 17 additions & 0 deletions tensorhue/torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import torch
import numpy as np


def _tensorhue_to_numpy_torch(tensor: torch.Tensor) -> np.ndarray:
if isinstance(tensor, torch.masked.MaskedTensor):
return np.ma.masked_array(tensor.get_data(), torch.logical_not(tensor.get_mask()))
try:
return tensor.numpy()
except RuntimeError as e:
raise NotImplementedError(
f"{e}: It looks like tensors of type {type(tensor)} cannot be converted to numpy arrays out-of-the-box. Raise an issue if you need to visualize them."
) from e
except Exception as e:
raise RuntimeError(
f"An unexpected error occurred while converting tensor of type {type(tensor)} to numpy array: {e}"
) from e
55 changes: 55 additions & 0 deletions tensorhue/viz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from rich.console import Console
import numpy as np
from tensorhue.colors import ColorScheme
from tensorhue._print_opts import PRINT_OPTS
from tensorhue.numpy import NumpyArrayWrapper


def viz(tensor, *args, **kwargs):
if isinstance(tensor, np.ndarray):
tensor = NumpyArrayWrapper(tensor)
tensor.viz(*args, **kwargs) # pylint: disable=no-member
else:
try:
tensor.viz(*args, **kwargs)
except Exception as e:
raise NotImplementedError(
f"TensorHue does not support type {type(tensor)}. Raise an issue if you need to visualize them. Alternatively, check if you imported tensorhue *after* your other library."
) from e


def _viz(self, colorscheme: ColorScheme = None):
"""
Prints a tensor using colored Unicode art representation.
Args:
colorscheme (ColorScheme, optional): The color scheme to use.
Defaults to None, which means the global default color scheme is used.
"""
if colorscheme is None:
colorscheme = PRINT_OPTS.colorscheme

self = self._tensorhue_to_numpy()
shape = self.shape

if len(shape) > 2:
raise NotImplementedError(
"Visualization for tensors with more than 2 dimensions is under development. Please slice them for now."
)

colors = colorscheme(self)[..., :3]

result_lines = [""]
for y in range(0, shape[0] - 1, 2):
for x in range(shape[-1]):
result_lines[
-1
] += f"[rgb({colors[y, x, 0]},{colors[y, x, 1]},{colors[y, x, 2]}) on rgb({colors[y+1, x, 0]},{colors[y+1, x, 1]},{colors[y+1, x, 2]})]▀[/]"
result_lines.append("")

if shape[0] % 2 == 1:
for x in range(shape[1]):
result_lines[-1] += f"[rgb({colors[-1, x, 0]},{colors[-1, x, 1]},{colors[-1, x, 2]})]▀[/]"

c = Console(log_path=False, record=False)
c.print("\n".join(result_lines))

0 comments on commit 16db8d3

Please sign in to comment.