diff --git a/GaussianProxy/utils/misc.py b/GaussianProxy/utils/misc.py index 69159e4..18cbccb 100644 --- a/GaussianProxy/utils/misc.py +++ b/GaussianProxy/utils/misc.py @@ -1,29 +1,30 @@ import time -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime from functools import wraps from logging import Logger from pathlib import Path -from typing import Literal, Optional +from typing import Literal, Optional, overload import numpy as np import plotly.express as px import plotly.graph_objects as go import torch +import wandb from accelerate import Accelerator from accelerate.logging import MultiProcessAdapter from diffusers.configuration_utils import FrozenDict +from enlighten import Manager from numpy import ndarray from omegaconf import OmegaConf from PIL import Image from termcolor import colored from torch import Tensor from torchvision.transforms import RandomHorizontalFlip, RandomVerticalFlip +from wandb.sdk.wandb_run import Run as WandBRun -import wandb from GaussianProxy.conf.training_conf import Config from GaussianProxy.utils.data import RandomRotationSquareSymmetry -from wandb.sdk.wandb_run import Run as WandBRun def create_repo_structure( @@ -136,7 +137,9 @@ def modify_args_for_debug( strat.nb_diffusion_timesteps = 5 if hasattr(strat, "nb_samples_to_gen_per_time"): assert hasattr(strat, "batch_size"), "Expected batch_size to be present" - setattr(strat, "nb_samples_to_gen_per_time", getattr(strat, "batch_size")) + # only change it if it's a hard-coded value + if isinstance(getattr(strat, "nb_samples_to_gen_per_time"), int): + setattr(strat, "nb_samples_to_gen_per_time", getattr(strat, "batch_size")) # TODO: update registered wandb config @@ -546,23 +549,59 @@ def filter_internal_fields(config_dict: FrozenDict | dict): RandomHorizontalFlip, RandomVerticalFlip, RandomRotationSquareSymmetry, + "RandomHorizontalFlip", + "RandomVerticalFlip", + "RandomRotationSquareSymmetry", ) -def generate_all_augs(img: torch.Tensor, transforms: list[type]) -> list[Tensor]: +@overload +def generate_all_augs(img: Tensor, transforms: list[type] | list[str]) -> list[Tensor]: ... + + +@overload +def generate_all_augs(img: Image.Image, transforms: list[type] | list[str]) -> list[Image.Image]: ... + + +def generate_all_augs( + img: Tensor | Image.Image, transforms: list[type] | list[str] +) -> list[Tensor] | list[Image.Image]: """Generate all augmentations of `img` based on `transforms`.""" - # checks - assert img.ndim == 3, f"Expected 3D image, got shape {img.shape}" + # general checks assert all(t in SUPPORTED_TRANSFORMS_TO_GENERATE for t in transforms), f"Unsupported transforms: {transforms}" + # choose backend + if isinstance(img, Tensor): + aug_imgs = _generate_all_augs_torch(img, transforms) + elif isinstance(img, Image.Image): + aug_imgs = _generate_all_augs_pil(img, transforms) + else: + raise TypeError(f"Unsupported image type: {type(img)}") + + assert len(aug_imgs) in ( + 1, + 2, + 4, + 8, + ), f"Expected 1, 2, 4, or 8 images at this point, got {len(aug_imgs)}" + + return aug_imgs + + +# TODO: factorize backends better + + +def _generate_all_augs_torch(img: Tensor, transforms: list[type] | list[str]) -> list[Tensor]: + """torch backend for `generate_all_augs`.""" + assert img.ndim == 3, f"Expected 3D image, got shape {img.shape}" # generate all possible augmentations aug_imgs = [img] - if RandomHorizontalFlip in transforms: + if RandomHorizontalFlip in transforms or "RandomHorizontalFlip" in transforms: aug_imgs.append(torch.flip(img, [2])) - if RandomVerticalFlip in transforms: + if RandomVerticalFlip in transforms or "RandomVerticalFlip" in transforms: for base_img_idx in range(len(aug_imgs)): aug_imgs.append(torch.flip(aug_imgs[base_img_idx], [1])) - if RandomRotationSquareSymmetry in transforms: + if RandomRotationSquareSymmetry in transforms or "RandomRotationSquareSymmetry" in transforms: if len(aug_imgs) == 4: # must have been flipped in both directions then aug_imgs += [ torch.rot90(aug_imgs[0], k=1, dims=(1, 2)), @@ -572,17 +611,85 @@ def generate_all_augs(img: torch.Tensor, transforms: list[type]) -> list[Tensor] torch.rot90(aug_imgs[1], k=1, dims=(1, 2)), torch.rot90(aug_imgs[1], k=3, dims=(1, 2)), ] - elif len(aug_imgs) in (1, 2): # simply perform all 4 rotations on each image + elif len(aug_imgs) in (1, 2): # simply perform all 3 rotations on each image for base_img_idx in range(len(aug_imgs)): for nb_rot in [1, 2, 3]: aug_imgs.append(torch.rot90(aug_imgs[base_img_idx], k=nb_rot, dims=(1, 2))) else: raise ValueError(f"Expected 1, 2, or 4 images at this point, got {len(aug_imgs)}") - assert len(aug_imgs) in ( - 1, - 2, - 4, - 8, - ), f"Expected 1, 2, 4, or 8 images at this point, got {len(aug_imgs)}" return aug_imgs + + +def _generate_all_augs_pil(img: Image.Image, transforms: list[type] | list[str]) -> list[Image.Image]: + """PIL backend for `generate_all_augs`.""" + # generate all possible augmentations + aug_imgs = [img] + if RandomHorizontalFlip in transforms or "RandomHorizontalFlip" in transforms: + aug_imgs.append(img.transpose(Image.Transpose.FLIP_LEFT_RIGHT)) + if RandomVerticalFlip in transforms or "RandomVerticalFlip" in transforms: + for base_img_idx in range(len(aug_imgs)): + aug_imgs.append(aug_imgs[base_img_idx].transpose(Image.Transpose.FLIP_TOP_BOTTOM)) + if RandomRotationSquareSymmetry in transforms or "RandomRotationSquareSymmetry" in transforms: + if len(aug_imgs) == 4: # must have been flipped in both directions then + aug_imgs += [ + aug_imgs[0].transpose(Image.Transpose.ROTATE_90), + aug_imgs[0].transpose(Image.Transpose.ROTATE_270), + ] + aug_imgs += [ + aug_imgs[1].transpose(Image.Transpose.ROTATE_90), + aug_imgs[1].transpose(Image.Transpose.ROTATE_270), + ] + elif len(aug_imgs) in (1, 2): # simply perform all 3 rotations on each image + for base_img_idx in range(len(aug_imgs)): + aug_imgs += [ + aug_imgs[base_img_idx].transpose(Image.Transpose.ROTATE_90), + aug_imgs[base_img_idx].transpose(Image.Transpose.ROTATE_180), + aug_imgs[base_img_idx].transpose(Image.Transpose.ROTATE_270), + ] + else: + raise ValueError(f"Expected 1, 2, or 4 images at this point, got {len(aug_imgs)}") + + return aug_imgs + + +def hard_augment_dataset_all_square_symmetries( + dataset_path: Path, logger: MultiProcessAdapter, pbar_manager: Manager, pbar_pos: int, files_ext: str +): + """ + Save ("in-place") the 8 augmented versions of each image in the given `dataset_path`. + + ### Args: + - `dataset_path` (`Path`): The path to the dataset to augment. + + Each image will be augmented with the 8 square symmetries (Dih4) + and saved in the same subfolder with '_aug_' appended. + + This function should only be ran on main process. + """ + # Get all base images + all_base_imgs = list(dataset_path.rglob(f"*.{files_ext}")) + assert len(all_base_imgs) > 0, f"No '*.{files_ext}' images found in {dataset_path}" + logger.debug(f"Found {len(all_base_imgs)} base images in {dataset_path}") + + # Save augmented images + logger.debug("Writing augmented datasets to disk") + pbar = pbar_manager.counter(total=..., unit="base image", desc="Saving augmented images", position=pbar_pos) + + def aug_save_img(base_img_path: Path): + base_img = Image.open(base_img_path) + augs = generate_all_augs(base_img, [RandomRotationSquareSymmetry, RandomHorizontalFlip, RandomVerticalFlip]) + for aug_idx, aug in enumerate(augs[1:]): # skip the original image + save_path = base_img_path.parent / f"{base_img_path.stem}_aug{aug_idx}.{base_img_path.suffix}" + aug.save(save_path) + pbar.update() + + with ThreadPoolExecutor() as executor: + futures = [] + for base_img in all_base_imgs: + futures.append(executor.submit(aug_save_img, base_img)) + + for future in as_completed(futures): + future.result() # raises exception if any + + pbar.close() diff --git a/GaussianProxy/utils/training.py b/GaussianProxy/utils/training.py index fe60682..f84c06f 100644 --- a/GaussianProxy/utils/training.py +++ b/GaussianProxy/utils/training.py @@ -37,6 +37,7 @@ from GaussianProxy.utils.misc import ( StateLogger, get_evenly_spaced_timesteps, + hard_augment_dataset_all_square_symmetries, log_state, save_eval_artifacts_log_to_wandb, save_images_for_metrics_compute, @@ -125,6 +126,7 @@ class TimeDiffusion: # WARNING: best_metric_to_date will only be updated on the main process! best_metric_to_date: float = field(init=False) first_metrics_eval: bool = True # always reset to True even if resuming from a checkpoint + eval_on_start: bool = False # constant evaluation starting states _eval_noise: Tensor = field(init=False) _eval_video_times: Tensor = field(init=False) @@ -271,6 +273,7 @@ def fit( if ( self.cfg.evaluation.every_n_opt_steps is not None and self.global_optimization_step % self.cfg.evaluation.every_n_opt_steps == 0 + and (self.global_optimization_step != 0 or self.eval_on_start) ): self._evaluate( train_timestep_dataloaders, @@ -652,7 +655,7 @@ def _evaluate( # TODO: save & load a compiled artifact when torch.export is stabilized # 3. Run through evaluation strategies - for eval_strat in self.cfg.evaluation.strategies: + for eval_strat in self.cfg.evaluation.strategies: # TODO: match on type when config is updated if eval_strat.name == "SimpleGeneration": self._simple_gen( pbar_manager, @@ -713,6 +716,7 @@ def _evaluate( ) elif eval_strat.name == "MetricsComputation": + # TODO: index with class names true_data_classes_paths = {idx: "" for idx in range(len(test_dataloaders.keys()))} for cl_idx, dl in enumerate(test_dataloaders.values()): inner_ds = dl.dataset @@ -720,8 +724,27 @@ def _evaluate( raise ValueError( f"Expected a `BaseDataset` for the underlying dataset of the evaluation dataloader, got {type(inner_ds)}" ) - # assuming they are share the same parent... - true_data_classes_paths[cl_idx] = Path(inner_ds.samples[0]).parent.as_posix() + # assuming they all share the same parent... + this_class_path = Path(inner_ds.samples[0]).parent + if eval_strat.nb_samples_to_gen_per_time == "adapt half aug": # pyright: ignore[reportAttributeAccessIssue] + # use the hard augmented version of the dataset if it's not the one we are already training on + if "_hard_augmented" not in this_class_path.parent.name: + self.logger.debug( + f"Using hard augmented version of {this_class_path}: {this_class_path.parent.name}_hard_augmented" + ) + hard_augmented_dataset_path = ( + this_class_path.parent.with_name(this_class_path.parent.name + "_hard_augmented") + / this_class_path.name + ) + else: + hard_augmented_dataset_path = this_class_path + + assert ( + hard_augmented_dataset_path.exists() + ), f"hard augmented dataset does not exist at {hard_augmented_dataset_path} " + true_data_classes_paths[cl_idx] = hard_augmented_dataset_path.as_posix() + else: + true_data_classes_paths[cl_idx] = this_class_path.as_posix() self._metrics_computation( tmp_save_folder, pbar_manager, @@ -931,7 +954,6 @@ def _inv_regen( for video_t_idx, video_time in enumerate(torch.linspace(0, 1, self.cfg.evaluation.nb_video_timesteps)): image = inverted_gauss.clone() - self.logger.debug(f"Video timestep index {video_t_idx} / {self.cfg.evaluation.nb_video_timesteps - 1}") video_time_enc = inference_video_time_encoding.forward(video_time.item(), batch.shape[0]) for t in inference_scheduler.timesteps: @@ -1065,8 +1087,6 @@ def _iter_inv_regen( prev_video_time = 0 image = batch for video_t_idx, video_time in enumerate(video_times): - self.logger.debug(f"Video timestep index {video_t_idx} / {self.cfg.evaluation.nb_video_timesteps}") - # 2. Generate the inverted Gaussians inverted_gauss = image inversion_video_time = inference_video_time_encoding.forward(prev_video_time, batch.shape[0]) @@ -1241,7 +1261,6 @@ def _forward_noising( for video_t_idx, video_time in enumerate(torch.linspace(0, 1, self.cfg.evaluation.nb_video_timesteps)): image = slightly_noised_sample.clone() - self.logger.debug(f"Video timestep index {video_t_idx} / {self.cfg.evaluation.nb_video_timesteps}") video_time_enc = inference_video_time_encoding.forward(video_time.item(), batch.shape[0]) for t in inference_scheduler.timesteps[noise_timestep_idx:]: @@ -1308,23 +1327,31 @@ def _metrics_computation( Everything is distributed. """ - # 0. Preparations + ##### 0. Preparations # duplicate the scheduler to not mess with the training one inference_scheduler: DDIMScheduler = DDIMScheduler.from_config(self.dynamic.config) # pyright: ignore[reportAssignmentType] inference_scheduler.set_timesteps(eval_strat.nb_diffusion_timesteps) # Misc. - self.logger.info(f"Starting {eval_strat.name}") + self.logger.info(f"Starting {eval_strat.name}: {eval_strat}") self.logger.debug( f"Starting {eval_strat.name} on process ({self.accelerator.process_index})", main_process_only=False, ) + metrics_computation_folder = tmp_save_folder / "metrics_computation" + assert ( + len(self.empirical_dists_timesteps) == len(true_data_classes_paths) + ), f"Mismatch between number of timesteps and classes: empirical_dists_timesteps={self.empirical_dists_timesteps}, true_data_classes_paths={true_data_classes_paths}" + if eval_strat.nb_samples_to_gen_per_time == "adapt half aug": + assert all( + "hard_augmented" in true_data_classes_paths[key] for key in true_data_classes_paths + ), f"Expected all true data paths to be hard augmented, got:\n{true_data_classes_paths}" # use training time encodings eval_video_time = torch.tensor(self.empirical_dists_timesteps).to(self.accelerator.device) eval_video_time_enc = inference_video_time_encoding.forward(eval_video_time) - # 1. Generate the samples + ##### 1. Generate the samples # loop over training video times video_times_pbar = pbar_manager.counter( total=len(eval_video_time), @@ -1340,7 +1367,7 @@ def _metrics_computation( video_time_enc = video_time_enc.unsqueeze(0).repeat(eval_strat.batch_size, 1) # find how many samples to generate, batchify generation and distribute along processes - gen_dir = tmp_save_folder / "metrics_computation" / str(video_time_idx) + gen_dir = metrics_computation_folder / str(video_time_idx) gen_dir.mkdir(parents=True, exist_ok=True) this_proc_gen_batches = self._find_this_proc_this_time_batches_for_metrics_comp( eval_strat, @@ -1360,7 +1387,7 @@ def _metrics_computation( batches_pbar.refresh() # loop over generation batches - for batch_idx, batch_size in batches_pbar(enumerate(this_proc_gen_batches)): + for batch_size in batches_pbar(this_proc_gen_batches): gen_pbar = pbar_manager.counter( total=len(inference_scheduler.timesteps), position=4, @@ -1402,19 +1429,42 @@ def _metrics_computation( # no need to wait here then self.logger.info("Finished image generation") - # 2. Compute metrics + ##### 1.5 Augment the generated samples if applicable + if eval_strat.nb_samples_to_gen_per_time == "adapt half aug": + # on main process: + if self.accelerator.is_main_process: + # augment + extension = "png" # TODO: remove this hardcoded extension (by moving DatasetParams'params into the base DataSet class used in config, another TODO) + hard_augment_dataset_all_square_symmetries( + metrics_computation_folder, + self.logger, + pbar_manager, + 3, + extension, + ) + # check result + subdirs = [d for d in metrics_computation_folder.iterdir() if d.is_dir()] + nb_elems_per_class = { + class_path.name: len(list((metrics_computation_folder / class_path.name).glob(f"*.{extension}"))) + for class_path in subdirs + } + assert all( + nb_elems_per_class[cl_name] % 8 == 0 for cl_name in nb_elems_per_class + ), f"Expected number of elements to be a multiple of 8, got:\n{nb_elems_per_class}" + + self.accelerator.wait_for_everyone() + + ##### 2. Compute metrics # consistency of cache naming with below is important - metrics_caches: dict[str | int, Path] = { - "all_classes": tmp_save_folder / "metrics_computation" / self.cfg.dataset.name - } + metrics_caches: dict[str | int, Path] = {"all_classes": metrics_computation_folder / self.cfg.dataset.name} for video_time_idx in range(len(self.empirical_dists_timesteps)): - metrics_caches[video_time_idx] = ( - tmp_save_folder / "metrics_computation" / (self.cfg.dataset.name + "_class_" + str(video_time_idx)) + metrics_caches[video_time_idx] = metrics_computation_folder / ( + self.cfg.dataset.name + "_class_" + str(video_time_idx) ) # clear the dataset caches (on first eval of a run only...) # because the dataset might not be exactly the same that in the previous run, - # despite having the same cfg.dataset.name (used as ID): risk of invalid cache + # despite having the same cfg.dataset.name (used as ID): risk of invalid cache! if self.first_metrics_eval: self.first_metrics_eval = False # clear on main process @@ -1425,40 +1475,48 @@ def _metrics_computation( self.accelerator.wait_for_everyone() # TODO: weight tasks by number of samples - tasks = ["all_classes"] + list(range(len(self.empirical_dists_timesteps))) + # TODO: include "all_classes" in the tasks, but differentiation between using seen data only and all the available dataset + # tasks = ["all_classes"] + list(range(len(self.empirical_dists_timesteps))) + tasks = list(range(len(self.empirical_dists_timesteps))) tasks_for_this_process = tasks[self.accelerator.process_index :: self.accelerator.num_processes] self.logger.info("Computing metrics...") metrics_dict: dict[str, dict[str, float]] = {} for task in tasks_for_this_process: if task == "all_classes": - self.logger.debug(f"Computing metrics against true samples at {Path(self.cfg.dataset.path).as_posix()}") + self.logger.debug( + f"Computing metrics against true samples at {Path(self.cfg.dataset.path).as_posix()} on process {self.accelerator.process_index}", + main_process_only=False, + ) metrics = torch_fidelity.calculate_metrics( input1=Path(self.cfg.dataset.path).as_posix(), - input2=(tmp_save_folder / "metrics_computation").as_posix(), + input2=metrics_computation_folder.as_posix(), cuda=True, - batch_size=eval_strat.batch_size, # TODO: optimize - isc=True, + batch_size=eval_strat.batch_size * 4, # TODO: optimize + isc=False, fid=True, - prc=True, + prc=False, verbose=self.cfg.debug and self.accelerator.is_main_process, - cache_root=(tmp_save_folder / "metrics_computation").as_posix(), + cache_root=metrics_computation_folder.as_posix(), input1_cache_name=metrics_caches["all_classes"].name, samples_find_deep=True, ) else: assert isinstance(task, int) - self.logger.debug(f"Computing metrics against true samples at {true_data_classes_paths[task]}") + self.logger.debug( + f"Computing metrics against true samples at {true_data_classes_paths[task]} on process {self.accelerator.process_index}", + main_process_only=False, + ) metrics = torch_fidelity.calculate_metrics( input1=true_data_classes_paths[task], - input2=(tmp_save_folder / "metrics_computation" / str(task)).as_posix(), + input2=(metrics_computation_folder / str(task)).as_posix(), cuda=True, - batch_size=eval_strat.batch_size, # TODO: optimize - isc=True, + batch_size=eval_strat.batch_size * 4, # TODO: optimize + isc=False, fid=True, - prc=True, + prc=False, verbose=self.cfg.debug and self.accelerator.is_main_process, - cache_root=(tmp_save_folder / "metrics_computation").as_posix(), + cache_root=metrics_computation_folder.as_posix(), input1_cache_name=metrics_caches[task].name, ) metrics_dict[str(task)] = metrics @@ -1478,7 +1536,7 @@ def _metrics_computation( ) self.accelerator.wait_for_everyone() - # 3. Merge metrics from all processes & Log metrics + ##### 3. Merge metrics from all processes & Log metrics if self.accelerator.is_main_process: final_metrics_dict = {} for metrics_file in [f for f in tmp_save_folder.iterdir() if f.name.endswith("metrics_dict.pkl")]: @@ -1496,41 +1554,48 @@ def _metrics_computation( f"Logged metrics {final_metrics_dict}", ) - # 4. Check if best model to date + ##### 4. Check if best model to date if self.accelerator.is_main_process: # WARNING: only "smaller is better" metrics are supported! # WARNING: best_metric_to_date will only be updated on the main process! + if "all_classes" not in final_metrics_dict: # pyright: ignore[reportPossiblyUnboundVariable] + key = list(final_metrics_dict.keys())[0] # pyright: ignore[reportPossiblyUnboundVariable] + assert isinstance(key, str), f"Expected a string key, got {key} of type {type(key)}" + self.logger.warning( + f"No 'all_classes' key in final_metrics_dict, will update best_metric_to_date with the FID of the first class ({key})" + ) + else: + key = "all_classes" if ( - self.cfg.debug - or self.best_metric_to_date > final_metrics_dict["all_classes"]["frechet_inception_distance"] # pyright: ignore[reportPossiblyUnboundVariable] + self.cfg.debug or self.best_metric_to_date > final_metrics_dict[key]["frechet_inception_distance"] # pyright: ignore[reportPossiblyUnboundVariable] ): - self.best_metric_to_date = final_metrics_dict["all_classes"]["frechet_inception_distance"] # pyright: ignore[reportPossiblyUnboundVariable] + self.best_metric_to_date = final_metrics_dict[key]["frechet_inception_distance"] # pyright: ignore[reportPossiblyUnboundVariable] self.accelerator.log( {"training/best_metric_to_date": self.best_metric_to_date}, step=self.global_optimization_step ) - self.logger.info(f"Saving best model to date with all classes FID: {self.best_metric_to_date}") - self._save_pipeline() + self.logger.info(f"Saving best model to date with class='{key}' FID: {self.best_metric_to_date}") + self._save_pipeline(called_on_main_process_only=True) - # 5. Clean up (important because we reuse existing generated samples!) + ##### 5. Clean up (important because we reuse existing generated samples!) if self.accelerator.is_main_process: - shutil.rmtree(tmp_save_folder / "metrics_computation") + shutil.rmtree(metrics_computation_folder) self.logger.debug(f"Cleaned up metrics computation folder {tmp_save_folder / 'metrics_computation'}") - self.accelerator.wait_for_everyone() def _find_this_proc_this_time_batches_for_metrics_comp( self, eval_strat: MetricsComputation, - video_time_idx: int, + video_time_idx: int, # TODO: use class name instead of index true_data_classes_paths: dict[int, str], gen_dir: Path, ) -> list[int]: """ - Return the list of batch sizes for this process to generate, for this `video_time_idx`. + Return the list of batch sizes to generate, for a given `video_time_idx` and splitting between processes. `eval_strat.nb_samples_to_gen_per_time` can be: - an `int`: the number of samples to generate - `"adapt"`: generate as many samples as there are in the true data class - `"adapt half"`: generate half as many samples as there are in the true data class + - `"adapt half aug"`: generate half as many samples as there are in the true data class, then 8⨉ augment them (Dih4) """ # find total number of samples to generate for this video time if isinstance(eval_strat.nb_samples_to_gen_per_time, int): @@ -1540,11 +1605,17 @@ def _find_this_proc_this_time_batches_for_metrics_comp( self.logger.debug( f"Will generate {tot_nb_samples} samples for class n°{video_time_idx} at {true_data_classes_paths[video_time_idx]}" ) - elif eval_strat.nb_samples_to_gen_per_time == "adapt half": + elif eval_strat.nb_samples_to_gen_per_time.startswith("adapt half"): tot_nb_samples = len(list(Path(true_data_classes_paths[video_time_idx]).iterdir())) // 2 self.logger.debug( f"Will generate {tot_nb_samples} samples for class n°{video_time_idx} at {true_data_classes_paths[video_time_idx]}" ) + if eval_strat.nb_samples_to_gen_per_time == "adapt half aug": + self.logger.debug("Will augment samples 8⨉ after generation") + else: + assert ( + eval_strat.nb_samples_to_gen_per_time == "adapt half" + ), f"Expected 'adapt half' or 'adapt half aug' at this point, got {eval_strat.nb_samples_to_gen_per_time}" else: raise ValueError( f"Expected 'nb_samples_to_gen_per_time' to be an int, 'adapt', or 'adapt half', got {eval_strat.nb_samples_to_gen_per_time}" @@ -1579,12 +1650,12 @@ def _find_this_proc_this_time_batches_for_metrics_comp( return this_proc_gen_batches @log_state(state_logger) - def _save_pipeline(self): + def _save_pipeline(self, called_on_main_process_only: bool = False): """ Save the net, time encoder, and dynamic config to disk as an independent pretrained pipeline. Also save the current ResumingArgs to that same folder for later "model save" resuming. - Can be called by all processes (only main will actually save). + Can be called by all processes (only main will actually save), or by main only (then no barrier). """ # net self.accelerator.unwrap_model(self.net).save_pretrained( @@ -1606,7 +1677,9 @@ def _save_pipeline(self): } resuming_args = ResumingArgs.from_dict(training_info_for_resume) resuming_args.to_json(self.model_save_folder / ResumingArgs.json_state_filename) - self.accelerator.wait_for_everyone() + + if not called_on_main_process_only: + self.accelerator.wait_for_everyone() self.logger.info( f"Saved net, video time encoder, dynamic config, and resuming args to {self.model_save_folder}" diff --git a/scripts/hard_augment_datatset.py b/scripts/hard_augment_datatset.py new file mode 100644 index 0000000..8d8c967 --- /dev/null +++ b/scripts/hard_augment_datatset.py @@ -0,0 +1,81 @@ +import sys +from concurrent.futures import ProcessPoolExecutor, as_completed +from pathlib import Path +from time import sleep + +from PIL import Image +from rich.traceback import install +from tqdm.rich import tqdm + +sys.path.insert(0, "..") +from GaussianProxy.utils.misc import generate_all_augs + +install() + +###################################################### Arguments ###################################################### +DATASET_BASE_PATH = Path( + "/projects/static2dynamic/datasets/biotine/3_channels_min_99_perc_normalized_rgb_stacks_png/patches_255" +) +EXTENSION = "png" +DEBUG = True +TRANSFORMS = ["RandomHorizontalFlip", "RandomVerticalFlip", "RandomRotationSquareSymmetry"] + +######################################################### Info ######################################################## +print(f"Augmenting base dataset located at {DATASET_BASE_PATH}", flush=True) +print(f"Using the following transforms: {TRANSFORMS}", flush=True) +print("DEBUG:", DEBUG, flush=True) +sleep(3) + + +def ending(path: Path, n): + return Path(*path.parts[-n:]) + + +def augment_save_one_file(base_file: Path, aug_subdir_path: Path): + img = Image.open(base_file) + augs = generate_all_augs(img, TRANSFORMS) + for i, aug in enumerate(augs): + this_aug_save_path = aug_subdir_path / base_file.parent.name / f"{base_file.stem}_aug_{i}.{EXTENSION}" + if not DEBUG: + aug.save(this_aug_save_path) + else: + print(f"Would have saved patch {i} of {ending(base_file, 3)} to {ending(this_aug_save_path, 3)}") + + +if __name__ == "__main__": + # get names of subdirs / timestamps + subdirs_names = [x.name for x in DATASET_BASE_PATH.iterdir() if x.is_dir()] + print(f"Found {len(subdirs_names)} subdirectories: {subdirs_names}") + + # get all base files to augment + all_files = list(DATASET_BASE_PATH.glob(f"**/*.{EXTENSION}")) + print(f"Found {len(all_files)} base files in total") + + # create augmented subdirs in a adjacent dir to the base dataset one + aug_subdir_path = DATASET_BASE_PATH.with_name(DATASET_BASE_PATH.name + "_hard_augmented") + print(f"Saving augmented images at {aug_subdir_path}") + for subdir_name in subdirs_names: + (aug_subdir_path / subdir_name).mkdir(parents=True, exist_ok=True) + + # augment and save to disk + pbar = tqdm(total=len(all_files), desc="Saving augmented images") + + with ProcessPoolExecutor() as executor: + futures = {executor.submit(augment_save_one_file, file, aug_subdir_path): file for file in all_files} + for future in as_completed(futures): + try: + future.result() + except Exception as e: + raise Exception(f"Error processing file {futures[future]}") from e + pbar.update() + + pbar.close() + + # check result + for subdir_name in subdirs_names: + found_nb = len(list((aug_subdir_path / subdir_name).glob(f"*.{EXTENSION}"))) + expected_nb = 8 * len(list((DATASET_BASE_PATH / subdir_name).glob(f"*.{EXTENSION}"))) + assert ( + found_nb == expected_nb + ), f"Expected {expected_nb} files in {ending(aug_subdir_path / subdir_name, 2)}, found {found_nb}" + print("All checks passed") diff --git a/scripts/test_FID_augmentations.py b/scripts/test_FID_augmentations.py new file mode 100644 index 0000000..80f5dc3 --- /dev/null +++ b/scripts/test_FID_augmentations.py @@ -0,0 +1,161 @@ +""" +Script used to test the hypothesis that augmentations +can mimic different data splits. + +- take one single split on the dataset (vs 10 in the "evaluations" notebook) +- compute 10 times the FID between the two splits, with random augmentations each time +- compare obtained FIDs to the ones obtained in the "evaluations" notebook +""" + +import json +import random +import sys +from pathlib import Path +from pprint import pprint +from warnings import warn + +import seaborn as sns +import torch +import torch_fidelity +from torch.utils.data import Subset +from torchvision.transforms import Compose, ConvertImageDtype, RandomHorizontalFlip, RandomVerticalFlip +from tqdm.notebook import trange + +sys.path.insert(0, "..") +from GaussianProxy.utils.data import RandomRotationSquareSymmetry + +torch.set_grad_enabled(False) +sns.set_theme(context="paper") + +# Dataset +from my_conf.dataset.BBBC021_196_docetaxel_inference import BBBC021_196_docetaxel_inference as dataset # noqa: E402 + +assert dataset.dataset_params is not None +database_path = Path(dataset.path) +print(f"Using dataset {dataset.name} from {database_path}") +subdirs: list[Path] = [e for e in database_path.iterdir() if e.is_dir() and not e.name.startswith(".")] +subdirs.sort(key=dataset.dataset_params.sorting_func) +print(f"Found {len(subdirs)} classes: {', '.join(e.name for e in subdirs)}") + +# now split the dataset into 2 non-overlapping parts, respecting classes proportions... +is_flip_or_rotation = lambda t: isinstance(t, (RandomHorizontalFlip, RandomVerticalFlip, RandomRotationSquareSymmetry)) +flips_rot = [t for t in dataset.transforms.transforms if is_flip_or_rotation(t)] +transforms = Compose(flips_rot + [ConvertImageDtype(torch.uint8)]) +print(f"Using transforms:\n{transforms}") +nb_repeats = 10 +nb_elems_per_class: dict[str, int] = {} + +# create split and datasets once +ds1_elems = [] +ds2_elems = [] +for subdir in subdirs: + this_class_elems = list(subdir.glob(f"*.{dataset.dataset_params.file_extension}")) + nb_elems_per_class[subdir.name] = len(this_class_elems) + random.shuffle(this_class_elems) + new_ds1_elems = this_class_elems[: len(this_class_elems) // 2] + new_ds2_elems = this_class_elems[len(this_class_elems) // 2 :] + ds1_elems += new_ds1_elems + ds2_elems += new_ds2_elems + assert len(new_ds1_elems) + len(new_ds2_elems) == len( + this_class_elems + ), f"{len(new_ds1_elems)} + {len(new_ds2_elems)} != {len(this_class_elems)}" +assert abs(len(ds1_elems) - len(ds2_elems)) <= len(subdirs) +ds1 = dataset.dataset_params.dataset_class( + ds1_elems, + transforms, + dataset.expected_initial_data_range, +) +ds2 = dataset.dataset_params.dataset_class( + ds2_elems, + transforms, + dataset.expected_initial_data_range, +) +print("ds1:", ds1) +print("ds2:", ds2) + +nb_elems_per_class["all_classes"] = sum(nb_elems_per_class.values()) +print("nb_elems_per_class:", nb_elems_per_class) + + +# FID +# ## Compute train vs train FIDs +def compute_metrics(batch_size: int, metrics_save_path: Path): + eval_metrics = {} + + for exp_rep in trange(nb_repeats, unit="experiment repeat"): + metrics_dict: dict[str, dict[str, float]] = {} + metrics_dict["all_classes"] = torch_fidelity.calculate_metrics( + input1=ds1, + input2=ds2, + cuda=True, + batch_size=batch_size, + isc=True, + fid=True, + prc=True, + verbose=True, + samples_find_deep=True, + ) + # per-class + for subdir in subdirs: + this_class_ds1_idxes = [i for i, e in enumerate(ds1_elems) if e.parent == subdir] + this_class_ds2_idxes = [i for i, e in enumerate(ds2_elems) if e.parent == subdir] + ds1_this_cl = Subset(ds1, this_class_ds1_idxes) + ds2_this_cl = Subset(ds2, this_class_ds2_idxes) + assert abs(len(ds1_this_cl) - len(ds2_this_cl)) <= 1 + assert ( + len(ds1_this_cl) + len(ds2_this_cl) == nb_elems_per_class[subdir.name] + ), f"{len(ds1_this_cl)} + {len(ds2_this_cl)} != {nb_elems_per_class[subdir.name]}" + metrics_dict_cl = torch_fidelity.calculate_metrics( + input1=ds1_this_cl, + input2=ds2_this_cl, + cuda=True, + batch_size=batch_size, + isc=True, + fid=True, + prc=True, + verbose=True, + ) + metrics_dict[subdir.name] = metrics_dict_cl + + eval_metrics[exp_rep] = metrics_dict + + if metrics_save_path.exists(): + raise RuntimeError(f"File {metrics_save_path} already exists, not overwriting") + if not metrics_save_path.parent.exists(): + metrics_save_path.parent.mkdir(parents=True) + with open(metrics_save_path, "w") as f: + json.dump(eval_metrics, f) + + return eval_metrics + + +batch_size = 512 +metrics_save_path = Path(f"evaluations/{dataset.name}/eval_metrics_TEST_REPS_WITH_AUGS.json") +recompute = True + +if recompute: + inpt = input("Confirm recompute (y/[n]):") + if inpt != "y": + warn(f"Will not recompute but load from {metrics_save_path}") + with open(metrics_save_path, "r") as f: + eval_metrics = json.load(f) + else: + warn("Will recompute") + assert not metrics_save_path.exists(), f"Refusing to overwrite {metrics_save_path}" + eval_metrics = compute_metrics(batch_size, metrics_save_path) +else: + warn(f"Will not recompute but load from {metrics_save_path}") + with open(metrics_save_path, "r") as f: + eval_metrics = json.load(f) + +pprint(eval_metrics) + +# Extract class names and FID scores for training data vs training data +class_names = list(eval_metrics[0].keys()) +fid_scores_by_class_train = {class_name: [] for class_name in class_names} + +for exp_rep in eval_metrics.values(): + for class_name in class_names: + fid_scores_by_class_train[class_name].append(exp_rep[class_name]["frechet_inception_distance"]) + +pprint(fid_scores_by_class_train) diff --git a/scripts/test_FID_hard_augmentations.py b/scripts/test_FID_hard_augmentations.py new file mode 100644 index 0000000..4dec9df --- /dev/null +++ b/scripts/test_FID_hard_augmentations.py @@ -0,0 +1,258 @@ +""" +Script used to test the hypothesis that augmentations +can mimic different data splits. + +- take one single split on the dataset (vs 10 in the "evaluations" notebook) +- multiply 8 times each sample with the 8 square symmetries +- compute the FID between the two augmented splits +- save it in a json file +""" + +import concurrent.futures +import json +import random +import sys +from pathlib import Path +from pprint import pprint +from warnings import warn + +import seaborn as sns +import torch +import torch_fidelity +from PIL import Image +from rich.traceback import install +from torch import Tensor +from torchvision.transforms import Compose, ConvertImageDtype, RandomHorizontalFlip, RandomVerticalFlip +from tqdm import tqdm, trange + +sys.path.insert(0, "..") +sys.path.insert(0, "../GaussianProxy") +from GaussianProxy.utils.data import BaseDataset, RandomRotationSquareSymmetry +from GaussianProxy.utils.misc import generate_all_augs + +torch.set_grad_enabled(False) +sns.set_theme(context="paper") + +install() + +# Dataset to use +from my_conf.dataset.BBBC021_196_docetaxel_inference import BBBC021_196_docetaxel_inference as dataset # noqa: E402 + +database_path = Path(dataset.path) +assert dataset.dataset_params is not None +print(f"Using dataset {dataset.name} from {database_path}") +subdirs: list[Path] = [e for e in database_path.iterdir() if e.is_dir() and not e.name.startswith(".")] +subdirs.sort(key=dataset.dataset_params.sorting_func) +print(f"Found {len(subdirs)} classes: {', '.join(e.name for e in subdirs)}") + + +def augment_images(repeat_number: int): + # split nb_repeats times the dataset into 2 non-overlapping parts, respecting classes proportions... + transforms = Compose([ConvertImageDtype(torch.uint8)]) # no augmentations! hard-saved after + print(f"Using transforms:\n{transforms}") + nb_elems_per_class: dict[str, int] = {} + + # create split and datasets once + assert dataset.dataset_params is not None + ds1_elems = [] + ds2_elems = [] + for subdir in subdirs: + this_class_elems = list(subdir.glob(f"*.{dataset.dataset_params.file_extension}")) + nb_elems_per_class[subdir.name] = len(this_class_elems) + random.shuffle(this_class_elems) + new_ds1_elems = this_class_elems[: len(this_class_elems) // 2] + new_ds2_elems = this_class_elems[len(this_class_elems) // 2 :] + ds1_elems += new_ds1_elems + ds2_elems += new_ds2_elems + assert len(new_ds1_elems) + len(new_ds2_elems) == len( + this_class_elems + ), f"{len(new_ds1_elems)} + {len(new_ds2_elems)} != {len(this_class_elems)}" + assert abs(len(ds1_elems) - len(ds2_elems)) <= len(subdirs) + ds1: BaseDataset = dataset.dataset_params.dataset_class( + ds1_elems, + transforms, + dataset.expected_initial_data_range, + ) + ds2: BaseDataset = dataset.dataset_params.dataset_class( + ds2_elems, + transforms, + dataset.expected_initial_data_range, + ) + print("ds1:", ds1) + print("ds2:", ds2) + + nb_elems_per_class["all_classes"] = sum(nb_elems_per_class.values()) + print("nb_elems_per_class:", nb_elems_per_class) + + # Augment the datasets on disk + this_rep_base_save_path = TMP_AUG_DS_SAVE_PATH / f"repeat_{repeat_number}" + print(f"Will save augmented datasets to {this_rep_base_save_path}") + + def augment_to_imgs(elem: Tensor) -> list[Image.Image]: + list_augs_tensors = generate_all_augs( + elem, [RandomRotationSquareSymmetry, RandomHorizontalFlip, RandomVerticalFlip] + ) + assert all(aug_tensor.dtype == torch.uint8 for aug_tensor in list_augs_tensors), "Expected uint8 tensors" + assert all( + 0 <= aug_tensor.min() and aug_tensor.max() <= 255 for aug_tensor in list_augs_tensors + ), "Expected [0, 255] range" + list_augs_imgs = [Image.fromarray(aug_tensor.numpy().transpose(1, 2, 0)) for aug_tensor in list_augs_tensors] + return list_augs_imgs + + for split in ("split1", "split2"): + for subdir in subdirs: + if not (this_rep_base_save_path / split / subdir.name).exists(): + (this_rep_base_save_path / split / subdir.name).mkdir(parents=True) + + print("Writing augmented datasets to disk...") + pbar = tqdm(total=len(ds1) + len(ds2), unit="base image", desc="Saving augmented images", position=2) + + def process_element(split, ds, elem_idx, elem): + elem_path = Path(ds.samples[elem_idx]) + augs = augment_to_imgs(elem) + for aug_idx, aug in enumerate(augs): + save_path = this_rep_base_save_path / split / elem_path.parent.name / f"{elem_path.stem}_{aug_idx}.png" + assert save_path.parent.exists(), f"Parent {save_path.parent} does not exist" + assert not save_path.exists(), f"Refusing to overwrite {save_path}" + aug.save(save_path) + pbar.update() + + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [] + for split, ds in (("split1", ds1), ("split2", ds2)): + for elem_idx, elem in enumerate(iter(ds)): + futures.append(executor.submit(process_element, split, ds, elem_idx, elem)) + + for future in concurrent.futures.as_completed(futures): + future.result() # raises exception if any + + pbar.close() + + +def check_augmented_datasets(nb_repeats: int): + for rep in range(nb_repeats): + nb_elems_split1_per_class = { + class_path.name: len( + list((TMP_AUG_DS_SAVE_PATH / f"repeat_{rep}" / "split1" / class_path.name).glob("*.png")) + ) + for class_path in subdirs + } + nb_elems_split2_per_class = { + class_path.name: len( + list((TMP_AUG_DS_SAVE_PATH / f"repeat_{rep}" / "split2" / class_path.name).glob("*.png")) + ) + for class_path in subdirs + } + assert all(nb_elems_split1_per_class.values()), "Expected non-zero number of elements" + for class_path in subdirs: + assert ( + nb_elems_split1_per_class[class_path.name] - nb_elems_split2_per_class[class_path.name] <= 8 + ), f"Expected same number of elements in both splits, got:{nb_elems_split1_per_class[class_path.name]} vs {nb_elems_split2_per_class[class_path.name]} for class {class_path.name}" + assert all( + nb_elems_this_split_per_class[cl_path.name] % 8 == 0 + for cl_path in subdirs + for nb_elems_this_split_per_class in (nb_elems_split1_per_class, nb_elems_split2_per_class) + ), f"Expected number of elements to be a multiple of 8, got:\n{nb_elems_split1_per_class}\nand:\n{nb_elems_split2_per_class}" + + +# FID +# ## Compute train vs train FIDs +def compute_metrics(batch_size: int, metrics_save_path: Path): + if metrics_save_path.exists(): + raise RuntimeError(f"File {metrics_save_path} already exists, not overwriting") + + eval_metrics: dict[str, dict[str, dict[str, float]]] = {} ## accumulate exp repeats here + + for repeat in [subdir.name for subdir in TMP_AUG_DS_SAVE_PATH.iterdir()]: + assert repeat.startswith("repeat_"), f"Unexpected directory {repeat}" + torch.cuda.empty_cache() + + metrics_dict: dict[str, dict[str, float]] = {} + metrics_dict["all_classes"] = torch_fidelity.calculate_metrics( + input1=(TMP_AUG_DS_SAVE_PATH / repeat / "split1").as_posix(), + input2=(TMP_AUG_DS_SAVE_PATH / repeat / "split2").as_posix(), + cuda=True, + batch_size=batch_size, + isc=False, + fid=True, + prc=False, + verbose=True, + samples_find_deep=True, + ) + # per-class + for subdir in subdirs: + metrics_dict_cl = torch_fidelity.calculate_metrics( + input1=(TMP_AUG_DS_SAVE_PATH / repeat / "split1" / subdir.name).as_posix(), + input2=(TMP_AUG_DS_SAVE_PATH / repeat / "split2" / subdir.name).as_posix(), + cuda=True, + batch_size=batch_size, + isc=False, + fid=True, + prc=False, + verbose=True, + ) + metrics_dict[subdir.name] = metrics_dict_cl + # save in common dict + eval_metrics[repeat] = metrics_dict + + if not metrics_save_path.parent.exists(): + metrics_save_path.parent.mkdir(parents=True) + with open(metrics_save_path, "w") as f: + json.dump(eval_metrics, f) + + return eval_metrics + + +TMP_AUG_DS_SAVE_PATH = Path(Path(__file__).parent, "tmp_augmented_datasets", dataset.name) +batch_size = 4096 +metrics_save_path = Path(Path(__file__).parent, "evaluations", dataset.name, "eval_metrics_TEST_HARD_AUGS.json") +reaugment = False +recompute = True +NB_REPEATS = 10 +print(f"Will save augmented datasets to {TMP_AUG_DS_SAVE_PATH}") +print(f"Will save metrics to {metrics_save_path}") +print("reaugment:", reaugment, "recompute:", recompute) + + +if __name__ == "__main__": + # augment datasets + if reaugment: + inpt = input("Confirm reaugment (y/[n]):") + if inpt != "y": + warn(f"Will not reaugment but resuse existing augmented datasets at {TMP_AUG_DS_SAVE_PATH}") + else: + warn(f"Will reaugment at {TMP_AUG_DS_SAVE_PATH}") + for rep in trange(NB_REPEATS, unit="experiment repeat"): + augment_images(rep) + else: + warn(f"Will not reaugment but resuse existing augmented datasets at {TMP_AUG_DS_SAVE_PATH}") + + # check augmented datasets + check_augmented_datasets(NB_REPEATS) + + # compute metrics + if recompute: + inpt = input("Confirm recompute (y/[n]):") + if inpt != "y": + warn(f"Will not recompute but load from {metrics_save_path}") + with open(metrics_save_path, "r") as f: + eval_metrics = json.load(f) + else: + warn("Will recompute") + assert not metrics_save_path.exists(), f"Refusing to overwrite {metrics_save_path}" + eval_metrics = compute_metrics(batch_size, metrics_save_path) + else: + warn(f"Will not recompute but load from {metrics_save_path}") + with open(metrics_save_path, "r") as f: + eval_metrics = json.load(f) + + pprint(eval_metrics) + + class_names = list(eval_metrics["repeat_0"].keys()) + fid_scores_by_class_train = {class_name: [] for class_name in class_names} + + for exp_rep in eval_metrics.values(): + for class_name in class_names: + fid_scores_by_class_train[class_name].append(exp_rep[class_name]["frechet_inception_distance"]) + + pprint(fid_scores_by_class_train)