Skip to content

Commit

Permalink
Merge pull request #3 from epistoteles/feature-tensorflow-support
Browse files Browse the repository at this point in the history
Add tensorflow support
  • Loading branch information
epistoteles authored Jun 26, 2024
2 parents d7a3461 + 6bca55d commit 2fda6a6
Show file tree
Hide file tree
Showing 14 changed files with 86 additions and 13 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@

TensorHue is a Python library that allows you to visualize tensors right in your console, making understanding and debugging tensor contents easier.

You can use it with your favorite tensor processing libraries, such as PyTorch, JAX, and TensorFlow*.
_*coming soon_
You can use it with your favorite tensor processing libraries, such as PyTorch, JAX, and TensorFlow.

TensorHue automagically detects which kind of tensor you are visualizing and adjusts accordingly:

Expand Down
2 changes: 1 addition & 1 deletion coverage-badge.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
pre-commit
pylint
torch
tensorflow
jax
tox
pytest
Expand Down
15 changes: 11 additions & 4 deletions tensorhue/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import sys
import inspect
import tensorhue._numpy as np
from tensorhue.colors import COLORS, ColorScheme
from tensorhue._print_opts import PRINT_OPTS, set_printoptions
from tensorhue._numpy import NumpyArrayWrapper
from tensorhue._torch import _tensorhue_to_numpy_torch
from tensorhue._jax import _tensorhue_to_numpy_jax
from tensorhue.connectors._numpy import NumpyArrayWrapper
from tensorhue.connectors._torch import _tensorhue_to_numpy_torch
from tensorhue.connectors._jax import _tensorhue_to_numpy_jax
from tensorhue.connectors._tensorflow import _tensorhue_to_numpy_tensorflow
from tensorhue.eastereggs import pride
from tensorhue.viz import viz, _viz

Expand Down Expand Up @@ -33,3 +33,10 @@
}: # jax >= 0.4.X (not sure about the exact version this changed)
setattr(jaxlib.xla_extension.ArrayImpl, "viz", _viz)
setattr(jaxlib.xla_extension.ArrayImpl, "_tensorhue_to_numpy", _tensorhue_to_numpy_jax)
if "tensorflow" in sys.modules:
tensorflow = sys.modules["tensorflow"]
setattr(tensorflow.Tensor, "viz", _viz)
setattr(tensorflow.Tensor, "_tensorhue_to_numpy", _tensorhue_to_numpy_tensorflow)
composite_tensor = sys.modules["tensorflow.python.framework.composite_tensor"]
setattr(composite_tensor.CompositeTensor, "viz", _viz)
setattr(composite_tensor.CompositeTensor, "_tensorhue_to_numpy", _tensorhue_to_numpy_tensorflow)
File renamed without changes.
File renamed without changes.
22 changes: 22 additions & 0 deletions tensorhue/connectors/_tensorflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import warnings
import numpy as np


def _tensorhue_to_numpy_tensorflow(tensor) -> np.ndarray:
if tensor.__class__.__name__ == "RaggedTensor": # hacky - but we shouldn't import torch here
warnings.warn(
"Tensorflow RaggedTensors are currently converted to dense tensors by filling with the value 0. Values that are actually 0 and filled-in values will appear indistinguishable. This behavior will change in the future."
)
return _tensorhue_to_numpy_tensorflow(tensor.to_tensor())
if tensor.__class__.__name__ == "SparseTensor":
raise ValueError("Tensorflow SparseTensors are not yet supported by TensorHue.")
try: # pylint: disable=duplicate-code
return tensor.numpy()
except RuntimeError as e:
raise NotImplementedError(
f"{e}: It looks like tensors of type {type(tensor)} cannot be converted to numpy arrays out-of-the-box. Raise an issue if you need to visualize them."
) from e
except Exception as e:
raise RuntimeError(
f"An unexpected error occurred while converting tensor of type {type(tensor)} to numpy array: {e}"
) from e
2 changes: 1 addition & 1 deletion tensorhue/_torch.py → tensorhue/connectors/_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
def _tensorhue_to_numpy_torch(tensor) -> np.ndarray:
if tensor.__class__.__name__ == "MaskedTensor": # hacky - but we shouldn't import torch here
return np.ma.masked_array(tensor.get_data(), ~tensor.get_mask())
try:
try: # pylint: disable=duplicate-code
return tensor.numpy()
except RuntimeError as e:
raise NotImplementedError(
Expand Down
2 changes: 1 addition & 1 deletion tensorhue/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
from tensorhue.colors import ColorScheme
from tensorhue._print_opts import PRINT_OPTS
from tensorhue._numpy import NumpyArrayWrapper
from tensorhue.connectors._numpy import NumpyArrayWrapper


def viz(tensor, *args, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion tests/test__jax.py → tests/test_connector_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp
from jax import core
import numpy as np
from tensorhue._jax import _tensorhue_to_numpy_jax
from tensorhue.connectors._jax import _tensorhue_to_numpy_jax


class NonConvertibleTensor:
Expand Down
2 changes: 1 addition & 1 deletion tests/test__numpy.py → tests/test_connector_numpy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from tensorhue._numpy import NumpyArrayWrapper
from tensorhue.connectors._numpy import NumpyArrayWrapper


def test_instantiation():
Expand Down
44 changes: 44 additions & 0 deletions tests/test_connector_tensorflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import numpy as np
import tensorflow as tf
import pytest
from tensorhue.connectors._tensorflow import _tensorhue_to_numpy_tensorflow


class NonConvertibleTensor:
def numpy(self):
raise RuntimeError("This tensor cannot be converted to numpy")


def test_tensor_dtypes():
dtypes = {
tf.float32: "float32",
tf.double: "float64",
tf.int32: "int32",
tf.int64: "int64",
tf.bool: "bool",
tf.complex128: "complex128",
}
tf_tensor = tf.constant([0.0, 1.0, 2.0, float("nan"), float("inf")])
for dtype_tf, dtype_np in dtypes.items():
tensor_casted = tf.cast(tf_tensor, dtype_tf)
converted = _tensorhue_to_numpy_tensorflow(tensor_casted)
assert np.array_equal(
converted.dtype, dtype_np
), f"dtype mismatch in torch to numpy conversion: expected {dtype_np}, got {converted.dtype}"


def test_runtime_error_for_non_convertible_tensor():
non_convertible = NonConvertibleTensor()
with pytest.raises(NotImplementedError) as exc_info:
_tensorhue_to_numpy_tensorflow(non_convertible)
assert "This tensor cannot be converted to numpy" in str(exc_info.value)


def test_unexpected_exception_for_other_errors():
class UnexpectedErrorTensor:
def numpy(self):
raise ValueError("Unexpected error")

with pytest.raises(RuntimeError) as exc_info:
_tensorhue_to_numpy_tensorflow(UnexpectedErrorTensor())
assert "Unexpected error" in str(exc_info.value)
2 changes: 1 addition & 1 deletion tests/test__torch.py → tests/test_connector_torch.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import torch
import pytest
from tensorhue._torch import _tensorhue_to_numpy_torch
from tensorhue.connectors._torch import _tensorhue_to_numpy_torch


class NonConvertibleTensor(torch.Tensor):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
import numpy as np
from tensorhue.viz import viz
from tensorhue._torch import _tensorhue_to_numpy_torch
from tensorhue.connectors._torch import _tensorhue_to_numpy_torch


@pytest.mark.parametrize("tensor", [np.ones(10), _tensorhue_to_numpy_torch(torch.ones(10))])
Expand Down

0 comments on commit 2fda6a6

Please sign in to comment.