From 5cb86bb61c3acfc45d1801f58c8f5b39d81d3259 Mon Sep 17 00:00:00 2001 From: lalalapotter Date: Wed, 10 Jul 2024 09:47:03 +0800 Subject: [PATCH 1/2] LLM: unify memory optimization env variables. --- python/llm/src/ipex_llm/transformers/convert.py | 10 +++++++--- python/llm/src/ipex_llm/transformers/models/llama.py | 2 ++ python/llm/src/ipex_llm/transformers/models/mistral.py | 2 ++ 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index c419eebf65b..31405577240 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -327,9 +327,13 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, optimize_lm_head = False if is_lm_head(name, model_config, out_features): model_type = getattr(model_config, "model_type", None) - if model_type in ["gptj", "llama", "qwen2"] and \ - os.environ.get("IPEX_LLM_LAST_LM_HEAD", None) == "1": - optimize_lm_head = True + if model_type in ["gptj", "llama", "qwen2"]: + if os.environ.get("IPEX_LLM_LAST_LM_HEAD", None) is not None: + optimize_lm_head = os.environ.get("IPEX_LLM_LAST_LM_HEAD", None) == "1" + elif os.environ.get("IPEX_LLM_LOW_MEM", None) is not None: + optimize_lm_head = os.environ.get("IPEX_LLM_LOW_MEM", None) == "1" + else: + optimize_lm_head = False with init_empty_weights(): new_linear = None is_gptq = is_gptq_linear(module) diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index ac81c5a4dc7..11425b3939e 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -286,6 +286,8 @@ def should_split_qkv_tensor(query_states, bsz, num_heads, q_len, kv_seq_len, out if not output_attentions: if os.environ.get("IPEX_LLM_SPLIT_QKV", None) is not None: return os.environ.get("IPEX_LLM_SPLIT_QKV", None) == "1" + elif os.environ.get("IPEX_LLM_LOW_MEM", None) is not None: + return os.environ.get("IPEX_LLM_LOW_MEM", None) == "1" elif query_states.dtype == torch.float16 and \ query_states.shape[2] >= 6800: # split tensor for memory block limitation diff --git a/python/llm/src/ipex_llm/transformers/models/mistral.py b/python/llm/src/ipex_llm/transformers/models/mistral.py index 3f282295873..1891f982395 100644 --- a/python/llm/src/ipex_llm/transformers/models/mistral.py +++ b/python/llm/src/ipex_llm/transformers/models/mistral.py @@ -92,6 +92,8 @@ def should_split_qkv_tensor(query_states, bsz, num_heads, q_len, kv_seq_len, out if not output_attentions: if os.environ.get("IPEX_LLM_SPLIT_QKV", None) is not None: return os.environ.get("IPEX_LLM_SPLIT_QKV", None) == "1" + elif os.environ.get("IPEX_LLM_LOW_MEM", None) is not None: + return os.environ.get("IPEX_LLM_LOW_MEM", None) == "1" elif query_states.dtype == torch.float16 and \ query_states.shape[2] >= 6300: # split tensor for memory block limitation From 256a4e1fedd18c3e7af04392f4bde80bc0bc04cf Mon Sep 17 00:00:00 2001 From: lalalapotter Date: Thu, 11 Jul 2024 10:52:45 +0800 Subject: [PATCH 2/2] fix comments. --- python/llm/src/ipex_llm/transformers/convert.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 31405577240..42a9acfd7a2 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -332,8 +332,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, optimize_lm_head = os.environ.get("IPEX_LLM_LAST_LM_HEAD", None) == "1" elif os.environ.get("IPEX_LLM_LOW_MEM", None) is not None: optimize_lm_head = os.environ.get("IPEX_LLM_LOW_MEM", None) == "1" - else: - optimize_lm_head = False with init_empty_weights(): new_linear = None is_gptq = is_gptq_linear(module)