Skip to content

Commit

Permalink
Merge pull request #39 from facefusion/add-albumentations-second-try
Browse files Browse the repository at this point in the history
Add albumentations second try
  • Loading branch information
henryruhs authored Mar 4, 2025
2 parents 907a2d0 + d4aeec3 commit b808360
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 6 deletions.
41 changes: 35 additions & 6 deletions face_swapper/src/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import random

import albumentations
from torch import Tensor
from torch.utils.data import Dataset
from torchvision import io, transforms
Expand Down Expand Up @@ -37,16 +38,12 @@ def compose_transforms(self) -> transforms:
[
transforms.ToPILImage(),
transforms.Resize((256, 256), interpolation = transforms.InterpolationMode.BICUBIC),
transforms.ColorJitter(brightness = 0.2, contrast = 0.2, saturation = 0.2, hue = 0.1),
transforms.RandomAffine(4, translate = (0.01, 0.01), scale = (0.98, 1.02), shear = (1, 1)),
AugmentTransform(),
transforms.ToTensor(),
transforms.Lambda(self.warp_tensor),
WarpTransform(self.warp_template),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

def warp_tensor(self, temp_tensor : Tensor) -> Tensor:
return warp_tensor(temp_tensor.unsqueeze(0), self.warp_template).squeeze(0)

def prepare_different_batch(self, source_path : str) -> Batch:
target_path = random.choice(self.file_paths)
source_tensor = io.read_image(source_path)
Expand All @@ -69,3 +66,35 @@ def prepare_same_batch(self, source_path : str) -> Batch:
target_tensor = io.read_image(target_path)
target_tensor = self.transforms(target_tensor)
return source_tensor, target_tensor


class AugmentTransform:
def __init__(self) -> None:
self.transforms = self.compose_transforms()

def __call__(self, input_tensor : Tensor) -> Tensor:
temp_tensor = input_tensor.numpy().transpose(1, 2, 0)
return self.transforms(temp_tensor).get('image')

@staticmethod
def compose_transforms() -> albumentations.Compose:
return albumentations.Compose(
[
albumentations.RandomBrightnessContrast(p = 0.3),
albumentations.OneOf(
[
albumentations.MotionBlur(p = 0.1),
albumentations.MedianBlur(p = 0.1)
], p = 0.3),
albumentations.ColorJitter(p = 0.1),
])


class WarpTransform:
def __init__(self, warp_template : WarpTemplate) -> None:
self.warp_template = warp_template

def __call__(self, input_tensor : Tensor) -> Tensor:
temp_tensor = input_tensor.unsqueeze(0)
return warp_tensor(temp_tensor, self.warp_template).squeeze(0)

1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
--extra-index-url https://download.pytorch.org/whl/cu124
albumentations==2.0.5
lightning==2.5.0
onnx==1.17.0
onnxruntime==1.20.1
Expand Down

0 comments on commit b808360

Please sign in to comment.