Skip to content

Commit

Permalink
Add usage tracking
Browse files Browse the repository at this point in the history
  • Loading branch information
sidjha1 committed Nov 3, 2024
1 parent b093b6c commit 0c86fc8
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 3 deletions.
16 changes: 15 additions & 1 deletion .github/tests/lm_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,28 @@
lotus.logger.setLevel("DEBUG")


@pytest.fixture
@pytest.fixture(scope="session")
def setup_gpt_models():
# Setup GPT models
gpt_4o_mini = LM(model="gpt-4o-mini")
gpt_4o = LM(model="gpt-4o")
return gpt_4o_mini, gpt_4o


@pytest.fixture(autouse=True)
def print_usage_after_each_test(setup_gpt_models):
yield # this runs the test
gpt_4o_mini, gpt_4o = setup_gpt_models
print("\nUsage stats for gpt-4o-mini after test:")
gpt_4o_mini.print_total_usage()
print("\nUsage stats for gpt-4o after test:")
gpt_4o.print_total_usage()

# Reset stats
gpt_4o_mini.reset_stats()
gpt_4o.reset_stats()


def test_filter_operation(setup_gpt_models):
gpt_4o_mini, _ = setup_gpt_models
lotus.settings.configure(lm=gpt_4o_mini)
Expand Down
26 changes: 24 additions & 2 deletions lotus/models/lm.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import Any

import numpy as np
from litellm import batch_completion
from litellm import batch_completion, completion_cost
from litellm.types.utils import ChatCompletionTokenLogprob, Choices, ModelResponse
from litellm.utils import token_counter
from tokenizers import Tokenizer

from lotus.types import LMOutput, LogprobsForCascade, LogprobsForFilterCascade
from lotus.types import LMOutput, LMStats, LogprobsForCascade, LogprobsForFilterCascade


class LM:
Expand All @@ -25,6 +25,8 @@ def __init__(
self.tokenizer = tokenizer
self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs)

self.stats: LMStats = LMStats()

def __call__(
self, messages: list[dict[str, str]] | list[list[dict[str, str]]], **kwargs: dict[str, Any]
) -> LMOutput:
Expand All @@ -42,8 +44,17 @@ def __call__(
[self._get_top_choice_logprobs(resp) for resp in responses] if kwargs_for_batch.get("logprobs") else None
)

for resp in responses:
self._update_stats(resp)

return LMOutput(outputs=outputs, logprobs=logprobs)

def _update_stats(self, response: ModelResponse):
self.stats.total_usage.prompt_tokens += response.usage.prompt_tokens
self.stats.total_usage.completion_tokens += response.usage.completion_tokens
self.stats.total_usage.total_tokens += response.usage.total_tokens
self.stats.total_usage.total_cost += completion_cost(completion_response=response)

def _format_batch_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]:
all_kwargs = {**self.kwargs, **kwargs}
if all_kwargs.get("logprobs", False):
Expand Down Expand Up @@ -121,3 +132,14 @@ def count_tokens(self, messages: list[dict[str, str]] | str) -> int:
model=self.model,
messages=messages,
)

def print_total_usage(self):
print(f"Total cost: ${self.stats.total_usage.total_cost:.6f}")
print(f"Total prompt tokens: {self.stats.total_usage.prompt_tokens}")
print(f"Total completion tokens: {self.stats.total_usage.completion_tokens}")
print(f"Total tokens: {self.stats.total_usage.total_tokens}")

def reset_stats(self):
self.stats = LMStats(
total_usage=LMStats.TotalUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0, total_cost=0.0)
)
10 changes: 10 additions & 0 deletions lotus/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,13 @@ class LogprobsForCascade(BaseModel):

class LogprobsForFilterCascade(LogprobsForCascade):
true_probs: list[float]


class LMStats(BaseModel):
class TotalUsage(BaseModel):
prompt_tokens: int = 0
completion_tokens: int = 0
total_tokens: int = 0
total_cost: float = 0.0

total_usage: TotalUsage = TotalUsage()

0 comments on commit 0c86fc8

Please sign in to comment.