diff --git a/docs/api/treescope.rst b/docs/api/treescope.rst index 056c1bd..8fb1225 100644 --- a/docs/api/treescope.rst +++ b/docs/api/treescope.rst @@ -109,7 +109,7 @@ construct their own renderer objects and handlers, or directly construct renderings using Treescope's intermediate representation. Renderer objects and the expected types of handlers are defined in ``penzai.treescope.renderer``, and the intermediate representation is currently defined in -``penzai.treescope.foldable_representation``. +``penzai.treescope.rendering_parts``. .. warning:: The Treescope intermediate representation and handler system will be changing diff --git a/penzai/core/_treescope_handlers/layer_handler.py b/penzai/core/_treescope_handlers/layer_handler.py index 8556000..05287ef 100644 --- a/penzai/core/_treescope_handlers/layer_handler.py +++ b/penzai/core/_treescope_handlers/layer_handler.py @@ -26,15 +26,11 @@ from penzai.data_effects import effect_base from penzai.nn import grouping from penzai.treescope import context +from penzai.treescope import formatting_util +from penzai.treescope import handlers from penzai.treescope import renderer -from penzai.treescope.foldable_representation import basic_parts -from penzai.treescope.foldable_representation import common_structures -from penzai.treescope.foldable_representation import common_styles -from penzai.treescope.foldable_representation import layout_algorithms -from penzai.treescope.foldable_representation import part_interface -from penzai.treescope.handlers import builtin_structure_handler -from penzai.treescope.handlers import shared_value_postprocessor - +from penzai.treescope import rendering_parts +from penzai.treescope._internal import layout_algorithms _already_seen_layer: context.ContextualValue[bool] = context.ContextualValue( module=__name__, qualname="_already_seen_layer", initial_value=False @@ -54,8 +50,8 @@ def handle_layers( grouping.CheckedSequential, ), ) -> ( - part_interface.RenderableTreePart - | part_interface.RenderableAndLineAnnotations + rendering_parts.RenderableTreePart + | rendering_parts.RenderableAndLineAnnotations | type(NotImplemented) ): """Renders a penzai layer. @@ -101,11 +97,11 @@ def handle_layers( or exc_message is not None ): if exc_message is not None: - structure_annotation = basic_parts.FoldCondition( - expanded=basic_parts.ScopedSelectableAnnotation( - common_styles.DashedGrayOutlineBox( - common_styles.ErrorColor( - basic_parts.Text( + structure_annotation = rendering_parts.fold_condition( + expanded=rendering_parts.floating_annotation_with_separate_focus( + rendering_parts.dashed_gray_outline_box( + rendering_parts.error_color( + rendering_parts.text( "Error while inferring input/output structure:" f" {exc_message}" ) @@ -118,17 +114,17 @@ def handle_layers( # Add input and output type annotations. # Don't worry about shared values while rendering these, since they # don't actually appear in the tree. - with shared_value_postprocessor.setup_shared_value_context(): - structure_annotation = basic_parts.FoldCondition( - expanded=basic_parts.ScopedSelectableAnnotation( - common_styles.DashedGrayOutlineBox( - common_styles.CommentColor( - basic_parts.OnSeparateLines.build([ - basic_parts.siblings_with_annotations( + with handlers.setup_shared_value_context(): + structure_annotation = rendering_parts.fold_condition( + expanded=rendering_parts.floating_annotation_with_separate_focus( + rendering_parts.dashed_gray_outline_box( + rendering_parts.comment_color( + rendering_parts.on_separate_lines([ + rendering_parts.siblings_with_annotations( "# Input: ", subtree_renderer(input_structure), ), - basic_parts.siblings_with_annotations( + rendering_parts.siblings_with_annotations( "# Output: ", subtree_renderer(output_structure), ), @@ -149,21 +145,14 @@ def handle_layers( for effect_protocol in free_effects: effect_type_blobs.append(" ") effect_type_blobs.append( - common_styles.WithBlockColor( - common_styles.ColoredSingleLineSpanGroup( - common_structures.maybe_qualified_type_name( - effect_protocol - ) - ), - color=effect_base.get_effect_color(effect_protocol), - ) + rendering_parts.maybe_qualified_type_name(effect_protocol) ) extra_annotations.append( - basic_parts.FoldCondition( - expanded=basic_parts.ScopedSelectableAnnotation( - common_styles.DashedGrayOutlineBox( - common_styles.CommentColor( - basic_parts.siblings( + rendering_parts.fold_condition( + expanded=rendering_parts.floating_annotation_with_separate_focus( + rendering_parts.dashed_gray_outline_box( + rendering_parts.comment_color( + rendering_parts.siblings( "# Unhandled effects:", *effect_type_blobs ) ) @@ -173,13 +162,13 @@ def handle_layers( ) broken_refs = effect_base.broken_handler_refs(node) if broken_refs: - with shared_value_postprocessor.setup_shared_value_context(): - broken_annotation = basic_parts.FoldCondition( - expanded=basic_parts.ScopedSelectableAnnotation( - common_styles.DashedGrayOutlineBox( - basic_parts.build_full_line_with_annotations( - common_styles.ErrorColor( - basic_parts.Text("# Broken handler refs: ") + with handlers.setup_shared_value_context(): + broken_annotation = rendering_parts.fold_condition( + expanded=rendering_parts.floating_annotation_with_separate_focus( + rendering_parts.dashed_gray_outline_box( + rendering_parts.build_full_line_with_annotations( + rendering_parts.error_color( + rendering_parts.text("# Broken handler refs: ") ), subtree_renderer(broken_refs, path=None), ) @@ -189,7 +178,7 @@ def handle_layers( layout_algorithms.expand_to_depth(broken_annotation, 0) extra_annotations.append(broken_annotation) - children = builtin_structure_handler.build_field_children( + children = rendering_parts.build_field_children( node, path, subtree_renderer, @@ -198,7 +187,7 @@ def handle_layers( ) background_color, background_pattern = ( - builtin_structure_handler.parse_color_and_pattern( + formatting_util.parse_simple_color_and_pattern_spec( node.treescope_color(), type(node).__name__ ) ) @@ -208,21 +197,21 @@ def handle_layers( isinstance(node, grouping.Sequential) and type(node) is not grouping.Sequential ): - first_line_annotation = common_styles.CommentColor( - basic_parts.Text(" # Sequential") + first_line_annotation = rendering_parts.comment_color( + rendering_parts.text(" # Sequential") ) elif ( isinstance(node, grouping.CheckedSequential) and type(node) is not grouping.CheckedSequential ): - first_line_annotation = common_styles.CommentColor( - basic_parts.Text(" # CheckedSequential") + first_line_annotation = rendering_parts.comment_color( + rendering_parts.text(" # CheckedSequential") ) else: first_line_annotation = None # pylint: enable=unidiomatic-typecheck - return common_structures.build_foldable_tree_node_from_children( + return rendering_parts.build_foldable_tree_node_from_children( prefix=constructor_open, children=extra_annotations + children, suffix=")", diff --git a/penzai/core/_treescope_handlers/named_axes_handlers.py b/penzai/core/_treescope_handlers/named_axes_handlers.py index 11e9734..7355add 100644 --- a/penzai/core/_treescope_handlers/named_axes_handlers.py +++ b/penzai/core/_treescope_handlers/named_axes_handlers.py @@ -24,14 +24,10 @@ from penzai.core import named_axes from penzai.core._treescope_handlers import struct_handler from penzai.treescope import dtype_util +from penzai.treescope import lowering from penzai.treescope import renderer -from penzai.treescope.foldable_representation import basic_parts -from penzai.treescope.foldable_representation import common_structures -from penzai.treescope.foldable_representation import common_styles -from penzai.treescope.foldable_representation import foldable_impl -from penzai.treescope.foldable_representation import part_interface -from penzai.treescope.handlers import builtin_structure_handler -from penzai.treescope.handlers.interop import jax_support +from penzai.treescope import rendering_parts +from penzai.treescope._internal.handlers.interop import jax_support def named_array_and_contained_type_summary( @@ -97,8 +93,8 @@ def handle_named_arrays( path: str | None, subtree_renderer: renderer.TreescopeSubtreeRenderer, ) -> ( - part_interface.RenderableTreePart - | part_interface.RenderableAndLineAnnotations + rendering_parts.RenderableTreePart + | rendering_parts.RenderableAndLineAnnotations | type(NotImplemented) ): """Renders NamedArrays.""" @@ -114,26 +110,26 @@ def _make_label(inspect_device_data): summary, contained_type = named_array_and_contained_type_summary( node, inspect_device_data=inspect_device_data ) - return basic_parts.SummarizableCondition( - summary=common_styles.AbbreviationColor( - basic_parts.Text( + return rendering_parts.summarizable_condition( + summary=rendering_parts.abbreviation_color( + rendering_parts.text( f"<{type(node).__name__} {summary} (wrapping" f" {contained_type})>" ) ), - detail=basic_parts.siblings( - common_structures.maybe_qualified_type_name(type(node)), + detail=rendering_parts.siblings( + rendering_parts.maybe_qualified_type_name(type(node)), "(", - basic_parts.FoldCondition( - expanded=common_styles.CommentColor( - basic_parts.Text(" # " + summary) + rendering_parts.fold_condition( + expanded=rendering_parts.comment_color( + rendering_parts.text(" # " + summary) ) ), ), ) fields = dataclasses.fields(node) - children = builtin_structure_handler.build_field_children( + children = rendering_parts.build_field_children( node, path, subtree_renderer, @@ -141,21 +137,18 @@ def _make_label(inspect_device_data): attr_style_fn=struct_handler.struct_attr_style_fn_for_fields(fields), ) - indented_children = basic_parts.IndentedChildren.build(children) + indented_children = rendering_parts.indented_children(children) - return common_structures.build_custom_foldable_tree_node( - label=foldable_impl.maybe_defer_rendering( + return rendering_parts.build_custom_foldable_tree_node( + label=lowering.maybe_defer_rendering( main_thunk=lambda _: _make_label(inspect_device_data=True), placeholder_thunk=lambda: _make_label(inspect_device_data=False), ), - contents=basic_parts.SummarizableCondition( - detail=basic_parts.siblings( - indented_children, - ")", - ) + contents=rendering_parts.summarizable_condition( + detail=rendering_parts.siblings(indented_children, ")") ), path=path, - expand_state=part_interface.ExpandState.COLLAPSED, + expand_state=rendering_parts.ExpandState.COLLAPSED, ) return NotImplemented diff --git a/penzai/core/_treescope_handlers/selection_rendering.py b/penzai/core/_treescope_handlers/selection_rendering.py index 858e8fa..dba8200 100644 --- a/penzai/core/_treescope_handlers/selection_rendering.py +++ b/penzai/core/_treescope_handlers/selection_rendering.py @@ -22,14 +22,13 @@ from penzai.core import selectors from penzai.treescope import context from penzai.treescope import default_renderer -from penzai.treescope import html_escaping +from penzai.treescope import lowering from penzai.treescope import renderer -from penzai.treescope.foldable_representation import basic_parts -from penzai.treescope.foldable_representation import common_structures -from penzai.treescope.foldable_representation import common_styles -from penzai.treescope.foldable_representation import foldable_impl -from penzai.treescope.foldable_representation import layout_algorithms -from penzai.treescope.foldable_representation import part_interface +from penzai.treescope import rendering_parts +from penzai.treescope._internal import html_escaping +from penzai.treescope._internal import layout_algorithms +from penzai.treescope._internal.parts import basic_parts +from penzai.treescope._internal.parts import part_interface @dataclasses.dataclass @@ -131,8 +130,8 @@ def _wrap_selected_nodes( path: str | None, node_renderer: renderer.TreescopeSubtreeRenderer, ) -> ( - part_interface.RenderableTreePart - | part_interface.RenderableAndLineAnnotations + rendering_parts.RenderableTreePart + | rendering_parts.RenderableAndLineAnnotations | type(NotImplemented) ): """Custom wrapper hook that intercepts selected nodes.""" @@ -148,19 +147,19 @@ def _wrap_selected_nodes( # Tag the child, and possibly annotate its visualization. rendering = node_renderer(node, path) - tagged_rendering = part_interface.RenderableAndLineAnnotations( + tagged_rendering = rendering_parts.RenderableAndLineAnnotations( renderable=SelectionTaggedGroup(rendering.renderable), annotations=rendering.annotations, ) if tracker.visible_boundary: - wrapped_rendering = basic_parts.siblings_with_annotations( - common_structures.build_custom_foldable_tree_node( + wrapped_rendering = rendering_parts.siblings_with_annotations( + rendering_parts.build_custom_foldable_tree_node( contents=SelectionBoundaryBox( - basic_parts.OnSeparateLines.build([ - basic_parts.FoldCondition( + rendering_parts.on_separate_lines([ + rendering_parts.fold_condition( expanded=SelectionBoundaryLabel( - basic_parts.Text("# Selected:") + rendering_parts.text("# Selected:") ) ), tagged_rendering.renderable, @@ -181,7 +180,7 @@ def render_selection_to_foldable_representation( selection: selectors.Selection, visible_selection: bool = True, ignore_exceptions: bool = False, -) -> part_interface.RenderableTreePart: +) -> rendering_parts.RenderableTreePart: """Renders a top-level selection object to its foldable representation. This function produces a rendering of either the selection @@ -226,7 +225,7 @@ def render_selection_to_foldable_representation( visible_boundary=visible_selection, ) with _selected_nodes.set_scoped(tracker): - rendered_ir = basic_parts.build_full_line_with_annotations( + rendered_ir = rendering_parts.build_full_line_with_annotations( extended_renderer.to_foldable_representation( selection.deselect(), ignore_exceptions=ignore_exceptions, @@ -266,7 +265,7 @@ def render_selection_to_foldable_representation( if visible_selection: # Render the keypaths: - keypath_rendering = basic_parts.build_full_line_with_annotations( + keypath_rendering = rendering_parts.build_full_line_with_annotations( base_renderer.to_foldable_representation( tuple(key for key in selection.selected_by_path.keys()), ignore_exceptions=ignore_exceptions, @@ -277,35 +276,37 @@ def render_selection_to_foldable_representation( # Combine everything into a rendering of the selection itself. count = len(selection) - result = basic_parts.Siblings.build( - common_styles.CommentColor(basic_parts.Text("pz.select(")), - basic_parts.IndentedChildren.build([rendered_ir]), - common_styles.CommentColor(basic_parts.Text(").at_keypaths(")), - common_structures.build_custom_foldable_tree_node( - label=common_styles.CommentColor( - basic_parts.FoldCondition( - collapsed=basic_parts.Text( + result = rendering_parts.siblings( + rendering_parts.comment_color(rendering_parts.text("pz.select(")), + rendering_parts.indented_children([rendered_ir]), + rendering_parts.comment_color(rendering_parts.text(").at_keypaths(")), + rendering_parts.build_custom_foldable_tree_node( + label=rendering_parts.comment_color( + rendering_parts.fold_condition( + collapsed=rendering_parts.text( f"<{count} subtrees, highlighted above>" ), - expanded=basic_parts.Text( + expanded=rendering_parts.text( f"# {count} subtrees, highlighted above" ), ) ), - contents=basic_parts.FoldCondition( - expanded=basic_parts.IndentedChildren.build([keypath_rendering]) + contents=rendering_parts.fold_condition( + expanded=rendering_parts.indented_children([keypath_rendering]) ), - expand_state=part_interface.ExpandState.COLLAPSED, + expand_state=rendering_parts.ExpandState.COLLAPSED, ).renderable, - common_styles.CommentColor(basic_parts.Text(")")), + rendering_parts.comment_color(rendering_parts.text(")")), ) else: # Just return our existing rendering. result = rendered_ir if warnings: - result = basic_parts.OnSeparateLines.build([ - common_styles.ErrorColor(basic_parts.OnSeparateLines.build(warnings)), + result = rendering_parts.on_separate_lines([ + rendering_parts.error_color( + rendering_parts.on_separate_lines(warnings) + ), result, ]) @@ -344,12 +345,10 @@ def display_selection_streaming( being called directly, e.g. when registering this as a default pretty-printer. """ - with foldable_impl.collecting_deferred_renderings() as deferreds: + with lowering.collecting_deferred_renderings() as deferreds: rendered_ir = render_selection_to_foldable_representation( selection, visible_selection=visible_selection, ignore_exceptions=ignore_exceptions, ) - foldable_impl.display_streaming_as_root( - rendered_ir, deferreds, roundtrip=False - ) + lowering.display_streaming_as_root(rendered_ir, deferreds, roundtrip=False) diff --git a/penzai/core/_treescope_handlers/shapecheck_handlers.py b/penzai/core/_treescope_handlers/shapecheck_handlers.py index fc2e340..4be0b58 100644 --- a/penzai/core/_treescope_handlers/shapecheck_handlers.py +++ b/penzai/core/_treescope_handlers/shapecheck_handlers.py @@ -22,12 +22,11 @@ import numpy as np from penzai.core import shapecheck from penzai.treescope import dtype_util -from penzai.treescope import html_escaping from penzai.treescope import renderer -from penzai.treescope.foldable_representation import basic_parts -from penzai.treescope.foldable_representation import common_structures -from penzai.treescope.foldable_representation import part_interface -from penzai.treescope.handlers import builtin_structure_handler +from penzai.treescope import rendering_parts +from penzai.treescope._internal import html_escaping +from penzai.treescope._internal.parts import basic_parts +from penzai.treescope._internal.parts import part_interface class ArrayVariableStyle(basic_parts.BaseSpanGroup): @@ -69,13 +68,13 @@ def _span_css_rule( ) -def _wrap_dimvar(msg: str) -> part_interface.RenderableTreePart: - return ArrayVariableStyle(basic_parts.Text(msg)) +def _wrap_dimvar(msg: str) -> rendering_parts.RenderableTreePart: + return ArrayVariableStyle(rendering_parts.text(msg)) def _arraystructure_summary( structure: shapecheck.ArraySpec, -) -> basic_parts.Siblings: +) -> rendering_parts.RenderableTreePart: """Creates a summary line for an array structure.""" # Give a short summary for our named arrays. @@ -128,7 +127,7 @@ def _arraystructure_summary( summary_parts.append(f"{name}:{dim}") summary_parts.append(")") - return basic_parts.siblings(*summary_parts) + return rendering_parts.siblings(*summary_parts) def handle_arraystructures( @@ -136,17 +135,17 @@ def handle_arraystructures( path: str | None, subtree_renderer: renderer.TreescopeSubtreeRenderer, ) -> ( - part_interface.RenderableTreePart - | part_interface.RenderableAndLineAnnotations + rendering_parts.RenderableTreePart + | rendering_parts.RenderableAndLineAnnotations | type(NotImplemented) ): """Renders ArraySpec and contained variables.""" if isinstance(node, shapecheck.Wildcard): summary = "*" if node.description is None else f"<{node.description}>" - return common_structures.build_one_line_tree_node( - line=basic_parts.RoundtripCondition( - roundtrip=basic_parts.siblings( - common_structures.maybe_qualified_type_name(type(node)), + return rendering_parts.build_one_line_tree_node( + line=rendering_parts.roundtrip_condition( + roundtrip=rendering_parts.siblings( + rendering_parts.maybe_qualified_type_name(type(node)), f"({repr(node.description)})", ), not_roundtrip=_wrap_dimvar(summary), @@ -156,32 +155,32 @@ def handle_arraystructures( elif isinstance(node, shapecheck.ArraySpec): summary = _arraystructure_summary(node) - children = builtin_structure_handler.build_field_children( + children = rendering_parts.build_field_children( node, path, subtree_renderer, fields_or_attribute_names=dataclasses.fields(node), ) - indented_children = basic_parts.IndentedChildren.build(children) + indented_children = rendering_parts.indented_children(children) - return common_structures.build_custom_foldable_tree_node( - label=basic_parts.SummarizableCondition( + return rendering_parts.build_custom_foldable_tree_node( + label=rendering_parts.summarizable_condition( summary=ArraySpecStyle( - basic_parts.siblings("") + rendering_parts.siblings("") ), - detail=basic_parts.siblings( - common_structures.maybe_qualified_type_name(type(node)), + detail=rendering_parts.siblings( + rendering_parts.maybe_qualified_type_name(type(node)), "(", ), ), - contents=basic_parts.SummarizableCondition( - detail=basic_parts.siblings( + contents=rendering_parts.summarizable_condition( + detail=rendering_parts.siblings( indented_children, ")", ) ), path=path, - expand_state=part_interface.ExpandState.COLLAPSED, + expand_state=rendering_parts.ExpandState.COLLAPSED, ) else: diff --git a/penzai/core/_treescope_handlers/struct_handler.py b/penzai/core/_treescope_handlers/struct_handler.py index 801aa77..3af454e 100644 --- a/penzai/core/_treescope_handlers/struct_handler.py +++ b/penzai/core/_treescope_handlers/struct_handler.py @@ -21,13 +21,12 @@ from penzai.core import struct from penzai.treescope import dataclass_util -from penzai.treescope import html_escaping +from penzai.treescope import formatting_util from penzai.treescope import renderer -from penzai.treescope.foldable_representation import basic_parts -from penzai.treescope.foldable_representation import common_structures -from penzai.treescope.foldable_representation import common_styles -from penzai.treescope.foldable_representation import part_interface -from penzai.treescope.handlers import builtin_structure_handler +from penzai.treescope import rendering_parts +from penzai.treescope._internal import html_escaping +from penzai.treescope._internal.parts import basic_parts +from penzai.treescope._internal.parts import part_interface class PyTreeNodeFieldName(basic_parts.BaseSpanGroup): @@ -52,17 +51,17 @@ def _span_css_rule( def render_struct_constructor( node: struct.Struct, -) -> part_interface.RenderableTreePart: +) -> rendering_parts.RenderableTreePart: """Renders the constructor of a Struct, with an open parenthesis.""" if dataclass_util.init_takes_fields(type(node)): - return basic_parts.siblings( - common_structures.maybe_qualified_type_name(type(node)), "(" + return rendering_parts.siblings( + rendering_parts.maybe_qualified_type_name(type(node)), "(" ) else: - return basic_parts.siblings( - common_structures.maybe_qualified_type_name(type(node)), - basic_parts.RoundtripCondition( - roundtrip=basic_parts.Text(".from_attributes") + return rendering_parts.siblings( + rendering_parts.maybe_qualified_type_name(type(node)), + rendering_parts.roundtrip_condition( + roundtrip=rendering_parts.text(".from_attributes") ), "(", ) @@ -70,7 +69,7 @@ def render_struct_constructor( def render_short_struct_summary( the_struct: struct.Struct, -) -> part_interface.RenderableTreePart: +) -> rendering_parts.RenderableTreePart: """Renders a short summary of a struct. Can be used by other handlers that manipulate structs. @@ -81,17 +80,21 @@ def render_short_struct_summary( Returns: A short, single-line summary of the struct. """ - return common_styles.WithBlockColor( - common_styles.ColoredSingleLineSpanGroup( - basic_parts.Text(type(the_struct).__name__ + "(...)") - ), - color=the_struct.treescope_color(), + background_color, background_pattern = ( + formatting_util.parse_simple_color_and_pattern_spec( + the_struct.treescope_color(), type(the_struct).__name__ + ) ) + return rendering_parts.build_one_line_tree_node( + rendering_parts.text(type(the_struct).__name__ + "(...)"), + background_color=background_color, + background_pattern=background_pattern, + ).renderable def struct_attr_style_fn_for_fields( fields, -) -> Callable[[str], part_interface.RenderableTreePart]: +) -> Callable[[str], rendering_parts.RenderableTreePart]: """Builds a function to render attributes of a struct. The resulting function will render pytree node fields in a different style. @@ -107,9 +110,9 @@ def struct_attr_style_fn_for_fields( def attr_style_fn(field_name): if struct.is_pytree_node_field(fields_by_name[field_name]): - return PyTreeNodeFieldName(basic_parts.Text(field_name)) + return PyTreeNodeFieldName(rendering_parts.text(field_name)) else: - return basic_parts.Text(field_name) + return rendering_parts.text(field_name) return attr_style_fn @@ -119,8 +122,8 @@ def handle_structs( path: str | None, subtree_renderer: renderer.TreescopeSubtreeRenderer, ) -> ( - part_interface.RenderableTreePart - | part_interface.RenderableAndLineAnnotations + rendering_parts.RenderableTreePart + | rendering_parts.RenderableAndLineAnnotations | type(NotImplemented) ): """Renders a penzai struct or layer. @@ -152,12 +155,12 @@ def handle_structs( fields = dataclasses.fields(node) background_color, background_pattern = ( - builtin_structure_handler.parse_color_and_pattern( + formatting_util.parse_simple_color_and_pattern_spec( node.treescope_color(), type(node).__name__ ) ) - children = builtin_structure_handler.build_field_children( + children = rendering_parts.build_field_children( node, path, subtree_renderer, @@ -165,7 +168,7 @@ def handle_structs( attr_style_fn=struct_attr_style_fn_for_fields(fields), ) - return common_structures.build_foldable_tree_node_from_children( + return rendering_parts.build_foldable_tree_node_from_children( prefix=constructor_open, children=children, suffix=")", diff --git a/penzai/data_effects/_treescope_handlers.py b/penzai/data_effects/_treescope_handlers.py index 2497020..44514f1 100644 --- a/penzai/data_effects/_treescope_handlers.py +++ b/penzai/data_effects/_treescope_handlers.py @@ -27,12 +27,8 @@ from penzai.treescope import context from penzai.treescope import formatting_util from penzai.treescope import renderer -from penzai.treescope.foldable_representation import basic_parts -from penzai.treescope.foldable_representation import common_structures -from penzai.treescope.foldable_representation import common_styles -from penzai.treescope.foldable_representation import foldable_impl -from penzai.treescope.foldable_representation import part_interface -from penzai.treescope.handlers import builtin_structure_handler +from penzai.treescope import rendering_parts +from penzai.treescope._internal.parts import foldable_impl _known_handlers: context.ContextualValue[ dict[str, tuple[effect_base.EffectHandler, str | None]] | None @@ -52,8 +48,8 @@ def handle_data_effects_objects( path: str | None, subtree_renderer: renderer.TreescopeSubtreeRenderer, ) -> ( - part_interface.RenderableTreePart - | part_interface.RenderableAndLineAnnotations + rendering_parts.RenderableTreePart + | rendering_parts.RenderableAndLineAnnotations | type(NotImplemented) ): """Handles data effects objects.""" @@ -66,15 +62,15 @@ def handler_id_interceptor( hyperlink_path=None, ): if isinstance(node, str) and node == handler_id: - child = basic_parts.Text(repr(node)) + child = rendering_parts.text(repr(node)) if hyperlink_path is not None: child = foldable_impl.NodeHyperlink( child=child, target_keypath=hyperlink_path ) - return common_structures.build_one_line_tree_node( - common_styles.CustomTextColor( + return rendering_parts.build_one_line_tree_node( + rendering_parts.custom_text_color( child, - color=formatting_util.color_from_string( + css_color=formatting_util.color_from_string( node, lightness=0.51, chroma=0.11 ), ), @@ -104,11 +100,11 @@ def handler_id_interceptor( if cur_known is not None and node.handler_id in cur_known: handler, handler_path = cur_known[node.handler_id] comment = [ - common_styles.CommentColor( - basic_parts.siblings( + rendering_parts.comment_color( + rendering_parts.siblings( " # Handled by ", foldable_impl.NodeHyperlink( - child=common_structures.maybe_qualified_type_name( + child=rendering_parts.maybe_qualified_type_name( type(handler) ), target_keypath=handler_path, @@ -123,7 +119,7 @@ def handler_id_interceptor( assert dataclasses.is_dataclass(node), "Every struct.Struct is a dataclass" constructor_open = struct_handler.render_struct_constructor(node) fields = dataclasses.fields(node) - children = builtin_structure_handler.build_field_children( + children = rendering_parts.build_field_children( node, path, functools.partial( @@ -135,8 +131,8 @@ def handler_id_interceptor( attr_style_fn=struct_handler.struct_attr_style_fn_for_fields(fields), ) background_color = node.treescope_color() - return basic_parts.siblings_with_annotations( - common_structures.build_foldable_tree_node_from_children( + return rendering_parts.siblings_with_annotations( + rendering_parts.build_foldable_tree_node_from_children( prefix=constructor_open, children=children, suffix=")", @@ -155,16 +151,14 @@ def handler_id_interceptor( if isinstance(node, struct.Struct): constructor_open = struct_handler.render_struct_constructor(node) elif dataclasses.is_dataclass(node): - constructor_open = builtin_structure_handler.render_dataclass_constructor( - node - ) + constructor_open = rendering_parts.render_dataclass_constructor(node) else: return NotImplemented background_color = node.treescope_color() - return common_structures.build_foldable_tree_node_from_children( + return rendering_parts.build_foldable_tree_node_from_children( prefix=constructor_open, - children=builtin_structure_handler.build_field_children( + children=rendering_parts.build_field_children( node, path, functools.partial( diff --git a/penzai/experimental/v2/nn/_treescope_handlers/layer_handler.py b/penzai/experimental/v2/nn/_treescope_handlers/layer_handler.py index b297c5d..2d4ca5c 100644 --- a/penzai/experimental/v2/nn/_treescope_handlers/layer_handler.py +++ b/penzai/experimental/v2/nn/_treescope_handlers/layer_handler.py @@ -21,12 +21,9 @@ from penzai.core._treescope_handlers import struct_handler from penzai.experimental.v2.nn import grouping from penzai.experimental.v2.nn import layer +from penzai.treescope import formatting_util from penzai.treescope import renderer -from penzai.treescope.foldable_representation import basic_parts -from penzai.treescope.foldable_representation import common_structures -from penzai.treescope.foldable_representation import common_styles -from penzai.treescope.foldable_representation import part_interface -from penzai.treescope.handlers import builtin_structure_handler +from penzai.treescope import rendering_parts def handle_layer( @@ -34,8 +31,8 @@ def handle_layer( path: str | None, subtree_renderer: renderer.TreescopeSubtreeRenderer, ) -> ( - part_interface.RenderableTreePart - | part_interface.RenderableAndLineAnnotations + rendering_parts.RenderableTreePart + | rendering_parts.RenderableAndLineAnnotations ): """Renders a penzai layer. @@ -55,7 +52,7 @@ def handle_layer( constructor_open = struct_handler.render_struct_constructor(node) fields = dataclasses.fields(node) - children = builtin_structure_handler.build_field_children( + children = rendering_parts.build_field_children( node, path, subtree_renderer, @@ -64,7 +61,7 @@ def handle_layer( ) background_color, background_pattern = ( - builtin_structure_handler.parse_color_and_pattern( + formatting_util.parse_simple_color_and_pattern_spec( node.treescope_color(), type(node).__name__ ) ) @@ -74,21 +71,21 @@ def handle_layer( isinstance(node, grouping.Sequential) and type(node) is not grouping.Sequential ): - first_line_annotation = common_styles.CommentColor( - basic_parts.Text(" # Sequential") + first_line_annotation = rendering_parts.comment_color( + rendering_parts.text(" # Sequential") ) elif ( isinstance(node, grouping.CheckedSequential) and type(node) is not grouping.CheckedSequential ): - first_line_annotation = common_styles.CommentColor( - basic_parts.Text(" # CheckedSequential") + first_line_annotation = rendering_parts.comment_color( + rendering_parts.text(" # CheckedSequential") ) else: first_line_annotation = None # pylint: enable=unidiomatic-typecheck - return common_structures.build_foldable_tree_node_from_children( + return rendering_parts.build_foldable_tree_node_from_children( prefix=constructor_open, children=children, suffix=")", diff --git a/penzai/toolshed/token_visualization.py b/penzai/toolshed/token_visualization.py index 161b814..f4b790b 100644 --- a/penzai/toolshed/token_visualization.py +++ b/penzai/toolshed/token_visualization.py @@ -24,7 +24,6 @@ import numpy as np from penzai import pz from penzai.treescope import figures -from penzai.treescope.foldable_representation import basic_parts # pylint: disable=invalid-name @@ -152,13 +151,7 @@ def show_token_array( # Add an indentation level, but allow sequences to wrap in the indented # block. - parts.append( - figures.TreescopeRenderingFigure( - basic_parts.IndentedChildren.build( - [pz.ts.inline(*subparts, wrap=True)] - ), - ) - ) + parts.append(figures.indented(pz.ts.inline(*subparts, wrap=True))) return pz.ts.inline(*parts) @@ -211,12 +204,6 @@ def show_token_scores( # Add an indentation level, but allow sequences to wrap in the indented # block. - parts.append( - figures.TreescopeRenderingFigure( - basic_parts.IndentedChildren.build( - [pz.ts.inline(*subparts, wrap=True)] - ), - ) - ) + parts.append(figures.indented(pz.ts.inline(*subparts, wrap=True))) return pz.ts.inline(*parts) diff --git a/penzai/treescope/__init__.py b/penzai/treescope/__init__.py index 3e6c400..14559c7 100644 --- a/penzai/treescope/__init__.py +++ b/penzai/treescope/__init__.py @@ -39,13 +39,12 @@ from . import dataclass_util from . import default_renderer from . import figures -from . import foldable_representation from . import formatting_util from . import handlers -from . import html_encapsulation -from . import html_escaping +from . import lowering from . import ndarray_adapters from . import renderer +from . import rendering_parts from . import repr_lib from . import treescope_ipython from . import type_registries diff --git a/penzai/treescope/_internal/arrayviz_impl.py b/penzai/treescope/_internal/arrayviz_impl.py new file mode 100644 index 0000000..6b98934 --- /dev/null +++ b/penzai/treescope/_internal/arrayviz_impl.py @@ -0,0 +1,708 @@ +# Copyright 2024 The Penzai Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Internal implementation of array visualizer.""" + + +from __future__ import annotations + +import base64 +import dataclasses +import io +import json +import os +from typing import Any, Literal, Sequence + +import numpy as np +from penzai.treescope import ndarray_adapters +from penzai.treescope._internal import html_escaping +from penzai.treescope._internal.parts import basic_parts +from penzai.treescope._internal.parts import part_interface + + +AxisName = Any + +PositionalAxisInfo = ndarray_adapters.PositionalAxisInfo +NamedPositionlessAxisInfo = ndarray_adapters.NamedPositionlessAxisInfo +NamedPositionalAxisInfo = ndarray_adapters.NamedPositionalAxisInfo +AxisInfo = ndarray_adapters.AxisInfo + +ArrayInRegistry = Any + + +def load_arrayvis_javascript() -> str: + """Loads the contents of `arrayvis.js` from the Python package. + + Returns: + Source code for arrayviz. + """ + filepath = __file__ + if filepath is None: + raise ValueError("Could not find the path to arrayviz.js!") + + # Look for the resource relative to the current module's filesystem path. + base = filepath.removesuffix("arrayviz_impl.py") + load_path = os.path.join(base, "js", "arrayviz.js") + + with open(load_path, "r") as f: + return f.read() + + +def html_setup() -> ( + set[part_interface.CSSStyleRule | part_interface.JavaScriptDefn] +): + """Builds the setup HTML that should be included in any arrayviz output cell.""" + arrayviz_src = html_escaping.heuristic_strip_javascript_comments( + load_arrayvis_javascript() + ) + return { + part_interface.CSSStyleRule(html_escaping.without_repeated_whitespace(""" + .arrayviz_container { + white-space: normal; + } + .arrayviz_container .info { + font-family: monospace; + color: #aaaaaa; + margin-bottom: 0.25em; + white-space: pre; + } + .arrayviz_container .info input[type="range"] { + vertical-align: middle; + filter: grayscale(1) opacity(0.5); + } + .arrayviz_container .info input[type="range"]:hover { + filter: grayscale(0.5); + } + .arrayviz_container .info input[type="number"]:not(:focus) { + border-radius: 3px; + } + .arrayviz_container .info input[type="number"]:not(:focus):not(:hover) { + color: #777777; + border: 1px solid #777777; + } + .arrayviz_container .info.sliders { + white-space: pre; + } + .arrayviz_container .hovertip { + display: none; + position: absolute; + background-color: white; + border: 1px solid black; + padding: 0.25ch; + pointer-events: none; + width: fit-content; + overflow: visible; + white-space: pre; + z-index: 1000; + } + .arrayviz_container .hoverbox { + display: none; + position: absolute; + box-shadow: 0 0 0 1px black, 0 0 0 2px white; + pointer-events: none; + z-index: 900; + } + .arrayviz_container .clickdata { + white-space: pre; + } + .arrayviz_container .loading_message { + color: #aaaaaa; + } + """)), + part_interface.JavaScriptDefn( + arrayviz_src + " this.getRootNode().host.defns.arrayviz = arrayviz;" + ), + } + + +def render_array_data_to_html( + array_data: np.ndarray, + valid_mask: np.ndarray, + column_axes: Sequence[int], + row_axes: Sequence[int], + slider_axes: Sequence[int], + axis_labels: list[str], + vmin: float, + vmax: float, + cmap_type: Literal["continuous", "palette_index", "digitbox"], + cmap_data: list[tuple[int, int, int]], + info: str = "", + formatting_instructions: list[dict[str, Any]] | None = None, + dynamic_continous_cmap: bool = False, + raw_min_abs: float | None = None, + raw_max_abs: float | None = None, +) -> str: + """Helper to render an array to HTML by passing arguments to javascript. + + Args: + array_data: Array data to render. + valid_mask: Mask array, of same shape as array_data, that is True for items + we should render. + column_axes: Axes (by index into `array_data`) to arrange as columns, + ordered from outermost group to innermost group. + row_axes: Axes (by index into `array_data`) to arrange as rows, ordered from + outermost group to innermost group. + slider_axes: Axes to bind to sliders. + axis_labels: Labels for each axis. + vmin: Minimum for the colormap. + vmax: Maximum for the colormap. + cmap_type: Type of colormap (see `render_array`) + cmap_data: Data for the colormap, as a sequence of RGB triples. + info: Info for the plot. + formatting_instructions: Formatting instructions for values on mouse hover + or click. These will be interpreted by `formatValueAndIndices` on the + JavaScript side. Can assume each axis is named "a0", "a1", etc. when + running in JavaScript. + dynamic_continous_cmap: Whether to dynamically adjust the colormap during + rendering. + raw_min_abs: Minimum absolute value of the array, for dynamic remapping. + raw_max_abs: Maximum absolute value of the array, for dynamic remapping. + + Returns: + HTML source for an arrayviz rendering. + """ + assert len(array_data.shape) == len(axis_labels) + assert len(valid_mask.shape) == len(axis_labels) + + if formatting_instructions is None: + formatting_instructions = [{"type": "value"}] + + # Compute strides for each axis. We refer to each axis as "a0", "a1", etc + # across the JavaScript boundary. + stride = 1 + strides = {} + for i, axis_size in reversed(list(enumerate(array_data.shape))): + strides[f"a{i}"] = stride + stride *= axis_size + + if cmap_type == "continuous": + converted_array_data = array_data.astype(np.float32) + array_dtype = "float32" + else: + converted_array_data = array_data.astype(np.int32) + array_dtype = "int32" + + def axis_spec_arg(i): + return { + "name": f"a{i}", + "label": axis_labels[i], + "start": 0, + "end": array_data.shape[i], + } + + x_axis_specs_arg = [] + for axis in column_axes: + x_axis_specs_arg.append(axis_spec_arg(axis)) + + y_axis_specs_arg = [] + for axis in row_axes: + y_axis_specs_arg.append(axis_spec_arg(axis)) + + sliced_axis_specs_arg = [] + for axis in slider_axes: + sliced_axis_specs_arg.append(axis_spec_arg(axis)) + + args_json = json.dumps({ + "info": info, + "arrayBase64": base64.b64encode(converted_array_data.tobytes()).decode( + "ascii" + ), + "arrayDtype": array_dtype, + "validMaskBase64": base64.b64encode( + valid_mask.astype(np.uint8).tobytes() + ).decode("ascii"), + "dataStrides": strides, + "xAxisSpecs": x_axis_specs_arg, + "yAxisSpecs": y_axis_specs_arg, + "slicedAxisSpecs": sliced_axis_specs_arg, + "colormapConfig": { + "type": cmap_type, + "min": vmin, + "max": vmax, + "dynamic": dynamic_continous_cmap, + "rawMinAbs": raw_min_abs, + "rawMaxAbs": raw_max_abs, + "cmapData": cmap_data, + }, + "valueFormattingInstructions": formatting_instructions, + }) + # Note: We need to save the parent of the treescope-run-here element first, + # because it will be removed before the runSoon callback executes. + inner_fn = html_escaping.without_repeated_whitespace(""" + const parent = this.parentNode; + const defns = this.getRootNode().host.defns; + defns.runSoon(() => { + const tpl = parent.querySelector('template.deferred_args'); + const config = JSON.parse( + tpl.content.querySelector('script').textContent + ); + tpl.remove(); + defns.arrayviz.buildArrayvizFigure(parent, config); + }); + """) + src = ( + '
' + 'Rendering array...' + f'" + '
' + ) + return src + + +def infer_rows_and_columns( + all_axes: Sequence[AxisInfo], + known_rows: Sequence[AxisInfo] = (), + known_columns: Sequence[AxisInfo] = (), + edge_items_per_axis: tuple[int | None, ...] | None = None, +) -> tuple[list[AxisInfo], list[AxisInfo]]: + """Infers an ordered assignment of axis indices or names to rows and columns. + + The unassigned axes are sorted by size and then assigned to rows and columns + to try to balance the total number of elements along the row and column axes. + This curently uses a greedy algorithm with an adjustment to try to keep + columns longer than rows, except when there are exactly two axes and both are + positional, in which case it lays out axis 0 as the rows and axis 1 as the + columns. + + Axes with logical positions are sorted before axes with only names + (in reverse order, so that later axes are rendered on the inside). Axes with + names only appear afterward, with explicitly-assigned ones before unassigned + ones. + + Args: + all_axes: Sequence of axis infos in the array that should be assigned. + known_rows: Sequence of axis indices or names that must map to rows. + known_columns: Sequence of axis indices or names that must map to columns. + edge_items_per_axis: Optional edge items specification, determining + truncated size of each axis. Must match the ordering of `all_axes`. + + Returns: + Tuple (rows, columns) of assignments, which consist of `known_rows` and + `known_columns` followed by the remaining unassigned axes in a balanced + order. + """ + if edge_items_per_axis is None: + edge_items_per_axis = (None,) * len(all_axes) + + if not known_rows and not known_columns and len(all_axes) == 2: + ax_a, ax_b = all_axes + if ( + isinstance(ax_a, PositionalAxisInfo) + and isinstance(ax_b, PositionalAxisInfo) + and {ax_a.axis_logical_index, ax_b.axis_logical_index} == {0, 1} + ): + # Two-dimensional positional array. Always do rows then columns. + if ax_a.axis_logical_index == 0: + return ([ax_a], [ax_b]) + else: + return ([ax_b], [ax_a]) + + truncated_sizes = { + ax: ax.size if edge_items is None else 2 * edge_items + 1 + for ax, edge_items in zip(all_axes, edge_items_per_axis) + } + unassigned = [ + ax for ax in all_axes if ax not in known_rows and ax not in known_columns + ] + + # Sort by size descending, so that we make the most important layout decisions + # first. + unassigned = sorted( + unassigned, key=lambda ax: (truncated_sizes[ax], ax.size), reverse=True + ) + + # Compute the total size every axis would have if we assigned them to the + # same axis. + unassigned_size = np.prod([truncated_sizes[ax] for ax in unassigned]) + + rows = list(known_rows) + row_size = np.prod([truncated_sizes[ax] for ax in rows]) + columns = list(known_columns) + column_size = np.prod([truncated_sizes[ax] for ax in columns]) + + for ax in unassigned: + axis_size = truncated_sizes[ax] + unassigned_size = unassigned_size // axis_size + if row_size * axis_size > column_size * unassigned_size: + # If we assign this to the row axis, we'll end up with a visualization + # with more rows than columns regardless of what we do later, which can + # waste screen space. Assign to columns instead. + columns.append(ax) + column_size *= axis_size + else: + # Assign to the row axis. We'll assign columns later. + rows.append(ax) + row_size *= axis_size + + # The specific ordering of axes along the rows and the columns is somewhat + # arbitrary. Re-order each so that they have positional then named axes, and + # so that position axes are in reverse position order, and the explicitly + # mentioned named axes are before the unassigned ones. + def ax_sort_key(ax: AxisInfo): + if isinstance(ax, PositionalAxisInfo | NamedPositionalAxisInfo): + return (0, -ax.axis_logical_index) + elif ax in unassigned: + return (2,) + else: + return (1,) + + return sorted(rows, key=ax_sort_key), sorted(columns, key=ax_sort_key) + + +def infer_vmin_vmax( + array: np.ndarray, + mask: np.ndarray, + vmin: float | None, + vmax: float | None, + around_zero: bool, + trim_outliers: bool, +) -> tuple[float, float]: + """Infer reasonable lower and upper colormap bounds from an array.""" + inferring_both_bounds = vmax is None and vmin is None + finite_mask = np.logical_and(np.isfinite(array), mask) + if vmax is None: + if around_zero: + if vmin is not None: + vmax = -vmin # pylint: disable=invalid-unary-operand-type + else: + vmax = np.max(np.where(finite_mask, np.abs(array), 0)) + else: + vmax = np.max(np.where(finite_mask, array, -np.inf)) + + assert vmax is not None + + if vmin is None: + if around_zero: + vmin = -vmax # pylint: disable=invalid-unary-operand-type + else: + vmin = np.min(np.where(finite_mask, array, np.inf)) + + if inferring_both_bounds and trim_outliers: + if around_zero: + center = 0 + else: + center = np.nanmean(np.where(finite_mask, array, np.nan)) + center = np.where(np.isfinite(center), center, 0.0) + + second_moment = np.nanmean( + np.where(finite_mask, np.square(array - center), np.nan) + ) + sigma = np.where( + np.isfinite(second_moment), np.sqrt(second_moment), vmax - vmin + ) + + vmin_limit = center - 3 * sigma + vmin = np.maximum(vmin, vmin_limit) + vmax_limit = center + 3 * sigma + vmax = np.minimum(vmax, vmax_limit) + + return vmin, vmax + + +def infer_abs_min_max( + array: np.ndarray, mask: np.ndarray +) -> tuple[float, float]: + """Infer smallest and largest absolute values in array.""" + finite_mask = np.logical_and(np.isfinite(array), mask) + absmin = np.min( + np.where(np.logical_and(finite_mask, array != 0), np.abs(array), np.inf) + ) + absmin = np.where(np.isinf(absmin), 0.0, absmin) + absmax = np.max(np.where(finite_mask, np.abs(array), -np.inf)) + absmax = np.where(np.isinf(absmax), 0.0, absmax) + return absmin, absmax + + +def infer_balanced_truncation( + shape: Sequence[int], + maximum_size: int, + cutoff_size_per_axis: int, + minimum_edge_items: int, + doubling_bonus: float = 10.0, +) -> tuple[int | None, ...]: + """Infers a balanced truncation from a shape. + + This function computes a set of truncation sizes for each axis of the array + such that it obeys the constraints about array and axis sizes, while also + keeping the relative proportions of the array consistent (e.g. we keep more + elements along axes that were originally longer). This means that the aspect + ratio of the truncated array will still resemble the aspect ratio of the + original array. + + To avoid very-unbalanced renderings and truncate longer axes more than short + ones, this function truncates based on the square-root of the axis size by + default. + + Args: + shape: The shape of the array we are truncating. + maximum_size: Maximum number of elements of an array to show. Arrays larger + than this will be truncated along one or more axes. + cutoff_size_per_axis: Maximum number of elements of each individual axis to + show without truncation. Any axis longer than this will be truncated, with + their visual size increasing logarithmically with the true axis size + beyond this point. + minimum_edge_items: How many values to keep along each axis for truncated + arrays. We may keep more than this up to the budget of maximum_size. + doubling_bonus: Number of elements to add to each axis each time it doubles + beyond `cutoff_size_per_axis`. Used to make longer axes appear visually + longer while still keeping them a reasonable size. + + Returns: + A tuple of edge sizes. Each element corresponds to an axis in `shape`, + and is either `None` (for no truncation) or an integer (corresponding to + the number of elements to keep at the beginning and and at the end). + """ + shape_arr = np.array(list(shape)) + remaining_elements_to_divide = maximum_size + edge_items_per_axis = {} + # Order our shape from smallest to largest, since the smallest axes will + # require the least amount of truncation and will have the most stringent + # constraints. + sorted_axes = np.argsort(shape_arr) + sorted_shape = shape_arr[sorted_axes] + + # Figure out maximum sizes based on the cutoff + cutoff_adjusted_maximum_sizes = np.where( + sorted_shape <= cutoff_size_per_axis, + sorted_shape, + cutoff_size_per_axis + + doubling_bonus * np.log2(sorted_shape / cutoff_size_per_axis), + ) + + # Suppose we want to make a scaled version of the array with relative + # axis sizes + # s0, s1, s2, ... + # The total size is then + # size = (c * s0) * (c * s1) * (c * s2) * ... + # log(size) = ndim * log(c) + [ log s0 + log s1 + log s2 + ... ] + # If we have a known final size we want to reach, we can solve for c as + # c = exp( (log size - [ log s0 + log s1 + log s2 + ... ]) / ndim ) + axis_proportions = np.sqrt(sorted_shape) + log_axis_proportions = np.log(axis_proportions) + for i in range(len(sorted_axes)): + original_axis = sorted_axes[i] + size = shape_arr[original_axis] + # If we truncated this axis and every axis after it proportional to + # their weights, how small of an axis size would we need for this + # axis? + log_c = ( + np.log(remaining_elements_to_divide) - np.sum(log_axis_proportions[i:]) + ) / (len(shape) - i) + soft_limit_for_this_axis = np.exp(log_c + log_axis_proportions[i]) + cutoff_limit_for_this_axis = np.floor( + np.minimum( + soft_limit_for_this_axis, + cutoff_adjusted_maximum_sizes[i], + ) + ) + if size <= 2 * minimum_edge_items + 1 or size <= cutoff_limit_for_this_axis: + # If this axis is already smaller than the minimum size it would have + # after truncation, there's no reason to truncate it. + # But pretend we did, so that other axes still grow monotonically if + # their axis sizes increase. + remaining_elements_to_divide = ( + remaining_elements_to_divide / soft_limit_for_this_axis + ) + edge_items_per_axis[original_axis] = None + elif cutoff_limit_for_this_axis < 2 * minimum_edge_items + 1: + # If this axis is big enough to truncate, but our naive target size is + # smaller than the minimum allowed truncation, we should truncate it + # to the minimum size allowed instead. + edge_items_per_axis[original_axis] = minimum_edge_items + remaining_elements_to_divide = remaining_elements_to_divide / ( + 2 * minimum_edge_items + 1 + ) + else: + # Otherwise, truncate it and all remaining axes based on our target + # truncations. + for j in range(i, len(sorted_axes)): + visual_size = np.floor( + np.minimum( + np.exp(log_c + log_axis_proportions[j]), + cutoff_adjusted_maximum_sizes[j], + ) + ) + edge_items_per_axis[sorted_axes[j]] = int(visual_size // 2) + break + + return tuple( + edge_items_per_axis[orig_axis] for orig_axis in range(len(shape)) + ) + + +def compute_truncated_shape( + shape: tuple[int, ...], + edge_items: tuple[int | None, ...], +) -> tuple[int, ...]: + """Computes the shape of a truncated array. + + This can be used to estimate the size of an array visualization after it has + been truncated by `infer_balanced_truncation`. + + Args: + shape: The original array shape. + edge_items: Number of edge items to keep along each axis. + + Returns: + The shape of the truncated array. + """ + return tuple( + orig if edge is None else 2 * edge + 1 + for orig, edge in zip(shape, edge_items) + ) + + +@dataclasses.dataclass(frozen=True) +class ArrayvizRendering(part_interface.RenderableTreePart): + """A rendering of an array with Arrayviz. + + Attributes: + html_src: HTML source for the rendering. + """ + + html_src: str + + def _compute_collapsed_width(self) -> int: + return 80 + + def _compute_newlines_in_expanded_parent(self) -> int: + return 10 + + def foldables_in_this_part(self) -> Sequence[part_interface.FoldableTreeNode]: + return () + + def _compute_tags_in_this_part(self) -> frozenset[Any]: + return frozenset() + + def render_to_text( + self, + stream: io.TextIOBase, + *, + expanded_parent: bool, + indent: int, + roundtrip_mode: bool, + render_context: dict[Any, Any], + ): + stream.write("") + + def html_setup_parts( + self, setup_context: part_interface.HtmlContextForSetup + ) -> set[part_interface.CSSStyleRule | part_interface.JavaScriptDefn]: + del setup_context + return html_setup() + + def render_to_html( + self, + stream: io.TextIOBase, + *, + at_beginning_of_line: bool = False, + render_context: dict[Any, Any], + ): + stream.write(self.html_src) + + +@dataclasses.dataclass(frozen=True) +class ArrayvizDigitboxRendering(ArrayvizRendering): + """A rendering of a single digitbox with Arrayviz.""" + + def _compute_collapsed_width(self) -> int: + return 2 + + def _compute_newlines_in_expanded_parent(self) -> int: + return 1 + + +@dataclasses.dataclass(frozen=True) +class ValueColoredTextbox(basic_parts.DeferringToChild): + """A rendering of text with a colored background. + + Attributes: + child: Child part to render. + text_color: Color for the text. + background_color: Color for the background, usually from a colormap. + out_of_bounds: Whether this value was out of bounds of the colormap. + value: Underlying float value that is being visualized. Rendered on hover. + """ + + child: part_interface.RenderableTreePart + text_color: str + background_color: str + out_of_bounds: bool + value: float + + def html_setup_parts( + self, setup_context: part_interface.HtmlContextForSetup + ) -> set[part_interface.CSSStyleRule | part_interface.JavaScriptDefn]: + return ( + { + part_interface.CSSStyleRule( + html_escaping.without_repeated_whitespace(""" + .arrayviz_textbox { + padding-left: 0.5ch; + padding-right: 0.5ch; + outline: 1px solid black; + position: relative; + display: inline-block; + font-family: monospace; + white-space: pre; + margin-top: 1px; + box-sizing: border-box; + } + .arrayviz_textbox.out_of_bounds { + outline: 3px double darkorange; + } + .arrayviz_textbox .value { + display: none; + position: absolute; + bottom: 110%; + left: 0; + overflow: visible; + color: black; + background-color: white; + font-size: 0.7em; + } + .arrayviz_textbox:hover .value { + display: block; + } + """) + ) + } + | self.child.html_setup_parts(setup_context) + ) + + def render_to_html( + self, + stream: io.TextIOBase, + *, + at_beginning_of_line: bool = False, + render_context: dict[Any, Any], + ): + class_string = "arrayviz_textbox" + if self.out_of_bounds: + class_string += " out_of_bounds" + bg_color = html_escaping.escape_html_attribute(self.background_color) + text_color = html_escaping.escape_html_attribute(self.text_color) + stream.write( + f'' + f'{float(self.value):.4g}' + ) + self.child.render_to_html( + stream, + at_beginning_of_line=False, + render_context=render_context, + ) + stream.write("") diff --git a/penzai/treescope/_internal/figures_impl.py b/penzai/treescope/_internal/figures_impl.py new file mode 100644 index 0000000..191d219 --- /dev/null +++ b/penzai/treescope/_internal/figures_impl.py @@ -0,0 +1,133 @@ +# Copyright 2024 The Penzai Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Internal definitions for figure utilities.""" + + +from __future__ import annotations + +import dataclasses +import io +from typing import Any + +from penzai.treescope import lowering +from penzai.treescope._internal import html_escaping +from penzai.treescope._internal.parts import basic_parts +from penzai.treescope._internal.parts import part_interface + + +@dataclasses.dataclass(frozen=True) +class TreescopeFigure: + """Wrapper that renders its child Treescope part as an IPython figure. + + This class implements the IPython display methods, so that it can be rendered + to IPython at the top level. + + Attributes: + treescope_part: Child to render. + """ + + treescope_part: part_interface.RenderableTreePart + + def _repr_html_(self) -> str: + """Returns a rich HTML representation of this part.""" + return lowering.render_to_html_as_root(self.treescope_part, compressed=True) + + def _repr_pretty_(self, p, cycle): + """Builds a representation of this part for the IPython text prettyprinter.""" + del cycle + p.text(lowering.render_to_text_as_root(self.treescope_part)) + + +class InlineBlock(basic_parts.BaseSpanGroup): + """Renders an object in "inline-block" mode.""" + + def _span_css_class(self) -> str: + return "inline_block" + + def _span_css_rule( + self, context: part_interface.HtmlContextForSetup + ) -> part_interface.CSSStyleRule: + return part_interface.CSSStyleRule( + html_escaping.without_repeated_whitespace(""" + .inline_block { + display: inline-block; + } + """) + ) + + +class AllowWordWrap(basic_parts.BaseSpanGroup): + """Allows line breaks in its child..""" + + def _span_css_class(self) -> str: + return "allow_wrap" + + def _span_css_rule( + self, context: part_interface.HtmlContextForSetup + ) -> part_interface.CSSStyleRule: + return part_interface.CSSStyleRule( + html_escaping.without_repeated_whitespace(""" + .allow_wrap { + white-space: pre-wrap; + } + """) + ) + + +class PreventWordWrap(basic_parts.BaseSpanGroup): + """Allows line breaks in its child..""" + + def _span_css_class(self) -> str: + return "prevent_wrap" + + def _span_css_rule( + self, context: part_interface.HtmlContextForSetup + ) -> part_interface.CSSStyleRule: + return part_interface.CSSStyleRule( + html_escaping.without_repeated_whitespace(""" + .prevent_wrap { + white-space: pre; + } + """) + ) + + +@dataclasses.dataclass(frozen=True) +class CSSStyled(basic_parts.DeferringToChild): + """Adjusts the CSS style of its child. + + Attributes: + child: Child to render. + css: A CSS style string. + """ + + child: part_interface.RenderableTreePart + style: str + + def render_to_html( + self, + stream: io.TextIOBase, + *, + at_beginning_of_line: bool = False, + render_context: dict[Any, Any], + ): + style = html_escaping.escape_html_attribute(self.style) + stream.write(f'') + self.child.render_to_html( + stream, + at_beginning_of_line=at_beginning_of_line, + render_context=render_context, + ) + stream.write("") diff --git a/penzai/treescope/handlers/__init__.py b/penzai/treescope/_internal/handlers/__init__.py similarity index 100% rename from penzai/treescope/handlers/__init__.py rename to penzai/treescope/_internal/handlers/__init__.py diff --git a/penzai/treescope/handlers/autovisualizer_hook.py b/penzai/treescope/_internal/handlers/autovisualizer_hook.py similarity index 59% rename from penzai/treescope/handlers/autovisualizer_hook.py rename to penzai/treescope/_internal/handlers/autovisualizer_hook.py index a4b4094..c0db584 100644 --- a/penzai/treescope/handlers/autovisualizer_hook.py +++ b/penzai/treescope/_internal/handlers/autovisualizer_hook.py @@ -19,13 +19,10 @@ from typing import Any from penzai.treescope import autovisualize +from penzai.treescope import lowering from penzai.treescope import renderer -from penzai.treescope.foldable_representation import basic_parts -from penzai.treescope.foldable_representation import common_structures -from penzai.treescope.foldable_representation import common_styles -from penzai.treescope.foldable_representation import embedded_iframe -from penzai.treescope.foldable_representation import foldable_impl -from penzai.treescope.foldable_representation import part_interface +from penzai.treescope import rendering_parts +from penzai.treescope._internal import object_inspection IPythonVisualization = autovisualize.IPythonVisualization CustomTreescopeVisualization = autovisualize.CustomTreescopeVisualization @@ -37,8 +34,8 @@ def use_autovisualizer_if_present( path: str | None, node_renderer: renderer.TreescopeSubtreeRenderer, ) -> ( - part_interface.RenderableTreePart - | part_interface.RenderableAndLineAnnotations + rendering_parts.RenderableTreePart + | rendering_parts.RenderableAndLineAnnotations | type(NotImplemented) ): """Treescope wrapper hook that runs the active autovisualizer.""" @@ -59,35 +56,37 @@ def use_autovisualizer_if_present( ordinary_result = node_renderer(node, path) if isinstance(result, IPythonVisualization): - if isinstance(result.display_object, embedded_iframe.HasReprHtml): + if isinstance(result.display_object, object_inspection.HasReprHtml): obj = result.display_object def _thunk(_): - html_rendering = embedded_iframe.to_html(obj) + html_rendering = object_inspection.to_html(obj) if html_rendering: - return embedded_iframe.EmbeddedIFrame( + return rendering_parts.embedded_iframe( embedded_html=html_rendering, - fallback_in_text_mode=common_styles.AbbreviationColor( - basic_parts.Text("") + fallback_in_text_mode=rendering_parts.abbreviation_color( + rendering_parts.text("") ), ) else: - return common_styles.ErrorColor( - basic_parts.Text( + return rendering_parts.error_color( + rendering_parts.text( "" ) ) - ipy_rendering = foldable_impl.maybe_defer_rendering( + ipy_rendering = lowering.maybe_defer_rendering( _thunk, - lambda: basic_parts.Text(""), + lambda: rendering_parts.text( + "" + ), ) else: # Bad display object - ipy_rendering = common_structures.build_one_line_tree_node( - line=common_styles.ErrorColor( - basic_parts.Text( + ipy_rendering = rendering_parts.build_one_line_tree_node( + line=rendering_parts.error_color( + rendering_parts.text( "" ) @@ -97,32 +96,36 @@ def _thunk(_): if result.replace: replace = True rendering_and_annotations = ( - common_structures.build_custom_foldable_tree_node( - label=common_styles.AbbreviationColor( - basic_parts.Text(f"")), + rendering_parts.abbreviation_color( + rendering_parts.text(">") + ), ), path=path, - expand_state=part_interface.ExpandState.EXPANDED, + expand_state=rendering_parts.ExpandState.EXPANDED, ) ) else: replace = False - rendering_and_annotations = part_interface.RenderableAndLineAnnotations( - renderable=basic_parts.ScopedSelectableAnnotation( - common_styles.DashedGrayOutlineBox(ipy_rendering) + rendering_and_annotations = rendering_parts.RenderableAndLineAnnotations( + renderable=rendering_parts.floating_annotation_with_separate_focus( + rendering_parts.dashed_gray_outline_box(ipy_rendering) ), - annotations=basic_parts.EmptyPart(), + annotations=rendering_parts.empty_part(), ) else: assert isinstance(result, CustomTreescopeVisualization) @@ -130,26 +133,28 @@ def _thunk(_): rendering_and_annotations = result.rendering if replace: - in_roundtrip_with_annotations = basic_parts.siblings_with_annotations( + in_roundtrip_with_annotations = rendering_parts.siblings_with_annotations( ordinary_result, extra_annotations=[ - common_styles.CommentColor( - basic_parts.Text(" # Visualization hidden in roundtrip mode") + rendering_parts.comment_color( + rendering_parts.text( + " # Visualization hidden in roundtrip mode" + ) ) ], ) - return part_interface.RenderableAndLineAnnotations( - renderable=basic_parts.RoundtripCondition( + return rendering_parts.RenderableAndLineAnnotations( + renderable=rendering_parts.roundtrip_condition( roundtrip=in_roundtrip_with_annotations.renderable, not_roundtrip=rendering_and_annotations.renderable, ), - annotations=basic_parts.RoundtripCondition( + annotations=rendering_parts.roundtrip_condition( roundtrip=in_roundtrip_with_annotations.annotations, not_roundtrip=rendering_and_annotations.annotations, ), ) else: - return basic_parts.siblings_with_annotations( + return rendering_parts.siblings_with_annotations( ordinary_result, rendering_and_annotations ) @@ -159,9 +164,9 @@ def _thunk(_): return node_renderer(node, path) else: - return common_structures.build_one_line_tree_node( - line=common_styles.ErrorColor( - basic_parts.Text( + return rendering_parts.build_one_line_tree_node( + line=rendering_parts.error_color( + rendering_parts.text( f"" diff --git a/penzai/treescope/_internal/handlers/basic_types_handler.py b/penzai/treescope/_internal/handlers/basic_types_handler.py new file mode 100644 index 0000000..70ced90 --- /dev/null +++ b/penzai/treescope/_internal/handlers/basic_types_handler.py @@ -0,0 +1,456 @@ +# Copyright 2024 The Penzai Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Handlers for basic Python types, along with namedtuples and dataclasses.""" + + +import ast +import dataclasses +import enum +import types +from typing import Any + +from penzai.treescope import formatting_util +from penzai.treescope import renderer +from penzai.treescope import rendering_parts +from penzai.treescope._internal import html_escaping +from penzai.treescope._internal.parts import basic_parts +from penzai.treescope._internal.parts import part_interface + +CSSStyleRule = part_interface.CSSStyleRule +HtmlContextForSetup = part_interface.HtmlContextForSetup + + +class KeywordColor(basic_parts.BaseSpanGroup): + """Renders its child in a color for keywords.""" + + def _span_css_class(self) -> str: + return "color_keyword" + + def _span_css_rule(self, context: HtmlContextForSetup) -> CSSStyleRule: + return CSSStyleRule(html_escaping.without_repeated_whitespace(""" + .color_keyword + { + color: #0000ff; + } + """)) + + +class NumberColor(basic_parts.BaseSpanGroup): + """Renders its child in a color for numbers.""" + + def _span_css_class(self) -> str: + return "color_number" + + def _span_css_rule(self, context: HtmlContextForSetup) -> CSSStyleRule: + return CSSStyleRule(html_escaping.without_repeated_whitespace(""" + .color_number + { + color: #098156; + } + """)) + + +class StringLiteralColor(basic_parts.BaseSpanGroup): + """Renders its child in a color for string literals.""" + + def _span_css_class(self) -> str: + return "color_string" + + def _span_css_rule(self, context: HtmlContextForSetup) -> CSSStyleRule: + return CSSStyleRule(html_escaping.without_repeated_whitespace(""" + .color_string + { + color: #a31515; + } + """)) + + +def render_string_or_bytes( + node: str | bytes, + path: str | None, + subtree_renderer: renderer.TreescopeSubtreeRenderer, +) -> ( + rendering_parts.RenderableTreePart + | rendering_parts.RenderableAndLineAnnotations + | type(NotImplemented) +): + """Renders a string or bytes literal.""" + del subtree_renderer + lines = node.splitlines(keepends=True) + if len(lines) > 1: + # For multiline strings, we use two renderings: + # - When collapsed, they render with ordinary `repr`, + # - When expanded, they render as the implicit concatenation of per-line + # string literals. + # Note that the `repr` for a string sometimes switches delimiters + # depending on whether the string contains quotes or not, so we can't do + # much manipulation of the strings themselves. This means that the safest + # thing to do is to just embed two copies of the string into the IR, + # one for the full string and the other for each line. + return rendering_parts.build_custom_foldable_tree_node( + contents=StringLiteralColor( + rendering_parts.fold_condition( + collapsed=rendering_parts.text(repr(node)), + expanded=rendering_parts.indented_children( + children=[ + rendering_parts.text(repr(line)) for line in lines + ], + comma_separated=False, + ), + ) + ), + path=path, + ) + else: + # No newlines, so render it on a single line. + return rendering_parts.build_one_line_tree_node( + StringLiteralColor(rendering_parts.text(repr(node))), path + ) + + +def render_numeric_literal( + node: int | float, + path: str | None, + subtree_renderer: renderer.TreescopeSubtreeRenderer, +) -> ( + rendering_parts.RenderableTreePart + | rendering_parts.RenderableAndLineAnnotations + | type(NotImplemented) +): + """Renders a numeric literal.""" + del subtree_renderer + return rendering_parts.build_one_line_tree_node( + NumberColor(rendering_parts.text(repr(node))), path + ) + + +def render_keyword( + node: bool | None | type(Ellipsis) | type(NotImplemented), + path: str | None, + subtree_renderer: renderer.TreescopeSubtreeRenderer, +) -> ( + rendering_parts.RenderableTreePart + | rendering_parts.RenderableAndLineAnnotations + | type(NotImplemented) +): + """Renders a builtin constant (None, False, True, ..., NotImplemented).""" + del subtree_renderer + return rendering_parts.build_one_line_tree_node( + KeywordColor(rendering_parts.text(repr(node))), path + ) + + +def render_enum( + node: enum.Enum, + path: str | None, + subtree_renderer: renderer.TreescopeSubtreeRenderer, +) -> ( + rendering_parts.RenderableTreePart + | rendering_parts.RenderableAndLineAnnotations + | type(NotImplemented) +): + """Renders an enum (roundtrippably, unlike the normal enum `repr`).""" + del subtree_renderer + cls = type(node) + if node is getattr(cls, node.name): + return rendering_parts.build_one_line_tree_node( + rendering_parts.siblings_with_annotations( + rendering_parts.maybe_qualified_type_name(cls), + "." + node.name, + extra_annotations=[ + rendering_parts.comment_color( + rendering_parts.text(f" # value: {repr(node.value)}") + ) + ], + ), + path, + ) + + +def render_dict( + node: dict[Any, Any], + path: str | None, + subtree_renderer: renderer.TreescopeSubtreeRenderer, +) -> rendering_parts.RenderableAndLineAnnotations: + """Renders a dictionary.""" + + children = [] + for i, (key, child) in enumerate(node.items()): + if i < len(node) - 1: + # Not the last child. Always show a comma, and add a space when + # collapsed. + comma_after = rendering_parts.siblings( + ",", + rendering_parts.fold_condition(collapsed=rendering_parts.text(" ")), + ) + else: + # Last child: only show the comma when the node is expanded. + comma_after = rendering_parts.fold_condition( + expanded=rendering_parts.text(",") + ) + + child_path = None if path is None else f"{path}[{repr(key)}]" + # Figure out whether this key is simple enough to render inline with + # its value. + key_rendering = subtree_renderer(key) + value_rendering = subtree_renderer(child, path=child_path) + + if ( + key_rendering.renderable.collapsed_width < 40 + and not key_rendering.renderable.foldables_in_this_part() + and key_rendering.annotations.collapsed_width == 0 + ): + # Simple enough to render on one line. + children.append( + rendering_parts.siblings_with_annotations( + key_rendering, ": ", value_rendering, comma_after + ) + ) + else: + # Should render on multiple lines. + children.append( + rendering_parts.siblings( + rendering_parts.build_full_line_with_annotations( + key_rendering, + ":", + rendering_parts.fold_condition( + collapsed=rendering_parts.text(" ") + ), + ), + rendering_parts.indented_children([ + rendering_parts.siblings_with_annotations( + value_rendering, comma_after + ), + rendering_parts.fold_condition( + expanded=rendering_parts.vertical_space("0.5em") + ), + ]), + ) + ) + + if type(node) is dict: # pylint: disable=unidiomatic-typecheck + start = "{" + end = "}" + else: + start = rendering_parts.siblings( + rendering_parts.maybe_qualified_type_name(type(node)), "({" + ) + end = "})" + + if not children: + return rendering_parts.build_one_line_tree_node( + line=rendering_parts.siblings(start, end), path=path + ) + else: + return rendering_parts.build_foldable_tree_node_from_children( + prefix=start, + children=children, + suffix=end, + path=path, + ) + + +def render_sequence_or_set( + sequence: dict[Any, Any], + path: str | None, + subtree_renderer: renderer.TreescopeSubtreeRenderer, +) -> rendering_parts.RenderableAndLineAnnotations: + """Renders a sequence or set to a foldable.""" + if ( + isinstance(sequence, tuple) + and type(sequence) is not tuple # pylint: disable=unidiomatic-typecheck + and hasattr(type(sequence), "_fields") + ): + # This is actually a namedtuple, which renders with keyword arguments. + return render_namedtuple_or_ast(sequence, path, subtree_renderer) + + children = [] + for i, child in enumerate(sequence): + child_path = None if path is None else f"{path}[{repr(i)}]" + children.append(subtree_renderer(child, path=child_path)) + + force_trailing_comma = False + if isinstance(sequence, tuple): + before = "(" + after = ")" + if type(sequence) is not tuple: # pylint: disable=unidiomatic-typecheck + # Subclass of `tuple`. + assert not hasattr(type(sequence), "_fields"), "impossible: checked above" + # Unusual situation: this is a subclass of `tuple`, but it isn't a + # namedtuple. Assume we can call it with a single ordinary tuple as an + # argument. + before = rendering_parts.siblings( + rendering_parts.maybe_qualified_type_name(type(sequence)), + "(" + before, + ) + after = after + ")" + force_trailing_comma = len(sequence) == 1 + elif isinstance(sequence, list): + before = "[" + after = "]" + if type(sequence) is not list: # pylint: disable=unidiomatic-typecheck + before = rendering_parts.siblings( + rendering_parts.maybe_qualified_type_name(type(sequence)), + "(" + before, + ) + after = after + ")" + elif isinstance(sequence, set): + if not sequence: + before = "set(" + after = ")" + else: # pylint: disable=unidiomatic-typecheck + before = "{" + after = "}" + + if type(sequence) is not set: # pylint: disable=unidiomatic-typecheck + before = rendering_parts.siblings( + rendering_parts.maybe_qualified_type_name(type(sequence)), + "(" + before, + ) + after = after + ")" + elif isinstance(sequence, frozenset): + before = "frozenset({" + after = "})" + if type(sequence) is not frozenset: # pylint: disable=unidiomatic-typecheck + before = rendering_parts.siblings( + rendering_parts.maybe_qualified_type_name(type(sequence)), + "(" + before, + ) + after = after + ")" + else: + raise ValueError(f"Unrecognized sequence {sequence}") + + if not children: + return rendering_parts.build_one_line_tree_node( + line=rendering_parts.siblings(before, after), path=path + ) + else: + return rendering_parts.build_foldable_tree_node_from_children( + prefix=before, + children=children, + suffix=after, + path=path, + comma_separated=True, + force_trailing_comma=force_trailing_comma, + ) + + +def render_simplenamespace( + node: types.SimpleNamespace, + path: str | None, + subtree_renderer: renderer.TreescopeSubtreeRenderer, +) -> rendering_parts.RenderableAndLineAnnotations: + """Renders a SimpleNamespace.""" + return rendering_parts.build_foldable_tree_node_from_children( + prefix=rendering_parts.siblings( + rendering_parts.maybe_qualified_type_name(type(node)), "(" + ), + children=rendering_parts.build_field_children( + node, + path, + subtree_renderer, + fields_or_attribute_names=tuple(node.__dict__.keys()), + ), + suffix=")", + path=path, + ) + + +def render_namedtuple_or_ast( + node: tuple[Any, ...] | ast.AST, + path: str | None, + subtree_renderer: renderer.TreescopeSubtreeRenderer, +) -> rendering_parts.RenderableAndLineAnnotations: + """Renders a namedtuple or AST class.""" + ty = type(node) + assert hasattr(ty, "_fields") + return rendering_parts.build_foldable_tree_node_from_children( + prefix=rendering_parts.siblings( + rendering_parts.maybe_qualified_type_name(ty), "(" + ), + children=rendering_parts.build_field_children( + node, path, subtree_renderer, fields_or_attribute_names=ty._fields + ), + suffix=")", + path=path, + ) + + +BUILTINS_REGISTRY = { + # Builtin atomic types. + str: render_string_or_bytes, + bytes: render_string_or_bytes, + int: render_numeric_literal, + float: render_numeric_literal, + bool: render_keyword, + type(None): render_keyword, + type(NotImplemented): render_keyword, + type(Ellipsis): render_keyword, + enum.Enum: render_enum, + # Builtin basic structures. + dict: render_dict, + tuple: render_sequence_or_set, + list: render_sequence_or_set, + set: render_sequence_or_set, + frozenset: render_sequence_or_set, + types.SimpleNamespace: render_simplenamespace, + ast.AST: render_namedtuple_or_ast, +} + + +def handle_basic_types( + node: Any, + path: str | None, + subtree_renderer: renderer.TreescopeSubtreeRenderer, +) -> ( + rendering_parts.RenderableTreePart + | rendering_parts.RenderableAndLineAnnotations + | type(NotImplemented) +): + """Renders basic builtin Python types.""" + candidate_type = type(node) + for supertype in candidate_type.__mro__: + if supertype in BUILTINS_REGISTRY: + return BUILTINS_REGISTRY[supertype](node, path, subtree_renderer) + + if dataclasses.is_dataclass(node) and not isinstance(node, type): + constructor_open = rendering_parts.render_dataclass_constructor(node) + if hasattr(node, "__treescope_color__") and callable( + node.__treescope_color__ + ): + background_color, background_pattern = ( + formatting_util.parse_simple_color_and_pattern_spec( + node.__treescope_color__(), type(node).__name__ + ) + ) + else: + background_color = None + background_pattern = None + + return rendering_parts.build_foldable_tree_node_from_children( + prefix=constructor_open, + children=rendering_parts.build_field_children( + node, + path, + subtree_renderer, + fields_or_attribute_names=dataclasses.fields(node), + ), + suffix=")", + path=path, + background_color=background_color, + background_pattern=background_pattern, + ) + + return NotImplemented diff --git a/penzai/treescope/handlers/canonical_alias_postprocessor.py b/penzai/treescope/_internal/handlers/canonical_alias_postprocessor.py similarity index 75% rename from penzai/treescope/handlers/canonical_alias_postprocessor.py rename to penzai/treescope/_internal/handlers/canonical_alias_postprocessor.py index f994d3d..762b6e1 100644 --- a/penzai/treescope/handlers/canonical_alias_postprocessor.py +++ b/penzai/treescope/_internal/handlers/canonical_alias_postprocessor.py @@ -36,10 +36,7 @@ from penzai.treescope import canonical_aliases from penzai.treescope import renderer -from penzai.treescope.foldable_representation import basic_parts -from penzai.treescope.foldable_representation import common_structures -from penzai.treescope.foldable_representation import common_styles -from penzai.treescope.foldable_representation import part_interface +from penzai.treescope import rendering_parts def replace_with_canonical_aliases( @@ -48,8 +45,8 @@ def replace_with_canonical_aliases( node_renderer: renderer.TreescopeSubtreeRenderer, summarization_threshold: int = 20, ) -> ( - part_interface.RenderableTreePart - | part_interface.RenderableAndLineAnnotations + rendering_parts.RenderableTreePart + | rendering_parts.RenderableAndLineAnnotations | type(NotImplemented) ): """Rewrites objects to use well-known aliases when known. @@ -91,35 +88,37 @@ def replace_with_canonical_aliases( else: prefix, suffix = qualified_name.rsplit(".", 1) if len(qualified_name) > summarization_threshold: - name_rendering = basic_parts.siblings( - basic_parts.SummarizableCondition( - detail=common_styles.QualifiedTypeNameSpanGroup( - basic_parts.Text(prefix + ".") + name_rendering = rendering_parts.siblings( + rendering_parts.summarizable_condition( + detail=rendering_parts.qualified_type_name_style( + rendering_parts.text(prefix + ".") ) ), suffix, ) else: - name_rendering = basic_parts.siblings( - common_styles.QualifiedTypeNameSpanGroup( - basic_parts.Text(prefix + ".") + name_rendering = rendering_parts.siblings( + rendering_parts.qualified_type_name_style( + rendering_parts.text(prefix + ".") ), suffix, ) original_rendering = node_renderer(node, path=path) - return common_structures.build_custom_foldable_tree_node( - label=common_styles.CommentColorWhenExpanded( - basic_parts.siblings( - basic_parts.FoldCondition(expanded=basic_parts.Text("# ")), + return rendering_parts.build_custom_foldable_tree_node( + label=rendering_parts.comment_color_when_expanded( + rendering_parts.siblings( + rendering_parts.fold_condition( + expanded=rendering_parts.text("# ") + ), name_rendering, ) ), - contents=basic_parts.FoldCondition( - expanded=basic_parts.IndentedChildren.build([original_rendering]) + contents=rendering_parts.fold_condition( + expanded=rendering_parts.indented_children([original_rendering]) ), # Aliases should start collapsed regardless of layout. - expand_state=basic_parts.ExpandState.COLLAPSED, + expand_state=rendering_parts.ExpandState.COLLAPSED, path=path, ) diff --git a/penzai/treescope/handlers/custom_type_handlers.py b/penzai/treescope/_internal/handlers/custom_type_handlers.py similarity index 92% rename from penzai/treescope/handlers/custom_type_handlers.py rename to penzai/treescope/_internal/handlers/custom_type_handlers.py index 0dc4344..6464f28 100644 --- a/penzai/treescope/handlers/custom_type_handlers.py +++ b/penzai/treescope/_internal/handlers/custom_type_handlers.py @@ -18,10 +18,10 @@ from typing import Any -from penzai.treescope import object_inspection from penzai.treescope import renderer +from penzai.treescope import rendering_parts from penzai.treescope import type_registries -from penzai.treescope.foldable_representation import part_interface +from penzai.treescope._internal import object_inspection def handle_via_penzai_repr_method( @@ -29,8 +29,8 @@ def handle_via_penzai_repr_method( path: str | None, subtree_renderer: renderer.TreescopeSubtreeRenderer, ) -> ( - part_interface.RenderableTreePart - | part_interface.RenderableAndLineAnnotations + rendering_parts.RenderableTreePart + | rendering_parts.RenderableAndLineAnnotations | type(NotImplemented) ): """Renders a type by calling its __penzai_repr__ method, if it exists. @@ -71,8 +71,8 @@ def handle_via_global_registry( path: str | None, subtree_renderer: renderer.TreescopeSubtreeRenderer, ) -> ( - part_interface.RenderableTreePart - | part_interface.RenderableAndLineAnnotations + rendering_parts.RenderableTreePart + | rendering_parts.RenderableAndLineAnnotations | type(NotImplemented) ): """Renders a type by looking it up in the global handler registry. diff --git a/penzai/treescope/handlers/function_reflection_handlers.py b/penzai/treescope/_internal/handlers/function_reflection_handlers.py similarity index 68% rename from penzai/treescope/handlers/function_reflection_handlers.py rename to penzai/treescope/_internal/handlers/function_reflection_handlers.py index 32fbfa5..82c8638 100644 --- a/penzai/treescope/handlers/function_reflection_handlers.py +++ b/penzai/treescope/_internal/handlers/function_reflection_handlers.py @@ -26,12 +26,9 @@ import re from typing import Any -from penzai.treescope import html_escaping from penzai.treescope import renderer -from penzai.treescope.foldable_representation import basic_parts -from penzai.treescope.foldable_representation import common_structures -from penzai.treescope.foldable_representation import common_styles -from penzai.treescope.foldable_representation import part_interface +from penzai.treescope import rendering_parts +from penzai.treescope._internal import html_escaping @functools.cache @@ -51,7 +48,7 @@ def _get_filepath_and_lineno(value) -> tuple[str, int] | tuple[None, None]: def format_source_location( filepath: str, lineno: int -) -> part_interface.RenderableTreePart: +) -> rendering_parts.RenderableTreePart: """Formats a reference to a given filepath and line number.""" # Try to match it as an IPython file @@ -60,9 +57,9 @@ def format_source_location( ) if ipython_output_path: cell_number = ipython_output_path.group("cell_number") - return basic_parts.Text(f"line {lineno} of output cell {cell_number}") + return rendering_parts.text(f"line {lineno} of output cell {cell_number}") - return basic_parts.Text(f"line {lineno} of {filepath}") + return rendering_parts.text(f"line {lineno} of {filepath}") def handle_code_objects_with_reflection( @@ -71,8 +68,8 @@ def handle_code_objects_with_reflection( subtree_renderer: renderer.TreescopeSubtreeRenderer, show_closure_vars: bool = False, ) -> ( - part_interface.RenderableTreePart - | part_interface.RenderableAndLineAnnotations + rendering_parts.RenderableTreePart + | rendering_parts.RenderableAndLineAnnotations | type(NotImplemented) ): """Renders code objects using source-code reflection and closure inspection.""" @@ -93,28 +90,30 @@ def handle_code_objects_with_reflection( filepath, lineno = _get_filepath_and_lineno(node) if filepath is not None: annotations.append( - common_styles.CommentColor( - basic_parts.siblings( - basic_parts.Text(" # Defined at "), + rendering_parts.comment_color( + rendering_parts.siblings( + rendering_parts.text(" # Defined at "), format_source_location(filepath, lineno), ) ) ) if closure_vars: - boxed_closure_var_rendering = common_styles.DashedGrayOutlineBox( - basic_parts.OnSeparateLines.build([ - common_styles.CommentColor( - basic_parts.Text("# Closure variables:") + boxed_closure_var_rendering = rendering_parts.dashed_gray_outline_box( + rendering_parts.on_separate_lines([ + rendering_parts.comment_color( + rendering_parts.text("# Closure variables:") ), subtree_renderer(closure_vars), ]) ) - return basic_parts.siblings_with_annotations( - common_structures.build_custom_foldable_tree_node( - label=common_styles.AbbreviationColor(basic_parts.Text(repr(node))), - contents=basic_parts.FoldCondition( - expanded=basic_parts.IndentedChildren.build( + return rendering_parts.siblings_with_annotations( + rendering_parts.build_custom_foldable_tree_node( + label=rendering_parts.abbreviation_color( + rendering_parts.text(repr(node)) + ), + contents=rendering_parts.fold_condition( + expanded=rendering_parts.indented_children( [boxed_closure_var_rendering] ) ), @@ -124,9 +123,11 @@ def handle_code_objects_with_reflection( ) else: - return basic_parts.siblings_with_annotations( - common_structures.build_one_line_tree_node( - line=common_styles.AbbreviationColor(basic_parts.Text(repr(node))), + return rendering_parts.siblings_with_annotations( + rendering_parts.build_one_line_tree_node( + line=rendering_parts.abbreviation_color( + rendering_parts.text(repr(node)) + ), path=path, ), extra_annotations=annotations, diff --git a/penzai/treescope/handlers/generic_pytree_handler.py b/penzai/treescope/_internal/handlers/generic_pytree_handler.py similarity index 65% rename from penzai/treescope/handlers/generic_pytree_handler.py rename to penzai/treescope/_internal/handlers/generic_pytree_handler.py index 02cb425..69448c0 100644 --- a/penzai/treescope/handlers/generic_pytree_handler.py +++ b/penzai/treescope/_internal/handlers/generic_pytree_handler.py @@ -17,12 +17,9 @@ import sys from typing import Any +from penzai.treescope import handlers from penzai.treescope import renderer -from penzai.treescope.foldable_representation import basic_parts -from penzai.treescope.foldable_representation import common_structures -from penzai.treescope.foldable_representation import common_styles -from penzai.treescope.foldable_representation import part_interface -from penzai.treescope.handlers import generic_repr_handler +from penzai.treescope import rendering_parts def handle_arbitrary_pytrees( @@ -30,8 +27,8 @@ def handle_arbitrary_pytrees( path: str | None, subtree_renderer: renderer.TreescopeSubtreeRenderer, ) -> ( - part_interface.RenderableTreePart - | part_interface.RenderableAndLineAnnotations + rendering_parts.RenderableTreePart + | rendering_parts.RenderableAndLineAnnotations | type(NotImplemented) ): """Generic foldable fallback for an unrecognized pytree type.""" @@ -52,7 +49,7 @@ def handle_arbitrary_pytrees( return NotImplemented # First, render the object with repr. - repr_rendering = generic_repr_handler.handle_anything_with_repr( + repr_rendering = handlers.handle_anything_with_repr( node=node, path=path, subtree_renderer=subtree_renderer, @@ -63,7 +60,7 @@ def handle_arbitrary_pytrees( for (key,), child in paths_and_subtrees: child_path = None if path is None else path + str(key) list_items.append( - basic_parts.siblings_with_annotations( + rendering_parts.siblings_with_annotations( subtree_renderer(key, path=None), ": ", subtree_renderer(child, path=child_path), @@ -71,22 +68,22 @@ def handle_arbitrary_pytrees( ) ) - boxed_pytree_children = basic_parts.IndentedChildren.build([ - common_styles.DashedGrayOutlineBox( - basic_parts.build_full_line_with_annotations( - common_structures.build_custom_foldable_tree_node( - label=common_styles.CommentColor( - basic_parts.Text("# PyTree children: ") + boxed_pytree_children = rendering_parts.indented_children([ + rendering_parts.dashed_gray_outline_box( + rendering_parts.build_full_line_with_annotations( + rendering_parts.build_custom_foldable_tree_node( + label=rendering_parts.comment_color( + rendering_parts.text("# PyTree children: ") ), - contents=basic_parts.IndentedChildren.build(list_items), + contents=rendering_parts.indented_children(list_items), ) ), ) ]) - return basic_parts.siblings_with_annotations( + return rendering_parts.siblings_with_annotations( repr_rendering, - basic_parts.FoldCondition( - expanded=basic_parts.RoundtripCondition( + rendering_parts.fold_condition( + expanded=rendering_parts.roundtrip_condition( not_roundtrip=boxed_pytree_children ) ), diff --git a/penzai/treescope/handlers/generic_repr_handler.py b/penzai/treescope/_internal/handlers/generic_repr_handler.py similarity index 53% rename from penzai/treescope/handlers/generic_repr_handler.py rename to penzai/treescope/_internal/handlers/generic_repr_handler.py index 1a496b1..5fa4c83 100644 --- a/penzai/treescope/handlers/generic_repr_handler.py +++ b/penzai/treescope/_internal/handlers/generic_repr_handler.py @@ -20,14 +20,7 @@ from penzai.treescope import copypaste_fallback from penzai.treescope import renderer -from penzai.treescope.foldable_representation import basic_parts -from penzai.treescope.foldable_representation import common_structures -from penzai.treescope.foldable_representation import common_styles -from penzai.treescope.foldable_representation import part_interface - - -CSSStyleRule = part_interface.CSSStyleRule -HtmlContextForSetup = part_interface.HtmlContextForSetup +from penzai.treescope import rendering_parts def handle_anything_with_repr( @@ -35,8 +28,8 @@ def handle_anything_with_repr( path: str | None, subtree_renderer: renderer.TreescopeSubtreeRenderer, ) -> ( - part_interface.RenderableTreePart - | part_interface.RenderableAndLineAnnotations + rendering_parts.RenderableTreePart + | rendering_parts.RenderableAndLineAnnotations | type(NotImplemented) ): """Builds a foldable from its repr.""" @@ -52,60 +45,66 @@ def handle_anything_with_repr( # comment lines = node_repr.split("\n") lines_with_markers = [ - basic_parts.siblings( - basic_parts.Text(line), - basic_parts.FoldCondition( - expanded=basic_parts.Text("\n"), - collapsed=common_styles.CommentColor(basic_parts.Text("↩")), + rendering_parts.siblings( + rendering_parts.text(line), + rendering_parts.fold_condition( + expanded=rendering_parts.text("\n"), + collapsed=rendering_parts.comment_color( + rendering_parts.text("↩") + ), ), ) for line in lines[:-1] ] - lines_with_markers.append(basic_parts.Text(lines[-1])) - return basic_parts.siblings_with_annotations( - common_structures.build_custom_foldable_tree_node( - contents=basic_parts.RoundtripCondition( + lines_with_markers.append(rendering_parts.text(lines[-1])) + return rendering_parts.siblings_with_annotations( + rendering_parts.build_custom_foldable_tree_node( + contents=rendering_parts.roundtrip_condition( roundtrip=fallback, - not_roundtrip=common_styles.AbbreviationColor( - basic_parts.siblings(*lines_with_markers) + not_roundtrip=rendering_parts.abbreviation_color( + rendering_parts.siblings(*lines_with_markers) ), ), path=path, ), extra_annotations=[ - common_styles.CommentColor(basic_parts.Text(" # " + basic_repr)) + rendering_parts.comment_color( + rendering_parts.text(" # " + basic_repr) + ) ], ) else: # Use basic repr as the summary. - return common_structures.build_custom_foldable_tree_node( - label=basic_parts.RoundtripCondition( + return rendering_parts.build_custom_foldable_tree_node( + label=rendering_parts.roundtrip_condition( roundtrip=fallback, - not_roundtrip=common_styles.AbbreviationColor( - common_styles.CommentColorWhenExpanded( - basic_parts.siblings( - basic_parts.FoldCondition( - expanded=basic_parts.Text("# ") + not_roundtrip=rendering_parts.abbreviation_color( + rendering_parts.comment_color_when_expanded( + rendering_parts.siblings( + rendering_parts.fold_condition( + expanded=rendering_parts.text("# ") ), basic_repr, ) ) ), ), - contents=basic_parts.FoldCondition( - expanded=basic_parts.IndentedChildren.build( - [common_styles.AbbreviationColor(basic_parts.Text(node_repr))] - ) + contents=rendering_parts.fold_condition( + expanded=rendering_parts.indented_children([ + rendering_parts.abbreviation_color( + rendering_parts.text(node_repr) + ) + ]) ), path=path, ) elif node_repr == basic_repr: # Just use the basic repr as the summary, since we don't have anything else. - return common_structures.build_one_line_tree_node( - line=basic_parts.RoundtripCondition( + return rendering_parts.build_one_line_tree_node( + line=rendering_parts.roundtrip_condition( roundtrip=fallback, - not_roundtrip=common_styles.AbbreviationColor( - basic_parts.Text(node_repr) + not_roundtrip=rendering_parts.abbreviation_color( + rendering_parts.text(node_repr) ), ), path=path, @@ -114,17 +113,19 @@ def handle_anything_with_repr( # Use the custom repr as a one-line summary, but float the basic repr to # the right to tell the user what the type is in case the custom repr # doesn't include that info. - return basic_parts.siblings_with_annotations( - common_structures.build_one_line_tree_node( - line=basic_parts.RoundtripCondition( + return rendering_parts.siblings_with_annotations( + rendering_parts.build_one_line_tree_node( + line=rendering_parts.roundtrip_condition( roundtrip=fallback, - not_roundtrip=common_styles.AbbreviationColor( - basic_parts.Text(node_repr) + not_roundtrip=rendering_parts.abbreviation_color( + rendering_parts.text(node_repr) ), ), path=path, ), extra_annotations=[ - common_styles.CommentColor(basic_parts.Text(" # " + basic_repr)) + rendering_parts.comment_color( + rendering_parts.text(" # " + basic_repr) + ) ], ) diff --git a/penzai/treescope/handlers/interop/jax_support.py b/penzai/treescope/_internal/handlers/interop/jax_support.py similarity index 90% rename from penzai/treescope/handlers/interop/jax_support.py rename to penzai/treescope/_internal/handlers/interop/jax_support.py index 6c2fe66..8c77c6c 100644 --- a/penzai/treescope/handlers/interop/jax_support.py +++ b/penzai/treescope/_internal/handlers/interop/jax_support.py @@ -23,15 +23,13 @@ from penzai.treescope import canonical_aliases from penzai.treescope import context from penzai.treescope import dtype_util +from penzai.treescope import lowering from penzai.treescope import ndarray_adapters from penzai.treescope import renderer +from penzai.treescope import rendering_parts from penzai.treescope import repr_lib from penzai.treescope import type_registries -from penzai.treescope.foldable_representation import basic_parts -from penzai.treescope.foldable_representation import common_structures -from penzai.treescope.foldable_representation import common_styles -from penzai.treescope.foldable_representation import foldable_impl -from penzai.treescope.foldable_representation import part_interface +from penzai.treescope._internal.parts import part_interface # pylint: disable=g-import-not-at-top try: @@ -295,8 +293,8 @@ def render_shape_dtype_struct( path: str | None, subtree_renderer: renderer.TreescopeSubtreeRenderer, ) -> ( - part_interface.RenderableTreePart - | part_interface.RenderableAndLineAnnotations + rendering_parts.RenderableTreePart + | rendering_parts.RenderableAndLineAnnotations | type(NotImplemented) ): """Renders jax.ShapeDtypeStruct.""" @@ -332,8 +330,8 @@ def render_precision( path: str | None, subtree_renderer: renderer.TreescopeSubtreeRenderer, ) -> ( - part_interface.RenderableTreePart - | part_interface.RenderableAndLineAnnotations + rendering_parts.RenderableTreePart + | rendering_parts.RenderableAndLineAnnotations | type(NotImplemented) ): """Renders jax.lax.Precision.""" @@ -496,8 +494,8 @@ def render_jax_arrays( path: str | None, subtree_renderer: renderer.TreescopeSubtreeRenderer, ) -> ( - part_interface.RenderableTreePart - | part_interface.RenderableAndLineAnnotations + rendering_parts.RenderableTreePart + | rendering_parts.RenderableAndLineAnnotations | type(NotImplemented) ): """Renders a JAX array.""" @@ -510,14 +508,16 @@ def render_jax_arrays( adapter = JAXArrayAdapter() if node.is_deleted(): - return common_styles.ErrorColor( - basic_parts.Text("<" + adapter.get_array_summary(node, fast=True) + ">") + return rendering_parts.error_color( + rendering_parts.text( + "<" + adapter.get_array_summary(node, fast=True) + ">" + ) ) - def _placeholder() -> part_interface.RenderableTreePart: - return common_structures.fake_placeholder_foldable( - common_styles.DeferredPlaceholderStyle( - basic_parts.Text(adapter.get_array_summary(node, fast=True)) + def _placeholder() -> rendering_parts.RenderableTreePart: + return rendering_parts.fake_placeholder_foldable( + rendering_parts.deferred_placeholder_style( + rendering_parts.text(adapter.get_array_summary(node, fast=True)) ), extra_newlines_guess=8, ) @@ -526,8 +526,8 @@ def _thunk(placeholder): # Is this array simple enough to render without a summary? node_repr = faster_array_repr(node) if "\n" not in node_repr and "..." not in node_repr: - rendering = common_styles.AbbreviationColor( - basic_parts.Text(f"") + rendering = rendering_parts.abbreviation_color( + rendering_parts.text(f"") ) else: if node_repr.count("\n") <= 15: @@ -538,28 +538,28 @@ def _thunk(placeholder): default_expand_state = part_interface.ExpandState.WEAKLY_EXPANDED else: # Always start big NDArrays in collapsed mode to hide irrelevant detail. - default_expand_state = part_interface.ExpandState.COLLAPSED + default_expand_state = rendering_parts.ExpandState.COLLAPSED # Render it with a summary. summarized = adapter.get_array_summary(node, fast=False) - rendering = common_structures.build_custom_foldable_tree_node( - label=common_styles.AbbreviationColor( - common_styles.CommentColorWhenExpanded( - basic_parts.siblings( - basic_parts.FoldCondition( - expanded=basic_parts.Text("# "), - collapsed=basic_parts.Text("<"), + rendering = rendering_parts.build_custom_foldable_tree_node( + label=rendering_parts.abbreviation_color( + rendering_parts.comment_color_when_expanded( + rendering_parts.siblings( + rendering_parts.fold_condition( + expanded=rendering_parts.text("# "), + collapsed=rendering_parts.text("<"), ), summarized, - basic_parts.FoldCondition( - collapsed=basic_parts.Text(">") + rendering_parts.fold_condition( + collapsed=rendering_parts.text(">") ), ) ) ), - contents=basic_parts.FoldCondition( - expanded=basic_parts.IndentedChildren.build( - [basic_parts.Text(node_repr)] + contents=rendering_parts.fold_condition( + expanded=rendering_parts.indented_children( + [rendering_parts.text(node_repr)] ) ), path=path, @@ -568,11 +568,11 @@ def _thunk(placeholder): return rendering - return basic_parts.RenderableAndLineAnnotations( - renderable=foldable_impl.maybe_defer_rendering( + return rendering_parts.RenderableAndLineAnnotations( + renderable=lowering.maybe_defer_rendering( main_thunk=_thunk, placeholder_thunk=_placeholder ), - annotations=common_structures.build_copy_button(path), + annotations=rendering_parts.build_copy_button(path), ) diff --git a/penzai/treescope/handlers/interop/numpy_support.py b/penzai/treescope/_internal/handlers/interop/numpy_support.py similarity index 83% rename from penzai/treescope/handlers/interop/numpy_support.py rename to penzai/treescope/_internal/handlers/interop/numpy_support.py index 0e4ee89..8e5c4d4 100644 --- a/penzai/treescope/handlers/interop/numpy_support.py +++ b/penzai/treescope/_internal/handlers/interop/numpy_support.py @@ -20,14 +20,12 @@ import numpy as np from penzai.treescope import canonical_aliases from penzai.treescope import dtype_util +from penzai.treescope import lowering from penzai.treescope import ndarray_adapters from penzai.treescope import renderer +from penzai.treescope import rendering_parts from penzai.treescope import type_registries -from penzai.treescope.foldable_representation import basic_parts -from penzai.treescope.foldable_representation import common_structures -from penzai.treescope.foldable_representation import common_styles -from penzai.treescope.foldable_representation import foldable_impl -from penzai.treescope.foldable_representation import part_interface +from penzai.treescope._internal.parts import part_interface def _truncate_and_copy( @@ -204,8 +202,8 @@ def render_ndarrays( path: str | None, subtree_renderer: renderer.TreescopeSubtreeRenderer, ) -> ( - part_interface.RenderableTreePart - | part_interface.RenderableAndLineAnnotations + rendering_parts.RenderableTreePart + | rendering_parts.RenderableAndLineAnnotations | type(NotImplemented) ): """Renders a numpy array.""" @@ -213,10 +211,10 @@ def render_ndarrays( assert isinstance(node, np.ndarray) adapter = NumpyArrayAdapter() - def _placeholder() -> part_interface.RenderableTreePart: - return common_structures.fake_placeholder_foldable( - common_styles.DeferredPlaceholderStyle( - basic_parts.Text(adapter.get_array_summary(node, fast=True)) + def _placeholder() -> rendering_parts.RenderableTreePart: + return rendering_parts.fake_placeholder_foldable( + rendering_parts.deferred_placeholder_style( + rendering_parts.text(adapter.get_array_summary(node, fast=True)) ), extra_newlines_guess=8, ) @@ -225,7 +223,7 @@ def _thunk(placeholder): # Is this array simple enough to render without a summary? node_repr = repr(node) if "\n" not in node_repr and "..." not in node_repr: - rendering = basic_parts.Text(f"np.{node_repr}") + rendering = rendering_parts.text(f"np.{node_repr}") else: if node_repr.count("\n") <= 15: if isinstance(placeholder, part_interface.FoldableTreeNode): @@ -235,28 +233,28 @@ def _thunk(placeholder): default_expand_state = part_interface.ExpandState.WEAKLY_EXPANDED else: # Always start big NDArrays in collapsed mode to hide irrelevant detail. - default_expand_state = part_interface.ExpandState.COLLAPSED + default_expand_state = rendering_parts.ExpandState.COLLAPSED # Render it with a summary. summarized = adapter.get_array_summary(node, fast=False) - rendering = common_structures.build_custom_foldable_tree_node( - label=common_styles.AbbreviationColor( - common_styles.CommentColorWhenExpanded( - basic_parts.siblings( - basic_parts.FoldCondition( - expanded=basic_parts.Text("# "), - collapsed=basic_parts.Text("<"), + rendering = rendering_parts.build_custom_foldable_tree_node( + label=rendering_parts.abbreviation_color( + rendering_parts.comment_color_when_expanded( + rendering_parts.siblings( + rendering_parts.fold_condition( + expanded=rendering_parts.text("# "), + collapsed=rendering_parts.text("<"), ), summarized, - basic_parts.FoldCondition( - collapsed=basic_parts.Text(">") + rendering_parts.fold_condition( + collapsed=rendering_parts.text(">") ), ) ) ), - contents=basic_parts.FoldCondition( - expanded=basic_parts.IndentedChildren.build( - [basic_parts.Text(node_repr)] + contents=rendering_parts.fold_condition( + expanded=rendering_parts.indented_children( + [rendering_parts.text(node_repr)] ) ), path=path, @@ -265,11 +263,11 @@ def _thunk(placeholder): return rendering - return basic_parts.RenderableAndLineAnnotations( - renderable=foldable_impl.maybe_defer_rendering( + return rendering_parts.RenderableAndLineAnnotations( + renderable=lowering.maybe_defer_rendering( main_thunk=_thunk, placeholder_thunk=_placeholder ), - annotations=common_structures.build_copy_button(path), + annotations=rendering_parts.build_copy_button(path), ) @@ -278,8 +276,8 @@ def render_dtype_instances( path: str | None, subtree_renderer: renderer.TreescopeSubtreeRenderer, ) -> ( - part_interface.RenderableTreePart - | part_interface.RenderableAndLineAnnotations + rendering_parts.RenderableTreePart + | rendering_parts.RenderableAndLineAnnotations | type(NotImplemented) ): """Renders a np.dtype, adding the `np.` qualifier.""" @@ -299,9 +297,11 @@ def render_dtype_instances( # and add the "numpy." prefix as needed. dtype_string = repr(node) - return common_structures.build_one_line_tree_node( - line=basic_parts.siblings( - basic_parts.RoundtripCondition(roundtrip=basic_parts.Text("np.")), + return rendering_parts.build_one_line_tree_node( + line=rendering_parts.siblings( + rendering_parts.roundtrip_condition( + roundtrip=rendering_parts.text("np.") + ), dtype_string, ), path=path, diff --git a/penzai/treescope/handlers/interop/penzai_core_support.py b/penzai/treescope/_internal/handlers/interop/penzai_core_support.py similarity index 98% rename from penzai/treescope/handlers/interop/penzai_core_support.py rename to penzai/treescope/_internal/handlers/interop/penzai_core_support.py index 602e543..90f1b27 100644 --- a/penzai/treescope/handlers/interop/penzai_core_support.py +++ b/penzai/treescope/_internal/handlers/interop/penzai_core_support.py @@ -25,7 +25,7 @@ from penzai.core._treescope_handlers import named_axes_handlers from penzai.treescope import ndarray_adapters from penzai.treescope import type_registries -from penzai.treescope.handlers.interop import jax_support +from penzai.treescope._internal.handlers.interop import jax_support class NamedArrayAdapter( diff --git a/penzai/treescope/handlers/interop/torch_support.py b/penzai/treescope/_internal/handlers/interop/torch_support.py similarity index 80% rename from penzai/treescope/handlers/interop/torch_support.py rename to penzai/treescope/_internal/handlers/interop/torch_support.py index bc6921d..8e592c1 100644 --- a/penzai/treescope/handlers/interop/torch_support.py +++ b/penzai/treescope/_internal/handlers/interop/torch_support.py @@ -22,19 +22,16 @@ import numpy as np from penzai.treescope import context from penzai.treescope import formatting_util +from penzai.treescope import lowering from penzai.treescope import ndarray_adapters from penzai.treescope import renderer +from penzai.treescope import rendering_parts from penzai.treescope import type_registries -from penzai.treescope.foldable_representation import basic_parts -from penzai.treescope.foldable_representation import common_structures -from penzai.treescope.foldable_representation import common_styles -from penzai.treescope.foldable_representation import foldable_impl -from penzai.treescope.foldable_representation import part_interface -from penzai.treescope.handlers import builtin_structure_handler +from penzai.treescope._internal.parts import part_interface # pylint: disable=g-import-not-at-top try: - import torch + import torch # pytype: disable=import-error except ImportError: assert not typing.TYPE_CHECKING torch = None @@ -264,8 +261,8 @@ def render_torch_tensors( path: str | None, subtree_renderer: renderer.TreescopeSubtreeRenderer, ) -> ( - part_interface.RenderableTreePart - | part_interface.RenderableAndLineAnnotations + rendering_parts.RenderableTreePart + | rendering_parts.RenderableAndLineAnnotations | type(NotImplemented) ): """Renders a numpy array.""" @@ -274,10 +271,10 @@ def render_torch_tensors( assert isinstance(node, torch.Tensor) adapter = TorchTensorAdapter() - def _placeholder() -> part_interface.RenderableTreePart: - return common_structures.fake_placeholder_foldable( - common_styles.DeferredPlaceholderStyle( - basic_parts.Text(adapter.get_array_summary(node, fast=True)) + def _placeholder() -> rendering_parts.RenderableTreePart: + return rendering_parts.fake_placeholder_foldable( + rendering_parts.deferred_placeholder_style( + rendering_parts.text(adapter.get_array_summary(node, fast=True)) ), extra_newlines_guess=8, ) @@ -289,7 +286,7 @@ def _thunk(placeholder): if node_repr.startswith("tensor("): # Add module path, for consistency with other Treescope renderings. node_repr = f"torch.{node_repr}" - rendering = basic_parts.Text(node_repr) + rendering = rendering_parts.text(node_repr) else: if node_repr.count("\n") <= 15: if isinstance(placeholder, part_interface.FoldableTreeNode): @@ -299,28 +296,28 @@ def _thunk(placeholder): default_expand_state = part_interface.ExpandState.WEAKLY_EXPANDED else: # Always start big NDArrays in collapsed mode to hide irrelevant detail. - default_expand_state = part_interface.ExpandState.COLLAPSED + default_expand_state = rendering_parts.ExpandState.COLLAPSED # Render it with a summary. summarized = adapter.get_array_summary(node, fast=False) - rendering = common_structures.build_custom_foldable_tree_node( - label=common_styles.AbbreviationColor( - common_styles.CommentColorWhenExpanded( - basic_parts.siblings( - basic_parts.FoldCondition( - expanded=basic_parts.Text("# "), - collapsed=basic_parts.Text("<"), + rendering = rendering_parts.build_custom_foldable_tree_node( + label=rendering_parts.abbreviation_color( + rendering_parts.comment_color_when_expanded( + rendering_parts.siblings( + rendering_parts.fold_condition( + expanded=rendering_parts.text("# "), + collapsed=rendering_parts.text("<"), ), summarized, - basic_parts.FoldCondition( - collapsed=basic_parts.Text(">") + rendering_parts.fold_condition( + collapsed=rendering_parts.text(">") ), ) ) ), - contents=basic_parts.FoldCondition( - expanded=basic_parts.IndentedChildren.build( - [basic_parts.Text(node_repr)] + contents=rendering_parts.fold_condition( + expanded=rendering_parts.indented_children( + [rendering_parts.text(node_repr)] ) ), path=path, @@ -329,11 +326,11 @@ def _thunk(placeholder): return rendering - return basic_parts.RenderableAndLineAnnotations( - renderable=foldable_impl.maybe_defer_rendering( + return rendering_parts.RenderableAndLineAnnotations( + renderable=lowering.maybe_defer_rendering( main_thunk=_thunk, placeholder_thunk=_placeholder ), - annotations=common_structures.build_copy_button(path), + annotations=rendering_parts.build_copy_button(path), ) @@ -342,29 +339,29 @@ def render_torch_modules( path: str | None, subtree_renderer: renderer.TreescopeSubtreeRenderer, ) -> ( - part_interface.RenderableTreePart - | part_interface.RenderableAndLineAnnotations + rendering_parts.RenderableTreePart + | rendering_parts.RenderableAndLineAnnotations | type(NotImplemented) ): """Renders a torch module.""" assert torch is not None, "PyTorch is not available." assert isinstance(node, torch.nn.Module) node_type = type(node) - constructor = basic_parts.siblings( - basic_parts.RoundtripCondition(roundtrip=basic_parts.Text("<")), - common_structures.maybe_qualified_type_name(node_type), + constructor = rendering_parts.siblings( + rendering_parts.roundtrip_condition(roundtrip=rendering_parts.text("<")), + rendering_parts.maybe_qualified_type_name(node_type), "(", ) - closing_suffix = basic_parts.siblings( + closing_suffix = rendering_parts.siblings( ")", - basic_parts.RoundtripCondition(roundtrip=basic_parts.Text(">")), + rendering_parts.roundtrip_condition(roundtrip=rendering_parts.text(">")), ) if hasattr(node, "__treescope_color__") and callable( node.__treescope_color__ ): background_color, background_pattern = ( - builtin_structure_handler.parse_color_and_pattern( + formatting_util.parse_simple_color_and_pattern_spec( node.__treescope_color__(), node_type.__name__ ) ) @@ -400,12 +397,14 @@ def render_torch_modules( value = vars(node)[attr] child_path = None if path is None else f"{path}.{attr}" attr_children.append( - basic_parts.build_full_line_with_annotations( - basic_parts.siblings_with_annotations( + rendering_parts.build_full_line_with_annotations( + rendering_parts.siblings_with_annotations( f"{attr}=", subtree_renderer(value, path=child_path), ",", - basic_parts.FoldCondition(collapsed=basic_parts.Text(" ")), + rendering_parts.fold_condition( + collapsed=rendering_parts.text(" ") + ), ) ) ) @@ -414,15 +413,15 @@ def render_torch_modules( else: has_attr_children_expander = True children.append( - common_structures.build_custom_foldable_tree_node( - label=basic_parts.FoldCondition( - expanded=common_styles.CommentColor( - basic_parts.Text("# Attributes:") + rendering_parts.build_custom_foldable_tree_node( + label=rendering_parts.fold_condition( + expanded=rendering_parts.comment_color( + rendering_parts.text("# Attributes:") ), ), - contents=basic_parts.OnSeparateLines.build(attr_children), + contents=rendering_parts.on_separate_lines(attr_children), path=None, - expand_state=part_interface.ExpandState.COLLAPSED, + expand_state=rendering_parts.ExpandState.COLLAPSED, ) ) else: @@ -432,11 +431,11 @@ def render_torch_modules( extra_repr = extra_repr + ", " if "\n" in extra_repr: children.append( - basic_parts.OnSeparateLines.build(extra_repr.split("\n")) + rendering_parts.on_separate_lines(extra_repr.split("\n")) ) prefers_expand = True else: - children.append(basic_parts.Text(extra_repr)) + children.append(rendering_parts.text(extra_repr)) # Render parameters and buffers for group_name, group in ( @@ -446,21 +445,23 @@ def render_torch_modules( group = list(group) if group: children.append( - basic_parts.FoldCondition( - expanded=common_styles.CommentColor( - basic_parts.Text(f"# {group_name}:") + rendering_parts.fold_condition( + expanded=rendering_parts.comment_color( + rendering_parts.text(f"# {group_name}:") ) ) ) for name, value in group: child_path = None if path is None else f"{path}.{name}" children.append( - basic_parts.build_full_line_with_annotations( - basic_parts.siblings_with_annotations( + rendering_parts.build_full_line_with_annotations( + rendering_parts.siblings_with_annotations( f"{name}=", subtree_renderer(value, path=child_path), ",", - basic_parts.FoldCondition(collapsed=basic_parts.Text(" ")), + rendering_parts.fold_condition( + collapsed=rendering_parts.text(" ") + ), ) ) ) @@ -469,9 +470,9 @@ def render_torch_modules( submodules = list(node.named_children()) if submodules: children.append( - basic_parts.FoldCondition( - expanded=common_styles.CommentColor( - basic_parts.Text("# Child modules:") + rendering_parts.fold_condition( + expanded=rendering_parts.comment_color( + rendering_parts.text("# Child modules:") ) ) ) @@ -484,12 +485,14 @@ def render_torch_modules( child_path = f"{path}.get_submodule({repr(name)})" keystr = f"({name}): " children.append( - basic_parts.build_full_line_with_annotations( - basic_parts.siblings_with_annotations( + rendering_parts.build_full_line_with_annotations( + rendering_parts.siblings_with_annotations( keystr, subtree_renderer(submod, path=child_path), ",", - basic_parts.FoldCondition(collapsed=basic_parts.Text(" ")), + rendering_parts.fold_condition( + collapsed=rendering_parts.text(" ") + ), ) ) ) @@ -501,11 +504,11 @@ def render_torch_modules( # Heuristic: If a module doesn't have any submodules, mark it collapsed, to # match the behavior of PyTorch repr. if prefers_expand: - expand_state = part_interface.ExpandState.WEAKLY_EXPANDED + expand_state = rendering_parts.ExpandState.WEAKLY_EXPANDED else: - expand_state = part_interface.ExpandState.COLLAPSED + expand_state = rendering_parts.ExpandState.COLLAPSED - return common_structures.build_foldable_tree_node_from_children( + return rendering_parts.build_foldable_tree_node_from_children( prefix=constructor, children=children, suffix=closing_suffix, diff --git a/penzai/treescope/handlers/repr_html_postprocessor.py b/penzai/treescope/_internal/handlers/repr_html_postprocessor.py similarity index 63% rename from penzai/treescope/handlers/repr_html_postprocessor.py rename to penzai/treescope/_internal/handlers/repr_html_postprocessor.py index 1c924ef..e832b72 100644 --- a/penzai/treescope/handlers/repr_html_postprocessor.py +++ b/penzai/treescope/_internal/handlers/repr_html_postprocessor.py @@ -28,13 +28,10 @@ from typing import Any from penzai.treescope import context +from penzai.treescope import lowering from penzai.treescope import renderer -from penzai.treescope.foldable_representation import basic_parts -from penzai.treescope.foldable_representation import common_structures -from penzai.treescope.foldable_representation import common_styles -from penzai.treescope.foldable_representation import embedded_iframe -from penzai.treescope.foldable_representation import foldable_impl -from penzai.treescope.foldable_representation import part_interface +from penzai.treescope import rendering_parts +from penzai.treescope._internal import object_inspection _already_processing_repr_html: context.ContextualValue[bool] = ( @@ -57,8 +54,8 @@ def append_repr_html_when_present( path: str | None, node_renderer: renderer.TreescopeSubtreeRenderer, ) -> ( - part_interface.RenderableTreePart - | part_interface.RenderableAndLineAnnotations + rendering_parts.RenderableTreePart + | rendering_parts.RenderableAndLineAnnotations | type(NotImplemented) ): """Appends rich HTML representations of objects that have them.""" @@ -66,7 +63,7 @@ def append_repr_html_when_present( # We've processed the repr_html for a parent of this node already. return NotImplemented - if not isinstance(node, embedded_iframe.HasReprHtml): + if not isinstance(node, object_inspection.HasReprHtml): return NotImplemented # Make sure we don't try to call _repr_html_ on the children of this node, @@ -77,35 +74,35 @@ def append_repr_html_when_present( node_rendering = node_renderer(node, path=path) def _thunk(_): - html_rendering = embedded_iframe.to_html(node) + html_rendering = object_inspection.to_html(node) if html_rendering: - return embedded_iframe.EmbeddedIFrame( + return rendering_parts.embedded_iframe( embedded_html=html_rendering, - fallback_in_text_mode=common_styles.AbbreviationColor( - basic_parts.Text("# (not shown in text mode)") + fallback_in_text_mode=rendering_parts.abbreviation_color( + rendering_parts.text("# (not shown in text mode)") ), ) else: - return common_styles.ErrorColor( + return rendering_parts.error_color( "# (couldn't compute HTML representation)" ) - iframe_rendering = foldable_impl.maybe_defer_rendering( + iframe_rendering = lowering.maybe_defer_rendering( main_thunk=_thunk, - placeholder_thunk=lambda: common_styles.DeferredPlaceholderStyle( - basic_parts.Text("...") + placeholder_thunk=lambda: rendering_parts.deferred_placeholder_style( + rendering_parts.text("...") ), ) - boxed_html_repr = basic_parts.IndentedChildren.build([ - basic_parts.ScopedSelectableAnnotation( - common_styles.DashedGrayOutlineBox( - basic_parts.build_full_line_with_annotations( - common_structures.build_custom_foldable_tree_node( - label=common_styles.CommentColor( - basic_parts.Text("# Rich HTML representation") + boxed_html_repr = rendering_parts.indented_children([ + rendering_parts.floating_annotation_with_separate_focus( + rendering_parts.dashed_gray_outline_box( + rendering_parts.build_full_line_with_annotations( + rendering_parts.build_custom_foldable_tree_node( + label=rendering_parts.comment_color( + rendering_parts.text("# Rich HTML representation") ), - contents=basic_parts.FoldCondition( + contents=rendering_parts.fold_condition( expanded=iframe_rendering ), ) @@ -113,9 +110,11 @@ def _thunk(_): ) ) ]) - return basic_parts.siblings_with_annotations( + return rendering_parts.siblings_with_annotations( node_rendering, - basic_parts.FoldCondition( - expanded=basic_parts.RoundtripCondition(not_roundtrip=boxed_html_repr) + rendering_parts.fold_condition( + expanded=rendering_parts.roundtrip_condition( + not_roundtrip=boxed_html_repr + ) ), ) diff --git a/penzai/treescope/handlers/shared_value_postprocessor.py b/penzai/treescope/_internal/handlers/shared_value_postprocessor.py similarity index 94% rename from penzai/treescope/handlers/shared_value_postprocessor.py rename to penzai/treescope/_internal/handlers/shared_value_postprocessor.py index 1be45e7..5a78d62 100644 --- a/penzai/treescope/handlers/shared_value_postprocessor.py +++ b/penzai/treescope/_internal/handlers/shared_value_postprocessor.py @@ -24,11 +24,12 @@ from typing import Any, Optional, Sequence from penzai.treescope import context -from penzai.treescope import html_escaping from penzai.treescope import renderer +from penzai.treescope import rendering_parts from penzai.treescope import type_registries -from penzai.treescope.foldable_representation import basic_parts -from penzai.treescope.foldable_representation import part_interface +from penzai.treescope._internal import html_escaping +from penzai.treescope._internal.parts import basic_parts +from penzai.treescope._internal.parts import part_interface @dataclasses.dataclass @@ -93,7 +94,7 @@ def _span_css_rule( @dataclasses.dataclass(frozen=False) -class DynamicSharedCheck(part_interface.RenderableTreePart): +class DynamicSharedCheck(rendering_parts.RenderableTreePart): """Dynamic group that renders its child only if a node is shared. This node is used to apply special rendering to nodes that are encountered in @@ -118,7 +119,7 @@ class DynamicSharedCheck(part_interface.RenderableTreePart): active `_shared_object_ids_seen` context. """ - if_shared: part_interface.RenderableTreePart + if_shared: rendering_parts.RenderableTreePart node_id: int seen_more_than_once: set[int] @@ -199,7 +200,7 @@ class WithDynamicSharedPip(basic_parts.DeferringToChild): active `_shared_object_ids_seen` context. """ - child: part_interface.RenderableTreePart + child: rendering_parts.RenderableTreePart node_id: int seen_more_than_once: set[int] @@ -277,8 +278,8 @@ def check_for_shared_values( path: str | None, node_renderer: renderer.TreescopeSubtreeRenderer, ) -> ( - part_interface.RenderableTreePart - | part_interface.RenderableAndLineAnnotations + rendering_parts.RenderableTreePart + | rendering_parts.RenderableAndLineAnnotations | type(NotImplemented) ): # pylint: disable=g-doc-args,g-doc-return-or-yield @@ -329,16 +330,18 @@ def check_for_shared_values( # Wrap it in a shared value wrapper; this will check to see if the same # node was seen more than once, and add an annotation if so. - return part_interface.RenderableAndLineAnnotations( + return rendering_parts.RenderableAndLineAnnotations( renderable=WithDynamicSharedPip( rendering.renderable, node_id=node_id, seen_more_than_once=shared_object_tracker.seen_more_than_once, ), - annotations=basic_parts.siblings( + annotations=rendering_parts.siblings( DynamicSharedCheck( if_shared=SharedWarningLabel( - basic_parts.Text(f" # Repeated python obj at 0x{node_id:x}") + rendering_parts.text( + f" # Repeated python obj at 0x{node_id:x}" + ) ), node_id=node_id, seen_more_than_once=shared_object_tracker.seen_more_than_once, diff --git a/penzai/treescope/html_encapsulation.py b/penzai/treescope/_internal/html_encapsulation.py similarity index 100% rename from penzai/treescope/html_encapsulation.py rename to penzai/treescope/_internal/html_encapsulation.py diff --git a/penzai/treescope/html_escaping.py b/penzai/treescope/_internal/html_escaping.py similarity index 100% rename from penzai/treescope/html_escaping.py rename to penzai/treescope/_internal/html_escaping.py diff --git a/penzai/treescope/js/arrayviz.js b/penzai/treescope/_internal/js/arrayviz.js similarity index 100% rename from penzai/treescope/js/arrayviz.js rename to penzai/treescope/_internal/js/arrayviz.js diff --git a/penzai/treescope/foldable_representation/layout_algorithms.py b/penzai/treescope/_internal/layout_algorithms.py similarity index 99% rename from penzai/treescope/foldable_representation/layout_algorithms.py rename to penzai/treescope/_internal/layout_algorithms.py index 2297813..7aab515 100644 --- a/penzai/treescope/foldable_representation/layout_algorithms.py +++ b/penzai/treescope/_internal/layout_algorithms.py @@ -19,7 +19,7 @@ import collections from typing import Any, Collection -from penzai.treescope.foldable_representation import part_interface +from penzai.treescope._internal.parts import part_interface ExpandState = part_interface.ExpandState RenderableTreePart = part_interface.RenderableTreePart diff --git a/penzai/treescope/object_inspection.py b/penzai/treescope/_internal/object_inspection.py similarity index 63% rename from penzai/treescope/object_inspection.py rename to penzai/treescope/_internal/object_inspection.py index 4d70745..97aaa92 100644 --- a/penzai/treescope/object_inspection.py +++ b/penzai/treescope/_internal/object_inspection.py @@ -16,6 +16,7 @@ from __future__ import annotations +import abc import types from typing import Any, Callable @@ -47,3 +48,30 @@ def safely_get_real_method( return retrieved except Exception: # pylint: disable=broad-exception-caught return None + + +class HasReprHtml(abc.ABC): + """Abstract base class for rich-display objects in IPython.""" + + @abc.abstractmethod + def _repr_html_(self) -> str | tuple[str, Any]: + """Returns a rich HTML representation of an object.""" + ... + + @classmethod + def __subclasshook__(cls, subclass, /): + """Checks if a class is a subclass of HasReprHtml.""" + return hasattr(subclass, '_repr_html_') and callable(subclass._repr_html_) # pylint: disable=protected-access + + +def to_html(node: Any) -> str | None: + """Extracts a rich HTML representation of node using _repr_html_.""" + repr_html_method = safely_get_real_method(node, '_repr_html_') + if repr_html_method is None: + return None + html_for_node_and_maybe_metadata = repr_html_method() + if isinstance(html_for_node_and_maybe_metadata, tuple): + html_for_node, _ = html_for_node_and_maybe_metadata + else: + html_for_node = html_for_node_and_maybe_metadata + return html_for_node diff --git a/penzai/treescope/foldable_representation/__init__.py b/penzai/treescope/_internal/parts/__init__.py similarity index 96% rename from penzai/treescope/foldable_representation/__init__.py rename to penzai/treescope/_internal/parts/__init__.py index 65f6b8c..4f5b28c 100644 --- a/penzai/treescope/foldable_representation/__init__.py +++ b/penzai/treescope/_internal/parts/__init__.py @@ -19,5 +19,4 @@ from . import common_styles from . import embedded_iframe from . import foldable_impl -from . import layout_algorithms from . import part_interface diff --git a/penzai/treescope/foldable_representation/basic_parts.py b/penzai/treescope/_internal/parts/basic_parts.py similarity index 83% rename from penzai/treescope/foldable_representation/basic_parts.py rename to penzai/treescope/_internal/parts/basic_parts.py index ada6d88..968a583 100644 --- a/penzai/treescope/foldable_representation/basic_parts.py +++ b/penzai/treescope/_internal/parts/basic_parts.py @@ -26,8 +26,8 @@ import typing from typing import Any, Sequence -from penzai.treescope import html_escaping -from penzai.treescope.foldable_representation import part_interface +from penzai.treescope._internal import html_escaping +from penzai.treescope._internal.parts import part_interface CSSStyleRule = part_interface.CSSStyleRule @@ -91,6 +91,11 @@ class EmptyPart(BaseContentlessLeaf): """A definitely-empty part, which can be detected and special-cased.""" +def empty_part() -> EmptyPart: + """Returns an empty part.""" + return EmptyPart() + + @dataclasses.dataclass(frozen=True) class Text(RenderableTreePart): """A raw text literal.""" @@ -147,34 +152,17 @@ def render_to_text( stream.write(("\n" + " " * indent).join(self.text.split("\n"))) +def text(text_content: str) -> RenderableTreePart: + """Builds a one-line text part.""" + return Text(text_content) + + @dataclasses.dataclass(frozen=True) class Siblings(RenderableTreePart): """A sequence of children parts, rendered inline.""" children: Sequence[RenderableTreePart] - @classmethod - def build(cls, *args: RenderableTreePart | str) -> Siblings: - """Builds a Siblings part from inline arguments. - - Args: - *args: Sequence of renderables or strings (which will be wrapped in Text) - - Returns: - A new Siblings part containing these concatenated together. - """ - parts = [] - for arg in args: - if isinstance(arg, str): - parts.append(Text(arg)) - elif isinstance(arg, Siblings): - parts.extend(arg.children) - elif isinstance(arg, EmptyPart): - pass - else: - parts.append(arg) - return cls(tuple(parts)) - def _compute_collapsed_width(self) -> int: return sum(part.collapsed_width for part in self.children) @@ -239,7 +227,28 @@ def render_to_text( ) -siblings = Siblings.build +def siblings(*args: RenderableTreePart | str) -> RenderableTreePart: + """Builds a Siblings part from inline arguments. + + Args: + *args: Sequence of renderables or strings (which will be wrapped in Text). + + Returns: + A new Siblings part containing these concatenated together. + """ + parts = [] + for arg in args: + if isinstance(arg, str): + parts.append(Text(arg)) + elif isinstance(arg, Siblings): + parts.extend(arg.children) + elif isinstance(arg, EmptyPart): + pass + elif isinstance(arg, RenderableTreePart): + parts.append(arg) + else: + raise ValueError(f"Invalid argument type {type(arg)}") + return Siblings(tuple(parts)) class DeferringToChild(RenderableTreePart): @@ -427,6 +436,21 @@ def render_to_html( ) +def vertical_space(css_height: str) -> RenderableTreePart: + """Returns a vertical space with the given height in HTML mode. + + Args: + css_height: The height of the space, as a CSS length string. + + Returns: + A renderable part that renders as a vertical space in HTML mode, and does + not render in text mode. + """ + if not isinstance(css_height, str): + raise ValueError(f"css_height must be a string, got {css_height}") + return VerticalSpace(height=css_height) + + ################################################################################ # Conditional rendering ################################################################################ @@ -533,6 +557,37 @@ def render_to_text( ) +def fold_condition( + collapsed: RenderableTreePart | None = None, + expanded: RenderableTreePart | None = None, +) -> RenderableTreePart: + """Builds a part that renders differently when collapsed or expanded. + + Args: + collapsed: Contents to render when parent is collapsed. + expanded: Contents to render when parent is expanded. + + Returns: + A renderable part that renders as ``collapsed`` when the parent is collapsed + and as ``expanded`` when the parent is expanded. + """ + if collapsed is None: + collapsed = EmptyPart() + if expanded is None: + expanded = EmptyPart() + if not isinstance(collapsed, RenderableTreePart): + raise ValueError( + "`collapsed` must be a renderable part or None. Got" + f" {type(collapsed).__name__}" + ) + if not isinstance(expanded, RenderableTreePart): + raise ValueError( + "`expanded` must be a renderable part or None. Got" + f" {type(expanded).__name__}" + ) + return FoldCondition(collapsed=collapsed, expanded=expanded) + + @dataclasses.dataclass(frozen=True) class RoundtripCondition(RenderableTreePart): """Renders conditionally depending on whether it's in roundtrip mode. @@ -637,6 +692,37 @@ def render_to_text( ) +def roundtrip_condition( + roundtrip: RenderableTreePart | None = None, + not_roundtrip: RenderableTreePart | None = None, +) -> RenderableTreePart: + """Builds a part that renders differently in roundtrip mode. + + Args: + roundtrip: Contents to render when rendering in round trip mode. + not_roundtrip: Contents to render when renderingin ordinary mode. + + Returns: + A renderable part that renders as ``roundtrip`` in roundtrip mode + and as ``not_roundtrip`` in ordinary mode. + """ + if roundtrip is None: + roundtrip = EmptyPart() + if not_roundtrip is None: + not_roundtrip = EmptyPart() + if not isinstance(roundtrip, RenderableTreePart): + raise ValueError( + "`roundtrip` must be a renderable part or None. Got" + f" {type(roundtrip).__name__}" + ) + if not isinstance(not_roundtrip, RenderableTreePart): + raise ValueError( + "`not_roundtrip` must be a renderable part or None. Got" + f" {type(not_roundtrip).__name__}" + ) + return RoundtripCondition(roundtrip=roundtrip, not_roundtrip=not_roundtrip) + + @dataclasses.dataclass(frozen=True) class SummarizableCondition(RenderableTreePart): """Renders conditionally depending on combination of roundtrip/collapsed. @@ -745,6 +831,41 @@ def render_to_text( ) +def summarizable_condition( + summary: RenderableTreePart | None = None, + detail: RenderableTreePart | None = None, +) -> RenderableTreePart: + """Builds a part that renders depending on combination of roundtrip/collapsed. + + The idea is that, when collapsed and not in roundtrip mode, it's sometimes + convenient to summarize a compound node with a simpler non-roundtrippable + representation. + + Args: + summary: Contents to render when collapsed and not in roundtrip mode. + detail: Contents to render when either expanded or in roundtrip mode. + + Returns: + A renderable part that renders as ``summary`` when both collpased and not + in roundtrip mode, and as ``detail`` otherwise. + """ + if summary is None: + summary = EmptyPart() + if detail is None: + detail = EmptyPart() + if not isinstance(summary, RenderableTreePart): + raise ValueError( + "`summary` must be a renderable part or None. Got" + f" {type(summary).__name__}" + ) + if not isinstance(detail, RenderableTreePart): + raise ValueError( + "`detail` must be a renderable part or None. Got" + f" {type(detail).__name__}" + ) + return SummarizableCondition(summary=summary, detail=detail) + + ################################################################################ # Line comments ################################################################################ @@ -775,17 +896,20 @@ def siblings_with_annotations( parts.append(arg) elif isinstance(arg, str): parts.append(Text(arg)) - else: + elif isinstance(arg, RenderableAndLineAnnotations): parts.append(arg.renderable) if arg.annotations is not None: annotations.append(arg.annotations) + else: + raise ValueError( + "Expected a renderable tree part (possibly with line annotations) or" + f" a string, but got: {type(arg)}" + ) for annotation in extra_annotations: annotations.append(annotation) - return RenderableAndLineAnnotations( - Siblings.build(*parts), Siblings.build(*annotations) - ) + return RenderableAndLineAnnotations(siblings(*parts), siblings(*annotations)) def build_full_line_with_annotations( @@ -816,7 +940,7 @@ def build_full_line_with_annotations( ) ): return combined.renderable - return Siblings.build( + return siblings( combined.renderable, FoldCondition(expanded=combined.annotations), ) @@ -842,17 +966,7 @@ def build( cls, children: Sequence[RenderableAndLineAnnotations | RenderableTreePart], ) -> OnSeparateLines: - """Builds a OnSeparateLines instance, supporting annotations. - - This method stacks the children together, moving any comments to the end of - their lines. - - Args: - children: Children to render. - - Returns: - New OnSeparateLines instance. - """ + """Builds a OnSeparateLines instance, supporting annotations.""" return cls([build_full_line_with_annotations(line) for line in children]) def _compute_collapsed_width(self) -> int: @@ -949,6 +1063,24 @@ def render_to_text( stream.write("\n" + " " * indent) +def on_separate_lines( + children: Sequence[RenderableAndLineAnnotations | RenderableTreePart], +) -> RenderableTreePart: + """Builds a part that renders its children on separate lines. + + The resulting part stacks the children together, moving any comments to the + end of their lines. + + Args: + children: Children to render. + + Returns: + A renderable part that renders the children on separate lines when expanded. + When collapsed, it instead concatenates them. + """ + return OnSeparateLines.build(children) + + @dataclasses.dataclass(frozen=True) class IndentedChildren(RenderableTreePart): """A sequence of children, one per line, and indented. @@ -969,29 +1101,14 @@ def build( comma_separated: bool = False, force_trailing_comma: bool = False, ) -> IndentedChildren: - """Builds a IndentedChildren instance, supporting annotations and delimiters. - - This method stacks the children together, optionally inserting delimiters, - and moving any comments to the end of their lines. - - Args: - children: Children to render. - comma_separated: Whether to automatically insert commas between children. - If False, delimiters can be manually inserted into `children` first - instead. - force_trailing_comma: Whether to render a trailing comma in collapsed - mode. - - Returns: - New IndentedChildren instance. - """ + """Builds a IndentedChildren instance.""" lines = [] for i, child in enumerate(children): if comma_separated: if i < len(children) - 1: # Not the last child. Always show a comma, and add a space when # collapsed. - delimiter = Siblings.build(",", FoldCondition(collapsed=Text(" "))) + delimiter = siblings(",", FoldCondition(collapsed=Text(" "))) elif force_trailing_comma: # Last child, forced comma. delimiter = Text(",") @@ -1107,6 +1224,33 @@ def render_to_text( stream.write("\n" + " " * indent) +def indented_children( + children: Sequence[RenderableAndLineAnnotations | RenderableTreePart], + comma_separated: bool = False, + force_trailing_comma: bool = False, +) -> RenderableTreePart: + """Builds a IndentedChildren instance, supporting annotations and delimiters. + + This method stacks the children together, optionally inserting delimiters, + and moving any comments to the end of their lines. + + Args: + children: Children to render. + comma_separated: Whether to automatically insert commas between children. If + False, delimiters can be manually inserted into `children` first instead. + force_trailing_comma: Whether to render a trailing comma in collapsed mode. + + Returns: + A renderable part that renders the children on separate lines with an + indent. When collapsed, it instead concatenates them. + """ + return IndentedChildren.build( + children=children, + comma_separated=comma_separated, + force_trailing_comma=force_trailing_comma, + ) + + @dataclasses.dataclass(frozen=True) class BaseBoxWithOutline(RenderableTreePart, abc.ABC): """An outlined box, which displays in "block" mode when rendered to HTML. @@ -1403,3 +1547,36 @@ def render_to_text( roundtrip_mode=roundtrip_mode, render_context=render_context, ) + + +def floating_annotation_with_separate_focus( + child: part_interface.RenderableTreePart, +) -> part_interface.RenderableTreePart: + """Wraps a child so that selections outside it don't include it. + + It is sometimes useful to add additional annotations to an object that aren't + pretty-printed parts of that object. This causes some problems for ordinary + roundtrip mode, since we want it to be possible to exactly round-trip an + object based on its printed representation. + + This object marks its child so that it doesn't get selected by the mouse + when the selection starts outside the node. This means that if you copy the + object normally, you don't copy the annotation, so that what you copied + stays roundtrippable. + + In text mode, selections can't be manipulated. We fake the same thing by + rendering it as a "comment" (by adding comment markers before every line) + and hiding it if collapsed. + + Args: + child: Child to render. + + Returns: + A wrapped version of ``child`` that does not participate in text selection + in HTML mode, and adds comments in text mode. + """ + if not isinstance(child, RenderableTreePart): + raise ValueError( + f"`child` must be a renderable part, but got {type(child).__name__}" + ) + return ScopedSelectableAnnotation(child) diff --git a/penzai/treescope/foldable_representation/common_structures.py b/penzai/treescope/_internal/parts/common_structures.py similarity index 95% rename from penzai/treescope/foldable_representation/common_structures.py rename to penzai/treescope/_internal/parts/common_structures.py index 3b803df..3102259 100644 --- a/penzai/treescope/foldable_representation/common_structures.py +++ b/penzai/treescope/_internal/parts/common_structures.py @@ -20,10 +20,10 @@ from typing import Any, Sequence from penzai.treescope import canonical_aliases -from penzai.treescope.foldable_representation import basic_parts -from penzai.treescope.foldable_representation import common_styles -from penzai.treescope.foldable_representation import foldable_impl -from penzai.treescope.foldable_representation import part_interface +from penzai.treescope._internal.parts import basic_parts +from penzai.treescope._internal.parts import common_styles +from penzai.treescope._internal.parts import foldable_impl +from penzai.treescope._internal.parts import part_interface CSSStyleRule = part_interface.CSSStyleRule @@ -106,9 +106,7 @@ def build_one_line_tree_node( if isinstance(line, RenderableAndLineAnnotations): line_primary = line.renderable - annotations = basic_parts.Siblings.build( - maybe_copy_button, line.annotations - ) + annotations = basic_parts.siblings(maybe_copy_button, line.annotations) elif isinstance(line, str): line_primary = basic_parts.Text(line) annotations = maybe_copy_button @@ -182,7 +180,7 @@ def build_foldable_tree_node_from_children( """ if not children: return build_one_line_tree_node( - line=basic_parts.Siblings.build(prefix, suffix), + line=basic_parts.siblings(prefix, suffix), path=path, background_color=background_color, ) @@ -239,7 +237,7 @@ def wrap_block(block): wrap_topline(prefix), keypath=path, ), - contents=basic_parts.Siblings.build( + contents=basic_parts.siblings( maybe_copy_button, maybe_first_line_annotation, indented_child_class.build( @@ -278,7 +276,7 @@ def maybe_qualified_type_name(ty: type[Any]) -> RenderableTreePart: access_path = f".{class_name}" if access_path.endswith(class_name): - return basic_parts.Siblings.build( + return basic_parts.siblings( basic_parts.RoundtripCondition( roundtrip=common_styles.QualifiedTypeNameSpanGroup( basic_parts.Text(access_path.removesuffix(class_name)) diff --git a/penzai/treescope/foldable_representation/common_styles.py b/penzai/treescope/_internal/parts/common_styles.py similarity index 81% rename from penzai/treescope/foldable_representation/common_styles.py rename to penzai/treescope/_internal/parts/common_styles.py index 95a7ee9..6c72e57 100644 --- a/penzai/treescope/foldable_representation/common_styles.py +++ b/penzai/treescope/_internal/parts/common_styles.py @@ -20,9 +20,9 @@ import io from typing import Any -from penzai.treescope import html_escaping -from penzai.treescope.foldable_representation import basic_parts -from penzai.treescope.foldable_representation import part_interface +from penzai.treescope._internal import html_escaping +from penzai.treescope._internal.parts import basic_parts +from penzai.treescope._internal.parts import part_interface CSSStyleRule = part_interface.CSSStyleRule @@ -46,6 +46,17 @@ def _span_css_rule(self, context: HtmlContextForSetup) -> CSSStyleRule: """)) +def abbreviation_color( + child: part_interface.RenderableTreePart, +) -> part_interface.RenderableTreePart: + """Returns a wrapped child in a color for non-roundtrippable abbreviations.""" + if not isinstance(child, RenderableTreePart): + raise ValueError( + f"`child` must be a renderable part, but got {type(child).__name__}" + ) + return AbbreviationColor(child) + + class CommentColor(basic_parts.BaseSpanGroup): """Renders its child in a color for comments.""" @@ -63,6 +74,17 @@ def _span_css_rule(self, context: HtmlContextForSetup) -> CSSStyleRule: """)) +def comment_color( + child: part_interface.RenderableTreePart, +) -> part_interface.RenderableTreePart: + """Returns a wrapped child in a color for comments.""" + if not isinstance(child, RenderableTreePart): + raise ValueError( + f"`child` must be a renderable part, but got {type(child).__name__}" + ) + return CommentColor(child) + + class ErrorColor(basic_parts.BaseSpanGroup): """Renders its child in red to indicate errors / problems during rendering.""" @@ -77,6 +99,17 @@ def _span_css_rule(self, context: HtmlContextForSetup) -> CSSStyleRule: """)) +def error_color( + child: part_interface.RenderableTreePart, +) -> part_interface.RenderableTreePart: + """Returns a wrapped child in red to indicate errors / problems during rendering.""" + if not isinstance(child, RenderableTreePart): + raise ValueError( + f"`child` must be a renderable part, but got {type(child).__name__}" + ) + return ErrorColor(child) + + class DeferredPlaceholderStyle(basic_parts.BaseSpanGroup): """Renders its child in italics to indicate a deferred placeholder.""" @@ -92,6 +125,17 @@ def _span_css_rule(self, context: HtmlContextForSetup) -> CSSStyleRule: """)) +def deferred_placeholder_style( + child: part_interface.RenderableTreePart, +) -> part_interface.RenderableTreePart: + """Returns a wrapped child in italics to indicate a deferred placeholder.""" + if not isinstance(child, RenderableTreePart): + raise ValueError( + f"`child` must be a renderable part, but got {type(child).__name__}" + ) + return DeferredPlaceholderStyle(child) + + class CommentColorWhenExpanded(basic_parts.BaseSpanGroup): """Renders its child in a color for comments, but only when expanded. @@ -112,6 +156,17 @@ def _span_css_rule(self, context: HtmlContextForSetup) -> CSSStyleRule: """)) +def comment_color_when_expanded( + child: part_interface.RenderableTreePart, +) -> part_interface.RenderableTreePart: + """Returns a wrapped child in a color for comments, but only when expanded.""" + if not isinstance(child, RenderableTreePart): + raise ValueError( + f"`child` must be a renderable part, but got {type(child).__name__}" + ) + return CommentColorWhenExpanded(child) + + @dataclasses.dataclass(frozen=True) class CustomTextColor(basic_parts.DeferringToChild): """A group that wraps its child in a span with a custom text color. @@ -141,8 +196,23 @@ def render_to_html( stream.write("") +def custom_text_color( + child: part_interface.RenderableTreePart, css_color: str +) -> part_interface.RenderableTreePart: + """Returns a wrapped child that renders in a particular CSS color.""" + if not isinstance(child, RenderableTreePart): + raise ValueError( + f"`child` must be a renderable part, but got {type(child).__name__}" + ) + if not isinstance(css_color, str): + raise ValueError( + f"`css_color` must be a string, but got {type(css_color).__name__}" + ) + return CustomTextColor(child, css_color) + + class DashedGrayOutlineBox(basic_parts.BaseBoxWithOutline): - """A highlighted box that identifies a part as being selected.""" + """A dashed gray box.""" def _box_css_class(self) -> str: return "dashed_gray_outline" @@ -160,6 +230,17 @@ def _box_css_rule( ) +def dashed_gray_outline_box( + child: part_interface.RenderableTreePart, +) -> part_interface.RenderableTreePart: + """Returns a wrapped child that displays in a dashed gray box.""" + if not isinstance(child, RenderableTreePart): + raise ValueError( + f"`child` must be a renderable part, but got {type(child).__name__}" + ) + return DashedGrayOutlineBox(child) + + @dataclasses.dataclass(frozen=True) class ColoredBorderIndentedChildren(basic_parts.IndentedChildren): """A sequence of children that also draws a colored line on the left. @@ -488,3 +569,14 @@ def _span_css_rule(self, context: HtmlContextForSetup) -> CSSStyleRule: font-size: 0.8em; } """)) + + +def qualified_type_name_style( + child: part_interface.RenderableTreePart, +) -> part_interface.RenderableTreePart: + """Returns a wrapped child in a small font to indicate a qualified name.""" + if not isinstance(child, RenderableTreePart): + raise ValueError( + f"`child` must be a renderable part, but got {type(child).__name__}" + ) + return QualifiedTypeNameSpanGroup(child) diff --git a/penzai/treescope/_internal/parts/custom_dataclass_util.py b/penzai/treescope/_internal/parts/custom_dataclass_util.py new file mode 100644 index 0000000..be3bca9 --- /dev/null +++ b/penzai/treescope/_internal/parts/custom_dataclass_util.py @@ -0,0 +1,161 @@ +# Copyright 2024 The Penzai Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Handler for built-in collection types (lists, sets, dicts, etc).""" + +from __future__ import annotations + +import dataclasses +from typing import Any, Callable, Optional, Sequence + +from penzai.treescope import dataclass_util +from penzai.treescope import renderer +from penzai.treescope._internal import layout_algorithms +from penzai.treescope._internal.parts import basic_parts +from penzai.treescope._internal.parts import common_structures +from penzai.treescope._internal.parts import common_styles +from penzai.treescope._internal.parts import part_interface + + +CSSStyleRule = part_interface.CSSStyleRule +HtmlContextForSetup = part_interface.HtmlContextForSetup + + +def build_field_children( + node: Any, + path: str | None, + subtree_renderer: renderer.TreescopeSubtreeRenderer, + fields_or_attribute_names: Sequence[dataclasses.Field[Any] | str], + attr_style_fn: ( + Callable[[str], part_interface.RenderableTreePart] | None + ) = None, +) -> list[part_interface.RenderableTreePart]: + """Renders a set of fields/attributes into a list of comma-separated children. + + This is a helper function used for rendering dataclasses, namedtuples, and + similar objects, of the form :: + + ClassName( + field_name_one=value1, + field_name_two=value2, + ) + + If `fields_or_attribute_names` includes dataclass fields: + + * Metadata for the fields will be visible on hover, + + * Fields with ``repr=False`` will be hidden unless roundtrip mode is enabled. + + Args: + node: Node to render. + path: Path to this node. + subtree_renderer: How to render subtrees (see `TreescopeSubtreeRenderer`) + fields_or_attribute_names: Sequence of fields or attribute names to render. + Any field with the metadata key "treescope_always_collapse" set to True + will always render collapsed. + attr_style_fn: Optional function which makes attributes to a part that + should render them. If not provided, all parts are rendered as plain text. + + Returns: + A list of child objects. This can be passed to + `common_structures.build_foldable_tree_node_from_children` (with + ``comma_separated=False``) + """ + if attr_style_fn is None: + attr_style_fn = basic_parts.Text + + field_names = [] + fields: list[Optional[dataclasses.Field[Any]]] = [] + for field_or_name in fields_or_attribute_names: + if isinstance(field_or_name, str): + field_names.append(field_or_name) + fields.append(None) + else: + field_names.append(field_or_name.name) + fields.append(field_or_name) + + children = [] + for i, (field_name, maybe_field) in enumerate(zip(field_names, fields)): + child_path = None if path is None else f"{path}.{field_name}" + + if i < len(fields) - 1: + # Not the last child. Always show a comma, and add a space when + # collapsed. + comma_after = basic_parts.siblings( + ",", basic_parts.FoldCondition(collapsed=basic_parts.Text(" ")) + ) + else: + # Last child: only show the comma when the node is expanded. + comma_after = basic_parts.FoldCondition(expanded=basic_parts.Text(",")) + + if maybe_field is not None: + hide_except_in_roundtrip = not maybe_field.repr + force_collapsed = maybe_field.metadata.get( + "treescope_always_collapse", False + ) + else: + hide_except_in_roundtrip = False + force_collapsed = False + + field_name_rendering = attr_style_fn(field_name) + + try: + field_value = getattr(node, field_name) + except AttributeError: + child = basic_parts.FoldCondition( + expanded=common_styles.CommentColor( + basic_parts.siblings("# ", field_name_rendering, " is missing") + ) + ) + else: + child = basic_parts.siblings_with_annotations( + field_name_rendering, + "=", + subtree_renderer(field_value, path=child_path), + ) + + child_line = basic_parts.build_full_line_with_annotations( + child, comma_after + ) + if force_collapsed: + layout_algorithms.expand_to_depth(child_line, 0) + if hide_except_in_roundtrip: + child_line = basic_parts.RoundtripCondition(roundtrip=child_line) + + children.append(child_line) + + return children + + +def render_dataclass_constructor( + node: Any, +) -> part_interface.RenderableTreePart: + """Renders the constructor for a dataclass, including the open parenthesis.""" + assert dataclasses.is_dataclass(node) and not isinstance(node, type) + if not dataclass_util.init_takes_fields(type(node)): + constructor_open = basic_parts.siblings( + basic_parts.RoundtripCondition( + roundtrip=basic_parts.Text("pz.dataclass_from_attributes(") + ), + common_structures.maybe_qualified_type_name(type(node)), + basic_parts.RoundtripCondition( + roundtrip=basic_parts.Text(", "), + not_roundtrip=basic_parts.Text("("), + ), + ) + else: + constructor_open = basic_parts.siblings( + common_structures.maybe_qualified_type_name(type(node)), "(" + ) + return constructor_open diff --git a/penzai/treescope/foldable_representation/embedded_iframe.py b/penzai/treescope/_internal/parts/embedded_iframe.py similarity index 78% rename from penzai/treescope/foldable_representation/embedded_iframe.py rename to penzai/treescope/_internal/parts/embedded_iframe.py index 429f583..86ee0fb 100644 --- a/penzai/treescope/foldable_representation/embedded_iframe.py +++ b/penzai/treescope/_internal/parts/embedded_iframe.py @@ -15,14 +15,12 @@ """Embedding of external HTML content into treescope's IR.""" from __future__ import annotations -import abc import dataclasses import io from typing import Any, Sequence -from penzai.treescope import html_escaping -from penzai.treescope import object_inspection -from penzai.treescope.foldable_representation import part_interface +from penzai.treescope._internal import html_escaping +from penzai.treescope._internal.parts import part_interface CSSStyleRule = part_interface.CSSStyleRule JavaScriptDefn = part_interface.JavaScriptDefn @@ -32,35 +30,6 @@ FoldableTreeNode = part_interface.FoldableTreeNode -class HasReprHtml(abc.ABC): - """Abstract base class for rich-display objects in IPython.""" - - @abc.abstractmethod - def _repr_html_(self) -> str | tuple[str, Any]: - """Returns a rich HTML representation of an object.""" - ... - - @classmethod - def __subclasshook__(cls, subclass, /): - """Checks if a class is a subclass of HasReprHtml.""" - return hasattr(subclass, '_repr_html_') and callable(subclass._repr_html_) # pylint: disable=protected-access - - -def to_html(node: Any) -> str | None: - """Extracts a rich HTML representation of node using _repr_html_.""" - repr_html_method = object_inspection.safely_get_real_method( - node, '_repr_html_' - ) - if repr_html_method is None: - return None - html_for_node_and_maybe_metadata = repr_html_method() - if isinstance(html_for_node_and_maybe_metadata, tuple): - html_for_node, _ = html_for_node_and_maybe_metadata - else: - html_for_node = html_for_node_and_maybe_metadata - return html_for_node - - @dataclasses.dataclass(frozen=True) class EmbeddedIFrame(RenderableTreePart): """Builds an HTML iframe containing scoped HTML for a rich display object. @@ -168,3 +137,38 @@ def render_to_html( ' onload="this.getRootNode().host.defns.resize_iframe_by_content(this)">' '' ) + + +def embedded_iframe( + embedded_html: str, + fallback_in_text_mode: RenderableTreePart, + virtual_width: int = 80, + virtual_height: int = 2, +) -> part_interface.RenderableTreePart: + """Returns a wrapped child in a color for non-roundtrippable abbreviations.""" + if not isinstance(embedded_html, str): + raise ValueError( + '`embedded_html` must be a string, but got' + f' {type(embedded_html).__name__}' + ) + if not isinstance(fallback_in_text_mode, RenderableTreePart): + raise ValueError( + '`fallback_in_text_mode` must be a renderable part, but got' + f' {type(fallback_in_text_mode).__name__}' + ) + if not isinstance(virtual_width, int): + raise ValueError( + '`virtual_width` must be an integer, but got' + f' {type(virtual_width).__name__}' + ) + if not isinstance(virtual_height, int): + raise ValueError( + '`virtual_height` must be an integer, but got' + f' {type(virtual_height).__name__}' + ) + return EmbeddedIFrame( + embedded_html=embedded_html, + fallback_in_text_mode=fallback_in_text_mode, + virtual_width=virtual_width, + virtual_height=virtual_height, + ) diff --git a/penzai/treescope/foldable_representation/foldable_impl.py b/penzai/treescope/_internal/parts/foldable_impl.py similarity index 60% rename from penzai/treescope/foldable_representation/foldable_impl.py rename to penzai/treescope/_internal/parts/foldable_impl.py index 0e24e65..e3e2838 100644 --- a/penzai/treescope/foldable_representation/foldable_impl.py +++ b/penzai/treescope/_internal/parts/foldable_impl.py @@ -20,19 +20,13 @@ from __future__ import annotations -import contextlib import dataclasses import io -import json -from typing import Any, Callable, Iterator, Sequence -import uuid +from typing import Any, Callable, Sequence -from penzai.treescope import context -from penzai.treescope import html_encapsulation -from penzai.treescope import html_escaping -from penzai.treescope.foldable_representation import basic_parts -from penzai.treescope.foldable_representation import common_styles -from penzai.treescope.foldable_representation import part_interface +from penzai.treescope._internal import html_escaping +from penzai.treescope._internal.parts import basic_parts +from penzai.treescope._internal.parts import part_interface CSSStyleRule = part_interface.CSSStyleRule @@ -89,7 +83,7 @@ def set_expand_state(self, expand_state: ExpandState): def as_expanded_part(self) -> RenderableTreePart: """Returns the contents of this foldable when expanded.""" - return basic_parts.Siblings.build(self.label, self.contents) + return basic_parts.siblings(self.label, self.contents) def html_setup_parts( self, setup_context: HtmlContextForSetup @@ -617,385 +611,3 @@ class DeferredWithThunk: placeholder: DeferredPlaceholder thunk: Callable[[RenderableTreePart | None], RenderableTreePart] - - -_deferrables: context.ContextualValue[list[DeferredWithThunk] | None] = ( - context.ContextualValue( - module=__name__, qualname="_deferrables", initial_value=None - ) -) -"""An optional list of accumulated deferrables, for use by this module.""" - - -def maybe_defer_rendering( - main_thunk: Callable[[RenderableTreePart | None], RenderableTreePart], - placeholder_thunk: Callable[[], RenderableTreePart], -) -> RenderableTreePart: - """Possibly defers rendering of a part in interactive contexts. - - This function can be used by advanced handlers and autovisualizers to delay - the rendering of "expensive" leaves such as `jax.Array` until after the tree - structure is drawn. If run in a non-interactive context, this just calls the - main thunk. If run in an interactive context, it instead calls the placeholder - thunk, and enqueues the placeholder thunk to be called later. - - Rendering can be performed in a deferred context by running the handlers under - the `collecting_deferred_renderings` context manager, and then rendered to - a sequence of streaming HTML updates using the `display_streaming_as_root` - function. - - Note that handlers who call this are responsible for ensuring that the - logic in `main_thunk` is safe to run at a later point in time. In particular, - any rendering context managers may have been exited by the time this main - thunk is called. As a best practice, handlers should control all of the logic - in `main_thunk` and shouldn't recursively call the subtree renderer inside it; - subtrees should be rendered before calling `maybe_defer_rendering`. - - Args: - main_thunk: A callable producing the main part to render. If not deferred, - will be called with None. If deferred, will be called with the placeholder - part, which can be inspected to e.g. infer folding state. - placeholder_thunk: A callable producing a placeholder object, which will be - rendered if we are deferring rendering. - - Returns: - Either the rendered main part or a wrapped placeholder that will later be - replaced with the main part. - """ - deferral_list = _deferrables.get() - if deferral_list is None: - return main_thunk(None) - else: - placeholder = DeferredPlaceholder( - child=placeholder_thunk(), - replacement_id="deferred_" + uuid.uuid4().hex, - ) - deferral_list.append(DeferredWithThunk(placeholder, main_thunk)) - return placeholder - - -@contextlib.contextmanager -def collecting_deferred_renderings() -> Iterator[list[DeferredWithThunk]]: - # pylint: disable=g-doc-return-or-yield - """Context manager that defers and collects `maybe_defer_rendering` calls. - - This context manager can be used by renderers that wish to render deferred - objects in a streaming fashion. When used in a - `with collecting_deferred_renderings() as deferreds:` - expression, `deferreds` will be a list that is populated by calls to - `maybe_defer_rendering`. This can later be passed to - `display_streaming_as_root` to render the deferred object in a streaming - fashion. - - Returns: - A context manager in which `maybe_defer_rendering` calls will be deferred - and collected into the result list. - """ - # pylint: enable=g-doc-return-or-yield - try: - target = [] - with _deferrables.set_scoped(target): - yield target - finally: - pass - - -################################################################################ -# Top-level rendering and roundtrip mode implementation -################################################################################ - - -def render_to_text_as_root( - root_node: RenderableTreePart, - roundtrip: bool = False, - strip_trailing_whitespace: bool = True, - strip_whitespace_lines: bool = True, -) -> str: - """Renders a root node to text. - - Args: - root_node: The root node to render. - roundtrip: Whether to render in roundtrip mode. - strip_trailing_whitespace: Whether to remove trailing whitespace from lines. - strip_whitespace_lines: Whether to remove lines that are entirely - whitespace. These lines can sometimes be generated by layout code being - conservative about line breaks. Should only be True if - `strip_trailing_whitespace` is True. - - Returns: - Text for the rendered node. - """ - if strip_whitespace_lines and not strip_trailing_whitespace: - raise ValueError("strip_whitespace_lines must be False if " - "strip_trailing_whitespace is False.") - - stream = io.StringIO() - root_node.render_to_text( - stream, - expanded_parent=True, - indent=0, - roundtrip_mode=roundtrip, - render_context={}, - ) - result = stream.getvalue() - - if strip_trailing_whitespace: - trimmed_lines = [] - for line in result.split("\n"): - line = line.rstrip() - if line or not strip_whitespace_lines: - trimmed_lines.append(line) - result = "\n".join(trimmed_lines) - return result - - -TREESCOPE_PREAMBLE_SCRIPT = """(()=> { - const defns = this.getRootNode().host.defns; - let _pendingActions = []; - let _pendingActionHandle = null; - defns.runSoon = (work) => { - const doWork = () => { - const tick = performance.now(); - while (performance.now() - tick < 32) { - if (_pendingActions.length == 0) { - _pendingActionHandle = null; - return; - } else { - const thunk = _pendingActions.shift(); - thunk(); - } - } - _pendingActionHandle = ( - window.requestAnimationFrame(doWork)); - }; - _pendingActions.push(work); - if (_pendingActionHandle === null) { - _pendingActionHandle = ( - window.requestAnimationFrame(doWork)); - } - }; - defns.toggle_root_roundtrip = (rootelt, event) => { - if (event.key == "r") { - rootelt.classList.toggle("roundtrip_mode"); - } - }; -})(); -""" - - -def _render_to_html_as_root_streaming( - root_node: RenderableTreePart, - roundtrip: bool, - deferreds: Sequence[DeferredWithThunk], -) -> Iterator[str]: - """Helper function: renders a root node to HTML one step at a time. - - Args: - root_node: The root node to render. - roundtrip: Whether to render in roundtrip mode. - deferreds: Sequence of deferred objects to render and splice in. - - Yields: - HTML source for the rendered node, followed by logic to substitute each - deferred object. - """ - all_css_styles = set() - all_js_defns = set() - - def _render_one( - node, - at_beginning_of_line: bool, - render_context: dict[Any, Any], - stream: io.StringIO, - ): - # Extract setup rules. - setup_parts = node.html_setup_parts(SETUP_CONTEXT) - current_styles = [] - current_js_defns = [] - for part in setup_parts: - if isinstance(part, CSSStyleRule): - if part not in all_css_styles: - current_styles.append(part) - all_css_styles.add(part) - elif isinstance(part, JavaScriptDefn): - if part not in all_js_defns: - current_js_defns.append(part) - all_js_defns.add(part) - else: - raise ValueError(f"Invalid setup object: {part}") - - if current_styles: - stream.write("") - - if current_js_defns: - stream.write( - "") - - # Render the node itself. - node.render_to_html( - stream, - at_beginning_of_line=at_beginning_of_line, - render_context=render_context, - ) - - # Set up the styles and scripts for the root object. - stream = io.StringIO() - stream.write("") - # These scripts allow us to defer execution of javascript blocks until after - # the content is loaded, avoiding locking up the browser rendering process. - stream.write("") - - # Render the root node. - classnames = "treescope_root" - if roundtrip: - classnames += " roundtrip_mode" - stream.write( - f'
' - ) - _render_one(root_node, True, {}, stream) - stream.write("
") - - yield stream.getvalue() - - # Render any deferred parts. We insert each part into a hidden element, then - # move them all out to their appropriate positions. - if deferreds: - stream = io.StringIO() - for deferred in deferreds: - stream.write( - '") - - all_ids = [deferred.placeholder.replacement_id for deferred in deferreds] - inner_script = ( - f"const targetIds = {json.dumps(all_ids)};" - + html_escaping.without_repeated_whitespace(""" - const docroot = this.getRootNode(); - const treeroot = docroot.querySelector(".treescope_root"); - const fragment = document.createDocumentFragment(); - const treerootClone = fragment.appendChild(treeroot.cloneNode(true)); - for (let i = 0; i < targetIds.length; i++) { - let target = fragment.getElementById(targetIds[i]); - let sourceDiv = docroot.querySelector("#for_" + targetIds[i]); - target.replaceWith(sourceDiv.firstElementChild); - sourceDiv.remove(); - } - treeroot.replaceWith(treerootClone); - """) - ) - stream.write( - '" - ) - yield stream.getvalue() - - -def render_to_html_as_root( - root_node: RenderableTreePart, - roundtrip: bool = False, - compressed: bool = False, -) -> str: - """Renders a root node to HTML. - - This handles collecting styles and JS definitions and inserting the root - HTML element. - - Args: - root_node: The root node to render. - roundtrip: Whether to render in roundtrip mode. - compressed: Whether to compress the HTML for display. - - Returns: - HTML source for the rendered node. - """ - render_iterator = _render_to_html_as_root_streaming(root_node, roundtrip, []) - html_src = "".join(render_iterator) - return html_encapsulation.encapsulate_html(html_src, compress=compressed) - - -def display_streaming_as_root( - root_node: RenderableTreePart, - deferreds: Sequence[DeferredWithThunk], - roundtrip: bool = False, - compressed: bool = True, - stealable: bool = False, -) -> str | None: - """Displays a root node in an IPython notebook in a streaming fashion. - - Args: - root_node: The root node to render. - deferreds: Deferred objects to render and splice in. - roundtrip: Whether to render in roundtrip mode. - compressed: Whether to compress the HTML for display. - stealable: Whether to return an extra HTML snippet that allows the streaming - rendering to be relocated after it is shown. - - Returns: - If ``stealable`` is True, a final HTML snippet which, if inserted into a - document, will "steal" the root node rendering, moving the DOM nodes for it - into itself. In particular, using this as the HTML rendering of the root - node during pretty printing will correctly associate the rendering with the - IPython "cell output", which is visible in some IPython backends (e.g. - JupyterLab). If ``stealable`` is False, returns None. - """ - import IPython.display # pylint: disable=g-import-not-at-top - - render_iterator = _render_to_html_as_root_streaming( - root_node, roundtrip, deferreds - ) - encapsulated_iterator = html_encapsulation.encapsulate_streaming_html( - render_iterator, compress=compressed, stealable=stealable - ) - - for step in encapsulated_iterator: - if step.segment_type == html_encapsulation.SegmentType.FINAL_OUTPUT_STEALER: - return step.html_src - else: - IPython.display.display(IPython.display.HTML(step.html_src)) diff --git a/penzai/treescope/foldable_representation/part_interface.py b/penzai/treescope/_internal/parts/part_interface.py similarity index 99% rename from penzai/treescope/foldable_representation/part_interface.py rename to penzai/treescope/_internal/parts/part_interface.py index aa39a13..e19d939 100644 --- a/penzai/treescope/foldable_representation/part_interface.py +++ b/penzai/treescope/_internal/parts/part_interface.py @@ -96,6 +96,9 @@ class HtmlContextForSetup: class RenderableTreePart(abc.ABC): """Abstract base class for a formatted part of a foldable tree. + WARNING: The details of this interface are an implementation detail, and are + subject to change. + Formatted objects are produced by treescope handlers from the original Python objects, and know how to convert themselves to a concrete renderable representation. diff --git a/penzai/treescope/array_autovisualizer.py b/penzai/treescope/array_autovisualizer.py index 9bc094f..4b8868d 100644 --- a/penzai/treescope/array_autovisualizer.py +++ b/penzai/treescope/array_autovisualizer.py @@ -23,13 +23,12 @@ from penzai.treescope import arrayviz from penzai.treescope import autovisualize from penzai.treescope import dtype_util +from penzai.treescope import lowering from penzai.treescope import ndarray_adapters +from penzai.treescope import rendering_parts from penzai.treescope import type_registries -from penzai.treescope.foldable_representation import basic_parts -from penzai.treescope.foldable_representation import common_structures -from penzai.treescope.foldable_representation import common_styles -from penzai.treescope.foldable_representation import foldable_impl -from penzai.treescope.foldable_representation import part_interface +from penzai.treescope._internal import arrayviz_impl +from penzai.treescope._internal.parts import part_interface PositionalAxisInfo = ndarray_adapters.PositionalAxisInfo @@ -90,8 +89,8 @@ def _autovisualize_array( adapter: ndarray_adapters.NDArrayAdapter, path: str | None, label: str, - expand_state: part_interface.ExpandState, - ) -> part_interface.RenderableTreePart: + expand_state: rendering_parts.ExpandState, + ) -> rendering_parts.RenderableTreePart: """Helper to visualize an array.""" # Extract information about axis names, indices, and sizes. array_axis_info = adapter.get_axis_info_for_array_data(array) @@ -107,14 +106,14 @@ def _autovisualize_array( row_axes.append(info.axis_name) # Infer a good truncated shape for this array. - edge_items_per_axis = arrayviz.infer_balanced_truncation( + edge_items_per_axis = arrayviz_impl.infer_balanced_truncation( tuple(info.size for info in array_axis_info), maximum_size=self.maximum_size, cutoff_size_per_axis=self.cutoff_size_per_axis, minimum_edge_items=self.edge_items, ) - row_axes, column_axes = arrayviz.infer_rows_and_columns( + row_axes, column_axes = arrayviz_impl.infer_rows_and_columns( all_axes=array_axis_info, known_rows=row_axes, known_columns=column_axes, @@ -162,7 +161,7 @@ def _autovisualize_array( value_item_labels=value_item_labels, axis_labels=None, ) - rendering_parts = [array_rendering] + outputs = [array_rendering] last_line_parts = [] # Render the sharding as well. @@ -189,20 +188,24 @@ def _autovisualize_array( array, columns=[c.logical_key() for c in column_axes], rows=[r.logical_key() for r in row_axes], - ) - rendering_parts.append( - common_structures.build_custom_foldable_tree_node( - label=common_styles.AbbreviationColor( - basic_parts.siblings( - basic_parts.Text(sharding_summary_str), - basic_parts.FoldCondition( - expanded=basic_parts.Text(":"), - collapsed=basic_parts.Text(" (click to expand)"), + ).treescope_part + outputs.append( + rendering_parts.build_custom_foldable_tree_node( + label=rendering_parts.abbreviation_color( + rendering_parts.siblings( + rendering_parts.text(sharding_summary_str), + rendering_parts.fold_condition( + expanded=rendering_parts.text(":"), + collapsed=rendering_parts.text( + " (click to expand)" + ), ), ) ), - contents=basic_parts.FoldCondition( - expanded=basic_parts.IndentedChildren([sharding_rendering]), + contents=rendering_parts.fold_condition( + expanded=rendering_parts.indented_children( + [sharding_rendering] + ), ), ) ) @@ -210,21 +213,21 @@ def _autovisualize_array( # We render it with a path, but remove the copy path button. This will be # added back by the caller. if last_line_parts: - last_line = basic_parts.siblings( - basic_parts.FoldCondition( - expanded=basic_parts.Text("".join(last_line_parts)), + last_line = rendering_parts.siblings( + rendering_parts.fold_condition( + expanded=rendering_parts.text("".join(last_line_parts)), ), - basic_parts.Text(">"), + rendering_parts.text(">"), ) else: - last_line = basic_parts.Text(">") - custom_rendering = common_structures.build_custom_foldable_tree_node( - label=common_styles.AbbreviationColor(label), - contents=basic_parts.siblings( - basic_parts.FoldCondition( - expanded=basic_parts.IndentedChildren.build(rendering_parts) + last_line = rendering_parts.text(">") + custom_rendering = rendering_parts.build_custom_foldable_tree_node( + label=rendering_parts.abbreviation_color(label), + contents=rendering_parts.siblings( + rendering_parts.fold_condition( + expanded=rendering_parts.indented_children(outputs) ), - common_styles.AbbreviationColor(last_line), + rendering_parts.abbreviation_color(last_line), ), path=path, expand_state=expand_state, @@ -256,34 +259,34 @@ def __call__( if not _supported_dtype(np_dtype): return None - def _placeholder() -> part_interface.RenderableTreePart: + def _placeholder() -> rendering_parts.RenderableTreePart: summary = adapter.get_array_summary(value, fast=True) - return common_structures.fake_placeholder_foldable( - common_styles.DeferredPlaceholderStyle( - basic_parts.Text(f"<{summary}>") + return rendering_parts.fake_placeholder_foldable( + rendering_parts.deferred_placeholder_style( + rendering_parts.text(f"<{summary}>") ), extra_newlines_guess=8, ) - def _thunk(placeholder) -> part_interface.RenderableTreePart: + def _thunk(placeholder) -> rendering_parts.RenderableTreePart: # Full rendering of the array. if isinstance(placeholder, part_interface.FoldableTreeNode): expand_state = placeholder.get_expand_state() else: assert placeholder is None - expand_state = part_interface.ExpandState.WEAKLY_EXPANDED + expand_state = rendering_parts.ExpandState.WEAKLY_EXPANDED summary = adapter.get_array_summary(value, fast=False) - label = common_styles.AbbreviationColor(basic_parts.Text(f"<{summary}")) + label = rendering_parts.abbreviation_color( + rendering_parts.text(f"<{summary}") + ) return self._autovisualize_array( value, adapter, path, label, expand_state ) return autovisualize.CustomTreescopeVisualization( - basic_parts.RenderableAndLineAnnotations( - renderable=foldable_impl.maybe_defer_rendering( - _thunk, _placeholder - ), - annotations=common_structures.build_copy_button(path), + rendering_parts.RenderableAndLineAnnotations( + renderable=lowering.maybe_defer_rendering(_thunk, _placeholder), + annotations=rendering_parts.build_copy_button(path), ) ) @@ -350,19 +353,19 @@ def _thunk(placeholder) -> part_interface.RenderableTreePart: shardvis = arrayviz.render_sharding_info( array_axis_info=fake_axis_info, sharding_info=sharding_info, - ) - custom_rendering = common_structures.build_custom_foldable_tree_node( - label=common_styles.AbbreviationColor( - basic_parts.Text("<" + repr_oneline) + ).treescope_part + custom_rendering = rendering_parts.build_custom_foldable_tree_node( + label=rendering_parts.abbreviation_color( + rendering_parts.text("<" + repr_oneline) ), - contents=basic_parts.siblings( - basic_parts.FoldCondition( - expanded=basic_parts.IndentedChildren.build([shardvis]) + contents=rendering_parts.siblings( + rendering_parts.fold_condition( + expanded=rendering_parts.indented_children([shardvis]) ), - common_styles.AbbreviationColor(basic_parts.Text(">")), + rendering_parts.abbreviation_color(rendering_parts.text(">")), ), path=path, - expand_state=part_interface.ExpandState.EXPANDED, + expand_state=rendering_parts.ExpandState.EXPANDED, ) return autovisualize.CustomTreescopeVisualization(custom_rendering) else: diff --git a/penzai/treescope/arrayviz.py b/penzai/treescope/arrayviz.py index a8e3709..30692f8 100644 --- a/penzai/treescope/arrayviz.py +++ b/penzai/treescope/arrayviz.py @@ -14,32 +14,27 @@ """Single-purpose ndarray visualizer for Python in vanilla Javascript. -Designed to visualize the contents of arbitrarily-high-dimensional NDArrays -quickly and without any dependencies, to allow them to be visualized by default -instead of requiring lots of manual effort. +Designed to quickly visualize the contents of arbitrarily-high-dimensional +NDArrays, to allow them to be visualized by default instead of requiring lots +of manual effort. """ from __future__ import annotations -import base64 import collections -import dataclasses -import io import itertools import json -import os from typing import Any, Literal, Sequence import numpy as np from penzai.treescope import context from penzai.treescope import dtype_util -from penzai.treescope import figures -from penzai.treescope import html_escaping from penzai.treescope import ndarray_adapters +from penzai.treescope import rendering_parts from penzai.treescope import type_registries -from penzai.treescope.foldable_representation import basic_parts -from penzai.treescope.foldable_representation import part_interface - +from penzai.treescope._internal import arrayviz_impl +from penzai.treescope._internal import figures_impl +from penzai.treescope._internal import html_escaping AxisName = Any @@ -51,579 +46,6 @@ ArrayInRegistry = Any -def load_arrayvis_javascript() -> str: - """Loads the contents of `arrayvis.js` from the Python package. - - Returns: - Source code for arrayviz. - """ - filepath = __file__ - if filepath is None: - raise ValueError("Could not find the path to arrayviz.js!") - - # Look for the resource relative to the current module's filesystem path. - base = filepath.removesuffix("arrayviz.py") - load_path = os.path.join(base, "js", "arrayviz.js") - - with open(load_path, "r") as f: - return f.read() - - -def _html_setup() -> ( - set[part_interface.CSSStyleRule | part_interface.JavaScriptDefn] -): - """Builds the setup HTML that should be included in any arrayviz output cell.""" - arrayviz_src = html_escaping.heuristic_strip_javascript_comments( - load_arrayvis_javascript() - ) - return { - part_interface.CSSStyleRule(html_escaping.without_repeated_whitespace(""" - .arrayviz_container { - white-space: normal; - } - .arrayviz_container .info { - font-family: monospace; - color: #aaaaaa; - margin-bottom: 0.25em; - white-space: pre; - } - .arrayviz_container .info input[type="range"] { - vertical-align: middle; - filter: grayscale(1) opacity(0.5); - } - .arrayviz_container .info input[type="range"]:hover { - filter: grayscale(0.5); - } - .arrayviz_container .info input[type="number"]:not(:focus) { - border-radius: 3px; - } - .arrayviz_container .info input[type="number"]:not(:focus):not(:hover) { - color: #777777; - border: 1px solid #777777; - } - .arrayviz_container .info.sliders { - white-space: pre; - } - .arrayviz_container .hovertip { - display: none; - position: absolute; - background-color: white; - border: 1px solid black; - padding: 0.25ch; - pointer-events: none; - width: fit-content; - overflow: visible; - white-space: pre; - z-index: 1000; - } - .arrayviz_container .hoverbox { - display: none; - position: absolute; - box-shadow: 0 0 0 1px black, 0 0 0 2px white; - pointer-events: none; - z-index: 900; - } - .arrayviz_container .clickdata { - white-space: pre; - } - .arrayviz_container .loading_message { - color: #aaaaaa; - } - """)), - part_interface.JavaScriptDefn( - arrayviz_src + " this.getRootNode().host.defns.arrayviz = arrayviz;" - ), - } - - -def _render_array_to_html( - array_data: np.ndarray, - valid_mask: np.ndarray, - column_axes: Sequence[int], - row_axes: Sequence[int], - slider_axes: Sequence[int], - axis_labels: list[str], - vmin: float, - vmax: float, - cmap_type: Literal["continuous", "palette_index", "digitbox"], - cmap_data: list[tuple[int, int, int]], - info: str = "", - formatting_instructions: list[dict[str, Any]] | None = None, - dynamic_continous_cmap: bool = False, - raw_min_abs: float | None = None, - raw_max_abs: float | None = None, -) -> str: - """Helper to render an array to HTML by passing arguments to javascript. - - Args: - array_data: Array data to render. - valid_mask: Mask array, of same shape as array_data, that is True for items - we should render. - column_axes: Axes (by index into `array_data`) to arrange as columns, - ordered from outermost group to innermost group. - row_axes: Axes (by index into `array_data`) to arrange as rows, ordered from - outermost group to innermost group. - slider_axes: Axes to bind to sliders. - axis_labels: Labels for each axis. - vmin: Minimum for the colormap. - vmax: Maximum for the colormap. - cmap_type: Type of colormap (see `render_array`) - cmap_data: Data for the colormap, as a sequence of RGB triples. - info: Info for the plot. - formatting_instructions: Formatting instructions for values on mouse hover - or click. These will be interpreted by `formatValueAndIndices` on the - JavaScript side. Can assume each axis is named "a0", "a1", etc. when - running in JavaScript. - dynamic_continous_cmap: Whether to dynamically adjust the colormap during - rendering. - raw_min_abs: Minimum absolute value of the array, for dynamic remapping. - raw_max_abs: Maximum absolute value of the array, for dynamic remapping. - - Returns: - HTML source for an arrayviz rendering. - """ - assert len(array_data.shape) == len(axis_labels) - assert len(valid_mask.shape) == len(axis_labels) - - if formatting_instructions is None: - formatting_instructions = [{"type": "value"}] - - # Compute strides for each axis. We refer to each axis as "a0", "a1", etc - # across the JavaScript boundary. - stride = 1 - strides = {} - for i, axis_size in reversed(list(enumerate(array_data.shape))): - strides[f"a{i}"] = stride - stride *= axis_size - - if cmap_type == "continuous": - converted_array_data = array_data.astype(np.float32) - array_dtype = "float32" - else: - converted_array_data = array_data.astype(np.int32) - array_dtype = "int32" - - def axis_spec_arg(i): - return { - "name": f"a{i}", - "label": axis_labels[i], - "start": 0, - "end": array_data.shape[i], - } - - x_axis_specs_arg = [] - for axis in column_axes: - x_axis_specs_arg.append(axis_spec_arg(axis)) - - y_axis_specs_arg = [] - for axis in row_axes: - y_axis_specs_arg.append(axis_spec_arg(axis)) - - sliced_axis_specs_arg = [] - for axis in slider_axes: - sliced_axis_specs_arg.append(axis_spec_arg(axis)) - - args_json = json.dumps({ - "info": info, - "arrayBase64": base64.b64encode(converted_array_data.tobytes()).decode( - "ascii" - ), - "arrayDtype": array_dtype, - "validMaskBase64": base64.b64encode( - valid_mask.astype(np.uint8).tobytes() - ).decode("ascii"), - "dataStrides": strides, - "xAxisSpecs": x_axis_specs_arg, - "yAxisSpecs": y_axis_specs_arg, - "slicedAxisSpecs": sliced_axis_specs_arg, - "colormapConfig": { - "type": cmap_type, - "min": vmin, - "max": vmax, - "dynamic": dynamic_continous_cmap, - "rawMinAbs": raw_min_abs, - "rawMaxAbs": raw_max_abs, - "cmapData": cmap_data, - }, - "valueFormattingInstructions": formatting_instructions, - }) - # Note: We need to save the parent of the treescope-run-here element first, - # because it will be removed before the runSoon callback executes. - inner_fn = html_escaping.without_repeated_whitespace(""" - const parent = this.parentNode; - const defns = this.getRootNode().host.defns; - defns.runSoon(() => { - const tpl = parent.querySelector('template.deferred_args'); - const config = JSON.parse( - tpl.content.querySelector('script').textContent - ); - tpl.remove(); - defns.arrayviz.buildArrayvizFigure(parent, config); - }); - """) - src = ( - '
' - 'Rendering array...' - f'" - '
' - ) - return src - - -def infer_rows_and_columns( - all_axes: Sequence[AxisInfo], - known_rows: Sequence[AxisInfo] = (), - known_columns: Sequence[AxisInfo] = (), - edge_items_per_axis: tuple[int | None, ...] | None = None, -) -> tuple[list[AxisInfo], list[AxisInfo]]: - """Infers an ordered assignment of axis indices or names to rows and columns. - - The unassigned axes are sorted by size and then assigned to rows and columns - to try to balance the total number of elements along the row and column axes. - This curently uses a greedy algorithm with an adjustment to try to keep - columns longer than rows, except when there are exactly two axes and both are - positional, in which case it lays out axis 0 as the rows and axis 1 as the - columns. - - Axes with logical positions are sorted before axes with only names - (in reverse order, so that later axes are rendered on the inside). Axes with - names only appear afterward, with explicitly-assigned ones before unassigned - ones. - - Args: - all_axes: Sequence of axis infos in the array that should be assigned. - known_rows: Sequence of axis indices or names that must map to rows. - known_columns: Sequence of axis indices or names that must map to columns. - edge_items_per_axis: Optional edge items specification, determining - truncated size of each axis. Must match the ordering of `all_axes`. - - Returns: - Tuple (rows, columns) of assignments, which consist of `known_rows` and - `known_columns` followed by the remaining unassigned axes in a balanced - order. - """ - if edge_items_per_axis is None: - edge_items_per_axis = (None,) * len(all_axes) - - if not known_rows and not known_columns and len(all_axes) == 2: - ax_a, ax_b = all_axes - if ( - isinstance(ax_a, PositionalAxisInfo) - and isinstance(ax_b, PositionalAxisInfo) - and {ax_a.axis_logical_index, ax_b.axis_logical_index} == {0, 1} - ): - # Two-dimensional positional array. Always do rows then columns. - if ax_a.axis_logical_index == 0: - return ([ax_a], [ax_b]) - else: - return ([ax_b], [ax_a]) - - truncated_sizes = { - ax: ax.size if edge_items is None else 2 * edge_items + 1 - for ax, edge_items in zip(all_axes, edge_items_per_axis) - } - unassigned = [ - ax for ax in all_axes if ax not in known_rows and ax not in known_columns - ] - - # Sort by size descending, so that we make the most important layout decisions - # first. - unassigned = sorted( - unassigned, key=lambda ax: (truncated_sizes[ax], ax.size), reverse=True - ) - - # Compute the total size every axis would have if we assigned them to the - # same axis. - unassigned_size = np.prod([truncated_sizes[ax] for ax in unassigned]) - - rows = list(known_rows) - row_size = np.prod([truncated_sizes[ax] for ax in rows]) - columns = list(known_columns) - column_size = np.prod([truncated_sizes[ax] for ax in columns]) - - for ax in unassigned: - axis_size = truncated_sizes[ax] - unassigned_size = unassigned_size // axis_size - if row_size * axis_size > column_size * unassigned_size: - # If we assign this to the row axis, we'll end up with a visualization - # with more rows than columns regardless of what we do later, which can - # waste screen space. Assign to columns instead. - columns.append(ax) - column_size *= axis_size - else: - # Assign to the row axis. We'll assign columns later. - rows.append(ax) - row_size *= axis_size - - # The specific ordering of axes along the rows and the columns is somewhat - # arbitrary. Re-order each so that they have positional then named axes, and - # so that position axes are in reverse position order, and the explicitly - # mentioned named axes are before the unassigned ones. - def ax_sort_key(ax: AxisInfo): - if isinstance(ax, PositionalAxisInfo | NamedPositionalAxisInfo): - return (0, -ax.axis_logical_index) - elif ax in unassigned: - return (2,) - else: - return (1,) - - return sorted(rows, key=ax_sort_key), sorted(columns, key=ax_sort_key) - - -def _infer_vmin_vmax( - array: np.ndarray, - mask: np.ndarray, - vmin: float | None, - vmax: float | None, - around_zero: bool, - trim_outliers: bool, -) -> tuple[float, float]: - """Infer reasonable lower and upper colormap bounds from an array.""" - inferring_both_bounds = vmax is None and vmin is None - finite_mask = np.logical_and(np.isfinite(array), mask) - if vmax is None: - if around_zero: - if vmin is not None: - vmax = -vmin # pylint: disable=invalid-unary-operand-type - else: - vmax = np.max(np.where(finite_mask, np.abs(array), 0)) - else: - vmax = np.max(np.where(finite_mask, array, -np.inf)) - - assert vmax is not None - - if vmin is None: - if around_zero: - vmin = -vmax # pylint: disable=invalid-unary-operand-type - else: - vmin = np.min(np.where(finite_mask, array, np.inf)) - - if inferring_both_bounds and trim_outliers: - if around_zero: - center = 0 - else: - center = np.nanmean(np.where(finite_mask, array, np.nan)) - center = np.where(np.isfinite(center), center, 0.0) - - second_moment = np.nanmean( - np.where(finite_mask, np.square(array - center), np.nan) - ) - sigma = np.where( - np.isfinite(second_moment), np.sqrt(second_moment), vmax - vmin - ) - - vmin_limit = center - 3 * sigma - vmin = np.maximum(vmin, vmin_limit) - vmax_limit = center + 3 * sigma - vmax = np.minimum(vmax, vmax_limit) - - return vmin, vmax - - -def _infer_abs_min_max( - array: np.ndarray, mask: np.ndarray -) -> tuple[float, float]: - """Infer smallest and largest absolute values in array.""" - finite_mask = np.logical_and(np.isfinite(array), mask) - absmin = np.min( - np.where(np.logical_and(finite_mask, array != 0), np.abs(array), np.inf) - ) - absmin = np.where(np.isinf(absmin), 0.0, absmin) - absmax = np.max(np.where(finite_mask, np.abs(array), -np.inf)) - absmax = np.where(np.isinf(absmax), 0.0, absmax) - return absmin, absmax - - -def infer_balanced_truncation( - shape: Sequence[int], - maximum_size: int, - cutoff_size_per_axis: int, - minimum_edge_items: int, - doubling_bonus: float = 10.0, -) -> tuple[int | None, ...]: - """Infers a balanced truncation from a shape. - - This function computes a set of truncation sizes for each axis of the array - such that it obeys the constraints about array and axis sizes, while also - keeping the relative proportions of the array consistent (e.g. we keep more - elements along axes that were originally longer). This means that the aspect - ratio of the truncated array will still resemble the aspect ratio of the - original array. - - To avoid very-unbalanced renderings and truncate longer axes more than short - ones, this function truncates based on the square-root of the axis size by - default. - - Args: - shape: The shape of the array we are truncating. - maximum_size: Maximum number of elements of an array to show. Arrays larger - than this will be truncated along one or more axes. - cutoff_size_per_axis: Maximum number of elements of each individual axis to - show without truncation. Any axis longer than this will be truncated, with - their visual size increasing logarithmically with the true axis size - beyond this point. - minimum_edge_items: How many values to keep along each axis for truncated - arrays. We may keep more than this up to the budget of maximum_size. - doubling_bonus: Number of elements to add to each axis each time it doubles - beyond `cutoff_size_per_axis`. Used to make longer axes appear visually - longer while still keeping them a reasonable size. - - Returns: - A tuple of edge sizes. Each element corresponds to an axis in `shape`, - and is either `None` (for no truncation) or an integer (corresponding to - the number of elements to keep at the beginning and and at the end). - """ - shape_arr = np.array(list(shape)) - remaining_elements_to_divide = maximum_size - edge_items_per_axis = {} - # Order our shape from smallest to largest, since the smallest axes will - # require the least amount of truncation and will have the most stringent - # constraints. - sorted_axes = np.argsort(shape_arr) - sorted_shape = shape_arr[sorted_axes] - - # Figure out maximum sizes based on the cutoff - cutoff_adjusted_maximum_sizes = np.where( - sorted_shape <= cutoff_size_per_axis, - sorted_shape, - cutoff_size_per_axis - + doubling_bonus * np.log2(sorted_shape / cutoff_size_per_axis), - ) - - # Suppose we want to make a scaled version of the array with relative - # axis sizes - # s0, s1, s2, ... - # The total size is then - # size = (c * s0) * (c * s1) * (c * s2) * ... - # log(size) = ndim * log(c) + [ log s0 + log s1 + log s2 + ... ] - # If we have a known final size we want to reach, we can solve for c as - # c = exp( (log size - [ log s0 + log s1 + log s2 + ... ]) / ndim ) - axis_proportions = np.sqrt(sorted_shape) - log_axis_proportions = np.log(axis_proportions) - for i in range(len(sorted_axes)): - original_axis = sorted_axes[i] - size = shape_arr[original_axis] - # If we truncated this axis and every axis after it proportional to - # their weights, how small of an axis size would we need for this - # axis? - log_c = ( - np.log(remaining_elements_to_divide) - np.sum(log_axis_proportions[i:]) - ) / (len(shape) - i) - soft_limit_for_this_axis = np.exp(log_c + log_axis_proportions[i]) - cutoff_limit_for_this_axis = np.floor( - np.minimum( - soft_limit_for_this_axis, - cutoff_adjusted_maximum_sizes[i], - ) - ) - if size <= 2 * minimum_edge_items + 1 or size <= cutoff_limit_for_this_axis: - # If this axis is already smaller than the minimum size it would have - # after truncation, there's no reason to truncate it. - # But pretend we did, so that other axes still grow monotonically if - # their axis sizes increase. - remaining_elements_to_divide = ( - remaining_elements_to_divide / soft_limit_for_this_axis - ) - edge_items_per_axis[original_axis] = None - elif cutoff_limit_for_this_axis < 2 * minimum_edge_items + 1: - # If this axis is big enough to truncate, but our naive target size is - # smaller than the minimum allowed truncation, we should truncate it - # to the minimum size allowed instead. - edge_items_per_axis[original_axis] = minimum_edge_items - remaining_elements_to_divide = remaining_elements_to_divide / ( - 2 * minimum_edge_items + 1 - ) - else: - # Otherwise, truncate it and all remaining axes based on our target - # truncations. - for j in range(i, len(sorted_axes)): - visual_size = np.floor( - np.minimum( - np.exp(log_c + log_axis_proportions[j]), - cutoff_adjusted_maximum_sizes[j], - ) - ) - edge_items_per_axis[sorted_axes[j]] = int(visual_size // 2) - break - - return tuple( - edge_items_per_axis[orig_axis] for orig_axis in range(len(shape)) - ) - - -def compute_truncated_shape( - shape: tuple[int, ...], - edge_items: tuple[int | None, ...], -) -> tuple[int, ...]: - """Computes the shape of a truncated array. - - This can be used to estimate the size of an array visualization after it has - been truncated by `infer_balanced_truncation`. - - Args: - shape: The original array shape. - edge_items: Number of edge items to keep along each axis. - - Returns: - The shape of the truncated array. - """ - return tuple( - orig if edge is None else 2 * edge + 1 - for orig, edge in zip(shape, edge_items) - ) - - -@dataclasses.dataclass(frozen=True) -class ArrayvizRendering(figures.RendersAsRootInIPython): - """A rendering of an array with Arrayviz. - - Attributes: - html_src: HTML source for the rendering. - """ - - html_src: str - - def _compute_collapsed_width(self) -> int: - return 80 - - def _compute_newlines_in_expanded_parent(self) -> int: - return 10 - - def foldables_in_this_part(self) -> Sequence[part_interface.FoldableTreeNode]: - return () - - def _compute_tags_in_this_part(self) -> frozenset[Any]: - return frozenset() - - def render_to_text( - self, - stream: io.TextIOBase, - *, - expanded_parent: bool, - indent: int, - roundtrip_mode: bool, - render_context: dict[Any, Any], - ): - stream.write("") - - def html_setup_parts( - self, setup_context: part_interface.HtmlContextForSetup - ) -> set[part_interface.CSSStyleRule | part_interface.JavaScriptDefn]: - del setup_context - return _html_setup() - - def render_to_html( - self, - stream: io.TextIOBase, - *, - at_beginning_of_line: bool = False, - render_context: dict[Any, Any], - ): - stream.write(self.html_src) - - default_sequential_colormap: context.ContextualValue[ list[tuple[int, int, int]] ] = context.ContextualValue( @@ -713,7 +135,7 @@ def render_array( axis_item_labels: dict[AxisName | int, list[str]] | None = None, value_item_labels: dict[int, str] | None = None, axis_labels: dict[AxisName | int, str] | None = None, -) -> ArrayvizRendering: +) -> figures_impl.TreescopeFigure: """Renders an array (positional or named) to a displayable HTML object. Each element of the array is rendered to a fixed-size square, with its @@ -890,7 +312,7 @@ def render_array( if truncate: # Infer a good truncated shape for this array. - edge_items_per_axis = infer_balanced_truncation( + edge_items_per_axis = arrayviz_impl.infer_balanced_truncation( tuple(info.size for info in array_axis_info), maximum_size=maximum_size, cutoff_size_per_axis=cutoff_size_per_axis, @@ -913,31 +335,33 @@ def render_array( # axes to the rows and columns until we've assigned all of them, trying to # balance rows and columns. - row_infos, column_infos = infer_rows_and_columns( + row_infos, column_infos = arrayviz_impl.infer_rows_and_columns( all_axes=[ax for ax in array_axis_info if ax not in slider_infos], known_rows=row_infos, known_columns=column_infos, edge_items_per_axis=edge_items_per_axis, ) - return _render_pretruncated( - array_axis_info=array_axis_info, - row_infos=row_infos, - column_infos=column_infos, - slider_infos=slider_infos, - truncated_array_data=truncated_array_data, - truncated_mask_data=truncated_mask_data, - edge_items_per_axis=edge_items_per_axis, - continuous=continuous, - around_zero=around_zero, - vmax=vmax, - vmin=vmin, - trim_outliers=trim_outliers, - dynamic_colormap=dynamic_colormap, - colormap=colormap, - axis_item_labels=axis_item_labels, - value_item_labels=value_item_labels, - axis_labels=axis_labels, + return figures_impl.TreescopeFigure( + _render_pretruncated( + array_axis_info=array_axis_info, + row_infos=row_infos, + column_infos=column_infos, + slider_infos=slider_infos, + truncated_array_data=truncated_array_data, + truncated_mask_data=truncated_mask_data, + edge_items_per_axis=edge_items_per_axis, + continuous=continuous, + around_zero=around_zero, + vmax=vmax, + vmin=vmin, + trim_outliers=trim_outliers, + dynamic_colormap=dynamic_colormap, + colormap=colormap, + axis_item_labels=axis_item_labels, + value_item_labels=value_item_labels, + axis_labels=axis_labels, + ) ) @@ -960,7 +384,7 @@ def _render_pretruncated( axis_item_labels: dict[AxisName | int, list[str]] | None, value_item_labels: dict[int, str] | None, axis_labels: dict[AxisName | int, str] | None, -) -> ArrayvizRendering: +) -> arrayviz_impl.ArrayvizRendering: """Internal helper to render an array that has already been truncated.""" if axis_item_labels is None: axis_item_labels = {} @@ -1190,7 +614,7 @@ def _render_pretruncated( if not around_zero: raise ValueError("Cannot use dynamic_colormap without around_zero.") - raw_min_abs, raw_max_abs = _infer_abs_min_max( + raw_min_abs, raw_max_abs = arrayviz_impl.infer_abs_min_max( truncated_array_data, truncated_mask_data ) raw_min_abs = float(raw_min_abs) @@ -1201,7 +625,7 @@ def _render_pretruncated( # Infer concrete `vmin` and `vmax`. if continuous and (vmin is None or vmax is None): - vmin, vmax = _infer_vmin_vmax( + vmin, vmax = arrayviz_impl.infer_vmin_vmax( array=truncated_array_data, mask=truncated_mask_data, vmin=vmin, @@ -1247,7 +671,7 @@ def _render_pretruncated( info_parts.append(" Hover/click for array data.") # Step 8: Render it! - html_src = _render_array_to_html( + html_src = arrayviz_impl.render_array_data_to_html( array_data=truncated_array_data, valid_mask=truncated_mask_data, column_axes=column_data_axes, @@ -1266,7 +690,7 @@ def _render_pretruncated( raw_min_abs=raw_min_abs, raw_max_abs=raw_max_abs, ) - return ArrayvizRendering(html_src) + return arrayviz_impl.ArrayvizRendering(html_src) def render_sharding_info( @@ -1274,7 +698,7 @@ def render_sharding_info( sharding_info: ndarray_adapters.ShardingInfo, rows: Sequence[int | AxisName] = (), columns: Sequence[int | AxisName] = (), -) -> ArrayvizRendering: +) -> figures_impl.TreescopeFigure: """Renders the sharding of an array. This is a helper function for rendering array shardings. It can be used either @@ -1318,7 +742,7 @@ def render_sharding_info( # Compute a truncation for visualizing a single shard. Each shard will be # shown as a shrunken version of the actual shard dimensions, roughly # proportional to the shard sizes. - mini_trunc = infer_balanced_truncation( + mini_trunc = arrayviz_impl.infer_balanced_truncation( shape=array_shape, maximum_size=1000, cutoff_size_per_axis=10, @@ -1328,7 +752,7 @@ def render_sharding_info( # Infer an axis ordering. known_row_infos = [info_by_name_or_position[spec] for spec in rows] known_column_infos = [info_by_name_or_position[spec] for spec in columns] - row_infos, column_infos = infer_rows_and_columns( + row_infos, column_infos = arrayviz_impl.infer_rows_and_columns( all_axes=array_axis_info, known_rows=known_row_infos, known_columns=known_column_infos, @@ -1498,7 +922,7 @@ def render_sharding_info( # Build the rendering. html_srcs = [] html_srcs.append( - _render_array_to_html( + arrayviz_impl.render_array_data_to_html( array_data=dest, valid_mask=destmask, column_axes=[data_axis_from_axis_info[c] for c in column_infos], @@ -1523,20 +947,22 @@ def render_sharding_info( if i == 0: html_srcs.append(f"{sharding_info.device_type}") label = ",".join(f"{d}" for d in shard_devices) - subsrc = integer_digitbox( - shard_offset_values[shard_offsets], - label_bottom="", - ).html_src - html_srcs.append(f" {subsrc} {label}") + part = integer_digitbox( + shard_offset_values[shard_offsets], label_bottom="" + ).treescope_part + assert isinstance(part, arrayviz_impl.ArrayvizDigitboxRendering) + html_srcs.append(f" {part.html_src} {label}") html_srcs.append("") - return ArrayvizRendering("".join(html_srcs)) + return figures_impl.TreescopeFigure( + arrayviz_impl.ArrayvizRendering("".join(html_srcs)) + ) def render_array_sharding( array: ArrayInRegistry, rows: Sequence[int | AxisName] = (), columns: Sequence[int | AxisName] = (), -) -> ArrayvizRendering: +) -> figures_impl.TreescopeFigure: """Renders the sharding of an array. Args: @@ -1576,23 +1002,12 @@ def render_array_sharding( ) -@dataclasses.dataclass(frozen=True) -class ArrayvizDigitboxRendering(ArrayvizRendering): - """A rendering of a single digitbox with Arrayviz.""" - - def _compute_collapsed_width(self) -> int: - return 2 - - def _compute_newlines_in_expanded_parent(self) -> int: - return 1 - - def integer_digitbox( value: int, label_top: str = "", label_bottom: str | None = None, size: str = "1em", -) -> ArrayvizDigitboxRendering: +) -> figures_impl.TreescopeFigure: """Returns a "digitbox" rendering of a single integer. Args: @@ -1630,92 +1045,9 @@ def integer_digitbox( "" "" ) - return ArrayvizRendering(src) - - -@dataclasses.dataclass(frozen=True) -class ValueColoredTextbox( - figures.RendersAsRootInIPython, basic_parts.DeferringToChild -): - """A rendering of text with a colored background. - - Attributes: - child: Child part to render. - text_color: Color for the text. - background_color: Color for the background, usually from a colormap. - out_of_bounds: Whether this value was out of bounds of the colormap. - value: Underlying float value that is being visualized. Rendered on hover. - """ - - child: part_interface.RenderableTreePart - text_color: str - background_color: str - out_of_bounds: bool - value: float - - def html_setup_parts( - self, setup_context: part_interface.HtmlContextForSetup - ) -> set[part_interface.CSSStyleRule | part_interface.JavaScriptDefn]: - return ( - { - part_interface.CSSStyleRule( - html_escaping.without_repeated_whitespace(""" - .arrayviz_textbox { - padding-left: 0.5ch; - padding-right: 0.5ch; - outline: 1px solid black; - position: relative; - display: inline-block; - font-family: monospace; - white-space: pre; - margin-top: 1px; - box-sizing: border-box; - } - .arrayviz_textbox.out_of_bounds { - outline: 3px double darkorange; - } - .arrayviz_textbox .value { - display: none; - position: absolute; - bottom: 110%; - left: 0; - overflow: visible; - color: black; - background-color: white; - font-size: 0.7em; - } - .arrayviz_textbox:hover .value { - display: block; - } - """) - ) - } - | self.child.html_setup_parts(setup_context) - ) - - def render_to_html( - self, - stream: io.TextIOBase, - *, - at_beginning_of_line: bool = False, - render_context: dict[Any, Any], - ): - class_string = "arrayviz_textbox" - if self.out_of_bounds: - class_string += " out_of_bounds" - bg_color = html_escaping.escape_html_attribute(self.background_color) - text_color = html_escaping.escape_html_attribute(self.text_color) - stream.write( - f'' - f'{float(self.value):.4g}' - ) - self.child.render_to_html( - stream, - at_beginning_of_line=False, - render_context=render_context, - ) - stream.write("") + return figures_impl.TreescopeFigure( + arrayviz_impl.ArrayvizDigitboxRendering(src) + ) def text_on_color( @@ -1724,7 +1056,7 @@ def text_on_color( vmax: float = 1.0, vmin: float | None = None, colormap: list[tuple[int, int, int]] | None = None, -) -> ValueColoredTextbox: +) -> figures_impl.TreescopeFigure: """Renders some text on colored background, similar to arrayviz coloring. Args: @@ -1778,10 +1110,12 @@ def text_on_color( text_color = "black" else: text_color = "white" - return ValueColoredTextbox( - child=basic_parts.Text(text), - text_color=text_color, - background_color=f"rgb({r} {g} {b})", - out_of_bounds=is_out_of_bounds, - value=value, + return figures_impl.TreescopeFigure( + arrayviz_impl.ValueColoredTextbox( + child=rendering_parts.text(text), + text_color=text_color, + background_color=f"rgb({r} {g} {b})", + out_of_bounds=is_out_of_bounds, + value=value, + ) ) diff --git a/penzai/treescope/autovisualize.py b/penzai/treescope/autovisualize.py index f659c1c..fb0c86e 100644 --- a/penzai/treescope/autovisualize.py +++ b/penzai/treescope/autovisualize.py @@ -32,8 +32,8 @@ from typing import Any, Protocol from penzai.treescope import context -from penzai.treescope.foldable_representation import embedded_iframe -from penzai.treescope.foldable_representation import part_interface +from penzai.treescope import rendering_parts +from penzai.treescope._internal import object_inspection @dataclasses.dataclass @@ -46,7 +46,7 @@ class IPythonVisualization: object. If False, the display object will be rendered below the object. """ - display_object: embedded_iframe.HasReprHtml + display_object: object_inspection.HasReprHtml replace: bool = False @@ -59,7 +59,7 @@ class CustomTreescopeVisualization: subtree. """ - rendering: part_interface.RenderableAndLineAnnotations + rendering: rendering_parts.RenderableAndLineAnnotations @dataclasses.dataclass diff --git a/penzai/treescope/copypaste_fallback.py b/penzai/treescope/copypaste_fallback.py index cb2b5aa..16621ae 100644 --- a/penzai/treescope/copypaste_fallback.py +++ b/penzai/treescope/copypaste_fallback.py @@ -20,12 +20,9 @@ import sys from typing import Any +from penzai.treescope import handlers from penzai.treescope import renderer -from penzai.treescope.foldable_representation import part_interface -from penzai.treescope.handlers import builtin_atom_handler -from penzai.treescope.handlers import builtin_structure_handler -from penzai.treescope.handlers import canonical_alias_postprocessor -from penzai.treescope.handlers import function_reflection_handlers +from penzai.treescope import rendering_parts @dataclasses.dataclass(frozen=True) @@ -73,7 +70,7 @@ def __treescope_color__(self): def render_not_roundtrippable( obj: NotRoundtrippable, repr_override: str | None = None, -) -> part_interface.RenderableTreePart: +) -> rendering_parts.RenderableTreePart: """Renders an object as a `NotRoundtrippable` instance. This can be used inside handlers for non-roundtrippable objects to render @@ -89,12 +86,11 @@ def render_not_roundtrippable( """ fallback_renderer = renderer.TreescopeRenderer( handlers=[ - builtin_atom_handler.handle_builtin_atoms, - builtin_structure_handler.handle_builtin_structures, - function_reflection_handlers.handle_code_objects_with_reflection, + handlers.handle_basic_types, + handlers.handle_code_objects_with_reflection, ], wrapper_hooks=[ - canonical_alias_postprocessor.replace_with_canonical_aliases, + handlers.replace_with_canonical_aliases, ], context_builders=[], ) diff --git a/penzai/treescope/default_renderer.py b/penzai/treescope/default_renderer.py index 797bfb9..cb2cd84 100644 --- a/penzai/treescope/default_renderer.py +++ b/penzai/treescope/default_renderer.py @@ -18,22 +18,12 @@ import functools from typing import Any, Callable from penzai.treescope import context +from penzai.treescope import handlers +from penzai.treescope import lowering from penzai.treescope import renderer +from penzai.treescope import rendering_parts from penzai.treescope import type_registries -from penzai.treescope.foldable_representation import basic_parts -from penzai.treescope.foldable_representation import foldable_impl -from penzai.treescope.foldable_representation import layout_algorithms -from penzai.treescope.foldable_representation import part_interface -from penzai.treescope.handlers import autovisualizer_hook -from penzai.treescope.handlers import builtin_atom_handler -from penzai.treescope.handlers import builtin_structure_handler -from penzai.treescope.handlers import canonical_alias_postprocessor -from penzai.treescope.handlers import custom_type_handlers -from penzai.treescope.handlers import function_reflection_handlers -from penzai.treescope.handlers import generic_pytree_handler -from penzai.treescope.handlers import generic_repr_handler -from penzai.treescope.handlers import repr_html_postprocessor -from penzai.treescope.handlers import shared_value_postprocessor +from penzai.treescope._internal import layout_algorithms active_renderer: context.ContextualValue[renderer.TreescopeRenderer] = ( @@ -43,37 +33,35 @@ initial_value=renderer.TreescopeRenderer( handlers=[ # Objects with `__penzai_repr__` defined. - custom_type_handlers.handle_via_penzai_repr_method, - # Objects in the global registry of type handlers. - custom_type_handlers.handle_via_global_registry, + handlers.handle_via_penzai_repr_method, + # Objects in the global registry of custom type handlers. + handlers.handle_via_global_registry, # Reflection of functions and classes. - function_reflection_handlers.handle_code_objects_with_reflection, - # Numbers, strings, constants, enums, etc. - builtin_atom_handler.handle_builtin_atoms, - # Lists, dicts, tuples, dataclasses, namedtuples, etc. - builtin_structure_handler.handle_builtin_structures, + handlers.handle_code_objects_with_reflection, + # Render basic builtin types, namedtuples, and dataclasses. + handlers.handle_basic_types, # Fallback for unknown pytree types: Show repr and also the # PyTree children. Note: This is a no-op unless JAX has been # imported. - generic_pytree_handler.handle_arbitrary_pytrees, + handlers.handle_arbitrary_pytrees, # Fallback to ordinary `repr` for any other object. - generic_repr_handler.handle_anything_with_repr, + handlers.handle_anything_with_repr, ], wrapper_hooks=[ # Allow user-configurable visualizations. - autovisualizer_hook.use_autovisualizer_if_present, + handlers.use_autovisualizer_if_present, # Show display objects inline. - repr_html_postprocessor.append_repr_html_when_present, + handlers.append_repr_html_when_present, # Collapse well-known objects into aliases. - canonical_alias_postprocessor.replace_with_canonical_aliases, + handlers.replace_with_canonical_aliases, # Annotate multiple references to the same mutable Python # object. - shared_value_postprocessor.check_for_shared_values, + handlers.check_for_shared_values, ], context_builders=[ # Set up a new context for each rendered object when rendering # shared values. - shared_value_postprocessor.setup_shared_value_context, + handlers.setup_shared_value_context, # Update type registries to account for newly imported # modules before rendering. type_registries.update_registries_for_imports, @@ -95,7 +83,7 @@ """ active_expansion_strategy = context.ContextualValue[ - Callable[[part_interface.RenderableTreePart], None] + Callable[[rendering_parts.RenderableTreePart], None] ]( module=__name__, qualname="active_expansion_strategy", @@ -162,7 +150,7 @@ def using_expansion_strategy( def build_foldable_representation( value: Any, ignore_exceptions: bool = False, -) -> part_interface.RenderableAndLineAnnotations: +) -> rendering_parts.RenderableAndLineAnnotations: """Builds a foldable representation of an object using default configuration. Uses the default renderer and expansion strategy. @@ -205,10 +193,10 @@ def render_to_text( Returns: A text representation of the object. """ - foldable_ir = basic_parts.build_full_line_with_annotations( + foldable_ir = rendering_parts.build_full_line_with_annotations( build_foldable_representation(value, ignore_exceptions=ignore_exceptions) ) - return foldable_impl.render_to_text_as_root(foldable_ir, roundtrip_mode) + return lowering.render_to_text_as_root(foldable_ir, roundtrip_mode) def render_to_html( @@ -230,9 +218,9 @@ def render_to_html( Returns: HTML source code for the foldable representation of the object. """ - foldable_ir = basic_parts.build_full_line_with_annotations( + foldable_ir = rendering_parts.build_full_line_with_annotations( build_foldable_representation(value, ignore_exceptions=ignore_exceptions) ) - return foldable_impl.render_to_html_as_root( + return lowering.render_to_html_as_root( foldable_ir, roundtrip_mode, compressed=compressed ) diff --git a/penzai/treescope/figures.py b/penzai/treescope/figures.py index 1c27447..22e8e2c 100644 --- a/penzai/treescope/figures.py +++ b/penzai/treescope/figures.py @@ -22,224 +22,74 @@ from __future__ import annotations -import dataclasses -import io from typing import Any from penzai.treescope import default_renderer -from penzai.treescope import html_escaping -from penzai.treescope.foldable_representation import basic_parts -from penzai.treescope.foldable_representation import embedded_iframe -from penzai.treescope.foldable_representation import foldable_impl -from penzai.treescope.foldable_representation import part_interface +from penzai.treescope import rendering_parts +from penzai.treescope._internal import figures_impl +from penzai.treescope._internal import object_inspection +from penzai.treescope._internal.parts import basic_parts +from penzai.treescope._internal.parts import embedded_iframe -class RendersAsRootInIPython(part_interface.RenderableTreePart): - """Base class / mixin that implements ``_repr_html_`` for treescope parts. - - Subclasses of this class will render themselves as rich display objects when - displayed in IPython, instead of having their contents formatted with a - pretty printer or ``repr``. - """ - - def _repr_html_(self) -> str: - """Returns a rich HTML representation of this part.""" - return foldable_impl.render_to_html_as_root(self, compressed=True) - - def _repr_pretty_(self, p, cycle): - """Builds a representation of this part for the IPython text prettyprinter.""" - del cycle - p.text(foldable_impl.render_to_text_as_root(self)) - - -@dataclasses.dataclass(frozen=True) -class TreescopeRenderingFigure( - basic_parts.DeferringToChild, RendersAsRootInIPython -): - """Wrapper that renders its child rendering as a HTML figure. - - Attributes: - child: Child to render as a figure. - """ - - child: part_interface.RenderableTreePart - - -class InlineBlock(basic_parts.BaseSpanGroup): - """Renders an object in "inline-block" mode.""" - - def _span_css_class(self) -> str: - return "inline_block" - - def _span_css_rule( - self, context: part_interface.HtmlContextForSetup - ) -> part_interface.CSSStyleRule: - return part_interface.CSSStyleRule( - html_escaping.without_repeated_whitespace(""" - .inline_block { - display: inline-block; - } - """) - ) - - -def wrap_as_treescope_figure(value: Any) -> part_interface.RenderableTreePart: - """Converts an arbitrary object to a renderable treescope part if possible. - - Behavior depends on the type of `value`: - - * If ``value`` is an instance of `RendersAsRootInIPython`, returns it - unchanged, since it knows how to render itself. - * If ``value`` is a string, returns a rendering of that string. - * If ``value`` has a ``_repr_html_`` method (but isn't an instance of - `RendersAsRootInIPython`), returns an embedded iframe with the given HTML - contents. - * Otherwise, renders the value using the default treescope renderer, but - strips off any top-level comments / copy button annotations. - - The typical use is to provide helper constructors for containers to allow - rendering lots of different objects in the "obvious" way. - - Args: - value: Value to wrap. - - Returns: - A renderable treescope part showing the value. - """ - if isinstance(value, RendersAsRootInIPython): - return value - elif isinstance(value, str): - return basic_parts.Text(value) - else: - maybe_html = embedded_iframe.to_html(value) - if maybe_html: - return InlineBlock( - embedded_iframe.EmbeddedIFrame( - maybe_html, - fallback_in_text_mode=basic_parts.Text(object.__repr__(value)), - ) - ) - else: - return default_renderer.build_foldable_representation(value).renderable - - -class AllowWordWrap(basic_parts.BaseSpanGroup): - """Allows line breaks in its child..""" - - def _span_css_class(self) -> str: - return "allow_wrap" - - def _span_css_rule( - self, context: part_interface.HtmlContextForSetup - ) -> part_interface.CSSStyleRule: - return part_interface.CSSStyleRule( - html_escaping.without_repeated_whitespace(""" - .allow_wrap { - white-space: pre-wrap; - } - """) - ) - - -class PreventWordWrap(basic_parts.BaseSpanGroup): - """Allows line breaks in its child..""" - - def _span_css_class(self) -> str: - return "prevent_wrap" - - def _span_css_rule( - self, context: part_interface.HtmlContextForSetup - ) -> part_interface.CSSStyleRule: - return part_interface.CSSStyleRule( - html_escaping.without_repeated_whitespace(""" - .prevent_wrap { - white-space: pre; - } - """) - ) - - -def inline(*parts: Any, wrap: bool = False) -> RendersAsRootInIPython: +def inline( + *subfigures: Any, wrap: bool = False +) -> figures_impl.TreescopeFigure: """Returns a figure that arranges a set of displayable objects along a line. Args: - *parts: Subfigures to display inline. These will be displayed using - `wrap_as_treescope_figure`. + *subfigures: Subfigures to display inline. wrap: Whether to wrap (insert newlines) between words at the end of a line. Returns: A figure which can be rendered in IPython or used to build more complex figures. """ - siblings = basic_parts.siblings( - *(wrap_as_treescope_figure(part) for part in parts) + siblings = rendering_parts.siblings( + *(treescope_part_from_display_object(subfig) for subfig in subfigures) ) if wrap: - return TreescopeRenderingFigure(AllowWordWrap(siblings)) + return figures_impl.TreescopeFigure(figures_impl.AllowWordWrap(siblings)) else: - return TreescopeRenderingFigure(PreventWordWrap(siblings)) + return figures_impl.TreescopeFigure(figures_impl.PreventWordWrap(siblings)) -def indented(subfigure: Any) -> RendersAsRootInIPython: +def indented(subfigure: Any) -> figures_impl.TreescopeFigure: """Returns a figure object that displays a value with an indent. Args: - subfigure: A value to render indented. Will be wrapped using - `wrap_as_treescope_figure`. + subfigure: A value to render indented. """ - return TreescopeRenderingFigure( - basic_parts.IndentedChildren.build([wrap_as_treescope_figure(subfigure)]) + return figures_impl.TreescopeFigure( + rendering_parts.indented_children([ + rendering_parts.vertical_space("0.25em"), + treescope_part_from_display_object(subfigure), + rendering_parts.vertical_space("0.25em"), + ]) ) -@dataclasses.dataclass(frozen=True) -class CSSStyled(basic_parts.DeferringToChild): - """Adjusts the CSS style of its child. - - Attributes: - child: Child to render. - css: A CSS style string. - """ - - child: part_interface.RenderableTreePart - style: str - - def render_to_html( - self, - stream: io.TextIOBase, - *, - at_beginning_of_line: bool = False, - render_context: dict[Any, Any], - ): - style = html_escaping.escape_html_attribute(self.style) - stream.write(f'') - self.child.render_to_html( - stream, - at_beginning_of_line=at_beginning_of_line, - render_context=render_context, - ) - stream.write("") - - -def styled(subfigure: Any, style: str) -> RendersAsRootInIPython: +def styled(subfigure: Any, style: str) -> figures_impl.TreescopeFigure: """Returns a CSS-styled version of the first figure. Args: - subfigure: A value to render. Will be wrapped using - `wrap_as_treescope_figure`. + subfigure: A value to render. style: A CSS style string. """ - return TreescopeRenderingFigure( - CSSStyled(wrap_as_treescope_figure(subfigure), style) + return figures_impl.TreescopeFigure( + figures_impl.CSSStyled( + treescope_part_from_display_object(subfigure), style + ) ) -def with_font_size(subfigure: Any, size: str | float) -> RendersAsRootInIPython: +def with_font_size( + subfigure: Any, size: str | float +) -> figures_impl.TreescopeFigure: """Returns a scaled version of the first figure. Args: - subfigure: A value to render. Will be wrapped using - `wrap_as_treescope_figure`. + subfigure: A value to render. size: A multiplier for the font size (as a float) or a string giving a specific CSS font size (e.g. "14pt" or "2em"). """ @@ -247,31 +97,92 @@ def with_font_size(subfigure: Any, size: str | float) -> RendersAsRootInIPython: style = f"font-size: {size}" else: style = f"font-size: {size}em" - return TreescopeRenderingFigure( - CSSStyled(wrap_as_treescope_figure(subfigure), style) + return figures_impl.TreescopeFigure( + figures_impl.CSSStyled( + treescope_part_from_display_object(subfigure), style + ) ) -def with_color(subfigure: Any, color: str) -> RendersAsRootInIPython: +def with_color(subfigure: Any, color: str) -> figures_impl.TreescopeFigure: """Returns a colored version of the first figure. Args: - subfigure: A value to render. Will be wrapped using - `wrap_as_treescope_figure`. + subfigure: A value to render. color: Any CSS color string. """ - return TreescopeRenderingFigure( - CSSStyled(wrap_as_treescope_figure(subfigure), f"color: {color}") + return figures_impl.TreescopeFigure( + figures_impl.CSSStyled( + treescope_part_from_display_object(subfigure), f"color: {color}" + ) ) -def bolded(subfigure: Any) -> RendersAsRootInIPython: +def bolded(subfigure: Any) -> figures_impl.TreescopeFigure: """Returns a bolded version of the first figure. Args: - subfigure: A value to render. Will be wrapped using - `wrap_as_treescope_figure`. + subfigure: A value to render. """ - return TreescopeRenderingFigure( - CSSStyled(wrap_as_treescope_figure(subfigure), "font-weight: bold") + return figures_impl.TreescopeFigure( + figures_impl.CSSStyled( + treescope_part_from_display_object(subfigure), "font-weight: bold" + ) ) + + +def figure_from_treescope_rendering_part( + part: rendering_parts.RenderableTreePart, +) -> figures_impl.TreescopeFigure: + """Returns a figure object that displays a Treescope rendering part. + + Args: + part: A Treescope rendering part to display, usually constructed via + `repr_lib` or `rendering_parts`. + + Returns: + A figure object that can be rendered in IPython. + """ + return figures_impl.TreescopeFigure(part) + + +def treescope_part_from_display_object( + value: Any, +) -> rendering_parts.RenderableTreePart: + """Converts an arbitrary object to a renderable treescope part if possible. + + Behavior depends on the type of `value`: + + * If ``value`` is an instance of `TreescopeFigure`, unwraps the + underlying treescope part. + * If ``value`` is a string, returns a rendering of that string. + * If ``value`` has a ``_repr_html_`` method (but isn't an instance of + `TreescopeFigure`), returns an embedded iframe with the given HTML + contents. + * Otherwise, renders the value using the default treescope renderer, but + strips off any top-level comments / copy button annotations. + + The typical use is to provide helper constructors for containers to allow + rendering lots of different objects in the "obvious" way. + + Args: + value: Value to wrap. + + Returns: + A renderable treescope part showing the value. + """ + if isinstance(value, figures_impl.TreescopeFigure): + return value.treescope_part + elif isinstance(value, str): + return basic_parts.Text(value) + else: + maybe_html = object_inspection.to_html(value) + if maybe_html: + return figures_impl.InlineBlock( + embedded_iframe.embedded_iframe( + maybe_html, + fallback_in_text_mode=basic_parts.Text(object.__repr__(value)), + ) + ) + else: + return default_renderer.build_foldable_representation(value).renderable diff --git a/penzai/treescope/formatting_util.py b/penzai/treescope/formatting_util.py index b5ba922..576297a 100644 --- a/penzai/treescope/formatting_util.py +++ b/penzai/treescope/formatting_util.py @@ -15,6 +15,7 @@ """Utilities for formatting and rendering.""" import hashlib +import warnings def oklch_color( @@ -71,3 +72,39 @@ def color_from_string( uniform = (fingerprint % (2**16)) / (2**16) # Convert to a hue. return oklch_color(lightness, chroma, 360 * uniform, alpha) + + +def parse_simple_color_and_pattern_spec( + requested_color: str | tuple[str, str], typename_for_warning: str | None +) -> tuple[str | None, str | None]: + """Parses a background color and pattern from a user-provided color request. + + Args: + requested_color: A color request, which is either a single CSS color or a + tuple of outline and background colors. + typename_for_warning: If provided, and the color is invalid, a warning will + be issued with this typename as context. + + Returns: + A tuple (background_color, background_pattern) that can be passed to + the Treescope low-level representation construction functions (such as + `build_foldable_tree_node_from_children`) that will configure it with the + given background color and outline. + """ + if isinstance(requested_color, str): + background_color = requested_color + background_pattern = None + elif isinstance(requested_color, tuple) and len(requested_color) == 2: + background_color = requested_color[0] + background_pattern = ( + f"linear-gradient({requested_color[1]},{requested_color[1]})" + ) + else: + if typename_for_warning: + warnings.warn( + f"{typename_for_warning} requested an invalid color:" + f" {requested_color} (not a string or a tuple)" + ) + background_color = None + background_pattern = None + return background_color, background_pattern diff --git a/penzai/treescope/handlers.py b/penzai/treescope/handlers.py new file mode 100644 index 0000000..bb30b41 --- /dev/null +++ b/penzai/treescope/handlers.py @@ -0,0 +1,52 @@ +# Copyright 2024 The Penzai Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Treescope rendering handlers, for use in renderer configurations. + +These handlers are responsible for implementing each of the steps of the +default Treescope renderer, and can be used individually to build custom +renderer configurations. +""" + +# pylint: disable=g-importing-member,g-multiple-import,unused-import + +from penzai.treescope._internal.handlers.autovisualizer_hook import ( + use_autovisualizer_if_present, +) +from penzai.treescope._internal.handlers.basic_types_handler import ( + handle_basic_types, +) +from penzai.treescope._internal.handlers.canonical_alias_postprocessor import ( + replace_with_canonical_aliases, +) +from penzai.treescope._internal.handlers.custom_type_handlers import ( + handle_via_global_registry, + handle_via_penzai_repr_method, +) +from penzai.treescope._internal.handlers.function_reflection_handlers import ( + handle_code_objects_with_reflection, +) +from penzai.treescope._internal.handlers.generic_pytree_handler import ( + handle_arbitrary_pytrees, +) +from penzai.treescope._internal.handlers.generic_repr_handler import ( + handle_anything_with_repr, +) +from penzai.treescope._internal.handlers.repr_html_postprocessor import ( + append_repr_html_when_present, +) +from penzai.treescope._internal.handlers.shared_value_postprocessor import ( + check_for_shared_values, + setup_shared_value_context, +) diff --git a/penzai/treescope/handlers/builtin_atom_handler.py b/penzai/treescope/handlers/builtin_atom_handler.py deleted file mode 100644 index 370dfee..0000000 --- a/penzai/treescope/handlers/builtin_atom_handler.py +++ /dev/null @@ -1,153 +0,0 @@ -# Copyright 2024 The Penzai Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Handlers for builtin atom types (e.g. constants or literals).""" -from __future__ import annotations - -import enum -from typing import Any - -from penzai.treescope import html_escaping -from penzai.treescope import renderer -from penzai.treescope.foldable_representation import basic_parts -from penzai.treescope.foldable_representation import common_structures -from penzai.treescope.foldable_representation import common_styles -from penzai.treescope.foldable_representation import part_interface - - -CSSStyleRule = part_interface.CSSStyleRule -HtmlContextForSetup = part_interface.HtmlContextForSetup - - -class KeywordColor(basic_parts.BaseSpanGroup): - """Renders its child in a color for keywords.""" - - def _span_css_class(self) -> str: - return "color_keyword" - - def _span_css_rule(self, context: HtmlContextForSetup) -> CSSStyleRule: - return CSSStyleRule(html_escaping.without_repeated_whitespace(""" - .color_keyword - { - color: #0000ff; - } - """)) - - -class NumberColor(basic_parts.BaseSpanGroup): - """Renders its child in a color for numbers.""" - - def _span_css_class(self) -> str: - return "color_number" - - def _span_css_rule(self, context: HtmlContextForSetup) -> CSSStyleRule: - return CSSStyleRule(html_escaping.without_repeated_whitespace(""" - .color_number - { - color: #098156; - } - """)) - - -class StringLiteralColor(basic_parts.BaseSpanGroup): - """Renders its child in a color for string literals.""" - - def _span_css_class(self) -> str: - return "color_string" - - def _span_css_rule(self, context: HtmlContextForSetup) -> CSSStyleRule: - return CSSStyleRule(html_escaping.without_repeated_whitespace(""" - .color_string - { - color: #a31515; - } - """)) - - -def handle_builtin_atoms( - node: Any, - path: str | None, - subtree_renderer: renderer.TreescopeSubtreeRenderer, -) -> ( - part_interface.RenderableTreePart - | part_interface.RenderableAndLineAnnotations - | type(NotImplemented) -): - """Handles builtin atom types.""" - del subtree_renderer - - # String literals. - if isinstance(node, (str, bytes)): - lines = node.splitlines(keepends=True) - if len(lines) > 1: - # For multiline strings, we use two renderings: - # - When collapsed, they render with ordinary `repr`, - # - When expanded, they render as the implicit concatenation of per-line - # string literals. - # Note that the `repr` for a string sometimes switches delimiters - # depending on whether the string contains quotes or not, so we can't do - # much manipulation of the strings themselves. This means that the safest - # thing to do is to just embed two copies of the string into the IR, - # one for the full string and the other for each line. - return common_structures.build_custom_foldable_tree_node( - contents=StringLiteralColor( - basic_parts.FoldCondition( - collapsed=basic_parts.Text(repr(node)), - expanded=basic_parts.IndentedChildren.build( - children=[basic_parts.Text(repr(line)) for line in lines], - comma_separated=False, - ), - ) - ), - path=path, - ) - else: - # No newlines, so render it on a single line. - return common_structures.build_one_line_tree_node( - StringLiteralColor(basic_parts.Text(repr(node))), path - ) - - # Numeric literals. - if isinstance(node, (int, float)): - return common_structures.build_one_line_tree_node( - NumberColor(basic_parts.Text(repr(node))), path - ) - - # Keyword objects. - if any( - node is literal - for literal in (False, True, None, Ellipsis, NotImplemented) - ): - return common_structures.build_one_line_tree_node( - KeywordColor(basic_parts.Text(repr(node))), path - ) - - # Enums. (Rendered roundtrippably, unlike the normal enum `repr`.) - if isinstance(node, enum.Enum): - cls = type(node) - if node is getattr(cls, node.name): - return common_structures.build_one_line_tree_node( - basic_parts.siblings_with_annotations( - common_structures.maybe_qualified_type_name(cls), - "." + node.name, - extra_annotations=[ - common_styles.CommentColor( - basic_parts.Text(f" # value: {repr(node.value)}") - ) - ], - ), - path, - ) - - return NotImplemented diff --git a/penzai/treescope/handlers/builtin_structure_handler.py b/penzai/treescope/handlers/builtin_structure_handler.py deleted file mode 100644 index 080ee12..0000000 --- a/penzai/treescope/handlers/builtin_structure_handler.py +++ /dev/null @@ -1,421 +0,0 @@ -# Copyright 2024 The Penzai Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Handler for built-in collection types (lists, sets, dicts, etc).""" - -from __future__ import annotations - -import ast -import dataclasses -import types -from typing import Any, Callable, Optional, Sequence -import warnings - -from penzai.treescope import dataclass_util -from penzai.treescope import renderer -from penzai.treescope.foldable_representation import basic_parts -from penzai.treescope.foldable_representation import common_structures -from penzai.treescope.foldable_representation import common_styles -from penzai.treescope.foldable_representation import layout_algorithms -from penzai.treescope.foldable_representation import part_interface - - -CSSStyleRule = part_interface.CSSStyleRule -HtmlContextForSetup = part_interface.HtmlContextForSetup - - -def _dict_to_foldable( - node: dict[Any, Any], - path: str | None, - subtree_renderer: renderer.TreescopeSubtreeRenderer, -) -> part_interface.RenderableAndLineAnnotations: - """Renders a dictionary.""" - - children = [] - for i, (key, child) in enumerate(node.items()): - if i < len(node) - 1: - # Not the last child. Always show a comma, and add a space when - # collapsed. - comma_after = basic_parts.siblings( - ",", basic_parts.FoldCondition(collapsed=basic_parts.Text(" ")) - ) - else: - # Last child: only show the comma when the node is expanded. - comma_after = basic_parts.FoldCondition(expanded=basic_parts.Text(",")) - - child_path = None if path is None else f"{path}[{repr(key)}]" - # Figure out whether this key is simple enough to render inline with - # its value. - key_rendering = subtree_renderer(key) - value_rendering = subtree_renderer(child, path=child_path) - - if ( - key_rendering.renderable.collapsed_width < 40 - and not key_rendering.renderable.foldables_in_this_part() - and key_rendering.annotations.collapsed_width == 0 - ): - # Simple enough to render on one line. - children.append( - basic_parts.siblings_with_annotations( - key_rendering, ": ", value_rendering, comma_after - ) - ) - else: - # Should render on multiple lines. - children.append( - basic_parts.siblings( - basic_parts.build_full_line_with_annotations( - key_rendering, - ":", - basic_parts.FoldCondition(collapsed=basic_parts.Text(" ")), - ), - basic_parts.IndentedChildren.build([ - basic_parts.siblings_with_annotations( - value_rendering, comma_after - ), - basic_parts.FoldCondition( - expanded=basic_parts.VerticalSpace("0.5em") - ), - ]), - ) - ) - - if type(node) is dict: # pylint: disable=unidiomatic-typecheck - start = "{" - end = "}" - else: - start = basic_parts.siblings( - common_structures.maybe_qualified_type_name(type(node)), "({" - ) - end = "})" - - if not children: - return common_structures.build_one_line_tree_node( - line=basic_parts.siblings(start, end), path=path - ) - else: - return common_structures.build_foldable_tree_node_from_children( - prefix=start, - children=children, - suffix=end, - path=path, - ) - - -def _sequence_or_set_to_foldable( - sequence: dict[Any, Any], - path: str | None, - subtree_renderer: renderer.TreescopeSubtreeRenderer, -) -> part_interface.RenderableAndLineAnnotations: - """Renders a sequence or set to a foldable.""" - - children = [] - for i, child in enumerate(sequence): - child_path = None if path is None else f"{path}[{repr(i)}]" - children.append(subtree_renderer(child, path=child_path)) - - force_trailing_comma = False - if isinstance(sequence, tuple): - before = "(" - after = ")" - if type(sequence) is not tuple: # pylint: disable=unidiomatic-typecheck - # Unusual situation: this is a subclass of `tuple`, but it shouldn't be - # a namedtuple because we look for _fields already. - assert not hasattr(type(sequence), "_fields") - # It's unclear what the constructor will be; try calling it with a single - # ordinary tuple as an argument. - before = basic_parts.siblings( - common_structures.maybe_qualified_type_name(type(sequence)), - "(" + before, - ) - after = after + ")" - force_trailing_comma = len(sequence) == 1 - elif isinstance(sequence, list): - before = "[" - after = "]" - if type(sequence) is not list: # pylint: disable=unidiomatic-typecheck - before = basic_parts.siblings( - common_structures.maybe_qualified_type_name(type(sequence)), - "(" + before, - ) - after = after + ")" - elif isinstance(sequence, set): - if not sequence: - before = "set(" - after = ")" - else: # pylint: disable=unidiomatic-typecheck - before = "{" - after = "}" - - if type(sequence) is not set: # pylint: disable=unidiomatic-typecheck - before = basic_parts.siblings( - common_structures.maybe_qualified_type_name(type(sequence)), - "(" + before, - ) - after = after + ")" - elif isinstance(sequence, frozenset): - before = "frozenset({" - after = "})" - if type(sequence) is not frozenset: # pylint: disable=unidiomatic-typecheck - before = basic_parts.siblings( - common_structures.maybe_qualified_type_name(type(sequence)), - "(" + before, - ) - after = after + ")" - else: - raise ValueError(f"Unrecognized sequence {sequence}") - - if not children: - return common_structures.build_one_line_tree_node( - line=basic_parts.siblings(before, after), path=path - ) - else: - return common_structures.build_foldable_tree_node_from_children( - prefix=before, - children=children, - suffix=after, - path=path, - comma_separated=True, - force_trailing_comma=force_trailing_comma, - ) - - -def build_field_children( - node: dict[Any, Any], - path: str | None, - subtree_renderer: renderer.TreescopeSubtreeRenderer, - fields_or_attribute_names: Sequence[dataclasses.Field[Any] | str], - attr_style_fn: ( - Callable[[str], part_interface.RenderableTreePart] | None - ) = None, -) -> list[part_interface.RenderableTreePart]: - """Renders a set of fields/attributes into a list of comma-separated children. - - This is a helper function used for rendering dataclasses, namedtuples, and - similar objects, of the form :: - - ClassName( - field_name_one=value1, - field_name_two=value2, - ) - - If `fields_or_attribute_names` includes dataclass fields: - - * Metadata for the fields will be visible on hover, - - * Fields with ``repr=False`` will be hidden unless roundtrip mode is enabled. - - Args: - node: Node to render. - path: Path to this node. - subtree_renderer: How to render subtrees (see `TreescopeSubtreeRenderer`) - fields_or_attribute_names: Sequence of fields or attribute names to render. - Any field with the metadata key "treescope_always_collapse" set to True - will always render collapsed. - attr_style_fn: Optional function which makes attributes to a part that - should render them. If not provided, all parts are rendered as plain text. - - Returns: - A list of child objects. This can be passed to - `common_structures.build_foldable_tree_node_from_children` (with - ``comma_separated=False``) - """ - if attr_style_fn is None: - attr_style_fn = basic_parts.Text - - field_names = [] - fields: list[Optional[dataclasses.Field[Any]]] = [] - for field_or_name in fields_or_attribute_names: - if isinstance(field_or_name, str): - field_names.append(field_or_name) - fields.append(None) - else: - field_names.append(field_or_name.name) - fields.append(field_or_name) - - children = [] - for i, (field_name, maybe_field) in enumerate(zip(field_names, fields)): - child_path = None if path is None else f"{path}.{field_name}" - - if i < len(fields) - 1: - # Not the last child. Always show a comma, and add a space when - # collapsed. - comma_after = basic_parts.siblings( - ",", basic_parts.FoldCondition(collapsed=basic_parts.Text(" ")) - ) - else: - # Last child: only show the comma when the node is expanded. - comma_after = basic_parts.FoldCondition(expanded=basic_parts.Text(",")) - - if maybe_field is not None: - hide_except_in_roundtrip = not maybe_field.repr - force_collapsed = maybe_field.metadata.get( - "treescope_always_collapse", False - ) - else: - hide_except_in_roundtrip = False - force_collapsed = False - - field_name_rendering = attr_style_fn(field_name) - - try: - field_value = getattr(node, field_name) - except AttributeError: - child = basic_parts.FoldCondition( - expanded=common_styles.CommentColor( - basic_parts.siblings("# ", field_name_rendering, " is missing") - ) - ) - else: - child = basic_parts.siblings_with_annotations( - field_name_rendering, - "=", - subtree_renderer(field_value, path=child_path), - ) - - child_line = basic_parts.build_full_line_with_annotations( - child, comma_after - ) - if force_collapsed: - layout_algorithms.expand_to_depth(child_line, 0) - if hide_except_in_roundtrip: - child_line = basic_parts.RoundtripCondition(roundtrip=child_line) - - children.append(child_line) - - return children - - -def render_dataclass_constructor( - node: Any, -) -> part_interface.RenderableTreePart: - """Renders the constructor for a dataclass, including the open parenthesis.""" - assert dataclasses.is_dataclass(node) and not isinstance(node, type) - if not dataclass_util.init_takes_fields(type(node)): - constructor_open = basic_parts.siblings( - basic_parts.RoundtripCondition( - roundtrip=basic_parts.Text("pz.dataclass_from_attributes(") - ), - common_structures.maybe_qualified_type_name(type(node)), - basic_parts.RoundtripCondition( - roundtrip=basic_parts.Text(", "), - not_roundtrip=basic_parts.Text("("), - ), - ) - else: - constructor_open = basic_parts.siblings( - common_structures.maybe_qualified_type_name(type(node)), "(" - ) - return constructor_open - - -def parse_color_and_pattern( - requested_color: str | tuple[str, str], typename_for_warning: str | None -) -> tuple[str | None, str | None]: - """Parses a background color and pattern from a user-provided color request.""" - if isinstance(requested_color, str): - background_color = requested_color - background_pattern = None - elif isinstance(requested_color, tuple) and len(requested_color) == 2: - background_color = requested_color[0] - background_pattern = ( - f"linear-gradient({requested_color[1]},{requested_color[1]})" - ) - else: - if typename_for_warning: - warnings.warn( - f"{typename_for_warning} requested an invalid color:" - f" {requested_color} (not a string or a tuple)" - ) - background_color = None - background_pattern = None - return background_color, background_pattern - - -def handle_builtin_structures( - node: Any, - path: str | None, - subtree_renderer: renderer.TreescopeSubtreeRenderer, -) -> ( - part_interface.RenderableTreePart - | part_interface.RenderableAndLineAnnotations - | type(NotImplemented) -): - """Renders builtin structure types.""" - if dataclasses.is_dataclass(node) and not isinstance(node, type): - constructor_open = render_dataclass_constructor(node) - - if hasattr(node, "__treescope_color__") and callable( - node.__treescope_color__ - ): - background_color, background_pattern = parse_color_and_pattern( - node.__treescope_color__(), type(node).__name__ - ) - else: - background_color = None - background_pattern = None - - return common_structures.build_foldable_tree_node_from_children( - prefix=constructor_open, - children=build_field_children( - node, - path, - subtree_renderer, - fields_or_attribute_names=dataclasses.fields(node), - ), - suffix=")", - path=path, - background_color=background_color, - background_pattern=background_pattern, - ) - - elif isinstance(node, (tuple, ast.AST)) and hasattr(type(node), "_fields"): - # Namedtuple or AST class. - return common_structures.build_foldable_tree_node_from_children( - prefix=basic_parts.siblings( - common_structures.maybe_qualified_type_name(type(node)), "(" - ), - children=build_field_children( - node, - path, - subtree_renderer, - fields_or_attribute_names=type(node)._fields, - ), - suffix=")", - path=path, - ) - - if isinstance(node, dict): - return _dict_to_foldable(node, path, subtree_renderer) - - if isinstance(node, (tuple, list, set, frozenset)): - # Sequence or set. (Not a namedtuple; those are handled above.) - return _sequence_or_set_to_foldable(node, path, subtree_renderer) - - elif isinstance(node, types.SimpleNamespace): - return common_structures.build_foldable_tree_node_from_children( - prefix=basic_parts.siblings( - common_structures.maybe_qualified_type_name(type(node)), "(" - ), - children=build_field_children( - node, - path, - subtree_renderer, - fields_or_attribute_names=tuple(node.__dict__.keys()), - ), - suffix=")", - path=path, - ) - - return NotImplemented diff --git a/penzai/treescope/lowering.py b/penzai/treescope/lowering.py new file mode 100644 index 0000000..f691fe9 --- /dev/null +++ b/penzai/treescope/lowering.py @@ -0,0 +1,422 @@ +# Copyright 2024 The Penzai Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Lowering of Treescope's renderable parts to text or HTML. + +This module defines the final logic that converts the output of node handlers +and figure parts to the final Treescope representation. +""" + +from __future__ import annotations + +import contextlib +import io +import json +from typing import Any, Callable, Iterator, Sequence +import uuid + +from penzai.treescope import context +from penzai.treescope import rendering_parts +from penzai.treescope._internal import html_encapsulation +from penzai.treescope._internal import html_escaping +from penzai.treescope._internal.parts import foldable_impl +from penzai.treescope._internal.parts import part_interface + +_deferrables: context.ContextualValue[ + list[foldable_impl.DeferredWithThunk] | None +] = context.ContextualValue( + module=__name__, qualname="_deferrables", initial_value=None +) +"""An optional list of accumulated deferrables, for use by this module.""" + + +def maybe_defer_rendering( + main_thunk: Callable[ + [rendering_parts.RenderableTreePart | None], + rendering_parts.RenderableTreePart, + ], + placeholder_thunk: Callable[[], rendering_parts.RenderableTreePart], +) -> rendering_parts.RenderableTreePart: + """Possibly defers rendering of a part in interactive contexts. + + This function can be used by advanced handlers and autovisualizers to delay + the rendering of "expensive" leaves such as `jax.Array` until after the tree + structure is drawn. If run in a non-interactive context, this just calls the + main thunk. If run in an interactive context, it instead calls the placeholder + thunk, and enqueues the placeholder thunk to be called later. + + Rendering can be performed in a deferred context by running the handlers under + the `collecting_deferred_renderings` context manager, and then rendered to + a sequence of streaming HTML updates using the `display_streaming_as_root` + function. + + Note that handlers who call this are responsible for ensuring that the + logic in `main_thunk` is safe to run at a later point in time. In particular, + any rendering context managers may have been exited by the time this main + thunk is called. As a best practice, handlers should control all of the logic + in `main_thunk` and shouldn't recursively call the subtree renderer inside it; + subtrees should be rendered before calling `maybe_defer_rendering`. + + Args: + main_thunk: A callable producing the main part to render. If not deferred, + will be called with None. If deferred, will be called with the placeholder + part, which can be inspected to e.g. infer folding state. + placeholder_thunk: A callable producing a placeholder object, which will be + rendered if we are deferring rendering. + + Returns: + Either the rendered main part or a wrapped placeholder that will later be + replaced with the main part. + """ + deferral_list = _deferrables.get() + if deferral_list is None: + return main_thunk(None) + else: + placeholder = foldable_impl.DeferredPlaceholder( + child=placeholder_thunk(), + replacement_id="deferred_" + uuid.uuid4().hex, + ) + deferral_list.append( + foldable_impl.DeferredWithThunk(placeholder, main_thunk) + ) + return placeholder + + +@contextlib.contextmanager +def collecting_deferred_renderings() -> ( + Iterator[list[foldable_impl.DeferredWithThunk]] +): + # pylint: disable=g-doc-return-or-yield + """Context manager that defers and collects `maybe_defer_rendering` calls. + + This context manager can be used by renderers that wish to render deferred + objects in a streaming fashion. When used in a + `with collecting_deferred_renderings() as deferreds:` + expression, `deferreds` will be a list that is populated by calls to + `maybe_defer_rendering`. This can later be passed to + `display_streaming_as_root` to render the deferred object in a streaming + fashion. + + Returns: + A context manager in which `maybe_defer_rendering` calls will be deferred + and collected into the result list. + """ + # pylint: enable=g-doc-return-or-yield + try: + target = [] + with _deferrables.set_scoped(target): + yield target + finally: + pass + + +################################################################################ +# Top-level rendering and roundtrip mode implementation +################################################################################ + + +def render_to_text_as_root( + root_node: rendering_parts.RenderableTreePart, + roundtrip: bool = False, + strip_trailing_whitespace: bool = True, + strip_whitespace_lines: bool = True, +) -> str: + """Renders a root node to text. + + Args: + root_node: The root node to render. + roundtrip: Whether to render in roundtrip mode. + strip_trailing_whitespace: Whether to remove trailing whitespace from lines. + strip_whitespace_lines: Whether to remove lines that are entirely + whitespace. These lines can sometimes be generated by layout code being + conservative about line breaks. Should only be True if + `strip_trailing_whitespace` is True. + + Returns: + Text for the rendered node. + """ + if strip_whitespace_lines and not strip_trailing_whitespace: + raise ValueError("strip_whitespace_lines must be False if " + "strip_trailing_whitespace is False.") + + stream = io.StringIO() + root_node.render_to_text( + stream, + expanded_parent=True, + indent=0, + roundtrip_mode=roundtrip, + render_context={}, + ) + result = stream.getvalue() + + if strip_trailing_whitespace: + trimmed_lines = [] + for line in result.split("\n"): + line = line.rstrip() + if line or not strip_whitespace_lines: + trimmed_lines.append(line) + result = "\n".join(trimmed_lines) + return result + + +_TREESCOPE_PREAMBLE_SCRIPT = """(()=> { + const defns = this.getRootNode().host.defns; + let _pendingActions = []; + let _pendingActionHandle = null; + defns.runSoon = (work) => { + const doWork = () => { + const tick = performance.now(); + while (performance.now() - tick < 32) { + if (_pendingActions.length == 0) { + _pendingActionHandle = null; + return; + } else { + const thunk = _pendingActions.shift(); + thunk(); + } + } + _pendingActionHandle = ( + window.requestAnimationFrame(doWork)); + }; + _pendingActions.push(work); + if (_pendingActionHandle === null) { + _pendingActionHandle = ( + window.requestAnimationFrame(doWork)); + } + }; + defns.toggle_root_roundtrip = (rootelt, event) => { + if (event.key == "r") { + rootelt.classList.toggle("roundtrip_mode"); + } + }; +})(); +""" + + +def _render_to_html_as_root_streaming( + root_node: rendering_parts.RenderableTreePart, + roundtrip: bool, + deferreds: Sequence[foldable_impl.DeferredWithThunk], +) -> Iterator[str]: + """Helper function: renders a root node to HTML one step at a time. + + Args: + root_node: The root node to render. + roundtrip: Whether to render in roundtrip mode. + deferreds: Sequence of deferred objects to render and splice in. + + Yields: + HTML source for the rendered node, followed by logic to substitute each + deferred object. + """ + all_css_styles = set() + all_js_defns = set() + + def _render_one( + node, + at_beginning_of_line: bool, + render_context: dict[Any, Any], + stream: io.StringIO, + ): + # Extract setup rules. + setup_parts = node.html_setup_parts(foldable_impl.SETUP_CONTEXT) + current_styles = [] + current_js_defns = [] + for part in setup_parts: + if isinstance(part, part_interface.CSSStyleRule): + if part not in all_css_styles: + current_styles.append(part) + all_css_styles.add(part) + elif isinstance(part, part_interface.JavaScriptDefn): + if part not in all_js_defns: + current_js_defns.append(part) + all_js_defns.add(part) + else: + raise ValueError(f"Invalid setup object: {part}") + + if current_styles: + stream.write("") + + if current_js_defns: + stream.write( + "") + + # Render the node itself. + node.render_to_html( + stream, + at_beginning_of_line=at_beginning_of_line, + render_context=render_context, + ) + + # Set up the styles and scripts for the root object. + stream = io.StringIO() + stream.write("") + # These scripts allow us to defer execution of javascript blocks until after + # the content is loaded, avoiding locking up the browser rendering process. + stream.write("") + + # Render the root node. + classnames = "treescope_root" + if roundtrip: + classnames += " roundtrip_mode" + stream.write( + f'
' + ) + _render_one(root_node, True, {}, stream) + stream.write("
") + + yield stream.getvalue() + + # Render any deferred parts. We insert each part into a hidden element, then + # move them all out to their appropriate positions. + if deferreds: + stream = io.StringIO() + for deferred in deferreds: + stream.write( + '") + + all_ids = [deferred.placeholder.replacement_id for deferred in deferreds] + inner_script = ( + f"const targetIds = {json.dumps(all_ids)};" + + html_escaping.without_repeated_whitespace(""" + const docroot = this.getRootNode(); + const treeroot = docroot.querySelector(".treescope_root"); + const fragment = document.createDocumentFragment(); + const treerootClone = fragment.appendChild(treeroot.cloneNode(true)); + for (let i = 0; i < targetIds.length; i++) { + let target = fragment.getElementById(targetIds[i]); + let sourceDiv = docroot.querySelector("#for_" + targetIds[i]); + target.replaceWith(sourceDiv.firstElementChild); + sourceDiv.remove(); + } + treeroot.replaceWith(treerootClone); + """) + ) + stream.write( + '" + ) + yield stream.getvalue() + + +def render_to_html_as_root( + root_node: rendering_parts.RenderableTreePart, + roundtrip: bool = False, + compressed: bool = False, +) -> str: + """Renders a root node to HTML. + + This handles collecting styles and JS definitions and inserting the root + HTML element. + + Args: + root_node: The root node to render. + roundtrip: Whether to render in roundtrip mode. + compressed: Whether to compress the HTML for display. + + Returns: + HTML source for the rendered node. + """ + render_iterator = _render_to_html_as_root_streaming(root_node, roundtrip, []) + html_src = "".join(render_iterator) + return html_encapsulation.encapsulate_html(html_src, compress=compressed) + + +def display_streaming_as_root( + root_node: rendering_parts.RenderableTreePart, + deferreds: Sequence[foldable_impl.DeferredWithThunk], + roundtrip: bool = False, + compressed: bool = True, + stealable: bool = False, +) -> str | None: + """Displays a root node in an IPython notebook in a streaming fashion. + + Args: + root_node: The root node to render. + deferreds: Deferred objects to render and splice in. + roundtrip: Whether to render in roundtrip mode. + compressed: Whether to compress the HTML for display. + stealable: Whether to return an extra HTML snippet that allows the streaming + rendering to be relocated after it is shown. + + Returns: + If ``stealable`` is True, a final HTML snippet which, if inserted into a + document, will "steal" the root node rendering, moving the DOM nodes for it + into itself. In particular, using this as the HTML rendering of the root + node during pretty printing will correctly associate the rendering with the + IPython "cell output", which is visible in some IPython backends (e.g. + JupyterLab). If ``stealable`` is False, returns None. + """ + import IPython.display # pylint: disable=g-import-not-at-top + + render_iterator = _render_to_html_as_root_streaming( + root_node, roundtrip, deferreds + ) + encapsulated_iterator = html_encapsulation.encapsulate_streaming_html( + render_iterator, compress=compressed, stealable=stealable + ) + + for step in encapsulated_iterator: + if step.segment_type == html_encapsulation.SegmentType.FINAL_OUTPUT_STEALER: + return step.html_src + else: + IPython.display.display(IPython.display.HTML(step.html_src)) diff --git a/penzai/treescope/renderer.py b/penzai/treescope/renderer.py index c25c1c9..c9379e6 100644 --- a/penzai/treescope/renderer.py +++ b/penzai/treescope/renderer.py @@ -23,11 +23,9 @@ from typing import Any, Callable, Iterable import warnings -from penzai.treescope.foldable_representation import basic_parts -from penzai.treescope.foldable_representation import common_styles -from penzai.treescope.foldable_representation import foldable_impl -from penzai.treescope.foldable_representation import layout_algorithms -from penzai.treescope.foldable_representation import part_interface +from penzai.treescope import lowering +from penzai.treescope import rendering_parts +from penzai.treescope._internal import layout_algorithms class TreescopeSubtreeRenderer(typing.Protocol): @@ -41,7 +39,7 @@ def __call__( self, node: Any, path: str | None = None, - ) -> part_interface.RenderableAndLineAnnotations: + ) -> rendering_parts.RenderableAndLineAnnotations: """Signature for a (recursive) subtree renderer. Args: @@ -70,8 +68,8 @@ def __call__( path: str | None, subtree_renderer: TreescopeSubtreeRenderer, ) -> ( - part_interface.RenderableTreePart - | part_interface.RenderableAndLineAnnotations + rendering_parts.RenderableTreePart + | rendering_parts.RenderableAndLineAnnotations | type(NotImplemented) ): """Signature for a rendering handler for a particular node type. @@ -114,8 +112,8 @@ def __call__( path: str | None, node_renderer: TreescopeSubtreeRenderer, ) -> ( - part_interface.RenderableTreePart - | part_interface.RenderableAndLineAnnotations + rendering_parts.RenderableTreePart + | rendering_parts.RenderableAndLineAnnotations | type(NotImplemented) ): """Signature for a custom wrapper hook. @@ -197,7 +195,7 @@ def _render_subtree( already_executed_wrapper_count: int, node: Any, path: str | None = None, - ) -> part_interface.RenderableAndLineAnnotations: + ) -> rendering_parts.RenderableAndLineAnnotations: """Renders a specific subtree using the renderer. Args: @@ -234,15 +232,15 @@ def _render_subtree( continue elif isinstance( - postprocessed_result, part_interface.RenderableAndLineAnnotations + postprocessed_result, rendering_parts.RenderableAndLineAnnotations ): return postprocessed_result elif isinstance( - postprocessed_result, part_interface.RenderableTreePart + postprocessed_result, rendering_parts.RenderableTreePart ): - return part_interface.RenderableAndLineAnnotations( + return rendering_parts.RenderableAndLineAnnotations( renderable=postprocessed_result, - annotations=basic_parts.EmptyPart(), + annotations=rendering_parts.empty_part(), ) else: raise ValueError( @@ -267,14 +265,14 @@ def _render_subtree( if node_id in rendering_stack: # Cycle! This object contains itself. - return part_interface.RenderableAndLineAnnotations( - renderable=common_styles.ErrorColor( - basic_parts.Text( + return rendering_parts.RenderableAndLineAnnotations( + renderable=rendering_parts.error_color( + rendering_parts.text( f"" ) ), - annotations=basic_parts.EmptyPart(), + annotations=rendering_parts.empty_part(), ) else: # Track cyclic references. We use `try: ... finally: ...` to ensure we @@ -296,15 +294,15 @@ def _render_subtree( # Try the next handler. continue elif isinstance( - maybe_result, part_interface.RenderableAndLineAnnotations + maybe_result, rendering_parts.RenderableAndLineAnnotations ): # Found a result! return maybe_result - elif isinstance(maybe_result, part_interface.RenderableTreePart): + elif isinstance(maybe_result, rendering_parts.RenderableTreePart): # Wrap it with empty annotations. - return part_interface.RenderableAndLineAnnotations( + return rendering_parts.RenderableAndLineAnnotations( renderable=maybe_result, - annotations=basic_parts.EmptyPart(), + annotations=rendering_parts.empty_part(), ) else: raise ValueError( @@ -330,11 +328,11 @@ def _render_subtree( ) # Fall back to a basic `repr` so that we still render something even # without a handler for it. - return part_interface.RenderableAndLineAnnotations( - renderable=common_styles.AbbreviationColor( - basic_parts.Text(repr(node)) + return rendering_parts.RenderableAndLineAnnotations( + renderable=rendering_parts.abbreviation_color( + rendering_parts.text(repr(node)) ), - annotations=basic_parts.EmptyPart(), + annotations=rendering_parts.empty_part(), ) else: raise ValueError( @@ -351,7 +349,7 @@ def to_foldable_representation( value: Any, ignore_exceptions: bool = False, root_keypath: str | None = "", - ) -> part_interface.RenderableAndLineAnnotations: + ) -> rendering_parts.RenderableAndLineAnnotations: """Renders an object to the foldable intermediate representation. Args: @@ -393,11 +391,11 @@ def to_text(self, value: Any, roundtrip_mode: bool = False) -> str: Returns: A text representation of the object. """ - foldable_ir = basic_parts.build_full_line_with_annotations( + foldable_ir = rendering_parts.build_full_line_with_annotations( self.to_foldable_representation(value) ) layout_algorithms.expand_for_balanced_layout(foldable_ir) - return foldable_impl.render_to_text_as_root(foldable_ir, roundtrip_mode) + return lowering.render_to_text_as_root(foldable_ir, roundtrip_mode) def to_html(self, value: Any, roundtrip_mode: bool = False) -> str: """Convenience method to render an object to HTML. @@ -409,8 +407,8 @@ def to_html(self, value: Any, roundtrip_mode: bool = False) -> str: Returns: HTML source code for the foldable representation of the object. """ - foldable_ir = basic_parts.build_full_line_with_annotations( + foldable_ir = rendering_parts.build_full_line_with_annotations( self.to_foldable_representation(value) ) layout_algorithms.expand_for_balanced_layout(foldable_ir) - return foldable_impl.render_to_html_as_root(foldable_ir, roundtrip_mode) + return lowering.render_to_html_as_root(foldable_ir, roundtrip_mode) diff --git a/penzai/treescope/rendering_parts.py b/penzai/treescope/rendering_parts.py new file mode 100644 index 0000000..f0a6ed5 --- /dev/null +++ b/penzai/treescope/rendering_parts.py @@ -0,0 +1,75 @@ +# Copyright 2024 The Penzai Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Parts and builders for Treescope's intermediate output format. + +The functions exposed here can be used to construct a tree of parts that can be +rendered to text or to interactive HTML. Node handlers and the __penzai_repr__ +method can use them to render custom types. + +Note that the internal definition of `RenderableTreePart` is considered an +implementation detail, and is subject to change. To build renderable tree parts, +you should instead use the functions exposed here (or higher-level wrappers in +`penzai.treescope.repr_lib`). +""" + +# pylint: disable=g-importing-member,g-multiple-import,unused-import + + +from penzai.treescope._internal.parts.basic_parts import ( + build_full_line_with_annotations, + empty_part, + floating_annotation_with_separate_focus, + fold_condition, + indented_children, + on_separate_lines, + roundtrip_condition, + siblings_with_annotations, + siblings, + summarizable_condition, + text, + vertical_space, +) +from penzai.treescope._internal.parts.common_structures import ( + build_copy_button, + build_custom_foldable_tree_node, + build_foldable_tree_node_from_children, + build_one_line_tree_node, + fake_placeholder_foldable, + maybe_qualified_type_name, +) +from penzai.treescope._internal.parts.common_styles import ( + abbreviation_color, + comment_color_when_expanded, + comment_color, + custom_text_color, + dashed_gray_outline_box, + deferred_placeholder_style, + error_color, + qualified_type_name_style, +) +from penzai.treescope._internal.parts.custom_dataclass_util import ( + build_field_children, + render_dataclass_constructor, +) +from penzai.treescope._internal.parts.embedded_iframe import ( + embedded_iframe, +) +from penzai.treescope._internal.parts.part_interface import ( + RenderableTreePart, + RenderableAndLineAnnotations, + Rendering, + ExpandState, + NodePath, +) diff --git a/penzai/treescope/repr_lib.py b/penzai/treescope/repr_lib.py index e8d3a0f..0464b50 100644 --- a/penzai/treescope/repr_lib.py +++ b/penzai/treescope/repr_lib.py @@ -29,10 +29,7 @@ from typing import Any, Mapping from penzai.treescope import renderer -from penzai.treescope.foldable_representation import basic_parts -from penzai.treescope.foldable_representation import common_structures -from penzai.treescope.foldable_representation import common_styles -from penzai.treescope.foldable_representation import part_interface +from penzai.treescope import rendering_parts def render_object_constructor( @@ -42,7 +39,7 @@ def render_object_constructor( subtree_renderer: renderer.TreescopeSubtreeRenderer, roundtrippable: bool = False, color: str | None = None, -) -> part_interface.Rendering: +) -> rendering_parts.Rendering: """Renders an object in "constructor format", similar to a dataclass. This produces a rendering like `Foo(bar=1, baz=2)`, where Foo identifies the @@ -69,8 +66,8 @@ def __penzai_repr__(self, path, subtree_renderer): from `__penzai_repr__`, this should come from the `path` argument to `__penzai_repr__`. subtree_renderer: The renderer to use to render subtrees. When - `render_object_constructor` is called from `__penzai_repr__`, this - should come from the `subtree_renderer` argument to `__penzai_repr__`. + `render_object_constructor` is called from `__penzai_repr__`, this should + come from the `subtree_renderer` argument to `__penzai_repr__`. roundtrippable: Whether evaluating the rendering as Python code will produce an object that is equal to the original object. This implies that the keyword arguments are actually the keyword arguments to the constructor, @@ -83,19 +80,23 @@ def __penzai_repr__(self, path, subtree_renderer): A rendering of the object, suitable for returning from `__penzai_repr__`. """ if roundtrippable: - constructor = basic_parts.siblings( - common_structures.maybe_qualified_type_name(object_type), "(" + constructor = rendering_parts.siblings( + rendering_parts.maybe_qualified_type_name(object_type), "(" ) - closing_suffix = basic_parts.Text(")") + closing_suffix = rendering_parts.text(")") else: - constructor = basic_parts.siblings( - basic_parts.RoundtripCondition(roundtrip=basic_parts.Text("<")), - common_structures.maybe_qualified_type_name(object_type), + constructor = rendering_parts.siblings( + rendering_parts.roundtrip_condition( + roundtrip=rendering_parts.text("<") + ), + rendering_parts.maybe_qualified_type_name(object_type), "(", ) - closing_suffix = basic_parts.siblings( + closing_suffix = rendering_parts.siblings( ")", - basic_parts.RoundtripCondition(roundtrip=basic_parts.Text(">")), + rendering_parts.roundtrip_condition( + roundtrip=rendering_parts.text(">") + ), ) children = [] @@ -105,15 +106,18 @@ def __penzai_repr__(self, path, subtree_renderer): if i < len(attributes) - 1: # Not the last child. Always show a comma, and add a space when # collapsed. - comma_after = basic_parts.siblings( - ",", basic_parts.FoldCondition(collapsed=basic_parts.Text(" ")) + comma_after = rendering_parts.siblings( + ",", + rendering_parts.fold_condition(collapsed=rendering_parts.text(" ")), ) else: # Last child: only show the comma when the node is expanded. - comma_after = basic_parts.FoldCondition(expanded=basic_parts.Text(",")) + comma_after = rendering_parts.fold_condition( + expanded=rendering_parts.text(",") + ) - child_line = basic_parts.build_full_line_with_annotations( - basic_parts.siblings_with_annotations( + child_line = rendering_parts.build_full_line_with_annotations( + rendering_parts.siblings_with_annotations( f"{name}=", subtree_renderer(value, path=child_path), ), @@ -121,7 +125,7 @@ def __penzai_repr__(self, path, subtree_renderer): ) children.append(child_line) - return common_structures.build_foldable_tree_node_from_children( + return rendering_parts.build_foldable_tree_node_from_children( prefix=constructor, children=children, suffix=closing_suffix, @@ -137,7 +141,7 @@ def render_dictionary_wrapper( subtree_renderer: renderer.TreescopeSubtreeRenderer, roundtrippable: bool = False, color: str | None = None, -) -> part_interface.Rendering: +) -> rendering_parts.Rendering: """Renders an object in "wrapped dictionary format". This produces a rendering like `Foo({"bar": 1, "baz": 2})`, where Foo @@ -163,8 +167,8 @@ def __penzai_repr__(self, path, subtree_renderer): from `__penzai_repr__`, this should come from the `path` argument to `__penzai_repr__`. subtree_renderer: The renderer to use to render subtrees. When - `render_object_constructor` is called from `__penzai_repr__`, this - should come from the `subtree_renderer` argument to `__penzai_repr__`. + `render_object_constructor` is called from `__penzai_repr__`, this should + come from the `subtree_renderer` argument to `__penzai_repr__`. roundtrippable: Whether evaluating the rendering as Python code will produce an object that is equal to the original object. This implies that the constructor for `object_type` takes a single argument, which is a @@ -179,19 +183,23 @@ def __penzai_repr__(self, path, subtree_renderer): A rendering of the object, suitable for returning from `__penzai_repr__`. """ if roundtrippable: - constructor = basic_parts.siblings( - common_structures.maybe_qualified_type_name(object_type), "({" + constructor = rendering_parts.siblings( + rendering_parts.maybe_qualified_type_name(object_type), "({" ) - closing_suffix = basic_parts.Text("})") + closing_suffix = rendering_parts.text("})") else: - constructor = basic_parts.siblings( - basic_parts.RoundtripCondition(roundtrip=basic_parts.Text("<")), - common_structures.maybe_qualified_type_name(object_type), + constructor = rendering_parts.siblings( + rendering_parts.roundtrip_condition( + roundtrip=rendering_parts.text("<") + ), + rendering_parts.maybe_qualified_type_name(object_type), "({", ) - closing_suffix = basic_parts.siblings( + closing_suffix = rendering_parts.siblings( "})", - basic_parts.RoundtripCondition(roundtrip=basic_parts.Text(">")), + rendering_parts.roundtrip_condition( + roundtrip=rendering_parts.text(">") + ), ) children = [] @@ -201,12 +209,15 @@ def __penzai_repr__(self, path, subtree_renderer): if i < len(wrapped_dict) - 1: # Not the last child. Always show a comma, and add a space when # collapsed. - comma_after = basic_parts.siblings( - ",", basic_parts.FoldCondition(collapsed=basic_parts.Text(" ")) + comma_after = rendering_parts.siblings( + ",", + rendering_parts.fold_condition(collapsed=rendering_parts.text(" ")), ) else: # Last child: only show the comma when the node is expanded. - comma_after = basic_parts.FoldCondition(expanded=basic_parts.Text(",")) + comma_after = rendering_parts.fold_condition( + expanded=rendering_parts.text(",") + ) key_rendering = subtree_renderer(key) value_rendering = subtree_renderer(value, path=child_path) @@ -221,31 +232,33 @@ def __penzai_repr__(self, path, subtree_renderer): ): # Simple enough to render on one line. children.append( - basic_parts.siblings_with_annotations( + rendering_parts.siblings_with_annotations( key_rendering, ": ", value_rendering, comma_after ) ) else: # Should render on multiple lines. children.append( - basic_parts.siblings( - basic_parts.build_full_line_with_annotations( + rendering_parts.siblings( + rendering_parts.build_full_line_with_annotations( key_rendering, ":", - basic_parts.FoldCondition(collapsed=basic_parts.Text(" ")), + rendering_parts.fold_condition( + collapsed=rendering_parts.text(" ") + ), ), - basic_parts.IndentedChildren.build([ - basic_parts.siblings_with_annotations( + rendering_parts.indented_children([ + rendering_parts.siblings_with_annotations( value_rendering, comma_after ), - basic_parts.FoldCondition( - expanded=basic_parts.VerticalSpace("0.5em") + rendering_parts.fold_condition( + expanded=rendering_parts.vertical_space("0.5em") ), ]), ) ) - return common_structures.build_foldable_tree_node_from_children( + return rendering_parts.build_foldable_tree_node_from_children( prefix=constructor, children=children, suffix=closing_suffix, @@ -261,8 +274,8 @@ def render_enumlike_item( path: str | None, subtree_renderer: renderer.TreescopeSubtreeRenderer, ) -> ( - part_interface.RenderableTreePart - | part_interface.RenderableAndLineAnnotations + rendering_parts.RenderableTreePart + | rendering_parts.RenderableAndLineAnnotations ): """Renders a value of an enum-like type (e.g. like `enum.Enum`). @@ -286,13 +299,13 @@ def render_enumlike_item( A rendering of the object, suitable for returning from `__treescope_repr__`. """ del subtree_renderer - return common_structures.build_one_line_tree_node( - basic_parts.siblings_with_annotations( - common_structures.maybe_qualified_type_name(object_type), + return rendering_parts.build_one_line_tree_node( + rendering_parts.siblings_with_annotations( + rendering_parts.maybe_qualified_type_name(object_type), "." + item_name, extra_annotations=[ - common_styles.CommentColor( - basic_parts.Text(f" # value: {repr(item_value)}") + rendering_parts.comment_color( + rendering_parts.text(f" # value: {repr(item_value)}") ) ], ), diff --git a/penzai/treescope/treescope_ipython.py b/penzai/treescope/treescope_ipython.py index 2c6e61b..6078c64 100644 --- a/penzai/treescope/treescope_ipython.py +++ b/penzai/treescope/treescope_ipython.py @@ -22,9 +22,9 @@ from penzai.treescope import context from penzai.treescope import default_renderer from penzai.treescope import figures -from penzai.treescope import object_inspection -from penzai.treescope.foldable_representation import basic_parts -from penzai.treescope.foldable_representation import foldable_impl +from penzai.treescope import lowering +from penzai.treescope import rendering_parts +from penzai.treescope._internal import object_inspection # pylint: disable=g-import-not-at-top try: @@ -166,7 +166,7 @@ def _render_for_ipython(value): # in an interactive context, we can defer rendering of leaves that # support deferral and splice them in one at a time. deferreds = stack.enter_context( - foldable_impl.collecting_deferred_renderings() + lowering.collecting_deferred_renderings() ) else: deferreds = None @@ -177,13 +177,13 @@ def _render_for_ipython(value): if root_repr_method: foldable_ir = root_repr_method() else: - foldable_ir = basic_parts.build_full_line_with_annotations( + foldable_ir = rendering_parts.build_full_line_with_annotations( default_renderer.build_foldable_representation( value, ignore_exceptions=True ) ) if streaming: - output_stealer = foldable_impl.display_streaming_as_root( + output_stealer = lowering.display_streaming_as_root( foldable_ir, deferreds, roundtrip=False, @@ -198,7 +198,7 @@ def _render_for_ipython(value): return output_stealer else: assert deferreds is None - return foldable_impl.render_to_html_as_root( + return lowering.render_to_html_as_root( foldable_ir, roundtrip=False, compressed=compress_html, diff --git a/penzai/treescope/type_registries.py b/penzai/treescope/type_registries.py index 2ba2054..3b6d29f 100644 --- a/penzai/treescope/type_registries.py +++ b/penzai/treescope/type_registries.py @@ -41,7 +41,6 @@ from penzai.treescope import ndarray_adapters from penzai.treescope import renderer - T = TypeVar("T") @@ -111,19 +110,19 @@ # rendering system, but we define its setup function here as well for # consistency with the other array modules. "numpy": ( - "penzai.treescope.handlers.interop.numpy_support", + "penzai.treescope._internal.handlers.interop.numpy_support", "set_up_treescope", ), "jax": ( - "penzai.treescope.handlers.interop.jax_support", + "penzai.treescope._internal.handlers.interop.jax_support", "set_up_treescope", ), "penzai.core": ( - "penzai.treescope.handlers.interop.penzai_core_support", + "penzai.treescope._internal.handlers.interop.penzai_core_support", "set_up_treescope", ), "torch": ( - "penzai.treescope.handlers.interop.torch_support", + "penzai.treescope._internal.handlers.interop.torch_support", "set_up_treescope", ), } diff --git a/tests/treescope/ndarray_adapters_test.py b/tests/treescope/ndarray_adapters_test.py index 04440f6..0a50529 100644 --- a/tests/treescope/ndarray_adapters_test.py +++ b/tests/treescope/ndarray_adapters_test.py @@ -258,17 +258,17 @@ def test_array_rendering_without_error(self, array_type, dtype): with self.subTest("explicit_unmasked"): res = arrayviz.render_array(array) - self.assertIsInstance(res, arrayviz.ArrayvizRendering) + self.assertTrue(hasattr(res, "_repr_html_")) with self.subTest("explicit_masked"): res = arrayviz.render_array(array, valid_mask=array > 100) - self.assertIsInstance(res, arrayviz.ArrayvizRendering) + self.assertTrue(hasattr(res, "_repr_html_")) with self.subTest("explicit_masked_truncated"): res = arrayviz.render_array( array, valid_mask=array > 100, truncate=True, maximum_size=100 ) - self.assertIsInstance(res, arrayviz.ArrayvizRendering) + self.assertTrue(hasattr(res, "_repr_html_")) with self.subTest("automatic"): with autovisualize.active_autovisualizer.set_scoped( diff --git a/tests/treescope/renderer_test.py b/tests/treescope/renderer_test.py index 429dcd6..f170e4c 100644 --- a/tests/treescope/renderer_test.py +++ b/tests/treescope/renderer_test.py @@ -35,11 +35,10 @@ from tests.treescope.fixtures import treescope_examples_fixture as fixture_lib from penzai.treescope import autovisualize from penzai.treescope import default_renderer -from penzai.treescope.foldable_representation import basic_parts -from penzai.treescope.foldable_representation import foldable_impl -from penzai.treescope.foldable_representation import layout_algorithms -from penzai.treescope.foldable_representation import part_interface -from penzai.treescope.handlers import function_reflection_handlers +from penzai.treescope import handlers +from penzai.treescope import lowering +from penzai.treescope import rendering_parts +from penzai.treescope._internal import layout_algorithms import torch @@ -64,7 +63,7 @@ def test_renderer_interface(self): rendering = renderer.to_foldable_representation({"key": "value"}) self.assertIsInstance( - rendering, part_interface.RenderableAndLineAnnotations + rendering, rendering_parts.RenderableAndLineAnnotations ) def test_high_level_interface(self): @@ -94,8 +93,8 @@ def hook_that_crashes(node, path, node_renderer): rendering = renderer.to_foldable_representation([1, 2, 3, "foo", 4]) layout_algorithms.expand_to_depth(rendering.renderable, 1) self.assertEqual( - foldable_impl.render_to_text_as_root( - basic_parts.build_full_line_with_annotations(rendering) + lowering.render_to_text_as_root( + rendering_parts.build_full_line_with_annotations(rendering) ), "[\n 1,\n 2,\n 3,\n 'foo',\n 4,\n]", ) @@ -117,8 +116,8 @@ def hook_that_crashes(node, path, node_renderer): ) layout_algorithms.expand_to_depth(rendering.renderable, 1) self.assertEqual( - foldable_impl.render_to_text_as_root( - basic_parts.build_full_line_with_annotations(rendering) + lowering.render_to_text_as_root( + rendering_parts.build_full_line_with_annotations(rendering) ), "[\n 1,\n 2,\n 3,\n 'trigger handler error',\n 'trigger hook" " error',\n 4,\n]", @@ -740,7 +739,7 @@ def test_object_rendering( renderer = default_renderer.active_renderer.get() # Render it to IR. - rendering = basic_parts.build_full_line_with_annotations( + rendering = rendering_parts.build_full_line_with_annotations( renderer.to_foldable_representation(target) ) @@ -750,14 +749,14 @@ def test_object_rendering( if expected_collapsed is not None: with self.subTest("collapsed"): self.assertEqual( - foldable_impl.render_to_text_as_root(rendering), + lowering.render_to_text_as_root(rendering), expected_collapsed, ) if expected_roundtrip_collapsed is not None: with self.subTest("roundtrip_collapsed"): self.assertEqual( - foldable_impl.render_to_text_as_root(rendering, roundtrip=True), + lowering.render_to_text_as_root(rendering, roundtrip=True), expected_roundtrip_collapsed, ) @@ -766,20 +765,20 @@ def test_object_rendering( if expected_expanded is not None: with self.subTest("expanded"): self.assertEqual( - foldable_impl.render_to_text_as_root(rendering), + lowering.render_to_text_as_root(rendering), expected_expanded, ) if expected_roundtrip is not None: with self.subTest("roundtrip"): self.assertEqual( - foldable_impl.render_to_text_as_root(rendering, roundtrip=True), + lowering.render_to_text_as_root(rendering, roundtrip=True), expected_roundtrip, ) # Render to HTML; make sure it doesn't raise any errors. with self.subTest("html_no_errors"): - _ = foldable_impl.render_to_html_as_root(rendering) + _ = lowering.render_to_html_as_root(rendering) def test_closure_rendering(self): def outer_fn(x): @@ -795,13 +794,13 @@ def inner_fn(y): renderer = renderer.extended_with( handlers=[ functools.partial( - function_reflection_handlers.handle_code_objects_with_reflection, + handlers.handle_code_objects_with_reflection, show_closure_vars=True, ) ] ) # Render it to IR. - rendering = basic_parts.build_full_line_with_annotations( + rendering = rendering_parts.build_full_line_with_annotations( renderer.to_foldable_representation(closure) ) @@ -818,23 +817,23 @@ def inner_fn(y): " of ", "tests/treescope/renderer_test.py", ], - foldable_impl.render_to_text_as_root(rendering), + lowering.render_to_text_as_root(rendering), ) def test_fallback_repr_pytree_node(self): target = [fixture_lib.UnknownPytreeNode(1234, 5678)] renderer = default_renderer.active_renderer.get() - rendering = basic_parts.build_full_line_with_annotations( + rendering = rendering_parts.build_full_line_with_annotations( renderer.to_foldable_representation(target) ) layout_algorithms.expand_to_depth(rendering, 0) self.assertEqual( - foldable_impl.render_to_text_as_root(rendering), + lowering.render_to_text_as_root(rendering), "[]", ) layout_algorithms.expand_to_depth(rendering, 2) - rendered_text = foldable_impl.render_to_text_as_root(rendering) + rendered_text = lowering.render_to_text_as_root(rendering) self.assertEqual( "\n".join( line.rstrip() for line in rendered_text.splitlines(keepends=True) @@ -854,17 +853,17 @@ def test_fallback_repr_pytree_node(self): def test_fallback_repr_one_line(self): target = [fixture_lib.UnknownObjectWithOneLineRepr()] renderer = default_renderer.active_renderer.get() - rendering = basic_parts.build_full_line_with_annotations( + rendering = rendering_parts.build_full_line_with_annotations( renderer.to_foldable_representation(target) ) layout_algorithms.expand_to_depth(rendering, 0) self.assertEqual( - foldable_impl.render_to_text_as_root(rendering), + lowering.render_to_text_as_root(rendering), "[]", ) layout_algorithms.expand_to_depth(rendering, 2) self.assertEqual( - foldable_impl.render_to_text_as_root(rendering), + lowering.render_to_text_as_root(rendering), textwrap.dedent(f"""\ [ , # {object.__repr__(target[0])} @@ -874,17 +873,17 @@ def test_fallback_repr_one_line(self): def test_fallback_repr_multiline_idiomatic(self): target = [fixture_lib.UnknownObjectWithMultiLineRepr()] renderer = default_renderer.active_renderer.get() - rendering = basic_parts.build_full_line_with_annotations( + rendering = rendering_parts.build_full_line_with_annotations( renderer.to_foldable_representation(target) ) layout_algorithms.expand_to_depth(rendering, 0) self.assertEqual( - foldable_impl.render_to_text_as_root(rendering), + lowering.render_to_text_as_root(rendering), "[]", ) layout_algorithms.expand_to_depth(rendering, 2) self.assertEqual( - foldable_impl.render_to_text_as_root(rendering), + lowering.render_to_text_as_root(rendering), textwrap.dedent(f"""\ [