diff --git a/llmclient/llms.py b/llmclient/llms.py index 2f23255..d68be73 100644 --- a/llmclient/llms.py +++ b/llmclient/llms.py @@ -614,7 +614,7 @@ class MultipleCompletionLLMModel(BaseModel): "Configuration of the model:" "model is the name of the llm model to use," "temperature is the sampling temperature, and" - "n is the number of completions to generate." + "n is the number of completions to generate by default." ), ) encoding: Any | None = None @@ -832,9 +832,6 @@ async def _call( # noqa: C901, PLR0915 return results - # TODO: Is it good practice to have this multiple interface? - # Users can just use `call` and we chat `n` - # or they can specifically call `call_single` or `call_multiple` async def call_single( self, messages: list[Message], @@ -844,7 +841,25 @@ async def call_single( tool_choice: Tool | str | None = TOOL_CHOICE_REQUIRED, **chat_kwargs, ) -> LLMResult: - if chat_kwargs.get("n", 1) != 1 or self.config.get("n", 1) != 1: + """ + Calls the LLM with a list of messages and returns a single completion result. + + Args: + messages: A list of messages to send to the LLM. + callbacks: A list of callback functions to execute after the call. + output_type: The type of the output model. + tools: A list of tools to use during the call. + tool_choice: The tool or tool choice to use. + **chat_kwargs: Additional keyword arguments for the chat. + + Returns: + The result of the LLM call as a LLMResult object. + + Raises: + ValueError: If the value of 'n' is not 1. + """ + n = chat_kwargs.get("n", self.config.get("n", 1)) + if n != 1: raise ValueError("n must be 1 for call_single.") return ( await self._call( @@ -861,17 +876,27 @@ async def call_multiple( tool_choice: Tool | str | None = TOOL_CHOICE_REQUIRED, **chat_kwargs, ) -> list[LLMResult]: - if 1 in {chat_kwargs.get("n", 1), self.config.get("n", 1)}: - if ( - chat_kwargs.get("n") - and self.config.get("n") - and chat_kwargs.get("n") != self.config.get("n") - ): - raise ValueError( - f"Incompatible number of completions requested. " - f"Model's configuration n is {self.config['n']}, " - f"but kwarg n={chat_kwargs['n']} was passed." - ) + """ + Calls the LLM with a list of messages and returns a list of completion results. + + Args: + messages: A list of messages to send to the LLM. + callbacks: A list of callback functions to execute after receiving the response. + output_type: The type of the output model. + tools: A list of tools to use during the call. + tool_choice: The tool or tool choice strategy to use. + **chat_kwargs: Additional keyword arguments to pass to the chat function. + + Returns: + A list of results from the LLM. + + Raises: + Warning: If the number of completions (`n`) requested is set to 1, + a warning is logged indicating that the returned list will contain a single element. + `n` can be set in chat_kargs or in the model's configuration. + """ + n = chat_kwargs.get("n", self.config.get("n", 1)) + if n == 1: logger.warning( "n is 1 for call_multiple. It will return a list with a single element" ) @@ -913,14 +938,33 @@ async def call( n: int | None = None, **chat_kwargs, ) -> list[LLMResult] | LLMResult: + """ + Call the LLM model with the given messages and configuration. - # Uses the LLMModel configuration unless specified in chat_kwargs - # If n is not specified anywhere, defaults to 1 + Args: + messages: A list of messages to send to the language model. + callbacks: A list of callback functions to execute after receiving the response. + output_type: The type of the output model. + tools: A list of tools to use during the call. + tool_choice: The tool or tool identifier to use. + n: An integer argument that specifies the number of completions to generate. + If n is not specified, the model's configuration is used. + **chat_kwargs: Additional keyword arguments to pass to the chat function. + + Returns: + A list of LLMResult objects if multiple completions are requested (n>1), + otherwise a single LLMResult object. + + Raises: + ValueError: If the number of completions `n` is invalid. + """ if not n or n <= 0: logger.info( - "Invalid n passed to the call function. Will get it from the model's configuration" + "Invalid number of completions `n` requested to the call function. " + "Will get it from the model's configuration." ) n = self.config.get("n", 1) + chat_kwargs["n"] = n if n == 1: return await self.call_single( messages, callbacks, output_type, tools, tool_choice, **chat_kwargs