diff --git a/agents-api/agents_api/activities/excecute_api_call.py b/agents-api/agents_api/activities/excecute_api_call.py index 4cc025c6f..e2e60f0a3 100644 --- a/agents-api/agents_api/activities/excecute_api_call.py +++ b/agents-api/agents_api/activities/excecute_api_call.py @@ -41,14 +41,20 @@ async def execute_api_call( **request_args, ) + response.raise_for_status() content_base64 = base64.b64encode(response.content).decode("ascii") response_dict = { "status_code": response.status_code, "headers": dict(response.headers), "content": content_base64, - "json": response.json(), } + + try: + response_dict["json"] = response.json() + except BaseException as e: + response_dict["json"] = None + activity.logger.debug(f"Failed to parse JSON response: {e}") return response_dict diff --git a/agents-api/agents_api/common/exceptions/tasks.py b/agents-api/agents_api/common/exceptions/tasks.py index 81331234c..4dd37fac3 100644 --- a/agents-api/agents_api/common/exceptions/tasks.py +++ b/agents-api/agents_api/common/exceptions/tasks.py @@ -19,6 +19,8 @@ import requests import temporalio.exceptions +### FIXME: This should be the opposite. We should retry on only known errors + # List of error types that should not be retried NON_RETRYABLE_ERROR_TYPES = ( # Temporal-specific errors @@ -56,8 +58,6 @@ # HTTP and API-related errors fastapi.exceptions.HTTPException, fastapi.exceptions.RequestValidationError, - httpx.RequestError, - httpx.HTTPStatusError, # # Asynchronous programming errors asyncio.CancelledError, @@ -102,6 +102,7 @@ ) +### FIXME: This should be the opposite. So `is_retryable_error` instead of `is_non_retryable_error` def is_non_retryable_error(error: BaseException) -> bool: """ Determines if the given error is non-retryable. @@ -115,4 +116,13 @@ def is_non_retryable_error(error: BaseException) -> bool: Returns: bool: True if the error is non-retryable, False otherwise. """ - return isinstance(error, NON_RETRYABLE_ERROR_TYPES) + if isinstance(error, NON_RETRYABLE_ERROR_TYPES): + return True + + # Check for specific HTTP errors (status code == 429) + if isinstance(error, httpx.HTTPStatusError): + if error.response.status_code in (408, 429, 503, 504): + return False + + # If we don't know about the error, we should not retry + return True \ No newline at end of file diff --git a/agents-api/tests/test_execution_workflow.py b/agents-api/tests/test_execution_workflow.py index f8a89cb62..267ca3097 100644 --- a/agents-api/tests/test_execution_workflow.py +++ b/agents-api/tests/test_execution_workflow.py @@ -1,6 +1,7 @@ # Tests for task queries import asyncio +import json from unittest.mock import patch import yaml @@ -559,6 +560,80 @@ async def _( assert result["hello"] == data.input["test"] +@test("workflow: tool call api_call test retry") +async def _( + client=cozo_client, + developer_id=test_developer_id, + agent=test_agent, +): + data = CreateExecutionRequest(input={"test": "input"}) + status_codes_to_retry = ','.join(str(code) for code in (408, 429, 503, 504)) + + task = create_task( + developer_id=developer_id, + agent_id=agent.id, + data=CreateTaskRequest( + **{ + "name": "test task", + "description": "test task about", + "input_schema": {"type": "object", "additionalProperties": True}, + "tools": [ + { + "type": "api_call", + "name": "hello", + "api_call": { + "method": "GET", + "url": f"https://httpbin.org/status/{status_codes_to_retry}", + }, + } + ], + "main": [ + { + "tool": "hello", + "arguments": { + "params": {"test": "_.test"}, + }, + }, + ], + } + ), + client=client, + ) + + async with patch_testing_temporal() as (_, mock_run_task_execution_workflow): + execution, handle = await start_execution( + developer_id=developer_id, + task_id=task.id, + data=data, + client=client, + ) + + assert handle is not None + mock_run_task_execution_workflow.assert_called_once() + + # Let it run for a bit + result_coroutine = handle.result() + task = asyncio.create_task(result_coroutine) + try: + await asyncio.wait_for(task, timeout=3) + except BaseException: + task.cancel() + + # Get the history + history = await handle.fetch_history() + events = [MessageToDict(e) for e in history.events] + assert len(events) > 0 + + # NOTE: super janky but works + events_strings = [json.dumps(event) for event in events] + num_retries = len([ + event for event in events_strings + if "execute_api_call" in event + ]) + + assert num_retries >= 2 + + @test("workflow: tool call integration dummy") async def _( client=cozo_client,