diff --git a/configs/train_unrolledADMM.yaml b/configs/train_unrolledADMM.yaml index e3e34607..f2c8b7a8 100644 --- a/configs/train_unrolledADMM.yaml +++ b/configs/train_unrolledADMM.yaml @@ -23,6 +23,7 @@ files: huggingface_dataset: True huggingface_psf: psf.tiff single_channel_psf: False # whether to sum all PSF channels into one + hf_simulated: False # -- train/test split split_seed: null # if null use train/test split from dataset diff --git a/lensless/utils/dataset.py b/lensless/utils/dataset.py index 42ed13f0..e96451d9 100644 --- a/lensless/utils/dataset.py +++ b/lensless/utils/dataset.py @@ -23,7 +23,8 @@ from lensless.hardware.utils import capture from lensless.hardware.utils import display from lensless.hardware.slm import set_programmable_mask, adafruit_sub2full -from datasets import load_dataset, load_from_disk +from datasets import load_dataset +from lensless.recon.rfft_convolve import RealFFTConvolve2D from huggingface_hub import hf_hub_download import cv2 from lensless.hardware.sensor import sensor_dict, SensorParam @@ -1019,6 +1020,101 @@ def _get_images_pair(self, idx): return lensless, lensed +class HFSimulated(DualDataset): + def __init__( + self, + huggingface_repo, + split, + n_files=None, + psf=None, + downsample=1, + cache_dir=None, + single_channel_psf=False, + flipud=False, + **kwargs, + ): + """ + Wrapper for Hugging Face datasets, where lensless images are simulated from lensed ones. + + This is used for seeing how simulated lensless images compare with real ones. + """ + + if isinstance(split, str): + if n_files is not None: + split = f"{split}[0:{n_files}]" + self.dataset = load_dataset(huggingface_repo, split=split, cache_dir=cache_dir) + elif isinstance(split, Dataset): + self.dataset = split + else: + raise ValueError("split should be a string or a Dataset object") + + # deduce downsampling factor from the first image + data_0 = self.dataset[0] + self.downsample = downsample + # -- use lensless data just for shape but using lensed data in simulation + lensless = np.array(data_0["lensless"]) + self.lensless_shape = np.array(lensless.shape[:2]) // self.downsample + + # download PSF from huggingface + # TODO : assuming psf is not None + self.multimask = False + psf_fp = hf_hub_download(repo_id=huggingface_repo, filename=psf, repo_type="dataset") + psf, _ = load_psf( + psf_fp, + shape=self.lensless_shape, + return_float=True, + return_bg=True, + flip_ud=flipud, + bg_pix=(0, 15), + single_psf=single_channel_psf, + ) + self.psf = torch.from_numpy(psf) + if single_channel_psf: + # replicate across three channels + self.psf = self.psf.repeat(1, 1, 1, 3) + + # TODO create convolver object + self.convolver = RealFFTConvolve2D(self.psf) + self.crop = None + self.random_flip = None + self.flipud = flipud + + super(HFSimulated, self).__init__(**kwargs) + + def __len__(self): + return len(self.dataset) + + def _get_images_pair(self, idx): + + # load image + lensed_np = np.array(self.dataset[idx]["lensed"]) + if self.flipud: + lensed_np = np.flipud(lensed_np) + + # convert to float + if lensed_np.dtype == np.uint8: + lensed_np = lensed_np.astype(np.float32) / 255 + else: + # 16 bit + lensed_np = lensed_np.astype(np.float32) / 65535 + + # resize if necessary + if (self.lensless_shape != np.array(lensed_np.shape[:2])).any(): + lensed_np = resize( + lensed_np, shape=self.lensless_shape, interpolation=cv2.INTER_NEAREST + ) + lensed = torch.from_numpy(lensed_np) + + # simulate lensless with convolution + lensed = lensed.unsqueeze(0) # add batch dimension + lensless = self.convolver.convolve(lensed) + if lensless.max() > 1: + print("CLIPPING!") + lensless /= lensless.max() + + return lensless, lensed + + class HFDataset(DualDataset): def __init__( self, @@ -1163,7 +1259,7 @@ def __init__( shape=lensless.shape, return_float=True, return_bg=True, - flip=rotate, + flip=self.rotate, flip_ud=flipud, bg_pix=(0, 15), force_rgb=force_rgb, @@ -1195,7 +1291,7 @@ def __init__( sensor=sensor, slm=slm, downsample=downsample_fact, - flipud=rotate or flipud, # TODO separate commands? + flipud=self.rotate or flipud, # TODO separate commands? use_waveprop=simulation_config.get("use_waveprop", False), scene2mask=simulation_config.get("scene2mask", None), mask2sensor=simulation_config.get("mask2sensor", None), @@ -1223,7 +1319,7 @@ def __init__( sensor=sensor, slm=slm, downsample=downsample_fact, - flipud=rotate or flipud, # TODO separate commands? + flipud=self.rotate or flipud, # TODO separate commands? use_waveprop=simulation_config.get("use_waveprop", False), scene2mask=simulation_config.get("scene2mask", None), mask2sensor=simulation_config.get("mask2sensor", None), diff --git a/scripts/recon/train_learning_based.py b/scripts/recon/train_learning_based.py index 52fd091c..f570e475 100644 --- a/scripts/recon/train_learning_based.py +++ b/scripts/recon/train_learning_based.py @@ -44,6 +44,7 @@ HFDataset, MyDataParallel, simulate_dataset, + HFSimulated, ) from torch.utils.data import Subset from lensless.recon.utils import create_process_network @@ -211,26 +212,41 @@ def train_learned(config): dataset, [train_size, test_size], generator=generator ) - train_set = HFDataset( - huggingface_repo=config.files.dataset, - cache_dir=config.files.cache_dir, - psf=config.files.huggingface_psf, - single_channel_psf=config.files.single_channel_psf, - split=split_train, - display_res=config.files.image_res, - rotate=config.files.rotate, - flipud=config.files.flipud, - flip_lensed=config.files.flip_lensed, - downsample=config.files.downsample, - downsample_lensed=config.files.downsample_lensed, - alignment=config.alignment, - save_psf=config.files.save_psf, - n_files=config.files.n_files, - simulation_config=config.simulation, - force_rgb=config.files.force_rgb, - simulate_lensless=config.files.simulate_lensless, - random_flip=config.files.random_flip, - ) + if config.files.hf_simulated: + # simulate lensless by using measured PSF + train_set = HFSimulated( + huggingface_repo=config.files.dataset, + split=split_train, + n_files=config.files.n_files, + psf=config.files.huggingface_psf, + downsample=config.files.downsample, + cache_dir=config.files.cache_dir, + single_channel_psf=config.files.single_channel_psf, + flipud=config.files.flipud, + ) + + else: + train_set = HFDataset( + huggingface_repo=config.files.dataset, + cache_dir=config.files.cache_dir, + psf=config.files.huggingface_psf, + single_channel_psf=config.files.single_channel_psf, + split=split_train, + display_res=config.files.image_res, + rotate=config.files.rotate, + flipud=config.files.flipud, + flip_lensed=config.files.flip_lensed, + downsample=config.files.downsample, + downsample_lensed=config.files.downsample_lensed, + alignment=config.alignment, + save_psf=config.files.save_psf, + n_files=config.files.n_files, + simulation_config=config.simulation, + force_rgb=config.files.force_rgb, + simulate_lensless=config.files.simulate_lensless, + random_flip=config.files.random_flip, + ) + test_set = HFDataset( huggingface_repo=config.files.dataset, cache_dir=config.files.cache_dir, @@ -249,8 +265,8 @@ def train_learned(config): simulation_config=config.simulation, force_rgb=config.files.force_rgb, simulate_lensless=False, # in general evaluate on measured (set to False) - # random_flip=config.files.random_flip, # shouldn't flip test set, just for testing ) + if train_set.multimask: # get first PSF for initialization if device_ids is not None: