Skip to content

Commit

Permalink
Configure uv, pyink, and pytype, and fix lint/type errors. (#90)
Browse files Browse the repository at this point in the history
- Penzai is now configured to use the `uv` package manager for testing and development.
- Configurations have been updated to run lint and type checks along with unit tests.
- Fixed a number of minor formatting and type errors to ensure the new checks pass.
  • Loading branch information
danieldjohnson authored Oct 31, 2024
1 parent 7c597d8 commit 281475b
Show file tree
Hide file tree
Showing 61 changed files with 3,845 additions and 144 deletions.
23 changes: 19 additions & 4 deletions .github/workflows/unittests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,25 @@ jobs:
# cache: pip
# cache-dependency-path: '**/pyproject.toml'

- run: pip --version
- run: pip install -e .[dev,extras]
- run: pip freeze
- uses: astral-sh/setup-uv@v3
with:
version: "0.4.17"

- name: Install dependencies
run: |
uv sync --locked --extra extras --extra dev
# Check formatting
- name: Check pyink formatting
run: uv run pyink penzai --check

- name: Run pylint
run: uv run pylint penzai

# Run tests
- name: Run tests
run: python run_tests.py
run: uv run python run_tests.py

# Run typechecker
- name: Run pytype
run: uv run pytype --jobs auto penzai
12 changes: 9 additions & 3 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,12 @@ disable=abstract-method,
buffer-builtin,
c-extension-no-member,
consider-using-enumerate,
consider-using-in,
cmp-builtin,
cmp-method,
coerce-builtin,
coerce-method,
cyclic-import, # used for treescope handlers
delslice-method,
div-method,
duplicate-code,
Expand Down Expand Up @@ -113,6 +115,7 @@ disable=abstract-method,
no-self-use,
nonzero-method,
not-callable, # false positives for jax.jit
not-an-iterable, # false positives around dataclasses
oct-method,
old-division,
old-ne-operator,
Expand Down Expand Up @@ -141,14 +144,17 @@ disable=abstract-method,
too-many-instance-attributes,
too-many-locals,
too-many-nested-blocks,
too-many-positional-arguments,
too-many-public-methods,
too-many-return-statements,
too-many-statements,
trailing-newlines,
unichr-builtin,
unicode-builtin,
unnecessary-lambda-assignment,
unnecessary-pass,
unpacking-in-except,
use-dict-literal,
useless-else-on-loop,
useless-object-inheritance,
useless-suppression,
Expand Down Expand Up @@ -434,6 +440,6 @@ valid-metaclass-classmethod-first-arg=mcs

# Exceptions that will emit a warning when being caught. Defaults to
# "Exception"
overgeneral-exceptions=StandardError,
Exception,
BaseException
overgeneral-exceptions=builtins.StandardError,
builtins.Exception,
builtins.BaseException
3 changes: 2 additions & 1 deletion .vscode/extensions.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{
"recommendations": [
"ms-python.black-formatter"
"ms-python.black-formatter",
"ms-python.pylint"
]
}
23 changes: 12 additions & 11 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@
"files.associations": {
".pylintrc": "ini"
},
"files.watcherExclude": {
"**/.git/**": true
},
"files.exclude": {
"**/__pycache__": true,
"**/.pytest_cache": true,
"**/*.egg-info": true
},
"python.testing.unittestEnabled": false,
"python.testing.nosetestsEnabled": false,
"python.testing.pytestEnabled": true,
Expand All @@ -13,16 +21,9 @@
"editor.rulers": [80],
"editor.tabSize": 2,
"editor.formatOnSave": true,
"editor.detectIndentation": false
"editor.detectIndentation": false,
"editor.defaultFormatter": "ms-python.black-formatter",
},
"python.formatting.provider": "black",
"python.formatting.blackPath": "pyink",
"files.watcherExclude": {
"**/.git/**": true
},
"files.exclude": {
"**/__pycache__": true,
"**/.pytest_cache": true,
"**/*.egg-info": true
}
"black-formatter.path": ["uvx", "pyink"],
"pylint.enabled": true,
}
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
# documentation root, use os.path.abspath to make it absolute, like shown here.

