Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Breaking change: deprecate tensor.viz() #12

Merged
merged 5 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified .github/images.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified .github/tensor_types.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
32 changes: 14 additions & 18 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
<img src="https://img.shields.io/badge/contributions-welcome-orange.svg">
</div>

> [!IMPORTANT]
> t.viz() has been deprecated. Please use tensorhue.viz(t) instead.

> [!WARNING]
> TensorHue is currently in alpha. Expect bugs. We appreciate any feedback!

Expand All @@ -33,42 +36,35 @@ Install TensorHue with pip:
pip install tensorhue
```

Using TensorHue is easy, simply import TensorHue *after* importing the library of your choice:
Using TensorHue is easy, simply import TensorHue together with the library of your choice:

```python
import torch
import tensorhue
```

That's it! You can now visualize any tensor by calling .viz() on it in your Python console:
Or, alternatively:

```python
t = torch.rand(20,20)
t.viz() ✅
from tensorhue import viz
```

You can also visualize them like this:
That's it! You can now visualize any tensor by calling .viz() on it in your Python console:

```python
t = torch.rand(20,20)
tensorhue.viz(t) ✅
```

Numpy arrays can only be visualized with `tensorhue.viz(...)` (because np.ndarray is immutable):

```python
np.array([1,2,3]).viz() ❌
tensorhue.viz(np.array([1,2,3])) ✅
```

## Images

Pillow images can be visualized in RGB using `.viz()`:
Pillow images can be visualized in RGB and other color modes:

```python
from torchvision.datasets import CIFAR10
dataset = CIFAR10('.', dowload=True)
img = dataset[0][0]
img.viz() ✅
tensorhue.viz(img) ✅
```

<div align="center">
Expand All @@ -78,7 +74,7 @@ img.viz() ✅
By default, images get downscaled to the size of your terminal, but you can make them even smaller if you want:

```python
img.viz(max_size=(40,40)) ✅
tensorhue.viz(img, max_size=(40,40)) ✅
```

## Custom colors
Expand All @@ -92,7 +88,7 @@ from matplotlib import colormaps
cs = ColorScheme(colormap=colormaps['inferno'],
true_color=(255,255,255),
false_color=(0,0,0))
t.viz(cs)
tensorhue.viz(t, colorscheme=cs)
```

Alternatively, you can overwrite the default ColorScheme:
Expand All @@ -110,13 +106,13 @@ By default, TensorHue normalizes numerical values between 0 and 1 and then appli
from matplotlib.colors import CenteredNorm
cs = ColorScheme(colormap=colormaps['bwr'],
normalize=CenteredNorm(vcenter=0))
t.viz(cs)
tensorhue.viz(t, colorscheme=cs)
```

You can also specify the normalization range manually, for example when you want to visualize a confusion matrix where colors should be mapped to the range [0, 1], but the actual values of the tensor are in the range [0.12, 0.73]:

```
conf_matrix.viz(vmin=0, vmax=1, scale=3)
tensorhue.viz(conf_matrix, vmin=0, vmax=1, scale=3)
```

<div align="center">
Expand Down
2 changes: 1 addition & 1 deletion coverage-badge.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
39 changes: 17 additions & 22 deletions tensorhue/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,46 +2,41 @@
import inspect
from tensorhue.colors import COLORS, ColorScheme
from tensorhue._print_opts import PRINT_OPTS, set_printoptions
from tensorhue.connectors._numpy import NumpyArrayWrapper
from tensorhue.connectors._torch import _tensorhue_to_numpy_torch
from tensorhue.connectors._jax import _tensorhue_to_numpy_jax
from tensorhue.connectors._tensorflow import _tensorhue_to_numpy_tensorflow
from tensorhue.connectors._pillow import _tensorhue_to_numpy_pillow
from tensorhue.eastereggs import pride
from tensorhue.viz import viz, _viz, _viz_image


__version__ = "0.0.17" # single source of version truth
__version__ = "0.1.0" # single source of version truth

__all__ = ["set_printoptions", "viz", "pride"]

# automagically set up TensorHue
setattr(NumpyArrayWrapper, "viz", _viz)

# show deprecation warning for t.viz() usage
# delete everything below this line after version 0.2.0


def _viz_is_deprecated(self):
raise DeprecationWarning("The tensor.viz() function has been deprecated. Please use tensorhue.viz(tensor) instead.")


if "torch" in sys.modules:
torch = sys.modules["torch"]
setattr(torch.Tensor, "viz", _viz)
setattr(torch.Tensor, "_tensorhue_to_numpy", _tensorhue_to_numpy_torch)
setattr(torch.Tensor, "viz", _viz_is_deprecated)
if "jax" in sys.modules:
jax = sys.modules["jax"]
setattr(jax.Array, "viz", _viz)
setattr(jax.Array, "_tensorhue_to_numpy", _tensorhue_to_numpy_jax)
setattr(jax.Array, "viz", _viz_is_deprecated)
jaxlib = sys.modules["jaxlib"]
if "DeviceArrayBase" in {x[0] for x in inspect.getmembers(jaxlib.xla_extension)}: # jax < 0.4.X
setattr(jaxlib.xla_extension.DeviceArrayBase, "viz", _viz)
setattr(jaxlib.xla_extension.DeviceArrayBase, "_tensorhue_to_numpy", _tensorhue_to_numpy_jax)
setattr(jaxlib.xla_extension.DeviceArrayBase, "viz", _viz_is_deprecated)
if "ArrayImpl" in {
x[0] for x in inspect.getmembers(jaxlib.xla_extension)
}: # jax >= 0.4.X (not sure about the exact version this changed)
setattr(jaxlib.xla_extension.ArrayImpl, "viz", _viz)
setattr(jaxlib.xla_extension.ArrayImpl, "_tensorhue_to_numpy", _tensorhue_to_numpy_jax)
setattr(jaxlib.xla_extension.ArrayImpl, "viz", _viz_is_deprecated)
if "tensorflow" in sys.modules:
tensorflow = sys.modules["tensorflow"]
setattr(tensorflow.Tensor, "viz", _viz)
setattr(tensorflow.Tensor, "_tensorhue_to_numpy", _tensorhue_to_numpy_tensorflow)
setattr(tensorflow.Tensor, "viz", _viz_is_deprecated)
composite_tensor = sys.modules["tensorflow.python.framework.composite_tensor"]
setattr(composite_tensor.CompositeTensor, "viz", _viz)
setattr(composite_tensor.CompositeTensor, "_tensorhue_to_numpy", _tensorhue_to_numpy_tensorflow)
setattr(composite_tensor.CompositeTensor, "viz", _viz_is_deprecated)
if "PIL" in sys.modules:
PIL = sys.modules["PIL"]
setattr(PIL.Image.Image, "viz", _viz_image)
setattr(PIL.Image.Image, "_tensorhue_to_numpy", _tensorhue_to_numpy_pillow)
setattr(PIL.Image.Image, "viz", _viz_is_deprecated)
Empty file removed tensorhue/connectors/__init__.py
Empty file.
24 changes: 0 additions & 24 deletions tensorhue/connectors/_jax.py

This file was deleted.

10 changes: 0 additions & 10 deletions tensorhue/connectors/_numpy.py

This file was deleted.

16 changes: 0 additions & 16 deletions tensorhue/connectors/_pillow.py

This file was deleted.

22 changes: 0 additions & 22 deletions tensorhue/connectors/_tensorflow.py

This file was deleted.

16 changes: 0 additions & 16 deletions tensorhue/connectors/_torch.py

This file was deleted.

117 changes: 117 additions & 0 deletions tensorhue/converters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from __future__ import annotations
import warnings
import numpy as np


def tensor_to_numpy(tensor, **kwargs) -> np.ndarray:
"""
Converts a tensor of unknown type to a numpy array.

Args:
tensor (Any): The tensor to be converted.
**kwargs: Additional keyword arguments that are passed to the underlying converter functions.

Returns:
The converted numpy array.
"""
mro_strings = mro_to_strings(tensor.__class__.__mro__)

if "numpy.ndarray" in mro_strings:
return tensor
if "torch.Tensor" in mro_strings:
return _tensor_to_numpy_torch(tensor, **kwargs)
if "tensorflow.python.types.core.Tensor" in mro_strings:
return _tensor_to_numpy_tensorflow(tensor, **kwargs)
if "jaxlib.xla_extension.DeviceArray" in mro_strings:
return _tensor_to_numpy_jax(tensor, **kwargs)
if "PIL.Image.Image" in mro_strings:
return _tensor_to_numpy_pillow(tensor, **kwargs)
raise NotImplementedError(
f"Conversion of tensor of type {type(tensor)} is not supported. Please raise an issue of you think this is a bug or should be implemented."
)


def mro_to_strings(mro) -> list[str]:
"""
Converts the __mro__ of a class to a list of module.class_name strings.

Args:
mro (tuple[type]): The __mro__ to be converted.

Returns:
The converted list of strings.
"""
return [f"{c.__module__}.{c.__name__}" for c in mro]


def _tensor_to_numpy_torch(tensor) -> np.ndarray:
if tensor.__class__.__name__ == "MaskedTensor": # hacky - but we shouldn't import torch here
return np.ma.masked_array(tensor.get_data(), ~tensor.get_mask())
try: # pylint: disable=duplicate-code
return tensor.numpy()
except RuntimeError as e:
raise NotImplementedError(
f"{e}: It looks like tensors of type {type(tensor)} cannot be converted to numpy arrays out-of-the-box. Raise an issue if you need to visualize them."
) from e
except Exception as e:
raise RuntimeError(
f"An unexpected error occurred while converting tensor of type {type(tensor)} to numpy array: {e}"
) from e


def _tensor_to_numpy_tensorflow(tensor) -> np.ndarray:
if tensor.__class__.__name__ == "RaggedTensor": # hacky - but we shouldn't import torch here
warnings.warn(
"Tensorflow RaggedTensors are currently converted to dense tensors by filling with the value 0. Values that are actually 0 and filled-in values will appear indistinguishable. This behavior will change in the future."
)
return _tensor_to_numpy_tensorflow(tensor.to_tensor())
if tensor.__class__.__name__ == "SparseTensor":
raise ValueError("Tensorflow SparseTensors are not yet supported by TensorHue.")
try: # pylint: disable=duplicate-code
return tensor.numpy()
except RuntimeError as e:
raise NotImplementedError(
f"{e}: It looks like tensors of type {type(tensor)} cannot be converted to numpy arrays out-of-the-box. Raise an issue if you need to visualize them."
) from e
except Exception as e:
raise RuntimeError(
f"An unexpected error occurred while converting tensor of type {type(tensor)} to numpy array: {e}"
) from e


def _tensor_to_numpy_jax(tensor) -> np.ndarray:
not_implemented = {"ShapedArray", "UnshapedArray", "AbstractArray"}
if {c.__name__ for c in tensor.__class__.__mro__}.intersection(
not_implemented
): # hacky - but we shouldn't import jax here
raise NotImplementedError(
f"Jax arrays of type {tensor.__class__.__name__} cannot be visualized. Raise an issue if you believe this is wrong."
)
try:
array = np.asarray(tensor)
if array.dtype == "object":
raise RuntimeError("Got non-visualizable dtype 'object'.")
return array
except RuntimeError as e:
raise NotImplementedError(
f"{e}: It looks like JAX arrays of type {type(tensor)} cannot be converted to numpy arrays out-of-the-box. Raise an issue if you need to visualize them."
) from e
except Exception as e:
raise RuntimeError(
f"An unexpected error occurred while converting tensor of type {type(tensor)} to numpy array: {e}"
) from e


def _tensor_to_numpy_pillow(image, thumbnail, max_size) -> np.ndarray:
try:
image = image.convert("RGB")
except Exception as e:
raise ValueError("Could not convert image from mode '{mode}' to 'RGB'.") from e

if thumbnail:
image.thumbnail(max_size)

array = np.array(image)
assert array.dtype == "uint8"

return array
Loading