diff --git a/examples/llm/sft/hf.py b/examples/llm/sft/hf.py index 59b8b4ad3491..1d282312b130 100755 --- a/examples/llm/sft/hf.py +++ b/examples/llm/sft/hf.py @@ -19,7 +19,7 @@ from nemo import lightning as nl from nemo.collections import llm -from nemo.lightning.pytorch.accelerate.transformer_engine import is_te_accelerated, te_accelerate +from nemo.lightning.pytorch.accelerate.transformer_engine import is_te_accelerated from nemo.lightning.pytorch.callbacks import ModelCallback @@ -75,16 +75,17 @@ def squad(tokenizer) -> pl.LightningDataModule: grad_clip = None use_dist_samp = False - model = llm.HfAutoModelForCausalLM(args.model) - tokenizer = model.tokenizer + model_accelerator = None + if args.model_accelerator == "te": + from functools import partial + from nemo.lightning.pytorch.accelerate.transformer_engine import te_accelerate - callbacks = [] - if args.model_accelerator: - if args.model_accelerator == "te": - model_transform = ModelCallback( - on_train_start=lambda model: te_accelerate(model, fp8_autocast=args.fp8_autocast) - ) - callbacks.append(model_transform) + model_accelerator = partial(te_accelerate, fp8_autocast=args.fp8_autocast) + + from nemo.lightning.pytorch.accelerate.transformer_engine import te_accelerate + + model = llm.HfAutoModelForCausalLM(model_name=args.model, model_accelerator=model_accelerator) + tokenizer = model.tokenizer llm.api.finetune( model=model, @@ -100,7 +101,7 @@ def squad(tokenizer) -> pl.LightningDataModule: accumulate_grad_batches=10, gradient_clip_val=grad_clip, use_distributed_sampler=use_dist_samp, - callbacks=callbacks, + callbacks=[], logger=wandb, ), optim=fdl.build(llm.adam.pytorch_adam_with_flat_lr(lr=1e-5)), diff --git a/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py b/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py index c0f02d706ceb..26e4604adc43 100644 --- a/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py +++ b/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py @@ -39,6 +39,7 @@ def __init__( tokenizer=None, loss_fn=masked_cross_entropy, model_transform=None, + model_accelerator=None, trust_remote_code=False, ): super().__init__() @@ -50,6 +51,7 @@ def __init__( self.load_pretrained_weights = load_pretrained_weights self.is_hf_model = True self.model_transform = model_transform + self.model_accelerator = model_accelerator self.trust_remote_code = trust_remote_code @property @@ -78,6 +80,10 @@ def configure_model(self): config = AutoConfig.from_pretrained(self.model_name, trust_remote_code=self.trust_remote_code) self.model = AutoModelForCausalLM.from_config(config, trust_remote_code=self.trust_remote_code) + + if self.model_accelerator is not None: + self.model_accelerator(self.model) + self.model.train() def forward(self, input_ids, attention_mask=None, labels=None, loss_mask=None):