Skip to content

Commit

Permalink
refactor connector layout and add tensorflow tests
Browse files Browse the repository at this point in the history
  • Loading branch information
epistoteles committed Jun 26, 2024
1 parent fc6f03b commit dc62c8e
Show file tree
Hide file tree
Showing 11 changed files with 60 additions and 14 deletions.
18 changes: 10 additions & 8 deletions tensorhue/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import sys
import inspect
import tensorhue._numpy as np
from tensorhue.colors import COLORS, ColorScheme
from tensorhue._print_opts import PRINT_OPTS, set_printoptions
from tensorhue._numpy import NumpyArrayWrapper
from tensorhue._torch import _tensorhue_to_numpy_torch
from tensorhue._jax import _tensorhue_to_numpy_jax
from tensorhue._tensorflow import _tensorhue_to_numpy_tensorflow
from tensorhue.connectors.numpy import NumpyArrayWrapper
from tensorhue.connectors.torch import _tensorhue_to_numpy_torch
from tensorhue.connectors.jax import _tensorhue_to_numpy_jax
from tensorhue.connectors.tensorflow import _tensorhue_to_numpy_tensorflow
from tensorhue.eastereggs import pride
from tensorhue.viz import viz, _viz

Expand Down Expand Up @@ -35,6 +34,9 @@
setattr(jaxlib.xla_extension.ArrayImpl, "viz", _viz)
setattr(jaxlib.xla_extension.ArrayImpl, "_tensorhue_to_numpy", _tensorhue_to_numpy_jax)
if "tensorflow" in sys.modules:
tf = sys.modules["tensorflow"]
setattr(tf.Tensor, "viz", _viz)
setattr(tf.Tensor, "_tensorhue_to_numpy", _tensorhue_to_numpy_tensorflow)
tensorflow = sys.modules["tensorflow"]
setattr(tensorflow.Tensor, "viz", _viz)
setattr(tensorflow.Tensor, "_tensorhue_to_numpy", _tensorhue_to_numpy_tensorflow)
composite_tensor = sys.modules["tensorflow.python.framework.composite_tensor"]
setattr(composite_tensor.CompositeTensor, "viz", _viz)
setattr(composite_tensor.CompositeTensor, "_tensorhue_to_numpy", _tensorhue_to_numpy_tensorflow)
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def _tensorhue_to_numpy_tensorflow(tensor) -> np.ndarray:
)
return _tensorhue_to_numpy_tensorflow(tensor.to_tensor())
elif tensor.__class__.__name__ == "SparseTensor":
raise ValueError("Tensorflow SparseTensors are not yet supported.")
raise ValueError("Tensorflow SparseTensors are not yet supported by TensorHue.")
try:
return tensor.numpy()
except RuntimeError as e:
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion tensorhue/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
from tensorhue.colors import ColorScheme
from tensorhue._print_opts import PRINT_OPTS
from tensorhue._numpy import NumpyArrayWrapper
from tensorhue.connectors.numpy import NumpyArrayWrapper


def viz(tensor, *args, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion tests/test__jax.py → tests/test_connector_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp
from jax import core
import numpy as np
from tensorhue._jax import _tensorhue_to_numpy_jax
from tensorhue.connectors.jax import _tensorhue_to_numpy_jax


class NonConvertibleTensor:
Expand Down
2 changes: 1 addition & 1 deletion tests/test__numpy.py → tests/test_connector_numpy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from tensorhue._numpy import NumpyArrayWrapper
from tensorhue.connectors.numpy import NumpyArrayWrapper


def test_instantiation():
Expand Down
44 changes: 44 additions & 0 deletions tests/test_connector_tensorflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import numpy as np
import tensorflow as tf
import pytest
from tensorhue.connectors.tensorflow import _tensorhue_to_numpy_tensorflow


class NonConvertibleTensor:
def numpy(self):
raise RuntimeError("This tensor cannot be converted to numpy")


def test_tensor_dtypes():
dtypes = {
tf.float32: "float32",
tf.double: "float64",
tf.int32: "int32",
tf.int64: "int64",
tf.bool: "bool",
tf.complex128: "complex128",
}
tf_tensor = tf.constant([0.0, 1.0, 2.0, float("nan"), float("inf")])
for dtype_tf, dtype_np in dtypes.items():
tensor_casted = tf.cast(tf_tensor, dtype_tf)
converted = _tensorhue_to_numpy_tensorflow(tensor_casted)
assert np.array_equal(
converted.dtype, dtype_np
), f"dtype mismatch in torch to numpy conversion: expected {dtype_np}, got {converted.dtype}"


def test_runtime_error_for_non_convertible_tensor():
non_convertible = NonConvertibleTensor()
with pytest.raises(NotImplementedError) as exc_info:
_tensorhue_to_numpy_tensorflow(non_convertible)
assert "This tensor cannot be converted to numpy" in str(exc_info.value)


def test_unexpected_exception_for_other_errors():
class UnexpectedErrorTensor:
def numpy(self):
raise ValueError("Unexpected error")

with pytest.raises(RuntimeError) as exc_info:
_tensorhue_to_numpy_tensorflow(UnexpectedErrorTensor())
assert "Unexpected error" in str(exc_info.value)
2 changes: 1 addition & 1 deletion tests/test__torch.py → tests/test_connector_torch.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import torch
import pytest
from tensorhue._torch import _tensorhue_to_numpy_torch
from tensorhue.connectors.torch import _tensorhue_to_numpy_torch


class NonConvertibleTensor(torch.Tensor):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
import numpy as np
from tensorhue.viz import viz
from tensorhue._torch import _tensorhue_to_numpy_torch
from tensorhue.connectors.torch import _tensorhue_to_numpy_torch


@pytest.mark.parametrize("tensor", [np.ones(10), _tensorhue_to_numpy_torch(torch.ones(10))])
Expand Down

0 comments on commit dc62c8e

Please sign in to comment.