From 3f4a22f773209410fcd3f5c8bb13e78f49a1aedf Mon Sep 17 00:00:00 2001 From: Lebourdais Date: Wed, 23 Oct 2024 16:50:30 +0200 Subject: [PATCH] fix: fix alignment between diarization and sources --- CHANGELOG.md | 1 + pyannote/audio/pipelines/speech_separation.py | 10 ++++++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 100ef7278..95e099a4b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ ### Fixes - fix: fix clipping issue in speech separation pipeline ([@joonaskalda](https://github.com/joonaskalda/)) +- fix: fix alignment between separated sources and diarization when the diarization reference is available ([@Lebourdais](https://github.com/Lebourdais/)) ## Version 3.3.2 (2024-09-11) diff --git a/pyannote/audio/pipelines/speech_separation.py b/pyannote/audio/pipelines/speech_separation.py index dacb637b1..a129ea7a4 100644 --- a/pyannote/audio/pipelines/speech_separation.py +++ b/pyannote/audio/pipelines/speech_separation.py @@ -124,7 +124,7 @@ class SpeechSeparation(SpeakerDiarizationMixin, Pipeline): def __init__( self, - segmentation: PipelineModel = None, + segmentation: PipelineModel = "pyannote/separation-ami-1.0", segmentation_step: float = 0.1, embedding: PipelineModel = "speechbrain/spkrec-ecapa-voxceleb@5c0be3875fda05e81f3c004ed8c7c06be308de1e", embedding_exclude_overlap: bool = False, @@ -698,6 +698,13 @@ def apply( # strings and integers when reference is available and some hypothesis # speakers are not present in the reference) + # re-order sources so that they match + # the order given by diarization.labels() + inverse_mapping = {label: index for index, label in mapping.items()} + source.data = sources.data[ + :, [inverse_mapping[label] for label in diarization.labels()] + ] + if not return_embeddings: return diarization, sources @@ -717,7 +724,6 @@ def apply( # re-order centroids so that they match # the order given by diarization.labels() - inverse_mapping = {label: index for index, label in mapping.items()} centroids = centroids[ [inverse_mapping[label] for label in diarization.labels()] ]