From 2924921a3fee011cc722f98d94197a7f5bed73ee Mon Sep 17 00:00:00 2001 From: Thomas Boyer Date: Sun, 29 Dec 2024 18:55:00 +0100 Subject: [PATCH] use coherent time keys --- GaussianProxy/utils/data.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/GaussianProxy/utils/data.py b/GaussianProxy/utils/data.py index 5f3beaa..1051acd 100644 --- a/GaussianProxy/utils/data.py +++ b/GaussianProxy/utils/data.py @@ -231,7 +231,7 @@ def setup_dataloaders( sorting_func=lambda subdir: int(subdir.name.split("_")[1]), dataset_class=TIFFDataset, ) - case "chromaLive6h_3ch_png_patches_380px": + case "chromaLive6h_3ch_png_patches_380px" | "chromaLive6h_3ch_png_patches_380px_hard_aug": ds_params = DatasetParams( file_extension="png", key_transform=str, @@ -281,13 +281,14 @@ def _dataset_builder( # use time-unpaired data if asked for # (does not assume all videos span the same/total time range) video_ids_times: dict[str, dict[TimeKey, Path]] = {} # dict[video_id, dict[time, file]] - time_key_type = type(list(files_dict_per_time.keys())[0]) + time_key_to_time_id: dict[TimeKey, set[str]] = {time_key: set() for time_key in files_dict_per_time.keys()} if cfg.training.unpaired_data: logger.warning("Building time-unpaired dataset") # fill dict for time, files in files_dict_per_time.items(): for f in files: - video_id, time = extract_video_id(f.stem, time_key_type) + video_id, time_id = extract_video_id(f.stem) + time_key_to_time_id[time].add(time_id) if video_id not in video_ids_times: video_ids_times[video_id] = {time: f} else: @@ -295,6 +296,10 @@ def _dataset_builder( time not in video_ids_times[video_id] ), f"Found multiple files at time {time} for video {video_id}: {f} and {video_ids_times[video_id][time]}" video_ids_times[video_id][time] = f + # check 1-to-1 mapping between found time ids and time keys + assert all( + len(time_key_to_time_id[time_key]) == 1 for time_key in files_dict_per_time.keys() + ), f"Found multiple time ids for some time keys: time_key_to_time_id={time_key_to_time_id}" # select one time at random for each video_id unpaired_files_dict_per_time: dict[TimeKey, list[Path]] = {} for video_id, times_files_d in video_ids_times.items(): @@ -420,7 +425,7 @@ def _dataset_builder( return train_dataloaders_dict, test_dataloaders_dict -def extract_video_id(filename: str, time_key_type: type[TimeKey]) -> tuple[str, TimeKey]: +def extract_video_id(filename: str) -> tuple[str, str]: """ Ugly helper to extract video_id and time from a filename, using hard-coded rules. @@ -435,8 +440,7 @@ def extract_video_id(filename: str, time_key_type: type[TimeKey]) -> tuple[str, else: raise ValueError(f"Could not extract time from filename {filename}") - t = time_key_type(time) - return video_id, t + return video_id, time def remove_flips_and_rotations_from_transforms(transforms: Compose):