From 2bea564dae4e086e7708f4d3f2b5118a38f02c7d Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Thu, 20 Jun 2024 15:47:43 +0000 Subject: [PATCH 1/5] switch instructlab dolomite Signed-off-by: Yu Chin Fabian Lim --- requirements.txt | 2 +- src/instructlab/training/main_ds.py | 8 +++----- src/instructlab/training/utils.py | 2 +- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/requirements.txt b/requirements.txt index 1bbec7e4..fb84950e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,7 +9,7 @@ datasets>=2.15.0 numba numpy rich -dolomite-engine @ git+https://github.com/ibm-granite/dolomite-engine.git@main +instructlab-dolomite @ git+https://github.com/instructlab/GPTDolomite.git@main trl>=0.9.4 peft pydantic>=2.7.0 diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index eeb0c077..4d18ba7d 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -43,6 +43,9 @@ setup_logger, ) import instructlab.training.data_process as dp +from instructlab.dolomite.hf_models import GPTDolomiteForCausalLM +from instructlab.dolomite.enums import GradientCheckpointingMethod +from instructlab.dolomite.gradient_checkpointing import apply_gradient_checkpointing def get_ds_config(world_size, samples_per_gpu, grad_accum, opts: DeepSpeedOptions): @@ -88,8 +91,6 @@ def setup_model(args, tokenizer, train_loader, grad_accum): ) if args.is_granite: - # Third Party - from dolomite_engine.hf_models.models import GPTDolomiteForCausalLM model = GPTDolomiteForCausalLM.from_pretrained( args.model_name_or_path, @@ -201,9 +202,6 @@ def setup_model(args, tokenizer, train_loader, grad_accum): # granite gradient checkpointing is handled uniformly # for both lora and full here if args.is_granite: - # Third Party - from dolomite_engine.enums import GradientCheckpointingMethod - from dolomite_engine.gradient_checkpointing import apply_gradient_checkpointing block_name = model._no_split_modules[0] apply_gradient_checkpointing( diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py index 6feaa548..a52fe9a0 100644 --- a/src/instructlab/training/utils.py +++ b/src/instructlab/training/utils.py @@ -13,6 +13,7 @@ import warnings # Third Party +from instructlab.dolomite.hf_models import export_to_huggingface from rich.logging import RichHandler from torch import distributed as dist from torch.distributed import get_rank, is_initialized @@ -539,7 +540,6 @@ def save_hf_format_ds(args, model, tokenizer, samples_seen, convert_granite=True from tempfile import TemporaryDirectory # Third Party - from dolomite_engine.hf_models import export_to_huggingface from safetensors.torch import save_file with TemporaryDirectory("w") as tmpdir: From 8a2f5ce13c379e510877a9c44ab899a0330b25c7 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Fri, 21 Jun 2024 15:38:13 +0000 Subject: [PATCH 2/5] fix fmt Signed-off-by: Yu Chin Fabian Lim --- src/instructlab/training/main_ds.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 4d18ba7d..cb216ace 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -11,6 +11,9 @@ # Third Party from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam from deepspeed.runtime.zero.utils import ZeRORuntimeException +from instructlab.dolomite.enums import GradientCheckpointingMethod +from instructlab.dolomite.gradient_checkpointing import apply_gradient_checkpointing +from instructlab.dolomite.hf_models import GPTDolomiteForCausalLM from torch.distributed import ReduceOp, all_reduce from tqdm import tqdm from transformers import AutoModelForCausalLM, get_scheduler @@ -43,9 +46,6 @@ setup_logger, ) import instructlab.training.data_process as dp -from instructlab.dolomite.hf_models import GPTDolomiteForCausalLM -from instructlab.dolomite.enums import GradientCheckpointingMethod -from instructlab.dolomite.gradient_checkpointing import apply_gradient_checkpointing def get_ds_config(world_size, samples_per_gpu, grad_accum, opts: DeepSpeedOptions): @@ -91,7 +91,6 @@ def setup_model(args, tokenizer, train_loader, grad_accum): ) if args.is_granite: - model = GPTDolomiteForCausalLM.from_pretrained( args.model_name_or_path, attn_implementation="flash_attention_2", @@ -202,7 +201,6 @@ def setup_model(args, tokenizer, train_loader, grad_accum): # granite gradient checkpointing is handled uniformly # for both lora and full here if args.is_granite: - block_name = model._no_split_modules[0] apply_gradient_checkpointing( model, From 34ab9a24d9c951a9fd8acacd91c137402f9fe575 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Fri, 21 Jun 2024 15:40:16 +0000 Subject: [PATCH 3/5] ignore duplicate code check Signed-off-by: Yu Chin Fabian Lim --- .pylintrc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.pylintrc b/.pylintrc index 693a6bc1..77366973 100644 --- a/.pylintrc +++ b/.pylintrc @@ -471,7 +471,8 @@ disable=raw-checker-failed, dangerous-default-value, consider-using-generator, broad-exception-caught, - super-init-not-called + super-init-not-called, + duplicate-code # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option From 5c5fc82df17915387a9ed5e97c38e52d8b8e089f Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Sat, 22 Jun 2024 07:34:27 +0000 Subject: [PATCH 4/5] moved granite checkpointing Signed-off-by: Yu Chin Fabian Lim --- src/instructlab/training/main_ds.py | 4 +- src/instructlab/training/utils.py | 60 +++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 3 deletions(-) diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index cb216ace..ffb9286f 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -11,8 +11,6 @@ # Third Party from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam from deepspeed.runtime.zero.utils import ZeRORuntimeException -from instructlab.dolomite.enums import GradientCheckpointingMethod -from instructlab.dolomite.gradient_checkpointing import apply_gradient_checkpointing from instructlab.dolomite.hf_models import GPTDolomiteForCausalLM from torch.distributed import ReduceOp, all_reduce from tqdm import tqdm @@ -36,6 +34,7 @@ from instructlab.training.utils import ( StreamablePopen, add_noisy_embeddings, + apply_gradient_checkpointing, convert_loss_to_reduce_sum, patch_target_module, prepare_peft_model, @@ -204,7 +203,6 @@ def setup_model(args, tokenizer, train_loader, grad_accum): block_name = model._no_split_modules[0] apply_gradient_checkpointing( model, - GradientCheckpointingMethod.block, block_name=block_name, use_reentrant=True, # this should be the HF default mode ) diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py index a52fe9a0..43a5c4ae 100644 --- a/src/instructlab/training/utils.py +++ b/src/instructlab/training/utils.py @@ -1,4 +1,5 @@ # Standard +from functools import partial from pathlib import Path from typing import Any, List, Optional import importlib @@ -17,6 +18,11 @@ from rich.logging import RichHandler from torch import distributed as dist from torch.distributed import get_rank, is_initialized +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + CheckpointImpl, + apply_activation_checkpointing, + checkpoint_wrapper, +) from torch.distributed.fsdp import FullStateDictConfig from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType @@ -454,6 +460,60 @@ class UniversalCheckpointArgs: log_rank_0(f"Preparing universal checkpoint took {time.time() - start} seconds") +# this function is for supporting gradient checkpointing for padding free +# dolomite +def apply_gradient_checkpointing( + model: torch.nn.Module, + **kwargs, +) -> None: + def get_module_class_from_name( + model: torch.nn.Module, name: str + ) -> List[torch.nn.Module]: + modules_children = list(model.children()) + + if model.__class__.__name__ == name: + return model.__class__ + elif len(modules_children) == 0: + return + else: + for child_module in modules_children: + module_class = get_module_class_from_name(child_module, name) + if module_class is not None: + return module_class + + def block_checkpointing( + model: torch.nn.Module, + block_name: str, + checkpoint_every: int = 1, + use_reentrant: bool = False, + ) -> None: + block_class = get_module_class_from_name(model, block_name) + block_idx = 0 + + def _whether_to_checkpoint(submodule: torch.nn.Module) -> bool: + nonlocal block_idx + + if isinstance(submodule, block_class): + block_idx += 1 + if (block_idx - 1) % checkpoint_every == 0: + return True + return False + + checkpoint_wrapper_function = checkpoint_wrapper + if use_reentrant: + checkpoint_wrapper_function = partial( + checkpoint_wrapper, checkpoint_impl=CheckpointImpl.REENTRANT + ) + + apply_activation_checkpointing( + model, + checkpoint_wrapper_fn=checkpoint_wrapper_function, + check_fn=_whether_to_checkpoint, + ) + + block_checkpointing(model, **kwargs) + + def setup_logger(level="DEBUG"): logging.basicConfig( level=level, format="%(message)s", datefmt="[%X]", handlers=[RichHandler()] From f07bee0a3bca4da463572e44198ef2ff75c204c8 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Sun, 23 Jun 2024 02:17:56 +0000 Subject: [PATCH 5/5] switch to pypi repository Signed-off-by: Yu Chin Fabian Lim --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index fb84950e..97ae805d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,7 +9,7 @@ datasets>=2.15.0 numba numpy rich -instructlab-dolomite @ git+https://github.com/instructlab/GPTDolomite.git@main +instructlab-dolomite trl>=0.9.4 peft pydantic>=2.7.0