Skip to content

Commit

Permalink
[WIP] Support llama2 with transformers==4.38.0 (#11024)
Browse files Browse the repository at this point in the history
* support llama2 with transformers==4.38.0

* add supprot for quantize_qkv

* add original support for 4.38.0 now

* code style fix
  • Loading branch information
sgwhat authored May 15, 2024
1 parent 686f603 commit 9942a4b
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 64 deletions.
26 changes: 17 additions & 9 deletions python/llm/src/ipex_llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,16 +961,24 @@ def _optimize_post(model, lightweight_bmm=False):
llama_decoder_forward)
if version.parse(trans_version) >= version.parse("4.36.0"):
# transformers version >= 4.36.0
from ipex_llm.transformers.models.llama import llama_attention_forward_4_36
from ipex_llm.transformers.models.llama import llama_attention_forward_4_38
from ipex_llm.transformers.models.llama import llama_model_forward_4_36
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaAttention,
llama_attention_forward_4_36, )
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaModel,
llama_model_forward_4_36)
if version.parse(trans_version) >= version.parse("4.38.0"):
from ipex_llm.transformers.models.llama import llama_attention_forward_4_38_original
# Todo: support llama_model_forward with transformers version >= 4.38.0
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaAttention,
llama_attention_forward_4_38_original)
else:
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaModel,
llama_model_forward_4_36)
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaAttention,
llama_attention_forward_4_38)
else:
# transformers version between 4.31.0 - 4.35.2
convert_forward(
Expand Down
Loading

0 comments on commit 9942a4b

Please sign in to comment.