Skip to content

Commit

Permalink
nemo2 peft merge (#11017)
Browse files Browse the repository at this point in the history
* initial draft

Signed-off-by: HuiyingLi <[email protected]>

* refactor wip

Signed-off-by: HuiyingLi <[email protected]>

* refac v2 WIP

Signed-off-by: HuiyingLi <[email protected]>

* update address comments and add model dump

Signed-off-by: HuiyingLi <[email protected]>

* remove merge script

* move driver script

Signed-off-by: HuiyingLi <[email protected]>

* format

Signed-off-by: HuiyingLi <[email protected]>

* format

Signed-off-by: HuiyingLi <[email protected]>

* Apply isort and black reformatting

Signed-off-by: HuiyingLi <[email protected]>

* update with nemo2 main

Signed-off-by: HuiyingLi <[email protected]>

* Apply isort and black reformatting

Signed-off-by: HuiyingLi <[email protected]>

* cleanup import

Signed-off-by: HuiyingLi <[email protected]>

* merge api v3

Signed-off-by: Huiying Li <[email protected]>

* cleanup

Signed-off-by: Huiying Li <[email protected]>

* refac merge func to transform(by ChenCui)

Signed-off-by: Huiying Li <[email protected]>

* read base model from io instead of user input and bug fix

Signed-off-by: Huiying Li <[email protected]>

* Apply isort and black reformatting

Signed-off-by: HuiyingLi <[email protected]>

* add docstring

Signed-off-by: Huiying Li <[email protected]>

* refac

Signed-off-by: Huiying Li <[email protected]>

* add test

Signed-off-by: Huiying Li <[email protected]>

* Apply isort and black reformatting

Signed-off-by: HuiyingLi <[email protected]>

* add trainer to model

Signed-off-by: Huiying Li <[email protected]>

* add copyright

Signed-off-by: Huiying Li <[email protected]>

* clean up

Signed-off-by: Huiying Li <[email protected]>

---------

Signed-off-by: HuiyingLi <[email protected]>
Signed-off-by: HuiyingLi <[email protected]>
Signed-off-by: Huiying Li <[email protected]>
Co-authored-by: HuiyingLi <[email protected]>
  • Loading branch information
HuiyingLi and HuiyingLi authored Nov 21, 2024
1 parent ffdccaf commit 3765580
Show file tree
Hide file tree
Showing 5 changed files with 216 additions and 4 deletions.
11 changes: 11 additions & 0 deletions .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4319,7 +4319,18 @@ jobs:
--mbs 1 \
--model mistral \
--dist-opt
L2_NEMO_2_LoRA_MERGE:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_NEMO_2_LoRA_MERGE') || needs.cicd-test-container-setup.outputs.all == 'true'
with:
RUNNER: self-hosted-azure
SCRIPT: |
python tests/collections/llm/peft/lora_merge.py \
--lora_checkpoint_path=/home/TestData/nemo2_ckpt/llama_lora_ci_checkpoint/ \
--output_path=/tmp/nemo2_lora_merge/${{ github.run_id }}
L2_NeMo_2_NeMo_Mcore_Mixtral_bitexact:
needs: [cicd-test-container-setup]
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/llm/peft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from nemo.collections.llm.peft.api import gpt_lora
from nemo.collections.llm.peft.api import gpt_lora, merge_lora
from nemo.collections.llm.peft.dora import DoRA
from nemo.collections.llm.peft.lora import LoRA

Expand All @@ -23,4 +23,4 @@
"dora": DoRA,
}

__all__ = ["LoRA", "DoRA", "gpt_lora", "PEFT_STR2CLS"]
__all__ = ["LoRA", "DoRA", "gpt_lora", "PEFT_STR2CLS", "merge_lora"]
123 changes: 121 additions & 2 deletions nemo/collections/llm/peft/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,133 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from nemo.collections.llm.peft.lora import LoRA
import json
from pathlib import Path
from typing import Tuple, Union

import pytorch_lightning as pl
from megatron.core import dist_checkpointing
from pytorch_lightning.trainer.states import TrainerFn

from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
from nemo.collections.llm.peft.lora import LoRA, LoRAMerge
from nemo.collections.llm.utils import factory
from nemo.lightning import MegatronStrategy, Trainer, _strategy_lib, io
from nemo.lightning.ckpt_utils import ADAPTER_META_FILENAME, ckpt_to_context_subdir
from nemo.lightning.io.pl import TrainerContext, ckpt_to_weights_subdir
from nemo.lightning.pytorch.callbacks import PEFT
from nemo.lightning.pytorch.callbacks.peft import PEFT
from nemo.lightning.pytorch.strategies.utils import RestoreConfig
from nemo.utils import logging


@factory
def gpt_lora() -> PEFT:
return LoRA()


__all__ = ["gpt_lora"]
def merge_lora(
lora_checkpoint_path: str,
output_path: str,
) -> None:
"""
Merges the LoRA adapter weights into the base model's weights.
Python Usage:
```python
if __name__ == '__main__':
llm.peft.merge_lora(
lora_checkpoint_path=your_lora_checkpoint_path,
output_path=your_output_path,
)
```
Args:
lora_checkpoint_path: The path to the LoRA checkpoint.
output_path: The path to save the merged checkpoint.
"""
from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed

trainer = Trainer(
devices=1,
accelerator="cpu",
strategy=MegatronStrategy(ddp="pytorch", setup_optimizers=False, plugins=bf16_mixed()),
)

model, lora = _load_base_model_and_lora(lora_checkpoint_path)
_setup_trainer_and_restore_model_and_adapter(Path(lora_checkpoint_path), trainer, model, lora)

