Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SpectralDifusserCam #150

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions configs/benchmark_hyperspectral.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# python scripts/eval/benchmark_recon.py
#Hydra config
hydra:
run:
dir: "benchmark/${now:%Y-%m-%d}/${now:%H-%M-%S}"
job:
chdir: True


dataset: PolarLitis # DiffuserCam, DigiCamCelebA, HFDataset
seed: 0
batchsize: 1 # must be 1 for iterative approaches

huggingface:
repo: "noakraicer/polarlitis"
cache_dir: null # where to read/write dataset. Defaults to `"~/.cache/huggingface/datasets"`.
psf: psf.mat
mask: mask.npy # null for simulating PSF
image_res: [250, 250] # used during measurement
rotate: False # if measurement is upside-down
flipud: False
flip_lensed: False # if rotate or flipud is True, apply to lensed

alignment:
top_left: null
height: null

downsample: 1
downsample_lensed: 2
split_seed: null
single_channel_psf: True

device: "cuda"
# numbers of iterations to benchmark
n_iter_range: [2000]
# number of files to benchmark
n_files: null # null for all files
#How much should the image be downsampled
downsample: 2
#algorithm to benchmark
algorithms: ["HyperSpectralFISTA"] #["ADMM", "ADMM_Monakhova2019", "FISTA", "GradientDescent", "NesterovGradientDescent"]

# baseline from Monakhova et al. 2019, https://arxiv.org/abs/1908.11502
baseline: "MONAKHOVA 100iter"

save_idx: [0, 1, 2, 3, 4] # provide index of files to save e.g. [1, 5, 10]
gamma_psf: 1.5 # gamma factor for PSF


# Hyperparameters
nesterov:
p: 0
mu: 0.9
fista:
tk: 1
admm:
mu1: 1e-6
mu2: 1e-5
mu3: 4e-5
tau: 0.0001


# for DigiCamCelebA
files:
test_size: 0.15
downsample: 1
celeba_root: /scratch/bezzam


# dataset: /scratch/bezzam/celeba_adafruit_random_2mm_20230720_10K
# psf: data/psf/adafruit_random_2mm_20231907.png
# vertical_shift: null
# horizontal_shift: null
# crop: null

dataset: /scratch/bezzam/celeba/celeba_adafruit_random_30cm_2mm_20231004_26K
psf: rpi_hq_adafruit_psf_2mm/raw_data_rgb.png
vertical_shift: -117
horizontal_shift: -25
crop:
vertical: [0, 525]
horizontal: [265, 695]

# for prepping ground truth data
#for simulated dataset
simulation:
grayscale: False
output_dim: null # should be set if no PSF is used
# random variations
object_height: 0.33 # [m], range for random height or scalar
flip: True # change the orientation of the object (from vertical to horizontal)
random_shift: False
random_vflip: 0.5
random_hflip: 0.5
random_rotate: False
# these distance parameters are typically fixed for a given PSF
# for DiffuserCam psf # for tape_rgb psf
# scene2mask: 10e-2 # scene2mask: 40e-2
# mask2sensor: 9e-3 # mask2sensor: 4e-3
# -- for CelebA
scene2mask: 0.25 # [m]
mask2sensor: 0.002 # [m]
deadspace: True # whether to account for deadspace for programmable mask
# see waveprop.devices
use_waveprop: False # for PSF simulation
sensor: "rpi_hq"
snr_db: 10
# simulate different sensor resolution
# output_dim: [24, 32] # [H, W] or null
# Downsampling for PSF
downsample: 8
# max val in simulated measured (quantized 8 bits)
quantize: False # must be False for differentiability
max_val: 255
1 change: 1 addition & 0 deletions lensless/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
NesterovGradientDescent,
FISTA,
GradientDescentUpdate,
HyperSpectralFISTA
)
from .recon.tikhonov import CodedApertureReconstruction
from .hardware.sensor import VirtualSensor, SensorOptions
Expand Down
71 changes: 71 additions & 0 deletions lensless/recon/gd.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,77 @@ def _update(self, iter):
self._xk = xk


