Skip to content

Commit

Permalink
fix(agents-api): Fix tests for workflows
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Singh Tomer <[email protected]>
  • Loading branch information
creatorrr committed Dec 27, 2024
1 parent bf6f4d7 commit fbb33c4
Show file tree
Hide file tree
Showing 9 changed files with 1,483 additions and 1,401 deletions.
12 changes: 12 additions & 0 deletions agents-api/agents_api/activities/container.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
class State:
pass


class Container:
state: State

def __init__(self):
self.state = State()


container = Container()
19 changes: 15 additions & 4 deletions agents-api/agents_api/activities/execute_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@
from beartype import beartype
from temporalio import activity

from ..app import lifespan
from ..autogen.openapi_model import BaseIntegrationDef
from ..clients import integrations
from ..common.exceptions.tools import IntegrationExecutionException
from ..common.protocol.tasks import ExecutionInput, StepContext
from ..env import testing
from ..queries.tools import get_tool_args_from_metadata
from .container import container


@lifespan(container)
@beartype
async def execute_integration(
context: StepContext,
Expand All @@ -26,12 +29,20 @@ async def execute_integration(
agent_id = context.execution_input.agent.id
task_id = context.execution_input.task.id

merged_tool_args = get_tool_args_from_metadata(
developer_id=developer_id, agent_id=agent_id, task_id=task_id, arg_type="args"
merged_tool_args = await get_tool_args_from_metadata(
developer_id=developer_id,
agent_id=agent_id,
task_id=task_id,
arg_type="args",
connection_pool=container.state.postgres_pool,
)

merged_tool_setup = get_tool_args_from_metadata(
developer_id=developer_id, agent_id=agent_id, task_id=task_id, arg_type="setup"
merged_tool_setup = await get_tool_args_from_metadata(
developer_id=developer_id,
agent_id=agent_id,
task_id=task_id,
arg_type="setup",
connection_pool=container.state.postgres_pool,
)

arguments = (
Expand Down
8 changes: 7 additions & 1 deletion agents-api/agents_api/activities/execute_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from fastapi.background import BackgroundTasks
from temporalio import activity

from ..app import lifespan
from ..autogen.openapi_model import (
ChatInput,
CreateDocRequest,
Expand All @@ -21,12 +22,14 @@
from ..common.protocol.tasks import ExecutionInput, StepContext
from ..env import testing
from ..queries.developers import get_developer
from .container import container
from .utils import get_handler

# For running synchronous code in the background
process_pool_executor = ProcessPoolExecutor()


@lifespan(container)
@beartype
async def execute_system(
context: StepContext,
Expand Down Expand Up @@ -89,7 +92,10 @@ async def execute_system(

# Handle chat operations
if system.operation == "chat" and system.resource == "session":
developer = await get_developer(developer_id=arguments.get("developer_id"))
developer = await get_developer(
developer_id=arguments["developer_id"],
connection_pool=container.state.postgres_pool,
)
session_id = arguments.get("session_id")
x_custom_api_key = arguments.get("x_custom_api_key", None)
chat_input = ChatInput(**arguments)
Expand Down
14 changes: 4 additions & 10 deletions agents-api/agents_api/activities/task_steps/pg_query_step.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,26 @@
from typing import Any

from async_lru import alru_cache
from beartype import beartype
from temporalio import activity

from ... import queries
from ...clients.pg import create_db_pool
from ...app import lifespan
from ...env import pg_dsn, testing
from ..container import container


@alru_cache(maxsize=1)
async def get_db_pool(dsn: str):
return await create_db_pool(dsn=dsn)


@lifespan(container)
@beartype
async def pg_query_step(
query_name: str,
values: dict[str, Any],
dsn: str = pg_dsn,
) -> Any:
pool = await get_db_pool(dsn=dsn)

(module_name, name) = query_name.split(".")

module = getattr(queries, module_name)
query = getattr(module, name)
return await query(**values, connection_pool=pool)
return await query(**values, connection_pool=container.state.postgres_pool)


# Note: This is here just for clarity. We could have just imported pg_query_step directly
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from fastapi import HTTPException
from temporalio import activity

from ...app import lifespan
from ...autogen.openapi_model import CreateTransitionRequest, Transition
from ...clients.temporal import get_workflow_handle
from ...common.protocol.tasks import ExecutionInput, StepContext
Expand All @@ -17,12 +18,14 @@
from ...queries.executions.create_execution_transition import (
create_execution_transition,
)
from ..container import container
from ..utils import RateLimiter

# Global rate limiter instance
rate_limiter = RateLimiter(max_requests=transition_requests_per_minute)


@lifespan(container)
@beartype
async def transition_step(
context: StepContext,
Expand Down Expand Up @@ -57,6 +60,7 @@ async def transition_step(
execution_id=context.execution_input.execution.id,
data=transition_info,
task_token=transition_info.task_token,
connection_pool=container.state.postgres_pool,
)

except Exception as e:
Expand Down
11 changes: 10 additions & 1 deletion agents-api/agents_api/app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from contextlib import asynccontextmanager
from typing import Any, Protocol

from aiobotocore.session import get_session
from fastapi import APIRouter, FastAPI
Expand All @@ -10,9 +11,17 @@
from .env import api_prefix, hostname, protocol, public_port


class Assignable(Protocol):
def __setattr__(self, name: str, value: Any) -> None: ...


class ObjectWithState(Protocol):
state: Assignable


# TODO: This currently doesn't use .env variables, but we should move to using them
@asynccontextmanager
async def lifespan(app: FastAPI):
async def lifespan(app: FastAPI | ObjectWithState):
# INIT POSTGRES #
pg_dsn = os.environ.get("PG_DSN")

Expand Down
4 changes: 1 addition & 3 deletions agents-api/agents_api/worker/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

from tenacity import after_log, retry, retry_if_exception_type, wait_fixed

from ..app import app, lifespan
from ..clients import temporal
from .worker import create_worker

Expand All @@ -37,8 +36,7 @@ async def main():
worker = create_worker(client)

# Start the worker to listen for and process tasks
async with lifespan(app):
await worker.run()
await worker.run()


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def create_worker(client: Client) -> Any:
from ..workflows.task_execution import TaskExecutionWorkflow
from ..workflows.truncation import TruncationWorkflow

task_activity_names, task_activities = zip(*getmembers(task_steps, isfunction))
_task_activity_names, task_activities = zip(*getmembers(task_steps, isfunction))

# Initialize the worker with the specified task queue, workflows, and activities
worker = Worker(
Expand Down
Loading

0 comments on commit fbb33c4

Please sign in to comment.