Skip to content

Commit

Permalink
update4
Browse files Browse the repository at this point in the history
  • Loading branch information
songhappy committed Jun 13, 2024
1 parent 9486c5b commit a1f10b5
Showing 1 changed file with 220 additions and 1 deletion.
221 changes: 220 additions & 1 deletion python/llm/src/ipex_llm/transformers/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,7 +1005,7 @@ def llama_attention_forward_4_41(
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
if use_quantize_kv_cache(self.q_proj, hidden_states):
forward_function = llama_attention_forward_4_38_quantized
forward_function = llama_attention_forward_4_41_quantized
else:
forward_function = llama_attention_forward_4_41_original
return forward_function(
Expand All @@ -1021,6 +1021,225 @@ def llama_attention_forward_4_41(
)


def llama_attention_forward_4_41_quantized(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[List[torch.FloatTensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. "
"Please make sure use `attention_mask` instead.`"
)

bsz, q_len, _ = hidden_states.size()
device = hidden_states.device
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
no_tp = not self.config.pretraining_tp > 1
decoding_fast_path = use_decoding_fast_path(self.q_proj,
use_fuse_rope,
enough_kv_room,
bsz * q_len,
llama_decoding_fast_path_qtype_check) and no_tp
if decoding_fast_path:
hidden_states = hidden_states.view(1, -1)
tmp_cache_k, tmp_cache_v = init_kv_cache(
bsz,
self.num_key_value_heads,
self.head_dim,
0,
1,
dtype=hidden_states.dtype,
device=device
)
import xe_linear
query_states, key_states, value_states = xe_linear.forward_qkv(hidden_states,
self.q_proj.weight,
self.k_proj.weight,
self.v_proj.weight,
position_ids,
tmp_cache_k, tmp_cache_v,
self.q_proj.weight.qtype,
self.v_proj.weight.qtype,
0,
self.head_dim,
self.rotary_emb.base,)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(bsz, q_len,
self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len,
self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len,
self.num_key_value_heads, self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
invalidInputError(
False,
f"The cache structure has changed since version v4.36."
f" If you are using {self.__class__.__name__} "
f"for auto-regressive decoding with k/v caching,"
f" please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
if use_fuse_rope:
rope_theta = self.rotary_emb.base
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
key_states,
position_ids,
"llama",
rope_theta=rope_theta)
else:
if cache_position is not None:
# for transformers 4.38.0
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids, "llama2")
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids, "llama")
kv_seq_len = key_states.shape[-2]

if len(past_key_value.key_cache) <= self.layer_idx:
repeated_key_states = repeat_kv(key_states, self.num_key_value_groups)
repeated_value_states = repeat_kv(value_states, self.num_key_value_groups)
if should_split_qkv_tensor(query_states, bsz, self.num_heads,
q_len, kv_seq_len, output_attentions):
attn_output, _ = native_sdp_split_qkv_tensor(query_states, repeated_key_states,
repeated_value_states,
attention_mask, cache_position,
bsz, q_len, kv_seq_len, self.head_dim,
self.num_heads)
else:
attn_weights = torch.matmul(query_states, repeated_key_states
.transpose(2, 3)) / math.sqrt(self.head_dim)

if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
invalidInputError(
False,
f"Attention weights should be of size "
f"{(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)

if attention_mask is not None:
if cache_position is not None:
# for transformers 4.38.0
causal_mask = attention_mask[:, :, :, : kv_seq_len]
attn_weights = attn_weights + causal_mask
else:
attn_mask_size = (bsz, 1, q_len, kv_seq_len)
if attention_mask.size() != attn_mask_size:
invalidInputError(False,
f"Attention mask should be of size {attn_mask_size}, "
f"but is {attention_mask.size()}")
attn_weights = attn_weights + attention_mask

if kv_seq_len >= 2048 or bsz >= 64:
# for memory considerations, do not upcast attention to fp32
# for long sequences or large batches
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
else:
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, repeated_value_states)
if use_cache:
cache_kwargs = None
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, cache_kwargs)
else:
cache_kwargs = None # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, cache_kwargs)
kv_seq_len = key_states.shape[-2]
if not use_sdp_fp8(q_len, key_states.shape[2], query_states):
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype)
key_states = repeat_kv(key_states, self.num_key_value_groups)\
.to(device, dtype=query_states.dtype)
value_states = repeat_kv(value_states, self.num_key_value_groups)\
.to(device, dtype=query_states.dtype)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
attn_weights = attn_weights / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
invalidInputError(
False,
f"Attention weights should be of size"
f" {(bsz, self.num_heads, q_len, kv_seq_len)},"
f" but is {attn_weights.size()}"
)

if attention_mask is not None:
if cache_position is not None:
# for transformers 4.38.0
causal_mask = attention_mask[:, :, :, : kv_seq_len]
attn_weights = attn_weights + causal_mask
else:
attn_mask_size = (bsz, 1, q_len, kv_seq_len)
if attention_mask.size() != attn_mask_size:
invalidInputError(False,
f"Attention mask should be of size {attn_mask_size}, "
f"but is {attention_mask.size()}")
attn_weights = attn_weights + attention_mask

if kv_seq_len >= 2048 or bsz >= 64:
# for memory considerations, do not upcast attention to fp32
# for long sequences or large batches
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
else:
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
else:
import xe_addons
if cache_position is not None:
new_attn_mask = attention_mask[:, :, :, 0:kv_seq_len]
else:
new_attn_mask = attention_mask
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, new_attn_mask)
attn_weights = None

if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
invalidInputError(
False,
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)},"
f" but is {attn_output.size()}"
)

attn_output = attn_output.transpose(1, 2).contiguous()

attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

if self.config.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size
// self.config.pretraining_tp, dim=1)
attn_output = sum([F.linear(attn_output[i],
o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
else:
attn_output = self.o_proj(attn_output)

if not output_attentions:
attn_weights = None

return attn_output, attn_weights, past_key_value

def llama_attention_forward_4_41_original(
self,
hidden_states: torch.Tensor,
Expand Down

0 comments on commit a1f10b5

Please sign in to comment.