Skip to content

Commit

Permalink
remove t.viz() functionality and show deprecation warning instead
Browse files Browse the repository at this point in the history
  • Loading branch information
epistoteles committed Sep 9, 2024
1 parent 4f161e2 commit 9022cb6
Show file tree
Hide file tree
Showing 16 changed files with 212 additions and 209 deletions.
26 changes: 8 additions & 18 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
<img src="https://img.shields.io/badge/contributions-welcome-orange.svg">
</div>

> [!IMPORTANT]
> t.viz() has been deprecated. Please use tensorhue.viz(t) instead.
> [!WARNING]
> TensorHue is currently in alpha. Expect bugs. We appreciate any feedback!
Expand Down Expand Up @@ -44,22 +47,9 @@ That's it! You can now visualize any tensor by calling .viz() on it in your Pyth

```python
t = torch.rand(20,20)
t.viz() ✅
```

You can also visualize them like this:

```python
tensorhue.viz(t) ✅
```

Numpy arrays can only be visualized with `tensorhue.viz(...)` (because np.ndarray is immutable):

```python
np.array([1,2,3]).viz() ❌
tensorhue.viz(np.array([1,2,3])) ✅
```

## Images

Pillow images can be visualized in RGB using `.viz()`:
Expand All @@ -68,7 +58,7 @@ Pillow images can be visualized in RGB using `.viz()`:
from torchvision.datasets import CIFAR10
dataset = CIFAR10('.', dowload=True)
img = dataset[0][0]
img.viz() ✅
tensorhue.viz(img) ✅
```

<div align="center">
Expand All @@ -78,7 +68,7 @@ img.viz() ✅
By default, images get downscaled to the size of your terminal, but you can make them even smaller if you want:

```python
img.viz(max_size=(40,40)) ✅
tensorhue.viz(img, max_size=(40,40)) ✅
```

## Custom colors
Expand All @@ -92,7 +82,7 @@ from matplotlib import colormaps
cs = ColorScheme(colormap=colormaps['inferno'],
true_color=(255,255,255),
false_color=(0,0,0))
t.viz(cs)
tensorhue.viz(t, cs)
```