class HyperSpectralFISTA(FISTA):
"""
Applying HyperSpectral FISTA as in: https://github.com/Waller-Lab/SpectralDiffuserCam

"""

def __init__(self, psf, mask, **kwargs):
"""

Parameters
----------
mask :
Hyperspectral mask

"""
# same PSF for all hyperspectral channels
assert psf.shape[-1] == 1
assert mask.shape[-3:-1] == psf.shape[-3:-1]
self._mask = mask[None, ...] # adding batch dimension

super(HyperSpectralFISTA, self).__init__(psf, **kwargs)

def reset(self):

# TODO set lipschitz constant correctly/differently?

if self.is_torch:
if self._initial_est is not None:
self._image_est = self._initial_est
else:
# initial guess, half intensity image
psf_flat = self._psf.reshape(-1, self._psf_shape[3])
pixel_start = (
torch.max(psf_flat, axis=0).values + torch.min(psf_flat, axis=0).values
) / 2
# initialize image estimate as [Batch, Depth, Height, Width, Channels]
self._image_est = torch.ones_like(self._mask) * pixel_start

# set step size as < 2 / lipschitz
Hadj_flat = self._convolver._Hadj.reshape(-1, self._psf_shape[3])
H_flat = self._convolver._H.reshape(-1, self._psf_shape[3])
self._alpha = torch.real(1.8 / torch.max(torch.abs(Hadj_flat * H_flat), axis=0).values)

else:
if self._initial_est is not None:
self._image_est = self._initial_est
else:
psf_flat = self._psf.reshape(-1, self._psf_shape[3])
pixel_start = (np.max(psf_flat, axis=0) + np.min(psf_flat, axis=0)) / 2
# initialize image estimate as [Batch, Depth, Height, Width, Channels]
self._image_est = np.ones_like(self._mask) * pixel_start

# set step size as < 2 / lipschitz
Hadj_flat = self._convolver._Hadj.reshape(-1, self._psf_shape[3])
H_flat = self._convolver._H.reshape(-1, self._psf_shape[3])
self._alpha = np.real(1.8 / np.max(Hadj_flat * H_flat, axis=0))

# # TODO how was his value determined?
# self._alpha = 1 / 4770.13

def _grad(self):
# make sure to sum on correct axis, and apply mask on correct dimensions
diff = (
np.sum(self._mask * self._convolver.convolve(self._image_est), -1, keepdims=True)
- self._data
) # (B, D, H, W, 1)
return self._convolver.deconvolve(
diff * self._mask
) # (H, W, C) where C is number of hyperspectral channels


def apply_gradient_descent(psf_fp, data_fp, n_iter, verbose=False, proj=non_neg, **kwargs):

# load data
Expand Down
2 changes: 1 addition & 1 deletion lensless/recon/recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,6 @@ def set_data(self, data):
assert np.all(
self._psf_shape[-3:-1] == np.array(data.shape)[-3:-1]
), "PSF and data shape mismatch"

