Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(agents-api): Fix tests for workflows #998

Merged
merged 1 commit into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions agents-api/agents_api/activities/container.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
class State:
pass


class Container:
state: State

def __init__(self):
self.state = State()


container = Container()
19 changes: 15 additions & 4 deletions agents-api/agents_api/activities/execute_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@
from beartype import beartype
from temporalio import activity

from ..app import lifespan
from ..autogen.openapi_model import BaseIntegrationDef
from ..clients import integrations
from ..common.exceptions.tools import IntegrationExecutionException
from ..common.protocol.tasks import ExecutionInput, StepContext
from ..env import testing
from ..queries.tools import get_tool_args_from_metadata
from .container import container


@lifespan(container)
@beartype
async def execute_integration(
context: StepContext,
Expand All @@ -26,12 +29,20 @@ async def execute_integration(
agent_id = context.execution_input.agent.id
task_id = context.execution_input.task.id

merged_tool_args = get_tool_args_from_metadata(
developer_id=developer_id, agent_id=agent_id, task_id=task_id, arg_type="args"
merged_tool_args = await get_tool_args_from_metadata(
developer_id=developer_id,
agent_id=agent_id,
task_id=task_id,
arg_type="args",
connection_pool=container.state.postgres_pool,
)

merged_tool_setup = get_tool_args_from_metadata(
developer_id=developer_id, agent_id=agent_id, task_id=task_id, arg_type="setup"
merged_tool_setup = await get_tool_args_from_metadata(
developer_id=developer_id,
agent_id=agent_id,
task_id=task_id,
arg_type="setup",
connection_pool=container.state.postgres_pool,
)

arguments = (
Expand Down
8 changes: 7 additions & 1 deletion agents-api/agents_api/activities/execute_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from fastapi.background import BackgroundTasks
from temporalio import activity

from ..app import lifespan
from ..autogen.openapi_model import (
ChatInput,
CreateDocRequest,
Expand All @@ -21,12 +22,14 @@
from ..common.protocol.tasks import ExecutionInput, StepContext
from ..env import testing
from ..queries.developers import get_developer
from .container import container
from .utils import get_handler

# For running synchronous code in the background
process_pool_executor = ProcessPoolExecutor()


@lifespan(container)
@beartype
async def execute_system(
context: StepContext,
Expand Down Expand Up @@ -89,7 +92,10 @@ async def execute_system(

# Handle chat operations
if system.operation == "chat" and system.resource == "session":
developer = await get_developer(developer_id=arguments.get("developer_id"))
developer = await get_developer(
developer_id=arguments["developer_id"],
connection_pool=container.state.postgres_pool,
)
session_id = arguments.get("session_id")
x_custom_api_key = arguments.get("x_custom_api_key", None)
chat_input = ChatInput(**arguments)
Expand Down
14 changes: 4 additions & 10 deletions agents-api/agents_api/activities/task_steps/pg_query_step.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,26 @@
from typing import Any

from async_lru import alru_cache
from beartype import beartype
from temporalio import activity

from ... import queries
from ...clients.pg import create_db_pool
from ...app import lifespan
from ...env import pg_dsn, testing
from ..container import container


@alru_cache(maxsize=1)
async def get_db_pool(dsn: str):
return await create_db_pool(dsn=dsn)


@lifespan(container)
@beartype
async def pg_query_step(
query_name: str,
values: dict[str, Any],
dsn: str = pg_dsn,
) -> Any:
pool = await get_db_pool(dsn=dsn)

(module_name, name) = query_name.split(".")

module = getattr(queries, module_name)
query = getattr(module, name)
return await query(**values, connection_pool=pool)
return await query(**values, connection_pool=container.state.postgres_pool)


# Note: This is here just for clarity. We could have just imported pg_query_step directly
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from fastapi import HTTPException
from temporalio import activity

from ...app import lifespan
from ...autogen.openapi_model import CreateTransitionRequest, Transition
from ...clients.temporal import get_workflow_handle
from ...common.protocol.tasks import ExecutionInput, StepContext
Expand All @@ -17,12 +18,14 @@
from ...queries.executions.create_execution_transition import (
create_execution_transition,
)
from ..container import container
from ..utils import RateLimiter

# Global rate limiter instance
rate_limiter = RateLimiter(max_requests=transition_requests_per_minute)


@lifespan(container)
@beartype
async def transition_step(
context: StepContext,
Expand Down Expand Up @@ -57,6 +60,7 @@ async def transition_step(
execution_id=context.execution_input.execution.id,
data=transition_info,
task_token=transition_info.task_token,
connection_pool=container.state.postgres_pool,
)

except Exception as e:
Expand Down
11 changes: 10 additions & 1 deletion agents-api/agents_api/app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from contextlib import asynccontextmanager
from typing import Any, Protocol

from aiobotocore.session import get_session
from fastapi import APIRouter, FastAPI
Expand All @@ -10,9 +11,17 @@
from .env import api_prefix, hostname, protocol, public_port


class Assignable(Protocol):
def __setattr__(self, name: str, value: Any) -> None: ...


class ObjectWithState(Protocol):
state: Assignable


# TODO: This currently doesn't use .env variables, but we should move to using them
@asynccontextmanager
async def lifespan(app: FastAPI):
async def lifespan(app: FastAPI | ObjectWithState):
# INIT POSTGRES #
pg_dsn = os.environ.get("PG_DSN")

Expand Down
4 changes: 1 addition & 3 deletions agents-api/agents_api/worker/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

from tenacity import after_log, retry, retry_if_exception_type, wait_fixed

from ..app import app, lifespan
from ..clients import temporal
from .worker import create_worker

Expand All @@ -37,8 +36,7 @@ async def main():
worker = create_worker(client)

# Start the worker to listen for and process tasks
async with lifespan(app):
await worker.run()
await worker.run()


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def create_worker(client: Client) -> Any:
from ..workflows.task_execution import TaskExecutionWorkflow
from ..workflows.truncation import TruncationWorkflow

task_activity_names, task_activities = zip(*getmembers(task_steps, isfunction))
_task_activity_names, task_activities = zip(*getmembers(task_steps, isfunction))

# Initialize the worker with the specified task queue, workflows, and activities
worker = Worker(
Expand Down
Loading
Loading