From 797dbc48b8d97968b1a99299e1fd8fef379de038 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Thu, 23 May 2024 17:37:37 +0800 Subject: [PATCH] fix phi-2 and phi-3 convert (#11116) --- python/llm/src/ipex_llm/transformers/convert.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 69e005af328..6db25be0526 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1522,7 +1522,7 @@ def safe_bmm_fwd(*args, **kwargs): from ipex_llm.transformers.models.starcoder2 import model_forward convert_forward(model, module.Starcoder2Attention, attention_forward) convert_forward(model, module.Starcoder2Model, model_forward) - elif model.config.model_type in ["phi3", "phi3_v"]: + elif model.config.model_type == "phi": # for phi-2 modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) @@ -1530,7 +1530,7 @@ def safe_bmm_fwd(*args, **kwargs): from ipex_llm.transformers.models.phi import model_forward convert_forward(model, module.PhiAttention, attention_forward) convert_forward(model, module.PhiModel, model_forward) - elif model.config.model_type == "phi3": + elif model.config.model_type in ["phi3", "phi3_v"]: # for phi-3 modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name)