diff --git a/src/useq/__init__.py b/src/useq/__init__.py index bf3d2481..5ac235da 100644 --- a/src/useq/__init__.py +++ b/src/useq/__init__.py @@ -10,7 +10,6 @@ GridRowsColumns, GridWidthHeight, MultiPointPlan, - OrderMode, RandomPoints, RelativeMultiPointPlan, Shape, @@ -20,6 +19,7 @@ from useq._mda_sequence import MDASequence from useq._plate import WellPlate, WellPlatePlan from useq._plate_registry import register_well_plates, registered_well_plate_keys +from useq._point_visiting import OrderMode, TraversalOrder from useq._position import AbsolutePosition, Position, RelativePosition from useq._time import ( AnyTimePlan, @@ -68,6 +68,7 @@ "TDurationLoops", "TIntervalDuration", "TIntervalLoops", + "TraversalOrder", "WellPlate", "WellPlatePlan", "ZAboveBelow", diff --git a/src/useq/_grid.py b/src/useq/_grid.py index 51dcd660..b7e166a5 100644 --- a/src/useq/_grid.py +++ b/src/useq/_grid.py @@ -4,12 +4,23 @@ import math import warnings from enum import Enum -from functools import partial -from typing import Any, Callable, Iterator, Optional, Sequence, Tuple, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Iterable, + Iterator, + Optional, + Sequence, + Tuple, + Union, +) import numpy as np -from pydantic import Field, field_validator +from annotated_types import Ge, Gt # noqa: TCH002 +from pydantic import Field, field_validator, model_validator +from useq._point_visiting import OrderMode, TraversalOrder from useq._position import ( AbsolutePosition, PositionT, @@ -17,7 +28,14 @@ _MultiPointPlan, ) -MIN_RANDOM_POINTS = 5000 +if TYPE_CHECKING: + from typing_extensions import Annotated, Self, TypeAlias + + PointGenerator: TypeAlias = Callable[ + [np.random.RandomState, int, float, float], Iterable[tuple[float, float]] + ] + +MIN_RANDOM_POINTS = 10000 class RelativeTo(Enum): @@ -35,93 +53,7 @@ class RelativeTo(Enum): top_left: str = "top_left" -class OrderMode(Enum): - """Order in which grid positions will be iterated. - - Attributes - ---------- - row_wise : Literal['row_wise'] - Iterate row by row. - column_wise : Literal['column_wise'] - Iterate column by column. - row_wise_snake : Literal['row_wise_snake'] - Iterate row by row, but alternate the direction of the columns. - column_wise_snake : Literal['column_wise_snake'] - Iterate column by column, but alternate the direction of the rows. - spiral : Literal['spiral'] - Iterate in a spiral pattern, starting from the center. - """ - - row_wise = "row_wise" - column_wise = "column_wise" - row_wise_snake = "row_wise_snake" - column_wise_snake = "column_wise_snake" - spiral = "spiral" - - -def _spiral_indices( - rows: int, columns: int, center_origin: bool = False -) -> Iterator[Tuple[int, int]]: - """Return a spiral iterator over a 2D grid. - - Parameters - ---------- - rows : int - Number of rows. - columns : int - Number of columns. - center_origin : bool - If center_origin is True, all indices are centered around (0, 0), and some will - be negative. Otherwise, the indices are centered around (rows//2, columns//2) - - Yields - ------ - (x, y) : tuple[int, int] - Indices of the next element in the spiral. - """ - # direction: first down and then clockwise (assuming positive Y is down) - - x = y = 0 - if center_origin: # see docstring - xshift = yshift = 0 - else: - xshift = (columns - 1) // 2 - yshift = (rows - 1) // 2 - dx = 0 - dy = -1 - for _ in range(max(columns, rows) ** 2): - if (-columns / 2 < x <= columns / 2) and (-rows / 2 < y <= rows / 2): - yield y + yshift, x + xshift - if x == y or (x < 0 and x == -y) or (x > 0 and x == 1 - y): - dx, dy = -dy, dx - x, y = x + dx, y + dy - - -# function that iterates indices (row, col) in a grid where (0, 0) is the top left -def _rect_indices( - rows: int, columns: int, snake: bool = False, row_wise: bool = True -) -> Iterator[Tuple[int, int]]: - """Return a row or column-wise iterator over a 2D grid.""" - c, r = np.meshgrid(np.arange(columns), np.arange(rows)) - if snake: - if row_wise: - c[1::2, :] = c[1::2, :][:, ::-1] - else: - r[:, 1::2] = r[:, 1::2][::-1, :] - return zip(r.ravel(), c.ravel()) if row_wise else zip(r.T.ravel(), c.T.ravel()) - - # used in iter_indices below, to determine the order in which indices are yielded -IndexGenerator = Callable[[int, int], Iterator[Tuple[int, int]]] -_INDEX_GENERATORS: dict[OrderMode, IndexGenerator] = { - OrderMode.row_wise: partial(_rect_indices, snake=False, row_wise=True), - OrderMode.column_wise: partial(_rect_indices, snake=False, row_wise=False), - OrderMode.row_wise_snake: partial(_rect_indices, snake=True, row_wise=True), - OrderMode.column_wise_snake: partial(_rect_indices, snake=True, row_wise=False), - OrderMode.spiral: _spiral_indices, -} - - class _GridPlan(_MultiPointPlan[PositionT]): """Base class for all grid plans. @@ -199,12 +131,12 @@ def iter_grid_positions( fov_width: float | None = None, fov_height: float | None = None, *, - mode: OrderMode | None = None, + order: OrderMode | None = None, ) -> Iterator[PositionT]: """Iterate over all grid positions, given a field of view size.""" _fov_width = fov_width or self.fov_width or 1.0 _fov_height = fov_height or self.fov_height or 1.0 - mode = self.mode if mode is None else OrderMode(mode) + order = self.mode if order is None else OrderMode(order) dx, dy = self._step_size(_fov_width, _fov_height) rows = self._nrows(dy) @@ -213,7 +145,7 @@ def iter_grid_positions( y0 = self._offset_y(dy) pos_cls = RelativePosition if self.is_relative else AbsolutePosition - for idx, (r, c) in enumerate(_INDEX_GENERATORS[mode](rows, cols)): + for idx, (r, c) in enumerate(order.generate_indices(rows, cols)): yield pos_cls( x=x0 + c * dx, y=y0 - r * dy, @@ -431,9 +363,9 @@ class RandomPoints(_MultiPointPlan[RelativePosition]): num_points : int Number of points to generate. max_width : float - Maximum width of the bounding box. + Maximum width of the bounding box in microns. max_height : float - Maximum height of the bounding box. + Maximum height of the bounding box in microns. shape : Shape Shape of the bounding box. Current options are "ellipse" and "rectangle". random_seed : Optional[int] @@ -442,39 +374,71 @@ class RandomPoints(_MultiPointPlan[RelativePosition]): allow_overlap : bool By defaut, True. If False and `fov_width` and `fov_height` are specified, points will not overlap and will be at least `fov_width` and `fov_height apart. + order : TraversalOrder + Order in which the points will be visited. If None, order is simply the order + in which the points are generated (random). Use 'nearest_neighbor' or + 'two_opt' to order the points in a more structured way. + start_at : int + Index of the point to start at. This is only used if `order` is + 'nearest_neighbor' or 'two_opt'. """ - num_points: int - max_width: float = np.inf - max_height: float = np.inf + num_points: Annotated[int, Gt(1)] + max_width: Annotated[float, Gt(0)] = 1 + max_height: Annotated[float, Gt(0)] = 1 shape: Shape = Shape.ELLIPSE random_seed: Optional[int] = None allow_overlap: bool = True + order: TraversalOrder = TraversalOrder.TWO_OPT + start_at: Annotated[int, Ge(0)] = 0 + + @model_validator(mode="after") + def _validate_startat(self) -> Self: + if self.start_at > (self.num_points - 1): + warnings.warn( + "start_at is greater than the number of points. " + "Setting start_at to last point.", + stacklevel=2, + ) + self.start_at = self.num_points - 1 + return self def __iter__(self) -> Iterator[RelativePosition]: # type: ignore [override] seed = np.random.RandomState(self.random_seed) func = _POINTS_GENERATORS[self.shape] - n_points = max(self.num_points, MIN_RANDOM_POINTS) - points: list[Tuple[float, float]] = [] - for idx, (x, y) in enumerate( - func(seed, n_points, self.max_width, self.max_height) - ): - if ( - self.allow_overlap - or self.fov_width is None - or self.fov_height is None - or _is_a_valid_point(points, x, y, self.fov_width, self.fov_height) - ): - yield RelativePosition(x=x, y=y, name=f"{str(idx).zfill(4)}") - points.append((x, y)) - if len(points) >= self.num_points: - break + + points: Iterable[Tuple[float, float]] + # in the easy case, just generate the requested number of points + if self.allow_overlap or self.fov_width is None or self.fov_height is None: + points = func(seed, self.num_points, self.max_width, self.max_height) + else: - warnings.warn( - f"Unable to generate {self.num_points} non-overlapping points. " - f"Only {len(points)} points were found.", - stacklevel=2, - ) + # if we need to avoid overlap, generate points, check if they are valid, and + # repeat until we have enough + points = [] + per_iter = 100 + tries = 0 + while tries < MIN_RANDOM_POINTS and len(points) < self.num_points: + candidates = func(seed, per_iter, self.max_width, self.max_height) + tries += per_iter + for p in candidates: + if _is_a_valid_point(points, *p, self.fov_width, self.fov_height): + points.append(p) + if len(points) >= self.num_points: + break + + if len(points) < self.num_points: + warnings.warn( + f"Unable to generate {self.num_points} non-overlapping points. " + f"Only {len(points)} points were found.", + stacklevel=2, + ) + + if self.order is not None: + points = self.order(points, start_at=self.start_at) + + for idx, (x, y) in enumerate(points): + yield RelativePosition(x=x, y=y, name=f"{str(idx).zfill(4)}") def num_positions(self) -> int: return self.num_points @@ -504,8 +468,9 @@ def _random_points_in_ellipse( The point is within +/- radius_x and +/- radius_y at a random angle. """ - xy = np.sqrt(seed.uniform(0, 1, size=(n_points, 2))) - angle = seed.uniform(0, 2 * np.pi, size=n_points) + points = seed.uniform(0, 1, size=(n_points, 3)) + xy = points[:, :2] + angle = points[:, 2] * 2 * np.pi xy[:, 0] *= (max_width / 2) * np.cos(angle) xy[:, 1] *= (max_height / 2) * np.sin(angle) return xy @@ -524,7 +489,6 @@ def _random_points_in_rectangle( return xy -PointGenerator = Callable[[np.random.RandomState, int, float, float], np.ndarray] _POINTS_GENERATORS: dict[Shape, PointGenerator] = { Shape.ELLIPSE: _random_points_in_ellipse, Shape.RECTANGLE: _random_points_in_rectangle, diff --git a/src/useq/_mda_sequence.py b/src/useq/_mda_sequence.py index dc9e2eeb..fdc6a79d 100644 --- a/src/useq/_mda_sequence.py +++ b/src/useq/_mda_sequence.py @@ -23,13 +23,15 @@ from useq._grid import MultiPointPlan # noqa: TCH001 from useq._hardware_autofocus import AnyAutofocusPlan, AxesBasedAF from useq._iter_sequence import iter_sequence -from useq._plate import WellPlatePlan # noqa: TCH001 +from useq._plate import WellPlatePlan from useq._position import Position, PositionBase from useq._time import AnyTimePlan # noqa: TCH001 from useq._utils import AXES, Axis, TimeEstimate, estimate_sequence_duration from useq._z import AnyZPlan # noqa: TCH001 if TYPE_CHECKING: + from typing_extensions import Self + from useq._mda_event import MDAEvent @@ -282,19 +284,18 @@ def _validate_axis_order(cls, v: Any) -> tuple[str, ...]: return order @model_validator(mode="after") - @classmethod - def _validate_mda(cls, values: Any) -> Any: - if values.axis_order: - cls._check_order( - values.axis_order, - z_plan=values.z_plan, - stage_positions=values.stage_positions, - channels=values.channels, - grid_plan=values.grid_plan, - autofocus_plan=values.autofocus_plan, + def _validate_mda(self) -> Self: + if self.axis_order: + self._check_order( + self.axis_order, + z_plan=self.z_plan, + stage_positions=self.stage_positions, + channels=self.channels, + grid_plan=self.grid_plan, + autofocus_plan=self.autofocus_plan, ) - if values.stage_positions: - for p in values.stage_positions: + if self.stage_positions and not isinstance(self.stage_positions, WellPlatePlan): + for p in self.stage_positions: if hasattr(p, "sequence") and getattr( p.sequence, "keep_shutter_open_across", None ): # pragma: no cover @@ -302,7 +303,7 @@ def _validate_mda(cls, values: Any) -> Any: "keep_shutter_open_across cannot currently be set on a " "Position sequence" ) - return values + return self def __eq__(self, other: Any) -> bool: """Return `True` if two `MDASequences` are equal (uid is excluded).""" @@ -315,7 +316,7 @@ def __eq__(self, other: Any) -> bool: @staticmethod def _check_order( - order: str, + order: tuple[str, ...], z_plan: Optional[AnyZPlan] = None, stage_positions: Sequence[Position] = (), channels: Sequence[Channel] = (), diff --git a/src/useq/_plate.py b/src/useq/_plate.py index d8e79d69..568022cb 100644 --- a/src/useq/_plate.py +++ b/src/useq/_plate.py @@ -5,7 +5,6 @@ from functools import cached_property from typing import ( Any, - Callable, Iterable, List, Sequence, @@ -375,71 +374,9 @@ def affine_transform(self) -> np.ndarray: def plot(self, show_axis: bool = True) -> None: """Plot the selected positions on the plate.""" - import matplotlib.pyplot as plt - from matplotlib import patches + from useq._plot import plot_plate - _, ax = plt.subplots() - - # hide axes - if not show_axis: - ax.axis("off") - - # ################ draw outline of all wells ################ - height, width = self.plate.well_size # mm - height, width = height * 1000, width * 1000 # µm - - kwargs = {} - offset_x, offset_y = 0.0, 0.0 - if self.plate.circular_wells: - patch_type: Callable = patches.Ellipse - else: - patch_type = patches.Rectangle - offset_x, offset_y = -width / 2, -height / 2 - kwargs["rotation_point"] = "center" - - for well in self.all_well_positions: - sh = patch_type( - (well.x + offset_x, well.y + offset_y), # type: ignore[operator] - width=width, - height=height, - angle=self.rotation or 0, - facecolor="none", - edgecolor="gray", - linewidth=0.5, - linestyle="--", - **kwargs, - ) - ax.add_patch(sh) - - ################ plot image positions ################ - w, h = self.well_points_plan.fov_width, self.well_points_plan.fov_height - - for img_point in self.image_positions: - x, y = float(img_point.x), float(img_point.y) # type: ignore[arg-type] # µm - if w and h: - ax.add_patch( - patches.Rectangle( - (x - w / 2, y - h / 2), - width=w, - height=h, - facecolor="magenta", - edgecolor="gray", - linewidth=0.5, - alpha=0.5, - ) - ) - else: - plt.plot(x, y, "mo", markersize=3, alpha=0.5) - - # ################ draw names on used wells ################ - offset_x, offset_y = -width / 2, -height / 2 - for well in self.selected_well_positions: - x, y = float(well.x), float(well.y) # type: ignore[arg-type] - # draw name next to spot - ax.text(x + offset_x, y - offset_y, well.name, fontsize=7) - - ax.axis("equal") - plt.show() + plot_plate(self, show_axis=show_axis) def _index_to_row_name(index: int) -> str: diff --git a/src/useq/_plot.py b/src/useq/_plot.py new file mode 100644 index 00000000..52e798d8 --- /dev/null +++ b/src/useq/_plot.py @@ -0,0 +1,101 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable, Iterable + +try: + import matplotlib.pyplot as plt + from matplotlib import patches +except ImportError as e: + raise ImportError( + "Matplotlib is required for plotting functions. Please install matplotlib." + ) from e + +if TYPE_CHECKING: + from matplotlib.axes import Axes + + from useq._plate import WellPlatePlan + from useq._position import PositionBase + + +def plot_points(points: Iterable[PositionBase], ax: Axes | None = None) -> None: + """Plot a list of positions. + + Can be used with any iterable of PositionBase objects. + """ + if ax is None: + _, ax = plt.subplots() + + x, y = zip(*[(point.x, point.y) for point in points]) + ax.scatter(x, y) + ax.scatter(x[0], y[0], color="red") # mark the first point + ax.plot(x, y, alpha=0.5, color="gray") # connect the points + ax.axis("equal") + plt.show() + + +def plot_plate( + plate_plan: WellPlatePlan, show_axis: bool = True, ax: Axes | None = None +) -> None: + if ax is None: + _, ax = plt.subplots() + + # hide axes + if not show_axis: + ax.axis("off") + + # ################ draw outline of all wells ################ + height, width = plate_plan.plate.well_size # mm + height, width = height * 1000, width * 1000 # µm + + kwargs = {} + offset_x, offset_y = 0.0, 0.0 + if plate_plan.plate.circular_wells: + patch_type: Callable = patches.Ellipse + else: + patch_type = patches.Rectangle + offset_x, offset_y = -width / 2, -height / 2 + kwargs["rotation_point"] = "center" + + for well in plate_plan.all_well_positions: + sh = patch_type( + (well.x + offset_x, well.y + offset_y), # type: ignore[operator] + width=width, + height=height, + angle=plate_plan.rotation or 0, + facecolor="none", + edgecolor="gray", + linewidth=0.5, + linestyle="--", + **kwargs, + ) + ax.add_patch(sh) + + ################ plot image positions ################ + w, h = plate_plan.well_points_plan.fov_width, plate_plan.well_points_plan.fov_height + + for img_point in plate_plan.image_positions: + x, y = float(img_point.x), float(img_point.y) # type: ignore[arg-type] # µm + if w and h: + ax.add_patch( + patches.Rectangle( + (x - w / 2, y - h / 2), + width=w, + height=h, + facecolor="magenta", + edgecolor="gray", + linewidth=0.5, + alpha=0.5, + ) + ) + else: + plt.plot(x, y, "mo", markersize=3, alpha=0.5) + + # ################ draw names on used wells ################ + offset_x, offset_y = -width / 2, -height / 2 + for well in plate_plan.selected_well_positions: + x, y = float(well.x), float(well.y) # type: ignore[arg-type] + # draw name next to spot + ax.text(x + offset_x, y - offset_y, well.name, fontsize=7) + + ax.axis("equal") + plt.show() diff --git a/src/useq/_point_visiting.py b/src/useq/_point_visiting.py new file mode 100644 index 00000000..6ae2b864 --- /dev/null +++ b/src/useq/_point_visiting.py @@ -0,0 +1,233 @@ +from __future__ import annotations + +from enum import Enum +from functools import partial +from typing import Callable, Iterable, Iterator, Tuple + +import numpy as np + +# ----------------------------- Random Points ----------------------------------- + + +class OrderMode(Enum): + """Order in which grid positions will be iterated. + + Attributes + ---------- + row_wise : Literal['row_wise'] + Iterate row by row. + column_wise : Literal['column_wise'] + Iterate column by column. + row_wise_snake : Literal['row_wise_snake'] + Iterate row by row, but alternate the direction of the columns. + column_wise_snake : Literal['column_wise_snake'] + Iterate column by column, but alternate the direction of the rows. + spiral : Literal['spiral'] + Iterate in a spiral pattern, starting from the center. + """ + + row_wise = "row_wise" + column_wise = "column_wise" + row_wise_snake = "row_wise_snake" + column_wise_snake = "column_wise_snake" + spiral = "spiral" + + def generate_indices(self, rows: int, columns: int) -> Iterator[Tuple[int, int]]: + """Generate indices for the given grid size.""" + return _INDEX_GENERATORS[self](rows, columns) + + +def _spiral_indices( + rows: int, columns: int, center_origin: bool = False +) -> Iterator[Tuple[int, int]]: + """Return a spiral iterator over a 2D grid. + + Parameters + ---------- + rows : int + Number of rows. + columns : int + Number of columns. + center_origin : bool + If center_origin is True, all indices are centered around (0, 0), and some will + be negative. Otherwise, the indices are centered around (rows//2, columns//2) + + Yields + ------ + (x, y) : tuple[int, int] + Indices of the next element in the spiral. + """ + # direction: first down and then clockwise (assuming positive Y is down) + + x = y = 0 + if center_origin: # see docstring + xshift = yshift = 0 + else: + xshift = (columns - 1) // 2 + yshift = (rows - 1) // 2 + dx = 0 + dy = -1 + for _ in range(max(columns, rows) ** 2): + if (-columns / 2 < x <= columns / 2) and (-rows / 2 < y <= rows / 2): + yield y + yshift, x + xshift + if x == y or (x < 0 and x == -y) or (x > 0 and x == 1 - y): + dx, dy = -dy, dx + x, y = x + dx, y + dy + + +# function that iterates indices (row, col) in a grid where (0, 0) is the top left +def _rect_indices( + rows: int, columns: int, snake: bool = False, row_wise: bool = True +) -> Iterator[Tuple[int, int]]: + """Return a row or column-wise iterator over a 2D grid.""" + c, r = np.meshgrid(np.arange(columns), np.arange(rows)) + if snake: + if row_wise: + c[1::2, :] = c[1::2, :][:, ::-1] + else: + r[:, 1::2] = r[:, 1::2][::-1, :] + return zip(r.ravel(), c.ravel()) if row_wise else zip(r.T.ravel(), c.T.ravel()) + + +IndexGenerator = Callable[[int, int], Iterator[Tuple[int, int]]] +_INDEX_GENERATORS: dict[OrderMode, IndexGenerator] = { + OrderMode.row_wise: partial(_rect_indices, snake=False, row_wise=True), + OrderMode.column_wise: partial(_rect_indices, snake=False, row_wise=False), + OrderMode.row_wise_snake: partial(_rect_indices, snake=True, row_wise=True), + OrderMode.column_wise_snake: partial(_rect_indices, snake=True, row_wise=False), + OrderMode.spiral: _spiral_indices, +} + +# ----------------------------- Random Points ----------------------------------- + + +class TraversalOrder(Enum): + NEAREST_NEIGHBOR = "nearest_neighbor" + TWO_OPT = "two_opt" + RANDOM = "random" + + def order_points(self, points: np.ndarray, start_at: int = 0) -> np.ndarray: + """Return the order of points based on the traversal order.""" + start_at = min(start_at, len(points) - 1) + if self == TraversalOrder.NEAREST_NEIGHBOR: + return _nearest_neighbor_order(points, start_at) + if self == TraversalOrder.TWO_OPT: + return _two_opt_order(points, start_at) + if self == TraversalOrder.RANDOM: + return np.random.permutation(len(points)) + raise ValueError(f"Unknown traversal order: {self}") # pragma: no cover + + def __call__( + self, points: Iterable[tuple[float, float]], start_at: int = 0 + ) -> np.ndarray: + """Sort the points based on the traversal order.""" + points = np.asarray(points) + order = self.order_points(points, start_at=start_at) + return points[order] # type: ignore [no-any-return] + + +def _nearest_neighbor_order(points: np.ndarray, start_at: int = 0) -> np.ndarray: + """Return the order of points based on the nearest neighbor algorithm. + + Parameters. + ---------- + points : np.ndarray + Array of 2D (Y, X) points in the format (n, 2). + start_at : int, optional + Index of the point to start at. By default, the first point is used. + + Examples + -------- + >>> points = np.array([[0, 0], [1, 0], [1, 1], [0, 1]]) + >>> order = _nearest_neighbor_order(points) + >>> sorted = points[order] + """ + n = len(points) + visited = np.zeros(n, dtype=bool) + order = np.zeros(n, dtype=int) + order[0] = start_at + visited[start_at] = True + + LARGE_NUMBER = 1e9 + # NOTE: at ~500+ points, scipy.spatial.cKDTree would begin to be faster + # but it's a new dependency and may not be a common use case + for i in range(1, n): + # calculate the distance from the last visited point to all other points + dist = np.linalg.norm(points - points[order[i - 1]], axis=1) + # find the nearest point that has not been visited + next_point = np.argmin(dist + visited * LARGE_NUMBER) + # store it and mark it as visited + order[i] = next_point + visited[next_point] = True + + return order + + +def _two_opt_order( + points: np.ndarray, start_at: int = 0, improvement_threshold: float = 0.05 +) -> np.ndarray: + """Return the order of points based on the 2-opt algorithm. + + https://en.wikipedia.org/wiki/2-opt + + Parameters + ---------- + points : np.ndarray + Array of 2D (Y, X) points in the format (n, 2). + start_at : int, optional + Index of the point to start at. By default, the first point is used. + improvement_threshold : float, optional + The minimum improvement factor required to continue the optimization. + By default, 0.05. + + Examples + -------- + >>> points = np.random.rand(100, 2) + >>> order = _two_opt_order(points) + >>> sorted = points[order] + """ + n = points.shape[0] + route = np.arange(n) + + if start_at != 0: + route = np.roll(route, -start_at) + + dist_matrix = _distance_matrix(points) + + # this will track the best distance found so far + best_distance = _total_distance(points[route]) + improvement_factor = 1.0 + while improvement_factor > improvement_threshold: + distance_to_beat = best_distance + for i in range(1, n - 2): + for k in range(i + 1, n): + # Calculate the distances involved in the potential swap + ri = route[i] + ri1 = route[i - 1] + rk = route[k] + y = route[(k + 1) % n] + + dist_before = dist_matrix[ri1, ri] + dist_matrix[rk, y] + dist_after = dist_matrix[ri1, rk] + dist_matrix[ri, y] + + # If the new distance is better, perform the swap + if dist_after < dist_before: + # Reverse the order of all elements from element i to element k. + route[i : k + 1] = route[i : k + 1][::-1] + best_distance = best_distance - dist_before + dist_after + + improvement_factor = 1 - best_distance / distance_to_beat + + return route + + +def _total_distance(points: np.ndarray) -> float: + # Calculate the total Euclidean distance of the route traversing the given points + # in the order provided + diffs = points - np.roll(points, shift=1, axis=0) + return np.sum(np.linalg.norm(diffs, axis=1)) # type: ignore [no-any-return] + + +def _distance_matrix(points: np.ndarray) -> np.ndarray: + # Calculate the distance matrix (euclidean distance between each pair of points) + return np.sqrt(np.sum((points[:, None] - points) ** 2, axis=2)) # type: ignore [no-any-return] diff --git a/src/useq/_position.py b/src/useq/_position.py index 2c19956f..ec1c5526 100644 --- a/src/useq/_position.py +++ b/src/useq/_position.py @@ -100,6 +100,12 @@ def __iter__(self) -> Iterator[PositionT]: # type: ignore [override] def num_positions(self) -> int: raise NotImplementedError("This method must be implemented by subclasses.") + def plot(self) -> None: + """Plot the positions in the plan.""" + from useq._plot import plot_points + + plot_points(self) + class RelativePosition(PositionBase, _MultiPointPlan["RelativePosition"]): """A relative position in 3D space. diff --git a/tests/test_grid.py b/tests/test_points_plans.py similarity index 87% rename from tests/test_grid.py rename to tests/test_points_plans.py index d623106b..8548851c 100644 --- a/tests/test_grid.py +++ b/tests/test_points_plans.py @@ -11,8 +11,9 @@ RandomPoints, RelativeMultiPointPlan, RelativePosition, + TraversalOrder, ) -from useq._grid import OrderMode, _rect_indices, _spiral_indices +from useq._point_visiting import OrderMode, _rect_indices, _spiral_indices if TYPE_CHECKING: from useq._position import PositionBase @@ -126,7 +127,7 @@ def test_num_position_error() -> None: expected_rectangle = [(0.2, 1.1), (0.4, 0.2), (-0.3, 0.7)] -expected_ellipse = [(-0.0, -2.1), (0.7, 1.7), (-1.0, 1.3)] +expected_ellipse = [(-0.9, -1.1), (0.9, -0.5), (-0.8, -0.4)] @pytest.mark.parametrize("n_points", [3, 100]) @@ -156,6 +157,22 @@ def test_random_points(n_points: int, shape: str, seed: Optional[int]) -> None: list(rp) +@pytest.mark.parametrize("order", list(TraversalOrder)) +def test_traversal(order: TraversalOrder): + pp = RandomPoints( + num_points=30, + max_height=3000, + max_width=3000, + order=order, + random_seed=1, + start_at=10, + fov_height=300, + fov_width=300, + allow_overlap=False, + ) + list(pp) + + fov = {"fov_height": 200, "fov_width": 200} @@ -168,7 +185,12 @@ def test_random_points(n_points: int, shape: str, seed: Optional[int]) -> None: RelativePosition(**fov), ], ) -def test_points_plans(obj: RelativeMultiPointPlan): +def test_points_plans(obj: RelativeMultiPointPlan, monkeypatch: pytest.MonkeyPatch): + mpl = pytest.importorskip("matplotlib.pyplot") + monkeypatch.setattr(mpl, "show", lambda: None) + assert isinstance(obj, get_args(RelativeMultiPointPlan)) assert all(isinstance(x, RelativePosition) for x in obj) assert isinstance(obj.num_positions(), int) + + obj.plot() diff --git a/tests/test_sequence.py b/tests/test_sequence.py index 023e4a7a..3ffb746d 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -116,9 +116,9 @@ random_seed=0, ), [ - RelativePosition(x=-0.0, y=-2.1, name="0000"), - RelativePosition(x=0.7, y=1.7, name="0001"), - RelativePosition(x=-1.0, y=1.3, name="0002"), + RelativePosition(x=-0.9, y=-1.1, name="0000"), + RelativePosition(x=0.9, y=-0.5, name="0001"), + RelativePosition(x=-0.8, y=-0.4, name="0002"), ], ), ]