From ba7a68255bb2be0d449f7c63ed43178f78e188fd Mon Sep 17 00:00:00 2001 From: Dmytro Pykhtar <37850217+dimapihtar@users.noreply.github.com> Date: Fri, 22 Nov 2024 21:22:25 +0200 Subject: [PATCH 1/3] mlm conversion & tiktokenizer support (#11349) * mlm conversion fix Signed-off-by: dimapihtar * add tiktoken support for nemotron -> hf Signed-off-by: dimapihtar * additional params Signed-off-by: dimapihtar * Apply isort and black reformatting Signed-off-by: dimapihtar * add ci test for mlm conversion Signed-off-by: dimapihtar * add ci test for mlm ckpt conversion Signed-off-by: dimapihtar * remove extra if statement Signed-off-by: dimapihtar * fix typo Signed-off-by: dimapihtar * fix if statement Signed-off-by: dimapihtar * fix paths Signed-off-by: dimapihtar * update paths Signed-off-by: dimapihtar --------- Signed-off-by: dimapihtar Signed-off-by: dimapihtar Co-authored-by: dimapihtar --- .github/workflows/cicd-main.yml | 121 ++++++++++++++++++ .../megatron_ckpt_to_nemo.py | 9 +- nemo/collections/nlp/models/nlp_model.py | 17 ++- .../convert_nemotron_nemo_to_hf.py | 38 +++++- 4 files changed, 182 insertions(+), 3 deletions(-) diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 49c6c55ca778..b82bbc65cfc1 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -2109,6 +2109,126 @@ jobs: # } # } + L2_Megatron_LM_To_NeMo_Conversion: + needs: [cicd-test-container-setup] + uses: ./.github/workflows/_test_template.yml + if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_Megatron_LM_To_NeMo_Conversion') || needs.cicd-test-container-setup.outputs.all == 'true' + with: + RUNNER: self-hosted-azure + SCRIPT: | + CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=1 Megatron-LM/pretrain_gpt.py \ + --mock-data \ + --distributed-timeout-minutes 60 \ + --use-mcore-models \ + --no-mmap-bin-files \ + --untie-embeddings-and-output-weights \ + --disable-bias-linear \ + --train-samples 80 \ + --init-method-std 0.014 \ + --position-embedding-type rope \ + --rotary-base 1000000 \ + --rotary-percent 1.0 \ + --squared-relu \ + --num-layers 4 \ + --hidden-size 384 \ + --num-attention-heads 8 \ + --group-query-attention \ + --num-query-groups 8 \ + --ffn-hidden-size 1536 \ + --kv-channels 128 \ + --normalization RMSNorm \ + --attention-dropout 0.0 \ + --hidden-dropout 0.0 \ + --exit-duration-in-mins 5750 \ + --tensor-model-parallel-size 1 \ + --pipeline-model-parallel-size 1 \ + --seq-length 8192 \ + --max-position-embeddings 8192 \ + --micro-batch-size 1 \ + --global-batch-size 8 \ + --lr 6e-4 \ + --min-lr 6e-6 \ + --weight-decay 0.1 \ + --clip-grad 1.0 \ + --lr-decay-style cosine \ + --log-interval 1 \ + --eval-iters 1 \ + --eval-interval 10 \ + --tokenizer-type GPT2BPETokenizer \ + --tokenizer-model /home/TestData/nlp/gpt2_tokenizer \ + --vocab-file /home/TestData/nlp/gpt2_tokenizer/vocab.json \ + --merge-file /home/TestData/nlp/gpt2_tokenizer/merges.txt \ + --save /tmp/mlm_conversion_ckpt \ + --save-interval 10 \ + --ckpt-format torch_dist \ + --ckpt-fully-parallel-save \ + --ckpt-fully-parallel-load \ + --async-save \ + --ckpt-assume-constant-structure \ + --timing-log-option minmax \ + --log-params-norm \ + --log-num-zeros-in-grad \ + --log-throughput \ + --bf16 \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --use-distributed-optimizer \ + --overlap-grad-reduce \ + --overlap-param-gather \ + --manual-gc \ + --num-workers 2 + + python examples/nlp/language_modeling/megatron_gpt_pretraining.py \ + model.data.data_impl=mock \ + model.data.data_prefix=[] \ + model.skip_train=True \ + model.transformer_engine=True \ + model.use_flash_attention=False \ + model.normalization=rmsnorm \ + model.num_layers=4 \ + model.hidden_size=384 \ + model.ffn_hidden_size=1536 \ + model.num_attention_heads=8 \ + model.num_query_groups=8 \ + model.bias=False \ + model.bias_activation_fusion=False \ + model.bias_dropout_add_fusion=True \ + model.masked_softmax_fusion=True \ + model.encoder_seq_length=8192 \ + model.max_position_embeddings=8192 \ + model.data.seq_length=8192 \ + model.activation=squared-relu \ + model.transformer_block_type=True \ + model.micro_batch_size=1 \ + model.global_batch_size=8 \ + ++model.rotary_base=1000000 \ + model.rotary_percentage=1.0 \ + model.apply_query_key_layer_scaling=False \ + ++model.group_query_attention=True \ + model.apply_rope_fusion=True \ + model.kv_channels=128 \ + ++model.bert_binary_head=True \ + ++model.position_embedding_type=rope \ + ++model.add_position_embedding=True \ + trainer.limit_val_batches=1 \ + exp_manager.exp_dir=/tmp/nemo_conversion_ckpt + + python -m torch.distributed.launch --nproc_per_node=1 examples/nlp/language_modeling/megatron_ckpt_to_nemo.py \ + --checkpoint_folder /tmp/mlm_conversion_ckpt \ + --checkpoint_name iter_0000010 \ + --nemo_file_path /tmp/mlm_to_nemo_test.nemo \ + --tensor_model_parallel_size 1 \ + --pipeline_model_parallel_size 1 \ + --gpus_per_node 1 \ + --model_type gpt \ + --hparams_file /tmp/nemo_conversion_ckpt/megatron_gpt/version_0/hparams.yaml \ + --convert_mlm + + AFTER_SCRIPT: | + rm -rf /tmp/nemo_conversion_ckpt + rm -rf /tmp/mlm_conversion_ckpt + rm -rf /tmp/mlm_to_nemo_test.nemo + L2_Megatron_GPT_with_ResetLR_Pretraining_and_Resume_Training_TP2: needs: [cicd-test-container-setup] uses: ./.github/workflows/_test_template.yml @@ -4432,6 +4552,7 @@ jobs: - L2_RAG_Pipeline_Generating - L2_Megatron_GPT_Pretraining_and_Resume_Training_TP2 - L2_Megatron_GPT_Skip_Train + - L2_Megatron_LM_To_NeMo_Conversion - L2_Megatron_GPT_with_Rope_Pretraining_and_Resume_Training_TP2 - L2_Megatron_GPT_with_ResetLR_Pretraining_and_Resume_Training_TP2 - L2_Megatron_GPT_with_Drop_Optimizer_States_TP2 diff --git a/examples/nlp/language_modeling/megatron_ckpt_to_nemo.py b/examples/nlp/language_modeling/megatron_ckpt_to_nemo.py index b46f8f459ff0..4b9fab987dc7 100644 --- a/examples/nlp/language_modeling/megatron_ckpt_to_nemo.py +++ b/examples/nlp/language_modeling/megatron_ckpt_to_nemo.py @@ -112,6 +112,11 @@ def get_args(): choices=['32-true', '16-mixed', 'bf16-mixed'], help="Precision value for the trainer that matches with precision of the ckpt", ) + parser.add_argument( + "--convert_mlm", + action="store_true", + help="Use this flag to convert megatron-lm checkpoints.", + ) args = parser.parse_args() return args @@ -195,7 +200,9 @@ def convert(local_rank, rank, world_size, args): ) if args.model_type == 'gpt': - model = MegatronGPTModel.load_from_checkpoint(checkpoint_path, hparams_file=args.hparams_file, trainer=trainer) + model = MegatronGPTModel.load_from_checkpoint( + checkpoint_path, hparams_file=args.hparams_file, trainer=trainer, load_mlm=args.convert_mlm + ) elif args.model_type == 'sft': model = MegatronGPTSFTModel.load_from_checkpoint( checkpoint_path, hparams_file=args.hparams_file, trainer=trainer diff --git a/nemo/collections/nlp/models/nlp_model.py b/nemo/collections/nlp/models/nlp_model.py index 0c61b085bc7f..6a87eb28723c 100644 --- a/nemo/collections/nlp/models/nlp_model.py +++ b/nemo/collections/nlp/models/nlp_model.py @@ -397,7 +397,22 @@ def dummy(): model.trainer.strategy.launcher.launch(dummy, trainer=model.trainer) model.trainer.strategy.setup_environment() sharded_state_dict = model.sharded_state_dict() - checkpoint['state_dict'] = sharded_state_dict + if kwargs.get("load_mlm", False): + mlm_sharded_state_dict = {} + for k, v in sharded_state_dict.items(): + # Remove 'model.' from the sharded_state_dict keys + new_key = k.replace('model.', '', 1) + + # Update the key attribute of the ShardedTensor value + new_value = v + if hasattr(v, 'key'): + new_value.key = v.key.replace('model.', '', 1) + + # Add the updated key-value pair to the new dictionary + mlm_sharded_state_dict[new_key] = new_value + checkpoint['state_dict'] = mlm_sharded_state_dict + else: + checkpoint['state_dict'] = sharded_state_dict # load the checkpoint from disk checkpoint = dist_checkpointing.load(sharded_state_dict=checkpoint, checkpoint_dir=checkpoint_dir) # restore the weights diff --git a/scripts/checkpoint_converters/convert_nemotron_nemo_to_hf.py b/scripts/checkpoint_converters/convert_nemotron_nemo_to_hf.py index 392e3628ccdb..2f66773f8724 100644 --- a/scripts/checkpoint_converters/convert_nemotron_nemo_to_hf.py +++ b/scripts/checkpoint_converters/convert_nemotron_nemo_to_hf.py @@ -21,7 +21,7 @@ import torch from lightning.pytorch import Trainer from transformers import LlamaTokenizer, PreTrainedTokenizerFast -from transformers.convert_slow_tokenizer import LlamaConverter +from transformers.convert_slow_tokenizer import LlamaConverter, TikTokenConverter from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel @@ -130,6 +130,20 @@ def convert_hf_config(nemo_config, tokenizer, vocab_size, dtype, hf_output_path, json.dump(hf_config, open(f"{hf_output_path}/config.json", "w"), indent=2) +def convert_tiktoken(vocab_file) -> None: + with open(vocab_file, 'r') as f: + vocab = json.load(f) + os.remove(vocab_file) + + lines = [] + for line in vocab: + lines.append(f"{line['token_bytes']} {line['rank']}") + + for line in lines: + with open(vocab_file, 'a') as f: + f.write(line + '\n') + + def convert(input_nemo_file, output_hf_file, precision=None, cpu_only=False) -> None: """ Convert NeMo weights to HF weights @@ -323,6 +337,28 @@ def extract_nemotron_tokenizer(nemo_file, model_config, output_hf_path, nemo_tok ) tokenizer.save_pretrained(output_hf_path) logging.info(f"Setencepiece tokenizer has been saved to {output_tokenizer}") + elif tokenizer_cfg.library == "tiktoken": + tokenizer_fn = tokenizer_cfg.model[5:] + special_tokens = ["", "", ""] + import tarfile + + archive = tarfile.open(nemo_file, "r") + tokenizer_filename = "./" + tokenizer_fn # exclude 'nemo:' prefix + archive.extract(tokenizer_filename, output_hf_path) + archive.close() + vocab_file = os.path.join(output_hf_path, tokenizer_fn) + convert_tiktoken(vocab_file) + converted_tokenizer = TikTokenConverter( + vocab_file=vocab_file, additional_special_tokens=special_tokens + ).converted() + os.remove(vocab_file) + tokenizer = PreTrainedTokenizerFast( + tokenizer_object=converted_tokenizer, + model_input_names=["input_ids", "attention_mask"], + bos_token="", + eos_token="", + ) + tokenizer.save_pretrained(output_hf_path) elif isinstance(nemo_tokenizer, AutoTokenizer): nemo_tokenizer.tokenizer.save_pretrained(output_hf_path) logging.info(f"HF AutoTokenizer has been saved to {output_hf_path}") From 7ec58fab6fe990efb6abf18b68bc2eceffbbd457 Mon Sep 17 00:00:00 2001 From: Pablo Garay Date: Fri, 22 Nov 2024 11:25:39 -0800 Subject: [PATCH 2/3] nit: remove non-strictly needed lines --- .github/workflows/cicd-main.yml | 5 ----- 1 file changed, 5 deletions(-) diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index b82bbc65cfc1..a4b2baa59550 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -2224,11 +2224,6 @@ jobs: --hparams_file /tmp/nemo_conversion_ckpt/megatron_gpt/version_0/hparams.yaml \ --convert_mlm - AFTER_SCRIPT: | - rm -rf /tmp/nemo_conversion_ckpt - rm -rf /tmp/mlm_conversion_ckpt - rm -rf /tmp/mlm_to_nemo_test.nemo - L2_Megatron_GPT_with_ResetLR_Pretraining_and_Resume_Training_TP2: needs: [cicd-test-container-setup] uses: ./.github/workflows/_test_template.yml From 86315db0bfa2edee49000e36cdc9bb7fedb1fa8d Mon Sep 17 00:00:00 2001 From: tango4j Date: Fri, 22 Nov 2024 22:58:55 +0000 Subject: [PATCH 3/3] Apply isort and black reformatting Signed-off-by: tango4j --- .../asr/parts/utils/manifest_utils.py | 20 +++-- tests/collections/asr/test_diar_utils.py | 84 +++++++++++++++---- 2 files changed, 78 insertions(+), 26 deletions(-) diff --git a/nemo/collections/asr/parts/utils/manifest_utils.py b/nemo/collections/asr/parts/utils/manifest_utils.py index 55c8e3dbb8c3..e05108c509ff 100644 --- a/nemo/collections/asr/parts/utils/manifest_utils.py +++ b/nemo/collections/asr/parts/utils/manifest_utils.py @@ -67,11 +67,11 @@ def get_ctm_line( ) -> str: """ Get a line in Conversation Time Mark (CTM) format. Following CTM format appeared in `Rich Transcription Meeting Eval Plan: RT09` document. - - CTM Format: + + CTM Format: - - Reference: + + Reference: https://web.archive.org/web/20170119114252/http://www.itl.nist.gov/iad/mig/tests/rt/2009/docs/rt09-meeting-eval-plan-v2.pdf Args: @@ -80,11 +80,11 @@ def get_ctm_line( start_time (float): is the begin time of the word, which we refer to as `start_time` in NeMo. duration (float): is duration of the word token (str): Token or word for the current entry - conf (float): is a floating point number between 0 (no confidence) and 1 (certainty). A value of “NA” is used (in CTM format data) - when no confidence is computed and in the reference data. + conf (float): is a floating point number between 0 (no confidence) and 1 (certainty). A value of “NA” is used (in CTM format data) + when no confidence is computed and in the reference data. type_of_token (str): is the token type. The legal values of are “lex”, “frag”, “fp”, “un-lex”, “for-lex”, “non-lex”, “misc”, or “noscore” speaker (str): is a string identifier for the speaker who uttered the token. This should be “null” for non-speech tokens and “unknown” when - the speaker has not been determined. + the speaker has not been determined. NA_token (str, optional): A token for . Defaults to ''. output_precision (int, optional): The precision of the output floating point number. Defaults to 3. @@ -368,7 +368,11 @@ def create_segment_manifest( segments_manifest_file = write_rttm2manifest(AUDIO_RTTM_MAP, segment_manifest_path, deci) subsegments_manifest_file = subsegment_manifest_path segments_manifest_to_subsegments_manifest( - segments_manifest_file, subsegments_manifest_file, window, shift, min_subsegment_duration, + segments_manifest_file, + subsegments_manifest_file, + window, + shift, + min_subsegment_duration, ) subsegments_dict = get_subsegment_dict(subsegments_manifest_file, window, shift, deci) write_truncated_subsegments(input_manifest_dict, subsegments_dict, output_manifest_path, step_count, deci) diff --git a/tests/collections/asr/test_diar_utils.py b/tests/collections/asr/test_diar_utils.py index a72313923a66..cb364675fcf4 100644 --- a/tests/collections/asr/test_diar_utils.py +++ b/tests/collections/asr/test_diar_utils.py @@ -82,8 +82,7 @@ def matrix(mat, use_tensor=True, dtype=torch.long): def generate_orthogonal_embs(total_spks, perturb_sigma, emb_dim): - """Generate a set of artificial orthogonal embedding vectors from random numbers - """ + """Generate a set of artificial orthogonal embedding vectors from random numbers""" gaus = torch.randn(emb_dim, emb_dim) _svd = torch.linalg.svd(gaus) orth = _svd[0] @ _svd[2] @@ -130,8 +129,7 @@ def generate_toy_data( class TestDiarizationSequneceUtilFunctions: - """Tests diarization and speaker-task related utils. - """ + """Tests diarization and speaker-task related utils.""" @pytest.mark.unit @pytest.mark.parametrize("Y", [[3, 3, 3, 4, 4, 5], [100, 100, 100, 104, 104, 1005]]) @@ -278,7 +276,10 @@ def test_embedding_reducer(self, n_spks, target_speaker_index, merge_quantity): em, ts, mc, mw, spk_ts, gt = generate_toy_data(n_spks=n_spks, spk_dur=10) em_s, ts_s = split_input_data(em, ts, mc) merged_embs, merged_clus_labels, _ = run_reducer( - pre_embs=em_s[-1], target_spk_idx=target_speaker_index, merge_quantity=merge_quantity, pre_clus_labels=gt, + pre_embs=em_s[-1], + target_spk_idx=target_speaker_index, + merge_quantity=merge_quantity, + pre_clus_labels=gt, ) assert (torch.sum(gt == target_speaker_index).item() - merge_quantity) == merged_clus_labels.shape[0] @@ -287,7 +288,11 @@ def test_embedding_reducer(self, n_spks, target_speaker_index, merge_quantity): @pytest.mark.parametrize("pcl", [torch.tensor([0] * 70 + [1] * 32)]) @pytest.mark.parametrize("mspb", [25]) def test_merge_scheduler_2clus(self, ntbr, pcl, mspb): - class_target_vol = get_merge_quantity(num_to_be_removed=ntbr, pre_clus_labels=pcl, min_count_per_cluster=mspb,) + class_target_vol = get_merge_quantity( + num_to_be_removed=ntbr, + pre_clus_labels=pcl, + min_count_per_cluster=mspb, + ) assert all(class_target_vol == torch.tensor([3, 0])) @pytest.mark.unit @@ -295,7 +300,11 @@ def test_merge_scheduler_2clus(self, ntbr, pcl, mspb): @pytest.mark.parametrize("pcl", [torch.tensor([0] * 80 + [1] * 35 + [2] * 32)]) @pytest.mark.parametrize("mspb", [0, 25]) def test_merge_scheduler_3clus(self, ntbr, pcl, mspb): - class_target_vol = get_merge_quantity(num_to_be_removed=ntbr, pre_clus_labels=pcl, min_count_per_cluster=mspb,) + class_target_vol = get_merge_quantity( + num_to_be_removed=ntbr, + pre_clus_labels=pcl, + min_count_per_cluster=mspb, + ) assert all(class_target_vol == torch.tensor([3, 0, 0])) @pytest.mark.unit @@ -303,7 +312,11 @@ def test_merge_scheduler_3clus(self, ntbr, pcl, mspb): @pytest.mark.parametrize("pcl", [torch.tensor([2] * 70 + [0] * 32 + [1] * 27 + [3] * 3)]) @pytest.mark.parametrize("mspb", [3, 10]) def test_merge_scheduler_4clus_shuff(self, ntbr, pcl, mspb): - class_target_vol = get_merge_quantity(num_to_be_removed=ntbr, pre_clus_labels=pcl, min_count_per_cluster=mspb,) + class_target_vol = get_merge_quantity( + num_to_be_removed=ntbr, + pre_clus_labels=pcl, + min_count_per_cluster=mspb, + ) assert all(class_target_vol == torch.tensor([18, 13, 56, 0])) @pytest.mark.unit @@ -311,7 +324,11 @@ def test_merge_scheduler_4clus_shuff(self, ntbr, pcl, mspb): @pytest.mark.parametrize("pcl", [torch.tensor([0] * 5 + [1] * 4 + [2] * 3)]) @pytest.mark.parametrize("mspb", [0, 2]) def test_merge_scheduler_3clus(self, ntbr, pcl, mspb): - class_target_vol = get_merge_quantity(num_to_be_removed=ntbr, pre_clus_labels=pcl, min_count_per_cluster=mspb,) + class_target_vol = get_merge_quantity( + num_to_be_removed=ntbr, + pre_clus_labels=pcl, + min_count_per_cluster=mspb, + ) assert all(class_target_vol == torch.tensor([2, 1, 0])) @pytest.mark.unit @@ -319,7 +336,11 @@ def test_merge_scheduler_3clus(self, ntbr, pcl, mspb): @pytest.mark.parametrize("pcl", [torch.tensor([0] * 7 + [1] * 5 + [2] * 3 + [3] * 5)]) @pytest.mark.parametrize("mspb", [2]) def test_merge_scheduler_3clus_repeat(self, ntbr, pcl, mspb): - class_target_vol = get_merge_quantity(num_to_be_removed=ntbr, pre_clus_labels=pcl, min_count_per_cluster=mspb,) + class_target_vol = get_merge_quantity( + num_to_be_removed=ntbr, + pre_clus_labels=pcl, + min_count_per_cluster=mspb, + ) assert all(class_target_vol == torch.tensor([2, 0, 0, 0])) @@ -414,13 +435,21 @@ def test_is_overlap_false(self, rangeA, rangeB): @pytest.mark.parametrize("x", [1.0, 2.3456]) @pytest.mark.parametrize("decimals", [1, 2, 3, 4]) def test_fl2int(self, x, decimals): - assert fl2int(x, decimals) == round(x * 10 ** decimals, 0) + assert fl2int(x, decimals) == round(x * 10**decimals, 0) @pytest.mark.unit @pytest.mark.parametrize("x", [1234]) - @pytest.mark.parametrize("decimals", [1, 2, 3, 4,]) + @pytest.mark.parametrize( + "decimals", + [ + 1, + 2, + 3, + 4, + ], + ) def test_int2fl(self, x, decimals): - assert abs(int2fl(x, decimals) - round(x / (10 ** decimals), decimals)) < (10 ** -(decimals + 1)) + assert abs(int2fl(x, decimals) - round(x / (10**decimals), decimals)) < (10 ** -(decimals + 1)) @pytest.mark.unit def test_merge_float_intervals_edge_margin_test(self): @@ -462,7 +491,11 @@ def test_get_speech_labels_for_update(self): vad_timestamps = torch.tensor([[0.9600, 4.8400]]) cursor_for_old_segments = 1.0 speech_labels_for_update, cumulative_speech_labels = get_speech_labels_for_update( - frame_start, buffer_end, cumulative_speech_labels, vad_timestamps, cursor_for_old_segments, + frame_start, + buffer_end, + cumulative_speech_labels, + vad_timestamps, + cursor_for_old_segments, ) assert (speech_labels_for_update - torch.tensor([[1.0000, 3.7600]])).sum() < 1e-8 assert (cumulative_speech_labels - torch.tensor([[0.9600, 4.8400]])).sum() < 1e-8 @@ -532,7 +565,10 @@ def test_tensor_to_list(self, source_range_list): @pytest.mark.unit @pytest.mark.parametrize( "buffer_start, buffer_end, subsegments, ind_offset, window, sample_rate", - [(0.0, 2.0, [[0.5, 1.0], [1.5, 2.0]], 0, 0.1, 16000), (0.0, 5.0, [[0.5, 2.5], [2.7, 5.0]], 0, 1.0, 16000),], + [ + (0.0, 2.0, [[0.5, 1.0], [1.5, 2.0]], 0, 0.1, 16000), + (0.0, 5.0, [[0.5, 2.5], [2.7, 5.0]], 0, 1.0, 16000), + ], ) def test_get_online_segments_from_slices( self, buffer_start, buffer_end, subsegments, ind_offset, window, sample_rate @@ -665,7 +701,13 @@ def test_offline_speaker_clustering_cpu(self, n_spks, total_sec, SSV, perturb_si @pytest.mark.parametrize("SSV, enhanced_count_thres, min_samples_for_nmesc", [(5, 40, 6)]) @pytest.mark.parametrize("seed", [0]) def test_offline_speaker_clustering_very_short_cpu( - self, n_spks, spk_dur, SSV, enhanced_count_thres, min_samples_for_nmesc, seed, + self, + n_spks, + spk_dur, + SSV, + enhanced_count_thres, + min_samples_for_nmesc, + seed, ): em, ts, mc, mw, spk_ts, gt = generate_toy_data( n_spks=n_spks, spk_dur=spk_dur, perturb_sigma=0.1, torch_seed=seed @@ -697,7 +739,13 @@ def test_offline_speaker_clustering_very_short_cpu( @pytest.mark.parametrize("n_spks, SSV, enhanced_count_thres, min_samples_for_nmesc", [(1, 5, 40, 6)]) @pytest.mark.parametrize("seed", [0]) def test_offline_speaker_clustering_very_short_gpu( - self, n_spks, spk_dur, SSV, enhanced_count_thres, min_samples_for_nmesc, seed, + self, + n_spks, + spk_dur, + SSV, + enhanced_count_thres, + min_samples_for_nmesc, + seed, ): em, ts, mc, mw, spk_ts, gt = generate_toy_data( n_spks=n_spks, spk_dur=spk_dur, perturb_sigma=0.1, torch_seed=seed @@ -908,7 +956,7 @@ def test_linear_sum_assignment_algorithm_cost_matrix(self, cost_matrix): Test the linear sum assignment algorithm with a cost matrix Compare with the scipy implementation and make sure the final cost is the same. - NOTE: There could be multiple solutions with the same cost in linear sum assignment problem. + NOTE: There could be multiple solutions with the same cost in linear sum assignment problem. This test only checks if the cost is the same. """ row_ind_nm, col_ind_nm = nemo_linear_sum_assignment(cost_matrix)