Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JAX functionality #1

Merged
merged 10 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions .github/workflows/update-coverage-badge.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ jobs:
with:
python-version: '3.x'

- name: "Install Dependencies"
- name: "Install dependencies"
run: |
python -m pip install --upgrade pip
pip install -r requirements-dev.txt
pip install -e .

- name: "Run Coverage and Generate Badge"
- name: "Run coverage and generate badge"
run: |
pytest .
coverage report
Expand All @@ -52,11 +52,22 @@ jobs:
echo "svg_changed=true" >> $GITHUB_ENV
fi

- name: "Commit and Push Changes"
- name: "Commit and push changes"
if: env.svg_changed == 'true'
run: |
git config user.email "${{ github.run_id }}+github-actions[bot]@users.noreply.github.com"
git config user.name "github-actions[bot]"

if [[ "${{ github.event_name }}" == 'push' ]]; then
target_branch=$(echo "${{ github.ref }}" | awk -F'/' '{print $3}')
else
target_branch="${{ github.event.pull_request.head.ref }}"
fi

git fetch origin "${target_branch}:${target_branch}"
git checkout "${target_branch}" || git checkout -b "${target_branch}"
git push --set-upstream origin "${target_branch}"

git add coverage-badge.svg
git commit -m "gh-actions[bot]: update code coverage badge"
git push
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

TensorHue is a Python library that allows you to visualize tensors right in your console, making understanding and debugging tensor contents easier.

You can use it with your favorite tensor processing libraries, such as PyTorch, JAX*, and TensorFlow*.
You can use it with your favorite tensor processing libraries, such as PyTorch, JAX, and TensorFlow*.
_*coming soon_

TensorHue automagically detects which kind of tensor you are visualizing and adjusts accordingly:
Expand Down
2 changes: 1 addition & 1 deletion coverage-badge.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
pre-commit
pylint
torch
jax
tox
pytest
pytest-cov
Expand Down
16 changes: 15 additions & 1 deletion tensorhue/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import sys
from rich.console import Console
import inspect
import tensorhue._numpy as np
from tensorhue.colors import COLORS, ColorScheme
from tensorhue._print_opts import PRINT_OPTS, set_printoptions
from tensorhue._numpy import NumpyArrayWrapper
from tensorhue._torch import _tensorhue_to_numpy_torch
from tensorhue._jax import _tensorhue_to_numpy_jax
from tensorhue.eastereggs import pride
from tensorhue.viz import viz, _viz

Expand All @@ -19,3 +20,16 @@
torch = sys.modules["torch"]
setattr(torch.Tensor, "viz", _viz)
setattr(torch.Tensor, "_tensorhue_to_numpy", _tensorhue_to_numpy_torch)
if "jax" in sys.modules:
jax = sys.modules["jax"]
setattr(jax.Array, "viz", _viz)
setattr(jax.Array, "_tensorhue_to_numpy", _tensorhue_to_numpy_jax)
jaxlib = sys.modules["jaxlib"]
if "DeviceArrayBase" in {x[0] for x in inspect.getmembers(jaxlib.xla_extension)}: # jax < 0.4.X
setattr(jaxlib.xla_extension.DeviceArrayBase, "viz", _viz)
setattr(jaxlib.xla_extension.DeviceArrayBase, "_tensorhue_to_numpy", _tensorhue_to_numpy_jax)
if "ArrayImpl" in {
x[0] for x in inspect.getmembers(jaxlib.xla_extension)
}: # jax >= 0.4.X (not sure about the exact version this changed)
setattr(jaxlib.xla_extension.ArrayImpl, "viz", _viz)
setattr(jaxlib.xla_extension.ArrayImpl, "_tensorhue_to_numpy", _tensorhue_to_numpy_jax)
24 changes: 24 additions & 0 deletions tensorhue/_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import numpy as np


def _tensorhue_to_numpy_jax(tensor) -> np.ndarray:
not_implemented = {"ShapedArray", "UnshapedArray", "AbstractArray"}
if {c.__name__ for c in tensor.__class__.__mro__}.intersection(
not_implemented
): # hacky - but we shouldn't import jax here
raise NotImplementedError(
f"Jax arrays of type {tensor.__class__.__name__} cannot be visualized. Raise an issue if you believe this is wrong."
)
try:
array = np.asarray(tensor)
if array.dtype == "object":
raise RuntimeError("Got non-visualizable dtype 'object'.")
return array
except RuntimeError as e:
raise NotImplementedError(
f"{e}: It looks like JAX arrays of type {type(tensor)} cannot be converted to numpy arrays out-of-the-box. Raise an issue if you need to visualize them."
) from e
except Exception as e:
raise RuntimeError(
f"An unexpected error occurred while converting tensor of type {type(tensor)} to numpy array: {e}"
) from e
50 changes: 50 additions & 0 deletions tests/test__jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import pytest
import jax.numpy as jnp
from jax import core
import numpy as np
from tensorhue._jax import _tensorhue_to_numpy_jax


class NonConvertibleTensor:
pass


def test_jax_device_array():
data = [[1, 2], [3, 4]]
device_array = jnp.array(data)
assert np.array_equal(_tensorhue_to_numpy_jax(device_array), np.array(data))


def test_tensor_dtypes():
dtypes = {
jnp.float32: "float32",
jnp.bfloat16: "bfloat16",
jnp.int32: "int32",
jnp.uint8: "uint8",
bool: "bool",
jnp.complex64: "complex64",
}
jnp_array = jnp.array([0.0, 1.0, 2.0, jnp.nan, jnp.inf])
for dtype_jnp, dtype_np in dtypes.items():
jnp_casted = jnp_array.astype(dtype_jnp)
converted = _tensorhue_to_numpy_jax(jnp_casted)
assert np.array_equal(
converted.dtype, dtype_np
), f"dtype mismatch in jax.numpy to numpy conversion: expected {dtype_np}, got {converted.dtype}"


def test_jax_incompatible_arrays():
shape = (2, 2)
dtype = jnp.float32

shaped_array = core.ShapedArray(shape, dtype)
with pytest.raises(NotImplementedError) as exc_info:
_tensorhue_to_numpy_jax(shaped_array)
assert "cannot be visualized" in str(exc_info.value)


def test_runtime_error_for_non_convertible_tensor():
non_convertible = NonConvertibleTensor()
with pytest.raises(NotImplementedError) as exc_info:
_tensorhue_to_numpy_jax(non_convertible)
assert "Got non-visualizable dtype 'object'." in str(exc_info.value)