From 99b2802d3b8108b029e751ea1684e160dd750b49 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Tue, 9 Jul 2024 17:14:01 +0800 Subject: [PATCH] optimize qewn2 memory (#11535) --- .../src/ipex_llm/transformers/models/qwen2.py | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2.py b/python/llm/src/ipex_llm/transformers/models/qwen2.py index de679a2266e..2bd9626e8ab 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2.py @@ -37,6 +37,7 @@ # limitations under the License. # +import os import math from typing import Optional, Tuple, Union, List @@ -55,7 +56,7 @@ from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask_for_sdpa from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from transformers.cache_utils import Cache, DynamicCache +from transformers.cache_utils import Cache from transformers import logging @@ -339,20 +340,9 @@ def merge_qkv(module: torch.nn.Module): del module.q_proj, module.k_proj, module.v_proj - # Qwen2 uses pre-computed rope table to accelerate rope, - # original `cos_cached` and `sin_cached` are added by `register_buffer`, - # so they will move to xpu during `model.to('xpu')`. - # But gpu fuse kernel doesn't need this rope table, only cpu needs them, - # so delete them then add them with `=`, so that they will be pinned on CPU, - # this can save about 0.5GB gpu memory usage when running Qwen2 - if hasattr(module.rotary_emb, "cos_cached"): - cos_cached = module.rotary_emb.cos_cached + if os.environ.get("IPEX_LLM_LOW_MEM", None) == "1": del module.rotary_emb.cos_cached - module.rotary_emb.cos_cached = cos_cached - if hasattr(module.rotary_emb, "sin_cached"): - sin_cached = module.rotary_emb.sin_cached del module.rotary_emb.sin_cached - module.rotary_emb.sin_cached = sin_cached def padding_mlp(module: torch.nn.Module):