Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update weights optimization #1791

Merged
merged 13 commits into from
Nov 26, 2024
8 changes: 8 additions & 0 deletions pyannote/audio/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,14 @@ def setup(self, stage=None):
f"does not correspond to the cached one ({self.prepared_data['protocol']})"
)

@property
def automatic_optimization(self) -> bool:
return self.model.automatic_optimization

@automatic_optimization.setter
def automatic_optimization(self, automatic_optimisation: bool) -> None:
self.model.automatic_optimization = automatic_optimisation

@property
def specifications(self) -> Union[Specifications, Tuple[Specifications]]:
# setup metadata on-demand the first time specifications are requested and missing
Expand Down
11 changes: 5 additions & 6 deletions pyannote/audio/models/segmentation/SSeRiouSS.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ def __init__(
self.wav2vec_weights = nn.Parameter(
data=torch.ones(wav2vec_num_layers), requires_grad=True
)

for param in self.wav2vec.parameters():
param.requires_grad = not wav2vec_frozen

lstm = merge_dict(self.LSTM_DEFAULTS, lstm)
lstm["batch_first"] = True
Expand Down Expand Up @@ -300,13 +303,9 @@ def forward(self, waveforms: torch.Tensor) -> torch.Tensor:
None if self.hparams.wav2vec_layer < 0 else self.hparams.wav2vec_layer
)

