Skip to content

Commit

Permalink
Treescope API refactoring: Separate internal implementation and handl…
Browse files Browse the repository at this point in the history
…er interface.

- Moves internal implementation details into an _internal subdirectory.
- Adds function construction wrappers for internal part classes.
- Adds new `handlers` and `rendering_parts` aliases.
- Rewrites most handlers to use the non-internal interface.
- Refactors "figures" to better separate them from renderable parts.

PiperOrigin-RevId: 651758768
  • Loading branch information
danieldjohnson authored and Penzai Developers committed Jul 18, 2024
1 parent c8bf57e commit 7df32a5
Show file tree
Hide file tree
Showing 58 changed files with 3,411 additions and 2,819 deletions.
2 changes: 1 addition & 1 deletion docs/api/treescope.rst
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ construct their own renderer objects and handlers, or directly construct
renderings using Treescope's intermediate representation. Renderer objects and
the expected types of handlers are defined in ``penzai.treescope.renderer``, and
the intermediate representation is currently defined in
``penzai.treescope.foldable_representation``.
``penzai.treescope.rendering_parts``.

.. warning::
The Treescope intermediate representation and handler system will be changing
Expand Down
89 changes: 39 additions & 50 deletions penzai/core/_treescope_handlers/layer_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,11 @@
from penzai.data_effects import effect_base
from penzai.nn import grouping
from penzai.treescope import context
from penzai.treescope import formatting_util
from penzai.treescope import handlers
from penzai.treescope import renderer
from penzai.treescope.foldable_representation import basic_parts
from penzai.treescope.foldable_representation import common_structures
from penzai.treescope.foldable_representation import common_styles
from penzai.treescope.foldable_representation import layout_algorithms
from penzai.treescope.foldable_representation import part_interface
from penzai.treescope.handlers import builtin_structure_handler
from penzai.treescope.handlers import shared_value_postprocessor

from penzai.treescope import rendering_parts
from penzai.treescope._internal import layout_algorithms

