diff --git a/libs/community/langchain_community/llms/ipex_llm.py b/libs/community/langchain_community/llms/ipex_llm.py index 6e1bcc1693dee..d6173dff1f260 100644 --- a/libs/community/langchain_community/llms/ipex_llm.py +++ b/libs/community/langchain_community/llms/ipex_llm.py @@ -1,5 +1,5 @@ import logging -from typing import Any, List, Literal, Mapping, Optional +from typing import Any, List, Mapping, Optional from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM @@ -46,7 +46,6 @@ def from_model_id( tokenizer_id: Optional[str] = None, load_in_4bit: bool = True, load_in_low_bit: Optional[str] = None, - device_map: Literal["cpu", "xpu"] = "cpu", **kwargs: Any, ) -> LLM: """ @@ -76,7 +75,6 @@ def from_model_id( low_bit_model=False, load_in_4bit=load_in_4bit, load_in_low_bit=load_in_low_bit, - device_map=device_map, model_kwargs=model_kwargs, kwargs=kwargs, ) @@ -88,7 +86,6 @@ def from_model_id_low_bit( model_kwargs: Optional[dict] = None, *, tokenizer_id: Optional[str] = None, - device_map: Literal["cpu", "xpu"] = "cpu", **kwargs: Any, ) -> LLM: """ @@ -112,7 +109,6 @@ def from_model_id_low_bit( low_bit_model=True, load_in_4bit=False, # not used for low-bit model load_in_low_bit=None, # not used for low-bit model - device_map=device_map, model_kwargs=model_kwargs, kwargs=kwargs, ) @@ -125,7 +121,6 @@ def _load_model( load_in_4bit: bool = False, load_in_low_bit: Optional[str] = None, low_bit_model: bool = False, - device_map: Literal["cpu", "xpu"] = "cpu", model_kwargs: Optional[dict] = None, kwargs: Optional[dict] = None, ) -> Any: @@ -147,6 +142,16 @@ def _load_model( kwargs = kwargs or {} _tokenizer_id = tokenizer_id or model_id + # Set "cpu" as default device + if "device" not in model_kwargs: + model_kwargs["device"] = "cpu" + + if model_kwargs["device"] not in ["cpu", "xpu"]: + raise ValueError( + "IpexLLMBgeEmbeddings currently only supports device to be " + f"'cpu' or 'xpu', but you have: {model_kwargs['device']}." + ) + device = model_kwargs.pop("device") try: tokenizer = AutoTokenizer.from_pretrained(_tokenizer_id, **_model_kwargs) @@ -194,14 +199,7 @@ def _load_model( model_kwargs=_model_kwargs, ) - # Set "cpu" as default device - - if device_map not in ["cpu", "xpu"]: - raise ValueError( - "IpexLLM currently only supports device to be " - f"'cpu' or 'xpu', but you have: {device_map}." - ) - model.to(device_map) + model.to(device) return cls( model_id=model_id, @@ -252,7 +250,7 @@ def _call( from transformers import TextStreamer input_ids = self.tokenizer.encode(prompt, return_tensors="pt") - input_ids.to(self.model.device) + input_ids = input_ids.to(self.model.device) streamer = TextStreamer( self.tokenizer, skip_prompt=True, skip_special_tokens=True ) @@ -279,7 +277,7 @@ def _call( return text else: input_ids = self.tokenizer.encode(prompt, return_tensors="pt") - input_ids.to(self.model.device) + input_ids = input_ids.to(self.model.device) if stop is not None: from transformers.generation.stopping_criteria import ( StoppingCriteriaList, diff --git a/libs/community/tests/integration_tests/llms/test_ipex_llm.py b/libs/community/tests/integration_tests/llms/test_ipex_llm.py index b6cf5ce7d1478..9ec9095e7c949 100644 --- a/libs/community/tests/integration_tests/llms/test_ipex_llm.py +++ b/libs/community/tests/integration_tests/llms/test_ipex_llm.py @@ -13,12 +13,18 @@ not model_ids_to_test, reason="TEST_IPEXLLM_MODEL_IDS environment variable not set." ) model_ids_to_test = [model_id.strip() for model_id in model_ids_to_test.split(",")] # type: ignore +device = os.getenv("TEST_IPEXLLM_BGE_EMBEDDING_MODEL_DEVICE") or "cpu" def load_model(model_id: str) -> Any: llm = IpexLLM.from_model_id( model_id=model_id, - model_kwargs={"temperature": 0, "max_length": 16, "trust_remote_code": True}, + model_kwargs={ + "temperature": 0, + "max_length": 16, + "trust_remote_code": True, + "device": device, + }, ) return llm @@ -87,24 +93,3 @@ def test_save_load_lowbit(model_id: str) -> None: ) output = loaded_llm.invoke("Hello!") assert isinstance(output, str) - - -@skip_if_no_model_ids -@pytest.mark.parametrize( - "model_id", - model_ids_to_test, -) -def test_load_generate_gpu(model_id: str) -> None: - """Test valid call.""" - llm = IpexLLM.from_model_id( - model_id=model_id, - model_kwargs={ - "temperature": 0, - "max_length": 16, - "trust_remote_code": True, - }, - device_map="xpu", - ) - output = llm.generate(["Hello!"]) - assert isinstance(output, LLMResult) - assert isinstance(output.generations, list)