From 01a5c376bf329598535c8db4369d82b89b72b0fb Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Fri, 20 Dec 2024 16:00:34 +0000 Subject: [PATCH] Update experiments. --- configs/benchmark.yaml | 1 + configs/benchmark_diffusercam_mirflickr.yaml | 34 ++++-- configs/benchmark_digicam_celeba.yaml | 21 ++-- .../benchmark_digicam_mirflickr_multi.yaml | 10 +- configs/benchmark_digicam_mirflickr_pnp.yaml | 44 ++++++++ .../benchmark_digicam_mirflickr_single.yaml | 27 +++-- configs/benchmark_tapecam_mirflickr.yaml | 29 ++--- configs/recon_digicam_mirflickr.yaml | 3 +- configs/recon_digicam_mirflickr_err.yaml | 12 ++- configs/sim_digicam_psf.yaml | 2 +- configs/train_digicam_celeba.yaml | 3 +- configs/train_mirflickr_diffuser_sim.yaml | 12 +++ configs/train_unrolledADMM.yaml | 5 +- lensless/eval/benchmark.py | 46 ++++++-- lensless/recon/model_dict.py | 39 +++++++ lensless/recon/recon.py | 18 +++- lensless/recon/trainable_recon.py | 22 +++- lensless/utils/dataset.py | 49 ++++----- scripts/eval/benchmark_recon.py | 4 +- scripts/recon/digicam_mirflickr_psf_err.py | 102 +++++++++++++++--- scripts/recon/train_learning_based.py | 5 + 21 files changed, 375 insertions(+), 113 deletions(-) create mode 100644 configs/benchmark_digicam_mirflickr_pnp.yaml create mode 100644 configs/train_mirflickr_diffuser_sim.yaml diff --git a/configs/benchmark.yaml b/configs/benchmark.yaml index ecfcc87d..47c96009 100644 --- a/configs/benchmark.yaml +++ b/configs/benchmark.yaml @@ -44,6 +44,7 @@ baseline: "MONAKHOVA 100iter" save_idx: [0, 1, 2, 3, 4] # provide index of files to save e.g. [1, 5, 10] save_intermediate: False # save intermediate results, i.e. after pre-processor and after camera inversion +swap_channels: False # list of two RGB channels to swap, e.g. [0, 1] for swapping red and green gamma_psf: 1.5 # gamma factor for PSF diff --git a/configs/benchmark_diffusercam_mirflickr.yaml b/configs/benchmark_diffusercam_mirflickr.yaml index e4b97d23..0548bfbe 100644 --- a/configs/benchmark_diffusercam_mirflickr.yaml +++ b/configs/benchmark_diffusercam_mirflickr.yaml @@ -23,16 +23,30 @@ algorithms: [ # "ADMM", ## -- reconstructions trained on DiffuserCam measured - # "hf:diffusercam:mirflickr:U5+Unet8M", + "hf:diffusercam:mirflickr:U5+Unet8M", "hf:diffusercam:mirflickr:Unet8M+U5", - # "hf:diffusercam:mirflickr:TrainInv+Unet8M", - # "hf:diffusercam:mirflickr:MMCN4M+Unet4M", - # "hf:diffusercam:mirflickr:MWDN8M", - # "hf:diffusercam:mirflickr:Unet4M+U5+Unet4M", - # "hf:diffusercam:mirflickr:Unet4M+TrainInv+Unet4M", - # "hf:diffusercam:mirflickr:Unet2M+MMCN+Unet2M", - # "hf:diffusercam:mirflickr:Unet2M+MWDN6M", - # "hf:diffusercam:mirflickr:Unet4M+U10+Unet4M", + "hf:diffusercam:mirflickr:TrainInv+Unet8M", + "hf:diffusercam:mirflickr:MMCN4M+Unet4M", + "hf:diffusercam:mirflickr:MWDN8M", + "hf:diffusercam:mirflickr:Unet4M+U5+Unet4M", + "hf:diffusercam:mirflickr:Unet4M+TrainInv+Unet4M", + "hf:diffusercam:mirflickr:Unet2M+MMCN+Unet2M", + "hf:diffusercam:mirflickr:Unet2M+MWDN6M", + "hf:diffusercam:mirflickr:Unet4M+U10+Unet4M", + "hf:diffusercam:mirflickr:Unet4M+U5+Unet4M_psfNN", + + # # -- benchmark PSF error + # "hf:diffusercam:mirflickr:U5+Unet8M_psf0dB", + # "hf:diffusercam:mirflickr:U5+Unet8M_psf-5dB", + # "hf:diffusercam:mirflickr:U5+Unet8M_psf-10dB", + # "hf:diffusercam:mirflickr:U5+Unet8M_psf-20dB", + # "hf:diffusercam:mirflickr:Unet4M+U5+Unet4M_psf-0dB", + # "hf:diffusercam:mirflickr:Unet4M+U5+Unet4M_psf-5dB", + # "hf:diffusercam:mirflickr:Unet4M+U5+Unet4M_psf-10dB", + # "hf:diffusercam:mirflickr:Unet4M+U5+Unet4M_psf-20dB", + # "hf:diffusercam:mirflickr:Unet4M+U5+Unet4M_psfNN_psf-0dB", + # "hf:diffusercam:mirflickr:Unet4M+U5+Unet4M_psfNN_psf-10dB", + # "hf:diffusercam:mirflickr:Unet4M+U5+Unet4M_psfNN_psf-20dB", # "hf:diffusercam:mirflickr:Unet4M+U5+Unet4M_ft_tapecam", # "hf:diffusercam:mirflickr:Unet4M+U5+Unet4M_ft_tapecam_post", @@ -52,7 +66,9 @@ algorithms: [ # "hf:digicam:mirflickr_single_25k:Unet4M+U5+Unet4M_wave", # "hf:digicam:celeba_26k:Unet4M+U5+Unet4M_wave", # "hf:tapecam:mirflickr:Unet4M+U5+Unet4M", + # "hf:tapecam:mirflickr:Unet4M+U5+Unet4M_psfNN", # "hf:digicam:mirflickr_single_25k:Unet4M+U10+Unet4M_wave", + # "hf:digicam:mirflickr_single_25k:Unet4M+U5+Unet4M_wave_psfNN", # "hf:tapecam:mirflickr:Unet4M+U5+Unet4M_flips", # "hf:tapecam:mirflickr:Unet4M+U5+Unet4M_flips_rotate10", # "hf:tapecam:mirflickr:Unet4M+U5+Unet4M_aux1", diff --git a/configs/benchmark_digicam_celeba.yaml b/configs/benchmark_digicam_celeba.yaml index 17364cf5..16d028bd 100644 --- a/configs/benchmark_digicam_celeba.yaml +++ b/configs/benchmark_digicam_celeba.yaml @@ -6,22 +6,23 @@ defaults: dataset: HFDataset batchsize: 10 -device: "cuda:0" +device: "cuda:1" algorithms: [ # "ADMM", ## -- reconstructions trained on measured data - # "hf:digicam:celeba_26k:U5+Unet8M_wave", + "hf:digicam:celeba_26k:U5+Unet8M_wave", "hf:digicam:celeba_26k:Unet8M+U5_wave", - # "hf:digicam:celeba_26k:TrainInv+Unet8M_wave", - # "hf:digicam:celeba_26k:MWDN8M_wave", - # "hf:digicam:celeba_26k:MMCN4M+Unet4M_wave", - # "hf:digicam:celeba_26k:Unet2M+MWDN6M_wave", - # "hf:digicam:celeba_26k:Unet4M+TrainInv+Unet4M_wave", - # "hf:digicam:celeba_26k:Unet2M+MMCN+Unet2M_wave", - # "hf:digicam:celeba_26k:Unet4M+U5+Unet4M_wave", - # "hf:digicam:celeba_26k:Unet4M+U10+Unet4M_wave", + "hf:digicam:celeba_26k:TrainInv+Unet8M_wave", + "hf:digicam:celeba_26k:MWDN8M_wave", + "hf:digicam:celeba_26k:MMCN4M+Unet4M_wave", + "hf:digicam:celeba_26k:Unet2M+MWDN6M_wave", + "hf:digicam:celeba_26k:Unet4M+TrainInv+Unet4M_wave", + "hf:digicam:celeba_26k:Unet2M+MMCN+Unet2M_wave", + "hf:digicam:celeba_26k:Unet4M+U5+Unet4M_wave", + "hf:digicam:celeba_26k:Unet4M+U5+Unet4M_wave_psfNN", + "hf:digicam:celeba_26k:Unet4M+U10+Unet4M_wave", # # -- reconstructions trained on other datasets/systems # "hf:diffusercam:mirflickr:Unet4M+U10+Unet4M", diff --git a/configs/benchmark_digicam_mirflickr_multi.yaml b/configs/benchmark_digicam_mirflickr_multi.yaml index 1523a3c4..e7e2438a 100644 --- a/configs/benchmark_digicam_mirflickr_multi.yaml +++ b/configs/benchmark_digicam_mirflickr_multi.yaml @@ -25,11 +25,14 @@ algorithms: [ ## -- reconstructions trained on measured data "hf:digicam:mirflickr_multi_25k:Unet4M+U5+Unet4M_wave", - # "hf:digicam:mirflickr_multi_25k:Unet4M+U10+Unet4M_wave", + "hf:digicam:mirflickr_multi_25k:Unet4M+U5+Unet4M_wave_psfNN", + "hf:digicam:mirflickr_multi_25k:Unet4M+U10+Unet4M_wave", "hf:digicam:mirflickr_multi_25k:Unet4M+U5+Unet4M_wave_aux1", - # "hf:digicam:mirflickr_multi_25k:Unet4M+U5+Unet4M_wave_flips", + "hf:digicam:mirflickr_multi_25k:Unet4M+U5+Unet4M_wave_flips", + "hf:digicam:mirflickr_multi_25k:Unet8M_wave_v2", # ## -- reconstructions trained on other datasets/systems + # "hf:digicam:mirflickr_single_25k:Unet4M+U5+Unet4M_wave_psfNN", # "hf:diffusercam:mirflickr:Unet4M+U10+Unet4M", # "hf:tapecam:mirflickr:Unet4M+U10+Unet4M", # "hf:digicam:mirflickr_single_25k:Unet4M+U10+Unet4M_wave", @@ -42,6 +45,9 @@ algorithms: [ # "hf:diffusercam:mirflickr:Unet4M+U5+Unet4M", # "hf:digicam:mirflickr_single_25k:Unet4M+U5+Unet4M_ft_flips", # "hf:digicam:mirflickr_single_25k:Unet4M+U5+Unet4M_ft_flips_rotate10", + # "hf:digicam:mirflickr_single_25k:Unet4M+U5+Unet4M_wave_psfNN", + # "hf:tapecam:mirflickr:Unet4M+U5+Unet4M_psfNN", + # "hf:diffusercam:mirflickr:Unet4M+U5+Unet4M_psfNN", ] # # -- to only use output from unrolled diff --git a/configs/benchmark_digicam_mirflickr_pnp.yaml b/configs/benchmark_digicam_mirflickr_pnp.yaml new file mode 100644 index 00000000..cdd76a84 --- /dev/null +++ b/configs/benchmark_digicam_mirflickr_pnp.yaml @@ -0,0 +1,44 @@ +# python scripts/eval/benchmark_recon.py -cn benchmark_digicam_mirflickr_pnp +defaults: + - benchmark + - _self_ + + +dataset: HFDataset +batchsize: 1 +device: "cuda:0" + +huggingface: + repo: "bezzam/DigiCam-Mirflickr-MultiMask-25K" + psf: null # null for simulating PSF + image_res: [900, 1200] # used during measurement + rotate: True # if measurement is upside-down + flipud: False + flip_lensed: False # if rotate or flipud is True, apply to lensed + alignment: + top_left: [80, 100] # height, width + height: 200 + downsample: 1 + +algorithms: [ + "hf:digicam:mirflickr_single_25k:Unet4M+U5+Unet4M_wave_psfNN", + "hf:digicam:mirflickr_multi_25k:Unet4M+U5+Unet4M_wave_psfNN", +] + +pnp: + mu: 1e-3 # weight for distance from original model parameters + model_path: null # leave null to be overwritten in script + # lr: 1e-2 + # n_iter: 10 + lr: 3e-3 # learning rate for SGD + n_iter: 10 # number of iterations + + +save_idx: [1, 2, 4, 5, 9, 24, 33, 61] + +# simulating PSF +simulation: + use_waveprop: True + deadspace: True + scene2mask: 0.3 + mask2sensor: 0.002 diff --git a/configs/benchmark_digicam_mirflickr_single.yaml b/configs/benchmark_digicam_mirflickr_single.yaml index cffbd7f7..28bdaeca 100644 --- a/configs/benchmark_digicam_mirflickr_single.yaml +++ b/configs/benchmark_digicam_mirflickr_single.yaml @@ -25,18 +25,20 @@ algorithms: [ # "ADMM", # # -- reconstructions trained on measured data - # "hf:digicam:mirflickr_single_25k:U5+Unet8M_wave", + "hf:digicam:mirflickr_single_25k:U5+Unet8M_wave", "hf:digicam:mirflickr_single_25k:Unet8M+U5_wave", - # "hf:digicam:mirflickr_single_25k:TrainInv+Unet8M_wave", - # "hf:digicam:mirflickr_single_25k:MMCN4M+Unet4M_wave", - # "hf:digicam:mirflickr_single_25k:MWDN8M_wave", - # "hf:digicam:mirflickr_single_25k:Unet4M+TrainInv+Unet4M_wave", - # "hf:digicam:mirflickr_single_25k:Unet4M+U5+Unet4M_wave", - # "hf:digicam:mirflickr_single_25k:Unet2M+MMCN+Unet2M_wave", - # "hf:digicam:mirflickr_single_25k:Unet2M+MWDN6M_wave", - # "hf:digicam:mirflickr_single_25k:Unet4M+U10+Unet4M_wave", - # "hf:digicam:mirflickr_single_25k:Unet4M+U5+Unet4M_wave_flips", - # "hf:digicam:mirflickr_single_25k:Unet4M+U5+Unet4M_wave_flips_rotate10", + "hf:digicam:mirflickr_single_25k:TrainInv+Unet8M_wave", + "hf:digicam:mirflickr_single_25k:MMCN4M+Unet4M_wave", + "hf:digicam:mirflickr_single_25k:MWDN8M_wave", + "hf:digicam:mirflickr_single_25k:Unet4M+TrainInv+Unet4M_wave", + "hf:digicam:mirflickr_single_25k:Unet4M+U5+Unet4M_wave", + "hf:digicam:mirflickr_single_25k:Unet2M+MMCN+Unet2M_wave", + "hf:digicam:mirflickr_single_25k:Unet2M+MWDN6M_wave", + "hf:digicam:mirflickr_single_25k:Unet4M+U10+Unet4M_wave", + "hf:digicam:mirflickr_single_25k:Unet4M+U5+Unet4M_wave_psfNN", + "hf:digicam:mirflickr_single_25k:Unet4M+U5+Unet4M_wave_flips", + "hf:digicam:mirflickr_single_25k:Unet4M+U5+Unet4M_wave_flips_rotate10", + "hf:digicam:mirflickr_single_25k:Unet8M_wave_v2", # ## -- reconstructions trained on other datasets/systems # "hf:diffusercam:mirflickr:Unet4M+U10+Unet4M", @@ -46,6 +48,9 @@ algorithms: [ # "hf:tapecam:mirflickr:Unet4M+U5+Unet4M", # "hf:diffusercam:mirflickr:Unet4M+U5+Unet4M", # "hf:digicam:mirflickr_multi_25k:Unet4M+U5+Unet4M_wave", + # "hf:digicam:mirflickr_multi_25k:Unet4M+U5+Unet4M_wave_psfNN", + # "hf:tapecam:mirflickr:Unet4M+U5+Unet4M_psfNN", + # "hf:diffusercam:mirflickr:Unet4M+U5+Unet4M_psfNN", ] save_idx: [1, 2, 4, 5, 9, 24, 33, 61] diff --git a/configs/benchmark_tapecam_mirflickr.yaml b/configs/benchmark_tapecam_mirflickr.yaml index db8e56cc..2ccc6fc5 100644 --- a/configs/benchmark_tapecam_mirflickr.yaml +++ b/configs/benchmark_tapecam_mirflickr.yaml @@ -5,7 +5,7 @@ defaults: dataset: HFDataset batchsize: 4 -device: "cuda:0" +device: "cuda:1" huggingface: repo: "bezzam/TapeCam-Mirflickr-25K" @@ -26,18 +26,19 @@ algorithms: [ # "ADMM", # -- reconstructions trained on measured data - # "hf:tapecam:mirflickr:U5+Unet8M", + "hf:tapecam:mirflickr:U5+Unet8M", "hf:tapecam:mirflickr:Unet8M+U5", - # "hf:tapecam:mirflickr:TrainInv+Unet8M", - # "hf:tapecam:mirflickr:MMCN4M+Unet4M", - # "hf:tapecam:mirflickr:Unet4M+U5+Unet4M", - # "hf:tapecam:mirflickr:Unet4M+TrainInv+Unet4M", - # "hf:tapecam:mirflickr:Unet2M+MMCN+Unet2M", - # "hf:tapecam:mirflickr:Unet4M+U10+Unet4M", - # "hf:tapecam:mirflickr:Unet4M+U5+Unet4M_flips_rotate10", - # "hf:tapecam:mirflickr:Unet4M+U5+Unet4M_aux1", - # "hf:tapecam:mirflickr:Unet4M+U5+Unet4M_flips", - # "hf:tapecam:mirflickr:Unet4M+U5+Unet4M_flips_rotate10", + "hf:tapecam:mirflickr:TrainInv+Unet8M", + "hf:tapecam:mirflickr:MMCN4M+Unet4M", + "hf:tapecam:mirflickr:Unet4M+U5+Unet4M", + "hf:tapecam:mirflickr:Unet4M+TrainInv+Unet4M", + "hf:tapecam:mirflickr:Unet2M+MMCN+Unet2M", + "hf:tapecam:mirflickr:Unet4M+U10+Unet4M", + "hf:tapecam:mirflickr:Unet4M+U5+Unet4M_flips_rotate10", + "hf:tapecam:mirflickr:Unet4M+U5+Unet4M_aux1", + "hf:tapecam:mirflickr:Unet4M+U5+Unet4M_flips", + "hf:tapecam:mirflickr:Unet4M+U5+Unet4M_flips_rotate10", + "hf:tapecam:mirflickr:Unet4M+U5+Unet4M_psfNN", # # below models need `single_channel_psf = True` # "hf:tapecam:mirflickr:MWDN8M", @@ -59,7 +60,9 @@ algorithms: [ # "hf:diffusercam:mirflickr:Unet4M+U5+Unet4M_ft_tapecam_pre", # "hf:diffusercam:mirflickr_sim:Unet4M+U5+Unet4M_ft_digicam_multi_pre", # "hf:diffusercam:mirflickr_sim:Unet4M+U5+Unet4M_ft_digicam_multi", + # "hf:diffusercam:mirflickr:Unet4M+U5+Unet4M_psfNN", + # "hf:digicam:mirflickr_single_25k:Unet4M+U5+Unet4M_wave_psfNN", ] save_idx: [1, 2, 4, 5, 9] -n_iter_range: [100] # for ADM +n_iter_range: [100] # for ADMM diff --git a/configs/recon_digicam_mirflickr.yaml b/configs/recon_digicam_mirflickr.yaml index 6e1146c9..35fc639d 100644 --- a/configs/recon_digicam_mirflickr.yaml +++ b/configs/recon_digicam_mirflickr.yaml @@ -26,11 +26,12 @@ fn: null # if not null, download this file from https://huggingface.co/datasets # model: MMCN4M+Unet4M_wave # model: MWDN8M_wave # model: U5+Unet8M_wave -model: Unet8M+U5_wave +# model: Unet8M+U5_wave # model: Unet4M+TrainInv+Unet4M_wave # model: Unet2M+MMCN+Unet2M_wave # model: Unet4M+U5+Unet4M_wave # model: Unet4M+U10+Unet4M_wave +model: Unet4M+U5+Unet4M_wave_psfNN # # --- dataset: mirflickr_multi_25k # model: Unet4M+U5+Unet4M_wave diff --git a/configs/recon_digicam_mirflickr_err.yaml b/configs/recon_digicam_mirflickr_err.yaml index 2e08bbc7..6589ba6f 100644 --- a/configs/recon_digicam_mirflickr_err.yaml +++ b/configs/recon_digicam_mirflickr_err.yaml @@ -9,14 +9,20 @@ hf_repo: null # by default use one in model config # set model # -- for learning-based methods (comment if using ADMM) -model: Unet4M+U5+Unet4M_wave +model: Unet4M+U5+Unet4M_wave_psfNN # # -- for ADMM with fixed parameters # model: admm -# n_iter: 10 +n_iter: 10 device: cuda:1 save_idx: [1, 2, 4, 5, 9] n_files: null percent_pixels_wrong: [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100] -flip: True # whether to flip mask values (True) or reset them (False) \ No newline at end of file +plot_vs_percent_wrong: False # whether to plot again percent wrong or correct +flip: False # whether to flip mask values (True) or reset them (False) + +compare_aes: [128, 256] # key lengths +digicam_ratio: 0.6 # approximate ratio of pixels that need to be correct +bit_depth: 8 +n_pixel: 1404 \ No newline at end of file diff --git a/configs/sim_digicam_psf.yaml b/configs/sim_digicam_psf.yaml index 3547d180..06b95fe3 100644 --- a/configs/sim_digicam_psf.yaml +++ b/configs/sim_digicam_psf.yaml @@ -5,7 +5,7 @@ hydra: use_torch: True dtype: float32 -torch_device: cuda +torch_device: cuda:1 requires_grad: False # if repo not provided, check for local file at `digicam.pattern` diff --git a/configs/train_digicam_celeba.yaml b/configs/train_digicam_celeba.yaml index c86dd45a..5f19cf1c 100644 --- a/configs/train_digicam_celeba.yaml +++ b/configs/train_digicam_celeba.yaml @@ -10,8 +10,9 @@ eval_disp_idx: [0, 2, 3, 4, 9] # Dataset files: dataset: bezzam/DigiCam-CelebA-26K - huggingface_psf: "psf_simulated.png" + huggingface_psf: "psf_simulated_waveprop.png" huggingface_dataset: True + cache_dir: /dev/shm split_seed: 0 test_size: 0.15 downsample: 2 diff --git a/configs/train_mirflickr_diffuser_sim.yaml b/configs/train_mirflickr_diffuser_sim.yaml new file mode 100644 index 00000000..c568f1d5 --- /dev/null +++ b/configs/train_mirflickr_diffuser_sim.yaml @@ -0,0 +1,12 @@ +# python scripts/recon/train_learning_based.py -cn train_mirflickr_diffuser_sim +defaults: + - train_mirflickr_diffuser + - _self_ + +torch_device: 'cuda:0' +device_ids: [0, 1, 2, 3] +eval_disp_idx: [0, 1, 3, 4, 8] + +# Dataset +files: + hf_simulated: True diff --git a/configs/train_unrolledADMM.yaml b/configs/train_unrolledADMM.yaml index ce4d851d..b1d61ad7 100644 --- a/configs/train_unrolledADMM.yaml +++ b/configs/train_unrolledADMM.yaml @@ -34,6 +34,7 @@ files: downsample: 2 # factor by which to downsample the PSF, note that for DiffuserCam the PSF has 4x the resolution downsample_lensed: 2 # only used if lensed if measured input_snr: null # adding shot noise at input (for measured dataset) at this SNR in dB + psf_snr: null # adding noise to PSF at this SNR in dB background_fp: null background_snr_range: null vertical_shift: null @@ -92,7 +93,7 @@ reconstruction: # processing PSF psf_network: False # False, or set number of channels for UnetRes, e.g. [8,16,32,64], with skip connection - psf_residual: True # if psf_network used, whether to use residual connection for original PSF estimate + psf_residual: False # if psf_network used, whether to use residual connection for original PSF estimate # background subtraction (if dataset has corresponding background images) direct_background_subtraction: False # True or False @@ -196,7 +197,7 @@ training: optimizer: type: AdamW # Adam, SGD... (Pytorch class) lr: 1e-4 - lr_step_epoch: True # True -> update LR at end of each epoch, False at the end of each mini-batch + lr_step_epoch: False # True -> update LR at end of each epoch, False at the end of each mini-batch cosine_decay_warmup: True # if set, cosine decay with warmup of 5% final_lr: False # if set, exponentially decay *to* this value exp_decay: False # if set, exponentially decay *with* this value diff --git a/lensless/eval/benchmark.py b/lensless/eval/benchmark.py index 1ada6abf..86e4023f 100644 --- a/lensless/eval/benchmark.py +++ b/lensless/eval/benchmark.py @@ -43,6 +43,22 @@ def pnp_loss(y_est, y, recon, mu, original_param): return loss +def swap_color_channels(tensor, channels): + if len(channels) == 2: + color1 = tensor[..., channels[0]].copy() + tensor[..., channels[0]] = tensor[..., channels[1]] + tensor[..., channels[1]] = color1 + elif len(channels) == 3: + # reorder channels + color1 = tensor[..., channels[0]].copy() + color2 = tensor[..., channels[1]].copy() + color3 = tensor[..., channels[2]].copy() + tensor[..., 0] = color1 + tensor[..., 1] = color2 + tensor[..., 2] = color3 + return tensor + + def benchmark( model, dataset, @@ -61,6 +77,7 @@ def benchmark( epoch=None, use_background=True, pnp=None, + swap_channels=False, **kwargs, ): """ @@ -97,7 +114,7 @@ def benchmark( Returns ------- Dict[str, float] - A dictionnary containing the metrics name and average value + A dictionary containing the metrics name and average value """ assert isinstance(model._psf, torch.Tensor), "model need to be constructed with torch support" device = model._psf.device @@ -242,6 +259,7 @@ def benchmark( ) if output_intermediate: + psfs_out = prediction[3] pre_process_out = prediction[2] unrolled_out = prediction[1] prediction = prediction[0] @@ -257,13 +275,9 @@ def benchmark( prediction, axis=(-2, -1), flip_lr=flip_lr, flip_ud=flip_ud ) else: - prediction = model.forward( - batch=lensless, psfs=psfs, background=background, **kwargs - ) prediction, lensed = dataset.extract_roi( prediction, axis=(-2, -1), lensed=lensed, flip_lr=flip_lr, flip_ud=flip_ud ) - assert np.all(lensed.shape == prediction.shape) elif crop is not None: assert flip_lr is None and flip_ud is None prediction = prediction[ @@ -289,12 +303,32 @@ def benchmark( if save_intermediate: fp = os.path.join(output_dir, f"{_batch_idx}_inv.png") unrolled_out_np = unrolled_out.cpu().numpy()[i].squeeze() + # -- swap red and green channels + if swap_channels: + unrolled_out_np = swap_color_channels(unrolled_out_np, swap_channels) save_image(unrolled_out_np, fp=fp) fp = os.path.join(output_dir, f"{_batch_idx}_preproc.png") pre_process_out_np = pre_process_out.cpu().numpy()[i].squeeze() + # -- swap red and green channels + if swap_channels: + pre_process_out_np = swap_color_channels( + pre_process_out_np, swap_channels + ) save_image(pre_process_out_np, fp=fp) + if psfs_out is not None: + fp = os.path.join(output_dir, f"{_batch_idx}_psfs.png") + if psfs_out.shape[0] == 1: + psfs_out_np = psfs_out.cpu().numpy().squeeze() + else: + psfs_out_np = psfs_out.cpu().numpy()[i].squeeze() + + # -- swap red and green channels + if swap_channels: + psfs_out_np = swap_color_channels(psfs_out_np, swap_channels) + save_image(psfs_out_np, fp=fp) + if use_wandb: assert epoch is not None, "epoch must be provided for wandb logging" log_key = f"{_batch_idx}_{label}" if label is not None else f"{_batch_idx}" @@ -313,7 +347,7 @@ def benchmark( for metric in metrics: if metric == "ReconstructionError": metrics_values[metric] += model.reconstruction_error( - prediction=prediction_original, lensless=lensless + prediction=prediction_original, lensless=lensless, psfs=psfs ).tolist() else: try: diff --git a/lensless/recon/model_dict.py b/lensless/recon/model_dict.py index 7979011e..ab019df7 100644 --- a/lensless/recon/model_dict.py +++ b/lensless/recon/model_dict.py @@ -67,6 +67,19 @@ "Unet2M+MMCN+Unet2M": "bezzam/diffusercam-mirflickr-unet2M-mmcn-unet2M", "Unet4M+U20+Unet4M": "bezzam/diffusercam-mirflickr-unet4M-unrolled-admm20-unet4M", "Unet4M+U10+Unet4M": "bezzam/diffusercam-mirflickr-unet4M-unrolled-admm10-unet4M", + "Unet4M+U5+Unet4M_psfNN": "bezzam/diffusercam-mirflickr-unet4M-unrolled-admm5-unet4M-psfNN", + # training with PSF noise + "U5+Unet8M_psf0dB": "bezzam/diffusercam-mirflickr-unrolled-admm5-unet8M-psf0dB", + "U5+Unet8M_psf-5dB": "bezzam/diffusercam-mirflickr-unrolled-admm5-unet8M-psf-5dB", + "U5+Unet8M_psf-10dB": "bezzam/diffusercam-mirflickr-unrolled-admm5-unet8M-psf-10dB", + "U5+Unet8M_psf-20dB": "bezzam/diffusercam-mirflickr-unrolled-admm5-unet8M-psf-20dB", + "Unet4M+U5+Unet4M_psf-0dB": "bezzam/diffusercam-mirflickr-unet4M-unrolled-admm5-unet4M-psf-0dB", + "Unet4M+U5+Unet4M_psf-5dB": "bezzam/diffusercam-mirflickr-unet4M-unrolled-admm5-unet4M-psf-5dB", + "Unet4M+U5+Unet4M_psf-10dB": "bezzam/diffusercam-mirflickr-unet4M-unrolled-admm5-unet4M-psf-10dB", + "Unet4M+U5+Unet4M_psf-20dB": "bezzam/diffusercam-mirflickr-unet4M-unrolled-admm5-unet4M-psf-20dB", + "Unet4M+U5+Unet4M_psfNN_psf-0dB": "bezzam/diffusercam-mirflickr-unet4M-unrolled-admm5-unet4M-psfNN-psf-0dB", + "Unet4M+U5+Unet4M_psfNN_psf-10dB": "bezzam/diffusercam-mirflickr-unet4M-unrolled-admm5-unet4M-psfNN-psf-10dB", + "Unet4M+U5+Unet4M_psfNN_psf-20dB": "bezzam/diffusercam-mirflickr-unet4M-unrolled-admm5-unet4M-psfNN-psf-20dB", # training with noise "U5+Unet8M_10db": "bezzam/diffusercam-mirflickr-unrolled-admm5-unet8M-10db", "U5+Unet8M_40db": "bezzam/diffusercam-mirflickr-unrolled-admm5-unet8M-40db", @@ -113,6 +126,7 @@ "Unet2M+MMCN+Unet2M_wave": "bezzam/digicam-celeba-unet2M-mmcn-unet2M", "Unet4M+U5+Unet4M_wave": "bezzam/digicam-celeba-unet4M-unrolled-admm5-unet4M", "Unet4M+U10+Unet4M_wave": "bezzam/digicam-celeba-unet4M-unrolled-admm10-unet4M", + "Unet4M+U5+Unet4M_wave_psfNN": "bezzam/digicam-celeba-unet4M-unrolled-admm5-unet4M-wave-psfNN", }, "mirflickr_single_25k": { # simulated PSF (without waveprop, with deadspace) @@ -126,11 +140,13 @@ "U10_wave": "bezzam/digicam-mirflickr-single-25k-unrolled-admm10-wave", "U10+Unet8M_wave": "bezzam/digicam-mirflickr-single-25k-unrolled-admm10-unet8M-wave", "Unet8M_wave": "bezzam/digicam-mirflickr-single-25k-unet8M-wave", + "Unet8M_wave_v2": "bezzam/digicam-mirflickr-single-25k-unet8M-wave-v2", "Unet4M+U10+Unet4M_wave": "bezzam/digicam-mirflickr-single-25k-unet4M-unrolled-admm10-unet4M-wave", "TrainInv+Unet8M_wave": "bezzam/digicam-mirflickr-single-25k-trainable-inv-unet8M-wave", "U5+Unet8M_wave": "bezzam/digicam-mirflickr-single-25k-unrolled-admm5-unet8M-wave", "Unet8M+U5_wave": "bezzam/digicam-mirflickr-single-25k-unet8M-unrolled-admm5-wave", "Unet4M+U5+Unet4M_wave": "bezzam/digicam-mirflickr-single-25k-unet4M-unrolled-admm5-unet4M-wave", + "Unet4M+U5+Unet4M_wave_psfNN": "bezzam/digicam-mirflickr-single-25k-unet4M-unrolled-admm5-unet4M-wave-psfNN", "MWDN8M_wave": "bezzam/digicam-mirflickr-single-25k-mwdn-8M", "MMCN4M+Unet4M_wave": "bezzam/digicam-mirflickr-single-25k-mmcn-unet4M", "Unet2M+MMCN+Unet2M_wave": "bezzam/digicam-mirflickr-single-25k-unet2M-mmcn-unet2M-wave", @@ -152,10 +168,12 @@ "mirflickr_multi_25k": { # simulated PSFs (without waveprop, with deadspace) "Unet8M": "bezzam/digicam-mirflickr-multi-25k-unet8M", + "Unet8M_wave_v2": "bezzam/digicam-mirflickr-multi-25k-unet8M-wave-v2", "Unet4M+U10+Unet4M": "bezzam/digicam-mirflickr-multi-25k-unet4M-unrolled-admm10-unet4M", # simulated PSF (with waveprop, with deadspace) "Unet4M+U10+Unet4M_wave": "bezzam/digicam-mirflickr-multi-25k-unet4M-unrolled-admm10-unet4M-wave", "Unet4M+U5+Unet4M_wave": "bezzam/digicam-mirflickr-multi-25k-unet4M-unrolled-admm5-unet4M-wave", + "Unet4M+U5+Unet4M_wave_psfNN": "bezzam/digicam-mirflickr-multi-25k-unet4M-unrolled-admm5-unet4M-wave-psfNN", "Unet4M+U5+Unet4M_wave_aux1": "bezzam/digicam-mirflickr-multi-25k-unet4M-unrolled-admm5-unet4M-wave-aux1", "Unet4M+U5+Unet4M_wave_flips": "bezzam/digicam-mirflickr-multi-25k-unet4M-unrolled-admm5-unet4M-wave-flips", }, @@ -175,6 +193,7 @@ "Unet4M+U5+Unet4M_flips": "bezzam/tapecam-mirflickr-unet4M-unrolled-admm5-unet4M-flips", "Unet4M+U5+Unet4M_flips_rotate10": "bezzam/tapecam-mirflickr-unet4M-unrolled-admm5-unet4M-flips-rotate10", "Unet4M+U5+Unet4M_aux1": "bezzam/tapecam-mirflickr-unet4M-unrolled-admm5-unet4M-aux1", + "Unet4M+U5+Unet4M_psfNN": "bezzam/tapecam-mirflickr-unet4M-unrolled-admm5-unet4M-psfNN", }, }, "multilens": { @@ -299,6 +318,12 @@ def load_model( print("Loading checkpoint from : ", model_checkpoint) model_state_dict = torch.load(model_checkpoint, map_location=device) + if config["files"].get("psf_snr", None) is not None: + # overwrite PSF with noisy PSF used during training + psf_path = os.path.join(model_path, "psf.pt") + assert os.path.exists(psf_path), "PSF does not exist" + psf = torch.load(psf_path, map_location=device) + # load model pre_process = None post_process = None @@ -330,6 +355,18 @@ def load_model( down_subtraction=config["reconstruction"]["down_subtraction"], ) + # network for PSF + psf_network = None + if config["reconstruction"].get("psf_network", None): + # create UnetRes for PSF + psf_network, _ = create_process_network( + network="UnetRes", + depth=len(config["reconstruction"]["psf_network"]), + nc=config["reconstruction"]["psf_network"], + device=device, + device_ids=device_ids, + ) + if config["reconstruction"].get("init", None): init_model = config["reconstruction"]["init"] @@ -397,6 +434,8 @@ def load_model( skip_pre=skip_pre, skip_post=skip_post, compensation=config["reconstruction"].get("compensation", None), + psf_network=psf_network, + psf_residual=config["reconstruction"].get("psf_residual", False), compensation_residual=config["reconstruction"].get("compensation_residual", False), direct_background_subtraction=config["reconstruction"].get( "direct_background_subtraction", False diff --git a/lensless/recon/recon.py b/lensless/recon/recon.py index 9ef77d89..4734fa2e 100644 --- a/lensless/recon/recon.py +++ b/lensless/recon/recon.py @@ -604,7 +604,7 @@ def apply( else: return final_im - def reconstruction_error(self, prediction=None, lensless=None): + def reconstruction_error(self, prediction=None, lensless=None, psfs=None, normalize=True): """ Compute reconstruction error. @@ -627,7 +627,9 @@ def reconstruction_error(self, prediction=None, lensless=None): lensless = self._data # convolver = self._convolver - convolver = RealFFTConvolve2D(self._psf.to(prediction.device), **self._convolver_param) + if psfs is None: + psfs = self._psf + convolver = RealFFTConvolve2D(psfs.to(prediction.device), **self._convolver_param) if not convolver.pad: prediction = convolver._pad(prediction) Hx = convolver.convolve(prediction) @@ -635,9 +637,17 @@ def reconstruction_error(self, prediction=None, lensless=None): if not convolver.pad: Hx = convolver._crop(Hx) + # -- normalize + if normalize: + min_vals = torch.amin(Hx, dim=(-1, -2, -3), keepdim=True) + Hx = Hx - min_vals + max_vals = torch.amax(Hx, dim=(-1, -2, -3), keepdim=True) + Hx = Hx / max_vals + # don't reduce batch dimension if self.is_torch: - return torch.sum(torch.sqrt((Hx - lensless) ** 2), dim=(-1, -2, -3, -4)) / self._npix + # torch.mean((Hx - lensless) ** 2, dim=(-1, -2, -3, -4)) + return torch.sum((Hx - lensless) ** 2, dim=(-1, -2, -3, -4)) / self._npix else: - return np.sum(np.sqrt((Hx - lensless) ** 2), axis=(-1, -2, -3, -4)) / self._npix + return np.sum((Hx - lensless) ** 2, axis=(-1, -2, -3, -4)) / self._npix diff --git a/lensless/recon/trainable_recon.py b/lensless/recon/trainable_recon.py index 8903f589..7e2f5f38 100644 --- a/lensless/recon/trainable_recon.py +++ b/lensless/recon/trainable_recon.py @@ -329,7 +329,7 @@ def forward(self, batch, psfs=None, background=None): # set / transform PSFs if need be if self.psf_network is not None: if psfs is None: - psfs = self._psf + psfs = self._psf.to(self._data.device) if self.psf_residual: psfs = self.psf_network(psfs, self.psf_network_param).to(psfs.device) + psfs else: @@ -337,7 +337,7 @@ def forward(self, batch, psfs=None, background=None): if psfs is not None: # assert same shape - assert psfs.shape == batch.shape, "psfs must have the same shape as batch" + assert psfs.shape[-3:] == batch.shape[-3:], "psfs must have the same shape as batch" # -- update convolver self._convolver = RealFFTConvolve2D(psfs.to(self._data.device), **self._convolver_param) elif self._data.device != self._convolver._H.device: @@ -388,7 +388,7 @@ def forward(self, batch, psfs=None, background=None): final_est = image_est if self.return_intermediate: - return final_est, image_est, pre_processed + return final_est, image_est, pre_processed, psfs else: return final_est @@ -461,6 +461,17 @@ def apply( ).to(self._data.device) self._data = torch.clamp(self._data, 0, 1) + # transform PSF if need be + psf = None + if self.psf_network is not None: + psf = self._psf.to(self._data.device) + if self.psf_residual: + psf = self.psf_network(psf, self.psf_network_param).to(psf.device) + psf + else: + psf = self.psf_network(psf, self.psf_network_param).to(psf.device) + if psf is not None: # -- update convolver + self._convolver = RealFFTConvolve2D(psf.to(self._data.device), **self._convolver_param) + pre_processed_image = None if self.integrated_background_subtraction: # use preprocess for background subtraction @@ -516,7 +527,10 @@ def apply( plt.savefig(plib.Path(save) / "final.png") if output_intermediate: - return im, pre_post_process_image, pre_processed_image + return_items = im, pre_post_process_image, pre_processed_image + if self.psf_network is not None: + return_items += (psf,) + return return_items elif plot: return im, ax else: diff --git a/lensless/utils/dataset.py b/lensless/utils/dataset.py index 7e72e663..900da859 100644 --- a/lensless/utils/dataset.py +++ b/lensless/utils/dataset.py @@ -1410,6 +1410,7 @@ def __init__( downsample=1, downsample_lensed=1, input_snr=None, + psf_snr=None, display_res=None, sensor="rpi_hq", slm="adafruit", @@ -1569,6 +1570,23 @@ def __init__( # replicate across three channels self.psf = self.psf.repeat(1, 1, 1, 3) + if psf_snr is not None: + # # add noise to PSF + # self.psf = add_shot_noise(self.psf, psf_snr) + # add Gaussian noise to PSF + noise = torch.randn_like(self.psf) + noise_var = torch.var(noise.flatten()) + psf_var = torch.var(self.psf.flatten()) + noise *= torch.sqrt(psf_var / noise_var) / 10 ** (psf_snr / 20) + self.psf += noise + + # save PSF as torch tensor for loading model later on + torch.save(self.psf, "psf.pt") + + if save_psf: + # same viewable image of PSF + save_image(self.psf.squeeze().cpu().numpy(), f"{split}_psf.png") + elif "mask_label" in data_0: self.multimask = True mask_labels = [] @@ -1583,24 +1601,6 @@ def __init__( mask_vals = self.get_mask_vals(label) self.psf[label] = self.simulate_psf(mask_vals) - # mask_fp = hf_hub_download( - # repo_id=huggingface_repo, - # filename=f"masks/mask_{label}.npy", - # repo_type="dataset", - # ) - # mask_vals = np.load(mask_fp) - # mask = AdafruitLCD( - # initial_vals=torch.from_numpy(mask_vals.astype(np.float32)), - # sensor=sensor, - # slm=slm, - # downsample=downsample_fact, - # flipud=self.rotate or flipud, # TODO separate commands? - # use_waveprop=simulation_config.get("use_waveprop", False), - # scene2mask=simulation_config.get("scene2mask", None), - # mask2sensor=simulation_config.get("mask2sensor", None), - # deadspace=simulation_config.get("deadspace", True), - # ) - # self.psf[label] = mask.get_psf().detach() assert ( self.psf[label].shape[-3:-1] == lensless.shape[:2] @@ -1617,18 +1617,7 @@ def __init__( repo_id=huggingface_repo, filename="mask_pattern.npy", repo_type="dataset" ) mask_vals = np.load(mask_fp) - mask = AdafruitLCD( - initial_vals=torch.from_numpy(mask_vals.astype(np.float32)), - sensor=sensor, - slm=slm, - downsample=downsample_fact, - flipud=self.rotate or flipud, # TODO separate commands? - use_waveprop=simulation_config.get("use_waveprop", False), - scene2mask=simulation_config.get("scene2mask", None), - mask2sensor=simulation_config.get("mask2sensor", None), - deadspace=simulation_config.get("deadspace", True), - ) - self.psf = mask.get_psf().detach() + self.psf = self.simulate_psf(mask_vals) assert ( self.psf.shape[-3:-1] == lensless.shape[:2] ), "PSF shape should match lensless shape" diff --git a/scripts/eval/benchmark_recon.py b/scripts/eval/benchmark_recon.py index 7dc21aee..2588b36a 100644 --- a/scripts/eval/benchmark_recon.py +++ b/scripts/eval/benchmark_recon.py @@ -299,8 +299,8 @@ def benchmark_recon(config): model_obj.eval() if pnp is not None: - print(f"Usinng parameterize and perturb (P&P) with {pnp} parameters...") pnp["model_path"] = model + print(f"Using parameterize and perturb (P&P) with {pnp} parameters...") result = benchmark( model_obj, @@ -312,6 +312,7 @@ def benchmark_recon(config): crop=crop, snr=config.snr, pnp=pnp, + swap_channels=config.swap_channels, ) results[model_name] = result @@ -344,6 +345,7 @@ def benchmark_recon(config): crop=crop, use_background=config.huggingface.use_background, snr=config.snr, + swap_channels=config.swap_channels, ) results[model_name][int(n_iter)] = result diff --git a/scripts/recon/digicam_mirflickr_psf_err.py b/scripts/recon/digicam_mirflickr_psf_err.py index 5b55740c..55995a05 100644 --- a/scripts/recon/digicam_mirflickr_psf_err.py +++ b/scripts/recon/digicam_mirflickr_psf_err.py @@ -14,6 +14,10 @@ from matplotlib import pyplot as plt +def key_to_ratio_correct(key_length, bit_depth, n_pixel): + return np.emath.logn(bit_depth, 2) * key_length / n_pixel + + @hydra.main( version_base=None, config_path="../../configs", config_name="recon_digicam_mirflickr_err" ) @@ -21,6 +25,7 @@ def apply_pretrained(config): device = config.device model_name = config.model percent_pixels_wrong = config.percent_pixels_wrong + key_to_ratio = {"DigiCam": config.digicam_ratio} if config.metrics_fp is not None: @@ -92,15 +97,6 @@ def apply_pretrained(config): return_mask_label=True, ) - # # create Dataset loader - # batch_size = 4 - # dataloader = torch.utils.data.DataLoader( - # dataset=test_set, - # batch_size=batch_size, - # shuffle=False, - # pin_memory=(device != "cpu"), - # ) - psf_norms = [] for mask_label in test_set.psf.keys(): psf_norms.append(np.mean(test_set.psf[mask_label].cpu().numpy().flatten() ** 2)) @@ -150,6 +146,8 @@ def apply_pretrained(config): if percent_wrong > 0: n_pixels = mask_vals.size + assert n_pixels == config.n_pixel + n_wrong_pixels = int(n_pixels * percent_wrong / 100) wrong_pixels = np.random.choice(n_pixels, n_wrong_pixels, replace=False) noisy_mask_vals = noisy_mask_vals.flatten() @@ -216,6 +214,17 @@ def apply_pretrained(config): # save if idx in config.save_idx: + + # PSF + from lensless.utils.image import gamma_correction + + psf_np = psf.cpu().numpy().squeeze() + fp = os.path.join(str(idx), f"psf_percentwrong{percent_wrong}.png") + psf_np /= np.max(psf_np) + psf_np = gamma_correction(psf_np, gamma=3) + save_image(psf_np, fp) + + # reconstruction img = recon.cpu().numpy().squeeze() alignment = test_set.alignment top_left = alignment["top_left"] @@ -238,21 +247,84 @@ def apply_pretrained(config): with open(f"{model_name}_metrics.json", "w") as f: json.dump(metrics_values, f, indent=4) - # plot each metrics vs percent_wrong + # plot each metrics + ## - config text size + plt.rcParams.update({"font.size": 24}) + ## - config line width + plt.rcParams.update({"lines.linewidth": 5}) + + key_lengths = config.compare_aes + for key_length in key_lengths: + key_to_ratio[key_length] = key_to_ratio_correct( + key_length, config.bit_depth, config.n_pixel + ) + if config.plot_vs_percent_wrong: + key_to_ratio[key_length] = 1 - key_to_ratio[key_length] + linestyles = ["--", "-.", ":"] + colors = ["red", "forestgreen", "purple", "green", "blue", "black"] + for k, v in metrics_values.items(): - plt.figure() - plt.xlabel("Percent pixels wrong [%]") + plt.figure(figsize=(6.7, 5)) + y_vals = np.mean(v, axis=1) + if config.plot_vs_percent_wrong: + plt.xlabel("Percent pixels wrong [%]") + x_vals = percent_pixels_wrong + else: + plt.xlabel("Percent pixels correct [%]") + x_vals = 100 - np.array(percent_pixels_wrong) + plt.xlim([0, 100]) if k == "psf_err": - plt.plot(percent_pixels_wrong, np.mean(v, axis=1) * 100) + plt.plot(x_vals, y_vals * 100) plt.ylabel("Relative PSF error [%]") + plt.ylim([0, 100]) + elif k == "PSNR": + plt.plot(x_vals, y_vals) + plt.ylabel("PSNR [dB]") + elif k == "LPIPS_Vgg": + plt.plot(x_vals, y_vals) + plt.ylabel("LPIPS") else: - plt.plot(percent_pixels_wrong, np.mean(v, axis=1)) + plt.plot(x_vals, y_vals) plt.ylabel(k) + if config.plot_vs_percent_wrong: + print(f"-- Metric {k} : ", y_vals[::-1]) + print("% wrong vals : ", x_vals[::-1]) + else: + print(f"-- Metric {k} : ", y_vals[::-1]) + print("% corrects vals : ", x_vals[::-1]) + + # plot keys + # for idx, key_length in enumerate(key_lengths): + for idx, _key in enumerate(list(key_to_ratio.keys())): + if isinstance(_key, int): + label = f"AES-{_key}" + else: + label = _key + plt.axvline( + key_to_ratio[_key] * 100, color=colors[idx], linestyle=linestyles[idx], label=label + ) + # label along x axis + # plt.text( + # key_to_ratio[key_length] * 100, + # plt.ylim()[-1], + # # f"{key_length} bits ({100 * key_to_ratio[key_length]:.2f}%)", + # f"{key_length} bits", + # rotation=45, + # verticalalignment="bottom", + # ) + + # legend bottom right + if k == "PSNR" or k == "SSIM": + plt.legend(loc="lower right") + else: + plt.legend(loc="upper right") + # save plot # - tight + plt.grid() plt.tight_layout() - plt.savefig(f"{k}_{model_name}.png") + plt.savefig(f"{k}_{model_name}.png", bbox_inches="tight") if __name__ == "__main__": diff --git a/scripts/recon/train_learning_based.py b/scripts/recon/train_learning_based.py index 826c27d7..1ee3f24a 100644 --- a/scripts/recon/train_learning_based.py +++ b/scripts/recon/train_learning_based.py @@ -259,6 +259,7 @@ def train_learned(config): if config.files.background_fp is not None else None, input_snr=config.files.input_snr, + psf_snr=config.files.psf_snr, ) test_set = HFDataset( @@ -286,6 +287,10 @@ def train_learned(config): input_snr=config.files.input_snr, ) + if config.files.psf_snr is not None: + # overwrite test set PSF with train set PSF + test_set.psf = train_set.psf + if train_set.multimask: # get first PSF for initialization if device_ids is not None: