-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Overloaded MultipleCompletionLLMModel.call type #13
Merged
Merged
Changes from 13 commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
7146c91
Overloaded typying in MultipleCompletionLLMModel.call. It returns eit…
maykcaldas 1dcda13
Improved logging for call_multiple
maykcaldas 2847af7
removed deprecated check of n in kwargs
maykcaldas 2eac4a6
Merge branch 'main' into over-mult
maykcaldas 6fbf2f2
Added cassets for TestMultipleCompletionLLMModel
maykcaldas 5d3a3c9
Fix lint
maykcaldas 3f650fc
Implemented tests to check kwarg priority when calling
maykcaldas 7edd613
Exposed missing classes
maykcaldas bae8765
added embedding_model_factory
maykcaldas 1e6eb78
Added documentation to call functions
maykcaldas cb16d19
skip lint checking for argument with default value in test_llms
maykcaldas 7966f9a
Fixed pre-commit errors
maykcaldas 9e91858
Reverted changes in uv.lock
maykcaldas 29e4d91
Fixed line wrap in docstrings
maykcaldas f8090bb
reverting uv.lock
maykcaldas 418fa3b
removed the dependency on numpy. It is now a conditional dependency f…
maykcaldas ba974e5
Merge branch 'main' into remove_numpy
maykcaldas c34b02c
Removed image group dependency
maykcaldas 270948e
Merge branch 'remove_numpy' of github.com:Future-House/llm-client int…
maykcaldas 86d455d
Fixed typos
maykcaldas 7ef8f49
Removed overload from the multiple completion llm call
maykcaldas 03ede77
Merge branch 'remove_numpy' into over-mult
maykcaldas 7d196df
Merge branch 'update_init' into over-mult
maykcaldas File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,9 +16,11 @@ | |
from typing import ( | ||
Any, | ||
ClassVar, | ||
Literal, | ||
Self, | ||
TypeVar, | ||
cast, | ||
overload, | ||
) | ||
|
||
import litellm | ||
|
@@ -612,7 +614,7 @@ class MultipleCompletionLLMModel(BaseModel): | |
"Configuration of the model:" | ||
"model is the name of the llm model to use," | ||
"temperature is the sampling temperature, and" | ||
"n is the number of completions to generate." | ||
"n is the number of completions to generate by default." | ||
), | ||
) | ||
encoding: Any | None = None | ||
|
@@ -658,7 +660,7 @@ async def achat_iter(self, messages: Iterable[Message], **kwargs) -> AsyncGenera | |
# > `required` means the model must call one or more tools. | ||
TOOL_CHOICE_REQUIRED: ClassVar[str] = "required" | ||
|
||
async def call( # noqa: C901, PLR0915 | ||
async def _call( # noqa: C901, PLR0915 | ||
self, | ||
messages: list[Message], | ||
callbacks: list[Callable] | None = None, | ||
|
@@ -829,3 +831,144 @@ async def call( # noqa: C901, PLR0915 | |
result.seconds_to_last_token = end_clock - start_clock | ||
|
||
return results | ||
|
||
async def call_single( | ||
self, | ||
messages: list[Message], | ||
callbacks: list[Callable] | None = None, | ||
output_type: type[BaseModel] | None = None, | ||
tools: list[Tool] | None = None, | ||
tool_choice: Tool | str | None = TOOL_CHOICE_REQUIRED, | ||
**chat_kwargs, | ||
) -> LLMResult: | ||
""" | ||
Calls the LLM with a list of messages and returns a single completion result. | ||
|
||
Args: | ||
messages: A list of messages to send to the LLM. | ||
callbacks: A list of callback functions to execute after the call. | ||
output_type: The type of the output model. | ||
tools: A list of tools to use during the call. | ||
tool_choice: The tool or tool choice to use. | ||
**chat_kwargs: Additional keyword arguments for the chat. | ||
|
||
Returns: | ||
The result of the LLM call as a LLMResult object. | ||
|
||
Raises: | ||
ValueError: If the value of 'n' is not 1. | ||
""" | ||
n = chat_kwargs.get("n", self.config.get("n", 1)) | ||
if n != 1: | ||
raise ValueError("n must be 1 for call_single.") | ||
return ( | ||
await self._call( | ||
messages, callbacks, output_type, tools, tool_choice, **chat_kwargs | ||
) | ||
)[0] | ||
|
||
async def call_multiple( | ||
self, | ||
messages: list[Message], | ||
callbacks: list[Callable] | None = None, | ||
output_type: type[BaseModel] | None = None, | ||
tools: list[Tool] | None = None, | ||
tool_choice: Tool | str | None = TOOL_CHOICE_REQUIRED, | ||
**chat_kwargs, | ||
) -> list[LLMResult]: | ||
""" | ||
Calls the LLM with a list of messages and returns a list of completion results. | ||
|
||
Args: | ||
messages: A list of messages to send to the LLM. | ||
callbacks: A list of callback functions to execute after receiving the response. | ||
output_type: The type of the output model. | ||
tools: A list of tools to use during the call. | ||
tool_choice: The tool or tool choice strategy to use. | ||
**chat_kwargs: Additional keyword arguments to pass to the chat function. | ||
|
||
Returns: | ||
A list of results from the LLM. | ||
|
||
Raises: | ||
Warning: If the number of completions (`n`) requested is set to 1, | ||
a warning is logged indicating that the returned list will contain a single element. | ||
`n` can be set in chat_kargs or in the model's configuration. | ||
""" | ||
n = chat_kwargs.get("n", self.config.get("n", 1)) | ||
if n == 1: | ||
logger.warning( | ||
"n is 1 for call_multiple. It will return a list with a single element" | ||
) | ||
return await self._call( | ||
messages, callbacks, output_type, tools, tool_choice, **chat_kwargs | ||
) | ||
|
||
@overload | ||
async def call( | ||
self, | ||
messages: list[Message], | ||
callbacks: list[Callable] | None = None, | ||
output_type: type[BaseModel] | None = None, | ||
tools: list[Tool] | None = None, | ||
tool_choice: Tool | str | None = TOOL_CHOICE_REQUIRED, | ||
n: Literal[1] = 1, | ||
**chat_kwargs, | ||
) -> LLMResult: ... | ||
|
||
@overload | ||
async def call( | ||
self, | ||
messages: list[Message], | ||
callbacks: list[Callable] | None = None, | ||
output_type: type[BaseModel] | None = None, | ||
tools: list[Tool] | None = None, | ||
tool_choice: Tool | str | None = TOOL_CHOICE_REQUIRED, | ||
n: int | None = None, | ||
**chat_kwargs, | ||
) -> list[LLMResult]: ... | ||
|
||
async def call( | ||
self, | ||
messages: list[Message], | ||
callbacks: list[Callable] | None = None, | ||
output_type: type[BaseModel] | None = None, | ||
tools: list[Tool] | None = None, | ||
tool_choice: Tool | str | None = TOOL_CHOICE_REQUIRED, | ||
n: int | None = None, | ||
**chat_kwargs, | ||
) -> list[LLMResult] | LLMResult: | ||
""" | ||
Call the LLM model with the given messages and configuration. | ||
|
||
Args: | ||
messages: A list of messages to send to the language model. | ||
callbacks: A list of callback functions to execute after receiving the response. | ||
output_type: The type of the output model. | ||
tools: A list of tools to use during the call. | ||
tool_choice: The tool or tool identifier to use. | ||
n: An integer argument that specifies the number of completions to generate. | ||
If n is not specified, the model's configuration is used. | ||
**chat_kwargs: Additional keyword arguments to pass to the chat function. | ||
|
||
Returns: | ||
A list of LLMResult objects if multiple completions are requested (n>1), | ||
otherwise a single LLMResult object. | ||
|
||
Raises: | ||
ValueError: If the number of completions `n` is invalid. | ||
""" | ||
if not n or n <= 0: | ||
logger.info( | ||
"Invalid number of completions `n` requested to the call function. " | ||
"Will get it from the model's configuration." | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should raise an error if |
||
n = self.config.get("n", 1) | ||
chat_kwargs["n"] = n | ||
if n == 1: | ||
return await self.call_single( | ||
messages, callbacks, output_type, tools, tool_choice, **chat_kwargs | ||
) | ||
return await self.call_multiple( | ||
messages, callbacks, output_type, tools, tool_choice, **chat_kwargs | ||
) |
196 changes: 196 additions & 0 deletions
196
tests/cassettes/TestMultipleCompletionLLMModel.test_multiple_completion[gpt-3.5-turbo].yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,196 @@ | ||
interactions: | ||
- request: | ||
body: | ||
'{"messages":[{"role":"system","content":"Respond with single words."},{"role":"user","content":"Hello, | ||
how are you?"}],"model":"gpt-3.5-turbo","n":2}' | ||
headers: | ||
accept: | ||
- application/json | ||
accept-encoding: | ||
- gzip, deflate | ||
connection: | ||
- keep-alive | ||
content-length: | ||
- "149" | ||
content-type: | ||
- application/json | ||
host: | ||
- api.openai.com | ||
user-agent: | ||
- AsyncOpenAI/Python 1.57.0 | ||
x-stainless-arch: | ||
- arm64 | ||
x-stainless-async: | ||
- async:asyncio | ||
x-stainless-lang: | ||
- python | ||
x-stainless-os: | ||
- MacOS | ||
x-stainless-package-version: | ||
- 1.57.0 | ||
x-stainless-raw-response: | ||
- "true" | ||
x-stainless-retry-count: | ||
- "1" | ||
x-stainless-runtime: | ||
- CPython | ||
x-stainless-runtime-version: | ||
- 3.12.7 | ||
method: POST | ||
uri: https://api.openai.com/v1/chat/completions | ||
response: | ||
body: | ||
string: !!binary | | ||
H4sIAAAAAAAAA9RTy2rDMBC8+yuEzklo3jS3QCCXXNoe+qIYWdrYamStKq1LS8i/FzkPOySFXnvR | ||
YWZnNLsrbRPGuFZ8xrgsBMnSme48W6zuxvBAL5v55/3H06Ra4OOoWhXl82LJO1GB2TtIOqp6Ektn | ||
gDTaPS09CILo2p8Oh6PhYHo7qYkSFZgoyx11h71xlyqfYfemPxgflAVqCYHP2GvCGGPb+owZrYIv | ||
PmM3nSNSQggiBz47FTHGPZqIcBGCDiQs8U5DSrQEto69RFRtysO6CiJGs5UxB3x3ustg7jxm4cCf | ||
8LW2OhSpBxHQRt9A6HjSEl800P83DSSMvdVLqc5icuexdJQSbsBGw8Fgb8ebZ9AiDxwhCdOCR50r | ||
ZqkCEtqE1ki4FLIA1SibByAqpbFFtMd+meWa975tbfO/2DeElOAIVOo8KC3P+23KPMQ/8lvZacR1 | ||
YB6+A0GZrrXNwTuv6yXXm9wlPwAAAP//AwAh8pBrpAMAAA== | ||
headers: | ||
CF-Cache-Status: | ||
- DYNAMIC | ||
CF-RAY: | ||
- 8ed70040cbcdf99b-SJC | ||
Connection: | ||
- keep-alive | ||
Content-Encoding: | ||
- gzip | ||
Content-Type: | ||
- application/json | ||
Date: | ||
- Thu, 05 Dec 2024 21:06:36 GMT | ||
Server: | ||
- cloudflare | ||
Transfer-Encoding: | ||
- chunked | ||
X-Content-Type-Options: | ||
- nosniff | ||
access-control-expose-headers: | ||
- X-Request-ID | ||
alt-svc: | ||
- h3=":443"; ma=86400 | ||
openai-organization: | ||
- future-house-xr4tdh | ||
openai-processing-ms: | ||
- "134" | ||
openai-version: | ||
- "2020-10-01" | ||
strict-transport-security: | ||
- max-age=31536000; includeSubDomains; preload | ||
x-ratelimit-limit-requests: | ||
- "12000" | ||
x-ratelimit-limit-tokens: | ||
- "1000000" | ||
x-ratelimit-remaining-requests: | ||
- "11999" | ||
x-ratelimit-remaining-tokens: | ||
- "999953" | ||
x-ratelimit-reset-requests: | ||
- 5ms | ||
x-ratelimit-reset-tokens: | ||
- 2ms | ||
x-request-id: | ||
- req_1f88664946b9891fbc90796687f144c4 | ||
status: | ||
code: 200 | ||
message: OK | ||
- request: | ||
body: | ||
'{"messages":[{"role":"system","content":"Respond with single words."},{"role":"user","content":"Hello, | ||
how are you?"}],"model":"gpt-3.5-turbo","n":2}' | ||
headers: | ||
accept: | ||
- application/json | ||
accept-encoding: | ||
- gzip, deflate | ||
connection: | ||
- keep-alive | ||
content-length: | ||
- "149" | ||
content-type: | ||
- application/json | ||
host: | ||
- api.openai.com | ||
user-agent: | ||
- AsyncOpenAI/Python 1.57.0 | ||
x-stainless-arch: | ||
- arm64 | ||
x-stainless-async: | ||
- async:asyncio | ||
x-stainless-lang: | ||
- python | ||
x-stainless-os: | ||
- MacOS | ||
x-stainless-package-version: | ||
- 1.57.0 | ||
x-stainless-raw-response: | ||
- "true" | ||
x-stainless-retry-count: | ||
- "0" | ||
x-stainless-runtime: | ||
- CPython | ||
x-stainless-runtime-version: | ||
- 3.12.7 | ||
method: POST | ||
uri: https://api.openai.com/v1/chat/completions | ||
response: | ||
body: | ||
string: !!binary | | ||
H4sIAAAAAAAAA9RTTUsDMRC9768IOW9LP63tzaIIIqgH7UFkSZPZbTSbCcksWEr/u2T7sVtawauX | ||
HN6b9/JmJtkkjHGt+IxxuRIkS2c6N8vbx5fF9Tisy7l5e13clfl0/vQwfl5P5o6nUYHLT5B0UHUl | ||
ls4AabQ7WnoQBNG1PxkOR8PBZHpVEyUqMFFWOOoMu+MOVX6JnV5/MN4rV6glBD5j7wljjG3qM2a0 | ||
Cr75jPXSA1JCCKIAPjsWMcY9mohwEYIOJCzxtCElWgJbx75HVG3KQ14FEaPZypg9vj3eZbBwHpdh | ||
zx/xXFsdVpkHEdBG30DoeNISnzXQ/zcNJIx91EupTmJy57F0lBF+gY2Gg8HOjjfPoEXuOUISpgWP | ||
0gtmmQIS2oTWSLgUcgWqUTYPQFRKY4toj/08yyXvXdvaFn+xbwgpwRGozHlQWp7225R5iH/kt7Lj | ||
iOvAPKwDQZnl2hbgndf1kutNbpMfAAAA//8DALEE5HikAwAA | ||
headers: | ||
CF-Cache-Status: | ||
- DYNAMIC | ||
CF-RAY: | ||
- 8ed700428d77f99b-SJC | ||
Connection: | ||
- keep-alive | ||
Content-Encoding: | ||
- gzip | ||
Content-Type: | ||
- application/json | ||
Date: | ||
- Thu, 05 Dec 2024 21:06:36 GMT | ||
Server: | ||
- cloudflare | ||
Transfer-Encoding: | ||
- chunked | ||
X-Content-Type-Options: | ||
- nosniff | ||
access-control-expose-headers: | ||
- X-Request-ID | ||
alt-svc: | ||
- h3=":443"; ma=86400 | ||
openai-organization: | ||
- future-house-xr4tdh | ||
openai-processing-ms: | ||
- "114" | ||
openai-version: | ||
- "2020-10-01" | ||
strict-transport-security: | ||
- max-age=31536000; includeSubDomains; preload | ||
x-ratelimit-limit-requests: | ||
- "12000" | ||
x-ratelimit-limit-tokens: | ||
- "1000000" | ||
x-ratelimit-remaining-requests: | ||
- "11999" | ||
x-ratelimit-remaining-tokens: | ||
- "999953" | ||
x-ratelimit-reset-requests: | ||
- 5ms | ||
x-ratelimit-reset-tokens: | ||
- 2ms | ||
x-request-id: | ||
- req_e32516fa5bb6ab11dda5155511280ea6 | ||
status: | ||
code: 200 | ||
message: OK | ||
version: 1 |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When would we expect someone to use these overloads instead of the dedicated methods
call_single
andcall_multiple
?IMO, it would be easier to maintain just two methods:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like Sid's suggestion too. Also, let's make a docstring somewhere mentioning what
n
does on the "back end". Readers won't intuitively know whatn
means, it can refer to so many things.On a related note,
MultipleCompletionLLMModel.achat
callslitellm.acompletion
. Can we renameMultipleCompletionLLMModel.achat
to beMultipleCompletionLLMModel.acompletion
to standardize with the actual API endpoint ultimately being invoked?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like we still have the overloads,
call_single
, andcall_multiple
here - can we reduce this tocall
andcall_single
?