diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5f428f8..a2a2e60 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -21,6 +21,7 @@ jobs: ci: runs-on: ubuntu-latest strategy: + max-parallel: 1 matrix: python-version: ["3.8", "3.9", "3.10", "3.11"] @@ -41,10 +42,17 @@ jobs: - name: Install Dev dependencies run: | - pip install langchain openai pytest ruff mypy + pip install langchain openai pytest ruff mypy pymysql pytest-asyncio types-PyMySQL # TODO(yuanbohan): code coverage with pytest-cov - name: Test with pytest + env: + GREPTIMEAI_HOST: ${{ secrets.GREPTIMEAI_HOST }} + GREPTIMEAI_USERNAME: ${{ secrets.GREPTIMEAI_USERNAME }} + GREPTIMEAI_PASSWORD: ${{ secrets.GREPTIMEAI_PASSWORD }} + GREPTIMEAI_DATABASE: ${{ secrets.GREPTIMEAI_DATABASE }} + GREPTIMEAI_TOKEN: ${{ secrets.GREPTIMEAI_TOKEN }} + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} run: | pytest tests/ -v @@ -56,3 +64,7 @@ jobs: - name: Static Type Checking with mypy run: | mypy $(git ls-files '*.py') --check-untyped-defs + +concurrency: + group: ${{ github.repository }} + cancel-in-progress: true diff --git a/pyproject.toml b/pyproject.toml index f759bb9..ce990aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,8 @@ dev-dependencies = [ "langchain>=0.0.27", "openai>=1.3.5", "mypy>=1.7.1", + "pymysql>=1.1.0", + "pytest-asyncio>=0.23.2", ] [tool.rye.scripts] diff --git a/requirements-dev.lock b/requirements-dev.lock index 09e0d31..d5af55c 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -22,7 +22,6 @@ distro==1.8.0 exceptiongroup==1.1.3 frozenlist==1.4.0 googleapis-common-protos==1.61.0 -greenlet==3.0.1 h11==0.14.0 httpcore==1.0.2 httpx==0.25.1 @@ -47,16 +46,18 @@ opentelemetry-sdk==1.20.0 opentelemetry-semantic-conventions==0.41b0 packaging==23.2 pluggy==1.3.0 -protobuf==4.24.4 +protobuf==4.25.1 pydantic==2.4.2 pydantic-core==2.10.1 +pymysql==1.1.0 pytest==7.4.3 +pytest-asyncio==0.23.2 pyyaml==6.0.1 regex==2023.10.3 requests==2.31.0 ruff==0.1.3 sniffio==1.3.0 -sqlalchemy==2.0.22 +sqlalchemy==2.0.23 tenacity==8.2.3 tiktoken==0.5.1 tomli==2.0.1 diff --git a/src/greptimeai/openai_patcher.py b/src/greptimeai/openai_patcher.py index 214f68f..19b5086 100644 --- a/src/greptimeai/openai_patcher.py +++ b/src/greptimeai/openai_patcher.py @@ -18,6 +18,8 @@ ) from greptimeai.patcher.openai_patcher.retry import _RetryPatcher +_collector: Collector = None # type: ignore + def setup( host: str = "", @@ -37,20 +39,21 @@ def setup( token: if None or empty string, GREPTIMEAI_TOKEN environment variable will be used. client: if None, then openai module-level client will be patched. """ - collector = Collector( + global _collector + _collector = Collector( service_name="openai", host=host, database=database, token=token ) patchers: List[Patcher] = [ - _AudioPatcher(collector=collector, client=client), - _ChatCompletionPatcher(collector=collector, client=client), - _CompletionPatcher(collector=collector, client=client), - _EmbeddingPatcher(collector=collector, client=client), - _FilePatcher(collector=collector, client=client), - _FineTuningPatcher(collector=collector, client=client), - _ImagePatcher(collector=collector, client=client), - _ModelPatcher(collector=collector, client=client), - _ModerationPatcher(collector=collector, client=client), - _RetryPatcher(collector=collector, client=client), + _AudioPatcher(collector=_collector, client=client), + _ChatCompletionPatcher(collector=_collector, client=client), + _CompletionPatcher(collector=_collector, client=client), + _EmbeddingPatcher(collector=_collector, client=client), + _FilePatcher(collector=_collector, client=client), + _FineTuningPatcher(collector=_collector, client=client), + _ImagePatcher(collector=_collector, client=client), + _ModelPatcher(collector=_collector, client=client), + _ModerationPatcher(collector=_collector, client=client), + _RetryPatcher(collector=_collector, client=client), ] for patcher in patchers: diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/collection/__init__.py b/tests/collection/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integrations/__init__.py b/tests/integrations/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integrations/database/__init__.py b/tests/integrations/database/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integrations/database/db.py b/tests/integrations/database/db.py new file mode 100644 index 0000000..a0694ba --- /dev/null +++ b/tests/integrations/database/db.py @@ -0,0 +1,58 @@ +import logging +import os +from typing import Union, List + +import pymysql + +from .model import Tables + +db = pymysql.connect( + host=os.getenv("GREPTIMEAI_HOST"), + user=os.getenv("GREPTIMEAI_USERNAME"), + passwd=os.getenv("GREPTIMEAI_PASSWORD"), + port=4002, + db=os.getenv("GREPTIMEAI_DATABASE"), +) +cursor = db.cursor() + +trace_sql = "SELECT model,prompt_tokens,completion_tokens FROM %s WHERE user_id = '%s'" +truncate_sql = "TRUNCATE %s" + + +def get_trace_data(user_id: str) -> List[Union[str, int]]: + """ + get trace data for llm trace by user_id + :param is_stream: + :param user_id: + :return: model, prompt_tokens, completion_tokens + """ + + cursor.execute(trace_sql % (Tables.llm_trace, user_id)) + trace = cursor.fetchone() + if trace is None: + raise Exception("trace data is None") + return list(trace) + + +def truncate_tables(): + """ + truncate all tables + :return: + """ + tables = [ + "llm_completion_tokens", + "llm_completion_tokens_cost", + "llm_errors", + "llm_prompt_tokens", + "llm_prompt_tokens_cost", + "llm_request_duration_ms_bucket", + "llm_request_duration_ms_count", + "llm_request_duration_ms_sum", + "llm_traces_preview_v01", + ] + try: + cursor.executemany(truncate_sql, tables) + db.commit() + except Exception as e: + logging.error(e) + db.rollback() diff --git a/tests/integrations/database/model.py b/tests/integrations/database/model.py new file mode 100644 index 0000000..6fcb3f8 --- /dev/null +++ b/tests/integrations/database/model.py @@ -0,0 +1,88 @@ +class LlmTrace(object): + table_name = "llm_traces_preview_v01" + + trace_id: str + span_id: str + parent_span_id: str + resource_attributes: str + scope_name: str + scope_version: str + scope_attributes: str + trace_state: str + span_name: str + span_kind: str + span_status_code: str + span_status_message: str + span_attributes: str + span_events: str + span_links: str + start: float + end: float + user_id: str + model: str + prompt_tokens: int + prompt_cost: float + completion_tokens: int + completion_cost: float + greptime_value: str + greptime_timestamp: float + + +class LlmPromptToken(object): + table_name = "llm_prompt_tokens" + + telemetry_sdk_language: str + telemetry_sdk_name: str + telemetry_sdk_version: str + service_name: str + span_name: str + model: str + greptime_value: str + greptime_timestamp: float + + +class LlmPromptTokenCost(object): + table_name = "llm_prompt_tokens_cost" + + telemetry_sdk_language: str + telemetry_sdk_name: str + telemetry_sdk_version: str + service_name: str + span_name: str + model: str + greptime_value: str + greptime_timestamp: float + + +class LlmCompletionToken(object): + table_name = "llm_completion_tokens" + + telemetry_sdk_language: str + telemetry_sdk_name: str + telemetry_sdk_version: str + service_name: str + span_name: str + model: str + greptime_value: str + greptime_timestamp: float + + +class LlmCompletionTokenCost(object): + table_name = "llm_completion_tokens_cost" + + telemetry_sdk_language: str + telemetry_sdk_name: str + telemetry_sdk_version: str + service_name: str + span_name: str + model: str + greptime_value: str + greptime_timestamp: float + + +class Tables(object): + llm_trace = "llm_traces_preview_v01" + llm_prompt_tokens = "llm_prompt_tokens" + llm_prompt_tokens_cost = "llm_prompt_tokens_cost" + llm_completion_tokens = "llm_completion_tokens" + llm_completion_tokens_cost = "llm_completion_tokens_cost" diff --git a/tests/integrations/openai_tracker/__init__.py b/tests/integrations/openai_tracker/__init__.py new file mode 100644 index 0000000..68ecdd1 --- /dev/null +++ b/tests/integrations/openai_tracker/__init__.py @@ -0,0 +1,10 @@ +from openai import AsyncOpenAI +from openai import OpenAI + +from greptimeai import openai_patcher # type: ignore + +async_client = AsyncOpenAI() +openai_patcher.setup(client=async_client) + +client = OpenAI() +openai_patcher.setup(client=client) diff --git a/tests/integrations/openai_tracker/test_sync.py b/tests/integrations/openai_tracker/test_sync.py new file mode 100644 index 0000000..c23e7ef --- /dev/null +++ b/tests/integrations/openai_tracker/test_sync.py @@ -0,0 +1,43 @@ +import time +import uuid + +import pytest + +from ..database.db import ( + get_trace_data, + truncate_tables, +) +from ..openai_tracker import client +from greptimeai.openai_patcher import _collector # type: ignore + + +@pytest.fixture +def _truncate_tables(): + truncate_tables() + yield + + +def test_chat_completion(_truncate_tables): + user_id = str(uuid.uuid4()) + resp = client.chat.completions.create( + messages=[ + { + "role": "user", + "content": "1+1=", + } + ], + model="gpt-3.5-turbo", + user=user_id, + seed=1, + ) + assert resp.choices[0].message.content == "2" + + _collector._collector._force_flush() + time.sleep(6) + trace = get_trace_data(user_id) + + assert resp.model == trace[0] + + if resp.usage: + assert resp.usage.prompt_tokens == trace[1] + assert resp.usage.completion_tokens == trace[2] diff --git a/tests/langchain/__init__.py b/tests/langchain/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 0000000..e69de29