Skip to content

Commit

Permalink
Merge pull request #40 from facefusion/expression-loss
Browse files Browse the repository at this point in the history
Add expression loss
  • Loading branch information
harisreedhar authored Mar 4, 2025
2 parents b808360 + 2b3d4e9 commit 6520e9b
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 21 deletions.
3 changes: 2 additions & 1 deletion face_swapper/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,9 @@ adversarial_weight = 1.5
attribute_weight = 10
reconstruction_weight = 20
identity_weight = 20
pose_weight = 0
gaze_weight = 0
pose_weight = 0
expression_weight = 0
```

```
Expand Down
3 changes: 2 additions & 1 deletion face_swapper/config.ini
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ adversarial_weight =
attribute_weight =
reconstruction_weight =
identity_weight =
pose_weight =
gaze_weight =
pose_weight =
expression_weight =

[training.trainer]
learning_rate =
Expand Down
7 changes: 3 additions & 4 deletions face_swapper/src/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ def __len__(self) -> int:
def compose_transforms(self) -> transforms:
return transforms.Compose(
[
AugmentTransform(),
transforms.ToPILImage(),
transforms.Resize((256, 256), interpolation = transforms.InterpolationMode.BICUBIC),
AugmentTransform(),
transforms.ToTensor(),
WarpTransform(self.warp_template),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
Expand Down Expand Up @@ -74,7 +74,7 @@ def __init__(self) -> None:

def __call__(self, input_tensor : Tensor) -> Tensor:
temp_tensor = input_tensor.numpy().transpose(1, 2, 0)
return self.transforms(temp_tensor).get('image')
return self.transforms(image = temp_tensor).get('image')

@staticmethod
def compose_transforms() -> albumentations.Compose:
Expand All @@ -86,7 +86,7 @@ def compose_transforms() -> albumentations.Compose:
albumentations.MotionBlur(p = 0.1),
albumentations.MedianBlur(p = 0.1)
], p = 0.3),
albumentations.ColorJitter(p = 0.1),
albumentations.ColorJitter(p = 0.1)
])


Expand All @@ -97,4 +97,3 @@ def __init__(self, warp_template : WarpTemplate) -> None:
def __call__(self, input_tensor : Tensor) -> Tensor:
temp_tensor = input_tensor.unsqueeze(0)
return warp_tensor(temp_tensor, self.warp_template).squeeze(0)

34 changes: 23 additions & 11 deletions face_swapper/src/models/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,31 +106,43 @@ def forward(self, source_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tenso
return identity_loss, weighted_identity_loss


class PoseLoss(nn.Module):
def __init__(self, motion_extractor : MotionExtractorModule) -> None:
class MotionLoss(nn.Module):
def __init__(self, motion_extractor : MotionExtractorModule):
super().__init__()
self.motion_extractor = motion_extractor
self.mse_loss = nn.MSELoss()

def forward(self, target_tensor : Tensor, output_tensor : Tensor, ) -> Tuple[Tensor, Tensor]:
def forward(self, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, ...]:
target_poses, target_expression = self.get_motions(target_tensor)
output_poses, output_expression = self.get_motions(output_tensor)
pose_loss, weighted_pose_loss = self.calc_pose_loss(target_poses, output_poses)
expression_loss, weighted_expression_loss = self.calc_expression_loss(target_expression, output_expression)
return pose_loss, weighted_pose_loss, expression_loss, weighted_expression_loss

def calc_pose_loss(self, target_poses : Tuple[Tensor, ...], output_poses : Tuple[Tensor, ...]) -> Tuple[Tensor, Tensor]:
pose_weight = CONFIG.getfloat('training.losses', 'pose_weight')
output_motion_features = self.get_motion_features(output_tensor)
target_motion_features = self.get_motion_features(target_tensor)
temp_tensors = []

for target_motion_feature, output_motion_feature in zip(target_motion_features, output_motion_features):
temp_tensor = self.mse_loss(target_motion_feature, output_motion_feature)
for target_pose, output_pose in zip(target_poses, output_poses):
temp_tensor = self.mse_loss(target_pose, output_pose)
temp_tensors.append(temp_tensor)

pose_loss = torch.stack(temp_tensors).mean()
weighted_pose_loss = pose_loss * pose_weight
return pose_loss, weighted_pose_loss

def get_motion_features(self, input_tensor : Tensor) -> Tuple[Tensor, Tensor, Tensor]:
def calc_expression_loss(self, target_expression : Tensor, output_expression : Tensor) -> Tuple[Tensor, Tensor]:
expression_weight = CONFIG.getfloat('training.losses', 'expression_weight')
expression_loss = (1 - torch.cosine_similarity(target_expression, output_expression)).mean()
weighted_expression_loss = expression_loss * expression_weight
return expression_loss, weighted_expression_loss

def get_motions(self, input_tensor : Tensor) -> Tuple[Tuple[Tensor, ...], Tensor]:
input_tensor = (input_tensor + 1) * 0.5
pitch, yaw, roll, translation, expression, scale, _ = self.motion_extractor(input_tensor)
pitch, yaw, roll, translation, expression, scale, motion_points = self.motion_extractor(input_tensor)
rotation = torch.cat([ pitch, yaw, roll ], dim = 1)
return translation, scale, rotation
pose = translation, scale, rotation, motion_points
return pose, expression


class GazeLoss(nn.Module):
Expand All @@ -144,7 +156,7 @@ def __init__(self, gazer : GazerModule) -> None:
transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ], std = [ 0.229, 0.224, 0.225 ])
])

