Skip to content

Commit

Permalink
low_cpu_fsdp for rank0-only problem
Browse files Browse the repository at this point in the history
  • Loading branch information
ddlBoJack committed Dec 5, 2023
1 parent c3fc827 commit d5af718
Showing 1 changed file with 13 additions and 27 deletions.
40 changes: 13 additions & 27 deletions src/llama_recipes/models/slam_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,11 @@ def setup_llm(train_config, model_config, **kwargs):
model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms
overhead and currently requires latest nightly.
"""
v = packaging.version.parse(torch.__version__)
verify_latest_nightly = v.is_devrelease and v.dev >= 20230701
if not verify_latest_nightly:
raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, "
"please install latest nightly.")
# v = packaging.version.parse(torch.__version__)
# verify_latest_nightly = v.is_devrelease and v.dev >= 20230701
# if not verify_latest_nightly:
# raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, "
# "please install latest nightly.")
rank = int(os.environ["RANK"])
if rank == 0:
model = LlamaForCausalLM.from_pretrained(
Expand All @@ -95,30 +95,16 @@ def setup_llm(train_config, model_config, **kwargs):
else:
llama_config = LlamaConfig.from_pretrained(model_config.llm_path)
llama_config.use_cache = use_cache
with torch.device("meta"):
model = LlamaForCausalLM(llama_config)
# with torch.device("meta"):
model = LlamaForCausalLM(llama_config) #(FIX:MZY): torch 2.0.1 does not support `meta`

else:
if train_config.enable_fsdp:
rank = int(os.environ["RANK"])
if rank == 0:
model = LlamaForCausalLM.from_pretrained(
model_config.llm_path,
load_in_8bit=True if train_config.quantization else None,
device_map="auto" if train_config.quantization else None,
use_cache=use_cache,
)
else:
llama_config = LlamaConfig.from_pretrained(model_config.llm_path)
llama_config.use_cache = use_cache
model = LlamaForCausalLM(llama_config)
else:
model = LlamaForCausalLM.from_pretrained(
model_config.llm_path,
load_in_8bit=True if train_config.quantization else None,
device_map="auto" if train_config.quantization else None,
use_cache=use_cache,
)
model = LlamaForCausalLM.from_pretrained(
model_config.llm_path,
load_in_8bit=True if train_config.quantization else None,
device_map="auto" if train_config.quantization else None,
use_cache=use_cache,
)
if train_config.enable_fsdp and train_config.use_fast_kernels:
"""
For FSDP and FSDP+PEFT, setting 'use_fast_kernels' will enable
Expand Down

0 comments on commit d5af718

Please sign in to comment.