diff --git a/libs/community/tests/integration_tests/llms/test_bigdl_llm.py b/libs/community/tests/integration_tests/llms/test_bigdl_llm.py index c5198cd15fdc9..4162fbdb89ed9 100644 --- a/libs/community/tests/integration_tests/llms/test_bigdl_llm.py +++ b/libs/community/tests/integration_tests/llms/test_bigdl_llm.py @@ -6,10 +6,12 @@ from langchain_community.llms.bigdl_llm import BigdlLLM -model_ids_to_test = os.getenv('TEST_BIGDLLLM_MODEL_IDS') or "" -skip_if_no_model_ids = pytest.mark.skipif(not model_ids_to_test, - reason="TEST_BIGDLLLM_MODEL_IDS environment variable not set.") -model_ids_to_test = [model_id.strip() for model_id in model_ids_to_test.split(',')] +model_ids_to_test = os.getenv("TEST_BIGDLLLM_MODEL_IDS") or "" +skip_if_no_model_ids = pytest.mark.skipif( + not model_ids_to_test, + reason="TEST_BIGDLLLM_MODEL_IDS environment variable not set.", +) +model_ids_to_test = [model_id.strip() for model_id in model_ids_to_test.split(",")] @skip_if_no_model_ids @@ -17,7 +19,7 @@ "model_id", model_ids_to_test, ) -def test_call(model_id:str) -> None: +def test_call(model_id: str) -> None: """Test valid call to bigdl-llm.""" llm = BigdlLLM.from_model_id( model_id=model_id, @@ -32,7 +34,7 @@ def test_call(model_id:str) -> None: "model_id", model_ids_to_test, ) -def test_generate(model_id:str) -> None: +def test_generate(model_id: str) -> None: """Test valid call to bigdl-llm.""" llm = BigdlLLM.from_model_id( model_id=model_id, diff --git a/libs/community/tests/integration_tests/llms/test_ipex_llm.py b/libs/community/tests/integration_tests/llms/test_ipex_llm.py index 9f79ca4c36a91..c895be9f60ce1 100644 --- a/libs/community/tests/integration_tests/llms/test_ipex_llm.py +++ b/libs/community/tests/integration_tests/llms/test_ipex_llm.py @@ -6,10 +6,12 @@ from langchain_community.llms import IpexLLM -model_ids_to_test = os.getenv('TEST_IPEXLLM_MODEL_IDS') or "" -skip_if_no_model_ids = pytest.mark.skipif(not model_ids_to_test, - reason="TEST_IPEXLLM_MODEL_IDS environment variable not set.") -model_ids_to_test = [model_id.strip() for model_id in model_ids_to_test.split(',')] +model_ids_to_test = os.getenv("TEST_IPEXLLM_MODEL_IDS") or "" +skip_if_no_model_ids = pytest.mark.skipif( + not model_ids_to_test, reason="TEST_IPEXLLM_MODEL_IDS environment variable not set." +) +model_ids_to_test = [model_id.strip() for model_id in model_ids_to_test.split(",")] + def load_model(model_id: str) -> None: llm = IpexLLM.from_model_id( @@ -18,10 +20,11 @@ def load_model(model_id: str) -> None: ) return llm + def load_model_more_types(model_id: str, load_in_low_bit: str) -> None: llm = IpexLLM.from_model_id( model_id=model_id, - load_in_low_bit = load_in_low_bit, + load_in_low_bit=load_in_low_bit, model_kwargs={"temperature": 0, "max_length": 16, "trust_remote_code": True}, ) return llm @@ -51,7 +54,7 @@ def test_asym_int4(model_id: str) -> None: assert isinstance(output, str) -@skip_if_no_model_ids +@skip_if_no_model_ids @pytest.mark.parametrize( "model_id", model_ids_to_test, @@ -64,7 +67,7 @@ def test_generate(model_id: str) -> None: assert isinstance(output.generations, list) -@skip_if_no_model_ids +@skip_if_no_model_ids @pytest.mark.parametrize( "model_id", model_ids_to_test, @@ -77,9 +80,8 @@ def test_save_load_lowbit(model_id: str) -> None: del llm loaded_llm = IpexLLM.from_model_id_low_bit( model_id=saved_lowbit_path, - tokenizer_id = model_id, + tokenizer_id=model_id, model_kwargs={"temperature": 0, "max_length": 16, "trust_remote_code": True}, ) output = loaded_llm("Hello!") assert isinstance(output, str) - \ No newline at end of file