Skip to content

Commit

Permalink
Adjust some namings
Browse files Browse the repository at this point in the history
  • Loading branch information
henryruhs committed Mar 1, 2025
1 parent 3b97f6e commit 3a155ab
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions face_swapper/src/models/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,11 @@ def forward(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor
reconstruction_weight = CONFIG.getfloat('training.losses', 'reconstruction_weight')
source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
target_embedding = calc_embedding(self.embedder, target_tensor, (0, 0, 0, 0))
same_person = torch.cosine_similarity(source_embedding, target_embedding) > 0.8
is_similar_identity = torch.cosine_similarity(source_embedding, target_embedding) > 0.8

reconstruction_loss = torch.mean((source_tensor - target_tensor) ** 2, dim = (1, 2, 3))
reconstruction_loss = (reconstruction_loss * same_person).mean() * 0.5
reconstruction_loss = (reconstruction_loss * is_similar_identity).mean() * 0.5

data_range = float(torch.max(output_tensor) - torch.min(output_tensor))
similarity = 1 - ssim(output_tensor, target_tensor, data_range = data_range).mean()
reconstruction_loss = (reconstruction_loss + similarity) * 0.5
Expand Down

0 comments on commit 3a155ab

Please sign in to comment.