Skip to content

Commit

Permalink
compute metrics on hard augmented datasets + fix distributed barrier bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas Boyer committed Dec 17, 2024
1 parent a394785 commit bef1946
Show file tree
Hide file tree
Showing 5 changed files with 746 additions and 66 deletions.
143 changes: 125 additions & 18 deletions GaussianProxy/utils/misc.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,30 @@
import time
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
from functools import wraps
from logging import Logger
from pathlib import Path
from typing import Literal, Optional
from typing import Literal, Optional, overload

import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import torch
import wandb
from accelerate import Accelerator
from accelerate.logging import MultiProcessAdapter
from diffusers.configuration_utils import FrozenDict
from enlighten import Manager
from numpy import ndarray
from omegaconf import OmegaConf
from PIL import Image
from termcolor import colored
from torch import Tensor
from torchvision.transforms import RandomHorizontalFlip, RandomVerticalFlip
from wandb.sdk.wandb_run import Run as WandBRun

import wandb
from GaussianProxy.conf.training_conf import Config
from GaussianProxy.utils.data import RandomRotationSquareSymmetry
from wandb.sdk.wandb_run import Run as WandBRun


def create_repo_structure(
Expand Down Expand Up @@ -136,7 +137,9 @@ def modify_args_for_debug(
strat.nb_diffusion_timesteps = 5
if hasattr(strat, "nb_samples_to_gen_per_time"):
assert hasattr(strat, "batch_size"), "Expected batch_size to be present"
setattr(strat, "nb_samples_to_gen_per_time", getattr(strat, "batch_size"))
# only change it if it's a hard-coded value
if isinstance(getattr(strat, "nb_samples_to_gen_per_time"), int):
setattr(strat, "nb_samples_to_gen_per_time", getattr(strat, "batch_size"))
# TODO: update registered wandb config


Expand Down Expand Up @@ -546,23 +549,59 @@ def filter_internal_fields(config_dict: FrozenDict | dict):
RandomHorizontalFlip,
RandomVerticalFlip,
RandomRotationSquareSymmetry,
"RandomHorizontalFlip",
"RandomVerticalFlip",
"RandomRotationSquareSymmetry",
)


def generate_all_augs(img: torch.Tensor, transforms: list[type]) -> list[Tensor]:
@overload
def generate_all_augs(img: Tensor, transforms: list[type] | list[str]) -> list[Tensor]: ...


@overload
def generate_all_augs(img: Image.Image, transforms: list[type] | list[str]) -> list[Image.Image]: ...


def generate_all_augs(
img: Tensor | Image.Image, transforms: list[type] | list[str]
) -> list[Tensor] | list[Image.Image]:
"""Generate all augmentations of `img` based on `transforms`."""
# checks
assert img.ndim == 3, f"Expected 3D image, got shape {img.shape}"
# general checks
assert all(t in SUPPORTED_TRANSFORMS_TO_GENERATE for t in transforms), f"Unsupported transforms: {transforms}"

# choose backend
if isinstance(img, Tensor):
aug_imgs = _generate_all_augs_torch(img, transforms)
elif isinstance(img, Image.Image):
aug_imgs = _generate_all_augs_pil(img, transforms)
else:
raise TypeError(f"Unsupported image type: {type(img)}")

assert len(aug_imgs) in (
1,
2,
4,
8,
), f"Expected 1, 2, 4, or 8 images at this point, got {len(aug_imgs)}"

return aug_imgs


# TODO: factorize backends better


def _generate_all_augs_torch(img: Tensor, transforms: list[type] | list[str]) -> list[Tensor]:
"""torch backend for `generate_all_augs`."""
assert img.ndim == 3, f"Expected 3D image, got shape {img.shape}"
# generate all possible augmentations
aug_imgs = [img]
if RandomHorizontalFlip in transforms:
if RandomHorizontalFlip in transforms or "RandomHorizontalFlip" in transforms:
aug_imgs.append(torch.flip(img, [2]))
if RandomVerticalFlip in transforms:
if RandomVerticalFlip in transforms or "RandomVerticalFlip" in transforms:
for base_img_idx in range(len(aug_imgs)):
aug_imgs.append(torch.flip(aug_imgs[base_img_idx], [1]))
if RandomRotationSquareSymmetry in transforms:
if RandomRotationSquareSymmetry in transforms or "RandomRotationSquareSymmetry" in transforms:
if len(aug_imgs) == 4: # must have been flipped in both directions then
aug_imgs += [
torch.rot90(aug_imgs[0], k=1, dims=(1, 2)),
Expand All @@ -572,17 +611,85 @@ def generate_all_augs(img: torch.Tensor, transforms: list[type]) -> list[Tensor]
torch.rot90(aug_imgs[1], k=1, dims=(1, 2)),
torch.rot90(aug_imgs[1], k=3, dims=(1, 2)),
]
elif len(aug_imgs) in (1, 2): # simply perform all 4 rotations on each image
elif len(aug_imgs) in (1, 2): # simply perform all 3 rotations on each image
for base_img_idx in range(len(aug_imgs)):
for nb_rot in [1, 2, 3]:
aug_imgs.append(torch.rot90(aug_imgs[base_img_idx], k=nb_rot, dims=(1, 2)))
else:
raise ValueError(f"Expected 1, 2, or 4 images at this point, got {len(aug_imgs)}")

assert len(aug_imgs) in (
1,
2,
4,
8,
), f"Expected 1, 2, 4, or 8 images at this point, got {len(aug_imgs)}"
return aug_imgs


def _generate_all_augs_pil(img: Image.Image, transforms: list[type] | list[str]) -> list[Image.Image]:
"""PIL backend for `generate_all_augs`."""
# generate all possible augmentations
aug_imgs = [img]
if RandomHorizontalFlip in transforms or "RandomHorizontalFlip" in transforms:
aug_imgs.append(img.transpose(Image.Transpose.FLIP_LEFT_RIGHT))
if RandomVerticalFlip in transforms or "RandomVerticalFlip" in transforms:
for base_img_idx in range(len(aug_imgs)):
aug_imgs.append(aug_imgs[base_img_idx].transpose(Image.Transpose.FLIP_TOP_BOTTOM))
if RandomRotationSquareSymmetry in transforms or "RandomRotationSquareSymmetry" in transforms:
if len(aug_imgs) == 4: # must have been flipped in both directions then
aug_imgs += [
aug_imgs[0].transpose(Image.Transpose.ROTATE_90),
aug_imgs[0].transpose(Image.Transpose.ROTATE_270),
]
aug_imgs += [
aug_imgs[1].transpose(Image.Transpose.ROTATE_90),
aug_imgs[1].transpose(Image.Transpose.ROTATE_270),
]
elif len(aug_imgs) in (1, 2): # simply perform all 3 rotations on each image
for base_img_idx in range(len(aug_imgs)):
aug_imgs += [
aug_imgs[base_img_idx].transpose(Image.Transpose.ROTATE_90),
aug_imgs[base_img_idx].transpose(Image.Transpose.ROTATE_180),
aug_imgs[base_img_idx].transpose(Image.Transpose.ROTATE_270),
]
else:
raise ValueError(f"Expected 1, 2, or 4 images at this point, got {len(aug_imgs)}")

return aug_imgs


def hard_augment_dataset_all_square_symmetries(
dataset_path: Path, logger: MultiProcessAdapter, pbar_manager: Manager, pbar_pos: int, files_ext: str
):
"""
Save ("in-place") the 8 augmented versions of each image in the given `dataset_path`.
### Args:
- `dataset_path` (`Path`): The path to the dataset to augment.
Each image will be augmented with the 8 square symmetries (Dih4)
and saved in the same subfolder with '_aug_<idx>' appended.
This function should only be ran on main process.
"""
# Get all base images
all_base_imgs = list(dataset_path.rglob(f"*.{files_ext}"))
assert len(all_base_imgs) > 0, f"No '*.{files_ext}' images found in {dataset_path}"
logger.debug(f"Found {len(all_base_imgs)} base images in {dataset_path}")

# Save augmented images
logger.debug("Writing augmented datasets to disk")
pbar = pbar_manager.counter(total=..., unit="base image", desc="Saving augmented images", position=pbar_pos)

def aug_save_img(base_img_path: Path):
base_img = Image.open(base_img_path)
augs = generate_all_augs(base_img, [RandomRotationSquareSymmetry, RandomHorizontalFlip, RandomVerticalFlip])
for aug_idx, aug in enumerate(augs[1:]): # skip the original image
save_path = base_img_path.parent / f"{base_img_path.stem}_aug{aug_idx}.{base_img_path.suffix}"
aug.save(save_path)
pbar.update()

with ThreadPoolExecutor() as executor:
futures = []
for base_img in all_base_imgs:
futures.append(executor.submit(aug_save_img, base_img))

for future in as_completed(futures):
future.result() # raises exception if any

pbar.close()
Loading

0 comments on commit bef1946

Please sign in to comment.