diff --git a/configs/benchmark_hyperspectral.yaml b/configs/benchmark_hyperspectral.yaml new file mode 100644 index 00000000..53671cf4 --- /dev/null +++ b/configs/benchmark_hyperspectral.yaml @@ -0,0 +1,114 @@ +# python scripts/eval/benchmark_recon.py +#Hydra config +hydra: + run: + dir: "benchmark/${now:%Y-%m-%d}/${now:%H-%M-%S}" + job: + chdir: True + + +dataset: PolarLitis # DiffuserCam, DigiCamCelebA, HFDataset +seed: 0 +batchsize: 1 # must be 1 for iterative approaches + +huggingface: + repo: "noakraicer/polarlitis" + cache_dir: null # where to read/write dataset. Defaults to `"~/.cache/huggingface/datasets"`. + psf: psf.mat + mask: mask.npy # null for simulating PSF + image_res: [250, 250] # used during measurement + rotate: False # if measurement is upside-down + flipud: False + flip_lensed: False # if rotate or flipud is True, apply to lensed + + alignment: + top_left: null + height: null + + downsample: 1 + downsample_lensed: 2 + split_seed: null + single_channel_psf: True + +device: "cuda" +# numbers of iterations to benchmark +n_iter_range: [2000] +# number of files to benchmark +n_files: null # null for all files +#How much should the image be downsampled +downsample: 2 +#algorithm to benchmark +algorithms: ["HyperSpectralFISTA"] #["ADMM", "ADMM_Monakhova2019", "FISTA", "GradientDescent", "NesterovGradientDescent"] + +# baseline from Monakhova et al. 2019, https://arxiv.org/abs/1908.11502 +baseline: "MONAKHOVA 100iter" + +save_idx: [0, 1, 2, 3, 4] # provide index of files to save e.g. [1, 5, 10] +gamma_psf: 1.5 # gamma factor for PSF + + +# Hyperparameters +nesterov: + p: 0 + mu: 0.9 +fista: + tk: 1 +admm: + mu1: 1e-6 + mu2: 1e-5 + mu3: 4e-5 + tau: 0.0001 + + +# for DigiCamCelebA +files: + test_size: 0.15 + downsample: 1 + celeba_root: /scratch/bezzam + + + # dataset: /scratch/bezzam/celeba_adafruit_random_2mm_20230720_10K + # psf: data/psf/adafruit_random_2mm_20231907.png + # vertical_shift: null + # horizontal_shift: null + # crop: null + + dataset: /scratch/bezzam/celeba/celeba_adafruit_random_30cm_2mm_20231004_26K + psf: rpi_hq_adafruit_psf_2mm/raw_data_rgb.png + vertical_shift: -117 + horizontal_shift: -25 + crop: + vertical: [0, 525] + horizontal: [265, 695] + +# for prepping ground truth data +#for simulated dataset +simulation: + grayscale: False + output_dim: null # should be set if no PSF is used + # random variations + object_height: 0.33 # [m], range for random height or scalar + flip: True # change the orientation of the object (from vertical to horizontal) + random_shift: False + random_vflip: 0.5 + random_hflip: 0.5 + random_rotate: False + # these distance parameters are typically fixed for a given PSF + # for DiffuserCam psf # for tape_rgb psf + # scene2mask: 10e-2 # scene2mask: 40e-2 + # mask2sensor: 9e-3 # mask2sensor: 4e-3 + # -- for CelebA + scene2mask: 0.25 # [m] + mask2sensor: 0.002 # [m] + deadspace: True # whether to account for deadspace for programmable mask + # see waveprop.devices + use_waveprop: False # for PSF simulation + sensor: "rpi_hq" + snr_db: 10 + # simulate different sensor resolution + # output_dim: [24, 32] # [H, W] or null + # Downsampling for PSF + downsample: 8 + # max val in simulated measured (quantized 8 bits) + quantize: False # must be False for differentiability + max_val: 255 diff --git a/lensless/__init__.py b/lensless/__init__.py index 70990774..4d67f179 100644 --- a/lensless/__init__.py +++ b/lensless/__init__.py @@ -20,6 +20,7 @@ NesterovGradientDescent, FISTA, GradientDescentUpdate, + HyperSpectralFISTA ) from .recon.tikhonov import CodedApertureReconstruction from .hardware.sensor import VirtualSensor, SensorOptions diff --git a/lensless/recon/gd.py b/lensless/recon/gd.py index dc61e809..b7b69e98 100644 --- a/lensless/recon/gd.py +++ b/lensless/recon/gd.py @@ -238,6 +238,77 @@ def _update(self, iter): self._xk = xk +class HyperSpectralFISTA(FISTA): + """ + Applying HyperSpectral FISTA as in: https://github.com/Waller-Lab/SpectralDiffuserCam + + """ + + def __init__(self, psf, mask, **kwargs): + """ + + Parameters + ---------- + mask : + Hyperspectral mask + + """ + # same PSF for all hyperspectral channels + assert psf.shape[-1] == 1 + assert mask.shape[-3:-1] == psf.shape[-3:-1] + self._mask = mask[None, ...] # adding batch dimension + + super(HyperSpectralFISTA, self).__init__(psf, **kwargs) + + def reset(self): + + # TODO set lipschitz constant correctly/differently? + + if self.is_torch: + if self._initial_est is not None: + self._image_est = self._initial_est + else: + # initial guess, half intensity image + psf_flat = self._psf.reshape(-1, self._psf_shape[3]) + pixel_start = ( + torch.max(psf_flat, axis=0).values + torch.min(psf_flat, axis=0).values + ) / 2 + # initialize image estimate as [Batch, Depth, Height, Width, Channels] + self._image_est = torch.ones_like(self._mask) * pixel_start + + # set step size as < 2 / lipschitz + Hadj_flat = self._convolver._Hadj.reshape(-1, self._psf_shape[3]) + H_flat = self._convolver._H.reshape(-1, self._psf_shape[3]) + self._alpha = torch.real(1.8 / torch.max(torch.abs(Hadj_flat * H_flat), axis=0).values) + + else: + if self._initial_est is not None: + self._image_est = self._initial_est + else: + psf_flat = self._psf.reshape(-1, self._psf_shape[3]) + pixel_start = (np.max(psf_flat, axis=0) + np.min(psf_flat, axis=0)) / 2 + # initialize image estimate as [Batch, Depth, Height, Width, Channels] + self._image_est = np.ones_like(self._mask) * pixel_start + + # set step size as < 2 / lipschitz + Hadj_flat = self._convolver._Hadj.reshape(-1, self._psf_shape[3]) + H_flat = self._convolver._H.reshape(-1, self._psf_shape[3]) + self._alpha = np.real(1.8 / np.max(Hadj_flat * H_flat, axis=0)) + + # # TODO how was his value determined? + # self._alpha = 1 / 4770.13 + + def _grad(self): + # make sure to sum on correct axis, and apply mask on correct dimensions + diff = ( + np.sum(self._mask * self._convolver.convolve(self._image_est), -1, keepdims=True) + - self._data + ) # (B, D, H, W, 1) + return self._convolver.deconvolve( + diff * self._mask + ) # (H, W, C) where C is number of hyperspectral channels + + def apply_gradient_descent(psf_fp, data_fp, n_iter, verbose=False, proj=non_neg, **kwargs): # load data diff --git a/lensless/recon/recon.py b/lensless/recon/recon.py index ff1fc55c..5f5ce924 100644 --- a/lensless/recon/recon.py +++ b/lensless/recon/recon.py @@ -372,7 +372,6 @@ def set_data(self, data): assert np.all( self._psf_shape[-3:-1] == np.array(data.shape)[-3:-1] ), "PSF and data shape mismatch" - if len(data.shape) == 3: self._data = data[None, None, ...] elif len(data.shape) == 4: @@ -569,6 +568,7 @@ def apply( for i in range(n_iter): self._update(i) + if self.compensation_branch is not None and i < self._n_iter - 1: self.compensation_branch_inputs.append(self._form_image()) diff --git a/lensless/recon/rfft_convolve.py b/lensless/recon/rfft_convolve.py index e7b9be74..ea057971 100644 --- a/lensless/recon/rfft_convolve.py +++ b/lensless/recon/rfft_convolve.py @@ -24,7 +24,7 @@ class RealFFTConvolve2D: - def __init__(self, psf, dtype=None, pad=True, norm="ortho", rgb=None, **kwargs): + def __init__(self, psf, dtype=None, pad=True, norm=None, rgb=None, **kwargs): """ Linear operator that performs convolution in Fourier domain, and assumes real-valued signals. @@ -82,18 +82,24 @@ def _crop(self, x): ] def _pad(self, v): + + shape = self._padded_shape.copy() + if v.shape[-1] != self._padded_shape[-1]: + # different number of channels in PSF and data + assert v.shape[-1] == 1 or self._padded_shape[-1] == 1 + shape[-1] = v.shape[-1] + if len(v.shape) == 5: batch_size = v.shape[0] - shape = [batch_size] + self._padded_shape - elif len(v.shape) == 4: - shape = self._padded_shape - else: + shape = [batch_size] + shape + elif len(v.shape) != 4: raise ValueError("Expected 4D or 5D tensor") if self.is_torch: vpad = torch.zeros(size=shape, dtype=v.dtype, device=v.device) else: vpad = np.zeros(shape).astype(v.dtype) + vpad[ ..., self._start_idx[0] : self._end_idx[0], self._start_idx[1] : self._end_idx[1], : ] = v @@ -138,7 +144,7 @@ def convolve(self, x): self._padded_data = self._pad(x) else: if self.is_torch: - self._padded_data = x # .type(self.dtype).to(self._psf.device) + self._padded_data = x else: self._padded_data[:] = x # .astype(self.dtype) diff --git a/lensless/utils/dataset.py b/lensless/utils/dataset.py index 5ce95f7a..d58eca05 100644 --- a/lensless/utils/dataset.py +++ b/lensless/utils/dataset.py @@ -17,7 +17,7 @@ from torchvision.transforms import functional as F from lensless.hardware.trainable_mask import prep_trainable_mask, AdafruitLCD from lensless.utils.simulation import FarFieldSimulator -from lensless.utils.io import load_image, load_psf, save_image +from lensless.utils.io import load_image, load_psf, save_image,load_mask from lensless.utils.image import is_grayscale, resize, rgb2gray import re from lensless.hardware.utils import capture @@ -1271,6 +1271,7 @@ def __init__( split, n_files=None, psf=None, + mask=None, rotate=False, # just the lensless image flipud=False, flip_lensed=False, @@ -1409,11 +1410,11 @@ def __init__( if psf is not None: # download PSF from huggingface psf_fp = hf_hub_download(repo_id=huggingface_repo, filename=psf, repo_type="dataset") - psf, _ = load_psf( + psf = load_psf( psf_fp, shape=lensless.shape, return_float=True, - return_bg=True, + return_bg=False, flip=self.rotate, flip_ud=flipud, bg_pix=(0, 15), @@ -1424,6 +1425,10 @@ def __init__( if single_channel_psf: # replicate across three channels self.psf = self.psf.repeat(1, 1, 1, 3) + if mask is not None: + mask_fp = hf_hub_download(repo_id=huggingface_repo, filename=mask, repo_type="dataset") + mask = load_mask(mask_fp) + self.mask= torch.from_numpy(mask) elif "mask_label" in data_0: self.multimask = True @@ -1563,7 +1568,9 @@ def _get_images_pair(self, idx): # convert to float if lensless_np.dtype == np.uint8: lensless_np = lensless_np.astype(np.float32) / 255 + lensless_np = lensless_np / np.max(lensless_np) lensed_np = lensed_np.astype(np.float32) / 255 + lensed_np = lensed_np / np.max(lensed_np) else: # 16 bit lensless_np = lensless_np.astype(np.float32) / 65535 diff --git a/lensless/utils/io.py b/lensless/utils/io.py index 5596befd..1d505193 100644 --- a/lensless/utils/io.py +++ b/lensless/utils/io.py @@ -8,7 +8,7 @@ import os.path import warnings - +import scipy import cv2 import numpy as np from PIL import Image @@ -18,6 +18,11 @@ from lensless.utils.plot import plot_image +def load_mask(fp): + mask = np.load(fp) + return np.expand_dims(mask, axis=0) + + def load_image( fp, verbose=False, @@ -121,6 +126,9 @@ def load_image( black_level = np.array(raw.black_level_per_channel[:3]).astype(np.float32) elif "npy" in fp or "npz" in fp: img = np.load(fp) + elif "mat" in fp: + mat = scipy.io.loadmat(fp) + img = mat["psf"][:, :, 0] else: img = cv2.imread(fp, cv2.IMREAD_UNCHANGED) @@ -202,7 +210,6 @@ def load_image( else: if dtype is None: dtype = original_dtype - img = img.astype(dtype) return img diff --git a/scripts/eval/benchmark_recon.py b/scripts/eval/benchmark_recon.py index 76fbc367..b7571698 100644 --- a/scripts/eval/benchmark_recon.py +++ b/scripts/eval/benchmark_recon.py @@ -25,7 +25,7 @@ import pathlib as plib from lensless.eval.benchmark import benchmark import matplotlib.pyplot as plt -from lensless import ADMM, FISTA, GradientDescent, NesterovGradientDescent +from lensless import ADMM, FISTA, GradientDescent, NesterovGradientDescent,HyperSpectralFISTA from lensless.utils.dataset import DiffuserCamTestDataset, DigiCamCelebA, HFDataset from lensless.utils.io import save_image from lensless.utils.image import gamma_correction @@ -35,7 +35,7 @@ from torch.utils.data import Subset -@hydra.main(version_base=None, config_path="../../configs", config_name="benchmark") +@hydra.main(version_base=None, config_path="../../configs", config_name="benchmark_hyperspectral") def benchmark_recon(config): # set seed @@ -86,7 +86,7 @@ def benchmark_recon(config): _, benchmark_dataset = torch.utils.data.random_split( dataset, [train_size, test_size], generator=generator ) - elif dataset == "HFDataset": + elif dataset == "PolarLitis": split_test = "test" if config.huggingface.split_seed is not None: @@ -120,6 +120,7 @@ def benchmark_recon(config): huggingface_repo=config.huggingface.repo, cache_dir=config.huggingface.cache_dir, psf=config.huggingface.psf, + mask = config.huggingface.mask, n_files=n_files, split=split_test, display_res=config.huggingface.image_res, @@ -138,6 +139,8 @@ def benchmark_recon(config): psf = benchmark_dataset.psf[first_psf_key].to(device) else: psf = benchmark_dataset.psf.to(device) + mask = benchmark_dataset.mask.to(device) + else: raise ValueError(f"Dataset {dataset} not supported") @@ -190,6 +193,8 @@ def benchmark_recon(config): ) if algo == "FISTA": model_list.append(("FISTA", FISTA(psf, tk=config.fista.tk))) + if algo == "HyperSpectralFISTA": + model_list.append(("HyperSpectralFISTA", HyperSpectralFISTA(psf,mask, tk=config.fista.tk))) if algo == "GradientDescent": model_list.append(("GradientDescent", GradientDescent(psf))) if algo == "NesterovGradientDescent": @@ -243,7 +248,7 @@ def benchmark_recon(config): :2 ] # take first two in case multimask dataset ground_truth_np = ground_truth.cpu().numpy()[0] - lensless_np = lensless.cpu().numpy()[0] + lensless_np = lensless.cpu().numpy() if crop is not None: ground_truth_np = ground_truth_np[ diff --git a/scripts/recon/hyperspectral.py b/scripts/recon/hyperspectral.py new file mode 100644 index 00000000..be2adafc --- /dev/null +++ b/scripts/recon/hyperspectral.py @@ -0,0 +1,103 @@ +""" +Apply FISTA for hyperspectral data recovery. + +``` +python scripts/recon/hyperspectral.py +``` + +""" + +import hydra +from hydra.utils import to_absolute_path +import os +import numpy as np +import time +import pathlib as plib +import matplotlib.pyplot as plt +from lensless.utils.io import load_image +from lensless import ( + HyperSpectralFISTA, +) +import scipy + + +@hydra.main(version_base=None, config_path="../../configs", config_name="defaults_recon") +def gradient_descent( + config, +): + + # set paths + mask_fp = "/root/FORKS/LenslessPiCamNoa/data/mask.npy" + psf_fp = "/root/FORKS/LenslessPiCamNoa/data/psf.mat" + data_fp = "/root/FORKS/LenslessPiCamNoa/data/266_lensless.png" + + ### - put your paths + # mask_fp = None + # psf_fp = None + # data_fp = None + + # load mask and PSF + mask = np.load(mask_fp) + mask = np.expand_dims(mask, axis=0) + mask = mask.astype(np.float32) + + # load PSF + + mat = scipy.io.loadmat(psf_fp) + psf = mat["psf"][:, :, 0] + psf = psf.astype(np.float32) + psf = psf[10:260, 35 : 320 - 35] + psf = psf / np.linalg.norm(psf) + psf = np.expand_dims(psf, axis=0) # add depth + psf = np.expand_dims(psf, axis=-1) # add channels + + # load data + data = load_image( + data_fp, + return_float=True, + normalize=False, + dtype=np.float32, + ) + # -- add depth and channels dimensions + data = np.expand_dims(data, axis=0) + data = np.expand_dims(data, axis=-1) + + # apply FISTA + save = config["save"] + if save: + save = os.getcwd() + + start_time = time.time() + recon = HyperSpectralFISTA( + psf, + mask, + # norm=None, + norm="ortho", + ) + recon.set_data(data) + print(f"Setup time : {time.time() - start_time} s") + + start_time = time.time() + res = recon.apply( + n_iter=100, + disp_iter=20, + save=save, + gamma=1.0, + plot=False, + ) + print(f"Processing time : {time.time() - start_time} s") + + if config.torch: + img = res[0].cpu().numpy() + else: + img = res[0] + + if config["display"]["plot"]: + plt.show() + if save: + np.save(plib.Path(save) / "final_reconstruction.npy", img) + print(f"Files saved to : {save}") + + +if __name__ == "__main__": + gradient_descent()