def forward(self, target_tensor : Tensor, output_tensor : Tensor, ) -> Tuple[Tensor, Tensor]:
def forward(self, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]:
gaze_weight = CONFIG.getfloat('training.losses', 'gaze_weight')
output_pitch_tensor, output_yaw_tensor = self.detect_gaze(output_tensor)
target_pitch_tensor, target_yaw_tensor = self.detect_gaze(target_tensor)
Expand Down
8 changes: 4 additions & 4 deletions face_swapper/src/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .helper import calc_embedding
from .models.discriminator import Discriminator
from .models.generator import Generator
from .models.loss import AdversarialLoss, AttributeLoss, DiscriminatorLoss, GazeLoss, IdentityLoss, PoseLoss, ReconstructionLoss
from .models.loss import AdversarialLoss, AttributeLoss, DiscriminatorLoss, GazeLoss, IdentityLoss, MotionLoss, ReconstructionLoss
from .types import Batch, BatchMode, Embedding, OptimizerConfig, WarpTemplate

warnings.filterwarnings('ignore', category = UserWarning, module = 'torch')
Expand All @@ -44,7 +44,7 @@ def __init__(self) -> None:
self.attribute_loss = AttributeLoss()
self.reconstruction_loss = ReconstructionLoss(self.embedder)
self.identity_loss = IdentityLoss(self.embedder)
self.pose_loss = PoseLoss(self.motion_extractor)
self.motion_loss = MotionLoss(self.motion_extractor)
self.gaze_loss = GazeLoss(self.gazer)
self.automatic_optimization = False

Expand Down Expand Up @@ -95,9 +95,9 @@ def training_step(self, batch : Batch, batch_index : int) -> Tensor:
attribute_loss, weighted_attribute_loss = self.attribute_loss(target_attributes, generator_output_attributes)
reconstruction_loss, weighted_reconstruction_loss = self.reconstruction_loss(source_tensor, target_tensor, generator_output_tensor)
identity_loss, weighted_identity_loss = self.identity_loss(generator_output_tensor, source_tensor)
pose_loss, weighted_pose_loss = self.pose_loss(target_tensor, generator_output_tensor)
pose_loss, weighted_pose_loss, expression_loss, weighted_expression_loss = self.motion_loss(target_tensor, generator_output_tensor)
gaze_loss, weighted_gaze_loss = self.gaze_loss(target_tensor, generator_output_tensor)
generator_loss = weighted_adversarial_loss + weighted_attribute_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_pose_loss + weighted_gaze_loss
generator_loss = weighted_adversarial_loss + weighted_attribute_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_pose_loss + weighted_gaze_loss + weighted_expression_loss

generator_optimizer.zero_grad()
self.manual_backward(generator_loss)
Expand Down

0 comments on commit 6520e9b

Please sign in to comment.