Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Overloaded MultipleCompletionLLMModel.call type #13

Merged
merged 23 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
7146c91
Overloaded typying in MultipleCompletionLLMModel.call. It returns eit…
maykcaldas Dec 5, 2024
1dcda13
Improved logging for call_multiple
maykcaldas Dec 5, 2024
2847af7
removed deprecated check of n in kwargs
maykcaldas Dec 5, 2024
2eac4a6
Merge branch 'main' into over-mult
maykcaldas Dec 6, 2024
6fbf2f2
Added cassets for TestMultipleCompletionLLMModel
maykcaldas Dec 6, 2024
5d3a3c9
Fix lint
maykcaldas Dec 6, 2024
3f650fc
Implemented tests to check kwarg priority when calling
maykcaldas Dec 9, 2024
7edd613
Exposed missing classes
maykcaldas Dec 9, 2024
bae8765
added embedding_model_factory
maykcaldas Dec 9, 2024
1e6eb78
Added documentation to call functions
maykcaldas Dec 9, 2024
cb16d19
skip lint checking for argument with default value in test_llms
maykcaldas Dec 9, 2024
7966f9a
Fixed pre-commit errors
maykcaldas Dec 9, 2024
9e91858
Reverted changes in uv.lock
maykcaldas Dec 9, 2024
29e4d91
Fixed line wrap in docstrings
maykcaldas Dec 9, 2024
f8090bb
reverting uv.lock
maykcaldas Dec 9, 2024
418fa3b
removed the dependency on numpy. It is now a conditional dependency f…
maykcaldas Dec 9, 2024
ba974e5
Merge branch 'main' into remove_numpy
maykcaldas Dec 9, 2024
c34b02c
Removed image group dependency
maykcaldas Dec 9, 2024
270948e
Merge branch 'remove_numpy' of github.com:Future-House/llm-client int…
maykcaldas Dec 9, 2024
86d455d
Fixed typos
maykcaldas Dec 9, 2024
7ef8f49
Removed overload from the multiple completion llm call
maykcaldas Dec 9, 2024
03ede77
Merge branch 'remove_numpy' into over-mult
maykcaldas Dec 9, 2024
7d196df
Merge branch 'update_init' into over-mult
maykcaldas Dec 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 145 additions & 2 deletions llmclient/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
from typing import (
Any,
ClassVar,
Literal,
Self,
TypeVar,
cast,
overload,
)

import litellm
Expand Down Expand Up @@ -612,7 +614,7 @@ class MultipleCompletionLLMModel(BaseModel):
"Configuration of the model:"
"model is the name of the llm model to use,"
"temperature is the sampling temperature, and"
"n is the number of completions to generate."
"n is the number of completions to generate by default."
),
)
encoding: Any | None = None
Expand Down Expand Up @@ -658,7 +660,7 @@ async def achat_iter(self, messages: Iterable[Message], **kwargs) -> AsyncGenera
# > `required` means the model must call one or more tools.
TOOL_CHOICE_REQUIRED: ClassVar[str] = "required"

