Skip to content

Commit

Permalink
remove new_layout parameter (#10906)
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 authored Apr 29, 2024
1 parent fbcd7bc commit d884c62
Show file tree
Hide file tree
Showing 11 changed files with 25 additions and 38 deletions.
9 changes: 2 additions & 7 deletions python/llm/src/ipex_llm/transformers/kv.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def update(
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]]=None,
new_layout=False,
) -> Tuple[torch.Tensor, torch.Tensor]:

batch_size, num_heads, seq_len, head_dim = key_states.shape
Expand All @@ -50,18 +49,15 @@ def update(
k_cache, v_cache = init_fp8_kv_cache(
batch_size, num_heads, seq_len, head_dim,
device=key_states.device,
new_layout=new_layout,
)
k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states,
new_layout=new_layout)
k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states)

self.key_cache.append(k_cache)
self.value_cache.append(v_cache)
else:
k_cache = self.key_cache[layer_idx]
v_cache = self.value_cache[layer_idx]
k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states,
new_layout=new_layout)
k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_states, value_states)
self.key_cache[layer_idx] = k_cache
self.value_cache[layer_idx] = v_cache

Expand All @@ -77,7 +73,6 @@ def update(
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]]=None,
new_layout=False, # useless, just keep same as DynamicFp8Cache
) -> Tuple[torch.Tensor, torch.Tensor]:

batch_size, num_heads, seq_len, head_dim = key_states.shape
Expand Down
6 changes: 3 additions & 3 deletions python/llm/src/ipex_llm/transformers/models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,15 +128,15 @@ def baichuan_attention_forward_7b_quantized(
if use_cache:
k_cache, v_cache = init_fp8_kv_cache(
bsz, self.num_heads, kv_seq_len, self.head_dim,
device=device, new_layout=True
device=device
)
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache, key_states,
value_states, new_layout=True)
value_states)
past_key_value = (key_states, value_states)
else:
k_cache, v_cache = past_key_value
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
key_states, value_states, new_layout=True)
key_states, value_states)
kv_seq_len = key_states.shape[-2]
past_key_value = (key_states, value_states)
if query_states.size(2) != 1 or query_states.device.type != 'xpu':
Expand Down
4 changes: 2 additions & 2 deletions python/llm/src/ipex_llm/transformers/models/baichuan2.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,12 @@ def baichuan_attention_forward_7b_quantized(
kv_seq_len = key_states.shape[-2]
k_cache, v_cache = init_fp8_kv_cache(
bsz, self.num_heads, kv_seq_len, self.head_dim,
device=device, new_layout=True
device=device
)
else:
k_cache, v_cache = past_key_value
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
key_states, value_states, new_layout=True)
key_states, value_states)

past_key_value = (key_states, value_states) if use_cache else None

Expand Down
6 changes: 2 additions & 4 deletions python/llm/src/ipex_llm/transformers/models/chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,17 +280,15 @@ def chatglm2_quantized_attention_forward_8eb45c(
n_kv_head,
seq_len,
head_dim,
query_layer.device,
new_layout=True)
query_layer.device)
k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_layer, value_layer)
else:
k_cache, v_cache = kv_cache
k_cache = k_cache.permute(1, 2, 0, 3)
v_cache = v_cache.permute(1, 2, 0, 3)
# k_cache, v_cache's shape: [bs, n_kv_head, seq_len, head_dim]

k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_layer, value_layer,
new_layout=True)
k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_layer, value_layer)

if attention_mask is not None:
attention_mask = ~attention_mask
Expand Down
10 changes: 4 additions & 6 deletions python/llm/src/ipex_llm/transformers/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,15 +438,15 @@ def llama_attention_forward_4_31_quantized(
if use_cache:
k_cache, v_cache = init_fp8_kv_cache(
bsz, self.num_key_value_heads, kv_seq_len, self.head_dim,
device=query_states.device, new_layout=True
device=query_states.device
)
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
key_states, value_states)
past_key_value = (key_states, value_states)
else:
k_cache, v_cache = past_key_value
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
key_states, value_states, new_layout=True)
key_states, value_states)
kv_seq_len = key_states.shape[-2]
past_key_value = (key_states, value_states)

Expand Down Expand Up @@ -1067,13 +1067,11 @@ def llama_attention_forward_4_36_quantized(
if use_cache:
cache_kwargs = None
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, cache_kwargs,
new_layout=True)
self.layer_idx, cache_kwargs)
else:
cache_kwargs = None # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, cache_kwargs,
new_layout=True)
self.layer_idx, cache_kwargs)
kv_seq_len = key_states.shape[-2]
if not use_sdp_fp8(q_len, key_states.shape[2], query_states):
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
Expand Down
10 changes: 4 additions & 6 deletions python/llm/src/ipex_llm/transformers/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,15 +299,15 @@ def mistral_attention_forward_quantized(
if use_cache:
k_cache, v_cache = init_fp8_kv_cache(
bsz, self.num_heads, kv_seq_len, self.head_dim,
device=query_states.device, new_layout=True
device=query_states.device
)
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
key_states, value_states)
past_key_value = (key_states, value_states)
else:
k_cache, v_cache = past_key_value
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache,
key_states, value_states, new_layout=True)
key_states, value_states)
kv_seq_len = key_states.shape[-2]
past_key_value = (key_states, value_states)

