From 2e38a40229fab67c1c728af1922adddbe9eee3ca Mon Sep 17 00:00:00 2001 From: ivy-lv11 Date: Fri, 31 May 2024 16:27:38 +0800 Subject: [PATCH] enable gpu --- .../langchain_community/llms/ipex_llm.py | 20 ++++++++++++++++--- .../integration_tests/llms/test_ipex_llm.py | 20 +++++++++++++++++++ 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/libs/community/langchain_community/llms/ipex_llm.py b/libs/community/langchain_community/llms/ipex_llm.py index 0e41c305bb7a8..b0217baf1f22c 100644 --- a/libs/community/langchain_community/llms/ipex_llm.py +++ b/libs/community/langchain_community/llms/ipex_llm.py @@ -1,5 +1,5 @@ import logging -from typing import Any, List, Mapping, Optional +from typing import Any, List, Mapping, Optional, Literal from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM @@ -46,6 +46,7 @@ def from_model_id( tokenizer_id: Optional[str] = None, load_in_4bit: bool = True, load_in_low_bit: Optional[str] = None, + device_map: Literal['cpu','xpu'] = 'cpu', **kwargs: Any, ) -> LLM: """ @@ -75,6 +76,7 @@ def from_model_id( low_bit_model=False, load_in_4bit=load_in_4bit, load_in_low_bit=load_in_low_bit, + device_map=device_map, model_kwargs=model_kwargs, kwargs=kwargs, ) @@ -86,6 +88,7 @@ def from_model_id_low_bit( model_kwargs: Optional[dict] = None, *, tokenizer_id: Optional[str] = None, + device_map: Literal['cpu','xpu'] = 'cpu', **kwargs: Any, ) -> LLM: """ @@ -109,6 +112,7 @@ def from_model_id_low_bit( low_bit_model=True, load_in_4bit=False, # not used for low-bit model load_in_low_bit=None, # not used for low-bit model + device_map=device_map, model_kwargs=model_kwargs, kwargs=kwargs, ) @@ -121,6 +125,7 @@ def _load_model( load_in_4bit: bool = False, load_in_low_bit: Optional[str] = None, low_bit_model: bool = False, + device_map: Literal['cpu','xpu'] = "cpu", model_kwargs: Optional[dict] = None, kwargs: Optional[dict] = None, ) -> Any: @@ -189,6 +194,15 @@ def _load_model( model_kwargs=_model_kwargs, ) + # Set "cpu" as default device + + if device_map not in ["cpu", "xpu"]: + raise ValueError( + "IpexLLM currently only supports device to be " + f"'cpu' or 'xpu', but you have: {device_map}." + ) + model.to(device_map) + return cls( model_id=model_id, model=model, @@ -237,7 +251,7 @@ def _call( if self.streaming: from transformers import TextStreamer - input_ids = self.tokenizer.encode(prompt, return_tensors="pt") + input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.model.device) streamer = TextStreamer( self.tokenizer, skip_prompt=True, skip_special_tokens=True ) @@ -263,7 +277,7 @@ def _call( text = self.tokenizer.decode(output[0], skip_special_tokens=True) return text else: - input_ids = self.tokenizer.encode(prompt, return_tensors="pt") + input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.model.device) if stop is not None: from transformers.generation.stopping_criteria import ( StoppingCriteriaList, diff --git a/libs/community/tests/integration_tests/llms/test_ipex_llm.py b/libs/community/tests/integration_tests/llms/test_ipex_llm.py index 163458029c5d5..b153cde2b5868 100644 --- a/libs/community/tests/integration_tests/llms/test_ipex_llm.py +++ b/libs/community/tests/integration_tests/llms/test_ipex_llm.py @@ -86,3 +86,23 @@ def test_save_load_lowbit(model_id: str) -> None: ) output = loaded_llm.invoke("Hello!") assert isinstance(output, str) + +@skip_if_no_model_ids +@pytest.mark.parametrize( + "model_id", + model_ids_to_test, +) +def test_load_generate_gpu(model_id: str) -> None: + """Test valid call.""" + llm = IpexLLM.from_model_id( + model_id=model_id, + model_kwargs={ + "temperature": 0, + "max_length": 16, + "trust_remote_code": True, + }, + device_map="xpu", + ) + output = llm.generate(["Hello!"]) + assert isinstance(output, LLMResult) + assert isinstance(output.generations, list) \ No newline at end of file