-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
update ColorScheme and supported libraries and lint everything
- Loading branch information
1 parent
2f9a6dd
commit 16db8d3
Showing
7 changed files
with
157 additions
and
86 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import numpy as np | ||
from matplotlib.colors import LinearSegmentedColormap | ||
from rich.color_triplet import ColorTriplet | ||
from tensorhue.colors import ColorScheme | ||
from tensorhue.viz import viz | ||
|
||
|
||
def pride(): | ||
pride_colors = [ | ||
ColorTriplet(228, 3, 3), | ||
ColorTriplet(255, 140, 0), | ||
ColorTriplet(255, 237, 0), | ||
ColorTriplet(0, 128, 38), | ||
ColorTriplet(0, 76, 255), | ||
ColorTriplet(115, 41, 130), | ||
] | ||
pride_cm = LinearSegmentedColormap.from_list(colors=[c.normalized for c in pride_colors], name="pride") | ||
pride_cs = ColorScheme(colormap=pride_cm) | ||
arr = np.repeat(np.linspace(0, 1, 6).reshape(-1, 1), 10, axis=1) | ||
viz(arr, colorscheme=pride_cs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
import numpy as np | ||
|
||
|
||
class NumpyArrayWrapper(np.ndarray): | ||
def __new__(cls, input_array): | ||
obj = np.asarray(input_array).view(cls) | ||
return obj | ||
|
||
def _tensorhue_to_numpy(self): | ||
return self |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
import torch | ||
import numpy as np | ||
|
||
|
||
def _tensorhue_to_numpy_torch(tensor: torch.Tensor) -> np.ndarray: | ||
if isinstance(tensor, torch.masked.MaskedTensor): | ||
return np.ma.masked_array(tensor.get_data(), torch.logical_not(tensor.get_mask())) | ||
try: | ||
return tensor.numpy() | ||
except RuntimeError as e: | ||
raise NotImplementedError( | ||
f"{e}: It looks like tensors of type {type(tensor)} cannot be converted to numpy arrays out-of-the-box. Raise an issue if you need to visualize them." | ||
) from e | ||
except Exception as e: | ||
raise RuntimeError( | ||
f"An unexpected error occurred while converting tensor of type {type(tensor)} to numpy array: {e}" | ||
) from e |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
from rich.console import Console | ||
import numpy as np | ||
from tensorhue.colors import ColorScheme | ||
from tensorhue._print_opts import PRINT_OPTS | ||
from tensorhue.numpy import NumpyArrayWrapper | ||
|
||
|
||
def viz(tensor, *args, **kwargs): | ||
if isinstance(tensor, np.ndarray): | ||
tensor = NumpyArrayWrapper(tensor) | ||
tensor.viz(*args, **kwargs) # pylint: disable=no-member | ||
else: | ||
try: | ||
tensor.viz(*args, **kwargs) | ||
except Exception as e: | ||
raise NotImplementedError( | ||
f"TensorHue does not support type {type(tensor)}. Raise an issue if you need to visualize them. Alternatively, check if you imported tensorhue *after* your other library." | ||
) from e | ||
|
||
|
||
def _viz(self, colorscheme: ColorScheme = None): | ||
""" | ||
Prints a tensor using colored Unicode art representation. | ||
Args: | ||
colorscheme (ColorScheme, optional): The color scheme to use. | ||
Defaults to None, which means the global default color scheme is used. | ||
""" | ||
if colorscheme is None: | ||
colorscheme = PRINT_OPTS.colorscheme | ||
|
||
self = self._tensorhue_to_numpy() | ||
shape = self.shape | ||
|
||
if len(shape) > 2: | ||
raise NotImplementedError( | ||
"Visualization for tensors with more than 2 dimensions is under development. Please slice them for now." | ||
) | ||
|
||
colors = colorscheme(self)[..., :3] | ||
|
||
result_lines = [""] | ||
for y in range(0, shape[0] - 1, 2): | ||
for x in range(shape[-1]): | ||
result_lines[ | ||
-1 | ||
] += f"[rgb({colors[y, x, 0]},{colors[y, x, 1]},{colors[y, x, 2]}) on rgb({colors[y+1, x, 0]},{colors[y+1, x, 1]},{colors[y+1, x, 2]})]▀[/]" | ||
result_lines.append("") | ||
|
||
if shape[0] % 2 == 1: | ||
for x in range(shape[1]): | ||
result_lines[-1] += f"[rgb({colors[-1, x, 0]},{colors[-1, x, 1]},{colors[-1, x, 2]})]▀[/]" | ||
|
||
c = Console(log_path=False, record=False) | ||
c.print("\n".join(result_lines)) |