diff --git a/agency_swarm/__init__.py b/agency_swarm/__init__.py index 4f34e47d..837a8afb 100644 --- a/agency_swarm/__init__.py +++ b/agency_swarm/__init__.py @@ -1,7 +1,14 @@ from .agency import Agency from .agents import Agent from .tools import BaseTool -from .util import get_openai_client, llm_validator, set_openai_client, set_openai_key +from .util import ( + get_openai_client, + get_usage_tracker, + llm_validator, + set_openai_client, + set_openai_key, + set_usage_tracker, +) from .util.streaming import ( AgencyEventHandler, AgencyEventHandlerWithTracking, @@ -17,4 +24,6 @@ "set_openai_client", "set_openai_key", "llm_validator", + "set_usage_tracker", + "get_usage_tracker", ] diff --git a/agency_swarm/util/__init__.py b/agency_swarm/util/__init__.py index c8b2762c..5b993bd5 100644 --- a/agency_swarm/util/__init__.py +++ b/agency_swarm/util/__init__.py @@ -1,6 +1,12 @@ from .cli.create_agent_template import create_agent_template from .cli.import_agent import import_agent from .files import get_file_purpose, get_tools -from .oai import get_openai_client, set_openai_client, set_openai_key -from .usage_tracking import AbstractTracker, LangfuseUsageTracker, SQLiteUsageTracker +from .oai import ( + get_openai_client, + get_usage_tracker, + set_openai_client, + set_openai_key, + set_usage_tracker, +) +from .tracking import AbstractTracker, LangfuseUsageTracker, SQLiteUsageTracker from .validators import llm_validator diff --git a/agency_swarm/util/oai.py b/agency_swarm/util/oai.py index 3570d5a6..e0321df1 100644 --- a/agency_swarm/util/oai.py +++ b/agency_swarm/util/oai.py @@ -4,7 +4,7 @@ import httpx from dotenv import load_dotenv -from agency_swarm.util.usage_tracking.tracker_factory import get_tracker +from agency_swarm.util.tracking.tracker_factory import get_tracker_by_name load_dotenv() @@ -32,7 +32,7 @@ def get_usage_tracker(): Returns: AbstractTracker: The current usage tracker instance. """ - return get_tracker(_usage_tracker) + return get_tracker_by_name(_usage_tracker) def get_openai_client(): diff --git a/agency_swarm/util/streaming/agency_event_handler.py b/agency_swarm/util/streaming/agency_event_handler.py index 5f6c35b2..f8f47ce9 100644 --- a/agency_swarm/util/streaming/agency_event_handler.py +++ b/agency_swarm/util/streaming/agency_event_handler.py @@ -5,7 +5,7 @@ from openai.types.beta.threads.runs.run_step import RunStep from agency_swarm.util.oai import get_usage_tracker -from agency_swarm.util.usage_tracking.abstract_tracker import AbstractTracker +from agency_swarm.util.tracking.abstract_tracker import AbstractTracker class AgencyEventHandler(AssistantEventHandler, ABC): diff --git a/agency_swarm/util/usage_tracking/__init__.py b/agency_swarm/util/tracking/__init__.py similarity index 100% rename from agency_swarm/util/usage_tracking/__init__.py rename to agency_swarm/util/tracking/__init__.py diff --git a/agency_swarm/util/usage_tracking/abstract_tracker.py b/agency_swarm/util/tracking/abstract_tracker.py similarity index 56% rename from agency_swarm/util/usage_tracking/abstract_tracker.py rename to agency_swarm/util/tracking/abstract_tracker.py index 177989e9..7eaf02ee 100644 --- a/agency_swarm/util/usage_tracking/abstract_tracker.py +++ b/agency_swarm/util/tracking/abstract_tracker.py @@ -1,5 +1,4 @@ from abc import ABC, abstractmethod -from typing import Dict from openai.types.beta.threads.runs.run_step import Usage @@ -17,39 +16,37 @@ def track_usage( Track token usage. Args: - usage: Usage object containing token usage statistics - assistant_id: ID of the assistant that generated the usage - thread_id: ID of the thread that generated the usage - model: Model that generated the usage + usage (Usage): Object containing token usage statistics. + assistant_id (str): ID of the assistant that generated the usage. + thread_id (str): ID of the thread that generated the usage. + model (str): Model that generated the usage. """ pass @abstractmethod - def get_total_tokens(self) -> Dict[str, int]: + def get_total_tokens(self) -> Usage: """ - Get total token usage statistics. + Get total token usage statistics accumulated so far. Returns: - Dictionary containing total token usage statistics + Usage: An object containing cumulative prompt, completion, and total tokens. """ pass @abstractmethod def close(self) -> None: """ - Close the tracker. Called automatically when the tracker is garbage collected. + Close the tracker and release resources, if any. """ pass - def __del__(self): - self.close() - @classmethod + @abstractmethod def get_observe_decorator(cls): """ Get the observe decorator for the tracker. Will be applied to the get_completion function. Returns: - The observe decorator + Callable: The observe decorator. """ pass diff --git a/agency_swarm/util/tracking/langfuse_tracker.py b/agency_swarm/util/tracking/langfuse_tracker.py new file mode 100644 index 00000000..e17fcd5d --- /dev/null +++ b/agency_swarm/util/tracking/langfuse_tracker.py @@ -0,0 +1,60 @@ +from langfuse import Langfuse +from langfuse.decorators import observe +from openai.types.beta.threads.runs.run_step import Usage + +from agency_swarm.util.tracking.abstract_tracker import AbstractTracker + + +class LangfuseUsageTracker(AbstractTracker): + def __init__(self): + self.client = Langfuse() + + def track_usage( + self, usage: Usage, assistant_id: str, thread_id: str, model: str + ) -> None: + """ + Track usage by recording a generation event in Langfuse. + """ + self.client.generation( + model=model, + metadata={ + "assistant_id": assistant_id, + "thread_id": thread_id, + }, + usage={ + "input": usage.prompt_tokens, + "output": usage.completion_tokens, + "total": usage.total_tokens, + "unit": "TOKENS", + }, + ) + + def get_total_tokens(self) -> Usage: + """ + Retrieve total usage from Langfuse by summing over all recorded generations. + """ + generations = self.client.fetch_observations(type="GENERATION").data + + prompt_tokens = 0 + completion_tokens = 0 + total_tokens = 0 + + for generation in generations: + if generation.usage: + prompt_tokens += generation.usage.input or 0 + completion_tokens += generation.usage.output or 0 + total_tokens += generation.usage.total or 0 + + return Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + ) + + def close(self) -> None: + # Nothing to close + pass + + @classmethod + def get_observe_decorator(cls): + return observe diff --git a/agency_swarm/util/usage_tracking/sqlite_tracker.py b/agency_swarm/util/tracking/sqlite_tracker.py similarity index 70% rename from agency_swarm/util/usage_tracking/sqlite_tracker.py rename to agency_swarm/util/tracking/sqlite_tracker.py index 6e63ee15..4bac5243 100644 --- a/agency_swarm/util/usage_tracking/sqlite_tracker.py +++ b/agency_swarm/util/tracking/sqlite_tracker.py @@ -3,18 +3,26 @@ from openai.types.beta.threads.runs.run_step import Usage -from agency_swarm.util.usage_tracking.abstract_tracker import AbstractTracker +from agency_swarm.util.tracking.abstract_tracker import AbstractTracker class SQLiteUsageTracker(AbstractTracker): def __init__(self, db_path: str = "token_usage.db"): + """ + Initializes a SQLite-based usage tracker. + + Args: + db_path (str): Path to the SQLite database file. + """ self.conn = sqlite3.connect(db_path, check_same_thread=False) self.lock = threading.Lock() self._create_table() + self._closed = False - def _create_table(self): + def _create_table(self) -> None: with self.conn: - self.conn.execute(""" + self.conn.execute( + """ CREATE TABLE IF NOT EXISTS token_usage ( id INTEGER PRIMARY KEY AUTOINCREMENT, prompt_tokens INTEGER, @@ -25,18 +33,21 @@ def _create_table(self): model TEXT, timestamp DATETIME DEFAULT CURRENT_TIMESTAMP ) - """) + """ + ) def track_usage( self, usage: Usage, assistant_id: str, thread_id: str, model: str ) -> None: with self.lock: + if self._closed: + raise RuntimeError("Attempting to track usage on a closed tracker.") with self.conn: self.conn.execute( """ INSERT INTO token_usage (prompt_tokens, completion_tokens, total_tokens, assistant_id, thread_id, model) VALUES (?, ?, ?, ?, ?, ?) - """, + """, ( usage.prompt_tokens, usage.completion_tokens, @@ -49,11 +60,15 @@ def track_usage( def get_total_tokens(self) -> Usage: with self.lock: + if self._closed: + return Usage(prompt_tokens=0, completion_tokens=0, total_tokens=0) cursor = self.conn.cursor() - cursor.execute(""" + cursor.execute( + """ SELECT SUM(prompt_tokens), SUM(completion_tokens), SUM(total_tokens) FROM token_usage - """) + """ + ) prompt, completion, total = cursor.fetchone() return Usage( prompt_tokens=prompt or 0, @@ -62,9 +77,12 @@ def get_total_tokens(self) -> Usage: ) def close(self) -> None: - self.conn.close() + with self.lock: + if not self._closed: + self.conn.close() + self._closed = True @classmethod def get_observe_decorator(cls): - # Return a noop decorator (decorator tracking is supported by other providers) + # Return a no-op decorator. return lambda f: f diff --git a/agency_swarm/util/tracking/tracker_factory.py b/agency_swarm/util/tracking/tracker_factory.py new file mode 100644 index 00000000..42d55b47 --- /dev/null +++ b/agency_swarm/util/tracking/tracker_factory.py @@ -0,0 +1,10 @@ +from typing import Literal + +from agency_swarm.util.tracking.langfuse_tracker import LangfuseUsageTracker +from agency_swarm.util.tracking.sqlite_tracker import SQLiteUsageTracker + + +def get_tracker_by_name(tracker_type: Literal["sqlite", "langfuse"] = "sqlite"): + if tracker_type == "langfuse": + return LangfuseUsageTracker() + return SQLiteUsageTracker() diff --git a/agency_swarm/util/usage_tracking/langfuse_tracker.py b/agency_swarm/util/usage_tracking/langfuse_tracker.py deleted file mode 100644 index c3f0c5b6..00000000 --- a/agency_swarm/util/usage_tracking/langfuse_tracker.py +++ /dev/null @@ -1,36 +0,0 @@ -from langfuse import Langfuse -from langfuse.decorators import observe -from openai.types.beta.threads.runs.run_step import Usage - -from agency_swarm.util.usage_tracking.abstract_tracker import AbstractTracker - - -class LangfuseUsageTracker(AbstractTracker): - def track_usage( - self, usage: Usage, assistant_id: str, thread_id: str, model: str - ) -> None: - langfuse = Langfuse() - langfuse.generation( - model=model, - metadata={ - "assistant_id": assistant_id, - "thread_id": thread_id, - }, - usage={ - "input": usage.prompt_tokens, - "output": usage.completion_tokens, - "total": usage.total_tokens, - "unit": "TOKENS", - }, - ) - - def get_total_tokens(self) -> Usage: - # TODO: Implement this - return Usage(prompt_tokens=0, completion_tokens=0, total_tokens=0) - - def close(self) -> None: - pass - - @classmethod - def get_observe_decorator(cls): - return observe diff --git a/agency_swarm/util/usage_tracking/tracker_factory.py b/agency_swarm/util/usage_tracking/tracker_factory.py deleted file mode 100644 index 5bcc3bd7..00000000 --- a/agency_swarm/util/usage_tracking/tracker_factory.py +++ /dev/null @@ -1,10 +0,0 @@ -from typing import Literal - -from agency_swarm.util.usage_tracking.langfuse_tracker import LangfuseUsageTracker -from agency_swarm.util.usage_tracking.sqlite_tracker import SQLiteUsageTracker - - -def get_tracker(tracker_type: Literal["sqlite", "langfuse"] = "sqlite"): - if tracker_type == "langfuse": - return LangfuseUsageTracker() - return SQLiteUsageTracker() diff --git a/tests/test_usage_tracking.py b/tests/test_usage_tracking.py index 1c45b42d..9fac3359 100644 --- a/tests/test_usage_tracking.py +++ b/tests/test_usage_tracking.py @@ -1,55 +1,123 @@ -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest from openai.types.beta.threads.runs.run_step import Usage from agency_swarm.util.oai import ( _get_openai_module, - set_openai_client, set_usage_tracker, ) -from agency_swarm.util.usage_tracking import LangfuseUsageTracker, SQLiteUsageTracker +from agency_swarm.util.tracking import LangfuseUsageTracker, SQLiteUsageTracker @pytest.fixture def sqlite_tracker(): - return SQLiteUsageTracker(":memory:") + tracker = SQLiteUsageTracker(":memory:") + yield tracker + tracker.close() @pytest.fixture def langfuse_tracker(): - return LangfuseUsageTracker(api_key="test_key", project_id="test_project") + tracker = LangfuseUsageTracker() + yield tracker + tracker.close() def test_sqlite_track_and_get_total_tokens(sqlite_tracker): usage = Usage(prompt_tokens=10, completion_tokens=5, total_tokens=15) - sqlite_tracker.track_usage(usage) + sqlite_tracker.track_usage(usage, "test_assistant", "test_thread", "gpt-4o") totals = sqlite_tracker.get_total_tokens() - assert totals.model_dump() == usage.model_dump() + assert totals == usage -@patch("requests.post") -def test_langfuse_track_usage(mock_post, langfuse_tracker): +def test_sqlite_multiple_entries(sqlite_tracker): + # Insert multiple usage entries + usages = [ + Usage(prompt_tokens=10, completion_tokens=5, total_tokens=15), + Usage(prompt_tokens=20, completion_tokens=10, total_tokens=30), + ] + for u in usages: + sqlite_tracker.track_usage(u, "assistant", "thread", "gpt-4o") + + totals = sqlite_tracker.get_total_tokens() + # Expected totals: prompt=30, completion=15, total=45 + assert totals == Usage(prompt_tokens=30, completion_tokens=15, total_tokens=45) + + +@patch("agency_swarm.util.tracking.langfuse_tracker.Langfuse") +def test_langfuse_track_usage(mock_langfuse, langfuse_tracker): + # Create mock instance and set it as the client + mock_langfuse_instance = MagicMock() + mock_langfuse_instance.generation = MagicMock() + mock_langfuse.return_value = mock_langfuse_instance + langfuse_tracker.client = mock_langfuse_instance # Set the mocked client + usage = Usage(prompt_tokens=20, completion_tokens=10, total_tokens=30) - mock_post.return_value.status_code = 200 - langfuse_tracker.track_usage(usage) + langfuse_tracker.track_usage( + usage=usage, + assistant_id="test_assistant", + thread_id="test_thread", + model="gpt-4o", + ) - mock_post.assert_called_once_with( - f"https://api.langfuse.com/projects/test_project/token-usage", - json=usage.model_dump(), - headers={ - "Authorization": "Bearer test_key", - "Content-Type": "application/json", + mock_langfuse_instance.generation.assert_called_once_with( + model="gpt-4o", + metadata={ + "assistant_id": "test_assistant", + "thread_id": "test_thread", + }, + usage={ + "input": 20, + "output": 10, + "total": 30, + "unit": "TOKENS", }, ) -def test_langfuse_get_total_tokens(langfuse_tracker): +@patch("agency_swarm.util.tracking.langfuse_tracker.Langfuse") +def test_langfuse_get_total_tokens_empty(mock_langfuse, langfuse_tracker): + # Mock the fetch_observations method to return an empty list + mock_langfuse_instance = MagicMock() + mock_langfuse_instance.fetch_observations.return_value = MagicMock(data=[]) + mock_langfuse.return_value = mock_langfuse_instance + langfuse_tracker.client = mock_langfuse_instance # Set the mocked client + totals = langfuse_tracker.get_total_tokens() assert totals == Usage(prompt_tokens=0, completion_tokens=0, total_tokens=0) +@patch("agency_swarm.util.tracking.langfuse_tracker.Langfuse") +def test_langfuse_get_total_tokens_multiple(mock_langfuse, langfuse_tracker): + # Mock multiple generations + mock_generation1 = MagicMock() + mock_generation1.usage.input = 10 + mock_generation1.usage.output = 5 + mock_generation1.usage.total = 15 + + mock_generation2 = MagicMock() + mock_generation2.usage.input = 20 + mock_generation2.usage.output = 10 + mock_generation2.usage.total = 30 + + mock_langfuse_instance = MagicMock() + mock_langfuse_instance.fetch_observations.return_value = MagicMock( + data=[mock_generation1, mock_generation2] + ) + mock_langfuse.return_value = mock_langfuse_instance + langfuse_tracker.client = mock_langfuse_instance # Set the mocked client + + totals = langfuse_tracker.get_total_tokens() + # Expected totals: prompt=30, completion=15, total=45 + assert totals == Usage(prompt_tokens=30, completion_tokens=15, total_tokens=45) + + +def test_get_observe_decorator(langfuse_tracker): + assert callable(langfuse_tracker.get_observe_decorator()) + + if __name__ == "__main__": from dotenv import load_dotenv