diff --git a/batchalign/models/__init__.py b/batchalign/models/__init__.py index b687527..97258d9 100644 --- a/batchalign/models/__init__.py +++ b/batchalign/models/__init__.py @@ -1,5 +1,5 @@ -from .bert_utterance import BertUtteranceModel +from .utterance import BertUtteranceModel from .whisper import WhisperASRModel, WhisperFAModel -from .bert_utterance import BertUtteranceModel from .utils import ASRAudioFile from .resolve import resolve + diff --git a/batchalign/models/speaker/__init__.py b/batchalign/models/speaker/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/batchalign/models/speaker/batchalign b/batchalign/models/speaker/batchalign new file mode 120000 index 0000000..19aca6e --- /dev/null +++ b/batchalign/models/speaker/batchalign @@ -0,0 +1 @@ +/Users/houjun/Documents/Projects/batchalign2/batchalign \ No newline at end of file diff --git a/batchalign/models/speaker/config.yaml b/batchalign/models/speaker/config.yaml new file mode 100644 index 0000000..6dbc64b --- /dev/null +++ b/batchalign/models/speaker/config.yaml @@ -0,0 +1,68 @@ +# This YAML file is created for all types of offline speaker diarization inference tasks in `/example/speaker_tasks/diarization` folder. +# The inference parameters for VAD, speaker embedding extractor, clustering module, MSDD module, ASR decoder are all included in this YAML file. +# All the keys under `diarizer` key (`vad`, `speaker_embeddings`, `clustering`, `msdd_model`, `asr`) can be selectively used for its own purpose and also can be ignored if the module is not used. +# The configurations in this YAML file is suitable for telephone recordings involving 2~8 speakers in a session and may not show the best performance on the other types of acoustic conditions or dialogues. +# An example line in an input manifest file (`.json` format): +# {"audio_filepath": "/path/to/audio_file", "offset": 0, "duration": null, "label": "infer", "text": "-", "num_speakers": null, "rttm_filepath": "/path/to/rttm/file", "uem_filepath": "/path/to/uem/file"} +name: &name "ClusterDiarizer" + +num_workers: 0 +sample_rate: 16000 +batch_size: 8 +device: null # can specify a specific device, i.e: cuda:1 (default cuda if cuda available, else cpu) +verbose: True # enable additional logging + +diarizer: + manifest_filepath: ??? + out_dir: ??? + oracle_vad: False # If True, uses RTTM files provided in the manifest file to get speech activity (VAD) timestamps + collar: 0.25 # Collar value for scoring + ignore_overlap: True # Consider or ignore overlap segments while scoring + + vad: + model_path: vad_multilingual_marblenet # .nemo local model path or pretrained VAD model name + external_vad_manifest: null # This option is provided to use external vad and provide its speech activity labels for speaker embeddings extraction. Only one of model_path or external_vad_manifest should be set + + parameters: # Tuned parameters for CH109 (using the 11 multi-speaker sessions as dev set) + window_length_in_sec: 0.15 # Window length in sec for VAD context input + shift_length_in_sec: 0.01 # Shift length in sec for generate frame level VAD prediction + smoothing: "median" # False or type of smoothing method (eg: median) + overlap: 0.5 # Overlap ratio for overlapped mean/median smoothing filter + onset: 0.1 # Onset threshold for detecting the beginning and end of a speech + offset: 0.1 # Offset threshold for detecting the end of a speech + pad_onset: 0.1 # Adding durations before each speech segment + pad_offset: 0 # Adding durations after each speech segment + min_duration_on: 0 # Threshold for small non_speech deletion + min_duration_off: 0.2 # Threshold for short speech segment deletion + filter_speech_first: True + + speaker_embeddings: + model_path: titanet_large # .nemo local model path or pretrained model name (titanet_large, ecapa_tdnn or speakerverification_speakernet) + parameters: + window_length_in_sec: [10.5,7.25,5.0,3.75,2.5] # Window length(s) in sec (floating-point number). either a number or a list. ex) 1.5 or [1.5,1.0,0.5] + shift_length_in_sec: [4.75,3.625,2.5,1.375,0.25] # Shift length(s) in sec (floating-point number). either a number or a list. ex) 0.75 or [0.75,0.5,0.25] + multiscale_weights: [1,1,1,1,1] # Weight for each scale. should be null (for single scale) or a list matched with window/shift scale count. ex) [0.33,0.33,0.33] + save_embeddings: True # If True, save speaker embeddings in pickle format. This should be True if clustering result is used for other models, such as `msdd_model`. + + clustering: + parameters: + oracle_num_speakers: True # If True, use num of speakers value provided in manifest file. + max_num_speakers: 8 # Max number of speakers for each recording. If an oracle number of speakers is passed, this value is ignored. + enhanced_count_thres: 80 # If the number of segments is lower than this number, enhanced speaker counting is activated. + max_rp_threshold: 0.25 # Determines the range of p-value search: 0 < p <= max_rp_threshold. + sparse_search_volume: 30 # The higher the number, the more values will be examined with more time. + maj_vote_spk_count: False # If True, take a majority vote on multiple p-values to estimate the number of speakers. + chunk_cluster_count: 50 # Number of forced clusters (overclustering) per unit chunk in long-form audio clustering. + embeddings_per_chunk: 10000 # Number of embeddings in each chunk for long-form audio clustering. Adjust based on GPU memory capacity. (default: 10000, approximately 40 mins of audio) + + msdd_model: + model_path: diar_msdd_telephonic # .nemo local model path or pretrained model name for multiscale diarization decoder (MSDD) + parameters: + use_speaker_model_from_ckpt: True # If True, use speaker embedding model in checkpoint. If False, the provided speaker embedding model in config will be used. + infer_batch_size: 25 # Batch size for MSDD inference. + sigmoid_threshold: [0.7] # Sigmoid threshold for generating binarized speaker labels. The smaller the more generous on detecting overlaps. + seq_eval_mode: False # If True, use oracle number of speaker and evaluate F1 score for the given speaker sequences. Default is False. + split_infer: True # If True, break the input audio clip to short sequences and calculate cluster average embeddings for inference. + diar_window_length: 50 # The length of split short sequence when split_infer is True. + overlap_infer_spk_limit: 5 # If the estimated number of speakers are larger than this number, overlap speech is not estimated. + diff --git a/batchalign/models/speaker/infer.py b/batchalign/models/speaker/infer.py new file mode 100644 index 0000000..aafe1b3 --- /dev/null +++ b/batchalign/models/speaker/infer.py @@ -0,0 +1,84 @@ +import os +import json +import copy +import glob +import tempfile +from pydub import AudioSegment +from omegaconf import OmegaConf +from nemo.collections.asr.models.msdd_models import NeuralDiarizer +from nemo.collections.asr.modules.msdd_diarizer import MSDD_module +from batchalign.models.speaker.utils import conv_scale_weights + +import torch + +# compute device +DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" + +# override msdd implementation +MSDD_module.conv_scale_weights = conv_scale_weights + +os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = str(1) + +INPUT = "/Users/houjun/Documents/Projects/batchalign2/extern/Untitled.wav" +NUM_SPEAKERS = 2 + +def resolve_config(): + return os.path.join(os.path.abspath(os.path.join(__file__, os.pardir)), "config.yaml") + +base = OmegaConf.load(resolve_config()) + +# get a input dir +in_file = INPUT +# make a copy of the input config +config = copy.deepcopy(base) +# create a working directory and configure settings +with tempfile.TemporaryDirectory() as workdir: + # create the mono file + sound = AudioSegment.from_file(in_file).set_channels(1) + sound.export(os.path.join(workdir, "mono_file.wav"), format="wav") + + # create configuration with the info we need + meta = { + "audio_filepath": os.path.join(workdir, "mono_file.wav"), + "offset": 0, + "duration": None, + "label": "infer", + "text": "-", + "rttm_filepath": None, + "uem_filepath": None, + "num_speakers": NUM_SPEAKERS + } + manifest_path = os.path.join(workdir, "input_manifest.json") + with open(manifest_path, "w") as fp: + json.dump(meta, fp) + fp.write("\n") + config.diarizer.manifest_filepath = manifest_path + config.diarizer.out_dir = workdir + config.device = DEVICE + + # initialize a diarizer and brrr + msdd_model = NeuralDiarizer(cfg=config) + msdd_model.diarize() + + # read output and return + # https://github.com/MahmoudAshraf97/whisper-diarization/blob/main/diarize.py + speaker_ts = [] + with open(os.path.join(workdir, "pred_rttms", "mono_file.rttm"), "r") as f: + lines = f.readlines() + for line in lines: + line_list = line.split(" ") + s = int(float(line_list[5]) * 1000) + e = s + int(float(line_list[8]) * 1000) + speaker_ts.append([s, e, int(line_list[11].split("_")[-1])]) + print(speaker_ts) + + + +# breakpoint() + + + +# conf = +# import torch + +# torch.rand(1,10).to("mps:0") diff --git a/batchalign/models/speaker/utils.py b/batchalign/models/speaker/utils.py new file mode 100644 index 0000000..25ca43f --- /dev/null +++ b/batchalign/models/speaker/utils.py @@ -0,0 +1,44 @@ +import torch +import torch.nn.functional as F + +def conv_scale_weights(self, ms_avg_embs_perm, ms_emb_seq_single): + """ + Use multiple Convnet layers to estimate the scale weights based on the cluster-average embedding and + input embedding sequence. + + Args: + ms_avg_embs_perm (Tensor): + Tensor containing cluster-average speaker embeddings for each scale. + Shape: (batch_size, length, scale_n, emb_dim) + ms_emb_seq_single (Tensor): + Tensor containing multi-scale speaker embedding sequences. ms_emb_seq_single is input from the + given audio stream input. + Shape: (batch_size, length, num_spks, emb_dim) + + Returns: + scale_weights (Tensor): + Weight vectors that determine the weight of each scale. + Shape: (batch_size, length, num_spks, emb_dim) + """ + ms_cnn_input_seq = torch.cat([ms_avg_embs_perm, ms_emb_seq_single], dim=2) + ms_cnn_input_seq = ms_cnn_input_seq.unsqueeze(2).flatten(0, 1) + + conv_out = self.conv_forward( + ms_cnn_input_seq, conv_module=self.conv[0], bn_module=self.conv_bn[0], first_layer=True + ) + for conv_idx in range(1, self.conv_repeat + 1): + conv_out = self.conv_forward( + conv_input=conv_out, + conv_module=self.conv[conv_idx], + bn_module=self.conv_bn[conv_idx], + first_layer=False, + ) + + # reshape / view + lin_input_seq = conv_out.reshape(self.batch_size, self.length, self.cnn_output_ch * self.emb_dim) + hidden_seq = self.conv_to_linear(lin_input_seq) + hidden_seq = self.dropout(F.leaky_relu(hidden_seq)) + scale_weights = self.softmax(self.linear_to_weights(hidden_seq)) + scale_weights = scale_weights.unsqueeze(3).expand(-1, -1, -1, self.num_spks) + return scale_weights + diff --git a/batchalign/models/training/run.py b/batchalign/models/training/run.py index 5dd12d1..472e501 100644 --- a/batchalign/models/training/run.py +++ b/batchalign/models/training/run.py @@ -1,4 +1,4 @@ -from batchalign.models.bert_utterance.execute import utterance +from batchalign.models.utterance.execute import utterance import logging as L import rich_click as click diff --git a/batchalign/models/bert_utterance/__init__.py b/batchalign/models/utterance/__init__.py similarity index 100% rename from batchalign/models/bert_utterance/__init__.py rename to batchalign/models/utterance/__init__.py diff --git a/batchalign/models/bert_utterance/dataset.py b/batchalign/models/utterance/dataset.py similarity index 100% rename from batchalign/models/bert_utterance/dataset.py rename to batchalign/models/utterance/dataset.py diff --git a/batchalign/models/bert_utterance/execute.py b/batchalign/models/utterance/execute.py similarity index 91% rename from batchalign/models/bert_utterance/execute.py rename to batchalign/models/utterance/execute.py index c1733b8..ca58330 100644 --- a/batchalign/models/bert_utterance/execute.py +++ b/batchalign/models/utterance/execute.py @@ -1,6 +1,6 @@ from batchalign.models.training.utils import * -from batchalign.models.bert_utterance.prep import prep as P -from batchalign.models.bert_utterance.train import train as T +from batchalign.models.utterance.prep import prep as P +from batchalign.models.utterance.train import train as T import rich_click as click diff --git a/batchalign/models/bert_utterance/infer.py b/batchalign/models/utterance/infer.py similarity index 100% rename from batchalign/models/bert_utterance/infer.py rename to batchalign/models/utterance/infer.py diff --git a/batchalign/models/bert_utterance/prep.py b/batchalign/models/utterance/prep.py similarity index 100% rename from batchalign/models/bert_utterance/prep.py rename to batchalign/models/utterance/prep.py diff --git a/batchalign/models/bert_utterance/train.py b/batchalign/models/utterance/train.py similarity index 98% rename from batchalign/models/bert_utterance/train.py rename to batchalign/models/utterance/train.py index 9c1ac03..97e057f 100644 --- a/batchalign/models/bert_utterance/train.py +++ b/batchalign/models/utterance/train.py @@ -16,7 +16,7 @@ from transformers import DataCollatorForTokenClassification # import our dataset -from batchalign.models.bert_utterance.dataset import TOKENS, UtteranceBoundaryDataset +from batchalign.models.utterance.dataset import TOKENS, UtteranceBoundaryDataset # tqdm from tqdm import tqdm diff --git a/batchalign/models/whisper/__init__.py b/batchalign/models/whisper/__init__.py index e822fa7..872751e 100644 --- a/batchalign/models/whisper/__init__.py +++ b/batchalign/models/whisper/__init__.py @@ -1,2 +1,2 @@ -from .whisper_asr import WhisperASRModel -from .whisper_fa import WhisperFAModel +from .infer_asr import WhisperASRModel +from .infer_fa import WhisperFAModel diff --git a/batchalign/models/whisper/whisper_asr.py b/batchalign/models/whisper/infer_asr.py similarity index 100% rename from batchalign/models/whisper/whisper_asr.py rename to batchalign/models/whisper/infer_asr.py diff --git a/batchalign/models/whisper/whisper_fa.py b/batchalign/models/whisper/infer_fa.py similarity index 100% rename from batchalign/models/whisper/whisper_fa.py rename to batchalign/models/whisper/infer_fa.py diff --git a/setup.py b/setup.py index fbe4267..f48ab39 100644 --- a/setup.py +++ b/setup.py @@ -59,6 +59,16 @@ def read(fname): 'train': [ 'wandb~=0.16', 'accelerate~=0.27', + ], + 'diarize': [ + "nemo-toolkit[asr]==1.21.0", + "omegaconf~=2.3.0", + "pydub~=0.25.1" + # "youtokentome~=1.0.6", + # "hydra-core~=1.3.2", + # "inflect>=7.0.0", + # "webdataset~=0.2.86", + # "editdistance~=0.2.86", ] }, include_package_data=True,