Skip to content

Commit

Permalink
optimize llama 3.2 rope (#12128)
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 authored Sep 26, 2024
1 parent 584c348 commit a266528
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions python/llm/src/ipex_llm/transformers/models/llama32.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from ipex_llm.utils.common import invalidInputError
from ipex_llm.transformers.models.common import attention_softmax
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
from ipex_llm.transformers.models.utils import should_use_fuse_rope
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache

Expand Down Expand Up @@ -111,6 +112,12 @@ def llama_model_forward(
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)

# IPEX-LLM OPT start: use fused rope
if (should_use_fuse_rope(hidden_states, position_ids, False)
and self.rotary_emb.rope_type == "llama3"):
position_embeddings = self.rotary_emb.inv_freq
# IEPX_LLM OPT end

# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
Expand Down Expand Up @@ -179,11 +186,16 @@ def llama_attention_forward(
self.num_key_value_heads,
self.num_key_value_heads], dim=1)

if position_embeddings is None:
cos, sin = self.rotary_emb(value_states, position_ids)
if isinstance(position_embeddings, torch.Tensor):
import xe_addons
inv_freq = position_embeddings
xe_addons.rotary_half_inplaced(inv_freq, position_ids, query_states, key_states)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if position_embeddings is None:
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

if past_key_value is not None:
key_states, value_states = past_key_value.update(key_states, value_states,
Expand Down

0 comments on commit a266528

Please sign in to comment.