Skip to content

Commit

Permalink
Enable fused layernorm (intel#9614)
Browse files Browse the repository at this point in the history
* bloom layernorm

* fix

* layernorm

* fix

* fix

* fix

* style fix

* fix

* replace nn.LayerNorm
  • Loading branch information
qiuxin2012 authored Dec 11, 2023
1 parent e99f7ac commit b16a93f
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 0 deletions.
6 changes: 6 additions & 0 deletions python/llm/src/bigdl/llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,12 @@ def _optimize_post(model, lightweight_bmm=False):
# todo implement 4.28.0 ~ 4.30.2
pass

# convert all nn.LayerNorm
from bigdl.llm.transformers.models.bloom import bloom_layer_norm_forward
convert_forward(model,
nn.LayerNorm,
bloom_layer_norm_forward)

if model.config.architectures is not None and model.config.architectures[0] == "ChatGLMModel":
if model.config.num_layers == 28 and hasattr(model.config, 'rope_ratio'):
# chatglm2-6b-32k
Expand Down
13 changes: 13 additions & 0 deletions python/llm/src/bigdl/llm/transformers/models/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,19 @@ def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training:
return out


def bloom_layer_norm_forward(self, hidden_states):
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
import linear_q4_0
hidden_states = linear_q4_0.fused_layer_norm(hidden_states,
[self.weight.size(0)],
self.weight,
self.bias,
self.eps)
return hidden_states
else:
return F.layer_norm(hidden_states, self.normalized_shape, self.weight, self.bias, self.eps)


def bloom_attention_forward(
self,
hidden_states: torch.Tensor,
Expand Down

0 comments on commit b16a93f

Please sign in to comment.