Skip to content

Commit

Permalink
enable gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
ivy-lv11 committed May 31, 2024
1 parent 242eeb5 commit 2e38a40
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 3 deletions.
20 changes: 17 additions & 3 deletions libs/community/langchain_community/llms/ipex_llm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Any, List, Mapping, Optional
from typing import Any, List, Mapping, Optional, Literal

from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
Expand Down Expand Up @@ -46,6 +46,7 @@ def from_model_id(
tokenizer_id: Optional[str] = None,
load_in_4bit: bool = True,
load_in_low_bit: Optional[str] = None,
device_map: Literal['cpu','xpu'] = 'cpu',
**kwargs: Any,
) -> LLM:
"""
Expand Down Expand Up @@ -75,6 +76,7 @@ def from_model_id(
low_bit_model=False,
load_in_4bit=load_in_4bit,
load_in_low_bit=load_in_low_bit,
device_map=device_map,
model_kwargs=model_kwargs,
kwargs=kwargs,
)
Expand All @@ -86,6 +88,7 @@ def from_model_id_low_bit(
model_kwargs: Optional[dict] = None,
*,
tokenizer_id: Optional[str] = None,
device_map: Literal['cpu','xpu'] = 'cpu',
**kwargs: Any,
) -> LLM:
"""
Expand All @@ -109,6 +112,7 @@ def from_model_id_low_bit(
low_bit_model=True,
load_in_4bit=False, # not used for low-bit model
load_in_low_bit=None, # not used for low-bit model
device_map=device_map,
model_kwargs=model_kwargs,
kwargs=kwargs,
)
Expand All @@ -121,6 +125,7 @@ def _load_model(
load_in_4bit: bool = False,
load_in_low_bit: Optional[str] = None,
low_bit_model: bool = False,
device_map: Literal['cpu','xpu'] = "cpu",
model_kwargs: Optional[dict] = None,
kwargs: Optional[dict] = None,
) -> Any:
Expand Down Expand Up @@ -189,6 +194,15 @@ def _load_model(
model_kwargs=_model_kwargs,
)

# Set "cpu" as default device

if device_map not in ["cpu", "xpu"]:
raise ValueError(
"IpexLLM currently only supports device to be "
f"'cpu' or 'xpu', but you have: {device_map}."
)
model.to(device_map)

return cls(
model_id=model_id,
model=model,
Expand Down Expand Up @@ -237,7 +251,7 @@ def _call(
if self.streaming:
from transformers import TextStreamer

input_ids = self.tokenizer.encode(prompt, return_tensors="pt")
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.model.device)
streamer = TextStreamer(
self.tokenizer, skip_prompt=True, skip_special_tokens=True
)
Expand All @@ -263,7 +277,7 @@ def _call(
text = self.tokenizer.decode(output[0], skip_special_tokens=True)
return text
else:
input_ids = self.tokenizer.encode(prompt, return_tensors="pt")
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.model.device)
if stop is not None:
from transformers.generation.stopping_criteria import (
StoppingCriteriaList,
Expand Down
20 changes: 20 additions & 0 deletions libs/community/tests/integration_tests/llms/test_ipex_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,23 @@ def test_save_load_lowbit(model_id: str) -> None:
)
output = loaded_llm.invoke("Hello!")
assert isinstance(output, str)

@skip_if_no_model_ids
@pytest.mark.parametrize(
"model_id",
model_ids_to_test,
)
def test_load_generate_gpu(model_id: str) -> None:
"""Test valid call."""
llm = IpexLLM.from_model_id(
model_id=model_id,
model_kwargs={
"temperature": 0,
"max_length": 16,
"trust_remote_code": True,
},
device_map="xpu",
)
output = llm.generate(["Hello!"])
assert isinstance(output, LLMResult)
assert isinstance(output.generations, list)

0 comments on commit 2e38a40

Please sign in to comment.