Skip to content

Commit

Permalink
Add simulator (#5)
Browse files Browse the repository at this point in the history
* Add utils for simulation.

* Add utils for simulating dataset.

* Clean up simulators for custom and pytorch datasets.
  • Loading branch information
ebezzam authored Jan 6, 2023
1 parent a7acee6 commit 552d2fd
Show file tree
Hide file tree
Showing 6 changed files with 571 additions and 12 deletions.
8 changes: 4 additions & 4 deletions examples/square_ap_poly_video_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,21 +53,21 @@
filenames = []
frames = []


def simulate(i):
u_out_wv, _, _ = angular_spectrum(
u_in=u_in * gain, wv=cs.wv[i], d1=d1, dz=dz, device=device
)
u_out_wv, _, _ = angular_spectrum(u_in=u_in * gain, wv=cs.wv[i], d1=d1, dz=dz, device=device)
if plot_int:
res = torch.real(u_out_wv * np.conjugate(u_out_wv))
else:
res = torch.abs(u_out_wv)
return res


for dz in dz_vals:
"""loop over wavelengths for simulation"""

start_time = time.time()

u_out = Parallel(n_jobs=n_jobs)(delayed(simulate)(i) for i in range(cs.n_wavelength))
u_out = torch.stack(u_out).permute(1, 2, 0)

Expand Down
151 changes: 147 additions & 4 deletions waveprop/dataset_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
from PIL import Image
import os
import numpy as np
import glob
from waveprop.devices import sensor_dict
from waveprop.simulation import FarFieldSimulator
from abc import abstractmethod


class Datasets(object):
Expand All @@ -13,10 +17,149 @@ class Datasets(object):
FLICKR8k = "FLICKR"


# TODO : abstract parent class for Dataset, add distance for far-field propagation. See MNIST
# TODO : take into account FOV and offset


class SimulatedDataset(Dataset):
"""
Abstract class for simulated datasets.
"""

def __init__(
self,
transform_list=None,
psf=None,
target="original",
random_vflip=False,
random_hflip=False,
random_rotate=False,
**kwargs,
):
"""
Parameters
----------
transform_list : list of torchvision.transforms, optional
List of transforms to apply to image, by default None.
psf : np.ndarray, optional
Point spread function, by default None.
target : str, optional
Target to return, by default "original".
"original" : return propagated image and original image.
"object_plane" : return propagated image and object plane.
"label" : return propagated image and label.
random_vflip : float, optional
Probability of vertical flip, by default False.
random_hflip : float, optional
Probability of horizontal flip, by default False.
random_rotate : float, optional
Maximum angle of rotation, by default False.
"""

self.target = target

# random transforms
self._transform = None
if transform_list is None:
transform_list = []
if random_vflip:
transform_list.append(transforms.RandomVerticalFlip(p=random_vflip))
if random_hflip:
transform_list.append(transforms.RandomHorizontalFlip(p=random_hflip))
if random_rotate:
transform_list.append(transforms.RandomRotation(random_rotate))
if len(transform_list) > 0:
self._transform = transforms.Compose(transform_list)

# initialize simulator
if psf is not None:
if psf.shape[-1] <= 3:
raise ValueError("Channel dimension should not be last.")
self.sim = FarFieldSimulator(psf=psf, is_torch=True, **kwargs)

@abstractmethod
def get_image(self, index):
raise NotImplementedError

def __getitem__(self, index):

# load image
img, label = self.get_image(index)
if self._transform is not None:
img = self._transform(img)

# propagate and return with desired output
if self.target == "original":
return self.sim.propagate(img), img
elif self.target == "object_plane":
return self.sim.propagate(img, return_object_plane=True)
elif self.target == "label":
return self.sim.propagate(img), label

def __len__(self):
return self.n_files


class SimulatedDatasetFolder(SimulatedDataset):
"""
Dataset of propagated images from a folder of images.
"""

def __init__(self, path, image_ext="jpg", n_files=None, **kwargs):
"""
Parameters
----------
path : str
Path to folder of images.
image_ext : str, optional
Extension of images, by default "jpg".
n_files : int, optional
Number of files to load, by default load all.
"""

self.path = path
self._files = glob.glob(os.path.join(self.path, f"*.{image_ext}"))
if n_files is None:
self.n_files = len(self._files)
else:
self.n_files = n_files
self._files = self._files[:n_files]

# initialize simulator
super(SimulatedDatasetFolder, self).__init__(
transform_list=[transforms.ToTensor()], **kwargs
)

def get_image(self, index):
img = Image.open(self._files[index])
label = None
return img, label


class SimulatedPytorchDataset(SimulatedDataset):
"""
Dataset of propagated images from a torch Dataset.
"""

def __init__(self, dataset, **kwargs):
"""
Parameters
----------
dataset : torch.utils.data.Dataset
Dataset to propagate.
"""

assert isinstance(dataset, Dataset)
self.dataset = dataset
self.n_files = len(dataset)

# initialize simulator
super(SimulatedPytorchDataset, self).__init__(**kwargs)

def get_image(self, index):
return self.dataset[index]


class MNISTDataset(datasets.MNIST):
def __init__(
self,
Expand All @@ -34,7 +177,7 @@ def __init__(
vflip=True,
grayscale=True,
scale=(1, 1),
**kwargs
**kwargs,
):
"""
MNIST - 60'000 examples of 28x28
Expand Down Expand Up @@ -103,7 +246,7 @@ def __init__(
scale=(1, 1),
download=True,
vflip=True,
**kwargs
**kwargs,
):
"""
CIFAR10 - 50;000 examples of 32x32
Expand Down Expand Up @@ -172,7 +315,7 @@ def __init__(
grayscale=False,
device=None,
scale=(1, 1),
**kwargs
**kwargs,
):
"""
Flickr8k - varied, around 400x500
Expand Down
44 changes: 44 additions & 0 deletions waveprop/noise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import numpy as np
from scipy import ndimage
import torch


def add_shot_noise(image, snr_db, tol=1e-6):
"""
Add shot noise to image.
Parameters
----------
image : np.ndarray
Image.
snr_db : float
Signal-to-noise ratio in dB.
tol : float, optional
Tolerance for noise variance, by default 1e-6.
Returns
-------
np.ndarray
Image with added shot noise.
"""

if torch.is_tensor(image):
with torch.no_grad():
image_np = image.cpu().numpy()
else:
image_np = image

if image_np.min() < 0:
image_np -= image_np.min()
noise = np.random.poisson(image_np)

sig_var = ndimage.variance(image_np)
noise_var = np.maximum(ndimage.variance(noise), tol)
fact = np.sqrt(sig_var / noise_var / (10 ** (snr_db / 10)))

noise = fact * noise
if torch.is_tensor(image):
noise = torch.from_numpy(noise).to(image.device)

return image + fact * noise
61 changes: 61 additions & 0 deletions waveprop/pytorch_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,64 @@ def fftconvolve(in1, in2, mode=None, axes=None):
return torch.complex(_real, _imag)
else:
return crop(ret, top=y_pad_edge, left=x_pad_edge, height=s1[axes[0]], width=s1[axes[1]])


class RealFFTConvolve2D:
def __init__(self, filter, mode=None, axes=(-2, -1), img_shape=None, device=None):
"""
Operator that performs convolution in Fourier domain, and assumes
real-valued signals. Useful if convolving with the same filter, i.e.
avoid computing FFT of same filter.
Assume arrays of shape [..., H, W], where ... means an arbitrary number of leading dimensions.
Parameters
----------
filter array_like
2D filter to use. Must be of shape (channels, height, width) even if
only one channel.
img_shape : tuple
If image different shape than filter, specify here.
dtype : float32 or float64
Data type to use for optimization.
"""
assert torch.is_tensor(filter)
if device is not None:
filter = filter.to(device)
self.device = device

self.filter_shape = filter.shape
if img_shape is None:
self.img_shape = filter.shape
else:
assert len(img_shape) == 3
self.img_shape = img_shape
if axes is None:
self.shape = [
self.filter_shape[i] + self.img_shape[i] - 1 for i in range(len(self.filter_shape))
]
else:
self.shape = [self.filter_shape[i] + self.img_shape[i] - 1 for i in axes]
self.axes = axes
if mode is not None:
if mode != "same":
raise ValueError(f"{mode} mode not supported ")

self.filter_freq = torch.fft.rfftn(filter, self.shape, dim=axes)

def __call__(self, x):
orig_device = x.device
if self.device is not None:
x = x.to(self.device)
x_freq = torch.fft.rfftn(x, self.shape, dim=self.axes)
ret = torch.fft.irfftn(self.filter_freq * x_freq, self.shape, dim=self.axes)

y_pad_edge = int((self.shape[0] - self.img_shape[self.axes[0]]) / 2)
x_pad_edge = int((self.shape[1] - self.img_shape[self.axes[1]]) / 2)
return crop(
ret,
top=y_pad_edge,
left=x_pad_edge,
height=self.img_shape[self.axes[0]],
width=self.img_shape[self.axes[1]],
).to(orig_device)
Loading

0 comments on commit 552d2fd

Please sign in to comment.