Skip to content

Commit

Permalink
LLM: update rms related usage to suport ipex 2.1 new api (intel#9466)
Browse files Browse the repository at this point in the history
* update rms related usage

* fix style
  • Loading branch information
rnwang04 authored Nov 16, 2023
1 parent 404c342 commit 3d4f147
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 8 deletions.
13 changes: 11 additions & 2 deletions python/llm/src/bigdl/llm/transformers/models/baichuan2.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
from transformers.utils import logging, ContextManagers
from bigdl.llm.transformers.models.llama import get_ipex_version
logger = logging.get_logger(__name__)

try:
Expand All @@ -47,8 +48,16 @@

def baichuan_13b_rms_norm_forward(self, hidden_states):
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
[self.weight.size(0)], self.weight)
if get_ipex_version() <= "2.0.110+xpu":
hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
[self.weight.size(0)],
self.weight)
else:
hidden_states, _ = torch.ops.torch_ipex.fast_rms_norm(hidden_states,
[self.weight.size(0)],
self.weight,
None,
self.epsilon)
else:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
Expand Down
13 changes: 11 additions & 2 deletions python/llm/src/bigdl/llm/transformers/models/chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import torch.nn.functional as F
from transformers.modeling_outputs import BaseModelOutputWithPast
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.llama import get_ipex_version


KV_CACHE_ALLOC_BLOCK_LENGTH = 256
Expand Down Expand Up @@ -77,8 +78,16 @@ def apply_rotary_pos_emb_chatglm(x: torch.Tensor, rope_cache: torch.Tensor) -> t

def chatglm_rms_norm_forward(self, hidden_states):
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
[self.weight.size(0)], self.weight)
if get_ipex_version() <= "2.0.110+xpu":
hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
[self.weight.size(0)], self.weight)
else:
# for ipex >= 2.1
hidden_states = torch.ops.torch_ipex.fast_rms_norm(hidden_states,
[self.weight.size(0)],
self.weight,
None, # bias
self.eps)
else:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
Expand Down
10 changes: 6 additions & 4 deletions python/llm/src/bigdl/llm/transformers/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,15 @@ def get_ipex_version():

def llama_rms_norm_forward(self, hidden_states):
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
if get_ipex_version() == "2.0.110+xpu":
if get_ipex_version() <= "2.0.110+xpu":
hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
[self.weight.size(0)], self.weight)
else:
hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
[self.weight.size(0)], self.weight,
self.variance_epsilon)
hidden_states, _ = torch.ops.torch_ipex.fast_rms_norm(hidden_states,
[self.weight.size(0)],
self.weight,
None,
self.variance_epsilon)
else:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
Expand Down

0 comments on commit 3d4f147

Please sign in to comment.