From b1ad3a74d5a14eff2458f876bffdfc2d61bc18b0 Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Fri, 14 Jun 2024 09:26:25 +0000 Subject: [PATCH] Add support for preproc aux and initializing from HF model. --- configs/train_unrolledADMM.yaml | 12 +++- lensless/eval/benchmark.py | 36 ++++++++++-- lensless/recon/model_dict.py | 26 ++++++++ lensless/recon/recon.py | 17 +++--- lensless/recon/trainable_recon.py | 18 +++--- lensless/recon/unrolled_admm.py | 3 +- lensless/recon/utils.py | 85 ++++++++++++++++++++------- lensless/utils/dataset.py | 3 + recon_requirements.txt | 3 +- scripts/recon/train_learning_based.py | 33 +++++++++-- 10 files changed, 182 insertions(+), 54 deletions(-) diff --git a/configs/train_unrolledADMM.yaml b/configs/train_unrolledADMM.yaml index d21a20c2..2a55a834 100644 --- a/configs/train_unrolledADMM.yaml +++ b/configs/train_unrolledADMM.yaml @@ -20,6 +20,7 @@ files: dataset: bezzam/DiffuserCam-Lensless-Mirflickr-Dataset-NORM huggingface_dataset: True huggingface_psf: psf.tiff + single_channel_psf: False # whether to sum all PSF channels into one # -- train/test split split_seed: null # if null use train/test split from dataset @@ -67,7 +68,11 @@ reconstruction: # Method: unrolled_admm, unrolled_fista, trainable_inv method: unrolled_admm skip_unrolled: False - init_processors: null # model name + + # initialize with "init_processors" + # -- for HuggingFace model use "hf:camera:dataset:model_name" + # -- for local model use "local:model_path" + init_processors: null init_pre: True # if `init_processors`, set pre-procesor is available init_post: True # if `init_processors`, set post-procesor is available @@ -149,7 +154,6 @@ simulation: max_val: 255 #Training - training: batch_size: 8 epoch: 25 @@ -174,4 +178,6 @@ loss: 'l2' # set lpips to false to deactivate. Otherwise, give the weigth for the loss (the main loss l2/l1 always having a weigth of 1) lpips: 1.0 unrolled_output_factor: False # whether to account for unrolled output in loss (there must post-processor) -pre_proc_aux: False # factor for auxiliary pre-processor loss to promote measurement consistency -> ||pre_proc(y) - Ax|| \ No newline at end of file +# factor for auxiliary pre-processor loss to promote measurement consistency -> ||pre_proc(y) - A * camera_inversion(y)|| +# -- use camera inversion output so that doesn't include enhancements / coloring by post-processor +pre_proc_aux: False \ No newline at end of file diff --git a/lensless/eval/benchmark.py b/lensless/eval/benchmark.py index 8a86f32c..6b4e0529 100644 --- a/lensless/eval/benchmark.py +++ b/lensless/eval/benchmark.py @@ -36,6 +36,7 @@ def benchmark( save_idx=None, output_dir=None, unrolled_output_factor=False, + pre_process_aux=False, return_average=True, snr=None, use_wandb=False, @@ -106,6 +107,8 @@ def benchmark( for key in output_metrics: if key != "ReconstructionError": metrics_values[key + "_unrolled"] = [] + if pre_process_aux: + metrics_values["ReconstructionError_PreProc"] = [] # loop over batches dataloader = DataLoader(dataset, batch_size=batchsize, pin_memory=(device != "cpu")) @@ -141,14 +144,18 @@ def benchmark( model._set_psf(psfs[0]) model.set_data(lensless) prediction = model.apply( - plot=False, save=False, output_intermediate=unrolled_output_factor, **kwargs + plot=False, + save=False, + output_intermediate=unrolled_output_factor or pre_process_aux, + **kwargs, ) else: prediction = model.forward(lensless, psfs, **kwargs) - if unrolled_output_factor: - unrolled_out = prediction[-1] + if unrolled_output_factor or pre_process_aux: + pre_process_out = prediction[2] + unrolled_out = prediction[1] prediction = prediction[0] prediction_original = prediction.clone() @@ -245,7 +252,17 @@ def benchmark( unrolled_out = unrolled_out.reshape(-1, *unrolled_out.shape[-3:]).movedim(-1, -3) # -- extraction region of interest - if crop is not None: + if hasattr(dataset, "alignment"): + if dataset.alignment is not None: + unrolled_out = dataset.extract_roi(unrolled_out, axis=(-2, -1)) + else: + unrolled_out = dataset.extract_roi( + unrolled_out, + axis=(-2, -1), + # lensed=lensed # lensed already extracted before + ) + assert np.all(lensed.shape == unrolled_out.shape) + elif crop is not None: unrolled_out = unrolled_out[ ..., crop["vertical"][0] : crop["vertical"][1], @@ -288,13 +305,20 @@ def benchmark( else: metrics_values[metric + "_unrolled"].append(vals.item()) + # compute metrics for pre-processed output + if pre_process_aux: + metrics_values["ReconstructionError_PreProc"] += model.reconstruction_error( + prediction=prediction_original, lensless=pre_process_out + ).tolist() + model.reset() idx += batchsize # average metrics if return_average: - for metric in metrics: - if "MSE" in metric or "ReconstructionError" in metric or "LPIPS" in metric: + for metric in metrics_values.keys(): + if "MSE" in metric or "LPIPS" in metric: + # differently because metrics are grouped into bathces metrics_values[metric] = np.sum(metrics_values[metric]) / len(dataset) else: metrics_values[metric] = np.mean(metrics_values[metric]) diff --git a/lensless/recon/model_dict.py b/lensless/recon/model_dict.py index 8440a149..b3d4536a 100644 --- a/lensless/recon/model_dict.py +++ b/lensless/recon/model_dict.py @@ -15,6 +15,7 @@ from lensless.recon.trainable_inversion import TrainableInversion from lensless.hardware.trainable_mask import prep_trainable_mask import yaml +from lensless.recon.multi_wiener import MultiWiener from huggingface_hub import snapshot_download from collections import OrderedDict @@ -100,6 +101,17 @@ "Unet4M+U10+Unet4M_wave": "bezzam/digicam-mirflickr-multi-25k-unet4M-unrolled-admm10-unet4M-wave", }, }, + "tapecam": { + "mirflickr": { + "U5+Unet8M": "bezzam/tapecam-mirflickr-unrolled-admm5-unet8M", + "TrainInv+Unet8M": "bezzam/tapecam-mirflickr-trainable-inv-unet8M", + "MMCN4M+Unet4M": "bezzam/tapecam-mirflickr-mmcn-unet4M", + "MWDN8M": "bezzam/tapecam-mirflickr-mwdn-8M", + "Unet4M+TrainInv+Unet4M": "bezzam/tapecam-mirflickr-unet4M-trainable-inv-unet4M", + "Unet4M+U5+Unet4M": "bezzam/tapecam-mirflickr-unet4M-unrolled-admm5-unet4M", + "Unet2M+MMCN+Unet2M": "bezzam/tapecam-mirflickr-unet2M-mmcn-unet2M", + }, + }, } @@ -225,6 +237,10 @@ def load_model( if "nc" in config["reconstruction"]["post_process"].keys() else None, device=device, + # get from dict + concatenate_compensation=True + if config["reconstruction"].get("compensation", None) is not None + else False, ) if config["reconstruction"]["method"] == "unrolled_admm": @@ -237,6 +253,7 @@ def load_model( legacy_denoiser=legacy_denoiser, skip_pre=skip_pre, skip_post=skip_post, + compensation=config["reconstruction"].get("compensation", None), ) elif config["reconstruction"]["method"] == "trainable_inv": recon = TrainableInversion( @@ -248,6 +265,15 @@ def load_model( skip_pre=skip_pre, skip_post=skip_post, ) + elif config["reconstruction"]["method"] == "multi_wiener": + recon = MultiWiener( + in_channels=3, + out_channels=3, + psf=psf, + psf_channels=3, + nc=config["reconstruction"]["multi_wiener"]["nc"], + ) + recon.to(device) if mask is not None: psf_learned = torch.nn.Parameter(psf_learned) diff --git a/lensless/recon/recon.py b/lensless/recon/recon.py index d58c5906..ca172ff8 100644 --- a/lensless/recon/recon.py +++ b/lensless/recon/recon.py @@ -255,6 +255,7 @@ def __init__( assert len(psf.shape) == 4, "PSF must be 4D: (depth, height, width, channels)." assert psf.shape[3] == 3 or psf.shape[3] == 1, "PSF must either be rgb (3) or grayscale (1)" self._psf = psf + self._npix = np.prod(self._psf.shape) self._n_iter = n_iter self._psf_shape = np.array(self._psf.shape) @@ -611,22 +612,18 @@ def reconstruction_error(self, prediction=None, lensless=None): if lensless is None: lensless = self._data - convolver = self._convolver + # convolver = self._convolver + convolver = RealFFTConvolve2D(self._psf.to(prediction.device), **self._convolver_param) if not convolver.pad: prediction = convolver._pad(prediction) - Fx = convolver.convolve(prediction) - Fy = lensless + Hx = convolver.convolve(prediction) if not convolver.pad: - Fx = convolver._crop(Fx) + Hx = convolver._crop(Hx) # don't reduce batch dimension if self.is_torch: - return torch.sum(torch.sqrt((Fx - Fy) ** 2), dim=(-1, -2, -3, -4)) / np.prod( - prediction.shape[1:] - ) + return torch.sum(torch.sqrt((Hx - lensless) ** 2), dim=(-1, -2, -3, -4)) / self._npix else: - return np.sum(np.sqrt((Fx - Fy) ** 2), axis=(-1, -2, -3, -4)) / np.prod( - prediction.shape[1:] - ) + return np.sum(np.sqrt((Hx - lensless) ** 2), axis=(-1, -2, -3, -4)) / self._npix diff --git a/lensless/recon/trainable_recon.py b/lensless/recon/trainable_recon.py index 9c38c810..7243e05e 100644 --- a/lensless/recon/trainable_recon.py +++ b/lensless/recon/trainable_recon.py @@ -54,7 +54,7 @@ def __init__( skip_unrolled=False, skip_pre=False, skip_post=False, - return_unrolled_output=False, + return_intermediate=False, legacy_denoiser=False, compensation=None, **kwargs, @@ -100,7 +100,7 @@ def __init__( self.skip_unrolled = skip_unrolled self.skip_pre = skip_pre self.skip_post = skip_post - self.return_unrolled_output = return_unrolled_output + self.return_intermediate = return_intermediate self.compensation_branch = compensation if compensation is not None: from lensless.recon.utils import CompensationBranch @@ -112,11 +112,12 @@ def __init__( len(compensation) == n_iter ), "compensation_nc must have the same length as n_iter" self.compensation_branch = CompensationBranch(compensation) + self.compensation_branch = self.compensation_branch.to(self._psf.device) - if self.return_unrolled_output: + if self.return_intermediate: assert ( - post_process is not None - ), "If return_unrolled_output is True, post_process must be defined." + post_process is not None or pre_process is not None + ), "If return_intermediate is True, post_process or pre_process must be defined." if self.skip_unrolled: assert ( post_process is not None or pre_process is not None @@ -246,6 +247,7 @@ def forward(self, batch, psfs=None): device_before = self._data.device self._data = self.pre_process(self._data, self.pre_process_param) self._data = self._data.to(device_before) + pre_processed = self._data self.reset(batch_size=batch_size) @@ -273,8 +275,8 @@ def forward(self, batch, psfs=None): else: final_est = image_est - if self.return_unrolled_output: - return final_est, image_est + if self.return_intermediate: + return final_est, image_est, pre_processed else: return final_est @@ -365,7 +367,7 @@ def apply( plt.savefig(plib.Path(save) / "final.png") if output_intermediate: - return im, pre_processed_image, pre_post_process_image + return im, pre_post_process_image, pre_processed_image elif plot: return im, ax else: diff --git a/lensless/recon/unrolled_admm.py b/lensless/recon/unrolled_admm.py index d30eb6b9..174d8ce6 100644 --- a/lensless/recon/unrolled_admm.py +++ b/lensless/recon/unrolled_admm.py @@ -235,5 +235,6 @@ def _update(self, iter): def _form_image(self): image = self._convolver._crop(self._image_est) - image = torch.clamp(image, min=0) + # image = torch.clamp(image, min=0) + image = torch.clip(image, min=0.0) return image diff --git a/lensless/recon/utils.py b/lensless/recon/utils.py index e78961c4..930c4bea 100644 --- a/lensless/recon/utils.py +++ b/lensless/recon/utils.py @@ -418,6 +418,7 @@ def __init__( crop=None, clip_grad=1.0, unrolled_output_factor=False, + pre_proc_aux=False, extra_eval_sets=None, use_wandb=False, # for adding components during training @@ -613,6 +614,14 @@ def __init__( assert self.post_process_unfreeze is None assert self.post_process_freeze is None + # -- adding pre-processed output to loss + self.pre_proc_aux = pre_proc_aux + if self.pre_proc_aux: + assert self.pre_process is not None + assert self.pre_process_delay is None + assert self.pre_process_unfreeze is None + assert self.pre_process_freeze is None + # optimizer self.clip_grad_norm = clip_grad self.optimizer_config = optimizer @@ -641,6 +650,10 @@ def __init__( # -- add unrolled metrics for key in ["MSE", "MAE", "LPIPS_Vgg", "LPIPS_Alex", "PSNR", "SSIM"]: self.metrics[key + "_unrolled"] = [] + if self.pre_proc_aux: + self.metrics[ + "ReconstructionError_PreProc" + ] = [] # reconstruction error of ||pre_proc(y) - A * camera_inversion(y)|| if metric_for_best_model is not None: assert metric_for_best_model in self.metrics.keys() if extra_eval_sets is not None: @@ -748,9 +761,8 @@ def train_epoch(self, data_loader): # forward pass # torch.autograd.set_detect_anomaly(True) # for debugging y_pred = self.recon.forward(batch=X, psfs=psfs) - if self.unrolled_output_factor: - unrolled_out = y_pred[1] - y_pred = y_pred[0] + if self.unrolled_output_factor or self.pre_proc_aux: + y_pred, camera_inv_out, pre_proc_out = y_pred[0], y_pred[1], y_pred[2] # normalizing each output eps = 1e-12 @@ -762,18 +774,20 @@ def train_epoch(self, data_loader): y = y / y_max # convert to CHW for loss and remove depth - y_pred = y_pred.reshape(-1, *y_pred.shape[-3:]).movedim(-1, -3) + y_pred_crop = y_pred.reshape(-1, *y_pred.shape[-3:]).movedim(-1, -3) y = y.reshape(-1, *y.shape[-3:]).movedim(-1, -3) # extraction region of interest for loss if hasattr(self.train_dataset, "alignment"): if self.train_dataset.alignment is not None: - y_pred = self.train_dataset.extract_roi(y_pred, axis=(-2, -1)) + y_pred_crop = self.train_dataset.extract_roi(y_pred_crop, axis=(-2, -1)) else: - y_pred, y = self.train_dataset.extract_roi(y_pred, axis=(-2, -1), lensed=y) + y_pred_crop, y = self.train_dataset.extract_roi( + y_pred_crop, axis=(-2, -1), lensed=y + ) elif self.crop is not None: - y_pred = y_pred[ + y_pred_crop = y_pred_crop[ ..., self.crop["vertical"][0] : self.crop["vertical"][1], self.crop["horizontal"][0] : self.crop["horizontal"][1], @@ -784,19 +798,19 @@ def train_epoch(self, data_loader): self.crop["horizontal"][0] : self.crop["horizontal"][1], ] - loss_v = self.Loss(y_pred, y) + loss_v = self.Loss(y_pred_crop, y) # add LPIPS loss if self.lpips: - if y_pred.shape[1] == 1: + if y_pred_crop.shape[1] == 1: # if only one channel, repeat for LPIPS - y_pred = y_pred.repeat(1, 3, 1, 1) + y_pred_crop = y_pred_crop.repeat(1, 3, 1, 1) y = y.repeat(1, 3, 1, 1) # value for LPIPS needs to be in range [-1, 1] loss_v = loss_v + self.lpips * torch.mean( - self.Loss_lpips(2 * y_pred - 1, 2 * y - 1) + self.Loss_lpips(2 * y_pred_crop - 1, 2 * y - 1) ) if self.use_mask and self.l1_mask: for p in self.mask.parameters(): @@ -805,37 +819,65 @@ def train_epoch(self, data_loader): if self.unrolled_output_factor: # -- normalize - unrolled_out_max = torch.amax(unrolled_out, dim=(-1, -2, -3), keepdim=True) + eps - unrolled_out = unrolled_out / unrolled_out_max + unrolled_out_max = torch.amax(camera_inv_out, dim=(-1, -2, -3), keepdim=True) + eps + camera_inv_out_norm = camera_inv_out / unrolled_out_max # -- convert to CHW for loss and remove depth - unrolled_out = unrolled_out.reshape(-1, *unrolled_out.shape[-3:]).movedim(-1, -3) + camera_inv_out_norm = camera_inv_out_norm.reshape( + -1, *camera_inv_out.shape[-3:] + ).movedim(-1, -3) # -- extraction region of interest for loss - if self.crop is not None: - unrolled_out = unrolled_out[ + if hasattr(self.train_dataset, "alignment"): + if self.train_dataset.alignment is not None: + camera_inv_out_norm = self.train_dataset.extract_roi( + camera_inv_out_norm, axis=(-2, -1) + ) + else: + camera_inv_out_norm = self.train_dataset.extract_roi( + camera_inv_out_norm, + axis=(-2, -1), + # y=y # lensed already extracted before + ) + assert np.all(y.shape == camera_inv_out_norm.shape) + elif self.crop is not None: + camera_inv_out_norm = camera_inv_out_norm[ ..., self.crop["vertical"][0] : self.crop["vertical"][1], self.crop["horizontal"][0] : self.crop["horizontal"][1], ] # -- compute unrolled output loss - loss_unrolled = self.Loss(unrolled_out, y) + loss_unrolled = self.Loss(camera_inv_out_norm, y) # -- add LPIPS loss if self.lpips: - if unrolled_out.shape[1] == 1: + if camera_inv_out_norm.shape[1] == 1: # if only one channel, repeat for LPIPS - unrolled_out = unrolled_out.repeat(1, 3, 1, 1) + camera_inv_out_norm = camera_inv_out_norm.repeat(1, 3, 1, 1) # value for LPIPS needs to be in range [-1, 1] loss_unrolled = loss_unrolled + self.lpips * torch.mean( - self.Loss_lpips(2 * unrolled_out - 1, 2 * y - 1) + self.Loss_lpips(2 * camera_inv_out_norm - 1, 2 * y - 1) ) # -- add unrolled loss to total loss loss_v = loss_v + self.unrolled_output_factor * loss_unrolled + if self.pre_proc_aux: + # -- normalize + unrolled_out_max = torch.amax(camera_inv_out, dim=(-1, -2, -3), keepdim=True) + eps + camera_inv_out_norm = camera_inv_out / unrolled_out_max + + err = torch.mean( + self.recon.reconstruction_error( + prediction=camera_inv_out_norm, + # prediction=y_pred, + lensless=pre_proc_out, + ) + ) + loss_v = loss_v + self.pre_proc_aux * err + # backward pass loss_v.backward() @@ -923,6 +965,7 @@ def evaluate(self, mean_loss, epoch, disp=None): output_dir=output_dir, crop=self.crop, unrolled_output_factor=self.unrolled_output_factor, + pre_process_aux=self.pre_proc_aux, use_wandb=self.use_wandb, epoch=epoch, ) @@ -949,6 +992,8 @@ def evaluate(self, mean_loss, epoch, disp=None): if self.lpips is not None: unrolled_loss += self.lpips * current_metrics["LPIPS_Vgg_unrolled"] eval_loss += self.unrolled_output_factor * unrolled_loss + if self.pre_proc_aux: + eval_loss += self.pre_proc_aux * current_metrics["ReconstructionError_PreProc"] else: eval_loss = current_metrics[self.metrics["metric_for_best_model"]] diff --git a/lensless/utils/dataset.py b/lensless/utils/dataset.py index e8311f26..638216c7 100644 --- a/lensless/utils/dataset.py +++ b/lensless/utils/dataset.py @@ -1041,6 +1041,7 @@ def __init__( simulate_lensless=False, force_rgb=False, cache_dir=None, + single_channel_psf=False, **kwargs, ): """ @@ -1102,6 +1103,7 @@ def __init__( self.downsample_lensless = downsample self.downsample_lensed = downsample_lensed lensless = np.array(data_0["lensless"]) + if self.downsample_lensless != 1.0: lensless = resize(lensless, factor=1 / self.downsample_lensless) if psf is None: @@ -1159,6 +1161,7 @@ def __init__( flip_ud=flipud, bg_pix=(0, 15), force_rgb=force_rgb, + single_psf=single_channel_psf, ) self.psf = torch.from_numpy(psf) diff --git a/recon_requirements.txt b/recon_requirements.txt index beb67296..30b6b31b 100644 --- a/recon_requirements.txt +++ b/recon_requirements.txt @@ -3,7 +3,8 @@ lpips==0.1.4 pylops==1.18.0 scikit-image>=0.19.0rc0 click>=8.0.1 -waveprop>=0.0.10 # for simulation +# waveprop>=0.0.10 # for simulation +waveprop @ git+https://github.com/ebezzam/waveprop.git slm_controller @ git+https://github.com/ebezzam/slm-controller.git # Library for learning algorithm diff --git a/scripts/recon/train_learning_based.py b/scripts/recon/train_learning_based.py index 350614f3..a3ce60f1 100644 --- a/scripts/recon/train_learning_based.py +++ b/scripts/recon/train_learning_based.py @@ -209,6 +209,7 @@ def train_learned(config): train_set = HFDataset( huggingface_repo=config.files.dataset, psf=config.files.huggingface_psf, + single_channel_psf=config.files.single_channel_psf, split=split_train, display_res=config.files.image_res, rotate=config.files.rotate, @@ -224,6 +225,7 @@ def train_learned(config): test_set = HFDataset( huggingface_repo=config.files.dataset, psf=config.files.huggingface_psf, + single_channel_psf=config.files.single_channel_psf, split=split_test, display_res=config.files.image_res, rotate=config.files.rotate, @@ -393,10 +395,23 @@ def train_learned(config): # initialize pre- and post processor with another model if config.reconstruction.init_processors is not None: - from lensless.recon.model_dict import load_model, model_dict + from lensless.recon.model_dict import load_model, download_model + + if "hf" in config.reconstruction.init_processors: + param = config.reconstruction.init_processors.split(":") + camera = param[1] + dataset = param[2] + model_name = param[3] + model_path = download_model(camera=camera, dataset=dataset, model=model_name) + + elif "local" in config.reconstruction.init_processors: + model_path = config.reconstruction.init_processors.split(":")[1] + + else: + raise ValueError(f"{config.reconstruction.init_processors} is not a supported model") model_orig = load_model( - model_dict["diffusercam"]["mirflickr"][config.reconstruction.init_processors], + model_path=model_path, psf=psf, device=device, ) @@ -430,7 +445,9 @@ def train_learned(config): pre_process=pre_process if pre_proc_delay is None else None, post_process=post_process if post_proc_delay is None else None, skip_unrolled=config.reconstruction.skip_unrolled, - return_unrolled_output=True if config.unrolled_output_factor > 0 else False, + return_intermediate=True + if config.unrolled_output_factor > 0 or config.pre_proc_aux > 0 + else False, compensation=config.reconstruction.compensation, ) elif config.reconstruction.method == "unrolled_admm": @@ -444,16 +461,21 @@ def train_learned(config): pre_process=pre_process if pre_proc_delay is None else None, post_process=post_process if post_proc_delay is None else None, skip_unrolled=config.reconstruction.skip_unrolled, - return_unrolled_output=True if config.unrolled_output_factor > 0 else False, + return_intermediate=True + if config.unrolled_output_factor > 0 or config.pre_proc_aux > 0 + else False, compensation=config.reconstruction.compensation, ) elif config.reconstruction.method == "trainable_inv": + assert config.trainable_mask.mask_type == "TrainablePSF" recon = TrainableInversion( psf, K=config.reconstruction.trainable_inv.K, pre_process=pre_process if pre_proc_delay is None else None, post_process=post_process if post_proc_delay is None else None, - return_unrolled_output=True if config.unrolled_output_factor > 0 else False, + return_intermediate=True + if config.unrolled_output_factor > 0 or config.pre_proc_aux > 0 + else False, ) elif config.reconstruction.method == "multi_wiener": recon = MultiWiener( @@ -516,6 +538,7 @@ def train_learned(config): post_process_unfreeze=config.reconstruction.post_process.unfreeze, clip_grad=config.training.clip_grad, unrolled_output_factor=config.unrolled_output_factor, + pre_proc_aux=config.pre_proc_aux, extra_eval_sets=extra_eval_sets if config.files.extra_eval is not None else None, use_wandb=True if config.wandb_project is not None else False, )