Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Catch reasoning_content from litellm response #40

Merged
merged 9 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .mailmap
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Mayk Caldas <[email protected]> maykcaldas <[email protected]>
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ repos:
- aiohttp
- coredis
- fhaviary[llm]>=0.14.0 # Match pyproject.toml
- litellm>=1.44,<1.57.2 # Match pyproject.toml
- litellm>1.59.3 # Match pyproject.toml
- limits
- numpy
- pydantic~=2.0,>=2.10.1,<2.10.2 # Match pyproject.toml
Expand Down
77 changes: 39 additions & 38 deletions llmclient/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,16 @@
import logging
from abc import ABC
from collections.abc import (
AsyncGenerator,
AsyncIterable,
AsyncIterator,
Awaitable,
Callable,
Coroutine,
Iterable,
Mapping,
)
from enum import StrEnum
from inspect import isasyncgenfunction, isawaitable, signature
from typing import (
Any,
ClassVar,
TypeAlias,
TypeVar,
cast,
)
from typing import Any, ClassVar, ParamSpec, TypeAlias, cast, overload

import litellm
from aviary.core import (
Expand All @@ -32,7 +25,6 @@
ToolSelector,
is_coroutine_callable,
)
from litellm.types.utils import ModelResponse, ModelResponseStream
from pydantic import (
BaseModel,
ConfigDict,
Expand Down Expand Up @@ -84,7 +76,7 @@ class CommonLLMNames(StrEnum):
)


def sum_logprobs(choice: litellm.utils.Choices | list) -> float | None:
def sum_logprobs(choice: litellm.utils.Choices | list[float]) -> float | None:
"""Calculate the sum of the log probabilities of an LLM completion (a Choices object).

Args:
Expand Down Expand Up @@ -193,14 +185,12 @@ async def acompletion(self, messages: list[Message], **kwargs) -> list[LLMResult

async def acompletion_iter(
self, messages: list[Message], **kwargs
) -> AsyncGenerator[LLMResult]:
) -> AsyncIterable[LLMResult]:
"""Return an async generator that yields completions.

Only the last tuple will be non-zero.
"""
raise NotImplementedError
if False: # type: ignore[unreachable] # pylint: disable=using-constant-test
yield # Trick mypy: https://github.com/python/mypy/issues/5070#issuecomment-1050834495

def count_tokens(self, text: str) -> int:
return len(text) // 4 # gross approximation
Expand Down Expand Up @@ -323,7 +313,7 @@ async def call( # noqa: C901, PLR0915
)
sync_callbacks = [f for f in callbacks if not is_coroutine_callable(f)]
async_callbacks = [f for f in callbacks if is_coroutine_callable(f)]
stream_results = await self.acompletion_iter(messages, **chat_kwargs) # type: ignore[misc]
stream_results = await self.acompletion_iter(messages, **chat_kwargs)
text_result = []
async for result in stream_results:
if result.text:
Expand All @@ -340,7 +330,7 @@ async def call( # noqa: C901, PLR0915
for result in results:
usage = result.prompt_count, result.completion_count
if not sum(usage):
result.completion_count = self.count_tokens(result.text)
result.completion_count = self.count_tokens(cast(str, result.text))
result.seconds_to_last_token = (
asyncio.get_running_loop().time() - start_clock
)
Expand Down Expand Up @@ -368,25 +358,26 @@ async def call_single(
return results[0]


LLMModelOrChild = TypeVar("LLMModelOrChild", bound=LLMModel)
P = ParamSpec("P")


@overload
def rate_limited(
func: Callable[P, Coroutine[Any, Any, list[LLMResult]]],
) -> Callable[P, Coroutine[Any, Any, list[LLMResult]]]: ...


@overload
def rate_limited(
func: Callable[
[LLMModelOrChild, Any],
Awaitable[ModelResponse | ModelResponseStream | list[LLMResult]]
| AsyncIterable[LLMResult],
],
) -> Callable[
[LLMModelOrChild, Any],
Awaitable[list[LLMResult] | AsyncIterator[LLMResult]],
]:
func: Callable[P, AsyncIterable[LLMResult]],
) -> Callable[P, Coroutine[Any, Any, AsyncIterable[LLMResult]]]: ...


def rate_limited(func):
"""Decorator to rate limit relevant methods of an LLMModel."""

@functools.wraps(func)
async def wrapper(
self: LLMModelOrChild, *args: Any, **kwargs: Any
) -> list[LLMResult] | AsyncIterator[LLMResult]:
async def wrapper(self, *args, **kwargs):
if not hasattr(self, "check_rate_limit"):
raise NotImplementedError(
f"Model {self.name} must have a `check_rate_limit` method."
Expand All @@ -405,7 +396,7 @@ async def wrapper(
# portion before yielding
if isasyncgenfunction(func):

async def rate_limited_generator() -> AsyncGenerator[LLMResult, None]:
async def rate_limited_generator() -> AsyncIterable[LLMResult]:
async for item in func(self, *args, **kwargs):
token_count = 0
if isinstance(item, LLMResult):
Expand All @@ -417,9 +408,8 @@ async def rate_limited_generator() -> AsyncGenerator[LLMResult, None]:

return rate_limited_generator()

# We checked isasyncgenfunction above, so this must be a Awaitable
result = await cast(Awaitable[Any], func(self, *args, **kwargs))

# We checked isasyncgenfunction above, so this must be an Awaitable
result = await func(self, *args, **kwargs)
if func.__name__ == "acompletion" and isinstance(result, list):
await self.check_rate_limit(sum(r.completion_count for r in result))
return result
Expand Down Expand Up @@ -552,7 +542,7 @@ async def check_rate_limit(self, token_count: float, **kwargs) -> None:
)

@rate_limited
async def acompletion(self, messages: list[Message], **kwargs) -> list[LLMResult]: # type: ignore[override]
async def acompletion(self, messages: list[Message], **kwargs) -> list[LLMResult]:
prompts = [m.model_dump(by_alias=True) for m in messages if m.content]
completions = await track_costs(self.router.acompletion)(
self.name, prompts, **kwargs
Expand All @@ -574,6 +564,15 @@ async def acompletion(self, messages: list[Message], **kwargs) -> list[LLMResult
]
else:
output_messages = [Message(**completion.message.model_dump())]

reasoning_content = None
if hasattr(completion.message, "provider_specific_fields"):
provider_specific_fields = completion.message.provider_specific_fields
if isinstance(provider_specific_fields, dict):
reasoning_content = provider_specific_fields.get(
"reasoning_content", None
)

results.append(
LLMResult(
model=self.name,
Expand All @@ -584,14 +583,15 @@ async def acompletion(self, messages: list[Message], **kwargs) -> list[LLMResult
prompt_count=completions.usage.prompt_tokens, # type: ignore[attr-defined]
completion_count=completions.usage.completion_tokens, # type: ignore[attr-defined]
system_fingerprint=completions.system_fingerprint,
reasoning_content=reasoning_content,
)
)
return results

@rate_limited
async def acompletion_iter( # type: ignore[override]
async def acompletion_iter(
self, messages: list[Message], **kwargs
) -> AsyncGenerator[LLMResult]:
) -> AsyncIterable[LLMResult]:
prompts = [m.model_dump(by_alias=True) for m in messages if m.content]
stream_completions = await track_costs_iter(self.router.acompletion)(
self.name,
Expand All @@ -601,7 +601,6 @@ async def acompletion_iter( # type: ignore[override]
**kwargs,
)
start_clock = asyncio.get_running_loop().time()
result = LLMResult(model=self.name, prompt=messages)
outputs = []
logprobs = []
role = None
Expand All @@ -612,14 +611,16 @@ async def acompletion_iter( # type: ignore[override]
logprobs.append(choice.logprobs.content[0].logprob or 0)
outputs.append(delta.content or "")
role = delta.role or role
# NOTE: litellm is not populating provider_specific_fields in streaming mode.
# TODO: Get reasoning_content when this issue is fixed
# https://github.com/BerriAI/litellm/issues/7942

text = "".join(outputs)
result = LLMResult(
model=self.name,
text=text,
prompt=messages,
messages=[Message(role=role, content=text)],
# TODO: Can we marginalize over all choices?
logprob=sum_logprobs(logprobs),
)

Expand Down
3 changes: 3 additions & 0 deletions llmclient/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ class LLMResult(BaseModel):
logprob: float | None = Field(
default=None, description="Sum of logprobs in the completion."
)
reasoning_content: str | None = Field(
default=None, description="DeepSeek-R1 reasoning content from the LLM."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reasoning content may come from other models besides DeepSeek-R1. Maybe adjust to say "Reasoning content from LLMs such as DeepSeek-R1"

)

def __str__(self) -> str:
return self.text or ""
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ dependencies = [
"pydantic~=2.0,>=2.10.1,<2.10.2",
"tiktoken>=0.4.0",
"typing-extensions; python_version <= '3.11'", # for typing.override
'litellm; python_version < "3.13"', # NOTE: paper-qa==5.3 doesn't support 3.13 yet
'litellm>1.59.3; python_version < "3.13"', # NOTE: paper-qa==5.3 doesn't support 3.13 yet
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we just get rid of this python_version thing and have one litellm>1.59.3

'litellm>=1.49.1; python_version >= "3.13"', # For removal of imghdr
]
description = "A client to provide LLM responses for FutureHouse applications."
Expand All @@ -44,7 +44,7 @@ dev = [
"fh-llm-client[local]",
"fhaviary[xml]",
"ipython>=8", # Pin to keep recent
"litellm<1.57.2", # Pin for Router.acompletion typing break from https://github.com/BerriAI/litellm/pull/7594
"litellm>1.59.3", # Pin for deepseek support
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just remove this litellm entry from dev, #32 added it

"mypy>=1.8", # Pin for mutable-override
"pre-commit>=3.4", # Pin to keep recent
"pylint-pydantic",
Expand Down
Loading
Loading