diff --git a/.github/workflows/update-coverage-badge.yml b/.github/workflows/update-coverage-badge.yml index f69ca34..4829147 100644 --- a/.github/workflows/update-coverage-badge.yml +++ b/.github/workflows/update-coverage-badge.yml @@ -27,13 +27,13 @@ jobs: with: python-version: '3.x' - - name: "Install Dependencies" + - name: "Install dependencies" run: | python -m pip install --upgrade pip pip install -r requirements-dev.txt pip install -e . - - name: "Run Coverage and Generate Badge" + - name: "Run coverage and generate badge" run: | pytest . coverage report @@ -52,11 +52,22 @@ jobs: echo "svg_changed=true" >> $GITHUB_ENV fi - - name: "Commit and Push Changes" + - name: "Commit and push changes" if: env.svg_changed == 'true' run: | git config user.email "${{ github.run_id }}+github-actions[bot]@users.noreply.github.com" git config user.name "github-actions[bot]" + + if [[ "${{ github.event_name }}" == 'push' ]]; then + target_branch=$(echo "${{ github.ref }}" | awk -F'/' '{print $3}') + else + target_branch="${{ github.event.pull_request.head.ref }}" + fi + + git fetch origin "${target_branch}:${target_branch}" + git checkout "${target_branch}" || git checkout -b "${target_branch}" + git push --set-upstream origin "${target_branch}" + git add coverage-badge.svg git commit -m "gh-actions[bot]: update code coverage badge" git push diff --git a/README.md b/README.md index 98ffcef..89c1a6a 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ TensorHue is a Python library that allows you to visualize tensors right in your console, making understanding and debugging tensor contents easier. -You can use it with your favorite tensor processing libraries, such as PyTorch, JAX*, and TensorFlow*. +You can use it with your favorite tensor processing libraries, such as PyTorch, JAX, and TensorFlow*. _*coming soon_ TensorHue automagically detects which kind of tensor you are visualizing and adjusts accordingly: diff --git a/coverage-badge.svg b/coverage-badge.svg index e18725f..166ad97 100644 --- a/coverage-badge.svg +++ b/coverage-badge.svg @@ -1 +1 @@ -coverage: 89.68%coverage89.68% \ No newline at end of file +coverage: 88.95%coverage88.95% \ No newline at end of file diff --git a/requirements-dev.txt b/requirements-dev.txt index 4347553..bf680e5 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,6 +2,7 @@ pre-commit pylint torch +jax tox pytest pytest-cov diff --git a/tensorhue/__init__.py b/tensorhue/__init__.py index 105ef07..3b9d1a9 100644 --- a/tensorhue/__init__.py +++ b/tensorhue/__init__.py @@ -1,10 +1,11 @@ import sys -from rich.console import Console +import inspect import tensorhue._numpy as np from tensorhue.colors import COLORS, ColorScheme from tensorhue._print_opts import PRINT_OPTS, set_printoptions from tensorhue._numpy import NumpyArrayWrapper from tensorhue._torch import _tensorhue_to_numpy_torch +from tensorhue._jax import _tensorhue_to_numpy_jax from tensorhue.eastereggs import pride from tensorhue.viz import viz, _viz @@ -19,3 +20,16 @@ torch = sys.modules["torch"] setattr(torch.Tensor, "viz", _viz) setattr(torch.Tensor, "_tensorhue_to_numpy", _tensorhue_to_numpy_torch) +if "jax" in sys.modules: + jax = sys.modules["jax"] + setattr(jax.Array, "viz", _viz) + setattr(jax.Array, "_tensorhue_to_numpy", _tensorhue_to_numpy_jax) + jaxlib = sys.modules["jaxlib"] + if "DeviceArrayBase" in {x[0] for x in inspect.getmembers(jaxlib.xla_extension)}: # jax < 0.4.X + setattr(jaxlib.xla_extension.DeviceArrayBase, "viz", _viz) + setattr(jaxlib.xla_extension.DeviceArrayBase, "_tensorhue_to_numpy", _tensorhue_to_numpy_jax) + if "ArrayImpl" in { + x[0] for x in inspect.getmembers(jaxlib.xla_extension) + }: # jax >= 0.4.X (not sure about the exact version this changed) + setattr(jaxlib.xla_extension.ArrayImpl, "viz", _viz) + setattr(jaxlib.xla_extension.ArrayImpl, "_tensorhue_to_numpy", _tensorhue_to_numpy_jax) diff --git a/tensorhue/_jax.py b/tensorhue/_jax.py new file mode 100644 index 0000000..493bd8c --- /dev/null +++ b/tensorhue/_jax.py @@ -0,0 +1,24 @@ +import numpy as np + + +def _tensorhue_to_numpy_jax(tensor) -> np.ndarray: + not_implemented = {"ShapedArray", "UnshapedArray", "AbstractArray"} + if {c.__name__ for c in tensor.__class__.__mro__}.intersection( + not_implemented + ): # hacky - but we shouldn't import jax here + raise NotImplementedError( + f"Jax arrays of type {tensor.__class__.__name__} cannot be visualized. Raise an issue if you believe this is wrong." + ) + try: + array = np.asarray(tensor) + if array.dtype == "object": + raise RuntimeError("Got non-visualizable dtype 'object'.") + return array + except RuntimeError as e: + raise NotImplementedError( + f"{e}: It looks like JAX arrays of type {type(tensor)} cannot be converted to numpy arrays out-of-the-box. Raise an issue if you need to visualize them." + ) from e + except Exception as e: + raise RuntimeError( + f"An unexpected error occurred while converting tensor of type {type(tensor)} to numpy array: {e}" + ) from e diff --git a/tests/test__jax.py b/tests/test__jax.py new file mode 100644 index 0000000..257c1bf --- /dev/null +++ b/tests/test__jax.py @@ -0,0 +1,50 @@ +import pytest +import jax.numpy as jnp +from jax import core +import numpy as np +from tensorhue._jax import _tensorhue_to_numpy_jax + + +class NonConvertibleTensor: + pass + + +def test_jax_device_array(): + data = [[1, 2], [3, 4]] + device_array = jnp.array(data) + assert np.array_equal(_tensorhue_to_numpy_jax(device_array), np.array(data)) + + +def test_tensor_dtypes(): + dtypes = { + jnp.float32: "float32", + jnp.bfloat16: "bfloat16", + jnp.int32: "int32", + jnp.uint8: "uint8", + bool: "bool", + jnp.complex64: "complex64", + } + jnp_array = jnp.array([0.0, 1.0, 2.0, jnp.nan, jnp.inf]) + for dtype_jnp, dtype_np in dtypes.items(): + jnp_casted = jnp_array.astype(dtype_jnp) + converted = _tensorhue_to_numpy_jax(jnp_casted) + assert np.array_equal( + converted.dtype, dtype_np + ), f"dtype mismatch in jax.numpy to numpy conversion: expected {dtype_np}, got {converted.dtype}" + + +def test_jax_incompatible_arrays(): + shape = (2, 2) + dtype = jnp.float32 + + shaped_array = core.ShapedArray(shape, dtype) + with pytest.raises(NotImplementedError) as exc_info: + _tensorhue_to_numpy_jax(shaped_array) + assert "cannot be visualized" in str(exc_info.value) + + +def test_runtime_error_for_non_convertible_tensor(): + non_convertible = NonConvertibleTensor() + with pytest.raises(NotImplementedError) as exc_info: + _tensorhue_to_numpy_jax(non_convertible) + assert "Got non-visualizable dtype 'object'." in str(exc_info.value)