Skip to content

Commit

Permalink
Add option to train on simulated, and test on measured.
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed Jul 11, 2024
1 parent 230af02 commit 4d8cd85
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 25 deletions.
1 change: 1 addition & 0 deletions configs/train_unrolledADMM.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ files:
huggingface_dataset: True
huggingface_psf: psf.tiff
single_channel_psf: False # whether to sum all PSF channels into one
hf_simulated: False

# -- train/test split
split_seed: null # if null use train/test split from dataset
Expand Down
104 changes: 100 additions & 4 deletions lensless/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
from lensless.hardware.utils import capture
from lensless.hardware.utils import display
from lensless.hardware.slm import set_programmable_mask, adafruit_sub2full
from datasets import load_dataset, load_from_disk
from datasets import load_dataset
from lensless.recon.rfft_convolve import RealFFTConvolve2D
from huggingface_hub import hf_hub_download
import cv2
from lensless.hardware.sensor import sensor_dict, SensorParam
Expand Down Expand Up @@ -1019,6 +1020,101 @@ def _get_images_pair(self, idx):
return lensless, lensed


class HFSimulated(DualDataset):
def __init__(
self,
huggingface_repo,
split,
n_files=None,
psf=None,
downsample=1,
cache_dir=None,
single_channel_psf=False,
flipud=False,
**kwargs,
):
"""
Wrapper for Hugging Face datasets, where lensless images are simulated from lensed ones.
This is used for seeing how simulated lensless images compare with real ones.
"""

if isinstance(split, str):
if n_files is not None:
split = f"{split}[0:{n_files}]"
self.dataset = load_dataset(huggingface_repo, split=split, cache_dir=cache_dir)
elif isinstance(split, Dataset):
self.dataset = split
else:
raise ValueError("split should be a string or a Dataset object")

# deduce downsampling factor from the first image
data_0 = self.dataset[0]
self.downsample = downsample
# -- use lensless data just for shape but using lensed data in simulation
lensless = np.array(data_0["lensless"])
self.lensless_shape = np.array(lensless.shape[:2]) // self.downsample

# download PSF from huggingface
# TODO : assuming psf is not None
self.multimask = False
psf_fp = hf_hub_download(repo_id=huggingface_repo, filename=psf, repo_type="dataset")
psf, _ = load_psf(
psf_fp,
shape=self.lensless_shape,
return_float=True,
return_bg=True,
flip_ud=flipud,
bg_pix=(0, 15),
single_psf=single_channel_psf,
)
self.psf = torch.from_numpy(psf)
if single_channel_psf:
# replicate across three channels
self.psf = self.psf.repeat(1, 1, 1, 3)

# TODO create convolver object
self.convolver = RealFFTConvolve2D(self.psf)
self.crop = None
self.random_flip = None
self.flipud = flipud

super(HFSimulated, self).__init__(**kwargs)

def __len__(self):
return len(self.dataset)

def _get_images_pair(self, idx):

# load image
lensed_np = np.array(self.dataset[idx]["lensed"])
if self.flipud:
lensed_np = np.flipud(lensed_np)

# convert to float
if lensed_np.dtype == np.uint8:
lensed_np = lensed_np.astype(np.float32) / 255
else:
# 16 bit
lensed_np = lensed_np.astype(np.float32) / 65535

# resize if necessary
if (self.lensless_shape != np.array(lensed_np.shape[:2])).any():
lensed_np = resize(
lensed_np, shape=self.lensless_shape, interpolation=cv2.INTER_NEAREST
)
lensed = torch.from_numpy(lensed_np)

# simulate lensless with convolution
lensed = lensed.unsqueeze(0) # add batch dimension
lensless = self.convolver.convolve(lensed)
if lensless.max() > 1:
print("CLIPPING!")
lensless /= lensless.max()

return lensless, lensed


