Skip to content

Commit

Permalink
Merge branch 'expose_sim_param' of github.com:LCAV/LenslessPiCam into…
Browse files Browse the repository at this point in the history
… expose_sim_param
  • Loading branch information
ebezzam committed Jul 21, 2024
2 parents ee1e221 + 937a59c commit d0f9f8e
Show file tree
Hide file tree
Showing 5 changed files with 265 additions and 48 deletions.
36 changes: 27 additions & 9 deletions configs/recon_digicam_mirflickr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,37 @@ defaults:
- defaults_recon
- _self_

cache_dir: /dev/shm

# fn: null # if not null, download this file from https://huggingface.co/datasets/bezzam/DigiCam-Mirflickr-SingleMask-25K/tree/main
# fn: raw_box.png
# rotate: False
# alignment:
# dim: [190, 260]
# top_left: [145, 130]

fn: raw_stuffed_animals.png
rotate: False
alignment:
dim: [200, 280]
top_left: [115, 120]


# - Learned reconstructions: see "lensless/recon/model_dict.py"
# model: U10
# model: Unet8M
# model: TrainInv+Unet8M
# model: U10+Unet8M
# model: Unet4M+TrainInv+Unet4M
# model: Unet4M+U10+Unet4M
### dataset: mirflickr_single_25k
# model: U5+Unet8M_wave
# model: Unet4M+U5+Unet4M_wave
# model: TrainInv+Unet8M_wave
# model: Unet4M+TrainInv+Unet4M_wave

# ## dataset: mirflickr_multi_25k
# model: Unet4M+U5+Unet4M_wave

# -- for ADMM with fixed parameters
model: admm
n_iter: 10
n_iter: 100

device: cuda:0
n_trials: 100 # more if you want to get average inference time
device: cuda:2
n_trials: 1 # more if you want to get average inference time
idx: 1 # index from test set to reconstruct
save: True
13 changes: 13 additions & 0 deletions lensless/recon/model_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
"Unet2M+MWDN6M_wave": "bezzam/digicam-mirflickr-single-25k-unet2M-mwdn-6M",
"Unet4M+U5+Unet4M_wave_aux1": "bezzam/digicam-mirflickr-single-25k-unet4M-unrolled-admm5-unet4M-wave-aux1",
"Unet4M+U5+Unet4M_wave_flips": "bezzam/digicam-mirflickr-single-25k-unet4M-unrolled-admm5-unet4M-wave-flips",
"Unet4M+U5+Unet4M_wave_flips_rotate10": "bezzam/digicam-mirflickr-single-25k-unet4M-unrolled-admm5-unet4M-wave-flips-rotate10",
# measured PSF
"Unet4M+U10+Unet4M_measured": "bezzam/digicam-mirflickr-single-25k-unet4M-unrolled-admm10-unet4M-measured",
# simulated PSF (with waveprop, no deadspace)
Expand All @@ -131,6 +132,7 @@
"Unet4M+U10+Unet4M_wave": "bezzam/digicam-mirflickr-multi-25k-unet4M-unrolled-admm10-unet4M-wave",
"Unet4M+U5+Unet4M_wave": "bezzam/digicam-mirflickr-multi-25k-unet4M-unrolled-admm5-unet4M-wave",
"Unet4M+U5+Unet4M_wave_aux1": "bezzam/digicam-mirflickr-multi-25k-unet4M-unrolled-admm5-unet4M-wave-aux1",
"Unet4M+U5+Unet4M_wave_flips": "bezzam/digicam-mirflickr-multi-25k-unet4M-unrolled-admm5-unet4M-wave-flips",
},
},
"tapecam": {
Expand All @@ -145,6 +147,8 @@
"Unet2M+MWDN6M": "bezzam/tapecam-mirflickr-unet2M-mwdn-6M",
"Unet4M+U10+Unet4M": "bezzam/tapecam-mirflickr-unet4M-unrolled-admm10-unet4M",
"Unet4M+U5+Unet4M_flips": "bezzam/tapecam-mirflickr-unet4M-unrolled-admm5-unet4M-flips",
"Unet4M+U5+Unet4M_flips_rotate10": "bezzam/tapecam-mirflickr-unet4M-unrolled-admm5-unet4M-flips-rotate10",
"Unet4M+U5+Unet4M_aux1": "bezzam/tapecam-mirflickr-unet4M-unrolled-admm5-unet4M-aux1",
},
},
}
Expand Down Expand Up @@ -206,6 +210,7 @@ def load_model(
verbose=True,
skip_pre=False,
skip_post=False,
train_last_layer=False,
):

