diff --git a/waveprop/simulation.py b/waveprop/simulation.py index d6a06b7..3ade96e 100644 --- a/waveprop/simulation.py +++ b/waveprop/simulation.py @@ -38,6 +38,7 @@ def __init__( device_conv="cpu", random_shift=False, is_torch=False, + quantize=True, **kwargs ): """ @@ -63,6 +64,8 @@ def __init__( Whether to randomly shift the image, by default False. is_torch : bool, optional Whether to use pytorch, by default False. + quantize : bool, optional + Whether to quantize image, by default True. """ if is_torch: self.axes = (-2, -1) @@ -78,6 +81,7 @@ def __init__( self.mask2sensor = mask2sensor self.sensor = sensor_dict[sensor] self.random_shift = random_shift + self.quantize = quantize # for convolution if psf is not None: @@ -162,13 +166,14 @@ def propagate(self, obj, return_object_plane=False): if self.snr_db is not None: image_plane = add_shot_noise(image_plane, snr_db=self.snr_db) - # 5) Quantize as on sensor - image_plane = image_plane / image_plane.max() - image_plane = image_plane * self.max_val - if torch.is_tensor(image_plane): - image_plane = image_plane.to(self.output_dtype) - else: - image_plane = image_plane.astype(self.output_dtype) + # 5) (Optionaly) Quantize as on sensor + if self.quantize: + image_plane = image_plane / image_plane.max() + image_plane = image_plane * self.max_val + if torch.is_tensor(image_plane): + image_plane = image_plane.to(self.output_dtype) + else: + image_plane = image_plane.astype(self.output_dtype) if return_object_plane: return image_plane, object_plane