Skip to content

Commit

Permalink
beginning work on diarizer
Browse files Browse the repository at this point in the history
  • Loading branch information
Jemoka committed Mar 31, 2024
1 parent a43724d commit 52ab387
Show file tree
Hide file tree
Showing 17 changed files with 215 additions and 8 deletions.
4 changes: 2 additions & 2 deletions batchalign/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .bert_utterance import BertUtteranceModel
from .utterance import BertUtteranceModel
from .whisper import WhisperASRModel, WhisperFAModel
from .bert_utterance import BertUtteranceModel
from .utils import ASRAudioFile
from .resolve import resolve

Empty file.
1 change: 1 addition & 0 deletions batchalign/models/speaker/batchalign
68 changes: 68 additions & 0 deletions batchalign/models/speaker/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# This YAML file is created for all types of offline speaker diarization inference tasks in `<NeMo git root>/example/speaker_tasks/diarization` folder.
# The inference parameters for VAD, speaker embedding extractor, clustering module, MSDD module, ASR decoder are all included in this YAML file.
# All the keys under `diarizer` key (`vad`, `speaker_embeddings`, `clustering`, `msdd_model`, `asr`) can be selectively used for its own purpose and also can be ignored if the module is not used.
# The configurations in this YAML file is suitable for telephone recordings involving 2~8 speakers in a session and may not show the best performance on the other types of acoustic conditions or dialogues.
# An example line in an input manifest file (`.json` format):
# {"audio_filepath": "/path/to/audio_file", "offset": 0, "duration": null, "label": "infer", "text": "-", "num_speakers": null, "rttm_filepath": "/path/to/rttm/file", "uem_filepath": "/path/to/uem/file"}
name: &name "ClusterDiarizer"

num_workers: 0
sample_rate: 16000
batch_size: 8
device: null # can specify a specific device, i.e: cuda:1 (default cuda if cuda available, else cpu)
verbose: True # enable additional logging

diarizer:
manifest_filepath: ???
out_dir: ???
oracle_vad: False # If True, uses RTTM files provided in the manifest file to get speech activity (VAD) timestamps
collar: 0.25 # Collar value for scoring
ignore_overlap: True # Consider or ignore overlap segments while scoring

vad:
model_path: vad_multilingual_marblenet # .nemo local model path or pretrained VAD model name
external_vad_manifest: null # This option is provided to use external vad and provide its speech activity labels for speaker embeddings extraction. Only one of model_path or external_vad_manifest should be set

parameters: # Tuned parameters for CH109 (using the 11 multi-speaker sessions as dev set)
window_length_in_sec: 0.15 # Window length in sec for VAD context input
shift_length_in_sec: 0.01 # Shift length in sec for generate frame level VAD prediction
smoothing: "median" # False or type of smoothing method (eg: median)
overlap: 0.5 # Overlap ratio for overlapped mean/median smoothing filter
onset: 0.1 # Onset threshold for detecting the beginning and end of a speech
offset: 0.1 # Offset threshold for detecting the end of a speech
pad_onset: 0.1 # Adding durations before each speech segment
pad_offset: 0 # Adding durations after each speech segment
min_duration_on: 0 # Threshold for small non_speech deletion
min_duration_off: 0.2 # Threshold for short speech segment deletion
filter_speech_first: True

speaker_embeddings:
model_path: titanet_large # .nemo local model path or pretrained model name (titanet_large, ecapa_tdnn or speakerverification_speakernet)
parameters:
window_length_in_sec: [10.5,7.25,5.0,3.75,2.5] # Window length(s) in sec (floating-point number). either a number or a list. ex) 1.5 or [1.5,1.0,0.5]
shift_length_in_sec: [4.75,3.625,2.5,1.375,0.25] # Shift length(s) in sec (floating-point number). either a number or a list. ex) 0.75 or [0.75,0.5,0.25]
multiscale_weights: [1,1,1,1,1] # Weight for each scale. should be null (for single scale) or a list matched with window/shift scale count. ex) [0.33,0.33,0.33]
save_embeddings: True # If True, save speaker embeddings in pickle format. This should be True if clustering result is used for other models, such as `msdd_model`.

