Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: make id to pass when multiple tools are passed #50

Merged
merged 6 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 3 additions & 17 deletions libs/ibm/langchain_ibm/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,6 @@ def _convert_delta_to_message_chunk(
_dict: Mapping[str, Any],
default_class: Type[BaseMessageChunk],
call_id: str,
is_first_tool_chunk: bool,
) -> BaseMessageChunk:
id_ = call_id
role = cast(str, _dict.get("role"))
Expand All @@ -291,9 +290,9 @@ def _convert_delta_to_message_chunk(
try:
tool_call_chunks = [
tool_call_chunk(
name=rtc["function"].get("name") if is_first_tool_chunk else None,
name=rtc["function"].get("name"),
args=rtc["function"].get("arguments"),
id=rtc.get("id") if is_first_tool_chunk else None,
id=rtc.get("id"),
index=rtc["index"],
)
for rtc in raw_tool_calls
Expand Down Expand Up @@ -329,7 +328,6 @@ def _convert_chunk_to_generation_chunk(
default_chunk_class: Type,
base_generation_info: Optional[Dict],
is_first_chunk: bool,
is_first_tool_chunk: bool,
) -> Optional[ChatGenerationChunk]:
token_usage = chunk.get("usage")
choices = chunk.get("choices", [])
Expand All @@ -350,7 +348,7 @@ def _convert_chunk_to_generation_chunk(
return None

message_chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class, chunk["id"], is_first_tool_chunk
choice["delta"], default_chunk_class, chunk["id"]
)
generation_info = {**base_generation_info} if base_generation_info else {}

Expand Down Expand Up @@ -724,7 +722,6 @@ def _stream(
base_generation_info: dict = {}

is_first_chunk = True
is_first_tool_chunk = True

for chunk in self.watsonx_model.chat_stream(
messages=message_dicts, **(kwargs | {"params": updated_params})
Expand All @@ -736,7 +733,6 @@ def _stream(
default_chunk_class,
base_generation_info if is_first_chunk else {},
is_first_chunk,
is_first_tool_chunk,
)
if generation_chunk is None:
continue
Expand All @@ -746,16 +742,6 @@ def _stream(
run_manager.on_llm_new_token(
generation_chunk.text, chunk=generation_chunk, logprobs=logprobs
)
if hasattr(generation_chunk.message, "tool_calls") and isinstance(
generation_chunk.message.tool_calls, list
):
first_tool_call = (
generation_chunk.message.tool_calls[0]
if generation_chunk.message.tool_calls
else None
)
if isinstance(first_tool_call, dict) and first_tool_call.get("name"):
is_first_tool_chunk = False

is_first_chunk = False

Expand Down
103 changes: 99 additions & 4 deletions libs/ibm/tests/integration_tests/test_chat_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import os
from typing import Any, Optional
from typing import Any, Optional, cast

import pytest
from ibm_watsonx_ai.foundation_models.schema import TextChatParameters # type: ignore
Expand Down Expand Up @@ -429,7 +429,7 @@ class Person(BaseModel):

with_tool = chat.bind_tools([Person], tool_choice=tool_choice)

result = with_tool.invoke("Erick, 27 years old")
result = with_tool.invoke("Erick, 27 years old. Make sure to use correct name")
assert isinstance(result, AIMessage)
assert result.content == "" # should just be tool call
tool_call = result.tool_calls[0]
Expand Down Expand Up @@ -627,11 +627,106 @@ class Person(BaseModel):
assert "tool_calls" not in acc.additional_kwargs


def test_streaming_multiple_tool_call() -> None:
chat = ChatWatsonx(
model_id=MODEL_ID_TOOL,
url=URL, # type: ignore[arg-type]
project_id=WX_PROJECT_ID,
temperature=0,
)

from typing import Literal

from langchain_core.tools import tool

@tool("search")
def search(query: str) -> list[str]:
"""Call to search the web for capital of countries"""
return ["capital of america is washington D.C."]

@tool("get_weather")
def get_weather(city: Literal["nyc"]) -> str:
"""Use this to get weather information."""
if city == "nyc":
return "It might be cloudy in nyc"
else:
raise ValueError("Unknown city")

tools = [search, get_weather]
tools_name = {el.name for el in tools}

tool_llm = chat.bind_tools(tools)

stream_response = tool_llm.stream(
"What is the weather in the NY and what is capital of USA?"
)

ai_message = None

for chunk in stream_response:
if ai_message is None:
ai_message = chunk
else:
ai_message += chunk
print(chunk.id, type(chunk.id))
assert isinstance(chunk, AIMessageChunk)
assert chunk.content == ""

ai_message = cast(AIMessageChunk, ai_message)
assert ai_message.response_metadata.get("finish_reason") == "tool_calls"
assert ai_message.response_metadata.get("model_name") == MODEL_ID_TOOL
assert ai_message.id is not None

# additional_kwargs
assert ai_message.additional_kwargs is not None
assert "tool_calls" in ai_message.additional_kwargs
assert len(ai_message.additional_kwargs["tool_calls"]) == 2
assert {
el["function"]["name"] for el in ai_message.additional_kwargs["tool_calls"]
} == tools_name

# tool_calls
assert all({el["id"] is not None for el in ai_message.tool_calls})
assert all({el["type"] == "tool_call" for el in ai_message.tool_calls})
assert {el["name"] for el in ai_message.tool_calls} == tools_name

generated_tools_args = [{"city": "nyc"}, {"query": "capital of USA"}]
assert {list(el["args"].keys())[0] for el in ai_message.tool_calls} == {
list(el.keys())[0] for el in generated_tools_args
}

# tool_call_chunks
predicted_tool_call_chunks = []
for i, el in enumerate(ai_message.tool_calls):
el |= {"type": "tool_call_chunk"} # type: ignore[typeddict-item]
el["args"] = json.dumps(el["args"]) # type: ignore[typeddict-item]
el |= {"index": i} # type: ignore[misc]
predicted_tool_call_chunks.append(el)

assert ai_message.tool_call_chunks == predicted_tool_call_chunks
assert (
json.loads(
ai_message.additional_kwargs["tool_calls"][0]["function"]["arguments"]
)
== generated_tools_args[0]
)
assert (
json.loads(
ai_message.additional_kwargs["tool_calls"][1]["function"]["arguments"]
)
== generated_tools_args[1]
)

# TODO: these tests should works when usage field will be fixed
# assert ai_message.usage_metadata is not None


def test_structured_output() -> None:
chat = ChatWatsonx(
model_id=MODEL_ID_TOOL,
url=URL, # type: ignore[arg-type]
project_id=WX_PROJECT_ID,
temperature=0,
)
schema = {
"title": "AnswerWithJustification",
Expand Down Expand Up @@ -856,7 +951,7 @@ def test_invoke_with_params_5() -> None:


def test_init_and_invoke_with_params_1() -> None:
params_1 = None
params_1 = {"max_tokens": 11}
chat = ChatWatsonx(
model_id=MODEL_ID_TOOL,
url=URL, # type: ignore[arg-type]
Expand All @@ -867,7 +962,7 @@ def test_init_and_invoke_with_params_1() -> None:
completion_tokens = resp.response_metadata.get("token_usage", {}).get(
"completion_tokens"
)
assert chat.params == {}
assert chat.params == params_1
assert 7 < completion_tokens < 11


Expand Down