Skip to content

Commit

Permalink
allow start_at to be a point
Browse files Browse the repository at this point in the history
  • Loading branch information
tlambert03 committed Jul 8, 2024
1 parent b623130 commit 8d18eae
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 11 deletions.
4 changes: 3 additions & 1 deletion src/useq/_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ def replace(self, **kwargs: Any) -> "Self":
assumes that all objects are valid and will not perform any validation or
casting.
"""
return type(self).model_validate({**self.model_dump(exclude={"uid"}), **kwargs})
# only get values for top level fields
d = {k: getattr(self, k) for k in self.model_fields if k != "uid"}
return type(self).model_validate({**d, **kwargs})

def __repr_args__(self) -> "ReprArgs":
"""Only show fields that are not None or equal to their default value."""
Expand Down
30 changes: 20 additions & 10 deletions src/useq/_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,9 +378,12 @@ class RandomPoints(_MultiPointPlan[RelativePosition]):
Order in which the points will be visited. If None, order is simply the order
in which the points are generated (random). Use 'nearest_neighbor' or
'two_opt' to order the points in a more structured way.
start_at : int
Index of the point to start at. This is only used if `order` is
'nearest_neighbor' or 'two_opt'.
start_at : int | RelativePosition
Position or index of the point to start at. This is only used if `order` is
'nearest_neighbor' or 'two_opt'. If a position is provided, it will *always*
be included in the list of points. If an index is provided, it must be less than
the number of points, and corresponds to the index of the (randomly generated)
points; this likely only makes sense when `random_seed` is provided.
"""

num_points: Annotated[int, Gt(0)]
Expand All @@ -390,11 +393,11 @@ class RandomPoints(_MultiPointPlan[RelativePosition]):
random_seed: Optional[int] = None
allow_overlap: bool = True
order: Optional[TraversalOrder] = TraversalOrder.TWO_OPT
start_at: Annotated[int, Ge(0)] = 0
start_at: Union[RelativePosition, Annotated[int, Ge(0)]] = 0

@model_validator(mode="after")
def _validate_startat(self) -> Self:
if self.start_at > (self.num_points - 1):
if isinstance(self.start_at, int) and self.start_at > (self.num_points - 1):
warnings.warn(
"start_at is greater than the number of points. "
"Setting start_at to last point.",
Expand All @@ -407,16 +410,23 @@ def __iter__(self) -> Iterator[RelativePosition]: # type: ignore [override]
seed = np.random.RandomState(self.random_seed)
func = _POINTS_GENERATORS[self.shape]

points: Iterable[Tuple[float, float]]
points: list[Tuple[float, float]] = []
needed_points = self.num_points
start_at = self.start_at
if isinstance(start_at, RelativePosition):
points = [(start_at.x, start_at.y)]
needed_points -= 1
start_at = 0

# in the easy case, just generate the requested number of points
if self.allow_overlap or self.fov_width is None or self.fov_height is None:
points = func(seed, self.num_points, self.max_width, self.max_height)
_points = func(seed, needed_points, self.max_width, self.max_height)
points.extend(_points)

else:
# if we need to avoid overlap, generate points, check if they are valid, and
# repeat until we have enough
points = []
per_iter = 100
per_iter = needed_points
tries = 0
while tries < MIN_RANDOM_POINTS and len(points) < self.num_points:
candidates = func(seed, per_iter, self.max_width, self.max_height)
Expand All @@ -435,7 +445,7 @@ def __iter__(self) -> Iterator[RelativePosition]: # type: ignore [override]
)

if self.order is not None:
points = self.order(points, start_at=self.start_at)
points = self.order(points, start_at=start_at) # type: ignore [assignment]

for idx, (x, y) in enumerate(points):
yield RelativePosition(x=x, y=y, name=f"{str(idx).zfill(4)}")
Expand Down

0 comments on commit 8d18eae

Please sign in to comment.