Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Global cost tracking #28

Merged
merged 6 commits into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ repos:
additional_dependencies:
- aiohttp
- coredis
- fhaviary[llm]>=0.8.2 # Match pyproject.toml
- fhaviary[llm]>=0.14.0 # Match pyproject.toml
- litellm>=1.44 # Match pyproject.toml
- limits
- numpy
Expand Down
3 changes: 3 additions & 0 deletions llmclient/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
EXTRA_TOKENS_FROM_USER_ROLE,
MODEL_COST_MAP,
)
from .cost_tracker import track_costs_ctx, track_costs_global
from .embeddings import (
EmbeddingModel,
EmbeddingModes,
Expand Down Expand Up @@ -51,5 +52,7 @@
"configure_llm_logs",
"embedding_model_factory",
"sum_logprobs",
"track_costs_ctx",
"track_costs_global",
"validate_json_completion",
]
125 changes: 125 additions & 0 deletions llmclient/cost_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import contextvars
import logging
from collections.abc import Awaitable, Callable
from contextlib import asynccontextmanager
from functools import wraps
from typing import ParamSpec, TypeVar

import litellm

logger = logging.getLogger(__name__)


TRACK_COSTS = contextvars.ContextVar[bool]("track_costs", default=False)
REPORT_EVERY_USD = 1.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Few comments here:

  • What do you think of making it a ClassVar of CostTracker? We still get global state, but it's less awkward with the global and setters/getters
  • Can you make this name more intuitive, and add units to it? It's unclear if
    • It's a frequency (Hz)
    • It's a dollar threshold (USD)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused on the last bullet - _USD is the units - how would you describe it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've renamed set_reporting_frequency -> set_reporting_threshold to remove the ambiguity.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And made both ClassVars.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I went further and made them instance variables - no reason not to.



def set_reporting_frequency(frequency: float):
global REPORT_EVERY_USD # noqa: PLW0603 # pylint: disable=global-statement
REPORT_EVERY_USD = frequency


def track_costs_global(enabled: bool = True):
TRACK_COSTS.set(enabled)


@asynccontextmanager
async def track_costs_ctx(enabled: bool = True):
prev = TRACK_COSTS.get()
TRACK_COSTS.set(enabled)
try:
yield
finally:
TRACK_COSTS.set(prev)


class CostTracker:
def __init__(self):
self.lifetime_cost_usd = 0.0
self.last_report = 0.0

def record(self, response: litellm.ModelResponse):
sidnarayanan marked this conversation as resolved.
Show resolved Hide resolved
self.lifetime_cost_usd += litellm.cost_calculator.completion_cost(
completion_response=response
)

if self.lifetime_cost_usd - self.last_report > REPORT_EVERY_USD:
logger.info(
f"Cumulative llmclient API call cost: ${self.lifetime_cost_usd:.8f}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
f"Cumulative llmclient API call cost: ${self.lifetime_cost_usd:.8f}"
f"Cumulative client API call cost: ${self.lifetime_cost_usd:.8f}"

We will eventually maybe rename from llmclient, and maybe do things besides just LLMs (e.g. embeddings), so let's just be generally worded here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was intentional wording - the cost tracker only tracks llmclient costs right now.

)
self.last_report = self.lifetime_cost_usd


GLOBAL_COST_TRACKER = CostTracker()


TReturn = TypeVar("TReturn", bound=Awaitable)
TParams = ParamSpec("TParams")


def track_costs(
sidnarayanan marked this conversation as resolved.
Show resolved Hide resolved
func: Callable[TParams, TReturn],
) -> Callable[TParams, TReturn]:
async def wrapped_func(*args, **kwargs):
response = await func(*args, **kwargs)
if TRACK_COSTS.get():
GLOBAL_COST_TRACKER.record(response)
return response

return wrapped_func


class TrackedStreamWrapper:
"""Class that tracks costs as one iterates through the stream.

Note that the following is not possible:
```
async def wrap(func):
resp: CustomStreamWrapper = await func()
async for response in resp:
yield response

# This is ok
async for resp in await litellm.acompletion(stream=True):
print(resp
sidnarayanan marked this conversation as resolved.
Show resolved Hide resolved


# This is not, because we cannot await an AsyncGenerator
async for resp in await wrap(litellm.acompletion(stream=True)):
print(resp)
```

In order for `track_costs_iter` to not change how users call functions,
we introduce this class to wrap the stream.
"""

