diff --git a/tensorhue/__init__.py b/tensorhue/__init__.py index 4793716..46f4e2a 100644 --- a/tensorhue/__init__.py +++ b/tensorhue/__init__.py @@ -3,63 +3,19 @@ import numpy as np from tensorhue.colors import COLORS, ColorScheme from tensorhue._print_opts import PRINT_OPTS, set_printoptions +from tensorhue.numpy import NumpyArrayWrapper +from tensorhue.torch import _tensorhue_to_numpy_torch +from tensorhue.eastereggs import pride +from tensorhue.viz import viz, _viz -__version__ = "0.0.2" # single source of version truth - -__all__ = ["set_printoptions"] - - -# def viz(self) -> None: -# """ -# Prints the tensor using a colored Unicode art representation. -# This method checks the type of the tensor and calls the `_viz_tensor` function with the appropriate colors. -# """ -# if isinstance(self, (torch.FloatTensor, torch.IntTensor, torch.LongTensor)): # pylint: disable=possibly-used-before-assignment -# _viz_tensor(self) -# elif isinstance(self, torch.BoolTensor): -# _viz_tensor(self, (COLORS["false"], COLORS["true"])) - - -def viz(self, colorscheme: ColorScheme = None) -> None: - """ - Prints a tensor using colored Unicode art representation. - - This function takes a tensor and a tuple of two tuples of integers representing the colors. - It converts the tensor data to a numpy array, calculates the colors for each element based on the input colors, - and generates a string representation of the tensor using the calculated colors. - The resulting string is then printed using the Console class from the rich library. - - Parameters: - colors (tuple[tuple[int], tuple[int]]): A tuple of two RGB tuples representing the colors. - The first tuple represents the RGB color for the smallest value in the tensor, the second tuple represents - the RGB color for the biggest value of the tensor. (Default = None; uses default colors) - """ - if colorscheme is None: - colorscheme = PRINT_OPTS.colorscheme - - data = self.data.numpy() - shape = data.shape - - colors = colorscheme(data)[..., :3] - - result_lines = [""] - for y in range(0, shape[0] - 1, 2): - for x in range(shape[-1]): - result_lines[ - -1 - ] += f"[rgb({colors[y, x, 0]},{colors[y, x, 1]},{colors[y, x, 2]}) on rgb({colors[y+1, x, 0]},{colors[y+1, x, 1]},{colors[y+1, x, 2]})]▀[/]" - result_lines.append("") - - if shape[0] % 2 == 1: - for x in range(shape[1]): - result_lines[-1] += f"[rgb({colors[-1, x, 0]},{colors[-1, x, 1]},{colors[-1, x, 2]})]▀[/]" - - c = Console(log_path=False, record=False) - c.print("\n".join(result_lines)) +__version__ = "0.0.3" # single source of version truth +__all__ = ["set_printoptions", "viz", "pride"] # automagically set up TensorHue +setattr(NumpyArrayWrapper, "viz", _viz) if "torch" in sys.modules: torch = sys.modules["torch"] - setattr(torch.Tensor, "viz", viz) + setattr(torch.Tensor, "viz", _viz) + setattr(torch.Tensor, "_tensorhue_to_numpy", _tensorhue_to_numpy_torch) diff --git a/tensorhue/_print_opts.py b/tensorhue/_print_opts.py index 4116331..72192d3 100644 --- a/tensorhue/_print_opts.py +++ b/tensorhue/_print_opts.py @@ -1,8 +1,8 @@ -import dataclasses +from dataclasses import dataclass from tensorhue.colors import ColorScheme -@dataclasses.dataclass +@dataclass class __PrinterOptions: threshold: float = 1000 edgeitems: int = 3 @@ -19,6 +19,7 @@ def set_printoptions( edgeitems: int = None, linewidth: int = None, colorscheme: ColorScheme = None, + accessible: bool = False, ): """Set options for printing. Items shamelessly taken from NumPy @@ -31,13 +32,8 @@ def set_printoptions( inserting line breaks (default = 200). Thresholded matrices will ignore this parameter. colorscheme: The color scheme to use. - - Example:: - - >>> TODO - + accessible: Whether to use accessible mode or not (default = False). """ - if threshold is not None: assert isinstance(threshold, int) PRINT_OPTS.threshold = threshold @@ -53,11 +49,7 @@ def set_printoptions( if colorscheme is not None: assert isinstance(colorscheme, ColorScheme) PRINT_OPTS.colorscheme = colorscheme - - -def _get_printoptions() -> dict[str, any]: - """ - Gets the current options for printing, as a dictionary that can be passed - as ``**kwargs`` to set_printoptions(). - """ - return dataclasses.asdict(PRINT_OPTS) + if accessible: + raise NotImplementedError( + "Accessible mode is not implemented yet. If you have a vision impairment and can provide feedback or beta-test color schemes please let us know." + ) diff --git a/tensorhue/colors.py b/tensorhue/colors.py index 867898f..e6955b7 100644 --- a/tensorhue/colors.py +++ b/tensorhue/colors.py @@ -1,6 +1,5 @@ from __future__ import annotations -from dataclasses import dataclass, field from rich.color_triplet import ColorTriplet import numpy as np from matplotlib import colormaps @@ -8,28 +7,41 @@ COLORS = { - "masked": ColorTriplet(140, 140, 140), # medium grey + "masked": ColorTriplet(127, 127, 127), # medium grey "default_dark": ColorTriplet(64, 17, 159), # dark purple "default_medium": ColorTriplet(255, 55, 140), # pink "default_bright": ColorTriplet(255, 210, 240), # light rose "true": ColorTriplet(125, 215, 82), # green "false": ColorTriplet(255, 80, 80), # red - "accessible_true": ColorTriplet(255, 80, 80), # TODO - "accessible_false": ColorTriplet(125, 215, 82), # TODO + "accessible_true": ColorTriplet(255, 255, 255), # TODO + "accessible_false": ColorTriplet(0, 0, 0), # TODO "black": ColorTriplet(0, 0, 0), # black "white": ColorTriplet(255, 255, 255), # white } -@dataclass class ColorScheme: - _colormap: Colormap = field(default_factory=lambda: colormaps["magma"]) - normalize: Normalize = field(default_factory=Normalize) - _masked_color: ColorTriplet = field(default_factory=lambda: COLORS["masked"]) - true_color: ColorTriplet = field(default_factory=lambda: COLORS["true"]) - false_color: ColorTriplet = field(default_factory=lambda: COLORS["false"]) - _inf_color: ColorTriplet = field(default_factory=lambda: COLORS["white"]) - _ninf_color: ColorTriplet = field(default_factory=lambda: COLORS["black"]) + def __init__( + self, + colormap: Colormap = colormaps["magma"], + normalize: Normalize = Normalize(), + masked_color: ColorTriplet = COLORS["masked"], + true_color: ColorTriplet = COLORS["true"], + false_color: ColorTriplet = COLORS["false"], + inf_color: ColorTriplet = COLORS["white"], + ninf_color: ColorTriplet = COLORS["black"], + ): + self._colormap = colormap + self.normalize = normalize + self._masked_color = masked_color + self.true_color = true_color + self.false_color = false_color + self._inf_color = inf_color + self._ninf_color = ninf_color + + self.colormap.set_extremes( + bad=self.masked_color.normalized, under=self.ninf_color.normalized, over=self.inf_color.normalized + ) @property def colormap(self): @@ -38,7 +50,9 @@ def colormap(self): @colormap.setter def colormap(self, value): self._colormap = value - self._colormap.set_extreme(bad=self._masked_color, under=self._ninf_color, over=self._inf_color) + self._colormap.set_extremes( + bad=self._masked_color.normalized, under=self._ninf_color.normalized, over=self._inf_color.normalized + ) @property def masked_color(self): @@ -47,7 +61,7 @@ def masked_color(self): @masked_color.setter def masked_color(self, value): self._masked_color = value - self._colormap.set_bad(value) + self._colormap.set_bad(value.normalized) @property def inf_color(self): @@ -56,7 +70,7 @@ def inf_color(self): @inf_color.setter def inf_color(self, value): self._inf_color = value - self._colormap.set_over(value) + self._colormap.set_over(value.normalized) @property def ninf_color(self): @@ -65,15 +79,22 @@ def ninf_color(self): @ninf_color.setter def ninf_color(self, value): self._ninf_color = value - self._colormap.set_under(value) + self._colormap.set_under(value.normalized) def __call__(self, data: np.ndarray) -> np.ndarray: + if data.dtype == "bool": + true_values = np.array(self.true_color, dtype=np.uint8) + false_values = np.array(self.false_color, dtype=np.uint8) + return np.where(data[..., np.newaxis], true_values, false_values) + data_noinf = np.where(np.isinf(data), np.nan, data) + self.normalize.vmin = np.nanmin(data_noinf) + self.normalize.vmax = np.nanmax(data_noinf) return self.colormap(self.normalize(data), bytes=True) def __repr__(self): return ( f"ColorScheme(\n" - f" colormap={self.colormap},\n" + f" colormap={self._colormap},\n" f" normalize={self.normalize},\n" f" masked_color={self._masked_color},\n" f" true_color={self.true_color},\n" diff --git a/tensorhue/eastereggs.py b/tensorhue/eastereggs.py new file mode 100644 index 0000000..e19548a --- /dev/null +++ b/tensorhue/eastereggs.py @@ -0,0 +1,20 @@ +import numpy as np +from matplotlib.colors import LinearSegmentedColormap +from rich.color_triplet import ColorTriplet +from tensorhue.colors import ColorScheme +from tensorhue.viz import viz + + +def pride(): + pride_colors = [ + ColorTriplet(228, 3, 3), + ColorTriplet(255, 140, 0), + ColorTriplet(255, 237, 0), + ColorTriplet(0, 128, 38), + ColorTriplet(0, 76, 255), + ColorTriplet(115, 41, 130), + ] + pride_cm = LinearSegmentedColormap.from_list(colors=[c.normalized for c in pride_colors], name="pride") + pride_cs = ColorScheme(colormap=pride_cm) + arr = np.repeat(np.linspace(0, 1, 6).reshape(-1, 1), 10, axis=1) + viz(arr, colorscheme=pride_cs) diff --git a/tensorhue/numpy.py b/tensorhue/numpy.py new file mode 100644 index 0000000..3425f96 --- /dev/null +++ b/tensorhue/numpy.py @@ -0,0 +1,10 @@ +import numpy as np + + +class NumpyArrayWrapper(np.ndarray): + def __new__(cls, input_array): + obj = np.asarray(input_array).view(cls) + return obj + + def _tensorhue_to_numpy(self): + return self diff --git a/tensorhue/torch.py b/tensorhue/torch.py new file mode 100644 index 0000000..0d643f9 --- /dev/null +++ b/tensorhue/torch.py @@ -0,0 +1,17 @@ +import torch +import numpy as np + + +def _tensorhue_to_numpy_torch(tensor: torch.Tensor) -> np.ndarray: + if isinstance(tensor, torch.masked.MaskedTensor): + return np.ma.masked_array(tensor.get_data(), torch.logical_not(tensor.get_mask())) + try: + return tensor.numpy() + except RuntimeError as e: + raise NotImplementedError( + f"{e}: It looks like tensors of type {type(tensor)} cannot be converted to numpy arrays out-of-the-box. Raise an issue if you need to visualize them." + ) from e + except Exception as e: + raise RuntimeError( + f"An unexpected error occurred while converting tensor of type {type(tensor)} to numpy array: {e}" + ) from e diff --git a/tensorhue/viz.py b/tensorhue/viz.py new file mode 100644 index 0000000..a490909 --- /dev/null +++ b/tensorhue/viz.py @@ -0,0 +1,55 @@ +from rich.console import Console +import numpy as np +from tensorhue.colors import ColorScheme +from tensorhue._print_opts import PRINT_OPTS +from tensorhue.numpy import NumpyArrayWrapper + + +def viz(tensor, *args, **kwargs): + if isinstance(tensor, np.ndarray): + tensor = NumpyArrayWrapper(tensor) + tensor.viz(*args, **kwargs) # pylint: disable=no-member + else: + try: + tensor.viz(*args, **kwargs) + except Exception as e: + raise NotImplementedError( + f"TensorHue does not support type {type(tensor)}. Raise an issue if you need to visualize them. Alternatively, check if you imported tensorhue *after* your other library." + ) from e + + +def _viz(self, colorscheme: ColorScheme = None): + """ + Prints a tensor using colored Unicode art representation. + + Args: + colorscheme (ColorScheme, optional): The color scheme to use. + Defaults to None, which means the global default color scheme is used. + """ + if colorscheme is None: + colorscheme = PRINT_OPTS.colorscheme + + self = self._tensorhue_to_numpy() + shape = self.shape + + if len(shape) > 2: + raise NotImplementedError( + "Visualization for tensors with more than 2 dimensions is under development. Please slice them for now." + ) + + colors = colorscheme(self)[..., :3] + + result_lines = [""] + for y in range(0, shape[0] - 1, 2): + for x in range(shape[-1]): + result_lines[ + -1 + ] += f"[rgb({colors[y, x, 0]},{colors[y, x, 1]},{colors[y, x, 2]}) on rgb({colors[y+1, x, 0]},{colors[y+1, x, 1]},{colors[y+1, x, 2]})]▀[/]" + result_lines.append("") + + if shape[0] % 2 == 1: + for x in range(shape[1]): + result_lines[-1] += f"[rgb({colors[-1, x, 0]},{colors[-1, x, 1]},{colors[-1, x, 2]})]▀[/]" + + c = Console(log_path=False, record=False) + c.print("\n".join(result_lines))