clustering:
parameters:
oracle_num_speakers: True # If True, use num of speakers value provided in manifest file.
max_num_speakers: 8 # Max number of speakers for each recording. If an oracle number of speakers is passed, this value is ignored.
enhanced_count_thres: 80 # If the number of segments is lower than this number, enhanced speaker counting is activated.
max_rp_threshold: 0.25 # Determines the range of p-value search: 0 < p <= max_rp_threshold.
sparse_search_volume: 30 # The higher the number, the more values will be examined with more time.
maj_vote_spk_count: False # If True, take a majority vote on multiple p-values to estimate the number of speakers.
chunk_cluster_count: 50 # Number of forced clusters (overclustering) per unit chunk in long-form audio clustering.
embeddings_per_chunk: 10000 # Number of embeddings in each chunk for long-form audio clustering. Adjust based on GPU memory capacity. (default: 10000, approximately 40 mins of audio)

msdd_model:
model_path: diar_msdd_telephonic # .nemo local model path or pretrained model name for multiscale diarization decoder (MSDD)
parameters:
use_speaker_model_from_ckpt: True # If True, use speaker embedding model in checkpoint. If False, the provided speaker embedding model in config will be used.
infer_batch_size: 25 # Batch size for MSDD inference.
sigmoid_threshold: [0.7] # Sigmoid threshold for generating binarized speaker labels. The smaller the more generous on detecting overlaps.
seq_eval_mode: False # If True, use oracle number of speaker and evaluate F1 score for the given speaker sequences. Default is False.
split_infer: True # If True, break the input audio clip to short sequences and calculate cluster average embeddings for inference.
diar_window_length: 50 # The length of split short sequence when split_infer is True.
overlap_infer_spk_limit: 5 # If the estimated number of speakers are larger than this number, overlap speech is not estimated.

84 changes: 84 additions & 0 deletions batchalign/models/speaker/infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import os
import json
import copy
import glob
import tempfile
from pydub import AudioSegment
from omegaconf import OmegaConf
from nemo.collections.asr.models.msdd_models import NeuralDiarizer
from nemo.collections.asr.modules.msdd_diarizer import MSDD_module
from batchalign.models.speaker.utils import conv_scale_weights

import torch

# compute device
DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

# override msdd implementation
MSDD_module.conv_scale_weights = conv_scale_weights

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = str(1)

INPUT = "/Users/houjun/Documents/Projects/batchalign2/extern/Untitled.wav"
NUM_SPEAKERS = 2

def resolve_config():
return os.path.join(os.path.abspath(os.path.join(__file__, os.pardir)), "config.yaml")

base = OmegaConf.load(resolve_config())

# get a input dir
in_file = INPUT
# make a copy of the input config
config = copy.deepcopy(base)
# create a working directory and configure settings
with tempfile.TemporaryDirectory() as workdir:
# create the mono file
sound = AudioSegment.from_file(in_file).set_channels(1)
sound.export(os.path.join(workdir, "mono_file.wav"), format="wav")

# create configuration with the info we need
meta = {
"audio_filepath": os.path.join(workdir, "mono_file.wav"),
"offset": 0,
"duration": None,
"label": "infer",
"text": "-",
"rttm_filepath": None,
"uem_filepath": None,
"num_speakers": NUM_SPEAKERS
}
manifest_path = os.path.join(workdir, "input_manifest.json")
with open(manifest_path, "w") as fp:
json.dump(meta, fp)
fp.write("\n")
config.diarizer.manifest_filepath = manifest_path
config.diarizer.out_dir = workdir
config.device = DEVICE

# initialize a diarizer and brrr
msdd_model = NeuralDiarizer(cfg=config)
msdd_model.diarize()

