Skip to content

Commit

Permalink
update tests and api docs and fix format
Browse files Browse the repository at this point in the history
  • Loading branch information
shane-huang committed Apr 24, 2024
1 parent 2a72111 commit 7c133b7
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 25 deletions.
8 changes: 3 additions & 5 deletions docs/docs/integrations/llms/ipex_llm.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6a436f7eb96849409a337312a93b9dd3",
"model_id": "897501860fe4452b836f816c72d955dd",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -114,7 +114,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"2024-04-24 17:20:10,884 - INFO - Converting the current model to sym_int4 format......\n"
"2024-04-24 21:20:12,461 - INFO - Converting the current model to sym_int4 format......\n"
]
}
],
Expand Down Expand Up @@ -184,8 +184,6 @@
"outputs": [],
"source": [
"saved_lowbit_model_path = \"./vicuna-7b-1.5-low-bit\" # path to save low-bit model\n",
"\n",
"\n",
"llm.model.save_low_bit(saved_lowbit_model_path)\n",
"del llm"
]
Expand All @@ -207,7 +205,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"2024-04-24 17:20:43,387 - INFO - Converting the current model to sym_int4 format......\n"
"2024-04-24 21:20:35,874 - INFO - Converting the current model to sym_int4 format......\n"
]
}
],
Expand Down
33 changes: 29 additions & 4 deletions libs/community/langchain_community/llms/bigdl_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ class BigdlLLM(IpexLLM):
def from_model_id(
cls,
model_id: str,
tokenizer_id: Optional[str] = None,
load_in_4bit: bool = True,
load_in_low_bit: Optional[str] = None,
model_kwargs: Optional[dict] = None,
**kwargs: Any,
) -> LLM:
Expand All @@ -31,6 +34,8 @@ def from_model_id(
Args:
model_id: Path for the huggingface repo id to be downloaded or
the huggingface checkpoint folder.
tokenizer_id: Path for the huggingface repo id to be downloaded or
the huggingface checkpoint folder which contains the tokenizer.
model_kwargs: Keyword arguments to pass to the model and tokenizer.
kwargs: Extra arguments to pass to the model and tokenizer.
Expand All @@ -52,12 +57,27 @@ def from_model_id(
"Please install it with `pip install --pre --upgrade bigdl-llm[all]`."
)

if load_in_low_bit is not None:
logger.warning(
"""`load_in_low_bit` option is not supported in BigdlLLM and
is ignored. For more data types support with `load_in_low_bit`,
use IpexLLM instead."""
)

if not load_in_4bit:
raise ValueError(
"BigdlLLM only supports loading in 4-bit mode, "
"i.e. load_in_4bit = True. "
"Please install it with `pip install --pre --upgrade bigdl-llm[all]`."
)

_model_kwargs = model_kwargs or {}
_tokenizer_id = tokenizer_id or model_id

try:
tokenizer = AutoTokenizer.from_pretrained(model_id, **_model_kwargs)
tokenizer = AutoTokenizer.from_pretrained(_tokenizer_id, **_model_kwargs)
except Exception:
tokenizer = LlamaTokenizer.from_pretrained(model_id, **_model_kwargs)
tokenizer = LlamaTokenizer.from_pretrained(_tokenizer_id, **_model_kwargs)

try:
model = AutoModelForCausalLM.from_pretrained(
Expand Down Expand Up @@ -85,6 +105,7 @@ def from_model_id(
def from_model_id_low_bit(
cls,
model_id: str,
tokenizer_id: Optional[str] = None,
model_kwargs: Optional[dict] = None,
**kwargs: Any,
) -> LLM:
Expand All @@ -94,6 +115,8 @@ def from_model_id_low_bit(
Args:
model_id: Path for the bigdl-llm transformers low-bit model folder.
tokenizer_id: Path for the huggingface repo id or local model folder
which contains the tokenizer.
model_kwargs: Keyword arguments to pass to the model and tokenizer.
kwargs: Extra arguments to pass to the model and tokenizer.
Expand All @@ -117,10 +140,12 @@ def from_model_id_low_bit(
)

_model_kwargs = model_kwargs or {}
_tokenizer_id = tokenizer_id or model_id

try:
tokenizer = AutoTokenizer.from_pretrained(model_id, **_model_kwargs)
tokenizer = AutoTokenizer.from_pretrained(_tokenizer_id, **_model_kwargs)
except Exception:
tokenizer = LlamaTokenizer.from_pretrained(model_id, **_model_kwargs)
tokenizer = LlamaTokenizer.from_pretrained(_tokenizer_id, **_model_kwargs)

try:
model = AutoModelForCausalLM.load_low_bit(model_id, **_model_kwargs)
Expand Down
36 changes: 24 additions & 12 deletions libs/community/langchain_community/llms/ipex_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def from_model_id(
cls,
model_id: str,
tokenizer_id: Optional[str] = None,
load_in_4bit: Optional[bool] = True,
load_in_4bit: bool = True,
load_in_low_bit: Optional[str] = None,
model_kwargs: Optional[dict] = None,
**kwargs: Any,
Expand All @@ -53,6 +53,13 @@ def from_model_id(
Args:
model_id: Path for the huggingface repo id to be downloaded or
the huggingface checkpoint folder.
tokenizer_id: Path for the huggingface repo id to be downloaded or
the huggingface checkpoint folder which contains the tokenizer.
load_in_4bit: "Whether to load model in 4bit.
Unused if `load_in_low_bit` is not None.
load_in_low_bit: Which low bit precisions to use when loading model.
Example values: 'sym_int4', 'asym_int4', 'fp4', 'nf4', 'fp8', etc.
Overrides `load_in_4bit` if specified.
model_kwargs: Keyword arguments to pass to the model and tokenizer.
kwargs: Extra arguments to pass to the model and tokenizer.
Expand Down Expand Up @@ -85,6 +92,8 @@ def from_model_id_low_bit(
Args:
model_id: Path for the ipex-llm transformers low-bit model folder.
tokenizer_id: Path for the huggingface repo id or local model folder
which contains the tokenizer.
model_kwargs: Keyword arguments to pass to the model and tokenizer.
kwargs: Extra arguments to pass to the model and tokenizer.
Expand All @@ -105,11 +114,11 @@ def from_model_id_low_bit(
@classmethod
def _load_model(
cls,
model_id,
model_id: str,
tokenizer_id: Optional[str] = None,
low_bit_model: Optional[bool] = False,
load_in_4bit: Optional[bool] = True,
load_in_4bit: bool = False,
load_in_low_bit: Optional[str] = None,
low_bit_model: bool = False,
model_kwargs: Optional[dict] = None,
kwargs: Optional[dict] = None,
) -> Any:
Expand Down Expand Up @@ -137,13 +146,22 @@ def _load_model(
except Exception:
tokenizer = LlamaTokenizer.from_pretrained(_tokenizer_id, **_model_kwargs)

# restore model_kwargs
if "trust_remote_code" in _model_kwargs:
_model_kwargs = {
k: v for k, v in _model_kwargs.items() if k != "trust_remote_code"
}

# load model with AutoModelForCausalLM and falls back to AutoModel on failure.
load_kwargs = {"use_cache": True, "trust_remote_code": True}
load_kwargs = {
"use_cache": True,
"trust_remote_code": True,
}

if not low_bit_model:
if load_in_low_bit is not None:
load_function_name = "from_pretrained"
load_kwargs["load_in_low_bit"] = load_in_low_bit
load_kwargs["load_in_low_bit"] = load_in_low_bit # type: ignore
else:
load_function_name = "from_pretrained"
load_kwargs["load_in_4bit"] = load_in_4bit
Expand All @@ -169,12 +187,6 @@ def _load_model(
model_kwargs=_model_kwargs,
)

# restore model_kwargs
if "trust_remote_code" in _model_kwargs:
_model_kwargs = {
k: v for k, v in _model_kwargs.items() if k != "trust_remote_code"
}

return cls(
model_id=model_id,
model=model,
Expand Down
26 changes: 22 additions & 4 deletions libs/community/tests/integration_tests/llms/test_bigdl_llm.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,41 @@
"""Test BigdlLLM"""
import os

import pytest
from langchain_core.outputs import LLMResult

from langchain_community.llms.bigdl_llm import BigdlLLM

model_ids_to_test = os.getenv('TEST_BIGDLLLM_MODEL_IDS') or ""
skip_if_no_model_ids = pytest.mark.skipif(not model_ids_to_test,
reason="TEST_BIGDLLLM_MODEL_IDS environment variable not set.")
model_ids_to_test = [model_id.strip() for model_id in model_ids_to_test.split(',')]


def test_call() -> None:
@skip_if_no_model_ids
@pytest.mark.parametrize(
"model_id",
model_ids_to_test,
)
def test_call(model_id:str) -> None:
"""Test valid call to bigdl-llm."""
llm = BigdlLLM.from_model_id(
model_id="lmsys/vicuna-7b-v1.5",
model_id=model_id,
model_kwargs={"temperature": 0, "max_length": 16, "trust_remote_code": True},
)
output = llm("Hello!")
assert isinstance(output, str)


def test_generate() -> None:
@skip_if_no_model_ids
@pytest.mark.parametrize(
"model_id",
model_ids_to_test,
)
def test_generate(model_id:str) -> None:
"""Test valid call to bigdl-llm."""
llm = BigdlLLM.from_model_id(
model_id="lmsys/vicuna-7b-v1.5",
model_id=model_id,
model_kwargs={"temperature": 0, "max_length": 16, "trust_remote_code": True},
)
output = llm.generate(["Hello!"])
Expand Down

0 comments on commit 7c133b7

Please sign in to comment.