Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New iterator pattern [wip] #192

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 143 additions & 0 deletions example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import abc
import sys
from dataclasses import dataclass
from itertools import count, islice, product
from typing import Iterable, Iterator, Sequence, TypeVar, cast

from useq._mda_event import MDAEvent

T = TypeVar("T")


class AxisIterator(Iterable[T]):
INFINITE = -1

@property
@abc.abstractmethod
def axis_key(self) -> str:
"""A string id representing the axis."""

def __iter__(self) -> Iterator[T]:
"""Iterate over the axis."""

def length(self) -> int:
"""Return the number of axis values.

If the axis is infinite, return -1.
"""
return self.INFINITE

@abc.abstractmethod
def create_event_kwargs(cls, val: T) -> dict: ...

def should_skip(cls, kwargs: dict) -> bool:
return False


class TimePlan(AxisIterator[float]):
def __init__(self, tpoints: Sequence[float]) -> None:
self._tpoints = tpoints

axis_key = "t"

def __iter__(self) -> Iterator[float]:
yield from self._tpoints

def length(self) -> int:
return len(self._tpoints)

def create_event_kwargs(cls, val: float) -> dict:
return {"min_start_time": val}


class ZPlan(AxisIterator[int]):
def __init__(self, stop: int | None = None) -> None:
self._stop = stop
self.acquire_every = 2

axis_key = "z"

def __iter__(self) -> Iterator[int]:
if self._stop is not None:
return iter(range(self._stop))
return count()

def length(self) -> int:
return self._stop or self.INFINITE

def create_event_kwargs(cls, val: int) -> dict:
return {"z_pos": val}

def should_skip(self, event: dict) -> bool:
index = event["index"]
if "t" in index and index["t"] % self.acquire_every:
return True
return False


@dataclass
class MySequence:
axes: tuple[AxisIterator, ...]
order: tuple[str, ...]
chunk_size = 1000

@property
def is_infinite(self) -> bool:
"""Return `True` if the sequence is infinite."""
return any(ax.length() == ax.INFINITE for ax in self.axes)

def _enumerate_ax(
self, key: str, ax: Iterable[T], start: int = 0
) -> Iterable[tuple[str, int, T]]:
"""Return the key for an enumerated axis."""
for idx, val in enumerate(ax, start):
yield key, idx, val

def __iter__(self) -> MDAEvent:
ax_map: dict[str, type[AxisIterator]] = {ax.axis_key: ax for ax in self.axes}
for item in self._iter_inner():
event: dict = {"index": {}}
for axis_key, index, value in item:
ax_type = ax_map[axis_key]
event["index"][axis_key] = index
event.update(ax_type.create_event_kwargs(value))

if not any(ax_type.should_skip(event) for ax_type in ax_map.values()):
yield MDAEvent(**event)

def _iter_inner(self) -> Iterator[tuple[str, int, T]]:
"""Iterate over the sequence."""
ax_map = {ax.axis_key: ax for ax in self.axes}
sorted_axes = [ax_map[key] for key in self.order]
if not self.is_infinite:
iterators = (self._enumerate_ax(ax.axis_key, ax) for ax in sorted_axes)
yield from product(*iterators)
else:
idx = 0
while True:
yield from self._iter_infinite_slice(sorted_axes, idx, self.chunk_size)
idx += self.chunk_size

def _iter_infinite_slice(
self, sorted_axes: list[AxisIterator], start: int, chunk_size: int
) -> Iterator[tuple[str, T]]:
"""Iterate over a slice of an infinite sequence."""
iterators = []
for ax in sorted_axes:
if ax.length() is not ax.INFINITE:
iterator, begin = cast("Iterable", ax), 0
else:
# use islice to avoid calling product with infinite iterators
iterator, begin = islice(ax, start, start + chunk_size), start
iterators.append(self._enumerate_ax(ax.axis_key, iterator, begin))

return product(*iterators)


if __name__ == "__main__":
seq = MySequence(axes=(TimePlan((0, 1, 2, 3, 4)), ZPlan(3)), order=("t", "z"))
if seq.is_infinite:
print("Infinite sequence")
sys.exit(0)
for event in seq:
print(event)
67 changes: 67 additions & 0 deletions src/useq/_axis_iterable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from __future__ import annotations

from typing import (
TYPE_CHECKING,
ClassVar,
Iterator,
Protocol,
Sized,
TypeVar,
runtime_checkable,
)

from pydantic import BaseModel

if TYPE_CHECKING:
from useq._iter_sequence import MDAEventDict


# ------ Protocol that can be used as a field annotation in a Pydantic model ------

T = TypeVar("T")


@runtime_checkable
class AxisIterable(Protocol[T]):
@property
def axis_key(self) -> str:
"""A string id representing the axis. Prefer lowercase."""

def __iter__(self) -> Iterator[T]:
"""Iterate over the axis."""

def create_event_kwargs(self, val: T) -> MDAEventDict:
"""Convert a value from the iterator to kwargs for an MDAEvent."""

def length(self) -> int:
"""Return the number of axis values.

If the axis is infinite, return -1.
"""

