diff --git a/README.md b/README.md index f0031b1..6f5dfc4 100644 --- a/README.md +++ b/README.md @@ -37,3 +37,37 @@ That's it! You can now vizualize any tensor by calling .viz() in your Python con t = torch.rand(20,20) t.viz() ``` + + +## Custom colors + +You can pass along your own ColorScheme when visualizing a specific tensor: + +```python +from tensorhue import ColorScheme +from matplotlib import colormaps + +cs = ColorScheme(colormap=colormaps['inferno'], + true_color=(10,10,10), + false_color=(20,20,20)) +t.viz(cs) +``` + +Alternatively, you can overwrite the default ColorScheme: + + +```python +tensorhue.set_printoptions(colorscheme=cs) +``` + +## Advanced colors + +By default, TensorHue normalizes numerical values between 0 and 1 and then applies the matplotlib colormap. If you want to use diverging colormaps such as `coolwarm` or `bwr` and the value 0 to be mapped to the middle of the colormap, you need to specify the normailzer, e.g. `matplotlib.colors.CenteredNorm`: + +from matplotlib.colors import CenteredNorm + +```python +cs = ColorScheme(colormap=colormaps['bwr'], + normalize=CenteredNorm(vcenter=0)) +t.viz(cs) +``` diff --git a/tensorhue/__init__.py b/tensorhue/__init__.py index 42ea1b6..4793716 100644 --- a/tensorhue/__init__.py +++ b/tensorhue/__init__.py @@ -1,25 +1,27 @@ import sys from rich.console import Console import numpy as np -from tensorhue.colors import COLORS +from tensorhue.colors import COLORS, ColorScheme +from tensorhue._print_opts import PRINT_OPTS, set_printoptions + __version__ = "0.0.2" # single source of version truth -__all__ = [] +__all__ = ["set_printoptions"] -def viz(self) -> None: - """ - Prints the tensor using a colored Unicode art representation. - This method checks the type of the tensor and calls the `_viz_tensor` function with the appropriate colors. - """ - if isinstance(self, torch.FloatTensor): # pylint: disable=possibly-used-before-assignment - _viz_tensor(self) - elif isinstance(self, torch.BoolTensor): - _viz_tensor(self, (COLORS["false"], COLORS["true"])) +# def viz(self) -> None: +# """ +# Prints the tensor using a colored Unicode art representation. +# This method checks the type of the tensor and calls the `_viz_tensor` function with the appropriate colors. +# """ +# if isinstance(self, (torch.FloatTensor, torch.IntTensor, torch.LongTensor)): # pylint: disable=possibly-used-before-assignment +# _viz_tensor(self) +# elif isinstance(self, torch.BoolTensor): +# _viz_tensor(self, (COLORS["false"], COLORS["true"])) -def _viz_tensor(self, colors: tuple[tuple[int], tuple[int]] = None) -> None: +def viz(self, colorscheme: ColorScheme = None) -> None: """ Prints a tensor using colored Unicode art representation. @@ -33,59 +35,31 @@ def _viz_tensor(self, colors: tuple[tuple[int], tuple[int]] = None) -> None: The first tuple represents the RGB color for the smallest value in the tensor, the second tuple represents the RGB color for the biggest value of the tensor. (Default = None; uses default colors) """ - if colors is None: - colors = COLORS["default_dark"], COLORS["default_bright"] + if colorscheme is None: + colorscheme = PRINT_OPTS.colorscheme + data = self.data.numpy() shape = data.shape - color_a = np.array(colors[0]) - color_b = np.array(colors[1]) - color = ((1 - data[::2, :, None]) * color_a + data[::2, :, None] * color_b).astype(int) - bgcolor = ((1 - data[1::2, :, None]) * color_a + data[1::2, :, None] * color_b).astype(int) - result_parts = [] - for y in range(shape[0] // 2): - for x in range(shape[1]): - result_parts.append( - f"[rgb({color[y, x, 0]},{color[y, x, 1]},{color[y, x, 2]}) on rgb({bgcolor[y, x, 0]},{bgcolor[y, x, 1]},{bgcolor[y, x, 2]})]▀[/]" - ) - result_parts.append("\n") + colors = colorscheme(data)[..., :3] + + result_lines = [""] + for y in range(0, shape[0] - 1, 2): + for x in range(shape[-1]): + result_lines[ + -1 + ] += f"[rgb({colors[y, x, 0]},{colors[y, x, 1]},{colors[y, x, 2]}) on rgb({colors[y+1, x, 0]},{colors[y+1, x, 1]},{colors[y+1, x, 2]})]▀[/]" + result_lines.append("") + if shape[0] % 2 == 1: for x in range(shape[1]): - result_parts.append(f"[rgb({color[-1, x, 0]},{color[-1, x, 1]},{color[-1, x, 2]})]▀[/]") + result_lines[-1] += f"[rgb({colors[-1, x, 0]},{colors[-1, x, 1]},{colors[-1, x, 2]})]▀[/]" c = Console(log_path=False, record=False) - c.print("".join(result_parts)) - return "".join(result_parts) + c.print("\n".join(result_lines)) # automagically set up TensorHue if "torch" in sys.modules: torch = sys.modules["torch"] setattr(torch.Tensor, "viz", viz) - - -# def _viz_tensor_alt(self, colors: tuple[tuple[int], tuple[int]] = None) -> None: -# if colors is None: -# colors = COLORS["default_dark"], COLORS["default_bright"] -# data = self.data.numpy() -# shape = data.shape -# dim = data.ndim -# color_a = np.array(colors[0]) -# color_b = np.array(colors[1]) -# color = ((1 - data[::2, :, None]) * color_a + data[::2, :, None] * color_b).astype(int) -# bgcolor = ((1 - data[1::2, :, None]) * color_a + data[1::2, :, None] * color_b).astype(int) - -# result_parts = [] -# for y in range(shape[0] // 2): -# for x in range(shape[1]): -# result_parts.append( -# f"[rgb({color[y, x, 0]},{color[y, x, 1]},{color[y, x, 2]}) on rgb({bgcolor[y, x, 0]},{bgcolor[y, x, 1]},{bgcolor[y, x, 2]})]▀[/]" -# ) -# result_parts.append("\n") -# if shape[0] % 2 == 1: -# for x in range(shape[1]): -# result_parts.append(f"[rgb({color[-1, x, 0]},{color[-1, x, 1]},{color[-1, x, 2]})]▀[/]") - -# c = Console(log_path=False, record=False) -# c.print("".join(result_parts)) -# # return "".join(result_parts) diff --git a/tensorhue/_print_opts.py b/tensorhue/_print_opts.py new file mode 100644 index 0000000..4116331 --- /dev/null +++ b/tensorhue/_print_opts.py @@ -0,0 +1,63 @@ +import dataclasses +from tensorhue.colors import ColorScheme + + +@dataclasses.dataclass +class __PrinterOptions: + threshold: float = 1000 + edgeitems: int = 3 + linewidth: int = 200 + colorscheme: ColorScheme = ColorScheme() + + +PRINT_OPTS = __PrinterOptions() + + +# We could use **kwargs, but this will give better docs +def set_printoptions( + threshold: int = None, + edgeitems: int = None, + linewidth: int = None, + colorscheme: ColorScheme = None, +): + """Set options for printing. Items shamelessly taken from NumPy + + Args: + threshold: Total number of array elements which trigger summarization + rather than full `repr` (default = 1000). + edgeitems: Number of array items in summary at beginning and end of + each dimension (default = 3). + linewidth: The number of characters per line for the purpose of + inserting line breaks (default = 200). Thresholded matrices will + ignore this parameter. + colorscheme: The color scheme to use. + + Example:: + + >>> TODO + + """ + + if threshold is not None: + assert isinstance(threshold, int) + PRINT_OPTS.threshold = threshold + if edgeitems is not None: + assert isinstance(edgeitems, int) + assert ( + edgeitems <= PRINT_OPTS.threshold // 2 + ), "edgeitems should not be larger than half the summarization threshold" + PRINT_OPTS.edgeitems = edgeitems + if linewidth is not None: + assert isinstance(linewidth, int) + PRINT_OPTS.linewidth = linewidth + if colorscheme is not None: + assert isinstance(colorscheme, ColorScheme) + PRINT_OPTS.colorscheme = colorscheme + + +def _get_printoptions() -> dict[str, any]: + """ + Gets the current options for printing, as a dictionary that can be passed + as ``**kwargs`` to set_printoptions(). + """ + return dataclasses.asdict(PRINT_OPTS) diff --git a/tensorhue/colors.py b/tensorhue/colors.py index 6c7eca2..867898f 100644 --- a/tensorhue/colors.py +++ b/tensorhue/colors.py @@ -1,11 +1,10 @@ from __future__ import annotations from dataclasses import dataclass, field -from collections.abc import Iterator from rich.color_triplet import ColorTriplet import numpy as np -from numpy.typing import NDArray -from scipy.interpolate import interp1d +from matplotlib import colormaps +from matplotlib.colors import Colormap, Normalize COLORS = { @@ -13,8 +12,8 @@ "default_dark": ColorTriplet(64, 17, 159), # dark purple "default_medium": ColorTriplet(255, 55, 140), # pink "default_bright": ColorTriplet(255, 210, 240), # light rose - "true": ColorTriplet(255, 80, 80), # green - "false": ColorTriplet(125, 215, 82), # red + "true": ColorTriplet(125, 215, 82), # green + "false": ColorTriplet(255, 80, 80), # red "accessible_true": ColorTriplet(255, 80, 80), # TODO "accessible_false": ColorTriplet(125, 215, 82), # TODO "black": ColorTriplet(0, 0, 0), # black @@ -23,58 +22,63 @@ @dataclass -class ColorGradient: - gradient: list[tuple[float, ColorTriplet]] = field( - default_factory=lambda: [ - (0.0, COLORS["default_dark"]), - (0.7, COLORS["default_medium"]), - (1.0, COLORS["default_bright"]), - ] - ) +class ColorScheme: + _colormap: Colormap = field(default_factory=lambda: colormaps["magma"]) + normalize: Normalize = field(default_factory=Normalize) + _masked_color: ColorTriplet = field(default_factory=lambda: COLORS["masked"]) + true_color: ColorTriplet = field(default_factory=lambda: COLORS["true"]) + false_color: ColorTriplet = field(default_factory=lambda: COLORS["false"]) + _inf_color: ColorTriplet = field(default_factory=lambda: COLORS["white"]) + _ninf_color: ColorTriplet = field(default_factory=lambda: COLORS["black"]) - def __post_init__(self): - if len(self.gradient) < 2: - raise ValueError("ColorGradient must have at least 2 points") - self.gradient = [(float(pos), color) for pos, color in self.gradient] - pos_values = [pos for pos, _ in self.gradient] - if len(set(pos_values)) != len(pos_values): - raise ValueError("ColorGradient must have unique position values") - if 0.0 not in pos_values: - raise ValueError("ColorGradient must include a color for position 0.0") - if 1.0 not in pos_values: - raise ValueError("ColorGradient must include a color for position 1.0") - if max(pos_values) > 1.0 or min(pos_values) < 0.0: - raise ValueError("ColorGradient positions must be between 0.0 and 1.0") - self.gradient = sorted(self.gradient, key=lambda x: x[0]) + @property + def colormap(self): + return self._colormap - def __iter__(self) -> Iterator[tuple[float, ColorTriplet]]: - return iter(self.gradient) + @colormap.setter + def colormap(self, value): + self._colormap = value + self._colormap.set_extreme(bad=self._masked_color, under=self._ninf_color, over=self._inf_color) + @property + def masked_color(self): + return self._masked_color -@dataclass -class ColorScheme: - gradient: ColorGradient = field(default_factory=ColorGradient) - masked_color: ColorTriplet = field(default_factory=lambda: COLORS["masked"]) - true_color: ColorTriplet = field(default_factory=lambda: COLORS["true"]) - false_color: ColorTriplet = field(default_factory=lambda: COLORS["false"]) - inf_color: ColorTriplet = field(default_factory=lambda: COLORS["white"]) - ninf_color: ColorTriplet = field(default_factory=lambda: COLORS["black"]) + @masked_color.setter + def masked_color(self, value): + self._masked_color = value + self._colormap.set_bad(value) + + @property + def inf_color(self): + return self._inf_color + + @inf_color.setter + def inf_color(self, value): + self._inf_color = value + self._colormap.set_over(value) + + @property + def ninf_color(self): + return self._ninf_color - def calculate_gradient_color_vectorized( - self, value_array: NDArray[np.number] | NDArray[bool] - ) -> NDArray[np.uint8] | any: - """ - Calculate the gradient color for each value in the input array using vectorized interpolation. + @ninf_color.setter + def ninf_color(self, value): + self._ninf_color = value + self._colormap.set_under(value) - Args: - value_array (NDArray[np.float64]): The input array of values. + def __call__(self, data: np.ndarray) -> np.ndarray: + return self.colormap(self.normalize(data), bytes=True) - Returns: - Union[Any, NDArray[np.uint8]]: The calculated gradient color for each value in the input array. - """ - positions = [pos for pos, _ in self.gradient] - color_data = np.array([[color.red, color.green, color.blue] for _, color in self.gradient]) - interp_functions = [interp1d(positions, channel) for channel in color_data.T] - flat_values = value_array.flatten() - interpolated_colors = np.stack([func(flat_values) for func in interp_functions], axis=0) - return interpolated_colors.reshape((3,) + value_array.shape).astype(np.uint8) + def __repr__(self): + return ( + f"ColorScheme(\n" + f" colormap={self.colormap},\n" + f" normalize={self.normalize},\n" + f" masked_color={self._masked_color},\n" + f" true_color={self.true_color},\n" + f" false_color={self.false_color},\n" + f" inf_color={self._inf_color},\n" + f" ninf_color={self._ninf_color}\n" + f")" + ) diff --git a/tests/test_colors.py b/tests/test_colors.py index f3ba479..78f6148 100644 --- a/tests/test_colors.py +++ b/tests/test_colors.py @@ -11,52 +11,20 @@ def test_COLORS(): assert isinstance(value, ColorTriplet) -def test_ColorGradient(): - with pytest.raises(ValueError): - ColorGradient(gradient=[(0.0, ColorTriplet(0, 0, 0))]) - - with pytest.raises(ValueError): - ColorGradient(gradient=[(0.0, ColorTriplet(0, 0, 0)), (0.0, ColorTriplet(0, 0, 0))]) - - with pytest.raises(ValueError): - ColorGradient(gradient=[(0.5, ColorTriplet(0, 0, 0)), (1.0, ColorTriplet(0, 0, 0))]) - - with pytest.raises(ValueError): - ColorGradient(gradient=[(0.0, ColorTriplet(0, 0, 0)), (0.5, ColorTriplet(0, 0, 0))]) - - with pytest.raises(ValueError): - ColorGradient( - gradient=[(0.0, ColorTriplet(0, 0, 0)), (1.0, ColorTriplet(0, 0, 0)), (2.0, ColorTriplet(0, 0, 0))] - ) - - cg = ColorGradient() - assert cg.gradient == [ - (0, COLORS["default_dark"]), - (0.7, COLORS["default_medium"]), - (1.0, COLORS["default_bright"]), - ] - - cg = ColorGradient(gradient=[(0.0, COLORS["white"]), (1.0, COLORS["black"])]) - assert cg.gradient == [(0.0, COLORS["white"]), (1.0, COLORS["black"])] - - cg = ColorGradient(gradient=[(0.0, ColorTriplet(255, 255, 255)), (1.0, ColorTriplet(0, 0, 0))]) - assert cg.gradient == [(0.0, COLORS["white"]), (1.0, COLORS["black"])] - - -def test_ColorScheme(): - cs = ColorScheme( - gradient=ColorGradient( - [(0.0, ColorTriplet(0, 0, 0)), (0.5, ColorTriplet(255, 255, 255)), (1.0, ColorTriplet(0, 0, 0))] - ) - ) - assert cs.gradient.gradient == [ - (0.0, ColorTriplet(0, 0, 0)), - (0.5, ColorTriplet(255, 255, 255)), - (1.0, ColorTriplet(0, 0, 0)), - ] - assert cs.masked_color == COLORS["masked"] - assert cs.true_color == COLORS["true"] - - values = np.array([0.0, 0.5, 0.75]) - result = cs.calculate_gradient_color_vectorized(values) - assert np.array_equal(result, np.array([[0, 255, 127], [0, 255, 127], [0, 255, 127]])) +# def test_ColorScheme(): +# cs = ColorScheme( +# gradient=ColorGradient( +# [(0.0, ColorTriplet(0, 0, 0)), (0.5, ColorTriplet(255, 255, 255)), (1.0, ColorTriplet(0, 0, 0))] +# ) +# ) +# assert cs.gradient.gradient == [ +# (0.0, ColorTriplet(0, 0, 0)), +# (0.5, ColorTriplet(255, 255, 255)), +# (1.0, ColorTriplet(0, 0, 0)), +# ] +# assert cs.masked_color == COLORS["masked"] +# assert cs.true_color == COLORS["true"] + +# values = np.array([0.0, 0.5, 0.75]) +# result = cs.calculate_gradient_color_vectorized(values) +# assert np.array_equal(result, np.array([[0, 255, 127], [0, 255, 127], [0, 255, 127]]))