# pylint: disable=g-bad-import-order
# pylint: disable=g-import-not-at-top
# pylint: disable=import-outside-toplevel
import inspect
import logging
import os
Expand Down
2 changes: 1 addition & 1 deletion penzai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""A JAX research toolkit for building, editing, and visualizing neural networks."""
"""A JAX research toolkit for building, editing and visualizing neural networks.""" # pylint: disable=line-too-long

__version__ = '0.2.2'
5 changes: 5 additions & 0 deletions penzai/core/_treescope_handlers/named_axes_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import dataclasses
from typing import Any
import typing

import jax
import numpy as np
Expand Down Expand Up @@ -252,6 +253,10 @@ def get_sharding_info_for_array_data(

def should_autovisualize(self, array: named_axes.NamedArrayBase) -> bool:
assert isinstance(array, named_axes.NamedArray | named_axes.NamedArrayView)
array = typing.cast(
named_axes.NamedArray | named_axes.NamedArrayView,
array,
)
return (
isinstance(array.data_array, jax.Array)
and not isinstance(array.data_array, jax.core.Tracer)
Expand Down
4 changes: 2 additions & 2 deletions penzai/core/named_axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1218,13 +1218,13 @@ def __iter__(self):
# Rendering
def __treescope_repr__(self, path: str | None, subtree_renderer: Any):
"""Treescope handler for named arrays."""
from penzai.core._treescope_handlers import named_axes_handlers # pylint: disable=g-import-not-at-top
from penzai.core._treescope_handlers import named_axes_handlers # pylint: disable=import-outside-toplevel

return named_axes_handlers.handle_named_arrays(self, path, subtree_renderer)

def __treescope_ndarray_adapter__(self):
"""Treescope handler for named arrays."""
from penzai.core._treescope_handlers import named_axes_handlers # pylint: disable=g-import-not-at-top
from penzai.core._treescope_handlers import named_axes_handlers # pylint: disable=import-outside-toplevel

return named_axes_handlers.NamedArrayAdapter()

Expand Down
12 changes: 6 additions & 6 deletions penzai/core/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,7 +811,7 @@ def safe_filter_fn(*args):
if with_keypath or innermost:
if with_keypath:
wrapped_filter_fn = safe_filter_fn
if not with_keypath:
else:
wrapped_filter_fn = lambda _, s: safe_filter_fn(s)

def process_subtree(keypath, leaf_or_subtree) -> tuple[bool, Any]:
Expand Down Expand Up @@ -920,7 +920,7 @@ def at_instances_of(
lambda subtree: isinstance(subtree, cls), innermost=innermost
)

def at_equal_to(self, template: OtherSubtree) -> Selection[OtherSubtree]: # pytype: disable=invalid-annotation
def at_equal_to(self, template: OtherSubtree) -> Selection[OtherSubtree]: # pytype: disable=invalid-annotation # pylint: disable=line-too-long
"""Selects subtrees that are equal to a particular object.
Mostly a convenience wrapper for ::
Expand All @@ -939,7 +939,7 @@ def at_equal_to(self, template: OtherSubtree) -> Selection[OtherSubtree]: # pyt
equal to this object (with other on the left).
"""
# Lazy import to avoid circular dependencies
import penzai.core.named_axes # pylint: disable=g-import-not-at-top
import penzai.core.named_axes # pylint: disable=import-outside-toplevel

bypass_equal_types = (
jax.Array,
Expand Down Expand Up @@ -1292,7 +1292,7 @@ def show_selection(self, ignore_exceptions: bool = False):
fallback for those subtrees.
"""
# Import selection_rendering here to avoid a circular import.
from penzai.core._treescope_handlers import selection_rendering # pylint: disable=g-import-not-at-top
from penzai.core._treescope_handlers import selection_rendering # pylint: disable=import-outside-toplevel

