Skip to content

Commit

Permalink
add matplotlib Colormap instead of ColorGradient
Browse files Browse the repository at this point in the history
  • Loading branch information
epistoteles committed Jun 16, 2024
1 parent 3305871 commit 766c3ff
Show file tree
Hide file tree
Showing 5 changed files with 200 additions and 157 deletions.
34 changes: 34 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,37 @@ That's it! You can now vizualize any tensor by calling .viz() in your Python con
t = torch.rand(20,20)
t.viz()
```


## Custom colors

You can pass along your own ColorScheme when visualizing a specific tensor:

```python
from tensorhue import ColorScheme
from matplotlib import colormaps

cs = ColorScheme(colormap=colormaps['inferno'],
true_color=(10,10,10),
false_color=(20,20,20))
t.viz(cs)
```

Alternatively, you can overwrite the default ColorScheme:


```python
tensorhue.set_printoptions(colorscheme=cs)
```

## Advanced colors

By default, TensorHue normalizes numerical values between 0 and 1 and then applies the matplotlib colormap. If you want to use diverging colormaps such as `coolwarm` or `bwr` and the value 0 to be mapped to the middle of the colormap, you need to specify the normailzer, e.g. `matplotlib.colors.CenteredNorm`:

from matplotlib.colors import CenteredNorm

```python
cs = ColorScheme(colormap=colormaps['bwr'],
normalize=CenteredNorm(vcenter=0))
t.viz(cs)
```
84 changes: 29 additions & 55 deletions tensorhue/__init__.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,27 @@
import sys
from rich.console import Console
import numpy as np
from tensorhue.colors import COLORS
from tensorhue.colors import COLORS, ColorScheme
from tensorhue._print_opts import PRINT_OPTS, set_printoptions


__version__ = "0.0.2" # single source of version truth

__all__ = []
__all__ = ["set_printoptions"]


def viz(self) -> None:
"""
Prints the tensor using a colored Unicode art representation.
This method checks the type of the tensor and calls the `_viz_tensor` function with the appropriate colors.
"""
if isinstance(self, torch.FloatTensor): # pylint: disable=possibly-used-before-assignment
_viz_tensor(self)
elif isinstance(self, torch.BoolTensor):
_viz_tensor(self, (COLORS["false"], COLORS["true"]))
# def viz(self) -> None:
# """
# Prints the tensor using a colored Unicode art representation.
# This method checks the type of the tensor and calls the `_viz_tensor` function with the appropriate colors.
# """
# if isinstance(self, (torch.FloatTensor, torch.IntTensor, torch.LongTensor)): # pylint: disable=possibly-used-before-assignment
# _viz_tensor(self)
# elif isinstance(self, torch.BoolTensor):
# _viz_tensor(self, (COLORS["false"], COLORS["true"]))


def _viz_tensor(self, colors: tuple[tuple[int], tuple[int]] = None) -> None:
def viz(self, colorscheme: ColorScheme = None) -> None:
"""
Prints a tensor using colored Unicode art representation.
Expand All @@ -33,59 +35,31 @@ def _viz_tensor(self, colors: tuple[tuple[int], tuple[int]] = None) -> None:
The first tuple represents the RGB color for the smallest value in the tensor, the second tuple represents
the RGB color for the biggest value of the tensor. (Default = None; uses default colors)
"""
if colors is None:
colors = COLORS["default_dark"], COLORS["default_bright"]
if colorscheme is None:
colorscheme = PRINT_OPTS.colorscheme

data = self.data.numpy()
shape = data.shape
color_a = np.array(colors[0])
color_b = np.array(colors[1])
color = ((1 - data[::2, :, None]) * color_a + data[::2, :, None] * color_b).astype(int)
bgcolor = ((1 - data[1::2, :, None]) * color_a + data[1::2, :, None] * color_b).astype(int)

result_parts = []
for y in range(shape[0] // 2):
for x in range(shape[1]):
result_parts.append(
f"[rgb({color[y, x, 0]},{color[y, x, 1]},{color[y, x, 2]}) on rgb({bgcolor[y, x, 0]},{bgcolor[y, x, 1]},{bgcolor[y, x, 2]})]▀[/]"
)
result_parts.append("\n")
colors = colorscheme(data)[..., :3]

result_lines = [""]
for y in range(0, shape[0] - 1, 2):
for x in range(shape[-1]):
result_lines[
-1
] += f"[rgb({colors[y, x, 0]},{colors[y, x, 1]},{colors[y, x, 2]}) on rgb({colors[y+1, x, 0]},{colors[y+1, x, 1]},{colors[y+1, x, 2]})]▀[/]"
result_lines.append("")

if shape[0] % 2 == 1:
for x in range(shape[1]):
result_parts.append(f"[rgb({color[-1, x, 0]},{color[-1, x, 1]},{color[-1, x, 2]})]▀[/]")
result_lines[-1] += f"[rgb({colors[-1, x, 0]},{colors[-1, x, 1]},{colors[-1, x, 2]})]▀[/]"

c = Console(log_path=False, record=False)
c.print("".join(result_parts))
return "".join(result_parts)
c.print("\n".join(result_lines))


# automagically set up TensorHue
if "torch" in sys.modules:
torch = sys.modules["torch"]
setattr(torch.Tensor, "viz", viz)


# def _viz_tensor_alt(self, colors: tuple[tuple[int], tuple[int]] = None) -> None:
# if colors is None:
# colors = COLORS["default_dark"], COLORS["default_bright"]
# data = self.data.numpy()
# shape = data.shape
# dim = data.ndim
# color_a = np.array(colors[0])
# color_b = np.array(colors[1])
# color = ((1 - data[::2, :, None]) * color_a + data[::2, :, None] * color_b).astype(int)
# bgcolor = ((1 - data[1::2, :, None]) * color_a + data[1::2, :, None] * color_b).astype(int)

# result_parts = []
# for y in range(shape[0] // 2):
# for x in range(shape[1]):
# result_parts.append(
# f"[rgb({color[y, x, 0]},{color[y, x, 1]},{color[y, x, 2]}) on rgb({bgcolor[y, x, 0]},{bgcolor[y, x, 1]},{bgcolor[y, x, 2]})]▀[/]"
# )
# result_parts.append("\n")
# if shape[0] % 2 == 1:
# for x in range(shape[1]):
# result_parts.append(f"[rgb({color[-1, x, 0]},{color[-1, x, 1]},{color[-1, x, 2]})]▀[/]")

# c = Console(log_path=False, record=False)
# c.print("".join(result_parts))
# # return "".join(result_parts)
63 changes: 63 additions & 0 deletions tensorhue/_print_opts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import dataclasses
from tensorhue.colors import ColorScheme


@dataclasses.dataclass
class __PrinterOptions:
threshold: float = 1000
edgeitems: int = 3
linewidth: int = 200
colorscheme: ColorScheme = ColorScheme()


PRINT_OPTS = __PrinterOptions()


# We could use **kwargs, but this will give better docs
def set_printoptions(
threshold: int = None,
edgeitems: int = None,
linewidth: int = None,
colorscheme: ColorScheme = None,
):
"""Set options for printing. Items shamelessly taken from NumPy
Args:
threshold: Total number of array elements which trigger summarization
rather than full `repr` (default = 1000).
edgeitems: Number of array items in summary at beginning and end of
each dimension (default = 3).
linewidth: The number of characters per line for the purpose of
inserting line breaks (default = 200). Thresholded matrices will
ignore this parameter.
colorscheme: The color scheme to use.
Example::
>>> TODO
"""

if threshold is not None:
assert isinstance(threshold, int)
PRINT_OPTS.threshold = threshold
if edgeitems is not None:
assert isinstance(edgeitems, int)
assert (
edgeitems <= PRINT_OPTS.threshold // 2
), "edgeitems should not be larger than half the summarization threshold"
PRINT_OPTS.edgeitems = edgeitems
if linewidth is not None:
assert isinstance(linewidth, int)
PRINT_OPTS.linewidth = linewidth
if colorscheme is not None:
assert isinstance(colorscheme, ColorScheme)
PRINT_OPTS.colorscheme = colorscheme


def _get_printoptions() -> dict[str, any]:
"""
Gets the current options for printing, as a dictionary that can be passed
as ``**kwargs`` to set_printoptions().
"""
return dataclasses.asdict(PRINT_OPTS)
110 changes: 57 additions & 53 deletions tensorhue/colors.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
from __future__ import annotations

from dataclasses import dataclass, field
from collections.abc import Iterator
from rich.color_triplet import ColorTriplet
import numpy as np
from numpy.typing import NDArray
from scipy.interpolate import interp1d
from matplotlib import colormaps
from matplotlib.colors import Colormap, Normalize


COLORS = {
"masked": ColorTriplet(140, 140, 140), # medium grey
"default_dark": ColorTriplet(64, 17, 159), # dark purple
"default_medium": ColorTriplet(255, 55, 140), # pink
"default_bright": ColorTriplet(255, 210, 240), # light rose
"true": ColorTriplet(255, 80, 80), # green
"false": ColorTriplet(125, 215, 82), # red
"true": ColorTriplet(125, 215, 82), # green
"false": ColorTriplet(255, 80, 80), # red
"accessible_true": ColorTriplet(255, 80, 80), # TODO
"accessible_false": ColorTriplet(125, 215, 82), # TODO
"black": ColorTriplet(0, 0, 0), # black
Expand All @@ -23,58 +22,63 @@


@dataclass
class ColorGradient:
gradient: list[tuple[float, ColorTriplet]] = field(
default_factory=lambda: [
(0.0, COLORS["default_dark"]),
(0.7, COLORS["default_medium"]),
(1.0, COLORS["default_bright"]),
]
)
class ColorScheme:
_colormap: Colormap = field(default_factory=lambda: colormaps["magma"])
normalize: Normalize = field(default_factory=Normalize)
_masked_color: ColorTriplet = field(default_factory=lambda: COLORS["masked"])
true_color: ColorTriplet = field(default_factory=lambda: COLORS["true"])
false_color: ColorTriplet = field(default_factory=lambda: COLORS["false"])
_inf_color: ColorTriplet = field(default_factory=lambda: COLORS["white"])
_ninf_color: ColorTriplet = field(default_factory=lambda: COLORS["black"])

def __post_init__(self):
if len(self.gradient) < 2:
raise ValueError("ColorGradient must have at least 2 points")
self.gradient = [(float(pos), color) for pos, color in self.gradient]
pos_values = [pos for pos, _ in self.gradient]
if len(set(pos_values)) != len(pos_values):
raise ValueError("ColorGradient must have unique position values")
if 0.0 not in pos_values:
raise ValueError("ColorGradient must include a color for position 0.0")
if 1.0 not in pos_values:
raise ValueError("ColorGradient must include a color for position 1.0")
if max(pos_values) > 1.0 or min(pos_values) < 0.0:
raise ValueError("ColorGradient positions must be between 0.0 and 1.0")
self.gradient = sorted(self.gradient, key=lambda x: x[0])
@property
def colormap(self):
return self._colormap

def __iter__(self) -> Iterator[tuple[float, ColorTriplet]]:
return iter(self.gradient)
@colormap.setter
def colormap(self, value):
self._colormap = value
self._colormap.set_extreme(bad=self._masked_color, under=self._ninf_color, over=self._inf_color)

@property
def masked_color(self):
return self._masked_color

@dataclass
class ColorScheme:
gradient: ColorGradient = field(default_factory=ColorGradient)
masked_color: ColorTriplet = field(default_factory=lambda: COLORS["masked"])
true_color: ColorTriplet = field(default_factory=lambda: COLORS["true"])
false_color: ColorTriplet = field(default_factory=lambda: COLORS["false"])
inf_color: ColorTriplet = field(default_factory=lambda: COLORS["white"])
ninf_color: ColorTriplet = field(default_factory=lambda: COLORS["black"])
@masked_color.setter
def masked_color(self, value):
self._masked_color = value
self._colormap.set_bad(value)

@property
def inf_color(self):
return self._inf_color

@inf_color.setter
def inf_color(self, value):
self._inf_color = value
self._colormap.set_over(value)

@property
def ninf_color(self):
return self._ninf_color

def calculate_gradient_color_vectorized(
self, value_array: NDArray[np.number] | NDArray[bool]
) -> NDArray[np.uint8] | any:
"""
Calculate the gradient color for each value in the input array using vectorized interpolation.
@ninf_color.setter
def ninf_color(self, value):
self._ninf_color = value
self._colormap.set_under(value)

Args:
value_array (NDArray[np.float64]): The input array of values.
def __call__(self, data: np.ndarray) -> np.ndarray:
return self.colormap(self.normalize(data), bytes=True)

Returns:
Union[Any, NDArray[np.uint8]]: The calculated gradient color for each value in the input array.
"""
positions = [pos for pos, _ in self.gradient]
color_data = np.array([[color.red, color.green, color.blue] for _, color in self.gradient])
interp_functions = [interp1d(positions, channel) for channel in color_data.T]
flat_values = value_array.flatten()
interpolated_colors = np.stack([func(flat_values) for func in interp_functions], axis=0)
return interpolated_colors.reshape((3,) + value_array.shape).astype(np.uint8)
def __repr__(self):
return (
f"ColorScheme(\n"
f" colormap={self.colormap},\n"
f" normalize={self.normalize},\n"
f" masked_color={self._masked_color},\n"
f" true_color={self.true_color},\n"
f" false_color={self.false_color},\n"
f" inf_color={self._inf_color},\n"
f" ninf_color={self._ninf_color}\n"
f")"
)
Loading

0 comments on commit 766c3ff

Please sign in to comment.