Skip to content

Commit

Permalink
changes
Browse files Browse the repository at this point in the history
  • Loading branch information
harisreedhar committed Feb 27, 2025
1 parent c44ba03 commit 0e25157
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions face_swapper/src/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def configure_optimizers(self) -> Tuple[OptimizerConfig, OptimizerConfig]:
learning_rate = CONFIG.getfloat('training.trainer', 'learning_rate')
generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr = learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr = learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
generator_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(generator_optimizer, T_max = 10, eta_min = 1e-6)
discriminator_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(discriminator_optimizer, T_max = 10, eta_min = 1e-6)
generator_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(generator_optimizer, T_0 = 300, T_mult = 2)
discriminator_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(discriminator_optimizer, T_0 = 300, T_mult = 2)

generator_config =\
{
Expand Down Expand Up @@ -82,6 +82,7 @@ def training_step(self, batch : Batch, batch_index : int) -> Tensor:
generator_output_attributes = self.generator.get_attributes(generator_output_tensor)
discriminator_output_tensors = self.discriminator(generator_output_tensor)

self.toggle_optimizer(generator_optimizer)
adversarial_loss, weighted_adversarial_loss = self.adversarial_loss(discriminator_output_tensors)
attribute_loss, weighted_attribute_loss = self.attribute_loss(target_attributes, generator_output_attributes)
reconstruction_loss, weighted_reconstruction_loss = self.reconstruction_loss(source_tensor, target_tensor, generator_output_tensor)
Expand All @@ -93,20 +94,17 @@ def training_step(self, batch : Batch, batch_index : int) -> Tensor:
generator_optimizer.zero_grad()
self.manual_backward(generator_loss)
generator_optimizer.step()
self.untoggle_optimizer(generator_optimizer)

generator_scheduler = self.lr_schedulers()[0]
generator_scheduler.step()

self.toggle_optimizer(discriminator_optimizer)
discriminator_source_tensors = self.discriminator(source_tensor)
discriminator_output_tensors = self.discriminator(generator_output_tensor.detach())
discriminator_loss = self.discriminator_loss(discriminator_source_tensors, discriminator_output_tensors)

discriminator_optimizer.zero_grad()
self.manual_backward(discriminator_loss)
discriminator_optimizer.step()

discriminator_scheduler = self.lr_schedulers()[1]
discriminator_scheduler.step()
self.untoggle_optimizer(discriminator_optimizer)

if self.global_step % preview_frequency == 0:
self.generate_preview(source_tensor, target_tensor, generator_output_tensor)
Expand Down

0 comments on commit 0e25157

Please sign in to comment.