Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for k-means clustering #1774

Merged
merged 7 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@

### New features

- feat: add support for `k-means` clustering
- feat: add `"hidden"` option to `ProgressHook`
- feat: add `FilterByNumberOfSpeakers` protocol files filter

### Fixes

- fix: fix clipping issue in speech separation pipeline ([@joonaskalda](https://github.com/joonaskalda/))


## Version 3.3.2 (2024-09-11)

### Fixes
Expand Down
82 changes: 77 additions & 5 deletions pyannote/audio/pipelines/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

"""Clustering pipelines"""


import random
from enum import Enum
from typing import Optional, Tuple
Expand All @@ -35,6 +34,7 @@
from scipy.cluster.hierarchy import fcluster, linkage
from scipy.optimize import linear_sum_assignment
from scipy.spatial.distance import cdist
from sklearn.cluster import KMeans

from pyannote.audio.core.io import AudioFile
from pyannote.audio.pipelines.utils import oracle_segmentation
Expand Down Expand Up @@ -264,8 +264,8 @@ def __call__(

train_clusters = self.cluster(
train_embeddings,
min_clusters,
max_clusters,
min_clusters=min_clusters,
max_clusters=max_clusters,
num_clusters=num_clusters,
)

Expand Down Expand Up @@ -298,6 +298,8 @@ class AgglomerativeClustering(BaseClustering):
Minimum cluster size
"""

expects_num_clusters: bool = False

def __init__(
self,
metric: str = "cosine",
Expand All @@ -321,8 +323,8 @@ def __init__(
def cluster(
self,
embeddings: np.ndarray,
min_clusters: int,
max_clusters: int,
min_clusters: Optional[int] = None,
max_clusters: Optional[int] = None,
num_clusters: Optional[int] = None,
):
"""
Expand Down Expand Up @@ -471,9 +473,78 @@ def cluster(
return clusters


class KMeansClustering(BaseClustering):
"""KMeans clustering

Parameters
----------
metric : {"cosine", "euclidean"}, optional
Distance metric to use. Defaults to "cosine".

Hyper-parameters
----------------
None
"""

expects_num_clusters: bool = True

def __init__(
self,
metric: str = "cosine",
):
if metric not in ["cosine", "euclidean"]:
raise ValueError(
f"Unsupported metric: {metric}. Must be 'cosine' or 'euclidean'."
)

super().__init__(metric=metric)

def cluster(
self,
embeddings: np.ndarray,
min_clusters: Optional[int] = None,
max_clusters: Optional[int] = None,
num_clusters: Optional[int] = None,
):
"""Perform KMeans clustering

Parameters
----------
embeddings : (num_embeddings, dimension) array
Embeddings
num_clusters : int, optional
Expected number of clusters.

Returns
-------
clusters : (num_embeddings, ) array
0-indexed cluster indices.
"""

if num_clusters is None:
raise ValueError("`num_clusters` must be provided.")

num_embeddings, _ = embeddings.shape
if num_embeddings < num_clusters:
# one cluster per embedding as int
return np.arange(num_embeddings, dtype=np.int32)

# unit-normalize embeddings to use 'euclidean' distance
if self.metric == "cosine":
with np.errstate(divide="ignore", invalid="ignore"):
embeddings /= np.linalg.norm(embeddings, axis=-1, keepdims=True)

# perform Kmeans clustering
return KMeans(
n_clusters=num_clusters, n_init=3, random_state=42, copy_x=False
).fit_predict(embeddings)


class OracleClustering(BaseClustering):
"""Oracle clustering"""

expects_num_clusters: bool = True

def __call__(
self,
embeddings: Optional[np.ndarray] = None,
Expand Down Expand Up @@ -558,4 +629,5 @@ def __call__(

class Clustering(Enum):
AgglomerativeClustering = AgglomerativeClustering
KMeansClustering = KMeansClustering
OracleClustering = OracleClustering
20 changes: 18 additions & 2 deletions pyannote/audio/pipelines/speaker_diarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import math
import textwrap
import warnings
from typing import Callable, Optional, Text, Union
from typing import Callable, Mapping, Optional, Text, Union

import numpy as np
import torch
Expand All @@ -45,6 +45,7 @@
SpeakerDiarizationMixin,
get_model,
)
from pyannote.audio.pipelines.utils.diarization import set_num_speakers
from pyannote.audio.utils.signal import binarize


Expand Down Expand Up @@ -177,6 +178,8 @@ def __init__(
)
self.clustering = Klustering.value(metric=metric)

self._expects_num_speakers = self.clustering.expects_num_clusters

@property
def segmentation_batch_size(self) -> int:
return self._segmentation.batch_size
Expand Down Expand Up @@ -469,12 +472,25 @@ def apply(
# setup hook (e.g. for debugging purposes)
hook = self.setup_hook(file, hook=hook)

num_speakers, min_speakers, max_speakers = self.set_num_speakers(
num_speakers, min_speakers, max_speakers = set_num_speakers(
num_speakers=num_speakers,
min_speakers=min_speakers,
max_speakers=max_speakers,
)

# when using KMeans clustering (or equivalent), the number of speakers must
# be provided alongside the audio file. also, during pipeline training, we
# infer the number of speakers from the reference annotation to avoid the
# pipeline complaining about missing number of speakers.
if self._expects_num_speakers and num_speakers is None:
if isinstance(file, Mapping) and "annotation" in file:
num_speakers = len(file["annotation"].labels())

else:
raise ValueError(
f"num_speakers must be provided when using {self.klustering} clustering"
)

segmentations = self.get_segmentations(file, hook=hook)
hook("segmentation", segmentations)
# shape: (num_chunks, num_frames, local_num_speakers)
Expand Down
3 changes: 2 additions & 1 deletion pyannote/audio/pipelines/speech_separation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
SpeakerDiarizationMixin,
get_model,
)
from pyannote.audio.pipelines.utils.diarization import set_num_speakers
from pyannote.audio.utils.signal import binarize


Expand Down Expand Up @@ -489,7 +490,7 @@ def apply(
# setup hook (e.g. for debugging purposes)
hook = self.setup_hook(file, hook=hook)

num_speakers, min_speakers, max_speakers = self.set_num_speakers(
num_speakers, min_speakers, max_speakers = set_num_speakers(
num_speakers=num_speakers,
min_speakers=min_speakers,
max_speakers=max_speakers,
Expand Down
58 changes: 43 additions & 15 deletions pyannote/audio/pipelines/utils/diarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,44 @@
from pyannote.audio.utils.signal import Binarize


# TODO: move to dedicated module
def set_num_speakers(
num_speakers: Optional[int] = None,
min_speakers: Optional[int] = None,
max_speakers: Optional[int] = None,
):
"""Validate number of speakers

Parameters
----------
num_speakers : int, optional
Number of speakers.
min_speakers : int, optional
Minimum number of speakers.
max_speakers : int, optional
Maximum number of speakers.

Returns
-------
num_speakers : int or None
min_speakers : int
max_speakers : int or np.inf
"""

# override {min|max}_num_speakers by num_speakers when available
min_speakers = num_speakers or min_speakers or 1
max_speakers = num_speakers or max_speakers or np.inf

if min_speakers > max_speakers:
raise ValueError(
f"min_speakers must be smaller than (or equal to) max_speakers "
f"(here: min_speakers={min_speakers:g} and max_speakers={max_speakers:g})."
)
if min_speakers == max_speakers:
num_speakers = min_speakers

return num_speakers, min_speakers, max_speakers


class SpeakerDiarizationMixin:
"""Defines a bunch of methods common to speaker diarization pipelines"""

Expand All @@ -58,20 +95,11 @@ def set_num_speakers(
min_speakers : int
max_speakers : int or np.inf
"""

# override {min|max}_num_speakers by num_speakers when available
min_speakers = num_speakers or min_speakers or 1
max_speakers = num_speakers or max_speakers or np.inf

if min_speakers > max_speakers:
raise ValueError(
f"min_speakers must be smaller than (or equal to) max_speakers "
f"(here: min_speakers={min_speakers:g} and max_speakers={max_speakers:g})."
)
if min_speakers == max_speakers:
num_speakers = min_speakers

return num_speakers, min_speakers, max_speakers
return set_num_speakers(
num_speakers=num_speakers,
min_speakers=min_speakers,
max_speakers=max_speakers,
)

@staticmethod
def optimal_mapping(
Expand Down
48 changes: 44 additions & 4 deletions pyannote/audio/utils/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
# SOFTWARE.

from functools import partial
from typing import Optional

import torchaudio
from pyannote.core import Annotation
from pyannote.database import FileFinder, Protocol, get_annotated
from pyannote.database.protocol import SpeakerVerificationProtocol

Expand Down Expand Up @@ -62,14 +64,12 @@ def check_protocol(protocol: Protocol) -> Protocol:

# does protocol provide audio keys?
if "audio" not in file:

if "waveform" in file:
if "sample_rate" not in file:
msg = f'Protocol {protocol.name} provides audio with "waveform" key but is missing a "sample_rate" key.'
raise ValueError(msg)

else:

file_finder = FileFinder()
try:
_ = file_finder(file)
Expand All @@ -90,7 +90,6 @@ def check_protocol(protocol: Protocol) -> Protocol:
print(msg)

if "waveform" not in file and "torchaudio.info" not in file:

# use soundfile when available (it usually is faster than ffmpeg for getting info)
backends = (
torchaudio.list_audio_backends()
Expand All @@ -107,7 +106,6 @@ def check_protocol(protocol: Protocol) -> Protocol:
print(msg)

if "annotated" not in file:

if "duration" not in file:
protocol.preprocessors["duration"] = get_duration

Expand Down Expand Up @@ -143,3 +141,45 @@ def check_protocol(protocol: Protocol) -> Protocol:
}

return protocol, checks


class FilterByNumberOfSpeakers:
"""Filter files based on the number of speakers

Note
----
Always returns True if `current_file` does not have an "annotation" key.

"""

def __init__(
self,
num_speakers: Optional[int] = None,
min_speakers: Optional[int] = None,
max_speakers: Optional[int] = None,
):
from pyannote.audio.pipelines.utils.diarization import set_num_speakers

self.num_speakers, self.min_speakers, self.max_speakers = set_num_speakers(
num_speakers=num_speakers,
min_speakers=min_speakers,
max_speakers=max_speakers,
)

def __call__(self, current_file: dict) -> bool:
if "annotation" not in current_file:
return True

annotation: Annotation = current_file["annotation"]
num_speakers: int = len(annotation.labels())

if self.num_speakers is not None and self.num_speakers != num_speakers:
return False

if self.min_speakers is not None and self.min_speakers > num_speakers:
return False

if self.max_speakers is not None and self.max_speakers < num_speakers:
return False

return True
Loading