Skip to content

Commit

Permalink
feat: formalizing point-visiting strategies (#177)
Browse files Browse the repository at this point in the history
* point visiting wip

* style(pre-commit.ci): auto fixes [...]

* wip

* working well

* add tests

* fix type

* fixes and tests

* future

* fix typing

* add doc

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
tlambert03 and pre-commit-ci[bot] authored Jul 8, 2024
1 parent 19cc774 commit fb1773e
Show file tree
Hide file tree
Showing 9 changed files with 473 additions and 208 deletions.
3 changes: 2 additions & 1 deletion src/useq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
GridRowsColumns,
GridWidthHeight,
MultiPointPlan,
OrderMode,
RandomPoints,
RelativeMultiPointPlan,
Shape,
Expand All @@ -20,6 +19,7 @@
from useq._mda_sequence import MDASequence
from useq._plate import WellPlate, WellPlatePlan
from useq._plate_registry import register_well_plates, registered_well_plate_keys
from useq._point_visiting import OrderMode, TraversalOrder
from useq._position import AbsolutePosition, Position, RelativePosition
from useq._time import (
AnyTimePlan,
Expand Down Expand Up @@ -68,6 +68,7 @@
"TDurationLoops",
"TIntervalDuration",
"TIntervalLoops",
"TraversalOrder",
"WellPlate",
"WellPlatePlan",
"ZAboveBelow",
Expand Down
206 changes: 85 additions & 121 deletions src/useq/_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,38 @@
import math
import warnings
from enum import Enum
from functools import partial
from typing import Any, Callable, Iterator, Optional, Sequence, Tuple, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterable,
Iterator,
Optional,
Sequence,
Tuple,
Union,
)

import numpy as np
from pydantic import Field, field_validator
from annotated_types import Ge, Gt # noqa: TCH002
from pydantic import Field, field_validator, model_validator

from useq._point_visiting import OrderMode, TraversalOrder
from useq._position import (
AbsolutePosition,
PositionT,
RelativePosition,
_MultiPointPlan,
)

MIN_RANDOM_POINTS = 5000
if TYPE_CHECKING:
from typing_extensions import Annotated, Self, TypeAlias

PointGenerator: TypeAlias = Callable[
[np.random.RandomState, int, float, float], Iterable[tuple[float, float]]
]

MIN_RANDOM_POINTS = 10000


class RelativeTo(Enum):
Expand All @@ -35,93 +53,7 @@ class RelativeTo(Enum):
top_left: str = "top_left"


class OrderMode(Enum):
"""Order in which grid positions will be iterated.
Attributes
----------
row_wise : Literal['row_wise']
Iterate row by row.
column_wise : Literal['column_wise']
Iterate column by column.
row_wise_snake : Literal['row_wise_snake']
Iterate row by row, but alternate the direction of the columns.
column_wise_snake : Literal['column_wise_snake']
Iterate column by column, but alternate the direction of the rows.
spiral : Literal['spiral']
Iterate in a spiral pattern, starting from the center.
"""

row_wise = "row_wise"
column_wise = "column_wise"
row_wise_snake = "row_wise_snake"
column_wise_snake = "column_wise_snake"
spiral = "spiral"


def _spiral_indices(
rows: int, columns: int, center_origin: bool = False
) -> Iterator[Tuple[int, int]]:
"""Return a spiral iterator over a 2D grid.
Parameters
----------
rows : int
Number of rows.
columns : int
Number of columns.
center_origin : bool
If center_origin is True, all indices are centered around (0, 0), and some will
be negative. Otherwise, the indices are centered around (rows//2, columns//2)
Yields
------
(x, y) : tuple[int, int]
Indices of the next element in the spiral.
"""
# direction: first down and then clockwise (assuming positive Y is down)

x = y = 0
if center_origin: # see docstring
xshift = yshift = 0
else:
xshift = (columns - 1) // 2
yshift = (rows - 1) // 2
dx = 0
dy = -1
for _ in range(max(columns, rows) ** 2):
if (-columns / 2 < x <= columns / 2) and (-rows / 2 < y <= rows / 2):
yield y + yshift, x + xshift
if x == y or (x < 0 and x == -y) or (x > 0 and x == 1 - y):
dx, dy = -dy, dx
x, y = x + dx, y + dy


# function that iterates indices (row, col) in a grid where (0, 0) is the top left
def _rect_indices(
rows: int, columns: int, snake: bool = False, row_wise: bool = True
) -> Iterator[Tuple[int, int]]:
"""Return a row or column-wise iterator over a 2D grid."""
c, r = np.meshgrid(np.arange(columns), np.arange(rows))
if snake:
if row_wise:
c[1::2, :] = c[1::2, :][:, ::-1]
else:
r[:, 1::2] = r[:, 1::2][::-1, :]
return zip(r.ravel(), c.ravel()) if row_wise else zip(r.T.ravel(), c.T.ravel())


# used in iter_indices below, to determine the order in which indices are yielded
IndexGenerator = Callable[[int, int], Iterator[Tuple[int, int]]]
_INDEX_GENERATORS: dict[OrderMode, IndexGenerator] = {
OrderMode.row_wise: partial(_rect_indices, snake=False, row_wise=True),
OrderMode.column_wise: partial(_rect_indices, snake=False, row_wise=False),
OrderMode.row_wise_snake: partial(_rect_indices, snake=True, row_wise=True),
OrderMode.column_wise_snake: partial(_rect_indices, snake=True, row_wise=False),
OrderMode.spiral: _spiral_indices,
}


class _GridPlan(_MultiPointPlan[PositionT]):
"""Base class for all grid plans.
Expand Down Expand Up @@ -199,12 +131,12 @@ def iter_grid_positions(
fov_width: float | None = None,
fov_height: float | None = None,
*,
mode: OrderMode | None = None,
order: OrderMode | None = None,
) -> Iterator[PositionT]:
"""Iterate over all grid positions, given a field of view size."""
_fov_width = fov_width or self.fov_width or 1.0
_fov_height = fov_height or self.fov_height or 1.0
mode = self.mode if mode is None else OrderMode(mode)
order = self.mode if order is None else OrderMode(order)

dx, dy = self._step_size(_fov_width, _fov_height)
rows = self._nrows(dy)
Expand All @@ -213,7 +145,7 @@ def iter_grid_positions(
y0 = self._offset_y(dy)

pos_cls = RelativePosition if self.is_relative else AbsolutePosition
for idx, (r, c) in enumerate(_INDEX_GENERATORS[mode](rows, cols)):
for idx, (r, c) in enumerate(order.generate_indices(rows, cols)):
yield pos_cls(
x=x0 + c * dx,
y=y0 - r * dy,
Expand Down Expand Up @@ -431,9 +363,9 @@ class RandomPoints(_MultiPointPlan[RelativePosition]):
num_points : int
Number of points to generate.
max_width : float
Maximum width of the bounding box.
Maximum width of the bounding box in microns.
max_height : float
Maximum height of the bounding box.
Maximum height of the bounding box in microns.
shape : Shape
Shape of the bounding box. Current options are "ellipse" and "rectangle".
random_seed : Optional[int]
Expand All @@ -442,39 +374,71 @@ class RandomPoints(_MultiPointPlan[RelativePosition]):
allow_overlap : bool
By defaut, True. If False and `fov_width` and `fov_height` are specified, points
will not overlap and will be at least `fov_width` and `fov_height apart.
order : TraversalOrder
Order in which the points will be visited. If None, order is simply the order
in which the points are generated (random). Use 'nearest_neighbor' or
'two_opt' to order the points in a more structured way.
start_at : int
Index of the point to start at. This is only used if `order` is
'nearest_neighbor' or 'two_opt'.
"""

num_points: int
max_width: float = np.inf
max_height: float = np.inf
num_points: Annotated[int, Gt(1)]
max_width: Annotated[float, Gt(0)] = 1
max_height: Annotated[float, Gt(0)] = 1
shape: Shape = Shape.ELLIPSE
random_seed: Optional[int] = None
allow_overlap: bool = True
order: TraversalOrder = TraversalOrder.TWO_OPT
start_at: Annotated[int, Ge(0)] = 0

@model_validator(mode="after")
def _validate_startat(self) -> Self:
if self.start_at > (self.num_points - 1):
warnings.warn(
"start_at is greater than the number of points. "
"Setting start_at to last point.",
stacklevel=2,
)
self.start_at = self.num_points - 1
return self

def __iter__(self) -> Iterator[RelativePosition]: # type: ignore [override]
seed = np.random.RandomState(self.random_seed)
func = _POINTS_GENERATORS[self.shape]
n_points = max(self.num_points, MIN_RANDOM_POINTS)
points: list[Tuple[float, float]] = []
for idx, (x, y) in enumerate(
func(seed, n_points, self.max_width, self.max_height)
):
if (
self.allow_overlap
or self.fov_width is None
or self.fov_height is None
or _is_a_valid_point(points, x, y, self.fov_width, self.fov_height)
):
yield RelativePosition(x=x, y=y, name=f"{str(idx).zfill(4)}")
points.append((x, y))
if len(points) >= self.num_points:
break

points: Iterable[Tuple[float, float]]
# in the easy case, just generate the requested number of points
if self.allow_overlap or self.fov_width is None or self.fov_height is None:
points = func(seed, self.num_points, self.max_width, self.max_height)

else:
warnings.warn(
f"Unable to generate {self.num_points} non-overlapping points. "
f"Only {len(points)} points were found.",
stacklevel=2,
)
# if we need to avoid overlap, generate points, check if they are valid, and
# repeat until we have enough
points = []
per_iter = 100
tries = 0
while tries < MIN_RANDOM_POINTS and len(points) < self.num_points:
candidates = func(seed, per_iter, self.max_width, self.max_height)
tries += per_iter
for p in candidates:
if _is_a_valid_point(points, *p, self.fov_width, self.fov_height):
points.append(p)
if len(points) >= self.num_points:
break

if len(points) < self.num_points:
warnings.warn(
f"Unable to generate {self.num_points} non-overlapping points. "
f"Only {len(points)} points were found.",
stacklevel=2,
)

if self.order is not None:
points = self.order(points, start_at=self.start_at)

for idx, (x, y) in enumerate(points):
yield RelativePosition(x=x, y=y, name=f"{str(idx).zfill(4)}")

def num_positions(self) -> int:
return self.num_points
Expand Down Expand Up @@ -504,8 +468,9 @@ def _random_points_in_ellipse(
The point is within +/- radius_x and +/- radius_y at a random angle.
"""
xy = np.sqrt(seed.uniform(0, 1, size=(n_points, 2)))
angle = seed.uniform(0, 2 * np.pi, size=n_points)
points = seed.uniform(0, 1, size=(n_points, 3))
xy = points[:, :2]
angle = points[:, 2] * 2 * np.pi
xy[:, 0] *= (max_width / 2) * np.cos(angle)
xy[:, 1] *= (max_height / 2) * np.sin(angle)
return xy
Expand All @@ -524,7 +489,6 @@ def _random_points_in_rectangle(
return xy


PointGenerator = Callable[[np.random.RandomState, int, float, float], np.ndarray]
_POINTS_GENERATORS: dict[Shape, PointGenerator] = {
Shape.ELLIPSE: _random_points_in_ellipse,
Shape.RECTANGLE: _random_points_in_rectangle,
Expand Down
31 changes: 16 additions & 15 deletions src/useq/_mda_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@
from useq._grid import MultiPointPlan # noqa: TCH001
from useq._hardware_autofocus import AnyAutofocusPlan, AxesBasedAF
from useq._iter_sequence import iter_sequence
from useq._plate import WellPlatePlan # noqa: TCH001
from useq._plate import WellPlatePlan
from useq._position import Position, PositionBase
from useq._time import AnyTimePlan # noqa: TCH001
from useq._utils import AXES, Axis, TimeEstimate, estimate_sequence_duration
from useq._z import AnyZPlan # noqa: TCH001

if TYPE_CHECKING:
from typing_extensions import Self

from useq._mda_event import MDAEvent


Expand Down Expand Up @@ -282,27 +284,26 @@ def _validate_axis_order(cls, v: Any) -> tuple[str, ...]:
return order

@model_validator(mode="after")
@classmethod
def _validate_mda(cls, values: Any) -> Any:
if values.axis_order:
cls._check_order(
values.axis_order,
z_plan=values.z_plan,
stage_positions=values.stage_positions,
channels=values.channels,
grid_plan=values.grid_plan,
autofocus_plan=values.autofocus_plan,
def _validate_mda(self) -> Self:
if self.axis_order:
self._check_order(
self.axis_order,
z_plan=self.z_plan,
stage_positions=self.stage_positions,
channels=self.channels,
grid_plan=self.grid_plan,
autofocus_plan=self.autofocus_plan,
)
if values.stage_positions:
for p in values.stage_positions:
if self.stage_positions and not isinstance(self.stage_positions, WellPlatePlan):
for p in self.stage_positions:
if hasattr(p, "sequence") and getattr(
p.sequence, "keep_shutter_open_across", None
): # pragma: no cover
raise ValueError(
"keep_shutter_open_across cannot currently be set on a "
"Position sequence"
)
return values
return self

def __eq__(self, other: Any) -> bool:
"""Return `True` if two `MDASequences` are equal (uid is excluded)."""
Expand All @@ -315,7 +316,7 @@ def __eq__(self, other: Any) -> bool:

@staticmethod
def _check_order(
order: str,
order: tuple[str, ...],
z_plan: Optional[AnyZPlan] = None,
stage_positions: Sequence[Position] = (),
channels: Sequence[Channel] = (),
Expand Down
Loading

0 comments on commit fb1773e

Please sign in to comment.