context = (
torch.no_grad() if self.hparams.wav2vec_frozen else contextlib.nullcontext()
outputs, _ = self.wav2vec.extract_features(
waveforms.squeeze(1), num_layers=num_layers
)
with context:
outputs, _ = self.wav2vec.extract_features(
waveforms.squeeze(1), num_layers=num_layers
)

if num_layers is None:
outputs = torch.stack(outputs, dim=-1) @ F.softmax(
Expand Down
12 changes: 10 additions & 2 deletions pyannote/audio/models/separation/ToTaToNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ class ToTaToNet(Model):
Number of separated sources. Defaults to 3.
use_wavlm : bool, optional
Whether to use the WavLM large model for feature extraction. Defaults to True.
wavlm_frozen : bool, optional
Whether to freeze the WavLM model. Defaults to False.
gradient_clip_val : float, optional
Gradient clipping value. Required when fine-tuning the WavLM model and thus using two different optimizers.
Defaults to 5.0.
Expand Down Expand Up @@ -137,6 +139,7 @@ def __init__(
task: Optional[Task] = None,
n_sources: int = 3,
use_wavlm: bool = True,
wavlm_frozen: bool = False,
clement-pages marked this conversation as resolved.
Show resolved Hide resolved
gradient_clip_val: float = 5.0,
):
if not ASTEROID_IS_AVAILABLE:
Expand All @@ -158,7 +161,9 @@ def __init__(
encoder_decoder = merge_dict(self.ENCODER_DECODER_DEFAULTS, encoder_decoder)
diar = merge_dict(self.DIAR_DEFAULTS, diar)
self.use_wavlm = use_wavlm
self.save_hyperparameters("encoder_decoder", "linear", "dprnn", "diar")
self.save_hyperparameters(
"encoder_decoder", "linear", "dprnn", "diar", "wavlm_frozen"
)
self.n_sources = n_sources

if encoder_decoder["fb_name"] == "free":
Expand All @@ -173,6 +178,8 @@ def __init__(

if self.use_wavlm:
self.wavlm = AutoModel.from_pretrained("microsoft/wavlm-large")
for param in self.wavlm.parameters():
param.requires_grad = not wavlm_frozen
downsampling_factor = 1
for conv_layer in self.wavlm.feature_extractor.conv_layers:
if isinstance(conv_layer.conv, nn.Conv1d):
Expand Down Expand Up @@ -216,7 +223,8 @@ def __init__(
]
)
self.gradient_clip_val = gradient_clip_val
self.automatic_optimization = False
# manual optimization is needed only when wavlm is finetuned
self.automatic_optimization = wavlm_frozen

@property
def dimension(self) -> int:
Expand Down
10 changes: 5 additions & 5 deletions pyannote/audio/tasks/segmentation/speaker_diarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,10 +398,6 @@ def collate_y(self, batch) -> torch.Tensor:

return torch.from_numpy(np.stack(collated_y))

@property
def automatic_optimization(self):
return self.model.automatic_optimization

def training_step(self, batch, batch_idx: int):
"""Compute permutation-invariant segmentation loss

Expand Down Expand Up @@ -469,8 +465,11 @@ def training_step(self, batch, batch_idx: int):
logger=True,
)

if not self.model.automatic_optimization:
if not self.automatic_optimization:
optimizers = self.model.optimizers()
optimizers = (
[optimizers] if not isinstance(optimizers, list) else optimizers
)
clement-pages marked this conversation as resolved.
Show resolved Hide resolved
for optimizer in optimizers:
optimizer.zero_grad()

Expand Down Expand Up @@ -673,4 +672,5 @@ def progress_hook(completed: Optional[int] = None, total: Optional[int] = None):

if __name__ == "__main__":
import typer

typer.run(evaluate)
43 changes: 18 additions & 25 deletions pyannote/audio/tasks/separation/PixIT.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,6 @@ class PixIT(SegmentationTask):
Defaults to AUROC (area under the ROC curve).
separation_loss_weight : float, optional
Scaling factor between diarization and separation losses. Defaults to 0.5.
finetune_wavlm : bool, optional
If True, the WavLM feature extractor will be fine-tuned during training.
Defaults to True.

References
----------
Expand Down Expand Up @@ -175,7 +172,6 @@ def __init__(
] = None, # deprecated in favor of `max_speakers_per_chunk``
loss: Literal["bce", "mse"] = None, # deprecated
separation_loss_weight: float = 0.5,
finetune_wavlm: bool = True,
):
if not ASTEROID_IS_AVAILABLE:
raise ImportError(
Expand Down Expand Up @@ -224,7 +220,6 @@ def __init__(
self.weight = weight
self.separation_loss_weight = separation_loss_weight
self.mixit_loss = MixITLossWrapper(multisrc_neg_sisdr, generalized=True)
self.finetune_wavlm = finetune_wavlm

def setup(self, stage=None):
super().setup(stage)
Expand Down Expand Up @@ -971,20 +966,14 @@ def training_step(self, batch, batch_idx: int):
loss : {str: torch.tensor}
{"loss": loss}
"""
# finetuning wavlm with a smaller learning rate requires two optimizers
# and manual gradient stepping
if self.finetune_wavlm:
wavlm_opt, rest_opt = self.model.optimizers()
wavlm_opt.zero_grad()
rest_opt.zero_grad()

(
seg_loss,
separation_loss,
diarization,
permutated_diarization,
target,
) = self.common_step(batch)

self.model.log(
"loss/train/separation",
separation_loss,
Expand Down Expand Up @@ -1020,20 +1009,24 @@ def training_step(self, batch, batch_idx: int):
logger=True,
)

if self.finetune_wavlm:
self.model.manual_backward(loss)
self.model.clip_gradients(
wavlm_opt,
gradient_clip_val=self.model.gradient_clip_val,
gradient_clip_algorithm="norm",
)
self.model.clip_gradients(
rest_opt,
gradient_clip_val=self.model.gradient_clip_val,
gradient_clip_algorithm="norm",
# using multiple optimizers requires manual optimization
if not self.automatic_optimization:
optimizers = self.model.optimizers()
optimizers = (
[optimizers] if not isinstance(optimizers, list) else optimizers
)
wavlm_opt.step()
rest_opt.step()
for optimizer in optimizers:
optimizer.zero_grad()

self.model.manual_backward(loss)

for optimizer in optimizers:
self.model.clip_gradients(
optimizer,
gradient_clip_val=self.model.gradient_clip_val,
gradient_clip_algorithm="norm",
)
optimizer.step()

return {"loss": loss}

Expand Down
Loading