# read output and return
# https://github.com/MahmoudAshraf97/whisper-diarization/blob/main/diarize.py
speaker_ts = []
with open(os.path.join(workdir, "pred_rttms", "mono_file.rttm"), "r") as f:
lines = f.readlines()
for line in lines:
line_list = line.split(" ")
s = int(float(line_list[5]) * 1000)
e = s + int(float(line_list[8]) * 1000)
speaker_ts.append([s, e, int(line_list[11].split("_")[-1])])
print(speaker_ts)



# breakpoint()



# conf =
# import torch

# torch.rand(1,10).to("mps:0")
44 changes: 44 additions & 0 deletions batchalign/models/speaker/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import torch
import torch.nn.functional as F

def conv_scale_weights(self, ms_avg_embs_perm, ms_emb_seq_single):
"""
Use multiple Convnet layers to estimate the scale weights based on the cluster-average embedding and
input embedding sequence.
Args:
ms_avg_embs_perm (Tensor):
Tensor containing cluster-average speaker embeddings for each scale.
Shape: (batch_size, length, scale_n, emb_dim)
ms_emb_seq_single (Tensor):
Tensor containing multi-scale speaker embedding sequences. ms_emb_seq_single is input from the
given audio stream input.
Shape: (batch_size, length, num_spks, emb_dim)
Returns:
scale_weights (Tensor):
Weight vectors that determine the weight of each scale.
Shape: (batch_size, length, num_spks, emb_dim)
"""
ms_cnn_input_seq = torch.cat([ms_avg_embs_perm, ms_emb_seq_single], dim=2)
ms_cnn_input_seq = ms_cnn_input_seq.unsqueeze(2).flatten(0, 1)

conv_out = self.conv_forward(
ms_cnn_input_seq, conv_module=self.conv[0], bn_module=self.conv_bn[0], first_layer=True
)
for conv_idx in range(1, self.conv_repeat + 1):
conv_out = self.conv_forward(
conv_input=conv_out,
conv_module=self.conv[conv_idx],
bn_module=self.conv_bn[conv_idx],
first_layer=False,
)

# reshape / view
lin_input_seq = conv_out.reshape(self.batch_size, self.length, self.cnn_output_ch * self.emb_dim)
hidden_seq = self.conv_to_linear(lin_input_seq)
hidden_seq = self.dropout(F.leaky_relu(hidden_seq))
scale_weights = self.softmax(self.linear_to_weights(hidden_seq))
scale_weights = scale_weights.unsqueeze(3).expand(-1, -1, -1, self.num_spks)
return scale_weights

2 changes: 1 addition & 1 deletion batchalign/models/training/run.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from batchalign.models.bert_utterance.execute import utterance
from batchalign.models.utterance.execute import utterance
import logging as L

import rich_click as click
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from batchalign.models.training.utils import *
from batchalign.models.bert_utterance.prep import prep as P
from batchalign.models.bert_utterance.train import train as T
from batchalign.models.utterance.prep import prep as P
from batchalign.models.utterance.train import train as T

import rich_click as click

Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from transformers import DataCollatorForTokenClassification

# import our dataset
from batchalign.models.bert_utterance.dataset import TOKENS, UtteranceBoundaryDataset
from batchalign.models.utterance.dataset import TOKENS, UtteranceBoundaryDataset

# tqdm
from tqdm import tqdm
Expand Down
4 changes: 2 additions & 2 deletions batchalign/models/whisper/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .whisper_asr import WhisperASRModel
from .whisper_fa import WhisperFAModel
from .infer_asr import WhisperASRModel
from .infer_fa import WhisperFAModel
File renamed without changes.
File renamed without changes.
10 changes: 10 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,16 @@ def read(fname):
'train': [
'wandb~=0.16',
'accelerate~=0.27',
],
'diarize': [
"nemo-toolkit[asr]==1.21.0",
"omegaconf~=2.3.0",
"pydub~=0.25.1"
# "youtokentome~=1.0.6",
# "hydra-core~=1.3.2",
# "inflect>=7.0.0",
# "webdataset~=0.2.86",
# "editdistance~=0.2.86",
]
},
include_package_data=True,
Expand Down

0 comments on commit 52ab387

Please sign in to comment.