Expand Down Expand Up @@ -680,13 +680,11 @@ def mistral_attention_forward_4_36_quantized(
if use_cache:
cache_kwargs = None
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, cache_kwargs,
new_layout=True)
self.layer_idx, cache_kwargs)
else:
cache_kwargs = None # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, cache_kwargs,
new_layout=True)
self.layer_idx, cache_kwargs)
kv_seq_len = key_states.shape[-2]
if not use_sdp_fp8(q_len, key_states.shape[2], query_states):
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
Expand Down
2 changes: 1 addition & 1 deletion python/llm/src/ipex_llm/transformers/models/phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def attention_forward(
invalidInputError(past_key_value is not None,
"`past_key_value` cannot be None")
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, None, new_layout=True)
self.layer_idx, None)

key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
Expand Down
4 changes: 2 additions & 2 deletions python/llm/src/ipex_llm/transformers/models/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def qwen_attention_forward_quantized(
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
k_cache, v_cache = init_fp8_kv_cache(
query.size(0), self.num_heads, kv_seq_len, self.head_dim,
device=query.device, new_layout=True
device=query.device
)
key, value = append_fp8_kv_cache(k_cache, v_cache, key, value)
else:
Expand All @@ -461,7 +461,7 @@ def qwen_attention_forward_quantized(
v_cache = v_cache.transpose(1, 2)
# k_cache and v_cache's shape: [bs, num_heads, context_length, head_dim]

key, value = append_fp8_kv_cache(k_cache, v_cache, key, value, new_layout=True)
key, value = append_fp8_kv_cache(k_cache, v_cache, key, value)

attn_output, attn_weight = core_attn(
self, query, key, value, causal_mask, attention_mask, head_mask
Expand Down
3 changes: 1 addition & 2 deletions python/llm/src/ipex_llm/transformers/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,8 +358,7 @@ def qwen2_attention_forward_quantized(
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, cache_kwargs,
new_layout=True)
self.layer_idx, cache_kwargs)

if q_len == 1 and query_states.device.type == 'xpu' and not self.training \
and not hidden_states.requires_grad:
Expand Down
2 changes: 1 addition & 1 deletion python/llm/src/ipex_llm/transformers/models/starcoder2.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def attention_forward(
use_quantize_kv = use_quantize_kv_cache(self.o_proj, hidden_states)

key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, None, new_layout=True)
self.layer_idx, None)

if use_quantize_kv and q_len == 1:
import linear_q4_0
Expand Down
7 changes: 3 additions & 4 deletions python/llm/src/ipex_llm/transformers/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,30 +96,29 @@ def kv_cache_device_check(x: torch.Tensor) -> bool:
1 < x.size(0) and x.size(0) <= 8)


def init_fp8_kv_cache(batch_size, num_heads, current_length, head_dim, device, new_layout=False):
def init_fp8_kv_cache(batch_size, num_heads, current_length, head_dim, device):
max_length = current_length + FP8_KV_ALLOC_LENGTH

k_cache_storage = torch.empty(batch_size, num_heads, max_length, head_dim,
dtype=torch.uint8, device=device)
k_cache = k_cache_storage.as_strided((batch_size, num_heads, 0, head_dim),
k_cache_storage.stride(), storage_offset=0)

# ignore `new_layout`, will remove it in next PR
v_cache_storage = torch.empty(batch_size, num_heads, max_length, head_dim,
dtype=torch.uint8, device=device)
v_cache = v_cache_storage.as_strided((batch_size, num_heads, 0, head_dim),
v_cache_storage.stride(), storage_offset=0)
return k_cache, v_cache


def append_fp8_kv_cache(k_cache, v_cache, key, value, new_layout=False):
def append_fp8_kv_cache(k_cache, v_cache, key, value):
batch_size, num_heads, cur_length, head_dim = k_cache.shape
new_length = cur_length + key.size(2)
new_size = (batch_size, num_heads, new_length, head_dim)

if k_cache.stride(1) < new_length * k_cache.size(3):
new_k_cache, new_v_cache = init_fp8_kv_cache(batch_size, num_heads, new_length,
head_dim, key.device, new_layout)
head_dim, key.device)
new_k_cache = new_k_cache.as_strided(new_size, new_k_cache.stride(), storage_offset=0)
new_v_cache = new_v_cache.as_strided(new_size, new_v_cache.stride(), storage_offset=0)
new_k_cache[:, :, :cur_length, :] = k_cache
Expand Down

0 comments on commit d884c62

Please sign in to comment.