From ab90c3178a1a15957a127414d225710ed07376f1 Mon Sep 17 00:00:00 2001 From: Siddharth Narayanan Date: Thu, 2 Jan 2025 13:08:26 -0600 Subject: [PATCH] Global cost tracking (#28) --- .pre-commit-config.yaml | 2 +- llmclient/__init__.py | 4 + llmclient/cost_tracker.py | 174 ++++++++++++++++++++++++++++++ llmclient/embeddings.py | 10 +- llmclient/llms.py | 40 +++---- pyproject.toml | 2 +- tests/conftest.py | 8 ++ tests/test_cost_tracking.py | 209 ++++++++++++++++++++++++++++++++++++ tests/test_llms.py | 5 +- uv.lock | 10 +- 10 files changed, 428 insertions(+), 36 deletions(-) create mode 100644 llmclient/cost_tracker.py create mode 100644 tests/test_cost_tracking.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8a66be9..b6130a7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -73,7 +73,7 @@ repos: additional_dependencies: - aiohttp - coredis - - fhaviary[llm]>=0.8.2 # Match pyproject.toml + - fhaviary[llm]>=0.14.0 # Match pyproject.toml - litellm>=1.44 # Match pyproject.toml - limits - numpy diff --git a/llmclient/__init__.py b/llmclient/__init__.py index c5d2589..139597d 100644 --- a/llmclient/__init__.py +++ b/llmclient/__init__.py @@ -3,6 +3,7 @@ EXTRA_TOKENS_FROM_USER_ROLE, MODEL_COST_MAP, ) +from .cost_tracker import GLOBAL_COST_TRACKER, cost_tracking_ctx, enable_cost_tracking from .embeddings import ( EmbeddingModel, EmbeddingModes, @@ -34,6 +35,7 @@ __all__ = [ "CHARACTERS_PER_TOKEN_ASSUMPTION", "EXTRA_TOKENS_FROM_USER_ROLE", + "GLOBAL_COST_TRACKER", "MODEL_COST_MAP", "Chunk", "Embeddable", @@ -49,7 +51,9 @@ "SentenceTransformerEmbeddingModel", "SparseEmbeddingModel", "configure_llm_logs", + "cost_tracking_ctx", "embedding_model_factory", + "enable_cost_tracking", "sum_logprobs", "validate_json_completion", ] diff --git a/llmclient/cost_tracker.py b/llmclient/cost_tracker.py new file mode 100644 index 0000000..961e099 --- /dev/null +++ b/llmclient/cost_tracker.py @@ -0,0 +1,174 @@ +import contextvars +import logging +from collections.abc import Awaitable, Callable +from contextlib import asynccontextmanager +from functools import wraps +from typing import ParamSpec, TypeVar + +import litellm + +logger = logging.getLogger(__name__) + + +class CostTracker: + def __init__(self): + self.lifetime_cost_usd = 0.0 + self.last_report = 0.0 + # A contextvar so that different coroutines don't affect each other's cost tracking + self.enabled = contextvars.ContextVar[bool]("track_costs", default=False) + # Not a contextvar because I can't imagine a scenario where you'd want more fine-grained control + self.report_every_usd = 1.0 + + def record( + self, response: litellm.ModelResponse | litellm.types.utils.EmbeddingResponse + ) -> None: + self.lifetime_cost_usd += litellm.cost_calculator.completion_cost( + completion_response=response + ) + + if self.lifetime_cost_usd - self.last_report > self.report_every_usd: + logger.info( + f"Cumulative llmclient API call cost: ${self.lifetime_cost_usd:.8f}" + ) + self.last_report = self.lifetime_cost_usd + + +GLOBAL_COST_TRACKER = CostTracker() + + +def set_reporting_threshold(threshold_usd: float) -> None: + GLOBAL_COST_TRACKER.report_every_usd = threshold_usd + + +def enable_cost_tracking(enabled: bool = True) -> None: + GLOBAL_COST_TRACKER.enabled.set(enabled) + + +@asynccontextmanager +async def cost_tracking_ctx(enabled: bool = True): + prev = GLOBAL_COST_TRACKER.enabled.get() + GLOBAL_COST_TRACKER.enabled.set(enabled) + try: + yield + finally: + GLOBAL_COST_TRACKER.enabled.set(prev) + + +TReturn = TypeVar( + "TReturn", + bound=Awaitable[litellm.ModelResponse] + | Awaitable[litellm.types.utils.EmbeddingResponse], +) +TParams = ParamSpec("TParams") + + +def track_costs( + func: Callable[TParams, TReturn], +) -> Callable[TParams, TReturn]: + """Automatically track API costs of a coroutine call. + + Note that the costs will only be recorded if `enable_cost_tracking()` is called, + or if in a `cost_tracking_ctx()` context. + + Usage: + ``` + @track_costs + async def api_call(...) -> litellm.ModelResponse: + ... + ``` + + Args: + func: A coroutine that returns a ModelResponse or EmbeddingResponse + + Returns: + A wrapped coroutine with the same signature. + """ + + async def wrapped_func(*args, **kwargs): + response = await func(*args, **kwargs) + if GLOBAL_COST_TRACKER.enabled.get(): + GLOBAL_COST_TRACKER.record(response) + return response + + return wrapped_func + + +class TrackedStreamWrapper: + """Class that tracks costs as one iterates through the stream. + + Note that the following is not possible: + ``` + async def wrap(func): + resp: CustomStreamWrapper = await func() + async for response in resp: + yield response + + + # This is ok + async for resp in await litellm.acompletion(stream=True): + print(resp) + + + # This is not, because we cannot await an AsyncGenerator + async for resp in await wrap(litellm.acompletion)(stream=True): + print(resp) + ``` + + In order for `track_costs_iter` to not change how users call functions, + we introduce this class to wrap the stream. + """ + + def __init__(self, stream: litellm.CustomStreamWrapper): + self.stream = stream + + def __iter__(self): + return self + + def __aiter__(self): + return self + + def __next__(self): + response = next(self.stream) + if GLOBAL_COST_TRACKER.enabled.get(): + GLOBAL_COST_TRACKER.record(response) + return response + + async def __anext__(self): + response = await self.stream.__anext__() + if GLOBAL_COST_TRACKER.enabled.get(): + GLOBAL_COST_TRACKER.record(response) + return response + + +def track_costs_iter( + func: Callable[TParams, Awaitable[litellm.CustomStreamWrapper]], +) -> Callable[TParams, Awaitable[TrackedStreamWrapper]]: + """Automatically track API costs of a streaming coroutine. + + The return type is changed to `TrackedStreamWrapper`, which can be iterated + through in the same way. The underlying litellm object is available at + `TrackedStreamWrapper.stream`. + + Note that the costs will only be recorded if `enable_cost_tracking()` is called, + or if in a `cost_tracking_ctx()` context. + + Usage: + ``` + @track_costs_iter + async def streaming_api_call(...) -> litellm.CustomStreamWrapper: + ... + ``` + + Args: + func: A coroutine that returns CustomStreamWrapper. + + Returns: + A wrapped coroutine with the same arguments but with a + return type of TrackedStreamWrapper. + """ + + @wraps(func) + async def wrapped_func(*args, **kwargs): + return TrackedStreamWrapper(await func(*args, **kwargs)) + + return wrapped_func diff --git a/llmclient/embeddings.py b/llmclient/embeddings.py index e03de2e..5cb3464 100644 --- a/llmclient/embeddings.py +++ b/llmclient/embeddings.py @@ -10,6 +10,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator from llmclient.constants import CHARACTERS_PER_TOKEN_ASSUMPTION, MODEL_COST_MAP +from llmclient.cost_tracker import track_costs from llmclient.rate_limiter import GLOBAL_LIMITER from llmclient.utils import get_litellm_retrying_config @@ -27,8 +28,7 @@ class EmbeddingModel(ABC, BaseModel): config: dict[str, Any] = Field( default_factory=dict, description=( - "Optional `rate_limit` key, value must be a RateLimitItem or RateLimitItem" - " string for parsing" + "Optional `rate_limit` key, value must be a RateLimitItem or RateLimitItem string for parsing" ), ) @@ -138,7 +138,6 @@ async def embed_documents( N = len(texts) embeddings = [] for i in range(0, N, batch_size): - await self.check_rate_limit( sum( len(t) / CHARACTERS_PER_TOKEN_ASSUMPTION @@ -146,7 +145,7 @@ async def embed_documents( ) ) - response = await litellm.aembedding( + response = await track_costs(litellm.aembedding)( model=self.name, input=texts[i : i + batch_size], dimensions=self.ndim, @@ -222,8 +221,7 @@ def __init__(self, **kwargs): from sentence_transformers import SentenceTransformer except ImportError as exc: raise ImportError( - "Please install fh-llm-client[local] to use" - " SentenceTransformerEmbeddingModel." + "Please install fh-llm-client[local] to use SentenceTransformerEmbeddingModel." ) from exc self._model = SentenceTransformer(self.name) diff --git a/llmclient/llms.py b/llmclient/llms.py index e75b3f3..4a11a78 100644 --- a/llmclient/llms.py +++ b/llmclient/llms.py @@ -47,6 +47,7 @@ EXTRA_TOKENS_FROM_USER_ROLE, IS_PYTHON_BELOW_312, ) +from llmclient.cost_tracker import TrackedStreamWrapper, track_costs, track_costs_iter from llmclient.exceptions import JSONSchemaValidationError from llmclient.prompts import default_system_prompt from llmclient.rate_limiter import GLOBAL_LIMITER @@ -258,7 +259,6 @@ async def run_prompt( ] return await self._run_chat(messages, callbacks, name, system_prompt) if self.llm_type == "completion": - return await self._run_completion( prompt, data, callbacks, name, system_prompt ) @@ -411,7 +411,6 @@ def rate_limited( async def wrapper( self: LLMModelOrChild, *args: Any, **kwargs: Any ) -> Chunk | AsyncIterator[Chunk] | AsyncIterator[LLMModelOrChild]: - if not hasattr(self, "check_rate_limit"): raise NotImplementedError( f"Model {self.name} must have a `check_rate_limit` method." @@ -566,7 +565,9 @@ async def check_rate_limit(self, token_count: float, **kwargs) -> None: @rate_limited async def acomplete(self, prompt: str) -> Chunk: # type: ignore[override] - response = await self.router.atext_completion(model=self.name, prompt=prompt) + response = await track_costs(self.router.atext_completion)( + model=self.name, prompt=prompt + ) return Chunk( text=response.choices[0].text, prompt_tokens=response.usage.prompt_tokens, @@ -577,7 +578,7 @@ async def acomplete(self, prompt: str) -> Chunk: # type: ignore[override] async def acomplete_iter( # type: ignore[override] self, prompt: str ) -> AsyncIterable[Chunk]: - completion = await self.router.atext_completion( + completion = await track_costs_iter(self.router.atext_completion)( model=self.name, prompt=prompt, stream=True, @@ -595,7 +596,7 @@ async def acomplete_iter( # type: ignore[override] @rate_limited async def achat(self, messages: list[Message]) -> Chunk: # type: ignore[override] prompts = [m.model_dump(by_alias=True) for m in messages if m.content] - response = await self.router.acompletion(self.name, prompts) + response = await track_costs(self.router.acompletion)(self.name, prompts) return Chunk( text=cast(litellm.Choices, response.choices[0]).message.content, prompt_tokens=response.usage.prompt_tokens, # type: ignore[attr-defined] @@ -607,7 +608,7 @@ async def achat_iter( # type: ignore[override] self, messages: list[Message] ) -> AsyncIterable[Chunk]: prompts = [m.model_dump(by_alias=True) for m in messages if m.content] - completion = await self.router.acompletion( + completion = await track_costs_iter(self.router.acompletion)( self.name, prompts, stream=True, @@ -642,7 +643,7 @@ async def select_tool( ) -> ToolRequestMessage: """Shim to aviary.core.ToolSelector that supports tool schemae.""" tool_selector = ToolSelector( - model_name=self.name, acompletion=self.router.acompletion + model_name=self.name, acompletion=track_costs(self.router.acompletion) ) return await tool_selector(*selection_args, **selection_kwargs) @@ -688,22 +689,21 @@ def set_model_name(self) -> Self: async def achat( self, messages: Iterable[Message], **kwargs ) -> litellm.ModelResponse: - return await litellm.acompletion( + return await track_costs(litellm.acompletion)( messages=[m.model_dump(by_alias=True) for m in messages], **(self.config | kwargs), ) - async def achat_iter(self, messages: Iterable[Message], **kwargs) -> AsyncGenerator: - return cast( - AsyncGenerator, - await litellm.acompletion( - messages=[m.model_dump(by_alias=True) for m in messages], - stream=True, - stream_options={ - "include_usage": True, # Included to get prompt token counts - }, - **(self.config | kwargs), - ), + async def achat_iter( + self, messages: Iterable[Message], **kwargs + ) -> TrackedStreamWrapper: + return await track_costs_iter(litellm.acompletion)( + messages=[m.model_dump(by_alias=True) for m in messages], + stream=True, + stream_options={ + "include_usage": True, # Included to get prompt token counts + }, + **(self.config | kwargs), ) # SEE: https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice @@ -816,7 +816,7 @@ async def call( # noqa: C901, PLR0915 results: list[LLMResult] = [] if callbacks is None: - completion: litellm.ModelResponse = await self.achat(prompt, **chat_kwargs) + completion = await self.achat(prompt, **chat_kwargs) if output_type is not None: validate_json_completion(completion, output_type) diff --git a/pyproject.toml b/pyproject.toml index 1018308..d879bf2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ classifiers = [ ] dependencies = [ "coredis", - "fhaviary>=0.8.2", # For core namespace + "fhaviary>=0.14.0", # For multi-image support "limits", "pydantic~=2.0,>=2.10.1,<2.10.2", "tiktoken>=0.4.0", diff --git a/tests/conftest.py b/tests/conftest.py index b0f5177..f99051d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,7 @@ import logging import shutil from collections.abc import Iterator +from enum import StrEnum from pathlib import Path from typing import Any @@ -73,3 +74,10 @@ def fixture_reset_log_levels(caplog) -> Iterator[None]: logger = logging.getLogger(name) logger.setLevel(logging.NOTSET) logger.propagate = True + + +class CILLMModelNames(StrEnum): + """Models to use for generic CI testing.""" + + ANTHROPIC = "claude-3-haiku-20240307" # Cheap and not Anthropic's cutting edge + OPENAI = "gpt-4o-mini-2024-07-18" # Cheap and not OpenAI's cutting edge diff --git a/tests/test_cost_tracking.py b/tests/test_cost_tracking.py new file mode 100644 index 0000000..e15a6e9 --- /dev/null +++ b/tests/test_cost_tracking.py @@ -0,0 +1,209 @@ +from contextlib import contextmanager +from typing import Any + +import numpy as np +import pytest +from aviary.core import Message + +from llmclient import cost_tracking_ctx +from llmclient.cost_tracker import GLOBAL_COST_TRACKER +from llmclient.embeddings import LiteLLMEmbeddingModel +from llmclient.llms import LiteLLMModel, MultipleCompletionLLMModel +from llmclient.types import LLMResult + +from .conftest import VCR_DEFAULT_MATCH_ON, CILLMModelNames + + +@contextmanager +def assert_costs_increased(): + """All tests in this file should increase accumulated costs.""" + initial_cost = GLOBAL_COST_TRACKER.lifetime_cost_usd + yield + assert GLOBAL_COST_TRACKER.lifetime_cost_usd > initial_cost + + +class TestLiteLLMEmbeddingCosts: + @pytest.mark.asyncio + async def test_embed_documents(self): + stub_texts = ["test1", "test2"] + with assert_costs_increased(): + async with cost_tracking_ctx(): + model = LiteLLMEmbeddingModel(name="text-embedding-3-small", ndim=8) + await model.embed_documents(stub_texts) + + +class TestLiteLLMModel: + @pytest.mark.vcr(match_on=[*VCR_DEFAULT_MATCH_ON, "body"]) + @pytest.mark.parametrize( + "config", + [ + pytest.param( + { + "model_name": "gpt-4o-mini", + "model_list": [ + { + "model_name": "gpt-4o-mini", + "litellm_params": { + "model": "gpt-4o-mini", + "temperature": 0, + "max_tokens": 56, + }, + } + ], + }, + id="chat-model", + ), + pytest.param( + { + "model_name": "gpt-3.5-turbo-instruct", + "model_list": [ + { + "model_name": "gpt-3.5-turbo-instruct", + "litellm_params": { + "model": "gpt-3.5-turbo-instruct", + "temperature": 0, + "max_tokens": 56, + }, + } + ], + }, + id="completion-model", + ), + ], + ) + @pytest.mark.asyncio + async def test_call(self, config: dict[str, Any]) -> None: + with assert_costs_increased(): + async with cost_tracking_ctx(): + llm = LiteLLMModel(name=config["model_name"], config=config) + messages = [ + Message(role="system", content="Respond with single words."), + Message( + role="user", content="What is the meaning of the universe?" + ), + ] + await llm.call(messages) + + @pytest.mark.asyncio + async def test_call_w_figure(self) -> None: + async def ac(x) -> None: + pass + + async with cost_tracking_ctx(): + with assert_costs_increased(): + llm = LiteLLMModel(name="gpt-4o") + image = np.zeros((32, 32, 3), dtype=np.uint8) + image[:] = [255, 0, 0] + messages = [ + Message( + role="system", + content="You are a detective who investigate colors", + ), + Message.create_message( + role="user", + text="What color is this square? Show me your chain of reasoning.", + images=image, + ), + ] # TODO: It's not decoding the image. It's trying to guess the color from the encoded image string. + await llm.call(messages) + + with assert_costs_increased(): + await llm.call(messages, [ac]) + + @pytest.mark.vcr(match_on=[*VCR_DEFAULT_MATCH_ON, "body"]) + @pytest.mark.parametrize( + "config", + [ + pytest.param( + { + "model_list": [ + { + "model_name": "gpt-4o-mini", + "litellm_params": { + "model": "gpt-4o-mini", + "temperature": 0, + "max_tokens": 56, + }, + } + ] + }, + id="with-router", + ), + pytest.param( + { + "pass_through_router": True, + "router_kwargs": {"temperature": 0, "max_tokens": 56}, + }, + id="without-router", + ), + ], + ) + @pytest.mark.asyncio + async def test_run_prompt(self, config: dict[str, Any]) -> None: + async with cost_tracking_ctx(): + with assert_costs_increased(): + llm = LiteLLMModel(name="gpt-4o-mini", config=config) + + outputs = [] + + def accum(x) -> None: + outputs.append(x) + + await llm.run_prompt( + prompt="The {animal} says", + data={"animal": "duck"}, + system_prompt=None, + callbacks=[accum], + ) + + +class TestMultipleCompletionLLMModel: + async def call_model( + self, model: MultipleCompletionLLMModel, *args, **kwargs + ) -> list[LLMResult]: + return await model.call(*args, **kwargs) + + @pytest.mark.parametrize( + "model_name", ["gpt-3.5-turbo", CILLMModelNames.ANTHROPIC.value] + ) + @pytest.mark.asyncio + async def test_achat(self, model_name: str) -> None: + async with cost_tracking_ctx(): + with assert_costs_increased(): + model = MultipleCompletionLLMModel(name=model_name) + await model.achat( + messages=[ + Message(content="What are three things I should do today?"), + ] + ) + + with assert_costs_increased(): + async for _ in await model.achat_iter( + messages=[ + Message(content="What are three things I should do today?"), + ] + ): + pass + + @pytest.mark.parametrize("model_name", [CILLMModelNames.OPENAI.value]) + @pytest.mark.asyncio + @pytest.mark.vcr + async def test_text_image_message(self, model_name: str) -> None: + async with cost_tracking_ctx(): + with assert_costs_increased(): + model = MultipleCompletionLLMModel(name=model_name, config={"n": 2}) + + # An RGB image of a red square + image = np.zeros((32, 32, 3), dtype=np.uint8) + # (255 red, 0 green, 0 blue) is maximum red in RGB + image[:] = [255, 0, 0] + + await self.call_model( + model, + messages=[ + Message.create_message( + text="What color is this square? Respond only with the color name.", + images=image, + ) + ], + ) diff --git a/tests/test_llms.py b/tests/test_llms.py index b741cf3..ded4860 100644 --- a/tests/test_llms.py +++ b/tests/test_llms.py @@ -22,7 +22,6 @@ class TestLiteLLMModel: - @pytest.mark.vcr(match_on=[*VCR_DEFAULT_MATCH_ON, "body"]) @pytest.mark.parametrize( "config", @@ -89,7 +88,7 @@ async def test_call_w_figure(self) -> None: Message.create_message( role="user", text="What color is this square? Show me your chain of reasoning.", - image=image, + images=image, ), ] # TODO: It's not decoding the image. It's trying to guess the color from the encoded image string. results = await llm.call(messages) @@ -412,7 +411,7 @@ async def test_text_image_message(self, model_name: str) -> None: messages=[ Message.create_message( text="What color is this square? Respond only with the color name.", - image=image, + images=image, ) ], ) diff --git a/uv.lock b/uv.lock index bb99f48..6a093ef 100644 --- a/uv.lock +++ b/uv.lock @@ -567,7 +567,7 @@ wheels = [ [[package]] name = "fh-llm-client" -version = "0.0.4.dev6+g5a2deb7.d20241210" +version = "0.0.8.dev2+gf8125c0.d20241231" source = { editable = "." } dependencies = [ { name = "coredis" }, @@ -652,7 +652,7 @@ dev = [ requires-dist = [ { name = "coredis" }, { name = "fh-llm-client", extras = ["local"], marker = "extra == 'dev'" }, - { name = "fhaviary", specifier = ">=0.8.2" }, + { name = "fhaviary", specifier = ">=0.14.0" }, { name = "fhaviary", extras = ["xml"], marker = "extra == 'dev'" }, { name = "ipython", marker = "extra == 'dev'", specifier = ">=8" }, { name = "limits" }, @@ -688,16 +688,16 @@ dev = [{ name = "fh-llm-client", extras = ["dev"] }] [[package]] name = "fhaviary" -version = "0.11.0" +version = "0.14.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "docstring-parser" }, { name = "httpx" }, { name = "pydantic" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/48/11/5382ad08ed37f7ddb71f74c08e0d052e9bfc710caba370a02e5825f36d7d/fhaviary-0.11.0.tar.gz", hash = "sha256:2d69eac9c8736c910fcf93a058db876d55a34132464d7b2ff485310eb48e1276", size = 242294 } +sdist = { url = "https://files.pythonhosted.org/packages/00/72/df04af3d1135b2bfb6af59f327c7759358a82c2ca7dad01156d93cddfdb1/fhaviary-0.14.0.tar.gz", hash = "sha256:99751f6484e28d33b585cd47c95142a589e814571509390e213e336cdd8aab7b", size = 311279 } wheels = [ - { url = "https://files.pythonhosted.org/packages/7f/6d/8382fe777f37c9183f167e5e2de34703340c8196b5adbdf2ab88211910a6/fhaviary-0.11.0-py3-none-any.whl", hash = "sha256:caff1b9d7dd8a0923f4acbd1ce9711a77672af4cbb9b93ecc6c77d465c4a2cad", size = 48921 }, + { url = "https://files.pythonhosted.org/packages/bb/a1/dfb72d03d72606c1ffc0591bc73b7304dff5c4ca8606fa1e9e8cfbbce1b7/fhaviary-0.14.0-py3-none-any.whl", hash = "sha256:ce1b6950b9719cd321c5afe13edefb429319f54c3f5fb18396a57b22ad465822", size = 52312 }, ] [package.optional-dependencies]