Skip to content

Commit

Permalink
add glm_sdpa back to fix chatglm-6b (#11313)
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 authored Jun 14, 2024
1 parent 7f65836 commit 91965b5
Showing 1 changed file with 44 additions and 1 deletion.
45 changes: 44 additions & 1 deletion python/llm/src/ipex_llm/transformers/models/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import torch.nn.functional as F
from typing import Optional, Tuple
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from ipex_llm.transformers.models.chatglm2 import glm_sdpa
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp


def rotate_half(x):
Expand All @@ -39,6 +39,49 @@ def apply_rotary_pos_emb_index(q, k, cos, sin, position_id):
q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
return q, k


def glm_sdpa(query, key, value, attention_mask=None, is_causal=False):
if use_flash_attention(query, key, attention_mask) or query.device.type == 'cpu':
context_layer = F.scaled_dot_product_attention(query.to(key.dtype),
key,
value,
attention_mask,
is_causal=is_causal).to(key.dtype)
else:
# attention_mask is not None only when past_key_value is not None and q_len > 1
if attention_mask is not None:
attn_bias = torch.zeros(attention_mask.shape, dtype=query.dtype,
device=query.device)
attention_mask = ~attention_mask
if attention_mask.dtype == torch.bool:
attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf"))
else:
attn_bias += attention_mask
elif is_causal:
L, S = query.size(-2), key.size(-2)
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(key.dtype)
else:
attn_bias = None
if use_sdp(query.shape[2], key.shape[2],
query.shape[-1], query):
import xe_addons
attn_output = xe_addons.sdp(query, key, value, attn_bias)
context_layer = attn_output.view(query.shape)
else:
head_dim = query.size(-1)
attn = torch.matmul(query.to(key.dtype) / math.sqrt(head_dim),
key.transpose(2, 3))
if attn_bias is not None:
attn += attn_bias
attn = F.softmax(attn, dim=-1,
dtype=torch.float32).to(value.dtype)
context_layer = torch.matmul(attn, value)
return context_layer


import os

KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
Expand Down

0 comments on commit 91965b5

Please sign in to comment.