From c8bf57ee705e8b456ca98a0ddb4857b4056a191f Mon Sep 17 00:00:00 2001 From: Daniel Johnson Date: Wed, 17 Jul 2024 17:21:06 -0700 Subject: [PATCH] Refactor array rendering, add type registries, add PyTorch renderer. This change significantly reworks how penzai.treescope renders custom types, by addding a "type registry" of type-specific pretty printers, similar to e.g. the IPython pretty printer. (This is implemented via a new handler step, and can be overridden if needed.) It also introduces a mechanism for dynamic type-dependent setup logic, so that new handlers can be added to the registry when a library is imported, without having to eagerly import that library. Additionally, it adds a new NDArrayAdapter system, and modifies the array visualization functions to use these adapters. The adapters make it possible to add support for new ndarray-like types, including np.ndarray, jax.Array, pz.nx.NamedArray, and torch.Tensor, using a uniform interface. Types in the adapter registry can be automatically visualized by the array autovisualizer and manually rendered via `pz.ts.render_array`. Furthermore, it adds initial support for PyTorch tensors (via the NDArrayAdapter registry) and PyTorch modules, making it possible to visualize them using treescope whenever torch is imported (but doing nothing if torch is not installed). PyTorch tensors support automatic visualization similar to JAX Arrays. PyTorch modules are dynamically inspected to build a visualization. (Note that due to the object semantics of PyTorch modules, and the convention of mutating the module state in __init__ or afterward, PyTorch module renderings are in general not round-trippable.) Other minor changes: - Removes or adjusts JAX imports so that Treescope can be used without importing JAX or running JAX device computations. - Moves around some tests to improve organization. PiperOrigin-RevId: 653411395 --- .../named_axes_handlers.py | 15 +- .../shapecheck_handlers.py | 4 +- penzai/core/named_axes.py | 31 +- penzai/pz/ts.py | 1 - penzai/treescope/__init__.py | 2 + penzai/treescope/array_autovisualizer.py | 468 ++++----- penzai/treescope/arrayviz.py | 945 ++++++++++-------- penzai/treescope/canonical_aliases.py | 46 +- penzai/treescope/copypaste_fallback.py | 14 +- penzai/treescope/default_renderer.py | 39 +- penzai/treescope/dtype_util.py | 89 ++ .../common_structures.py | 3 + .../foldable_representation/foldable_impl.py | 23 +- .../handlers/builtin_structure_handler.py | 5 +- ...hod_handler.py => custom_type_handlers.py} | 40 +- .../handlers/generic_pytree_handler.py | 8 +- .../handlers/hardcoded_structure_handlers.py | 192 ---- .../treescope/handlers/interop/jax_support.py | 611 +++++++++++ .../handlers/interop/numpy_support.py | 319 ++++++ .../handlers/interop/penzai_core_support.py | 141 +++ .../handlers/interop/torch_support.py | 532 ++++++++++ penzai/treescope/handlers/ndarray_handler.py | 176 ---- .../handlers/shared_value_postprocessor.py | 21 +- penzai/treescope/ndarray_adapters.py | 295 ++++++ penzai/treescope/ndarray_summarization.py | 532 ---------- penzai/treescope/repr_lib.py | 47 + penzai/treescope/treescope_ipython.py | 13 +- penzai/treescope/type_registries.py | 190 ++++ .../canonical_aliases_test.py} | 9 +- tests/{ => treescope}/fixtures/__init__.py | 0 .../fixtures/treescope_examples_fixture.py | 25 + tests/treescope/ndarray_adapters_test.py | 285 ++++++ .../renderer_test.py} | 170 ++-- .../representation_test.py} | 14 +- 34 files changed, 3532 insertions(+), 1773 deletions(-) create mode 100644 penzai/treescope/dtype_util.py rename penzai/treescope/handlers/{extension_method_handler.py => custom_type_handlers.py} (61%) delete mode 100644 penzai/treescope/handlers/hardcoded_structure_handlers.py create mode 100644 penzai/treescope/handlers/interop/jax_support.py create mode 100644 penzai/treescope/handlers/interop/numpy_support.py create mode 100644 penzai/treescope/handlers/interop/penzai_core_support.py create mode 100644 penzai/treescope/handlers/interop/torch_support.py delete mode 100644 penzai/treescope/handlers/ndarray_handler.py create mode 100644 penzai/treescope/ndarray_adapters.py delete mode 100644 penzai/treescope/ndarray_summarization.py create mode 100644 penzai/treescope/type_registries.py rename tests/{treescope_canonical_aliases_test.py => treescope/canonical_aliases_test.py} (98%) rename tests/{ => treescope}/fixtures/__init__.py (100%) rename tests/{ => treescope}/fixtures/treescope_examples_fixture.py (86%) create mode 100644 tests/treescope/ndarray_adapters_test.py rename tests/{treescope_renderer_test.py => treescope/renderer_test.py} (87%) rename tests/{treescope_representation_test.py => treescope/representation_test.py} (99%) diff --git a/penzai/core/_treescope_handlers/named_axes_handlers.py b/penzai/core/_treescope_handlers/named_axes_handlers.py index 916f8f7..11e9734 100644 --- a/penzai/core/_treescope_handlers/named_axes_handlers.py +++ b/penzai/core/_treescope_handlers/named_axes_handlers.py @@ -23,7 +23,7 @@ import numpy as np from penzai.core import named_axes from penzai.core._treescope_handlers import struct_handler -from penzai.treescope import ndarray_summarization +from penzai.treescope import dtype_util from penzai.treescope import renderer from penzai.treescope.foldable_representation import basic_parts from penzai.treescope.foldable_representation import common_structures @@ -31,6 +31,7 @@ from penzai.treescope.foldable_representation import foldable_impl from penzai.treescope.foldable_representation import part_interface from penzai.treescope.handlers import builtin_structure_handler +from penzai.treescope.handlers.interop import jax_support def named_array_and_contained_type_summary( @@ -59,7 +60,7 @@ def named_array_and_contained_type_summary( # Give a short summary for our named arrays. summary_parts = [] - summary_parts.append(ndarray_summarization.get_dtype_name(named_array.dtype)) + summary_parts.append(dtype_util.get_dtype_name(named_array.dtype)) summary_parts.append("(") for i, size in enumerate(named_array.positional_shape): if i: @@ -79,13 +80,13 @@ def named_array_and_contained_type_summary( summary_parts.append(f"{name}:{size}") summary_parts.append(")") - if inspect_device_data and ndarray_summarization.safe_to_summarize( - named_array.data_array + if ( + inspect_device_data + and isinstance(named_array.data_array, jax.Array) + and jax_support.safe_to_summarize(named_array.data_array) ): summary_parts.append( - ndarray_summarization.summarize_ndarray( - named_array.data_array, include_shape_and_dtype=False - ) + jax_support.summarize_array_data(named_array.data_array) ) return "".join(summary_parts), contained_type diff --git a/penzai/core/_treescope_handlers/shapecheck_handlers.py b/penzai/core/_treescope_handlers/shapecheck_handlers.py index 76dd7b3..fc2e340 100644 --- a/penzai/core/_treescope_handlers/shapecheck_handlers.py +++ b/penzai/core/_treescope_handlers/shapecheck_handlers.py @@ -21,8 +21,8 @@ import numpy as np from penzai.core import shapecheck +from penzai.treescope import dtype_util from penzai.treescope import html_escaping -from penzai.treescope import ndarray_summarization from penzai.treescope import renderer from penzai.treescope.foldable_representation import basic_parts from penzai.treescope.foldable_representation import common_structures @@ -83,7 +83,7 @@ def _arraystructure_summary( if structure.dtype is np.generic: summary_parts.append("any") else: - summary_parts.append(ndarray_summarization.get_dtype_name(structure.dtype)) + summary_parts.append(dtype_util.get_dtype_name(structure.dtype)) summary_parts.append("(") for i, dim in enumerate(structure.shape): if i: diff --git a/penzai/core/named_axes.py b/penzai/core/named_axes.py index 343d9d3..b921471 100644 --- a/penzai/core/named_axes.py +++ b/penzai/core/named_axes.py @@ -900,29 +900,35 @@ def order_like( Args: other: Another named array or named array view. Must have the same set of - named axes as this one. If this is a `NamedArrayView`, must also have - the same positional axes. + named axes as ``self``. If ``other`` is a `NamedArrayView`, ``other`` + must also have the same number of positional axes. Returns: A new `NamedArray` or `NamedArrayView` that has the content of ``self`` - but is possibly transposed to have the same PyTree structure as ``other`` - (as long as the arrays have the same shape). + but is possibly transposed to have the axes appear in the same order as + ``other`` in the data array. If the arrays have the same named and + positional shapes, the result will have the same PyTree structure as + ``other``. """ self.check_valid() other.check_valid() if isinstance(other, NamedArray): return self.order_as(*other.named_shape.keys()) elif isinstance(other, NamedArrayView): - if ( - self.positional_shape != other.positional_shape - or self.named_shape != other.named_shape - ): + if len(self.positional_shape) != len(other.positional_shape): raise ValueError( "Calling `order_like` with a NamedArrayView requires the two" - " arrays have the same positional and named shapes." - f" {self.positional_shape=}, {self.named_shape=}," - f" {other.positional_shape=}, {other.named_shape=}" + " arrays to have the same number of positional axes, but got" + f" positional shapes {self.positional_shape=}," + f" {other.positional_shape=}" ) + if set(self.named_shape.keys()) != set(other.named_shape.keys()): + raise ValueError( + "Calling `order_like` with a NamedArrayView requires the two" + " arrays to have the axis names, but got" + f" named shapes {self.named_shape=}, {other.named_shape=}" + ) + self_view = self.as_namedarrayview() new_to_old_data_axis = {} for old_data_axis, new_data_axis in zip( @@ -935,9 +941,8 @@ def order_like( self_view.data_array, [new_to_old_data_axis[i] for i in range(self_view.data_array.ndim)], ) - assert new_data_array.shape == other.data_shape return NamedArrayView( - data_shape=other.data_shape, + data_shape=new_data_array.shape, data_axis_for_logical_axis=other.data_axis_for_logical_axis, data_axis_for_name=other.data_axis_for_name, data_array=new_data_array, diff --git a/penzai/pz/ts.py b/penzai/pz/ts.py index 24936a0..ca65879 100644 --- a/penzai/pz/ts.py +++ b/penzai/pz/ts.py @@ -26,7 +26,6 @@ render_array, text_on_color, render_array_sharding, - render_sharded_shape, ) from penzai.treescope.autovisualize import ( Autovisualizer, diff --git a/penzai/treescope/__init__.py b/penzai/treescope/__init__.py index dde42cc..3e6c400 100644 --- a/penzai/treescope/__init__.py +++ b/penzai/treescope/__init__.py @@ -44,6 +44,8 @@ from . import handlers from . import html_encapsulation from . import html_escaping +from . import ndarray_adapters from . import renderer from . import repr_lib from . import treescope_ipython +from . import type_registries diff --git a/penzai/treescope/array_autovisualizer.py b/penzai/treescope/array_autovisualizer.py index 207d8d7..9bc094f 100644 --- a/penzai/treescope/array_autovisualizer.py +++ b/penzai/treescope/array_autovisualizer.py @@ -16,16 +16,15 @@ from __future__ import annotations import dataclasses +import sys from typing import Any, Callable, Collection -import jax -import jax.numpy as jnp import numpy as np -from penzai.core import named_axes -from penzai.core._treescope_handlers import named_axes_handlers from penzai.treescope import arrayviz from penzai.treescope import autovisualize -from penzai.treescope import ndarray_summarization +from penzai.treescope import dtype_util +from penzai.treescope import ndarray_adapters +from penzai.treescope import type_registries from penzai.treescope.foldable_representation import basic_parts from penzai.treescope.foldable_representation import common_structures from penzai.treescope.foldable_representation import common_styles @@ -33,11 +32,19 @@ from penzai.treescope.foldable_representation import part_interface -def _supported_dtype(dtype): - return ( - jnp.issubdtype(dtype, np.integer) - or jnp.issubdtype(dtype, np.floating) - or jnp.issubdtype(dtype, np.bool_) +PositionalAxisInfo = ndarray_adapters.PositionalAxisInfo +NamedPositionlessAxisInfo = ndarray_adapters.NamedPositionlessAxisInfo +NamedPositionalAxisInfo = ndarray_adapters.NamedPositionalAxisInfo +AxisInfo = ndarray_adapters.AxisInfo + +ArrayInRegistry = Any + + +def _supported_dtype(dtype: np.dtype | None): + return dtype is not None and ( + dtype_util.is_integer_dtype(dtype) + or dtype_util.is_floating_dtype(dtype) + or np.issubdtype(dtype, np.bool_) ) @@ -45,6 +52,9 @@ def _supported_dtype(dtype): class ArrayAutovisualizer: """An automatic visualizer for arrays. + ArrayAutovisualizer supports any array type registered with an NDArrayAdapter + in the global type registry. + Attributes: maximum_size: Maximum numer of elements of an array to show. Arrays larger than this will be truncated along one or more axes. @@ -74,139 +84,147 @@ class ArrayAutovisualizer: include_repr_line_threshold: int = 5 token_lookup_fn: Callable[[int], str] | None = None - def _autovisualize_namedarray( + def _autovisualize_array( self, - named_array: named_axes.NamedArrayBase, + array: ArrayInRegistry, + adapter: ndarray_adapters.NDArrayAdapter, path: str | None, label: str, expand_state: part_interface.ExpandState, ) -> part_interface.RenderableTreePart: - """Visualizes a named array.""" - named_array = named_array.as_namedarrayview() + """Helper to visualize an array.""" + # Extract information about axis names, indices, and sizes. + array_axis_info = adapter.get_axis_info_for_array_data(array) - # Assign axes with a preference. + # Assign axes, using preferred axes if possible. row_axes = [] column_axes = [] - names_in_array = set(named_array.named_shape.keys()) - unassigned = set(names_in_array) - - for name in self.prefers_column: - if name in names_in_array: - column_axes.append(name) - unassigned.remove(name) - - for name in self.prefers_row: - if name in names_in_array: - row_axes.append(name) - unassigned.remove(name) - - # Infer remaining assignment. - shape_after_truncation = ndarray_summarization.compute_truncated_shape( - named_array.data_array.shape, - ndarray_summarization.infer_balanced_truncation( - named_array.data_array.shape, - maximum_size=self.maximum_size, - cutoff_size_per_axis=self.cutoff_size_per_axis, - minimum_edge_items=self.edge_items, - ), + for info in array_axis_info: + if isinstance(info, NamedPositionalAxisInfo | NamedPositionlessAxisInfo): + if info.axis_name in self.prefers_column: + column_axes.append(info.axis_name) + elif info.axis_name in self.prefers_row: + row_axes.append(info.axis_name) + + # Infer a good truncated shape for this array. + edge_items_per_axis = arrayviz.infer_balanced_truncation( + tuple(info.size for info in array_axis_info), + maximum_size=self.maximum_size, + cutoff_size_per_axis=self.cutoff_size_per_axis, + minimum_edge_items=self.edge_items, ) + row_axes, column_axes = arrayviz.infer_rows_and_columns( - axis_sizes={ - **{ - name: shape_after_truncation[data_axis] - for name, data_axis in named_array.data_axis_for_name.items() - }, - **{ - i: shape_after_truncation[data_axis] - for i, data_axis in enumerate( - named_array.data_axis_for_logical_axis - ) - }, - }, - unassigned=( - list(unassigned) + list(range(len(named_array.positional_shape))) - ), + all_axes=array_axis_info, known_rows=row_axes, known_columns=column_axes, + edge_items_per_axis=edge_items_per_axis, + ) + + # Obtain truncated array and mask data from the adapter. + truncated_array_data, truncated_mask_data = ( + adapter.get_array_data_with_truncation( + array=array, + mask=None, + edge_items_per_axis=edge_items_per_axis, + ) ) # Maybe infer value labels from a tokenizer. if ( self.token_lookup_fn and not self.force_continuous - and np.issubdtype(named_array.dtype, np.integer) - and named_array.data_array.size < self.maximum_size + and dtype_util.is_integer_dtype(truncated_array_data.dtype) ): - tokens = np.unique(named_array.data_array.flatten()).tolist() + tokens = np.unique(truncated_array_data.flatten()).tolist() value_item_labels = { token: self.token_lookup_fn(token) for token in tokens } else: value_item_labels = None - array_rendering = arrayviz.render_array( - named_array, - columns=column_axes, - rows=row_axes, - truncate=True, - maximum_size=self.maximum_size, - cutoff_size_per_axis=self.cutoff_size_per_axis, - minimum_edge_items=self.edge_items, + array_rendering = arrayviz._render_pretruncated( # pylint: disable=protected-access + array_axis_info=array_axis_info, + row_infos=row_axes, + column_infos=column_axes, + slider_infos=(), + truncated_array_data=truncated_array_data, + truncated_mask_data=truncated_mask_data, + edge_items_per_axis=edge_items_per_axis, + continuous="auto", around_zero=self.around_zero, - continuous=True if self.force_continuous else "auto", + vmax=None, + vmin=None, + trim_outliers=True, + dynamic_colormap="auto", + colormap=None, + axis_item_labels=None, value_item_labels=value_item_labels, + axis_labels=None, ) rendering_parts = [array_rendering] + last_line_parts = [] # Render the sharding as well. - if ( - isinstance(named_array.data_array, jax.Array) - and hasattr(named_array.data_array, "sharding") - and not isinstance( - named_array.data_array.sharding, jax.sharding.SingleDeviceSharding - ) - ): - sharding = named_array.data_array.sharding - platform = next(iter(sharding.device_set)).platform - sharding_rendering = arrayviz.render_array_sharding( - named_array, columns=column_axes, rows=row_axes - ) - if sharding.is_fully_replicated: - sharding_summary_str = ( - f"Replicated across {len(sharding.device_set)}" - f" {platform.upper()} devices" + sharding_info = adapter.get_sharding_info_for_array_data(array) + if sharding_info: + num_devices = len(sharding_info.device_index_to_shard_slices) + if num_devices == 1: + [device_id] = sharding_info.device_index_to_shard_slices.keys() + last_line_parts.append( + f"| Device: {sharding_info.device_type} {device_id}" ) else: - sharding_summary_str = ( - f"Sharded across {len(sharding.device_set)}" - f" {platform.upper()} devices" - ) - rendering_parts.append( - common_structures.build_custom_foldable_tree_node( - label=common_styles.AbbreviationColor( - basic_parts.siblings( - basic_parts.Text(sharding_summary_str), - basic_parts.FoldCondition( - expanded=basic_parts.Text(":"), - collapsed=basic_parts.Text(" (click to expand)"), - ), - ) - ), - contents=basic_parts.FoldCondition( - expanded=basic_parts.IndentedChildren([sharding_rendering]), - ), + if sharding_info.fully_replicated: + sharding_summary_str = ( + "Replicated across" + f" {num_devices} {sharding_info.device_type} devices" ) - ) + else: + sharding_summary_str = ( + "Sharded across" + f" {num_devices} {sharding_info.device_type} devices" + ) + sharding_rendering = arrayviz.render_array_sharding( + array, + columns=[c.logical_key() for c in column_axes], + rows=[r.logical_key() for r in row_axes], + ) + rendering_parts.append( + common_structures.build_custom_foldable_tree_node( + label=common_styles.AbbreviationColor( + basic_parts.siblings( + basic_parts.Text(sharding_summary_str), + basic_parts.FoldCondition( + expanded=basic_parts.Text(":"), + collapsed=basic_parts.Text(" (click to expand)"), + ), + ) + ), + contents=basic_parts.FoldCondition( + expanded=basic_parts.IndentedChildren([sharding_rendering]), + ), + ) + ) # We render it with a path, but remove the copy path button. This will be # added back by the caller. + if last_line_parts: + last_line = basic_parts.siblings( + basic_parts.FoldCondition( + expanded=basic_parts.Text("".join(last_line_parts)), + ), + basic_parts.Text(">"), + ) + else: + last_line = basic_parts.Text(">") custom_rendering = common_structures.build_custom_foldable_tree_node( label=common_styles.AbbreviationColor(label), contents=basic_parts.siblings( basic_parts.FoldCondition( expanded=basic_parts.IndentedChildren.build(rendering_parts) ), - common_styles.AbbreviationColor(basic_parts.Text(">")), + common_styles.AbbreviationColor(last_line), ), path=path, expand_state=expand_state, @@ -217,153 +235,64 @@ def __call__( self, value: Any, path: str | None ) -> autovisualize.CustomTreescopeVisualization | None: """Implementation of an autovisualizer, visualizing arrays.""" - with jax.core.ensure_compile_time_eval(): - if isinstance(value, named_axes.NamedArray | named_axes.NamedArrayView): - try: - value.check_valid() - except ValueError: - return None - if not _supported_dtype(value.dtype): - return None - - if not isinstance(value.data_array, jax.Array): - return None - - if ( - isinstance(value.data_array, jax.core.Tracer) - or value.data_array.is_deleted() - ): - return None - - if value.data_array.size == 1: - # Don't visualize scalars. - return None - - def _placeholder() -> part_interface.RenderableTreePart: - # Quick summary of the array that doesn't require device - # computation. - summary, contained_type = ( - named_axes_handlers.named_array_and_contained_type_summary( - value, inspect_device_data=False - ) - ) - short_summary = ( - f"<{type(value).__name__} {summary} (wrapping {contained_type})>" - ) - return common_structures.fake_placeholder_foldable( - common_styles.DeferredPlaceholderStyle( - basic_parts.Text(short_summary) - ), - extra_newlines_guess=8, - ) + # Retrieve the adapter for this array, which we will use to construct + # the rendering. + adapter = type_registries.lookup_by_mro( + type_registries.NDARRAY_ADAPTER_REGISTRY, type(value) + ) + if adapter is not None: + if not adapter.should_autovisualize(value): + return None - def _thunk(placeholder) -> part_interface.RenderableTreePart: - # Full rendering of the array. - if isinstance(placeholder, part_interface.FoldableTreeNode): - expand_state = placeholder.get_expand_state() - else: - assert placeholder is None - expand_state = part_interface.ExpandState.WEAKLY_EXPANDED - summary, contained_type = ( - named_axes_handlers.named_array_and_contained_type_summary(value) - ) - label = common_styles.AbbreviationColor( - basic_parts.Text( - f"<{type(value).__name__} {summary} (wrapping" - f" {contained_type})" - ) - ) - named_array = value - return self._autovisualize_namedarray( - named_array, path, label, expand_state - ) + # This is an array we can visualize! + # Extract information about axis names, indices, and sizes. + array_axis_info = adapter.get_axis_info_for_array_data(value) + total_size = np.prod([ax.size for ax in array_axis_info]) + if total_size == 1: + # Don't visualize scalars. + return None - return autovisualize.CustomTreescopeVisualization( - basic_parts.RenderableAndLineAnnotations( - renderable=foldable_impl.maybe_defer_rendering( - _thunk, _placeholder - ), - annotations=common_structures.build_copy_button(path), - ) + np_dtype = adapter.get_numpy_dtype(value) + if not _supported_dtype(np_dtype): + return None + + def _placeholder() -> part_interface.RenderableTreePart: + summary = adapter.get_array_summary(value, fast=True) + return common_structures.fake_placeholder_foldable( + common_styles.DeferredPlaceholderStyle( + basic_parts.Text(f"<{summary}>") + ), + extra_newlines_guess=8, ) - elif isinstance(value, (np.ndarray, jax.Array)): - if not ( - jnp.issubdtype(value.dtype, np.integer) - or jnp.issubdtype(value.dtype, np.floating) - ): - return None - - if value.size == 1: - # Don't visualize scalars. - return None - - if isinstance(value, np.ndarray): - contained_type = "np.ndarray" - elif isinstance(value, jax.Array) and not isinstance( - value, jax.core.Tracer - ): - contained_type = "jax.Array" - if value.is_deleted(): - return None + def _thunk(placeholder) -> part_interface.RenderableTreePart: + # Full rendering of the array. + if isinstance(placeholder, part_interface.FoldableTreeNode): + expand_state = placeholder.get_expand_state() else: - # Unsupported type - return None - - def _placeholder() -> part_interface.RenderableTreePart: - # Quick summary of the array that doesn't require device - # computation. - dtypestr = ndarray_summarization.get_dtype_name(value.dtype) - short_summary = ( - f"<{contained_type} {dtypestr}{repr(value.shape)} ... >" - ) - return common_structures.fake_placeholder_foldable( - common_styles.DeferredPlaceholderStyle( - basic_parts.Text(short_summary) - ), - extra_newlines_guess=8, - ) + assert placeholder is None + expand_state = part_interface.ExpandState.WEAKLY_EXPANDED + summary = adapter.get_array_summary(value, fast=False) + label = common_styles.AbbreviationColor(basic_parts.Text(f"<{summary}")) + return self._autovisualize_array( + value, adapter, path, label, expand_state + ) - def _thunk(placeholder) -> part_interface.RenderableTreePart: - # Full rendering of the array. - if isinstance(placeholder, part_interface.FoldableTreeNode): - expand_state = placeholder.get_expand_state() - else: - assert placeholder is None - expand_state = part_interface.ExpandState.WEAKLY_EXPANDED - value_repr = ndarray_summarization.faster_array_repr(value) - if "\n" not in value_repr and "..." not in value_repr: - if value_repr.startswith("<") and value_repr.endswith(">"): - label = basic_parts.Text(value_repr[:-1]) - else: - label = basic_parts.Text("<" + value_repr) - else: - label = basic_parts.Text( - "<" - + contained_type - + " " - + ndarray_summarization.summarize_ndarray(value) - ) - # Convert it to a named array so we can render it. - if isinstance(value, np.ndarray): - to_wrap = jax.device_put(value, jax.local_devices(backend="cpu")[0]) - else: - to_wrap = value - named_array = named_axes.wrap(to_wrap) - return self._autovisualize_namedarray( - named_array, path, label, expand_state + return autovisualize.CustomTreescopeVisualization( + basic_parts.RenderableAndLineAnnotations( + renderable=foldable_impl.maybe_defer_rendering( + _thunk, _placeholder + ), + annotations=common_structures.build_copy_button(path), ) + ) - return autovisualize.CustomTreescopeVisualization( - basic_parts.RenderableAndLineAnnotations( - renderable=foldable_impl.maybe_defer_rendering( - _thunk, _placeholder - ), - annotations=common_structures.build_copy_button(path), - ) - ) + # Not an array in the registry. But it might be a JAX sharding that we can + # visualize (if JAX is imported). + if "jax" in sys.modules: + import jax # pylint: disable=g-import-not-at-top - elif isinstance( + if isinstance( value, jax.sharding.PositionalSharding | jax.sharding.NamedSharding @@ -371,36 +300,57 @@ def _thunk(placeholder) -> part_interface.RenderableTreePart: ): raw_repr = repr(value) repr_oneline = " ".join(line.strip() for line in raw_repr.split("\n")) + if isinstance(value, jax.sharding.PositionalSharding): - shardvis = arrayviz.render_sharded_shape(value, value.shape) + sharding = value + fake_axis_info = [ + PositionalAxisInfo(i, size) + for i, size in enumerate(sharding.shape) + ] elif isinstance(value, jax.sharding.NamedSharding): + sharding = value # Named shardings still act on positional arrays, so show them for # the positional shape they require. - smallest_shape = [] + fake_sizes = [] for part in value.spec: if part is None: - smallest_shape.append(1) + fake_sizes.append(1) elif isinstance(part, str): - smallest_shape.append(value.mesh.shape[part]) + fake_sizes.append(value.mesh.shape[part]) else: - smallest_shape.append( - int(np.prod([value.mesh.shape[a] for a in part])) - ) - smallest_shape = tuple(smallest_shape) - shardvis = arrayviz.render_sharded_shape(value, smallest_shape) - elif isinstance(value, jax.sharding.Mesh): - shardvis = arrayviz.render_sharded_shape( - jax.sharding.NamedSharding( - value, jax.sharding.PartitionSpec(*value.axis_names) - ), - jax.eval_shape( - lambda x: named_axes.wrap(x, *value.axis_names), - value.device_ids, - ), - ) + size = int(np.prod([value.mesh.shape[a] for a in part])) + fake_sizes.append(size) + fake_axis_info = [ + PositionalAxisInfo(i, size) for i, size in enumerate(fake_sizes) + ] else: - assert False # impossible - + # Meshes are based on named axes. We build a temporary positional + # sharding for visualization, but keep track of name order. + assert isinstance(value, jax.sharding.Mesh) + mesh = value + sharding = jax.sharding.NamedSharding( + value, jax.sharding.PartitionSpec(*mesh.axis_names) + ) + fake_axis_info = [ + NamedPositionlessAxisInfo(name, mesh.shape[name]) + for name in mesh.axis_names + ] + + fake_shape = tuple(ax.size for ax in fake_axis_info) + some_device = next(iter(sharding.device_set)) + device_index_map = sharding.devices_indices_map(fake_shape) + sharding_info = ndarray_adapters.ShardingInfo( + shard_shape=sharding.shard_shape(fake_shape), + device_index_to_shard_slices={ + d.id: v for d, v in device_index_map.items() + }, + device_type=some_device.platform.upper(), + fully_replicated=sharding.is_fully_replicated, + ) + shardvis = arrayviz.render_sharding_info( + array_axis_info=fake_axis_info, + sharding_info=sharding_info, + ) custom_rendering = common_structures.build_custom_foldable_tree_node( label=common_styles.AbbreviationColor( basic_parts.Text("<" + repr_oneline) diff --git a/penzai/treescope/arrayviz.py b/penzai/treescope/arrayviz.py index b963b43..a8e3709 100644 --- a/penzai/treescope/arrayviz.py +++ b/penzai/treescope/arrayviz.py @@ -24,25 +24,33 @@ import base64 import collections import dataclasses -import functools import io import itertools import json import os -from typing import Any, Literal, Mapping, Sequence +from typing import Any, Literal, Sequence -import jax -import jax.numpy as jnp import numpy as np -from penzai.core import named_axes from penzai.treescope import context +from penzai.treescope import dtype_util from penzai.treescope import figures from penzai.treescope import html_escaping -from penzai.treescope import ndarray_summarization +from penzai.treescope import ndarray_adapters +from penzai.treescope import type_registries from penzai.treescope.foldable_representation import basic_parts from penzai.treescope.foldable_representation import part_interface +AxisName = Any + +PositionalAxisInfo = ndarray_adapters.PositionalAxisInfo +NamedPositionlessAxisInfo = ndarray_adapters.NamedPositionlessAxisInfo +NamedPositionalAxisInfo = ndarray_adapters.NamedPositionalAxisInfo +AxisInfo = ndarray_adapters.AxisInfo + +ArrayInRegistry = Any + + def load_arrayvis_javascript() -> str: """Loads the contents of `arrayvis.js` from the Python package. @@ -129,8 +137,8 @@ def _html_setup() -> ( def _render_array_to_html( - array_data: np.ndarray | jax.Array, - valid_mask: np.ndarray | jax.Array, + array_data: np.ndarray, + valid_mask: np.ndarray, column_axes: Sequence[int], row_axes: Sequence[int], slider_axes: Sequence[int], @@ -265,63 +273,78 @@ def axis_spec_arg(i): def infer_rows_and_columns( - axis_sizes: dict[int | named_axes.AxisName, int], - unassigned: Sequence[int | named_axes.AxisName] | None = None, - known_rows: Sequence[int | named_axes.AxisName] = (), - known_columns: Sequence[int | named_axes.AxisName] = (), -) -> tuple[list[int | named_axes.AxisName], list[int | named_axes.AxisName]]: + all_axes: Sequence[AxisInfo], + known_rows: Sequence[AxisInfo] = (), + known_columns: Sequence[AxisInfo] = (), + edge_items_per_axis: tuple[int | None, ...] | None = None, +) -> tuple[list[AxisInfo], list[AxisInfo]]: """Infers an ordered assignment of axis indices or names to rows and columns. The unassigned axes are sorted by size and then assigned to rows and columns to try to balance the total number of elements along the row and column axes. - Curently uses a greedy algorithm with an adjustment to try to keep columns - longer than rows, except when there are exactly two axes and both are + This curently uses a greedy algorithm with an adjustment to try to keep + columns longer than rows, except when there are exactly two axes and both are positional, in which case it lays out axis 0 as the rows and axis 1 as the columns. + Axes with logical positions are sorted before axes with only names + (in reverse order, so that later axes are rendered on the inside). Axes with + names only appear afterward, with explicitly-assigned ones before unassigned + ones. + Args: - axis_sizes: Mapping from axis indices or names to their axis size. - unassigned: Sequence of unassigned axis indices or names. Inferred from the - axis_sizes if not provided. + all_axes: Sequence of axis infos in the array that should be assigned. known_rows: Sequence of axis indices or names that must map to rows. known_columns: Sequence of axis indices or names that must map to columns. + edge_items_per_axis: Optional edge items specification, determining + truncated size of each axis. Must match the ordering of `all_axes`. Returns: Tuple (rows, columns) of assignments, which consist of `known_rows` and `known_columns` followed by the remaining unassigned axes in a balanced order. """ - if unassigned is None: - unassigned = [ - key - for key in axis_sizes.keys() - if key not in known_rows and key not in known_columns - ] + if edge_items_per_axis is None: + edge_items_per_axis = (None,) * len(all_axes) + + if not known_rows and not known_columns and len(all_axes) == 2: + ax_a, ax_b = all_axes + if ( + isinstance(ax_a, PositionalAxisInfo) + and isinstance(ax_b, PositionalAxisInfo) + and {ax_a.axis_logical_index, ax_b.axis_logical_index} == {0, 1} + ): + # Two-dimensional positional array. Always do rows then columns. + if ax_a.axis_logical_index == 0: + return ([ax_a], [ax_b]) + else: + return ([ax_b], [ax_a]) - if ( - not known_rows - and not known_columns - and len(unassigned) == 2 - and set(unassigned) == {0, 1} - ): - # Two-dimensional positional array. Always do rows then columns. - return ([0], [1]) + truncated_sizes = { + ax: ax.size if edge_items is None else 2 * edge_items + 1 + for ax, edge_items in zip(all_axes, edge_items_per_axis) + } + unassigned = [ + ax for ax in all_axes if ax not in known_rows and ax not in known_columns + ] # Sort by size descending, so that we make the most important layout decisions # first. - unassigned = sorted(unassigned, key=lambda ax: -axis_sizes[ax]) + unassigned = sorted( + unassigned, key=lambda ax: (truncated_sizes[ax], ax.size), reverse=True + ) # Compute the total size every axis would have if we assigned them to the # same axis. - unassigned_size = np.prod([axis_sizes[ax] for ax in unassigned]) + unassigned_size = np.prod([truncated_sizes[ax] for ax in unassigned]) rows = list(known_rows) - row_size = np.prod([axis_sizes[ax] for ax in rows]) + row_size = np.prod([truncated_sizes[ax] for ax in rows]) columns = list(known_columns) - column_size = np.prod([axis_sizes[ax] for ax in columns]) + column_size = np.prod([truncated_sizes[ax] for ax in columns]) for ax in unassigned: - axis_size = axis_sizes[ax] + axis_size = truncated_sizes[ax] unassigned_size = unassigned_size // axis_size if row_size * axis_size > column_size * unassigned_size: # If we assign this to the row axis, we'll end up with a visualization @@ -338,9 +361,9 @@ def infer_rows_and_columns( # arbitrary. Re-order each so that they have positional then named axes, and # so that position axes are in reverse position order, and the explicitly # mentioned named axes are before the unassigned ones. - def ax_sort_key(ax: int | named_axes.AxisName): - if isinstance(ax, int): - return (0, -ax) + def ax_sort_key(ax: AxisInfo): + if isinstance(ax, PositionalAxisInfo | NamedPositionalAxisInfo): + return (0, -ax.axis_logical_index) elif ax in unassigned: return (2,) else: @@ -349,26 +372,25 @@ def ax_sort_key(ax: int | named_axes.AxisName): return sorted(rows, key=ax_sort_key), sorted(columns, key=ax_sort_key) -@functools.partial(jax.jit, static_argnames=("around_zero", "trim_outliers")) def _infer_vmin_vmax( - array: jnp.Array, - mask: jnp.Array, + array: np.ndarray, + mask: np.ndarray, vmin: float | None, vmax: float | None, around_zero: bool, trim_outliers: bool, -) -> tuple[float | jax.Array, float | jax.Array]: +) -> tuple[float, float]: """Infer reasonable lower and upper colormap bounds from an array.""" inferring_both_bounds = vmax is None and vmin is None - finite_mask = jnp.logical_and(jnp.isfinite(array), mask) + finite_mask = np.logical_and(np.isfinite(array), mask) if vmax is None: if around_zero: if vmin is not None: vmax = -vmin # pylint: disable=invalid-unary-operand-type else: - vmax = jnp.max(jnp.where(finite_mask, jnp.abs(array), 0)) + vmax = np.max(np.where(finite_mask, np.abs(array), 0)) else: - vmax = jnp.max(jnp.where(finite_mask, array, -np.inf)) + vmax = np.max(np.where(finite_mask, array, -np.inf)) assert vmax is not None @@ -376,47 +398,183 @@ def _infer_vmin_vmax( if around_zero: vmin = -vmax # pylint: disable=invalid-unary-operand-type else: - vmin = jnp.min(jnp.where(finite_mask, array, np.inf)) + vmin = np.min(np.where(finite_mask, array, np.inf)) if inferring_both_bounds and trim_outliers: if around_zero: center = 0 else: - center = jnp.nanmean(jnp.where(finite_mask, array, np.nan)) - center = jnp.where(jnp.isfinite(center), center, 0.0) + center = np.nanmean(np.where(finite_mask, array, np.nan)) + center = np.where(np.isfinite(center), center, 0.0) - second_moment = jnp.nanmean( - jnp.where(finite_mask, jnp.square(array - center), np.nan) + second_moment = np.nanmean( + np.where(finite_mask, np.square(array - center), np.nan) ) - sigma = jnp.where( - jnp.isfinite(second_moment), jnp.sqrt(second_moment), vmax - vmin + sigma = np.where( + np.isfinite(second_moment), np.sqrt(second_moment), vmax - vmin ) vmin_limit = center - 3 * sigma - vmin = jnp.maximum(vmin, vmin_limit) + vmin = np.maximum(vmin, vmin_limit) vmax_limit = center + 3 * sigma - vmax = jnp.minimum(vmax, vmax_limit) + vmax = np.minimum(vmax, vmax_limit) return vmin, vmax -@jax.jit def _infer_abs_min_max( - array: jnp.Array, mask: jnp.Array -) -> tuple[float | jax.Array, float | jax.Array]: + array: np.ndarray, mask: np.ndarray +) -> tuple[float, float]: """Infer smallest and largest absolute values in array.""" - finite_mask = jnp.logical_and(jnp.isfinite(array), mask) - absmin = jnp.min( - jnp.where( - jnp.logical_and(finite_mask, array != 0), jnp.abs(array), np.inf - ) + finite_mask = np.logical_and(np.isfinite(array), mask) + absmin = np.min( + np.where(np.logical_and(finite_mask, array != 0), np.abs(array), np.inf) ) - absmin = jnp.where(jnp.isinf(absmin), 0.0, absmin) - absmax = jnp.max(jnp.where(finite_mask, jnp.abs(array), -np.inf)) - absmax = jnp.where(jnp.isinf(absmax), 0.0, absmax) + absmin = np.where(np.isinf(absmin), 0.0, absmin) + absmax = np.max(np.where(finite_mask, np.abs(array), -np.inf)) + absmax = np.where(np.isinf(absmax), 0.0, absmax) return absmin, absmax +def infer_balanced_truncation( + shape: Sequence[int], + maximum_size: int, + cutoff_size_per_axis: int, + minimum_edge_items: int, + doubling_bonus: float = 10.0, +) -> tuple[int | None, ...]: + """Infers a balanced truncation from a shape. + + This function computes a set of truncation sizes for each axis of the array + such that it obeys the constraints about array and axis sizes, while also + keeping the relative proportions of the array consistent (e.g. we keep more + elements along axes that were originally longer). This means that the aspect + ratio of the truncated array will still resemble the aspect ratio of the + original array. + + To avoid very-unbalanced renderings and truncate longer axes more than short + ones, this function truncates based on the square-root of the axis size by + default. + + Args: + shape: The shape of the array we are truncating. + maximum_size: Maximum number of elements of an array to show. Arrays larger + than this will be truncated along one or more axes. + cutoff_size_per_axis: Maximum number of elements of each individual axis to + show without truncation. Any axis longer than this will be truncated, with + their visual size increasing logarithmically with the true axis size + beyond this point. + minimum_edge_items: How many values to keep along each axis for truncated + arrays. We may keep more than this up to the budget of maximum_size. + doubling_bonus: Number of elements to add to each axis each time it doubles + beyond `cutoff_size_per_axis`. Used to make longer axes appear visually + longer while still keeping them a reasonable size. + + Returns: + A tuple of edge sizes. Each element corresponds to an axis in `shape`, + and is either `None` (for no truncation) or an integer (corresponding to + the number of elements to keep at the beginning and and at the end). + """ + shape_arr = np.array(list(shape)) + remaining_elements_to_divide = maximum_size + edge_items_per_axis = {} + # Order our shape from smallest to largest, since the smallest axes will + # require the least amount of truncation and will have the most stringent + # constraints. + sorted_axes = np.argsort(shape_arr) + sorted_shape = shape_arr[sorted_axes] + + # Figure out maximum sizes based on the cutoff + cutoff_adjusted_maximum_sizes = np.where( + sorted_shape <= cutoff_size_per_axis, + sorted_shape, + cutoff_size_per_axis + + doubling_bonus * np.log2(sorted_shape / cutoff_size_per_axis), + ) + + # Suppose we want to make a scaled version of the array with relative + # axis sizes + # s0, s1, s2, ... + # The total size is then + # size = (c * s0) * (c * s1) * (c * s2) * ... + # log(size) = ndim * log(c) + [ log s0 + log s1 + log s2 + ... ] + # If we have a known final size we want to reach, we can solve for c as + # c = exp( (log size - [ log s0 + log s1 + log s2 + ... ]) / ndim ) + axis_proportions = np.sqrt(sorted_shape) + log_axis_proportions = np.log(axis_proportions) + for i in range(len(sorted_axes)): + original_axis = sorted_axes[i] + size = shape_arr[original_axis] + # If we truncated this axis and every axis after it proportional to + # their weights, how small of an axis size would we need for this + # axis? + log_c = ( + np.log(remaining_elements_to_divide) - np.sum(log_axis_proportions[i:]) + ) / (len(shape) - i) + soft_limit_for_this_axis = np.exp(log_c + log_axis_proportions[i]) + cutoff_limit_for_this_axis = np.floor( + np.minimum( + soft_limit_for_this_axis, + cutoff_adjusted_maximum_sizes[i], + ) + ) + if size <= 2 * minimum_edge_items + 1 or size <= cutoff_limit_for_this_axis: + # If this axis is already smaller than the minimum size it would have + # after truncation, there's no reason to truncate it. + # But pretend we did, so that other axes still grow monotonically if + # their axis sizes increase. + remaining_elements_to_divide = ( + remaining_elements_to_divide / soft_limit_for_this_axis + ) + edge_items_per_axis[original_axis] = None + elif cutoff_limit_for_this_axis < 2 * minimum_edge_items + 1: + # If this axis is big enough to truncate, but our naive target size is + # smaller than the minimum allowed truncation, we should truncate it + # to the minimum size allowed instead. + edge_items_per_axis[original_axis] = minimum_edge_items + remaining_elements_to_divide = remaining_elements_to_divide / ( + 2 * minimum_edge_items + 1 + ) + else: + # Otherwise, truncate it and all remaining axes based on our target + # truncations. + for j in range(i, len(sorted_axes)): + visual_size = np.floor( + np.minimum( + np.exp(log_c + log_axis_proportions[j]), + cutoff_adjusted_maximum_sizes[j], + ) + ) + edge_items_per_axis[sorted_axes[j]] = int(visual_size // 2) + break + + return tuple( + edge_items_per_axis[orig_axis] for orig_axis in range(len(shape)) + ) + + +def compute_truncated_shape( + shape: tuple[int, ...], + edge_items: tuple[int | None, ...], +) -> tuple[int, ...]: + """Computes the shape of a truncated array. + + This can be used to estimate the size of an array visualization after it has + been truncated by `infer_balanced_truncation`. + + Args: + shape: The original array shape. + edge_items: Number of edge items to keep along each axis. + + Returns: + The shape of the truncated array. + """ + return tuple( + orig if edge is None else 2 * edge + 1 + for orig, edge in zip(shape, edge_items) + ) + + @dataclasses.dataclass(frozen=True) class ArrayvizRendering(figures.RendersAsRootInIPython): """A rendering of an array with Arrayviz. @@ -535,23 +693,12 @@ def render_to_html( def render_array( - array: ( - named_axes.NamedArray - | named_axes.NamedArrayView - | np.ndarray - | jax.Array - ), + array: ArrayInRegistry, *, - columns: Sequence[named_axes.AxisName | int] = (), - rows: Sequence[named_axes.AxisName | int] = (), - sliders: Sequence[named_axes.AxisName | int] = (), - valid_mask: ( - named_axes.NamedArray - | named_axes.NamedArrayView - | np.ndarray - | jax.Array - | None - ) = None, + columns: Sequence[AxisName | int] = (), + rows: Sequence[AxisName | int] = (), + sliders: Sequence[AxisName | int] = (), + valid_mask: Any | None = None, continuous: bool | Literal["auto"] = "auto", around_zero: bool | Literal["auto"] = "auto", vmax: float | None = None, @@ -563,9 +710,9 @@ def render_array( maximum_size: int = 10_000, cutoff_size_per_axis: int = 512, minimum_edge_items: int = 5, - axis_item_labels: dict[named_axes.AxisName | int, list[str]] | None = None, + axis_item_labels: dict[AxisName | int, list[str]] | None = None, value_item_labels: dict[int, str] | None = None, - axis_labels: dict[named_axes.AxisName | int, str] | None = None, + axis_labels: dict[AxisName | int, str] | None = None, ) -> ArrayvizRendering: """Renders an array (positional or named) to a displayable HTML object. @@ -619,7 +766,8 @@ def render_array( colormap. Args: - array: The array to render. + array: The array to render. The type of this array must be registered in + the `type_registries.NDARRAY_ADAPTER_REGISTRY`. columns: Sequence of axis names or positional axis indices that should be placed on the x axis, from innermost to outermost. If not provided, inferred automatically. @@ -693,195 +841,186 @@ def render_array( An object which can be rendered in an IPython notebook, containing the HTML source of an arrayviz rendering. """ - if axis_item_labels is None: - axis_item_labels = {} - - if value_item_labels is None: - value_item_labels = {} - - if axis_labels is None: - axis_labels = {} - - # Step 1: Wrap as named arrays if needed, for consistency of the following - # steps. But keep them on the CPU if they were numpy arrays. - if not isinstance(array, named_axes.NamedArray | named_axes.NamedArrayView): - if not isinstance(array, jax.Array): - array = jax.device_put(array, jax.devices("cpu")[0]) - array = named_axes.wrap(array) + # Retrieve the adapter for this array, which we will use to construct + # the rendering. + adapter = type_registries.lookup_by_mro( + type_registries.NDARRAY_ADAPTER_REGISTRY, type(array) + ) + if adapter is None: + raise TypeError( + f"Cannot render array with unrecognized type {type(array)} (not found" + " in array adapter registry)" + ) - array.check_valid() + # Extract information about axis names, indices, and sizes. + array_axis_info = adapter.get_axis_info_for_array_data(array) - if valid_mask is not None: - if not isinstance( - valid_mask, named_axes.NamedArray | named_axes.NamedArrayView - ): - if not isinstance(valid_mask, jax.Array): - valid_mask = jax.device_put(valid_mask, jax.devices("cpu")[0]) - valid_mask = named_axes.wrap(valid_mask) - - valid_mask.check_valid() + data_axis_from_axis_info = { + info: axis for axis, info in enumerate(array_axis_info) + } + assert len(data_axis_from_axis_info) == len(array_axis_info) + + info_by_name_or_position = {} + for info in array_axis_info: + if isinstance(info, NamedPositionalAxisInfo): + info_by_name_or_position[info.axis_name] = info + info_by_name_or_position[info.axis_logical_index] = info + elif isinstance(info, PositionalAxisInfo): + info_by_name_or_position[info.axis_logical_index] = info + elif isinstance(info, NamedPositionlessAxisInfo): + info_by_name_or_position[info.axis_name] = info + else: + raise ValueError(f"Unrecognized axis info {type(info)}") - # Make sure they are broadcast-compatible, and add length-1 axes for any - # that are missing. - bad_names = set(valid_mask.named_shape.keys()) - set( - array.named_shape.keys() - ) - if bad_names: - raise ValueError( - "Valid mask must be broadcastable to the shape of `array`, but it" - f" had extra axis names {bad_names}" - ) + row_infos = [info_by_name_or_position[spec] for spec in rows] + column_infos = [info_by_name_or_position[spec] for spec in columns] + slider_infos = [info_by_name_or_position[spec] for spec in sliders] - vshape = valid_mask.positional_shape - ashape = array.positional_shape - if vshape != ashape[len(ashape) - len(vshape) :]: + unassigned_axes = set(array_axis_info) + seen_axes = set() + for axis_info in itertools.chain(row_infos, column_infos, slider_infos): + if axis_info in seen_axes: raise ValueError( - "Valid mask must be broadcastable to the shape of `array`, but its" - f" positional shape ({vshape}) was not a suffix of those of `array`" - f" ({ashape})" + f"Axis {axis_info} appeared multiple times in rows/columns/sliders" + " specifications. Each axis must be assigned to at most one" + " location." ) + seen_axes.add(axis_info) + unassigned_axes.remove(axis_info) - # Insert new axes. - new_names = set(array.named_shape.keys()) - set( - valid_mask.named_shape.keys() - ) - if new_names: - valid_mask = valid_mask[{name: None for name in new_names}] - - new_positional_axis_count = len(ashape) - len(vshape) - if new_positional_axis_count: - valid_mask = valid_mask[(None,) * new_positional_axis_count + (...,)] - - # Step 2: Extract a positionally-indexed array of data, and remember the - # mapping from the original axis names and indices to their new data axes. - # We try to avoid transposing the initial array if possible, so this won't - # necessarily match the display order. - # (Recall that integers are NOT valid names for a NamedArray, so there are - # no possibilities of conflict between original axis names and indices.) - tmp_names_for_positional_axes = [ - object() for _ in range(len(array.positional_shape)) - ] - - fully_named_array = array.tag( - *tmp_names_for_positional_axes - ).as_namedarrayview() - array_data = fully_named_array.data_array - - data_axis_from_tmp_axis = {} - tmp_axis_from_data_axis = {} - for name, data_axis in fully_named_array.data_axis_for_name.items(): - data_axis_from_tmp_axis[name] = data_axis - tmp_axis_from_data_axis[data_axis] = name - - data_axis_from_orig_axis = {} - for name in array.named_shape.keys(): - data_axis_from_orig_axis[name] = data_axis_from_tmp_axis[name] - for idx in range(len(array.positional_shape)): - data_axis_from_orig_axis[idx] = data_axis_from_tmp_axis[ - tmp_names_for_positional_axes[idx] - ] - - # Step 3: If the mask exists, extract its data in the same order, and add - # length-one axes for any axes that were missing. Otherwise, create a new - # mask array with only length-one axes. - if valid_mask is not None: - assert isinstance(valid_mask, named_axes.NamedArrayBase) - fully_named_mask = ( - valid_mask.tag(*tmp_names_for_positional_axes) - .order_as(*(tmp_axis_from_data_axis[i] for i in range(array_data.ndim))) - .as_namedarrayview() - ) - assert ( - fully_named_mask.data_axis_for_name - == fully_named_array.data_axis_for_name - ) - mask_data = fully_named_mask.data_array - else: - mask_data = np.ones([1] * array_data.ndim, dtype=bool) - - # Step 4: Truncate the array and valid masks if requested, and ensure that the - # mask has the same shape as the array. if truncate: - edge_items_per_axis = ndarray_summarization.infer_balanced_truncation( - array_data.shape, + # Infer a good truncated shape for this array. + edge_items_per_axis = infer_balanced_truncation( + tuple(info.size for info in array_axis_info), maximum_size=maximum_size, cutoff_size_per_axis=cutoff_size_per_axis, minimum_edge_items=minimum_edge_items, ) - truncated_array_data, truncated_mask_data = ( - ndarray_summarization.truncate_array_and_mask( - array=array_data, - mask=mask_data, - edge_items_per_axis=edge_items_per_axis, - ) - ) else: - edge_items_per_axis = (None,) * array_data.ndim - truncated_array_data = array_data - truncated_mask_data = jnp.broadcast_to(mask_data, array_data.shape) - - # (Ensure they are fetched to the CPU to avoid device computation / sharding - # issues) - truncated_array_data, truncated_mask_data = jax.device_get( - (truncated_array_data, truncated_mask_data) + edge_items_per_axis = (None,) * len(array_axis_info) + + # Obtain truncated array and mask data from the adapter. + truncated_array_data, truncated_mask_data = ( + adapter.get_array_data_with_truncation( + array=array, + mask=valid_mask, + edge_items_per_axis=edge_items_per_axis, + ) ) - skip_start_indices = [ - edge_items if edge_items is not None else size - for edge_items, size in zip(edge_items_per_axis, array_data.shape) - ] - skip_end_indices = [ - size - edge_items if edge_items is not None else size - for edge_items, size in zip(edge_items_per_axis, array_data.shape) - ] - # Step 5: Figure out which axes to render as rows, columns, and sliders and - # in which order. We start with the explicitly-requested axes, then add more + # in which order. We start with the explicitly-requested axes, then add more # axes to the rows and columns until we've assigned all of them, trying to # balance rows and columns. - unassigned_axes = set(array.named_shape.keys()) | set( - range(len(array.positional_shape)) + row_infos, column_infos = infer_rows_and_columns( + all_axes=[ax for ax in array_axis_info if ax not in slider_infos], + known_rows=row_infos, + known_columns=column_infos, + edge_items_per_axis=edge_items_per_axis, ) - seen_axes = set() - rows = list(rows) - columns = list(columns) - sliders = list(sliders) - for axis in itertools.chain(rows, columns, sliders): - if axis in seen_axes: - raise ValueError( - f"Axis {repr(axis)} appeared multiple times in rows/columns/sliders" - " specifications. Each axis must be assigned to at most one" - " location." - ) - elif axis not in unassigned_axes: - raise ValueError( - f"Axis {repr(axis)} was assigned a location in rows/columns/sliders" - " but was not present in the array to render." - ) - seen_axes.add(axis) - unassigned_axes.remove(axis) - - rows, columns = infer_rows_and_columns( - unassigned=list(unassigned_axes), - known_rows=rows, - known_columns=columns, - axis_sizes={ - **{ - orig: truncated_array_data.shape[data_axis] - for orig, data_axis in data_axis_from_orig_axis.items() - }, - }, + return _render_pretruncated( + array_axis_info=array_axis_info, + row_infos=row_infos, + column_infos=column_infos, + slider_infos=slider_infos, + truncated_array_data=truncated_array_data, + truncated_mask_data=truncated_mask_data, + edge_items_per_axis=edge_items_per_axis, + continuous=continuous, + around_zero=around_zero, + vmax=vmax, + vmin=vmin, + trim_outliers=trim_outliers, + dynamic_colormap=dynamic_colormap, + colormap=colormap, + axis_item_labels=axis_item_labels, + value_item_labels=value_item_labels, + axis_labels=axis_labels, ) + +def _render_pretruncated( + *, + array_axis_info: Sequence[AxisInfo], + row_infos: Sequence[AxisInfo], + column_infos: Sequence[AxisInfo], + slider_infos: Sequence[AxisInfo], + truncated_array_data: np.ndarray, + truncated_mask_data: np.ndarray, + edge_items_per_axis: Sequence[int | None], + continuous: bool | Literal["auto"], + around_zero: bool | Literal["auto"], + vmax: float | None, + vmin: float | None, + trim_outliers: bool, + dynamic_colormap: bool | Literal["auto"], + colormap: list[tuple[int, int, int]] | None, + axis_item_labels: dict[AxisName | int, list[str]] | None, + value_item_labels: dict[int, str] | None, + axis_labels: dict[AxisName | int, str] | None, +) -> ArrayvizRendering: + """Internal helper to render an array that has already been truncated.""" + if axis_item_labels is None: + axis_item_labels = {} + + if value_item_labels is None: + value_item_labels = {} + + if axis_labels is None: + axis_labels = {} + + data_axis_from_axis_info = { + info: axis for axis, info in enumerate(array_axis_info) + } + assert len(data_axis_from_axis_info) == len(array_axis_info) + + has_name_only = False + positional_count = 0 + + info_by_name_or_position = {} + for info in array_axis_info: + if isinstance(info, NamedPositionalAxisInfo): + info_by_name_or_position[info.axis_name] = info + info_by_name_or_position[info.axis_logical_index] = info + positional_count += 1 + elif isinstance(info, PositionalAxisInfo): + info_by_name_or_position[info.axis_logical_index] = info + positional_count += 1 + elif isinstance(info, NamedPositionlessAxisInfo): + info_by_name_or_position[info.axis_name] = info + has_name_only = True + else: + raise ValueError(f"Unrecognized axis info {type(info)}") + + axis_labels_by_info = { + info_by_name_or_position[orig_key]: value + for orig_key, value in axis_labels.items() + } + axis_item_labels_by_info = { + info_by_name_or_position[orig_key]: value + for orig_key, value in axis_item_labels.items() + } + + skip_start_indices = [ + edge_items if edge_items is not None else axis_info.size + for edge_items, axis_info in zip(edge_items_per_axis, array_axis_info) + ] + skip_end_indices = [ + axis_info.size - edge_items if edge_items is not None else axis_info.size + for edge_items, axis_info in zip(edge_items_per_axis, array_axis_info) + ] + # Convert the axis names into indices into our data array. column_data_axes = [ - data_axis_from_orig_axis[orig_axis] for orig_axis in columns + data_axis_from_axis_info[orig_axis] for orig_axis in column_infos + ] + row_data_axes = [ + data_axis_from_axis_info[orig_axis] for orig_axis in row_infos ] - row_data_axes = [data_axis_from_orig_axis[orig_axis] for orig_axis in rows] slider_data_axes = [ - data_axis_from_orig_axis[orig_axis] for orig_axis in sliders + data_axis_from_axis_info[orig_axis] for orig_axis in slider_infos ] # Step 6: Figure out how to render the labels and indices of each axis. @@ -893,19 +1032,22 @@ def render_array( axis_label_instructions = [] - if array.named_shape: + if has_name_only: formatting_instructions.append({"type": "literal", "value": "[{"}) - for i, (name, size) in enumerate(array.named_shape.items()): - data_axis = data_axis_from_orig_axis[name] + first = True + for data_axis, axis_info in enumerate(array_axis_info): + if not isinstance(axis_info, NamedPositionlessAxisInfo): + continue - if i: + if first: formatting_instructions.append( - {"type": "literal", "value": f", {repr(name)}:"} + {"type": "literal", "value": f"{repr(axis_info.axis_name)}:"} ) + first = False else: formatting_instructions.append( - {"type": "literal", "value": f"{repr(name)}:"} + {"type": "literal", "value": f", {repr(axis_info.axis_name)}:"} ) formatting_instructions.append({ @@ -915,16 +1057,19 @@ def render_array( "skip_end": skip_end_indices[data_axis], }) - if name in axis_labels: - data_axis_labels[data_axis] = axis_labels[name] - elif name in sliders: - data_axis_labels[data_axis] = f"{str(name)}" + if axis_info in axis_labels_by_info: + data_axis_labels[data_axis] = axis_labels_by_info[axis_info] + label_name = f"{axis_labels_by_info[axis_info]} ({axis_info.axis_name})" + elif axis_info in slider_infos: + label_name = f"{str(axis_info.axis_name)}" + data_axis_labels[data_axis] = label_name else: - data_axis_labels[data_axis] = f"{str(name)}: {size}" + label_name = f"{str(axis_info.axis_name)}" + data_axis_labels[data_axis] = f"{label_name}: {axis_info.size}" - if name in axis_item_labels: + if axis_info in axis_item_labels_by_info: axis_label_instructions.extend([ - {"type": "literal", "value": f"\n{str(name)} @ "}, + {"type": "literal", "value": f"\n{label_name} @ "}, { "type": "index", "axis": f"a{data_axis}", @@ -937,17 +1082,20 @@ def render_array( "axis": f"a{data_axis}", "skip_start": skip_start_indices[data_axis], "skip_end": skip_end_indices[data_axis], - "lookup_table": axis_item_labels[name], + "lookup_table": axis_item_labels_by_info[axis_info], }, ]) formatting_instructions.append({"type": "literal", "value": "}]"}) - if array.positional_shape: + if positional_count: formatting_instructions.append({"type": "literal", "value": "["}) - for orig_index, size in enumerate(array.positional_shape): - data_axis = data_axis_from_orig_axis[orig_index] - if orig_index: + for logical_index in range(positional_count): + axis_info = info_by_name_or_position[logical_index] + assert isinstance(axis_info, PositionalAxisInfo | NamedPositionalAxisInfo) + assert axis_info.axis_logical_index == logical_index + data_axis = data_axis_from_axis_info[axis_info] + if logical_index > 0: formatting_instructions.append({"type": "literal", "value": ", "}) formatting_instructions.append({ "type": "index", @@ -956,16 +1104,22 @@ def render_array( "skip_end": skip_end_indices[data_axis], }) - if orig_index in axis_labels: - data_axis_labels[data_axis] = axis_labels[orig_index] - elif orig_index in sliders: - data_axis_labels[data_axis] = f"axis{orig_index}" + if axis_info in axis_labels_by_info: + data_axis_labels[data_axis] = axis_labels_by_info[axis_info] + label_name = f"{axis_labels_by_info[axis_info]} (axis {logical_index})" else: - data_axis_labels[data_axis] = f"axis {orig_index}: {size}" - - if orig_index in axis_item_labels: + if isinstance(axis_info, NamedPositionalAxisInfo): + label_name = f"{axis_info.axis_name} (axis {logical_index})" + else: + label_name = f"axis {logical_index}" + if axis_info in slider_infos: + data_axis_labels[data_axis] = label_name + else: + data_axis_labels[data_axis] = f"{label_name}: {axis_info.size}" + + if axis_info in axis_item_labels_by_info: axis_label_instructions.extend([ - {"type": "literal", "value": f"\nAxis {orig_index} @ "}, + {"type": "literal", "value": f"\n{label_name} @ "}, { "type": "index", "axis": f"a{data_axis}", @@ -978,7 +1132,7 @@ def render_array( "axis": f"a{data_axis}", "skip_start": skip_start_indices[data_axis], "skip_end": skip_end_indices[data_axis], - "lookup_table": axis_item_labels[orig_index], + "lookup_table": axis_item_labels_by_info[axis_info], }, ]) @@ -990,7 +1144,7 @@ def render_array( # Step 7: Infer the colormap and rendering strategy. # Figure out whether the array is continuous. - inferred_continuous = jnp.issubdtype(array_data.dtype, np.floating) + inferred_continuous = dtype_util.is_floating_dtype(truncated_array_data.dtype) if continuous == "auto": continuous = inferred_continuous elif not continuous and inferred_continuous: @@ -999,6 +1153,10 @@ def render_array( " cast it to an integer array first." ) + if inferred_continuous: + # Cast to float32 to ensure we can easily manipulate the truncated data. + truncated_array_data = truncated_array_data.astype(np.float32) + if value_item_labels and not continuous: formatting_instructions.append({"type": "literal", "value": " # "}) formatting_instructions.append( @@ -1095,7 +1253,9 @@ def render_array( column_axes=column_data_axes, row_axes=row_data_axes, slider_axes=slider_data_axes, - axis_labels=[data_axis_labels[i] for i in range(array_data.ndim)], + axis_labels=[ + data_axis_labels[i] for i in range(truncated_array_data.ndim) + ], vmin=vmin, vmax=vmax, cmap_type=colormap_type, @@ -1109,14 +1269,11 @@ def render_array( return ArrayvizRendering(html_src) -def _render_sharding( - array_shape: tuple[int, ...], - shard_shape: tuple[int, ...], - device_indices_map: Mapping[Any, tuple[slice, ...]], - rows: list[int | named_axes.AxisName] | None = None, - columns: list[int | named_axes.AxisName] | None = None, - name_to_data_axis: dict[named_axes.AxisName, int] | None = None, - position_to_data_axis: tuple[int, ...] | None = None, +def render_sharding_info( + array_axis_info: tuple[AxisInfo, ...], + sharding_info: ndarray_adapters.ShardingInfo, + rows: Sequence[int | AxisName] = (), + columns: Sequence[int | AxisName] = (), ) -> ArrayvizRendering: """Renders the sharding of an array. @@ -1125,50 +1282,58 @@ def _render_sharding( given shape and sharding. Args: - array_shape: Shape of the sharded array. - shard_shape: Shape of each array shard. - device_indices_map: Map from devices to tuples of slices into the array, - identifying which parts of the array it corresponds to. Usually obtained - from a JAX sharding. + array_axis_info: Axis info for each axis of the array data. + sharding_info: Sharding info for the array, as produced by a NDArrayAdapter. rows: Optional explicit ordering of rows in the visualization. columns: Optional explicit ordering of columns in the visualization. - name_to_data_axis: Optional mapping from named axes to their axis in the - data array. - position_to_data_axis: Optional mapping from virtual positional axes to - their axis in the data array. Returns: A rendering of the sharding, which re-uses the digitbox rendering mode to render sets of devices. """ - if name_to_data_axis is None and position_to_data_axis is None: - name_to_data_axis = {} - position_to_data_axis = {i: i for i in range(len(array_shape))} - else: - assert name_to_data_axis is not None - assert position_to_data_axis is not None - if rows is None and columns is None: - rows, columns = infer_rows_and_columns( - { - name_or_pos: array_shape[data_axis] - for name_or_pos, data_axis in itertools.chain( - name_to_data_axis.items(), enumerate(position_to_data_axis) - ) - }, - tuple(name_to_data_axis.keys()) - + tuple(range(len(position_to_data_axis))), - ) + data_axis_from_axis_info = { + info: axis for axis, info in enumerate(array_axis_info) + } + + info_by_name_or_position = {} + has_name_only = False + positional_count = 0 + for info in array_axis_info: + if isinstance(info, NamedPositionalAxisInfo): + info_by_name_or_position[info.axis_name] = info + info_by_name_or_position[info.axis_logical_index] = info + positional_count += 1 + elif isinstance(info, PositionalAxisInfo): + info_by_name_or_position[info.axis_logical_index] = info + positional_count += 1 + elif isinstance(info, NamedPositionlessAxisInfo): + info_by_name_or_position[info.axis_name] = info + has_name_only = True + else: + raise ValueError(f"Unrecognized axis info {type(info)}") + + array_shape = [info.size for info in array_axis_info] + shard_shape = sharding_info.shard_shape num_shards = np.prod(array_shape) // np.prod(shard_shape) # Compute a truncation for visualizing a single shard. Each shard will be # shown as a shrunken version of the actual shard dimensions, roughly # proportional to the shard sizes. - mini_trunc = ndarray_summarization.infer_balanced_truncation( + mini_trunc = infer_balanced_truncation( shape=array_shape, maximum_size=1000, cutoff_size_per_axis=10, minimum_edge_items=2, doubling_bonus=5, ) + # Infer an axis ordering. + known_row_infos = [info_by_name_or_position[spec] for spec in rows] + known_column_infos = [info_by_name_or_position[spec] for spec in columns] + row_infos, column_infos = infer_rows_and_columns( + all_axes=array_axis_info, + known_rows=known_row_infos, + known_columns=known_column_infos, + edge_items_per_axis=mini_trunc, + ) # Build an actual matrix to represent each shard, with a size determined by # the inferred truncation. shard_mask = np.ones((), dtype=np.bool_) @@ -1183,9 +1348,10 @@ def _render_sharding( vec = np.array([True] * candidate + [False] + [True] * candidate) shard_mask = shard_mask[..., None] * vec # Figure out which device is responsible for each shard. + device_indices_map = sharding_info.device_index_to_shard_slices device_to_shard_offsets = {} shard_offsets_to_devices = collections.defaultdict(list) - for device, slices in device_indices_map.items(): + for device_index, slices in device_indices_map.items(): shard_offsets = [] for i, slc in enumerate(slices): assert slc.step is None @@ -1196,14 +1362,14 @@ def _render_sharding( assert slc.stop == slc.start + shard_shape[i] shard_offsets.append(slc.start // shard_shape[i]) shard_offsets = tuple(shard_offsets) - device_to_shard_offsets[device] = shard_offsets - shard_offsets_to_devices[shard_offsets].append(device) + device_to_shard_offsets[device_index] = shard_offsets + shard_offsets_to_devices[shard_offsets].append(device_index) # Figure out what value to show for each shard. This determines the # visualization color. shard_offset_values = {} shard_value_descriptions = {} if len(device_indices_map) <= 10 and all( - device.id < 10 for device in device_indices_map.keys() + device_index < 10 for device_index in device_indices_map.keys() ): # Map each device to an integer digit 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, and # then draw replicas as collections of base-10 digits. @@ -1215,14 +1381,15 @@ def _render_sharding( vis_value = 1234567 else: acc = 0 - for i, device in enumerate(shard_devices): - acc += 10 ** (len(shard_devices) - i - 1) * (device.id + 1) + for i, device_index in enumerate(shard_devices): + acc += 10 ** (len(shard_devices) - i - 1) * (device_index + 1) vis_value = acc shard_offset_values[shard_offsets] = vis_value - platform = shard_devices[0].platform.upper() assert vis_value not in shard_value_descriptions shard_value_descriptions[vis_value] = ( - platform + " " + ",".join(f"{d.id}" for d in shard_devices) + sharding_info.device_type + + " " + + ",".join(f"{d}" for d in shard_devices) ) render_info_message = "Colored by device index." elif num_shards < 10: @@ -1281,36 +1448,47 @@ def _render_sharding( data_axis_labels = {} formatting_instructions = [] formatting_instructions.append({"type": "literal", "value": "array"}) - if name_to_data_axis: + + if has_name_only: formatting_instructions.append({"type": "literal", "value": "[{"}) - for k, (name, data_axis) in enumerate(name_to_data_axis.items()): - if k: + + first = True + for data_axis, axis_info in enumerate(array_axis_info): + if not isinstance(axis_info, NamedPositionlessAxisInfo): + continue + + if first: formatting_instructions.append( - {"type": "literal", "value": f", {repr(name)}:["} + {"type": "literal", "value": f"{repr(axis_info.axis_name)}:["} ) + first = False else: formatting_instructions.append( - {"type": "literal", "value": f"{repr(name)}:["} + {"type": "literal", "value": f", {repr(axis_info.axis_name)}:["} ) + formatting_instructions.append(axis_lookups[data_axis]) formatting_instructions.append({"type": "literal", "value": "]"}) axshards = array_shape[data_axis] // shard_shape[data_axis] data_axis_labels[data_axis] = ( - f"{str(name)}: {array_shape[data_axis]}/{axshards}" + f"{axis_info.axis_name}: {array_shape[data_axis]}/{axshards}" ) formatting_instructions.append({"type": "literal", "value": "}]"}) - if position_to_data_axis: + + if positional_count: formatting_instructions.append({"type": "literal", "value": "["}) - for k in range(len(position_to_data_axis)): - data_axis = position_to_data_axis[k] - if k: + for logical_index in range(positional_count): + axis_info = info_by_name_or_position[logical_index] + data_axis = data_axis_from_axis_info[axis_info] + if logical_index: formatting_instructions.append({"type": "literal", "value": ", "}) formatting_instructions.append(axis_lookups[data_axis]) axshards = array_shape[data_axis] // shard_shape[data_axis] data_axis_labels[data_axis] = ( - f"axis {k}: {array_shape[data_axis]}/{axshards}" + f"axis {logical_index}: {array_shape[data_axis]}/{axshards}" ) formatting_instructions.append({"type": "literal", "value": "]"}) + formatting_instructions.append({"type": "literal", "value": ":\n "}) formatting_instructions.append({ "type": "value_lookup", @@ -1319,13 +1497,12 @@ def _render_sharding( }) # Build the rendering. html_srcs = [] - to_data_axis = {**dict(enumerate(position_to_data_axis)), **name_to_data_axis} html_srcs.append( _render_array_to_html( array_data=dest, valid_mask=destmask, - column_axes=[to_data_axis[c] for c in columns], - row_axes=[to_data_axis[r] for r in rows], + column_axes=[data_axis_from_axis_info[c] for c in column_infos], + row_axes=[data_axis_from_axis_info[r] for r in row_infos], slider_axes=(), axis_labels=[data_axis_labels[i] for i in range(len(array_shape))], vmin=0, @@ -1344,9 +1521,8 @@ def _render_sharding( shard_offsets_to_devices.items() ): if i == 0: - device = shard_devices[0] - html_srcs.append(f"{device.platform.upper()}") - label = ",".join(f"{d.id}" for d in shard_devices) + html_srcs.append(f"{sharding_info.device_type}") + label = ",".join(f"{d}" for d in shard_devices) subsrc = integer_digitbox( shard_offset_values[shard_offsets], label_bottom="", @@ -1357,9 +1533,9 @@ def _render_sharding( def render_array_sharding( - array: jax.Array | named_axes.NamedArray, - rows: list[int | named_axes.AxisName] | None = None, - columns: list[int | named_axes.AxisName] | None = None, + array: ArrayInRegistry, + rows: Sequence[int | AxisName] = (), + columns: Sequence[int | AxisName] = (), ) -> ArrayvizRendering: """Renders the sharding of an array. @@ -1371,79 +1547,30 @@ def render_array_sharding( Returns: A rendering of that array's sharding. """ - # Wrap as named arrays if needed, for consistency of the following steps. - if not isinstance(array, named_axes.NamedArrayBase): - if not isinstance(array, jax.Array): - raise ValueError( - "render_array_sharding can only be used on jax.Arrays and" - " pz.nx.NamedArray / NamedArrayView." - ) - array = named_axes.wrap(array) - array.check_valid() - array = array.as_namedarrayview() - assert array.data_array.shape == array.data_shape - if not hasattr(array.data_array, "sharding"): - raise ValueError( - "Provided array does not have a sharding! Is this a tracer?" - ) - sharding = array.data_array.sharding - - return _render_sharding( - array_shape=array.data_shape, - shard_shape=sharding.shard_shape(array.data_shape), - device_indices_map=sharding.devices_indices_map(array.data_shape), - name_to_data_axis=array.data_axis_for_name, - position_to_data_axis=array.data_axis_for_logical_axis, - rows=rows, - columns=columns, + # Retrieve the adapter for this array, which we will use to construct + # the rendering. + adapter = type_registries.lookup_by_mro( + type_registries.NDARRAY_ADAPTER_REGISTRY, type(array) ) - - -def render_sharded_shape( - sharding: jax.sharding.Sharding, - shape_or_namedarray_struct: ( - jax.ShapeDtypeStruct | named_axes.NamedArrayBase | tuple[int, ...] | Any - ), - rows: list[int | named_axes.AxisName] | None = None, - columns: list[int | named_axes.AxisName] | None = None, -) -> ArrayvizRendering: - """Renders the sharding an array would have, based on its shape. - - Args: - sharding: A sharding to visualize. - shape_or_namedarray_struct: Either an arbitrary object with a ``.shape`` - attribute, a tuple of integers, or a NamedArray wrapping a - `jax.lax.ShapeDtypeStruct`. - rows: Optional explicit ordering of axes for the visualization rows. - columns: Optional explicit ordering of axes for the visualization columns. - - Returns: - A rendering of the result of sharding an array with this shape using the - given sharding. - """ - if isinstance(shape_or_namedarray_struct, tuple): - shape_or_namedarray_struct = jax.ShapeDtypeStruct( - shape=shape_or_namedarray_struct, dtype=jnp.float32 + if adapter is None: + raise TypeError( + "Cannot render sharding for array with unrecognized type" + f" {type(array)} (not found in array adapter registry)" ) - elif not isinstance(shape_or_namedarray_struct, named_axes.NamedArrayBase): - shape_or_namedarray_struct = jax.ShapeDtypeStruct( - shape=shape_or_namedarray_struct.shape, dtype=jnp.float32 + + # Extract information about axis names, indices, and sizes, along with the + # sharding info. + array_axis_info = adapter.get_axis_info_for_array_data(array) + sharding_info = adapter.get_sharding_info_for_array_data(array) + if sharding_info is None: + raise ValueError( + "Cannot render sharding for array without sharding info (not provided" + f" by array adapter for {type(array)})." ) - def _traced_fixup(array): - if not isinstance(array, named_axes.NamedArrayBase): - array = named_axes.wrap(array) - array.check_valid() - return array.as_namedarrayview() - - view = jax.eval_shape(_traced_fixup, shape_or_namedarray_struct) - assert view.data_array.shape == view.data_shape - return _render_sharding( - array_shape=view.data_shape, - shard_shape=sharding.shard_shape(view.data_shape), - device_indices_map=sharding.devices_indices_map(view.data_shape), - name_to_data_axis=view.data_axis_for_name, - position_to_data_axis=view.data_axis_for_logical_axis, + return render_sharding_info( + array_axis_info=array_axis_info, + sharding_info=sharding_info, rows=rows, columns=columns, ) diff --git a/penzai/treescope/canonical_aliases.py b/penzai/treescope/canonical_aliases.py index c7720b7..324bc6d 100644 --- a/penzai/treescope/canonical_aliases.py +++ b/penzai/treescope/canonical_aliases.py @@ -191,22 +191,16 @@ class CanonicalAliasEnvironment: Attributes: aliases: A mapping from id(some_object) to the path where we expect to find that object. - lazy_populate_if_imported: A list of module names we should populate lazily - if they are imported, without importing them directly, along with a - predicate to use for them. """ aliases: dict[int, ModuleAttributePath] - lazy_populate_if_imported: list[ - tuple[str, Callable[[Any, ModuleAttributePath], bool]] - ] _alias_environment: context.ContextualValue[CanonicalAliasEnvironment] = ( context.ContextualValue( module=__name__, qualname="_alias_environment", - initial_value=CanonicalAliasEnvironment({}, []), + initial_value=CanonicalAliasEnvironment({}), ) ) """The current environment for module-level canonical aliases. @@ -274,25 +268,6 @@ def add_alias( """ -def update_lazy_aliases() -> None: - """Checks for newly-imported modules and defines aliases for them. - - This function loops over the modules listed in `lazy_populate_if_imported` - for the active environment, and adds canonical aliases for any modules that - were recently imported. - """ - alias_env = _alias_environment.get() - # Check for newly-imported modules that we should define aliases for. - all_handled = [] - for name, predicate in alias_env.lazy_populate_if_imported: - if name in sys.modules: - populate_from_public_api(sys.modules[name], predicate) - all_handled.append((name, predicate)) - if all_handled: - for pair in all_handled: - alias_env.lazy_populate_if_imported.remove(pair) - - def lookup_alias( the_object: Any, infer_from_attributes: bool = True, @@ -343,6 +318,7 @@ def lookup_alias( if ( hasattr(unwrapped, "__module__") and hasattr(unwrapped, "__qualname__") + and unwrapped.__qualname__ is not None and "<" not in unwrapped.__qualname__ ): alias = ModuleAttributePath( @@ -578,21 +554,3 @@ def predicate(the_object: Any, path: ModuleAttributePath) -> bool: return True return predicate - - -# Register well-known aliases for the functions defined in these modules, since -# they are likely to be used in penzai code. -_alias_environment.get().lazy_populate_if_imported.extend([ - # Third-party libraries with useful APIs: - ("numpy", prefix_filter("numpy", excludes=("numpy.core",))), - ("jax.lax", prefix_filter("jax")), - ("jax.numpy", prefix_filter("jax")), - ("jax.scipy", prefix_filter("jax")), - ("jax.random", prefix_filter("jax")), - ("jax.nn", prefix_filter("jax")), - ("jax.custom_derivatives", prefix_filter("jax")), - ("jax.experimental.pjit", prefix_filter("jax")), - ("jax.experimental.shard_map", prefix_filter("jax")), - ("jax", prefix_filter("jax")), - ("equinox", prefix_filter("equinox")), -]) diff --git a/penzai/treescope/copypaste_fallback.py b/penzai/treescope/copypaste_fallback.py index 7d80d47..cb2b5aa 100644 --- a/penzai/treescope/copypaste_fallback.py +++ b/penzai/treescope/copypaste_fallback.py @@ -17,9 +17,9 @@ from __future__ import annotations import dataclasses +import sys from typing import Any -import jax from penzai.treescope import renderer from penzai.treescope.foldable_representation import part_interface from penzai.treescope.handlers import builtin_atom_handler @@ -53,11 +53,13 @@ def from_object( cls, obj: Any, repr_override: str | None = None ) -> NotRoundtrippable: """Constructs a NotRoundtrippable from an object.""" - if isinstance(obj, jax.Array) and not isinstance(obj, jax.core.Tracer): - # Don't output the internal implementation type. - ty = jax.Array - else: - ty = type(obj) + ty = type(obj) + # Hide implementation details of JAX arrays if JAX is imported. + if "jax" in sys.modules: + jax = sys.modules["jax"] + if isinstance(obj, jax.Array) and not isinstance(obj, jax.core.Tracer): + # Don't output the internal implementation type. + ty = jax.Array if repr_override is None: repr_value = repr(obj) else: diff --git a/penzai/treescope/default_renderer.py b/penzai/treescope/default_renderer.py index 32b98ef..797bfb9 100644 --- a/penzai/treescope/default_renderer.py +++ b/penzai/treescope/default_renderer.py @@ -14,14 +14,12 @@ """Configures the default renderer, and allows reconfiguring it dynamically.""" -import ast import contextlib import functools from typing import Any, Callable -import jax -from penzai.treescope import canonical_aliases from penzai.treescope import context from penzai.treescope import renderer +from penzai.treescope import type_registries from penzai.treescope.foldable_representation import basic_parts from penzai.treescope.foldable_representation import foldable_impl from penzai.treescope.foldable_representation import layout_algorithms @@ -30,12 +28,10 @@ from penzai.treescope.handlers import builtin_atom_handler from penzai.treescope.handlers import builtin_structure_handler from penzai.treescope.handlers import canonical_alias_postprocessor -from penzai.treescope.handlers import extension_method_handler +from penzai.treescope.handlers import custom_type_handlers from penzai.treescope.handlers import function_reflection_handlers from penzai.treescope.handlers import generic_pytree_handler from penzai.treescope.handlers import generic_repr_handler -from penzai.treescope.handlers import hardcoded_structure_handlers -from penzai.treescope.handlers import ndarray_handler from penzai.treescope.handlers import repr_html_postprocessor from penzai.treescope.handlers import shared_value_postprocessor @@ -46,34 +42,19 @@ qualname="active_renderer", initial_value=renderer.TreescopeRenderer( handlers=[ - # Objects with their own handlers. - extension_method_handler.handle_via_penzai_repr_method, - # NDArrays. - ndarray_handler.handle_ndarrays, + # Objects with `__penzai_repr__` defined. + custom_type_handlers.handle_via_penzai_repr_method, + # Objects in the global registry of type handlers. + custom_type_handlers.handle_via_global_registry, # Reflection of functions and classes. function_reflection_handlers.handle_code_objects_with_reflection, # Numbers, strings, constants, enums, etc. builtin_atom_handler.handle_builtin_atoms, # Lists, dicts, tuples, dataclasses, namedtuples, etc. builtin_structure_handler.handle_builtin_structures, - # Hardcoded simple types. - hardcoded_structure_handlers.HardcodedStructureHandler({ - ast.AST: hardcoded_structure_handlers.HasFieldsInClassAttr( - "_fields", render_subclasses=True - ), - jax.ShapeDtypeStruct: ( - hardcoded_structure_handlers.HasFieldsInClassAttr( - "__slots__" - ) - ), - jax.lax.Precision: ( - hardcoded_structure_handlers.IsEnumLike() - ), - }), - # Dtype objects. - ndarray_handler.handle_dtype_instances, # Fallback for unknown pytree types: Show repr and also the - # PyTree children. + # PyTree children. Note: This is a no-op unless JAX has been + # imported. generic_pytree_handler.handle_arbitrary_pytrees, # Fallback to ordinary `repr` for any other object. generic_repr_handler.handle_anything_with_repr, @@ -93,9 +74,9 @@ # Set up a new context for each rendered object when rendering # shared values. shared_value_postprocessor.setup_shared_value_context, - # Update canonical aliases to account for newly imported + # Update type registries to account for newly imported # modules before rendering. - canonical_aliases.update_lazy_aliases, + type_registries.update_registries_for_imports, ], ), ) diff --git a/penzai/treescope/dtype_util.py b/penzai/treescope/dtype_util.py new file mode 100644 index 0000000..7e26c54 --- /dev/null +++ b/penzai/treescope/dtype_util.py @@ -0,0 +1,89 @@ +# Copyright 2024 The Penzai Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for working with dtypes. + +JAX extends the Numpy dtype system using the `ml_dtypes` package. Unfortunately, +these extended dtypes do not integrate directly with Numpy subdtype checks. +This module provides utilities to perform these checks that do not depend on +JAX being installed. +""" + +import numpy as np +import numpy.typing + + +def is_integer_dtype(dtype: numpy.typing.DTypeLike) -> bool: + """Returns whether the given dtype is an integer dtype. + + Supports both basic numpy dtypes and the extended dtypes in the `ml_dtypes` + package (if installed). + + Args: + dtype: The dtype to check. + + Returns: + True if the given dtype is an integer dtype. + """ + dtype = np.dtype(dtype) + if np.issubdtype(dtype, np.integer): + return True + if isinstance(dtype.type, type) and dtype.type.__module__ == "ml_dtypes": + import ml_dtypes # pylint: disable=g-import-not-at-top + + try: + _ = ml_dtypes.iinfo(dtype) + return True + except ValueError: + return False + return False + + +def is_floating_dtype(dtype: numpy.typing.DTypeLike) -> bool: + """Returns whether the given dtype is a floating dtype. + + Supports both basic numpy dtypes and the extended dtypes in the `ml_dtypes` + package (if installed). + + Args: + dtype: The dtype to check. + + Returns: + True if the given dtype is a floating dtype. + """ + dtype = np.dtype(dtype) + if np.issubdtype(dtype, np.floating): + return True + if isinstance(dtype.type, type) and dtype.type.__module__ == "ml_dtypes": + import ml_dtypes # pylint: disable=g-import-not-at-top + + try: + _ = ml_dtypes.finfo(dtype) + return True + except ValueError: + return False + return False + + +def get_dtype_name(dtype: numpy.typing.DTypeLike) -> str: + """Safely extracts a name for a dtype.""" + # Render scalar type objects as their literal names. + if isinstance(dtype, type) and issubclass(dtype, np.generic): + return dtype.__name__ + # Render any other dtype-like objects as the name of the concrete dtype they + # convert to. + try: + return np.dtype(dtype).name + except TypeError: + return str(dtype) diff --git a/penzai/treescope/foldable_representation/common_structures.py b/penzai/treescope/foldable_representation/common_structures.py index 4d54063..3b803df 100644 --- a/penzai/treescope/foldable_representation/common_structures.py +++ b/penzai/treescope/foldable_representation/common_structures.py @@ -151,6 +151,7 @@ def build_foldable_tree_node_from_children( background_color: str | None = None, background_pattern: str | None = None, first_line_annotation: RenderableTreePart | None = None, + expand_state: part_interface.ExpandState = part_interface.ExpandState.WEAKLY_COLLAPSED, ) -> RenderableAndLineAnnotations: """Builds a foldable tree node with path buttons and hyperlink support. @@ -173,6 +174,7 @@ def build_foldable_tree_node_from_children( the border for the pattern. first_line_annotation: An annotation for the first line of the node when it is expanded. + expand_state: Initial expand state for the foldable. Returns: A new renderable part, possibly with a copy button annotation, for use @@ -247,6 +249,7 @@ def wrap_block(block): ), wrap_bottomline(suffix), ), + expand_state=expand_state, ) ), annotations=maybe_copy_button, diff --git a/penzai/treescope/foldable_representation/foldable_impl.py b/penzai/treescope/foldable_representation/foldable_impl.py index 97a1a44..0e24e65 100644 --- a/penzai/treescope/foldable_representation/foldable_impl.py +++ b/penzai/treescope/foldable_representation/foldable_impl.py @@ -708,6 +708,7 @@ def collecting_deferred_renderings() -> Iterator[list[DeferredWithThunk]]: def render_to_text_as_root( root_node: RenderableTreePart, roundtrip: bool = False, + strip_trailing_whitespace: bool = True, strip_whitespace_lines: bool = True, ) -> str: """Renders a root node to text. @@ -715,13 +716,19 @@ def render_to_text_as_root( Args: root_node: The root node to render. roundtrip: Whether to render in roundtrip mode. + strip_trailing_whitespace: Whether to remove trailing whitespace from lines. strip_whitespace_lines: Whether to remove lines that are entirely whitespace. These lines can sometimes be generated by layout code being - conservative about line breaks. + conservative about line breaks. Should only be True if + `strip_trailing_whitespace` is True. Returns: Text for the rendered node. """ + if strip_whitespace_lines and not strip_trailing_whitespace: + raise ValueError("strip_whitespace_lines must be False if " + "strip_trailing_whitespace is False.") + stream = io.StringIO() root_node.render_to_text( stream, @@ -732,13 +739,13 @@ def render_to_text_as_root( ) result = stream.getvalue() - if strip_whitespace_lines: - postprocess_stream = io.StringIO() - for line in result.splitlines(keepends=True): - if line.strip(): - postprocess_stream.write(line) - result = postprocess_stream.getvalue() - + if strip_trailing_whitespace: + trimmed_lines = [] + for line in result.split("\n"): + line = line.rstrip() + if line or not strip_whitespace_lines: + trimmed_lines.append(line) + result = "\n".join(trimmed_lines) return result diff --git a/penzai/treescope/handlers/builtin_structure_handler.py b/penzai/treescope/handlers/builtin_structure_handler.py index 2e11f91..080ee12 100644 --- a/penzai/treescope/handlers/builtin_structure_handler.py +++ b/penzai/treescope/handlers/builtin_structure_handler.py @@ -16,6 +16,7 @@ from __future__ import annotations +import ast import dataclasses import types from typing import Any, Callable, Optional, Sequence @@ -379,8 +380,8 @@ def handle_builtin_structures( background_pattern=background_pattern, ) - elif isinstance(node, tuple) and hasattr(type(node), "_fields"): - # Namedtuple class. + elif isinstance(node, (tuple, ast.AST)) and hasattr(type(node), "_fields"): + # Namedtuple or AST class. return common_structures.build_foldable_tree_node_from_children( prefix=basic_parts.siblings( common_structures.maybe_qualified_type_name(type(node)), "(" diff --git a/penzai/treescope/handlers/extension_method_handler.py b/penzai/treescope/handlers/custom_type_handlers.py similarity index 61% rename from penzai/treescope/handlers/extension_method_handler.py rename to penzai/treescope/handlers/custom_type_handlers.py index 0ea20f2..0dc4344 100644 --- a/penzai/treescope/handlers/extension_method_handler.py +++ b/penzai/treescope/handlers/custom_type_handlers.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Handler for custom types via the __penzai_repr__ method.""" +"""Handler for custom types via __penzai_repr__ or the global registry.""" from __future__ import annotations @@ -20,6 +20,7 @@ from penzai.treescope import object_inspection from penzai.treescope import renderer +from penzai.treescope import type_registries from penzai.treescope.foldable_representation import part_interface @@ -63,3 +64,40 @@ def handle_via_penzai_repr_method( return penzai_repr_method(path, subtree_renderer) else: return NotImplemented + + +def handle_via_global_registry( + node: Any, + path: str | None, + subtree_renderer: renderer.TreescopeSubtreeRenderer, +) -> ( + part_interface.RenderableTreePart + | part_interface.RenderableAndLineAnnotations + | type(NotImplemented) +): + """Renders a type by looking it up in the global handler registry. + + If it is not feasible to define ``__treescope_repr__`` for a type, it can + instead be registered in the global handler registry. This is a dictionary + mapping types to functions that render a node of that type. + + Currently, the exact structure of the intermediate representation is an + implementation detail and may change in future releases. Instead of building + a rendering directly, most types should use the construction helpers in + `penzai.treescope.repr_lib` to implement this method. + + Args: + node: The node to render. + path: An optional path to this node from the root. + subtree_renderer: The renderer for sutrees of this node. + + Returns: + A rendering of this node, if it was found in the global registry. + """ + maybe_handler = type_registries.lookup_by_mro( + type_registries.TREESCOPE_HANDLER_REGISTRY, type(node) + ) + if maybe_handler: + return maybe_handler(node, path, subtree_renderer) + else: + return NotImplemented diff --git a/penzai/treescope/handlers/generic_pytree_handler.py b/penzai/treescope/handlers/generic_pytree_handler.py index a737e58..02cb425 100644 --- a/penzai/treescope/handlers/generic_pytree_handler.py +++ b/penzai/treescope/handlers/generic_pytree_handler.py @@ -14,9 +14,9 @@ """Pretty-print handlers for generic pytrees.""" +import sys from typing import Any -import jax from penzai.treescope import renderer from penzai.treescope.foldable_representation import basic_parts from penzai.treescope.foldable_representation import common_structures @@ -35,6 +35,12 @@ def handle_arbitrary_pytrees( | type(NotImplemented) ): """Generic foldable fallback for an unrecognized pytree type.""" + if "jax" not in sys.modules: + # JAX isn't imported, so we can't check the JAX pytree registry. + return NotImplemented + + jax = sys.modules["jax"] + # Is this a pytree? paths_and_subtrees, treedef = jax.tree_util.tree_flatten_with_path( node, is_leaf=lambda subtree: subtree is not node diff --git a/penzai/treescope/handlers/hardcoded_structure_handlers.py b/penzai/treescope/handlers/hardcoded_structure_handlers.py deleted file mode 100644 index b608e14..0000000 --- a/penzai/treescope/handlers/hardcoded_structure_handlers.py +++ /dev/null @@ -1,192 +0,0 @@ -# Copyright 2024 The Penzai Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Handles a hardcoded list of dataclass-like structures.""" -from __future__ import annotations - -import dataclasses -import functools -import inspect -from typing import Any, Sequence - -from penzai.treescope import renderer -from penzai.treescope.foldable_representation import basic_parts -from penzai.treescope.foldable_representation import common_structures -from penzai.treescope.foldable_representation import common_styles -from penzai.treescope.foldable_representation import part_interface -from penzai.treescope.handlers import builtin_structure_handler - - -@dataclasses.dataclass(frozen=True) -class HasFieldsInClassAttr: - """Marks a type as having fields listed as a class attribute. - - We assume that the constructor for such a type takes each of the fields as - keyword arguments, and that the fields are also attributes. - - Attributes: - fields_class_attr: Attribute on the class object that specifies what the - fields are, sometimes called "_fields", "__slots__" - render_subclasses: Whether to also render subclasses of the class. - """ - - fields_class_attr: str - render_subclasses: bool = False - - -@dataclasses.dataclass(frozen=True) -class HasExplicitFields: - """Marks a type as having an explicit set of fields. - - We assume that the constructor for such a type takes each of the fields as - keyword arguments, and that the fields are also attributes. - - Attributes: - fields: Collection of fields to render. - render_subclasses: Whether to also render subclasses of the class. - """ - - fields: Sequence[str] - render_subclasses: bool = False - - -@dataclasses.dataclass(frozen=True) -class HasFieldsLikeInit: - """Marks a type as having fields based on the signature of `__init__`. - - We assume that every argument to __init__ is also an attribute. - - Attributes: - render_subclasses: Whether to also render subclasses of the class. - """ - - render_subclasses: bool = False - - -@dataclasses.dataclass(frozen=True) -class IsEnumLike: - """Marks a type as behaving like an enum. - - Instances of enum-like types are assumed to have `name` and `value` - attributes, and those instances should be accessible through attribute lookup - on the class. - - Attributes: - render_subclasses: Whether to also render subclasses of the class. - """ - - render_subclasses: bool = False - - -def _dataclass_like( - fields: Sequence[str], - node: Any, - path: str, - subtree_renderer: renderer.TreescopeSubtreeRenderer, -): - """Renders a dataclass-like object.""" - return common_structures.build_foldable_tree_node_from_children( - prefix=basic_parts.siblings( - common_structures.maybe_qualified_type_name(type(node)), "(" - ), - children=builtin_structure_handler.build_field_children( - node, - path, - subtree_renderer, - fields_or_attribute_names=fields, - ), - suffix=")", - path=path, - ) - - -def _enum_like( - node: Any, - path: str, - subtree_renderer: renderer.TreescopeSubtreeRenderer, -): - """Renders a enum-like object.""" - del subtree_renderer - cls = type(node) - if node == getattr(cls, node.name): - return common_structures.build_one_line_tree_node( - basic_parts.siblings_with_annotations( - common_structures.maybe_qualified_type_name(cls), - "." + node.name, - extra_annotations=[ - common_styles.CommentColor( - basic_parts.Text(f" # value: {repr(node.value)}") - ) - ], - ), - path, - ) - else: - return NotImplemented - - -@functools.cache -def _get_init_args(cls: type[Any]) -> Sequence[str]: - return tuple(inspect.signature(cls).parameters.keys()) - - -@dataclasses.dataclass(frozen=True) -class HardcodedStructureHandler: - """A handler for a specific hardcoded list of dataclass-like/enum-like types. - - Each of these types will be shown like a dataclass or an enum. This is - intended to support structures that act like dataclasses, namedtuples, or - enums, but are not implemented as such (e.g. JAX's ShapeDtypeStruct.) - - Attributes: - known_structure_types: Mapping from handled types to a tuple of their - attribute names. - """ - - known_structure_types: dict[ - type[Any], HasFieldsInClassAttr | HasExplicitFields | IsEnumLike - ] - - def __call__( - self, - node: Any, - path: str | None, - subtree_renderer: renderer.TreescopeSubtreeRenderer, - ) -> ( - part_interface.RenderableTreePart - | part_interface.RenderableAndLineAnnotations - | type(NotImplemented) - ): - """Renders the hardcoded types from `known_structure_types`.""" - for candidate_type, spec in self.known_structure_types.items(): - if spec.render_subclasses: - matched = isinstance(node, candidate_type) - else: - matched = type(node) is candidate_type # pylint: disable=unidiomatic-typecheck - - if matched: - if isinstance(spec, HasFieldsInClassAttr): - fields = getattr(type(node), spec.fields_class_attr) - return _dataclass_like(fields, node, path, subtree_renderer) - elif isinstance(spec, HasExplicitFields): - fields = spec.fields - return _dataclass_like(fields, node, path, subtree_renderer) - elif isinstance(spec, HasFieldsLikeInit): - fields = _get_init_args(type(node)) - return _dataclass_like(fields, node, path, subtree_renderer) - elif isinstance(spec, IsEnumLike): - return _enum_like(node, path, subtree_renderer) - - else: - return NotImplemented diff --git a/penzai/treescope/handlers/interop/jax_support.py b/penzai/treescope/handlers/interop/jax_support.py new file mode 100644 index 0000000..6c2fe66 --- /dev/null +++ b/penzai/treescope/handlers/interop/jax_support.py @@ -0,0 +1,611 @@ +# Copyright 2024 The Penzai Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Lazy setup logic for adding JAX support to treescope.""" + +from __future__ import annotations + +import typing +from typing import Mapping + +import numpy as np +from penzai.treescope import canonical_aliases +from penzai.treescope import context +from penzai.treescope import dtype_util +from penzai.treescope import ndarray_adapters +from penzai.treescope import renderer +from penzai.treescope import repr_lib +from penzai.treescope import type_registries +from penzai.treescope.foldable_representation import basic_parts +from penzai.treescope.foldable_representation import common_structures +from penzai.treescope.foldable_representation import common_styles +from penzai.treescope.foldable_representation import foldable_impl +from penzai.treescope.foldable_representation import part_interface + +# pylint: disable=g-import-not-at-top +try: + import jax +except ImportError: + assert not typing.TYPE_CHECKING + jax = None +# pylint: enable=g-import-not-at-top + + +def _finite_mean_std_any(array): + """Helper to compute mean and standard deviation only over finite elements.""" + assert jax is not None + jnp = jax.numpy + isfinite = jnp.isfinite(array) + inf_to_nan = jnp.where(isfinite, array, jnp.array(jnp.nan, dtype=array.dtype)) + mean = jnp.nanmean(inf_to_nan) + std = jnp.nanstd(inf_to_nan) + return mean, std, jnp.any(isfinite) + + +def _is_subdtype(dtype, base) -> bool: + """Safely checks for dtype subtyping.""" + assert jax is not None + jnp = jax.numpy + try: + return jnp.issubdtype(dtype, base) + except TypeError: + return False + + +summarization_threshold: context.ContextualValue[Mapping[str, int | None]] = ( + context.ContextualValue( + module=__name__, + qualname="summarization_threshold", + initial_value={ + "tpu": 1_000_000_000, + "gpu": 10_000_000, + "default": 100_000, + }, + ) +) +"""Threshold for summarization of NDArrays for each backend. + +This threshold determines the largest number of elements we will +summarize with summary statistics (e.g. mean, standard deviation) +when rendering in treescope. Larger values may make it slower to +display large NDArrays. + +Each key should be the name of a JAX array platform, e.g. "cpu" or +"tpu". It can also be "numpy" to refer to Numpy arrays, or "default" +to refer to any other accelerator. The value is the size of the +array at which point we avoid showing summary statistics. `None` +means no limit. + +This configuration argument is intended to be set at the top level +by the user, e.g. in IPython. +""" + + +def safe_to_summarize(array: jax.Array) -> bool: + """Checks if the array is safe to summarize (not a tracer and not replicated).""" + assert jax is not None, "JAX is not available." + if isinstance(array, jax.core.Tracer): + return False + if array.is_deleted(): + return False + if not ( + getattr(array, "is_fully_addressable", False) + or getattr(array, "is_fully_replicated", False) + ): + return False + thresh_dict = summarization_threshold.get() + [platform] = set(device.platform for device in array.devices()) + thresh = thresh_dict.get(platform) + if thresh is None: + thresh = thresh_dict["default"] + return thresh is None or array.size < thresh + + +def _truncate_part_with_slices( + array: jax.Array, + mask: jax.Array, + prefix_slices: tuple[slice, ...], + remaining_edge_items_per_axis: tuple[int | None, ...], +) -> tuple[jax.Array, jax.Array]: + """Helper to truncate names of an array. + + Args: + array: An array to truncate. + mask: Mask array, which must have the same number of dimensions as `array`, + and whose axis sizes must be either 1 or the same as that axis of `array` + (e.g. they are broadcast compatible). + prefix_slices: Slices to apply to each axis of `array` and `mask`, starting + at axis 0, which we have already computed. + remaining_edge_items_per_axis: Number of edge items to keep for each axis, + ignoring any axes whose slices are already computed in `prefix_slices`. + + Returns: + Truncated array and mask, which will both be the same shape. + """ + assert jax is not None, "JAX is not available." + jnp = jax.numpy + if not remaining_edge_items_per_axis: + # Perform the base case slice. + assert len(prefix_slices) == len(array.shape) + truncated_array = array[prefix_slices] + + valid_mask_slices = tuple( + slice(None) if mask.shape[i] == 1 else array_slice + for i, array_slice in enumerate(prefix_slices) + ) + truncated_mask = jnp.broadcast_to( + jnp.array(mask[valid_mask_slices]), truncated_array.shape + ) + return truncated_array, truncated_mask + + # Recursive step: extract one name, run the function on each side, and + # concatenate. + axis = len(prefix_slices) + edge_items = remaining_edge_items_per_axis[0] + if edge_items is None: + # Don't need to slice. + return _truncate_part_with_slices( + array, + mask, + prefix_slices=prefix_slices + (slice(None),), + remaining_edge_items_per_axis=remaining_edge_items_per_axis[1:], + ) + else: + assert array.shape[axis] > 2 * edge_items + result_a, valid_a = _truncate_part_with_slices( + array, + mask, + prefix_slices=prefix_slices + (slice(None, edge_items),), + remaining_edge_items_per_axis=remaining_edge_items_per_axis[1:], + ) + result_b, valid_b = _truncate_part_with_slices( + array, + mask, + prefix_slices=prefix_slices + (slice(-edge_items, None),), + remaining_edge_items_per_axis=remaining_edge_items_per_axis[1:], + ) + padding_shape = list(result_a.shape) + padding_shape[axis] = 1 + result = jnp.concatenate( + [result_a, jnp.zeros(padding_shape, result_a.dtype), result_b], + axis=axis, + ) + valid = jnp.concatenate( + [valid_a, jnp.zeros(padding_shape, valid_a.dtype), valid_b], axis=axis + ) + return result, valid + + +def truncate_array_and_mask( + array: jax.Array, + mask: jax.Array, + edge_items_per_axis: tuple[int | None, ...], +) -> tuple[jax.Array, jax.Array]: + """Truncates an array along the given axis names. + + Args: + array: Array to truncate. + mask: Mask array, which must have the same number of dimensions as `array`, + and whose axis sizes must be either 1 or the same as that axis of `array` + (e.g. they are broadcast compatible). + edge_items_per_axis: Number of edge items to keep for each axis, ignoring + any axes whose slices are already computed in `prefix_slices`. + + Returns: + A tuple containing a truncated version of the array along with a valid mask. + Values taken from the original array have the valid mask as True, and there + is one extra element in the middle with valid as False (standing in for the + omitted elements). The return value is always fully replicated, because + we cannot guarantee that it is evenly sharded across devices, and this + function is usually used immediately before copying to the host. + """ + assert jax is not None, "JAX is not available." + sharding_kwargs = {} + if hasattr(array, "sharding") and hasattr( + array.sharding, "_device_assignment" + ): + # _truncate_part_with_slices usually returns slices that have odd + # dimensions, which aren't divisible by most shardings. Unfortunately, + # the XLA GSPMD partitioner sometimes still infers a sharding over one of + # these axes, which then leads to partitioning errors in JAX whenever we + # try to `device_get` the resulting array or call any additional operations + # on it. To avoid this, we'd like to tell JAX to always produce an output + # that is not sharded over any axis. Unfortunately, this is difficult + # because JAX requires the in_shardings and out_shardings to have the same + # devices in the same internal order, and at the time of writing JAX does + # not provide any public API to look up the order of the devices in a + # sharding (it allows looking up the device *set*, but not their order). + # Whether or not this error happens seems to be somewhat nondeterministic. + # To avoid this, we use the private property `_device_assignment` of + # each sharding in order to figure out what device order it has, and then + # explicitly request a fully-replicated output that is definitely safe to + # retrieve. + sharding_kwargs["out_shardings"] = ( + jax.sharding.GSPMDSharding.get_replicated( + array.sharding._device_assignment # pylint: disable=protected-access + ) + ) + fn = jax.jit( + _truncate_part_with_slices, static_argnums=(2, 3), **sharding_kwargs + ) + return fn(array, mask, (), edge_items_per_axis) + + +def faster_array_repr(array: jax.Array) -> str: + """Computes ``repr(array)``, only copying the rendered array elements. + + ``repr(array)`` on a very large jax Array can be slow, because it copies the + entire array to host memory even when only a few elements are actually needed. + We can avoid this by truncating the array on device before fetching it. + + Args: + array: The array to summarize. + + Returns: + A string representation of the array. May differ slightly from the ordinary + ``repr``, but should contain the same elements. + """ + assert jax is not None, "JAX is not available." + jnp = jax.numpy + if array.size < np.get_printoptions()["threshold"]: + return repr(array) + + if array.aval is not None and array.aval.weak_type: + dtype_str = f"dtype={array.dtype.name}, weak_type=True)" + else: + dtype_str = f"dtype={array.dtype.name})" + + edgeitems = np.get_printoptions()["edgeitems"] + edge_items_per_axis = [] + for size in array.shape: + if size > 2 * edgeitems + 1: + edge_items_per_axis.append(edgeitems) + else: + edge_items_per_axis.append(None) + array_edges, _ = truncate_array_and_mask( + array, + jnp.ones((1,) * array.ndim, dtype=jnp.bool_), + edge_items_per_axis=tuple(edge_items_per_axis), + ) + prefix = "Array(" + datastring = np.array2string( + np.array(array_edges), + prefix=prefix, + suffix=",", + separator=", ", + threshold=0, + edgeitems=edgeitems, + ) + return f"{prefix}{datastring}, {dtype_str}" + + +def render_shape_dtype_struct( + node: jax.ShapeDtypeStruct, + path: str | None, + subtree_renderer: renderer.TreescopeSubtreeRenderer, +) -> ( + part_interface.RenderableTreePart + | part_interface.RenderableAndLineAnnotations + | type(NotImplemented) +): + """Renders jax.ShapeDtypeStruct.""" + assert jax is not None, "JAX is not available." + if type(node) is not jax.ShapeDtypeStruct: # pylint: disable=unidiomatic-typecheck + return NotImplemented + attributes = { + "shape": node.shape, + "dtype": node.dtype, + } + if node.named_shape: + attributes["named_shape"] = node.named_shape + if node.sharding is not None: + attributes["sharding"] = node.sharding + + # Make sure we can correctly round-trip it. We check because ShapeDtypeStruct + # occasionally adds new attributes for new JAX features. + rebuilt = jax.ShapeDtypeStruct(**attributes) + if rebuilt != node: + return NotImplemented + else: + return repr_lib.render_object_constructor( + object_type=jax.ShapeDtypeStruct, + attributes=attributes, + path=path, + subtree_renderer=subtree_renderer, + roundtrippable=True, + ) + + +def render_precision( + node: jax.lax.Precision, + path: str | None, + subtree_renderer: renderer.TreescopeSubtreeRenderer, +) -> ( + part_interface.RenderableTreePart + | part_interface.RenderableAndLineAnnotations + | type(NotImplemented) +): + """Renders jax.lax.Precision.""" + assert jax is not None, "JAX is not available." + if type(node) is not jax.lax.Precision: # pylint: disable=unidiomatic-typecheck + return NotImplemented + return repr_lib.render_enumlike_item( + object_type=jax.lax.Precision, + item_name=node.name, + item_value=node.value, + path=path, + subtree_renderer=subtree_renderer, + ) + + +def summarize_array_data(array: jax.Array) -> str: + """Summarized the data of a JAX array. + + Args: + array: The array to summarize. + + Returns: + A string summarizing the data of the array. + """ + assert jax is not None, "JAX is not available." + jnp = jax.numpy + output_parts = [] + if array.is_deleted(): + output_parts.append(" - deleted!") + elif safe_to_summarize(array): + with jax.core.ensure_compile_time_eval(): + is_floating = _is_subdtype(array.dtype, jnp.floating) + is_integer = _is_subdtype(array.dtype, jnp.integer) + is_bool = _is_subdtype(array.dtype, jnp.bool_) + + if is_floating: + mean, std, any_finite = jax.jit(_finite_mean_std_any)(array) + + if any_finite: + output_parts.append(f" ≈{float(mean):.2} ±{float(std):.2}") + output_parts.append( + f" [≥{float(jnp.nanmin(array)):.2}," + f" ≤{float(jnp.nanmax(array)):.2}]" + ) + + if is_integer: + output_parts.append(f" [≥{jnp.min(array):_d}, ≤{jnp.max(array):_d}]") + + if is_floating or is_integer: + ct_zero = jnp.count_nonzero(array == 0) + if ct_zero: + output_parts.append(f" zero:{ct_zero:_d}") + + ct_nonzero = jnp.count_nonzero(array) + if ct_nonzero: + output_parts.append(f" nonzero:{ct_nonzero:_d}") + + if is_floating: + ct_nan = jnp.count_nonzero(jnp.isnan(array)) + if ct_nan: + output_parts.append(f" nan:{ct_nan:_d}") + + ct_inf = jnp.count_nonzero(jnp.isposinf(array)) + if ct_inf: + output_parts.append(f" inf:{ct_inf:_d}") + + ct_neginf = jnp.count_nonzero(jnp.isneginf(array)) + if ct_neginf: + output_parts.append(f" -inf:{ct_neginf:_d}") + + if is_bool: + ct_true = jnp.count_nonzero(array) + if ct_true: + output_parts.append(f" true:{ct_true:_d}") + + ct_false = jnp.count_nonzero(jnp.logical_not(array)) + if ct_false: + output_parts.append(f" false:{ct_false:_d}") + return "".join(output_parts) + + +class JAXArrayAdapter(ndarray_adapters.NDArrayAdapter[jax.Array]): + """Array adapter for JAX arrays.""" + + def get_axis_info_for_array_data( + self, array: jax.Array + ) -> tuple[ndarray_adapters.AxisInfo, ...]: + assert jax is not None, "JAX is not available." + return tuple( + ndarray_adapters.PositionalAxisInfo(i, size) + for i, size in enumerate(array.shape) + ) + + def get_array_data_with_truncation( + self, + array: jax.Array, + mask: jax.Array | None, + edge_items_per_axis: tuple[int | None, ...], + ) -> tuple[jax.Array, jax.Array]: + assert jax is not None, "JAX is not available." + jnp = jax.numpy + assert not isinstance(array, jax.core.Tracer) + assert not array.is_deleted() + if mask is not None: + # Make sure we can broadcast the shape correctly. + _ = jax.eval_shape(lambda: jnp.broadcast_to(mask, array.shape)) + mask = mask[(None,) * (array.ndim - mask.ndim) + (...,)] + else: + mask = jnp.ones((1,) * array.ndim, dtype=jnp.bool_) + + if edge_items_per_axis == (None,) * array.ndim: + # No truncation. + return array, jnp.broadcast_to(mask, array.shape) + + return truncate_array_and_mask(array, mask, edge_items_per_axis) + + def get_array_summary(self, array: jax.Array, fast: bool) -> str: + output_parts = ["jax.Array "] + + output_parts.append(dtype_util.get_dtype_name(array.dtype)) + output_parts.append(repr(array.shape)) + if array.is_deleted(): + output_parts.append(" - deleted!") + elif not fast: + output_parts.append(summarize_array_data(array)) + + return "".join(output_parts) + + def get_numpy_dtype(self, array: jax.Array) -> np.dtype | None: + if isinstance(array.dtype, np.dtype): + return array.dtype + else: + return None + + def get_sharding_info_for_array_data( + self, array: jax.Array + ) -> ndarray_adapters.ShardingInfo | None: + assert jax is not None, "JAX is not available." + if isinstance(array, jax.core.Tracer) or array.is_deleted(): + return None + + [platform] = set(device.platform for device in array.sharding.device_set) + device_map = array.sharding.devices_indices_map(array.shape) + return ndarray_adapters.ShardingInfo( + shard_shape=array.sharding.shard_shape(array.shape), + device_index_to_shard_slices={ + device.id: slices for device, slices in device_map.items() + }, + device_type=platform.upper(), + fully_replicated=array.is_fully_replicated, + ) + + def should_autovisualize(self, array: jax.Array) -> bool: + assert jax is not None, "JAX is not available." + return not isinstance(array, jax.core.Tracer) and not array.is_deleted() + + +def render_jax_arrays( + node: jax.Array, + path: str | None, + subtree_renderer: renderer.TreescopeSubtreeRenderer, +) -> ( + part_interface.RenderableTreePart + | part_interface.RenderableAndLineAnnotations + | type(NotImplemented) +): + """Renders a JAX array.""" + assert jax is not None, "JAX is not available." + del subtree_renderer + assert isinstance(node, jax.Array) + if isinstance(node, jax.core.Tracer): + return NotImplemented + + adapter = JAXArrayAdapter() + + if node.is_deleted(): + return common_styles.ErrorColor( + basic_parts.Text("<" + adapter.get_array_summary(node, fast=True) + ">") + ) + + def _placeholder() -> part_interface.RenderableTreePart: + return common_structures.fake_placeholder_foldable( + common_styles.DeferredPlaceholderStyle( + basic_parts.Text(adapter.get_array_summary(node, fast=True)) + ), + extra_newlines_guess=8, + ) + + def _thunk(placeholder): + # Is this array simple enough to render without a summary? + node_repr = faster_array_repr(node) + if "\n" not in node_repr and "..." not in node_repr: + rendering = common_styles.AbbreviationColor( + basic_parts.Text(f"") + ) + else: + if node_repr.count("\n") <= 15: + if isinstance(placeholder, part_interface.FoldableTreeNode): + default_expand_state = placeholder.get_expand_state() + else: + assert placeholder is None + default_expand_state = part_interface.ExpandState.WEAKLY_EXPANDED + else: + # Always start big NDArrays in collapsed mode to hide irrelevant detail. + default_expand_state = part_interface.ExpandState.COLLAPSED + + # Render it with a summary. + summarized = adapter.get_array_summary(node, fast=False) + rendering = common_structures.build_custom_foldable_tree_node( + label=common_styles.AbbreviationColor( + common_styles.CommentColorWhenExpanded( + basic_parts.siblings( + basic_parts.FoldCondition( + expanded=basic_parts.Text("# "), + collapsed=basic_parts.Text("<"), + ), + summarized, + basic_parts.FoldCondition( + collapsed=basic_parts.Text(">") + ), + ) + ) + ), + contents=basic_parts.FoldCondition( + expanded=basic_parts.IndentedChildren.build( + [basic_parts.Text(node_repr)] + ) + ), + path=path, + expand_state=default_expand_state, + ).renderable + + return rendering + + return basic_parts.RenderableAndLineAnnotations( + renderable=foldable_impl.maybe_defer_rendering( + main_thunk=_thunk, placeholder_thunk=_placeholder + ), + annotations=common_structures.build_copy_button(path), + ) + + +def set_up_treescope(): + """Sets up treescope to render JAX objects.""" + if jax is None: + raise RuntimeError( + "Cannot set up JAX support in treescope: JAX cannot be imported." + ) + type_registries.TREESCOPE_HANDLER_REGISTRY[jax.ShapeDtypeStruct] = ( + render_shape_dtype_struct + ) + type_registries.TREESCOPE_HANDLER_REGISTRY[jax.lax.Precision] = ( + render_precision + ) + + # The concrete type of a JAX array is a private type that is dynamically + # registered as a jax.Array subclass, so we need to add it to the list of + # dynamically-checked virtual base classes. + type_registries.VIRTUAL_BASE_CLASSES.append(jax.Array) + type_registries.IMMUTABLE_TYPES_REGISTRY[jax.Array] = True + type_registries.NDARRAY_ADAPTER_REGISTRY[jax.Array] = JAXArrayAdapter() + type_registries.TREESCOPE_HANDLER_REGISTRY[jax.Array] = render_jax_arrays + + for jax_api_module in [ + jax.lax, + jax.numpy, + jax.scipy, + jax.random, + jax.nn, + jax.custom_derivatives, + jax, + ]: + canonical_aliases.populate_from_public_api( + jax_api_module, canonical_aliases.prefix_filter("jax") + ) diff --git a/penzai/treescope/handlers/interop/numpy_support.py b/penzai/treescope/handlers/interop/numpy_support.py new file mode 100644 index 0000000..0e4ee89 --- /dev/null +++ b/penzai/treescope/handlers/interop/numpy_support.py @@ -0,0 +1,319 @@ +# Copyright 2024 The Penzai Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Lazy setup logic for adding Numpy support to treescope.""" +from __future__ import annotations + +from typing import Any + +import numpy as np +from penzai.treescope import canonical_aliases +from penzai.treescope import dtype_util +from penzai.treescope import ndarray_adapters +from penzai.treescope import renderer +from penzai.treescope import type_registries +from penzai.treescope.foldable_representation import basic_parts +from penzai.treescope.foldable_representation import common_structures +from penzai.treescope.foldable_representation import common_styles +from penzai.treescope.foldable_representation import foldable_impl +from penzai.treescope.foldable_representation import part_interface + + +def _truncate_and_copy( + array_source: np.ndarray, + array_dest: np.ndarray, + prefix_slices: tuple[slice, ...], + remaining_edge_items_per_axis: tuple[int | None, ...], +) -> None: + """Recursively copy values along the edges of a source into a destination. + + This function mutates the destination array in place, copying parts of input + array into them, so that it contains a truncated versions of the original + array. + + Args: + array_source: Source array, which we will truncate. + array_dest: Destination array, whose axis sizes will be either the same as + `array_source` or of size `2 * edge_items + 1` depending on the + truncation. + prefix_slices: Prefix of slices for the source and destination. + remaining_edge_items_per_axis: Number of edge items to keep for each axis, + ignoring any axes whose slices are already computed in `source_slices`. + """ + if not remaining_edge_items_per_axis: + # Perform the base case slice. + assert ( + len(prefix_slices) == len(array_source.shape) == len(array_dest.shape) + ) + array_dest[prefix_slices] = array_source[prefix_slices] + else: + # Recursive step. + axis = len(prefix_slices) + edge_items = remaining_edge_items_per_axis[0] + if edge_items is None: + # Don't need to slice. + _truncate_and_copy( + array_source=array_source, + array_dest=array_dest, + prefix_slices=prefix_slices + (slice(None),), + remaining_edge_items_per_axis=remaining_edge_items_per_axis[1:], + ) + else: + assert array_source.shape[axis] > 2 * edge_items + _truncate_and_copy( + array_source=array_source, + array_dest=array_dest, + prefix_slices=prefix_slices + (slice(None, edge_items),), + remaining_edge_items_per_axis=remaining_edge_items_per_axis[1:], + ) + _truncate_and_copy( + array_source=array_source, + array_dest=array_dest, + prefix_slices=prefix_slices + (slice(-edge_items, None),), + remaining_edge_items_per_axis=remaining_edge_items_per_axis[1:], + ) + + +class NumpyArrayAdapter(ndarray_adapters.NDArrayAdapter[np.ndarray]): + """NDArray adapter for numpy arrays.""" + + def get_axis_info_for_array_data( + self, array: np.ndarray + ) -> tuple[ndarray_adapters.AxisInfo, ...]: + return tuple( + ndarray_adapters.PositionalAxisInfo(i, size) + for i, size in enumerate(array.shape) + ) + + def get_array_data_with_truncation( + self, + array: np.ndarray, + mask: np.ndarray | None, + edge_items_per_axis: tuple[int | None, ...], + ) -> tuple[np.ndarray, np.ndarray]: + + if mask is None: + mask = np.ones((1,) * array.ndim, dtype=bool) + + # Broadcast mask. (Note: Broadcasting a Numpy array does not copy data.) + mask = np.broadcast_to(mask, array.shape) + + if edge_items_per_axis == (None,) * array.ndim: + # No truncation. + return array, mask + + dest_shape = [ + size if edge_items is None else 2 * edge_items + 1 + for size, edge_items in zip(array.shape, edge_items_per_axis) + ] + array_dest = np.zeros(dest_shape, array.dtype) + mask_dest = np.zeros(dest_shape, bool) + _truncate_and_copy( + array_source=array, + array_dest=array_dest, + prefix_slices=(), + remaining_edge_items_per_axis=edge_items_per_axis, + ) + _truncate_and_copy( + array_source=mask, + array_dest=mask_dest, + prefix_slices=(), + remaining_edge_items_per_axis=edge_items_per_axis, + ) + return array_dest, mask_dest + + def get_array_summary(self, array: np.ndarray, fast: bool) -> str: + output_parts = ["np.ndarray "] + + output_parts.append(dtype_util.get_dtype_name(array.dtype)) + output_parts.append(repr(array.shape)) + + if array.size > 0 and array.size < 100_000 and not fast: + is_floating = dtype_util.is_floating_dtype(array.dtype) + is_integer = dtype_util.is_integer_dtype(array.dtype) + is_bool = np.issubdtype(array.dtype, np.bool_) + + if is_floating: + isfinite = np.isfinite(array) + any_finite = np.any(isfinite) + inf_to_nan = np.where( + isfinite, array, np.array(np.nan, dtype=array.dtype) + ) + mean = np.nanmean(inf_to_nan) + std = np.nanstd(inf_to_nan) + + if any_finite: + output_parts.append(f" ≈{float(mean):.2} ±{float(std):.2}") + output_parts.append( + f" [≥{float(np.nanmin(array)):.2}, ≤{float(np.nanmax(array)):.2}]" + ) + + if is_integer: + output_parts.append(f" [≥{np.min(array):_d}, ≤{np.max(array):_d}]") + + if is_floating or is_integer: + ct_zero = np.count_nonzero(array == 0) + if ct_zero: + output_parts.append(f" zero:{ct_zero:_d}") + + ct_nonzero = np.count_nonzero(array) + if ct_nonzero: + output_parts.append(f" nonzero:{ct_nonzero:_d}") + + if is_floating: + ct_nan = np.count_nonzero(np.isnan(array)) + if ct_nan: + output_parts.append(f" nan:{ct_nan:_d}") + + ct_inf = np.count_nonzero(np.isposinf(array)) + if ct_inf: + output_parts.append(f" inf:{ct_inf:_d}") + + ct_neginf = np.count_nonzero(np.isneginf(array)) + if ct_neginf: + output_parts.append(f" -inf:{ct_neginf:_d}") + + if is_bool: + ct_true = np.count_nonzero(array) + if ct_true: + output_parts.append(f" true:{ct_true:_d}") + + ct_false = np.count_nonzero(np.logical_not(array)) + if ct_false: + output_parts.append(f" false:{ct_false:_d}") + + return "".join(output_parts) + + def get_numpy_dtype(self, array: np.ndarray) -> np.dtype: + return array.dtype + + +def render_ndarrays( + node: np.ndarray, + path: str | None, + subtree_renderer: renderer.TreescopeSubtreeRenderer, +) -> ( + part_interface.RenderableTreePart + | part_interface.RenderableAndLineAnnotations + | type(NotImplemented) +): + """Renders a numpy array.""" + del subtree_renderer + assert isinstance(node, np.ndarray) + adapter = NumpyArrayAdapter() + + def _placeholder() -> part_interface.RenderableTreePart: + return common_structures.fake_placeholder_foldable( + common_styles.DeferredPlaceholderStyle( + basic_parts.Text(adapter.get_array_summary(node, fast=True)) + ), + extra_newlines_guess=8, + ) + + def _thunk(placeholder): + # Is this array simple enough to render without a summary? + node_repr = repr(node) + if "\n" not in node_repr and "..." not in node_repr: + rendering = basic_parts.Text(f"np.{node_repr}") + else: + if node_repr.count("\n") <= 15: + if isinstance(placeholder, part_interface.FoldableTreeNode): + default_expand_state = placeholder.get_expand_state() + else: + assert placeholder is None + default_expand_state = part_interface.ExpandState.WEAKLY_EXPANDED + else: + # Always start big NDArrays in collapsed mode to hide irrelevant detail. + default_expand_state = part_interface.ExpandState.COLLAPSED + + # Render it with a summary. + summarized = adapter.get_array_summary(node, fast=False) + rendering = common_structures.build_custom_foldable_tree_node( + label=common_styles.AbbreviationColor( + common_styles.CommentColorWhenExpanded( + basic_parts.siblings( + basic_parts.FoldCondition( + expanded=basic_parts.Text("# "), + collapsed=basic_parts.Text("<"), + ), + summarized, + basic_parts.FoldCondition( + collapsed=basic_parts.Text(">") + ), + ) + ) + ), + contents=basic_parts.FoldCondition( + expanded=basic_parts.IndentedChildren.build( + [basic_parts.Text(node_repr)] + ) + ), + path=path, + expand_state=default_expand_state, + ).renderable + + return rendering + + return basic_parts.RenderableAndLineAnnotations( + renderable=foldable_impl.maybe_defer_rendering( + main_thunk=_thunk, placeholder_thunk=_placeholder + ), + annotations=common_structures.build_copy_button(path), + ) + + +def render_dtype_instances( + node: Any, + path: str | None, + subtree_renderer: renderer.TreescopeSubtreeRenderer, +) -> ( + part_interface.RenderableTreePart + | part_interface.RenderableAndLineAnnotations + | type(NotImplemented) +): + """Renders a np.dtype, adding the `np.` qualifier.""" + del subtree_renderer + if not isinstance(node, np.dtype): + return NotImplemented + + dtype_name = node.name + if dtype_name in np.sctypeDict and node is np.dtype( + np.sctypeDict[dtype_name] + ): + # Use the named type. (Sometimes extended dtypes don't print in a + # roundtrippable way otherwise.) + dtype_string = f"dtype({repr(dtype_name)})" + else: + # Hope that `repr` is already round-trippable (true for builtin numpy types) + # and add the "numpy." prefix as needed. + dtype_string = repr(node) + + return common_structures.build_one_line_tree_node( + line=basic_parts.siblings( + basic_parts.RoundtripCondition(roundtrip=basic_parts.Text("np.")), + dtype_string, + ), + path=path, + ) + + +def set_up_treescope(): + """Sets up treescope to render Numpy objects.""" + type_registries.NDARRAY_ADAPTER_REGISTRY[np.ndarray] = NumpyArrayAdapter() + type_registries.TREESCOPE_HANDLER_REGISTRY[np.ndarray] = render_ndarrays + type_registries.TREESCOPE_HANDLER_REGISTRY[np.dtype] = render_dtype_instances + + canonical_aliases.populate_from_public_api( + np, canonical_aliases.prefix_filter("numpy", excludes=("numpy.core",)) + ) diff --git a/penzai/treescope/handlers/interop/penzai_core_support.py b/penzai/treescope/handlers/interop/penzai_core_support.py new file mode 100644 index 0000000..602e543 --- /dev/null +++ b/penzai/treescope/handlers/interop/penzai_core_support.py @@ -0,0 +1,141 @@ +# Copyright 2024 The Penzai Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Lazy setup logic for adding `penzai.core` support to treescope. + +This is defined in a separate lazily-imported module to allow `penzai.treescope` +to render Penzai's named arrays if they are used, but not require importing +`penzai.core` if they are not used. +""" + +import jax +import numpy as np +from penzai.core import named_axes +from penzai.core._treescope_handlers import named_axes_handlers +from penzai.treescope import ndarray_adapters +from penzai.treescope import type_registries +from penzai.treescope.handlers.interop import jax_support + + +class NamedArrayAdapter( + ndarray_adapters.NDArrayAdapter[named_axes.NamedArrayBase] +): + """Array adapter for Penzai named arrays.""" + + def get_axis_info_for_array_data( + self, array: named_axes.NamedArrayBase + ) -> tuple[ndarray_adapters.AxisInfo, ...]: + array = array.as_namedarrayview() + infos = {} + for name, axis in array.data_axis_for_name.items(): + infos[axis] = ndarray_adapters.NamedPositionlessAxisInfo( + axis_name=name, + size=array.data_shape[axis], + ) + for logical_axis, axis in enumerate(array.data_axis_for_logical_axis): + infos[axis] = ndarray_adapters.PositionalAxisInfo( + axis_logical_index=logical_axis, + size=array.data_shape[axis], + ) + return tuple(infos[i] for i in range(len(infos))) + + def get_array_data_with_truncation( + self, + array: named_axes.NamedArrayBase, + mask: named_axes.NamedArrayBase | jax.Array | np.ndarray | None, + edge_items_per_axis: tuple[int | None, ...], + ) -> tuple[named_axes.NamedArrayBase, named_axes.NamedArrayBase]: + array = array.as_namedarrayview() + if mask is None: + mask_data = None + else: + # Make sure mask is compatible. + if not isinstance(mask, named_axes.NamedArrayBase): + mask = named_axes.wrap(mask) + bad_names = set(mask.named_shape.keys()) - set(array.named_shape.keys()) + if bad_names: + raise ValueError( + "Valid mask must be broadcastable to the shape of `array`, but it" + f" had extra axis names {bad_names}" + ) + + vshape = mask.positional_shape + ashape = array.positional_shape + if np.broadcast_shapes(vshape, ashape) != ashape: + raise ValueError( + "Valid mask must be broadcastable to the shape of `array`, but its" + f" positional shape {vshape} does not match (a suffix of) the" + f" positional shape {ashape} of `array`" + ) + + # Insert new length-1 axes. + new_names = set(array.named_shape.keys()) - set(mask.named_shape.keys()) + if new_names: + mask = mask[{name: None for name in new_names}] + new_positional_axis_count = len(ashape) - len(vshape) + if new_positional_axis_count: + mask = mask[(None,) * new_positional_axis_count + (...,)] + + # Possibly transpose the mask to match the main array, and extract its + # data. + mask_data = mask.order_like(array).data_array + + return jax_support.JAXArrayAdapter().get_array_data_with_truncation( + array=array.data_array, + mask=mask_data, + edge_items_per_axis=edge_items_per_axis, + ) + + def get_array_summary( + self, array: named_axes.NamedArrayBase, fast: bool + ) -> str: + summary, contained_type = ( + named_axes_handlers.named_array_and_contained_type_summary( + array, inspect_device_data=not fast + ) + ) + return f"{type(array).__name__} {summary} (wrapping {contained_type})" + + def get_numpy_dtype( + self, array: named_axes.NamedArrayBase + ) -> np.dtype | None: + if isinstance(array.dtype, np.dtype): + return array.dtype + else: + return None + + def get_sharding_info_for_array_data( + self, array: named_axes.NamedArrayBase + ) -> ndarray_adapters.ShardingInfo | None: + array = array.as_namedarrayview() + if not isinstance(array.data_array, jax.Array): + return None + return jax_support.JAXArrayAdapter().get_sharding_info_for_array_data( + array.data_array + ) + + def should_autovisualize(self, array: named_axes.NamedArrayBase) -> bool: + array = array.as_namedarrayview() + return ( + isinstance(array.data_array, jax.Array) + and not isinstance(array.data_array, jax.core.Tracer) + and not array.data_array.is_deleted() + ) + + +def set_up_treescope(): + """Sets up treescope to render Penzai named arrays.""" + type_registries.NDARRAY_ADAPTER_REGISTRY[named_axes.NamedArrayBase] = ( + NamedArrayAdapter() + ) diff --git a/penzai/treescope/handlers/interop/torch_support.py b/penzai/treescope/handlers/interop/torch_support.py new file mode 100644 index 0000000..bc6921d --- /dev/null +++ b/penzai/treescope/handlers/interop/torch_support.py @@ -0,0 +1,532 @@ +# Copyright 2024 The Penzai Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Lazy setup logic for adding PyTorch support to treescope.""" + +from __future__ import annotations + +import keyword +import typing + +import numpy as np +from penzai.treescope import context +from penzai.treescope import formatting_util +from penzai.treescope import ndarray_adapters +from penzai.treescope import renderer +from penzai.treescope import type_registries +from penzai.treescope.foldable_representation import basic_parts +from penzai.treescope.foldable_representation import common_structures +from penzai.treescope.foldable_representation import common_styles +from penzai.treescope.foldable_representation import foldable_impl +from penzai.treescope.foldable_representation import part_interface +from penzai.treescope.handlers import builtin_structure_handler + +# pylint: disable=g-import-not-at-top +try: + import torch +except ImportError: + assert not typing.TYPE_CHECKING + torch = None +# pylint: enable=g-import-not-at-top + +show_dynamic_attributes: context.ContextualValue[bool] = ( + context.ContextualValue( + module=__name__, + qualname="show_dynamic_attributes", + initial_value=True, + ) +) +"""Whether to inspect and show all non-private attributes of Torch modules. + +If set to True, when rendering a Torch module, we will walk all of its +attributes (the entries in its `__dict__`) and render every attribute that does +not start with an underscore. If set to False, we will defer to the `extra_repr` +method instead, which matches default Torch repr behavior. +""" + + +def _truncate_and_copy( + array_source: torch.Tensor, + array_dest: np.ndarray, + prefix_slices: tuple[slice, ...], + remaining_edge_items_per_axis: tuple[int | None, ...], +) -> None: + """Recursively copy values on the edges of a torch tensor into a numpy array. + + This function mutates the destination array in place, copying parts of input + array into them, so that it contains a truncated versions of the original + array. + + Args: + array_source: Source array, which we will truncate. + array_dest: Destination array, whose axis sizes will be either the same as + `array_source` or of size `2 * edge_items + 1` depending on the + truncation. + prefix_slices: Prefix of slices for the source and destination. + remaining_edge_items_per_axis: Number of edge items to keep for each axis, + ignoring any axes whose slices are already computed in `source_slices`. + """ + assert torch is not None, "PyTorch is not available." + if not remaining_edge_items_per_axis: + # Perform the base case slice. + assert ( + len(prefix_slices) == len(array_source.shape) == len(array_dest.shape) + ) + array_dest[prefix_slices] = array_source[prefix_slices].numpy() + else: + # Recursive step. + axis = len(prefix_slices) + edge_items = remaining_edge_items_per_axis[0] + if edge_items is None: + # Don't need to slice. + _truncate_and_copy( + array_source=array_source, + array_dest=array_dest, + prefix_slices=prefix_slices + (slice(None),), + remaining_edge_items_per_axis=remaining_edge_items_per_axis[1:], + ) + else: + assert array_source.shape[axis] > 2 * edge_items + _truncate_and_copy( + array_source=array_source, + array_dest=array_dest, + prefix_slices=prefix_slices + (slice(None, edge_items),), + remaining_edge_items_per_axis=remaining_edge_items_per_axis[1:], + ) + _truncate_and_copy( + array_source=array_source, + array_dest=array_dest, + prefix_slices=prefix_slices + (slice(-edge_items, None),), + remaining_edge_items_per_axis=remaining_edge_items_per_axis[1:], + ) + + +class TorchTensorAdapter(ndarray_adapters.NDArrayAdapter[torch.Tensor]): + """NDArray adapter for Torch tensors.""" + + def get_axis_info_for_array_data( + self, array: torch.Tensor + ) -> tuple[ndarray_adapters.AxisInfo, ...]: + infos = [] + for i, (size, name) in enumerate(zip(array.shape, array.names)): + if name is None: + infos.append(ndarray_adapters.PositionalAxisInfo(i, size)) + else: + infos.append( + ndarray_adapters.NamedPositionalAxisInfo( + axis_logical_index=i, axis_name=name, size=size + ) + ) + return tuple(infos) + + def get_array_data_with_truncation( + self, + array: torch.Tensor, + mask: torch.Tensor | None, + edge_items_per_axis: tuple[int | None, ...], + ) -> tuple[np.ndarray, np.ndarray]: + assert torch is not None, "PyTorch is not available." + array = array.detach() + + if mask is None: + mask = np.ones((1,) * array.ndim, dtype=bool) + + mask = torch.as_tensor(mask).detach() + + # Broadcast mask. (Note: Broadcasting a Numpy array does not copy data.) + mask = torch.broadcast_to(mask, array.shape) + + if edge_items_per_axis == (None,) * array.ndim: + # No truncation. + return array.numpy(), mask.numpy() + + dest_shape = [ + size if edge_items is None else 2 * edge_items + 1 + for size, edge_items in zip(array.shape, edge_items_per_axis) + ] + array_dest = np.zeros(dest_shape, dtype=self.get_numpy_dtype(array)) + mask_dest = np.zeros(dest_shape, dtype=np.bool_) + _truncate_and_copy( + array_source=array, + array_dest=array_dest, + prefix_slices=(), + remaining_edge_items_per_axis=edge_items_per_axis, + ) + _truncate_and_copy( + array_source=mask, + array_dest=mask_dest, + prefix_slices=(), + remaining_edge_items_per_axis=edge_items_per_axis, + ) + return array_dest, mask_dest + + def get_array_summary(self, array: torch.Tensor, fast: bool) -> str: + assert torch is not None, "PyTorch is not available." + ty = type(array) + array = array.detach() + typename = f"{ty.__module__}.{ty.__name__}" + if typename == "torch.nn.parameter.Parameter": + typename = "torch.nn.Parameter" + output_parts = [f"{typename} "] + + output_parts.append(repr(array.dtype).removeprefix("torch.")) + name_parts = [] + for size, name in zip(array.shape, array.names): + if name: + name_parts.append(f"{name}:{size}") + else: + name_parts.append(f"{size}") + if len(name_parts) == 1: + output_parts.append("(" + name_parts[0] + ",)") + else: + output_parts.append("(" + ", ".join(name_parts) + ")") + + # Drop axis names. + array = array.rename(None) + size = np.prod(array.shape) + if size > 0 and size < 100_000 and not fast: + is_floating = array.dtype.is_floating_point + is_bool = array.dtype == torch.bool + is_integer = ( + not is_floating and not is_bool and not array.dtype.is_complex + ) + + if is_floating: + isfinite = torch.isfinite(array) + any_finite = torch.any(isfinite) + inf_to_nan = torch.where(isfinite, array, torch.nan) + mean = torch.nanmean(inf_to_nan) + std = torch.nanmean(torch.square(inf_to_nan - mean)) + + if any_finite: + output_parts.append(f" ≈{float(mean):.2} ±{float(std):.2}") + nanmin = torch.amin(torch.where(isfinite, array, torch.inf)) + nanmax = torch.amax(torch.where(isfinite, array, -torch.inf)) + output_parts.append(f" [≥{float(nanmin):.2}, ≤{float(nanmax):.2}]") + + if is_integer: + output_parts.append( + f" [≥{torch.amin(array):_d}, ≤{torch.amax(array):_d}]" + ) + + if is_floating or is_integer: + ct_zero = torch.count_nonzero(array == 0) + if ct_zero: + output_parts.append(f" zero:{ct_zero:_d}") + + ct_nonzero = torch.count_nonzero(array) + if ct_nonzero: + output_parts.append(f" nonzero:{ct_nonzero:_d}") + + if is_floating: + ct_nan = torch.count_nonzero(torch.isnan(array)) + if ct_nan: + output_parts.append(f" nan:{ct_nan:_d}") + + ct_inf = torch.count_nonzero(torch.isposinf(array)) + if ct_inf: + output_parts.append(f" inf:{ct_inf:_d}") + + ct_neginf = torch.count_nonzero(torch.isneginf(array)) + if ct_neginf: + output_parts.append(f" -inf:{ct_neginf:_d}") + + if is_bool: + ct_true = torch.count_nonzero(array) + if ct_true: + output_parts.append(f" true:{ct_true:_d}") + + ct_false = torch.count_nonzero(torch.logical_not(array)) + if ct_false: + output_parts.append(f" false:{ct_false:_d}") + + return "".join(output_parts) + + def get_numpy_dtype(self, array: torch.Tensor) -> np.dtype: + assert torch is not None, "PyTorch is not available." + # Convert a zero-sized tensor to a numpy array to get its dtype. + return torch.zeros((0,), dtype=array.dtype).numpy().dtype + + +def render_torch_tensors( + node: torch.Tensor, + path: str | None, + subtree_renderer: renderer.TreescopeSubtreeRenderer, +) -> ( + part_interface.RenderableTreePart + | part_interface.RenderableAndLineAnnotations + | type(NotImplemented) +): + """Renders a numpy array.""" + assert torch is not None, "PyTorch is not available." + del subtree_renderer + assert isinstance(node, torch.Tensor) + adapter = TorchTensorAdapter() + + def _placeholder() -> part_interface.RenderableTreePart: + return common_structures.fake_placeholder_foldable( + common_styles.DeferredPlaceholderStyle( + basic_parts.Text(adapter.get_array_summary(node, fast=True)) + ), + extra_newlines_guess=8, + ) + + def _thunk(placeholder): + # Is this array simple enough to render without a summary? + node_repr = repr(node) + if "\n" not in node_repr and "..." not in node_repr: + if node_repr.startswith("tensor("): + # Add module path, for consistency with other Treescope renderings. + node_repr = f"torch.{node_repr}" + rendering = basic_parts.Text(node_repr) + else: + if node_repr.count("\n") <= 15: + if isinstance(placeholder, part_interface.FoldableTreeNode): + default_expand_state = placeholder.get_expand_state() + else: + assert placeholder is None + default_expand_state = part_interface.ExpandState.WEAKLY_EXPANDED + else: + # Always start big NDArrays in collapsed mode to hide irrelevant detail. + default_expand_state = part_interface.ExpandState.COLLAPSED + + # Render it with a summary. + summarized = adapter.get_array_summary(node, fast=False) + rendering = common_structures.build_custom_foldable_tree_node( + label=common_styles.AbbreviationColor( + common_styles.CommentColorWhenExpanded( + basic_parts.siblings( + basic_parts.FoldCondition( + expanded=basic_parts.Text("# "), + collapsed=basic_parts.Text("<"), + ), + summarized, + basic_parts.FoldCondition( + collapsed=basic_parts.Text(">") + ), + ) + ) + ), + contents=basic_parts.FoldCondition( + expanded=basic_parts.IndentedChildren.build( + [basic_parts.Text(node_repr)] + ) + ), + path=path, + expand_state=default_expand_state, + ).renderable + + return rendering + + return basic_parts.RenderableAndLineAnnotations( + renderable=foldable_impl.maybe_defer_rendering( + main_thunk=_thunk, placeholder_thunk=_placeholder + ), + annotations=common_structures.build_copy_button(path), + ) + + +def render_torch_modules( + node: torch.nn.Module, + path: str | None, + subtree_renderer: renderer.TreescopeSubtreeRenderer, +) -> ( + part_interface.RenderableTreePart + | part_interface.RenderableAndLineAnnotations + | type(NotImplemented) +): + """Renders a torch module.""" + assert torch is not None, "PyTorch is not available." + assert isinstance(node, torch.nn.Module) + node_type = type(node) + constructor = basic_parts.siblings( + basic_parts.RoundtripCondition(roundtrip=basic_parts.Text("<")), + common_structures.maybe_qualified_type_name(node_type), + "(", + ) + closing_suffix = basic_parts.siblings( + ")", + basic_parts.RoundtripCondition(roundtrip=basic_parts.Text(">")), + ) + + if hasattr(node, "__treescope_color__") and callable( + node.__treescope_color__ + ): + background_color, background_pattern = ( + builtin_structure_handler.parse_color_and_pattern( + node.__treescope_color__(), node_type.__name__ + ) + ) + elif type(node) is torch.nn.Sequential: # pylint: disable=unidiomatic-typecheck + background_color = "#cdcdcd" + background_pattern = "color-mix(in oklab, #cdcdcd 25%, white)" + elif type(node).forward is torch.nn.Module.forward: + # No implementation of forward. Don't color-code; this is probably a + # container like ModuleList or ModuleDict. + background_color = None + background_pattern = None + else: + type_string = node_type.__module__ + "." + node_type.__qualname__ + background_color = formatting_util.color_from_string(type_string) + background_pattern = None + + children = [] + prefers_expand = False + attr_children = None + has_attr_children_expander = False + + # Render constant attributes. + if show_dynamic_attributes.get(): + attr_children = [] + key_order = [ + key + for key in vars(node) + if not key.startswith("_") and key != "training" + ] + if "training" in vars(node): + key_order.append("training") + for attr in key_order: + value = vars(node)[attr] + child_path = None if path is None else f"{path}.{attr}" + attr_children.append( + basic_parts.build_full_line_with_annotations( + basic_parts.siblings_with_annotations( + f"{attr}=", + subtree_renderer(value, path=child_path), + ",", + basic_parts.FoldCondition(collapsed=basic_parts.Text(" ")), + ) + ) + ) + if len(attr_children) <= 1: + children.extend(attr_children) + else: + has_attr_children_expander = True + children.append( + common_structures.build_custom_foldable_tree_node( + label=basic_parts.FoldCondition( + expanded=common_styles.CommentColor( + basic_parts.Text("# Attributes:") + ), + ), + contents=basic_parts.OnSeparateLines.build(attr_children), + path=None, + expand_state=part_interface.ExpandState.COLLAPSED, + ) + ) + else: + extra_repr = node.extra_repr() + if extra_repr: + if not extra_repr.strip().endswith(","): + extra_repr = extra_repr + ", " + if "\n" in extra_repr: + children.append( + basic_parts.OnSeparateLines.build(extra_repr.split("\n")) + ) + prefers_expand = True + else: + children.append(basic_parts.Text(extra_repr)) + + # Render parameters and buffers + for group_name, group in ( + ("Parameters", node.named_parameters(recurse=False)), + ("Buffers", node.named_buffers(recurse=False)), + ): + group = list(group) + if group: + children.append( + basic_parts.FoldCondition( + expanded=common_styles.CommentColor( + basic_parts.Text(f"# {group_name}:") + ) + ) + ) + for name, value in group: + child_path = None if path is None else f"{path}.{name}" + children.append( + basic_parts.build_full_line_with_annotations( + basic_parts.siblings_with_annotations( + f"{name}=", + subtree_renderer(value, path=child_path), + ",", + basic_parts.FoldCondition(collapsed=basic_parts.Text(" ")), + ) + ) + ) + + # Render submodules. + submodules = list(node.named_children()) + if submodules: + children.append( + basic_parts.FoldCondition( + expanded=common_styles.CommentColor( + basic_parts.Text("# Child modules:") + ) + ) + ) + for name, submod in submodules: + prefers_expand = True + if name.isidentifier() and not keyword.iskeyword(name): + child_path = None if path is None else f"{path}.{name}" + keystr = f"{name}=" + else: + child_path = f"{path}.get_submodule({repr(name)})" + keystr = f"({name}): " + children.append( + basic_parts.build_full_line_with_annotations( + basic_parts.siblings_with_annotations( + keystr, + subtree_renderer(submod, path=child_path), + ",", + basic_parts.FoldCondition(collapsed=basic_parts.Text(" ")), + ) + ) + ) + + # If there are only dynamic attributes, don't add the level of indirection. + if has_attr_children_expander and len(children) == 1: + children = attr_children + + # Heuristic: If a module doesn't have any submodules, mark it collapsed, to + # match the behavior of PyTorch repr. + if prefers_expand: + expand_state = part_interface.ExpandState.WEAKLY_EXPANDED + else: + expand_state = part_interface.ExpandState.COLLAPSED + + return common_structures.build_foldable_tree_node_from_children( + prefix=constructor, + children=children, + suffix=closing_suffix, + path=path, + background_color=background_color, + background_pattern=background_pattern, + expand_state=expand_state, + ) + + +def set_up_treescope(): + """Sets up treescope to render PyTorch objects.""" + if torch is None: + raise RuntimeError( + "Cannot set up PyTorch support in treescope: PyTorch cannot be" + " imported." + ) + type_registries.NDARRAY_ADAPTER_REGISTRY[torch.Tensor] = TorchTensorAdapter() + type_registries.TREESCOPE_HANDLER_REGISTRY[torch.Tensor] = ( + render_torch_tensors + ) + type_registries.TREESCOPE_HANDLER_REGISTRY[torch.nn.Module] = ( + render_torch_modules + ) diff --git a/penzai/treescope/handlers/ndarray_handler.py b/penzai/treescope/handlers/ndarray_handler.py deleted file mode 100644 index d105a18..0000000 --- a/penzai/treescope/handlers/ndarray_handler.py +++ /dev/null @@ -1,176 +0,0 @@ -# Copyright 2024 The Penzai Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Handler for NDArrays.""" - -from typing import Any - -import jax -import numpy as np -from penzai.treescope import canonical_aliases -from penzai.treescope import copypaste_fallback -from penzai.treescope import ndarray_summarization -from penzai.treescope import renderer -from penzai.treescope.foldable_representation import basic_parts -from penzai.treescope.foldable_representation import common_structures -from penzai.treescope.foldable_representation import common_styles -from penzai.treescope.foldable_representation import foldable_impl -from penzai.treescope.foldable_representation import part_interface - - -def handle_ndarrays( - node: Any, - path: str | None, - subtree_renderer: renderer.TreescopeSubtreeRenderer, -) -> ( - part_interface.RenderableTreePart - | part_interface.RenderableAndLineAnnotations - | type(NotImplemented) -): - """Renders a NDArray.""" - del subtree_renderer - if not isinstance(node, (np.ndarray, jax.Array)): - return NotImplemented - - if isinstance(node, jax.core.Tracer): - return NotImplemented - - # What to call it? - if isinstance(node, np.ndarray): - np_name = canonical_aliases.maybe_local_module_name(np) - prefix = f"{np_name}." - short = f"{np_name}.ndarray" - else: - jax_name = canonical_aliases.maybe_local_module_name(jax) - prefix = f"{jax_name}." - short = f"{jax_name}.Array" - - if node.is_deleted(): - return common_styles.ErrorColor( - basic_parts.Text( - f"<{short} {ndarray_summarization.get_dtype_name(node.dtype)}{repr(node.shape)} -" - " deleted!>" - ) - ) - - def _placeholder() -> part_interface.RenderableTreePart: - short_summary = ( - f"<{short} {ndarray_summarization.get_dtype_name(node.dtype)}{repr(node.shape)} ... >" - ) - return common_structures.fake_placeholder_foldable( - common_styles.DeferredPlaceholderStyle(basic_parts.Text(short_summary)), - extra_newlines_guess=8, - ) - - def _thunk(placeholder): - # Is this array simple enough to render without a summary? - node_repr = ndarray_summarization.faster_array_repr(node) - if "\n" not in node_repr and "..." not in node_repr: - rendering = common_styles.AbbreviationColor( - basic_parts.Text(f"<{prefix}{node_repr}>") - ) - repr_summary = node_repr - else: - if node_repr.count("\n") <= 15: - if isinstance(placeholder, part_interface.FoldableTreeNode): - default_expand_state = placeholder.get_expand_state() - else: - assert placeholder is None - default_expand_state = part_interface.ExpandState.WEAKLY_EXPANDED - else: - # Always start big NDArrays in collapsed mode to hide irrelevant detail. - default_expand_state = part_interface.ExpandState.COLLAPSED - - # Render it with a summary. - summarized = ndarray_summarization.summarize_ndarray(node) - repr_summary = f"<{short} {summarized}>" - rendering = common_structures.build_custom_foldable_tree_node( - label=common_styles.AbbreviationColor( - common_styles.CommentColorWhenExpanded( - basic_parts.siblings( - basic_parts.FoldCondition( - expanded=basic_parts.Text("# "), - collapsed=basic_parts.Text("<"), - ), - f"{short} " + summarized, - basic_parts.FoldCondition( - collapsed=basic_parts.Text(">") - ), - ) - ) - ), - contents=basic_parts.FoldCondition( - expanded=basic_parts.IndentedChildren.build( - [basic_parts.Text(node_repr)] - ) - ), - path=path, - expand_state=default_expand_state, - ).renderable - - fallback_rendering = copypaste_fallback.render_not_roundtrippable( - node, repr_override=repr_summary - ) - return basic_parts.RoundtripCondition( - roundtrip=fallback_rendering, - not_roundtrip=rendering, - ) - - return basic_parts.RenderableAndLineAnnotations( - renderable=foldable_impl.maybe_defer_rendering( - main_thunk=_thunk, placeholder_thunk=_placeholder - ), - annotations=common_structures.build_copy_button(path), - ) - - -def handle_dtype_instances( - node: Any, - path: str | None, - subtree_renderer: renderer.TreescopeSubtreeRenderer, -) -> ( - part_interface.RenderableTreePart - | part_interface.RenderableAndLineAnnotations - | type(NotImplemented) -): - """Renders a np.dtype, adding the `np.` qualifier.""" - del subtree_renderer - if not isinstance(node, np.dtype): - return NotImplemented - - dtype_name = node.name - if dtype_name in np.sctypeDict and node is np.dtype( - np.sctypeDict[dtype_name] - ): - # Use the named type. (Sometimes extended dtypes don't print in a - # roundtrippable way otherwise.) - dtype_string = f"dtype({repr(dtype_name)})" - else: - # Hope that `repr` is already round-trippable (true for builtin numpy types) - # and add the "numpy." prefix as needed. - dtype_string = repr(node) - - # Use an alias for numpy if one is defined, since people often import numpy - # as np. - np_name = canonical_aliases.maybe_local_module_name(np) - - return common_structures.build_one_line_tree_node( - line=basic_parts.siblings( - basic_parts.RoundtripCondition( - roundtrip=basic_parts.Text(f"{np_name}.") - ), - dtype_string, - ), - path=path, - ) diff --git a/penzai/treescope/handlers/shared_value_postprocessor.py b/penzai/treescope/handlers/shared_value_postprocessor.py index 494c109..1be45e7 100644 --- a/penzai/treescope/handlers/shared_value_postprocessor.py +++ b/penzai/treescope/handlers/shared_value_postprocessor.py @@ -21,13 +21,12 @@ import contextlib import dataclasses import io -import types from typing import Any, Optional, Sequence -import jax from penzai.treescope import context from penzai.treescope import html_escaping from penzai.treescope import renderer +from penzai.treescope import type_registries from penzai.treescope.foldable_representation import basic_parts from penzai.treescope.foldable_representation import part_interface @@ -258,30 +257,18 @@ def setup_shared_value_context() -> contextlib.AbstractContextManager[None]: return _shared_object_ids_seen.set_scoped(_SharedObjectTracker({}, set())) -# Types that can have multiple references in the same object without it being -# necessary or important to highlight the shared reference. -_SAFE_TO_SHARE_TYPES = { - jax.Array, - types.FunctionType, - types.MethodType, - types.ModuleType, - type, - type(None), - type(NotImplemented), - type(Ellipsis), -} - - def _is_safe_to_share(node: Any) -> bool: """Returns whether the given node is immutable.""" # According to the Python data model, "If a class defines mutable objects and # implements an __eq__() method, it should not implement __hash__()". So, if # we find an object that implements __eq__ and __hash__, we can generally # assume it is immutable. - return isinstance(node, tuple(_SAFE_TO_SHARE_TYPES)) or ( + return ( type(node).__hash__ is not None and type(node).__hash__ is not object.__hash__ and type(node).__eq__ is not object.__eq__ + ) or type_registries.lookup_by_mro( + type_registries.IMMUTABLE_TYPES_REGISTRY, type(node) ) diff --git a/penzai/treescope/ndarray_adapters.py b/penzai/treescope/ndarray_adapters.py new file mode 100644 index 0000000..10980ec --- /dev/null +++ b/penzai/treescope/ndarray_adapters.py @@ -0,0 +1,295 @@ +# Copyright 2024 The Penzai Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""NDArray adapter interface. + +This module defines an interface for adapters that support rendering of +multi-dimensional arrays (tensors) in treescope. This can be used to add support +for a variety of array types, including numpy arrays, JAX arrays, and others. +""" + +from __future__ import annotations + +import abc +import dataclasses +from typing import Any, Generic, TypeVar + +import numpy as np + +T = TypeVar("T") + + +@dataclasses.dataclass(frozen=True) +class PositionalAxisInfo: + """Marks an axis as being an ordinary (positional) axis. + + Attributes: + axis_logical_index: The logical index of the axis. This is the index of the + axis in the underlying array data. + size: The size of the axis. + """ + + axis_logical_index: int + size: int + + def logical_key(self) -> int: + return self.axis_logical_index + + +@dataclasses.dataclass(frozen=True) +class NamedPositionlessAxisInfo: + """Marks an axis as being accessible by name only. + + Attributes: + axis_name: The name of the axis. + size: The size of the axis. + """ + + axis_name: Any + size: int + + def logical_key(self) -> Any: + return self.axis_name + + +@dataclasses.dataclass(frozen=True) +class NamedPositionalAxisInfo: + """Marks an axis as being accessible by name or by position. + + Attributes: + axis_logical_index: The logical index of the axis. This is the index of the + axis in the underlying array data. + axis_name: The name of the axis. + size: The size of the axis. + """ + + axis_logical_index: int + axis_name: int + size: int + + def logical_key(self) -> int: + return self.axis_logical_index + + +@dataclasses.dataclass(frozen=True) +class ArraySummary: + """Summary of the contents of an array. + + Any of the attributes of this summary can be None to indicate that the + corresponding statistic is not applicable to the array (either because of + dtype or because there are no finite values). + + Attributes: + finite_mean: The mean of the finite values in the array. + finite_stddev: The standard deviation of the finite values in the array. + finite_min: The minimum of the finite values in the array. + finite_max: The maximum of the finite values in the array. + count_zero: The number of zero values in the array. + count_nonzero: The number of nonzero values in the array. + count_nan: The number of NaN values in the array. + count_posinf: The number of positive infinity values in the array. + count_neginf: The number of negative infinity values in the array. + """ + + finite_mean: float | None + finite_stddev: float | None + finite_min: float | None + finite_max: float | None + count_zero: int | None + count_nonzero: float | None + count_nan: float | None + count_posinf: float | None + count_neginf: float | None + + +@dataclasses.dataclass(frozen=True) +class ShardingInfo: + """Summary of the sharding of an array. + + Attributes: + shard_shape: Shape of a single shard. + device_index_to_shard_slices: A mapping from device index to the tuple of + per-axis slices of the original array that is assigned to that device. The + length of each axis slice must match the `shard_shape` along that axis (or + be the full slice ``slice(None)``). + device_type: The type of device that the array is sharded across, as a + string (e.g. "CPU", "TPU", "GPU"). + fully_replicated: Whether the array is fully replicated across all devices. + """ + + shard_shape: tuple[int, ...] + device_index_to_shard_slices: dict[int, tuple[slice, ...]] + device_type: str + fully_replicated: bool = False + + +AxisInfo = ( + PositionalAxisInfo | NamedPositionlessAxisInfo | NamedPositionalAxisInfo +) + + +class NDArrayAdapter(abc.ABC, Generic[T]): + """An adapter to support rendering a multi-dimensional array (tensor) type.""" + + @abc.abstractmethod + def get_axis_info_for_array_data(self, array: T) -> tuple[AxisInfo, ...]: + """Returns axis information for each axis in the given array. + + This method should return a tuple with an AxisInfo entry for each axis in + the array. Array axes can be one of three types: + + * Positional axes have an index and a size, and can be accessed by + position. This is common in ordinary NDArrays. + * Named positionless axes have a name and a size, and can be accessed by + name only. This is how `penzai.core.named_axes` treats named axes. + * Named positional axes have an index, a name, and a size, and can be + accessed by either position or name. This is how PyTorch treats named + axes. + + Note that positional axes have an explicit "logical index", which may or may + not match their position in the underlying array data; this makes it + possible to support "views" of underlying array data that have a different + axis ordering than the original data. (`penzai.core.named_axes` uses this.) + + Args: + array: The array to get axis information for. + + Returns: + A tuple with an AxisInfo entry for each axis in the array. The ordering + must be consistent with the ordering expected by + `get_array_data_with_truncation`. + """ + raise NotImplementedError( + "Subclasses must override `get_axis_info_for_array_data`." + ) + + @abc.abstractmethod + def get_array_data_with_truncation( + self, + array: T, + mask: T | None, + edge_items_per_axis: tuple[int | None, ...], + ) -> tuple[np.ndarray, np.ndarray]: + """Returns a numpy array with truncated array (and mask) data. + + This method should construct a numpy array whose contents are a truncated + version of the given array's data; this array will be used to construct the + actual array visualization. It is also responsible for broadcasting the mask + appropriately and returning a compatible truncation of it. + + This method may be called many times when rendering a large structure of + arrays (once per array), so it should be as fast as possible. We suggest + doing truncation on an accelerator device and then copying the result, if + possible, to avoid unnecessary data transfer. + + Args: + array: The array to get data for. + mask: An optional mask array provided by the user, which should be + broadcast-compatible with `array`. (If it is not compatible, the user + has provided an invalid mask, and this method should raise an + informative exception.) Can be None if no mask is provided. + edge_items_per_axis: A tuple with one entry for each axis in the array. + Each entry is either the number of items to keep on each side of this + axis, or None to keep all items. The ordering will be consistent with + the axis order returned by `get_axis_info_for_array_data`, i.e. the + `k`th entry in `edge_items` corresponds to the `k`th entry in the axis + info tuple, regardless of the logical indices or axis names. + + Returns: + A tuple ``(truncated_data, truncated_mask)``. ``truncated_data`` should be + a numpy array with a truncated version of the given array's data. If the + ``k``th entry in ``edge_items`` is ``None``, the ``k``th axis should have + the same size as the ``size`` field of the ``k``th entry returned by + ``get_axis_info_for_array_data``. If the ``k``th entry in ``edge_items`` + is not ``None``, the ``k``th axis should have a size of ``edge_items[k] * + 2 + 1``, and the middle element can be arbitrary. ``truncated_mask`` + should be a numpy array with the same shape as ``truncated_data`` + containing a truncated, broadcasted version of the mask; the middle + element of the mask must be ``False`` for each truncated axis. + """ + raise NotImplementedError( + "Subclasses must override `get_array_data_with_truncation`." + ) + + @abc.abstractmethod + def get_array_summary(self, array: T, fast: bool) -> str: + """Summarizes the contents of the given array. + + The summary returned by this method will be used as a one-line summary of + the array in treescope when automatically visualized. + + If the ``fast`` argument is True, the method should return a summary that + can be computed quickly, ideally without any device computation. If it is + False, the method can return a more detailed summary, but it should still + be fast enough to be called many times when rendering a large structure of + arrays. + + Args: + array: The array to summarize. + fast: Whether to return a fast summary that can be computed without + expensive device computation. + + Returns: + A summary of the given array's contents. The summary should be a single + line of text. It will be wrapped between angle brackets (< and >) when + rendered. + """ + raise NotImplementedError("Subclasses must override `get_array_summary`.") + + def get_numpy_dtype(self, array: T) -> np.dtype | None: + """Returns the numpy dtype of the given array. + + This should match the dtype of the array returned by + `get_array_data_with_truncation`. + + Args: + array: The array to summarize. + + Returns: + The numpy dtype of the given array, or None if the array does not have a + numpy dtype. + """ + raise NotImplementedError("Subclasses must override `get_numpy_dtype`.") + + def get_sharding_info_for_array_data(self, array: T) -> ShardingInfo | None: + """Summarizes the sharding of the given array's data. + + The summary returned by this method will be used to render a sharding for + the array when automatic visualization is enabled. + + Args: + array: The array to summarize. + + Returns: + A summary of the given array's sharding, or None if it does not have a + sharding. + """ + # Default implementation: don't show any sharding information. + del array + return None + + def should_autovisualize(self, array: T) -> bool: + """Returns True if the given array should be automatically visualized. + + If this method returns True, the array will be automatically visualized + by the array visualizer if it is enabled. + + Args: + array: The array to possibly visualize. + + Returns: + True if the given array should be automatically visualized. + """ + del array + return True diff --git a/penzai/treescope/ndarray_summarization.py b/penzai/treescope/ndarray_summarization.py deleted file mode 100644 index 60f7be1..0000000 --- a/penzai/treescope/ndarray_summarization.py +++ /dev/null @@ -1,532 +0,0 @@ -# Copyright 2024 The Penzai Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Utilities for summarizing ndarray data.""" - -from __future__ import annotations - -from typing import Any, Mapping, Union - -import jax -import jax.numpy as jnp -import numpy as np -from penzai.treescope import context - - -def get_dtype_name(dtype) -> str: - """Safely extracts a name for a dtype.""" - # Render scalar type objects as their literal names. - if isinstance(dtype, type) and issubclass(dtype, np.generic): - return dtype.__name__ - # Render any other dtype-like objects as the name of the concrete dtype they - # convert to. - try: - return np.dtype(dtype).name - except TypeError: - return str(dtype) - - -def _is_subdtype(dtype, base) -> bool: - """Safely checks for dtype subtyping.""" - try: - return jnp.issubdtype(dtype, base) - except TypeError: - return False - - -@jax.jit -def _finite_mean_std_any(array): - """Helper to compute mean and standard deviation only over finite elements.""" - isfinite = jnp.isfinite(array) - inf_to_nan = jnp.where(isfinite, array, jnp.array(jnp.nan, dtype=array.dtype)) - mean = jnp.nanmean(inf_to_nan) - std = jnp.nanstd(inf_to_nan) - return mean, std, jnp.any(isfinite) - - -def summarize_ndarray( - array: Union[np.ndarray, jax.Array], include_shape_and_dtype: bool = True -) -> str: - """Summarizes an NDArray as a string. - - The summaries generated by this function have one of the forms: - - - floatX(shape) mean ±std [min, max] zero:? nonzero:? nan:? inf:? -inf:? - - intX(shape) [min, max] zero:? nonzero:? - - bool(shape) true:? false:? - - other_dtype(shape) - - where ? represents the number of elements that have that property, and any - property with no elements is omitted. When the array is empty, mean, std, - min, and max are also omitted. When there are infinite elements, the mean - and std are taken only over the finite ones, and nan is ignored for min/max. - - Args: - array: An array to summarize. - include_shape_and_dtype: Whether to include the shape and dtype in the - summary. If False, just summarizes the values. - - Returns: - A summary of the array. - """ - output_parts = [] - if include_shape_and_dtype: - output_parts.append(get_dtype_name(array.dtype)) - output_parts.append(repr(array.shape)) - - if safe_to_summarize(array): - if isinstance(array, np.ndarray): - xnp = np - is_numpy = True - elif isinstance(array, jax.Array): - # checked by safe_to_summarize - assert not isinstance(array, jax.core.Tracer) - xnp = jnp - is_numpy = False - else: - raise ValueError(f"Not a known array type: {array}") - - with jax.core.ensure_compile_time_eval(): - - if array.size: - is_floating = _is_subdtype(array.dtype, jnp.floating) - is_integer = _is_subdtype(array.dtype, jnp.integer) - is_bool = _is_subdtype(array.dtype, jnp.bool_) - - if is_floating: - if is_numpy: - isfinite = np.isfinite(array) - any_finite = np.any(isfinite) - inf_to_nan = np.where( - isfinite, array, np.array(np.nan, dtype=array.dtype) - ) - mean = np.nanmean(inf_to_nan) - std = np.nanstd(inf_to_nan) - else: - mean, std, any_finite = _finite_mean_std_any(array) - if any_finite: - output_parts.append(f" ≈{float(mean):.2} ±{float(std):.2}") - output_parts.append( - f" [≥{float(xnp.nanmin(array)):.2}," - f" ≤{float(xnp.nanmax(array)):.2}]" - ) - - if is_integer: - output_parts.append(f" [≥{xnp.min(array):_d}, ≤{xnp.max(array):_d}]") - - if is_floating or is_integer: - ct_zero = xnp.count_nonzero(array == 0) - if ct_zero: - output_parts.append(f" zero:{ct_zero:_d}") - - ct_nonzero = xnp.count_nonzero(array) - if ct_nonzero: - output_parts.append(f" nonzero:{ct_nonzero:_d}") - - if is_floating: - ct_nan = xnp.count_nonzero(xnp.isnan(array)) - if ct_nan: - output_parts.append(f" nan:{ct_nan:_d}") - - ct_inf = xnp.count_nonzero(xnp.isposinf(array)) - if ct_inf: - output_parts.append(f" inf:{ct_inf:_d}") - - ct_neginf = xnp.count_nonzero(xnp.isneginf(array)) - if ct_neginf: - output_parts.append(f" -inf:{ct_neginf:_d}") - - if is_bool: - ct_true = xnp.count_nonzero(array) - if ct_true: - output_parts.append(f" true:{ct_true:_d}") - - ct_false = xnp.count_nonzero(xnp.logical_not(array)) - if ct_false: - output_parts.append(f" false:{ct_false:_d}") - - return "".join(output_parts) - - -summarization_threshold: context.ContextualValue[Mapping[str, int | None]] = ( - context.ContextualValue( - module=__name__, - qualname="summarization_threshold", - initial_value={ - "tpu": 1_000_000_000, - "gpu": 10_000_000, - "default": 100_000, - }, - ) -) -"""Threshold for summarization of NDArrays for each backend. - -This threshold determines the largest number of elements we will -summarize with summary statistics (e.g. mean, standard deviation) -when rendering in treescope. Larger values may make it slower to -display large NDArrays. - -Each key should be the name of a JAX array platform, e.g. "cpu" or -"tpu". It can also be "numpy" to refer to Numpy arrays, or "default" -to refer to any other accelerator. The value is the size of the -array at which point we avoid showing summary statistics. `None` -means no limit. - -This configuration argument is intended to be set at the top level -by the user, e.g. in IPython. -""" - - -def safe_to_summarize(array: Union[np.ndarray, jax.Array, Any]) -> bool: - """Checks if the array is safe to summarize (not a tracer and not replicated).""" - thresh_dict = summarization_threshold.get() - if isinstance(array, jax.core.Tracer): - return False - if isinstance(array, np.ndarray): - thresh = thresh_dict.get("numpy") - if thresh is None: - thresh = thresh_dict.get("cpu") - if thresh is None: - thresh = thresh_dict["default"] - return thresh is None or array.size < thresh - if isinstance(array, jax.Array): - if array.is_deleted(): - return False - if not ( - getattr(array, "is_fully_addressable", False) - or getattr(array, "is_fully_replicated", False) - ): - return False - [platform] = set(device.platform for device in array.devices()) - thresh = thresh_dict.get(platform) - if thresh is None: - thresh = thresh_dict["default"] - return thresh is None or array.size < thresh - return False - - -def _truncate_part_with_slices( - array: jax.Array, - mask: jax.Array, - prefix_slices: tuple[slice, ...], - remaining_edge_items_per_axis: tuple[int | None, ...], -) -> tuple[jax.Array, jax.Array]: - """Helper to truncate names of an array. - - Args: - array: An array to truncate. - mask: Mask array, which must have the same number of dimensions as `array`, - and whose axis sizes must be either 1 or the same as that axis of `array` - (e.g. they are broadcast compatible). - prefix_slices: Slices to apply to each axis of `array` and `mask`, starting - at axis 0, which we have already computed. - remaining_edge_items_per_axis: Number of edge items to keep for each axis, - ignoring any axes whose slices are already computed in `prefix_slices`. - - Returns: - Truncated array and mask, which will both be the same shape. - """ - if not remaining_edge_items_per_axis: - # Perform the base case slice. - assert len(prefix_slices) == len(array.shape) - truncated_array = array[prefix_slices] - - valid_mask_slices = tuple( - slice(None) if mask.shape[i] == 1 else array_slice - for i, array_slice in enumerate(prefix_slices) - ) - truncated_mask = jnp.broadcast_to( - jnp.array(mask[valid_mask_slices]), truncated_array.shape - ) - return truncated_array, truncated_mask - - # Recursive step: extract one name, run the function on each side, and - # concatenate. - axis = len(prefix_slices) - edge_items = remaining_edge_items_per_axis[0] - if edge_items is None: - # Don't need to slice. - return _truncate_part_with_slices( - array, - mask, - prefix_slices=prefix_slices + (slice(None),), - remaining_edge_items_per_axis=remaining_edge_items_per_axis[1:], - ) - else: - assert array.shape[axis] > 2 * edge_items - result_a, valid_a = _truncate_part_with_slices( - array, - mask, - prefix_slices=prefix_slices + (slice(None, edge_items),), - remaining_edge_items_per_axis=remaining_edge_items_per_axis[1:], - ) - result_b, valid_b = _truncate_part_with_slices( - array, - mask, - prefix_slices=prefix_slices + (slice(-edge_items, None),), - remaining_edge_items_per_axis=remaining_edge_items_per_axis[1:], - ) - padding_shape = list(result_a.shape) - padding_shape[axis] = 1 - result = jnp.concatenate( - [result_a, jnp.zeros(padding_shape, result_a.dtype), result_b], - axis=axis, - ) - valid = jnp.concatenate( - [valid_a, jnp.zeros(padding_shape, valid_a.dtype), valid_b], axis=axis - ) - return result, valid - - -def truncate_array_and_mask( - array: jax.Array, - mask: jax.Array, - edge_items_per_axis: tuple[int | None, ...], -) -> tuple[jax.Array, jax.Array]: - """Truncates an array along the given axis names. - - Args: - array: Array to truncate. - mask: Mask array, which must have the same number of dimensions as `array`, - and whose axis sizes must be either 1 or the same as that axis of `array` - (e.g. they are broadcast compatible). - edge_items_per_axis: Number of edge items to keep for each axis, ignoring - any axes whose slices are already computed in `prefix_slices`. - - Returns: - A tuple containing a truncated version of the array along with a valid mask. - Values taken from the original array have the valid mask as True, and there - is one extra element in the middle with valid as False (standing in for the - omitted elements). The return value is always fully replicated, because - we cannot guarantee that it is evenly sharded across devices, and this - function is usually used immediately before copying to the host. - """ - sharding_kwargs = {} - if hasattr(array, "sharding") and hasattr( - array.sharding, "_device_assignment" - ): - # _truncate_part_with_slices usually returns slices that have odd - # dimensions, which aren't divisible by most shardings. Unfortunately, - # the XLA GSPMD partitioner sometimes still infers a sharding over one of - # these axes, which then leads to partitioning errors in JAX whenever we - # try to `device_get` the resulting array or call any additional operations - # on it. To avoid this, we'd like to tell JAX to always produce an output - # that is not sharded over any axis. Unfortunately, this is difficult - # because JAX requires the in_shardings and out_shardings to have the same - # devices in the same internal order, and at the time of writing JAX does - # not provide any public API to look up the order of the devices in a - # sharding (it allows looking up the device *set*, but not their order). - # Whether or not this error happens seems to be somewhat nondeterministic. - # To avoid this, we use the private property `_device_assignment` of - # each sharding in order to figure out what device order it has, and then - # explicitly request a fully-replicated output that is definitely safe to - # retrieve. - sharding_kwargs["out_shardings"] = ( - jax.sharding.GSPMDSharding.get_replicated( - array.sharding._device_assignment # pylint: disable=protected-access - ) - ) - fn = jax.jit( - _truncate_part_with_slices, static_argnums=(2, 3), **sharding_kwargs - ) - return fn(array, mask, (), edge_items_per_axis) - - -def infer_balanced_truncation( - shape: tuple[int, ...], - maximum_size: int, - cutoff_size_per_axis: int, - minimum_edge_items: int, - doubling_bonus: float = 10.0, -) -> tuple[int | None, ...]: - """Infers a balanced truncation from a shape. - - This function computes a set of truncation sizes for each axis of the array - such that it obeys the constraints about array and axis sizes, while also - keeping the relative proportions of the array consistent (e.g. we keep more - elements along axes that were originally longer). This means that the aspect - ratio of the truncated array will still resemble the aspect ratio of the - original array. - - To avoid very-unbalanced renderings and truncate longer axes more than short - ones, this function truncates based on the square-root of the axis size by - default. - - Args: - shape: The shape of the array we are truncating. - maximum_size: Maximum number of elements of an array to show. Arrays larger - than this will be truncated along one or more axes. - cutoff_size_per_axis: Maximum number of elements of each individual axis to - show without truncation. Any axis longer than this will be truncated, with - their visual size increasing logarithmically with the true axis size - beyond this point. - minimum_edge_items: How many values to keep along each axis for truncated - arrays. We may keep more than this up to the budget of maximum_size. - doubling_bonus: Number of elements to add to each axis each time it doubles - beyond `cutoff_size_per_axis`. Used to make longer axes appear visually - longer while still keeping them a reasonable size. - - Returns: - A tuple of edge sizes. Each element corresponds to an axis in `shape`, - and is either `None` (for no truncation) or an integer (corresponding to - the number of elements to keep at the beginning and and at the end). - """ - shape_arr = np.array(list(shape)) - remaining_elements_to_divide = maximum_size - edge_items_per_axis = {} - # Order our shape from smallest to largest, since the smallest axes will - # require the least amount of truncation and will have the most stringent - # constraints. - sorted_axes = np.argsort(shape_arr) - sorted_shape = shape_arr[sorted_axes] - - # Figure out maximum sizes based on the cutoff - cutoff_adjusted_maximum_sizes = np.where( - sorted_shape <= cutoff_size_per_axis, - sorted_shape, - cutoff_size_per_axis - + doubling_bonus * np.log2(sorted_shape / cutoff_size_per_axis), - ) - - # Suppose we want to make a scaled version of the array with relative - # axis sizes - # s0, s1, s2, ... - # The total size is then - # size = (c * s0) * (c * s1) * (c * s2) * ... - # log(size) = ndim * log(c) + [ log s0 + log s1 + log s2 + ... ] - # If we have a known final size we want to reach, we can solve for c as - # c = exp( (log size - [ log s0 + log s1 + log s2 + ... ]) / ndim ) - axis_proportions = np.sqrt(sorted_shape) - log_axis_proportions = np.log(axis_proportions) - for i in range(len(sorted_axes)): - original_axis = sorted_axes[i] - size = shape_arr[original_axis] - # If we truncated this axis and every axis after it proportional to - # their weights, how small of an axis size would we need for this - # axis? - log_c = ( - np.log(remaining_elements_to_divide) - np.sum(log_axis_proportions[i:]) - ) / (len(shape) - i) - soft_limit_for_this_axis = np.exp(log_c + log_axis_proportions[i]) - cutoff_limit_for_this_axis = np.floor( - np.minimum( - soft_limit_for_this_axis, - cutoff_adjusted_maximum_sizes[i], - ) - ) - if size <= 2 * minimum_edge_items + 1 or size <= cutoff_limit_for_this_axis: - # If this axis is already smaller than the minimum size it would have - # after truncation, there's no reason to truncate it. - # But pretend we did, so that other axes still grow monotonically if - # their axis sizes increase. - remaining_elements_to_divide = ( - remaining_elements_to_divide / soft_limit_for_this_axis - ) - edge_items_per_axis[original_axis] = None - elif cutoff_limit_for_this_axis < 2 * minimum_edge_items + 1: - # If this axis is big enough to truncate, but our naive target size is - # smaller than the minimum allowed truncation, we should truncate it - # to the minimum size allowed instead. - edge_items_per_axis[original_axis] = minimum_edge_items - remaining_elements_to_divide = remaining_elements_to_divide / ( - 2 * minimum_edge_items + 1 - ) - else: - # Otherwise, truncate it and all remaining axes based on our target - # truncations. - for j in range(i, len(sorted_axes)): - visual_size = np.floor( - np.minimum( - np.exp(log_c + log_axis_proportions[j]), - cutoff_adjusted_maximum_sizes[j], - ) - ) - edge_items_per_axis[sorted_axes[j]] = int(visual_size // 2) - break - - return tuple( - edge_items_per_axis[orig_axis] for orig_axis in range(len(shape)) - ) - - -def compute_truncated_shape( - shape: tuple[int, ...], - edge_items: tuple[int | None, ...], -) -> tuple[int, ...]: - """Computes the shape of a truncated array. - - This can be used to estimate the size of an array visualization after it has - been truncated by `infer_balanced_truncation`. - - Args: - shape: The original array shape. - edge_items: Number of edge items to keep along each axis. - - Returns: - The shape of the truncated array. - """ - return tuple( - orig if edge is None else 2 * edge + 1 - for orig, edge in zip(shape, edge_items) - ) - - -def faster_array_repr(array: np.ndarray | jax.Array) -> str: - """Computes ``repr(array)``, only copying the rendered array elements. - - ``repr(array)`` on a very large jax Array can be slow, because it copies the - entire array to host memory even when only a few elements are actually needed. - We can avoid this by truncating the array on device before fetching it. - - Args: - array: The array to summarize. - - Returns: - A string representation of the array. May differ slightly from the ordinary - ``repr``, but should contain the same elements. - """ - if isinstance(array, np.ndarray): - return repr(array) - else: - assert isinstance(array, jax.Array) - - if array.size < np.get_printoptions()["threshold"]: - return repr(array) - - if array.aval is not None and array.aval.weak_type: - dtype_str = f"dtype={array.dtype.name}, weak_type=True)" - else: - dtype_str = f"dtype={array.dtype.name})" - - edgeitems = np.get_printoptions()["edgeitems"] - edge_items_per_axis = [] - for size in array.shape: - if size > 2 * edgeitems + 1: - edge_items_per_axis.append(edgeitems) - else: - edge_items_per_axis.append(None) - array_edges, _ = truncate_array_and_mask( - array, - jnp.ones((1,) * array.ndim, dtype=jnp.bool_), - edge_items_per_axis=tuple(edge_items_per_axis), - ) - prefix = "Array(" - datastring = np.array2string( - np.array(array_edges), - prefix=prefix, - suffix=",", - separator=", ", - threshold=0, - edgeitems=edgeitems, - ) - return f"{prefix}{datastring}, {dtype_str}" diff --git a/penzai/treescope/repr_lib.py b/penzai/treescope/repr_lib.py index d754810..e8d3a0f 100644 --- a/penzai/treescope/repr_lib.py +++ b/penzai/treescope/repr_lib.py @@ -31,6 +31,7 @@ from penzai.treescope import renderer from penzai.treescope.foldable_representation import basic_parts from penzai.treescope.foldable_representation import common_structures +from penzai.treescope.foldable_representation import common_styles from penzai.treescope.foldable_representation import part_interface @@ -251,3 +252,49 @@ def __penzai_repr__(self, path, subtree_renderer): path=path, background_color=color, ) + + +def render_enumlike_item( + object_type: type[Any], + item_name: str, + item_value: Any, + path: str | None, + subtree_renderer: renderer.TreescopeSubtreeRenderer, +) -> ( + part_interface.RenderableTreePart + | part_interface.RenderableAndLineAnnotations +): + """Renders a value of an enum-like type (e.g. like `enum.Enum`). + + This method can be used to render a value of a type that acts like a Python + enum, in that there is a finite set of possible instances of the type, each of + which have a name and a value, and where the instance can be accessed as an + attribute (e.g. ``mytype.FOO`` is an instance of ``mytype`` with name "FOO"). + + Args: + object_type: The type of the object. + item_name: The name of the item. + item_value: The value of the item (``{object_type}.{item_name}.value``). + path: The path to the object. When `render_object_constructor` is called + from `__treescope_repr__`, this should come from the `path` argument to + `__treescope_repr__`. + subtree_renderer: The renderer to use to render subtrees. When + `render_object_constructor` is called from `__treescope_repr__`, this + should come from the `subtree_renderer` argument to `__treescope_repr__`. + + Returns: + A rendering of the object, suitable for returning from `__treescope_repr__`. + """ + del subtree_renderer + return common_structures.build_one_line_tree_node( + basic_parts.siblings_with_annotations( + common_structures.maybe_qualified_type_name(object_type), + "." + item_name, + extra_annotations=[ + common_styles.CommentColor( + basic_parts.Text(f" # value: {repr(item_value)}") + ) + ], + ), + path, + ) diff --git a/penzai/treescope/treescope_ipython.py b/penzai/treescope/treescope_ipython.py index 3c4161e..2c6e61b 100644 --- a/penzai/treescope/treescope_ipython.py +++ b/penzai/treescope/treescope_ipython.py @@ -17,8 +17,6 @@ import contextlib from typing import Any -import jax.numpy as jnp - from penzai.treescope import array_autovisualizer from penzai.treescope import autovisualize from penzai.treescope import context @@ -153,9 +151,7 @@ def _render_for_ipython(value): elif isinstance(value, ipython_display.DisplayObject) or ( object_inspection.safely_get_real_method(value, "_repr_pretty_") and not ( - object_inspection.safely_get_real_method( - value, "__penzai_repr__" - ) + object_inspection.safely_get_real_method(value, "__penzai_repr__") or object_inspection.safely_get_real_method( value, "__penzai_root_repr__" ) @@ -221,9 +217,12 @@ def _render_as_text_oneline(value, p, cycle): p.break_() p.text(line) + # Override the text formatter to render jax.Array without copying the entire + # array. cur_text_formatter = display_formatter.formatters["text/plain"] - arrayimpl = type(jnp.ones([0])) - cur_text_formatter.for_type(arrayimpl, _render_as_text_oneline) + cur_text_formatter.for_type_by_name( + "jaxlib.xla_extension", "ArrayImpl", _render_as_text_oneline + ) # Make sure the HTML formatter runs first, so streaming outputs work # correctly. diff --git a/penzai/treescope/type_registries.py b/penzai/treescope/type_registries.py new file mode 100644 index 0000000..2ba2054 --- /dev/null +++ b/penzai/treescope/type_registries.py @@ -0,0 +1,190 @@ +# Copyright 2024 The Penzai Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Global registries for adding treescope support to external types. + +This module defines global registries which can be used to add treescope support +for new types that it does not natively support, or types defined in libraries +that may not be installed. + +These registries are intended to be used by either the module that defines the +objects being registered, or by `penzai.treescope` itself. If a type already has +a global registry entry, you should generally avoid modifying it. This is +because the registries are defined as global variables, without a mechanism for +resolving conflicts between multiple entries. If you would like to customize the +rendering of a type that treescope already supports, you should generally either +define your own treescope renderer object and use it directly, or override the +default renderer or autovisualizer defined in +`penzai.treescope.default_renderer` using the `set_scoped` and `set_interactive` +methods. This will take precedence over any global registry entries. +""" + +from __future__ import annotations + +import abc +import importlib +import sys +import types +from typing import Any, TypeVar + +from penzai.treescope import ndarray_adapters +from penzai.treescope import renderer + + +T = TypeVar("T") + + +NDARRAY_ADAPTER_REGISTRY: dict[ + type[Any], ndarray_adapters.NDArrayAdapter[Any] +] = {} +"""Global registry of NDArray adapters, keyed by type. + +The value for a given type should be an instance of `NDArrayAdapter`, and will +be used to render any arrays of that type. + +If a type is not present in this registry, the entries of that type's `__mro__` +will also be searched. Additionally, virtual base classes will be checked if +the abtract base class is in `VIRTUAL_BASE_CLASSES`. +""" + +TREESCOPE_HANDLER_REGISTRY: dict[type[Any], renderer.TreescopeNodeHandler] = {} +"""Global registry of custom treescope handlers, keyed by type. + +If a type is not present in this registry, the entries of that type's `__mro__` +will also be searched. Additionally, virtual base classes will be checked if +the abtract base class is in `VIRTUAL_BASE_CLASSES`. + +The handler itself will be passed the object, and can either return a treescope +rendering or the `NotImplemented` sentinel, just like an ordinary treescope +handler. + +This registry is primarily intended to add treescope support to custom types +without requiring the type to be modified. If you can modify the type, you can +instead define the `__treescope_repr__` method on the type itself; this has +precedence over the registry. +""" + +VIRTUAL_BASE_CLASSES: list[type[abc.ABC]] = [] +"""List of abstract base classes that should be checked for virtual subclasses. + +This list should contain a list of abstract base classes that have virtual +subclasses (defined using the ``.register`` method), and which appear in the +global type registries `NDARRAY_ADAPTER_REGISTRY` or +`TREESCOPE_HANDLER_REGISTRY`. If a type is a subclass of one of these base +classes, the corresponding registry entry will be used. +""" + +IMMUTABLE_TYPES_REGISTRY: dict[type[Any], bool] = { + types.FunctionType: True, + types.MethodType: True, + types.ModuleType: True, + type: True, + type(None): True, + type(NotImplemented): True, + type(Ellipsis): True, +} +"""Global registry of non-hashable types that are considered immutable. + +By default, treescope will detect repeated values of any non-hashable type and +render a warning that they are shared across different parts of the tree. This +is intended to help catch bugs in which a value is accidentally shared between +different parts of a tree, which could cause problems when the tree is mutated. + +Some types are not hashable, but are still immutable. For instance, `jax.Array` +is immutable and can be safely shared. This set is used to suppress the +"shared value" warning for these types. +""" + +_LAZY_MODULE_SETUP_FUNCTIONS: dict[str, tuple[str, str]] = { + # Note: Numpy is always imported because it is used by the core array + # rendering system, but we define its setup function here as well for + # consistency with the other array modules. + "numpy": ( + "penzai.treescope.handlers.interop.numpy_support", + "set_up_treescope", + ), + "jax": ( + "penzai.treescope.handlers.interop.jax_support", + "set_up_treescope", + ), + "penzai.core": ( + "penzai.treescope.handlers.interop.penzai_core_support", + "set_up_treescope", + ), + "torch": ( + "penzai.treescope.handlers.interop.torch_support", + "set_up_treescope", + ), +} +"""Delayed setup functions that run only once a module is imported. + +This dictionary maps module name keys to a ``(setup_module, setup_attribute)`` +tuple, where ``setup_module`` is the name of a module and ``setup_attribute`` +is the name of a zero-argument function in that module that can be used to set +up support for the key module. + +When `update_registries_for_imports` is called (usually at the start of +rendering an object), if any of the key modules have already been imported, +the corresponding setup module will be imported as well +and the setup attribute will be called. This function can then modify the +global values `NDARRAY_ADAPTER_REGISTRY`, `TREESCOPE_HANDLER_REGISTRY`, +`IMMUTABLE_TYPES_SET`, or `VIRTUAL_BASE_CLASSES` to add support for this module. +It can also register the public API of the module using +`penzai.treescope.canonical_aliases`, if applicable. + +After being called, the setup function will be removed from this dictionary. +""" + + +def update_registries_for_imports(): + """Updates registries by running setup logic for newly-imported modules.""" + for module_name, (setup_module, setup_attribute) in tuple( + _LAZY_MODULE_SETUP_FUNCTIONS.items() + ): + if module_name in sys.modules: + module = importlib.import_module(setup_module) + setup_fn = getattr(module, setup_attribute) + setup_fn() + del _LAZY_MODULE_SETUP_FUNCTIONS[module_name] + + +def lookup_by_mro( + registry: dict[type[Any], T], candidate_type: type[Any] +) -> T | None: + """Looks up the given type in the given registry, or in its base classes. + + This function will first run any lazy setup functions for the module of the + given type, if applicable. It will then look up the given type in the given + registry, or in the registry for any base class of the given type, in method + resolution order. + + If no concrete base class is found in the registry, each of the entries of + `VIRTUAL_BASE_CLASSES` will be checked to see if it is a virtual base class. + The first such base class that has an entry in the registry will be used. + + Args: + registry: The registry to look up in. + candidate_type: The type to look up. + + Returns: + The value associated with the given type (or a base class of it) in the + given registry, or None if no entry was found. + """ + for supertype in candidate_type.__mro__: + if supertype in registry: + return registry[supertype] + for base_class in VIRTUAL_BASE_CLASSES: + if issubclass(candidate_type, base_class) and base_class in registry: + return registry[base_class] + return None diff --git a/tests/treescope_canonical_aliases_test.py b/tests/treescope/canonical_aliases_test.py similarity index 98% rename from tests/treescope_canonical_aliases_test.py rename to tests/treescope/canonical_aliases_test.py index 2d626fc..6bafeb4 100644 --- a/tests/treescope_canonical_aliases_test.py +++ b/tests/treescope/canonical_aliases_test.py @@ -22,14 +22,15 @@ import numpy import penzai.core.named_axes import penzai.core.struct -from tests import fixtures as fixture_parent -from tests.fixtures import treescope_examples_fixture as fixture_lib +from tests.treescope import fixtures as fixture_parent +from tests.treescope.fixtures import treescope_examples_fixture as fixture_lib from penzai.treescope import canonical_aliases +from penzai.treescope import type_registries def fresh_canonical_aliases(): return canonical_aliases._alias_environment.set_scoped( - canonical_aliases.CanonicalAliasEnvironment({}, []) + canonical_aliases.CanonicalAliasEnvironment({}) ) @@ -420,7 +421,7 @@ def test_local_aliases(self): ), ) def test_default_canonical_aliases(self, target, alias_string): - canonical_aliases.update_lazy_aliases() + type_registries.update_registries_for_imports() self.assertEqual(str(canonical_aliases.lookup_alias(target)), alias_string) diff --git a/tests/fixtures/__init__.py b/tests/treescope/fixtures/__init__.py similarity index 100% rename from tests/fixtures/__init__.py rename to tests/treescope/fixtures/__init__.py diff --git a/tests/fixtures/treescope_examples_fixture.py b/tests/treescope/fixtures/treescope_examples_fixture.py similarity index 86% rename from tests/fixtures/treescope_examples_fixture.py rename to tests/treescope/fixtures/treescope_examples_fixture.py index eefa52a..aedac02 100644 --- a/tests/fixtures/treescope_examples_fixture.py +++ b/tests/treescope/fixtures/treescope_examples_fixture.py @@ -27,6 +27,7 @@ import jax from penzai import pz +import torch class MyTestEnum(enum.Enum): @@ -196,3 +197,27 @@ def output_structure(self): @pz.checked_layer_call def __call__(self, value: int) -> int: return value + + +class SomePyTorchModule(torch.nn.Module): + """A basic PyTorch module to test rendering.""" + + def __init__(self): + super().__init__() + # Attributes + self.attr_one = 123 + self.attr_two = "abc" + # Child modules + self.linear = torch.nn.Linear(10, 10) + self.mod_list = torch.nn.ModuleList( + [torch.nn.LayerNorm(10), torch.nn.SiLU()] + ) + # Parameters + self.foo = torch.nn.Parameter(torch.ones(5)) + # Buffers + self.register_buffer("bar", torch.zeros(5)) + + @classmethod + def build(cls): + torch.random.manual_seed(1234) + return cls() diff --git a/tests/treescope/ndarray_adapters_test.py b/tests/treescope/ndarray_adapters_test.py new file mode 100644 index 0000000..04440f6 --- /dev/null +++ b/tests/treescope/ndarray_adapters_test.py @@ -0,0 +1,285 @@ +# Copyright 2024 The Penzai Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for NDArray adapters and array visualization.""" + +from absl.testing import absltest +from absl.testing import parameterized +import jax.numpy as jnp +import numpy as np +from penzai.core import named_axes +from penzai.treescope import array_autovisualizer +from penzai.treescope import arrayviz +from penzai.treescope import autovisualize +from penzai.treescope import default_renderer +from penzai.treescope import ndarray_adapters +from penzai.treescope import type_registries +import torch + + +class NdarrayAdaptersTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + type_registries.update_registries_for_imports() + + @parameterized.product( + array_type=["jax", "torch", "NamedArray", "NamedArrayView"], + dtype=[np.int32, np.float32, np.bool_], + ) + def test_adapter_positional_numpy_consistency(self, array_type, dtype): + + reshaped_arange = np.arange(19 * 23).reshape((19, 23)) + + if dtype == np.bool_: + array_np = (reshaped_arange % 2) == 0 + else: + array_np = reshaped_arange.astype(dtype) + + mask_np = (reshaped_arange % 3) != 0 + + if array_type == "jax": + array = jnp.array(array_np) + mask = jnp.array(mask_np) + elif array_type == "torch": + array = torch.tensor(array_np) + mask = torch.tensor(mask_np) + elif array_type == "NamedArray": + array = named_axes.wrap(array_np) + mask = ( + named_axes.wrap(mask_np.transpose((1, 0))) + .tag("a", "b") + .untag("b", "a") + ) + assert array.positional_shape == mask.positional_shape + elif array_type == "NamedArrayView": + array = named_axes.wrap(array_np).as_namedarrayview() + mask = ( + named_axes.wrap(mask_np.transpose((1, 0))) + .tag("a", "b") + .untag("b", "a") + ).with_positional_prefix() + assert array.positional_shape == mask.positional_shape + else: + raise ValueError(f"Unsupported array_type: {array_type}") + + np_adapter = type_registries.lookup_by_mro( + type_registries.NDARRAY_ADAPTER_REGISTRY, np.ndarray + ) + self.assertIsNotNone(np_adapter) + + cur_adapter = type_registries.lookup_by_mro( + type_registries.NDARRAY_ADAPTER_REGISTRY, type(array) + ) + self.assertIsNotNone(cur_adapter) + + with self.subTest("axis_info"): + self.assertEqual( + np_adapter.get_axis_info_for_array_data(array_np), + cur_adapter.get_axis_info_for_array_data(array), + ) + + with self.subTest("data_untruncated_unmasked"): + data_out_np, mask_out_np = np_adapter.get_array_data_with_truncation( + array=array_np, mask=None, edge_items_per_axis=(None, None) + ) + data_out, mask_out = cur_adapter.get_array_data_with_truncation( + array=array, mask=None, edge_items_per_axis=(None, None) + ) + np.testing.assert_array_equal(data_out_np, data_out) + np.testing.assert_array_equal(mask_out_np, mask_out) + + with self.subTest("data_untruncated_masked"): + data_out_np, mask_out_np = np_adapter.get_array_data_with_truncation( + array=array_np, mask=mask_np, edge_items_per_axis=(None, None) + ) + data_out, mask_out = cur_adapter.get_array_data_with_truncation( + array=array, mask=mask, edge_items_per_axis=(None, None) + ) + np.testing.assert_array_equal(data_out_np, data_out) + np.testing.assert_array_equal(mask_out_np, mask_out) + + with self.subTest("data_semitruncated_unmasked"): + data_out_np, mask_out_np = np_adapter.get_array_data_with_truncation( + array=array_np, mask=None, edge_items_per_axis=(2, None) + ) + data_out, mask_out = cur_adapter.get_array_data_with_truncation( + array=array, mask=None, edge_items_per_axis=(2, None) + ) + np.testing.assert_array_equal(data_out_np, data_out) + np.testing.assert_array_equal(mask_out_np, mask_out) + + with self.subTest("data_semitruncated_unmasked"): + data_out_np, mask_out_np = np_adapter.get_array_data_with_truncation( + array=array_np, mask=mask_np, edge_items_per_axis=(2, None) + ) + data_out, mask_out = cur_adapter.get_array_data_with_truncation( + array=array, mask=mask, edge_items_per_axis=(2, None) + ) + np.testing.assert_array_equal(data_out_np, data_out) + np.testing.assert_array_equal(mask_out_np, mask_out) + + with self.subTest("data_truncated_unmasked"): + data_out_np, mask_out_np = np_adapter.get_array_data_with_truncation( + array=array_np, mask=None, edge_items_per_axis=(2, 4) + ) + data_out, mask_out = cur_adapter.get_array_data_with_truncation( + array=array, mask=None, edge_items_per_axis=(2, 4) + ) + np.testing.assert_array_equal(data_out_np, data_out) + np.testing.assert_array_equal(mask_out_np, mask_out) + + with self.subTest("data_truncated_unmasked"): + data_out_np, mask_out_np = np_adapter.get_array_data_with_truncation( + array=array_np, mask=mask_np, edge_items_per_axis=(2, 4) + ) + data_out, mask_out = cur_adapter.get_array_data_with_truncation( + array=array, mask=mask, edge_items_per_axis=(2, 4) + ) + np.testing.assert_array_equal(data_out_np, data_out) + np.testing.assert_array_equal(mask_out_np, mask_out) + + with self.subTest("data_truncated_broadcast_mask"): + data_out_np, mask_out_np = np_adapter.get_array_data_with_truncation( + array=array_np, mask=mask_np[0, :], edge_items_per_axis=(3, 7) + ) + data_out, mask_out = cur_adapter.get_array_data_with_truncation( + array=array, mask=mask[0, :], edge_items_per_axis=(3, 7) + ) + np.testing.assert_array_equal(data_out_np, data_out) + np.testing.assert_array_equal(mask_out_np, mask_out) + + for fast in (True, False): + with self.subTest("summary_fast" if fast else "summary_slow"): + summary_info_np = np_adapter.get_array_summary(array_np, fast=True) + summary_info = cur_adapter.get_array_summary(array, fast=True) + if array_type in ("NamedArray", "NamedArrayView"): + summary_info_np = ( + summary_info_np.replace("(19, 23)", "(19, 23 |)") + + " (wrapping jax.Array)" + ) + self.assertEqual( + summary_info_np.split(" ", 1)[1], summary_info.split(" ", 1)[1] + ) + + def test_penzai_named_axes_info(self): + adapter = type_registries.lookup_by_mro( + type_registries.NDARRAY_ADAPTER_REGISTRY, named_axes.NamedArray + ) + data = np.arange(19 * 23).reshape((19, 23)) + + array = named_axes.wrap(data).tag("foo", "bar") + self.assertEqual( + adapter.get_axis_info_for_array_data(array), + ( + ndarray_adapters.NamedPositionlessAxisInfo("foo", 19), + ndarray_adapters.NamedPositionlessAxisInfo("bar", 23), + ), + ) + + array = named_axes.wrap(data).tag("foo", "bar").untag("bar") + self.assertEqual( + adapter.get_axis_info_for_array_data(array), + ( + ndarray_adapters.NamedPositionlessAxisInfo("foo", 19), + ndarray_adapters.PositionalAxisInfo(0, 23), + ), + ) + + array = named_axes.wrap(data).tag("foo", "bar").untag("bar", "foo") + self.assertEqual( + adapter.get_axis_info_for_array_data(array), + ( + ndarray_adapters.PositionalAxisInfo(1, 19), + ndarray_adapters.PositionalAxisInfo(0, 23), + ), + ) + + def test_pytorch_named_axes_info(self): + adapter = type_registries.lookup_by_mro( + type_registries.NDARRAY_ADAPTER_REGISTRY, torch.Tensor + ) + data = np.arange(19 * 23).reshape((19, 23)) + array = torch.tensor(data).rename("foo", None) + self.assertEqual( + adapter.get_axis_info_for_array_data(array), + ( + ndarray_adapters.NamedPositionalAxisInfo(0, "foo", 19), + ndarray_adapters.PositionalAxisInfo(1, 23), + ), + ) + + @parameterized.product( + array_type=[ + "numpy", + "jax", + "torch", + "NamedArray:positional", + "NamedArray:named", + "NamedArrayView", + ], + dtype=[np.int32, np.float32, np.bool_], + ) + def test_array_rendering_without_error(self, array_type, dtype): + reshaped_arange = np.arange(19 * 23).reshape((19, 23)) + + if dtype == np.bool_: + array_np = (reshaped_arange % 2) == 0 + else: + array_np = reshaped_arange.astype(dtype) + + if array_type == "numpy": + array = array_np + elif array_type == "jax": + array = jnp.array(array_np) + elif array_type == "torch": + array = torch.tensor(array_np) + elif array_type == "NamedArray:positional": + array = named_axes.wrap(array_np) + elif array_type == "NamedArray:named": + array = named_axes.wrap(array_np).tag("a", "b") + elif array_type == "NamedArrayView": + array = ( + named_axes.wrap(array_np.transpose((1, 0))).tag("a", "b").untag("b") + ) + else: + raise ValueError(f"Unsupported array_type: {array_type}") + + with self.subTest("explicit_unmasked"): + res = arrayviz.render_array(array) + self.assertIsInstance(res, arrayviz.ArrayvizRendering) + + with self.subTest("explicit_masked"): + res = arrayviz.render_array(array, valid_mask=array > 100) + self.assertIsInstance(res, arrayviz.ArrayvizRendering) + + with self.subTest("explicit_masked_truncated"): + res = arrayviz.render_array( + array, valid_mask=array > 100, truncate=True, maximum_size=100 + ) + self.assertIsInstance(res, arrayviz.ArrayvizRendering) + + with self.subTest("automatic"): + with autovisualize.active_autovisualizer.set_scoped( + array_autovisualizer.ArrayAutovisualizer() + ): + res = default_renderer.render_to_html( + array, ignore_exceptions=False, compressed=False + ) + self.assertIsInstance(res, str) + self.assertIn("arrayviz", res) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/treescope_renderer_test.py b/tests/treescope/renderer_test.py similarity index 87% rename from tests/treescope_renderer_test.py rename to tests/treescope/renderer_test.py index d0644b5..429dcd6 100644 --- a/tests/treescope_renderer_test.py +++ b/tests/treescope/renderer_test.py @@ -32,7 +32,7 @@ from penzai.core._treescope_handlers import selection_rendering import penzai.core.selectors import penzai.core.struct -from tests.fixtures import treescope_examples_fixture as fixture_lib +from tests.treescope.fixtures import treescope_examples_fixture as fixture_lib from penzai.treescope import autovisualize from penzai.treescope import default_renderer from penzai.treescope.foldable_representation import basic_parts @@ -40,6 +40,7 @@ from penzai.treescope.foldable_representation import layout_algorithms from penzai.treescope.foldable_representation import part_interface from penzai.treescope.handlers import function_reflection_handlers +import torch @dataclasses.dataclass @@ -162,7 +163,7 @@ def hook_that_crashes(node, path, node_renderer): target="some string\n with \n newlines in it", expected_collapsed="'some string\\n with \\n newlines in it'", expected_expanded=( - " 'some string\\n'\n ' with \\n'\n ' newlines in it'\n" + " 'some string\\n'\n ' with \\n'\n ' newlines in it'" ), ), dict( @@ -179,8 +180,8 @@ def hook_that_crashes(node, path, node_renderer): ]"""), expected_roundtrip=textwrap.dedent("""\ [ - tests.fixtures.treescope_examples_fixture.MyTestEnum.FOO, # value: 1 - tests.fixtures.treescope_examples_fixture.MyTestEnum.BAR, # value: 2 + tests.treescope.fixtures.treescope_examples_fixture.MyTestEnum.FOO, # value: 1 + tests.treescope.fixtures.treescope_examples_fixture.MyTestEnum.BAR, # value: 2 ]"""), ), dict( @@ -319,7 +320,7 @@ def hook_that_crashes(node, path, node_renderer): bar='qux', )"""), expected_roundtrip=textwrap.dedent("""\ - tests.fixtures.treescope_examples_fixture.SomeNamedtupleClass( + tests.treescope.fixtures.treescope_examples_fixture.SomeNamedtupleClass( foo='baz', bar='qux', )"""), @@ -334,7 +335,7 @@ def hook_that_crashes(node, path, node_renderer): bar='qux', )"""), expected_roundtrip=textwrap.dedent("""\ - tests.fixtures.treescope_examples_fixture.DataclassWithTwoChildren( + tests.treescope.fixtures.treescope_examples_fixture.DataclassWithTwoChildren( foo='baz', bar='qux', )"""), @@ -345,7 +346,7 @@ def hook_that_crashes(node, path, node_renderer): expected_collapsed="EmptyDataclass()", expected_expanded="EmptyDataclass()", expected_roundtrip=( - "tests.fixtures.treescope_examples_fixture.EmptyDataclass()" + "tests.treescope.fixtures.treescope_examples_fixture.EmptyDataclass()" ), ), dict( @@ -357,28 +358,27 @@ def hook_that_crashes(node, path, node_renderer): foo=100, )"""), expected_roundtrip=textwrap.dedent("""\ - tests.fixtures.treescope_examples_fixture.ExampleLayer( + tests.treescope.fixtures.treescope_examples_fixture.ExampleLayer( foo=100, )"""), ), dict( testcase_name="ndarray_small", target=np.array([1, 2, 4, 8, 16]), - expected_collapsed="", - expected_expanded="", + expected_collapsed="np.array([ 1, 2, 4, 8, 16])", + expected_expanded="np.array([ 1, 2, 4, 8, 16])", ), dict( testcase_name="ndarray_large", target=np.arange(3 * 7).reshape((3, 7)), expected_collapsed=( - "" + "" ), expected_expanded=textwrap.dedent("""\ - # numpy.ndarray int64(3, 7) [≥0, ≤20] zero:1 nonzero:20 + # np.ndarray int64(3, 7) [≥0, ≤20] zero:1 nonzero:20 array([[ 0, 1, 2, 3, 4, 5, 6], [ 7, 8, 9, 10, 11, 12, 13], - [14, 15, 16, 17, 18, 19, 20]]) - """), + [14, 15, 16, 17, 18, 19, 20]])"""), ), dict( testcase_name="jax_array_large", @@ -390,8 +390,7 @@ def hook_that_crashes(node, path, node_renderer): # jax.Array int32(3, 7) [≥0, ≤20] zero:1 nonzero:20 Array([[ 0, 1, 2, 3, 4, 5, 6], [ 7, 8, 9, 10, 11, 12, 13], - [14, 15, 16, 17, 18, 19, 20]], dtype=int32) - """), + [14, 15, 16, 17, 18, 19, 20]], dtype=int32)"""), ), dict( testcase_name="named_array_jax", @@ -409,22 +408,6 @@ def hook_that_crashes(node, path, node_renderer): data_array=, )"""), ), - dict( - testcase_name="named_array_np", - target_builder=lambda: penzai.core.named_axes.NamedArray( - named_axes=collections.OrderedDict({"bar": 5, "baz": 7}), - data_array=jnp.arange(3 * 5 * 7).reshape((3, 5, 7)), - ), - expected_collapsed=( - "" - ), - expected_expanded=textwrap.dedent("""\ - NamedArray( # int32(3 | bar:5, baz:7) [≥0, ≤104] zero:1 nonzero:104 - named_axes=OrderedDict({'bar': 5, 'baz': 7}), - data_array=, - )"""), - ), dict( testcase_name="named_array_view_jax", target_builder=lambda: penzai.core.named_axes.NamedArrayView( @@ -445,6 +428,24 @@ def hook_that_crashes(node, path, node_renderer): data_array=, )"""), ), + dict( + testcase_name="pytorch_tensor_small", + target_builder=lambda: torch.tensor(np.array([1, 2, 4, 8, 16])), + expected_collapsed="torch.tensor([ 1, 2, 4, 8, 16])", + expected_expanded="torch.tensor([ 1, 2, 4, 8, 16])", + ), + dict( + testcase_name="pytorch_tensor_large", + target_builder=lambda: torch.tensor(np.arange(3 * 7).reshape((3, 7))), + expected_collapsed=( + "" + ), + expected_expanded=textwrap.dedent("""\ + # torch.Tensor int64(3, 7) [≥0, ≤20] zero:1 nonzero:20 + tensor([[ 0, 1, 2, 3, 4, 5, 6], + [ 7, 8, 9, 10, 11, 12, 13], + [14, 15, 16, 17, 18, 19, 20]])"""), + ), dict( testcase_name="well_known_function", target=default_renderer.render_to_text, @@ -482,13 +483,13 @@ def hook_that_crashes(node, path, node_renderer): testcase_name="dtype_standard", target=np.dtype(np.float32), expected_collapsed="dtype('float32')", - expected_roundtrip_collapsed="numpy.dtype('float32')", + expected_roundtrip_collapsed="np.dtype('float32')", ), dict( testcase_name="dtype_extended", target=np.dtype(jnp.bfloat16), expected_collapsed="dtype('bfloat16')", - expected_roundtrip_collapsed="numpy.dtype('bfloat16')", + expected_roundtrip_collapsed="np.dtype('bfloat16')", ), dict( testcase_name="jax_precision", @@ -628,7 +629,7 @@ def hook_that_crashes(node, path, node_renderer): # # Output: penzai.core.shapecheck.Wildcard('output from body') # #╰┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄╯ handler_id='foo', - body=tests.fixtures.treescope_examples_fixture.LayerThatHoldsStuff( + body=tests.treescope.fixtures.treescope_examples_fixture.LayerThatHoldsStuff( # #╭┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄╮ # # Input: { # 'input': penzai.core.shapecheck.ArraySpec(shape=(1, 2, 3), dtype=numpy.generic, named_shape={}), @@ -647,6 +648,80 @@ def hook_that_crashes(node, path, node_renderer): ], )"""), ), + dict( + testcase_name="pytorch_module", + target_builder=fixture_lib.SomePyTorchModule.build, + expected_collapsed=( + "SomePyTorchModule(attr_one=123, attr_two='abc', training=True," + " foo=, bar=torch.tensor([0., 0., 0., 0., 0.])," + " linear=Linear(in_features=10, out_features=10, training=True," + " weight=, bias=, )," + " mod_list=ModuleList(training=True, (0):" + " LayerNorm(normalized_shape=(10,), eps=1e-05," + " elementwise_affine=True, training=True," + " weight=, bias=, ), (1): SiLU(inplace=False," + " training=True, ), ), )" + ), + expected_expanded=textwrap.dedent("""\ + SomePyTorchModule( + attr_one=123, attr_two='abc', training=True, + # Parameters: + foo=, + # Buffers: + bar=torch.tensor([0., 0., 0., 0., 0.]), + # Child modules: + linear=Linear(in_features=10, out_features=10, training=True, weight=, bias=, ), + mod_list=ModuleList(training=True, (0): LayerNorm(normalized_shape=(10,), eps=1e-05, elementwise_affine=True, training=True, weight=, bias=, ), (1): SiLU(inplace=False, training=True, ), ), + )"""), + expected_roundtrip=textwrap.dedent("""\ + , + # Buffers: + bar=torch.tensor([0., 0., 0., 0., 0.]), + # Child modules: + linear=, bias=, )>, + mod_list=, bias=, )>, (1): , )>, + )>"""), + ), + dict( + testcase_name="pytorch_module_expanded", + target_builder=fixture_lib.SomePyTorchModule.build, + expand_depth=2, + expected_expanded=textwrap.dedent("""\ + SomePyTorchModule( + # Attributes: + attr_one=123, + attr_two='abc', + training=True, + # Parameters: + foo=# torch.nn.Parameter float32(5,) ≈1.0 ±0.0 [≥1.0, ≤1.0] nonzero:5 + Parameter containing: + tensor([1., 1., 1., 1., 1.], requires_grad=True) + , + # Buffers: + bar=torch.tensor([0., 0., 0., 0., 0.]), + # Child modules: + linear=Linear( + in_features=10, out_features=10, training=True, + # Parameters: + weight=, + bias=, + ), + mod_list=ModuleList( + training=True, + # Child modules: + (0): LayerNorm(normalized_shape=(10,), eps=1e-05, elementwise_affine=True, training=True, weight=, bias=, ), + (1): SiLU(inplace=False, training=True, ), + ), + )"""), + ), ) def test_object_rendering( self, @@ -741,7 +816,7 @@ def inner_fn(y): "{'x': 100}", "# Defined at line ", " of ", - "tests/treescope_renderer_test.py", + "tests/treescope/renderer_test.py", ], foldable_impl.render_to_text_as_root(rendering), ) @@ -857,12 +932,12 @@ def test_fallback_repr_basic(self): self.assertContainsInOrder( [ ( - "penzai.treescope.copypaste_fallback.NotRoundtrippable(original_repr='