From 05f3cb63e381af4a899e8ef2acfc55f4c703afb2 Mon Sep 17 00:00:00 2001 From: yaoyu-33 Date: Fri, 22 Nov 2024 11:34:44 -0800 Subject: [PATCH] update mllama 90b recipe Signed-off-by: yaoyu-33 --- nemo/collections/vlm/recipes/mllama_90b.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/nemo/collections/vlm/recipes/mllama_90b.py b/nemo/collections/vlm/recipes/mllama_90b.py index 787cc54483ec..12e0329fc6dd 100644 --- a/nemo/collections/vlm/recipes/mllama_90b.py +++ b/nemo/collections/vlm/recipes/mllama_90b.py @@ -26,6 +26,7 @@ from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed from nemo.collections.vlm.mllama.data.mock import MockDataModule +from nemo.utils.exp_manager import TimingCallback NAME = "mllama_90b" @@ -46,7 +47,7 @@ def model() -> run.Config[pl.LightningModule]: >>> model_config = model() >>> print(model_config) """ - return run.Config(vlm.MLlamaModel, config=run.Config(vlm.MLlamaConfig90B)) + return run.Config(vlm.MLlamaModel, config=run.Config(vlm.MLlamaConfig90BInstruct)) @run.cli.factory(target=llm.finetune, name=NAME) @@ -107,6 +108,7 @@ def finetune_recipe( plugins=bf16_mixed(), strategy=strategy, val_check_interval=100, + callbacks=[run.Config(TimingCallback)], ) recipe = run.Partial( @@ -116,7 +118,7 @@ def finetune_recipe( data=run.Config( MockDataModule, seq_length=6404, # encoder (vision) seq length - decoder_seq_length=512, # decoder (llm) seq length + decoder_seq_length=2048, # decoder (llm) seq length global_batch_size=16, micro_batch_size=2, vocab_size=128256, @@ -125,7 +127,7 @@ def finetune_recipe( ), log=llm.default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), optim=distributed_fused_adam_with_cosine_annealing(max_lr=1e-4, min_lr=2.0e-07, warmup_steps=150), - resume=nemo_resume("meta-llama/Llama-3.2-90B-Vision"), + resume=nemo_resume("meta-llama/Llama-3.2-90B-Vision-Instruct"), ) if peft_scheme is None or peft_scheme.lower() == 'none':