Skip to content

Commit

Permalink
use coherent time keys
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas Boyer committed Dec 29, 2024
1 parent ac95151 commit 2924921
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions GaussianProxy/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def setup_dataloaders(
sorting_func=lambda subdir: int(subdir.name.split("_")[1]),
dataset_class=TIFFDataset,
)
case "chromaLive6h_3ch_png_patches_380px":
case "chromaLive6h_3ch_png_patches_380px" | "chromaLive6h_3ch_png_patches_380px_hard_aug":
ds_params = DatasetParams(
file_extension="png",
key_transform=str,
Expand Down Expand Up @@ -281,20 +281,25 @@ def _dataset_builder(
# use time-unpaired data if asked for
# (does not assume all videos span the same/total time range)
video_ids_times: dict[str, dict[TimeKey, Path]] = {} # dict[video_id, dict[time, file]]
time_key_type = type(list(files_dict_per_time.keys())[0])
time_key_to_time_id: dict[TimeKey, set[str]] = {time_key: set() for time_key in files_dict_per_time.keys()}
if cfg.training.unpaired_data:
logger.warning("Building time-unpaired dataset")
# fill dict
for time, files in files_dict_per_time.items():
for f in files:
video_id, time = extract_video_id(f.stem, time_key_type)
video_id, time_id = extract_video_id(f.stem)
time_key_to_time_id[time].add(time_id)
if video_id not in video_ids_times:
video_ids_times[video_id] = {time: f}
else:
assert (
time not in video_ids_times[video_id]
), f"Found multiple files at time {time} for video {video_id}: {f} and {video_ids_times[video_id][time]}"
video_ids_times[video_id][time] = f
# check 1-to-1 mapping between found time ids and time keys
assert all(
len(time_key_to_time_id[time_key]) == 1 for time_key in files_dict_per_time.keys()
), f"Found multiple time ids for some time keys: time_key_to_time_id={time_key_to_time_id}"
# select one time at random for each video_id
unpaired_files_dict_per_time: dict[TimeKey, list[Path]] = {}
for video_id, times_files_d in video_ids_times.items():
Expand Down Expand Up @@ -420,7 +425,7 @@ def _dataset_builder(
return train_dataloaders_dict, test_dataloaders_dict


def extract_video_id(filename: str, time_key_type: type[TimeKey]) -> tuple[str, TimeKey]:
def extract_video_id(filename: str) -> tuple[str, str]:
"""
Ugly helper to extract video_id and time from a filename, using hard-coded rules.
Expand All @@ -435,8 +440,7 @@ def extract_video_id(filename: str, time_key_type: type[TimeKey]) -> tuple[str,
else:
raise ValueError(f"Could not extract time from filename {filename}")

t = time_key_type(time)
return video_id, t
return video_id, time


def remove_flips_and_rotations_from_transforms(transforms: Compose):
Expand Down

0 comments on commit 2924921

Please sign in to comment.