From dbbea0ae1306ba32a129016de006e29f70c0a340 Mon Sep 17 00:00:00 2001 From: Yuwen Hu Date: Wed, 12 Jun 2024 17:59:00 +0800 Subject: [PATCH] Fix bug that torch.ops.torch_ipex.matmul_bias_out cannot work on Linux MTL for short input --- python/llm/src/ipex_llm/transformers/low_bit_linear.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/low_bit_linear.py b/python/llm/src/ipex_llm/transformers/low_bit_linear.py index b429c3ce08e..86a689eebc1 100644 --- a/python/llm/src/ipex_llm/transformers/low_bit_linear.py +++ b/python/llm/src/ipex_llm/transformers/low_bit_linear.py @@ -813,14 +813,17 @@ def forward(self, x: torch.Tensor): self.weight.data = self.weight.data.to(x.dtype) if not self.use_esimd_kernel(x): - if get_ipex_version() < "2.1.10+xpu": + if get_ipex_version() < "2.1.10+xpu" \ + or get_xpu_device_type(x) not in ["arc", "flex", "pvc"]: if self.weight_type == 2: - self.weight = self.weight.transpose(0, 1).contiguous() + self.weight = torch.nn.Parameter(self.weight.transpose(0, 1).contiguous(), + requires_grad=False) self.weight_type = 1 result = F.linear(x, self.weight, self.bias) else: if self.weight_type == 1: - self.weight = self.weight.transpose(0, 1).contiguous() + self.weight = torch.nn.Parameter(self.weight.transpose(0, 1).contiguous(), + requires_grad=False) self.weight_type = 2 result = torch.ops.torch_ipex.matmul_bias_out(x, self.weight, self.bias) if self.mp_group is not None: