Skip to content

Commit

Permalink
changes
Browse files Browse the repository at this point in the history
  • Loading branch information
harisreedhar committed Mar 3, 2025
1 parent 907a2d0 commit 1eb4330
Showing 1 changed file with 27 additions and 5 deletions.
32 changes: 27 additions & 5 deletions face_swapper/src/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import os
import random

import albumentations
import torch
from torch import Tensor
from torch.utils.data import Dataset
from torchvision import io, transforms
Expand All @@ -16,6 +18,7 @@ def __init__(self, file_pattern : str, warp_template : WarpTemplate, batch_mode
self.warp_template = warp_template
self.batch_mode = batch_mode
self.batch_ratio = batch_ratio
self.augmentations = self.compose_augmentations()
self.transforms = self.compose_transforms()

def __getitem__(self, index : int) -> Batch:
Expand All @@ -32,31 +35,48 @@ def __getitem__(self, index : int) -> Batch:
def __len__(self) -> int:
return len(self.file_paths)

def compose_augmentations(self) -> albumentations.Compose:
return albumentations.Compose(
[
albumentations.RandomBrightnessContrast(p = 0.3),
albumentations.OneOf([
albumentations.MotionBlur(p = 0.1),
albumentations.MedianBlur(p = 0.1)
], p = 0.3),
albumentations.ColorJitter(p = 0.1),
albumentations.ToTensorV2()
])

def compose_transforms(self) -> transforms:
return transforms.Compose(
[
transforms.ToPILImage(),
transforms.Resize((256, 256), interpolation = transforms.InterpolationMode.BICUBIC),
transforms.ColorJitter(brightness = 0.2, contrast = 0.2, saturation = 0.2, hue = 0.1),
transforms.RandomAffine(4, translate = (0.01, 0.01), scale = (0.98, 1.02), shear = (1, 1)),
transforms.ToTensor(),
transforms.Resize(256),
transforms.Lambda(self.warp_tensor),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

def apply_augmentations(self, input_tensor : Tensor) -> Tensor:
input_frame = input_tensor.numpy().transpose(1, 2, 0)
output_tensor = self.augmentations(image = input_frame)['image']
output_tensor = output_tensor.to(torch.float32) / 255
return output_tensor

def warp_tensor(self, temp_tensor : Tensor) -> Tensor:
return warp_tensor(temp_tensor.unsqueeze(0), self.warp_template).squeeze(0)

def prepare_different_batch(self, source_path : str) -> Batch:
target_path = random.choice(self.file_paths)
source_tensor = io.read_image(source_path)
source_tensor = self.apply_augmentations(source_tensor)
source_tensor = self.transforms(source_tensor)
target_tensor = io.read_image(target_path)
target_tensor = self.apply_augmentations(target_tensor)
target_tensor = self.transforms(target_tensor)
return source_tensor, target_tensor

def prepare_equal_batch(self, source_path : str) -> Batch:
source_tensor = io.read_image(source_path)
source_tensor = self.apply_augmentations(source_tensor)
source_tensor = self.transforms(source_tensor)
return source_tensor, source_tensor

Expand All @@ -65,7 +85,9 @@ def prepare_same_batch(self, source_path : str) -> Batch:
target_file_name_and_extension = random.choice(os.listdir(target_directory_path))
target_path = os.path.join(target_directory_path, target_file_name_and_extension)
source_tensor = io.read_image(source_path)
source_tensor = self.apply_augmentations(source_tensor)
source_tensor = self.transforms(source_tensor)
target_tensor = io.read_image(target_path)
target_tensor = self.apply_augmentations(target_tensor)
target_tensor = self.transforms(target_tensor)
return source_tensor, target_tensor

0 comments on commit 1eb4330

Please sign in to comment.