Skip to content

Commit

Permalink
fix weights update issue when training separation model
Browse files Browse the repository at this point in the history
  • Loading branch information
clement-pages committed Nov 19, 2024
1 parent 7d84f61 commit 7b4e2a4
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 10 deletions.
4 changes: 4 additions & 0 deletions pyannote/audio/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,10 @@ def setup(self, stage=None):
f"does not correspond to the cached one ({self.prepared_data['protocol']})"
)

@property
def automatic_optimization(self) -> bool:
return self.model.automatic_optimization

@property
def specifications(self) -> Union[Specifications, Tuple[Specifications]]:
# setup metadata on-demand the first time specifications are requested and missing
Expand Down
19 changes: 16 additions & 3 deletions pyannote/audio/models/separation/ToTaToNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

# AUTHOR: Joonas Kalda (github.com/joonaskalda)

import contextlib
from functools import lru_cache
from typing import Optional

Expand Down Expand Up @@ -96,6 +97,8 @@ class ToTaToNet(Model):
Number of separated sources. Defaults to 3.
use_wavlm : bool, optional
Whether to use the WavLM large model for feature extraction. Defaults to True.
wavlm_frozen : bool, optional
Whether to freeze the WavLM model. Defaults to False.
gradient_clip_val : float, optional
Gradient clipping value. Required when fine-tuning the WavLM model and thus using two different optimizers.
Defaults to 5.0.
Expand Down Expand Up @@ -137,6 +140,7 @@ def __init__(
task: Optional[Task] = None,
n_sources: int = 3,
use_wavlm: bool = True,
wavlm_frozen: bool = False,
gradient_clip_val: float = 5.0,
):
if not ASTEROID_IS_AVAILABLE:
Expand All @@ -158,7 +162,9 @@ def __init__(
encoder_decoder = merge_dict(self.ENCODER_DECODER_DEFAULTS, encoder_decoder)
diar = merge_dict(self.DIAR_DEFAULTS, diar)
self.use_wavlm = use_wavlm
self.save_hyperparameters("encoder_decoder", "linear", "dprnn", "diar")
self.save_hyperparameters(
"encoder_decoder", "linear", "dprnn", "diar", "wavlm_frozen"
)
self.n_sources = n_sources

if encoder_decoder["fb_name"] == "free":
Expand Down Expand Up @@ -216,7 +222,8 @@ def __init__(
]
)
self.gradient_clip_val = gradient_clip_val
self.automatic_optimization = False
# manual optimization is needed only when wavlm is finetuned
self.automatic_optimization = wavlm_frozen

@property
def dimension(self) -> int:
Expand Down Expand Up @@ -321,7 +328,13 @@ def forward(self, waveforms: torch.Tensor) -> torch.Tensor:
bsz = waveforms.shape[0]
tf_rep = self.encoder(waveforms)
if self.use_wavlm:
wavlm_rep = self.wavlm(waveforms.squeeze(1)).last_hidden_state
context = (
torch.no_grad()
if self.hparams["wavlm_frozen"]
else contextlib.nullcontext()
)
with context:
wavlm_rep = self.wavlm(waveforms.squeeze(1)).last_hidden_state
wavlm_rep = wavlm_rep.transpose(1, 2)
wavlm_rep = wavlm_rep.repeat_interleave(self.wavlm_scaling, dim=-1)
wavlm_rep = pad_x_to_y(wavlm_rep, tf_rep)
Expand Down
9 changes: 2 additions & 7 deletions pyannote/audio/tasks/separation/PixIT.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,6 @@ class PixIT(SegmentationTask):
Defaults to AUROC (area under the ROC curve).
separation_loss_weight : float, optional
Scaling factor between diarization and separation losses. Defaults to 0.5.
finetune_wavlm : bool, optional
If True, the WavLM feature extractor will be fine-tuned during training.
Defaults to True.
References
----------
Expand Down Expand Up @@ -175,7 +172,6 @@ def __init__(
] = None, # deprecated in favor of `max_speakers_per_chunk``
loss: Literal["bce", "mse"] = None, # deprecated
separation_loss_weight: float = 0.5,
finetune_wavlm: bool = True,
):
if not ASTEROID_IS_AVAILABLE:
raise ImportError(
Expand Down Expand Up @@ -224,7 +220,6 @@ def __init__(
self.weight = weight
self.separation_loss_weight = separation_loss_weight
self.mixit_loss = MixITLossWrapper(multisrc_neg_sisdr, generalized=True)
self.finetune_wavlm = finetune_wavlm

def setup(self, stage=None):
super().setup(stage)
Expand Down Expand Up @@ -973,7 +968,7 @@ def training_step(self, batch, batch_idx: int):
"""
# finetuning wavlm with a smaller learning rate requires two optimizers
# and manual gradient stepping
if self.finetune_wavlm:
if not self.automatic_optimization:
wavlm_opt, rest_opt = self.model.optimizers()
wavlm_opt.zero_grad()
rest_opt.zero_grad()
Expand Down Expand Up @@ -1020,7 +1015,7 @@ def training_step(self, batch, batch_idx: int):
logger=True,
)

if self.finetune_wavlm:
if not self.automatic_optimization:
self.model.manual_backward(loss)
self.model.clip_gradients(
wavlm_opt,
Expand Down

0 comments on commit 7b4e2a4

Please sign in to comment.