Skip to content

Commit

Permalink
fix: fix numpy arrays in repr (#207)
Browse files Browse the repository at this point in the history
  • Loading branch information
tlambert03 authored Nov 25, 2024
1 parent 70b32dd commit 5563fc0
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 26 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ repos:
- id: validate-pyproject

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.7.4
rev: v0.8.0
hooks:
- id: ruff
args: [--fix, --unsafe-fixes]
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ select = [
"C4", # comprehensions
"B", # bugbear
"A001", # Variable shadowing a python builtin
"TCH", # flake8-type-checking
"TC", # flake8-type-checking
"TID", # flake8-tidy-imports
"RUF", # ruff-specific rules
"PERF", # performance
Expand Down
10 changes: 7 additions & 3 deletions src/useq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
Shape,
)
from useq._hardware_autofocus import AnyAutofocusPlan, AutoFocusPlan, AxesBasedAF
from useq._mda_event import MDAEvent, PropertyTuple
from useq._mda_event import Channel as EventChannel
from useq._mda_event import MDAEvent, PropertyTuple, SLMImage
from useq._mda_sequence import MDASequence
from useq._plate import WellPlate, WellPlatePlan
from useq._plate_registry import register_well_plates, registered_well_plate_keys
Expand Down Expand Up @@ -49,6 +50,7 @@
"AxesBasedAF",
"Axis",
"Channel",
"EventChannel",
"GridFromEdges",
"GridRelative",
"GridRowsColumns",
Expand All @@ -62,10 +64,10 @@
"Position", # alias for AbsolutePosition
"PropertyTuple",
"RandomPoints",
"register_well_plates",
"registered_well_plate_keys",
"RelativeMultiPointPlan",
"RelativePosition",
"SLMImage",
"SLMImages",
"Shape",
"TDurationLoops",
"TIntervalDuration",
Expand All @@ -78,6 +80,8 @@
"ZRangeAround",
"ZRelativePositions",
"ZTopBottom",
"register_well_plates",
"registered_well_plate_keys",
]


Expand Down
36 changes: 20 additions & 16 deletions src/useq/_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
)

import numpy as np
import pydantic
from pydantic import BaseModel, ConfigDict

if TYPE_CHECKING:
Expand All @@ -20,28 +21,17 @@

ReprArgs = Iterable[tuple[str | None, Any]]

__all__ = ["UseqModel", "FrozenModel"]
__all__ = ["FrozenModel", "UseqModel"]

_T = TypeVar("_T", bound="FrozenModel")
_Y = TypeVar("_Y", bound="UseqModel")


def _non_default_repr_args(obj: BaseModel, fields: "ReprArgs") -> "ReprArgs":
"""Set fields on a model instance."""
return [
(k, val)
for k, val in fields
if k in obj.model_fields
and val
!= (
factory()
if (factory := obj.model_fields[k].default_factory) is not None
else obj.model_fields[k].default
)
]
PYDANTIC_VERSION = tuple(int(x) for x in pydantic.__version__.split(".")[:3])
GET_DEFAULT_KWARGS: dict = {}
if PYDANTIC_VERSION >= (2, 10):
GET_DEFAULT_KWARGS = {"validated_data": {}}


# TODO: consider removing this and using model_copy directly
class _ReplaceableModel(BaseModel):
def replace(self, **kwargs: Any) -> "Self":
"""Return a new instance replacing specified kwargs with new values.
Expand All @@ -64,6 +54,20 @@ def __repr_args__(self) -> "ReprArgs":
return _non_default_repr_args(self, super().__repr_args__())


def _non_default_repr_args(obj: BaseModel, fields: "ReprArgs") -> "ReprArgs":
"""Set fields on a model instance."""
for k, val in fields:
if k and (field := obj.model_fields.get(k)) and field.repr:
default = field.get_default(call_default_factory=True, **GET_DEFAULT_KWARGS)
try:
if val == default:
continue
except ValueError:
if np.array_equal(val, default):
continue
yield k, val


class FrozenModel(_ReplaceableModel):
model_config: ClassVar["ConfigDict"] = ConfigDict(
populate_by_name=True,
Expand Down
4 changes: 2 additions & 2 deletions src/useq/_iter_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@

from typing_extensions import TypedDict

from useq._channel import Channel # noqa: TCH001 # noqa: TCH001
from useq._channel import Channel # noqa: TC001 # noqa: TCH001
from useq._mda_event import Channel as EventChannel
from useq._mda_event import MDAEvent
from useq._utils import AXES, Axis, _has_axes
from useq._z import AnyZPlan # noqa: TCH001 # noqa: TCH001
from useq._z import AnyZPlan # noqa: TC001 # noqa: TCH001

if TYPE_CHECKING:
from collections.abc import Iterator
Expand Down
6 changes: 3 additions & 3 deletions src/useq/_mda_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@

from useq._base_model import UseqModel
from useq._channel import Channel
from useq._grid import MultiPointPlan # noqa: TCH001
from useq._grid import MultiPointPlan # noqa: TC001
from useq._hardware_autofocus import AnyAutofocusPlan, AxesBasedAF
from useq._iter_sequence import iter_sequence
from useq._plate import WellPlatePlan
from useq._position import Position, PositionBase
from useq._time import AnyTimePlan # noqa: TCH001
from useq._time import AnyTimePlan # noqa: TC001
from useq._utils import AXES, Axis, TimeEstimate, estimate_sequence_duration
from useq._z import AnyZPlan # noqa: TCH001
from useq._z import AnyZPlan # noqa: TC001

if TYPE_CHECKING:
from typing_extensions import Self
Expand Down
6 changes: 6 additions & 0 deletions tests/test_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,7 @@ def test_slm_image() -> None:
# directly passing data
event = MDAEvent(slm_image=data)
assert isinstance(event.slm_image, SLMImage)
repr(event)

# we can cast SLMIamge to a numpy array
assert isinstance(np.asarray(event.slm_image), np.ndarray)
Expand All @@ -482,3 +483,8 @@ def test_slm_image() -> None:
assert event2.slm_image is not None
np.testing.assert_array_equal(event2.slm_image, np.array(data))
assert event2.slm_image.device == "SLM"
repr(event2)

# directly provide numpy array
event3 = MDAEvent(slm_image=SLMImage(data=np.ones((10, 10))))
print(repr(event3))

0 comments on commit 5563fc0

Please sign in to comment.