Skip to content

Commit

Permalink
fix: update w talley's suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
fdrgsp committed Sep 4, 2023
1 parent 802ad35 commit c16da53
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 128 deletions.
1 change: 0 additions & 1 deletion src/useq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
"AnyZPlan",
"AutoFocusPlan",
"AxesBasedAF",
"Point",
"Channel",
"GridFromEdges",
"GridRelative",
Expand Down
184 changes: 57 additions & 127 deletions src/useq/_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,12 @@
Any,
Callable,
ClassVar,
Iterable,
Iterator,
NamedTuple,
Optional,
Sequence,
Tuple,
Union,
cast,
)

import numpy as np
Expand All @@ -28,7 +26,7 @@
if TYPE_CHECKING:
from pydantic import ConfigDict

MAX_ITER = 5000
MIN_RANDOM_POINTS = 5000


class RelativeTo(Enum):
Expand Down Expand Up @@ -405,149 +403,81 @@ def is_relative(self) -> bool:
return True

def __iter__(self) -> Iterator[GridPosition]: # type: ignore
seed = np.random.RandomState(self.random_seed)
func = _POINTS_GENERATORS[self.shape]

min_distance = (self.fov_width, self.fov_height)

for x, y in func(
self.num_points,
self.max_width,
self.max_height,
min_distance,
self.allow_overlap,
self.random_seed,
):
yield GridPosition(x, y, 0, 0, True)
n_points = max(self.num_points, MIN_RANDOM_POINTS)
points: list[Tuple[float, float]] = []
for x, y in func(seed, n_points, self.max_width, self.max_height):
if (
self.allow_overlap
or (None in (self.fov_width, self.fov_height))
or _is_a_valid_point(points, x, y, self.fov_width, self.fov_height)
):
yield GridPosition(x, y, 0, 0, True)
points.append((x, y))
if len(points) >= self.num_points:
break
if len(points) < self.num_points:
_raise_warning(n_points, points)

def num_positions(self) -> int:
return self.num_points


def _random_points_in_ellipse(
num_points: int,
max_width: float,
max_height: float,
min_distance: Tuple[Optional[float], Optional[float]],
allow_overlap: bool,
random_seed: Optional[int],
) -> Iterable[Tuple[float, float]]:
"""Generate a random point around a circle with center (0, 0).
The point is within +/- radius_x and +/- radius_y at a random angle.
"""
_iter = _get_iterator(Shape.ELLIPSE, num_points, random_seed)
points: list[Tuple[float, float]] = []

try:
while len(points) < num_points:
x0, y0, angle = next(_iter)
x, y = (
math.sqrt(x0) * (max_width / 2) * math.cos(angle * 2 * math.pi),
math.sqrt(y0) * (max_height / 2) * math.sin(angle * 2 * math.pi),
)
if _check_validity(allow_overlap, min_distance, points, x, y):
points.append((x, y))

except StopIteration:
_raise_warning(points)

return points


def _random_points_in_rectangle(
num_points: int,
max_width: float,
max_height: float,
min_distance: Tuple[Optional[float], Optional[float]],
allow_overlap: bool,
random_seed: Optional[int],
) -> Iterable[Tuple[float, float]]:
"""Generate a random point around a rectangle with center (0, 0).
The point is within the bounding box (-width/2, -height/2, width, height)
"""
_iter = _get_iterator(Shape.RECTANGLE, num_points, random_seed)
points: list[Tuple[float, float]] = []

try:
while len(points) < num_points:
x0, y0 = next(_iter)
x, y = (
(x0 * max_width) - (max_width / 2),
(y0 * max_height) - (max_height / 2),
)
if _check_validity(allow_overlap, min_distance, points, x, y):
points.append((x, y))

except StopIteration:
_raise_warning(points)

return points


