From 5801ee6fc7e5d0b376e35376679be61da5e474b6 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Sun, 7 Jul 2024 12:38:18 -0400 Subject: [PATCH] refactor: multi-point-plan (#176) * refactor: multi-point-plan * add back replace * add tests * minor refactor * swap defaults * reduce dupe --- src/useq/__init__.py | 29 ++++++++++---- src/useq/_base_model.py | 72 ++++++++++++++++++++++------------- src/useq/_grid.py | 79 +++++++++++---------------------------- src/useq/_mda_sequence.py | 6 +-- src/useq/_plate.py | 10 ++--- src/useq/_position.py | 48 ++++++++++++++++++++---- tests/test_grid.py | 29 +++++++++++++- 7 files changed, 162 insertions(+), 111 deletions(-) diff --git a/src/useq/__init__.py b/src/useq/__init__.py index afa0c625..bf3d2481 100644 --- a/src/useq/__init__.py +++ b/src/useq/__init__.py @@ -1,22 +1,26 @@ """Implementation agnostic schema for multi-dimensional microscopy experiments.""" +import warnings from typing import Any from useq._actions import AcquireImage, Action, HardwareAutofocus from useq._channel import Channel from useq._grid import ( - AnyGridPlan, GridFromEdges, GridRowsColumns, GridWidthHeight, + MultiPointPlan, + OrderMode, RandomPoints, + RelativeMultiPointPlan, + Shape, ) from useq._hardware_autofocus import AnyAutofocusPlan, AutoFocusPlan, AxesBasedAF from useq._mda_event import MDAEvent, PropertyTuple from useq._mda_sequence import MDASequence from useq._plate import WellPlate, WellPlatePlan from useq._plate_registry import register_well_plates, registered_well_plate_keys -from useq._position import Position, RelativePosition +from useq._position import AbsolutePosition, Position, RelativePosition from useq._time import ( AnyTimePlan, MultiPhaseTimePlan, @@ -34,13 +38,10 @@ ) __all__ = [ - "Position", + "AbsolutePosition", "AcquireImage", "Action", - "register_well_plates", - "registered_well_plate_keys", "AnyAutofocusPlan", - "AnyGridPlan", "AnyTimePlan", "AnyZPlan", "AutoFocusPlan", @@ -54,14 +55,21 @@ "MDAEvent", "MDASequence", "MultiPhaseTimePlan", + "MultiPointPlan", + "OrderMode", + "Position", # alias for AbsolutePosition "PropertyTuple", "RandomPoints", + "register_well_plates", + "registered_well_plate_keys", + "RelativeMultiPointPlan", "RelativePosition", + "Shape", "TDurationLoops", "TIntervalDuration", "TIntervalLoops", - "WellPlatePlan", "WellPlate", + "WellPlatePlan", "ZAboveBelow", "ZAbsolutePositions", "ZRangeAround", @@ -85,4 +93,11 @@ def __getattr__(name: str) -> Any: # ) return GridRowsColumns + if name == "AnyGridPlan": # pragma: no cover + warnings.warn( + "useq.AnyGridPlan has been renamed to useq.MultiPointPlan", + DeprecationWarning, + stacklevel=2, + ) + return MultiPointPlan raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/src/useq/_base_model.py b/src/useq/_base_model.py index 52e3d2b2..71b02a88 100644 --- a/src/useq/_base_model.py +++ b/src/useq/_base_model.py @@ -5,9 +5,8 @@ TYPE_CHECKING, Any, ClassVar, + Iterable, Optional, - Sequence, - Tuple, Type, TypeVar, Union, @@ -17,7 +16,9 @@ from pydantic import BaseModel, ConfigDict if TYPE_CHECKING: - ReprArgs = Sequence[Tuple[Optional[str], Any]] + from typing_extensions import Self + + ReprArgs = Iterable[tuple[str | None, Any]] __all__ = ["UseqModel", "FrozenModel"] @@ -25,41 +26,58 @@ _Y = TypeVar("_Y", bound="UseqModel") -class FrozenModel(BaseModel): - model_config: ClassVar["ConfigDict"] = ConfigDict( - populate_by_name=True, - extra="ignore", - frozen=True, - json_encoders={MappingProxyType: dict}, - ) +def _non_default_repr_args(obj: BaseModel, fields: "ReprArgs") -> "ReprArgs": + """Set fields on a model instance.""" + return [ + (k, val) + for k, val in fields + if k in obj.model_fields + and val + != ( + factory() + if (factory := obj.model_fields[k].default_factory) is not None + else obj.model_fields[k].default + ) + ] + - def replace(self: _T, **kwargs: Any) -> _T: +# TODO: consider removing this and using model_copy directly +class _ReplaceableModel(BaseModel): + def replace(self, **kwargs: Any) -> "Self": """Return a new instance replacing specified kwargs with new values. This model is immutable, so this method is useful for creating a new sequence with only a few fields changed. The uid of the new sequence will be different from the original. - The difference between this and `self.copy(update={...})` is that this method - will perform validation and casting on the new values, whereas `copy` assumes - that all objects are valid and will not perform any validation or casting. + The difference between this and `self.model_copy(update={...})` is that this + method will perform validation and casting on the new values, whereas `copy` + assumes that all objects are valid and will not perform any validation or + casting. """ - state = self.model_dump(exclude={"uid"}) - return type(self)(**{**state, **kwargs}) + return type(self).model_validate({**self.model_dump(exclude={"uid"}), **kwargs}) def __repr_args__(self) -> "ReprArgs": """Only show fields that are not None or equal to their default value.""" - return [ - (k, val) - for k, val in super().__repr_args__() - if k in self.model_fields - and val - != ( - factory() - if (factory := self.model_fields[k].default_factory) is not None - else self.model_fields[k].default - ) - ] + return _non_default_repr_args(self, super().__repr_args__()) + + +class FrozenModel(_ReplaceableModel): + model_config: ClassVar["ConfigDict"] = ConfigDict( + populate_by_name=True, + extra="ignore", + frozen=True, + json_encoders={MappingProxyType: dict}, + ) + + +class MutableModel(_ReplaceableModel): + model_config: ClassVar["ConfigDict"] = ConfigDict( + populate_by_name=True, + validate_assignment=True, + validate_default=True, + extra="ignore", + ) class UseqModel(FrozenModel): diff --git a/src/useq/_grid.py b/src/useq/_grid.py index de3ef702..51dcd660 100644 --- a/src/useq/_grid.py +++ b/src/useq/_grid.py @@ -5,29 +5,17 @@ import warnings from enum import Enum from functools import partial -from typing import ( - TYPE_CHECKING, - Any, - Callable, - ClassVar, - Generic, - Iterator, - Literal, # noqa: F401 - Optional, - Sequence, - Tuple, - TypeVar, - Union, -) +from typing import Any, Callable, Iterator, Optional, Sequence, Tuple, Union import numpy as np from pydantic import Field, field_validator -from useq._base_model import FrozenModel -from useq._position import AbsolutePosition, PositionBase, RelativePosition - -if TYPE_CHECKING: - from pydantic import ConfigDict +from useq._position import ( + AbsolutePosition, + PositionT, + RelativePosition, + _MultiPointPlan, +) MIN_RANDOM_POINTS = 5000 @@ -133,28 +121,8 @@ def _rect_indices( OrderMode.spiral: _spiral_indices, } -PositionT = TypeVar("PositionT", bound=PositionBase) - - -class _PointsPlan(FrozenModel, Generic[PositionT]): - # Overriding FrozenModel to make fov_width and fov_height mutable. - model_config: ClassVar[ConfigDict] = {"validate_assignment": True, "frozen": False} - - fov_width: Optional[float] = Field(None) - fov_height: Optional[float] = Field(None) - - @property - def is_relative(self) -> bool: - return False - - def __iter__(self) -> Iterator[PositionT]: # type: ignore [override] - raise NotImplementedError("This method must be implemented by subclasses.") - - def num_positions(self) -> int: - raise NotImplementedError("This method must be implemented by subclasses.") - -class _GridPlan(_PointsPlan[PositionT]): +class _GridPlan(_MultiPointPlan[PositionT]): """Base class for all grid plans. Attributes @@ -246,7 +214,7 @@ def iter_grid_positions( pos_cls = RelativePosition if self.is_relative else AbsolutePosition for idx, (r, c) in enumerate(_INDEX_GENERATORS[mode](rows, cols)): - yield pos_cls( # type: ignore [misc] + yield pos_cls( x=x0 + c * dx, y=y0 - r * dy, row=r, @@ -303,6 +271,10 @@ class GridFromEdges(_GridPlan[AbsolutePosition]): bottom: float = Field(..., frozen=True) right: float = Field(..., frozen=True) + @property + def is_relative(self) -> bool: + return False + def _nrows(self, dy: float) -> int: total_height = abs(self.top - self.bottom) + dy return math.ceil(total_height / dy) @@ -319,7 +291,7 @@ def _offset_y(self, dy: float) -> float: class GridRowsColumns(_GridPlan[RelativePosition]): - """Yield relative delta increments to build a grid acquisition. + """Grid plan based on number of rows and columns. Attributes ---------- @@ -354,10 +326,6 @@ class GridRowsColumns(_GridPlan[RelativePosition]): columns: int = Field(..., frozen=True, ge=1) relative_to: RelativeTo = Field(RelativeTo.center, frozen=True) - @property - def is_relative(self) -> bool: - return True - def _nrows(self, dy: float) -> int: return self.rows @@ -381,7 +349,7 @@ def _offset_y(self, dy: float) -> float: class GridWidthHeight(_GridPlan[RelativePosition]): - """Yield relative delta increments to build a grid acquisition. + """Grid plan based on total width and height. Attributes ---------- @@ -416,10 +384,6 @@ class GridWidthHeight(_GridPlan[RelativePosition]): height: float = Field(..., frozen=True, gt=0) relative_to: RelativeTo = Field(RelativeTo.center, frozen=True) - @property - def is_relative(self) -> bool: - return True - def _nrows(self, dy: float) -> int: return math.ceil(self.height / dy) @@ -459,7 +423,7 @@ class Shape(Enum): RECTANGLE = "rectangle" -class RandomPoints(_PointsPlan[RelativePosition]): +class RandomPoints(_MultiPointPlan[RelativePosition]): """Yield random points in a specified geometric shape. Attributes @@ -487,10 +451,6 @@ class RandomPoints(_PointsPlan[RelativePosition]): random_seed: Optional[int] = None allow_overlap: bool = True - @property - def is_relative(self) -> bool: - return True - def __iter__(self) -> Iterator[RelativePosition]: # type: ignore [override] seed = np.random.RandomState(self.random_seed) func = _POINTS_GENERATORS[self.shape] @@ -571,4 +531,9 @@ def _random_points_in_rectangle( } -AnyGridPlan = Union[GridFromEdges, GridRowsColumns, GridWidthHeight, RandomPoints] +# all of these support __iter__() -> Iterator[PositionBase] and num_positions() -> int +RelativeMultiPointPlan = Union[ + GridRowsColumns, GridWidthHeight, RandomPoints, RelativePosition +] +AbsoluteMultiPointPlan = Union[GridFromEdges] +MultiPointPlan = Union[AbsoluteMultiPointPlan, RelativeMultiPointPlan] diff --git a/src/useq/_mda_sequence.py b/src/useq/_mda_sequence.py index d0a5c899..726966f7 100644 --- a/src/useq/_mda_sequence.py +++ b/src/useq/_mda_sequence.py @@ -20,7 +20,7 @@ from useq._base_model import UseqModel from useq._channel import Channel -from useq._grid import AnyGridPlan # noqa: TCH001 +from useq._grid import MultiPointPlan # noqa: TCH001 from useq._hardware_autofocus import AnyAutofocusPlan, AxesBasedAF from useq._iter_sequence import iter_sequence from useq._plate import WellPlatePlan # noqa: TCH001 @@ -184,7 +184,7 @@ class MDASequence(UseqModel): stage_positions: Union[WellPlatePlan, Tuple[Position, ...]] = Field( default_factory=tuple ) - grid_plan: Optional[AnyGridPlan] = None + grid_plan: Optional[MultiPointPlan] = None channels: Tuple[Channel, ...] = Field(default_factory=tuple) time_plan: Optional[AnyTimePlan] = None z_plan: Optional[AnyZPlan] = None @@ -317,7 +317,7 @@ def _check_order( z_plan: Optional[AnyZPlan] = None, stage_positions: Sequence[Position] = (), channels: Sequence[Channel] = (), - grid_plan: Optional[AnyGridPlan] = None, + grid_plan: Optional[MultiPointPlan] = None, autofocus_plan: Optional[AnyAutofocusPlan] = None, ) -> None: if ( diff --git a/src/useq/_plate.py b/src/useq/_plate.py index f336fe5e..d8e79d69 100644 --- a/src/useq/_plate.py +++ b/src/useq/_plate.py @@ -21,7 +21,7 @@ from typing_extensions import Annotated from useq._base_model import FrozenModel -from useq._grid import GridRowsColumns, RandomPoints, Shape, _PointsPlan +from useq._grid import RandomPoints, RelativeMultiPointPlan, Shape from useq._plate_registry import _PLATE_REGISTRY from useq._position import Position, PositionBase, RelativePosition @@ -159,9 +159,7 @@ class WellPlatePlan(FrozenModel, Sequence[Position]): a1_center_xy: Tuple[float, float] rotation: Union[float, None] = None selected_wells: Union[IndexExpression, None] = None - well_points_plan: Union[GridRowsColumns, RandomPoints, RelativePosition] = Field( - default_factory=lambda: RelativePosition(x=0, y=0) - ) + well_points_plan: RelativeMultiPointPlan = Field(default_factory=RelativePosition) @field_validator("plate", mode="before") @classmethod @@ -414,9 +412,7 @@ def plot(self, show_axis: bool = True) -> None: ax.add_patch(sh) ################ plot image positions ################ - w = h = None - if isinstance(self.well_points_plan, _PointsPlan): - w, h = self.well_points_plan.fov_width, self.well_points_plan.fov_height + w, h = self.well_points_plan.fov_width, self.well_points_plan.fov_height for img_point in self.image_positions: x, y = float(img_point.x), float(img_point.y) # type: ignore[arg-type] # µm diff --git a/src/useq/_position.py b/src/useq/_position.py index ede64d58..2c19956f 100644 --- a/src/useq/_position.py +++ b/src/useq/_position.py @@ -1,8 +1,8 @@ -from typing import TYPE_CHECKING, ClassVar, Literal, Optional, SupportsIndex +from typing import TYPE_CHECKING, Generic, Iterator, Optional, SupportsIndex, TypeVar from pydantic import Field -from useq._base_model import FrozenModel +from useq._base_model import FrozenModel, MutableModel if TYPE_CHECKING: from typing_extensions import Self @@ -10,7 +10,7 @@ from useq import MDASequence -class PositionBase(FrozenModel): +class PositionBase(MutableModel): """Define a position in 3D space. Any of the attributes can be `None` to indicate that the position is not @@ -72,16 +72,48 @@ def __round__(self, ndigits: "SupportsIndex | None" = None) -> "Self": return type(self).model_construct(**kwargs) # type: ignore [return-value] -class AbsolutePosition(PositionBase): +class AbsolutePosition(PositionBase, FrozenModel): """An absolute position in 3D space.""" - is_relative: ClassVar[Literal[False]] = False + @property + def is_relative(self) -> bool: + return False Position = AbsolutePosition # for backwards compatibility +PositionT = TypeVar("PositionT", bound=PositionBase) -class RelativePosition(PositionBase): - """A relative position in 3D space.""" +class _MultiPointPlan(MutableModel, Generic[PositionT]): + """Any plan that yields multiple positions.""" - is_relative: ClassVar[Literal[True]] = True + fov_width: Optional[float] = None + fov_height: Optional[float] = None + + @property + def is_relative(self) -> bool: + return True + + def __iter__(self) -> Iterator[PositionT]: # type: ignore [override] + raise NotImplementedError("This method must be implemented by subclasses.") + + def num_positions(self) -> int: + raise NotImplementedError("This method must be implemented by subclasses.") + + +class RelativePosition(PositionBase, _MultiPointPlan["RelativePosition"]): + """A relative position in 3D space. + + Relative positions also support `fov_width` and `fov_height` attributes, and can + be used to define a single field of view for a "multi-point" plan. + """ + + x: float = 0 + y: float = 0 + z: float = 0 + + def __iter__(self) -> Iterator["RelativePosition"]: # type: ignore [override] + yield self + + def num_positions(self) -> int: + return 1 diff --git a/tests/test_grid.py b/tests/test_grid.py index c01b12b6..d623106b 100644 --- a/tests/test_grid.py +++ b/tests/test_grid.py @@ -1,10 +1,17 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Iterable, Optional +from typing import TYPE_CHECKING, Iterable, Optional, get_args import pytest -from useq import GridFromEdges, GridRowsColumns, GridWidthHeight, RandomPoints +from useq import ( + GridFromEdges, + GridRowsColumns, + GridWidthHeight, + RandomPoints, + RelativeMultiPointPlan, + RelativePosition, +) from useq._grid import OrderMode, _rect_indices, _spiral_indices if TYPE_CHECKING: @@ -147,3 +154,21 @@ def test_random_points(n_points: int, shape: str, seed: Optional[int]) -> None: else: with pytest.raises(UserWarning, match="Unable to generate"): list(rp) + + +fov = {"fov_height": 200, "fov_width": 200} + + +@pytest.mark.parametrize( + "obj", + [ + GridRowsColumns(rows=1, columns=2, **fov), + GridWidthHeight(width=10, height=10, **fov), + RandomPoints(num_points=10, **fov), + RelativePosition(**fov), + ], +) +def test_points_plans(obj: RelativeMultiPointPlan): + assert isinstance(obj, get_args(RelativeMultiPointPlan)) + assert all(isinstance(x, RelativePosition) for x in obj) + assert isinstance(obj.num_positions(), int)