Skip to content

Commit

Permalink
refactor: multi-point-plan (#176)
Browse files Browse the repository at this point in the history
* refactor: multi-point-plan

* add back replace

* add tests

* minor refactor

* swap defaults

* reduce dupe
  • Loading branch information
tlambert03 authored Jul 7, 2024
1 parent 7cc1c96 commit 5801ee6
Show file tree
Hide file tree
Showing 7 changed files with 162 additions and 111 deletions.
29 changes: 22 additions & 7 deletions src/useq/__init__.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,26 @@
"""Implementation agnostic schema for multi-dimensional microscopy experiments."""

import warnings
from typing import Any

from useq._actions import AcquireImage, Action, HardwareAutofocus
from useq._channel import Channel
from useq._grid import (
AnyGridPlan,
GridFromEdges,
GridRowsColumns,
GridWidthHeight,
MultiPointPlan,
OrderMode,
RandomPoints,
RelativeMultiPointPlan,
Shape,
)
from useq._hardware_autofocus import AnyAutofocusPlan, AutoFocusPlan, AxesBasedAF
from useq._mda_event import MDAEvent, PropertyTuple
from useq._mda_sequence import MDASequence
from useq._plate import WellPlate, WellPlatePlan
from useq._plate_registry import register_well_plates, registered_well_plate_keys
from useq._position import Position, RelativePosition
from useq._position import AbsolutePosition, Position, RelativePosition
from useq._time import (
AnyTimePlan,
MultiPhaseTimePlan,
Expand All @@ -34,13 +38,10 @@
)

__all__ = [
"Position",
"AbsolutePosition",
"AcquireImage",
"Action",
"register_well_plates",
"registered_well_plate_keys",
"AnyAutofocusPlan",
"AnyGridPlan",
"AnyTimePlan",
"AnyZPlan",
"AutoFocusPlan",
Expand All @@ -54,14 +55,21 @@
"MDAEvent",
"MDASequence",
"MultiPhaseTimePlan",
"MultiPointPlan",
"OrderMode",
"Position", # alias for AbsolutePosition
"PropertyTuple",
"RandomPoints",
"register_well_plates",
"registered_well_plate_keys",
"RelativeMultiPointPlan",
"RelativePosition",
"Shape",
"TDurationLoops",
"TIntervalDuration",
"TIntervalLoops",
"WellPlatePlan",
"WellPlate",
"WellPlatePlan",
"ZAboveBelow",
"ZAbsolutePositions",
"ZRangeAround",
Expand All @@ -85,4 +93,11 @@ def __getattr__(name: str) -> Any:
# )

return GridRowsColumns
if name == "AnyGridPlan": # pragma: no cover
warnings.warn(
"useq.AnyGridPlan has been renamed to useq.MultiPointPlan",
DeprecationWarning,
stacklevel=2,
)
return MultiPointPlan
raise AttributeError(f"module {__name__} has no attribute {name}")
72 changes: 45 additions & 27 deletions src/useq/_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
TYPE_CHECKING,
Any,
ClassVar,
Iterable,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
Expand All @@ -17,49 +16,68 @@
from pydantic import BaseModel, ConfigDict

if TYPE_CHECKING:
ReprArgs = Sequence[Tuple[Optional[str], Any]]
from typing_extensions import Self

ReprArgs = Iterable[tuple[str | None, Any]]

__all__ = ["UseqModel", "FrozenModel"]

_T = TypeVar("_T", bound="FrozenModel")
_Y = TypeVar("_Y", bound="UseqModel")


class FrozenModel(BaseModel):
model_config: ClassVar["ConfigDict"] = ConfigDict(
populate_by_name=True,
extra="ignore",
frozen=True,
json_encoders={MappingProxyType: dict},
)
def _non_default_repr_args(obj: BaseModel, fields: "ReprArgs") -> "ReprArgs":
"""Set fields on a model instance."""
return [
(k, val)
for k, val in fields
if k in obj.model_fields
and val
!= (
factory()
if (factory := obj.model_fields[k].default_factory) is not None
else obj.model_fields[k].default
)
]


def replace(self: _T, **kwargs: Any) -> _T:
# TODO: consider removing this and using model_copy directly
class _ReplaceableModel(BaseModel):
def replace(self, **kwargs: Any) -> "Self":
"""Return a new instance replacing specified kwargs with new values.
This model is immutable, so this method is useful for creating a new
sequence with only a few fields changed. The uid of the new sequence will
be different from the original.
The difference between this and `self.copy(update={...})` is that this method
will perform validation and casting on the new values, whereas `copy` assumes
that all objects are valid and will not perform any validation or casting.
The difference between this and `self.model_copy(update={...})` is that this
method will perform validation and casting on the new values, whereas `copy`
assumes that all objects are valid and will not perform any validation or
casting.
"""
state = self.model_dump(exclude={"uid"})
return type(self)(**{**state, **kwargs})
return type(self).model_validate({**self.model_dump(exclude={"uid"}), **kwargs})

def __repr_args__(self) -> "ReprArgs":
"""Only show fields that are not None or equal to their default value."""
return [
(k, val)
for k, val in super().__repr_args__()
if k in self.model_fields
and val
!= (
factory()
if (factory := self.model_fields[k].default_factory) is not None
else self.model_fields[k].default
)
]
return _non_default_repr_args(self, super().__repr_args__())


class FrozenModel(_ReplaceableModel):
model_config: ClassVar["ConfigDict"] = ConfigDict(
populate_by_name=True,
extra="ignore",
frozen=True,
json_encoders={MappingProxyType: dict},
)


class MutableModel(_ReplaceableModel):
model_config: ClassVar["ConfigDict"] = ConfigDict(
populate_by_name=True,
validate_assignment=True,
validate_default=True,
extra="ignore",
)


class UseqModel(FrozenModel):
Expand Down
79 changes: 22 additions & 57 deletions src/useq/_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,17 @@
import warnings
from enum import Enum
from functools import partial
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Generic,
Iterator,
Literal, # noqa: F401
Optional,
Sequence,
Tuple,
TypeVar,
Union,
)
from typing import Any, Callable, Iterator, Optional, Sequence, Tuple, Union

import numpy as np
from pydantic import Field, field_validator

from useq._base_model import FrozenModel
from useq._position import AbsolutePosition, PositionBase, RelativePosition

if TYPE_CHECKING:
from pydantic import ConfigDict
from useq._position import (
AbsolutePosition,
PositionT,
RelativePosition,
_MultiPointPlan,
)

MIN_RANDOM_POINTS = 5000

Expand Down Expand Up @@ -133,28 +121,8 @@ def _rect_indices(
OrderMode.spiral: _spiral_indices,
}

PositionT = TypeVar("PositionT", bound=PositionBase)


class _PointsPlan(FrozenModel, Generic[PositionT]):
# Overriding FrozenModel to make fov_width and fov_height mutable.
model_config: ClassVar[ConfigDict] = {"validate_assignment": True, "frozen": False}

fov_width: Optional[float] = Field(None)
fov_height: Optional[float] = Field(None)

@property
def is_relative(self) -> bool:
return False

def __iter__(self) -> Iterator[PositionT]: # type: ignore [override]
raise NotImplementedError("This method must be implemented by subclasses.")

def num_positions(self) -> int:
raise NotImplementedError("This method must be implemented by subclasses.")


class _GridPlan(_PointsPlan[PositionT]):
class _GridPlan(_MultiPointPlan[PositionT]):
"""Base class for all grid plans.
Attributes
Expand Down Expand Up @@ -246,7 +214,7 @@ def iter_grid_positions(

pos_cls = RelativePosition if self.is_relative else AbsolutePosition
for idx, (r, c) in enumerate(_INDEX_GENERATORS[mode](rows, cols)):
yield pos_cls( # type: ignore [misc]
yield pos_cls(
x=x0 + c * dx,
y=y0 - r * dy,
row=r,
Expand Down Expand Up @@ -303,6 +271,10 @@ class GridFromEdges(_GridPlan[AbsolutePosition]):
bottom: float = Field(..., frozen=True)
right: float = Field(..., frozen=True)

@property
def is_relative(self) -> bool:
return False

def _nrows(self, dy: float) -> int:
total_height = abs(self.top - self.bottom) + dy
return math.ceil(total_height / dy)
Expand All @@ -319,7 +291,7 @@ def _offset_y(self, dy: float) -> float:


class GridRowsColumns(_GridPlan[RelativePosition]):
"""Yield relative delta increments to build a grid acquisition.
"""Grid plan based on number of rows and columns.
Attributes
----------
Expand Down Expand Up @@ -354,10 +326,6 @@ class GridRowsColumns(_GridPlan[RelativePosition]):
columns: int = Field(..., frozen=True, ge=1)
relative_to: RelativeTo = Field(RelativeTo.center, frozen=True)

@property
def is_relative(self) -> bool:
return True

def _nrows(self, dy: float) -> int:
return self.rows

Expand All @@ -381,7 +349,7 @@ def _offset_y(self, dy: float) -> float:


class GridWidthHeight(_GridPlan[RelativePosition]):
"""Yield relative delta increments to build a grid acquisition.
"""Grid plan based on total width and height.
Attributes
----------
Expand Down Expand Up @@ -416,10 +384,6 @@ class GridWidthHeight(_GridPlan[RelativePosition]):
height: float = Field(..., frozen=True, gt=0)
relative_to: RelativeTo = Field(RelativeTo.center, frozen=True)

@property
def is_relative(self) -> bool:
return True

def _nrows(self, dy: float) -> int:
return math.ceil(self.height / dy)

Expand Down Expand Up @@ -459,7 +423,7 @@ class Shape(Enum):
RECTANGLE = "rectangle"


class RandomPoints(_PointsPlan[RelativePosition]):
class RandomPoints(_MultiPointPlan[RelativePosition]):
"""Yield random points in a specified geometric shape.
Attributes
Expand Down Expand Up @@ -487,10 +451,6 @@ class RandomPoints(_PointsPlan[RelativePosition]):
random_seed: Optional[int] = None
allow_overlap: bool = True

@property
def is_relative(self) -> bool:
return True

def __iter__(self) -> Iterator[RelativePosition]: # type: ignore [override]
seed = np.random.RandomState(self.random_seed)
func = _POINTS_GENERATORS[self.shape]
Expand Down Expand Up @@ -571,4 +531,9 @@ def _random_points_in_rectangle(
}


AnyGridPlan = Union[GridFromEdges, GridRowsColumns, GridWidthHeight, RandomPoints]
# all of these support __iter__() -> Iterator[PositionBase] and num_positions() -> int
RelativeMultiPointPlan = Union[
GridRowsColumns, GridWidthHeight, RandomPoints, RelativePosition
]
AbsoluteMultiPointPlan = Union[GridFromEdges]
MultiPointPlan = Union[AbsoluteMultiPointPlan, RelativeMultiPointPlan]
6 changes: 3 additions & 3 deletions src/useq/_mda_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from useq._base_model import UseqModel
from useq._channel import Channel
from useq._grid import AnyGridPlan # noqa: TCH001
from useq._grid import MultiPointPlan # noqa: TCH001
from useq._hardware_autofocus import AnyAutofocusPlan, AxesBasedAF
from useq._iter_sequence import iter_sequence
from useq._plate import WellPlatePlan # noqa: TCH001
Expand Down Expand Up @@ -184,7 +184,7 @@ class MDASequence(UseqModel):
stage_positions: Union[WellPlatePlan, Tuple[Position, ...]] = Field(
default_factory=tuple
)
grid_plan: Optional[AnyGridPlan] = None
grid_plan: Optional[MultiPointPlan] = None
channels: Tuple[Channel, ...] = Field(default_factory=tuple)
time_plan: Optional[AnyTimePlan] = None
z_plan: Optional[AnyZPlan] = None
Expand Down Expand Up @@ -317,7 +317,7 @@ def _check_order(
z_plan: Optional[AnyZPlan] = None,
stage_positions: Sequence[Position] = (),
channels: Sequence[Channel] = (),
grid_plan: Optional[AnyGridPlan] = None,
grid_plan: Optional[MultiPointPlan] = None,
autofocus_plan: Optional[AnyAutofocusPlan] = None,
) -> None:
if (
Expand Down
10 changes: 3 additions & 7 deletions src/useq/_plate.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from typing_extensions import Annotated

from useq._base_model import FrozenModel
from useq._grid import GridRowsColumns, RandomPoints, Shape, _PointsPlan
from useq._grid import RandomPoints, RelativeMultiPointPlan, Shape
from useq._plate_registry import _PLATE_REGISTRY
from useq._position import Position, PositionBase, RelativePosition

Expand Down Expand Up @@ -159,9 +159,7 @@ class WellPlatePlan(FrozenModel, Sequence[Position]):
a1_center_xy: Tuple[float, float]
rotation: Union[float, None] = None
selected_wells: Union[IndexExpression, None] = None
well_points_plan: Union[GridRowsColumns, RandomPoints, RelativePosition] = Field(
default_factory=lambda: RelativePosition(x=0, y=0)
)
well_points_plan: RelativeMultiPointPlan = Field(default_factory=RelativePosition)

@field_validator("plate", mode="before")
@classmethod
Expand Down Expand Up @@ -414,9 +412,7 @@ def plot(self, show_axis: bool = True) -> None:
ax.add_patch(sh)

################ plot image positions ################
w = h = None
if isinstance(self.well_points_plan, _PointsPlan):
w, h = self.well_points_plan.fov_width, self.well_points_plan.fov_height
w, h = self.well_points_plan.fov_width, self.well_points_plan.fov_height

for img_point in self.image_positions:
x, y = float(img_point.x), float(img_point.y) # type: ignore[arg-type] # µm
Expand Down
Loading

0 comments on commit 5801ee6

Please sign in to comment.