-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #3 from epistoteles/feature-tensorflow-support
Add tensorflow support
- Loading branch information
Showing
14 changed files
with
86 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,7 @@ | |
pre-commit | ||
pylint | ||
torch | ||
tensorflow | ||
jax | ||
tox | ||
pytest | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import warnings | ||
import numpy as np | ||
|
||
|
||
def _tensorhue_to_numpy_tensorflow(tensor) -> np.ndarray: | ||
if tensor.__class__.__name__ == "RaggedTensor": # hacky - but we shouldn't import torch here | ||
warnings.warn( | ||
"Tensorflow RaggedTensors are currently converted to dense tensors by filling with the value 0. Values that are actually 0 and filled-in values will appear indistinguishable. This behavior will change in the future." | ||
) | ||
return _tensorhue_to_numpy_tensorflow(tensor.to_tensor()) | ||
if tensor.__class__.__name__ == "SparseTensor": | ||
raise ValueError("Tensorflow SparseTensors are not yet supported by TensorHue.") | ||
try: # pylint: disable=duplicate-code | ||
return tensor.numpy() | ||
except RuntimeError as e: | ||
raise NotImplementedError( | ||
f"{e}: It looks like tensors of type {type(tensor)} cannot be converted to numpy arrays out-of-the-box. Raise an issue if you need to visualize them." | ||
) from e | ||
except Exception as e: | ||
raise RuntimeError( | ||
f"An unexpected error occurred while converting tensor of type {type(tensor)} to numpy array: {e}" | ||
) from e |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import numpy as np | ||
import tensorflow as tf | ||
import pytest | ||
from tensorhue.connectors._tensorflow import _tensorhue_to_numpy_tensorflow | ||
|
||
|
||
class NonConvertibleTensor: | ||
def numpy(self): | ||
raise RuntimeError("This tensor cannot be converted to numpy") | ||
|
||
|
||
def test_tensor_dtypes(): | ||
dtypes = { | ||
tf.float32: "float32", | ||
tf.double: "float64", | ||
tf.int32: "int32", | ||
tf.int64: "int64", | ||
tf.bool: "bool", | ||
tf.complex128: "complex128", | ||
} | ||
tf_tensor = tf.constant([0.0, 1.0, 2.0, float("nan"), float("inf")]) | ||
for dtype_tf, dtype_np in dtypes.items(): | ||
tensor_casted = tf.cast(tf_tensor, dtype_tf) | ||
converted = _tensorhue_to_numpy_tensorflow(tensor_casted) | ||
assert np.array_equal( | ||
converted.dtype, dtype_np | ||
), f"dtype mismatch in torch to numpy conversion: expected {dtype_np}, got {converted.dtype}" | ||
|
||
|
||
def test_runtime_error_for_non_convertible_tensor(): | ||
non_convertible = NonConvertibleTensor() | ||
with pytest.raises(NotImplementedError) as exc_info: | ||
_tensorhue_to_numpy_tensorflow(non_convertible) | ||
assert "This tensor cannot be converted to numpy" in str(exc_info.value) | ||
|
||
|
||
def test_unexpected_exception_for_other_errors(): | ||
class UnexpectedErrorTensor: | ||
def numpy(self): | ||
raise ValueError("Unexpected error") | ||
|
||
with pytest.raises(RuntimeError) as exc_info: | ||
_tensorhue_to_numpy_tensorflow(UnexpectedErrorTensor()) | ||
assert "Unexpected error" in str(exc_info.value) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters