Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Overloaded MultipleCompletionLLMModel.call type #13

Merged
merged 23 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
7146c91
Overloaded typying in MultipleCompletionLLMModel.call. It returns eit…
maykcaldas Dec 5, 2024
1dcda13
Improved logging for call_multiple
maykcaldas Dec 5, 2024
2847af7
removed deprecated check of n in kwargs
maykcaldas Dec 5, 2024
2eac4a6
Merge branch 'main' into over-mult
maykcaldas Dec 6, 2024
6fbf2f2
Added cassets for TestMultipleCompletionLLMModel
maykcaldas Dec 6, 2024
5d3a3c9
Fix lint
maykcaldas Dec 6, 2024
3f650fc
Implemented tests to check kwarg priority when calling
maykcaldas Dec 9, 2024
7edd613
Exposed missing classes
maykcaldas Dec 9, 2024
bae8765
added embedding_model_factory
maykcaldas Dec 9, 2024
1e6eb78
Added documentation to call functions
maykcaldas Dec 9, 2024
cb16d19
skip lint checking for argument with default value in test_llms
maykcaldas Dec 9, 2024
7966f9a
Fixed pre-commit errors
maykcaldas Dec 9, 2024
9e91858
Reverted changes in uv.lock
maykcaldas Dec 9, 2024
29e4d91
Fixed line wrap in docstrings
maykcaldas Dec 9, 2024
f8090bb
reverting uv.lock
maykcaldas Dec 9, 2024
418fa3b
removed the dependency on numpy. It is now a conditional dependency f…
maykcaldas Dec 9, 2024
ba974e5
Merge branch 'main' into remove_numpy
maykcaldas Dec 9, 2024
c34b02c
Removed image group dependency
maykcaldas Dec 9, 2024
270948e
Merge branch 'remove_numpy' of github.com:Future-House/llm-client int…
maykcaldas Dec 9, 2024
86d455d
Fixed typos
maykcaldas Dec 9, 2024
7ef8f49
Removed overload from the multiple completion llm call
maykcaldas Dec 9, 2024
03ede77
Merge branch 'remove_numpy' into over-mult
maykcaldas Dec 9, 2024
7d196df
Merge branch 'update_init' into over-mult
maykcaldas Dec 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 100 additions & 1 deletion llmclient/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
from typing import (
Any,
ClassVar,
Literal,
Self,
TypeVar,
cast,
overload,
)

import litellm
Expand Down Expand Up @@ -658,7 +660,7 @@ async def achat_iter(self, messages: Iterable[Message], **kwargs) -> AsyncGenera
# > `required` means the model must call one or more tools.
TOOL_CHOICE_REQUIRED: ClassVar[str] = "required"

async def call( # noqa: C901, PLR0915
async def _call( # noqa: C901, PLR0915
self,
messages: list[Message],
callbacks: list[Callable] | None = None,
Expand Down Expand Up @@ -829,3 +831,100 @@ async def call( # noqa: C901, PLR0915
result.seconds_to_last_token = end_clock - start_clock

return results

# TODO: Is it good practice to have this multiple interface?
# Users can just use `call` and we chat `n`
# or they can specifically call `call_single` or `call_multiple`
async def call_single(
self,
messages: list[Message],
callbacks: list[Callable] | None = None,
output_type: type[BaseModel] | None = None,
tools: list[Tool] | None = None,
tool_choice: Tool | str | None = TOOL_CHOICE_REQUIRED,
**chat_kwargs,
) -> LLMResult:
if chat_kwargs.get("n", 1) != 1 or self.config.get("n", 1) != 1:
raise ValueError("n must be 1 for call_single.")
return (
await self._call(
messages, callbacks, output_type, tools, tool_choice, **chat_kwargs
)
)[0]

async def call_multiple(
self,
messages: list[Message],
callbacks: list[Callable] | None = None,
output_type: type[BaseModel] | None = None,
tools: list[Tool] | None = None,
tool_choice: Tool | str | None = TOOL_CHOICE_REQUIRED,
**chat_kwargs,
) -> list[LLMResult]:
if 1 in {chat_kwargs.get("n", 1), self.config.get("n", 1)}:
if (
chat_kwargs.get("n")
and self.config.get("n")
and chat_kwargs.get("n") != self.config.get("n")
):
raise ValueError(
f"Incompatible number of completions requested. "
f"Model's configuration n is {self.config['n']}, "
f"but kwarg n={chat_kwargs['n']} was passed."
)
logger.warning(
"n is 1 for call_multiple. It will return a list with a single element"
)
return await self._call(
messages, callbacks, output_type, tools, tool_choice, **chat_kwargs
)

@overload
async def call(
self,
messages: list[Message],
callbacks: list[Callable] | None = None,
output_type: type[BaseModel] | None = None,
tools: list[Tool] | None = None,
tool_choice: Tool | str | None = TOOL_CHOICE_REQUIRED,
n: Literal[1] = 1,
**chat_kwargs,
) -> LLMResult: ...

@overload
async def call(
self,
messages: list[Message],
callbacks: list[Callable] | None = None,
output_type: type[BaseModel] | None = None,
tools: list[Tool] | None = None,
tool_choice: Tool | str | None = TOOL_CHOICE_REQUIRED,
n: int | None = None,
**chat_kwargs,
) -> list[LLMResult]: ...
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When would we expect someone to use these overloads instead of the dedicated methods call_single and call_multiple?

IMO, it would be easier to maintain just two methods:

async def call(self, ..., n: int) -> list[LLMResult]:
    assert n > 0
    ...

async def call_single(self, ...) -> LLMResult:
    return self.call(..., n=1)[0]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like Sid's suggestion too. Also, let's make a docstring somewhere mentioning what n does on the "back end". Readers won't intuitively know what n means, it can refer to so many things.

On a related note, MultipleCompletionLLMModel.achat calls litellm.acompletion. Can we rename MultipleCompletionLLMModel.achat to be MultipleCompletionLLMModel.acompletion to standardize with the actual API endpoint ultimately being invoked?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like we still have the overloads, call_single, and call_multiple here - can we reduce this to call and call_single?


async def call(
self,
messages: list[Message],
callbacks: list[Callable] | None = None,
output_type: type[BaseModel] | None = None,
tools: list[Tool] | None = None,
tool_choice: Tool | str | None = TOOL_CHOICE_REQUIRED,
n: int | None = None,
**chat_kwargs,
) -> list[LLMResult] | LLMResult:

# Uses the LLMModel configuration unless specified in chat_kwargs
# If n is not specified anywhere, defaults to 1
if not n or n <= 0:
logger.info(
"Invalid n passed to the call function. Will get it from the model's configuration"
)
n = self.config.get("n", 1)
if n == 1:
return await self.call_single(
messages, callbacks, output_type, tools, tool_choice, **chat_kwargs
)
return await self.call_multiple(
messages, callbacks, output_type, tools, tool_choice, **chat_kwargs
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
interactions:
- request:
body:
'{"messages":[{"role":"system","content":"Respond with single words."},{"role":"user","content":"Hello,
how are you?"}],"model":"gpt-3.5-turbo","n":2}'
headers:
accept:
- application/json
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- "149"
content-type:
- application/json
host:
- api.openai.com
user-agent:
- AsyncOpenAI/Python 1.57.0
x-stainless-arch:
- arm64
x-stainless-async:
- async:asyncio
x-stainless-lang:
- python
x-stainless-os:
- MacOS
x-stainless-package-version:
- 1.57.0
x-stainless-raw-response:
- "true"
x-stainless-retry-count:
- "1"
x-stainless-runtime:
- CPython
x-stainless-runtime-version:
- 3.12.7
method: POST
uri: https://api.openai.com/v1/chat/completions
response:
body:
string: !!binary |
H4sIAAAAAAAAA9RTy2rDMBC8+yuEzklo3jS3QCCXXNoe+qIYWdrYamStKq1LS8i/FzkPOySFXnvR
YWZnNLsrbRPGuFZ8xrgsBMnSme48W6zuxvBAL5v55/3H06Ra4OOoWhXl82LJO1GB2TtIOqp6Ektn
gDTaPS09CILo2p8Oh6PhYHo7qYkSFZgoyx11h71xlyqfYfemPxgflAVqCYHP2GvCGGPb+owZrYIv
PmM3nSNSQggiBz47FTHGPZqIcBGCDiQs8U5DSrQEto69RFRtysO6CiJGs5UxB3x3ustg7jxm4cCf
8LW2OhSpBxHQRt9A6HjSEl800P83DSSMvdVLqc5icuexdJQSbsBGw8Fgb8ebZ9AiDxwhCdOCR50r
ZqkCEtqE1ki4FLIA1SibByAqpbFFtMd+meWa975tbfO/2DeElOAIVOo8KC3P+23KPMQ/8lvZacR1
YB6+A0GZrrXNwTuv6yXXm9wlPwAAAP//AwAh8pBrpAMAAA==
headers:
CF-Cache-Status:
- DYNAMIC
CF-RAY:
- 8ed70040cbcdf99b-SJC
Connection:
- keep-alive
Content-Encoding:
- gzip
Content-Type:
- application/json
Date:
- Thu, 05 Dec 2024 21:06:36 GMT
Server:
- cloudflare
Transfer-Encoding:
- chunked
X-Content-Type-Options:
- nosniff
access-control-expose-headers:
- X-Request-ID
alt-svc:
- h3=":443"; ma=86400
openai-organization:
- future-house-xr4tdh
openai-processing-ms:
- "134"
openai-version:
- "2020-10-01"
strict-transport-security:
- max-age=31536000; includeSubDomains; preload
x-ratelimit-limit-requests:
- "12000"
x-ratelimit-limit-tokens:
- "1000000"
x-ratelimit-remaining-requests:
- "11999"
x-ratelimit-remaining-tokens:
- "999953"
x-ratelimit-reset-requests:
- 5ms
x-ratelimit-reset-tokens:
- 2ms
x-request-id:
- req_1f88664946b9891fbc90796687f144c4
status:
code: 200
message: OK
- request:
body:
'{"messages":[{"role":"system","content":"Respond with single words."},{"role":"user","content":"Hello,
how are you?"}],"model":"gpt-3.5-turbo","n":2}'
headers:
accept:
- application/json
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- "149"
content-type:
- application/json
host:
- api.openai.com
user-agent:
- AsyncOpenAI/Python 1.57.0
x-stainless-arch:
- arm64
x-stainless-async:
- async:asyncio
x-stainless-lang:
- python
x-stainless-os:
- MacOS
x-stainless-package-version:
- 1.57.0
x-stainless-raw-response:
- "true"
x-stainless-retry-count:
- "0"
x-stainless-runtime:
- CPython
x-stainless-runtime-version:
- 3.12.7
method: POST
uri: https://api.openai.com/v1/chat/completions
response:
body:
string: !!binary |
H4sIAAAAAAAAA9RTTUsDMRC9768IOW9LP63tzaIIIqgH7UFkSZPZbTSbCcksWEr/u2T7sVtawauX
HN6b9/JmJtkkjHGt+IxxuRIkS2c6N8vbx5fF9Tisy7l5e13clfl0/vQwfl5P5o6nUYHLT5B0UHUl
ls4AabQ7WnoQBNG1PxkOR8PBZHpVEyUqMFFWOOoMu+MOVX6JnV5/MN4rV6glBD5j7wljjG3qM2a0
Cr75jPXSA1JCCKIAPjsWMcY9mohwEYIOJCzxtCElWgJbx75HVG3KQ14FEaPZypg9vj3eZbBwHpdh
zx/xXFsdVpkHEdBG30DoeNISnzXQ/zcNJIx91EupTmJy57F0lBF+gY2Gg8HOjjfPoEXuOUISpgWP
0gtmmQIS2oTWSLgUcgWqUTYPQFRKY4toj/08yyXvXdvaFn+xbwgpwRGozHlQWp7225R5iH/kt7Lj
iOvAPKwDQZnl2hbgndf1kutNbpMfAAAA//8DALEE5HikAwAA
headers:
CF-Cache-Status:
- DYNAMIC
CF-RAY:
- 8ed700428d77f99b-SJC
Connection:
- keep-alive
Content-Encoding:
- gzip
Content-Type:
- application/json
Date:
- Thu, 05 Dec 2024 21:06:36 GMT
Server:
- cloudflare
Transfer-Encoding:
- chunked
X-Content-Type-Options:
- nosniff
access-control-expose-headers:
- X-Request-ID
alt-svc:
- h3=":443"; ma=86400
openai-organization:
- future-house-xr4tdh
openai-processing-ms:
- "114"
openai-version:
- "2020-10-01"
strict-transport-security:
- max-age=31536000; includeSubDomains; preload
x-ratelimit-limit-requests:
- "12000"
x-ratelimit-limit-tokens:
- "1000000"
x-ratelimit-remaining-requests:
- "11999"
x-ratelimit-remaining-tokens:
- "999953"
x-ratelimit-reset-requests:
- 5ms
x-ratelimit-reset-tokens:
- 2ms
x-request-id:
- req_e32516fa5bb6ab11dda5155511280ea6
status:
code: 200
message: OK
version: 1
Loading
Loading