Skip to content

Commit

Permalink
Added documentation to call functions
Browse files Browse the repository at this point in the history
  • Loading branch information
maykcaldas committed Dec 9, 2024
1 parent 3f650fc commit 1e6eb78
Showing 1 changed file with 63 additions and 19 deletions.
82 changes: 63 additions & 19 deletions llmclient/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,7 @@ class MultipleCompletionLLMModel(BaseModel):
"Configuration of the model:"
"model is the name of the llm model to use,"
"temperature is the sampling temperature, and"
"n is the number of completions to generate."
"n is the number of completions to generate by default."
),
)
encoding: Any | None = None
Expand Down Expand Up @@ -832,9 +832,6 @@ async def _call( # noqa: C901, PLR0915

return results

# TODO: Is it good practice to have this multiple interface?
# Users can just use `call` and we chat `n`
# or they can specifically call `call_single` or `call_multiple`
async def call_single(
self,
messages: list[Message],
Expand All @@ -844,7 +841,25 @@ async def call_single(
tool_choice: Tool | str | None = TOOL_CHOICE_REQUIRED,
**chat_kwargs,
) -> LLMResult:
if chat_kwargs.get("n", 1) != 1 or self.config.get("n", 1) != 1:
"""
Calls the LLM with a list of messages and returns a single completion result.
Args:
messages: A list of messages to send to the LLM.
callbacks: A list of callback functions to execute after the call.
output_type: The type of the output model.
tools: A list of tools to use during the call.
tool_choice: The tool or tool choice to use.
**chat_kwargs: Additional keyword arguments for the chat.
Returns:
The result of the LLM call as a LLMResult object.
Raises:
ValueError: If the value of 'n' is not 1.
"""
n = chat_kwargs.get("n", self.config.get("n", 1))
if n != 1:
raise ValueError("n must be 1 for call_single.")
return (
await self._call(
Expand All @@ -861,17 +876,27 @@ async def call_multiple(
tool_choice: Tool | str | None = TOOL_CHOICE_REQUIRED,
**chat_kwargs,
) -> list[LLMResult]:
if 1 in {chat_kwargs.get("n", 1), self.config.get("n", 1)}:
if (
chat_kwargs.get("n")
and self.config.get("n")
and chat_kwargs.get("n") != self.config.get("n")
):
raise ValueError(
f"Incompatible number of completions requested. "
f"Model's configuration n is {self.config['n']}, "
f"but kwarg n={chat_kwargs['n']} was passed."
)
"""
Calls the LLM with a list of messages and returns a list of completion results.
Args:
messages: A list of messages to send to the LLM.
callbacks: A list of callback functions to execute after receiving the response.
output_type: The type of the output model.
tools: A list of tools to use during the call.
tool_choice: The tool or tool choice strategy to use.
**chat_kwargs: Additional keyword arguments to pass to the chat function.
Returns:
A list of results from the LLM.
Raises:
Warning: If the number of completions (`n`) requested is set to 1,
a warning is logged indicating that the returned list will contain a single element.
`n` can be set in chat_kargs or in the model's configuration.
"""
n = chat_kwargs.get("n", self.config.get("n", 1))
if n == 1:
logger.warning(
"n is 1 for call_multiple. It will return a list with a single element"
)
Expand Down Expand Up @@ -913,14 +938,33 @@ async def call(
n: int | None = None,
**chat_kwargs,
) -> list[LLMResult] | LLMResult:
"""
Call the LLM model with the given messages and configuration.
# Uses the LLMModel configuration unless specified in chat_kwargs
# If n is not specified anywhere, defaults to 1
Args:
messages: A list of messages to send to the language model.
callbacks: A list of callback functions to execute after receiving the response.
output_type: The type of the output model.
tools: A list of tools to use during the call.
tool_choice: The tool or tool identifier to use.
n: An integer argument that specifies the number of completions to generate.
If n is not specified, the model's configuration is used.
**chat_kwargs: Additional keyword arguments to pass to the chat function.
Returns:
A list of LLMResult objects if multiple completions are requested (n>1),
otherwise a single LLMResult object.
Raises:
ValueError: If the number of completions `n` is invalid.
"""
if not n or n <= 0:
logger.info(
"Invalid n passed to the call function. Will get it from the model's configuration"
"Invalid number of completions `n` requested to the call function. "
"Will get it from the model's configuration."
)
n = self.config.get("n", 1)
chat_kwargs["n"] = n
if n == 1:
return await self.call_single(
messages, callbacks, output_type, tools, tool_choice, **chat_kwargs
Expand Down

0 comments on commit 1e6eb78

Please sign in to comment.