From 9752ffe979c203b685100bdcd8ef6327f08f4c86 Mon Sep 17 00:00:00 2001 From: Cengguang Zhang Date: Fri, 26 Apr 2024 18:47:35 +0800 Subject: [PATCH] LLM: update split qkv native sdp. (#10895) * LLM: update split qkv native sdp. * fix typo. --- .../ipex_llm/transformers/models/chatglm2.py | 17 ++++++----------- .../src/ipex_llm/transformers/models/llama.py | 8 +++----- 2 files changed, 9 insertions(+), 16 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm2.py b/python/llm/src/ipex_llm/transformers/models/chatglm2.py index 787ed9cef5e..ed39c00e60e 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm2.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm2.py @@ -258,16 +258,14 @@ def chatglm2_quantized_attention_forward_8eb45c( query_split = torch.split(query_layer, block_size, dim=1) key_split = torch.split(key, block_size, dim=1) value_split = torch.split(value, block_size, dim=1) - context_layer = torch.empty(batch_size, n_head, seq_len, - head_dim, dtype=key.dtype).to(query_layer.device) - idx = 0 + results = [] for q, k, v in zip(query_split, key_split, value_split): if attention_mask is None: result = F.scaled_dot_product_attention(q, k, v, is_causal=True) else: result = F.scaled_dot_product_attention(q, k, v, attention_mask) - context_layer[:, idx:idx+q.shape[1], :, :] = result - idx = idx + q.shape[1] + results.append(result) + context_layer = torch.cat(results, dim=1) else: if attention_mask is None: context_layer = F.scaled_dot_product_attention(query_layer, key, @@ -541,14 +539,11 @@ def core_attn_forward_8eb45c(query_layer, key_layer, value_layer, attention_mask query_split = torch.split(query_layer.to(key_layer.dtype), block_size, dim=1) key_split = torch.split(key_layer, block_size, dim=1) value_split = torch.split(value_layer, block_size, dim=1) - batch_size, n_head, seq_len, head_dim = query_layer.shape - context_layer = torch.empty(batch_size, n_head, seq_len, - head_dim, dtype=key_layer.dtype).to(query_layer.device) - idx = 0 + results = [] for q, k, v in zip(query_split, key_split, value_split): result = F.scaled_dot_product_attention(q, k, v, is_causal=True).to(k.dtype) - context_layer[:, idx:idx+q.shape[1], :, :] = result - idx = idx + q.shape[1] + results.append(result) + context_layer = torch.cat(results, dim=1) else: context_layer = F.scaled_dot_product_attention(query_layer.to(key_layer.dtype), key_layer, diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index 4e879603e77..c9a043617d0 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -1423,8 +1423,7 @@ def native_sdp_split_qkv_tensor(query, key, value, attention_mask, query_split = torch.split(query.to(key.dtype), block_size, dim=1) key_split = torch.split(key.transpose(2, 3), block_size, dim=1) value_split = torch.split(value, block_size, dim=1) - attn_output = torch.empty(bsz, num_heads, q_len, head_dim).to(query.device) - idx = 0 + attn_outputs = [] for q, k, v in zip(query_split, key_split, value_split): attn_weights_split = torch.matmul(q, k) / math.sqrt(head_dim) block_actual_size = attn_weights_split.size(1) @@ -1442,9 +1441,8 @@ def native_sdp_split_qkv_tensor(query, key, value, attention_mask, f"but is {attention_mask.size()}") attn_weights_split = attn_weights_split + attention_mask attn_weights_split = nn.functional.softmax(attn_weights_split, dim=-1) - attn_weights_split = torch.matmul(attn_weights_split, v) - attn_output[:, idx:idx+block_actual_size, :, :] = attn_weights_split - idx = idx + block_actual_size + attn_outputs.append(torch.matmul(attn_weights_split, v)) + attn_output = torch.cat(attn_outputs, dim=1) return attn_output.to(key.dtype), None