Skip to content

Commit

Permalink
Implement get_total_tokens; Minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
bonk1t committed Dec 5, 2024
1 parent a640b3f commit 733af7e
Show file tree
Hide file tree
Showing 12 changed files with 214 additions and 92 deletions.
11 changes: 10 additions & 1 deletion agency_swarm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from .agency import Agency
from .agents import Agent
from .tools import BaseTool
from .util import get_openai_client, llm_validator, set_openai_client, set_openai_key
from .util import (
get_openai_client,
get_usage_tracker,
llm_validator,
set_openai_client,
set_openai_key,
set_usage_tracker,
)
from .util.streaming import (
AgencyEventHandler,
AgencyEventHandlerWithTracking,
Expand All @@ -17,4 +24,6 @@
"set_openai_client",
"set_openai_key",
"llm_validator",
"set_usage_tracker",
"get_usage_tracker",
]
10 changes: 8 additions & 2 deletions agency_swarm/util/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
from .cli.create_agent_template import create_agent_template
from .cli.import_agent import import_agent
from .files import get_file_purpose, get_tools
from .oai import get_openai_client, set_openai_client, set_openai_key
from .usage_tracking import AbstractTracker, LangfuseUsageTracker, SQLiteUsageTracker
from .oai import (
get_openai_client,
get_usage_tracker,
set_openai_client,
set_openai_key,
set_usage_tracker,
)
from .tracking import AbstractTracker, LangfuseUsageTracker, SQLiteUsageTracker
from .validators import llm_validator
4 changes: 2 additions & 2 deletions agency_swarm/util/oai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import httpx
from dotenv import load_dotenv

from agency_swarm.util.usage_tracking.tracker_factory import get_tracker
from agency_swarm.util.tracking.tracker_factory import get_tracker_by_name

load_dotenv()

Expand Down Expand Up @@ -32,7 +32,7 @@ def get_usage_tracker():
Returns:
AbstractTracker: The current usage tracker instance.
"""
return get_tracker(_usage_tracker)
return get_tracker_by_name(_usage_tracker)


def get_openai_client():
Expand Down
2 changes: 1 addition & 1 deletion agency_swarm/util/streaming/agency_event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from openai.types.beta.threads.runs.run_step import RunStep

from agency_swarm.util.oai import get_usage_tracker
from agency_swarm.util.usage_tracking.abstract_tracker import AbstractTracker
from agency_swarm.util.tracking.abstract_tracker import AbstractTracker


class AgencyEventHandler(AssistantEventHandler, ABC):
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from abc import ABC, abstractmethod
from typing import Dict

from openai.types.beta.threads.runs.run_step import Usage

Expand All @@ -17,39 +16,37 @@ def track_usage(
Track token usage.
Args:
usage: Usage object containing token usage statistics
assistant_id: ID of the assistant that generated the usage
thread_id: ID of the thread that generated the usage
model: Model that generated the usage
usage (Usage): Object containing token usage statistics.
assistant_id (str): ID of the assistant that generated the usage.
thread_id (str): ID of the thread that generated the usage.
model (str): Model that generated the usage.
"""
pass

@abstractmethod
def get_total_tokens(self) -> Dict[str, int]:
def get_total_tokens(self) -> Usage:
"""
Get total token usage statistics.
Get total token usage statistics accumulated so far.
Returns:
Dictionary containing total token usage statistics
Usage: An object containing cumulative prompt, completion, and total tokens.
"""
pass

@abstractmethod
def close(self) -> None:
"""
Close the tracker. Called automatically when the tracker is garbage collected.
Close the tracker and release resources, if any.
"""
pass

def __del__(self):
self.close()

@classmethod
@abstractmethod
def get_observe_decorator(cls):
"""
Get the observe decorator for the tracker. Will be applied to the get_completion function.
Returns:
The observe decorator
Callable: The observe decorator.
"""
pass
60 changes: 60 additions & 0 deletions agency_swarm/util/tracking/langfuse_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from langfuse import Langfuse
from langfuse.decorators import observe
from openai.types.beta.threads.runs.run_step import Usage

from agency_swarm.util.tracking.abstract_tracker import AbstractTracker


class LangfuseUsageTracker(AbstractTracker):
def __init__(self):
self.client = Langfuse()

def track_usage(
self, usage: Usage, assistant_id: str, thread_id: str, model: str
) -> None:
"""
Track usage by recording a generation event in Langfuse.
"""
self.client.generation(
model=model,
metadata={
"assistant_id": assistant_id,
"thread_id": thread_id,
},
usage={
"input": usage.prompt_tokens,
"output": usage.completion_tokens,
"total": usage.total_tokens,
"unit": "TOKENS",
},
)

def get_total_tokens(self) -> Usage:
"""
Retrieve total usage from Langfuse by summing over all recorded generations.
"""
generations = self.client.fetch_observations(type="GENERATION").data

prompt_tokens = 0
completion_tokens = 0
total_tokens = 0

for generation in generations:
if generation.usage:
prompt_tokens += generation.usage.input or 0
completion_tokens += generation.usage.output or 0
total_tokens += generation.usage.total or 0

return Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
)

def close(self) -> None:
# Nothing to close
pass

@classmethod
def get_observe_decorator(cls):
return observe
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,26 @@

from openai.types.beta.threads.runs.run_step import Usage

from agency_swarm.util.usage_tracking.abstract_tracker import AbstractTracker
from agency_swarm.util.tracking.abstract_tracker import AbstractTracker


class SQLiteUsageTracker(AbstractTracker):
def __init__(self, db_path: str = "token_usage.db"):
"""
Initializes a SQLite-based usage tracker.
Args:
db_path (str): Path to the SQLite database file.
"""
self.conn = sqlite3.connect(db_path, check_same_thread=False)
self.lock = threading.Lock()
self._create_table()
self._closed = False

def _create_table(self):
def _create_table(self) -> None:
with self.conn:
self.conn.execute("""
self.conn.execute(
"""
CREATE TABLE IF NOT EXISTS token_usage (
id INTEGER PRIMARY KEY AUTOINCREMENT,
prompt_tokens INTEGER,
Expand All @@ -25,18 +33,21 @@ def _create_table(self):
model TEXT,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
)
""")
"""
)

def track_usage(
self, usage: Usage, assistant_id: str, thread_id: str, model: str
) -> None:
with self.lock:
if self._closed:
raise RuntimeError("Attempting to track usage on a closed tracker.")
with self.conn:
self.conn.execute(
"""
INSERT INTO token_usage (prompt_tokens, completion_tokens, total_tokens, assistant_id, thread_id, model)
VALUES (?, ?, ?, ?, ?, ?)
""",
""",
(
usage.prompt_tokens,
usage.completion_tokens,
Expand All @@ -49,11 +60,15 @@ def track_usage(

def get_total_tokens(self) -> Usage:
with self.lock:
if self._closed:
return Usage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
cursor = self.conn.cursor()
cursor.execute("""
cursor.execute(
"""
SELECT SUM(prompt_tokens), SUM(completion_tokens), SUM(total_tokens)
FROM token_usage
""")
"""
)
prompt, completion, total = cursor.fetchone()
return Usage(
prompt_tokens=prompt or 0,
Expand All @@ -62,9 +77,12 @@ def get_total_tokens(self) -> Usage:
)

def close(self) -> None:
self.conn.close()
with self.lock:
if not self._closed:
self.conn.close()
self._closed = True

@classmethod
def get_observe_decorator(cls):
# Return a noop decorator (decorator tracking is supported by other providers)
# Return a no-op decorator.
return lambda f: f
10 changes: 10 additions & 0 deletions agency_swarm/util/tracking/tracker_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from typing import Literal

from agency_swarm.util.tracking.langfuse_tracker import LangfuseUsageTracker
from agency_swarm.util.tracking.sqlite_tracker import SQLiteUsageTracker


def get_tracker_by_name(tracker_type: Literal["sqlite", "langfuse"] = "sqlite"):
if tracker_type == "langfuse":
return LangfuseUsageTracker()
return SQLiteUsageTracker()
36 changes: 0 additions & 36 deletions agency_swarm/util/usage_tracking/langfuse_tracker.py

This file was deleted.

10 changes: 0 additions & 10 deletions agency_swarm/util/usage_tracking/tracker_factory.py

This file was deleted.

Loading

0 comments on commit 733af7e

Please sign in to comment.