Skip to content

Commit

Permalink
LLM: fix baichuan7b quantize kv abnormal output. (intel-analytics#10504)
Browse files Browse the repository at this point in the history
* fix abnormal output.

* fix style.

* fix style.
  • Loading branch information
lalalapotter authored Mar 22, 2024
1 parent ab47c02 commit cb524e5
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions python/llm/src/bigdl/llm/transformers/models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ def baichuan_attention_forward_7b_quantized(
bsz, self.num_heads, kv_seq_len, self.head_dim,
device=device, new_layout=True
)
key_states, value_states = append_kv_cache(k_cache, v_cache, key_states, value_states)
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache, key_states,
value_states, new_layout=True)
past_key_value = (key_states, value_states)
else:
k_cache, v_cache = past_key_value
Expand Down Expand Up @@ -185,7 +186,7 @@ def baichuan_attention_forward_7b_quantized(
if not output_attentions:
attn_weights = None

return attn_output, attn_weights, past_key_value
return attn_output.to(hidden_states.dtype), attn_weights, past_key_value


def baichuan_attention_forward_7b_origin(
Expand Down

0 comments on commit cb524e5

Please sign in to comment.