Skip to content

Commit

Permalink
refactor gemma to reduce old fuse rope usage (#12215)
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 authored Oct 16, 2024
1 parent 9104a16 commit a4a7586
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 343 deletions.
34 changes: 11 additions & 23 deletions python/llm/src/ipex_llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,6 +1017,10 @@ def _optimize_pre(model, qtype=None):
model.apply(pre_process_attn_and_mlp)
if model.config.model_type == "internvl_chat":
_optimize_pre(model.language_model, qtype=qtype)
if model.config.model_type == "gemma":
from ipex_llm.transformers.models.gemma import merge_qkv, pre_compute_inv_freq
model.apply(merge_qkv)
model.apply(pre_compute_inv_freq)
if model.config.model_type == "gemma2":
from ipex_llm.transformers.models.gemma2 import merge_qkv
model.apply(merge_qkv)
Expand Down Expand Up @@ -1846,32 +1850,16 @@ def _optimize_post(model, lightweight_bmm=False):
module.MistralMLP,
llama_mlp_forward)
elif model.config.model_type == "gemma":
invalidInputError(version.parse(trans_version) >= version.parse("4.38.0"),
"Please upgrade transformers to 4.38.0 or higher version "
"to run Mixtral models.")
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
if version.parse(trans_version) >= version.parse("4.39.0"):
from ipex_llm.transformers.models.gemma import gemma_attention_forward_4_39
convert_forward(model,
module.GemmaAttention,
gemma_attention_forward_4_39
)
else:
from ipex_llm.transformers.models.gemma import gemma_attention_forward
convert_forward(model,
module.GemmaAttention,
gemma_attention_forward,
)
from ipex_llm.transformers.models.gemma import gemma_model_forward
from ipex_llm.transformers.models.gemma import gemma_attention_forward
from ipex_llm.transformers.models.gemma import gemma_rms_norm_forward
from ipex_llm.transformers.models.gemma import gemma_mlp_forward
convert_forward(model,
module.GemmaRMSNorm,
gemma_rms_norm_forward)
convert_forward(model,
module.GemmaMLP,
gemma_mlp_forward)

from ipex_llm.transformers.models.common import mlp_gelu_forward
convert_forward(model, module.GemmaModel, gemma_model_forward)
convert_forward(model, module.GemmaAttention, gemma_attention_forward)
convert_forward(model, module.GemmaRMSNorm, gemma_rms_norm_forward)
convert_forward(model, module.GemmaMLP, mlp_gelu_forward)
elif model.config.model_type == "gemma2":
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
Expand Down
Loading

0 comments on commit a4a7586

Please sign in to comment.