Skip to content

Commit

Permalink
修复异步toolcall单测并发interrupt的问题
Browse files Browse the repository at this point in the history
  • Loading branch information
userpj committed Dec 20, 2024
1 parent 9a95c5e commit a740f39
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 17 deletions.
6 changes: 3 additions & 3 deletions python/core/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def __init__(self, secret_key=None, gateway="", gateway_v2=""):
self.session = AsyncInnerSession()

@staticmethod
def check_response_header(response: ClientResponse):
async def check_response_header(response: ClientResponse):
r"""check_response_header is a helper method for check head status .
:param response: requests.Response.
:rtype:
Expand All @@ -252,7 +252,7 @@ def check_response_header(response: ClientResponse):
if status_code == requests.codes.ok:
return
message = "request_id={} , http status code is {}, body is {}".format(
__class__.response_request_id(response), status_code, response.text
await __class__.response_request_id(response), status_code, await response.text()
)
if status_code == requests.codes.bad_request:
raise BadRequestException(message)
Expand All @@ -268,7 +268,7 @@ def check_response_header(response: ClientResponse):
raise BaseRPCException(message)

@staticmethod
def response_request_id(response: ClientResponse):
async def response_request_id(response: ClientResponse):
r"""response_request_id is a helper method to get the unique request id"""
return response.headers.get("X-Appbuilder-Request-Id", "")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ async def create_conversation(self) -> str:
response = await self.http_client.session.post(
url, headers=headers, json={"app_id": self.app_id}, timeout=None
)
self.http_client.check_response_header(response)
await self.http_client.check_response_header(response)
data = await response.json()
resp = data_class.CreateConversationResponse(**data)
return resp.conversation_id
Expand Down Expand Up @@ -116,8 +116,8 @@ async def run(
response = await self.http_client.session.post(
url, headers=headers, json=req.model_dump(), timeout=None
)
self.http_client.check_response_header(response)
request_id = self.http_client.response_request_id(response)
await self.http_client.check_response_header(response)
request_id = await self.http_client.response_request_id(response)
if stream:
client = AsyncSSEClient(response)
return Message(content=self._iterate_events(request_id, client.events()))
Expand Down Expand Up @@ -164,7 +164,7 @@ async def upload_local_file(self, conversation_id, local_file_path: str) -> str:
response = await self.http_client.session.post(
url, data=multipart_form_data, headers=headers
)
self.http_client.check_response_header(response)
await self.http_client.check_response_header(response)
data = await response.json()
resp = data_class.FileUploadResponse(**data)
return resp.id
Expand Down
6 changes: 3 additions & 3 deletions python/core/console/appbuilder_client/async_event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,13 @@ async def __async_run_process__(self):
while not self._is_complete:
if not self._need_tool_call:
res = await self._run()
self.__event_process__(res)
await self.__event_process__(res)
else:
res = await self._submit_tool_output()
self.__event_process__(res)
await self.__event_process__(res)
yield res
if self._need_tool_call and self._is_complete:
self.reset_state()
await self.reset_state()

async def __event_process__(self, run_response):
"""
Expand Down
15 changes: 8 additions & 7 deletions python/tests/test_async_appbuilder_client_toolcall.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ async def interrupt(self, run_context, run_response):
tool_call_id = tool_call.id
tool_res = self.get_current_weather(**tool_call.function.arguments)
# 蓝色打印
print("\033[1;34m", "-> 本地ToolCall结果: ", tool_res, "\033[0m\n")
print("\033[1;34m", "-> 本地ToolCallId: ", tool_call_id, "\033[0m")
print("\033[1;34m", "-> ToolCall结果: ", tool_res, "\033[0m\n")
tool_output.append(
{"tool_call_id": tool_call_id, "output": tool_res})
return tool_output
Expand All @@ -45,7 +46,7 @@ async def success(self, run_context, run_response):
run_response.answer, "\033[0m")


@unittest.skipUnless(os.getenv("TEST_CASE", "UNKNOWN") == "CPU_SERIAL", "")
# @unittest.skipUnless(os.getenv("TEST_CASE", "UNKNOWN") == "CPU_SERIAL", "")
class TestAgentRuntime(unittest.TestCase):
def setUp(self):
"""
Expand Down Expand Up @@ -92,9 +93,10 @@ def test_appbuilder_client_tool_call(self):
}
]

appbuilder.logger.setLoglevel("ERROR")
appbuilder.logger.setLoglevel("DEBUG")

async def agent_run(client, conversation_id, query):
async def agent_run(client, query):
conversation_id = await client.create_conversation()
with await client.run_with_handler(
conversation_id=conversation_id,
query=query,
Expand All @@ -105,11 +107,10 @@ async def agent_run(client, conversation_id, query):

async def agent_handle():
client = appbuilder.AsyncAppBuilderClient(self.app_id)
conversation_id = await client.create_conversation()
task1 = asyncio.create_task(
agent_run(client, conversation_id, "北京的天气怎么样"))
agent_run(client, "北京的天气怎么样"))
task2 = asyncio.create_task(
agent_run(client, conversation_id, "上海的天气怎么样"))
agent_run(client, "上海的天气怎么样"))
await asyncio.gather(task1, task2)

await client.http_client.session.close()
Expand Down

0 comments on commit a740f39

Please sign in to comment.