diff --git a/example.py b/example.py new file mode 100644 index 0000000..82e79f1 --- /dev/null +++ b/example.py @@ -0,0 +1,143 @@ +import abc +import sys +from dataclasses import dataclass +from itertools import count, islice, product +from typing import Iterable, Iterator, Sequence, TypeVar, cast + +from useq._mda_event import MDAEvent + +T = TypeVar("T") + + +class AxisIterator(Iterable[T]): + INFINITE = -1 + + @property + @abc.abstractmethod + def axis_key(self) -> str: + """A string id representing the axis.""" + + def __iter__(self) -> Iterator[T]: + """Iterate over the axis.""" + + def length(self) -> int: + """Return the number of axis values. + + If the axis is infinite, return -1. + """ + return self.INFINITE + + @abc.abstractmethod + def create_event_kwargs(cls, val: T) -> dict: ... + + def should_skip(cls, kwargs: dict) -> bool: + return False + + +class TimePlan(AxisIterator[float]): + def __init__(self, tpoints: Sequence[float]) -> None: + self._tpoints = tpoints + + axis_key = "t" + + def __iter__(self) -> Iterator[float]: + yield from self._tpoints + + def length(self) -> int: + return len(self._tpoints) + + def create_event_kwargs(cls, val: float) -> dict: + return {"min_start_time": val} + + +class ZPlan(AxisIterator[int]): + def __init__(self, stop: int | None = None) -> None: + self._stop = stop + self.acquire_every = 2 + + axis_key = "z" + + def __iter__(self) -> Iterator[int]: + if self._stop is not None: + return iter(range(self._stop)) + return count() + + def length(self) -> int: + return self._stop or self.INFINITE + + def create_event_kwargs(cls, val: int) -> dict: + return {"z_pos": val} + + def should_skip(self, event: dict) -> bool: + index = event["index"] + if "t" in index and index["t"] % self.acquire_every: + return True + return False + + +@dataclass +class MySequence: + axes: tuple[AxisIterator, ...] + order: tuple[str, ...] + chunk_size = 1000 + + @property + def is_infinite(self) -> bool: + """Return `True` if the sequence is infinite.""" + return any(ax.length() == ax.INFINITE for ax in self.axes) + + def _enumerate_ax( + self, key: str, ax: Iterable[T], start: int = 0 + ) -> Iterable[tuple[str, int, T]]: + """Return the key for an enumerated axis.""" + for idx, val in enumerate(ax, start): + yield key, idx, val + + def __iter__(self) -> MDAEvent: + ax_map: dict[str, type[AxisIterator]] = {ax.axis_key: ax for ax in self.axes} + for item in self._iter_inner(): + event: dict = {"index": {}} + for axis_key, index, value in item: + ax_type = ax_map[axis_key] + event["index"][axis_key] = index + event.update(ax_type.create_event_kwargs(value)) + + if not any(ax_type.should_skip(event) for ax_type in ax_map.values()): + yield MDAEvent(**event) + + def _iter_inner(self) -> Iterator[tuple[str, int, T]]: + """Iterate over the sequence.""" + ax_map = {ax.axis_key: ax for ax in self.axes} + sorted_axes = [ax_map[key] for key in self.order] + if not self.is_infinite: + iterators = (self._enumerate_ax(ax.axis_key, ax) for ax in sorted_axes) + yield from product(*iterators) + else: + idx = 0 + while True: + yield from self._iter_infinite_slice(sorted_axes, idx, self.chunk_size) + idx += self.chunk_size + + def _iter_infinite_slice( + self, sorted_axes: list[AxisIterator], start: int, chunk_size: int + ) -> Iterator[tuple[str, T]]: + """Iterate over a slice of an infinite sequence.""" + iterators = [] + for ax in sorted_axes: + if ax.length() is not ax.INFINITE: + iterator, begin = cast("Iterable", ax), 0 + else: + # use islice to avoid calling product with infinite iterators + iterator, begin = islice(ax, start, start + chunk_size), start + iterators.append(self._enumerate_ax(ax.axis_key, iterator, begin)) + + return product(*iterators) + + +if __name__ == "__main__": + seq = MySequence(axes=(TimePlan((0, 1, 2, 3, 4)), ZPlan(3)), order=("t", "z")) + if seq.is_infinite: + print("Infinite sequence") + sys.exit(0) + for event in seq: + print(event) diff --git a/src/useq/_axis_iterable.py b/src/useq/_axis_iterable.py new file mode 100644 index 0000000..560d0f4 --- /dev/null +++ b/src/useq/_axis_iterable.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from typing import ( + TYPE_CHECKING, + ClassVar, + Iterator, + Protocol, + Sized, + TypeVar, + runtime_checkable, +) + +from pydantic import BaseModel + +if TYPE_CHECKING: + from useq._iter_sequence import MDAEventDict + + +# ------ Protocol that can be used as a field annotation in a Pydantic model ------ + +T = TypeVar("T") + + +@runtime_checkable +class AxisIterable(Protocol[T]): + @property + def axis_key(self) -> str: + """A string id representing the axis. Prefer lowercase.""" + + def __iter__(self) -> Iterator[T]: + """Iterate over the axis.""" + + def create_event_kwargs(self, val: T) -> MDAEventDict: + """Convert a value from the iterator to kwargs for an MDAEvent.""" + + def length(self) -> int: + """Return the number of axis values. + + If the axis is infinite, return -1. + """ + + def should_skip(self, kwargs: dict) -> bool: + """Return True if the event should be skipped.""" + return False + + +# ------- concrete base class/mixin that implements the above protocol ------- + + +class AxisIterableBase(BaseModel): + axis_key: ClassVar[str] + + def create_event_kwargs(self, val: T) -> MDAEventDict: + """Convert a value from the iterator to kwargs for an MDAEvent.""" + raise NotImplementedError + + def length(self) -> int: + """Return the number of axis values. + + If the axis is infinite, return -1. + """ + if isinstance(self, Sized): + return len(self) + raise NotImplementedError + + def should_skip(self, kwargs: dict) -> bool: + return False diff --git a/src/useq/_channel.py b/src/useq/_channel.py index 0566aaf..c5790ba 100644 --- a/src/useq/_channel.py +++ b/src/useq/_channel.py @@ -1,9 +1,13 @@ -from typing import Optional +from typing import TYPE_CHECKING, Any, ClassVar, Optional, Tuple -from pydantic import Field +from pydantic import Field, RootModel, model_validator +from useq._axis_iterable import AxisIterableBase from useq._base_model import FrozenModel +if TYPE_CHECKING: + from useq._iter_sequence import MDAEventDict + class Channel(FrozenModel): """Define an acquisition channel. @@ -38,3 +42,48 @@ class Channel(FrozenModel): z_offset: float = 0.0 acquire_every: int = Field(default=1, gt=0) # acquire every n frames camera: Optional[str] = None + + @model_validator(mode="before") + def _validate_model(cls, value: Any) -> Any: + if isinstance(value, str): + return {"config": value} + return value + + +class Channels(RootModel, AxisIterableBase): + root: Tuple[Channel, ...] + axis_key: ClassVar[str] = "c" + + def __iter__(self): + return iter(self.root) + + def __getitem__(self, item): + return self.root[item] + + def create_event_kwargs(self, val: Channel) -> "MDAEventDict": + """Convert a value from the iterator to kwargs for an MDAEvent.""" + d: MDAEventDict = {"channel": {"config": val.config, "group": val.group}} + if val.z_offset: + d["z_pos_rel"] = val.z_offset + return d + + def length(self) -> int: + """Return the number of axis values. + + If the axis is infinite, return -1. + """ + return len(self.root) + + def should_skip(cls, kwargs: dict) -> bool: + return False + # # skip channels + # if Axis.TIME in index and index[Axis.TIME] % channel.acquire_every: + # return True + + # # only acquire on the middle plane: + # if ( + # not channel.do_stack + # and z_plan is not None + # and index[Axis.Z] != z_plan.num_positions() // 2 + # ): + # return True diff --git a/src/useq/_iter_sequence.py b/src/useq/_iter_sequence.py index 5a23604..15fbb00 100644 --- a/src/useq/_iter_sequence.py +++ b/src/useq/_iter_sequence.py @@ -1,6 +1,5 @@ from __future__ import annotations -from functools import lru_cache from itertools import product from types import MappingProxyType from typing import TYPE_CHECKING, Any, Iterator, cast @@ -15,7 +14,7 @@ if TYPE_CHECKING: from useq._mda_sequence import MDASequence - from useq._position import Position, PositionBase, RelativePosition + from useq._position import Position, RelativePosition class MDAEventDict(TypedDict, total=False): @@ -25,8 +24,11 @@ class MDAEventDict(TypedDict, total=False): min_start_time: float | None pos_name: str | None x_pos: float | None + x_pos_rel: float | None y_pos: float | None + y_pos_rel: float | None z_pos: float | None + z_pos_rel: float | None sequence: MDASequence | None # properties: list[tuple] | None metadata: dict @@ -39,21 +41,6 @@ class PositionDict(TypedDict, total=False): z_pos: float -@lru_cache(maxsize=None) -def _iter_axis(seq: MDASequence, ax: str) -> tuple[Channel | float | PositionBase, ...]: - return tuple(seq.iter_axis(ax)) - - -@lru_cache(maxsize=None) -def _sizes(seq: MDASequence) -> dict[str, int]: - return {k: len(list(_iter_axis(seq, k))) for k in seq.axis_order} - - -@lru_cache(maxsize=None) -def _used_axes(seq: MDASequence) -> str: - return "".join(k for k in seq.axis_order if _sizes(seq)[k]) - - def iter_sequence(sequence: MDASequence) -> Iterator[MDAEvent]: """Iterate over all events in the MDA sequence.'. @@ -143,9 +130,8 @@ def _iter_sequence( MDAEvent Each event in the MDA sequence. """ - order = _used_axes(sequence) - # this needs to be tuple(...) to work for mypyc - axis_iterators = tuple(enumerate(_iter_axis(sequence, ax)) for ax in order) + order = sequence.used_axes + axis_iterators = (enumerate(sequence.iter_axis(ax)) for ax in order) for item in product(*axis_iterators): if not item: # the case with no events continue # pragma: no cover @@ -265,11 +251,11 @@ def _position_offsets( def _parse_axes( event: zip[tuple[str, Any]], ) -> tuple[ - dict[str, int], + dict[str, int], # index float | None, # time - Position | None, - RelativePosition | None, - Channel | None, + Position | None, # position + RelativePosition | None, # grid + Channel | None, # channel float | None, # z ]: """Parse an individual event from the product of axis iterators. diff --git a/src/useq/_mda_event.py b/src/useq/_mda_event.py index f4b0327..eb02dee 100644 --- a/src/useq/_mda_event.py +++ b/src/useq/_mda_event.py @@ -94,11 +94,6 @@ class MDAEvent(UseqModel): exposure : float | None Exposure time in milliseconds. If not provided, implies use current exposure time. By default, `None`. - min_start_time : float | None - Minimum start time of this event, in seconds. If provided, the engine will - pause until this time has elapsed before starting this event. Times are - relative to the start of the sequence, or the last event with - `reset_event_timer` set to `True`. pos_name : str | None The name assigned to the position. By default, `None`. x_pos : float | None @@ -131,9 +126,18 @@ class MDAEvent(UseqModel): This is useful when the sequence of events being executed use the same illumination scheme (such as a z-stack in a single channel), and closing and opening the shutter between events would be slow. + min_start_time : float | None + Minimum start time of this event, in seconds. If provided, the engine should + pause until this time has elapsed before starting this event. Times are + relative to the start of the sequence, or the last event with + `reset_event_timer` set to `True`. + min_end_time : float | None + If provided, the engine should stop the entire sequence if the current time + exceeds this value. Times are relative to the start of the sequence, or the + last event with `reset_event_timer` set to `True`. reset_event_timer : bool If `True`, the engine should reset the event timer to the time of this event, - and future `min_start_time` values will be relative to this event. By default, + and future `min_start_time` values should be relative to this event. By default, `False`. """ @@ -141,7 +145,6 @@ class MDAEvent(UseqModel): index: Mapping[str, int] = Field(default_factory=lambda: MappingProxyType({})) channel: Optional[Channel] = None exposure: Optional[float] = Field(default=None, gt=0.0) - min_start_time: Optional[float] = None # time in sec pos_name: Optional[str] = None x_pos: Optional[float] = None y_pos: Optional[float] = None @@ -151,12 +154,19 @@ class MDAEvent(UseqModel): metadata: Dict[str, Any] = Field(default_factory=dict) action: AnyAction = Field(default_factory=AcquireImage) keep_shutter_open: bool = False + + min_start_time: Optional[float] = None # time in sec + min_end_time: Optional[float] = None # time in sec reset_event_timer: bool = False @field_validator("channel", mode="before") def _validate_channel(cls, val: Any) -> Any: return Channel(config=val) if isinstance(val, str) else val + @field_validator("index", mode="after") + def _validate_channel(cls, val: Mapping) -> MappingProxyType: + return MappingProxyType(val) + if field_serializer is not None: _si = field_serializer("index", mode="plain")(lambda v: dict(v)) _sx = field_serializer("x_pos", mode="plain")(_float_or_none) diff --git a/src/useq/_mda_sequence.py b/src/useq/_mda_sequence.py index fdc6a79..72e6d85 100644 --- a/src/useq/_mda_sequence.py +++ b/src/useq/_mda_sequence.py @@ -1,5 +1,6 @@ from __future__ import annotations +from types import MappingProxyType from typing import ( TYPE_CHECKING, Any, @@ -399,17 +400,37 @@ def shape(self) -> Tuple[int, ...]: """ return tuple(s for s in self.sizes.values() if s) + def _axis_size(self, axis: str) -> int: + """Return the size of a given axis. + + -1 indicates an infinite iterator. + """ + # TODO: make a generic interface for axes + if axis == Axis.TIME: + # note that this may be -1, which implies infinite + return self.time_plan.num_timepoints() if self.time_plan else 0 + if axis == Axis.POSITION: + return len(self.stage_positions) + if axis == Axis.Z: + return self.z_plan.num_positions() if self.z_plan else 0 + if axis == Axis.CHANNEL: + return len(self.channels) + if axis == Axis.GRID: + return self.grid_plan.num_positions() if self.grid_plan else 0 + raise ValueError(f"Invalid axis: {axis}") + @property def sizes(self) -> Mapping[str, int]: """Mapping of axis name to size of that axis.""" if self._sizes is None: - self._sizes = {k: len(list(self.iter_axis(k))) for k in self.axis_order} - return self._sizes + self._sizes = {k: self._axis_size(k) for k in self.axis_order} + return MappingProxyType(self._sizes) @property def used_axes(self) -> str: """Single letter string of axes used in this sequence, e.g. `ztc`.""" - return "".join(k for k in self.axis_order if self.sizes[k]) + sz = self.sizes + return "".join(k for k in self.axis_order if sz[k]) def iter_axis(self, axis: str) -> Iterator[Channel | float | PositionBase]: """Iterate over the positions or items of a given axis.""" diff --git a/src/useq/_multi_axis_sequence.py b/src/useq/_multi_axis_sequence.py new file mode 100644 index 0000000..30f1b82 --- /dev/null +++ b/src/useq/_multi_axis_sequence.py @@ -0,0 +1,122 @@ +from itertools import islice, product +from typing import Any, Iterable, Iterator, Sequence, Tuple, TypeVar, cast + +from pydantic import ConfigDict, field_validator + +from useq._axis_iterable import AxisIterable +from useq._base_model import UseqModel +from useq._mda_event import MDAEvent + +T = TypeVar("T") + +INFINITE = NotImplemented + + +class MultiDimSequence(UseqModel): + """A multi-dimensional sequence of events. + + Attributes + ---------- + axes : Tuple[AxisIterable, ...] + The individual axes to iterate over. + axis_order: tuple[str, ...] | None + An explicit order in which to iterate over the axes. + If `None`, axes are iterated in the order provided in the `axes` attribute. + Note that this may also be manually passed as an argument to the `iterate` + method. + chunk_size: int + For infinite sequences, the number of events to generate at a time. + """ + + axes: Tuple[AxisIterable, ...] = () + # if none, axes are used in order provided + axis_order: tuple[str, ...] | None = None + chunk_size: int = 10 + + model_config = ConfigDict(arbitrary_types_allowed=True) + + @field_validator("axes", mode="after") + def _validate_axes(cls, v: tuple[AxisIterable, ...]) -> tuple[AxisIterable, ...]: + keys = [x.axis_key for x in v] + if not len(keys) == len(set(keys)): + dupes = {k for k in keys if keys.count(k) > 1} + raise ValueError( + f"The following axis keys appeared more than once: {dupes}" + ) + return v + + @field_validator("axis_order", mode="before") + @classmethod + def _validate_axis_order(cls, v: Any) -> tuple[str, ...]: + if not isinstance(v, Iterable): + raise ValueError(f"axis_order must be iterable, got {type(v)}") + order = tuple(str(x).lower() for x in v) + if len(set(order)) < len(order): + raise ValueError(f"Duplicate entries found in acquisition order: {order}") + + return order + + @property + def is_infinite(self) -> bool: + """Return `True` if the sequence is infinite.""" + return any(ax.length() is INFINITE for ax in self.axes) + + def _enumerate_ax( + self, key: str, ax: Iterable[T], start: int = 0 + ) -> Iterable[tuple[str, int, T]]: + """Return the key for an enumerated axis.""" + for idx, val in enumerate(ax, start): + yield key, idx, val + + def __iter__(self) -> Iterator[MDAEvent]: # type: ignore [override] + return self.iterate() + + def iterate(self, axis_order: Sequence[str] | None = None) -> Iterator[MDAEvent]: + ax_map: dict[str, AxisIterable] = {ax.axis_key: ax for ax in self.axes} + _axis_order = axis_order or self.axis_order or list(ax_map) + if unknown_keys := set(_axis_order) - set(ax_map): + raise KeyError( + f"Unknown axis key(s): {unknown_keys!r}. Recognized axes: {set(ax_map)}" + ) + sorted_axes = [ax_map[key] for key in _axis_order] + if not sorted_axes: + return + + for item in self._iter_inner(sorted_axes): + event_index = {} + values = {} + for axis_key, idx, value in item: + event_index[axis_key] = idx + values[axis_key] = ax_map[axis_key].create_event_kwargs(value) + + if not any(ax_type.should_skip(event) for ax_type in ax_map.values()): + yield MDAEvent(**event) + + def _iter_inner( + self, sorted_axes: Sequence[AxisIterable] + ) -> Iterable[tuple[str, int, Any]]: + """Iterate over the sequence.""" + + if not self.is_infinite: + iterators = (self._enumerate_ax(ax.axis_key, ax) for ax in sorted_axes) + yield from product(*iterators) + else: + idx = 0 + while True: + yield from self._iter_infinite_slice(sorted_axes, idx, self.chunk_size) + idx += self.chunk_size + + def _iter_infinite_slice( + self, sorted_axes: list[AxisIterable], start: int, chunk_size: int + ) -> Iterable[tuple[str, int, Any]]: + """Iterate over a slice of an infinite sequence.""" + iterators = [] + for ax in sorted_axes: + if ax.length() is not INFINITE: + iterator, begin = cast("Iterable", ax), 0 + else: + # use islice to avoid calling product with infinite iterators + iterator, begin = islice(ax, start, start + chunk_size), start + iterators.append(self._enumerate_ax(ax.axis_key, iterator, begin)) + + yield from product(*iterators) diff --git a/src/useq/_plate.py b/src/useq/_plate.py index aa7a767..e5f8f66 100644 --- a/src/useq/_plate.py +++ b/src/useq/_plate.py @@ -4,6 +4,7 @@ from typing import ( TYPE_CHECKING, Any, + ClassVar, Iterable, List, Sequence, @@ -24,6 +25,7 @@ ) from typing_extensions import Annotated +from useq._axis_iterable import AxisIterableBase from useq._base_model import FrozenModel, UseqModel from useq._grid import RandomPoints, RelativeMultiPointPlan, Shape from useq._plate_registry import _PLATE_REGISTRY @@ -125,7 +127,7 @@ def from_str(cls, name: str) -> WellPlate: return WellPlate.model_validate(obj) -class WellPlatePlan(UseqModel, Sequence[Position]): +class WellPlatePlan(UseqModel, AxisIterableBase, Sequence[Position]): """A plan for acquiring images from a multi-well plate. Parameters @@ -168,6 +170,12 @@ class WellPlatePlan(UseqModel, Sequence[Position]): default_factory=RelativePosition, union_mode="left_to_right" ) + axis_key: ClassVar[str] = "p" + + def create_event_kwargs(cls, val: Position) -> dict: + """Convert a value from the iterator to kwargs for an MDAEvent.""" + return {"x_pos": val.x, "y_pos": val.y} + def __repr_args__(self) -> Iterable[Tuple[str | None, Any]]: for item in super().__repr_args__(): if item[0] == "selected_wells": diff --git a/src/useq/_position.py b/src/useq/_position.py index ec1c552..d8016ca 100644 --- a/src/useq/_position.py +++ b/src/useq/_position.py @@ -1,6 +1,16 @@ -from typing import TYPE_CHECKING, Generic, Iterator, Optional, SupportsIndex, TypeVar - -from pydantic import Field +from typing import ( + TYPE_CHECKING, + Any, + Generic, + Iterator, + Optional, + Sequence, + SupportsIndex, + TypeVar, +) + +import numpy as np +from pydantic import Field, model_validator from useq._base_model import FrozenModel, MutableModel @@ -71,6 +81,25 @@ def __round__(self, ndigits: "SupportsIndex | None" = None) -> "Self": # not sure why these Self types are not working return type(self).model_construct(**kwargs) # type: ignore [return-value] + @model_validator(mode="before") + @classmethod + def _validate_model(cls, value: Any) -> Any: + if isinstance(value, dict): + return value + if isinstance(value, Position): + return value.model_dump() + if isinstance(value, np.ndarray): + if value.ndim > 1: + raise ValueError(f"stage_positions must be 1D or 2D, got {value.ndim}D") + value = value.tolist() + if not isinstance(value, Sequence): # pragma: no cover + raise ValueError(f"stage_positions must be a sequence, got {type(value)}") + + x, *v = value + y, *v = v or (None,) + z = v[0] if v else None + return {"x": x, "y": y, "z": z} + class AbsolutePosition(PositionBase, FrozenModel): """An absolute position in 3D space.""" diff --git a/src/useq/_stage_positions.py b/src/useq/_stage_positions.py new file mode 100644 index 0000000..aa5a79b --- /dev/null +++ b/src/useq/_stage_positions.py @@ -0,0 +1,28 @@ +from typing import ClassVar, Tuple + +from pydantic import RootModel + +from useq import Position +from useq._axis_iterable import AxisIterableBase + + +class StagePositions(RootModel, AxisIterableBase): + root: Tuple[Position, ...] + axis_key: ClassVar[str] = "p" + + def __iter__(self): + return iter(self.root) + + def __getitem__(self, item): + return self.root[item] + + def create_event_kwargs(cls, val: Position) -> dict: + """Convert a value from the iterator to kwargs for an MDAEvent.""" + return {"x_pos": val.x, "y_pos": val.y} + + def length(self) -> int: + """Return the number of axis values. + + If the axis is infinite, return -1. + """ + return len(self.root) diff --git a/src/useq/_time.py b/src/useq/_time.py index 997b7aa..48cb879 100644 --- a/src/useq/_time.py +++ b/src/useq/_time.py @@ -1,5 +1,5 @@ from datetime import timedelta -from typing import Iterator, Sequence, Union +from typing import Any, Iterator, Sequence, Union from pydantic import BeforeValidator, Field, PlainSerializer from typing_extensions import Annotated @@ -23,14 +23,38 @@ def __iter__(self) -> Iterator[float]: # type: ignore for td in self.deltas(): yield td.total_seconds() + def length(self) -> int: + return self.num_timepoints() + + def should_skip(cls, kwargs: dict) -> bool: + return False + + def create_event_kwargs(self, val: Any) -> dict: + return {"min_start_time": val} + + @property + def axis_key(self) -> str: + """A string id representing the axis.""" + return "t" + def num_timepoints(self) -> int: + """Return the number of timepoints in the sequence. + + If the sequence is infinite, returns -1. + """ return self.loops # type: ignore # TODO def deltas(self) -> Iterator[timedelta]: + """Iterate over the time deltas between timepoints. + + If the sequence is infinite, yields indefinitely. + """ current = timedelta(0) - for _ in range(self.loops): # type: ignore # TODO + loops = self.num_timepoints() + while loops != 0: yield current current += self.interval # type: ignore # TODO + loops -= 1 class TIntervalLoops(TimePlan): @@ -101,6 +125,8 @@ class TIntervalDuration(TimePlan): @property def loops(self) -> int: + if self.interval == timedelta(0): + return -1 return self.duration // self.interval + 1 diff --git a/src/useq/_utils.py b/src/useq/_utils.py index 9a0cc87..ecc5693 100644 --- a/src/useq/_utils.py +++ b/src/useq/_utils.py @@ -4,6 +4,8 @@ from datetime import timedelta from typing import TYPE_CHECKING, Literal, NamedTuple, TypeVar +from useq._time import TIntervalDuration + if TYPE_CHECKING: from typing import Final @@ -158,7 +160,10 @@ def _time_phase_duration( # to actually acquire the data time_interval_s = s_per_timepoint - tot_duration = (phase.num_timepoints() - 1) * time_interval_s + s_per_timepoint + if isinstance(phase, TIntervalDuration): + tot_duration = phase.duration.total_seconds() + else: + tot_duration = (phase.num_timepoints() - 1) * time_interval_s + s_per_timepoint return tot_duration, time_interval_exceeded diff --git a/src/useq/_z.py b/src/useq/_z.py index c749cdb..cee1d81 100644 --- a/src/useq/_z.py +++ b/src/useq/_z.py @@ -1,7 +1,7 @@ from __future__ import annotations import math -from typing import Callable, Iterator, List, Sequence, Union +from typing import Any, Callable, ClassVar, Iterator, List, Sequence, Union import numpy as np from pydantic import field_validator @@ -23,6 +23,20 @@ def __iter__(self) -> Iterator[float]: # type: ignore positions = positions[::-1] yield from positions + def length(self) -> int: + return self.num_positions() + + def should_skip(cls, kwargs: dict) -> bool: + return False + + def create_event_kwargs(self, val: float) -> dict: + if self.is_relative: + return {"z_pos_rel": val} + else: + return {"z_pos": val} + + axis_key: ClassVar[str] = "z" + def _start_stop_step(self) -> tuple[float, float, float]: raise NotImplementedError @@ -31,7 +45,7 @@ def positions(self) -> Sequence[float]: if step == 0: return [start] stop += step / 2 # make sure we include the last point - return list(np.arange(start, stop, step)) + return [float(x) for x in np.arange(start, stop, step)] def num_positions(self) -> int: start, stop, step = self._start_stop_step() diff --git a/x.py b/x.py new file mode 100644 index 0000000..1c2f432 --- /dev/null +++ b/x.py @@ -0,0 +1,25 @@ +from rich import print + +from useq import TIntervalLoops, ZRangeAround +from useq._channel import Channel, Channels +from useq._mda_sequence import MDASequence +from useq._multi_axis_sequence import MultiDimSequence +from useq._stage_positions import StagePositions + +t = TIntervalLoops(interval=0.2, loops=4) +z = ZRangeAround(range=4, step=2) +p = StagePositions([(0, 0), (1, 1), (2, 2)]) +c = Channels( + [ + Channel(config="DAPI", do_stack=False), + Channel(config="FITC", z_offset=100), + Channel(config="Cy5", acquire_every=2), + ] +) +seq1 = MultiDimSequence(axes=(t, p, c, z)) +seq2 = MDASequence(time_plan=t, z_plan=z, stage_positions=list(p), channels=list(c)) +e1 = list(seq1) +e2 = list(seq2) + +print(e1[:5]) +print(e2[:5])