selection_rendering.display_selection_streaming(
self, visible_selection=True, ignore_exceptions=ignore_exceptions
Expand All @@ -1312,7 +1312,7 @@ def show_value(self, ignore_exceptions: bool = False):
fallback for those subtrees.
"""
# Import selection_rendering here to avoid a circular import.
from penzai.core._treescope_handlers import selection_rendering # pylint: disable=g-import-not-at-top
from penzai.core._treescope_handlers import selection_rendering # pylint: disable=import-outside-toplevel

selection_rendering.display_selection_streaming(
self, visible_selection=False, ignore_exceptions=ignore_exceptions
Expand Down Expand Up @@ -1404,7 +1404,7 @@ def apply_with_selected_index(

def __treescope_root_repr__(self):
"""Renders this selection as the root object in a treescope rendering."""
from penzai.core._treescope_handlers import selection_rendering # pylint: disable=g-import-not-at-top
from penzai.core._treescope_handlers import selection_rendering # pylint: disable=import-outside-toplevel

return selection_rendering.render_selection_to_foldable_representation(
self, visible_selection=True, ignore_exceptions=True
Expand Down
4 changes: 2 additions & 2 deletions penzai/core/shapecheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class Wildcard(struct.Struct):
)

def __treescope_repr__(self, path: str | None, subtree_renderer: Any):
from penzai.core._treescope_handlers import shapecheck_handlers # pylint: disable=g-import-not-at-top
from penzai.core._treescope_handlers import shapecheck_handlers # pylint: disable=import-outside-toplevel

return shapecheck_handlers.handle_arraystructures(
self, path, subtree_renderer
Expand Down Expand Up @@ -358,7 +358,7 @@ def into_pytree(self) -> jax.ShapeDtypeStruct | named_axes.NamedArray:
return jax.ShapeDtypeStruct(self.shape, self.dtype)

def __treescope_repr__(self, path: str | None, subtree_renderer: Any):
from penzai.core._treescope_handlers import shapecheck_handlers # pylint: disable=g-import-not-at-top
from penzai.core._treescope_handlers import shapecheck_handlers # pylint: disable=import-outside-toplevel

return shapecheck_handlers.handle_arraystructures(
self, path, subtree_renderer
Expand Down
10 changes: 5 additions & 5 deletions penzai/core/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ def select(self) -> selectors.Selection[Struct]:
A singleton selection containing this struct.
"""
# Dynamic import to avoid circular import issues.
from penzai.core import selectors # pylint: disable=g-import-not-at-top
from penzai.core import selectors # pylint: disable=import-outside-toplevel

return selectors.select(self)

Expand Down Expand Up @@ -679,7 +679,7 @@ def treescope_color(self) -> str | tuple[str, str]:
"""
# By default, we render structs in color if they define __call__.
if hasattr(self, "__call__"):
from treescope import formatting_util # pylint: disable=g-import-not-at-top
from treescope import formatting_util # pylint: disable=import-outside-toplevel

type_string = type(self).__module__ + "." + type(self).__qualname__
return formatting_util.color_from_string(type_string)
Expand All @@ -689,7 +689,7 @@ def treescope_color(self) -> str | tuple[str, str]:
def __repr__(self):
"""Renders this object with treescope, on a single line."""
# Defer to Treescope.
import treescope # pylint: disable=g-import-not-at-top
import treescope # pylint: disable=import-outside-toplevel

with treescope.using_expansion_strategy(max_height=1):
return treescope.render_to_text(self, ignore_exceptions=True)
Expand All @@ -698,7 +698,7 @@ def _repr_pretty_(self, p, cycle):
"""Pretty-prints this object for an IPython pretty-printer."""
del cycle
# Defer to Treescope.
import treescope # pylint: disable=g-import-not-at-top
import treescope # pylint: disable=import-outside-toplevel

rendering = treescope.render_to_text(self, ignore_exceptions=True)
for i, line in enumerate(rendering.split("\n")):
Expand All @@ -707,6 +707,6 @@ def _repr_pretty_(self, p, cycle):
p.text(line)

def __treescope_repr__(self, path: str | None, subtree_renderer: Any):
from penzai.core._treescope_handlers import struct_handler # pylint: disable=g-import-not-at-top
from penzai.core._treescope_handlers import struct_handler # pylint: disable=import-outside-toplevel

return struct_handler.handle_structs(self, path, subtree_renderer)
10 changes: 5 additions & 5 deletions penzai/core/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,13 +539,13 @@ def treescope_color(self) -> str | tuple[str, str]:

def __repr__(self):
# Defer to Treescope.
import treescope # pylint: disable=g-import-not-at-top
import treescope # pylint: disable=import-outside-toplevel

with treescope.using_expansion_strategy(max_height=1):
return treescope.render_to_text(self, ignore_exceptions=True)

def __treescope_repr__(self, path: str | None, subtree_renderer: Any):
from treescope import repr_lib # pylint: disable=g-import-not-at-top
from treescope import repr_lib # pylint: disable=import-outside-toplevel

return repr_lib.render_object_constructor(
type(self),
Expand Down Expand Up @@ -606,7 +606,7 @@ def treescope_color(self):
return "#93cce1"


@struct.pytree_dataclass(has_implicitly_inherited_fields=True)
@struct.pytree_dataclass(has_implicitly_inherited_fields=True) # pytype: disable=wrong-keyword-args # pylint: disable=line-too-long
class ParameterValue(LabeledVariableValue[T]):
"""The value of a Parameter, as a frozen JAX pytree.
Expand Down Expand Up @@ -656,7 +656,7 @@ def treescope_color(self):

@dataclasses.dataclass(frozen=True)
class AutoStateVarLabel(auto_order_types.AutoOrderedAcrossTypes):
"""A label for a StateVariable that is unique based on its Python object ID."""
"""A label for a StateVariable that is unique based on Python object ID."""

var_id: int

Expand Down Expand Up @@ -758,7 +758,7 @@ def treescope_color(self):
return "#f57603"


@struct.pytree_dataclass(has_implicitly_inherited_fields=True)
@struct.pytree_dataclass(has_implicitly_inherited_fields=True) # pytype: disable=wrong-keyword-args # pylint: disable=line-too-long
class StateVariableValue(LabeledVariableValue[T]):
"""The value of a StateVariable, as a frozen JAX pytree.
Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def handle_layers(
)
extra_annotations.append(
rendering_parts.fold_condition(
expanded=rendering_parts.floating_annotation_with_separate_focus(
expanded=rendering_parts.floating_annotation_with_separate_focus( # pylint: disable=line-too-long
rendering_parts.in_outlined_box(
rendering_parts.comment_color(
rendering_parts.siblings(
Expand Down
4 changes: 2 additions & 2 deletions penzai/deprecated/v1/core/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def output_structure(self) -> shapecheck.StructureAnnotation:
return shapecheck.ANY

def __init_subclass__(cls, **kwargs):
"""Checks that new subclasses of Layer have wrapped ``__call__`` if needed."""
"""Checks that new subclasses of Layer have wrapped ``__call__`` if needed.""" # pylint: disable=line-too-long
super().__init_subclass__(**kwargs)
if cls.__call__ is not Layer.__call__ and (
cls.input_structure is not Layer.input_structure
Expand All @@ -220,7 +220,7 @@ def __init_subclass__(cls, **kwargs):
)

def __treescope_repr__(self, path: str | None, subtree_renderer: Any):
from penzai.deprecated.v1.core._treescope_handlers import layer_handler # pylint: disable=g-import-not-at-top
from penzai.deprecated.v1.core._treescope_handlers import layer_handler # pylint: disable=import-outside-toplevel

return layer_handler.handle_layers(self, path, subtree_renderer)

Expand Down
Loading

0 comments on commit 281475b

Please sign in to comment.