Skip to content

Commit

Permalink
Refactor array rendering, add type registries, add PyTorch renderer.
Browse files Browse the repository at this point in the history
This change significantly reworks how penzai.treescope renders custom types,
by addding a "type registry" of type-specific pretty printers, similar to e.g. the
IPython pretty printer. (This is implemented via a new handler step, and can be overridden
if needed.) It also introduces a mechanism for dynamic type-dependent setup logic, so that
new handlers can be added to the registry when a library is imported, without having to
eagerly import that library.

Additionally, it adds a new NDArrayAdapter system, and modifies the array visualization
functions to use these adapters. The adapters make it possible to add support for new
ndarray-like types, including np.ndarray, jax.Array, pz.nx.NamedArray, and torch.Tensor,
using a uniform interface. Types in the adapter registry can be automatically visualized
by the array autovisualizer and manually rendered via `pz.ts.render_array`.

Furthermore, it adds initial support for PyTorch tensors (via the NDArrayAdapter registry)
and PyTorch modules, making it possible to visualize them using treescope whenever torch
is imported (but doing nothing if torch is not installed). PyTorch tensors support automatic
visualization similar to JAX Arrays. PyTorch modules are dynamically inspected to build a
visualization. (Note that due to the object semantics of PyTorch modules, and the convention
of mutating the module state in __init__ or afterward, PyTorch module renderings are in
general not round-trippable.)

Other minor changes:
- Removes or adjusts JAX imports so that Treescope can be used without importing JAX or running JAX device computations.
- Moves around some tests to improve organization.

PiperOrigin-RevId: 653411395
  • Loading branch information
danieldjohnson authored and Penzai Developers committed Jul 18, 2024
1 parent 64380f0 commit c8bf57e
Show file tree
Hide file tree
Showing 34 changed files with 3,532 additions and 1,773 deletions.
15 changes: 8 additions & 7 deletions penzai/core/_treescope_handlers/named_axes_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@
import numpy as np
from penzai.core import named_axes
from penzai.core._treescope_handlers import struct_handler
from penzai.treescope import ndarray_summarization
from penzai.treescope import dtype_util
from penzai.treescope import renderer
from penzai.treescope.foldable_representation import basic_parts
from penzai.treescope.foldable_representation import common_structures
from penzai.treescope.foldable_representation import common_styles
from penzai.treescope.foldable_representation import foldable_impl
from penzai.treescope.foldable_representation import part_interface
from penzai.treescope.handlers import builtin_structure_handler
from penzai.treescope.handlers.interop import jax_support


def named_array_and_contained_type_summary(
Expand Down Expand Up @@ -59,7 +60,7 @@ def named_array_and_contained_type_summary(

# Give a short summary for our named arrays.
summary_parts = []
summary_parts.append(ndarray_summarization.get_dtype_name(named_array.dtype))
summary_parts.append(dtype_util.get_dtype_name(named_array.dtype))
summary_parts.append("(")
for i, size in enumerate(named_array.positional_shape):
if i:
Expand All @@ -79,13 +80,13 @@ def named_array_and_contained_type_summary(
summary_parts.append(f"{name}:{size}")
summary_parts.append(")")

if inspect_device_data and ndarray_summarization.safe_to_summarize(
named_array.data_array
if (
inspect_device_data
and isinstance(named_array.data_array, jax.Array)
and jax_support.safe_to_summarize(named_array.data_array)
):
summary_parts.append(
ndarray_summarization.summarize_ndarray(
named_array.data_array, include_shape_and_dtype=False
)
jax_support.summarize_array_data(named_array.data_array)
)

return "".join(summary_parts), contained_type
Expand Down
4 changes: 2 additions & 2 deletions penzai/core/_treescope_handlers/shapecheck_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@

import numpy as np
from penzai.core import shapecheck
from penzai.treescope import dtype_util
from penzai.treescope import html_escaping
from penzai.treescope import ndarray_summarization
from penzai.treescope import renderer
from penzai.treescope.foldable_representation import basic_parts
from penzai.treescope.foldable_representation import common_structures
Expand Down Expand Up @@ -83,7 +83,7 @@ def _arraystructure_summary(
if structure.dtype is np.generic:
summary_parts.append("any")
else:
summary_parts.append(ndarray_summarization.get_dtype_name(structure.dtype))
summary_parts.append(dtype_util.get_dtype_name(structure.dtype))
summary_parts.append("(")
for i, dim in enumerate(structure.shape):
if i:
Expand Down
31 changes: 18 additions & 13 deletions penzai/core/named_axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,29 +900,35 @@ def order_like(
Args:
other: Another named array or named array view. Must have the same set of
named axes as this one. If this is a `NamedArrayView`, must also have
the same positional axes.
named axes as ``self``. If ``other`` is a `NamedArrayView`, ``other``
must also have the same number of positional axes.
Returns:
A new `NamedArray` or `NamedArrayView` that has the content of ``self``
but is possibly transposed to have the same PyTree structure as ``other``
(as long as the arrays have the same shape).
but is possibly transposed to have the axes appear in the same order as
``other`` in the data array. If the arrays have the same named and
positional shapes, the result will have the same PyTree structure as
``other``.
"""
self.check_valid()
other.check_valid()
if isinstance(other, NamedArray):
return self.order_as(*other.named_shape.keys())
elif isinstance(other, NamedArrayView):
if (
self.positional_shape != other.positional_shape
or self.named_shape != other.named_shape
):
if len(self.positional_shape) != len(other.positional_shape):
raise ValueError(
"Calling `order_like` with a NamedArrayView requires the two"
" arrays have the same positional and named shapes."
f" {self.positional_shape=}, {self.named_shape=},"
f" {other.positional_shape=}, {other.named_shape=}"
" arrays to have the same number of positional axes, but got"
f" positional shapes {self.positional_shape=},"
f" {other.positional_shape=}"
)
if set(self.named_shape.keys()) != set(other.named_shape.keys()):
raise ValueError(
"Calling `order_like` with a NamedArrayView requires the two"
" arrays to have the axis names, but got"
f" named shapes {self.named_shape=}, {other.named_shape=}"
)

self_view = self.as_namedarrayview()
new_to_old_data_axis = {}
for old_data_axis, new_data_axis in zip(
Expand All @@ -935,9 +941,8 @@ def order_like(
self_view.data_array,
[new_to_old_data_axis[i] for i in range(self_view.data_array.ndim)],
)
assert new_data_array.shape == other.data_shape
return NamedArrayView(
data_shape=other.data_shape,
data_shape=new_data_array.shape,
data_axis_for_logical_axis=other.data_axis_for_logical_axis,
data_axis_for_name=other.data_axis_for_name,
data_array=new_data_array,
Expand Down
1 change: 0 additions & 1 deletion penzai/pz/ts.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
render_array,
text_on_color,
render_array_sharding,
render_sharded_shape,
)
from penzai.treescope.autovisualize import (
Autovisualizer,
Expand Down
2 changes: 2 additions & 0 deletions penzai/treescope/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
from . import handlers
from . import html_encapsulation
from . import html_escaping
from . import ndarray_adapters
from . import renderer
from . import repr_lib
from . import treescope_ipython
from . import type_registries
Loading

0 comments on commit c8bf57e

Please sign in to comment.