-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor connector layout and add tensorflow tests
- Loading branch information
1 parent
fc6f03b
commit dc62c8e
Showing
11 changed files
with
60 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import numpy as np | ||
import tensorflow as tf | ||
import pytest | ||
from tensorhue.connectors.tensorflow import _tensorhue_to_numpy_tensorflow | ||
|
||
|
||
class NonConvertibleTensor: | ||
def numpy(self): | ||
raise RuntimeError("This tensor cannot be converted to numpy") | ||
|
||
|
||
def test_tensor_dtypes(): | ||
dtypes = { | ||
tf.float32: "float32", | ||
tf.double: "float64", | ||
tf.int32: "int32", | ||
tf.int64: "int64", | ||
tf.bool: "bool", | ||
tf.complex128: "complex128", | ||
} | ||
tf_tensor = tf.constant([0.0, 1.0, 2.0, float("nan"), float("inf")]) | ||
for dtype_tf, dtype_np in dtypes.items(): | ||
tensor_casted = tf.cast(tf_tensor, dtype_tf) | ||
converted = _tensorhue_to_numpy_tensorflow(tensor_casted) | ||
assert np.array_equal( | ||
converted.dtype, dtype_np | ||
), f"dtype mismatch in torch to numpy conversion: expected {dtype_np}, got {converted.dtype}" | ||
|
||
|
||
def test_runtime_error_for_non_convertible_tensor(): | ||
non_convertible = NonConvertibleTensor() | ||
with pytest.raises(NotImplementedError) as exc_info: | ||
_tensorhue_to_numpy_tensorflow(non_convertible) | ||
assert "This tensor cannot be converted to numpy" in str(exc_info.value) | ||
|
||
|
||
def test_unexpected_exception_for_other_errors(): | ||
class UnexpectedErrorTensor: | ||
def numpy(self): | ||
raise ValueError("Unexpected error") | ||
|
||
with pytest.raises(RuntimeError) as exc_info: | ||
_tensorhue_to_numpy_tensorflow(UnexpectedErrorTensor()) | ||
assert "Unexpected error" in str(exc_info.value) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters