Skip to content

Commit

Permalink
Use sdp when rest token seq_len > 1 in llama & mistral (for lookup & …
Browse files Browse the repository at this point in the history
…spec) (#10790)

* update sdp condition

* update

* fix

* update & test llama

* mistral

* fix style

* update

* fix style

* remove pvc constrain

* update ds on arc

* fix style
  • Loading branch information
cyita authored Apr 24, 2024
1 parent 844e18b commit dc27b3b
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 11 deletions.
11 changes: 6 additions & 5 deletions python/llm/src/ipex_llm/transformers/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
apply_rotary_pos_emb, is_enough_kv_cache_room_4_36
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
from ipex_llm.transformers.models.utils import use_flash_attention, use_new_esimd_sdp_fp16, \
use_sdp_fp8
from ipex_llm.transformers.models.utils import mlp_fusion_check, fp16_fusion_check
from ipex_llm.transformers.models.utils import use_decoding_fast_path
from transformers.modeling_outputs import BaseModelOutputWithPast
Expand Down Expand Up @@ -449,7 +450,7 @@ def llama_attention_forward_4_31_quantized(
kv_seq_len = key_states.shape[-2]
past_key_value = (key_states, value_states)

if query_states.size(2) != 1 or query_states.device.type != 'xpu':
if not use_sdp_fp8(q_len, key_states.shape[2], query_states):
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype)
# repeat k/v heads if n_kv_heads < n_heads
Expand Down Expand Up @@ -666,7 +667,7 @@ def llama_attention_forward_4_31_original(
is_causal=True)
attn_weights = None
elif not self.training and not hidden_states.requires_grad and \
use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states, attention_mask):
use_new_esimd_sdp_fp16(q_len, key_states.shape[2], self.head_dim, query_states):
import linear_q4_0
attn_output = linear_q4_0.sdp_fp16(query_states, key_states, value_states, attention_mask)
attn_output = attn_output.view(query_states.shape)
Expand Down Expand Up @@ -1074,7 +1075,7 @@ def llama_attention_forward_4_36_quantized(
self.layer_idx, cache_kwargs,
new_layout=True)
kv_seq_len = key_states.shape[-2]
if query_states.size(2) != 1 or query_states.device.type != 'xpu':
if not use_sdp_fp8(q_len, key_states.shape[2], query_states):
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype)
key_states = repeat_kv(key_states, self.num_key_value_groups)\
Expand Down Expand Up @@ -1342,7 +1343,7 @@ def llama_attention_forward_4_36_original(
is_causal=True)
attn_weights = None
elif not self.training and not hidden_states.requires_grad and \
use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
use_new_esimd_sdp_fp16(q_len, key_states.shape[2], self.head_dim, query_states):
import linear_q4_0
attn_output = linear_q4_0.sdp_fp16(query_states, key_states, value_states, attention_mask)
attn_output = attn_output.view(query_states.shape)
Expand Down
11 changes: 6 additions & 5 deletions python/llm/src/ipex_llm/transformers/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
is_enough_kv_cache_room_4_36
from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS
from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
from ipex_llm.transformers.models.utils import use_flash_attention, use_new_esimd_sdp_fp16, \
use_sdp_fp8
from ipex_llm.transformers.models.utils import use_decoding_fast_path
from ipex_llm.transformers.models.llama import llama_decoding_fast_path_qtype_check
from ipex_llm.transformers.models.llama import should_use_xetla_mm_qkv
Expand Down Expand Up @@ -310,7 +311,7 @@ def mistral_attention_forward_quantized(
kv_seq_len = key_states.shape[-2]
past_key_value = (key_states, value_states)

if query_states.size(2) != 1 or query_states.device.type != 'xpu':
if not use_sdp_fp8(q_len, key_states.shape[2], query_states):
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
Expand Down Expand Up @@ -503,7 +504,7 @@ def mistral_attention_forward_original(
attn_weights = None
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
elif use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
elif use_new_esimd_sdp_fp16(q_len, key_states.shape[2], self.head_dim, query_states):
# new fp16 sdp doesn't require repeat_kv
import linear_q4_0
attn_output = linear_q4_0.sdp_fp16(query_states, key_states, value_states, attention_mask)
Expand Down Expand Up @@ -687,7 +688,7 @@ def mistral_attention_forward_4_36_quantized(
self.layer_idx, cache_kwargs,
new_layout=True)
kv_seq_len = key_states.shape[-2]
if query_states.size(2) != 1 or query_states.device.type != 'xpu':
if not use_sdp_fp8(q_len, key_states.shape[2], query_states):
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
Expand Down Expand Up @@ -896,7 +897,7 @@ def mistral_attention_forward_4_36_original(
attn_weights = None
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
elif use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
elif use_new_esimd_sdp_fp16(q_len, key_states.shape[2], self.head_dim, query_states):
# new fp16 sdp doesn't require repeat_kv
import linear_q4_0
attn_output = linear_q4_0.sdp_fp16(query_states, key_states, value_states, attention_mask)
Expand Down
38 changes: 37 additions & 1 deletion python/llm/src/ipex_llm/transformers/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def use_esimd_sdp(q_len, k_len, head_dim, query_states, attention_mask=None):
# esimd_sdp only support head_dim = 128 now
return False
elif q_len != 1:
# esimd_sdp only support rest token now
# esimd_sdp only support rest token and q_len == 1 now
return False
elif k_len < 8:
# esimd_sdp will cause wrong output when k_len < 8
Expand Down Expand Up @@ -363,6 +363,42 @@ def use_esimd_sdp(q_len, k_len, head_dim, query_states, attention_mask=None):
return True


def use_new_esimd_sdp_fp16(q_len, k_len, head_dim, query_states):
if query_states.device.type != "xpu":
# esimd_sdp only support GPU now
return False
elif query_states.dtype != torch.float16:
# esimd_sdp only has optimization for FP16 now
return False
elif head_dim != 128 and head_dim != 64:
# esimd_sdp only support head_dim = 128 and 64 now
return False
elif q_len == k_len:
# new sdp_fp16 only support rest token now
return False
elif q_len > 32:
# Use new sdp_fp16 only when q_len <= 32
return False

device_name = torch.xpu.get_device_name(query_states.device.index)
if query_states.shape[0] > 1 and device_name.startswith("Intel(R) Arc(TM) A") \
and is_deepspeed_available:
# It seems there is an issue in DeepSpeed AutoTP when multi-card inference,
# Disable new sdp_fp16 for now
return False

return True


def use_sdp_fp8(q_len, k_len, query_states):
if query_states.device.type != "xpu":
return False
if q_len == k_len:
# sdp_fp8 only support rest token now
return False
return True


def mlp_fusion_check(x, qtype, training):
invalidInputError(x.dim() == 2,
"Here input x's dim should be 2.")
Expand Down

0 comments on commit dc27b3b

Please sign in to comment.