Skip to content

Commit

Permalink
Set mistral fuse rope to false except fp6 & fp16 (#11765)
Browse files Browse the repository at this point in the history
* set mistral fuse rope to false except fp6 & fp16

* lint

* lint

---------

Co-authored-by: ATMxsp01 <[email protected]>
  • Loading branch information
ATMxsp01 and ATMxsp01 authored Aug 12, 2024
1 parent 8db3405 commit 1b05cab
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 0 deletions.
1 change: 1 addition & 0 deletions python/llm/src/ipex_llm/transformers/low_bit_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
MOFQ8 = ggml_tensor_qtype["mixed_fp8"]
FP8E5 = ggml_tensor_qtype["fp8_e5m2"]
FP6 = ggml_tensor_qtype["fp6"]
FP16 = ggml_tensor_qtype["fp16"]
IQ2_XXS = ggml_tensor_qtype["gguf_iq2_xxs"]
IQ2_XS = ggml_tensor_qtype["gguf_iq2_xs"]
Q2_K = ggml_tensor_qtype["q2_k"]
Expand Down
17 changes: 17 additions & 0 deletions python/llm/src/ipex_llm/transformers/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@
except ImportError:
Cache = Tuple[torch.Tensor]

from ipex_llm.transformers.low_bit_linear import FP6, FP16

import os

KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
Expand Down Expand Up @@ -271,6 +273,9 @@ def mistral_attention_forward_quantized(
original_dtype = hidden_states.dtype

use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
if self.q_proj.qtype not in [FP6, FP16]:
use_fuse_rope = False

enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value)
decoding_fast_path = use_decoding_fast_path(self.q_proj,
use_fuse_rope,
Expand Down Expand Up @@ -476,6 +481,9 @@ def mistral_attention_forward_original(
original_dtype = hidden_states.dtype

use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
if self.q_proj.qtype not in [FP6, FP16]:
use_fuse_rope = False

enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value)
decoding_fast_path = use_decoding_fast_path(self.q_proj,
use_fuse_rope,
Expand Down Expand Up @@ -699,6 +707,9 @@ def mistral_attention_forward_4_36_quantized(
original_dtype = hidden_states.dtype

use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
if self.q_proj.qtype not in [FP6, FP16]:
use_fuse_rope = False

enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx,
seq_len=q_len)
decoding_fast_path = use_decoding_fast_path(self.q_proj,
Expand Down Expand Up @@ -917,6 +928,9 @@ def mistral_attention_forward_4_36_original(
use_compresskv = isinstance(past_key_value, DynamicCompressCache)

use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
if self.q_proj.qtype not in [FP6, FP16]:
use_fuse_rope = False

enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value,
self.layer_idx,
q_len)
Expand Down Expand Up @@ -1175,6 +1189,9 @@ def mistral_attention_forward_4_39_original(
use_compresskv = isinstance(past_key_value, DynamicCompressCache)

use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
if self.q_proj.qtype not in [FP6, FP16]:
use_fuse_rope = False

enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx,
q_len)
decoding_fast_path = use_decoding_fast_path(self.q_proj,
Expand Down

0 comments on commit 1b05cab

Please sign in to comment.