Skip to content

Commit

Permalink
fix(agents-api): Fix retry policy + fix db pool for testing
Browse files Browse the repository at this point in the history
  • Loading branch information
Ahmad-mtos committed Jan 8, 2025
1 parent f3a190a commit 259a1c4
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 36 deletions.
2 changes: 1 addition & 1 deletion agents-api/agents_api/activities/excecute_api_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ async def execute_api_call(
if activity.in_activity():
activity.logger.error(f"Error in execute_api_call: {e}")

return StepOutcome(error=str(e))
raise


mock_execute_api_call = execute_api_call
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/activities/execute_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ async def execute_integration(
)
activity.logger.error(f"Error in execute_integration {integration_str}: {e}")

return StepOutcome(error=str(e))
raise


mock_execute_integration = execute_integration
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/activities/execute_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ async def execute_system(
except BaseException as e:
if activity.in_activity():
activity.logger.error(f"Error in execute_system_call: {e}")
return StepOutcome(error=str(e))
raise


def _create_search_request(arguments: dict) -> Any:
Expand Down
24 changes: 13 additions & 11 deletions agents-api/agents_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
from scalar_fastapi import get_scalar_api_reference

from .clients.pg import create_db_pool
from .env import api_prefix, hostname, protocol, public_port

from .env import api_prefix, hostname, protocol, public_port, testing

class State(Protocol):
postgres_pool: Pool | None
Expand All @@ -32,12 +31,14 @@ async def lifespan(*containers: FastAPI | ObjectWithState):
pg_dsn = os.environ.get("PG_DSN")

global pool
if not pool:
if not pool and not testing:
pool = await create_db_pool(pg_dsn)
local_pool = await create_db_pool(pg_dsn) if testing else None
pools = [pool, local_pool]

for container in containers:
if hasattr(container, "state") and not getattr(container.state, "postgres_pool", None):
container.state.postgres_pool = pool
container.state.postgres_pool = pools[testing]

# INIT S3 #
s3_access_key = os.environ.get("S3_ACCESS_KEY")
Expand All @@ -57,13 +58,14 @@ async def lifespan(*containers: FastAPI | ObjectWithState):
try:
yield
finally:
# # CLOSE POSTGRES #
# for container in containers:
# if hasattr(container, "state") and getattr(container.state, "postgres_pool", None):
# pool = getattr(container.state, "postgres_pool", None)
# if pool:
# await pool.close()
# container.state.postgres_pool = None
# CLOSE POSTGRES #
if testing:
for container in containers:
if hasattr(container, "state") and getattr(container.state, "postgres_pool", None):
pools[testing] = getattr(container.state, "postgres_pool", None)
if pools[testing]:
await pools[testing].close()
container.state.postgres_pool = None

# CLOSE S3 #
for container in containers:
Expand Down
49 changes: 27 additions & 22 deletions agents-api/agents_api/workflows/task_execution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,16 +203,18 @@ async def run(
workflow.logger.debug(f"Step {context.cursor.step} completed successfully")

except Exception as e:
if isinstance(e, CancelledError) or (
isinstance(e, ActivityError) and isinstance(e.__cause__, CancelledError)
):
workflow.logger.info("Workflow cancelled")
await transition(context, type="cancelled", output=None)
raise
workflow.logger.error(f"Error in step {context.cursor.step}: {e!s}")
await transition(context, type="error", output=str(e))
msg = f"Activity {activity} threw error: {e}"
raise ApplicationError(msg) from e
match e:
case CancelledError() | ActivityError(__cause__=CancelledError()):
workflow.logger.info("Workflow cancelled")
await transition(context, type="cancelled", output=None)
case ApplicationError(_non_retryable=True):
workflow.logger.error(f"Error in step {context.cursor.step}: {e!s}")
await transition(context, type="error", output=str(e))
case ActivityError(__cause__=ApplicationError(_non_retryable=True)):
workflow.logger.error(f"Error in step {context.cursor.step}: {e!s}")
await transition(context, type="error", output=str(e.__cause__))

raise

# ---

Expand Down Expand Up @@ -693,18 +695,21 @@ async def run(
)

except Exception as e:
if isinstance(e, CancelledError) or (
isinstance(e, ActivityError) and isinstance(e.__cause__, CancelledError)
):
workflow.logger.info("Workflow cancelled")
await transition(context, type="cancelled", output=None)
raise
workflow.logger.error(f"Unhandled error: {e!s}")
await transition(
context, type="error", output=str(e), last_error=self.last_error
)
msg = "Workflow encountered an error"
raise ApplicationError(msg) from e
match e:
case CancelledError() | ActivityError(__cause__=CancelledError()):
workflow.logger.info("Workflow cancelled")
await transition(context, type="cancelled", output=None)
case ApplicationError(_non_retryable=True):
workflow.logger.error(f"Unhandled error: {e!s}")
await transition(
context, type="error", output=str(e), last_error=self.last_error
)
case ActivityError(__cause__=ApplicationError(_non_retryable=True)):
workflow.logger.error(f"Unhandled error: {e!s}")
await transition(
context, type="error", output=str(e.__cause__), last_error=self.last_error
)
raise

previous_inputs.append(final_output)

Expand Down

0 comments on commit 259a1c4

Please sign in to comment.