Skip to content

Commit

Permalink
allow to use more than two optimizers
Browse files Browse the repository at this point in the history
In the same way that is done for the speaker diarization task
  • Loading branch information
clement-pages committed Nov 20, 2024
1 parent ca2a5d4 commit c3a0313
Showing 1 changed file with 14 additions and 19 deletions.
33 changes: 14 additions & 19 deletions pyannote/audio/tasks/separation/PixIT.py
Original file line number Diff line number Diff line change
Expand Up @@ -966,20 +966,14 @@ def training_step(self, batch, batch_idx: int):
loss : {str: torch.tensor}
{"loss": loss}
"""
# finetuning wavlm with a smaller learning rate requires two optimizers
# and manual gradient stepping
if not self.automatic_optimization:
wavlm_opt, rest_opt = self.model.optimizers()
wavlm_opt.zero_grad()
rest_opt.zero_grad()

(
seg_loss,
separation_loss,
diarization,
permutated_diarization,
target,
) = self.common_step(batch)

self.model.log(
"loss/train/separation",
separation_loss,
Expand Down Expand Up @@ -1015,20 +1009,21 @@ def training_step(self, batch, batch_idx: int):
logger=True,
)

# using multiple optimizers requires manual optimization
if not self.automatic_optimization:
optimizers = self.model.optimizers()
for optimizer in optimizers:
optimizer.zero_grad()

self.model.manual_backward(loss)
self.model.clip_gradients(
wavlm_opt,
gradient_clip_val=self.model.gradient_clip_val,
gradient_clip_algorithm="norm",
)
self.model.clip_gradients(
rest_opt,
gradient_clip_val=self.model.gradient_clip_val,
gradient_clip_algorithm="norm",
)
wavlm_opt.step()
rest_opt.step()

for optimizer in optimizers:
self.model.clip_gradients(
optimizer,
gradient_clip_val=self.model.gradient_clip_val,
gradient_clip_algorithm="norm",
)
optimizer.step()

return {"loss": loss}

Expand Down

0 comments on commit c3a0313

Please sign in to comment.