From d181696a29b177e2f347ed54c4f0d241b73568e7 Mon Sep 17 00:00:00 2001 From: Korbinian Koch Date: Tue, 18 Jun 2024 21:46:31 +0200 Subject: [PATCH 01/10] implement JAX functionality and tests --- tensorhue/__init__.py | 8 +++++++ tensorhue/_jax.py | 24 ++++++++++++++++++++ tests/test__jax.py | 51 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 83 insertions(+) create mode 100644 tensorhue/_jax.py create mode 100644 tests/test__jax.py diff --git a/tensorhue/__init__.py b/tensorhue/__init__.py index 105ef07..2eeeb37 100644 --- a/tensorhue/__init__.py +++ b/tensorhue/__init__.py @@ -5,6 +5,7 @@ from tensorhue._print_opts import PRINT_OPTS, set_printoptions from tensorhue._numpy import NumpyArrayWrapper from tensorhue._torch import _tensorhue_to_numpy_torch +from tensorhue._jax import _tensorhue_to_numpy_jax from tensorhue.eastereggs import pride from tensorhue.viz import viz, _viz @@ -19,3 +20,10 @@ torch = sys.modules["torch"] setattr(torch.Tensor, "viz", _viz) setattr(torch.Tensor, "_tensorhue_to_numpy", _tensorhue_to_numpy_torch) +if "jax" in sys.modules: + jax = sys.modules["jax"] + setattr(jax.Array, "viz", _viz) + setattr(jax.Array, "_tensorhue_to_numpy", _tensorhue_to_numpy_jax) + jaxlib = sys.modules["jaxlib"] + setattr(jaxlib.xla_extension.DeviceArrayBase, "viz", _viz) + setattr(jaxlib.xla_extension.DeviceArrayBase, "_tensorhue_to_numpy", _tensorhue_to_numpy_jax) diff --git a/tensorhue/_jax.py b/tensorhue/_jax.py new file mode 100644 index 0000000..493bd8c --- /dev/null +++ b/tensorhue/_jax.py @@ -0,0 +1,24 @@ +import numpy as np + + +def _tensorhue_to_numpy_jax(tensor) -> np.ndarray: + not_implemented = {"ShapedArray", "UnshapedArray", "AbstractArray"} + if {c.__name__ for c in tensor.__class__.__mro__}.intersection( + not_implemented + ): # hacky - but we shouldn't import jax here + raise NotImplementedError( + f"Jax arrays of type {tensor.__class__.__name__} cannot be visualized. Raise an issue if you believe this is wrong." + ) + try: + array = np.asarray(tensor) + if array.dtype == "object": + raise RuntimeError("Got non-visualizable dtype 'object'.") + return array + except RuntimeError as e: + raise NotImplementedError( + f"{e}: It looks like JAX arrays of type {type(tensor)} cannot be converted to numpy arrays out-of-the-box. Raise an issue if you need to visualize them." + ) from e + except Exception as e: + raise RuntimeError( + f"An unexpected error occurred while converting tensor of type {type(tensor)} to numpy array: {e}" + ) from e diff --git a/tests/test__jax.py b/tests/test__jax.py new file mode 100644 index 0000000..29eea82 --- /dev/null +++ b/tests/test__jax.py @@ -0,0 +1,51 @@ +import pytest +import jax.numpy as jnp +from jax import core +import numpy as np +import jax +from tensorhue._jax import _tensorhue_to_numpy_jax + + +class NonConvertibleTensor(jax.Array): + pass + + +def test_jax_device_array(): + data = [[1, 2], [3, 4]] + device_array = jnp.array(data) + assert np.array_equal(_tensorhue_to_numpy_jax(device_array), np.array(data)) + + +def test_tensor_dtypes(): + dtypes = { + jnp.float32: "float32", + jnp.bfloat16: "bfloat16", + jnp.int32: "int32", + jnp.uint8: "uint8", + bool: "bool", + jnp.complex64: "complex64", + } + jnp_array = jnp.array([0.0, 1.0, 2.0, jnp.nan, jnp.inf]) + for dtype_jnp, dtype_np in dtypes.items(): + jnp_casted = jnp_array.astype(dtype_jnp) + converted = _tensorhue_to_numpy_jax(jnp_casted) + assert np.array_equal( + converted.dtype, dtype_np + ), f"dtype mismatch in jax.numpy to numpy conversion: expected {dtype_np}, got {converted.dtype}" + + +def test_jax_incompatible_arrays(): + shape = (2, 2) + dtype = jnp.float32 + + shaped_array = core.ShapedArray(shape, dtype) + with pytest.raises(NotImplementedError) as exc_info: + _tensorhue_to_numpy_jax(shaped_array) + assert "cannot be visualized" in str(exc_info.value) + + +def test_runtime_error_for_non_convertible_tensor(): + non_convertible = NonConvertibleTensor() + with pytest.raises(NotImplementedError) as exc_info: + _tensorhue_to_numpy_jax(non_convertible) + assert "Got non-visualizable dtype 'object'." in str(exc_info.value) From 6cd813f2bfbc6237da346d6bd65c7a2f656e333a Mon Sep 17 00:00:00 2001 From: Korbinian Koch Date: Tue, 18 Jun 2024 21:48:15 +0200 Subject: [PATCH 02/10] group imports --- tests/test__jax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test__jax.py b/tests/test__jax.py index 29eea82..2a54538 100644 --- a/tests/test__jax.py +++ b/tests/test__jax.py @@ -1,8 +1,8 @@ import pytest import jax.numpy as jnp from jax import core -import numpy as np import jax +import numpy as np from tensorhue._jax import _tensorhue_to_numpy_jax From 737bf26b56b0123132ceda9b8717c46292923d2f Mon Sep 17 00:00:00 2001 From: Korbinian Koch Date: Tue, 18 Jun 2024 21:51:25 +0200 Subject: [PATCH 03/10] update JAX in README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 98ffcef..89c1a6a 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ TensorHue is a Python library that allows you to visualize tensors right in your console, making understanding and debugging tensor contents easier. -You can use it with your favorite tensor processing libraries, such as PyTorch, JAX*, and TensorFlow*. +You can use it with your favorite tensor processing libraries, such as PyTorch, JAX, and TensorFlow*. _*coming soon_ TensorHue automagically detects which kind of tensor you are visualizing and adjusts accordingly: From 8b562baa69c8e697e3c327eb926793ab758b3bdd Mon Sep 17 00:00:00 2001 From: Korbinian Koch Date: Tue, 18 Jun 2024 22:08:49 +0200 Subject: [PATCH 04/10] uppdate requirements --- requirements-dev.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements-dev.txt b/requirements-dev.txt index 4347553..bf680e5 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,6 +2,7 @@ pre-commit pylint torch +jax tox pytest pytest-cov From 577158a3819465bb6a4a2bc4db9174ba9ca48ed2 Mon Sep 17 00:00:00 2001 From: Korbinian Koch Date: Tue, 18 Jun 2024 23:35:57 +0200 Subject: [PATCH 05/10] handle newer jax version --- tensorhue/__init__.py | 12 +++++++++--- tests/test__jax.py | 4 ++-- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/tensorhue/__init__.py b/tensorhue/__init__.py index 2eeeb37..3b9d1a9 100644 --- a/tensorhue/__init__.py +++ b/tensorhue/__init__.py @@ -1,5 +1,5 @@ import sys -from rich.console import Console +import inspect import tensorhue._numpy as np from tensorhue.colors import COLORS, ColorScheme from tensorhue._print_opts import PRINT_OPTS, set_printoptions @@ -25,5 +25,11 @@ setattr(jax.Array, "viz", _viz) setattr(jax.Array, "_tensorhue_to_numpy", _tensorhue_to_numpy_jax) jaxlib = sys.modules["jaxlib"] - setattr(jaxlib.xla_extension.DeviceArrayBase, "viz", _viz) - setattr(jaxlib.xla_extension.DeviceArrayBase, "_tensorhue_to_numpy", _tensorhue_to_numpy_jax) + if "DeviceArrayBase" in {x[0] for x in inspect.getmembers(jaxlib.xla_extension)}: # jax < 0.4.X + setattr(jaxlib.xla_extension.DeviceArrayBase, "viz", _viz) + setattr(jaxlib.xla_extension.DeviceArrayBase, "_tensorhue_to_numpy", _tensorhue_to_numpy_jax) + if "ArrayImpl" in { + x[0] for x in inspect.getmembers(jaxlib.xla_extension) + }: # jax >= 0.4.X (not sure about the exact version this changed) + setattr(jaxlib.xla_extension.ArrayImpl, "viz", _viz) + setattr(jaxlib.xla_extension.ArrayImpl, "_tensorhue_to_numpy", _tensorhue_to_numpy_jax) diff --git a/tests/test__jax.py b/tests/test__jax.py index 2a54538..04e975a 100644 --- a/tests/test__jax.py +++ b/tests/test__jax.py @@ -1,12 +1,12 @@ import pytest import jax.numpy as jnp from jax import core -import jax +import jaxlib import numpy as np from tensorhue._jax import _tensorhue_to_numpy_jax -class NonConvertibleTensor(jax.Array): +class NonConvertibleTensor: pass From 6b157c894ba2d9fddc34be6747398d29d4947ab3 Mon Sep 17 00:00:00 2001 From: Korbinian Koch Date: Tue, 18 Jun 2024 23:36:59 +0200 Subject: [PATCH 06/10] lint jax test --- tests/test__jax.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test__jax.py b/tests/test__jax.py index 04e975a..257c1bf 100644 --- a/tests/test__jax.py +++ b/tests/test__jax.py @@ -1,7 +1,6 @@ import pytest import jax.numpy as jnp from jax import core -import jaxlib import numpy as np from tensorhue._jax import _tensorhue_to_numpy_jax From fbe5fa39177c5a1badcd380eac18b11c283b4996 Mon Sep 17 00:00:00 2001 From: Korbinian Koch Date: Tue, 18 Jun 2024 23:47:46 +0200 Subject: [PATCH 07/10] update coverage workflow --- .github/workflows/update-coverage-badge.yml | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/.github/workflows/update-coverage-badge.yml b/.github/workflows/update-coverage-badge.yml index f69ca34..e319ce6 100644 --- a/.github/workflows/update-coverage-badge.yml +++ b/.github/workflows/update-coverage-badge.yml @@ -57,6 +57,14 @@ jobs: run: | git config user.email "${{ github.run_id }}+github-actions[bot]@users.noreply.github.com" git config user.name "github-actions[bot]" + + if [[ "${{ github.event_name }}" == 'push' ]]; then + target_branch="main" + else + target_branch="${{ github.event.pull_request.head.ref }}" + fi + + git checkout "${target_branch}" git add coverage-badge.svg git commit -m "gh-actions[bot]: update code coverage badge" git push From b6046ff6217708aa2c15cd78efef747028514a3e Mon Sep 17 00:00:00 2001 From: Korbinian Koch Date: Tue, 18 Jun 2024 23:51:57 +0200 Subject: [PATCH 08/10] update coverage workflow --- .github/workflows/update-coverage-badge.yml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/update-coverage-badge.yml b/.github/workflows/update-coverage-badge.yml index e319ce6..ae60e8e 100644 --- a/.github/workflows/update-coverage-badge.yml +++ b/.github/workflows/update-coverage-badge.yml @@ -59,12 +59,14 @@ jobs: git config user.name "github-actions[bot]" if [[ "${{ github.event_name }}" == 'push' ]]; then - target_branch="main" + target_branch=$(echo "${{ github.ref }}" | awk -F'/' '{print $3}') else target_branch="${{ github.event.pull_request.head.ref }}" fi - git checkout "${target_branch}" + git fetch origin "${target_branch}:${target_branch}" + git checkout "${target_branch}" || git checkout -b "${target_branch}" + git add coverage-badge.svg git commit -m "gh-actions[bot]: update code coverage badge" git push From cc1ecf7fdaa23eae4974a3e98f069cc938d7da31 Mon Sep 17 00:00:00 2001 From: Korbinian Koch Date: Tue, 18 Jun 2024 23:56:10 +0200 Subject: [PATCH 09/10] update coverage workflow --- .github/workflows/update-coverage-badge.yml | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/update-coverage-badge.yml b/.github/workflows/update-coverage-badge.yml index ae60e8e..4829147 100644 --- a/.github/workflows/update-coverage-badge.yml +++ b/.github/workflows/update-coverage-badge.yml @@ -27,13 +27,13 @@ jobs: with: python-version: '3.x' - - name: "Install Dependencies" + - name: "Install dependencies" run: | python -m pip install --upgrade pip pip install -r requirements-dev.txt pip install -e . - - name: "Run Coverage and Generate Badge" + - name: "Run coverage and generate badge" run: | pytest . coverage report @@ -52,7 +52,7 @@ jobs: echo "svg_changed=true" >> $GITHUB_ENV fi - - name: "Commit and Push Changes" + - name: "Commit and push changes" if: env.svg_changed == 'true' run: | git config user.email "${{ github.run_id }}+github-actions[bot]@users.noreply.github.com" @@ -66,6 +66,7 @@ jobs: git fetch origin "${target_branch}:${target_branch}" git checkout "${target_branch}" || git checkout -b "${target_branch}" + git push --set-upstream origin "${target_branch}" git add coverage-badge.svg git commit -m "gh-actions[bot]: update code coverage badge" From 39e6e88b47c19fb446acc7db632c5bd31e9c58e0 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <9572741515+github-actions[bot]@users.noreply.github.com> Date: Tue, 18 Jun 2024 21:58:35 +0000 Subject: [PATCH 10/10] gh-actions[bot]: update code coverage badge --- coverage-badge.svg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/coverage-badge.svg b/coverage-badge.svg index e18725f..166ad97 100644 --- a/coverage-badge.svg +++ b/coverage-badge.svg @@ -1 +1 @@ -coverage: 89.68%coverage89.68% \ No newline at end of file +coverage: 88.95%coverage88.95% \ No newline at end of file