Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor array rendering, add type registries, add PyTorch renderer. #65

Merged
merged 1 commit into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions penzai/core/_treescope_handlers/named_axes_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@
import numpy as np
from penzai.core import named_axes
from penzai.core._treescope_handlers import struct_handler
from penzai.treescope import ndarray_summarization
from penzai.treescope import dtype_util
from penzai.treescope import renderer
from penzai.treescope.foldable_representation import basic_parts
from penzai.treescope.foldable_representation import common_structures
from penzai.treescope.foldable_representation import common_styles
from penzai.treescope.foldable_representation import foldable_impl
from penzai.treescope.foldable_representation import part_interface
from penzai.treescope.handlers import builtin_structure_handler
from penzai.treescope.handlers.interop import jax_support


def named_array_and_contained_type_summary(
Expand Down Expand Up @@ -59,7 +60,7 @@ def named_array_and_contained_type_summary(

# Give a short summary for our named arrays.
summary_parts = []
summary_parts.append(ndarray_summarization.get_dtype_name(named_array.dtype))
summary_parts.append(dtype_util.get_dtype_name(named_array.dtype))
summary_parts.append("(")
for i, size in enumerate(named_array.positional_shape):
if i:
Expand All @@ -79,13 +80,13 @@ def named_array_and_contained_type_summary(
summary_parts.append(f"{name}:{size}")
summary_parts.append(")")

if inspect_device_data and ndarray_summarization.safe_to_summarize(
named_array.data_array
if (
inspect_device_data
and isinstance(named_array.data_array, jax.Array)
and jax_support.safe_to_summarize(named_array.data_array)
):
summary_parts.append(
ndarray_summarization.summarize_ndarray(
named_array.data_array, include_shape_and_dtype=False
)
jax_support.summarize_array_data(named_array.data_array)
)

return "".join(summary_parts), contained_type
Expand Down
4 changes: 2 additions & 2 deletions penzai/core/_treescope_handlers/shapecheck_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@

import numpy as np
from penzai.core import shapecheck
from penzai.treescope import dtype_util
from penzai.treescope import html_escaping
from penzai.treescope import ndarray_summarization
from penzai.treescope import renderer
from penzai.treescope.foldable_representation import basic_parts
from penzai.treescope.foldable_representation import common_structures
Expand Down Expand Up @@ -83,7 +83,7 @@ def _arraystructure_summary(
if structure.dtype is np.generic:
summary_parts.append("any")
else:
summary_parts.append(ndarray_summarization.get_dtype_name(structure.dtype))
summary_parts.append(dtype_util.get_dtype_name(structure.dtype))
summary_parts.append("(")
for i, dim in enumerate(structure.shape):
if i:
Expand Down
31 changes: 18 additions & 13 deletions penzai/core/named_axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,29 +900,35 @@ def order_like(

Args:
other: Another named array or named array view. Must have the same set of
named axes as this one. If this is a `NamedArrayView`, must also have
the same positional axes.
named axes as ``self``. If ``other`` is a `NamedArrayView`, ``other``
must also have the same number of positional axes.

Returns:
A new `NamedArray` or `NamedArrayView` that has the content of ``self``
but is possibly transposed to have the same PyTree structure as ``other``
(as long as the arrays have the same shape).
but is possibly transposed to have the axes appear in the same order as
``other`` in the data array. If the arrays have the same named and
positional shapes, the result will have the same PyTree structure as
``other``.
"""
self.check_valid()
other.check_valid()
if isinstance(other, NamedArray):
return self.order_as(*other.named_shape.keys())
elif isinstance(other, NamedArrayView):
if (
self.positional_shape != other.positional_shape
or self.named_shape != other.named_shape
):
if len(self.positional_shape) != len(other.positional_shape):
raise ValueError(
"Calling `order_like` with a NamedArrayView requires the two"
" arrays have the same positional and named shapes."
f" {self.positional_shape=}, {self.named_shape=},"
f" {other.positional_shape=}, {other.named_shape=}"
" arrays to have the same number of positional axes, but got"
f" positional shapes {self.positional_shape=},"
f" {other.positional_shape=}"
)
if set(self.named_shape.keys()) != set(other.named_shape.keys()):
raise ValueError(
"Calling `order_like` with a NamedArrayView requires the two"
" arrays to have the axis names, but got"
f" named shapes {self.named_shape=}, {other.named_shape=}"
)

self_view = self.as_namedarrayview()
new_to_old_data_axis = {}
for old_data_axis, new_data_axis in zip(
Expand All @@ -935,9 +941,8 @@ def order_like(
self_view.data_array,
[new_to_old_data_axis[i] for i in range(self_view.data_array.ndim)],
)
assert new_data_array.shape == other.data_shape
return NamedArrayView(
data_shape=other.data_shape,
data_shape=new_data_array.shape,
data_axis_for_logical_axis=other.data_axis_for_logical_axis,
data_axis_for_name=other.data_axis_for_name,
data_array=new_data_array,
Expand Down
1 change: 0 additions & 1 deletion penzai/pz/ts.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
render_array,
text_on_color,
render_array_sharding,
render_sharded_shape,
)
from penzai.treescope.autovisualize import (
Autovisualizer,
Expand Down
2 changes: 2 additions & 0 deletions penzai/treescope/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
from . import handlers
from . import html_encapsulation
from . import html_escaping
from . import ndarray_adapters
from . import renderer
from . import repr_lib
from . import treescope_ipython
from . import type_registries
Loading
Loading