Skip to content

Commit

Permalink
[NPU] Fix minicpm on MTL (#12599)
Browse files Browse the repository at this point in the history
  • Loading branch information
plusbang authored Dec 24, 2024
1 parent ad2dc96 commit 45f8f72
Showing 1 changed file with 6 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import numpy as np
from typing import Optional, Any, List
import numpy.typing as npt
import os

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -492,7 +493,11 @@ def rotate_half(self, x, *, num_heads, seq_len, head_dim):
def apply_rotary_pos_emb(self, *, q, k, cos, sin, position_ids,
num_heads, seq_len, head_dim):
if position_ids is not None:
position_ids = self.squeeze(position_ids)
if os.environ.get("IPEX_LLM_NPU_MTL", "0") == "1" or\
os.environ.get("IPEX_LLM_NPU_ARL", "0") == "1":
position_ids = self.reshape(position_ids, [-1])
else:
position_ids = self.squeeze(position_ids)
cos = self.gather(cos, self.convert_to_int32(position_ids), self.constant(1), 0)
sin = self.gather(sin, self.convert_to_int32(position_ids), self.constant(1), 0)
cos = self.unsqueeze(cos, [1])
Expand Down

0 comments on commit 45f8f72

Please sign in to comment.