Skip to content

Commit

Permalink
Fix for chat stream tool choice (#51)
Browse files Browse the repository at this point in the history
* partially revert changes from PR 50

* improve condition for setting name in tool_call_chunk

* fix mypy
  • Loading branch information
Mateusz-Switala authored Feb 13, 2025
1 parent 54d67d2 commit c50be28
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 3 deletions.
22 changes: 20 additions & 2 deletions libs/ibm/langchain_ibm/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ def _convert_delta_to_message_chunk(
_dict: Mapping[str, Any],
default_class: Type[BaseMessageChunk],
call_id: str,
is_first_tool_chunk: bool,
) -> BaseMessageChunk:
id_ = call_id
role = cast(str, _dict.get("role"))
Expand All @@ -290,8 +291,12 @@ def _convert_delta_to_message_chunk(
try:
tool_call_chunks = [
tool_call_chunk(
name=rtc["function"].get("name"),
name=rtc["function"].get("name")
if is_first_tool_chunk or (rtc.get("id") is not None)
else None,
args=rtc["function"].get("arguments"),
# `id` is provided only for the first delta with unique tool_calls
# (multiple tool calls scenario)
id=rtc.get("id"),
index=rtc["index"],
)
Expand Down Expand Up @@ -328,6 +333,7 @@ def _convert_chunk_to_generation_chunk(
default_chunk_class: Type,
base_generation_info: Optional[Dict],
is_first_chunk: bool,
is_first_tool_chunk: bool,
) -> Optional[ChatGenerationChunk]:
token_usage = chunk.get("usage")
choices = chunk.get("choices", [])
Expand All @@ -348,7 +354,7 @@ def _convert_chunk_to_generation_chunk(
return None

message_chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class, chunk["id"]
choice["delta"], default_chunk_class, chunk["id"], is_first_tool_chunk
)
generation_info = {**base_generation_info} if base_generation_info else {}

Expand Down Expand Up @@ -722,6 +728,7 @@ def _stream(
base_generation_info: dict = {}

is_first_chunk = True
is_first_tool_chunk = True

for chunk in self.watsonx_model.chat_stream(
messages=message_dicts, **(kwargs | {"params": updated_params})
Expand All @@ -733,6 +740,7 @@ def _stream(
default_chunk_class,
base_generation_info if is_first_chunk else {},
is_first_chunk,
is_first_tool_chunk,
)
if generation_chunk is None:
continue
Expand All @@ -742,6 +750,16 @@ def _stream(
run_manager.on_llm_new_token(
generation_chunk.text, chunk=generation_chunk, logprobs=logprobs
)
if hasattr(generation_chunk.message, "tool_calls") and isinstance(
generation_chunk.message.tool_calls, list
):
first_tool_call = (
generation_chunk.message.tool_calls[0]
if generation_chunk.message.tool_calls
else None
)
if isinstance(first_tool_call, dict) and first_tool_call.get("name"):
is_first_tool_chunk = False

is_first_chunk = False

Expand Down
2 changes: 1 addition & 1 deletion libs/ibm/tests/integration_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,7 @@ def get_weather(city: Literal["nyc"]) -> str:
if ai_message is None:
ai_message = chunk
else:
ai_message += chunk
ai_message += chunk # type: ignore[assignment]
print(chunk.id, type(chunk.id))
assert isinstance(chunk, AIMessageChunk)
assert chunk.content == ""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def chat_model_params(self) -> dict:
"url": URL,
"apikey": WX_APIKEY,
"project_id": WX_PROJECT_ID,
"temperature": 0,
}

@pytest.mark.xfail(reason="Supported for vision model.")
Expand Down

0 comments on commit c50be28

Please sign in to comment.