diff --git a/examples/square_ap_poly_video_gpu.py b/examples/square_ap_poly_video_gpu.py index fed8ce8..c5ae996 100644 --- a/examples/square_ap_poly_video_gpu.py +++ b/examples/square_ap_poly_video_gpu.py @@ -53,21 +53,21 @@ filenames = [] frames = [] + def simulate(i): - u_out_wv, _, _ = angular_spectrum( - u_in=u_in * gain, wv=cs.wv[i], d1=d1, dz=dz, device=device - ) + u_out_wv, _, _ = angular_spectrum(u_in=u_in * gain, wv=cs.wv[i], d1=d1, dz=dz, device=device) if plot_int: res = torch.real(u_out_wv * np.conjugate(u_out_wv)) else: res = torch.abs(u_out_wv) return res + for dz in dz_vals: """loop over wavelengths for simulation""" start_time = time.time() - + u_out = Parallel(n_jobs=n_jobs)(delayed(simulate)(i) for i in range(cs.n_wavelength)) u_out = torch.stack(u_out).permute(1, 2, 0) diff --git a/waveprop/dataset_util.py b/waveprop/dataset_util.py index 5d6d998..f745c98 100644 --- a/waveprop/dataset_util.py +++ b/waveprop/dataset_util.py @@ -5,6 +5,10 @@ from PIL import Image import os import numpy as np +import glob +from waveprop.devices import sensor_dict +from waveprop.simulation import FarFieldSimulator +from abc import abstractmethod class Datasets(object): @@ -13,10 +17,149 @@ class Datasets(object): FLICKR8k = "FLICKR" -# TODO : abstract parent class for Dataset, add distance for far-field propagation. See MNIST # TODO : take into account FOV and offset +class SimulatedDataset(Dataset): + """ + Abstract class for simulated datasets. + """ + + def __init__( + self, + transform_list=None, + psf=None, + target="original", + random_vflip=False, + random_hflip=False, + random_rotate=False, + **kwargs, + ): + """ + + Parameters + ---------- + transform_list : list of torchvision.transforms, optional + List of transforms to apply to image, by default None. + psf : np.ndarray, optional + Point spread function, by default None. + target : str, optional + Target to return, by default "original". + "original" : return propagated image and original image. + "object_plane" : return propagated image and object plane. + "label" : return propagated image and label. + random_vflip : float, optional + Probability of vertical flip, by default False. + random_hflip : float, optional + Probability of horizontal flip, by default False. + random_rotate : float, optional + Maximum angle of rotation, by default False. + """ + + self.target = target + + # random transforms + self._transform = None + if transform_list is None: + transform_list = [] + if random_vflip: + transform_list.append(transforms.RandomVerticalFlip(p=random_vflip)) + if random_hflip: + transform_list.append(transforms.RandomHorizontalFlip(p=random_hflip)) + if random_rotate: + transform_list.append(transforms.RandomRotation(random_rotate)) + if len(transform_list) > 0: + self._transform = transforms.Compose(transform_list) + + # initialize simulator + if psf is not None: + if psf.shape[-1] <= 3: + raise ValueError("Channel dimension should not be last.") + self.sim = FarFieldSimulator(psf=psf, is_torch=True, **kwargs) + + @abstractmethod + def get_image(self, index): + raise NotImplementedError + + def __getitem__(self, index): + + # load image + img, label = self.get_image(index) + if self._transform is not None: + img = self._transform(img) + + # propagate and return with desired output + if self.target == "original": + return self.sim.propagate(img), img + elif self.target == "object_plane": + return self.sim.propagate(img, return_object_plane=True) + elif self.target == "label": + return self.sim.propagate(img), label + + def __len__(self): + return self.n_files + + +class SimulatedDatasetFolder(SimulatedDataset): + """ + Dataset of propagated images from a folder of images. + """ + + def __init__(self, path, image_ext="jpg", n_files=None, **kwargs): + """ + Parameters + ---------- + path : str + Path to folder of images. + image_ext : str, optional + Extension of images, by default "jpg". + n_files : int, optional + Number of files to load, by default load all. + """ + + self.path = path + self._files = glob.glob(os.path.join(self.path, f"*.{image_ext}")) + if n_files is None: + self.n_files = len(self._files) + else: + self.n_files = n_files + self._files = self._files[:n_files] + + # initialize simulator + super(SimulatedDatasetFolder, self).__init__( + transform_list=[transforms.ToTensor()], **kwargs + ) + + def get_image(self, index): + img = Image.open(self._files[index]) + label = None + return img, label + + +class SimulatedPytorchDataset(SimulatedDataset): + """ + Dataset of propagated images from a torch Dataset. + """ + + def __init__(self, dataset, **kwargs): + """ + Parameters + ---------- + dataset : torch.utils.data.Dataset + Dataset to propagate. + """ + + assert isinstance(dataset, Dataset) + self.dataset = dataset + self.n_files = len(dataset) + + # initialize simulator + super(SimulatedPytorchDataset, self).__init__(**kwargs) + + def get_image(self, index): + return self.dataset[index] + + class MNISTDataset(datasets.MNIST): def __init__( self, @@ -34,7 +177,7 @@ def __init__( vflip=True, grayscale=True, scale=(1, 1), - **kwargs + **kwargs, ): """ MNIST - 60'000 examples of 28x28 @@ -103,7 +246,7 @@ def __init__( scale=(1, 1), download=True, vflip=True, - **kwargs + **kwargs, ): """ CIFAR10 - 50;000 examples of 32x32 @@ -172,7 +315,7 @@ def __init__( grayscale=False, device=None, scale=(1, 1), - **kwargs + **kwargs, ): """ Flickr8k - varied, around 400x500 diff --git a/waveprop/noise.py b/waveprop/noise.py new file mode 100644 index 0000000..d57b6e8 --- /dev/null +++ b/waveprop/noise.py @@ -0,0 +1,44 @@ +import numpy as np +from scipy import ndimage +import torch + + +def add_shot_noise(image, snr_db, tol=1e-6): + """ + Add shot noise to image. + + Parameters + ---------- + image : np.ndarray + Image. + snr_db : float + Signal-to-noise ratio in dB. + tol : float, optional + Tolerance for noise variance, by default 1e-6. + + Returns + ------- + np.ndarray + Image with added shot noise. + + """ + + if torch.is_tensor(image): + with torch.no_grad(): + image_np = image.cpu().numpy() + else: + image_np = image + + if image_np.min() < 0: + image_np -= image_np.min() + noise = np.random.poisson(image_np) + + sig_var = ndimage.variance(image_np) + noise_var = np.maximum(ndimage.variance(noise), tol) + fact = np.sqrt(sig_var / noise_var / (10 ** (snr_db / 10))) + + noise = fact * noise + if torch.is_tensor(image): + noise = torch.from_numpy(noise).to(image.device) + + return image + fact * noise diff --git a/waveprop/pytorch_util.py b/waveprop/pytorch_util.py index c5e7391..b2176ef 100644 --- a/waveprop/pytorch_util.py +++ b/waveprop/pytorch_util.py @@ -72,3 +72,64 @@ def fftconvolve(in1, in2, mode=None, axes=None): return torch.complex(_real, _imag) else: return crop(ret, top=y_pad_edge, left=x_pad_edge, height=s1[axes[0]], width=s1[axes[1]]) + + +class RealFFTConvolve2D: + def __init__(self, filter, mode=None, axes=(-2, -1), img_shape=None, device=None): + """ + Operator that performs convolution in Fourier domain, and assumes + real-valued signals. Useful if convolving with the same filter, i.e. + avoid computing FFT of same filter. + + Assume arrays of shape [..., H, W], where ... means an arbitrary number of leading dimensions. + + Parameters + ---------- + filter array_like + 2D filter to use. Must be of shape (channels, height, width) even if + only one channel. + img_shape : tuple + If image different shape than filter, specify here. + dtype : float32 or float64 + Data type to use for optimization. + """ + assert torch.is_tensor(filter) + if device is not None: + filter = filter.to(device) + self.device = device + + self.filter_shape = filter.shape + if img_shape is None: + self.img_shape = filter.shape + else: + assert len(img_shape) == 3 + self.img_shape = img_shape + if axes is None: + self.shape = [ + self.filter_shape[i] + self.img_shape[i] - 1 for i in range(len(self.filter_shape)) + ] + else: + self.shape = [self.filter_shape[i] + self.img_shape[i] - 1 for i in axes] + self.axes = axes + if mode is not None: + if mode != "same": + raise ValueError(f"{mode} mode not supported ") + + self.filter_freq = torch.fft.rfftn(filter, self.shape, dim=axes) + + def __call__(self, x): + orig_device = x.device + if self.device is not None: + x = x.to(self.device) + x_freq = torch.fft.rfftn(x, self.shape, dim=self.axes) + ret = torch.fft.irfftn(self.filter_freq * x_freq, self.shape, dim=self.axes) + + y_pad_edge = int((self.shape[0] - self.img_shape[self.axes[0]]) / 2) + x_pad_edge = int((self.shape[1] - self.img_shape[self.axes[1]]) / 2) + return crop( + ret, + top=y_pad_edge, + left=x_pad_edge, + height=self.img_shape[self.axes[0]], + width=self.img_shape[self.axes[1]], + ).to(orig_device) diff --git a/waveprop/simulation.py b/waveprop/simulation.py new file mode 100644 index 0000000..394c57e --- /dev/null +++ b/waveprop/simulation.py @@ -0,0 +1,180 @@ +import numpy as np +from waveprop.util import prepare_object_plane, resize +from torchvision.transforms.functional import resize as resize_torch +from waveprop.devices import sensor_dict, SensorParam +from waveprop.noise import add_shot_noise +import torch +from waveprop.pytorch_util import RealFFTConvolve2D +import warnings + + +class FarFieldSimulator(object): + + """ + Simulate far-field propagation with the following steps: + 1. Resize digital image for desired object height and to PSF resolution. + 2. Convolve with PSF + 3. (Optionally) Resize to lower sensor resolution. + 4. (Optionally) Add shot noise + 5. Quantize + + + Images and PSFs should be one of following shape + - For numpy arrays: (H, W) for grayscale and (H, W, 3) for RGB. + - For PyTorch tensors: (..., H, W) + + """ + + def __init__( + self, + object_height, + scene2mask, + mask2sensor, + sensor, + psf=None, + output_dim=None, + snr_db=None, + max_val=255, + device_conv="cpu", + random_shift=False, + is_torch=False, + **kwargs + ): + """ + Parameters + ---------- + psf : np.ndarray, optional. + Point spread function. If not provided, return image at object plane. + object_height : float or (float, float) + Height of object in meters. Or range of values to randomly sample from. + scene2mask : float + Distance from scene to mask in meters. + mask2sensor : float + Distance from mask to sensor in meters. + sensor : str + Sensor name. + snr_db : float, optional + Signal-to-noise ratio in dB, by default None. + max_val : int, optional + Maximum value of image, by default 255. + device_conv : str, optional + Device to use for convolution (when using pytorch), by default "cpu". + random_shift : bool, optional + Whether to randomly shift the image, by default False. + is_torch : bool, optional + Whether to use pytorch, by default False. + """ + if is_torch: + self.axes = (-2, -1) + output_dtype = torch.uint8 + else: + self.axes = (0, 1) + output_dtype = np.uint8 + self.is_torch = is_torch + + # for resizing + self.object_height = object_height + self.scene2mask = scene2mask + self.mask2sensor = mask2sensor + self.sensor = sensor_dict[sensor] + self.random_shift = random_shift + + # for convolution + if psf is not None: + self.conv_dim = np.array([psf.shape[_ax] for _ax in self.axes]) + self.fft_shape = 2 * np.array(self.conv_dim) - 1 + if torch.is_tensor(psf): + self.conv = RealFFTConvolve2D(psf, device=device_conv) + else: + self.H = np.fft.rfft2(psf, s=self.fft_shape, axes=self.axes) + # -- for removing padding + self.y_pad_edge = int( + (self.fft_shape[self.axes[0]] - self.conv_dim[self.axes[0]]) / 2 + ) + self.x_pad_edge = int( + (self.fft_shape[self.axes[1]] - self.conv_dim[self.axes[1]]) / 2 + ) + + # at sensor + self.output_dim = output_dim + self.snr_db = snr_db + self.max_val = max_val + self.output_dtype = output_dtype + + else: + # simply return object / scene plane + warnings.warn("No PSF provided. Returning image at object plane.") + self.fft_shape = None + assert output_dim is not None + self.conv_dim = np.array(output_dim) + + def propagate(self, obj, return_object_plane=False): + """ + + Parameters + ---------- + obj : np.ndarray + Object to propagate. + return_object_plane : bool, optional + Whether to return object plane, by default False. + """ + + if self.is_torch: + assert torch.is_tensor(obj) + + # 1) Resize image to PSF dimensions while keeping aspect ratio and + # setting object height to desired value. + if hasattr(self.object_height, "__len__"): + object_height = np.random.uniform(low=self.object_height[0], high=self.object_height[1]) + else: + object_height = self.object_height + + object_plane = prepare_object_plane( + obj=obj, + object_height=object_height, + scene2mask=self.scene2mask, + mask2sensor=self.mask2sensor, + sensor_size=self.sensor[SensorParam.SIZE], + sensor_dim=self.conv_dim, + random_shift=self.random_shift, + ) + + if self.fft_shape is not None: + # 2) Convolve with PSF + if torch.is_tensor(object_plane): + image_plane = self.conv(object_plane) + else: + I = np.fft.rfft2(object_plane, s=self.fft_shape, axes=self.axes) + image_plane = np.fft.irfft2(self.H * I, s=self.fft_shape, axes=self.axes) + image_plane = image_plane[ + self.y_pad_edge : self.y_pad_edge + self.conv_dim[0], + self.x_pad_edge : self.x_pad_edge + self.conv_dim[1], + ] + + # 3) (Optionally) Downsample to sensor size + if self.output_dim is not None: + if torch.is_tensor(obj): + image_plane = resize_torch(image_plane, size=self.output_dim) + else: + image_plane = resize(image_plane, shape=self.output_dim) + + # 4) (Optionally) Add shot noise + if self.snr_db is not None: + image_plane = add_shot_noise(image_plane, snr_db=self.snr_db) + + # 5) Quantize as on sensor + image_plane /= image_plane.max() + image_plane *= self.max_val + if torch.is_tensor(image_plane): + image_plane = image_plane.to(self.output_dtype) + else: + image_plane = image_plane.astype(self.output_dtype) + + if return_object_plane: + return image_plane, object_plane + else: + return image_plane + + else: + # return object plane for simulation with PSF at a different stage + return object_plane diff --git a/waveprop/util.py b/waveprop/util.py index 72821f7..39d8d15 100644 --- a/waveprop/util.py +++ b/waveprop/util.py @@ -6,6 +6,8 @@ import matplotlib.pyplot as plt import matplotlib.cm as cm from torchvision.transforms.functional import crop as crop_torch +from torchvision.transforms.functional import resize as resize_torch +import torch.nn.functional as F def ft2(g, delta): @@ -445,12 +447,10 @@ def zero_pad(u_in, pad=None): return np.pad(u_in, pad_width=pad_width, mode="constant", constant_values=0) -def resize(img, factor=None, shape=None, interpolation=cv2.INTER_CUBIC): +def resize(img, factor=None, shape=None, interpolation=cv2.INTER_CUBIC, axes=(0, 1)): """ Resize by given factor or to a given shape. - TODO support for PyTorch - Parameters ---------- img :py:class:`~numpy.ndarray` @@ -468,7 +468,8 @@ def resize(img, factor=None, shape=None, interpolation=cv2.INTER_CUBIC): """ min_val = img.min() max_val = img.max() - img_shape = np.array(img.shape)[:2] + img_shape = np.array([img.shape[_ax] for _ax in axes]) + if shape is None: assert factor is not None new_shape = tuple((img_shape * factor).astype(int)) @@ -479,3 +480,133 @@ def resize(img, factor=None, shape=None, interpolation=cv2.INTER_CUBIC): return img resized = cv2.resize(img, dsize=shape[::-1], interpolation=interpolation) return np.clip(resized, min_val, max_val) + + +def realfftconvolve2d(image, kernel): + """Convolve image with kernel using real FFT. + + Parameters + ---------- + image : np.ndarray + Image. + kernel : np.ndarray + Kernel. + + Returns + ------- + np.ndarray + Convolved image. + """ + image_shape = np.array(image.shape) + + fft_shape = image_shape + np.array(kernel.shape) - 1 + + H = np.fft.rfft2(kernel, s=fft_shape) + I = np.fft.rfft2(image, s=fft_shape) + output = np.fft.irfft2(H * I, s=fft_shape) + + # crop out zero padding + y_pad_edge = int((fft_shape[0] - image_shape[0]) / 2) + x_pad_edge = int((fft_shape[1] - image_shape[1]) / 2) + output = output[ + y_pad_edge : y_pad_edge + image_shape[0], x_pad_edge : x_pad_edge + image_shape[1] + ] + return output + + +def prepare_object_plane( + obj, + object_height, + scene2mask, + mask2sensor, + sensor_size, + sensor_dim, + random_shift=False, +): + """ + Prepare object plane for convolution with PSF. + + Parameters + ---------- + obj : np.ndarray + Input image (HxWx3). + object_height : float + Height of object plane in meters. + scene2mask : float + Distance from scene to mask in meters. + mask2sensor : float + Distance from mask to sensor in meters. + sensor_size : tuple + Size of sensor in meters. + sensor_dim : tuple + Dimension of sensor in pixels. + random_shift : bool + Randomly shift resized obj in its plane. + + Returns + ------- + np.ndarray + Object plane. + """ + if torch.is_tensor(obj): + axes = (-2, -1) + else: + axes = (0, 1) + + # determine object height in pixels + input_dim = np.array([obj.shape[_ax] for _ax in axes]) + magnification = mask2sensor / scene2mask + scene_dim = np.array(sensor_size) / magnification + object_height_pix = int(np.round(object_height / scene_dim[1] * sensor_dim[1])) + scaling = object_height_pix / input_dim[1] + object_dim = tuple((np.round(input_dim * scaling)).astype(int)) + + if torch.is_tensor(obj): + object_plane = resize_torch(obj, size=object_dim) + else: + object_plane = resize(obj, shape=object_dim) + + # pad object plane to convolution size + padding = sensor_dim - object_dim + left = padding[1] // 2 + right = padding[1] - left + top = padding[0] // 2 + bottom = padding[0] - top + + if top < 0: + top = 0 + bottom = 0 + if left < 0: + left = 0 + right = 0 + + if torch.is_tensor(obj): + object_plane = torch.nn.functional.pad( + object_plane, pad=(left, right, top, bottom), mode="constant", value=0.0 + ) + + object_plane_shape = np.array(object_plane.shape[-2:]) + + else: + pad_width = [(0, 0) for _ in range(len(obj.shape))] + pad_width[axes[0]] = (top, bottom) + pad_width[axes[1]] = (left, right) + pad_width = tuple(pad_width) + object_plane = np.pad(object_plane, pad_width=pad_width, mode="constant") + + object_plane_shape = np.array(object_plane.shape[:2]) + + # remove extra pixels if height extended beyond sensor + if (object_plane_shape != sensor_dim).any(): + object_plane = crop(object_plane, shape=sensor_dim) + + if random_shift: + hshift = int(np.random.uniform(low=-left, high=right)) + vshift = int(np.random.uniform(low=-bottom, high=top)) + if torch.is_tensor(obj): + object_plane = torch.roll(object_plane, shifts=(vshift, hshift), dims=axes) + else: + object_plane = np.roll(object_plane, shift=hshift, axis=axes[1]) + object_plane = np.roll(object_plane, shift=vshift, axis=axes[0]) + + return object_plane