diff --git a/face_swapper/src/dataset.py b/face_swapper/src/dataset.py index aa1c46d..32a715e 100644 --- a/face_swapper/src/dataset.py +++ b/face_swapper/src/dataset.py @@ -2,6 +2,8 @@ import os import random +import albumentations +import torch from torch import Tensor from torch.utils.data import Dataset from torchvision import io, transforms @@ -16,6 +18,7 @@ def __init__(self, file_pattern : str, warp_template : WarpTemplate, batch_mode self.warp_template = warp_template self.batch_mode = batch_mode self.batch_ratio = batch_ratio + self.augmentations = self.compose_augmentations() self.transforms = self.compose_transforms() def __getitem__(self, index : int) -> Batch: @@ -32,31 +35,48 @@ def __getitem__(self, index : int) -> Batch: def __len__(self) -> int: return len(self.file_paths) + def compose_augmentations(self) -> albumentations.Compose: + return albumentations.Compose( + [ + albumentations.RandomBrightnessContrast(p = 0.3), + albumentations.OneOf([ + albumentations.MotionBlur(p = 0.1), + albumentations.MedianBlur(p = 0.1) + ], p = 0.3), + albumentations.ColorJitter(p = 0.1), + albumentations.ToTensorV2() + ]) + def compose_transforms(self) -> transforms: return transforms.Compose( [ - transforms.ToPILImage(), - transforms.Resize((256, 256), interpolation = transforms.InterpolationMode.BICUBIC), - transforms.ColorJitter(brightness = 0.2, contrast = 0.2, saturation = 0.2, hue = 0.1), - transforms.RandomAffine(4, translate = (0.01, 0.01), scale = (0.98, 1.02), shear = (1, 1)), - transforms.ToTensor(), + transforms.Resize(256), transforms.Lambda(self.warp_tensor), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) + def apply_augmentations(self, input_tensor : Tensor) -> Tensor: + input_frame = input_tensor.numpy().transpose(1, 2, 0) + output_tensor = self.augmentations(image = input_frame)['image'] + output_tensor = output_tensor.to(torch.float32) / 255 + return output_tensor + def warp_tensor(self, temp_tensor : Tensor) -> Tensor: return warp_tensor(temp_tensor.unsqueeze(0), self.warp_template).squeeze(0) def prepare_different_batch(self, source_path : str) -> Batch: target_path = random.choice(self.file_paths) source_tensor = io.read_image(source_path) + source_tensor = self.apply_augmentations(source_tensor) source_tensor = self.transforms(source_tensor) target_tensor = io.read_image(target_path) + target_tensor = self.apply_augmentations(target_tensor) target_tensor = self.transforms(target_tensor) return source_tensor, target_tensor def prepare_equal_batch(self, source_path : str) -> Batch: source_tensor = io.read_image(source_path) + source_tensor = self.apply_augmentations(source_tensor) source_tensor = self.transforms(source_tensor) return source_tensor, source_tensor @@ -65,7 +85,9 @@ def prepare_same_batch(self, source_path : str) -> Batch: target_file_name_and_extension = random.choice(os.listdir(target_directory_path)) target_path = os.path.join(target_directory_path, target_file_name_and_extension) source_tensor = io.read_image(source_path) + source_tensor = self.apply_augmentations(source_tensor) source_tensor = self.transforms(source_tensor) target_tensor = io.read_image(target_path) + target_tensor = self.apply_augmentations(target_tensor) target_tensor = self.transforms(target_tensor) return source_tensor, target_tensor