-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
17 changed files
with
215 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
from .bert_utterance import BertUtteranceModel | ||
from .utterance import BertUtteranceModel | ||
from .whisper import WhisperASRModel, WhisperFAModel | ||
from .bert_utterance import BertUtteranceModel | ||
from .utils import ASRAudioFile | ||
from .resolve import resolve | ||
|
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
/Users/houjun/Documents/Projects/batchalign2/batchalign |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# This YAML file is created for all types of offline speaker diarization inference tasks in `<NeMo git root>/example/speaker_tasks/diarization` folder. | ||
# The inference parameters for VAD, speaker embedding extractor, clustering module, MSDD module, ASR decoder are all included in this YAML file. | ||
# All the keys under `diarizer` key (`vad`, `speaker_embeddings`, `clustering`, `msdd_model`, `asr`) can be selectively used for its own purpose and also can be ignored if the module is not used. | ||
# The configurations in this YAML file is suitable for telephone recordings involving 2~8 speakers in a session and may not show the best performance on the other types of acoustic conditions or dialogues. | ||
# An example line in an input manifest file (`.json` format): | ||
# {"audio_filepath": "/path/to/audio_file", "offset": 0, "duration": null, "label": "infer", "text": "-", "num_speakers": null, "rttm_filepath": "/path/to/rttm/file", "uem_filepath": "/path/to/uem/file"} | ||
name: &name "ClusterDiarizer" | ||
|
||
num_workers: 0 | ||
sample_rate: 16000 | ||
batch_size: 8 | ||
device: null # can specify a specific device, i.e: cuda:1 (default cuda if cuda available, else cpu) | ||
verbose: True # enable additional logging | ||
|
||
diarizer: | ||
manifest_filepath: ??? | ||
out_dir: ??? | ||
oracle_vad: False # If True, uses RTTM files provided in the manifest file to get speech activity (VAD) timestamps | ||
collar: 0.25 # Collar value for scoring | ||
ignore_overlap: True # Consider or ignore overlap segments while scoring | ||
|
||
vad: | ||
model_path: vad_multilingual_marblenet # .nemo local model path or pretrained VAD model name | ||
external_vad_manifest: null # This option is provided to use external vad and provide its speech activity labels for speaker embeddings extraction. Only one of model_path or external_vad_manifest should be set | ||
|
||
parameters: # Tuned parameters for CH109 (using the 11 multi-speaker sessions as dev set) | ||
window_length_in_sec: 0.15 # Window length in sec for VAD context input | ||
shift_length_in_sec: 0.01 # Shift length in sec for generate frame level VAD prediction | ||
smoothing: "median" # False or type of smoothing method (eg: median) | ||
overlap: 0.5 # Overlap ratio for overlapped mean/median smoothing filter | ||
onset: 0.1 # Onset threshold for detecting the beginning and end of a speech | ||
offset: 0.1 # Offset threshold for detecting the end of a speech | ||
pad_onset: 0.1 # Adding durations before each speech segment | ||
pad_offset: 0 # Adding durations after each speech segment | ||
min_duration_on: 0 # Threshold for small non_speech deletion | ||
min_duration_off: 0.2 # Threshold for short speech segment deletion | ||
filter_speech_first: True | ||
|
||
speaker_embeddings: | ||
model_path: titanet_large # .nemo local model path or pretrained model name (titanet_large, ecapa_tdnn or speakerverification_speakernet) | ||
parameters: | ||
window_length_in_sec: [10.5,7.25,5.0,3.75,2.5] # Window length(s) in sec (floating-point number). either a number or a list. ex) 1.5 or [1.5,1.0,0.5] | ||
shift_length_in_sec: [4.75,3.625,2.5,1.375,0.25] # Shift length(s) in sec (floating-point number). either a number or a list. ex) 0.75 or [0.75,0.5,0.25] | ||
multiscale_weights: [1,1,1,1,1] # Weight for each scale. should be null (for single scale) or a list matched with window/shift scale count. ex) [0.33,0.33,0.33] | ||
save_embeddings: True # If True, save speaker embeddings in pickle format. This should be True if clustering result is used for other models, such as `msdd_model`. | ||
|
||
clustering: | ||
parameters: | ||
oracle_num_speakers: True # If True, use num of speakers value provided in manifest file. | ||
max_num_speakers: 8 # Max number of speakers for each recording. If an oracle number of speakers is passed, this value is ignored. | ||
enhanced_count_thres: 80 # If the number of segments is lower than this number, enhanced speaker counting is activated. | ||
max_rp_threshold: 0.25 # Determines the range of p-value search: 0 < p <= max_rp_threshold. | ||
sparse_search_volume: 30 # The higher the number, the more values will be examined with more time. | ||
maj_vote_spk_count: False # If True, take a majority vote on multiple p-values to estimate the number of speakers. | ||
chunk_cluster_count: 50 # Number of forced clusters (overclustering) per unit chunk in long-form audio clustering. | ||
embeddings_per_chunk: 10000 # Number of embeddings in each chunk for long-form audio clustering. Adjust based on GPU memory capacity. (default: 10000, approximately 40 mins of audio) | ||
|
||
msdd_model: | ||
model_path: diar_msdd_telephonic # .nemo local model path or pretrained model name for multiscale diarization decoder (MSDD) | ||
parameters: | ||
use_speaker_model_from_ckpt: True # If True, use speaker embedding model in checkpoint. If False, the provided speaker embedding model in config will be used. | ||
infer_batch_size: 25 # Batch size for MSDD inference. | ||
sigmoid_threshold: [0.7] # Sigmoid threshold for generating binarized speaker labels. The smaller the more generous on detecting overlaps. | ||
seq_eval_mode: False # If True, use oracle number of speaker and evaluate F1 score for the given speaker sequences. Default is False. | ||
split_infer: True # If True, break the input audio clip to short sequences and calculate cluster average embeddings for inference. | ||
diar_window_length: 50 # The length of split short sequence when split_infer is True. | ||
overlap_infer_spk_limit: 5 # If the estimated number of speakers are larger than this number, overlap speech is not estimated. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import os | ||
import json | ||
import copy | ||
import glob | ||
import tempfile | ||
from pydub import AudioSegment | ||
from omegaconf import OmegaConf | ||
from nemo.collections.asr.models.msdd_models import NeuralDiarizer | ||
from nemo.collections.asr.modules.msdd_diarizer import MSDD_module | ||
from batchalign.models.speaker.utils import conv_scale_weights | ||
|
||
import torch | ||
|
||
# compute device | ||
DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" | ||
|
||
# override msdd implementation | ||
MSDD_module.conv_scale_weights = conv_scale_weights | ||
|
||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = str(1) | ||
|
||
INPUT = "/Users/houjun/Documents/Projects/batchalign2/extern/Untitled.wav" | ||
NUM_SPEAKERS = 2 | ||
|
||
def resolve_config(): | ||
return os.path.join(os.path.abspath(os.path.join(__file__, os.pardir)), "config.yaml") | ||
|
||
base = OmegaConf.load(resolve_config()) | ||
|
||
# get a input dir | ||
in_file = INPUT | ||
# make a copy of the input config | ||
config = copy.deepcopy(base) | ||
# create a working directory and configure settings | ||
with tempfile.TemporaryDirectory() as workdir: | ||
# create the mono file | ||
sound = AudioSegment.from_file(in_file).set_channels(1) | ||
sound.export(os.path.join(workdir, "mono_file.wav"), format="wav") | ||
|
||
# create configuration with the info we need | ||
meta = { | ||
"audio_filepath": os.path.join(workdir, "mono_file.wav"), | ||
"offset": 0, | ||
"duration": None, | ||
"label": "infer", | ||
"text": "-", | ||
"rttm_filepath": None, | ||
"uem_filepath": None, | ||
"num_speakers": NUM_SPEAKERS | ||
} | ||
manifest_path = os.path.join(workdir, "input_manifest.json") | ||
with open(manifest_path, "w") as fp: | ||
json.dump(meta, fp) | ||
fp.write("\n") | ||
config.diarizer.manifest_filepath = manifest_path | ||
config.diarizer.out_dir = workdir | ||
config.device = DEVICE | ||
|
||
# initialize a diarizer and brrr | ||
msdd_model = NeuralDiarizer(cfg=config) | ||
msdd_model.diarize() | ||
|
||
# read output and return | ||
# https://github.com/MahmoudAshraf97/whisper-diarization/blob/main/diarize.py | ||
speaker_ts = [] | ||
with open(os.path.join(workdir, "pred_rttms", "mono_file.rttm"), "r") as f: | ||
lines = f.readlines() | ||
for line in lines: | ||
line_list = line.split(" ") | ||
s = int(float(line_list[5]) * 1000) | ||
e = s + int(float(line_list[8]) * 1000) | ||
speaker_ts.append([s, e, int(line_list[11].split("_")[-1])]) | ||
print(speaker_ts) | ||
|
||
|
||
|
||
# breakpoint() | ||
|
||
|
||
|
||
# conf = | ||
# import torch | ||
|
||
# torch.rand(1,10).to("mps:0") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import torch | ||
import torch.nn.functional as F | ||
|
||
def conv_scale_weights(self, ms_avg_embs_perm, ms_emb_seq_single): | ||
""" | ||
Use multiple Convnet layers to estimate the scale weights based on the cluster-average embedding and | ||
input embedding sequence. | ||
Args: | ||
ms_avg_embs_perm (Tensor): | ||
Tensor containing cluster-average speaker embeddings for each scale. | ||
Shape: (batch_size, length, scale_n, emb_dim) | ||
ms_emb_seq_single (Tensor): | ||
Tensor containing multi-scale speaker embedding sequences. ms_emb_seq_single is input from the | ||
given audio stream input. | ||
Shape: (batch_size, length, num_spks, emb_dim) | ||
Returns: | ||
scale_weights (Tensor): | ||
Weight vectors that determine the weight of each scale. | ||
Shape: (batch_size, length, num_spks, emb_dim) | ||
""" | ||
ms_cnn_input_seq = torch.cat([ms_avg_embs_perm, ms_emb_seq_single], dim=2) | ||
ms_cnn_input_seq = ms_cnn_input_seq.unsqueeze(2).flatten(0, 1) | ||
|
||
conv_out = self.conv_forward( | ||
ms_cnn_input_seq, conv_module=self.conv[0], bn_module=self.conv_bn[0], first_layer=True | ||
) | ||
for conv_idx in range(1, self.conv_repeat + 1): | ||
conv_out = self.conv_forward( | ||
conv_input=conv_out, | ||
conv_module=self.conv[conv_idx], | ||
bn_module=self.conv_bn[conv_idx], | ||
first_layer=False, | ||
) | ||
|
||
# reshape / view | ||
lin_input_seq = conv_out.reshape(self.batch_size, self.length, self.cnn_output_ch * self.emb_dim) | ||
hidden_seq = self.conv_to_linear(lin_input_seq) | ||
hidden_seq = self.dropout(F.leaky_relu(hidden_seq)) | ||
scale_weights = self.softmax(self.linear_to_weights(hidden_seq)) | ||
scale_weights = scale_weights.unsqueeze(3).expand(-1, -1, -1, self.num_spks) | ||
return scale_weights | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
4 changes: 2 additions & 2 deletions
4
batchalign/models/bert_utterance/execute.py → batchalign/models/utterance/execute.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
from .whisper_asr import WhisperASRModel | ||
from .whisper_fa import WhisperFAModel | ||
from .infer_asr import WhisperASRModel | ||
from .infer_fa import WhisperFAModel |
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters