Skip to content

Commit

Permalink
ci
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanbohan committed Dec 14, 2023
1 parent 7dba4a5 commit d18b8ab
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 28 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ dev-dependencies = [
"mypy>=1.7.1",
"pymysql>=1.1.0",
"pytest-asyncio>=0.23.2",
"types-PyMySQL>=1.1.0.1",
]

[tool.rye.scripts]
Expand Down
2 changes: 2 additions & 0 deletions requirements-dev.lock
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ distro==1.8.0
exceptiongroup==1.1.3
frozenlist==1.4.0
googleapis-common-protos==1.61.0
greenlet==3.0.2
h11==0.14.0
httpcore==1.0.2
httpx==0.25.1
Expand Down Expand Up @@ -62,6 +63,7 @@ tenacity==8.2.3
tiktoken==0.5.1
tomli==2.0.1
tqdm==4.66.1
types-pymysql==1.1.0.1
typing-extensions==4.8.0
typing-inspect==0.9.0
urllib3==2.0.7
Expand Down
39 changes: 23 additions & 16 deletions tests/integrations/database/db.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,43 @@
import logging
import os
from typing import Union, List
from typing import List, Union

import pymysql

from greptimeai import logger

from .model import Tables

db = pymysql.connect(
connection = pymysql.connect(
host=os.getenv("GREPTIMEAI_HOST"),
user=os.getenv("GREPTIMEAI_USERNAME"),
passwd=os.getenv("GREPTIMEAI_PASSWORD"),
port=4002,
db=os.getenv("GREPTIMEAI_DATABASE"),
)
cursor = db.cursor()
# cursor = connection.cursor()

trace_sql = "SELECT model,prompt_tokens,completion_tokens FROM %s WHERE user_id = '%s'"
truncate_sql = "TRUNCATE %s"
trace_sql = f"SELECT model, prompt_tokens, completion_tokens FROM {Tables.llm_trace} WHERE user_id = '%s'"


def get_trace_data(user_id: str) -> List[Union[str, int]]:
"""
get trace data for llm trace by user_id
:param is_stream:
:param user_id:
:return: model, prompt_tokens, completion_tokens
"""

cursor.execute(trace_sql % (Tables.llm_trace, user_id))
trace = cursor.fetchone()
if trace is None:
raise Exception("trace data is None")
return list(trace)
with connection.cursor() as cursor:
cursor.execute("select * from llm_traces_preview_v01 limit 1")
trace = cursor.fetchone()
logger.info(f" {type(trace)=} {trace=}")

with connection.cursor() as cursor:
cursor.execute(trace_sql % (user_id))
trace = cursor.fetchone()
logger.info(f"{type(trace)=} {trace=}")
if trace is None:
raise Exception("trace data is None")
return list(trace)


def truncate_tables():
Expand All @@ -51,8 +57,9 @@ def truncate_tables():
"llm_traces_preview_v01",
]
try:
cursor.executemany(truncate_sql, tables)
db.commit()
with connection.cursor() as cursor:
cursor.executemany("TRUNCATE %s", tables)
connection.commit()
except Exception as e:
logging.error(e)
db.rollback()
logger.error(e)
connection.rollback()
5 changes: 2 additions & 3 deletions tests/integrations/openai_tracker/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from openai import AsyncOpenAI
from openai import OpenAI
from openai import AsyncOpenAI, OpenAI

from greptimeai import openai_patcher # type: ignore
from greptimeai import openai_patcher

async_client = AsyncOpenAI()
openai_patcher.setup(client=async_client)
Expand Down
16 changes: 7 additions & 9 deletions tests/integrations/openai_tracker/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@

import pytest

from ..database.db import (
get_trace_data,
truncate_tables,
)
from greptimeai.openai_patcher import _collector

from ..database.db import get_trace_data, truncate_tables
from ..openai_tracker import client
from greptimeai.openai_patcher import _collector # type: ignore


@pytest.fixture
Expand All @@ -33,11 +31,11 @@ def test_chat_completion(_truncate_tables):
assert resp.choices[0].message.content == "2"

_collector._collector._force_flush()
time.sleep(6)
time.sleep(1)
trace = get_trace_data(user_id)

assert resp.model == trace[0]

if resp.usage:
assert resp.usage.prompt_tokens == trace[1]
assert resp.usage.completion_tokens == trace[2]
assert resp.usage
assert resp.usage.prompt_tokens == trace[1]
assert resp.usage.completion_tokens == trace[2]

0 comments on commit d18b8ab

Please sign in to comment.