Skip to content

Commit

Permalink
optimize qewn2 memory (#11535)
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 authored Jul 9, 2024
1 parent 2929eb2 commit 99b2802
Showing 1 changed file with 3 additions and 13 deletions.
16 changes: 3 additions & 13 deletions python/llm/src/ipex_llm/transformers/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
# limitations under the License.
#

import os
import math
from typing import Optional, Tuple, Union, List

Expand All @@ -55,7 +56,7 @@
from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask_for_sdpa
from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.cache_utils import Cache, DynamicCache
from transformers.cache_utils import Cache
from transformers import logging


Expand Down Expand Up @@ -339,20 +340,9 @@ def merge_qkv(module: torch.nn.Module):

del module.q_proj, module.k_proj, module.v_proj

# Qwen2 uses pre-computed rope table to accelerate rope,
# original `cos_cached` and `sin_cached` are added by `register_buffer`,
# so they will move to xpu during `model.to('xpu')`.
# But gpu fuse kernel doesn't need this rope table, only cpu needs them,
# so delete them then add them with `=`, so that they will be pinned on CPU,
# this can save about 0.5GB gpu memory usage when running Qwen2
if hasattr(module.rotary_emb, "cos_cached"):
cos_cached = module.rotary_emb.cos_cached
if os.environ.get("IPEX_LLM_LOW_MEM", None) == "1":
del module.rotary_emb.cos_cached
module.rotary_emb.cos_cached = cos_cached
if hasattr(module.rotary_emb, "sin_cached"):
sin_cached = module.rotary_emb.sin_cached
del module.rotary_emb.sin_cached
module.rotary_emb.sin_cached = sin_cached


def padding_mlp(module: torch.nn.Module):
Expand Down

0 comments on commit 99b2802

Please sign in to comment.