Skip to content

Commit

Permalink
Global cost tracking (#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
sidnarayanan authored Jan 2, 2025
1 parent 9490c04 commit ab90c31
Show file tree
Hide file tree
Showing 10 changed files with 428 additions and 36 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ repos:
additional_dependencies:
- aiohttp
- coredis
- fhaviary[llm]>=0.8.2 # Match pyproject.toml
- fhaviary[llm]>=0.14.0 # Match pyproject.toml
- litellm>=1.44 # Match pyproject.toml
- limits
- numpy
Expand Down
4 changes: 4 additions & 0 deletions llmclient/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
EXTRA_TOKENS_FROM_USER_ROLE,
MODEL_COST_MAP,
)
from .cost_tracker import GLOBAL_COST_TRACKER, cost_tracking_ctx, enable_cost_tracking
from .embeddings import (
EmbeddingModel,
EmbeddingModes,
Expand Down Expand Up @@ -34,6 +35,7 @@
__all__ = [
"CHARACTERS_PER_TOKEN_ASSUMPTION",
"EXTRA_TOKENS_FROM_USER_ROLE",
"GLOBAL_COST_TRACKER",
"MODEL_COST_MAP",
"Chunk",
"Embeddable",
Expand All @@ -49,7 +51,9 @@
"SentenceTransformerEmbeddingModel",
"SparseEmbeddingModel",
"configure_llm_logs",
"cost_tracking_ctx",
"embedding_model_factory",
"enable_cost_tracking",
"sum_logprobs",
"validate_json_completion",
]
174 changes: 174 additions & 0 deletions llmclient/cost_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
import contextvars
import logging
from collections.abc import Awaitable, Callable
from contextlib import asynccontextmanager
from functools import wraps
from typing import ParamSpec, TypeVar

import litellm

logger = logging.getLogger(__name__)


class CostTracker:
def __init__(self):
self.lifetime_cost_usd = 0.0
self.last_report = 0.0
# A contextvar so that different coroutines don't affect each other's cost tracking
self.enabled = contextvars.ContextVar[bool]("track_costs", default=False)
# Not a contextvar because I can't imagine a scenario where you'd want more fine-grained control
self.report_every_usd = 1.0

def record(
self, response: litellm.ModelResponse | litellm.types.utils.EmbeddingResponse
) -> None:
self.lifetime_cost_usd += litellm.cost_calculator.completion_cost(
completion_response=response
)

if self.lifetime_cost_usd - self.last_report > self.report_every_usd:
logger.info(
f"Cumulative llmclient API call cost: ${self.lifetime_cost_usd:.8f}"
)
self.last_report = self.lifetime_cost_usd


GLOBAL_COST_TRACKER = CostTracker()


def set_reporting_threshold(threshold_usd: float) -> None:
GLOBAL_COST_TRACKER.report_every_usd = threshold_usd


def enable_cost_tracking(enabled: bool = True) -> None:
GLOBAL_COST_TRACKER.enabled.set(enabled)


@asynccontextmanager
async def cost_tracking_ctx(enabled: bool = True):
prev = GLOBAL_COST_TRACKER.enabled.get()
GLOBAL_COST_TRACKER.enabled.set(enabled)
try:
yield
finally:
GLOBAL_COST_TRACKER.enabled.set(prev)


TReturn = TypeVar(
"TReturn",
bound=Awaitable[litellm.ModelResponse]
| Awaitable[litellm.types.utils.EmbeddingResponse],
)
TParams = ParamSpec("TParams")


def track_costs(
func: Callable[TParams, TReturn],
) -> Callable[TParams, TReturn]:
"""Automatically track API costs of a coroutine call.
Note that the costs will only be recorded if `enable_cost_tracking()` is called,
or if in a `cost_tracking_ctx()` context.
Usage:
```
@track_costs
async def api_call(...) -> litellm.ModelResponse:
...
```
Args:
func: A coroutine that returns a ModelResponse or EmbeddingResponse
Returns:
A wrapped coroutine with the same signature.
"""

async def wrapped_func(*args, **kwargs):
response = await func(*args, **kwargs)
if GLOBAL_COST_TRACKER.enabled.get():
GLOBAL_COST_TRACKER.record(response)
return response

return wrapped_func


class TrackedStreamWrapper:
"""Class that tracks costs as one iterates through the stream.
Note that the following is not possible:
```
async def wrap(func):
resp: CustomStreamWrapper = await func()
async for response in resp:
yield response
# This is ok
async for resp in await litellm.acompletion(stream=True):
print(resp)
# This is not, because we cannot await an AsyncGenerator
async for resp in await wrap(litellm.acompletion)(stream=True):
print(resp)
```
In order for `track_costs_iter` to not change how users call functions,
we introduce this class to wrap the stream.
"""

def __init__(self, stream: litellm.CustomStreamWrapper):
self.stream = stream

def __iter__(self):
return self

def __aiter__(self):
return self

def __next__(self):
response = next(self.stream)
if GLOBAL_COST_TRACKER.enabled.get():
GLOBAL_COST_TRACKER.record(response)
return response

async def __anext__(self):
response = await self.stream.__anext__()
if GLOBAL_COST_TRACKER.enabled.get():
GLOBAL_COST_TRACKER.record(response)
return response


def track_costs_iter(
func: Callable[TParams, Awaitable[litellm.CustomStreamWrapper]],
) -> Callable[TParams, Awaitable[TrackedStreamWrapper]]:
"""Automatically track API costs of a streaming coroutine.
The return type is changed to `TrackedStreamWrapper`, which can be iterated
through in the same way. The underlying litellm object is available at
`TrackedStreamWrapper.stream`.
Note that the costs will only be recorded if `enable_cost_tracking()` is called,
or if in a `cost_tracking_ctx()` context.
Usage:
```
@track_costs_iter
async def streaming_api_call(...) -> litellm.CustomStreamWrapper:
...
```
Args:
func: A coroutine that returns CustomStreamWrapper.
Returns:
A wrapped coroutine with the same arguments but with a
return type of TrackedStreamWrapper.
"""

@wraps(func)
async def wrapped_func(*args, **kwargs):
return TrackedStreamWrapper(await func(*args, **kwargs))

return wrapped_func
10 changes: 4 additions & 6 deletions llmclient/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator

from llmclient.constants import CHARACTERS_PER_TOKEN_ASSUMPTION, MODEL_COST_MAP
from llmclient.cost_tracker import track_costs
from llmclient.rate_limiter import GLOBAL_LIMITER
from llmclient.utils import get_litellm_retrying_config

Expand All @@ -27,8 +28,7 @@ class EmbeddingModel(ABC, BaseModel):
config: dict[str, Any] = Field(
default_factory=dict,
description=(
"Optional `rate_limit` key, value must be a RateLimitItem or RateLimitItem"
" string for parsing"
"Optional `rate_limit` key, value must be a RateLimitItem or RateLimitItem string for parsing"
),
)

Expand Down Expand Up @@ -138,15 +138,14 @@ async def embed_documents(
N = len(texts)
embeddings = []
for i in range(0, N, batch_size):

await self.check_rate_limit(
sum(
len(t) / CHARACTERS_PER_TOKEN_ASSUMPTION
for t in texts[i : i + batch_size]
)
)

response = await litellm.aembedding(
response = await track_costs(litellm.aembedding)(
model=self.name,
input=texts[i : i + batch_size],
dimensions=self.ndim,
Expand Down Expand Up @@ -222,8 +221,7 @@ def __init__(self, **kwargs):
from sentence_transformers import SentenceTransformer
except ImportError as exc:
raise ImportError(
"Please install fh-llm-client[local] to use"
" SentenceTransformerEmbeddingModel."
"Please install fh-llm-client[local] to use SentenceTransformerEmbeddingModel."
) from exc

self._model = SentenceTransformer(self.name)
Expand Down
40 changes: 20 additions & 20 deletions llmclient/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
EXTRA_TOKENS_FROM_USER_ROLE,
IS_PYTHON_BELOW_312,
)
from llmclient.cost_tracker import TrackedStreamWrapper, track_costs, track_costs_iter
from llmclient.exceptions import JSONSchemaValidationError
from llmclient.prompts import default_system_prompt
from llmclient.rate_limiter import GLOBAL_LIMITER
Expand Down Expand Up @@ -258,7 +259,6 @@ async def run_prompt(
]
return await self._run_chat(messages, callbacks, name, system_prompt)
if self.llm_type == "completion":

return await self._run_completion(
prompt, data, callbacks, name, system_prompt
)
Expand Down Expand Up @@ -411,7 +411,6 @@ def rate_limited(
async def wrapper(
self: LLMModelOrChild, *args: Any, **kwargs: Any
) -> Chunk | AsyncIterator[Chunk] | AsyncIterator[LLMModelOrChild]:

if not hasattr(self, "check_rate_limit"):
raise NotImplementedError(
f"Model {self.name} must have a `check_rate_limit` method."
Expand Down Expand Up @@ -566,7 +565,9 @@ async def check_rate_limit(self, token_count: float, **kwargs) -> None:

@rate_limited
async def acomplete(self, prompt: str) -> Chunk: # type: ignore[override]
response = await self.router.atext_completion(model=self.name, prompt=prompt)
response = await track_costs(self.router.atext_completion)(
model=self.name, prompt=prompt
)
return Chunk(
text=response.choices[0].text,
prompt_tokens=response.usage.prompt_tokens,
Expand All @@ -577,7 +578,7 @@ async def acomplete(self, prompt: str) -> Chunk: # type: ignore[override]
async def acomplete_iter( # type: ignore[override]
self, prompt: str
) -> AsyncIterable[Chunk]:
completion = await self.router.atext_completion(
completion = await track_costs_iter(self.router.atext_completion)(
model=self.name,
prompt=prompt,
stream=True,
Expand All @@ -595,7 +596,7 @@ async def acomplete_iter( # type: ignore[override]
@rate_limited
async def achat(self, messages: list[Message]) -> Chunk: # type: ignore[override]
prompts = [m.model_dump(by_alias=True) for m in messages if m.content]
response = await self.router.acompletion(self.name, prompts)
response = await track_costs(self.router.acompletion)(self.name, prompts)
return Chunk(
text=cast(litellm.Choices, response.choices[0]).message.content,
prompt_tokens=response.usage.prompt_tokens, # type: ignore[attr-defined]
Expand All @@ -607,7 +608,7 @@ async def achat_iter( # type: ignore[override]
self, messages: list[Message]
) -> AsyncIterable[Chunk]:
prompts = [m.model_dump(by_alias=True) for m in messages if m.content]
completion = await self.router.acompletion(
completion = await track_costs_iter(self.router.acompletion)(
self.name,
prompts,
stream=True,
Expand Down Expand Up @@ -642,7 +643,7 @@ async def select_tool(
) -> ToolRequestMessage:
"""Shim to aviary.core.ToolSelector that supports tool schemae."""
tool_selector = ToolSelector(
model_name=self.name, acompletion=self.router.acompletion
model_name=self.name, acompletion=track_costs(self.router.acompletion)
)
return await tool_selector(*selection_args, **selection_kwargs)

Expand Down Expand Up @@ -688,22 +689,21 @@ def set_model_name(self) -> Self:
async def achat(
self, messages: Iterable[Message], **kwargs
) -> litellm.ModelResponse:
return await litellm.acompletion(
return await track_costs(litellm.acompletion)(
messages=[m.model_dump(by_alias=True) for m in messages],
**(self.config | kwargs),
)

async def achat_iter(self, messages: Iterable[Message], **kwargs) -> AsyncGenerator:
return cast(
AsyncGenerator,
await litellm.acompletion(
messages=[m.model_dump(by_alias=True) for m in messages],
stream=True,
stream_options={
"include_usage": True, # Included to get prompt token counts
},
**(self.config | kwargs),
),
async def achat_iter(
self, messages: Iterable[Message], **kwargs
) -> TrackedStreamWrapper:
return await track_costs_iter(litellm.acompletion)(
messages=[m.model_dump(by_alias=True) for m in messages],
stream=True,
stream_options={
"include_usage": True, # Included to get prompt token counts
},
**(self.config | kwargs),
)

# SEE: https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
Expand Down Expand Up @@ -816,7 +816,7 @@ async def call( # noqa: C901, PLR0915
results: list[LLMResult] = []

if callbacks is None:
completion: litellm.ModelResponse = await self.achat(prompt, **chat_kwargs)
completion = await self.achat(prompt, **chat_kwargs)
if output_type is not None:
validate_json_completion(completion, output_type)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ classifiers = [
]
dependencies = [
"coredis",
"fhaviary>=0.8.2", # For core namespace
"fhaviary>=0.14.0", # For multi-image support
"limits",
"pydantic~=2.0,>=2.10.1,<2.10.2",
"tiktoken>=0.4.0",
Expand Down
8 changes: 8 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import shutil
from collections.abc import Iterator
from enum import StrEnum
from pathlib import Path
from typing import Any

Expand Down Expand Up @@ -73,3 +74,10 @@ def fixture_reset_log_levels(caplog) -> Iterator[None]:
logger = logging.getLogger(name)
logger.setLevel(logging.NOTSET)
logger.propagate = True


class CILLMModelNames(StrEnum):
"""Models to use for generic CI testing."""

ANTHROPIC = "claude-3-haiku-20240307" # Cheap and not Anthropic's cutting edge
OPENAI = "gpt-4o-mini-2024-07-18" # Cheap and not OpenAI's cutting edge
Loading

0 comments on commit ab90c31

Please sign in to comment.