diff --git a/pyproject.toml b/pyproject.toml index ce990aa..832401a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ dev-dependencies = [ "mypy>=1.7.1", "pymysql>=1.1.0", "pytest-asyncio>=0.23.2", + "types-PyMySQL>=1.1.0.1", ] [tool.rye.scripts] diff --git a/requirements-dev.lock b/requirements-dev.lock index d5af55c..c10c9d7 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -22,6 +22,7 @@ distro==1.8.0 exceptiongroup==1.1.3 frozenlist==1.4.0 googleapis-common-protos==1.61.0 +greenlet==3.0.2 h11==0.14.0 httpcore==1.0.2 httpx==0.25.1 @@ -62,6 +63,7 @@ tenacity==8.2.3 tiktoken==0.5.1 tomli==2.0.1 tqdm==4.66.1 +types-pymysql==1.1.0.1 typing-extensions==4.8.0 typing-inspect==0.9.0 urllib3==2.0.7 diff --git a/tests/integrations/database/db.py b/tests/integrations/database/db.py index a0694ba..c9ddc85 100644 --- a/tests/integrations/database/db.py +++ b/tests/integrations/database/db.py @@ -1,37 +1,43 @@ -import logging import os -from typing import Union, List +from typing import List, Union import pymysql +from greptimeai import logger + from .model import Tables -db = pymysql.connect( +connection = pymysql.connect( host=os.getenv("GREPTIMEAI_HOST"), user=os.getenv("GREPTIMEAI_USERNAME"), passwd=os.getenv("GREPTIMEAI_PASSWORD"), port=4002, db=os.getenv("GREPTIMEAI_DATABASE"), ) -cursor = db.cursor() +# cursor = connection.cursor() -trace_sql = "SELECT model,prompt_tokens,completion_tokens FROM %s WHERE user_id = '%s'" -truncate_sql = "TRUNCATE %s" +trace_sql = f"SELECT model, prompt_tokens, completion_tokens FROM {Tables.llm_trace} WHERE user_id = '%s'" def get_trace_data(user_id: str) -> List[Union[str, int]]: """ get trace data for llm trace by user_id - :param is_stream: :param user_id: :return: model, prompt_tokens, completion_tokens """ - cursor.execute(trace_sql % (Tables.llm_trace, user_id)) - trace = cursor.fetchone() - if trace is None: - raise Exception("trace data is None") - return list(trace) + with connection.cursor() as cursor: + cursor.execute("select * from llm_traces_preview_v01 limit 1") + trace = cursor.fetchone() + logger.info(f" {type(trace)=} {trace=}") + + with connection.cursor() as cursor: + cursor.execute(trace_sql % (user_id)) + trace = cursor.fetchone() + logger.info(f"{type(trace)=} {trace=}") + if trace is None: + raise Exception("trace data is None") + return list(trace) def truncate_tables(): @@ -51,8 +57,9 @@ def truncate_tables(): "llm_traces_preview_v01", ] try: - cursor.executemany(truncate_sql, tables) - db.commit() + with connection.cursor() as cursor: + cursor.executemany("TRUNCATE %s", tables) + connection.commit() except Exception as e: - logging.error(e) - db.rollback() + logger.error(e) + connection.rollback() diff --git a/tests/integrations/openai_tracker/__init__.py b/tests/integrations/openai_tracker/__init__.py index 68ecdd1..d9ce880 100644 --- a/tests/integrations/openai_tracker/__init__.py +++ b/tests/integrations/openai_tracker/__init__.py @@ -1,7 +1,6 @@ -from openai import AsyncOpenAI -from openai import OpenAI +from openai import AsyncOpenAI, OpenAI -from greptimeai import openai_patcher # type: ignore +from greptimeai import openai_patcher async_client = AsyncOpenAI() openai_patcher.setup(client=async_client) diff --git a/tests/integrations/openai_tracker/test_sync.py b/tests/integrations/openai_tracker/test_sync.py index c23e7ef..baa33d7 100644 --- a/tests/integrations/openai_tracker/test_sync.py +++ b/tests/integrations/openai_tracker/test_sync.py @@ -3,12 +3,10 @@ import pytest -from ..database.db import ( - get_trace_data, - truncate_tables, -) +from greptimeai.openai_patcher import _collector + +from ..database.db import get_trace_data, truncate_tables from ..openai_tracker import client -from greptimeai.openai_patcher import _collector # type: ignore @pytest.fixture @@ -33,11 +31,11 @@ def test_chat_completion(_truncate_tables): assert resp.choices[0].message.content == "2" _collector._collector._force_flush() - time.sleep(6) + time.sleep(1) trace = get_trace_data(user_id) assert resp.model == trace[0] - if resp.usage: - assert resp.usage.prompt_tokens == trace[1] - assert resp.usage.completion_tokens == trace[2] + assert resp.usage + assert resp.usage.prompt_tokens == trace[1] + assert resp.usage.completion_tokens == trace[2]