def should_skip(self, kwargs: dict) -> bool:
"""Return True if the event should be skipped."""
return False


# ------- concrete base class/mixin that implements the above protocol -------


class AxisIterableBase(BaseModel):
axis_key: ClassVar[str]

def create_event_kwargs(self, val: T) -> MDAEventDict:
"""Convert a value from the iterator to kwargs for an MDAEvent."""
raise NotImplementedError

def length(self) -> int:
"""Return the number of axis values.

If the axis is infinite, return -1.
"""
if isinstance(self, Sized):
return len(self)
raise NotImplementedError

def should_skip(self, kwargs: dict) -> bool:
return False
53 changes: 51 additions & 2 deletions src/useq/_channel.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from typing import Optional
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Tuple

from pydantic import Field
from pydantic import Field, RootModel, model_validator

from useq._axis_iterable import AxisIterableBase
from useq._base_model import FrozenModel

if TYPE_CHECKING:
from useq._iter_sequence import MDAEventDict


class Channel(FrozenModel):
"""Define an acquisition channel.
Expand Down Expand Up @@ -38,3 +42,48 @@ class Channel(FrozenModel):
z_offset: float = 0.0
acquire_every: int = Field(default=1, gt=0) # acquire every n frames
camera: Optional[str] = None

@model_validator(mode="before")
def _validate_model(cls, value: Any) -> Any:
if isinstance(value, str):
return {"config": value}
return value


class Channels(RootModel, AxisIterableBase):
root: Tuple[Channel, ...]
axis_key: ClassVar[str] = "c"

def __iter__(self):
return iter(self.root)

def __getitem__(self, item):
return self.root[item]

def create_event_kwargs(self, val: Channel) -> "MDAEventDict":
"""Convert a value from the iterator to kwargs for an MDAEvent."""
d: MDAEventDict = {"channel": {"config": val.config, "group": val.group}}
if val.z_offset:
d["z_pos_rel"] = val.z_offset
return d

def length(self) -> int:
"""Return the number of axis values.

If the axis is infinite, return -1.
"""
return len(self.root)

def should_skip(cls, kwargs: dict) -> bool:
return False
# # skip channels
# if Axis.TIME in index and index[Axis.TIME] % channel.acquire_every:
# return True

# # only acquire on the middle plane:
# if (
# not channel.do_stack
# and z_plan is not None
# and index[Axis.Z] != z_plan.num_positions() // 2
# ):
# return True
34 changes: 10 additions & 24 deletions src/useq/_iter_sequence.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

from functools import lru_cache
from itertools import product
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, Iterator, cast
Expand All @@ -15,7 +14,7 @@

if TYPE_CHECKING:
from useq._mda_sequence import MDASequence
from useq._position import Position, PositionBase, RelativePosition
from useq._position import Position, RelativePosition


class MDAEventDict(TypedDict, total=False):
Expand All @@ -25,8 +24,11 @@ class MDAEventDict(TypedDict, total=False):
min_start_time: float | None
pos_name: str | None
x_pos: float | None
x_pos_rel: float | None
y_pos: float | None
y_pos_rel: float | None
z_pos: float | None
z_pos_rel: float | None
sequence: MDASequence | None
# properties: list[tuple] | None
metadata: dict
Expand All @@ -39,21 +41,6 @@ class PositionDict(TypedDict, total=False):
z_pos: float


@lru_cache(maxsize=None)
def _iter_axis(seq: MDASequence, ax: str) -> tuple[Channel | float | PositionBase, ...]:
return tuple(seq.iter_axis(ax))


@lru_cache(maxsize=None)
def _sizes(seq: MDASequence) -> dict[str, int]:
return {k: len(list(_iter_axis(seq, k))) for k in seq.axis_order}


@lru_cache(maxsize=None)
def _used_axes(seq: MDASequence) -> str:
return "".join(k for k in seq.axis_order if _sizes(seq)[k])


def iter_sequence(sequence: MDASequence) -> Iterator[MDAEvent]:
"""Iterate over all events in the MDA sequence.'.

Expand Down Expand Up @@ -143,9 +130,8 @@ def _iter_sequence(
MDAEvent
Each event in the MDA sequence.
"""
order = _used_axes(sequence)
# this needs to be tuple(...) to work for mypyc
axis_iterators = tuple(enumerate(_iter_axis(sequence, ax)) for ax in order)
order = sequence.used_axes
axis_iterators = (enumerate(sequence.iter_axis(ax)) for ax in order)
for item in product(*axis_iterators):
if not item: # the case with no events
continue # pragma: no cover
Expand Down Expand Up @@ -265,11 +251,11 @@ def _position_offsets(
def _parse_axes(
event: zip[tuple[str, Any]],
) -> tuple[
dict[str, int],
dict[str, int], # index
float | None, # time
Position | None,
RelativePosition | None,
Channel | None,
Position | None, # position
RelativePosition | None, # grid
Channel | None, # channel
float | None, # z
]:
"""Parse an individual event from the product of axis iterators.
Expand Down
Loading