Skip to content

Commit

Permalink
fix tool_call error
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanbohan committed Dec 26, 2023
1 parent c65e0cb commit 63f403e
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 41 deletions.
18 changes: 12 additions & 6 deletions src/greptimeai/utils/openai/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,16 +211,22 @@ def num_tokens_from_messages(
)
return 0

logger.debug(f"num_tokens_from_messages: {messages=}")
num_tokens = 0
try:
for message in messages:
num_tokens += tokens_per_message
for key, value in message.items():
num_tokens += len(encoding.encode(value))
if key == "name":
num_tokens += tokens_per_name
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
items = None
if isinstance(message, dict):
items = message.items()
elif hasattr(message, "model_dump"):
items = message.model_dump().items()

if items is not None:
for key, value in items:
num_tokens += len(encoding.encode(value))
if key == "name":
num_tokens += tokens_per_name
num_tokens += 3 # every reply is primed with <|im_start|>assistant<|im_sep|>
except Exception as e:
logger.error(f"failed to calculate tokens for message for {e}, {messages=}")
return 0
Expand Down
135 changes: 100 additions & 35 deletions tests/integrations/openai_tracker/test_chat_parameters.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import uuid
from typing import Any, Dict, List

import pytest
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolParam

from greptimeai import collector

Expand Down Expand Up @@ -172,49 +174,52 @@ def test_chat_completion_stop(_truncate_tables):
def test_chat_completion_tool_call(_truncate_tables):
user_id = str(uuid.uuid4())
model = "gpt-3.5-turbo"
messages: List[ChatCompletionMessageParam] = [
{"role": "user", "content": "What's the weather like in Boston today?"}
]
tools: List[ChatCompletionToolParam] = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state",
},
},
"required": ["location"],
},
},
}
]

## STEP 1
resp = sync_client.chat.completions.create(
messages=[
{"role": "user", "content": "what's the lower case of GREPTIMEAI ?"},
],
messages=messages,
model=model,
user=user_id,
seed=1,
tools=[
{
"type": "function",
"function": {
"description": "get lowercase letters",
"name": "get_lowercase_letters",
"parameters": {
"type": "object",
"properties": {
"letters": {
"type": "string",
"description": "words or letters",
}
},
"required": ["letters"],
},
},
}
],
tool_choice={"type": "function", "function": {"name": "get_lowercase_letters"}},
tools=tools,
)

print(f"==== {resp=}")
choice = resp.choices[0]
assert "tool_calls" == choice.finish_reason

tool_calls = resp.choices[0].message.tool_calls
assistant_message = choice.message
tool_calls = assistant_message.tool_calls
assert tool_calls
assert len(tool_calls) == 1
assert tool_calls[0].function.name == "get_lowercase_letters"
tool_call = tool_calls[0]
assert "get_current_weather" == tool_call.function.name

collector.otel._force_flush()

trace = get_trace_data_with_retry(user_id=user_id, retry=3)
trace = get_trace_data_with_retry(user_id=user_id)

assert trace is not None
print(f"==== {trace=}")

assert "greptimeai" == trace.get("resource_attributes", {}).get("service.name")
assert "openai_completion" == trace.get("span_name")
Expand All @@ -236,9 +241,69 @@ def test_chat_completion_tool_call(_truncate_tables):
attrs = event.get("attributes")
assert attrs
assert attrs["choices"]
choice = attrs["choices"][0]
assert choice
assert choice.get("finish_reason") == "tool_calls"
message = choice.get("message")
assert message
assert message.get("content").strip() == "greptimeai"
choice_in_db: Dict[str, Any] = attrs["choices"][0]
assert choice_in_db
assert choice_in_db.get("finish_reason") == "tool_calls"

## STEP 2
tool_message = {
"content": "sunny",
"role": "tool",
"tool_call_id": tool_call.id,
}
messages.append(assistant_message.model_dump()) # type: ignore
messages.append(tool_message) # type: ignore
user_id = str(uuid.uuid4())

resp = sync_client.chat.completions.create(
messages=messages,
model=model,
user=user_id,
tools=tools,
)

choice = resp.choices[0]
assert "stop" == choice.finish_reason

content = choice.message.content
assert content
assert "sunny" in content
assert "weather" in content
assert "boston" in content.lower()

collector.otel._force_flush()

trace = get_trace_data_with_retry(user_id=user_id)

assert trace is not None

assert "greptimeai" == trace.get("resource_attributes", {}).get("service.name")
assert "openai_completion" == trace.get("span_name")
assert "openai" == trace.get("span_attributes", {}).get("source")

assert ["client.chat.completions.create", "end"] == [
event.get("name") for event in trace.get("span_events", [])
]

assert resp.model == trace.get("model")
assert resp.model.startswith(model)

assert resp.usage
assert resp.usage.prompt_tokens == trace.get("prompt_tokens")
assert resp.usage.completion_tokens == trace.get("completion_tokens")

for event in trace.get("span_events", []):
if event.get("name") == "end":
attrs = event.get("attributes")
assert attrs
assert attrs["choices"]

db_choice: Dict[str, Any] = attrs["choices"][0]
assert db_choice
assert db_choice.get("finish_reason") == "stop"

content = db_choice.get("message", {}).get("content")
assert content
assert "sunny" in content
assert "weather" in content
assert "boston" in content.lower()

0 comments on commit 63f403e

Please sign in to comment.