Skip to content

Commit

Permalink
Ability to bypass usage of litellm.Router (#563)
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesbraza authored Oct 11, 2024
1 parent 34721a3 commit 471ef8f
Show file tree
Hide file tree
Showing 8 changed files with 910 additions and 173 deletions.
5 changes: 3 additions & 2 deletions paperqa/litqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from paperqa.llms import LiteLLMModel, LLMModel
from paperqa.prompts import EVAL_PROMPT_TEMPLATE, QA_PROMPT_TEMPLATE
from paperqa.settings import make_default_litellm_router_settings
from paperqa.settings import make_default_litellm_model_list_settings
from paperqa.types import Answer

if TYPE_CHECKING:
Expand Down Expand Up @@ -139,7 +139,8 @@ def from_question(

if isinstance(eval_model, str):
eval_model = LiteLLMModel(
name=eval_model, config=make_default_litellm_router_settings(eval_model)
name=eval_model,
config=make_default_litellm_model_list_settings(eval_model),
)

async def llm_from_answer(answer: Answer | str) -> LitQAEvaluation:
Expand Down
94 changes: 64 additions & 30 deletions paperqa/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from enum import StrEnum
from inspect import isasyncgenfunction, signature
from sys import version_info
from typing import Any, TypeVar
from typing import Any, TypeVar, cast

import litellm
import numpy as np
Expand Down Expand Up @@ -82,7 +82,10 @@ class EmbeddingModel(ABC, BaseModel):
name: str
config: dict[str, Any] = Field(
default_factory=dict,
description="Optional `rate_limit` key, value must be a RateLimitItem or RateLimitItem string for parsing",
description=(
"Optional `rate_limit` key, value must be a RateLimitItem or RateLimitItem"
" string for parsing"
),
)

async def check_rate_limit(self, token_count: float, **kwargs) -> None:
Expand All @@ -107,7 +110,12 @@ class LiteLLMEmbeddingModel(EmbeddingModel):
name: str = Field(default="text-embedding-3-small")
config: dict[str, Any] = Field(
default_factory=dict, # See below field_validator for injection of kwargs
description="Optional `rate_limit` key, value must be a RateLimitItem or RateLimitItem string for parsing",
description=(
"The optional `rate_limit` key's value must be a RateLimitItem or"
" RateLimitItem string for parsing. The optional `kwargs` key is keyword"
" arguments to pass to the litellm.aembedding function. Note that LiteLLM's"
" Router is not used here."
),
)

@field_validator("config")
Expand Down Expand Up @@ -510,21 +518,38 @@ def get_litellm_retrying_config(timeout: float = 60.0) -> dict[str, Any]:
return {"num_retries": 3, "timeout": timeout}


class LiteLLMModel(LLMModel):
"""A wrapper around the litellm library.
class PassThroughRouter(litellm.Router):
"""Router that is just a wrapper on LiteLLM's normal free functions."""

`config` should have two high level keys:
`model_list`: stores a list of all model configurations
(see https://docs.litellm.ai/docs/routing)
`router_kwargs`: kwargs for the Router class
`rate_limit`: (Optional) dictionary keyed by model group name
with values of type limits.RateLimitItem (in tokens / minute)
or valid limits.RateLimitItem string for parsing
def __init__(self, **kwargs):
self._default_kwargs = kwargs

This way users can specify routing strategies, retries, etc.
"""
async def atext_completion(self, *args, **kwargs):
return await litellm.atext_completion(*args, **(self._default_kwargs | kwargs))

config: dict = Field(default_factory=dict)
async def acompletion(self, *args, **kwargs):
return await litellm.acompletion(*args, **(self._default_kwargs | kwargs))


class LiteLLMModel(LLMModel):
"""A wrapper around the litellm library."""

config: dict = Field(
default_factory=dict,
description=(
"Configuration of this model containing several important keys. The"
" optional `model_list` key stores a list of all model configurations"
" (SEE: https://docs.litellm.ai/docs/routing). The optional"
" `router_kwargs` key is keyword arguments to pass to the Router class."
" Inclusion of a key `pass_through_router` with a truthy value will lead"
" to using not using LiteLLM's Router, instead just LiteLLM's free"
f" functions (see {PassThroughRouter.__name__}). Rate limiting applies"
" regardless of `pass_through_router` being present. The optional"
" `rate_limit` key is a dictionary keyed by model group name with values"
" of type limits.RateLimitItem (in tokens / minute) or valid"
" limits.RateLimitItem string for parsing."
),
)
name: str = "gpt-4o-mini"
_router: litellm.Router | None = None

Expand All @@ -548,9 +573,12 @@ def maybe_set_config_attribute(cls, data: dict[str, Any]) -> dict[str, Any]:
} | data.get("config", {})

if "router_kwargs" not in data.get("config", {}):
data["config"]["router_kwargs"] = get_litellm_retrying_config() | {
"retry_after": 5
}
if data.get("config", {}).get("pass_through_router"):
data["config"]["router_kwargs"] = get_litellm_retrying_config()
else:
data["config"]["router_kwargs"] = get_litellm_retrying_config() | {
"retry_after": 5
}

# we only support one "model name" for now, here we validate
model_list = data["config"]["model_list"]
Expand All @@ -562,7 +590,7 @@ def maybe_set_config_attribute(cls, data: dict[str, Any]) -> dict[str, Any]:
# pylint: disable-next=possibly-used-before-assignment
_DeploymentTypedDictValidator.validate_python(model_list)
if "config" in data and len({m["model_name"] for m in model_list}) > 1:
raise ValueError("Only one model name per router is supported for now.")
raise ValueError("Only one model name per model list is supported for now.")
return data

def __getstate__(self):
Expand All @@ -573,12 +601,15 @@ def __getstate__(self):
return state

@property
def router(self):
def router(self) -> litellm.Router:
if self._router is None:
self._router = litellm.Router(
model_list=self.config["model_list"],
**self.config.get("router_kwargs", {}),
)
router_kwargs: dict = self.config.get("router_kwargs", {})
if self.config.get("pass_through_router"):
self._router = PassThroughRouter(**router_kwargs)
else:
self._router = litellm.Router(
model_list=self.config["model_list"], **router_kwargs
)
return self._router

async def check_rate_limit(self, token_count: float, **kwargs) -> None:
Expand Down Expand Up @@ -622,19 +653,22 @@ async def acomplete_iter( # type: ignore[override]
async def achat( # type: ignore[override]
self, messages: Iterable[dict[str, str]]
) -> Chunk:
response = await self.router.acompletion(self.name, messages)
response = await self.router.acompletion(self.name, list(messages))
return Chunk(
text=response.choices[0].message.content,
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
text=cast(litellm.Choices, response.choices[0]).message.content,
prompt_tokens=response.usage.prompt_tokens, # type: ignore[attr-defined]
completion_tokens=response.usage.completion_tokens, # type: ignore[attr-defined]
)

@rate_limited
async def achat_iter( # type: ignore[override]
self, messages: Iterable[dict[str, str]]
) -> AsyncIterable[Chunk]:
completion = await self.router.acompletion(
self.name, messages, stream=True, stream_options={"include_usage": True}
self.name,
list(messages),
stream=True,
stream_options={"include_usage": True},
)
async for chunk in completion:
yield Chunk(
Expand Down Expand Up @@ -797,5 +831,5 @@ def embedding_model_factory(embedding: str, **kwargs) -> EmbeddingModel:
if embedding == "sparse":
return SparseEmbeddingModel(**kwargs)
if kwargs: # Only override the default config if there are actually kwargs
kwargs = {"config": kwargs}
return LiteLLMEmbeddingModel(name=embedding, config=kwargs)
return LiteLLMEmbeddingModel(name=embedding, **kwargs)
12 changes: 8 additions & 4 deletions paperqa/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,9 @@ def _deprecated_field(self) -> Self:
return self


def make_default_litellm_router_settings(llm: str, temperature: float = 0.0) -> dict:
def make_default_litellm_model_list_settings(
llm: str, temperature: float = 0.0
) -> dict:
"""Settings matching "model_list" schema here: https://docs.litellm.ai/docs/routing."""
return {
"model_list": [
Expand Down Expand Up @@ -674,21 +676,23 @@ def get_llm(self) -> LiteLLMModel:
return LiteLLMModel(
name=self.llm,
config=self.llm_config
or make_default_litellm_router_settings(self.llm, self.temperature),
or make_default_litellm_model_list_settings(self.llm, self.temperature),
)

def get_summary_llm(self) -> LiteLLMModel:
return LiteLLMModel(
name=self.summary_llm,
config=self.summary_llm_config
or make_default_litellm_router_settings(self.summary_llm, self.temperature),
or make_default_litellm_model_list_settings(
self.summary_llm, self.temperature
),
)

def get_agent_llm(self) -> LiteLLMModel:
return LiteLLMModel(
name=self.agent.agent_llm,
config=self.agent.agent_llm_config
or make_default_litellm_router_settings(
or make_default_litellm_model_list_settings(
self.agent.agent_llm, self.temperature
),
)
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 471ef8f

Please sign in to comment.