Skip to content

Commit

Permalink
Refactor treescope to mostly not depend on Penzai core.
Browse files Browse the repository at this point in the history
This change moves around parts of treescope and Penzai to reduce the dependencies
of treescope on the rest of Penzai core. This is the first step toward allowing
treescope to be installed and imported separately.

To support this, the IPython integration now looks for a special method
`__penzai_root_repr__` that allows types to customize their root representations.
Additionally, the `context` and `dataclass_util` modules have been moved to treescope
instead of Penzai core.

PiperOrigin-RevId: 641747922
  • Loading branch information
danieldjohnson authored and Penzai Developers committed Jul 17, 2024
1 parent e750334 commit c8a2804
Show file tree
Hide file tree
Showing 35 changed files with 132 additions and 103 deletions.
2 changes: 0 additions & 2 deletions penzai/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@

"""Penzai core types and functions."""

from . import context
from . import dataclass_util
from . import layer
from . import named_axes
from . import partitioning
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@
import dataclasses
from typing import Any

from penzai.core import context
from penzai.core import layer
from penzai.core import shapecheck
from penzai.core._treescope_handlers import struct_handler
from penzai.data_effects import effect_base
from penzai.nn import grouping
from penzai.treescope import context
from penzai.treescope import renderer
from penzai.treescope.foldable_representation import basic_parts
from penzai.treescope.foldable_representation import common_structures
Expand All @@ -33,7 +34,6 @@
from penzai.treescope.foldable_representation import part_interface
from penzai.treescope.handlers import builtin_structure_handler
from penzai.treescope.handlers import shared_value_postprocessor
from penzai.treescope.handlers.penzai import struct_handler


