From 841f0b32b716548824cde29ffaf08177412550ff Mon Sep 17 00:00:00 2001 From: sgwhat Date: Wed, 24 Jul 2024 17:16:30 +0800 Subject: [PATCH 1/2] remove re-import ipex --- modules/text_generation.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/modules/text_generation.py b/modules/text_generation.py index edb498441a..d5a1593ed9 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -377,10 +377,6 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(filtered_params) print() - if shared.args.device == "GPU": - import intel_extension_for_pytorch - shared.model = shared.model.to("xpu") - streamer = TextIteratorStreamer(shared.tokenizer, skip_prompt=True) t0 = time.time() From 3461a27b45e12a648ec23fa72a5abe76a20b8416 Mon Sep 17 00:00:00 2001 From: sgwhat Date: Wed, 24 Jul 2024 17:37:51 +0800 Subject: [PATCH 2/2] hot fix --- modules/models.py | 4 ++-- modules/text_generation.py | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/modules/models.py b/modules/models.py index 55b14092f6..1e248a51c5 100644 --- a/modules/models.py +++ b/modules/models.py @@ -26,6 +26,8 @@ from modules.models_settings import get_model_metadata from modules.relative_imports import RelativeImport +from ipex_llm.transformers import AutoModelForCausalLM, AutoModel, AutoModelForSeq2SeqLM + transformers.logging.set_verbosity_error() local_rank = None @@ -323,8 +325,6 @@ def AutoAWQ_loader(model_name): def ipex_llm_loader(model_name): - from ipex_llm.transformers import AutoModelForCausalLM, AutoModel, AutoModelForSeq2SeqLM - path_to_model = Path(f'{shared.args.model_dir}/{model_name}') config = AutoConfig.from_pretrained(path_to_model, trust_remote_code=shared.args.trust_remote_code) diff --git a/modules/text_generation.py b/modules/text_generation.py index d5a1593ed9..edb498441a 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -377,6 +377,10 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(filtered_params) print() + if shared.args.device == "GPU": + import intel_extension_for_pytorch + shared.model = shared.model.to("xpu") + streamer = TextIteratorStreamer(shared.tokenizer, skip_prompt=True) t0 = time.time()