Alternatively, you can overwrite the default ColorScheme:
Expand All @@ -110,13 +100,13 @@ By default, TensorHue normalizes numerical values between 0 and 1 and then appli
from matplotlib.colors import CenteredNorm
cs = ColorScheme(colormap=colormaps['bwr'],
normalize=CenteredNorm(vcenter=0))
t.viz(cs)
tensorhue.viz(t, cs)
```

You can also specify the normalization range manually, for example when you want to visualize a confusion matrix where colors should be mapped to the range [0, 1], but the actual values of the tensor are in the range [0.12, 0.73]:

```
conf_matrix.viz(vmin=0, vmax=1, scale=3)
tensorhue.viz(conf_matrix, vmin=0, vmax=1, scale=3)
```

<div align="center">
Expand Down
39 changes: 17 additions & 22 deletions tensorhue/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,46 +2,41 @@
import inspect
from tensorhue.colors import COLORS, ColorScheme
from tensorhue._print_opts import PRINT_OPTS, set_printoptions
from tensorhue.connectors._numpy import NumpyArrayWrapper
from tensorhue.connectors._torch import _tensorhue_to_numpy_torch
from tensorhue.connectors._jax import _tensorhue_to_numpy_jax
from tensorhue.connectors._tensorflow import _tensorhue_to_numpy_tensorflow
from tensorhue.connectors._pillow import _tensorhue_to_numpy_pillow
from tensorhue.eastereggs import pride
from tensorhue.viz import viz, _viz, _viz_image


__version__ = "0.0.17" # single source of version truth
__version__ = "0.1.0" # single source of version truth

__all__ = ["set_printoptions", "viz", "pride"]

# automagically set up TensorHue
setattr(NumpyArrayWrapper, "viz", _viz)

# show deprecation warning for t.viz() syntax
# delete everything below this line after version 0.2.0


def _viz_is_deprecated():
raise DeprecationWarning("The tensor.viz() function has been deprecated. Please use tensorhue.viz(tensor) instead.")


if "torch" in sys.modules:
torch = sys.modules["torch"]
setattr(torch.Tensor, "viz", _viz)
setattr(torch.Tensor, "_tensorhue_to_numpy", _tensorhue_to_numpy_torch)
setattr(torch.Tensor, "viz", _viz_is_deprecated)
if "jax" in sys.modules:
jax = sys.modules["jax"]
setattr(jax.Array, "viz", _viz)
setattr(jax.Array, "_tensorhue_to_numpy", _tensorhue_to_numpy_jax)
setattr(jax.Array, "viz", _viz_is_deprecated)
jaxlib = sys.modules["jaxlib"]
if "DeviceArrayBase" in {x[0] for x in inspect.getmembers(jaxlib.xla_extension)}: # jax < 0.4.X
setattr(jaxlib.xla_extension.DeviceArrayBase, "viz", _viz)
setattr(jaxlib.xla_extension.DeviceArrayBase, "_tensorhue_to_numpy", _tensorhue_to_numpy_jax)
setattr(jaxlib.xla_extension.DeviceArrayBase, "viz", _viz_is_deprecated)
if "ArrayImpl" in {
x[0] for x in inspect.getmembers(jaxlib.xla_extension)
}: # jax >= 0.4.X (not sure about the exact version this changed)
setattr(jaxlib.xla_extension.ArrayImpl, "viz", _viz)
setattr(jaxlib.xla_extension.ArrayImpl, "_tensorhue_to_numpy", _tensorhue_to_numpy_jax)
setattr(jaxlib.xla_extension.ArrayImpl, "viz", _viz_is_deprecated)
if "tensorflow" in sys.modules:
tensorflow = sys.modules["tensorflow"]
setattr(tensorflow.Tensor, "viz", _viz)
setattr(tensorflow.Tensor, "_tensorhue_to_numpy", _tensorhue_to_numpy_tensorflow)
setattr(tensorflow.Tensor, "viz", _viz_is_deprecated)
composite_tensor = sys.modules["tensorflow.python.framework.composite_tensor"]
setattr(composite_tensor.CompositeTensor, "viz", _viz)
setattr(composite_tensor.CompositeTensor, "_tensorhue_to_numpy", _tensorhue_to_numpy_tensorflow)
setattr(composite_tensor.CompositeTensor, "viz", _viz_is_deprecated)
if "PIL" in sys.modules:
PIL = sys.modules["PIL"]
setattr(PIL.Image.Image, "viz", _viz_image)
setattr(PIL.Image.Image, "_tensorhue_to_numpy", _tensorhue_to_numpy_pillow)
setattr(PIL.Image.Image, "viz", _viz_is_deprecated)
Empty file removed tensorhue/connectors/__init__.py
Empty file.
24 changes: 0 additions & 24 deletions tensorhue/connectors/_jax.py

This file was deleted.

10 changes: 0 additions & 10 deletions tensorhue/connectors/_numpy.py

This file was deleted.

16 changes: 0 additions & 16 deletions tensorhue/connectors/_pillow.py

This file was deleted.

22 changes: 0 additions & 22 deletions tensorhue/connectors/_tensorflow.py

This file was deleted.

16 changes: 0 additions & 16 deletions tensorhue/connectors/_torch.py

This file was deleted.

117 changes: 117 additions & 0 deletions tensorhue/converters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from __future__ import annotations
import warnings
import numpy as np


def tensor_to_numpy(tensor, **kwargs) -> np.ndarray:
"""
Converts a tensor of unknown type to a numpy array.
Args:
tensor: The tensor to be converted.
Returns:
The converted numpy array.
"""
mro_strings = mro_to_strings(tensor.__class__.__mro__)

if "numpy.ndarray" in mro_strings:
return tensor
elif "torch.Tensor" in mro_strings:
return _tensor_to_numpy_torch(tensor, **kwargs)
elif "tensorflow.python.types.core.Tensor" in mro_strings:
return _tensor_to_numpy_tensorflow(tensor, **kwargs)
elif "jaxlib.xla_extension.DeviceArray" in mro_strings:
return _tensor_to_numpy_jax(tensor, **kwargs)
elif "PIL.Image.Image" in mro_strings:
return _tensor_to_numpy_pillow(tensor, **kwargs)
else:
raise NotImplementedError(
f"Conversion of tensor of type {type(tensor)} is not supported. Please raise an issue of you think this is a bug or should be implemented."
)


def mro_to_strings(mro) -> list[str]:
"""
Converts the __mro__ of a class to a list of module.class_name strings.
Args:
mro: The __mro__ to be converted.
Returns:
The converted list of strings.
"""
return [f"{c.__module__}.{c.__name__}" for c in mro]


def _tensor_to_numpy_torch(tensor) -> np.ndarray:
if tensor.__class__.__name__ == "MaskedTensor": # hacky - but we shouldn't import torch here
return np.ma.masked_array(tensor.get_data(), ~tensor.get_mask())
try: # pylint: disable=duplicate-code
return tensor.numpy()
except RuntimeError as e:
raise NotImplementedError(
f"{e}: It looks like tensors of type {type(tensor)} cannot be converted to numpy arrays out-of-the-box. Raise an issue if you need to visualize them."
) from e
except Exception as e:
raise RuntimeError(
f"An unexpected error occurred while converting tensor of type {type(tensor)} to numpy array: {e}"
) from e


def _tensor_to_numpy_tensorflow(tensor) -> np.ndarray:
if tensor.__class__.__name__ == "RaggedTensor": # hacky - but we shouldn't import torch here
warnings.warn(
"Tensorflow RaggedTensors are currently converted to dense tensors by filling with the value 0. Values that are actually 0 and filled-in values will appear indistinguishable. This behavior will change in the future."
)
return _tensor_to_numpy_tensorflow(tensor.to_tensor())
if tensor.__class__.__name__ == "SparseTensor":
raise ValueError("Tensorflow SparseTensors are not yet supported by TensorHue.")
try: # pylint: disable=duplicate-code
return tensor.numpy()
except RuntimeError as e:
raise NotImplementedError(
f"{e}: It looks like tensors of type {type(tensor)} cannot be converted to numpy arrays out-of-the-box. Raise an issue if you need to visualize them."
) from e
except Exception as e:
raise RuntimeError(
f"An unexpected error occurred while converting tensor of type {type(tensor)} to numpy array: {e}"
) from e


def _tensor_to_numpy_jax(tensor) -> np.ndarray:
not_implemented = {"ShapedArray", "UnshapedArray", "AbstractArray"}
if {c.__name__ for c in tensor.__class__.__mro__}.intersection(
not_implemented
): # hacky - but we shouldn't import jax here
raise NotImplementedError(
f"Jax arrays of type {tensor.__class__.__name__} cannot be visualized. Raise an issue if you believe this is wrong."
)
try:
array = np.asarray(tensor)
if array.dtype == "object":
raise RuntimeError("Got non-visualizable dtype 'object'.")
return array
except RuntimeError as e:
raise NotImplementedError(
f"{e}: It looks like JAX arrays of type {type(tensor)} cannot be converted to numpy arrays out-of-the-box. Raise an issue if you need to visualize them."
) from e
except Exception as e:
raise RuntimeError(
f"An unexpected error occurred while converting tensor of type {type(tensor)} to numpy array: {e}"
) from e


def _tensor_to_numpy_pillow(image, thumbnail, max_size) -> np.ndarray:
try:
image = image.convert("RGB")
except Exception as e:
raise ValueError("Could not convert image from mode '{mode}' to 'RGB'.") from e

if thumbnail:
image.thumbnail(max_size)

array = np.array(image)
assert array.dtype == "uint8"

return array
Loading

0 comments on commit 9022cb6

Please sign in to comment.