Skip to content

Commit

Permalink
Merge pull request #7 from ddlBoJack/dev-mzy
Browse files Browse the repository at this point in the history
low_cpu_fsdp for rank0-only problem
  • Loading branch information
ddlBoJack authored Dec 5, 2023
2 parents 023a797 + d5af718 commit 2bcb3c8
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 28 deletions.
2 changes: 1 addition & 1 deletion scripts/finetune_echat.sh
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ src/llama_recipes/pipeline/finetune.py \
--freeze_encoder \
--freeze_llm \
--use_fp16 \
--enable_fsdp \
--enable_fsdp --low_cpu_fsdp \
--llm_name llama-2-7b-hf \
--llm_path $llm_path \
--encoder_name whisper \
Expand Down
40 changes: 13 additions & 27 deletions src/llama_recipes/models/slam_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,11 @@ def setup_llm(train_config, model_config, **kwargs):
model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms
overhead and currently requires latest nightly.
"""
v = packaging.version.parse(torch.__version__)
verify_latest_nightly = v.is_devrelease and v.dev >= 20230701
if not verify_latest_nightly:
raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, "
"please install latest nightly.")
# v = packaging.version.parse(torch.__version__)
# verify_latest_nightly = v.is_devrelease and v.dev >= 20230701
# if not verify_latest_nightly:
# raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, "
# "please install latest nightly.")
rank = int(os.environ["RANK"])
if rank == 0:
model = LlamaForCausalLM.from_pretrained(
Expand All @@ -95,30 +95,16 @@ def setup_llm(train_config, model_config, **kwargs):
else:
llama_config = LlamaConfig.from_pretrained(model_config.llm_path)
llama_config.use_cache = use_cache
with torch.device("meta"):
model = LlamaForCausalLM(llama_config)
# with torch.device("meta"):
model = LlamaForCausalLM(llama_config) #(FIX:MZY): torch 2.0.1 does not support `meta`

else:
if train_config.enable_fsdp:
rank = int(os.environ["RANK"])
if rank == 0:
model = LlamaForCausalLM.from_pretrained(
model_config.llm_path,
load_in_8bit=True if train_config.quantization else None,
device_map="auto" if train_config.quantization else None,
use_cache=use_cache,
)
else:
llama_config = LlamaConfig.from_pretrained(model_config.llm_path)
llama_config.use_cache = use_cache
model = LlamaForCausalLM(llama_config)
else:
model = LlamaForCausalLM.from_pretrained(
model_config.llm_path,
load_in_8bit=True if train_config.quantization else None,
device_map="auto" if train_config.quantization else None,
use_cache=use_cache,
)
model = LlamaForCausalLM.from_pretrained(
model_config.llm_path,
load_in_8bit=True if train_config.quantization else None,
device_map="auto" if train_config.quantization else None,
use_cache=use_cache,
)
if train_config.enable_fsdp and train_config.use_fast_kernels:
"""
For FSDP and FSDP+PEFT, setting 'use_fast_kernels' will enable
Expand Down

0 comments on commit 2bcb3c8

Please sign in to comment.