Skip to content

Commit

Permalink
update test format
Browse files Browse the repository at this point in the history
  • Loading branch information
shane-huang committed Apr 24, 2024
1 parent 7c133b7 commit 59cfeb6
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 15 deletions.
14 changes: 8 additions & 6 deletions libs/community/tests/integration_tests/llms/test_bigdl_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,20 @@

from langchain_community.llms.bigdl_llm import BigdlLLM

model_ids_to_test = os.getenv('TEST_BIGDLLLM_MODEL_IDS') or ""
skip_if_no_model_ids = pytest.mark.skipif(not model_ids_to_test,
reason="TEST_BIGDLLLM_MODEL_IDS environment variable not set.")
model_ids_to_test = [model_id.strip() for model_id in model_ids_to_test.split(',')]
model_ids_to_test = os.getenv("TEST_BIGDLLLM_MODEL_IDS") or ""
skip_if_no_model_ids = pytest.mark.skipif(
not model_ids_to_test,
reason="TEST_BIGDLLLM_MODEL_IDS environment variable not set.",
)
model_ids_to_test = [model_id.strip() for model_id in model_ids_to_test.split(",")]


@skip_if_no_model_ids
@pytest.mark.parametrize(
"model_id",
model_ids_to_test,
)
def test_call(model_id:str) -> None:
def test_call(model_id: str) -> None:
"""Test valid call to bigdl-llm."""
llm = BigdlLLM.from_model_id(
model_id=model_id,
Expand All @@ -32,7 +34,7 @@ def test_call(model_id:str) -> None:
"model_id",
model_ids_to_test,
)
def test_generate(model_id:str) -> None:
def test_generate(model_id: str) -> None:
"""Test valid call to bigdl-llm."""
llm = BigdlLLM.from_model_id(
model_id=model_id,
Expand Down
20 changes: 11 additions & 9 deletions libs/community/tests/integration_tests/llms/test_ipex_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@

from langchain_community.llms import IpexLLM

model_ids_to_test = os.getenv('TEST_IPEXLLM_MODEL_IDS') or ""
skip_if_no_model_ids = pytest.mark.skipif(not model_ids_to_test,
reason="TEST_IPEXLLM_MODEL_IDS environment variable not set.")
model_ids_to_test = [model_id.strip() for model_id in model_ids_to_test.split(',')]
model_ids_to_test = os.getenv("TEST_IPEXLLM_MODEL_IDS") or ""
skip_if_no_model_ids = pytest.mark.skipif(
not model_ids_to_test, reason="TEST_IPEXLLM_MODEL_IDS environment variable not set."
)
model_ids_to_test = [model_id.strip() for model_id in model_ids_to_test.split(",")]


def load_model(model_id: str) -> None:
llm = IpexLLM.from_model_id(
Expand All @@ -18,10 +20,11 @@ def load_model(model_id: str) -> None:
)
return llm


def load_model_more_types(model_id: str, load_in_low_bit: str) -> None:
llm = IpexLLM.from_model_id(
model_id=model_id,
load_in_low_bit = load_in_low_bit,
load_in_low_bit=load_in_low_bit,
model_kwargs={"temperature": 0, "max_length": 16, "trust_remote_code": True},
)
return llm
Expand Down Expand Up @@ -51,7 +54,7 @@ def test_asym_int4(model_id: str) -> None:
assert isinstance(output, str)


@skip_if_no_model_ids
@skip_if_no_model_ids
@pytest.mark.parametrize(
"model_id",
model_ids_to_test,
Expand All @@ -64,7 +67,7 @@ def test_generate(model_id: str) -> None:
assert isinstance(output.generations, list)


@skip_if_no_model_ids
@skip_if_no_model_ids
@pytest.mark.parametrize(
"model_id",
model_ids_to_test,
Expand All @@ -77,9 +80,8 @@ def test_save_load_lowbit(model_id: str) -> None:
del llm
loaded_llm = IpexLLM.from_model_id_low_bit(
model_id=saved_lowbit_path,
tokenizer_id = model_id,
tokenizer_id=model_id,
model_kwargs={"temperature": 0, "max_length": 16, "trust_remote_code": True},
)
output = loaded_llm("Hello!")
assert isinstance(output, str)

0 comments on commit 59cfeb6

Please sign in to comment.