lora_merge = LoRAMerge()
merged_model = lora_merge(trainer.strategy.megatron_parallel)
merged_weights = {k: v for k, v in merged_model.sharded_state_dict().items() if ".adapter." not in k}
_save_merged_weight(output_path, merged_weights, model, trainer)


def _load_base_model_and_lora(lora_checkpoint_path: Path) -> Tuple[pl.LightningModule, LoRA]:
model = io.load_context(ckpt_to_context_subdir(lora_checkpoint_path), "model")
model.model_transform, model.__io__.model_transform = None, None
model.config.bf16 = False
lora: Union[io.TrainerContext, LoRA] = io.load_context(
ckpt_to_context_subdir(lora_checkpoint_path), "model.model_transform"
)
assert isinstance(lora, LoRA), "LoRA config not found in checkpoint"
return model, lora


def _setup_trainer_and_restore_model_and_adapter(
lora_checkpoint_path: Path, trainer: Trainer, model: pl.LightningModule, lora: LoRA
) -> None:
if (
adapter_meta_path := ckpt_to_weights_subdir(lora_checkpoint_path, is_saving=False) / ADAPTER_META_FILENAME
).exists():
with open(adapter_meta_path, "r") as f:
metadata = json.load(f)
restore_config = RestoreConfig(
path=metadata["model_ckpt_path"],
load_model_state=True,
load_optim_state=False,
)
else:
raise ValueError(f"Cannot find adapter meta file in {lora_checkpoint_path}")

trainer.strategy.restore_config = restore_config
trainer.strategy._setup_optimizers = False
trainer.ckpt_path = None
trainer.strategy.connect(model)
trainer.strategy.setup_environment()

if not model.state_dict():
with _strategy_lib.megatron_cpu_init_context(model.config):
model.configure_model()

trainer.strategy.setup(trainer) # load base model ckpt
trainer.state.fn = TrainerFn.TESTING
trainer.strategy.setup_megatron_parallel(trainer=trainer)
trainer.strategy.trainer = trainer
model.trainer = trainer

lora(model)
adapter_sharded_state_dict = {
k: v for k, v in trainer.strategy.megatron_parallel.sharded_state_dict().items() if ".adapter." in k
}
adapter_state = trainer.strategy.checkpoint_io.load_checkpoint(
ckpt_to_weights_subdir(lora_checkpoint_path, is_saving=False), sharded_state_dict=adapter_sharded_state_dict
)
trainer.strategy.load_model_state_dict(adapter_state, strict=False)


def _save_merged_weight(output_path: str, merged_weights: dict, model: pl.LightningModule, trainer: Trainer):
weight_path = ckpt_to_weights_subdir(output_path, is_saving=True)
Path(weight_path).mkdir(parents=True, exist_ok=True)
dist_checkpointing.save(merged_weights, str(ckpt_to_weights_subdir(output_path, is_saving=True)))
if hasattr(model.tokenizer, "save_pretrained"):
model.tokenizer.save_pretrained("/tmp/nemo_tokenizer")
model.tokenizer = AutoTokenizer("/tmp/nemo_tokenizer")
if hasattr(trainer.model, "__io__") and hasattr(trainer.model.tokenizer, '__io__'):
trainer.model.__io__.tokenizer = trainer.model.tokenizer.__io__
TrainerContext.from_trainer(trainer).io_dump(ckpt_to_context_subdir(output_path), yaml_attrs=["model"])
logging.info(f"Merged checkpoint saved to {output_path}")


__all__ = ["gpt_lora", "merge_lora"]
40 changes: 40 additions & 0 deletions nemo/collections/llm/peft/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,43 @@ def wildcard_match(pattern, key):
)
return AdapterParallelAdd(m, adapter)
return m


class LoRAMerge(PEFT):
"""
Implements the LoRA weight merge for parameter-efficient fine-tuning.
Example:
--------
>>> from nemo.collections.llm.peft.lora import LoRAMerge
>>> lora_merge = LoRAMerge()
>>> merged_model = lora_merge(trainer.strategy.megatron_parallel)
"""

@torch.no_grad()
def transform(self, m: nn.Module, name=None, prefix=None):
"""
Merges the LoRA adapter with the base model weights.
Args:
m (nn.Module): The module to apply LoRA merge to.
name (str, optional): Name of the module to merge. Defaults to None.
prefix (str, optional): Prefix for the module name. Defaults to None.
Returns:
nn.Module: The modified module with the LoRA adapter merged into the base model weights.
"""

if not isinstance(m, AdapterParallelAdd):
return m
logging.info(f'merging {(prefix if prefix else "") + "." + (name if name else "")}')
base_weight = m.to_wrap.weight
lora_weight = (
m.adapter.alpha
/ m.adapter.dim
* m.adapter.linear_out.weight.to(base_weight.device)
@ m.adapter.linear_in.weight.to(base_weight.device)
)
merged_weight = base_weight + lora_weight
m.to_wrap.weight.data = merged_weight
return m
42 changes: 42 additions & 0 deletions tests/collections/llm/peft/lora_merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from dataclasses import dataclass

from nemo.collections import llm


@dataclass
class Llama3ConfigCI(llm.Llama3Config8B):
seq_length: int = 2048
num_layers: int = 2
hidden_size: int = 768
ffn_hidden_size: int = 3072
num_attention_heads: int = 8


def get_args():
parser = argparse.ArgumentParser(description='Merge LoRA weights with base LLM')
parser.add_argument('--lora_checkpoint_path', type=str, help="Path to finetuned LORA checkpoint")
parser.add_argument('--output_path', type=str, help="Path to save merged checkpoint")
return parser.parse_args()


if __name__ == '__main__':
args = get_args()

llm.peft.merge_lora(
lora_checkpoint_path=args.lora_checkpoint_path,
output_path=args.output_path,
)

0 comments on commit 3765580

Please sign in to comment.