_already_seen_layer: context.ContextualValue[bool] = context.ContextualValue(
module=__name__, qualname="_already_seen_layer", initial_value=False
Expand All @@ -54,8 +50,8 @@ def handle_layers(
grouping.CheckedSequential,
),
) -> (
part_interface.RenderableTreePart
| part_interface.RenderableAndLineAnnotations
rendering_parts.RenderableTreePart
| rendering_parts.RenderableAndLineAnnotations
| type(NotImplemented)
):
"""Renders a penzai layer.
Expand Down Expand Up @@ -101,11 +97,11 @@ def handle_layers(
or exc_message is not None
):
if exc_message is not None:
structure_annotation = basic_parts.FoldCondition(
expanded=basic_parts.ScopedSelectableAnnotation(
common_styles.DashedGrayOutlineBox(
common_styles.ErrorColor(
basic_parts.Text(
structure_annotation = rendering_parts.fold_condition(
expanded=rendering_parts.floating_annotation_with_separate_focus(
rendering_parts.dashed_gray_outline_box(
rendering_parts.error_color(
rendering_parts.text(
"Error while inferring input/output structure:"
f" {exc_message}"
)
Expand All @@ -118,17 +114,17 @@ def handle_layers(
# Add input and output type annotations.
# Don't worry about shared values while rendering these, since they
# don't actually appear in the tree.
with shared_value_postprocessor.setup_shared_value_context():
structure_annotation = basic_parts.FoldCondition(
expanded=basic_parts.ScopedSelectableAnnotation(
common_styles.DashedGrayOutlineBox(
common_styles.CommentColor(
basic_parts.OnSeparateLines.build([
basic_parts.siblings_with_annotations(
with handlers.setup_shared_value_context():
structure_annotation = rendering_parts.fold_condition(
expanded=rendering_parts.floating_annotation_with_separate_focus(
rendering_parts.dashed_gray_outline_box(
rendering_parts.comment_color(
rendering_parts.on_separate_lines([
rendering_parts.siblings_with_annotations(
"# Input: ",
subtree_renderer(input_structure),
),
basic_parts.siblings_with_annotations(
rendering_parts.siblings_with_annotations(
"# Output: ",
subtree_renderer(output_structure),
),
Expand All @@ -149,21 +145,14 @@ def handle_layers(
for effect_protocol in free_effects:
effect_type_blobs.append(" ")
effect_type_blobs.append(
common_styles.WithBlockColor(
common_styles.ColoredSingleLineSpanGroup(
common_structures.maybe_qualified_type_name(
effect_protocol
)
),
color=effect_base.get_effect_color(effect_protocol),
)
rendering_parts.maybe_qualified_type_name(effect_protocol)
)
extra_annotations.append(
basic_parts.FoldCondition(
expanded=basic_parts.ScopedSelectableAnnotation(
common_styles.DashedGrayOutlineBox(
common_styles.CommentColor(
basic_parts.siblings(
rendering_parts.fold_condition(
expanded=rendering_parts.floating_annotation_with_separate_focus(
rendering_parts.dashed_gray_outline_box(
rendering_parts.comment_color(
rendering_parts.siblings(
"# Unhandled effects:", *effect_type_blobs
)
)
Expand All @@ -173,13 +162,13 @@ def handle_layers(
)
broken_refs = effect_base.broken_handler_refs(node)
if broken_refs:
with shared_value_postprocessor.setup_shared_value_context():
broken_annotation = basic_parts.FoldCondition(
expanded=basic_parts.ScopedSelectableAnnotation(
common_styles.DashedGrayOutlineBox(
basic_parts.build_full_line_with_annotations(
common_styles.ErrorColor(
basic_parts.Text("# Broken handler refs: ")
with handlers.setup_shared_value_context():
broken_annotation = rendering_parts.fold_condition(
expanded=rendering_parts.floating_annotation_with_separate_focus(
rendering_parts.dashed_gray_outline_box(
rendering_parts.build_full_line_with_annotations(
rendering_parts.error_color(
rendering_parts.text("# Broken handler refs: ")
),
subtree_renderer(broken_refs, path=None),
)
Expand All @@ -189,7 +178,7 @@ def handle_layers(
layout_algorithms.expand_to_depth(broken_annotation, 0)
extra_annotations.append(broken_annotation)

children = builtin_structure_handler.build_field_children(
children = rendering_parts.build_field_children(
node,
path,
subtree_renderer,
Expand All @@ -198,7 +187,7 @@ def handle_layers(
)

background_color, background_pattern = (
builtin_structure_handler.parse_color_and_pattern(
formatting_util.parse_simple_color_and_pattern_spec(
node.treescope_color(), type(node).__name__
)
)
Expand All @@ -208,21 +197,21 @@ def handle_layers(
isinstance(node, grouping.Sequential)
and type(node) is not grouping.Sequential
):
first_line_annotation = common_styles.CommentColor(
basic_parts.Text(" # Sequential")
first_line_annotation = rendering_parts.comment_color(
rendering_parts.text(" # Sequential")
)
elif (
isinstance(node, grouping.CheckedSequential)
and type(node) is not grouping.CheckedSequential
):
first_line_annotation = common_styles.CommentColor(
basic_parts.Text(" # CheckedSequential")
first_line_annotation = rendering_parts.comment_color(
rendering_parts.text(" # CheckedSequential")
)
else:
first_line_annotation = None
# pylint: enable=unidiomatic-typecheck

return common_structures.build_foldable_tree_node_from_children(
return rendering_parts.build_foldable_tree_node_from_children(
prefix=constructor_open,
children=extra_annotations + children,
suffix=")",
Expand Down
47 changes: 20 additions & 27 deletions penzai/core/_treescope_handlers/named_axes_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,10 @@
from penzai.core import named_axes
from penzai.core._treescope_handlers import struct_handler
from penzai.treescope import dtype_util
from penzai.treescope import lowering
from penzai.treescope import renderer
from penzai.treescope.foldable_representation import basic_parts
from penzai.treescope.foldable_representation import common_structures
from penzai.treescope.foldable_representation import common_styles
from penzai.treescope.foldable_representation import foldable_impl
from penzai.treescope.foldable_representation import part_interface
from penzai.treescope.handlers import builtin_structure_handler
from penzai.treescope.handlers.interop import jax_support
from penzai.treescope import rendering_parts
from penzai.treescope._internal.handlers.interop import jax_support


def named_array_and_contained_type_summary(
Expand Down Expand Up @@ -97,8 +93,8 @@ def handle_named_arrays(
path: str | None,
subtree_renderer: renderer.TreescopeSubtreeRenderer,
) -> (
part_interface.RenderableTreePart
| part_interface.RenderableAndLineAnnotations
rendering_parts.RenderableTreePart
| rendering_parts.RenderableAndLineAnnotations
| type(NotImplemented)
):
"""Renders NamedArrays."""
Expand All @@ -114,48 +110,45 @@ def _make_label(inspect_device_data):
summary, contained_type = named_array_and_contained_type_summary(
node, inspect_device_data=inspect_device_data
)
return basic_parts.SummarizableCondition(
summary=common_styles.AbbreviationColor(
basic_parts.Text(
return rendering_parts.summarizable_condition(
summary=rendering_parts.abbreviation_color(
rendering_parts.text(
f"<{type(node).__name__} {summary} (wrapping"
f" {contained_type})>"
)
),
detail=basic_parts.siblings(
common_structures.maybe_qualified_type_name(type(node)),
detail=rendering_parts.siblings(
rendering_parts.maybe_qualified_type_name(type(node)),
"(",
basic_parts.FoldCondition(
expanded=common_styles.CommentColor(
basic_parts.Text(" # " + summary)
rendering_parts.fold_condition(
expanded=rendering_parts.comment_color(
rendering_parts.text(" # " + summary)
)
),
),
)

fields = dataclasses.fields(node)
children = builtin_structure_handler.build_field_children(
children = rendering_parts.build_field_children(
node,
path,
subtree_renderer,
fields_or_attribute_names=fields,
attr_style_fn=struct_handler.struct_attr_style_fn_for_fields(fields),
)

indented_children = basic_parts.IndentedChildren.build(children)
indented_children = rendering_parts.indented_children(children)

return common_structures.build_custom_foldable_tree_node(
label=foldable_impl.maybe_defer_rendering(
return rendering_parts.build_custom_foldable_tree_node(
label=lowering.maybe_defer_rendering(
main_thunk=lambda _: _make_label(inspect_device_data=True),
placeholder_thunk=lambda: _make_label(inspect_device_data=False),
),
contents=basic_parts.SummarizableCondition(
detail=basic_parts.siblings(
indented_children,
")",
)
contents=rendering_parts.summarizable_condition(
detail=rendering_parts.siblings(indented_children, ")")
),
path=path,
expand_state=part_interface.ExpandState.COLLAPSED,
expand_state=rendering_parts.ExpandState.COLLAPSED,
)

return NotImplemented
Loading

0 comments on commit 7df32a5

Please sign in to comment.