From 5c24276fc4819ac889dec3ca672b6aaead208fd6 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Wed, 8 Jan 2025 17:39:17 +0800 Subject: [PATCH] fix custom kernel registration (#12674) --- .../llm/src/ipex_llm/transformers/xpu_ops.py | 44 +++++++++++-------- 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/xpu_ops.py b/python/llm/src/ipex_llm/transformers/xpu_ops.py index 6ee00e1ce49..9b9c740ea1e 100644 --- a/python/llm/src/ipex_llm/transformers/xpu_ops.py +++ b/python/llm/src/ipex_llm/transformers/xpu_ops.py @@ -20,9 +20,9 @@ import xe_addons -@torch.library.register_fake("ipex_llm::forward_new") -def _(x, weight, qtype, input_size): - return torch.empty_like(x) +# @torch.library.register_fake("ipex_llm::forward_new") +# def _(x, weight, qtype, input_size): +# return ??? # @torch.library.register_fake("ipex_llm::dequant") @@ -32,32 +32,38 @@ def _(x, weight, qtype, input_size): @torch.library.register_fake("ipex_llm::mlp_forward_xpu") def _(x, weight1, weight2, batch_size, state_size, output_size, act_type, qtype): - return torch.empty_like(x) + return torch.empty([batch_size, output_size], + dtype=x.dtype, device=x.device) -# @torch.library.register_fake("ipex_llm::rwkv_linear_attention_v4") -# def _(time_decay, time_first, key, value, num_state, den_state, max_state) - # return ??? +@torch.library.register_fake("ipex_llm::rwkv_linear_attention_v4") +def _(time_decay, time_first, key, value, num_state, den_state, max_state): + return torch.empty_like(key) -# @torch.library.register_fake("ipex_llm::rwkv_linear_attention_v5") -# def _(time_decay, time_first, receptance, key, value, state) - # return ??? +@torch.library.register_fake("ipex_llm::rwkv_linear_attention_v5") +def _(time_decay, time_first, receptance, key, value, state): + bsz, n_heads, seq_len, head_dim = key.shape + return torch.empty([bsz, seq_len, n_heads, head_dim], + dtype=key.dtype, device=key.device) -# @torch.library.register_fake("ipex_llm::rwkv_time_shift") -# def _(hidden, shifted, mix): - # return ??? +@torch.library.register_fake("ipex_llm::rwkv_time_shift") +def _(hidden, shifted, mix): + bsz, seq_len, hidden_size = hidden.shape + return torch.empty([mix.size(0), bsz, seq_len, hidden_size], + dtype=hidden.dtype, device=hidden.device) -# @torch.library.register_fake("ipex_llm::dequantize_rows") -# def _(x, weight, qtype, state_size, output_size): - # return ??? +@torch.library.register_fake("ipex_llm::dequantize_rows") +def _(x, weight, qtype, state_size, output_size): + return torch.empty([x.size(0), x.size(1), state_size], + dtype=torch.float, device=weight.device) -@torch.library.register_fake("ipex_llm::batch_forward") -def _(x, weight, qtype): - return torch.empty_like(x) +# @torch.library.register_fake("ipex_llm::batch_forward") +# def _(x, weight, qtype): +# return ??? @torch.library.register_fake("ipex_llm::sdp")