_already_seen_layer: context.ContextualValue[bool] = context.ContextualValue(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import jax
import numpy as np
from penzai.core import named_axes
from penzai.core._treescope_handlers import struct_handler
from penzai.treescope import ndarray_summarization
from penzai.treescope import renderer
from penzai.treescope.foldable_representation import basic_parts
Expand All @@ -30,7 +31,6 @@
from penzai.treescope.foldable_representation import foldable_impl
from penzai.treescope.foldable_representation import part_interface
from penzai.treescope.handlers import builtin_structure_handler
from penzai.treescope.handlers.penzai import struct_handler


def named_array_and_contained_type_summary(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from typing import Any

import jax
from penzai.core import context
from penzai.core import selectors
from penzai.treescope import context
from penzai.treescope import default_renderer
from penzai.treescope import html_escaping
from penzai.treescope import renderer
Expand Down Expand Up @@ -69,6 +69,11 @@ class SelectedNodeTracker:
"""


def is_rendering_a_selection() -> bool:
"""Returns whether we are currently rendering a selection."""
return _selected_nodes.get() is not None


@dataclasses.dataclass(frozen=True)
class SelectionBoundaryTag:
"""A tag that can be used to identify selected nodes."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
import dataclasses
from typing import Any, Callable

from penzai.core import dataclass_util
from penzai.core import struct
from penzai.treescope import dataclass_util
from penzai.treescope import html_escaping
from penzai.treescope import renderer
from penzai.treescope.foldable_representation import basic_parts
Expand Down
2 changes: 1 addition & 1 deletion penzai/core/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def __init_subclass__(cls, **kwargs):
)

def __penzai_repr__(self, path: str | None, subtree_renderer: Any):
from penzai.treescope.handlers.penzai import layer_handler # pylint: disable=g-import-not-at-top
from penzai.core._treescope_handlers import layer_handler # pylint: disable=g-import-not-at-top

return layer_handler.handle_layers(self, path, subtree_renderer)

Expand Down
2 changes: 1 addition & 1 deletion penzai/core/named_axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1213,7 +1213,7 @@ def __iter__(self):
# Rendering
def __penzai_repr__(self, path: str | None, subtree_renderer: Any):
"""Treescope handler for named arrays."""
from penzai.treescope.handlers.penzai import named_axes_handlers # pylint: disable=g-import-not-at-top
from penzai.core._treescope_handlers import named_axes_handlers # pylint: disable=g-import-not-at-top

return named_axes_handlers.handle_named_arrays(self, path, subtree_renderer)

Expand Down
12 changes: 10 additions & 2 deletions penzai/core/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1193,7 +1193,7 @@ def show_selection(self):
This method should only be used when IPython is available.
"""
# Import selection_rendering here to avoid a circular import.
from penzai.treescope import selection_rendering # pylint: disable=g-import-not-at-top
from penzai.core._treescope_handlers import selection_rendering # pylint: disable=g-import-not-at-top

selection_rendering.display_selection_streaming(
self, visible_selection=True
Expand All @@ -1209,7 +1209,7 @@ def show_value(self):
This method should only be used when IPython is available.
"""
# Import selection_rendering here to avoid a circular import.
from penzai.treescope import selection_rendering # pylint: disable=g-import-not-at-top
from penzai.core._treescope_handlers import selection_rendering # pylint: disable=g-import-not-at-top

selection_rendering.display_selection_streaming(
self, visible_selection=False
Expand Down Expand Up @@ -1299,6 +1299,14 @@ def apply_with_selected_index(
else:
return new_selection.deselect()

def __penzai_root_repr__(self):
"""Renders this selection as the root object in a treescope rendering."""
from penzai.core._treescope_handlers import selection_rendering # pylint: disable=g-import-not-at-top

return selection_rendering.render_selection_to_foldable_representation(
self, visible_selection=True, ignore_exceptions=True
)


@contextlib.contextmanager
def _wrap_selection_errors(selection: "Selection"):
Expand Down
4 changes: 2 additions & 2 deletions penzai/core/shapecheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class Wildcard(struct.Struct):
)

def __penzai_repr__(self, path: str | None, subtree_renderer: Any):
from penzai.treescope.handlers.penzai import shapecheck_handlers # pylint: disable=g-import-not-at-top
from penzai.core._treescope_handlers import shapecheck_handlers # pylint: disable=g-import-not-at-top

return shapecheck_handlers.handle_arraystructures(
self, path, subtree_renderer
Expand Down Expand Up @@ -358,7 +358,7 @@ def into_pytree(self) -> jax.ShapeDtypeStruct | named_axes.NamedArray:
return jax.ShapeDtypeStruct(self.shape, self.dtype)

def __penzai_repr__(self, path: str | None, subtree_renderer: Any):
from penzai.treescope.handlers.penzai import shapecheck_handlers # pylint: disable=g-import-not-at-top
from penzai.core._treescope_handlers import shapecheck_handlers # pylint: disable=g-import-not-at-top

return shapecheck_handlers.handle_arraystructures(
self, path, subtree_renderer
Expand Down
19 changes: 16 additions & 3 deletions penzai/core/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from typing import Any, Callable, Hashable, Literal, Sequence, Type, TypeVar

import jax
from penzai.core import dataclass_util
from typing_extensions import dataclass_transform

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -514,7 +513,21 @@ def from_attributes(cls: Type[T], **field_values) -> T:
Returns:
A new instance of the class.
"""
return dataclass_util.dataclass_from_attributes(cls, **field_values)
# Make sure our fields are correct.
expected_fields = dataclasses.fields(cls) # pytype: disable=wrong-arg-types
expected_names = set(field.name for field in expected_fields)
given_names = set(field_values.keys())
if expected_names != given_names:
raise ValueError(
"Incorrect fields provided to `from_attributes`; expected"
f" {expected_names}, got {given_names}"
)
# Make a new object, bypassing the class's initializer.
value = object.__new__(cls)
# Set all the attributes, bypassing the class's __setattr__.
for k, v in field_values.items():
object.__setattr__(value, k, v)
return value

@typing.final
def attributes_dict(self) -> dict[str, Any]:
Expand Down Expand Up @@ -694,6 +707,6 @@ def _repr_pretty_(self, p, cycle):
p.text(line)

def __penzai_repr__(self, path: str | None, subtree_renderer: Any):
from penzai.treescope.handlers.penzai import struct_handler # pylint: disable=g-import-not-at-top
from penzai.core._treescope_handlers import struct_handler # pylint: disable=g-import-not-at-top

return struct_handler.handle_structs(self, path, subtree_renderer)
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
import functools
from typing import Any

from penzai.core import context
from penzai.core import struct
from penzai.core._treescope_handlers import layer_handler
from penzai.core._treescope_handlers import struct_handler
from penzai.data_effects import effect_base
from penzai.treescope import context
from penzai.treescope import formatting_util
from penzai.treescope import renderer
from penzai.treescope.foldable_representation import basic_parts
Expand All @@ -31,8 +33,6 @@
from penzai.treescope.foldable_representation import foldable_impl
from penzai.treescope.foldable_representation import part_interface
from penzai.treescope.handlers import builtin_structure_handler
from penzai.treescope.handlers.penzai import layer_handler
from penzai.treescope.handlers.penzai import struct_handler

_known_handlers: context.ContextualValue[
dict[str, tuple[effect_base.EffectHandler, str | None]] | None
Expand Down
12 changes: 6 additions & 6 deletions penzai/data_effects/effect_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,9 +390,9 @@ def treescope_color(self):
return get_effect_color(self.effect_protocol())

def __penzai_repr__(self, path: str | None, subtree_renderer: Any):
from penzai.treescope.handlers.penzai import data_effects_handlers # pylint: disable=g-import-not-at-top
from penzai.data_effects import _treescope_handlers # pylint: disable=g-import-not-at-top

return data_effects_handlers.handle_data_effects_objects(
return _treescope_handlers.handle_data_effects_objects(
self, path, subtree_renderer
)

Expand Down Expand Up @@ -444,9 +444,9 @@ def treescope_color(self):
return get_effect_color(self.effect_protocol())

def __penzai_repr__(self, path: str | None, subtree_renderer: Any):
from penzai.treescope.handlers.penzai import data_effects_handlers # pylint: disable=g-import-not-at-top
from penzai.data_effects import _treescope_handlers # pylint: disable=g-import-not-at-top

return data_effects_handlers.handle_data_effects_objects(
return _treescope_handlers.handle_data_effects_objects(
self, path, subtree_renderer
)

Expand Down Expand Up @@ -518,8 +518,8 @@ def treescope_color(self):
return formatting_util.color_from_string(type(self).__qualname__)

def __penzai_repr__(self, path: str | None, subtree_renderer: Any):
from penzai.treescope.handlers.penzai import data_effects_handlers # pylint: disable=g-import-not-at-top
from penzai.data_effects import _treescope_handlers # pylint: disable=g-import-not-at-top

return data_effects_handlers.handle_data_effects_objects(
return _treescope_handlers.handle_data_effects_objects(
self, path, subtree_renderer
)
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import dataclasses

from penzai.core._treescope_handlers import struct_handler
from penzai.experimental.v2.nn import grouping
from penzai.experimental.v2.nn import layer
from penzai.treescope import renderer
Expand All @@ -26,7 +27,6 @@
from penzai.treescope.foldable_representation import common_styles
from penzai.treescope.foldable_representation import part_interface
from penzai.treescope.handlers import builtin_structure_handler
from penzai.treescope.handlers.penzai import struct_handler


def handle_layer(
Expand Down
19 changes: 9 additions & 10 deletions penzai/experimental/v2/pz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,6 @@

# pylint: disable=g-multiple-import,g-importing-member,unused-import

from penzai.core.context import (
ContextualValue,
disable_interactive_context,
enable_interactive_context,
)
from penzai.core.dataclass_util import (
dataclass_from_attributes,
init_takes_fields,
)
import penzai.core.named_axes as nx
from penzai.core.partitioning import (
NotInThisPartition,
Expand Down Expand Up @@ -83,7 +74,15 @@
unbind_state_vars,
freeze_state_vars,
)

from penzai.treescope.context import (
ContextualValue,
disable_interactive_context,
enable_interactive_context,
)
from penzai.treescope.dataclass_util import (
dataclass_from_attributes,
init_takes_fields,
)
from penzai.treescope.treescope_ipython import show

from . import nn
Expand Down
18 changes: 9 additions & 9 deletions penzai/pz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,6 @@

# pylint: disable=g-multiple-import,g-importing-member,unused-import

from penzai.core.context import (
ContextualValue,
disable_interactive_context,
enable_interactive_context,
)
from penzai.core.dataclass_util import (
dataclass_from_attributes,
init_takes_fields,
)
from penzai.core.layer import (
Layer,
LayerLike,
Expand Down Expand Up @@ -60,6 +51,15 @@
from penzai.core.tree_util import (
pretty_keystr,
)
from penzai.treescope.context import (
ContextualValue,
disable_interactive_context,
enable_interactive_context,
)
from penzai.treescope.dataclass_util import (
dataclass_from_attributes,
init_takes_fields,
)
from penzai.treescope.treescope_ipython import show

from . import de
Expand Down
3 changes: 2 additions & 1 deletion penzai/treescope/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@
from . import arrayviz
from . import autovisualize
from . import canonical_aliases
from . import context
from . import copypaste_fallback
from . import dataclass_util
from . import default_renderer
from . import figures
from . import foldable_representation
Expand All @@ -43,5 +45,4 @@
from . import html_escaping
from . import renderer
from . import repr_lib
from . import selection_rendering
from . import treescope_ipython
2 changes: 1 addition & 1 deletion penzai/treescope/arrayviz/array_autovisualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import jax.numpy as jnp
import numpy as np
from penzai.core import named_axes
from penzai.core._treescope_handlers import named_axes_handlers
from penzai.treescope import autovisualize
from penzai.treescope import ndarray_summarization
from penzai.treescope.arrayviz import arrayviz
Expand All @@ -30,7 +31,6 @@
from penzai.treescope.foldable_representation import common_styles
from penzai.treescope.foldable_representation import foldable_impl
from penzai.treescope.foldable_representation import part_interface
from penzai.treescope.handlers.penzai import named_axes_handlers


def _supported_dtype(dtype):
Expand Down
2 changes: 1 addition & 1 deletion penzai/treescope/arrayviz/arrayviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
import jax
import jax.numpy as jnp
import numpy as np
from penzai.core import context
from penzai.core import named_axes
from penzai.treescope import context
from penzai.treescope import figures
from penzai.treescope import html_escaping
from penzai.treescope import ndarray_summarization
Expand Down
2 changes: 1 addition & 1 deletion penzai/treescope/autovisualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import dataclasses
from typing import Any, Protocol

from penzai.core import context
from penzai.treescope import context
from penzai.treescope.foldable_representation import embedded_iframe
from penzai.treescope.foldable_representation import part_interface

Expand Down
2 changes: 1 addition & 1 deletion penzai/treescope/canonical_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from typing import Any, Callable, Literal, Mapping
import warnings

from penzai.core import context
from penzai.treescope import context


@dataclasses.dataclass(frozen=True)
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def dataclass_from_attributes(cls: type[T], **field_values) -> T:
given_names = set(field_values.keys())
if expected_names != given_names:
raise ValueError(
"Incorrect fields provided to `from_attributes`; expected"
"Incorrect fields provided to `dataclass_from_attributes`; expected"
f" {expected_names}, got {given_names}"
)
# Make a new object, bypassing the class's initializer.
Expand Down
Loading

0 comments on commit c8a2804

Please sign in to comment.