"""
Expand Down Expand Up @@ -279,6 +284,14 @@ def load_model(
else False,
)

if train_last_layer:
for param in post_process.parameters():
for name, param in post_process.named_parameters():
if "m_tail" in name:
param.requires_grad = True
else:
param.requires_grad = False

if config["reconstruction"]["method"] == "unrolled_admm":
recon = UnrolledADMM(
psf if mask is None else psf_learned,
Expand Down
188 changes: 168 additions & 20 deletions lensless/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from lensless.hardware.trainable_mask import prep_trainable_mask, AdafruitLCD
from lensless.utils.simulation import FarFieldSimulator
from lensless.utils.io import load_image, load_psf, save_image
from lensless.utils.image import is_grayscale, resize, rgb2gray, rotate_HWC
from lensless.utils.image import is_grayscale, resize, rgb2gray
import re
from lensless.hardware.utils import capture
from lensless.hardware.utils import display
Expand All @@ -30,6 +30,8 @@
from lensless.hardware.sensor import sensor_dict, SensorParam
from scipy.ndimage import rotate
import warnings
from waveprop.noise import add_shot_noise
from PIL import Image


def convert(text):
Expand Down Expand Up @@ -168,8 +170,6 @@ def __getitem__(self, idx):

# add noise
if self.input_snr is not None:
from waveprop.noise import add_shot_noise

lensless = add_shot_noise(lensless, self.input_snr)

# flip image x and y if needed
Expand Down Expand Up @@ -1031,6 +1031,12 @@ def __init__(
cache_dir=None,
single_channel_psf=False,
flipud=False,
display_res=None,
alignment=None,
sensor="rpi_hq",
slm="adafruit",
simulation_config=dict(),
snr_db=40,
**kwargs,
):
"""
Expand Down Expand Up @@ -1058,26 +1064,89 @@ def __init__(
# download PSF from huggingface
# TODO : assuming psf is not None
self.multimask = False
psf_fp = hf_hub_download(repo_id=huggingface_repo, filename=psf, repo_type="dataset")
psf, _ = load_psf(
psf_fp,
shape=self.lensless_shape,
return_float=True,
return_bg=True,
flip_ud=flipud,
bg_pix=(0, 15),
single_psf=single_channel_psf,
)
self.psf = torch.from_numpy(psf)
if single_channel_psf:
# replicate across three channels
self.psf = self.psf.repeat(1, 1, 1, 3)
self.convolver = None
if psf is not None:
psf_fp = hf_hub_download(repo_id=huggingface_repo, filename=psf, repo_type="dataset")
psf, _ = load_psf(
psf_fp,
shape=self.lensless_shape,
return_float=True,
return_bg=True,
flip_ud=flipud,
bg_pix=(0, 15),
single_psf=single_channel_psf,
)
self.psf = torch.from_numpy(psf)
if single_channel_psf:
# replicate across three channels
self.psf = self.psf.repeat(1, 1, 1, 3)

# create convolver object
self.convolver = RealFFTConvolve2D(self.psf)

elif "mask_label" in data_0:
self.multimask = True
mask_labels = []
for i in range(len(self.dataset)):
mask_labels.append(self.dataset[i]["mask_label"])
mask_labels = list(set(mask_labels))

# simulate all PSFs
self.psf = dict()
for label in mask_labels:
mask_fp = hf_hub_download(
repo_id=huggingface_repo,
filename=f"masks/mask_{label}.npy",
repo_type="dataset",
)
mask_vals = np.load(mask_fp)

if psf is None:
sensor_res = sensor_dict[sensor][SensorParam.RESOLUTION]
downsample_fact = min(sensor_res / lensless.shape[:2])
else:
downsample_fact = 1

mask = AdafruitLCD(
initial_vals=torch.from_numpy(mask_vals.astype(np.float32)),
sensor=sensor,
slm=slm,
downsample=downsample_fact,
flipud=rotate or flipud, # TODO separate commands?
use_waveprop=simulation_config.get("use_waveprop", False),
scene2mask=simulation_config.get("scene2mask", None),
mask2sensor=simulation_config.get("mask2sensor", None),
deadspace=simulation_config.get("deadspace", True),
)
self.psf[label] = mask.get_psf().detach()

assert (
self.psf[label].shape[-3:-1] == lensless.shape[:2]
), f"PSF shape should match lensless shape: PSF {self.psf[label].shape[-3:-1]} vs lensless {lensless.shape[:2]}"

# create convolver object
self.convolver = RealFFTConvolve2D(self.psf[label])
assert self.convolver is not None

# TODO create convolver object
self.convolver = RealFFTConvolve2D(self.psf)
self.crop = None
self.random_flip = None
self.flipud = flipud
self.snr_db = snr_db

self.display_res = display_res
self.alignment = None
self.cropped_lensed_shape = None
if alignment is not None:
self.alignment = dict(alignment.copy())
self.alignment["top_left"] = (
int(self.alignment["top_left"][0] / downsample),
int(self.alignment["top_left"][1] / downsample),
)
self.alignment["height"] = int(self.alignment["height"] / downsample)

original_aspect_ratio = display_res[1] / display_res[0]
self.alignment["width"] = int(self.alignment["height"] * original_aspect_ratio)
self.cropped_lensed_shape = (self.alignment["height"], self.alignment["width"], 3)

super(HFSimulated, self).__init__(**kwargs)

Expand All @@ -1099,21 +1168,100 @@ def _get_images_pair(self, idx):
lensed_np = lensed_np.astype(np.float32) / 65535

# resize if necessary
if (self.lensless_shape != np.array(lensed_np.shape[:2])).any():
if self.cropped_lensed_shape is not None:
cropped_lensed_np = resize(
lensed_np, shape=self.cropped_lensed_shape, interpolation=cv2.INTER_NEAREST
)
lensed_np = np.zeros(tuple(self.lensless_shape) + (3,), dtype=np.float32)
lensed_np[
self.alignment["top_left"][0] : self.alignment["top_left"][0]
+ self.alignment["height"],
self.alignment["top_left"][1] : self.alignment["top_left"][1]
+ self.alignment["width"],
] = cropped_lensed_np

elif (self.lensless_shape != np.array(lensed_np.shape[:2])).any():

lensed_np = resize(
lensed_np, shape=self.lensless_shape, interpolation=cv2.INTER_NEAREST
)
lensed = torch.from_numpy(lensed_np)

# simulate lensless with convolution
lensed = lensed.unsqueeze(0) # add batch dimension

if self.multimask:
mask_label = self.dataset[idx]["mask_label"]
self.convolver.set_psf(self.psf[mask_label])
lensless = self.convolver.convolve(lensed)

# add noise
if self.snr_db is not None:
lensless = add_shot_noise(lensless, self.snr_db)

if lensless.max() > 1:
print("CLIPPING!")
lensless /= lensless.max()

if self.cropped_lensed_shape:
return lensless, torch.from_numpy(cropped_lensed_np)
else:
return lensless, lensed

def __getitem__(self, idx):
lensless, lensed = super().__getitem__(idx)
if self.multimask:
mask_label = self.dataset[idx]["mask_label"]
return lensless, lensed, self.psf[mask_label]
return lensless, lensed

def extract_roi(self, reconstruction, lensed=None, axis=(1, 2), **kwargs):
"""
Extract region of interest from lensless and lensed images.
"""

n_dim = len(reconstruction.shape)
assert max(axis) < n_dim, "Axis should be within the dimensions of the reconstruction."

# add batch dimension
if n_dim == 3:
if isinstance(reconstruction, torch.Tensor):
reconstruction = reconstruction.unsqueeze(0)
else:
reconstruction = reconstruction[np.newaxis]
# increment axis
axis = (axis[0] + 1, axis[1] + 1)

# extract
if self.alignment is not None:
top_left = self.alignment["top_left"]
height = self.alignment["height"]
width = self.alignment["width"]

# extract according to axis
index = [slice(None)] * n_dim
index[axis[0]] = slice(top_left[0], top_left[0] + height)
index[axis[1]] = slice(top_left[1], top_left[1] + width)
reconstruction = reconstruction[tuple(index)]

# rotate if necessary
angle = self.alignment.get("angle", 0)
if isinstance(reconstruction, torch.Tensor) and angle:
reconstruction = F.rotate(reconstruction, angle, expand=False)
elif angle:
reconstruction = rotate(reconstruction, angle, axes=axis, reshape=False)

# remove batch dimension
if n_dim == 3:
if isinstance(reconstruction, torch.Tensor):
reconstruction = reconstruction.squeeze(0)
else:
reconstruction = reconstruction[0]

if lensed is None:
return reconstruction
return reconstruction, lensed


class HFDataset(DualDataset):
def __init__(
Expand Down
Loading

0 comments on commit d0f9f8e

Please sign in to comment.