Skip to content

Commit

Permalink
Fix bug that torch.ops.torch_ipex.matmul_bias_out cannot work on Linu…
Browse files Browse the repository at this point in the history
…x MTL for short input (#11292)
  • Loading branch information
Oscilloscope98 authored Jun 12, 2024
1 parent b61f6e3 commit 8edcdeb
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions python/llm/src/ipex_llm/transformers/low_bit_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,14 +813,17 @@ def forward(self, x: torch.Tensor):
self.weight.data = self.weight.data.to(x.dtype)

if not self.use_esimd_kernel(x):
if get_ipex_version() < "2.1.10+xpu":
if get_ipex_version() < "2.1.10+xpu" \
or get_xpu_device_type(x) not in ["arc", "flex", "pvc"]:
if self.weight_type == 2:
self.weight = self.weight.transpose(0, 1).contiguous()
self.weight = torch.nn.Parameter(self.weight.transpose(0, 1).contiguous(),
requires_grad=False)
self.weight_type = 1
result = F.linear(x, self.weight, self.bias)
else:
if self.weight_type == 1:
self.weight = self.weight.transpose(0, 1).contiguous()
self.weight = torch.nn.Parameter(self.weight.transpose(0, 1).contiguous(),
requires_grad=False)
self.weight_type = 2
result = torch.ops.torch_ipex.matmul_bias_out(x, self.weight, self.bias)
if self.mp_group is not None:
Expand Down

0 comments on commit 8edcdeb

Please sign in to comment.