Refactor array rendering, add type registries, add PyTorch renderer. #65
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Refactor array rendering, add type registries, add PyTorch renderer.
This change significantly reworks how penzai.treescope renders custom types,
by addding a "type registry" of type-specific pretty printers, similar to e.g. the
IPython pretty printer. (This is implemented via a new handler step, and can be overridden
if needed.) It also introduces a mechanism for dynamic type-dependent setup logic, so that
new handlers can be added to the registry when a library is imported, without having to
eagerly import that library.
Additionally, it adds a new NDArrayAdapter system, and modifies the array visualization
functions to use these adapters. The adapters make it possible to add support for new
ndarray-like types, including np.ndarray, jax.Array, pz.nx.NamedArray, and torch.Tensor,
using a uniform interface. Types in the adapter registry can be automatically visualized
by the array autovisualizer and manually rendered via
pz.ts.render_array
.Furthermore, it adds initial support for PyTorch tensors (via the NDArrayAdapter registry)
and PyTorch modules, making it possible to visualize them using treescope whenever torch
is imported (but doing nothing if torch is not installed). PyTorch tensors support automatic
visualization similar to JAX Arrays. PyTorch modules are dynamically inspected to build a
visualization. (Note that due to the object semantics of PyTorch modules, and the convention
of mutating the module state in init or afterward, PyTorch module renderings are in
general not round-trippable.)
Other minor changes: