From 21688b7271a0aa15b097c952fc0ce5895e6bc133 Mon Sep 17 00:00:00 2001 From: James Braza Date: Thu, 23 Jan 2025 15:55:33 -0800 Subject: [PATCH] Fixing LiteLLM `Router.acompletion` typing issue (#43) --- llmclient/llms.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/llmclient/llms.py b/llmclient/llms.py index f028f64..02fc94a 100644 --- a/llmclient/llms.py +++ b/llmclient/llms.py @@ -553,7 +553,11 @@ async def check_rate_limit(self, token_count: float, **kwargs) -> None: @rate_limited async def acompletion(self, messages: list[Message], **kwargs) -> list[LLMResult]: - prompts = [m.model_dump(by_alias=True) for m in messages if m.content] + # cast is necessary for LiteLLM typing bug: https://github.com/BerriAI/litellm/issues/7641 + prompts = cast( + list[litellm.types.llms.openai.AllMessageValues], + [m.model_dump(by_alias=True) for m in messages if m.content], + ) completions = await track_costs(self.router.acompletion)( self.name, prompts, **kwargs ) @@ -602,7 +606,11 @@ async def acompletion(self, messages: list[Message], **kwargs) -> list[LLMResult async def acompletion_iter( self, messages: list[Message], **kwargs ) -> AsyncIterable[LLMResult]: - prompts = [m.model_dump(by_alias=True) for m in messages if m.content] + # cast is necessary for LiteLLM typing bug: https://github.com/BerriAI/litellm/issues/7641 + prompts = cast( + list[litellm.types.llms.openai.AllMessageValues], + [m.model_dump(by_alias=True) for m in messages if m.content], + ) stream_completions = await track_costs_iter(self.router.acompletion)( self.name, prompts,