Skip to content

Commit

Permalink
fix tool_call error
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanbohan committed Dec 26, 2023
1 parent 1026d67 commit 3382f9a
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions tests/integrations/openai_tracker/test_chat_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import pytest

from greptimeai import collector

from ..database.db import get_trace_data_with_retry, truncate_tables
from . import sync_client
from ..database.db import truncate_tables, get_trace_data_with_retry


@pytest.fixture
Expand Down Expand Up @@ -172,13 +173,9 @@ def test_chat_completion_tool_call(_truncate_tables):
user_id = str(uuid.uuid4())
model = "gpt-3.5-turbo"

def get_lowercase_letters(letters: str) -> str:
return letters.lower()

resp = sync_client.chat.completions.create(
messages=[
{"role": "user", "content": "GREPTIMEAI"},
{"role": "function", "name": "get_lowercase_letters", "content": "letters"},
{"role": "user", "content": "what's the lower case of GREPTIMEAI ?"},
],
model=model,
user=user_id,
Expand All @@ -194,27 +191,30 @@ def get_lowercase_letters(letters: str) -> str:
"properties": {
"letters": {
"type": "string",
"description": "uppercase letters",
"description": "words or letters",
}
},
"required": ["letters"],
},
},
}
],
tool_choice="auto",
tool_choice={"type": "function", "function": {"name": "get_lowercase_letters"}},
)

assert (
resp.choices[0].message.content
and resp.choices[0].message.content.strip() == "greptimeai"
)
print(f"==== {resp=}")

tool_calls = resp.choices[0].message.tool_calls
assert tool_calls
assert len(tool_calls) == 1
assert tool_calls[0].function.name == "get_lowercase_letters"

collector.otel._force_flush()

trace = get_trace_data_with_retry(user_id=user_id, retry=3)

assert trace is not None
print(f"==== {trace=}")

assert "greptimeai" == trace.get("resource_attributes", {}).get("service.name")
assert "openai_completion" == trace.get("span_name")
Expand Down

0 comments on commit 3382f9a

Please sign in to comment.