Skip to content

Commit

Permalink
do not cache cos sin
Browse files Browse the repository at this point in the history
  • Loading branch information
yangw1234 committed Oct 2, 2023
1 parent 3c8e52a commit 41c4be1
Showing 1 changed file with 33 additions and 42 deletions.
75 changes: 33 additions & 42 deletions python/llm/src/bigdl/llm/transformers/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:

KV_CACHE_ALLOC_BLOCK_LENGTH = 256


import linear_q4_0
def llama_attention_forward_4_31(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -93,47 +93,25 @@ def llama_attention_forward_4_31(
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

if query_states.device.type == "xpu" and position_ids is not None:

query_states = query_states.view(bsz, q_len,
self.num_heads, self.head_dim)
key_states = key_states.view(bsz, q_len,
self.num_key_value_heads, self.head_dim)
value_states = value_states.view(bsz, q_len,
self.num_key_value_heads, self.head_dim)

kv_seq_len = key_states.shape[-3]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]

if kv_seq_len > self.rotary_emb.max_seq_len_cached:
self.rotary_emb._set_cos_sin_cache(seq_len=kv_seq_len,
device=value_states.device,
dtype=value_states.dtype)
query_states = query_states.view(bsz, q_len,
self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len,
self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len,
self.num_key_value_heads, self.head_dim).transpose(1, 2)

cos = self.rotary_emb.cos_cached[0, 0][position_ids].unsqueeze(2)
sin = self.rotary_emb.sin_cached[0, 0][position_ids].unsqueeze(2)

torch.ops.torch_ipex.apply_rotary_embedding(query_states, sin, cos, query_states)
torch.ops.torch_ipex.apply_rotary_embedding(key_states, sin, cos, key_states)

query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
else:
query_states = query_states.view(bsz, q_len,
self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len,
self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len,
self.num_key_value_heads, self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids, "llama")
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
# cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
# query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
# cos, sin, position_ids, "llama")
q_embed = torch.empty(query_states.shape, dtype=query_states.dtype, device=query_states.device)
k_embed = torch.empty(key_states.shape, dtype=key_states.dtype, device=key_states.device)

linear_q4_0.apply_rotary_embedding_half_qk(query_states, key_states, position_ids, q_embed, k_embed)
query_states = q_embed
key_states = k_embed

if past_key_value is not None:
# reuse k, v, self_attention
Expand Down Expand Up @@ -176,6 +154,19 @@ def llama_attention_forward_4_31(
dtype=hidden_states.dtype)
value_states = repeat_kv(value_states, self.num_key_value_groups).to(device,
dtype=hidden_states.dtype)

# if attention_mask is None:
# attn_output = torch.nn.functional.scaled_dot_product_attention(query_states,
# key_states,
# value_states,
# attention_mask,
# is_causal=True,)
# else:
# attn_output = torch.nn.functional.scaled_dot_product_attention(query_states,
# key_states,
# value_states,
# attention_mask,
# is_causal=False)

attn_weights = torch.matmul(query_states,
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
Expand All @@ -199,7 +190,7 @@ def llama_attention_forward_4_31(
dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)

attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
# attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
# if attn_output.size() != attn_output_size:
# invalidInputError(False,
# f"`attn_output` should be of size {attn_output_size},"
Expand Down

0 comments on commit 41c4be1

Please sign in to comment.