if len(data.shape) == 3:
self._data = data[None, None, ...]
elif len(data.shape) == 4:
Expand Down Expand Up @@ -569,6 +568,7 @@ def apply(

for i in range(n_iter):
self._update(i)

if self.compensation_branch is not None and i < self._n_iter - 1:
self.compensation_branch_inputs.append(self._form_image())

Expand Down
18 changes: 12 additions & 6 deletions lensless/recon/rfft_convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@


class RealFFTConvolve2D:
def __init__(self, psf, dtype=None, pad=True, norm="ortho", rgb=None, **kwargs):
def __init__(self, psf, dtype=None, pad=True, norm=None, rgb=None, **kwargs):
"""
Linear operator that performs convolution in Fourier domain, and assumes
real-valued signals.
Expand Down Expand Up @@ -82,18 +82,24 @@ def _crop(self, x):
]

def _pad(self, v):

shape = self._padded_shape.copy()
if v.shape[-1] != self._padded_shape[-1]:
# different number of channels in PSF and data
assert v.shape[-1] == 1 or self._padded_shape[-1] == 1
shape[-1] = v.shape[-1]

if len(v.shape) == 5:
batch_size = v.shape[0]
shape = [batch_size] + self._padded_shape
elif len(v.shape) == 4:
shape = self._padded_shape
else:
shape = [batch_size] + shape
elif len(v.shape) != 4:
raise ValueError("Expected 4D or 5D tensor")

if self.is_torch:
vpad = torch.zeros(size=shape, dtype=v.dtype, device=v.device)
else:
vpad = np.zeros(shape).astype(v.dtype)

vpad[
..., self._start_idx[0] : self._end_idx[0], self._start_idx[1] : self._end_idx[1], :
] = v
Expand Down Expand Up @@ -138,7 +144,7 @@ def convolve(self, x):
self._padded_data = self._pad(x)
else:
if self.is_torch:
self._padded_data = x # .type(self.dtype).to(self._psf.device)
self._padded_data = x
else:
self._padded_data[:] = x # .astype(self.dtype)

Expand Down
13 changes: 10 additions & 3 deletions lensless/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torchvision.transforms import functional as F
from lensless.hardware.trainable_mask import prep_trainable_mask, AdafruitLCD
from lensless.utils.simulation import FarFieldSimulator
from lensless.utils.io import load_image, load_psf, save_image
from lensless.utils.io import load_image, load_psf, save_image,load_mask
from lensless.utils.image import is_grayscale, resize, rgb2gray
import re
from lensless.hardware.utils import capture
Expand Down Expand Up @@ -1271,6 +1271,7 @@ def __init__(
split,
n_files=None,
psf=None,
mask=None,
rotate=False, # just the lensless image
flipud=False,
flip_lensed=False,
Expand Down Expand Up @@ -1409,11 +1410,11 @@ def __init__(
if psf is not None:
# download PSF from huggingface
psf_fp = hf_hub_download(repo_id=huggingface_repo, filename=psf, repo_type="dataset")
psf, _ = load_psf(
psf = load_psf(
psf_fp,
shape=lensless.shape,
return_float=True,
return_bg=True,
return_bg=False,
flip=self.rotate,
flip_ud=flipud,
bg_pix=(0, 15),
Expand All @@ -1424,6 +1425,10 @@ def __init__(
if single_channel_psf:
# replicate across three channels
self.psf = self.psf.repeat(1, 1, 1, 3)
if mask is not None:
mask_fp = hf_hub_download(repo_id=huggingface_repo, filename=mask, repo_type="dataset")
mask = load_mask(mask_fp)
self.mask= torch.from_numpy(mask)

elif "mask_label" in data_0:
self.multimask = True
Expand Down Expand Up @@ -1563,7 +1568,9 @@ def _get_images_pair(self, idx):
# convert to float
if lensless_np.dtype == np.uint8:
lensless_np = lensless_np.astype(np.float32) / 255
lensless_np = lensless_np / np.max(lensless_np)
lensed_np = lensed_np.astype(np.float32) / 255
lensed_np = lensed_np / np.max(lensed_np)
else:
# 16 bit
lensless_np = lensless_np.astype(np.float32) / 65535
Expand Down
11 changes: 9 additions & 2 deletions lensless/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import os.path
import warnings

import scipy
import cv2
import numpy as np
from PIL import Image
Expand All @@ -18,6 +18,11 @@
from lensless.utils.plot import plot_image


def load_mask(fp):
mask = np.load(fp)
return np.expand_dims(mask, axis=0)


def load_image(
fp,
verbose=False,
Expand Down Expand Up @@ -121,6 +126,9 @@ def load_image(
black_level = np.array(raw.black_level_per_channel[:3]).astype(np.float32)
elif "npy" in fp or "npz" in fp:
img = np.load(fp)
elif "mat" in fp:
mat = scipy.io.loadmat(fp)
img = mat["psf"][:, :, 0]
else:
img = cv2.imread(fp, cv2.IMREAD_UNCHANGED)

Expand Down Expand Up @@ -202,7 +210,6 @@ def load_image(
else:
if dtype is None:
dtype = original_dtype
img = img.astype(dtype)

return img

Expand Down
Loading
Loading