diff --git a/lensless/utils/dataset.py b/lensless/utils/dataset.py index 34c5aa78..725f418b 100644 --- a/lensless/utils/dataset.py +++ b/lensless/utils/dataset.py @@ -1565,7 +1565,9 @@ def _get_images_pair(self, idx): if len(background_np.shape) == 2: warnings.warn(f"Converting background[{idx}] to RGB") - background_np = np.stack([background_np] * 3, axis=2) if not None else None + background_np = ( + np.stack([background_np] * 3, axis=2) if background_np is not None else None + ) elif len(background_np.shape) == 3: pass @@ -1573,12 +1575,16 @@ def _get_images_pair(self, idx): if lensless_np.dtype == np.uint8: lensless_np = lensless_np.astype(np.float32) / 255 lensed_np = lensed_np.astype(np.float32) / 255 - background_np = background_np.astype(np.float32) / 255 if not None else None + background_np = ( + background_np.astype(np.float32) / 255 if background_np is not None else None + ) else: # 16 bit lensless_np = lensless_np.astype(np.float32) / 65535 lensed_np = lensed_np.astype(np.float32) / 65535 - background_np = background_np.astype(np.float32) / 65535 if not None else None + background_np = ( + background_np.astype(np.float32) / 65535 if background_np is not None else None + ) # downsample if necessary if self.downsample_lensless != 1.0: @@ -1591,13 +1597,13 @@ def _get_images_pair(self, idx): factor=1 / self.downsample_lensless, interpolation=cv2.INTER_NEAREST, ) - if not None + if background_np is not None else None ) lensless = lensless_np lensed = lensed_np - background = background_np if not None else None + background = background_np if background_np is not None else None if self.simulator is not None: # convert to torch @@ -1640,7 +1646,7 @@ def __getitem__(self, idx): # to torch lensless = torch.from_numpy(lensless) lensed = torch.from_numpy(lensed) - background = torch.from_numpy(background) if not None else None + background = torch.from_numpy(background) if background is not None else None # If [H, W, C] -> [D, H, W, C] if len(lensless.shape) == 3: lensless = lensless.unsqueeze(0)