From c3fc8273c6ecc5ed19b554e8e5c91af45f6d5ebd Mon Sep 17 00:00:00 2001 From: ddlBoJack Date: Tue, 5 Dec 2023 11:16:49 +0800 Subject: [PATCH 1/2] low_cpu_fsdp for rank0-only problem --- scripts/finetune_echat.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/finetune_echat.sh b/scripts/finetune_echat.sh index 978aa22c..60444499 100644 --- a/scripts/finetune_echat.sh +++ b/scripts/finetune_echat.sh @@ -71,7 +71,7 @@ src/llama_recipes/pipeline/finetune.py \ --freeze_encoder \ --freeze_llm \ --use_fp16 \ ---enable_fsdp \ +--enable_fsdp --low_cpu_fsdp \ --llm_name llama-2-7b-hf \ --llm_path $llm_path \ --encoder_name whisper \ From d5af718960f5490f0c64af9b68a5a6ec3eb2f703 Mon Sep 17 00:00:00 2001 From: ddlBoJack Date: Tue, 5 Dec 2023 11:17:47 +0800 Subject: [PATCH 2/2] low_cpu_fsdp for rank0-only problem --- src/llama_recipes/models/slam_model.py | 40 +++++++++----------------- 1 file changed, 13 insertions(+), 27 deletions(-) diff --git a/src/llama_recipes/models/slam_model.py b/src/llama_recipes/models/slam_model.py index dc1ee149..11fb2db2 100644 --- a/src/llama_recipes/models/slam_model.py +++ b/src/llama_recipes/models/slam_model.py @@ -79,11 +79,11 @@ def setup_llm(train_config, model_config, **kwargs): model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms overhead and currently requires latest nightly. """ - v = packaging.version.parse(torch.__version__) - verify_latest_nightly = v.is_devrelease and v.dev >= 20230701 - if not verify_latest_nightly: - raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, " - "please install latest nightly.") + # v = packaging.version.parse(torch.__version__) + # verify_latest_nightly = v.is_devrelease and v.dev >= 20230701 + # if not verify_latest_nightly: + # raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, " + # "please install latest nightly.") rank = int(os.environ["RANK"]) if rank == 0: model = LlamaForCausalLM.from_pretrained( @@ -95,30 +95,16 @@ def setup_llm(train_config, model_config, **kwargs): else: llama_config = LlamaConfig.from_pretrained(model_config.llm_path) llama_config.use_cache = use_cache - with torch.device("meta"): - model = LlamaForCausalLM(llama_config) + # with torch.device("meta"): + model = LlamaForCausalLM(llama_config) #(FIX:MZY): torch 2.0.1 does not support `meta` else: - if train_config.enable_fsdp: - rank = int(os.environ["RANK"]) - if rank == 0: - model = LlamaForCausalLM.from_pretrained( - model_config.llm_path, - load_in_8bit=True if train_config.quantization else None, - device_map="auto" if train_config.quantization else None, - use_cache=use_cache, - ) - else: - llama_config = LlamaConfig.from_pretrained(model_config.llm_path) - llama_config.use_cache = use_cache - model = LlamaForCausalLM(llama_config) - else: - model = LlamaForCausalLM.from_pretrained( - model_config.llm_path, - load_in_8bit=True if train_config.quantization else None, - device_map="auto" if train_config.quantization else None, - use_cache=use_cache, - ) + model = LlamaForCausalLM.from_pretrained( + model_config.llm_path, + load_in_8bit=True if train_config.quantization else None, + device_map="auto" if train_config.quantization else None, + use_cache=use_cache, + ) if train_config.enable_fsdp and train_config.use_fast_kernels: """ For FSDP and FSDP+PEFT, setting 'use_fast_kernels' will enable