Skip to content

Commit

Permalink
fix llama2 (intel-analytics#10710)
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 authored Apr 9, 2024
1 parent e10040b commit 8f45e22
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
6 changes: 4 additions & 2 deletions python/llm/src/ipex_llm/transformers/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,8 +1011,10 @@ def llama_attention_forward_4_36_quantized(
kv_seq_len = key_states.shape[-2]

if len(past_key_value.key_cache) <= self.layer_idx:
repeated_key_states = repeat_kv(key_states, self.num_key_value_groups)
repeated_value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states,
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
repeated_key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
invalidInputError(
Expand All @@ -1038,7 +1040,7 @@ def llama_attention_forward_4_36_quantized(
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = torch.matmul(attn_weights, repeated_value_states)
if use_cache:
cache_kwargs = None
key_states, value_states = past_key_value.update(key_states, value_states,
Expand Down
2 changes: 2 additions & 0 deletions python/llm/src/ipex_llm/transformers/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,8 @@ def use_fused_layer_norm(x: torch.Tensor, training: bool):

def fp16_fusion_check(proj, x, training):
# only use fp16 fusion on PVC inference
if not hasattr(proj, "qtype"):
return False
if proj.qtype != ggml_tensor_qtype["fp16"]:
return False
if proj.weight_type != 2:
Expand Down

0 comments on commit 8f45e22

Please sign in to comment.