diff --git a/face_swapper/src/dataset.py b/face_swapper/src/dataset.py index aa1c46d..a0f00e7 100644 --- a/face_swapper/src/dataset.py +++ b/face_swapper/src/dataset.py @@ -2,6 +2,7 @@ import os import random +import albumentations from torch import Tensor from torch.utils.data import Dataset from torchvision import io, transforms @@ -37,16 +38,12 @@ def compose_transforms(self) -> transforms: [ transforms.ToPILImage(), transforms.Resize((256, 256), interpolation = transforms.InterpolationMode.BICUBIC), - transforms.ColorJitter(brightness = 0.2, contrast = 0.2, saturation = 0.2, hue = 0.1), - transforms.RandomAffine(4, translate = (0.01, 0.01), scale = (0.98, 1.02), shear = (1, 1)), + AugmentTransform(), transforms.ToTensor(), - transforms.Lambda(self.warp_tensor), + WarpTransform(self.warp_template), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) - def warp_tensor(self, temp_tensor : Tensor) -> Tensor: - return warp_tensor(temp_tensor.unsqueeze(0), self.warp_template).squeeze(0) - def prepare_different_batch(self, source_path : str) -> Batch: target_path = random.choice(self.file_paths) source_tensor = io.read_image(source_path) @@ -69,3 +66,35 @@ def prepare_same_batch(self, source_path : str) -> Batch: target_tensor = io.read_image(target_path) target_tensor = self.transforms(target_tensor) return source_tensor, target_tensor + + +class AugmentTransform: + def __init__(self) -> None: + self.transforms = self.compose_transforms() + + def __call__(self, input_tensor : Tensor) -> Tensor: + temp_tensor = input_tensor.numpy().transpose(1, 2, 0) + return self.transforms(temp_tensor).get('image') + + @staticmethod + def compose_transforms() -> albumentations.Compose: + return albumentations.Compose( + [ + albumentations.RandomBrightnessContrast(p = 0.3), + albumentations.OneOf( + [ + albumentations.MotionBlur(p = 0.1), + albumentations.MedianBlur(p = 0.1) + ], p = 0.3), + albumentations.ColorJitter(p = 0.1), + ]) + + +class WarpTransform: + def __init__(self, warp_template : WarpTemplate) -> None: + self.warp_template = warp_template + + def __call__(self, input_tensor : Tensor) -> Tensor: + temp_tensor = input_tensor.unsqueeze(0) + return warp_tensor(temp_tensor, self.warp_template).squeeze(0) + diff --git a/requirements.txt b/requirements.txt index d635148..103dce8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ --extra-index-url https://download.pytorch.org/whl/cu124 +albumentations==2.0.5 lightning==2.5.0 onnx==1.17.0 onnxruntime==1.20.1