class HFDataset(DualDataset):
def __init__(
self,
Expand Down Expand Up @@ -1163,7 +1259,7 @@ def __init__(
shape=lensless.shape,
return_float=True,
return_bg=True,
flip=rotate,
flip=self.rotate,
flip_ud=flipud,
bg_pix=(0, 15),
force_rgb=force_rgb,
Expand Down Expand Up @@ -1195,7 +1291,7 @@ def __init__(
sensor=sensor,
slm=slm,
downsample=downsample_fact,
flipud=rotate or flipud, # TODO separate commands?
flipud=self.rotate or flipud, # TODO separate commands?
use_waveprop=simulation_config.get("use_waveprop", False),
scene2mask=simulation_config.get("scene2mask", None),
mask2sensor=simulation_config.get("mask2sensor", None),
Expand Down Expand Up @@ -1223,7 +1319,7 @@ def __init__(
sensor=sensor,
slm=slm,
downsample=downsample_fact,
flipud=rotate or flipud, # TODO separate commands?
flipud=self.rotate or flipud, # TODO separate commands?
use_waveprop=simulation_config.get("use_waveprop", False),
scene2mask=simulation_config.get("scene2mask", None),
mask2sensor=simulation_config.get("mask2sensor", None),
Expand Down
58 changes: 37 additions & 21 deletions scripts/recon/train_learning_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
HFDataset,
MyDataParallel,
simulate_dataset,
HFSimulated,
)
from torch.utils.data import Subset
from lensless.recon.utils import create_process_network
Expand Down Expand Up @@ -211,26 +212,41 @@ def train_learned(config):
dataset, [train_size, test_size], generator=generator
)

train_set = HFDataset(
huggingface_repo=config.files.dataset,
cache_dir=config.files.cache_dir,
psf=config.files.huggingface_psf,
single_channel_psf=config.files.single_channel_psf,
split=split_train,
display_res=config.files.image_res,
rotate=config.files.rotate,
flipud=config.files.flipud,
flip_lensed=config.files.flip_lensed,
downsample=config.files.downsample,
downsample_lensed=config.files.downsample_lensed,
alignment=config.alignment,
save_psf=config.files.save_psf,
n_files=config.files.n_files,
simulation_config=config.simulation,
force_rgb=config.files.force_rgb,
simulate_lensless=config.files.simulate_lensless,
random_flip=config.files.random_flip,
)
if config.files.hf_simulated:
# simulate lensless by using measured PSF
train_set = HFSimulated(
huggingface_repo=config.files.dataset,
split=split_train,
n_files=config.files.n_files,
psf=config.files.huggingface_psf,
downsample=config.files.downsample,
cache_dir=config.files.cache_dir,
single_channel_psf=config.files.single_channel_psf,
flipud=config.files.flipud,
)

else:
train_set = HFDataset(
huggingface_repo=config.files.dataset,
cache_dir=config.files.cache_dir,
psf=config.files.huggingface_psf,
single_channel_psf=config.files.single_channel_psf,
split=split_train,
display_res=config.files.image_res,
rotate=config.files.rotate,
flipud=config.files.flipud,
flip_lensed=config.files.flip_lensed,
downsample=config.files.downsample,
downsample_lensed=config.files.downsample_lensed,
alignment=config.alignment,
save_psf=config.files.save_psf,
n_files=config.files.n_files,
simulation_config=config.simulation,
force_rgb=config.files.force_rgb,
simulate_lensless=config.files.simulate_lensless,
random_flip=config.files.random_flip,
)

test_set = HFDataset(
huggingface_repo=config.files.dataset,
cache_dir=config.files.cache_dir,
Expand All @@ -249,8 +265,8 @@ def train_learned(config):
simulation_config=config.simulation,
force_rgb=config.files.force_rgb,
simulate_lensless=False, # in general evaluate on measured (set to False)
# random_flip=config.files.random_flip, # shouldn't flip test set, just for testing
)

if train_set.multimask:
# get first PSF for initialization
if device_ids is not None:
Expand Down

0 comments on commit 4d8cd85

Please sign in to comment.