diff --git a/tests/test_llms.py b/tests/test_llms.py index d8da345..8bd4974 100644 --- a/tests/test_llms.py +++ b/tests/test_llms.py @@ -317,7 +317,6 @@ async def test_text_image_message(self, model_name: str) -> None: ), "Expected content in message, but got None" assert "red" in result.messages[-1].content.lower() - # Test n = 1 @pytest.mark.parametrize( "model_name", [CILLMModelNames.ANTHROPIC.value, "gpt-3.5-turbo"] ) @@ -339,6 +338,13 @@ async def test_single_completion(self, model_name: str) -> None: assert len(result.messages) == 1 assert result.messages[0].content + model = self.MODEL_CLS(name=model_name, config={"n": 2}) + result = await model.call(messages, n=1) + assert isinstance(result, LLMResult) + assert result.messages + assert len(result.messages) == 1 + assert result.messages[0].content + @pytest.mark.asyncio @pytest.mark.vcr @pytest.mark.parametrize( @@ -365,6 +371,10 @@ async def test_multiple_completion(self, model_name: str, request) -> None: results = await model.call(messages, n=self.NUM_COMPLETIONS) assert len(results) == self.NUM_COMPLETIONS + model = self.MODEL_CLS(name=model_name, config={"n": 1}) + results = await model.call(messages, n=self.NUM_COMPLETIONS) + assert len(results) == self.NUM_COMPLETIONS + def test_json_schema_validation() -> None: # Invalid JSON