Skip to content

Commit

Permalink
use new fused layer norm
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 committed Dec 16, 2024
1 parent 5ae0006 commit 4011c18
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 22 deletions.
7 changes: 3 additions & 4 deletions python/llm/src/ipex_llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1296,10 +1296,9 @@ def _optimize_post(model, lightweight_bmm=False):
trans_version = transformers.__version__

# convert all nn.LayerNorm
from ipex_llm.transformers.models.bloom import bloom_layer_norm_forward
convert_forward(model,
nn.LayerNorm,
bloom_layer_norm_forward)
from ipex_llm.transformers.models.common import layer_norm_forward
convert_forward(model, nn.LayerNorm, layer_norm_forward)

from ipex_llm.transformers.models.llama import llama_rms_norm_forward
from ipex_llm.transformers.models.llama import llama_mlp_forward

Expand Down
17 changes: 0 additions & 17 deletions python/llm/src/ipex_llm/transformers/models/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,23 +64,6 @@ def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training:
return out


def bloom_layer_norm_forward(self, hidden_states):
if use_fused_layer_norm(hidden_states, self.training):
import xe_addons
result = xe_addons.fused_layer_norm(hidden_states,
[self.weight.size(0)],
self.weight,
self.bias,
self.eps)
# if nelement == 0, means fused norm failed, go back to python implement.
if result.nelement != 0:
return result
input_dtype = hidden_states.dtype
result = F.layer_norm(hidden_states.to(self.weight.dtype),
self.normalized_shape, self.weight, self.bias, self.eps)
return result.to(input_dtype)


def bloom_attention_forward(
self,
hidden_states: torch.Tensor,
Expand Down
17 changes: 16 additions & 1 deletion python/llm/src/ipex_llm/transformers/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.


import math
import torch
from typing import List

Expand Down Expand Up @@ -159,7 +160,7 @@ def rms_norm_forward(self, hidden_states: torch.Tensor):
else:
eps = self.epsilon

if hidden_states.device.type == 'xpu':
if hidden_states.device.type == 'xpu' and hidden_states.dtype in [torch.float, torch.half]:
import xe_addons
x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous()
output = xe_addons.rms_norm(weight, x_2d, eps)
Expand All @@ -169,3 +170,17 @@ def rms_norm_forward(self, hidden_states: torch.Tensor):
variance = hidden_states.to(torch.float32).pow(2).mean(dim=-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + eps)
return weight * hidden_states.to(input_dtype)


def layer_norm_forward(self, hidden_states: torch.Tensor):
if hidden_states.device.type == 'xpu' and hidden_states.dtype in [torch.float, torch.half]:
import xe_addons
hidden_size = math.prod(self.normalized_shape)
x_2d = hidden_states.reshape(-1, hidden_size).contiguous()
output = xe_addons.layer_norm(x_2d, self.weight, self.bias, self.eps)
return output.reshape(hidden_states.shape)
else:
return torch.nn.functional.layer_norm(
input, self.normalized_shape,
self.weight, self.bias, self.eps
)

0 comments on commit 4011c18

Please sign in to comment.