-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Global cost tracking #28
Changes from 5 commits
d9c9575
f8125c0
08a8b9d
dcfc7b0
814c193
b6a51f6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,125 @@ | ||||||
import contextvars | ||||||
import logging | ||||||
from collections.abc import Awaitable, Callable | ||||||
from contextlib import asynccontextmanager | ||||||
from functools import wraps | ||||||
from typing import ParamSpec, TypeVar | ||||||
|
||||||
import litellm | ||||||
|
||||||
logger = logging.getLogger(__name__) | ||||||
|
||||||
|
||||||
TRACK_COSTS = contextvars.ContextVar[bool]("track_costs", default=False) | ||||||
REPORT_EVERY_USD = 1.0 | ||||||
|
||||||
|
||||||
def set_reporting_frequency(frequency: float): | ||||||
global REPORT_EVERY_USD # noqa: PLW0603 # pylint: disable=global-statement | ||||||
REPORT_EVERY_USD = frequency | ||||||
|
||||||
|
||||||
def track_costs_global(enabled: bool = True): | ||||||
TRACK_COSTS.set(enabled) | ||||||
|
||||||
|
||||||
@asynccontextmanager | ||||||
async def track_costs_ctx(enabled: bool = True): | ||||||
prev = TRACK_COSTS.get() | ||||||
TRACK_COSTS.set(enabled) | ||||||
try: | ||||||
yield | ||||||
finally: | ||||||
TRACK_COSTS.set(prev) | ||||||
|
||||||
|
||||||
class CostTracker: | ||||||
def __init__(self): | ||||||
self.lifetime_cost_usd = 0.0 | ||||||
self.last_report = 0.0 | ||||||
|
||||||
def record(self, response: litellm.ModelResponse): | ||||||
sidnarayanan marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
self.lifetime_cost_usd += litellm.cost_calculator.completion_cost( | ||||||
completion_response=response | ||||||
) | ||||||
|
||||||
if self.lifetime_cost_usd - self.last_report > REPORT_EVERY_USD: | ||||||
logger.info( | ||||||
f"Cumulative llmclient API call cost: ${self.lifetime_cost_usd:.8f}" | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
We will eventually maybe rename from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was intentional wording - the cost tracker only tracks |
||||||
) | ||||||
self.last_report = self.lifetime_cost_usd | ||||||
|
||||||
|
||||||
GLOBAL_COST_TRACKER = CostTracker() | ||||||
|
||||||
|
||||||
TReturn = TypeVar("TReturn", bound=Awaitable) | ||||||
TParams = ParamSpec("TParams") | ||||||
|
||||||
|
||||||
def track_costs( | ||||||
sidnarayanan marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
func: Callable[TParams, TReturn], | ||||||
) -> Callable[TParams, TReturn]: | ||||||
async def wrapped_func(*args, **kwargs): | ||||||
response = await func(*args, **kwargs) | ||||||
if TRACK_COSTS.get(): | ||||||
GLOBAL_COST_TRACKER.record(response) | ||||||
return response | ||||||
|
||||||
return wrapped_func | ||||||
|
||||||
|
||||||
class TrackedStreamWrapper: | ||||||
"""Class that tracks costs as one iterates through the stream. | ||||||
|
||||||
Note that the following is not possible: | ||||||
``` | ||||||
async def wrap(func): | ||||||
resp: CustomStreamWrapper = await func() | ||||||
async for response in resp: | ||||||
yield response | ||||||
|
||||||
# This is ok | ||||||
async for resp in await litellm.acompletion(stream=True): | ||||||
print(resp | ||||||
sidnarayanan marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
|
||||||
# This is not, because we cannot await an AsyncGenerator | ||||||
async for resp in await wrap(litellm.acompletion(stream=True)): | ||||||
print(resp) | ||||||
``` | ||||||
|
||||||
In order for `track_costs_iter` to not change how users call functions, | ||||||
we introduce this class to wrap the stream. | ||||||
""" | ||||||
|
||||||
def __init__(self, stream: litellm.CustomStreamWrapper): | ||||||
self.stream = stream | ||||||
|
||||||
def __iter__(self): | ||||||
return self | ||||||
|
||||||
def __aiter__(self): | ||||||
return self | ||||||
|
||||||
def __next__(self): | ||||||
response = next(self.stream) | ||||||
if TRACK_COSTS.get(): | ||||||
GLOBAL_COST_TRACKER.record(response) | ||||||
return response | ||||||
|
||||||
async def __anext__(self): | ||||||
response = await self.stream.__anext__() | ||||||
if TRACK_COSTS.get(): | ||||||
GLOBAL_COST_TRACKER.record(response) | ||||||
return response | ||||||
|
||||||
|
||||||
def track_costs_iter( | ||||||
func: Callable[TParams, TReturn], | ||||||
) -> Callable[TParams, Awaitable[TrackedStreamWrapper]]: | ||||||
@wraps(func) | ||||||
async def wrapped_func(*args, **kwargs): | ||||||
return TrackedStreamWrapper(await func(*args, **kwargs)) | ||||||
|
||||||
return wrapped_func |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ | |
import logging | ||
import shutil | ||
from collections.abc import Iterator | ||
from enum import StrEnum | ||
from pathlib import Path | ||
from typing import Any | ||
|
||
|
@@ -73,3 +74,10 @@ def fixture_reset_log_levels(caplog) -> Iterator[None]: | |
logger = logging.getLogger(name) | ||
logger.setLevel(logging.NOTSET) | ||
logger.propagate = True | ||
|
||
|
||
class CILLMModelNames(StrEnum): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We ought to just put this in class CommonLLMNames(StrEnum):
# Use these for model defaults
OPENAI_GENERAL = "gpt-4o-2024-08-06" # Cheap, fast, and decent
# Use these in unit testing
OPENAI_TEST = "gpt-4o-mini-2024-07-18" # Cheap and not OpenAI's cutting edge
ANTHROPIC_TEST = "claude-3-haiku-20240307" # Cheap and not Anthropic's cutting edge Then both the app and unit tests will just use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll leave that for another PR There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I did this in #30 |
||
"""Models to use for generic CI testing.""" | ||
|
||
ANTHROPIC = "claude-3-haiku-20240307" # Cheap and not Anthropic's cutting edge | ||
OPENAI = "gpt-4o-mini-2024-07-18" # Cheap and not OpenAI's cutting edge |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Few comments here:
ClassVar
ofCostTracker
? We still get global state, but it's less awkward with theglobal
and setters/gettersThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm confused on the last bullet -
_USD
is the units - how would you describe it?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've renamed
set_reporting_frequency
->set_reporting_threshold
to remove the ambiguity.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And made both ClassVars.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually I went further and made them instance variables - no reason not to.