From 2cd7e1e07e6949957f12e468cd7c463d6b85eaba Mon Sep 17 00:00:00 2001 From: Onur Yilmaz Date: Thu, 21 Nov 2024 13:50:28 -0500 Subject: [PATCH 1/3] Fix DDP unused param error when TE is enabled Signed-off-by: Onur Yilmaz --- examples/llm/sft/hf.py | 14 +++----------- .../llm/gpt/model/hf_auto_model_for_causal_lm.py | 9 +++++++++ 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/examples/llm/sft/hf.py b/examples/llm/sft/hf.py index 59b8b4ad3491..1956258aed89 100755 --- a/examples/llm/sft/hf.py +++ b/examples/llm/sft/hf.py @@ -19,7 +19,7 @@ from nemo import lightning as nl from nemo.collections import llm -from nemo.lightning.pytorch.accelerate.transformer_engine import is_te_accelerated, te_accelerate +from nemo.lightning.pytorch.accelerate.transformer_engine import is_te_accelerated from nemo.lightning.pytorch.callbacks import ModelCallback @@ -75,17 +75,9 @@ def squad(tokenizer) -> pl.LightningDataModule: grad_clip = None use_dist_samp = False - model = llm.HfAutoModelForCausalLM(args.model) + model = llm.HfAutoModelForCausalLM(model_name=args.model, model_accelerator=args.model_accelerator) tokenizer = model.tokenizer - callbacks = [] - if args.model_accelerator: - if args.model_accelerator == "te": - model_transform = ModelCallback( - on_train_start=lambda model: te_accelerate(model, fp8_autocast=args.fp8_autocast) - ) - callbacks.append(model_transform) - llm.api.finetune( model=model, data=squad(tokenizer), @@ -100,7 +92,7 @@ def squad(tokenizer) -> pl.LightningDataModule: accumulate_grad_batches=10, gradient_clip_val=grad_clip, use_distributed_sampler=use_dist_samp, - callbacks=callbacks, + callbacks=[], logger=wandb, ), optim=fdl.build(llm.adam.pytorch_adam_with_flat_lr(lr=1e-5)), diff --git a/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py b/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py index c0f02d706ceb..a0301e472dff 100644 --- a/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py +++ b/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py @@ -39,6 +39,7 @@ def __init__( tokenizer=None, loss_fn=masked_cross_entropy, model_transform=None, + model_accelerator=None, trust_remote_code=False, ): super().__init__() @@ -50,6 +51,7 @@ def __init__( self.load_pretrained_weights = load_pretrained_weights self.is_hf_model = True self.model_transform = model_transform + self.model_accelerator = model_accelerator self.trust_remote_code = trust_remote_code @property @@ -78,6 +80,13 @@ def configure_model(self): config = AutoConfig.from_pretrained(self.model_name, trust_remote_code=self.trust_remote_code) self.model = AutoModelForCausalLM.from_config(config, trust_remote_code=self.trust_remote_code) + + if self.model_accelerator is not None: + if self.model_accelerator == "te": + from nemo.lightning.pytorch.accelerate.transformer_engine import te_accelerate + + te_accelerate(self.model, fp8_autocast=False) + self.model.train() def forward(self, input_ids, attention_mask=None, labels=None, loss_mask=None): From f3785a1917bdd1bb9fb1b739c2e9e4157d45bfcb Mon Sep 17 00:00:00 2001 From: Onur Yilmaz Date: Fri, 22 Nov 2024 15:43:29 -0500 Subject: [PATCH 2/3] Added partial function for te Signed-off-by: Onur Yilmaz --- examples/llm/sft/hf.py | 9 ++++++++- .../llm/gpt/model/hf_auto_model_for_causal_lm.py | 5 +---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/examples/llm/sft/hf.py b/examples/llm/sft/hf.py index 1956258aed89..286253f44f16 100755 --- a/examples/llm/sft/hf.py +++ b/examples/llm/sft/hf.py @@ -75,7 +75,14 @@ def squad(tokenizer) -> pl.LightningDataModule: grad_clip = None use_dist_samp = False - model = llm.HfAutoModelForCausalLM(model_name=args.model, model_accelerator=args.model_accelerator) + model_accelerator = None + if args.model_accelerator == "te": + from functools import partial + from nemo.lightning.pytorch.accelerate.transformer_engine import te_accelerate + model_accelerator = partial(te_accelerate, fp8_autocast=args.fp8_autocast) + + from nemo.lightning.pytorch.accelerate.transformer_engine import te_accelerate + model = llm.HfAutoModelForCausalLM(model_name=args.model, model_accelerator=model_accelerator) tokenizer = model.tokenizer llm.api.finetune( diff --git a/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py b/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py index a0301e472dff..26e4604adc43 100644 --- a/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py +++ b/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py @@ -82,10 +82,7 @@ def configure_model(self): self.model = AutoModelForCausalLM.from_config(config, trust_remote_code=self.trust_remote_code) if self.model_accelerator is not None: - if self.model_accelerator == "te": - from nemo.lightning.pytorch.accelerate.transformer_engine import te_accelerate - - te_accelerate(self.model, fp8_autocast=False) + self.model_accelerator(self.model) self.model.train() From 413c1f49d2b5896a8dadd1410b6ae62f4f8c55fc Mon Sep 17 00:00:00 2001 From: oyilmaz-nvidia Date: Fri, 22 Nov 2024 20:45:27 +0000 Subject: [PATCH 3/3] Apply isort and black reformatting Signed-off-by: oyilmaz-nvidia --- examples/llm/sft/hf.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/llm/sft/hf.py b/examples/llm/sft/hf.py index 286253f44f16..1d282312b130 100755 --- a/examples/llm/sft/hf.py +++ b/examples/llm/sft/hf.py @@ -79,9 +79,11 @@ def squad(tokenizer) -> pl.LightningDataModule: if args.model_accelerator == "te": from functools import partial from nemo.lightning.pytorch.accelerate.transformer_engine import te_accelerate + model_accelerator = partial(te_accelerate, fp8_autocast=args.fp8_autocast) from nemo.lightning.pytorch.accelerate.transformer_engine import te_accelerate + model = llm.HfAutoModelForCausalLM(model_name=args.model, model_accelerator=model_accelerator) tokenizer = model.tokenizer