diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 029c6d22..7dd4d036 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -21,12 +21,13 @@ Added - ``lensless.utils.dataset.simulate_dataset`` for simulating a dataset given a mask/PSF. - Support for training/testing with multiple mask patterns in the dataset. - Multi-GPU support for training. -- DigiCam dataset which interfaces with Hugging Face. +- Dataset which interfaces with Hugging Face (``lensless.utils.dataset.HFDataset``). - Scripts for authentication. - DigiCam support for Telegram demo. - DiffuserCamMirflickr Hugging Face API. - Fallback for normalization if data not in 8bit range (``lensless.utils.io.save_image``). - Add utilities for fabricating masks with 3D printing (``lensless.hardware.fabrication``). +- WandB support. Changed ~~~~~~~ @@ -151,7 +152,7 @@ Added - Option to warm-start reconstruction algorithm with ``initial_est``. - TrainableReconstructionAlgorithm class inherited from ReconstructionAlgorithm and torch.module for use with pytorch autograd and optimizers. - Unrolled version of FISTA and ADMM as TrainableReconstructionAlgorithm with learnable parameters. -- ``train_unrolled.py`` script for training unrolled algorithms. +- ``train_learning_based.py`` script for training unrolled algorithms. - ``benchmark_recon.py`` script for benchmarking and comparing reconstruction algorithms. - Added ``reconstruction_error`` to ``ReconstructionAlgorithm`` . - Added support for npy/npz image in load_image. diff --git a/README.rst b/README.rst index 419295ad..532e0940 100644 --- a/README.rst +++ b/README.rst @@ -45,7 +45,7 @@ The toolkit includes: * Measurement scripts (`link `__). * Dataset preparation and loading tools, with `Hugging Face `__ integration (`slides `__ on uploading a dataset to Hugging Face with `this script `__). * `Reconstruction algorithms `__ (e.g. FISTA, ADMM, unrolled algorithms, trainable inversion, pre- and post-processors). -* `Training script `__ for learning-based reconstruction. +* `Training script `__ for learning-based reconstruction. * `Pre-trained models `__ that can be loaded from `Hugging Face `__, for example in `this script `__. * Mask `design `__ and `fabrication `__ tools. * `Simulation tools `__. diff --git a/configs/fine-tune_PSF.yaml b/configs/fine-tune_PSF.yaml index c7ff09c9..d0835cba 100644 --- a/configs/fine-tune_PSF.yaml +++ b/configs/fine-tune_PSF.yaml @@ -1,4 +1,4 @@ -# python scripts/recon/train_unrolled.py -cn fine-tune_PSF +# python scripts/recon/train_learning_based.py -cn fine-tune_PSF defaults: - train_unrolledADMM - _self_ @@ -12,25 +12,7 @@ trainable_mask: #Training training: - save_every: 10 - epoch: 50 - crop_preloss: False + save_every: 1 # to see how PSF evolves display: gamma: 2.2 - -reconstruction: - method: unrolled_admm - - pre_process: - network: UnetRes - depth: 2 - post_process: - network: DruNet - depth: 4 - -optimizer: - slow_start: 0.01 - -loss: l2 -lpips: 1.0 diff --git a/configs/train_celeba_digicam_hitl.yaml b/configs/train_celeba_digicam_hitl.yaml index 8129973b..046c5962 100644 --- a/configs/train_celeba_digicam_hitl.yaml +++ b/configs/train_celeba_digicam_hitl.yaml @@ -1,7 +1,7 @@ # Learn mask with HITL training by setting measure configuration (set to null for learning in simulation) # # EXAMPLE COMMAND: -# python scripts/recon/train_unrolled.py -cn train_celeba_digicam_hitl measure.rpi_username=USERNAME measure.rpi_hostname=HOSTNAME files.vertical_shift=SHIFT +# python scripts/recon/train_learning_based.py -cn train_celeba_digicam_hitl measure.rpi_username=USERNAME measure.rpi_hostname=HOSTNAME files.vertical_shift=SHIFT defaults: - train_celeba_digicam diff --git a/configs/train_celeba_digicam_mask.yaml b/configs/train_celeba_digicam_mask.yaml index 8dfd7f73..ba34ed46 100644 --- a/configs/train_celeba_digicam_mask.yaml +++ b/configs/train_celeba_digicam_mask.yaml @@ -1,5 +1,5 @@ # fine-tune mask for PSF, but don't re-simulate -# python scripts/recon/train_unrolled.py -cn train_celeba_digicam_mask +# python scripts/recon/train_learning_based.py -cn train_celeba_digicam_mask defaults: - train_celeba_digicam - _self_ diff --git a/configs/train_coded_aperture.yaml b/configs/train_coded_aperture.yaml index ea39b6ab..a0889435 100644 --- a/configs/train_coded_aperture.yaml +++ b/configs/train_coded_aperture.yaml @@ -1,4 +1,4 @@ -# python scripts/recon/train_unrolled.py -cn train_coded_aperture +# python scripts/recon/train_learning_based.py -cn train_coded_aperture defaults: - train_unrolledADMM - _self_ diff --git a/configs/train_digicam_celeba.yaml b/configs/train_digicam_celeba.yaml index 4a7d5028..b2724dc9 100644 --- a/configs/train_digicam_celeba.yaml +++ b/configs/train_digicam_celeba.yaml @@ -1,4 +1,4 @@ -# python scripts/recon/train_unrolled.py -cn train_digicam_singlemask +# python scripts/recon/train_learning_based.py -cn train_digicam_celeba defaults: - train_unrolledADMM - _self_ @@ -13,6 +13,7 @@ files: huggingface_psf: "psf_simulated.png" huggingface_dataset: True split_seed: 0 + test_size: 0.15 downsample: 2 rotate: True # if measurement is upside-down save_psf: False @@ -34,14 +35,14 @@ alignment: random_vflip: False random_hflip: False quantize: False - # shifting when there is no files.downsample + # shifting when there is no files to downsample vertical_shift: -117 horizontal_shift: -25 training: batch_size: 4 epoch: 25 - eval_batch_size: 4 + eval_batch_size: 16 crop_preloss: True reconstruction: diff --git a/configs/train_digicam_multimask.yaml b/configs/train_digicam_multimask.yaml index 6011f5f0..e05dda06 100644 --- a/configs/train_digicam_multimask.yaml +++ b/configs/train_digicam_multimask.yaml @@ -1,17 +1,18 @@ -# python scripts/recon/train_unrolled.py -cn train_digicam_multimask +# python scripts/recon/train_learning_based.py -cn train_digicam_multimask defaults: - train_unrolledADMM - _self_ - torch_device: 'cuda:0' device_ids: [0, 1, 2, 3] eval_disp_idx: [1, 2, 4, 5, 9] + # Dataset files: dataset: bezzam/DigiCam-Mirflickr-MultiMask-25K huggingface_dataset: True + huggingface_psf: null downsample: 1 # TODO: these parameters should be in the dataset? image_res: [900, 1200] # used during measurement @@ -55,4 +56,5 @@ reconstruction: post_process: network : UnetRes # UnetRes or DruNet or null depth : 4 # depth of each up/downsampling layer. Ignore if network is DruNet - nc: [32,64,116,128] \ No newline at end of file + nc: [32,64,116,128] + diff --git a/configs/train_digicam_singlemask.yaml b/configs/train_digicam_singlemask.yaml index 69b4e3a2..932d68a8 100644 --- a/configs/train_digicam_singlemask.yaml +++ b/configs/train_digicam_singlemask.yaml @@ -1,4 +1,4 @@ -# python scripts/recon/train_unrolled.py -cn train_digicam_singlemask +# python scripts/recon/train_learning_based.py -cn train_digicam_singlemask defaults: - train_unrolledADMM - _self_ @@ -11,12 +11,13 @@ eval_disp_idx: [1, 2, 4, 5, 9] files: dataset: bezzam/DigiCam-Mirflickr-SingleMask-25K huggingface_dataset: True + huggingface_psf: null downsample: 1 + # TODO: these parameters should be in the dataset? image_res: [900, 1200] # used during measurement rotate: True # if measurement is upside-down save_psf: False - # extra_eval: null extra_eval: multimask: huggingface_repo: bezzam/DigiCam-Mirflickr-MultiMask-25K @@ -26,6 +27,7 @@ files: topright: [80, 100] # height, width height: 200 +# TODO: these parameters should be in the dataset? alignment: # when there is no downsampling topright: [80, 100] # height, width diff --git a/configs/train_pre-post-processing.yaml b/configs/train_pre-post-processing.yaml deleted file mode 100644 index 86b95e86..00000000 --- a/configs/train_pre-post-processing.yaml +++ /dev/null @@ -1,24 +0,0 @@ -# python scripts/recon/train_unrolled.py -cn train_pre-post-processing -defaults: - - train_unrolledADMM - - _self_ - -reconstruction: - method: unrolled_admm - - pre_process: - network: UnetRes - depth: 2 - post_process: - network: DruNet - depth: 4 - -training: - epoch: 50 - crop_preloss: False - -optimizer: - slow_start: 0.01 - -loss: l2 -lpips: 1.0 diff --git a/configs/train_psf_from_scratch.yaml b/configs/train_psf_from_scratch.yaml index 82586751..8e1b0543 100644 --- a/configs/train_psf_from_scratch.yaml +++ b/configs/train_psf_from_scratch.yaml @@ -1,4 +1,4 @@ -# python scripts/recon/train_unrolled.py -cn train_psf_from_scratch +# python scripts/recon/train_learning_based.py -cn train_psf_from_scratch defaults: - train_unrolledADMM - _self_ @@ -6,6 +6,10 @@ defaults: # Train Dataset files: dataset: mnist # Simulated : "mnist", "fashion_mnist", "cifar10", "CelebA". Measure :"DiffuserCam" + huggingface_dataset: False + n_files: 1000 + test_size: 0.15 + celeba_root: /scratch/bezzam downsample: 8 @@ -24,8 +28,6 @@ simulation: object_height: 0.30 training: - crop_preloss: False # crop region for computing loss - batch_size: 8 - epoch: 25 + batch_size: 2 eval_batch_size: 16 save_every: 5 diff --git a/configs/train_unrolledADMM.yaml b/configs/train_unrolledADMM.yaml index d4998e11..47fba326 100644 --- a/configs/train_unrolledADMM.yaml +++ b/configs/train_unrolledADMM.yaml @@ -1,39 +1,52 @@ -# python scripts/recon/train_unrolled.py +# python scripts/recon/train_learning_based.py hydra: job: chdir: True # change to output folder +wandb_project: lensless seed: 0 start_delay: null # Dataset files: - dataset: /scratch/bezzam/DiffuserCam_mirflickr/dataset # Simulated : "mnist", "fashion_mnist", "cifar10", "CelebA". Measure :"DiffuserCam" - celeba_root: null # path to parent directory of CelebA: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html - psf: data/psf/diffusercam_psf.tiff - diffusercam_psf: True - - huggingface_dataset: null - huggingface_psf: null + # -- using local dataset + # dataset: /scratch/bezzam/DiffuserCam_mirflickr/dataset # Simulated : "mnist", "fashion_mnist", "cifar10", "CelebA". Measure :"DiffuserCam" + # celeba_root: null # path to parent directory of CelebA: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html + # psf: data/psf/diffusercam_psf.tiff + # diffusercam_psf: True + + # -- using huggingface dataset + dataset: bezzam/DiffuserCam-Lensless-Mirflickr-Dataset-NORM + huggingface_dataset: True + huggingface_psf: psf.tiff + + # -- train/test split split_seed: null # if null use train/test split from dataset - n_files: null # null to use all for both train/test - downsample: 2 # factor by which to downsample the PSF, note that for DiffuserCam the PSF has 4x the resolution - test_size: 0.15 + test_size: null + # -- processing parameters + downsample: 2 # factor by which to downsample the PSF, note that for DiffuserCam the PSF has 4x the resolution + downsample_lensed: 2 input_snr: null # adding shot noise at input (for measured dataset) at this SNR in dB vertical_shift: null horizontal_shift: null + rotate: False + save_psf: False crop: null # vertical: null # horizontal: null image_res: null # for measured data, what resolution used at screen - extra_eval: null # dict of extra datasets to evaluate on +alignment: null +# topright: null # height, width + # height: null + torch: True torch_device: 'cuda' +device_ids: null # for multi-gpu set list, e.g. [0, 1, 2, 3] measure: null # if measuring data on-the-fly # test set example to visualize at the end of every epoch @@ -130,14 +143,13 @@ simulation: training: batch_size: 8 - epoch: 50 + epoch: 25 eval_batch_size: 10 metric_for_best_model: null # e.g. LPIPS_Vgg, null does test loss save_every: null #In case of instable training skip_NAN: True clip_grad: 1.0 - crop_preloss: False # crop region for computing loss, files.crop should be set optimizer: diff --git a/configs/train_unrolled_pre_post.yaml b/configs/train_unrolled_pre_post.yaml new file mode 100644 index 00000000..82a8b794 --- /dev/null +++ b/configs/train_unrolled_pre_post.yaml @@ -0,0 +1,14 @@ +# python scripts/recon/train_learning_based.py -cn train_unrolled_pre_post +defaults: + - train_unrolledADMM + - _self_ + +reconstruction: + method: unrolled_admm + + pre_process: + network: UnetRes + depth: 2 + post_process: + network: UnetRes + depth: 2 diff --git a/docs/source/dataset.rst b/docs/source/dataset.rst index ad21defb..0a8c503b 100644 --- a/docs/source/dataset.rst +++ b/docs/source/dataset.rst @@ -19,6 +19,26 @@ or measured). :special-members: __init__, __len__ +Measured dataset objects +------------------------ + +.. autoclass:: lensless.utils.dataset.HFDataset + :members: + :special-members: __init__ + +.. autoclass:: lensless.utils.dataset.MeasuredDataset + :members: + :special-members: __init__ + +.. autoclass:: lensless.utils.dataset.MeasuredDatasetSimulatedOriginal + :members: + :special-members: __init__ + +.. autoclass:: lensless.utils.dataset.DiffuserCamTestDataset + :members: + :special-members: __init__ + + Simulated dataset objects ------------------------- @@ -43,19 +63,3 @@ mask / PSF. .. autoclass:: lensless.utils.dataset.SimulatedDatasetTrainableMask :members: :special-members: __init__ - - -Measured dataset objects ------------------------- - -.. autoclass:: lensless.utils.dataset.MeasuredDataset - :members: - :special-members: __init__ - -.. autoclass:: lensless.utils.dataset.MeasuredDatasetSimulatedOriginal - :members: - :special-members: __init__ - -.. autoclass:: lensless.utils.dataset.DiffuserCamTestDataset - :members: - :special-members: __init__ diff --git a/lensless/eval/benchmark.py b/lensless/eval/benchmark.py index 3bb7e25b..8df388c1 100644 --- a/lensless/eval/benchmark.py +++ b/lensless/eval/benchmark.py @@ -13,6 +13,7 @@ from tqdm import tqdm import os import numpy as np +import wandb try: import torch @@ -37,6 +38,9 @@ def benchmark( unrolled_output_factor=False, return_average=True, snr=None, + use_wandb=False, + label=None, + epoch=None, **kwargs, ): """ @@ -116,6 +120,9 @@ def benchmark( if dataset.multimask: lensless, lensed, psfs = batch psfs = psfs.to(device) + else: + lensless, lensed = batch + psfs = None else: lensless, lensed = batch psfs = None @@ -176,7 +183,15 @@ def benchmark( prediction_np = prediction.cpu().numpy()[i] # switch to [H, W, C] for saving prediction_np = np.moveaxis(prediction_np, 0, -1) - save_image(prediction_np, fp=os.path.join(output_dir, f"{_batch_idx}.png")) + fp = os.path.join(output_dir, f"{_batch_idx}.png") + save_image(prediction_np, fp=fp) + + if use_wandb: + assert epoch is not None, "epoch must be provided for wandb logging" + log_key = ( + f"{_batch_idx}_{label}" if label is not None else f"{_batch_idx}" + ) + wandb.log({log_key: wandb.Image(fp)}, step=epoch) # normalization prediction_max = torch.amax(prediction, dim=(-1, -2, -3), keepdim=True) @@ -198,24 +213,27 @@ def benchmark( .item() ) else: - if "LPIPS" in metric: - if prediction.shape[1] == 1: - # LPIPS needs 3 channels - metrics_values[metric].append( - metrics[metric]( - prediction.repeat(1, 3, 1, 1), lensed.repeat(1, 3, 1, 1) + try: + if "LPIPS" in metric: + if prediction.shape[1] == 1: + # LPIPS needs 3 channels + metrics_values[metric].append( + metrics[metric]( + prediction.repeat(1, 3, 1, 1), lensed.repeat(1, 3, 1, 1) + ) + .cpu() + .item() + ) + else: + metrics_values[metric].append( + metrics[metric](prediction, lensed).cpu().item() ) - .cpu() - .item() - ) else: metrics_values[metric].append( metrics[metric](prediction, lensed).cpu().item() ) - else: - metrics_values[metric].append( - metrics[metric](prediction, lensed).cpu().item() - ) + except Exception as e: + print(f"Error in metric {metric}: {e}") # compute metrics for unrolled output if unrolled_output_factor: diff --git a/lensless/recon/utils.py b/lensless/recon/utils.py index 7c6a60fb..80349001 100644 --- a/lensless/recon/utils.py +++ b/lensless/recon/utils.py @@ -6,7 +6,7 @@ # Eric BEZZAM [ebezzam@gmail.com] # ############################################################################# - +import wandb import json import math import numpy as np @@ -302,6 +302,7 @@ def __init__( clip_grad=1.0, unrolled_output_factor=False, extra_eval_sets=None, + use_wandb=False, # for adding components during training pre_process=None, pre_process_delay=None, @@ -382,6 +383,8 @@ def __init__( """ global print + self.use_wandb = use_wandb + self.device = recon._psf.device self.logger = logger if self.logger is not None: @@ -440,6 +443,7 @@ def __init__( self.simulated_dataset_trainable_mask = True self.mask = mask + self.gamma = gamma if mask is not None: assert isinstance(mask, TrainableMask) self.use_mask = True @@ -449,11 +453,18 @@ def __init__( # save original PSF psf_np = self.mask.get_psf().detach().cpu().numpy()[0, ...] psf_np = psf_np.squeeze() # remove (potential) singleton color channel - np.save(os.path.join("psf_original.npy"), psf_np) - save_image(psf_np, os.path.join("psf_original.png")) + np.save("psf_original.npy", psf_np) + fp = "psf_original.png" + save_image(psf_np, fp) + plot_image(psf_np, gamma=self.gamma) + fp_plot = "psf_original_plot.png" + plt.savefig(fp_plot) + + if self.use_wandb: + wandb.log({"psf": wandb.Image(fp)}, step=0) + wandb.log({"psf_plot": wandb.Image(fp_plot)}, step=0) self.l1_mask = l1_mask - self.gamma = gamma # loss if loss == "l2": @@ -490,6 +501,7 @@ def __init__( self.optimizer_config = optimizer self.set_optimizer() + # metrics self.metrics = { "LOSS": [], # train loss "LOSS_TEST": [], # test loss @@ -798,10 +810,14 @@ def evaluate(self, mean_loss, epoch, disp=None): output_dir=output_dir, crop=self.crop, unrolled_output_factor=self.unrolled_output_factor, + use_wandb=self.use_wandb, + epoch=epoch, ) # update metrics with current metrics self.metrics["LOSS"].append(mean_loss) + if self.use_wandb: + wandb.log({"LOSS": mean_loss}, step=epoch) for key in current_metrics: self.metrics[key].append(current_metrics[key]) @@ -824,8 +840,11 @@ def evaluate(self, mean_loss, epoch, disp=None): eval_loss = current_metrics[self.metrics["metric_for_best_model"]] self.metrics["LOSS_TEST"].append(eval_loss) + if self.use_wandb: + wandb.log({"LOSS_TEST": eval_loss}, step=epoch) # add extra evaluation sets + extra_metrics_epoch = {} if self.extra_eval_sets is not None: for eval_set in self.extra_eval_sets: @@ -852,6 +871,9 @@ def evaluate(self, mean_loss, epoch, disp=None): output_dir=output_dir, crop=self.crop, unrolled_output_factor=self.unrolled_output_factor, + use_wandb=self.use_wandb, + label=eval_set, + epoch=epoch, ) # add metrics to dictionary @@ -860,12 +882,19 @@ def evaluate(self, mean_loss, epoch, disp=None): self.metrics[eval_set][key] = [extra_metrics[key]] else: self.metrics[eval_set][key].append(extra_metrics[key]) + extra_metrics_epoch[f"{eval_set}_{key}"] = extra_metrics[key] # set back PSF to original in case changed # TODO: cleaner way? if not self.train_multimask: self.recon._set_psf(self.train_dataset.psf.to(self.device)) + # log metrics to wandb + if self.use_wandb: + wandb.log(current_metrics, step=epoch) + if self.extra_eval_sets is not None: + wandb.log(extra_metrics_epoch, step=epoch) + return eval_loss def on_epoch_end(self, mean_loss, save_pt, epoch, disp=None): @@ -929,7 +958,7 @@ def train(self, n_epoch=1, save_pt=None, disp=None): start_time = time.time() - self.evaluate(-1, epoch=0, disp=disp) + self.evaluate(mean_loss=1, epoch=0, disp=disp) for epoch in range(n_epoch): # add extra components (if specified) @@ -997,9 +1026,16 @@ def save(self, epoch, path="recon", include_optimizer=False): psf_np = self.mask.get_psf().detach().cpu().numpy()[0, ...] psf_np = psf_np.squeeze() # remove (potential) singleton color channel np.save(os.path.join(path, f"psf_epoch{epoch}.npy"), psf_np) - save_image(psf_np, os.path.join(path, f"psf_epoch{epoch}.png")) + fp = os.path.join(path, f"psf_epoch{epoch}.png") + save_image(psf_np, fp) plot_image(psf_np, gamma=self.gamma) - plt.savefig(os.path.join(path, f"psf_epoch{epoch}_plot.png")) + fp_plot = os.path.join(path, f"psf_epoch{epoch}_plot.png") + plt.savefig(fp_plot) + + if self.use_wandb and epoch != "BEST": + wandb.log({"psf": wandb.Image(fp)}, step=epoch) + wandb.log({"psf_plot": wandb.Image(fp_plot)}, step=epoch) + if epoch == "BEST": # save difference with original PSF psf_original = np.load("psf_original.npy") diff --git a/lensless/utils/dataset.py b/lensless/utils/dataset.py index 803d4166..5a01e770 100644 --- a/lensless/utils/dataset.py +++ b/lensless/utils/dataset.py @@ -1016,42 +1016,85 @@ def _get_images_pair(self, idx): return lensless, lensed -class DigiCam(DualDataset): +class HFDataset(DualDataset): def __init__( self, huggingface_repo, split, + n_files=None, psf=None, + rotate=False, # just the lensless image + downsample=1, + downsample_lensed=1, display_res=None, sensor="rpi_hq", slm="adafruit", - rotate=False, - downsample=1, alignment=None, - save_psf=False, - simulation_config=None, return_mask_label=False, + save_psf=False, **kwargs, ): + """ + Wrapper for lensless datasets on Hugging Face. + + Parameters + ---------- + huggingface_repo : str + Hugging Face repository ID. + split : str or :py:class:`torch.utils.data.Dataset` + Split of the dataset to use: 'train', 'test', or 'all'. If a Dataset object is given, it is used directly. + n_files : int, optional + Number of files to load from the dataset, by default None, namely all. + psf : str, optional + File name of the PSF at the repository. If None, it is assumed that there is a mask pattern from which the PSF is simulated, by default None. + rotate : bool, optional + If True, lensless images and PSF are rotated 180 degrees. Lensed/original image is not rotated! By default False. + downsample : float, optional + Downsample factor of the lensless images, by default 1. + downsample_lensed : float, optional + Downsample factor of the lensed images, by default 1. + display_res : tuple, optional + Resolution of images when displayed on screen during measurement. + sensor : str, optional + If `psf` not provided, the sensor to use for the PSF simulation, by default "rpi_hq". + slm : str, optional + If `psf` not provided, the SLM to use for the PSF simulation, by default "adafruit". + alignment : dict, optional + Alignment parameters between lensless and lensed data. + If "topright", "height", and "width" are provided, the region-of-interest from the reconstruction of ``lensless`` is extracted and ``lensed`` is reshaped to match. + If "crop" is provided, the region-of-interest is extracted from the simulated lensed image, namely a ``simulation`` configuration should be provided within ``alignment``. + return_mask_label : bool, optional + If multimask dataset, return the mask label (True) or the corresponding PSF (False). + save_psf : bool, optional + If multimask dataset, save the simulated PSFs. + + """ if isinstance(split, str): + if n_files is not None: + split = f"{split}[0:{n_files}]" self.dataset = load_dataset(huggingface_repo, split=split) elif isinstance(split, Dataset): self.dataset = split else: raise ValueError("split should be a string or a Dataset object") + self.rotate = rotate self.display_res = display_res self.return_mask_label = return_mask_label - # deduce downsampling factor from measurement + # deduce downsampling factor from the first image data_0 = self.dataset[0] self.downsample_lensless = downsample + self.downsample_lensed = downsample_lensed lensless = np.array(data_0["lensless"]) if self.downsample_lensless != 1.0: lensless = resize(lensless, factor=1 / self.downsample_lensless) - sensor_res = sensor_dict[sensor][SensorParam.RESOLUTION] - downsample_fact = min(sensor_res / lensless.shape[:2]) + if psf is None: + sensor_res = sensor_dict[sensor][SensorParam.RESOLUTION] + downsample_fact = min(sensor_res / lensless.shape[:2]) + else: + downsample_fact = 1 # deduce recon shape from original image self.alignment = None @@ -1071,6 +1114,7 @@ def __init__( # preparing ground-truth as simulated measurement of original elif "crop" in alignment: + assert "simulation" in alignment, "Simulation config should be provided" self.crop = dict(alignment["crop"].copy()) self.crop["vertical"][0] = int(self.crop["vertical"][0] / downsample) self.crop["vertical"][1] = int(self.crop["vertical"][1] / downsample) @@ -1085,7 +1129,7 @@ def __init__( psf_fp = hf_hub_download(repo_id=huggingface_repo, filename=psf, repo_type="dataset") psf, _ = load_psf( psf_fp, - downsample=downsample_fact, + shape=lensless.shape, return_float=True, return_bg=True, flip=rotate, @@ -1161,7 +1205,7 @@ def __init__( if "horizontal_shift" in simulation_config: self.horizontal_shift = int(simulation_config["horizontal_shift"] / downsample) - super(DigiCam, self).__init__(**kwargs) + super(HFDataset, self).__init__(**kwargs) def __len__(self): return len(self.dataset) @@ -1213,6 +1257,12 @@ def _get_images_pair(self, idx): lensed = resize( lensed_np, shape=(*self.display_res, 3), interpolation=cv2.INTER_NEAREST ) + elif self.downsample_lensed != 1.0: + lensed = resize( + lensed_np, + factor=1 / self.downsample_lensed, + interpolation=cv2.INTER_NEAREST, + ) return lensless, lensed @@ -1265,7 +1315,7 @@ def simulate_dataset(config, generator=None): Parameters ---------- config : omegaconf.DictConfig - Configuration, e.g. from Hydra. See ``scripts/recon/train_unrolled.py`` for an example that uses this function. + Configuration, e.g. from Hydra. See ``scripts/recon/train_learning_based.py`` for an example that uses this function. generator : torch.Generator, optional Random number generator, by default ``None``. """ diff --git a/lensless/utils/io.py b/lensless/utils/io.py index ffd62d91..47fd94f4 100644 --- a/lensless/utils/io.py +++ b/lensless/utils/io.py @@ -348,7 +348,7 @@ def load_psf( bg = np.array(bg) # resize - if downsample != 1: + if downsample != 1 or shape is not None: psf = resize(psf, shape=shape, factor=1 / downsample) if single_psf: diff --git a/recon_requirements.txt b/recon_requirements.txt index 0a2ff942..78dd418d 100644 --- a/recon_requirements.txt +++ b/recon_requirements.txt @@ -9,4 +9,6 @@ waveprop>=0.0.10 # for simulation torch >= 2.0.0 torchvision torchmetrics -lpips \ No newline at end of file +lpips +wandb +datasets \ No newline at end of file diff --git a/scripts/data/authenticate.py b/scripts/data/authenticate.py index 14f1d97b..9f71819c 100644 --- a/scripts/data/authenticate.py +++ b/scripts/data/authenticate.py @@ -29,7 +29,7 @@ """ -from lensless.utils.dataset import DigiCam +from lensless.utils.dataset import HFDataset import torch from lensless import ADMM from lensless.utils.image import rgb2gray @@ -67,14 +67,14 @@ def authen(config): # load multimask dataset if split == "all": - train_set = DigiCam( + train_set = HFDataset( huggingface_repo=huggingface_repo, split="train", rotate=rotate, downsample=downsample, return_mask_label=True, ) - test_set = DigiCam( + test_set = HFDataset( huggingface_repo=huggingface_repo, split="test", rotate=rotate, @@ -114,7 +114,7 @@ def authen(config): file_idx += list(np.arange(n_train_psf) + i * n_train_psf + test_files_offet) else: - all_set = DigiCam( + all_set = HFDataset( huggingface_repo=huggingface_repo, split=split, rotate=rotate, diff --git a/scripts/eval/benchmark_recon.py b/scripts/eval/benchmark_recon.py index 1e45971d..ece0bcfa 100644 --- a/scripts/eval/benchmark_recon.py +++ b/scripts/eval/benchmark_recon.py @@ -26,7 +26,7 @@ from lensless.eval.benchmark import benchmark import matplotlib.pyplot as plt from lensless import ADMM, FISTA, GradientDescent, NesterovGradientDescent -from lensless.utils.dataset import DiffuserCamTestDataset, DigiCamCelebA, DigiCam +from lensless.utils.dataset import DiffuserCamTestDataset, DigiCamCelebA, HFDataset from lensless.utils.io import save_image import torch @@ -85,7 +85,7 @@ def benchmark_recon(config): dataset, [train_size, test_size], generator=generator ) elif dataset == "DigiCamHF": - benchmark_dataset = DigiCam( + benchmark_dataset = HFDataset( huggingface_repo=config.huggingface.repo, split="test", display_res=config.huggingface.image_res, diff --git a/scripts/recon/dataset.py b/scripts/recon/dataset.py index e14f4ecd..906508db 100644 --- a/scripts/recon/dataset.py +++ b/scripts/recon/dataset.py @@ -35,7 +35,7 @@ from tqdm import tqdm from joblib import Parallel, delayed import numpy as np -from lensless.utils.dataset import DiffuserCamMirflickrHF, DigiCam +from lensless.utils.dataset import DiffuserCamMirflickrHF, HFDataset from lensless.eval.metric import psnr, lpips from lensless.utils.image import resize @@ -47,7 +47,7 @@ def recon_dataset(config): if config.dataset == "diffusercam": dataset = DiffuserCamMirflickrHF(split=config.split, downsample=config.downsample) else: - dataset = DigiCam( + dataset = HFDataset( huggingface_repo=config.dataset, split=config.split, downsample=config.downsample, diff --git a/scripts/recon/digicam_mirflickr.py b/scripts/recon/digicam_mirflickr.py index 88a6a036..60411fd0 100644 --- a/scripts/recon/digicam_mirflickr.py +++ b/scripts/recon/digicam_mirflickr.py @@ -3,7 +3,7 @@ import torch from lensless import ADMM from lensless.utils.plot import plot_image -from lensless.utils.dataset import DigiCam +from lensless.utils.dataset import HFDataset import os from lensless.utils.io import save_image import time @@ -35,7 +35,7 @@ def apply_pretrained(config): model_config = yaml.safe_load(stream) # load dataset - test_set = DigiCam( + test_set = HFDataset( huggingface_repo=model_config["files"]["dataset"], psf=model_config["files"]["huggingface_psf"] if "huggingface_psf" in model_config["files"] diff --git a/scripts/recon/train_unrolled.py b/scripts/recon/train_learning_based.py similarity index 89% rename from scripts/recon/train_unrolled.py rename to scripts/recon/train_learning_based.py index 95057aca..9ad7a016 100644 --- a/scripts/recon/train_unrolled.py +++ b/scripts/recon/train_learning_based.py @@ -1,6 +1,6 @@ # ############################################################################# -# train_unrolled.py -# ================= +# train_learning_based.py +# ======================= # Authors : # Yohann PERRON [yohann.perron@gmail.com] # Eric BEZZAM [ebezzam@gmail.com] @@ -10,28 +10,25 @@ Train unrolled version of reconstruction algorithm. ``` -python scripts/recon/train_unrolled.py +python scripts/recon/train_learning_based.py ``` By default it uses the configuration from the file `configs/train_unrolledADMM.yaml`. To train pre- and post-processing networks, use the following command: ``` -python scripts/recon/train_unrolled.py -cn train_pre-post-processing +python scripts/recon/train_learning_based.py -cn train_unrolled_pre_post ``` To fine-tune the DiffuserCam PSF, use the following command: ``` -python scripts/recon/train_unrolled.py -cn fine-tune_PSF +python scripts/recon/train_learning_based.py -cn fine-tune_PSF ``` -To train a PSF from scratch with a simulated dataset, use the following command: -``` -python scripts/recon/train_unrolled.py -cn train_psf_from_scratch -``` """ +import wandb import logging import hydra from hydra.utils import get_original_cwd @@ -43,7 +40,7 @@ from lensless.utils.dataset import ( DiffuserCamMirflickr, DigiCamCelebA, - DigiCam, + HFDataset, MyDataParallel, simulate_dataset, ) @@ -60,7 +57,16 @@ @hydra.main(version_base=None, config_path="../../configs", config_name="train_unrolledADMM") -def train_unrolled(config): +def train_learned(config): + + if config.wandb_project is not None: + # start a new wandb run to track this script + wandb.init( + # set the wandb project where this run will be logged + project=config.wandb_project, + # track hyperparameters and run metadata + config=dict(config), + ) # set seed seed = config.seed @@ -83,13 +89,15 @@ def train_unrolled(config): use_cuda = False if "cuda" in config.torch_device and torch.cuda.is_available(): # if config.torch_device == "cuda" and torch.cuda.is_available(): - log.info("Using GPU for training.") + log.info(f"Using GPU for training. Main device : {config.torch_device}") device = config.torch_device use_cuda = True else: log.info("Using CPU for training.") device = "cpu" device_ids = config.device_ids + if device_ids is not None: + log.info(f"Using multiple GPUs : {device_ids}") # load dataset and create dataloader train_set = None @@ -98,7 +106,7 @@ def train_unrolled(config): crop = None alignment = None # very similar to crop, TODO: should switch to this approach mask = None - if "DiffuserCam" in config.files.dataset: + if "DiffuserCam" in config.files.dataset and config.files.huggingface_dataset is False: original_path = os.path.join(get_original_cwd(), config.files.dataset) psf_path = os.path.join(get_original_cwd(), config.files.psf) @@ -122,15 +130,6 @@ def train_unrolled(config): # -- if learning mask mask = prep_trainable_mask(config, dataset.psf) - if mask is not None: - # plot initial PSF - psf_np = mask.get_psf().detach().cpu().numpy()[0, ...] - if config.trainable_mask.grayscale: - psf_np = psf_np[:, :, -1] - - save_image(psf_np, os.path.join(save, "psf_initial.png")) - plot_image(psf_np, gamma=config.display.gamma) - plt.savefig(os.path.join(save, "psf_initial_plot.png")) psf = dataset.psf @@ -167,17 +166,7 @@ def train_unrolled(config): mask = prep_trainable_mask(config, dataset.psf, downsample=downsample) if mask is not None: - # plot initial PSF - with torch.no_grad(): - psf_np = mask.get_psf().detach().cpu().numpy()[0, ...] - if config.trainable_mask.grayscale: - psf_np = psf_np[:, :, -1] - - save_image(psf_np, os.path.join(save, "psf_initial.png")) - plot_image(psf_np, gamma=config.display.gamma) - plt.savefig(os.path.join(save, "psf_initial_plot.png")) - - # save original PSF as well + # save original PSF psf_meas = dataset.psf.detach().cpu().numpy()[0, ...] plot_image(psf_meas, gamma=config.display.gamma) plt.savefig(os.path.join(save, "psf_meas_plot.png")) @@ -200,8 +189,13 @@ def train_unrolled(config): generator = torch.Generator().manual_seed(seed) # - combine train and test into single dataset - train_dataset = load_dataset(config.files.dataset, split="train") - test_dataset = load_dataset(config.files.dataset, split="test") + train_split = "train" + test_split = "test" + if config.files.n_files is not None: + train_split = f"train[:{config.files.n_files}]" + test_split = f"test[:{config.files.n_files}]" + train_dataset = load_dataset(config.files.dataset, split=train_split) + test_dataset = load_dataset(config.files.dataset, split=test_split) dataset = concatenate_datasets([test_dataset, train_dataset]) # - split into train and test @@ -211,29 +205,36 @@ def train_unrolled(config): dataset, [train_size, test_size], generator=generator ) - train_set = DigiCam( + train_set = HFDataset( huggingface_repo=config.files.dataset, psf=config.files.huggingface_psf, split=split_train, display_res=config.files.image_res, rotate=config.files.rotate, downsample=config.files.downsample, + downsample_lensed=config.files.downsample_lensed, alignment=config.alignment, save_psf=config.files.save_psf, + n_files=config.files.n_files, ) - test_set = DigiCam( + test_set = HFDataset( huggingface_repo=config.files.dataset, psf=config.files.huggingface_psf, split=split_test, display_res=config.files.image_res, rotate=config.files.rotate, downsample=config.files.downsample, + downsample_lensed=config.files.downsample_lensed, alignment=config.alignment, save_psf=config.files.save_psf, + n_files=config.files.n_files, ) if train_set.multimask: # get first PSF for initialization - first_psf_key = list(train_set.psf.keys())[device_ids[0]] + if device_ids is not None: + first_psf_key = list(train_set.psf.keys())[device_ids[0]] + else: + first_psf_key = list(train_set.psf.keys())[0] psf = train_set.psf[first_psf_key].to(device) else: psf = train_set.psf.to(device) @@ -244,14 +245,6 @@ def train_unrolled(config): mask = prep_trainable_mask(config, psf) if mask is not None: assert not train_set.multimask - # plot initial PSF - psf_np = mask.get_psf().detach().cpu().numpy()[0, ...] - if config.trainable_mask.grayscale: - psf_np = psf_np[:, :, -1] - - save_image(psf_np, os.path.join(save, "psf_initial.png")) - plot_image(psf_np, gamma=config.display.gamma) - plt.savefig(os.path.join(save, "psf_initial_plot.png")) else: @@ -259,6 +252,11 @@ def train_unrolled(config): psf = train_set.psf crop = train_set.crop + if not hasattr(train_set, "multimask"): + train_set.multimask = False + if not hasattr(test_set, "multimask"): + test_set.multimask = False + assert train_set is not None # if not hasattr(test_set, "psfs"): # assert psf is not None @@ -275,9 +273,10 @@ def train_unrolled(config): extra_eval_sets = dict() for eval_set in config.files.extra_eval: - extra_eval_sets[eval_set] = DigiCam( + extra_eval_sets[eval_set] = HFDataset( split="test", downsample=config.files.downsample, # needs to be same size + n_files=config.files.n_files, **config.files.extra_eval[eval_set], ) @@ -492,6 +491,7 @@ def train_unrolled(config): clip_grad=config.training.clip_grad, unrolled_output_factor=config.unrolled_output_factor, extra_eval_sets=extra_eval_sets if config.files.extra_eval is not None else None, + use_wandb=True if config.wandb_project is not None else False, ) trainer.train(n_epoch=config.training.epoch, save_pt=save, disp=config.eval_disp_idx) @@ -500,4 +500,4 @@ def train_unrolled(config): if __name__ == "__main__": - train_unrolled() + train_learned() diff --git a/setup.py b/setup.py index d1ab6d68..c3d0fc2c 100644 --- a/setup.py +++ b/setup.py @@ -34,6 +34,7 @@ "rawpy>=0.16.0", # less than python 3.12 "paramiko>=3.2.0", "hydra-core", + "slm_controller @ git+https://github.com/ebezzam/slm-controller.git" ], extra_requires={"dev": ["pudb", "black"]}, )