Skip to content

Commit

Permalink
fix tool_call error
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanbohan committed Dec 26, 2023
1 parent 1026d67 commit c65e0cb
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 19 deletions.
20 changes: 13 additions & 7 deletions src/greptimeai/utils/openai/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,14 +210,20 @@ def num_tokens_from_messages(
f"greptimeai doesn't support the computation of tokens for {model} at this time."
)
return 0

logger.debug(f"num_tokens_from_messages: {messages=}")
num_tokens = 0
for message in messages:
num_tokens += tokens_per_message
for key, value in message.items():
num_tokens += len(encoding.encode(value))
if key == "name":
num_tokens += tokens_per_name
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
try:
for message in messages:
num_tokens += tokens_per_message
for key, value in message.items():
num_tokens += len(encoding.encode(value))
if key == "name":
num_tokens += tokens_per_name
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
except Exception as e:
logger.error(f"failed to calculate tokens for message for {e}, {messages=}")
return 0
return num_tokens


Expand Down
24 changes: 12 additions & 12 deletions tests/integrations/openai_tracker/test_chat_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import pytest

from greptimeai import collector

from ..database.db import get_trace_data_with_retry, truncate_tables
from . import sync_client
from ..database.db import truncate_tables, get_trace_data_with_retry


@pytest.fixture
Expand Down Expand Up @@ -172,13 +173,9 @@ def test_chat_completion_tool_call(_truncate_tables):
user_id = str(uuid.uuid4())
model = "gpt-3.5-turbo"

def get_lowercase_letters(letters: str) -> str:
return letters.lower()

resp = sync_client.chat.completions.create(
messages=[
{"role": "user", "content": "GREPTIMEAI"},
{"role": "function", "name": "get_lowercase_letters", "content": "letters"},
{"role": "user", "content": "what's the lower case of GREPTIMEAI ?"},
],
model=model,
user=user_id,
Expand All @@ -194,27 +191,30 @@ def get_lowercase_letters(letters: str) -> str:
"properties": {
"letters": {
"type": "string",
"description": "uppercase letters",
"description": "words or letters",
}
},
"required": ["letters"],
},
},
}
],
tool_choice="auto",
tool_choice={"type": "function", "function": {"name": "get_lowercase_letters"}},
)

assert (
resp.choices[0].message.content
and resp.choices[0].message.content.strip() == "greptimeai"
)
print(f"==== {resp=}")

tool_calls = resp.choices[0].message.tool_calls
assert tool_calls
assert len(tool_calls) == 1
assert tool_calls[0].function.name == "get_lowercase_letters"

collector.otel._force_flush()

trace = get_trace_data_with_retry(user_id=user_id, retry=3)

assert trace is not None
print(f"==== {trace=}")

assert "greptimeai" == trace.get("resource_attributes", {}).get("service.name")
assert "openai_completion" == trace.get("span_name")
Expand Down

0 comments on commit c65e0cb

Please sign in to comment.