def __init__(self, stream: litellm.CustomStreamWrapper):
self.stream = stream

def __iter__(self):
return self

def __aiter__(self):
return self

def __next__(self):
response = next(self.stream)
if TRACK_COSTS.get():
GLOBAL_COST_TRACKER.record(response)
return response

async def __anext__(self):
response = await self.stream.__anext__()
if TRACK_COSTS.get():
GLOBAL_COST_TRACKER.record(response)
return response


def track_costs_iter(
func: Callable[TParams, TReturn],
) -> Callable[TParams, Awaitable[TrackedStreamWrapper]]:
@wraps(func)
async def wrapped_func(*args, **kwargs):
return TrackedStreamWrapper(await func(*args, **kwargs))

return wrapped_func
10 changes: 4 additions & 6 deletions llmclient/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator

from llmclient.constants import CHARACTERS_PER_TOKEN_ASSUMPTION, MODEL_COST_MAP
from llmclient.cost_tracker import track_costs
from llmclient.rate_limiter import GLOBAL_LIMITER
from llmclient.utils import get_litellm_retrying_config

Expand All @@ -27,8 +28,7 @@ class EmbeddingModel(ABC, BaseModel):
config: dict[str, Any] = Field(
default_factory=dict,
description=(
"Optional `rate_limit` key, value must be a RateLimitItem or RateLimitItem"
" string for parsing"
"Optional `rate_limit` key, value must be a RateLimitItem or RateLimitItem string for parsing"
),
)

