Skip to content

Commit

Permalink
add sdp causal support in llama (#11705)
Browse files Browse the repository at this point in the history
  • Loading branch information
cyita authored Aug 2, 2024
1 parent 736a7ef commit 8d1e0bd
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions python/llm/src/ipex_llm/transformers/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1497,6 +1497,13 @@ def llama_attention_forward_4_41_original(
value_states.to(device, dtype=torch.float16),
is_causal=True)
attn_weights = None
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim,
query_states, self.training):
import xe_addons
attn_output = xe_addons.sdp_causal(query_states, key_states.contiguous(),
value_states.contiguous(), new_attention_mask)
attn_output = attn_output.view(query_states.shape)
attn_weights = None
elif not self.training and not hidden_states.requires_grad and \
use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
import xe_addons
Expand Down Expand Up @@ -2040,6 +2047,13 @@ def llama_attention_forward_4_38_original(
value_states.to(device, dtype=torch.float16),
is_causal=True)
attn_weights = None
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim,
query_states, self.training):
import xe_addons
attn_output = xe_addons.sdp_causal(query_states, key_states.contiguous(),
value_states.contiguous(), new_attention_mask)
attn_output = attn_output.view(query_states.shape)
attn_weights = None
elif not self.training and not hidden_states.requires_grad and \
use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
import xe_addons
Expand Down

0 comments on commit 8d1e0bd

Please sign in to comment.