From 034d7691a4f075088b265a0ed7c27d3ca0d3ba15 Mon Sep 17 00:00:00 2001 From: Mateusz Switala Date: Fri, 31 Jan 2025 11:35:07 +0100 Subject: [PATCH] improve formatting --- libs/ibm/langchain_ibm/chat_models.py | 8 --- .../integration_tests/test_chat_models.py | 51 ++++++++++++------- libs/ibm/tests/integration_tests/test_llms.py | 48 ++++++++--------- 3 files changed, 58 insertions(+), 49 deletions(-) diff --git a/libs/ibm/langchain_ibm/chat_models.py b/libs/ibm/langchain_ibm/chat_models.py index 82cddf5..d309af4 100644 --- a/libs/ibm/langchain_ibm/chat_models.py +++ b/libs/ibm/langchain_ibm/chat_models.py @@ -742,14 +742,6 @@ def _stream( run_manager.on_llm_new_token( generation_chunk.text, chunk=generation_chunk, logprobs=logprobs ) - if hasattr(generation_chunk.message, "tool_calls") and isinstance( - generation_chunk.message.tool_calls, list - ): - first_tool_call = ( - generation_chunk.message.tool_calls[0] - if generation_chunk.message.tool_calls - else None - ) is_first_chunk = False diff --git a/libs/ibm/tests/integration_tests/test_chat_models.py b/libs/ibm/tests/integration_tests/test_chat_models.py index be3f315..f3ce95a 100644 --- a/libs/ibm/tests/integration_tests/test_chat_models.py +++ b/libs/ibm/tests/integration_tests/test_chat_models.py @@ -1,6 +1,6 @@ import json import os -from typing import Any, Optional +from typing import Any, Optional, cast import pytest from ibm_watsonx_ai.foundation_models.schema import TextChatParameters # type: ignore @@ -632,19 +632,20 @@ def test_streaming_multiple_tool_call() -> None: model_id=MODEL_ID_TOOL, url=URL, # type: ignore[arg-type] project_id=WX_PROJECT_ID, - temperature=0 + temperature=0, ) - from langchain_core.tools import tool from typing import Literal + from langchain_core.tools import tool + @tool("search") - def search(query: str): + def search(query: str) -> list[str]: """Call to search the web for capital of countries""" return ["capital of america is washington D.C."] @tool("get_weather") - def get_weather(city: Literal["nyc"]): + def get_weather(city: Literal["nyc"]) -> str: """Use this to get weather information.""" if city == "nyc": return "It might be cloudy in nyc" @@ -656,7 +657,9 @@ def get_weather(city: Literal["nyc"]): tool_llm = chat.bind_tools(tools) - stream_response = tool_llm.stream("What is the weather in the NY and what is capital of USA?") + stream_response = tool_llm.stream( + "What is the weather in the NY and what is capital of USA?" + ) ai_message = None @@ -669,15 +672,18 @@ def get_weather(city: Literal["nyc"]): assert isinstance(chunk, AIMessageChunk) assert chunk.content == "" - assert ai_message.response_metadata.get('finish_reason') == 'tool_calls' - assert ai_message.response_metadata.get('model_name') == MODEL_ID_TOOL + ai_message = cast(AIMessageChunk, ai_message) + assert ai_message.response_metadata.get("finish_reason") == "tool_calls" + assert ai_message.response_metadata.get("model_name") == MODEL_ID_TOOL assert ai_message.id is not None # additional_kwargs assert ai_message.additional_kwargs is not None assert "tool_calls" in ai_message.additional_kwargs assert len(ai_message.additional_kwargs["tool_calls"]) == 2 - assert {el["function"]["name"] for el in ai_message.additional_kwargs["tool_calls"]} == tools_name + assert { + el["function"]["name"] for el in ai_message.additional_kwargs["tool_calls"] + } == tools_name # tool_calls assert all({el["id"] is not None for el in ai_message.tool_calls}) @@ -685,22 +691,33 @@ def get_weather(city: Literal["nyc"]): assert {el["name"] for el in ai_message.tool_calls} == tools_name generated_tools_args = [{"city": "nyc"}, {"query": "capital of USA"}] - assert {list(el["args"].keys())[0] for el in ai_message.tool_calls} == {list(el.keys())[0] for el in generated_tools_args} - + assert {list(el["args"].keys())[0] for el in ai_message.tool_calls} == { + list(el.keys())[0] for el in generated_tools_args + } # tool_call_chunks predicted_tool_call_chunks = [] for i, el in enumerate(ai_message.tool_calls): - el |= {'type': 'tool_call_chunk'} - el['args'] = json.dumps(el['args']) - el |= {"index": i} + el |= {"type": "tool_call_chunk"} # type: ignore[typeddict-item] + el["args"] = json.dumps(el["args"]) # type: ignore[typeddict-item] + el |= {"index": i} # type: ignore[misc] predicted_tool_call_chunks.append(el) assert ai_message.tool_call_chunks == predicted_tool_call_chunks - assert json.loads(ai_message.additional_kwargs["tool_calls"][0]["function"]["arguments"]) == generated_tools_args[0] - assert json.loads(ai_message.additional_kwargs["tool_calls"][1]["function"]["arguments"]) == generated_tools_args[1] + assert ( + json.loads( + ai_message.additional_kwargs["tool_calls"][0]["function"]["arguments"] + ) + == generated_tools_args[0] + ) + assert ( + json.loads( + ai_message.additional_kwargs["tool_calls"][1]["function"]["arguments"] + ) + == generated_tools_args[1] + ) - #TODO: these tests should works when usage field will be fixed + # TODO: these tests should works when usage field will be fixed # assert ai_message.usage_metadata is not None diff --git a/libs/ibm/tests/integration_tests/test_llms.py b/libs/ibm/tests/integration_tests/test_llms.py index 5e74751..e16db26 100644 --- a/libs/ibm/tests/integration_tests/test_llms.py +++ b/libs/ibm/tests/integration_tests/test_llms.py @@ -233,14 +233,14 @@ def test_watsonxllm_stream() -> None: linked_text_stream = "" for chunk in stream_response: - assert isinstance( - chunk, str - ), f"chunk expect type '{str}', actual '{type(chunk)}'" + assert isinstance(chunk, str), ( + f"chunk expect type '{str}', actual '{type(chunk)}'" + ) linked_text_stream += chunk print(f"Linked text stream: {linked_text_stream}") - assert ( - response == linked_text_stream - ), "Linked text stream are not the same as generated text" + assert response == linked_text_stream, ( + "Linked text stream are not the same as generated text" + ) def test_watsonxllm_stream_with_kwargs() -> None: @@ -252,9 +252,9 @@ def test_watsonxllm_stream_with_kwargs() -> None: stream_response = watsonxllm.stream("What color sunflower is?", raw_response=True) for chunk in stream_response: - assert isinstance( - chunk, str - ), f"chunk expect type '{str}', actual '{type(chunk)}'" + assert isinstance(chunk, str), ( + f"chunk expect type '{str}', actual '{type(chunk)}'" + ) def test_watsonxllm_stream_with_params() -> None: @@ -276,14 +276,14 @@ def test_watsonxllm_stream_with_params() -> None: linked_text_stream = "" for chunk in stream_response: - assert isinstance( - chunk, str - ), f"chunk expect type '{str}', actual '{type(chunk)}'" + assert isinstance(chunk, str), ( + f"chunk expect type '{str}', actual '{type(chunk)}'" + ) linked_text_stream += chunk print(f"Linked text stream: {linked_text_stream}") - assert ( - response == linked_text_stream - ), "Linked text stream are not the same as generated text" + assert response == linked_text_stream, ( + "Linked text stream are not the same as generated text" + ) def test_watsonxllm_stream_with_params_2() -> None: @@ -300,9 +300,9 @@ def test_watsonxllm_stream_with_params_2() -> None: stream_response = watsonxllm.stream("What color sunflower is?", params=parameters) for chunk in stream_response: - assert isinstance( - chunk, str - ), f"chunk expect type '{str}', actual '{type(chunk)}'" + assert isinstance(chunk, str), ( + f"chunk expect type '{str}', actual '{type(chunk)}'" + ) print(chunk) @@ -323,9 +323,9 @@ def test_watsonxllm_stream_with_params_3() -> None: stream_response = watsonxllm.stream("What color sunflower is?", params=parameters_2) for chunk in stream_response: - assert isinstance( - chunk, str - ), f"chunk expect type '{str}', actual '{type(chunk)}'" + assert isinstance(chunk, str), ( + f"chunk expect type '{str}', actual '{type(chunk)}'" + ) print(chunk) @@ -346,9 +346,9 @@ def test_watsonxllm_stream_with_params_4() -> None: stream_response = watsonxllm.stream("What color sunflower is?", **parameters_2) # type: ignore[arg-type] for chunk in stream_response: - assert isinstance( - chunk, str - ), f"chunk expect type '{str}', actual '{type(chunk)}'" + assert isinstance(chunk, str), ( + f"chunk expect type '{str}', actual '{type(chunk)}'" + ) print(chunk)