Skip to content

Commit

Permalink
fix bug with channels=None in CellposeDenoiseModel (#1098)
Browse files Browse the repository at this point in the history
  • Loading branch information
carsen-stringer committed Feb 10, 2025
1 parent b9d25e7 commit fd6c853
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions cellpose/denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,7 @@ def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None,
styles (list, np.ndarray): style vector summarizing each image of size 256.
imgs (list of 2D/3D arrays): Restored images
"""

if isinstance(normalize, dict):
normalize_params = {**normalize_default, **normalize}
elif not isinstance(normalize, bool):
Expand All @@ -578,8 +579,11 @@ def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None,
# turn off special normalization for segmentation
normalize_params = normalize_default

# change channels for segmentation (denoise model outputs up to 2 channels)
channels_new = [0, 0] if channels[0] == 0 else [1, 2]
# change channels for segmentation
if channels is not None:
channels_new = [0, 0] if channels[0] == 0 else [1, 2]
else:
channels_new = None
# change diameter if self.ratio > 1 (upsampled to self.dn.diam_mean)
diameter = self.dn.diam_mean if self.dn.ratio > 1 else diameter
masks, flows, styles = self.cp.eval(
Expand Down Expand Up @@ -759,7 +763,7 @@ def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None,
else:
# reshape image
x = transforms.convert_image(x, channels, channel_axis=channel_axis,
z_axis=z_axis, do_3D=do_3D)
z_axis=z_axis, do_3D=do_3D, nchan=None)
if x.ndim < 4:
squeeze = True
x = x[np.newaxis, ...]
Expand Down Expand Up @@ -790,7 +794,7 @@ def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None,
elif rescale is None:
rescale = 1.0

if np.ptp(x[..., -1]) < 1e-3 or channels[-1] == 0:
if np.ptp(x[..., -1]) < 1e-3 or (channels is not None and channels[-1] == 0):
x = x[..., :1]

for c in range(x.shape[-1]):
Expand Down

0 comments on commit fd6c853

Please sign in to comment.