Skip to content

Commit

Permalink
allow a slicer as the input of drift_correction
Browse files Browse the repository at this point in the history
  • Loading branch information
hanjinliu committed Jul 19, 2024
1 parent 88e4d2e commit fbdcd23
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 10 deletions.
10 changes: 6 additions & 4 deletions impy/arrays/imgarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3996,7 +3996,7 @@ def track_drift(
def drift_correction(
self,
shift: Coords = None,
ref: ImgArray = None,
ref: ImgArray | Any | None = None,
*,
zero_ave: bool = True,
along: AxisLike | None = None,
Expand All @@ -4015,8 +4015,10 @@ def drift_correction(
shift : DataFrame or (N, D) array, optional
Translation vectors. If DataFrame, it must have columns named with all the symbols
contained in ``dims``.
ref : ImgArray, optional
The reference n-D image to determine drift, if ``shift`` was not given.
ref : ImgArray or slicer, optional
The reference n-D image to determine drift, if ``shift`` was not given. This
parameter can be a slicer, which will be used to slice the image to make a
reference.
zero_ave : bool, default is True
If True, average shift will be zero.
along : AxisLike, optional
Expand Down Expand Up @@ -4046,7 +4048,7 @@ def drift_correction(
if ref is None:
ref = self
elif not isinstance(ref, ImgArray):
raise TypeError(f"'ref' must be an ImgArray object, but got {type(ref)}")
ref = self[ref]
if ref.axes != [along] + dims:
from itertools import product
_c_axes = complement_axes([along] + dims, str(ref.axes))
Expand Down
13 changes: 7 additions & 6 deletions impy/arrays/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1149,10 +1149,11 @@ def pcc(x):
@same_dtype(asfloat=True)
@check_input_and_output_lazy
@dims_to_spatial_axes
def drift_correction(self, shift: Coords = None, ref: ImgArray = None, *,
zero_ave: bool = True, along: str = None, dims: Dims = 2,
update: bool = False, **affine_kwargs) -> LazyImgArray:

def drift_correction(
self, shift: Coords | None = None, ref: ImgArray | Any | None = None, *,
zero_ave: bool = True, along: str = None, dims: Dims = 2,
update: bool = False, **affine_kwargs,
) -> LazyImgArray:
if along is None:
along = find_first_appeared("tpzcia", include=self.axes, exclude=dims)
elif len(along) != 1:
Expand All @@ -1171,8 +1172,8 @@ def drift_correction(self, shift: Coords = None, ref: ImgArray = None, *,
UserWarning)
dims = _dims
elif not isinstance(ref, self.__class__):
raise TypeError(f"'ref' must be LazyImgArray object, but got {type(ref)}")
elif ref.axes != along + dims:
ref = self[ref]
if ref.axes != along + dims:
raise ValueError(f"Arguments `along`({along}) + `dims`({dims}) do not match "
f"axes of `ref`({ref.axes})")

Expand Down

0 comments on commit fbdcd23

Please sign in to comment.