Skip to content

Commit

Permalink
changes
Browse files Browse the repository at this point in the history
  • Loading branch information
harisreedhar committed Feb 27, 2025
1 parent 2e74010 commit 958e85f
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 15 deletions.
21 changes: 9 additions & 12 deletions face_swapper/src/models/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,32 +70,29 @@ def forward(self, target_attributes : Attributes, output_attributes : Attributes


class ReconstructionLoss(nn.Module):
def __init__(self) -> None:
def __init__(self, embedder : nn.Module) -> None:
super().__init__()
self.embedder = embedder
self.mse_loss = nn.MSELoss()

def forward(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]:
reconstruction_weight = CONFIG.getfloat('training.losses', 'reconstruction_weight')
temp_tensors = []

for __source_tensor__, __target_tensor__ in zip(source_tensor, target_tensor):
temp_tensor = self.mse_loss(__source_tensor__, __target_tensor__) * torch.equal(__source_tensor__, __target_tensor__)
temp_tensors.append(temp_tensor)

reconstruction_loss = torch.stack(temp_tensors).mean() * 0.5
source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
target_embedding = calc_embedding(self.embedder, output_tensor, (0, 0, 0, 0))
same_person = torch.cosine_similarity(source_embedding, target_embedding) > 0.8
reconstruction_loss = torch.mean((source_tensor - target_tensor) ** 2, dim = (1, 2, 3))
reconstruction_loss = (reconstruction_loss * same_person).mean() * 0.5
data_range = float(torch.max(output_tensor) - torch.min(output_tensor))
similarity = 1 - ssim(output_tensor, target_tensor, data_range = data_range).mean()

reconstruction_loss = (reconstruction_loss + similarity) * 0.5
weighted_reconstruction_loss = reconstruction_loss * reconstruction_weight
return reconstruction_loss, weighted_reconstruction_loss


class IdentityLoss(nn.Module):
def __init__(self) -> None:
def __init__(self, embedder : nn.Module) -> None:
super().__init__()
embedder_path = CONFIG.get('training.model', 'embedder_path')
self.embedder = torch.jit.load(embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.embedder = embedder

def forward(self, source_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]:
identity_weight = CONFIG.getfloat('training.losses', 'identity_weight')
Expand Down
6 changes: 3 additions & 3 deletions face_swapper/src/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,16 @@ def __init__(self) -> None:
super().__init__()
embedder_path = CONFIG.get('training.model', 'embedder_path')

self.embedder = torch.jit.load(embedder_path, map_location='cpu') # type:ignore[no-untyped-call]
self.generator = Generator()
self.discriminator = Discriminator()
self.discriminator_loss = DiscriminatorLoss()
self.adversarial_loss = AdversarialLoss()
self.attribute_loss = AttributeLoss()
self.reconstruction_loss = ReconstructionLoss()
self.identity_loss = IdentityLoss()
self.reconstruction_loss = ReconstructionLoss(self.embedder)
self.identity_loss = IdentityLoss(self.embedder)
self.pose_loss = PoseLoss()
self.gaze_loss = GazeLoss()
self.embedder = torch.jit.load(embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.automatic_optimization = False

def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tensor:
Expand Down

0 comments on commit 958e85f

Please sign in to comment.