async def call( # noqa: C901, PLR0915
async def _call( # noqa: C901, PLR0915
self,
messages: list[Message],
callbacks: list[Callable] | None = None,
Expand Down Expand Up @@ -829,3 +831,144 @@ async def call( # noqa: C901, PLR0915
result.seconds_to_last_token = end_clock - start_clock

return results

async def call_single(
self,
messages: list[Message],
callbacks: list[Callable] | None = None,
output_type: type[BaseModel] | None = None,
tools: list[Tool] | None = None,
tool_choice: Tool | str | None = TOOL_CHOICE_REQUIRED,
**chat_kwargs,
) -> LLMResult:
"""
Calls the LLM with a list of messages and returns a single completion result.

Args:
messages: A list of messages to send to the LLM.
callbacks: A list of callback functions to execute after the call.
output_type: The type of the output model.
tools: A list of tools to use during the call.
tool_choice: The tool or tool choice to use.
**chat_kwargs: Additional keyword arguments for the chat.

Returns:
The result of the LLM call as a LLMResult object.

Raises:
ValueError: If the value of 'n' is not 1.
"""
n = chat_kwargs.get("n", self.config.get("n", 1))
if n != 1:
raise ValueError("n must be 1 for call_single.")
return (
await self._call(
messages, callbacks, output_type, tools, tool_choice, **chat_kwargs
)
)[0]

async def call_multiple(
self,
messages: list[Message],
callbacks: list[Callable] | None = None,
output_type: type[BaseModel] | None = None,
tools: list[Tool] | None = None,
tool_choice: Tool | str | None = TOOL_CHOICE_REQUIRED,
**chat_kwargs,
) -> list[LLMResult]:
"""
Calls the LLM with a list of messages and returns a list of completion results.

Args:
messages: A list of messages to send to the LLM.
callbacks: A list of callback functions to execute after receiving the response.
output_type: The type of the output model.
tools: A list of tools to use during the call.
tool_choice: The tool or tool choice strategy to use.
**chat_kwargs: Additional keyword arguments to pass to the chat function.

Returns:
A list of results from the LLM.

Raises:
Warning: If the number of completions (`n`) requested is set to 1,
a warning is logged indicating that the returned list will contain a single element.
`n` can be set in chat_kargs or in the model's configuration.
"""
n = chat_kwargs.get("n", self.config.get("n", 1))
if n == 1:
logger.warning(
"n is 1 for call_multiple. It will return a list with a single element"
)
return await self._call(
messages, callbacks, output_type, tools, tool_choice, **chat_kwargs
)

@overload
async def call(
self,
messages: list[Message],
callbacks: list[Callable] | None = None,
output_type: type[BaseModel] | None = None,
tools: list[Tool] | None = None,
tool_choice: Tool | str | None = TOOL_CHOICE_REQUIRED,
n: Literal[1] = 1,
**chat_kwargs,
) -> LLMResult: ...

@overload
async def call(
self,
messages: list[Message],
callbacks: list[Callable] | None = None,
output_type: type[BaseModel] | None = None,
tools: list[Tool] | None = None,
tool_choice: Tool | str | None = TOOL_CHOICE_REQUIRED,
n: int | None = None,
**chat_kwargs,
) -> list[LLMResult]: ...
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When would we expect someone to use these overloads instead of the dedicated methods call_single and call_multiple?

IMO, it would be easier to maintain just two methods:

async def call(self, ..., n: int) -> list[LLMResult]:
    assert n > 0
    ...

async def call_single(self, ...) -> LLMResult:
    return self.call(..., n=1)[0]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like Sid's suggestion too. Also, let's make a docstring somewhere mentioning what n does on the "back end". Readers won't intuitively know what n means, it can refer to so many things.

On a related note, MultipleCompletionLLMModel.achat calls litellm.acompletion. Can we rename MultipleCompletionLLMModel.achat to be MultipleCompletionLLMModel.acompletion to standardize with the actual API endpoint ultimately being invoked?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like we still have the overloads, call_single, and call_multiple here - can we reduce this to call and call_single?


async def call(
self,
messages: list[Message],
callbacks: list[Callable] | None = None,
output_type: type[BaseModel] | None = None,
tools: list[Tool] | None = None,
tool_choice: Tool | str | None = TOOL_CHOICE_REQUIRED,
n: int | None = None,
**chat_kwargs,
) -> list[LLMResult] | LLMResult:
"""
Call the LLM model with the given messages and configuration.

Args:
messages: A list of messages to send to the language model.
callbacks: A list of callback functions to execute after receiving the response.
output_type: The type of the output model.
tools: A list of tools to use during the call.
tool_choice: The tool or tool identifier to use.
n: An integer argument that specifies the number of completions to generate.
If n is not specified, the model's configuration is used.
**chat_kwargs: Additional keyword arguments to pass to the chat function.

Returns:
A list of LLMResult objects if multiple completions are requested (n>1),
otherwise a single LLMResult object.

Raises:
ValueError: If the number of completions `n` is invalid.
"""
if not n or n <= 0:
logger.info(
"Invalid number of completions `n` requested to the call function. "
"Will get it from the model's configuration."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should raise an error if n<=0. And I don't think we need to emit a logging message if n is unspecified, since that will be a common case

n = self.config.get("n", 1)
chat_kwargs["n"] = n
if n == 1:
return await self.call_single(
messages, callbacks, output_type, tools, tool_choice, **chat_kwargs
)
return await self.call_multiple(
messages, callbacks, output_type, tools, tool_choice, **chat_kwargs
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
interactions:
- request:
body:
'{"messages":[{"role":"system","content":"Respond with single words."},{"role":"user","content":"Hello,
how are you?"}],"model":"gpt-3.5-turbo","n":2}'
headers:
accept:
- application/json
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- "149"
content-type:
- application/json
host:
- api.openai.com
user-agent:
- AsyncOpenAI/Python 1.57.0
x-stainless-arch:
- arm64
x-stainless-async:
- async:asyncio
x-stainless-lang:
- python
x-stainless-os:
- MacOS
x-stainless-package-version:
- 1.57.0
x-stainless-raw-response:
- "true"
x-stainless-retry-count:
- "1"
x-stainless-runtime:
- CPython
x-stainless-runtime-version:
- 3.12.7
method: POST
uri: https://api.openai.com/v1/chat/completions
response:
body:
string: !!binary |
H4sIAAAAAAAAA9RTy2rDMBC8+yuEzklo3jS3QCCXXNoe+qIYWdrYamStKq1LS8i/FzkPOySFXnvR
YWZnNLsrbRPGuFZ8xrgsBMnSme48W6zuxvBAL5v55/3H06Ra4OOoWhXl82LJO1GB2TtIOqp6Ektn
gDTaPS09CILo2p8Oh6PhYHo7qYkSFZgoyx11h71xlyqfYfemPxgflAVqCYHP2GvCGGPb+owZrYIv
PmM3nSNSQggiBz47FTHGPZqIcBGCDiQs8U5DSrQEto69RFRtysO6CiJGs5UxB3x3ustg7jxm4cCf
8LW2OhSpBxHQRt9A6HjSEl800P83DSSMvdVLqc5icuexdJQSbsBGw8Fgb8ebZ9AiDxwhCdOCR50r
ZqkCEtqE1ki4FLIA1SibByAqpbFFtMd+meWa975tbfO/2DeElOAIVOo8KC3P+23KPMQ/8lvZacR1
YB6+A0GZrrXNwTuv6yXXm9wlPwAAAP//AwAh8pBrpAMAAA==
headers:
CF-Cache-Status:
- DYNAMIC
CF-RAY:
- 8ed70040cbcdf99b-SJC
Connection:
- keep-alive
Content-Encoding:
- gzip
Content-Type:
- application/json
Date:
- Thu, 05 Dec 2024 21:06:36 GMT
Server:
- cloudflare
Transfer-Encoding:
- chunked
X-Content-Type-Options:
- nosniff
access-control-expose-headers:
- X-Request-ID
alt-svc:
- h3=":443"; ma=86400
openai-organization:
- future-house-xr4tdh
openai-processing-ms:
- "134"
openai-version:
- "2020-10-01"
strict-transport-security:
- max-age=31536000; includeSubDomains; preload
x-ratelimit-limit-requests:
- "12000"
x-ratelimit-limit-tokens:
- "1000000"
x-ratelimit-remaining-requests:
- "11999"
x-ratelimit-remaining-tokens:
- "999953"
x-ratelimit-reset-requests:
- 5ms
x-ratelimit-reset-tokens:
- 2ms
x-request-id:
- req_1f88664946b9891fbc90796687f144c4
status:
code: 200
message: OK
- request:
body:
'{"messages":[{"role":"system","content":"Respond with single words."},{"role":"user","content":"Hello,
how are you?"}],"model":"gpt-3.5-turbo","n":2}'
headers:
accept:
- application/json
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- "149"
content-type:
- application/json
host:
- api.openai.com
user-agent:
- AsyncOpenAI/Python 1.57.0
x-stainless-arch:
- arm64
x-stainless-async:
- async:asyncio
x-stainless-lang:
- python
x-stainless-os:
- MacOS
x-stainless-package-version:
- 1.57.0
x-stainless-raw-response:
- "true"
x-stainless-retry-count:
- "0"
x-stainless-runtime:
- CPython
x-stainless-runtime-version:
- 3.12.7
method: POST
uri: https://api.openai.com/v1/chat/completions
response:
body:
string: !!binary |
H4sIAAAAAAAAA9RTTUsDMRC9768IOW9LP63tzaIIIqgH7UFkSZPZbTSbCcksWEr/u2T7sVtawauX
HN6b9/JmJtkkjHGt+IxxuRIkS2c6N8vbx5fF9Tisy7l5e13clfl0/vQwfl5P5o6nUYHLT5B0UHUl
ls4AabQ7WnoQBNG1PxkOR8PBZHpVEyUqMFFWOOoMu+MOVX6JnV5/MN4rV6glBD5j7wljjG3qM2a0
Cr75jPXSA1JCCKIAPjsWMcY9mohwEYIOJCzxtCElWgJbx75HVG3KQ14FEaPZypg9vj3eZbBwHpdh
zx/xXFsdVpkHEdBG30DoeNISnzXQ/zcNJIx91EupTmJy57F0lBF+gY2Gg8HOjjfPoEXuOUISpgWP
0gtmmQIS2oTWSLgUcgWqUTYPQFRKY4toj/08yyXvXdvaFn+xbwgpwRGozHlQWp7225R5iH/kt7Lj
iOvAPKwDQZnl2hbgndf1kutNbpMfAAAA//8DALEE5HikAwAA
headers:
CF-Cache-Status:
- DYNAMIC
CF-RAY:
- 8ed700428d77f99b-SJC
Connection:
- keep-alive
Content-Encoding:
- gzip
Content-Type:
- application/json
Date:
- Thu, 05 Dec 2024 21:06:36 GMT
Server:
- cloudflare
Transfer-Encoding:
- chunked
X-Content-Type-Options:
- nosniff
access-control-expose-headers:
- X-Request-ID
alt-svc:
- h3=":443"; ma=86400
openai-organization:
- future-house-xr4tdh
openai-processing-ms:
- "114"
openai-version:
- "2020-10-01"
strict-transport-security:
- max-age=31536000; includeSubDomains; preload
x-ratelimit-limit-requests:
- "12000"
x-ratelimit-limit-tokens:
- "1000000"
x-ratelimit-remaining-requests:
- "11999"
x-ratelimit-remaining-tokens:
- "999953"
x-ratelimit-reset-requests:
- 5ms
x-ratelimit-reset-tokens:
- 2ms
x-request-id:
- req_e32516fa5bb6ab11dda5155511280ea6
status:
code: 200
message: OK
version: 1
Loading
Loading