Skip to content

Commit

Permalink
Merge branch 'refs/heads/yuya/update_mllama_use_attn_bias' into yuya/…
Browse files Browse the repository at this point in the history
…add_neva_scripts_and_tests
  • Loading branch information
yaoyu-33 committed Nov 22, 2024
2 parents 16a6fc6 + 05f3cb6 commit 617c8ca
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 10 deletions.
2 changes: 2 additions & 0 deletions nemo/collections/vlm/recipes/mllama_11b.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ def finetune_recipe(
recipe.trainer.strategy.tensor_model_parallel_size = 2
recipe.optim.config.lr = 2e-05
elif peft_scheme.lower() == 'lora':
# pylint: disable=line-too-long
"""Adapted from https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/configs/peft.py"""
recipe.peft = run.Config(
vlm.LoRA,
freeze_vision_model=True,
Expand Down
25 changes: 15 additions & 10 deletions nemo/collections/vlm/recipes/mllama_90b.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing
from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed
from nemo.collections.vlm.mllama.data.mock import MockDataModule
from nemo.utils.exp_manager import TimingCallback

NAME = "mllama_90b"

Expand All @@ -46,7 +47,7 @@ def model() -> run.Config[pl.LightningModule]:
>>> model_config = model()
>>> print(model_config)
"""
return run.Config(vlm.MLlamaModel, config=run.Config(vlm.MLlamaConfig90B))
return run.Config(vlm.MLlamaModel, config=run.Config(vlm.MLlamaConfig90BInstruct))


@run.cli.factory(target=llm.finetune, name=NAME)
Expand Down Expand Up @@ -107,6 +108,7 @@ def finetune_recipe(
plugins=bf16_mixed(),
strategy=strategy,
val_check_interval=100,
callbacks=[run.Config(TimingCallback)],
)

recipe = run.Partial(
Expand All @@ -116,7 +118,7 @@ def finetune_recipe(
data=run.Config(
MockDataModule,
seq_length=6404, # encoder (vision) seq length
decoder_seq_length=512, # decoder (llm) seq length
decoder_seq_length=2048, # decoder (llm) seq length
global_batch_size=16,
micro_batch_size=2,
vocab_size=128256,
Expand All @@ -125,23 +127,26 @@ def finetune_recipe(
),
log=llm.default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)),
optim=distributed_fused_adam_with_cosine_annealing(max_lr=1e-4, min_lr=2.0e-07, warmup_steps=150),
resume=nemo_resume("meta-llama/Llama-3.2-90B-Vision"),
resume=nemo_resume("meta-llama/Llama-3.2-90B-Vision-Instruct"),
)

if peft_scheme is None or peft_scheme.lower() == 'none':
raise ValueError("Full finetuning recipe for Llama-3.2-90B model will be supported soon.")
elif peft_scheme.lower() == 'lora':
# pylint: disable=line-too-long
"""Adapted from https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/configs/peft.py"""
recipe.peft = run.Config(
vlm.LoRA,
freeze_vision_model=False,
freeze_vision_model=True,
target_modules=[
"*.language_model.*.linear_qkv",
"*.language_model.*.linear_q",
"*.language_model.*.linear_kv",
"*.language_model.*.linear_proj",
"*.language_model.*.linear_fc1",
"*.language_model.*.linear_fc2",
"linear_qkv",
"linear_q",
"linear_kv",
],
dim=8,
alpha=32,
dropout=0.05,
dropout_position="pre",
)
recipe.optim.config.lr = 1e-4
else:
Expand Down

0 comments on commit 617c8ca

Please sign in to comment.