def _get_iterator(
shape: Shape, num_points: int, random_seed: Optional[int]
) -> Iterator[Tuple[float, ...]]:
"""Return an iterator of random numbers between 0 and 1 with size depending on the
`shape`.
""" # noqa: D205
# set the numpy random seed
seed = np.random.RandomState(random_seed)
# setting the max array size as the max number of iterations
iter_size = max(num_points, MAX_ITER)
# generate random numbers between 0 and 1 and shape them in an iterator depending
# on the shape
size = (iter_size, 2) if shape == Shape.RECTANGLE else (iter_size, 3)
return iter(seed.uniform(0, 1, size=size))


def _check_validity(
allow_overlap: bool,
min_distance: Tuple[Optional[float], Optional[float]],
points: list[Tuple[float, float]],
x: float,
y: float,
) -> bool:
"""Return True if `allow_overlap` is True, if `None` is in `min_distance` or if
the `(x, y)` point is at least `min_distance` away from all other points.
""" # noqa: D205
if allow_overlap or None in min_distance:
return True
min_distance = cast(Tuple[float, float], min_distance)
return _is_a_valid_point(points, x, y, *min_distance)


def _is_a_valid_point(
points: list[Tuple[float, float]],
x: float,
y: float,
min_dist_x: float,
min_dist_y: float,
min_dist_x: float | None,
min_dist_y: float | None,
) -> bool:
"""Return True if the `(x, y)` point is at least `min_dist_x` and `min_dist_y` away
from all other points (using Manhattan distance).
""" # noqa: D205
for point in points:
point_x, point_y = point
if abs(x - point_x) < min_dist_x and abs(y - point_y) < min_dist_y:
return False
return True
"""Return True if the the point is at least min_dist away from all the others.
note: using Manhattan distance.
"""
if min_dist_x is None or min_dist_y is None:
return True
return not any(
abs(x - point_x) < min_dist_x and abs(y - point_y) < min_dist_y
for point_x, point_y in points
)

def _raise_warning(points: list[Tuple[float, float]]) -> None:

def _raise_warning(n_points: int, points: list[Tuple[float, float]]) -> None:
"""Raise a warning if the number of points is less than the requested number."""
warnings.warn(
f"Max number of iterations reached ({MAX_ITER}). "
f"Max number of iterations reached ({n_points}). "
f"Only {len(points)} points were found.",
stacklevel=2,
)


# function that takes in num_points, max_width, max_height and returns
# an iterable of (x, y) points
PointGenerator = Callable[
[int, float, float, Tuple[Optional[float], Optional[float]], bool, Optional[int]],
Iterable[Tuple[float, float]],
]
def _random_points_in_ellipse(
seed: np.random.RandomState, n_points: int, max_width: float, max_height: float
) -> Iterator[Tuple[float, float]]:
"""Generate a random point around a circle with center (0, 0).
The point is within +/- radius_x and +/- radius_y at a random angle.
"""
for x0, y0, angle in seed.uniform(0, 1, size=(n_points, 3)):
yield (
np.sqrt(x0) * (max_width / 2) * np.cos(angle * 2 * np.pi),
np.sqrt(y0) * (max_height / 2) * np.sin(angle * 2 * np.pi),
)


def _random_points_in_rectangle(
seed: np.random.RandomState, n_points: int, max_width: float, max_height: float
) -> Iterator[Tuple[float, float]]:
"""Generate a random point around a rectangle with center (0, 0).
The point is within the bounding box (-width/2, -height/2, width, height).
"""
for x0, y0 in seed.uniform(0, 1, size=(n_points, 2)):
yield (x0 * max_width) - (max_width / 2), (y0 * max_height) - (max_height / 2)


PointGenerator = Callable[[np.random.RandomState, int, float, float], np.ndarray]
_POINTS_GENERATORS: dict[Shape, PointGenerator] = {
Shape.ELLIPSE: _random_points_in_ellipse,
Shape.RECTANGLE: _random_points_in_rectangle,
Expand Down

0 comments on commit c16da53

Please sign in to comment.