Expand Down Expand Up @@ -138,15 +138,14 @@ async def embed_documents(
N = len(texts)
embeddings = []
for i in range(0, N, batch_size):

await self.check_rate_limit(
sum(
len(t) / CHARACTERS_PER_TOKEN_ASSUMPTION
for t in texts[i : i + batch_size]
)
)

response = await litellm.aembedding(
response = await track_costs(litellm.aembedding)(
model=self.name,
input=texts[i : i + batch_size],
dimensions=self.ndim,
Expand Down Expand Up @@ -222,8 +221,7 @@ def __init__(self, **kwargs):
from sentence_transformers import SentenceTransformer
except ImportError as exc:
raise ImportError(
"Please install fh-llm-client[local] to use"
" SentenceTransformerEmbeddingModel."
"Please install fh-llm-client[local] to use SentenceTransformerEmbeddingModel."
) from exc

self._model = SentenceTransformer(self.name)
Expand Down
40 changes: 20 additions & 20 deletions llmclient/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
EXTRA_TOKENS_FROM_USER_ROLE,
IS_PYTHON_BELOW_312,
)
from llmclient.cost_tracker import TrackedStreamWrapper, track_costs, track_costs_iter
from llmclient.exceptions import JSONSchemaValidationError
from llmclient.prompts import default_system_prompt
from llmclient.rate_limiter import GLOBAL_LIMITER
Expand Down Expand Up @@ -258,7 +259,6 @@ async def run_prompt(
]
return await self._run_chat(messages, callbacks, name, system_prompt)
if self.llm_type == "completion":

return await self._run_completion(
prompt, data, callbacks, name, system_prompt
)
Expand Down Expand Up @@ -411,7 +411,6 @@ def rate_limited(
async def wrapper(
self: LLMModelOrChild, *args: Any, **kwargs: Any
) -> Chunk | AsyncIterator[Chunk] | AsyncIterator[LLMModelOrChild]:

if not hasattr(self, "check_rate_limit"):
raise NotImplementedError(
f"Model {self.name} must have a `check_rate_limit` method."
Expand Down Expand Up @@ -566,7 +565,9 @@ async def check_rate_limit(self, token_count: float, **kwargs) -> None:

@rate_limited
async def acomplete(self, prompt: str) -> Chunk: # type: ignore[override]
response = await self.router.atext_completion(model=self.name, prompt=prompt)
response = await track_costs(self.router.atext_completion)(
model=self.name, prompt=prompt
)
return Chunk(
text=response.choices[0].text,
prompt_tokens=response.usage.prompt_tokens,
Expand All @@ -577,7 +578,7 @@ async def acomplete(self, prompt: str) -> Chunk: # type: ignore[override]
async def acomplete_iter( # type: ignore[override]
self, prompt: str
) -> AsyncIterable[Chunk]:
completion = await self.router.atext_completion(
completion = await track_costs_iter(self.router.atext_completion)(
model=self.name,
prompt=prompt,
stream=True,
Expand All @@ -595,7 +596,7 @@ async def acomplete_iter( # type: ignore[override]
@rate_limited
async def achat(self, messages: list[Message]) -> Chunk: # type: ignore[override]
prompts = [m.model_dump(by_alias=True) for m in messages if m.content]
response = await self.router.acompletion(self.name, prompts)
response = await track_costs(self.router.acompletion)(self.name, prompts)
return Chunk(
text=cast(litellm.Choices, response.choices[0]).message.content,
prompt_tokens=response.usage.prompt_tokens, # type: ignore[attr-defined]
Expand All @@ -607,7 +608,7 @@ async def achat_iter( # type: ignore[override]
self, messages: list[Message]
) -> AsyncIterable[Chunk]:
prompts = [m.model_dump(by_alias=True) for m in messages if m.content]
completion = await self.router.acompletion(
completion = await track_costs_iter(self.router.acompletion)(
self.name,
prompts,
stream=True,
Expand Down Expand Up @@ -642,7 +643,7 @@ async def select_tool(
) -> ToolRequestMessage:
"""Shim to aviary.core.ToolSelector that supports tool schemae."""
tool_selector = ToolSelector(
model_name=self.name, acompletion=self.router.acompletion
model_name=self.name, acompletion=track_costs(self.router.acompletion)
)
return await tool_selector(*selection_args, **selection_kwargs)

Expand Down Expand Up @@ -688,22 +689,21 @@ def set_model_name(self) -> Self:
async def achat(
self, messages: Iterable[Message], **kwargs
) -> litellm.ModelResponse:
return await litellm.acompletion(
return await track_costs(litellm.acompletion)(
messages=[m.model_dump(by_alias=True) for m in messages],
**(self.config | kwargs),
)

async def achat_iter(self, messages: Iterable[Message], **kwargs) -> AsyncGenerator:
return cast(
AsyncGenerator,
await litellm.acompletion(
messages=[m.model_dump(by_alias=True) for m in messages],
stream=True,
stream_options={
"include_usage": True, # Included to get prompt token counts
},
**(self.config | kwargs),
),
async def achat_iter(
self, messages: Iterable[Message], **kwargs
) -> TrackedStreamWrapper:
return await track_costs_iter(litellm.acompletion)(
messages=[m.model_dump(by_alias=True) for m in messages],
stream=True,
stream_options={
"include_usage": True, # Included to get prompt token counts
},
**(self.config | kwargs),
)

# SEE: https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
Expand Down Expand Up @@ -816,7 +816,7 @@ async def call( # noqa: C901, PLR0915
results: list[LLMResult] = []

if callbacks is None:
completion: litellm.ModelResponse = await self.achat(prompt, **chat_kwargs)
completion = await self.achat(prompt, **chat_kwargs)
if output_type is not None:
validate_json_completion(completion, output_type)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ classifiers = [
]
dependencies = [
"coredis",
"fhaviary>=0.8.2", # For core namespace
"fhaviary>=0.14.0", # For multi-image support
"limits",
"pydantic~=2.0,>=2.10.1,<2.10.2",
"tiktoken>=0.4.0",
Expand Down
8 changes: 8 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import shutil
from collections.abc import Iterator
from enum import StrEnum
from pathlib import Path
from typing import Any

Expand Down Expand Up @@ -73,3 +74,10 @@ def fixture_reset_log_levels(caplog) -> Iterator[None]:
logger = logging.getLogger(name)
logger.setLevel(logging.NOTSET)
logger.propagate = True


class CILLMModelNames(StrEnum):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We ought to just put this in llmclient as one enum:

class CommonLLMNames(StrEnum):
    # Use these for model defaults
    OPENAI_GENERAL = "gpt-4o-2024-08-06"  # Cheap, fast, and decent

    # Use these in unit testing
    OPENAI_TEST = "gpt-4o-mini-2024-07-18" # Cheap and not OpenAI's cutting edge
    ANTHROPIC_TEST = "claude-3-haiku-20240307" # Cheap and not Anthropic's cutting edge

Then both the app and unit tests will just use CommonLLMNames

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll leave that for another PR

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did this in #30

"""Models to use for generic CI testing."""

ANTHROPIC = "claude-3-haiku-20240307" # Cheap and not Anthropic's cutting edge
OPENAI = "gpt-4o-mini-2024-07-18" # Cheap and not OpenAI's cutting edge
Loading
Loading