Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update weights optimization #1791

Merged
merged 13 commits into from
Nov 26, 2024
Merged
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,18 @@ Clipping and speaker/source alignment issues in speech separation pipeline have
- feat(utils): add `hidden` option to `ProgressHook`
- feat(utils): add `FilterByNumberOfSpeakers` protocol files filter

### Improvements

- improve(model): improve WavLM (un)freezing support for `SSeRiouSS` architecture ([@clement-pages](https://github.com/clement-pages/))
- improve(task): improve `SpeakerDiarization` training with manual optimization ([@clement-pages](https://github.com/clement-pages/))

### Fixes

- fix(model): improve WavLM (un)freezing support for `ToTaToNet` architecture ([@clement-pages](https://github.com/clement-pages/))
- fix(separation): fix clipping issue in speech separation pipeline ([@joonaskalda](https://github.com/joonaskalda/))
- fix(separation): fix alignment between separated sources and diarization ([@Lebourdais](https://github.com/Lebourdais/) and [@clement-pages](https://github.com/clement-pages/))
- fix(separation): prevent leakage removal collar from being applied to diarization ([@clement-pages](https://github.com/clement-pages/))
- fix(separation): fix `PixIT` training with manual optimization ([@clement-pages](https://github.com/clement-pages/))
- fix(doc): fix link to pytorch ([@emmanuel-ferdman](https://github.com/emmanuel-ferdman/))
- fix(task): fix corner case with small (<9) number of validation samples ([@antoinelaurent](https://github.com/antoinelaurent/))

Expand Down
8 changes: 8 additions & 0 deletions pyannote/audio/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,14 @@ def setup(self, stage=None):
f"does not correspond to the cached one ({self.prepared_data['protocol']})"
)

@property
def automatic_optimization(self) -> bool:
return self.model.automatic_optimization

@automatic_optimization.setter
def automatic_optimization(self, automatic_optimisation: bool) -> None:
self.model.automatic_optimization = automatic_optimisation

@property
def specifications(self) -> Union[Specifications, Tuple[Specifications]]:
# setup metadata on-demand the first time specifications are requested and missing
Expand Down
11 changes: 5 additions & 6 deletions pyannote/audio/models/segmentation/SSeRiouSS.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ def __init__(
self.wav2vec_weights = nn.Parameter(
data=torch.ones(wav2vec_num_layers), requires_grad=True
)

for param in self.wav2vec.parameters():
param.requires_grad = not wav2vec_frozen

lstm = merge_dict(self.LSTM_DEFAULTS, lstm)
lstm["batch_first"] = True
Expand Down Expand Up @@ -300,13 +303,9 @@ def forward(self, waveforms: torch.Tensor) -> torch.Tensor:
None if self.hparams.wav2vec_layer < 0 else self.hparams.wav2vec_layer
)

context = (
torch.no_grad() if self.hparams.wav2vec_frozen else contextlib.nullcontext()
outputs, _ = self.wav2vec.extract_features(
waveforms.squeeze(1), num_layers=num_layers
)
with context:
outputs, _ = self.wav2vec.extract_features(
waveforms.squeeze(1), num_layers=num_layers
)

if num_layers is None:
outputs = torch.stack(outputs, dim=-1) @ F.softmax(
Expand Down
12 changes: 10 additions & 2 deletions pyannote/audio/models/separation/ToTaToNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ class ToTaToNet(Model):
Number of separated sources. Defaults to 3.
use_wavlm : bool, optional
Whether to use the WavLM large model for feature extraction. Defaults to True.
wavlm_frozen : bool, optional
Whether to freeze the WavLM model. Defaults to False.
gradient_clip_val : float, optional
Gradient clipping value. Required when fine-tuning the WavLM model and thus using two different optimizers.
Defaults to 5.0.
Expand Down Expand Up @@ -137,6 +139,7 @@ def __init__(
task: Optional[Task] = None,
n_sources: int = 3,
use_wavlm: bool = True,
wavlm_frozen: bool = False,
clement-pages marked this conversation as resolved.
Show resolved Hide resolved
gradient_clip_val: float = 5.0,
):
if not ASTEROID_IS_AVAILABLE:
Expand All @@ -158,7 +161,9 @@ def __init__(
encoder_decoder = merge_dict(self.ENCODER_DECODER_DEFAULTS, encoder_decoder)
diar = merge_dict(self.DIAR_DEFAULTS, diar)
self.use_wavlm = use_wavlm
self.save_hyperparameters("encoder_decoder", "linear", "dprnn", "diar")
self.save_hyperparameters(
"encoder_decoder", "linear", "dprnn", "diar", "wavlm_frozen"
)
self.n_sources = n_sources

if encoder_decoder["fb_name"] == "free":
Expand All @@ -173,6 +178,8 @@ def __init__(

if self.use_wavlm:
self.wavlm = AutoModel.from_pretrained("microsoft/wavlm-large")
for param in self.wavlm.parameters():
param.requires_grad = not wavlm_frozen
downsampling_factor = 1
for conv_layer in self.wavlm.feature_extractor.conv_layers:
if isinstance(conv_layer.conv, nn.Conv1d):
Expand Down Expand Up @@ -216,7 +223,8 @@ def __init__(
]
)
self.gradient_clip_val = gradient_clip_val
self.automatic_optimization = False
# manual optimization is needed only when wavlm is finetuned
self.automatic_optimization = wavlm_frozen

@property
def dimension(self) -> int:
Expand Down
8 changes: 3 additions & 5 deletions pyannote/audio/tasks/segmentation/speaker_diarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,10 +398,6 @@ def collate_y(self, batch) -> torch.Tensor:

return torch.from_numpy(np.stack(collated_y))

@property
def automatic_optimization(self):
return self.model.automatic_optimization

def training_step(self, batch, batch_idx: int):
"""Compute permutation-invariant segmentation loss

Expand Down Expand Up @@ -469,8 +465,9 @@ def training_step(self, batch, batch_idx: int):
logger=True,
)

if not self.model.automatic_optimization:
if not self.automatic_optimization:
optimizers = self.model.optimizers()
optimizers = optimizers if isinstance(optimizers, list) else [optimizers]
for optimizer in optimizers:
optimizer.zero_grad()

Expand Down Expand Up @@ -673,4 +670,5 @@ def progress_hook(completed: Optional[int] = None, total: Optional[int] = None):

if __name__ == "__main__":
import typer

typer.run(evaluate)
41 changes: 16 additions & 25 deletions pyannote/audio/tasks/separation/PixIT.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,6 @@ class PixIT(SegmentationTask):
Defaults to AUROC (area under the ROC curve).
separation_loss_weight : float, optional
Scaling factor between diarization and separation losses. Defaults to 0.5.
finetune_wavlm : bool, optional
If True, the WavLM feature extractor will be fine-tuned during training.
Defaults to True.

References
----------
Expand Down Expand Up @@ -175,7 +172,6 @@ def __init__(
] = None, # deprecated in favor of `max_speakers_per_chunk``
loss: Literal["bce", "mse"] = None, # deprecated
separation_loss_weight: float = 0.5,
finetune_wavlm: bool = True,
):
if not ASTEROID_IS_AVAILABLE:
raise ImportError(
Expand Down Expand Up @@ -224,7 +220,6 @@ def __init__(
self.weight = weight
self.separation_loss_weight = separation_loss_weight
self.mixit_loss = MixITLossWrapper(multisrc_neg_sisdr, generalized=True)
self.finetune_wavlm = finetune_wavlm

def setup(self, stage=None):
super().setup(stage)
Expand Down Expand Up @@ -971,20 +966,14 @@ def training_step(self, batch, batch_idx: int):
loss : {str: torch.tensor}
{"loss": loss}
"""
# finetuning wavlm with a smaller learning rate requires two optimizers
# and manual gradient stepping
if self.finetune_wavlm:
wavlm_opt, rest_opt = self.model.optimizers()
wavlm_opt.zero_grad()
rest_opt.zero_grad()

(
seg_loss,
separation_loss,
diarization,
permutated_diarization,
target,
) = self.common_step(batch)

self.model.log(
"loss/train/separation",
separation_loss,
Expand Down Expand Up @@ -1020,20 +1009,22 @@ def training_step(self, batch, batch_idx: int):
logger=True,
)

if self.finetune_wavlm:
# using multiple optimizers requires manual optimization
if not self.automatic_optimization:
optimizers = self.model.optimizers()
optimizers = optimizers if isinstance(optimizers, list) else [optimizers]
for optimizer in optimizers:
optimizer.zero_grad()

self.model.manual_backward(loss)
self.model.clip_gradients(
wavlm_opt,
gradient_clip_val=self.model.gradient_clip_val,
gradient_clip_algorithm="norm",
)
self.model.clip_gradients(
rest_opt,
gradient_clip_val=self.model.gradient_clip_val,
gradient_clip_algorithm="norm",
)
wavlm_opt.step()
rest_opt.step()

for optimizer in optimizers:
self.model.clip_gradients(
optimizer,
gradient_clip_val=self.model.gradient_clip_val,
gradient_clip_algorithm="norm",
)
optimizer.step()

return {"loss": loss}

Expand Down