diff --git a/src/llama_recipes/models/slam_model.py b/src/llama_recipes/models/slam_model.py index dc1ee149..11fb2db2 100644 --- a/src/llama_recipes/models/slam_model.py +++ b/src/llama_recipes/models/slam_model.py @@ -79,11 +79,11 @@ def setup_llm(train_config, model_config, **kwargs): model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms overhead and currently requires latest nightly. """ - v = packaging.version.parse(torch.__version__) - verify_latest_nightly = v.is_devrelease and v.dev >= 20230701 - if not verify_latest_nightly: - raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, " - "please install latest nightly.") + # v = packaging.version.parse(torch.__version__) + # verify_latest_nightly = v.is_devrelease and v.dev >= 20230701 + # if not verify_latest_nightly: + # raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, " + # "please install latest nightly.") rank = int(os.environ["RANK"]) if rank == 0: model = LlamaForCausalLM.from_pretrained( @@ -95,30 +95,16 @@ def setup_llm(train_config, model_config, **kwargs): else: llama_config = LlamaConfig.from_pretrained(model_config.llm_path) llama_config.use_cache = use_cache - with torch.device("meta"): - model = LlamaForCausalLM(llama_config) + # with torch.device("meta"): + model = LlamaForCausalLM(llama_config) #(FIX:MZY): torch 2.0.1 does not support `meta` else: - if train_config.enable_fsdp: - rank = int(os.environ["RANK"]) - if rank == 0: - model = LlamaForCausalLM.from_pretrained( - model_config.llm_path, - load_in_8bit=True if train_config.quantization else None, - device_map="auto" if train_config.quantization else None, - use_cache=use_cache, - ) - else: - llama_config = LlamaConfig.from_pretrained(model_config.llm_path) - llama_config.use_cache = use_cache - model = LlamaForCausalLM(llama_config) - else: - model = LlamaForCausalLM.from_pretrained( - model_config.llm_path, - load_in_8bit=True if train_config.quantization else None, - device_map="auto" if train_config.quantization else None, - use_cache=use_cache, - ) + model = LlamaForCausalLM.from_pretrained( + model_config.llm_path, + load_in_8bit=True if train_config.quantization else None, + device_map="auto" if train_config.quantization else None, + use_cache=use_cache, + ) if train_config.enable_fsdp and train_config.use_fast_kernels: """ For FSDP and FSDP+PEFT, setting 'use_fast_kernels' will enable