diff --git a/.github/workflows/update-coverage-badge.yml b/.github/workflows/update-coverage-badge.yml
index f69ca34..4829147 100644
--- a/.github/workflows/update-coverage-badge.yml
+++ b/.github/workflows/update-coverage-badge.yml
@@ -27,13 +27,13 @@ jobs:
with:
python-version: '3.x'
- - name: "Install Dependencies"
+ - name: "Install dependencies"
run: |
python -m pip install --upgrade pip
pip install -r requirements-dev.txt
pip install -e .
- - name: "Run Coverage and Generate Badge"
+ - name: "Run coverage and generate badge"
run: |
pytest .
coverage report
@@ -52,11 +52,22 @@ jobs:
echo "svg_changed=true" >> $GITHUB_ENV
fi
- - name: "Commit and Push Changes"
+ - name: "Commit and push changes"
if: env.svg_changed == 'true'
run: |
git config user.email "${{ github.run_id }}+github-actions[bot]@users.noreply.github.com"
git config user.name "github-actions[bot]"
+
+ if [[ "${{ github.event_name }}" == 'push' ]]; then
+ target_branch=$(echo "${{ github.ref }}" | awk -F'/' '{print $3}')
+ else
+ target_branch="${{ github.event.pull_request.head.ref }}"
+ fi
+
+ git fetch origin "${target_branch}:${target_branch}"
+ git checkout "${target_branch}" || git checkout -b "${target_branch}"
+ git push --set-upstream origin "${target_branch}"
+
git add coverage-badge.svg
git commit -m "gh-actions[bot]: update code coverage badge"
git push
diff --git a/README.md b/README.md
index 98ffcef..89c1a6a 100644
--- a/README.md
+++ b/README.md
@@ -16,7 +16,7 @@
TensorHue is a Python library that allows you to visualize tensors right in your console, making understanding and debugging tensor contents easier.
-You can use it with your favorite tensor processing libraries, such as PyTorch, JAX*, and TensorFlow*.
+You can use it with your favorite tensor processing libraries, such as PyTorch, JAX, and TensorFlow*.
_*coming soon_
TensorHue automagically detects which kind of tensor you are visualizing and adjusts accordingly:
diff --git a/coverage-badge.svg b/coverage-badge.svg
index e18725f..166ad97 100644
--- a/coverage-badge.svg
+++ b/coverage-badge.svg
@@ -1 +1 @@
-
\ No newline at end of file
+
\ No newline at end of file
diff --git a/requirements-dev.txt b/requirements-dev.txt
index 4347553..bf680e5 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -2,6 +2,7 @@
pre-commit
pylint
torch
+jax
tox
pytest
pytest-cov
diff --git a/tensorhue/__init__.py b/tensorhue/__init__.py
index 105ef07..3b9d1a9 100644
--- a/tensorhue/__init__.py
+++ b/tensorhue/__init__.py
@@ -1,10 +1,11 @@
import sys
-from rich.console import Console
+import inspect
import tensorhue._numpy as np
from tensorhue.colors import COLORS, ColorScheme
from tensorhue._print_opts import PRINT_OPTS, set_printoptions
from tensorhue._numpy import NumpyArrayWrapper
from tensorhue._torch import _tensorhue_to_numpy_torch
+from tensorhue._jax import _tensorhue_to_numpy_jax
from tensorhue.eastereggs import pride
from tensorhue.viz import viz, _viz
@@ -19,3 +20,16 @@
torch = sys.modules["torch"]
setattr(torch.Tensor, "viz", _viz)
setattr(torch.Tensor, "_tensorhue_to_numpy", _tensorhue_to_numpy_torch)
+if "jax" in sys.modules:
+ jax = sys.modules["jax"]
+ setattr(jax.Array, "viz", _viz)
+ setattr(jax.Array, "_tensorhue_to_numpy", _tensorhue_to_numpy_jax)
+ jaxlib = sys.modules["jaxlib"]
+ if "DeviceArrayBase" in {x[0] for x in inspect.getmembers(jaxlib.xla_extension)}: # jax < 0.4.X
+ setattr(jaxlib.xla_extension.DeviceArrayBase, "viz", _viz)
+ setattr(jaxlib.xla_extension.DeviceArrayBase, "_tensorhue_to_numpy", _tensorhue_to_numpy_jax)
+ if "ArrayImpl" in {
+ x[0] for x in inspect.getmembers(jaxlib.xla_extension)
+ }: # jax >= 0.4.X (not sure about the exact version this changed)
+ setattr(jaxlib.xla_extension.ArrayImpl, "viz", _viz)
+ setattr(jaxlib.xla_extension.ArrayImpl, "_tensorhue_to_numpy", _tensorhue_to_numpy_jax)
diff --git a/tensorhue/_jax.py b/tensorhue/_jax.py
new file mode 100644
index 0000000..493bd8c
--- /dev/null
+++ b/tensorhue/_jax.py
@@ -0,0 +1,24 @@
+import numpy as np
+
+
+def _tensorhue_to_numpy_jax(tensor) -> np.ndarray:
+ not_implemented = {"ShapedArray", "UnshapedArray", "AbstractArray"}
+ if {c.__name__ for c in tensor.__class__.__mro__}.intersection(
+ not_implemented
+ ): # hacky - but we shouldn't import jax here
+ raise NotImplementedError(
+ f"Jax arrays of type {tensor.__class__.__name__} cannot be visualized. Raise an issue if you believe this is wrong."
+ )
+ try:
+ array = np.asarray(tensor)
+ if array.dtype == "object":
+ raise RuntimeError("Got non-visualizable dtype 'object'.")
+ return array
+ except RuntimeError as e:
+ raise NotImplementedError(
+ f"{e}: It looks like JAX arrays of type {type(tensor)} cannot be converted to numpy arrays out-of-the-box. Raise an issue if you need to visualize them."
+ ) from e
+ except Exception as e:
+ raise RuntimeError(
+ f"An unexpected error occurred while converting tensor of type {type(tensor)} to numpy array: {e}"
+ ) from e
diff --git a/tests/test__jax.py b/tests/test__jax.py
new file mode 100644
index 0000000..257c1bf
--- /dev/null
+++ b/tests/test__jax.py
@@ -0,0 +1,50 @@
+import pytest
+import jax.numpy as jnp
+from jax import core
+import numpy as np
+from tensorhue._jax import _tensorhue_to_numpy_jax
+
+
+class NonConvertibleTensor:
+ pass
+
+
+def test_jax_device_array():
+ data = [[1, 2], [3, 4]]
+ device_array = jnp.array(data)
+ assert np.array_equal(_tensorhue_to_numpy_jax(device_array), np.array(data))
+
+
+def test_tensor_dtypes():
+ dtypes = {
+ jnp.float32: "float32",
+ jnp.bfloat16: "bfloat16",
+ jnp.int32: "int32",
+ jnp.uint8: "uint8",
+ bool: "bool",
+ jnp.complex64: "complex64",
+ }
+ jnp_array = jnp.array([0.0, 1.0, 2.0, jnp.nan, jnp.inf])
+ for dtype_jnp, dtype_np in dtypes.items():
+ jnp_casted = jnp_array.astype(dtype_jnp)
+ converted = _tensorhue_to_numpy_jax(jnp_casted)
+ assert np.array_equal(
+ converted.dtype, dtype_np
+ ), f"dtype mismatch in jax.numpy to numpy conversion: expected {dtype_np}, got {converted.dtype}"
+
+
+def test_jax_incompatible_arrays():
+ shape = (2, 2)
+ dtype = jnp.float32
+
+ shaped_array = core.ShapedArray(shape, dtype)
+ with pytest.raises(NotImplementedError) as exc_info:
+ _tensorhue_to_numpy_jax(shaped_array)
+ assert "cannot be visualized" in str(exc_info.value)
+
+
+def test_runtime_error_for_non_convertible_tensor():
+ non_convertible = NonConvertibleTensor()
+ with pytest.raises(NotImplementedError) as exc_info:
+ _tensorhue_to_numpy_jax(non_convertible)
+ assert "Got non-visualizable dtype 'object'." in str(exc_info.value)