diff --git a/agents-api/agents_api/activities/embed_docs.py b/agents-api/agents_api/activities/embed_docs.py index c6c7663c3..a9a7cae44 100644 --- a/agents-api/agents_api/activities/embed_docs.py +++ b/agents-api/agents_api/activities/embed_docs.py @@ -7,13 +7,11 @@ from temporalio import activity from ..clients import cozo, litellm -from ..common.storage_handler import auto_blob_store from ..env import testing from ..models.docs.embed_snippets import embed_snippets as embed_snippets_query from .types import EmbedDocsPayload -@auto_blob_store(deep=True) @beartype async def embed_docs( payload: EmbedDocsPayload, cozo_client=None, max_batch_size: int = 100 diff --git a/agents-api/agents_api/activities/excecute_api_call.py b/agents-api/agents_api/activities/excecute_api_call.py index 09a33aaa8..2167aaead 100644 --- a/agents-api/agents_api/activities/excecute_api_call.py +++ b/agents-api/agents_api/activities/excecute_api_call.py @@ -6,7 +6,6 @@ from temporalio import activity from ..autogen.openapi_model import ApiCallDef -from ..common.storage_handler import auto_blob_store from ..env import testing @@ -20,7 +19,6 @@ class RequestArgs(TypedDict): headers: Optional[dict[str, str]] -@auto_blob_store(deep=True) @beartype async def execute_api_call( api_call: ApiCallDef, diff --git a/agents-api/agents_api/activities/execute_integration.py b/agents-api/agents_api/activities/execute_integration.py index 3316ad6f5..d058553c4 100644 --- a/agents-api/agents_api/activities/execute_integration.py +++ b/agents-api/agents_api/activities/execute_integration.py @@ -7,12 +7,10 @@ from ..clients import integrations from ..common.exceptions.tools import IntegrationExecutionException from ..common.protocol.tasks import ExecutionInput, StepContext -from ..common.storage_handler import auto_blob_store from ..env import testing from ..models.tools import get_tool_args_from_metadata -@auto_blob_store(deep=True) @beartype async def execute_integration( context: StepContext, diff --git a/agents-api/agents_api/activities/execute_system.py b/agents-api/agents_api/activities/execute_system.py index 590849080..647327a8a 100644 --- a/agents-api/agents_api/activities/execute_system.py +++ b/agents-api/agents_api/activities/execute_system.py @@ -19,16 +19,14 @@ VectorDocSearchRequest, ) from ..common.protocol.tasks import ExecutionInput, StepContext -from ..common.storage_handler import auto_blob_store, load_from_blob_store_if_remote from ..env import testing -from ..queries.developer import get_developer +from ..queries.developers import get_developer from .utils import get_handler # For running synchronous code in the background process_pool_executor = ProcessPoolExecutor() -@auto_blob_store(deep=True) @beartype async def execute_system( context: StepContext, @@ -37,9 +35,6 @@ async def execute_system( """Execute a system call with the appropriate handler and transformed arguments.""" arguments: dict[str, Any] = system.arguments or {} - if set(arguments.keys()) == {"bucket", "key"}: - arguments = await load_from_blob_store_if_remote(arguments) - if not isinstance(context.execution_input, ExecutionInput): raise TypeError("Expected ExecutionInput type for context.execution_input") diff --git a/agents-api/agents_api/activities/sync_items_remote.py b/agents-api/agents_api/activities/sync_items_remote.py index d71a5c566..14751c2b6 100644 --- a/agents-api/agents_api/activities/sync_items_remote.py +++ b/agents-api/agents_api/activities/sync_items_remote.py @@ -9,20 +9,16 @@ @beartype async def save_inputs_remote_fn(inputs: list[Any]) -> list[Any | RemoteObject]: - from ..common.storage_handler import store_in_blob_store_if_large + from ..common.interceptors import offload_if_large - return await asyncio.gather( - *[store_in_blob_store_if_large(input) for input in inputs] - ) + return await asyncio.gather(*[offload_if_large(input) for input in inputs]) @beartype async def load_inputs_remote_fn(inputs: list[Any | RemoteObject]) -> list[Any]: - from ..common.storage_handler import load_from_blob_store_if_remote + from ..common.interceptors import load_if_remote - return await asyncio.gather( - *[load_from_blob_store_if_remote(input) for input in inputs] - ) + return await asyncio.gather(*[load_if_remote(input) for input in inputs]) save_inputs_remote = activity.defn(name="save_inputs_remote")(save_inputs_remote_fn) diff --git a/agents-api/agents_api/activities/task_steps/base_evaluate.py b/agents-api/agents_api/activities/task_steps/base_evaluate.py index d87b961d3..3bb04e390 100644 --- a/agents-api/agents_api/activities/task_steps/base_evaluate.py +++ b/agents-api/agents_api/activities/task_steps/base_evaluate.py @@ -13,7 +13,6 @@ from temporalio import activity # noqa: E402 from thefuzz import fuzz # noqa: E402 -from ...common.storage_handler import auto_blob_store # noqa: E402 from ...env import testing # noqa: E402 from ..utils import get_evaluator # noqa: E402 @@ -63,7 +62,6 @@ def _recursive_evaluate(expr, evaluator: SimpleEval): raise ValueError(f"Invalid expression: {expr}") -@auto_blob_store(deep=True) @beartype async def base_evaluate( exprs: Any, diff --git a/agents-api/agents_api/activities/task_steps/cozo_query_step.py b/agents-api/agents_api/activities/task_steps/cozo_query_step.py index 16e9a53d8..8d28d83c9 100644 --- a/agents-api/agents_api/activities/task_steps/cozo_query_step.py +++ b/agents-api/agents_api/activities/task_steps/cozo_query_step.py @@ -4,11 +4,9 @@ from temporalio import activity from ... import models -from ...common.storage_handler import auto_blob_store from ...env import testing -@auto_blob_store(deep=True) @beartype async def cozo_query_step( query_name: str, diff --git a/agents-api/agents_api/activities/task_steps/evaluate_step.py b/agents-api/agents_api/activities/task_steps/evaluate_step.py index 904ec3b9d..08fa6cd55 100644 --- a/agents-api/agents_api/activities/task_steps/evaluate_step.py +++ b/agents-api/agents_api/activities/task_steps/evaluate_step.py @@ -5,11 +5,9 @@ from ...activities.utils import simple_eval_dict from ...common.protocol.tasks import StepContext, StepOutcome -from ...common.storage_handler import auto_blob_store from ...env import testing -@auto_blob_store(deep=True) @beartype async def evaluate_step( context: StepContext, diff --git a/agents-api/agents_api/activities/task_steps/for_each_step.py b/agents-api/agents_api/activities/task_steps/for_each_step.py index f51c1ef76..ca84eb75d 100644 --- a/agents-api/agents_api/activities/task_steps/for_each_step.py +++ b/agents-api/agents_api/activities/task_steps/for_each_step.py @@ -6,12 +6,10 @@ StepContext, StepOutcome, ) -from ...common.storage_handler import auto_blob_store from ...env import testing from .base_evaluate import base_evaluate -@auto_blob_store(deep=True) @beartype async def for_each_step(context: StepContext) -> StepOutcome: try: diff --git a/agents-api/agents_api/activities/task_steps/get_value_step.py b/agents-api/agents_api/activities/task_steps/get_value_step.py index ca38bc4fe..feeb71bbf 100644 --- a/agents-api/agents_api/activities/task_steps/get_value_step.py +++ b/agents-api/agents_api/activities/task_steps/get_value_step.py @@ -2,13 +2,12 @@ from temporalio import activity from ...common.protocol.tasks import StepContext, StepOutcome -from ...common.storage_handler import auto_blob_store from ...env import testing - # TODO: We should use this step to query the parent workflow and get the value from the workflow context # SCRUM-1 -@auto_blob_store(deep=True) + + @beartype async def get_value_step( context: StepContext, diff --git a/agents-api/agents_api/activities/task_steps/if_else_step.py b/agents-api/agents_api/activities/task_steps/if_else_step.py index cf3764199..ec4368640 100644 --- a/agents-api/agents_api/activities/task_steps/if_else_step.py +++ b/agents-api/agents_api/activities/task_steps/if_else_step.py @@ -6,12 +6,10 @@ StepContext, StepOutcome, ) -from ...common.storage_handler import auto_blob_store from ...env import testing from .base_evaluate import base_evaluate -@auto_blob_store(deep=True) @beartype async def if_else_step(context: StepContext) -> StepOutcome: # NOTE: This activity is only for logging, so we just evaluate the expression diff --git a/agents-api/agents_api/activities/task_steps/log_step.py b/agents-api/agents_api/activities/task_steps/log_step.py index 28fea2dae..f54018683 100644 --- a/agents-api/agents_api/activities/task_steps/log_step.py +++ b/agents-api/agents_api/activities/task_steps/log_step.py @@ -6,12 +6,10 @@ StepContext, StepOutcome, ) -from ...common.storage_handler import auto_blob_store from ...common.utils.template import render_template from ...env import testing -@auto_blob_store(deep=True) @beartype async def log_step(context: StepContext) -> StepOutcome: # NOTE: This activity is only for logging, so we just evaluate the expression diff --git a/agents-api/agents_api/activities/task_steps/map_reduce_step.py b/agents-api/agents_api/activities/task_steps/map_reduce_step.py index 872988bb4..c39bace20 100644 --- a/agents-api/agents_api/activities/task_steps/map_reduce_step.py +++ b/agents-api/agents_api/activities/task_steps/map_reduce_step.py @@ -8,12 +8,10 @@ StepContext, StepOutcome, ) -from ...common.storage_handler import auto_blob_store from ...env import testing from .base_evaluate import base_evaluate -@auto_blob_store(deep=True) @beartype async def map_reduce_step(context: StepContext) -> StepOutcome: try: diff --git a/agents-api/agents_api/activities/task_steps/prompt_step.py b/agents-api/agents_api/activities/task_steps/prompt_step.py index cf8b169d5..47560cadd 100644 --- a/agents-api/agents_api/activities/task_steps/prompt_step.py +++ b/agents-api/agents_api/activities/task_steps/prompt_step.py @@ -8,7 +8,6 @@ litellm, # We dont directly import `acompletion` so we can mock it ) from ...common.protocol.tasks import ExecutionInput, StepContext, StepOutcome -from ...common.storage_handler import auto_blob_store from ...common.utils.template import render_template from ...env import debug from .base_evaluate import base_evaluate @@ -62,7 +61,6 @@ def format_tool(tool: Tool) -> dict: @activity.defn -@auto_blob_store(deep=True) @beartype async def prompt_step(context: StepContext) -> StepOutcome: # Get context data diff --git a/agents-api/agents_api/activities/task_steps/raise_complete_async.py b/agents-api/agents_api/activities/task_steps/raise_complete_async.py index 640d6ae4e..bbf27c500 100644 --- a/agents-api/agents_api/activities/task_steps/raise_complete_async.py +++ b/agents-api/agents_api/activities/task_steps/raise_complete_async.py @@ -6,12 +6,10 @@ from ...autogen.openapi_model import CreateTransitionRequest from ...common.protocol.tasks import StepContext -from ...common.storage_handler import auto_blob_store from .transition_step import original_transition_step @activity.defn -@auto_blob_store(deep=True) @beartype async def raise_complete_async(context: StepContext, output: Any) -> None: activity_info = activity.info() diff --git a/agents-api/agents_api/activities/task_steps/return_step.py b/agents-api/agents_api/activities/task_steps/return_step.py index 08ac20de4..f15354536 100644 --- a/agents-api/agents_api/activities/task_steps/return_step.py +++ b/agents-api/agents_api/activities/task_steps/return_step.py @@ -6,12 +6,10 @@ StepContext, StepOutcome, ) -from ...common.storage_handler import auto_blob_store from ...env import testing from .base_evaluate import base_evaluate -@auto_blob_store(deep=True) @beartype async def return_step(context: StepContext) -> StepOutcome: try: diff --git a/agents-api/agents_api/activities/task_steps/set_value_step.py b/agents-api/agents_api/activities/task_steps/set_value_step.py index 1c97b6551..96db5d0d1 100644 --- a/agents-api/agents_api/activities/task_steps/set_value_step.py +++ b/agents-api/agents_api/activities/task_steps/set_value_step.py @@ -5,13 +5,12 @@ from ...activities.utils import simple_eval_dict from ...common.protocol.tasks import StepContext, StepOutcome -from ...common.storage_handler import auto_blob_store from ...env import testing - # TODO: We should use this step to signal to the parent workflow and set the value on the workflow context # SCRUM-2 -@auto_blob_store(deep=True) + + @beartype async def set_value_step( context: StepContext, diff --git a/agents-api/agents_api/activities/task_steps/switch_step.py b/agents-api/agents_api/activities/task_steps/switch_step.py index 6a95e98d2..100d8020a 100644 --- a/agents-api/agents_api/activities/task_steps/switch_step.py +++ b/agents-api/agents_api/activities/task_steps/switch_step.py @@ -6,12 +6,10 @@ StepContext, StepOutcome, ) -from ...common.storage_handler import auto_blob_store from ...env import testing from ..utils import get_evaluator -@auto_blob_store(deep=True) @beartype async def switch_step(context: StepContext) -> StepOutcome: try: diff --git a/agents-api/agents_api/activities/task_steps/tool_call_step.py b/agents-api/agents_api/activities/task_steps/tool_call_step.py index 5725a75d1..a2d7fd7c2 100644 --- a/agents-api/agents_api/activities/task_steps/tool_call_step.py +++ b/agents-api/agents_api/activities/task_steps/tool_call_step.py @@ -11,7 +11,6 @@ StepContext, StepOutcome, ) -from ...common.storage_handler import auto_blob_store # FIXME: This shouldn't be here. @@ -47,7 +46,6 @@ def construct_tool_call( @activity.defn -@auto_blob_store(deep=True) @beartype async def tool_call_step(context: StepContext) -> StepOutcome: assert isinstance(context.current_step, ToolCallStep) diff --git a/agents-api/agents_api/activities/task_steps/transition_step.py b/agents-api/agents_api/activities/task_steps/transition_step.py index 44046a5e7..11c7befb5 100644 --- a/agents-api/agents_api/activities/task_steps/transition_step.py +++ b/agents-api/agents_api/activities/task_steps/transition_step.py @@ -8,7 +8,6 @@ from ...autogen.openapi_model import CreateTransitionRequest, Transition from ...clients.temporal import get_workflow_handle from ...common.protocol.tasks import ExecutionInput, StepContext -from ...common.storage_handler import load_from_blob_store_if_remote from ...env import ( temporal_activity_after_retry_timeout, testing, @@ -48,11 +47,6 @@ async def transition_step( TaskExecutionWorkflow.set_last_error, LastErrorInput(last_error=None) ) - # Load output from blob store if it is a remote object - transition_info.output = await load_from_blob_store_if_remote( - transition_info.output - ) - if not isinstance(context.execution_input, ExecutionInput): raise TypeError("Expected ExecutionInput type for context.execution_input") diff --git a/agents-api/agents_api/activities/task_steps/wait_for_input_step.py b/agents-api/agents_api/activities/task_steps/wait_for_input_step.py index ad6eeb63e..a3cb00f67 100644 --- a/agents-api/agents_api/activities/task_steps/wait_for_input_step.py +++ b/agents-api/agents_api/activities/task_steps/wait_for_input_step.py @@ -3,12 +3,10 @@ from ...autogen.openapi_model import WaitForInputStep from ...common.protocol.tasks import StepContext, StepOutcome -from ...common.storage_handler import auto_blob_store from ...env import testing from .base_evaluate import base_evaluate -@auto_blob_store(deep=True) @beartype async def wait_for_input_step(context: StepContext) -> StepOutcome: try: diff --git a/agents-api/agents_api/activities/task_steps/yield_step.py b/agents-api/agents_api/activities/task_steps/yield_step.py index 199008703..18e5383cc 100644 --- a/agents-api/agents_api/activities/task_steps/yield_step.py +++ b/agents-api/agents_api/activities/task_steps/yield_step.py @@ -5,12 +5,10 @@ from ...autogen.openapi_model import TransitionTarget, YieldStep from ...common.protocol.tasks import ExecutionInput, StepContext, StepOutcome -from ...common.storage_handler import auto_blob_store from ...env import testing from .base_evaluate import base_evaluate -@auto_blob_store(deep=True) @beartype async def yield_step(context: StepContext) -> StepOutcome: try: diff --git a/agents-api/agents_api/app.py b/agents-api/agents_api/app.py index 735dfc8c0..ced41decb 100644 --- a/agents-api/agents_api/app.py +++ b/agents-api/agents_api/app.py @@ -1,7 +1,5 @@ -import json from contextlib import asynccontextmanager -import asyncpg from fastapi import FastAPI from prometheus_fastapi_instrumentator import Instrumentator @@ -11,9 +9,13 @@ @asynccontextmanager async def lifespan(app: FastAPI): - app.state.postgres_pool = await create_db_pool() + if not app.state.postgres_pool: + app.state.postgres_pool = await create_db_pool() + yield - await app.state.postgres_pool.close() + + if app.state.postgres_pool: + await app.state.postgres_pool.close() app: FastAPI = FastAPI( diff --git a/agents-api/agents_api/autogen/Entries.py b/agents-api/agents_api/autogen/Entries.py index de37e77d8..867b10192 100644 --- a/agents-api/agents_api/autogen/Entries.py +++ b/agents-api/agents_api/autogen/Entries.py @@ -52,6 +52,7 @@ class BaseEntry(BaseModel): ] tokenizer: str token_count: int + model: str = "gpt-4o-mini" tool_calls: ( list[ ChosenFunctionCall diff --git a/agents-api/agents_api/autogen/Sessions.py b/agents-api/agents_api/autogen/Sessions.py index 460fd25ce..e2a9ce164 100644 --- a/agents-api/agents_api/autogen/Sessions.py +++ b/agents-api/agents_api/autogen/Sessions.py @@ -31,6 +31,10 @@ class CreateSessionRequest(BaseModel): """ A specific situation that sets the background for this session """ + system_template: str | None = None + """ + System prompt for this session + """ render_templates: StrictBool = True """ Render system and assistant message content as jinja templates @@ -51,6 +55,10 @@ class CreateSessionRequest(BaseModel): If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. """ + forward_tool_calls: StrictBool = False + """ + Whether to forward tool calls to the model + """ recall_options: RecallOptions | None = None metadata: dict[str, Any] | None = None @@ -67,6 +75,10 @@ class PatchSessionRequest(BaseModel): """ A specific situation that sets the background for this session """ + system_template: str | None = None + """ + System prompt for this session + """ render_templates: StrictBool = True """ Render system and assistant message content as jinja templates @@ -87,6 +99,10 @@ class PatchSessionRequest(BaseModel): If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. """ + forward_tool_calls: StrictBool = False + """ + Whether to forward tool calls to the model + """ recall_options: RecallOptionsUpdate | None = None metadata: dict[str, Any] | None = None @@ -121,6 +137,10 @@ class Session(BaseModel): """ A specific situation that sets the background for this session """ + system_template: str | None = None + """ + System prompt for this session + """ summary: Annotated[str | None, Field(json_schema_extra={"readOnly": True})] = None """ Summary (null at the beginning) - generated automatically after every interaction @@ -145,6 +165,10 @@ class Session(BaseModel): If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. """ + forward_tool_calls: StrictBool = False + """ + Whether to forward tool calls to the model + """ recall_options: RecallOptions | None = None id: Annotated[UUID, Field(json_schema_extra={"readOnly": True})] metadata: dict[str, Any] | None = None @@ -197,6 +221,10 @@ class UpdateSessionRequest(BaseModel): """ A specific situation that sets the background for this session """ + system_template: str | None = None + """ + System prompt for this session + """ render_templates: StrictBool = True """ Render system and assistant message content as jinja templates @@ -217,6 +245,10 @@ class UpdateSessionRequest(BaseModel): If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. """ + forward_tool_calls: StrictBool = False + """ + Whether to forward tool calls to the model + """ recall_options: RecallOptions | None = None metadata: dict[str, Any] | None = None @@ -240,6 +272,10 @@ class CreateOrUpdateSessionRequest(CreateSessionRequest): """ A specific situation that sets the background for this session """ + system_template: str | None = None + """ + System prompt for this session + """ render_templates: StrictBool = True """ Render system and assistant message content as jinja templates @@ -260,6 +296,10 @@ class CreateOrUpdateSessionRequest(CreateSessionRequest): If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. """ + forward_tool_calls: StrictBool = False + """ + Whether to forward tool calls to the model + """ recall_options: RecallOptions | None = None metadata: dict[str, Any] | None = None diff --git a/agents-api/agents_api/autogen/openapi_model.py b/agents-api/agents_api/autogen/openapi_model.py index d19684cee..d809e0a35 100644 --- a/agents-api/agents_api/autogen/openapi_model.py +++ b/agents-api/agents_api/autogen/openapi_model.py @@ -14,7 +14,6 @@ model_validator, ) -from ..common.storage_handler import RemoteObject from ..common.utils.datetime import utcnow from .Agents import * from .Chat import * @@ -358,7 +357,7 @@ def validate_subworkflows(self): class SystemDef(SystemDef): - arguments: dict[str, Any] | None | RemoteObject = None + arguments: dict[str, Any] | None = None class CreateTransitionRequest(Transition): @@ -400,6 +399,7 @@ def from_model_input( source=source, tokenizer=tokenizer["type"], token_count=token_count, + model=model, **kwargs, ) diff --git a/agents-api/agents_api/clients/async_s3.py b/agents-api/agents_api/clients/async_s3.py index 0cd5235ee..b6ba76d8b 100644 --- a/agents-api/agents_api/clients/async_s3.py +++ b/agents-api/agents_api/clients/async_s3.py @@ -16,6 +16,7 @@ ) +@alru_cache(maxsize=1024) async def list_buckets() -> list[str]: session = get_session() diff --git a/agents-api/agents_api/clients/temporal.py b/agents-api/agents_api/clients/temporal.py index da2d7f6fa..cd2178d95 100644 --- a/agents-api/agents_api/clients/temporal.py +++ b/agents-api/agents_api/clients/temporal.py @@ -1,3 +1,4 @@ +import asyncio from datetime import timedelta from uuid import UUID @@ -12,9 +13,9 @@ from temporalio.runtime import PrometheusConfig, Runtime, TelemetryConfig from ..autogen.openapi_model import TransitionTarget +from ..common.interceptors import offload_if_large from ..common.protocol.tasks import ExecutionInput from ..common.retry_policies import DEFAULT_RETRY_POLICY -from ..common.storage_handler import store_in_blob_store_if_large from ..env import ( temporal_client_cert, temporal_metrics_bind_host, @@ -96,8 +97,10 @@ async def run_task_execution_workflow( client = client or (await get_client()) execution_id = execution_input.execution.id execution_id_key = SearchAttributeKey.for_keyword("CustomStringField") - execution_input.arguments = await store_in_blob_store_if_large( - execution_input.arguments + + old_args = execution_input.arguments + execution_input.arguments = await asyncio.gather( + *[offload_if_large(arg) for arg in old_args] ) return await client.start_workflow( diff --git a/agents-api/agents_api/common/interceptors.py b/agents-api/agents_api/common/interceptors.py index 40600a818..bfd64c374 100644 --- a/agents-api/agents_api/common/interceptors.py +++ b/agents-api/agents_api/common/interceptors.py @@ -4,8 +4,12 @@ certain types of errors that are known to be non-retryable. """ -from typing import Optional, Type +import asyncio +import sys +from functools import wraps +from typing import Any, Awaitable, Callable, Optional, Sequence, Type +from temporalio import workflow from temporalio.activity import _CompleteAsyncError as CompleteAsyncError from temporalio.exceptions import ApplicationError, FailureError, TemporalError from temporalio.service import RPCError @@ -23,7 +27,97 @@ ReadOnlyContextError, ) -from .exceptions.tasks import is_retryable_error +with workflow.unsafe.imports_passed_through(): + from ..env import blob_store_cutoff_kb, use_blob_store_for_temporal + from .exceptions.tasks import is_retryable_error + from .protocol.remote import RemoteObject + +# Common exceptions that should be re-raised without modification +PASSTHROUGH_EXCEPTIONS = ( + ContinueAsNewError, + ReadOnlyContextError, + NondeterminismError, + RPCError, + CompleteAsyncError, + TemporalError, + FailureError, + ApplicationError, +) + + +def is_too_large(result: Any) -> bool: + return sys.getsizeof(result) > blob_store_cutoff_kb * 1024 + + +async def load_if_remote[T](arg: T | RemoteObject[T]) -> T: + if use_blob_store_for_temporal and isinstance(arg, RemoteObject): + return await arg.load() + + return arg + + +async def offload_if_large[T](result: T) -> T: + if use_blob_store_for_temporal and is_too_large(result): + return await RemoteObject.from_value(result) + + return result + + +def offload_to_blob_store[S, T]( + func: Callable[[S, ExecuteActivityInput | ExecuteWorkflowInput], Awaitable[T]], +) -> Callable[ + [S, ExecuteActivityInput | ExecuteWorkflowInput], Awaitable[T | RemoteObject[T]] +]: + @wraps(func) + async def wrapper( + self, + input: ExecuteActivityInput | ExecuteWorkflowInput, + ) -> T | RemoteObject[T]: + # Load all remote arguments from the blob store + args: Sequence[Any] = input.args + + if use_blob_store_for_temporal: + input.args = await asyncio.gather(*[load_if_remote(arg) for arg in args]) + + # Execute the function + result = await func(self, input) + + # Save the result to the blob store if necessary + return await offload_if_large(result) + + return wrapper + + +async def handle_execution_with_errors[I, T]( + execution_fn: Callable[[I], Awaitable[T]], + input: I, +) -> T: + """ + Common error handling logic for both activities and workflows. + + Args: + execution_fn: Async function to execute with error handling + input: Input to the execution function + + Returns: + The result of the execution function + + Raises: + ApplicationError: For non-retryable errors + Any other exception: For retryable errors + """ + try: + return await execution_fn(input) + except PASSTHROUGH_EXCEPTIONS: + raise + except BaseException as e: + if not is_retryable_error(e): + raise ApplicationError( + str(e), + type=type(e).__name__, + non_retryable=True, + ) + raise class CustomActivityInterceptor(ActivityInboundInterceptor): @@ -35,95 +129,45 @@ class CustomActivityInterceptor(ActivityInboundInterceptor): as non-retryable errors. """ - async def execute_activity(self, input: ExecuteActivityInput): + @offload_to_blob_store + async def execute_activity(self, input: ExecuteActivityInput) -> Any: """ - 🎭 The Activity Whisperer: Handles activity execution with style and grace - - This is like a safety net for your activities - catching errors and deciding - their fate with the wisdom of a fortune cookie. + Handles activity execution by intercepting errors and determining their retry behavior. """ - try: - return await super().execute_activity(input) - except ( - ContinueAsNewError, # When you need a fresh start - ReadOnlyContextError, # When someone tries to write in a museum - NondeterminismError, # When chaos theory kicks in - RPCError, # When computers can't talk to each other - CompleteAsyncError, # When async goes wrong - TemporalError, # When time itself rebels - FailureError, # When failure is not an option, but happens anyway - ApplicationError, # When the app says "nope" - ): - raise - except BaseException as e: - if not is_retryable_error(e): - # If it's not retryable, we wrap it in a nice bow (ApplicationError) - # and mark it as non-retryable to prevent further attempts - raise ApplicationError( - str(e), - type=type(e).__name__, - non_retryable=True, - ) - # For retryable errors, we'll let Temporal retry with backoff - # Default retry policy ensures at least 2 retries - raise + return await handle_execution_with_errors( + super().execute_activity, + input, + ) class CustomWorkflowInterceptor(WorkflowInboundInterceptor): """ - 🎪 The Workflow Circus Ringmaster + Custom interceptor for Temporal workflows. - This interceptor is like a circus ringmaster - keeping all the workflow acts - running smoothly and catching any lions (errors) that escape their cages. + Handles workflow execution errors and determines their retry behavior. """ - async def execute_workflow(self, input: ExecuteWorkflowInput): + @offload_to_blob_store + async def execute_workflow(self, input: ExecuteWorkflowInput) -> Any: """ - 🎪 The Main Event: Workflow Execution Extravaganza! - - Watch as we gracefully handle errors like a trapeze artist catching their partner! + Executes workflows and handles error cases appropriately. """ - try: - return await super().execute_workflow(input) - except ( - ContinueAsNewError, # The show must go on! - ReadOnlyContextError, # No touching, please! - NondeterminismError, # When butterflies cause hurricanes - RPCError, # Lost in translation - CompleteAsyncError, # Async said "bye" too soon - TemporalError, # Time is relative, errors are absolute - FailureError, # Task failed successfully - ApplicationError, # App.exe has stopped working - ): - raise - except BaseException as e: - if not is_retryable_error(e): - # Pack the error in a nice box with a "do not retry" sticker - raise ApplicationError( - str(e), - type=type(e).__name__, - non_retryable=True, - ) - # Let it retry - everyone deserves a second (or third) chance! - raise + return await handle_execution_with_errors( + super().execute_workflow, + input, + ) class CustomInterceptor(Interceptor): """ - 🎭 The Grand Interceptor: Master of Ceremonies - - This is like the backstage manager of a theater - making sure both the - activity actors and workflow directors have their interceptor costumes on. + Main interceptor class that provides both activity and workflow interceptors. """ def intercept_activity( self, next: ActivityInboundInterceptor ) -> ActivityInboundInterceptor: """ - 🎬 Activity Interceptor Factory: Where the magic begins! - - Creating custom activity interceptors faster than a caffeinated barista - makes espresso shots. + Creates and returns a custom activity interceptor. """ return CustomActivityInterceptor(super().intercept_activity(next)) @@ -131,9 +175,6 @@ def workflow_interceptor_class( self, input: WorkflowInterceptorClassInput ) -> Optional[Type[WorkflowInboundInterceptor]]: """ - 🎪 Workflow Interceptor Class Selector - - Like a matchmaker for workflows and their interceptors - a match made in - exception handling heaven! + Returns the custom workflow interceptor class. """ return CustomWorkflowInterceptor diff --git a/agents-api/agents_api/common/protocol/remote.py b/agents-api/agents_api/common/protocol/remote.py index ce2a2a63a..86add1949 100644 --- a/agents-api/agents_api/common/protocol/remote.py +++ b/agents-api/agents_api/common/protocol/remote.py @@ -1,91 +1,34 @@ from dataclasses import dataclass -from typing import Any +from typing import Generic, Self, Type, TypeVar, cast -from temporalio import activity, workflow +from temporalio import workflow with workflow.unsafe.imports_passed_through(): - from pydantic import BaseModel - + from ...clients import async_s3 from ...env import blob_store_bucket + from ...worker.codec import deserialize, serialize -@dataclass -class RemoteObject: - key: str - bucket: str = blob_store_bucket - - -class BaseRemoteModel(BaseModel): - _remote_cache: dict[str, Any] - - class Config: - arbitrary_types_allowed = True - - def __init__(self, **data: Any): - super().__init__(**data) - self._remote_cache = {} - - async def load_item(self, item: Any | RemoteObject) -> Any: - if not activity.in_activity(): - return item - - from ..storage_handler import load_from_blob_store_if_remote - - return await load_from_blob_store_if_remote(item) +T = TypeVar("T") - async def save_item(self, item: Any) -> Any: - if not activity.in_activity(): - return item - from ..storage_handler import store_in_blob_store_if_large - - return await store_in_blob_store_if_large(item) - - async def get_attribute(self, name: str) -> Any: - if name.startswith("_"): - return super().__getattribute__(name) - - try: - value = super().__getattribute__(name) - except AttributeError: - raise AttributeError( - f"'{type(self).__name__}' object has no attribute '{name}'" - ) - - if isinstance(value, RemoteObject): - cache = super().__getattribute__("_remote_cache") - if name in cache: - return cache[name] - - loaded_data = await self.load_item(value) - cache[name] = loaded_data - return loaded_data - - return value - - async def set_attribute(self, name: str, value: Any) -> None: - if name.startswith("_"): - super().__setattr__(name, value) - return +@dataclass +class RemoteObject(Generic[T]): + _type: Type[T] + key: str + bucket: str - stored_value = await self.save_item(value) - super().__setattr__(name, stored_value) + @classmethod + async def from_value(cls, x: T) -> Self: + await async_s3.setup() - if isinstance(stored_value, RemoteObject): - cache = self.__dict__.get("_remote_cache", {}) - cache.pop(name, None) + serialized = serialize(x) - async def load_all(self) -> None: - for name in self.model_fields_set: - await self.get_attribute(name) + key = await async_s3.add_object_with_hash(serialized) + return RemoteObject[T](key=key, bucket=blob_store_bucket, _type=type(x)) - async def unload_attribute(self, name: str) -> None: - if name in self._remote_cache: - data = self._remote_cache.pop(name) - remote_obj = await self.save_item(data) - super().__setattr__(name, remote_obj) + async def load(self) -> T: + await async_s3.setup() - async def unload_all(self) -> "BaseRemoteModel": - for name in list(self._remote_cache.keys()): - await self.unload_attribute(name) - return self + fetched = await async_s3.get_object(self.key) + return cast(self._type, deserialize(fetched)) diff --git a/agents-api/agents_api/common/protocol/sessions.py b/agents-api/agents_api/common/protocol/sessions.py index 121afe702..3b04178e1 100644 --- a/agents-api/agents_api/common/protocol/sessions.py +++ b/agents-api/agents_api/common/protocol/sessions.py @@ -103,7 +103,7 @@ def get_active_tools(self) -> list[Tool]: return active_toolset.tools - def get_chat_environment(self) -> dict[str, dict | list[dict]]: + def get_chat_environment(self) -> dict[str, dict | list[dict] | None]: """ Get the chat environment from the session data. """ diff --git a/agents-api/agents_api/common/protocol/tasks.py b/agents-api/agents_api/common/protocol/tasks.py index 430a62f36..f3bb81d07 100644 --- a/agents-api/agents_api/common/protocol/tasks.py +++ b/agents-api/agents_api/common/protocol/tasks.py @@ -1,9 +1,8 @@ -import asyncio from typing import Annotated, Any, Literal from uuid import UUID from beartype import beartype -from temporalio import activity, workflow +from temporalio import workflow from temporalio.exceptions import ApplicationError with workflow.unsafe.imports_passed_through(): @@ -33,8 +32,6 @@ Workflow, WorkflowStep, ) - from ...common.storage_handler import load_from_blob_store_if_remote - from .remote import BaseRemoteModel, RemoteObject # TODO: Maybe we should use a library for this @@ -146,16 +143,16 @@ class ExecutionInput(BaseModel): task: TaskSpecDef agent: Agent agent_tools: list[Tool | CreateToolRequest] - arguments: dict[str, Any] | RemoteObject + arguments: dict[str, Any] # Not used at the moment user: User | None = None session: Session | None = None -class StepContext(BaseRemoteModel): - execution_input: ExecutionInput | RemoteObject - inputs: list[Any] | RemoteObject +class StepContext(BaseModel): + execution_input: ExecutionInput + inputs: list[Any] cursor: TransitionTarget @computed_field @@ -242,17 +239,9 @@ def model_dump(self, *args, **kwargs) -> dict[str, Any]: return dump | execution_input - async def prepare_for_step( - self, *args, include_remote: bool = True, **kwargs - ) -> dict[str, Any]: + async def prepare_for_step(self, *args, **kwargs) -> dict[str, Any]: current_input = self.current_input inputs = self.inputs - if activity.in_activity() and include_remote: - await self.load_all() - inputs = await asyncio.gather( - *[load_from_blob_store_if_remote(input) for input in inputs] - ) - current_input = await load_from_blob_store_if_remote(current_input) # Merge execution inputs into the dump dict dump = self.model_dump(*args, **kwargs) diff --git a/agents-api/agents_api/common/storage_handler.py b/agents-api/agents_api/common/storage_handler.py deleted file mode 100644 index 42beef270..000000000 --- a/agents-api/agents_api/common/storage_handler.py +++ /dev/null @@ -1,226 +0,0 @@ -import asyncio -import sys -from datetime import timedelta -from functools import wraps -from typing import Any, Callable - -from pydantic import BaseModel -from temporalio import workflow - -from ..activities.sync_items_remote import load_inputs_remote -from ..clients import async_s3 -from ..common.protocol.remote import BaseRemoteModel, RemoteObject -from ..common.retry_policies import DEFAULT_RETRY_POLICY -from ..env import ( - blob_store_cutoff_kb, - debug, - temporal_heartbeat_timeout, - temporal_schedule_to_close_timeout, - testing, - use_blob_store_for_temporal, -) -from ..worker.codec import deserialize, serialize - - -async def store_in_blob_store_if_large(x: Any) -> RemoteObject | Any: - if not use_blob_store_for_temporal: - return x - - await async_s3.setup() - - serialized = serialize(x) - data_size = sys.getsizeof(serialized) - - if data_size > blob_store_cutoff_kb * 1024: - key = await async_s3.add_object_with_hash(serialized) - return RemoteObject(key=key) - - return x - - -async def load_from_blob_store_if_remote(x: Any | RemoteObject) -> Any: - if not use_blob_store_for_temporal: - return x - - await async_s3.setup() - - if isinstance(x, RemoteObject): - fetched = await async_s3.get_object(x.key) - return deserialize(fetched) - - elif isinstance(x, dict) and set(x.keys()) == {"bucket", "key"}: - fetched = await async_s3.get_object(x["key"]) - return deserialize(fetched) - - return x - - -# Decorator that automatically does two things: -# 1. store in blob store if the output of a function is large -# 2. load from blob store if the input is a RemoteObject - - -def auto_blob_store(f: Callable | None = None, *, deep: bool = False) -> Callable: - def auto_blob_store_decorator(f: Callable) -> Callable: - async def load_args( - args: list | tuple, kwargs: dict[str, Any] - ) -> tuple[list | tuple, dict[str, Any]]: - new_args = await asyncio.gather( - *[load_from_blob_store_if_remote(arg) for arg in args] - ) - kwargs_keys, kwargs_values = list(zip(*kwargs.items())) or ([], []) - new_kwargs = await asyncio.gather( - *[load_from_blob_store_if_remote(v) for v in kwargs_values] - ) - new_kwargs = dict(zip(kwargs_keys, new_kwargs)) - - if deep: - args = new_args - kwargs = new_kwargs - - new_args = [] - - for arg in args: - if isinstance(arg, list): - new_args.append( - await asyncio.gather( - *[load_from_blob_store_if_remote(item) for item in arg] - ) - ) - elif isinstance(arg, dict): - keys, values = list(zip(*arg.items())) or ([], []) - values = await asyncio.gather( - *[load_from_blob_store_if_remote(value) for value in values] - ) - new_args.append(dict(zip(keys, values))) - - elif isinstance(arg, BaseRemoteModel): - new_args.append(await arg.unload_all()) - - elif isinstance(arg, BaseModel): - for field in arg.model_fields.keys(): - if isinstance(getattr(arg, field), RemoteObject): - setattr( - arg, - field, - await load_from_blob_store_if_remote( - getattr(arg, field) - ), - ) - elif isinstance(getattr(arg, field), list): - setattr( - arg, - field, - await asyncio.gather( - *[ - load_from_blob_store_if_remote(item) - for item in getattr(arg, field) - ] - ), - ) - elif isinstance(getattr(arg, field), BaseRemoteModel): - setattr( - arg, - field, - await getattr(arg, field).unload_all(), - ) - - new_args.append(arg) - - else: - new_args.append(arg) - - new_kwargs = {} - - for k, v in kwargs.items(): - if isinstance(v, list): - new_kwargs[k] = await asyncio.gather( - *[load_from_blob_store_if_remote(item) for item in v] - ) - - elif isinstance(v, dict): - keys, values = list(zip(*v.items())) or ([], []) - values = await asyncio.gather( - *[load_from_blob_store_if_remote(value) for value in values] - ) - new_kwargs[k] = dict(zip(keys, values)) - - elif isinstance(v, BaseRemoteModel): - new_kwargs[k] = await v.unload_all() - - elif isinstance(v, BaseModel): - for field in v.model_fields.keys(): - if isinstance(getattr(v, field), RemoteObject): - setattr( - v, - field, - await load_from_blob_store_if_remote( - getattr(v, field) - ), - ) - elif isinstance(getattr(v, field), list): - setattr( - v, - field, - await asyncio.gather( - *[ - load_from_blob_store_if_remote(item) - for item in getattr(v, field) - ] - ), - ) - elif isinstance(getattr(v, field), BaseRemoteModel): - setattr( - v, - field, - await getattr(v, field).unload_all(), - ) - new_kwargs[k] = v - - else: - new_kwargs[k] = v - - return new_args, new_kwargs - - async def unload_return_value(x: Any | BaseRemoteModel) -> Any: - if isinstance(x, BaseRemoteModel): - await x.unload_all() - - return await store_in_blob_store_if_large(x) - - @wraps(f) - async def async_wrapper(*args, **kwargs) -> Any: - new_args, new_kwargs = await load_args(args, kwargs) - output = await f(*new_args, **new_kwargs) - - return await unload_return_value(output) - - return async_wrapper if use_blob_store_for_temporal else f - - return auto_blob_store_decorator(f) if f else auto_blob_store_decorator - - -def auto_blob_store_workflow(f: Callable) -> Callable: - @wraps(f) - async def wrapper(*args, **kwargs) -> Any: - keys = kwargs.keys() - values = [kwargs[k] for k in keys] - - loaded = await workflow.execute_activity( - load_inputs_remote, - args=[[*args, *values]], - schedule_to_close_timeout=timedelta( - seconds=60 if debug or testing else temporal_schedule_to_close_timeout - ), - retry_policy=DEFAULT_RETRY_POLICY, - heartbeat_timeout=timedelta(seconds=temporal_heartbeat_timeout), - ) - - loaded_args = loaded[: len(args)] - loaded_kwargs = dict(zip(keys, loaded[len(args) :])) - - result = await f(*loaded_args, **loaded_kwargs) - - return result - - return wrapper if use_blob_store_for_temporal else f diff --git a/agents-api/agents_api/env.py b/agents-api/agents_api/env.py index 48623b771..7baa24653 100644 --- a/agents-api/agents_api/env.py +++ b/agents-api/agents_api/env.py @@ -36,8 +36,8 @@ # Blob Store # ---------- -use_blob_store_for_temporal: bool = ( - env.bool("USE_BLOB_STORE_FOR_TEMPORAL", default=False) if not testing else False +use_blob_store_for_temporal: bool = testing or env.bool( + "USE_BLOB_STORE_FOR_TEMPORAL", default=False ) blob_store_bucket: str = env.str("BLOB_STORE_BUCKET", default="agents-api") @@ -66,6 +66,8 @@ default="postgres://postgres:postgres@0.0.0.0:5432/postgres?sslmode=disable", ) +query_timeout: float = env.float("QUERY_TIMEOUT", default=90.0) + # Auth # ---- diff --git a/agents-api/agents_api/queries/__init__.py b/agents-api/agents_api/queries/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/agents-api/agents_api/queries/agents/__init__.py b/agents-api/agents_api/queries/agents/__init__.py index ebd169040..c0712c47c 100644 --- a/agents-api/agents_api/queries/agents/__init__.py +++ b/agents-api/agents_api/queries/agents/__init__.py @@ -19,3 +19,13 @@ from .list_agents import list_agents from .patch_agent import patch_agent from .update_agent import update_agent + +__all__ = [ + "create_agent", + "create_or_update_agent", + "delete_agent", + "get_agent", + "list_agents", + "patch_agent", + "update_agent", +] diff --git a/agents-api/agents_api/queries/agents/create_agent.py b/agents-api/agents_api/queries/agents/create_agent.py index 0ee250336..76c96f46b 100644 --- a/agents-api/agents_api/queries/agents/create_agent.py +++ b/agents-api/agents_api/queries/agents/create_agent.py @@ -3,12 +3,9 @@ It includes functions to construct and execute SQL queries for inserting new agent records. """ -from typing import Any, TypeVar from uuid import UUID from beartype import beartype -from fastapi import HTTPException -from pydantic import ValidationError from sqlglot import parse_one from uuid_extensions import uuid7 @@ -16,16 +13,12 @@ from ...metrics.counters import increase_counter from ..utils import ( generate_canonical_name, - partialclass, pg_query, - rewrap_exceptions, wrap_in_class, ) -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - -raw_query = """ +# Define the raw SQL query +agent_query = parse_one(""" INSERT INTO agents ( developer_id, agent_id, @@ -49,9 +42,7 @@ $9 ) RETURNING *; -""" - -query = parse_one(raw_query).sql(pretty=True) +""").sql(pretty=True) # @rewrap_exceptions( @@ -138,4 +129,7 @@ async def create_agent( default_settings, ] - return query, params + return ( + agent_query, + params, + ) diff --git a/agents-api/agents_api/queries/agents/create_or_update_agent.py b/agents-api/agents_api/queries/agents/create_or_update_agent.py index e2b3fc525..ef3a0abe5 100644 --- a/agents-api/agents_api/queries/agents/create_or_update_agent.py +++ b/agents-api/agents_api/queries/agents/create_or_update_agent.py @@ -3,28 +3,21 @@ It constructs and executes SQL queries to insert a new agent or update an existing agent's details based on agent ID and developer ID. """ -from typing import Any, TypeVar from uuid import UUID from beartype import beartype -from fastapi import HTTPException from sqlglot import parse_one -from sqlglot.optimizer import optimize from ...autogen.openapi_model import Agent, CreateOrUpdateAgentRequest from ...metrics.counters import increase_counter from ..utils import ( generate_canonical_name, - partialclass, pg_query, - rewrap_exceptions, wrap_in_class, ) -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - -raw_query = """ +# Define the raw SQL query +agent_query = parse_one(""" INSERT INTO agents ( developer_id, agent_id, @@ -48,9 +41,7 @@ $9 ) RETURNING *; -""" - -query = parse_one(raw_query).sql(pretty=True) +""").sql(pretty=True) # @rewrap_exceptions( @@ -113,4 +104,7 @@ async def create_or_update_agent( default_settings, ] - return (query, params) + return ( + agent_query, + params, + ) diff --git a/agents-api/agents_api/queries/agents/delete_agent.py b/agents-api/agents_api/queries/agents/delete_agent.py index 0a47bc0eb..3527f3611 100644 --- a/agents-api/agents_api/queries/agents/delete_agent.py +++ b/agents-api/agents_api/queries/agents/delete_agent.py @@ -3,28 +3,20 @@ It constructs and executes SQL queries to remove agent records and associated data. """ -from typing import Any, TypeVar from uuid import UUID from beartype import beartype -from fastapi import HTTPException from sqlglot import parse_one -from sqlglot.optimizer import optimize from ...autogen.openapi_model import ResourceDeletedResponse from ...common.utils.datetime import utcnow -from ...metrics.counters import increase_counter from ..utils import ( - partialclass, pg_query, - rewrap_exceptions, wrap_in_class, ) -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - -raw_query = """ +# Define the raw SQL query +agent_query = parse_one(""" WITH deleted_docs AS ( DELETE FROM docs WHERE developer_id = $1 @@ -44,13 +36,10 @@ DELETE FROM agents WHERE agent_id = $2 AND developer_id = $1 RETURNING developer_id, agent_id; -""" - - -# Convert the list of queries into a single query string -query = parse_one(raw_query).sql(pretty=True) +""").sql(pretty=True) +# @rewrap_exceptions( # @rewrap_exceptions( # { # psycopg_errors.ForeignKeyViolation: partialclass( @@ -66,7 +55,6 @@ one=True, transform=lambda d: {**d, "id": d["agent_id"], "deleted_at": utcnow()}, ) -@increase_counter("delete_agent") @pg_query @beartype async def delete_agent(*, agent_id: UUID, developer_id: UUID) -> tuple[str, list]: @@ -83,4 +71,7 @@ async def delete_agent(*, agent_id: UUID, developer_id: UUID) -> tuple[str, list # Note: We swap the parameter order because the queries use $1 for developer_id and $2 for agent_id params = [developer_id, agent_id] - return (query, params) + return ( + agent_query, + params, + ) diff --git a/agents-api/agents_api/queries/agents/get_agent.py b/agents-api/agents_api/queries/agents/get_agent.py index a9893d747..a731300fa 100644 --- a/agents-api/agents_api/queries/agents/get_agent.py +++ b/agents-api/agents_api/queries/agents/get_agent.py @@ -3,24 +3,19 @@ It constructs and executes SQL queries to fetch agent details based on agent ID and developer ID. """ -from typing import Any, TypeVar from uuid import UUID from beartype import beartype -from fastapi import HTTPException from sqlglot import parse_one -from sqlglot.optimizer import optimize from ...autogen.openapi_model import Agent -from ...metrics.counters import increase_counter from ..utils import ( - partialclass, pg_query, - rewrap_exceptions, wrap_in_class, ) -raw_query = """ +# Define the raw SQL query +agent_query = parse_one(""" SELECT agent_id, developer_id, @@ -37,12 +32,7 @@ agents WHERE agent_id = $2 AND developer_id = $1; -""" - -query = parse_one(raw_query).sql(pretty=True) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") +""").sql(pretty=True) # @rewrap_exceptions( @@ -56,7 +46,6 @@ # # TODO: Add more exceptions # ) @wrap_in_class(Agent, one=True, transform=lambda d: {"id": d["agent_id"], **d}) -@increase_counter("get_agent") @pg_query @beartype async def get_agent(*, agent_id: UUID, developer_id: UUID) -> tuple[str, list]: @@ -71,4 +60,7 @@ async def get_agent(*, agent_id: UUID, developer_id: UUID) -> tuple[str, list]: tuple[list[str], dict]: A tuple containing the SQL query and its parameters. """ - return (query, [developer_id, agent_id]) + return ( + agent_query, + [developer_id, agent_id], + ) diff --git a/agents-api/agents_api/queries/agents/list_agents.py b/agents-api/agents_api/queries/agents/list_agents.py index 3613268c5..87a0c942d 100644 --- a/agents-api/agents_api/queries/agents/list_agents.py +++ b/agents-api/agents_api/queries/agents/list_agents.py @@ -3,26 +3,19 @@ It constructs and executes SQL queries to fetch a list of agents based on developer ID with pagination. """ -from typing import Any, Literal, TypeVar +from typing import Any, Literal from uuid import UUID from beartype import beartype from fastapi import HTTPException -from sqlglot import parse_one -from sqlglot.optimizer import optimize from ...autogen.openapi_model import Agent -from ...metrics.counters import increase_counter from ..utils import ( - partialclass, pg_query, - rewrap_exceptions, wrap_in_class, ) -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - +# Define the raw SQL query raw_query = """ SELECT agent_id, @@ -58,7 +51,6 @@ # # TODO: Add more exceptions # ) @wrap_in_class(Agent, transform=lambda d: {"id": d["agent_id"], **d}) -@increase_counter("list_agents") @pg_query @beartype async def list_agents( @@ -90,7 +82,7 @@ async def list_agents( # Build metadata filter clause if needed - final_query = raw_query.format( + agent_query = raw_query.format( metadata_filter_query="AND metadata @> $6::jsonb" if metadata_filter else "" ) @@ -105,4 +97,7 @@ async def list_agents( if metadata_filter: params.append(metadata_filter) - return final_query, params + return ( + agent_query, + params, + ) diff --git a/agents-api/agents_api/queries/agents/patch_agent.py b/agents-api/agents_api/queries/agents/patch_agent.py index d2a172838..69a5a6ca5 100644 --- a/agents-api/agents_api/queries/agents/patch_agent.py +++ b/agents-api/agents_api/queries/agents/patch_agent.py @@ -3,27 +3,20 @@ It constructs and executes SQL queries to update specific fields of an agent based on agent ID and developer ID. """ -from typing import Any, TypeVar from uuid import UUID from beartype import beartype -from fastapi import HTTPException from sqlglot import parse_one -from sqlglot.optimizer import optimize from ...autogen.openapi_model import PatchAgentRequest, ResourceUpdatedResponse from ...metrics.counters import increase_counter from ..utils import ( - partialclass, pg_query, - rewrap_exceptions, wrap_in_class, ) -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - -raw_query = """ +# Define the raw SQL query +agent_query = parse_one(""" UPDATE agents SET name = CASE @@ -48,9 +41,7 @@ END WHERE agent_id = $2 AND developer_id = $1 RETURNING *; -""" - -query = parse_one(raw_query).sql(pretty=True) +""").sql(pretty=True) # @rewrap_exceptions( @@ -95,4 +86,7 @@ async def patch_agent( data.default_settings.model_dump() if data.default_settings else None, ] - return query, params + return ( + agent_query, + params, + ) diff --git a/agents-api/agents_api/queries/agents/update_agent.py b/agents-api/agents_api/queries/agents/update_agent.py index d03994e9c..f28e28264 100644 --- a/agents-api/agents_api/queries/agents/update_agent.py +++ b/agents-api/agents_api/queries/agents/update_agent.py @@ -3,27 +3,20 @@ It constructs and executes SQL queries to replace an agent's details based on agent ID and developer ID. """ -from typing import Any, TypeVar from uuid import UUID from beartype import beartype -from fastapi import HTTPException from sqlglot import parse_one -from sqlglot.optimizer import optimize from ...autogen.openapi_model import ResourceUpdatedResponse, UpdateAgentRequest from ...metrics.counters import increase_counter from ..utils import ( - partialclass, pg_query, - rewrap_exceptions, wrap_in_class, ) -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - -raw_query = """ +# Define the raw SQL query +agent_query = parse_one(""" UPDATE agents SET metadata = $3, @@ -33,9 +26,7 @@ default_settings = $7::jsonb WHERE agent_id = $2 AND developer_id = $1 RETURNING *; -""" - -query = parse_one(raw_query).sql(pretty=True) +""").sql(pretty=True) # @rewrap_exceptions( @@ -80,4 +71,7 @@ async def update_agent( data.default_settings.model_dump() if data.default_settings else {}, ] - return (query, params) + return ( + agent_query, + params, + ) diff --git a/agents-api/agents_api/queries/developers/__init__.py b/agents-api/agents_api/queries/developers/__init__.py index b3964aba4..c3d1d4bbb 100644 --- a/agents-api/agents_api/queries/developers/__init__.py +++ b/agents-api/agents_api/queries/developers/__init__.py @@ -20,3 +20,10 @@ from .get_developer import get_developer from .patch_developer import patch_developer from .update_developer import update_developer + +__all__ = [ + "create_developer", + "get_developer", + "patch_developer", + "update_developer", +] diff --git a/agents-api/agents_api/queries/developers/create_developer.py b/agents-api/agents_api/queries/developers/create_developer.py index 7ee845fbf..bed6371c4 100644 --- a/agents-api/agents_api/queries/developers/create_developer.py +++ b/agents-api/agents_api/queries/developers/create_developer.py @@ -1,16 +1,21 @@ from uuid import UUID +import asyncpg from beartype import beartype +from fastapi import HTTPException from sqlglot import parse_one from uuid_extensions import uuid7 from ...common.protocol.developers import Developer from ..utils import ( + partialclass, pg_query, + rewrap_exceptions, wrap_in_class, ) -query = parse_one(""" +# Define the raw SQL query +developer_query = parse_one(""" INSERT INTO developers ( developer_id, email, @@ -19,22 +24,25 @@ settings ) VALUES ( - $1, - $2, - $3, - $4, - $5::jsonb + $1, -- developer_id + $2, -- email + $3, -- active + $4, -- tags + $5::jsonb -- settings ) RETURNING *; """).sql(pretty=True) -# @rewrap_exceptions( -# { -# QueryException: partialclass(HTTPException, status_code=403), -# ValidationError: partialclass(HTTPException, status_code=500), -# } -# ) +@rewrap_exceptions( + { + asyncpg.UniqueViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer does not exist.", + ) + } +) @wrap_in_class(Developer, one=True, transform=lambda d: {**d, "id": d["developer_id"]}) @pg_query @beartype @@ -49,6 +57,6 @@ async def create_developer( developer_id = str(developer_id or uuid7()) return ( - query, + developer_query, [developer_id, email, active, tags or [], settings or {}], ) diff --git a/agents-api/agents_api/queries/developers/get_developer.py b/agents-api/agents_api/queries/developers/get_developer.py index 38302ab3b..373a2fb36 100644 --- a/agents-api/agents_api/queries/developers/get_developer.py +++ b/agents-api/agents_api/queries/developers/get_developer.py @@ -3,13 +3,14 @@ from typing import Any, TypeVar from uuid import UUID +import asyncpg from beartype import beartype from fastapi import HTTPException -from pydantic import ValidationError from sqlglot import parse_one from ...common.protocol.developers import Developer from ..utils import ( + partialclass, pg_query, rewrap_exceptions, wrap_in_class, @@ -18,18 +19,24 @@ # TODO: Add verify_developer verify_developer = None -query = parse_one("SELECT * FROM developers WHERE developer_id = $1").sql(pretty=True) +# Define the raw SQL query +developer_query = parse_one(""" +SELECT * FROM developers WHERE developer_id = $1 -- developer_id +""").sql(pretty=True) ModelT = TypeVar("ModelT", bound=Any) T = TypeVar("T") -# @rewrap_exceptions( -# { -# QueryException: partialclass(HTTPException, status_code=403), -# ValidationError: partialclass(HTTPException, status_code=500), -# } -# ) +@rewrap_exceptions( + { + asyncpg.UniqueViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer does not exist.", + ) + } +) @wrap_in_class(Developer, one=True, transform=lambda d: {**d, "id": d["developer_id"]}) @pg_query @beartype @@ -40,6 +47,6 @@ async def get_developer( developer_id = str(developer_id) return ( - query, + developer_query, [developer_id], ) diff --git a/agents-api/agents_api/queries/developers/patch_developer.py b/agents-api/agents_api/queries/developers/patch_developer.py index 49edfe370..af2ddb1f8 100644 --- a/agents-api/agents_api/queries/developers/patch_developer.py +++ b/agents-api/agents_api/queries/developers/patch_developer.py @@ -1,28 +1,36 @@ from uuid import UUID +import asyncpg from beartype import beartype +from fastapi import HTTPException from sqlglot import parse_one from ...common.protocol.developers import Developer from ..utils import ( + partialclass, pg_query, + rewrap_exceptions, wrap_in_class, ) -query = parse_one(""" +# Define the raw SQL query +developer_query = parse_one(""" UPDATE developers -SET email = $1, active = $2, tags = tags || $3, settings = settings || $4 -WHERE developer_id = $5 +SET email = $1, active = $2, tags = tags || $3, settings = settings || $4 -- settings +WHERE developer_id = $5 -- developer_id RETURNING *; """).sql(pretty=True) -# @rewrap_exceptions( -# { -# QueryException: partialclass(HTTPException, status_code=403), -# ValidationError: partialclass(HTTPException, status_code=500), -# } -# ) +@rewrap_exceptions( + { + asyncpg.UniqueViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer does not exist.", + ) + } +) @wrap_in_class(Developer, one=True, transform=lambda d: {**d, "id": d["developer_id"]}) @pg_query @beartype @@ -37,6 +45,6 @@ async def patch_developer( developer_id = str(developer_id) return ( - query, + developer_query, [email, active, tags or [], settings or {}, developer_id], ) diff --git a/agents-api/agents_api/queries/developers/update_developer.py b/agents-api/agents_api/queries/developers/update_developer.py index 8350d45a0..d41b333d5 100644 --- a/agents-api/agents_api/queries/developers/update_developer.py +++ b/agents-api/agents_api/queries/developers/update_developer.py @@ -1,15 +1,20 @@ from uuid import UUID +import asyncpg from beartype import beartype +from fastapi import HTTPException from sqlglot import parse_one from ...common.protocol.developers import Developer from ..utils import ( + partialclass, pg_query, + rewrap_exceptions, wrap_in_class, ) -query = parse_one(""" +# Define the raw SQL query +developer_query = parse_one(""" UPDATE developers SET email = $1, active = $2, tags = $3, settings = $4 WHERE developer_id = $5 @@ -17,12 +22,15 @@ """).sql(pretty=True) -# @rewrap_exceptions( -# { -# QueryException: partialclass(HTTPException, status_code=403), -# ValidationError: partialclass(HTTPException, status_code=500), -# } -# ) +@rewrap_exceptions( + { + asyncpg.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer does not exist.", + ) + } +) @wrap_in_class(Developer, one=True, transform=lambda d: {**d, "id": d["developer_id"]}) @pg_query @beartype @@ -37,6 +45,6 @@ async def update_developer( developer_id = str(developer_id) return ( - query, + developer_query, [email, active, tags or [], settings or {}, developer_id], ) diff --git a/agents-api/agents_api/queries/entries/__init__.py b/agents-api/agents_api/queries/entries/__init__.py new file mode 100644 index 000000000..e6db0efed --- /dev/null +++ b/agents-api/agents_api/queries/entries/__init__.py @@ -0,0 +1,21 @@ +""" +The `entry` module provides SQL query functions for managing entries +in the TimescaleDB database. This includes operations for: + +- Creating new entries +- Deleting entries +- Retrieving entry history +- Listing entries with filtering and pagination +""" + +from .create_entries import create_entries +from .delete_entries import delete_entries +from .get_history import get_history +from .list_entries import list_entries + +__all__ = [ + "create_entries", + "delete_entries", + "get_history", + "list_entries", +] diff --git a/agents-api/agents_api/queries/entries/create_entries.py b/agents-api/agents_api/queries/entries/create_entries.py new file mode 100644 index 000000000..33dcda984 --- /dev/null +++ b/agents-api/agents_api/queries/entries/create_entries.py @@ -0,0 +1,183 @@ +from typing import Literal +from uuid import UUID + +from beartype import beartype +from uuid_extensions import uuid7 + +from ...autogen.openapi_model import CreateEntryRequest, Entry, Relation +from ...common.utils.datetime import utcnow +from ...common.utils.messages import content_to_json +from ...metrics.counters import increase_counter +from ..utils import pg_query, wrap_in_class + +# Query for checking if the session exists +session_exists_query = """ +SELECT EXISTS ( + SELECT 1 FROM sessions + WHERE session_id = $1 AND developer_id = $2 +) AS exists; +""" + +# Define the raw SQL query for creating entries +entry_query = """ +INSERT INTO entries ( + session_id, + entry_id, + source, + role, + event_type, + name, + content, + tool_call_id, + tool_calls, + model, + token_count, + created_at, + timestamp +) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) +RETURNING *; +""" + +# Define the raw SQL query for creating entry relations +entry_relation_query = """ +INSERT INTO entry_relations ( + session_id, + head, + relation, + tail, +) VALUES ($1, $2, $3, $4, $5) +RETURNING *; +""" + + +# @rewrap_exceptions( +# { +# asyncpg.ForeignKeyViolationError: partialclass( +# HTTPException, +# status_code=404, +# detail="Session not found", +# ), +# asyncpg.UniqueViolationError: partialclass( +# HTTPException, +# status_code=409, +# detail="Entry already exists", +# ), +# asyncpg.NotNullViolationError: partialclass( +# HTTPException, +# status_code=400, +# detail="Not null violation", +# ), +# asyncpg.NoDataFoundError: partialclass( +# HTTPException, +# status_code=404, +# detail="Session not found", +# ), +# } +# ) +@wrap_in_class( + Entry, + transform=lambda d: { + "id": UUID(d.pop("entry_id")), + **d, + }, +) +@increase_counter("create_entries") +@pg_query +@beartype +async def create_entries( + *, + developer_id: UUID, + session_id: UUID, + data: list[CreateEntryRequest], +) -> list[tuple[str, list, Literal["fetch", "fetchmany"]]]: + # Convert the data to a list of dictionaries + data_dicts = [item.model_dump(mode="json") for item in data] + + # Prepare the parameters for the query + params = [] + + for item in data_dicts: + params.append( + [ + session_id, # $1 + item.pop("id", None) or str(uuid7()), # $2 + item.get("source"), # $3 + item.get("role"), # $4 + item.get("event_type") or "message.create", # $5 + item.get("name"), # $6 + content_to_json(item.get("content") or {}), # $7 + item.get("tool_call_id"), # $8 + content_to_json(item.get("tool_calls") or {}), # $9 + item.get("model"), # $10 + item.get("token_count"), # $11 + item.get("created_at") or utcnow(), # $12 + utcnow(), # $13 + ] + ) + + return [ + ( + session_exists_query, + [session_id, developer_id], + "fetch", + ), + ( + entry_query, + params, + "fetchmany", + ), + ] + + +# @rewrap_exceptions( +# { +# asyncpg.ForeignKeyViolationError: partialclass( +# HTTPException, +# status_code=404, +# detail="Session not found", +# ), +# asyncpg.UniqueViolationError: partialclass( +# HTTPException, +# status_code=409, +# detail="Entry already exists", +# ), +# } +# ) +@wrap_in_class(Relation) +@increase_counter("add_entry_relations") +@pg_query +@beartype +async def add_entry_relations( + *, + developer_id: UUID, + session_id: UUID, + data: list[Relation], +) -> list[tuple[str, list, Literal["fetch", "fetchmany"]]]: + # Convert the data to a list of dictionaries + data_dicts = [item.model_dump(mode="json") for item in data] + + # Prepare the parameters for the query + params = [] + + for item in data_dicts: + params.append( + [ + item.get("session_id"), # $1 + item.get("head"), # $2 + item.get("relation"), # $3 + item.get("tail"), # $4 + ] + ) + + return [ + ( + session_exists_query, + [session_id, developer_id], + "fetchrow", + ), + ( + entry_relation_query, + params, + "fetchmany", + ), + ] diff --git a/agents-api/agents_api/queries/entries/delete_entries.py b/agents-api/agents_api/queries/entries/delete_entries.py new file mode 100644 index 000000000..628ef9011 --- /dev/null +++ b/agents-api/agents_api/queries/entries/delete_entries.py @@ -0,0 +1,130 @@ +from typing import Literal +from uuid import UUID + +from beartype import beartype +from sqlglot import parse_one + +from ...autogen.openapi_model import ResourceDeletedResponse +from ...common.utils.datetime import utcnow +from ...metrics.counters import increase_counter +from ..utils import pg_query, wrap_in_class + +# Define the raw SQL query for deleting entries with a developer check +delete_entry_query = parse_one(""" +DELETE FROM entries +USING developers +WHERE entries.session_id = $1 -- session_id + AND developers.developer_id = $2 -- developer_id + +RETURNING entries.session_id as session_id; +""").sql(pretty=True) + +# Define the raw SQL query for deleting entries with a developer check +delete_entry_relations_query = parse_one(""" +DELETE FROM entry_relations +WHERE entry_relations.session_id = $1 -- session_id +""").sql(pretty=True) + +# Define the raw SQL query for deleting entries with a developer check +delete_entry_relations_by_ids_query = parse_one(""" +DELETE FROM entry_relations +WHERE entry_relations.session_id = $1 -- session_id + AND (entry_relations.head = ANY($2) -- entry_ids + OR entry_relations.tail = ANY($2)) -- entry_ids +""").sql(pretty=True) + +# Define the raw SQL query for deleting entries by entry_ids with a developer check +delete_entry_by_ids_query = parse_one(""" +DELETE FROM entries +USING developers +WHERE entries.entry_id = ANY($1) -- entry_ids + AND developers.developer_id = $2 -- developer_id + AND entries.session_id = $3 -- session_id + +RETURNING entries.entry_id as entry_id; +""").sql(pretty=True) + +# Add a session_exists_query similar to create_entries.py +session_exists_query = """ +SELECT EXISTS ( + SELECT 1 + FROM sessions + WHERE session_id = $1 + AND developer_id = $2 +); +""" + + +# @rewrap_exceptions( +# { +# asyncpg.ForeignKeyViolationError: partialclass( +# HTTPException, +# status_code=404, +# detail="The specified session or developer does not exist.", +# ), +# asyncpg.UniqueViolationError: partialclass( +# HTTPException, +# status_code=409, +# detail="The specified session has already been deleted.", +# ), +# } +# ) +@wrap_in_class( + ResourceDeletedResponse, + one=True, + transform=lambda d: { + "id": d["session_id"], + "deleted_at": utcnow(), + "jobs": [], + }, +) +@increase_counter("delete_entries_for_session") +@pg_query +@beartype +async def delete_entries_for_session( + *, + developer_id: UUID, + session_id: UUID, +) -> list[tuple[str, list, Literal["fetch", "fetchmany"]]]: + """Delete all entries for a given session.""" + return [ + (session_exists_query, [session_id, developer_id], "fetch"), + (delete_entry_relations_query, [session_id], "fetchmany"), + (delete_entry_query, [session_id, developer_id], "fetchmany"), + ] + + +# @rewrap_exceptions( +# { +# asyncpg.ForeignKeyViolationError: partialclass( +# HTTPException, +# status_code=404, +# detail="The specified entries, session, or developer does not exist.", +# ), +# asyncpg.UniqueViolationError: partialclass( +# HTTPException, +# status_code=409, +# detail="One or more specified entries have already been deleted.", +# ), +# } +# ) +@wrap_in_class( + ResourceDeletedResponse, + transform=lambda d: { + "id": d["entry_id"], + "deleted_at": utcnow(), + "jobs": [], + }, +) +@increase_counter("delete_entries") +@pg_query +@beartype +async def delete_entries( + *, developer_id: UUID, session_id: UUID, entry_ids: list[UUID] +) -> list[tuple[str, list, Literal["fetch", "fetchmany"]]]: + """Delete specific entries by their IDs.""" + return [ + (session_exists_query, [session_id, developer_id], "fetch"), + (delete_entry_relations_by_ids_query, [session_id, entry_ids], "fetchmany"), + (delete_entry_by_ids_query, [entry_ids, developer_id, session_id], "fetchmany"), + ] diff --git a/agents-api/agents_api/queries/entries/get_history.py b/agents-api/agents_api/queries/entries/get_history.py new file mode 100644 index 000000000..b0b767c08 --- /dev/null +++ b/agents-api/agents_api/queries/entries/get_history.py @@ -0,0 +1,72 @@ +from uuid import UUID + +from beartype import beartype +from sqlglot import parse_one + +from ...autogen.openapi_model import History +from ..utils import pg_query, wrap_in_class + +# Define the raw SQL query for getting history with a developer check +history_query = parse_one(""" +SELECT + e.entry_id as id, -- entry_id + e.session_id, -- session_id + e.role, -- role + e.name, -- name + e.content, -- content + e.source, -- source + e.token_count, -- token_count + e.created_at, -- created_at + e.timestamp, -- timestamp + e.tool_calls, -- tool_calls + e.tool_call_id -- tool_call_id +FROM entries e +JOIN developers d ON d.developer_id = $3 +WHERE e.session_id = $1 +AND e.source = ANY($2) +ORDER BY e.created_at; +""").sql(pretty=True) + + +# @rewrap_exceptions( +# { +# asyncpg.ForeignKeyViolationError: partialclass( +# HTTPException, +# status_code=404, +# detail="Session not found", +# ), +# asyncpg.UniqueViolationError: partialclass( +# HTTPException, +# status_code=404, +# detail="Session not found", +# ), +# } +# ) +@wrap_in_class( + History, + one=True, + transform=lambda d: { + **d, + "relations": [ + { + "head": r["head"], + "relation": r["relation"], + "tail": r["tail"], + } + for r in d.pop("relations") + ], + "entries": d.pop("entries"), + }, +) +@pg_query +@beartype +async def get_history( + *, + developer_id: UUID, + session_id: UUID, + allowed_sources: list[str] = ["api_request", "api_response"], +) -> tuple[str, list]: + return ( + history_query, + [session_id, allowed_sources, developer_id], + ) diff --git a/agents-api/agents_api/queries/entries/list_entries.py b/agents-api/agents_api/queries/entries/list_entries.py new file mode 100644 index 000000000..a6c355f53 --- /dev/null +++ b/agents-api/agents_api/queries/entries/list_entries.py @@ -0,0 +1,118 @@ +from typing import Literal +from uuid import UUID + +from beartype import beartype +from fastapi import HTTPException + +from ...autogen.openapi_model import Entry +from ...metrics.counters import increase_counter +from ..utils import pg_query, wrap_in_class + +# Query for checking if the session exists +session_exists_query = """ +SELECT CASE + WHEN EXISTS ( + SELECT 1 FROM sessions + WHERE session_id = $1 AND developer_id = $2 + ) + THEN TRUE + ELSE (SELECT NULL::boolean WHERE FALSE) -- This raises a NO_DATA_FOUND error +END; +""" + +list_entries_query = """ +SELECT + e.entry_id as id, + e.session_id, + e.role, + e.name, + e.content, + e.source, + e.token_count, + e.created_at, + e.timestamp, + e.event_type, + e.tool_call_id, + e.tool_calls, + e.model +FROM entries e +JOIN developers d ON d.developer_id = $5 +LEFT JOIN entry_relations er ON er.head = e.entry_id AND er.session_id = e.session_id +WHERE e.session_id = $1 +AND e.source = ANY($2) +AND (er.relation IS NULL OR er.relation != ALL($6)) +ORDER BY e.{sort_by} {direction} -- safe to interpolate +LIMIT $3 +OFFSET $4; +""" + + +# @rewrap_exceptions( +# { +# asyncpg.ForeignKeyViolationError: partialclass( +# HTTPException, +# status_code=404, +# detail="Session not found", +# ), +# asyncpg.UniqueViolationError: partialclass( +# HTTPException, +# status_code=409, +# detail="Entry already exists", +# ), +# asyncpg.NotNullViolationError: partialclass( +# HTTPException, +# status_code=400, +# detail="Entry is required", +# ), +# asyncpg.NoDataFoundError: partialclass( +# HTTPException, +# status_code=404, +# detail="Session not found", +# ), +# } +# ) +@wrap_in_class(Entry) +@increase_counter("list_entries") +@pg_query +@beartype +async def list_entries( + *, + developer_id: UUID, + session_id: UUID, + allowed_sources: list[str] = ["api_request", "api_response"], + limit: int = 100, + offset: int = 0, + sort_by: Literal["created_at", "timestamp"] = "timestamp", + direction: Literal["asc", "desc"] = "asc", + exclude_relations: list[str] = [], +) -> list[tuple[str, list] | tuple[str, list, str]]: + if limit < 1 or limit > 1000: + raise HTTPException(status_code=400, detail="Limit must be between 1 and 1000") + if offset < 0: + raise HTTPException(status_code=400, detail="Offset must be non-negative") + + query = list_entries_query.format( + sort_by=sort_by, + direction=direction, + ) + + # Parameters for the entry query + entry_params = [ + session_id, # $1 + allowed_sources, # $2 + limit, # $3 + offset, # $4 + developer_id, # $5 + exclude_relations, # $6 + ] + return [ + ( + session_exists_query, + [session_id, developer_id], + "fetchrow", + ), + ( + query, + entry_params, + ), + ] diff --git a/agents-api/agents_api/queries/sessions/__init__.py b/agents-api/agents_api/queries/sessions/__init__.py new file mode 100644 index 000000000..d0f64ea5e --- /dev/null +++ b/agents-api/agents_api/queries/sessions/__init__.py @@ -0,0 +1,30 @@ +""" +The `sessions` module within the `queries` package provides SQL query functions for managing sessions +in the PostgreSQL database. This includes operations for: + +- Creating new sessions +- Updating existing sessions +- Retrieving session details +- Listing sessions with filtering and pagination +- Deleting sessions +""" + +from .count_sessions import count_sessions +from .create_or_update_session import create_or_update_session +from .create_session import create_session +from .delete_session import delete_session +from .get_session import get_session +from .list_sessions import list_sessions +from .patch_session import patch_session +from .update_session import update_session + +__all__ = [ + "count_sessions", + "create_or_update_session", + "create_session", + "delete_session", + "get_session", + "list_sessions", + "patch_session", + "update_session", +] diff --git a/agents-api/agents_api/queries/sessions/count_sessions.py b/agents-api/agents_api/queries/sessions/count_sessions.py new file mode 100644 index 000000000..2abdf22e5 --- /dev/null +++ b/agents-api/agents_api/queries/sessions/count_sessions.py @@ -0,0 +1,55 @@ +"""This module contains functions for querying session data from the PostgreSQL database.""" + +from uuid import UUID + +import asyncpg +from beartype import beartype +from fastapi import HTTPException +from sqlglot import parse_one + +from ...metrics.counters import increase_counter +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL query outside the function +raw_query = """ +SELECT COUNT(session_id) as count +FROM sessions +WHERE developer_id = $1; +""" + +# Parse and optimize the query +query = parse_one(raw_query).sql(pretty=True) + + +@rewrap_exceptions( + { + asyncpg.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer does not exist.", + ) + } +) +@wrap_in_class(dict, one=True) +@increase_counter("count_sessions") +@pg_query +@beartype +async def count_sessions( + *, + developer_id: UUID, +) -> tuple[str, list]: + """ + Counts sessions from the PostgreSQL database. + Uses the index on developer_id for efficient counting. + + Args: + developer_id (UUID): The developer's ID to filter sessions by. + + Returns: + tuple[str, list]: SQL query and parameters. + """ + + return ( + query, + [developer_id], + ) diff --git a/agents-api/agents_api/queries/sessions/create_or_update_session.py b/agents-api/agents_api/queries/sessions/create_or_update_session.py new file mode 100644 index 000000000..3c4dbf66e --- /dev/null +++ b/agents-api/agents_api/queries/sessions/create_or_update_session.py @@ -0,0 +1,154 @@ +from uuid import UUID + +import asyncpg +from beartype import beartype +from fastapi import HTTPException +from sqlglot import parse_one + +from ...autogen.openapi_model import ( + CreateOrUpdateSessionRequest, + ResourceUpdatedResponse, +) +from ...metrics.counters import increase_counter +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL queries +session_query = parse_one(""" +INSERT INTO sessions ( + developer_id, + session_id, + situation, + system_template, + metadata, + render_templates, + token_budget, + context_overflow, + forward_tool_calls, + recall_options +) +VALUES ( + $1, + $2, + $3, + $4, + $5, + $6, + $7, + $8, + $9, + $10 +) +ON CONFLICT (developer_id, session_id) DO UPDATE SET + situation = EXCLUDED.situation, + system_template = EXCLUDED.system_template, + metadata = EXCLUDED.metadata, + render_templates = EXCLUDED.render_templates, + token_budget = EXCLUDED.token_budget, + context_overflow = EXCLUDED.context_overflow, + forward_tool_calls = EXCLUDED.forward_tool_calls, + recall_options = EXCLUDED.recall_options +RETURNING *; +""").sql(pretty=True) + +lookup_query = parse_one(""" +WITH deleted_lookups AS ( + DELETE FROM session_lookup + WHERE developer_id = $1 AND session_id = $2 +) +INSERT INTO session_lookup ( + developer_id, + session_id, + participant_type, + participant_id +) +VALUES ($1, $2, $3, $4); +""").sql(pretty=True) + + +@rewrap_exceptions( + { + asyncpg.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer or participant does not exist.", + ), + asyncpg.UniqueViolationError: partialclass( + HTTPException, + status_code=409, + detail="A session with this ID already exists.", + ), + } +) +@wrap_in_class( + ResourceUpdatedResponse, + one=True, + transform=lambda d: { + "id": d["session_id"], + "updated_at": d["updated_at"], + }, +) +@increase_counter("create_or_update_session") +@pg_query(return_index=0) +@beartype +async def create_or_update_session( + *, + developer_id: UUID, + session_id: UUID, + data: CreateOrUpdateSessionRequest, +) -> list[tuple[str, list] | tuple[str, list, str]]: + """ + Constructs SQL queries to create or update a session and its participant lookups. + + Args: + developer_id (UUID): The developer's UUID + session_id (UUID): The session's UUID + data (CreateOrUpdateSessionRequest): Session data to insert or update + + Returns: + list[tuple[str, list]]: List of SQL queries and their parameters + """ + # Handle participants + users = data.users or ([data.user] if data.user else []) + agents = data.agents or ([data.agent] if data.agent else []) + + if not agents: + raise HTTPException( + status_code=400, + detail="At least one agent must be provided", + ) + + if data.agent and data.agents: + raise HTTPException( + status_code=400, + detail="Only one of 'agent' or 'agents' should be provided", + ) + + # Prepare participant arrays for lookup query + participant_types = ["user"] * len(users) + ["agent"] * len(agents) + participant_ids = [str(u) for u in users] + [str(a) for a in agents] + + # Prepare session parameters + session_params = [ + developer_id, # $1 + session_id, # $2 + data.situation, # $3 + data.system_template, # $4 + data.metadata or {}, # $5 + data.render_templates, # $6 + data.token_budget, # $7 + data.context_overflow, # $8 + data.forward_tool_calls, # $9 + data.recall_options or {}, # $10 + ] + + # Prepare lookup parameters + lookup_params = [] + for participant_type, participant_id in zip(participant_types, participant_ids): + lookup_params.append( + [developer_id, session_id, participant_type, participant_id] + ) + + return [ + (session_query, session_params, "fetch"), + (lookup_query, lookup_params, "fetchmany"), + ] diff --git a/agents-api/agents_api/queries/sessions/create_session.py b/agents-api/agents_api/queries/sessions/create_session.py new file mode 100644 index 000000000..058462cf8 --- /dev/null +++ b/agents-api/agents_api/queries/sessions/create_session.py @@ -0,0 +1,141 @@ +from uuid import UUID + +import asyncpg +from beartype import beartype +from fastapi import HTTPException +from sqlglot import parse_one +from uuid_extensions import uuid7 + +from ...autogen.openapi_model import ( + CreateSessionRequest, + Session, +) +from ...metrics.counters import increase_counter +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL queries +session_query = parse_one(""" +INSERT INTO sessions ( + developer_id, + session_id, + situation, + system_template, + metadata, + render_templates, + token_budget, + context_overflow, + forward_tool_calls, + recall_options +) +VALUES ( + $1, -- developer_id + $2, -- session_id + $3, -- situation + $4, -- system_template + $5, -- metadata + $6, -- render_templates + $7, -- token_budget + $8, -- context_overflow + $9, -- forward_tool_calls + $10 -- recall_options +) +RETURNING *; +""").sql(pretty=True) + +lookup_query = parse_one(""" +INSERT INTO session_lookup ( + developer_id, + session_id, + participant_type, + participant_id +) +VALUES ($1, $2, $3, $4); +""").sql(pretty=True) + + +@rewrap_exceptions( + { + asyncpg.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer or participant does not exist.", + ), + asyncpg.UniqueViolationError: partialclass( + HTTPException, + status_code=409, + detail="A session with this ID already exists.", + ), + } +) +@wrap_in_class( + Session, + one=True, + transform=lambda d: { + **d, + "id": d["session_id"], + }, +) +@increase_counter("create_session") +@pg_query(return_index=0) +@beartype +async def create_session( + *, + developer_id: UUID, + session_id: UUID | None = None, + data: CreateSessionRequest, +) -> list[tuple[str, list] | tuple[str, list, str]]: + """ + Constructs SQL queries to create a new session and its participant lookups. + + Args: + developer_id (UUID): The developer's UUID + session_id (UUID): The session's UUID + data (CreateSessionRequest): Session creation data + + Returns: + list[tuple[str, list] | tuple[str, list, str]]: SQL queries and their parameters + """ + # Handle participants + users = data.users or ([data.user] if data.user else []) + agents = data.agents or ([data.agent] if data.agent else []) + session_id = session_id or uuid7() + + if not agents: + raise HTTPException( + status_code=400, + detail="At least one agent must be provided", + ) + + if data.agent and data.agents: + raise HTTPException( + status_code=400, + detail="Only one of 'agent' or 'agents' should be provided", + ) + + # Prepare participant arrays for lookup query + participant_types = ["user"] * len(users) + ["agent"] * len(agents) + participant_ids = [str(u) for u in users] + [str(a) for a in agents] + + # Prepare session parameters + session_params = [ + developer_id, # $1 + session_id, # $2 + data.situation, # $3 + data.system_template, # $4 + data.metadata or {}, # $5 + data.render_templates, # $6 + data.token_budget, # $7 + data.context_overflow, # $8 + data.forward_tool_calls, # $9 + data.recall_options or {}, # $10 + ] + + # Prepare lookup parameters as a list of parameter lists + lookup_params = [] + for ptype, pid in zip(participant_types, participant_ids): + lookup_params.append([developer_id, session_id, ptype, pid]) + + return [ + (session_query, session_params, "fetch"), + (lookup_query, lookup_params, "fetchmany"), + ] diff --git a/agents-api/agents_api/queries/sessions/delete_session.py b/agents-api/agents_api/queries/sessions/delete_session.py new file mode 100644 index 000000000..2e3234fe2 --- /dev/null +++ b/agents-api/agents_api/queries/sessions/delete_session.py @@ -0,0 +1,69 @@ +"""This module contains the implementation for deleting sessions from the PostgreSQL database.""" + +from uuid import UUID + +import asyncpg +from beartype import beartype +from fastapi import HTTPException +from sqlglot import parse_one + +from ...autogen.openapi_model import ResourceDeletedResponse +from ...common.utils.datetime import utcnow +from ...metrics.counters import increase_counter +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL queries +lookup_query = parse_one(""" +DELETE FROM session_lookup +WHERE developer_id = $1 AND session_id = $2; +""").sql(pretty=True) + +session_query = parse_one(""" +DELETE FROM sessions +WHERE developer_id = $1 AND session_id = $2 +RETURNING session_id; +""").sql(pretty=True) + + +@rewrap_exceptions( + { + asyncpg.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer does not exist.", + ), + } +) +@wrap_in_class( + ResourceDeletedResponse, + one=True, + transform=lambda d: { + "id": d["session_id"], + "deleted_at": utcnow(), + "jobs": [], + }, +) +@increase_counter("delete_session") +@pg_query +@beartype +async def delete_session( + *, + developer_id: UUID, + session_id: UUID, +) -> list[tuple[str, list]]: + """ + Constructs SQL queries to delete a session and its participant lookups. + + Args: + developer_id (UUID): The developer's UUID + session_id (UUID): The session's UUID to delete + + Returns: + list[tuple[str, list]]: List of SQL queries and their parameters + """ + params = [developer_id, session_id] + + return [ + (lookup_query, params), # Delete from lookup table first due to FK constraint + (session_query, params), # Then delete from sessions table + ] diff --git a/agents-api/agents_api/queries/sessions/get_session.py b/agents-api/agents_api/queries/sessions/get_session.py new file mode 100644 index 000000000..1f704539e --- /dev/null +++ b/agents-api/agents_api/queries/sessions/get_session.py @@ -0,0 +1,83 @@ +"""This module contains functions for retrieving session data from the PostgreSQL database.""" + +from uuid import UUID + +import asyncpg +from beartype import beartype +from fastapi import HTTPException +from sqlglot import parse_one + +from ...autogen.openapi_model import Session +from ...metrics.counters import increase_counter +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL query +raw_query = """ +WITH session_participants AS ( + SELECT + sl.session_id, + array_agg(sl.participant_id) FILTER (WHERE sl.participant_type = 'agent') as agents, + array_agg(sl.participant_id) FILTER (WHERE sl.participant_type = 'user') as users + FROM session_lookup sl + WHERE sl.developer_id = $1 AND sl.session_id = $2 + GROUP BY sl.session_id +) +SELECT + s.session_id as id, + s.developer_id, + s.situation, + s.system_template, + s.metadata, + s.render_templates, + s.token_budget, + s.context_overflow, + s.forward_tool_calls, + s.recall_options, + s.created_at, + s.updated_at, + sp.agents, + sp.users +FROM sessions s +LEFT JOIN session_participants sp ON s.session_id = sp.session_id +WHERE s.developer_id = $1 AND s.session_id = $2; +""" + +# Parse and optimize the query +query = parse_one(raw_query).sql(pretty=True) + + +@rewrap_exceptions( + { + asyncpg.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer does not exist.", + ), + asyncpg.NoDataFoundError: partialclass( + HTTPException, status_code=404, detail="Session not found" + ), + } +) +@wrap_in_class(Session, one=True) +@increase_counter("get_session") +@pg_query +@beartype +async def get_session( + *, + developer_id: UUID, + session_id: UUID, +) -> tuple[str, list]: + """ + Constructs SQL query to retrieve a session and its participants. + + Args: + developer_id (UUID): The developer's UUID + session_id (UUID): The session's UUID + + Returns: + tuple[str, list]: SQL query and parameters + """ + return ( + query, + [developer_id, session_id], + ) diff --git a/agents-api/agents_api/queries/sessions/list_sessions.py b/agents-api/agents_api/queries/sessions/list_sessions.py new file mode 100644 index 000000000..3aabaf32d --- /dev/null +++ b/agents-api/agents_api/queries/sessions/list_sessions.py @@ -0,0 +1,106 @@ +"""This module contains functions for querying session data from the PostgreSQL database.""" + +from typing import Any, Literal +from uuid import UUID + +import asyncpg +from beartype import beartype +from fastapi import HTTPException + +from ...autogen.openapi_model import Session +from ...metrics.counters import increase_counter +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL query +raw_query = """ +WITH session_participants AS ( + SELECT + sl.session_id, + array_agg(sl.participant_id) FILTER (WHERE sl.participant_type = 'agent') as agents, + array_agg(sl.participant_id) FILTER (WHERE sl.participant_type = 'user') as users + FROM session_lookup sl + WHERE sl.developer_id = $1 + GROUP BY sl.session_id +) +SELECT + s.session_id as id, + s.developer_id, + s.situation, + s.system_template, + s.metadata, + s.render_templates, + s.token_budget, + s.context_overflow, + s.forward_tool_calls, + s.recall_options, + s.created_at, + s.updated_at, + sp.agents, + sp.users +FROM sessions s +LEFT JOIN session_participants sp ON s.session_id = sp.session_id +WHERE s.developer_id = $1 + AND ($5::jsonb IS NULL OR s.metadata @> $5::jsonb) +ORDER BY + CASE WHEN $3 = 'created_at' AND $4 = 'desc' THEN s.created_at END DESC, + CASE WHEN $3 = 'created_at' AND $4 = 'asc' THEN s.created_at END ASC, + CASE WHEN $3 = 'updated_at' AND $4 = 'desc' THEN s.updated_at END DESC, + CASE WHEN $3 = 'updated_at' AND $4 = 'asc' THEN s.updated_at END ASC +LIMIT $2 OFFSET $6; +""" + +# Parse and optimize the query +# query = parse_one(raw_query).sql(pretty=True) +query = raw_query + + +@rewrap_exceptions( + { + asyncpg.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer does not exist.", + ), + asyncpg.NoDataFoundError: partialclass( + HTTPException, status_code=404, detail="No sessions found" + ), + } +) +@wrap_in_class(Session) +@increase_counter("list_sessions") +@pg_query +@beartype +async def list_sessions( + *, + developer_id: UUID, + limit: int = 100, + offset: int = 0, + sort_by: Literal["created_at", "updated_at"] = "created_at", + direction: Literal["asc", "desc"] = "desc", + metadata_filter: dict[str, Any] = {}, +) -> tuple[str, list]: + """ + Lists sessions from the PostgreSQL database based on the provided filters. + + Args: + developer_id (UUID): The developer's UUID + limit (int): Maximum number of sessions to return + offset (int): Number of sessions to skip + sort_by (str): Field to sort by ('created_at' or 'updated_at') + direction (str): Sort direction ('asc' or 'desc') + metadata_filter (dict): Dictionary of metadata fields to filter by + + Returns: + tuple[str, list]: SQL query and parameters + """ + return ( + query, + [ + developer_id, # $1 + limit, # $2 + sort_by, # $3 + direction, # $4 + metadata_filter or None, # $5 + offset, # $6 + ], + ) diff --git a/agents-api/agents_api/queries/sessions/patch_session.py b/agents-api/agents_api/queries/sessions/patch_session.py new file mode 100644 index 000000000..7d526ae1a --- /dev/null +++ b/agents-api/agents_api/queries/sessions/patch_session.py @@ -0,0 +1,89 @@ +from uuid import UUID + +import asyncpg +from beartype import beartype +from fastapi import HTTPException +from sqlglot import parse_one + +from ...autogen.openapi_model import PatchSessionRequest, ResourceUpdatedResponse +from ...metrics.counters import increase_counter +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL queries +# Build dynamic SET clause based on provided fields +session_query = parse_one(""" +WITH updated_session AS ( + UPDATE sessions + SET + situation = COALESCE($3, situation), + system_template = COALESCE($4, system_template), + metadata = sessions.metadata || $5, + render_templates = COALESCE($6, render_templates), + token_budget = COALESCE($7, token_budget), + context_overflow = COALESCE($8, context_overflow), + forward_tool_calls = COALESCE($9, forward_tool_calls), + recall_options = sessions.recall_options || $10 + WHERE + developer_id = $1 + AND session_id = $2 + RETURNING * +) +SELECT * FROM updated_session; +""").sql(pretty=True) + + +@rewrap_exceptions( + { + asyncpg.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer or participant does not exist.", + ), + asyncpg.NoDataFoundError: partialclass( + HTTPException, + status_code=404, + detail="Session not found", + ), + } +) +@wrap_in_class( + ResourceUpdatedResponse, + one=True, + transform=lambda d: {"id": d["session_id"], "updated_at": d["updated_at"]}, +) +@increase_counter("patch_session") +@pg_query +@beartype +async def patch_session( + *, + developer_id: UUID, + session_id: UUID, + data: PatchSessionRequest, +) -> list[tuple[str, list]]: + """ + Constructs SQL queries to patch a session and its participant lookups. + + Args: + developer_id (UUID): The developer's UUID + session_id (UUID): The session's UUID + data (PatchSessionRequest): Session patch data + + Returns: + list[tuple[str, list]]: List of SQL queries and their parameters + """ + + # Extract fields from data, using None for unset fields + session_params = [ + developer_id, # $1 + session_id, # $2 + data.situation, # $3 + data.system_template, # $4 + data.metadata or {}, # $5 + data.render_templates, # $6 + data.token_budget, # $7 + data.context_overflow, # $8 + data.forward_tool_calls, # $9 + data.recall_options or {}, # $10 + ] + + return [(session_query, session_params)] diff --git a/agents-api/agents_api/queries/sessions/update_session.py b/agents-api/agents_api/queries/sessions/update_session.py new file mode 100644 index 000000000..7c58d10e6 --- /dev/null +++ b/agents-api/agents_api/queries/sessions/update_session.py @@ -0,0 +1,89 @@ +from uuid import UUID + +import asyncpg +from beartype import beartype +from fastapi import HTTPException +from sqlglot import parse_one + +from ...autogen.openapi_model import ResourceUpdatedResponse, UpdateSessionRequest +from ...metrics.counters import increase_counter +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL queries +session_query = parse_one(""" +UPDATE sessions +SET + situation = $3, + system_template = $4, + metadata = $5, + render_templates = $6, + token_budget = $7, + context_overflow = $8, + forward_tool_calls = $9, + recall_options = $10 +WHERE + developer_id = $1 + AND session_id = $2 +RETURNING *; +""").sql(pretty=True) + + +@rewrap_exceptions( + { + asyncpg.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer or participant does not exist.", + ), + asyncpg.NoDataFoundError: partialclass( + HTTPException, + status_code=404, + detail="Session not found", + ), + } +) +@wrap_in_class( + ResourceUpdatedResponse, + one=True, + transform=lambda d: { + "id": d["session_id"], + "updated_at": d["updated_at"], + }, +) +@increase_counter("update_session") +@pg_query +@beartype +async def update_session( + *, + developer_id: UUID, + session_id: UUID, + data: UpdateSessionRequest, +) -> list[tuple[str, list]]: + """ + Constructs SQL queries to update a session and its participant lookups. + + Args: + developer_id (UUID): The developer's UUID + session_id (UUID): The session's UUID + data (UpdateSessionRequest): Session update data + + Returns: + list[tuple[str, list]]: List of SQL queries and their parameters + """ + # Prepare session parameters + session_params = [ + developer_id, # $1 + session_id, # $2 + data.situation, # $3 + data.system_template, # $4 + data.metadata or {}, # $5 + data.render_templates, # $6 + data.token_budget, # $7 + data.context_overflow, # $8 + data.forward_tool_calls, # $9 + data.recall_options or {}, # $10 + ] + + return [ + (session_query, session_params), + ] diff --git a/agents-api/agents_api/queries/users/create_or_update_user.py b/agents-api/agents_api/queries/users/create_or_update_user.py index d2be71bb4..965ae4ce4 100644 --- a/agents-api/agents_api/queries/users/create_or_update_user.py +++ b/agents-api/agents_api/queries/users/create_or_update_user.py @@ -4,14 +4,13 @@ from beartype import beartype from fastapi import HTTPException from sqlglot import parse_one -from sqlglot.optimizer import optimize from ...autogen.openapi_model import CreateOrUpdateUserRequest, User from ...metrics.counters import increase_counter from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class -# Optimize the raw query by using COALESCE for metadata to avoid explicit check -raw_query = """ +# Define the raw SQL query for creating or updating a user +user_query = parse_one(""" INSERT INTO users ( developer_id, user_id, @@ -20,21 +19,18 @@ metadata ) VALUES ( - $1, - $2, - $3, - $4, - $5 + $1, -- developer_id + $2, -- user_id + $3, -- name + $4, -- about + $5::jsonb -- metadata ) ON CONFLICT (developer_id, user_id) DO UPDATE SET name = EXCLUDED.name, about = EXCLUDED.about, metadata = EXCLUDED.metadata RETURNING *; -""" - -# Add index hint for better performance -query = parse_one(raw_query).sql(pretty=True) +""").sql(pretty=True) @rewrap_exceptions( @@ -51,7 +47,14 @@ ), } ) -@wrap_in_class(User, one=True, transform=lambda d: {**d, "id": d["user_id"]}) +@wrap_in_class( + User, + one=True, + transform=lambda d: { + **d, + "id": d["user_id"], + }, +) @increase_counter("create_or_update_user") @pg_query @beartype @@ -73,14 +76,14 @@ async def create_or_update_user( HTTPException: If developer doesn't exist (404) or on unique constraint violation (409) """ params = [ - developer_id, - user_id, - data.name, - data.about, - data.metadata or {}, + developer_id, # $1 + user_id, # $2 + data.name, # $3 + data.about, # $4 + data.metadata or {}, # $5 ] return ( - query, + user_query, params, ) diff --git a/agents-api/agents_api/queries/users/create_user.py b/agents-api/agents_api/queries/users/create_user.py index 66e8bcc27..8f35a646c 100644 --- a/agents-api/agents_api/queries/users/create_user.py +++ b/agents-api/agents_api/queries/users/create_user.py @@ -4,7 +4,6 @@ from beartype import beartype from fastapi import HTTPException from sqlglot import parse_one -from sqlglot.optimizer import optimize from uuid_extensions import uuid7 from ...autogen.openapi_model import CreateUserRequest, User @@ -12,7 +11,7 @@ from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class # Define the raw SQL query outside the function -raw_query = """ +user_query = parse_one(""" INSERT INTO users ( developer_id, user_id, @@ -21,17 +20,14 @@ metadata ) VALUES ( - $1, - $2, - $3, - $4, - $5 + $1, -- developer_id + $2, -- user_id + $3, -- name + $4, -- about + $5::jsonb -- metadata ) RETURNING *; -""" - -# Parse and optimize the query -query = parse_one(raw_query).sql(pretty=True) +""").sql(pretty=True) @rewrap_exceptions( @@ -48,7 +44,14 @@ ), } ) -@wrap_in_class(User, one=True, transform=lambda d: {**d, "id": d["user_id"]}) +@wrap_in_class( + User, + one=True, + transform=lambda d: { + **d, + "id": d["user_id"], + }, +) @increase_counter("create_user") @pg_query @beartype @@ -72,14 +75,14 @@ async def create_user( user_id = user_id or uuid7() params = [ - developer_id, - user_id, - data.name, - data.about, - data.metadata or {}, + developer_id, # $1 + user_id, # $2 + data.name, # $3 + data.about, # $4 + data.metadata or {}, # $5 ] return ( - query, + user_query, params, ) diff --git a/agents-api/agents_api/queries/users/delete_user.py b/agents-api/agents_api/queries/users/delete_user.py index 520c8d695..86bcc0b26 100644 --- a/agents-api/agents_api/queries/users/delete_user.py +++ b/agents-api/agents_api/queries/users/delete_user.py @@ -4,18 +4,17 @@ from beartype import beartype from fastapi import HTTPException from sqlglot import parse_one -from sqlglot.optimizer import optimize from ...autogen.openapi_model import ResourceDeletedResponse from ...common.utils.datetime import utcnow -from ...metrics.counters import increase_counter from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class # Define the raw SQL query outside the function -raw_query = """ +delete_query = parse_one(""" WITH deleted_data AS ( - DELETE FROM user_files - WHERE developer_id = $1 AND user_id = $2 + DELETE FROM user_files -- user_files + WHERE developer_id = $1 -- developer_id + AND user_id = $2 -- user_id ), deleted_docs AS ( DELETE FROM user_docs @@ -24,10 +23,7 @@ DELETE FROM users WHERE developer_id = $1 AND user_id = $2 RETURNING user_id, developer_id; -""" - -# Parse and optimize the query -query = parse_one(raw_query).sql(pretty=True) +""").sql(pretty=True) @rewrap_exceptions( @@ -36,15 +32,24 @@ HTTPException, status_code=404, detail="The specified developer does not exist.", - ) + ), + asyncpg.UniqueViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified user does not exist.", + ), } ) @wrap_in_class( ResourceDeletedResponse, one=True, - transform=lambda d: {**d, "id": d["user_id"], "deleted_at": utcnow()}, + transform=lambda d: { + **d, + "id": d["user_id"], + "deleted_at": utcnow(), + "jobs": [], + }, ) -@increase_counter("delete_user") @pg_query @beartype async def delete_user(*, developer_id: UUID, user_id: UUID) -> tuple[str, list]: @@ -61,6 +66,6 @@ async def delete_user(*, developer_id: UUID, user_id: UUID) -> tuple[str, list]: """ return ( - query, + delete_query, [developer_id, user_id], ) diff --git a/agents-api/agents_api/queries/users/get_user.py b/agents-api/agents_api/queries/users/get_user.py index 6989c8edb..2b71f9192 100644 --- a/agents-api/agents_api/queries/users/get_user.py +++ b/agents-api/agents_api/queries/users/get_user.py @@ -4,29 +4,24 @@ from beartype import beartype from fastapi import HTTPException from sqlglot import parse_one -from sqlglot.optimizer import optimize from ...autogen.openapi_model import User -from ...metrics.counters import increase_counter from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class # Define the raw SQL query outside the function -raw_query = """ +user_query = parse_one(""" SELECT - user_id as id, - developer_id, - name, - about, - metadata, - created_at, - updated_at + user_id as id, -- user_id + developer_id, -- developer_id + name, -- name + about, -- about + metadata, -- metadata + created_at, -- created_at + updated_at -- updated_at FROM users WHERE developer_id = $1 AND user_id = $2; -""" - -# Parse and optimize the query -query = parse_one(raw_query).sql(pretty=True) +""").sql(pretty=True) @rewrap_exceptions( @@ -35,11 +30,15 @@ HTTPException, status_code=404, detail="The specified developer does not exist.", - ) + ), + asyncpg.UniqueViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified user does not exist.", + ), } ) @wrap_in_class(User, one=True) -@increase_counter("get_user") @pg_query @beartype async def get_user(*, developer_id: UUID, user_id: UUID) -> tuple[str, list]: @@ -56,6 +55,6 @@ async def get_user(*, developer_id: UUID, user_id: UUID) -> tuple[str, list]: """ return ( - query, + user_query, [developer_id, user_id], ) diff --git a/agents-api/agents_api/queries/users/list_users.py b/agents-api/agents_api/queries/users/list_users.py index 7f3677eab..0f0818135 100644 --- a/agents-api/agents_api/queries/users/list_users.py +++ b/agents-api/agents_api/queries/users/list_users.py @@ -4,24 +4,21 @@ import asyncpg from beartype import beartype from fastapi import HTTPException -from sqlglot import parse_one -from sqlglot.optimizer import optimize from ...autogen.openapi_model import User -from ...metrics.counters import increase_counter from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class # Define the raw SQL query outside the function -raw_query = """ +user_query = """ WITH filtered_users AS ( SELECT - user_id as id, - developer_id, - name, - about, - metadata, - created_at, - updated_at + user_id as id, -- user_id + developer_id, -- developer_id + name, -- name + about, -- about + metadata, -- metadata + created_at, -- created_at + updated_at -- updated_at FROM users WHERE developer_id = $1 AND ($4::jsonb IS NULL OR metadata @> $4) @@ -37,9 +34,6 @@ OFFSET $3; """ -# Parse and optimize the query -# query = parse_one(raw_query).sql(pretty=True) - @rewrap_exceptions( { @@ -47,11 +41,15 @@ HTTPException, status_code=404, detail="The specified developer does not exist.", - ) + ), + asyncpg.UniqueViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified user does not exist.", + ), } ) @wrap_in_class(User) -@increase_counter("list_users") @pg_query @beartype async def list_users( @@ -84,15 +82,15 @@ async def list_users( raise HTTPException(status_code=400, detail="Offset must be non-negative") params = [ - developer_id, - limit, - offset, + developer_id, # $1 + limit, # $2 + offset, # $3 metadata_filter, # Will be NULL if not provided - sort_by, - direction, + sort_by, # $4 + direction, # $5 ] return ( - raw_query, + user_query, params, ) diff --git a/agents-api/agents_api/queries/users/patch_user.py b/agents-api/agents_api/queries/users/patch_user.py index 971e96b81..c55ee31b7 100644 --- a/agents-api/agents_api/queries/users/patch_user.py +++ b/agents-api/agents_api/queries/users/patch_user.py @@ -4,42 +4,38 @@ from beartype import beartype from fastapi import HTTPException from sqlglot import parse_one -from sqlglot.optimizer import optimize from ...autogen.openapi_model import PatchUserRequest, ResourceUpdatedResponse from ...metrics.counters import increase_counter from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class # Define the raw SQL query outside the function -raw_query = """ +user_query = parse_one(""" UPDATE users SET name = CASE - WHEN $3::text IS NOT NULL THEN $3 + WHEN $3::text IS NOT NULL THEN $3 -- name ELSE name END, about = CASE - WHEN $4::text IS NOT NULL THEN $4 + WHEN $4::text IS NOT NULL THEN $4 -- about ELSE about END, metadata = CASE - WHEN $5::jsonb IS NOT NULL THEN metadata || $5 + WHEN $5::jsonb IS NOT NULL THEN metadata || $5 -- metadata ELSE metadata END WHERE developer_id = $1 AND user_id = $2 RETURNING - user_id as id, - developer_id, - name, - about, - metadata, - created_at, - updated_at; -""" - -# Parse and optimize the query -query = parse_one(raw_query).sql(pretty=True) + user_id as id, -- user_id + developer_id, -- developer_id + name, -- name + about, -- about + metadata, -- metadata + created_at, -- created_at + updated_at; -- updated_at +""").sql(pretty=True) @rewrap_exceptions( @@ -48,7 +44,12 @@ HTTPException, status_code=404, detail="The specified developer does not exist.", - ) + ), + asyncpg.UniqueViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified user does not exist.", + ), } ) @wrap_in_class(ResourceUpdatedResponse, one=True) @@ -71,11 +72,14 @@ async def patch_user( tuple[str, list]: SQL query and parameters """ params = [ - developer_id, - user_id, - data.name, # Will be NULL if not provided - data.about, # Will be NULL if not provided - data.metadata, # Will be NULL if not provided + developer_id, # $1 + user_id, # $2 + data.name, # $3. Will be NULL if not provided + data.about, # $4. Will be NULL if not provided + data.metadata, # $5. Will be NULL if not provided ] - return query, params + return ( + user_query, + params, + ) diff --git a/agents-api/agents_api/queries/users/update_user.py b/agents-api/agents_api/queries/users/update_user.py index 1fffdebe7..91572e15d 100644 --- a/agents-api/agents_api/queries/users/update_user.py +++ b/agents-api/agents_api/queries/users/update_user.py @@ -4,26 +4,22 @@ from beartype import beartype from fastapi import HTTPException from sqlglot import parse_one -from sqlglot.optimizer import optimize from ...autogen.openapi_model import ResourceUpdatedResponse, UpdateUserRequest from ...metrics.counters import increase_counter from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class # Define the raw SQL query outside the function -raw_query = """ +user_query = parse_one(""" UPDATE users SET - name = $3, - about = $4, - metadata = $5 -WHERE developer_id = $1 -AND user_id = $2 + name = $3, -- name + about = $4, -- about + metadata = $5 -- metadata +WHERE developer_id = $1 -- developer_id +AND user_id = $2 -- user_id RETURNING * -""" - -# Parse and optimize the query -query = parse_one(raw_query).sql(pretty=True) +""").sql(pretty=True) @rewrap_exceptions( @@ -32,7 +28,12 @@ HTTPException, status_code=404, detail="The specified developer does not exist.", - ) + ), + asyncpg.UniqueViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified user does not exist.", + ), } ) @wrap_in_class( @@ -67,6 +68,6 @@ async def update_user( ] return ( - query, + user_query, params, ) diff --git a/agents-api/agents_api/queries/utils.py b/agents-api/agents_api/queries/utils.py index 1bd72dd5b..0c20ca59e 100644 --- a/agents-api/agents_api/queries/utils.py +++ b/agents-api/agents_api/queries/utils.py @@ -5,15 +5,27 @@ import socket import time from functools import partialmethod, wraps -from typing import Any, Awaitable, Callable, ParamSpec, Type, TypeVar, cast +from typing import ( + Any, + Awaitable, + Callable, + Literal, + NotRequired, + ParamSpec, + Type, + TypeVar, + cast, +) import asyncpg -import pandas as pd from asyncpg import Record +from beartype import beartype from fastapi import HTTPException from pydantic import BaseModel +from typing_extensions import TypedDict from ..app import app +from ..env import query_timeout P = ParamSpec("P") T = TypeVar("T") @@ -50,13 +62,72 @@ class NewCls(cls): return NewCls +class AsyncPGFetchArgs(TypedDict): + query: str + args: list[Any] + timeout: NotRequired[float | None] + + +type SQLQuery = str +type FetchMethod = Literal["fetch", "fetchmany", "fetchrow"] +type PGQueryArgs = tuple[SQLQuery, list[Any]] | tuple[SQLQuery, list[Any], FetchMethod] +type PreparedPGQueryArgs = tuple[FetchMethod, AsyncPGFetchArgs] +type BatchedPreparedPGQueryArgs = list[PreparedPGQueryArgs] + + +@beartype +def prepare_pg_query_args( + query_args: PGQueryArgs | list[PGQueryArgs], +) -> BatchedPreparedPGQueryArgs: + batch = [] + query_args = [query_args] if isinstance(query_args, tuple) else query_args + + for query_arg in query_args: + match query_arg: + case (query, variables) | (query, variables, "fetch"): + batch.append( + ( + "fetch", + AsyncPGFetchArgs( + query=query, args=variables, timeout=query_timeout + ), + ) + ) + case (query, variables, "fetchmany"): + batch.append( + ( + "fetchmany", + AsyncPGFetchArgs( + query=query, args=[variables], timeout=query_timeout + ), + ) + ) + case (query, variables, "fetchrow"): + batch.append( + ( + "fetchrow", + AsyncPGFetchArgs( + query=query, args=variables, timeout=query_timeout + ), + ) + ) + case _: + raise ValueError("Invalid query arguments") + + return batch + + +@beartype def pg_query( - func: Callable[P, tuple[str | list[str | None], dict]] | None = None, + func: Callable[P, PGQueryArgs | list[PGQueryArgs]] | None = None, debug: bool | None = None, only_on_error: bool = False, timeit: bool = False, -): - def pg_query_dec(func: Callable[P, tuple[str | list[Any], dict]]): + return_index: int = -1, +) -> Callable[..., Callable[P, list[Record]]] | Callable[P, list[Record]]: + def pg_query_dec( + func: Callable[P, PGQueryArgs | list[PGQueryArgs]], + ) -> Callable[..., Callable[P, list[Record]]]: """ Decorator that wraps a function that takes arbitrary arguments, and returns a (query string, variables) tuple. @@ -67,46 +138,47 @@ def pg_query_dec(func: Callable[P, tuple[str | list[Any], dict]]): from pprint import pprint - # from tenacity import ( - # retry, - # retry_if_exception, - # stop_after_attempt, - # wait_exponential, - # ) - - # TODO: Remove all tenacity decorators - # @retry( - # stop=stop_after_attempt(4), - # wait=wait_exponential(multiplier=1, min=4, max=10), - # # retry=retry_if_exception(is_resource_busy), - # ) @wraps(func) async def wrapper( *args: P.args, connection_pool: asyncpg.Pool | None = None, **kwargs: P.kwargs, ) -> list[Record]: - query, variables = await func(*args, **kwargs) + query_args = await func(*args, **kwargs) + batch = prepare_pg_query_args(query_args) - not only_on_error and debug and print(query) - not only_on_error and debug and pprint( - dict( - variables=variables, - ) - ) + not only_on_error and debug and pprint(batch) # Run the query + pool = ( + connection_pool + if connection_pool is not None + else cast(asyncpg.Pool, app.state.postgres_pool) + ) try: - pool = ( - connection_pool - if connection_pool is not None - else cast(asyncpg.Pool, app.state.postgres_pool) - ) async with pool.acquire() as conn: async with conn.transaction(): start = timeit and time.perf_counter() - results: list[Record] = await conn.fetch(query, *variables) + all_results = [] + + for method_name, payload in batch: + method = getattr(conn, method_name) + + query = payload["query"] + args = payload["args"] + timeout = payload.get("timeout") + + results: list[Record] = await method( + query, *args, timeout=timeout + ) + all_results.append(results) + + if method_name == "fetchrow" and ( + len(results) == 0 or results.get("bool") is None + ): + raise asyncpg.NoDataFoundError + end = timeit and time.perf_counter() timeit and print( @@ -115,8 +187,7 @@ async def wrapper( except Exception as e: if only_on_error and debug: - print(query) - pprint(variables) + pprint(batch) debug and print(repr(e)) connection_error = isinstance( @@ -132,13 +203,11 @@ async def wrapper( raise - not only_on_error and debug and pprint( - dict( - results=[dict(result.items()) for result in results], - ) - ) + # Return results from specified index + results_to_return = all_results[return_index] if all_results else [] + not only_on_error and debug and pprint(results_to_return) - return results + return results_to_return # Set the wrapped function as an attribute of the wrapper, # forwards the __wrapped__ attribute if it exists. @@ -156,8 +225,7 @@ def wrap_in_class( cls: Type[ModelT] | Callable[..., ModelT], one: bool = False, transform: Callable[[dict], dict] | None = None, - _kind: str | None = None, -): +) -> Callable[..., Callable[..., ModelT | list[ModelT]]]: def _return_data(rec: list[Record]): data = [dict(r.items()) for r in rec] @@ -172,7 +240,9 @@ def _return_data(rec: list[Record]): objs: list[ModelT] = [cls(**item) for item in map(transform, data)] return objs - def decorator(func: Callable[P, pd.DataFrame | Awaitable[pd.DataFrame]]): + def decorator( + func: Callable[P, list[Record] | Awaitable[list[Record]]], + ) -> Callable[P, ModelT | list[ModelT]]: @wraps(func) def wrapper(*args: P.args, **kwargs: P.kwargs) -> ModelT | list[ModelT]: return _return_data(func(*args, **kwargs)) @@ -199,7 +269,7 @@ def rewrap_exceptions( Type[BaseException] | Callable[[BaseException], BaseException], ], /, -): +) -> Callable[..., Callable[P, T | Awaitable[T]]]: def _check_error(error): nonlocal mapping @@ -219,14 +289,16 @@ def _check_error(error): raise new_error from error - def decorator(func: Callable[P, T | Awaitable[T]]): + def decorator( + func: Callable[P, T | Awaitable[T]], + ) -> Callable[..., Callable[P, T | Awaitable[T]]]: @wraps(func) async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: try: result: T = await func(*args, **kwargs) except BaseException as error: _check_error(error) - raise + raise error return result @@ -236,7 +308,7 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: result: T = func(*args, **kwargs) except BaseException as error: _check_error(error) - raise + raise error return result diff --git a/agents-api/agents_api/routers/healthz/check_health.py b/agents-api/agents_api/routers/healthz/check_health.py new file mode 100644 index 000000000..5a466ba39 --- /dev/null +++ b/agents-api/agents_api/routers/healthz/check_health.py @@ -0,0 +1,19 @@ +import logging +from uuid import UUID + +from ...models.agent.list_agents import list_agents as list_agents_query +from .router import router + + +@router.get("/healthz", tags=["healthz"]) +async def check_health() -> dict: + try: + # Check if the database is reachable + list_agents_query( + developer_id=UUID("00000000-0000-0000-0000-000000000000"), + ) + except Exception as e: + logging.error("An error occurred while checking health: %s", str(e)) + return {"status": "error", "message": "An internal error has occurred."} + + return {"status": "ok"} diff --git a/agents-api/agents_api/web.py b/agents-api/agents_api/web.py index b354f97bf..a04a7fc66 100644 --- a/agents-api/agents_api/web.py +++ b/agents-api/agents_api/web.py @@ -9,7 +9,7 @@ import sentry_sdk import uvicorn import uvloop -from fastapi import APIRouter, Depends, FastAPI, Request, status +from fastapi import APIRouter, FastAPI, Request, status from fastapi.exceptions import HTTPException, RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse @@ -20,7 +20,6 @@ from .app import app from .common.exceptions import BaseCommonException -from .dependencies.auth import get_api_key from .env import api_prefix, hostname, protocol, public_port, sentry_dsn from .exceptions import PromptTooBigError diff --git a/agents-api/agents_api/workflows/task_execution/__init__.py b/agents-api/agents_api/workflows/task_execution/__init__.py index 6ea9239df..a76c13975 100644 --- a/agents-api/agents_api/workflows/task_execution/__init__.py +++ b/agents-api/agents_api/workflows/task_execution/__init__.py @@ -15,7 +15,7 @@ from ...activities.excecute_api_call import execute_api_call from ...activities.execute_integration import execute_integration from ...activities.execute_system import execute_system - from ...activities.sync_items_remote import load_inputs_remote, save_inputs_remote + from ...activities.sync_items_remote import save_inputs_remote from ...autogen.openapi_model import ( ApiCallDef, BaseIntegrationDef, @@ -214,16 +214,6 @@ async def run( # 3. Then, based on the outcome and step type, decide what to do next workflow.logger.info(f"Processing outcome for step {context.cursor.step}") - [outcome] = await workflow.execute_activity( - load_inputs_remote, - args=[[outcome]], - schedule_to_close_timeout=timedelta( - seconds=60 if debug or testing else temporal_schedule_to_close_timeout - ), - retry_policy=DEFAULT_RETRY_POLICY, - heartbeat_timeout=timedelta(seconds=temporal_heartbeat_timeout), - ) - # Init state state = None diff --git a/agents-api/agents_api/workflows/task_execution/helpers.py b/agents-api/agents_api/workflows/task_execution/helpers.py index 1d68322f5..b2df640a7 100644 --- a/agents-api/agents_api/workflows/task_execution/helpers.py +++ b/agents-api/agents_api/workflows/task_execution/helpers.py @@ -19,11 +19,9 @@ ExecutionInput, StepContext, ) - from ...common.storage_handler import auto_blob_store_workflow from ...env import task_max_parallelism, temporal_heartbeat_timeout -@auto_blob_store_workflow async def continue_as_child( execution_input: ExecutionInput, start: TransitionTarget, @@ -50,7 +48,6 @@ async def continue_as_child( ) -@auto_blob_store_workflow async def execute_switch_branch( *, context: StepContext, @@ -84,7 +81,6 @@ async def execute_switch_branch( ) -@auto_blob_store_workflow async def execute_if_else_branch( *, context: StepContext, @@ -123,7 +119,6 @@ async def execute_if_else_branch( ) -@auto_blob_store_workflow async def execute_foreach_step( *, context: StepContext, @@ -161,7 +156,6 @@ async def execute_foreach_step( return results -@auto_blob_store_workflow async def execute_map_reduce_step( *, context: StepContext, @@ -209,7 +203,6 @@ async def execute_map_reduce_step( return result -@auto_blob_store_workflow async def execute_map_reduce_step_parallel( *, context: StepContext, diff --git a/agents-api/tests/fixtures.py b/agents-api/tests/fixtures.py index fa00c98e3..e1d286c9c 100644 --- a/agents-api/tests/fixtures.py +++ b/agents-api/tests/fixtures.py @@ -1,24 +1,14 @@ -import json import random import string -import time from uuid import UUID -import asyncpg from fastapi.testclient import TestClient -from temporalio.client import WorkflowHandle from uuid_extensions import uuid7 from ward import fixture from agents_api.autogen.openapi_model import ( CreateAgentRequest, - CreateDocRequest, - CreateExecutionRequest, - CreateFileRequest, CreateSessionRequest, - CreateTaskRequest, - CreateToolRequest, - CreateTransitionRequest, CreateUserRequest, ) from agents_api.clients.pg import create_db_pool @@ -36,15 +26,13 @@ # from agents_api.queries.execution.create_temporal_lookup import create_temporal_lookup # from agents_api.queries.files.create_file import create_file # from agents_api.queries.files.delete_file import delete_file -# from agents_api.queries.session.create_session import create_session -# from agents_api.queries.session.delete_session import delete_session +from agents_api.queries.sessions.create_session import create_session + # from agents_api.queries.task.create_task import create_task # from agents_api.queries.task.delete_task import delete_task # from agents_api.queries.tools.create_tools import create_tools # from agents_api.queries.tools.delete_tool import delete_tool from agents_api.queries.users.create_user import create_user - -# from agents_api.queries.users.delete_user import delete_user from agents_api.web import app from .utils import ( @@ -67,11 +55,10 @@ def pg_dsn(): @fixture(scope="global") def test_developer_id(): if not multi_tenant_mode: - yield UUID(int=0) - return + return UUID(int=0) developer_id = uuid7() - yield developer_id + return developer_id # @fixture(scope="global") @@ -98,8 +85,7 @@ async def test_developer(dsn=pg_dsn, developer_id=test_developer_id): connection_pool=pool, ) - yield developer - await pool.close() + return developer @fixture(scope="test") @@ -109,7 +95,7 @@ def patch_embed_acompletion(): yield embed, acompletion -@fixture(scope="global") +@fixture(scope="test") async def test_agent(dsn=pg_dsn, developer=test_developer): pool = await create_db_pool(dsn=dsn) @@ -118,18 +104,16 @@ async def test_agent(dsn=pg_dsn, developer=test_developer): data=CreateAgentRequest( model="gpt-4o-mini", name="test agent", - canonical_name=f"test_agent_{str(int(time.time()))}", about="test agent about", metadata={"test": "test"}, ), connection_pool=pool, ) - yield agent - await pool.close() + return agent -@fixture(scope="global") +@fixture(scope="test") async def test_user(dsn=pg_dsn, developer=test_developer): pool = await create_db_pool(dsn=dsn) @@ -142,8 +126,7 @@ async def test_user(dsn=pg_dsn, developer=test_developer): connection_pool=pool, ) - yield user - await pool.close() + return user @fixture(scope="test") @@ -167,22 +150,27 @@ async def test_new_developer(dsn=pg_dsn, email=random_email): return developer -# @fixture(scope="global") -# async def test_session( -# dsn=pg_dsn, -# developer_id=test_developer_id, -# test_user=test_user, -# test_agent=test_agent, -# ): -# async with get_pg_client(dsn=dsn) as client: -# session = await create_session( -# developer_id=developer_id, -# data=CreateSessionRequest( -# agent=test_agent.id, user=test_user.id, metadata={"test": "test"} -# ), -# client=client, -# ) -# yield session +@fixture(scope="test") +async def test_session( + dsn=pg_dsn, + developer_id=test_developer_id, + test_user=test_user, + test_agent=test_agent, +): + pool = await create_db_pool(dsn=dsn) + + session = await create_session( + developer_id=developer_id, + data=CreateSessionRequest( + agent=test_agent.id, + user=test_user.id, + metadata={"test": "test"}, + system_template="test system template", + ), + connection_pool=pool, + ) + + return session # @fixture(scope="global") @@ -349,38 +337,49 @@ async def test_new_developer(dsn=pg_dsn, email=random_email): # "type": "function", # } -# async with get_pg_client(dsn=dsn) as client: -# [tool, *_] = await create_tools( +# [tool, *_] = await create_tools( +# developer_id=developer_id, +# agent_id=agent.id, +# data=[CreateToolRequest(**tool)], +# connection_pool=pool, +# ) +# yield tool + +# # Cleanup +# try: +# await delete_tool( # developer_id=developer_id, -# agent_id=agent.id, -# data=[CreateToolRequest(**tool)], -# client=client, +# tool_id=tool.id, +# connection_pool=pool, # ) -# yield tool +# finally: +# await pool.close() -# @fixture(scope="global") -# def client(dsn=pg_dsn): -# client = TestClient(app=app) -# client.state.pg_client = get_pg_client(dsn=dsn) -# return client +@fixture(scope="global") +async def client(dsn=pg_dsn): + pool = await create_db_pool(dsn=dsn) + client = TestClient(app=app) + client.state.postgres_pool = pool + return client -# @fixture(scope="global") -# def make_request(client=client, developer_id=test_developer_id): -# def _make_request(method, url, **kwargs): -# headers = kwargs.pop("headers", {}) -# headers = { -# **headers, -# api_key_header_name: api_key, -# } -# if multi_tenant_mode: -# headers["X-Developer-Id"] = str(developer_id) +@fixture(scope="global") +async def make_request(client=client, developer_id=test_developer_id): + def _make_request(method, url, **kwargs): + headers = kwargs.pop("headers", {}) + headers = { + **headers, + api_key_header_name: api_key, + } + + if multi_tenant_mode: + headers["X-Developer-Id"] = str(developer_id) -# return client.request(method, url, headers=headers, **kwargs) + return client.request(method, url, headers=headers, **kwargs) -# return _make_request + return _make_request @fixture(scope="global") diff --git a/agents-api/tests/test_agent_queries.py b/agents-api/tests/test_agent_queries.py index 56a07ed03..85d10f6ea 100644 --- a/agents-api/tests/test_agent_queries.py +++ b/agents-api/tests/test_agent_queries.py @@ -1,7 +1,5 @@ # Tests for agent queries -from uuid import UUID -import asyncpg from uuid_extensions import uuid7 from ward import raises, test @@ -43,7 +41,7 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id): ) -@test("query: create agent with instructions sql") +@test("query: create or update agent sql") async def _(dsn=pg_dsn, developer_id=test_developer_id): """Test that an agent can be successfully created or updated.""" diff --git a/agents-api/tests/test_developer_queries.py b/agents-api/tests/test_developer_queries.py index d360a7dc2..eedc07dd2 100644 --- a/agents-api/tests/test_developer_queries.py +++ b/agents-api/tests/test_developer_queries.py @@ -4,7 +4,6 @@ from ward import raises, test from agents_api.clients.pg import create_db_pool -from agents_api.common.protocol.developers import Developer from agents_api.queries.developers.create_developer import create_developer from agents_api.queries.developers.get_developer import ( get_developer, diff --git a/agents-api/tests/test_entry_queries.py b/agents-api/tests/test_entry_queries.py index 220b8d232..f5b9d8d56 100644 --- a/agents-api/tests/test_entry_queries.py +++ b/agents-api/tests/test_entry_queries.py @@ -1,89 +1,62 @@ -# """ -# This module contains tests for entry queries against the CozoDB database. -# It verifies the functionality of adding, retrieving, and processing entries as defined in the schema. -# """ +""" +This module contains tests for entry queries against the CozoDB database. +It verifies the functionality of adding, retrieving, and processing entries as defined in the schema. +""" -# # Tests for entry queries +from fastapi import HTTPException +from uuid_extensions import uuid7 +from ward import raises, test -# import time +from agents_api.autogen.openapi_model import CreateEntryRequest +from agents_api.clients.pg import create_db_pool +from agents_api.queries.entries import create_entries, list_entries +from tests.fixtures import pg_dsn, test_developer # , test_session -# from ward import test +MODEL = "gpt-4o-mini" -# from agents_api.autogen.openapi_model import CreateEntryRequest -# from agents_api.queries.entry.create_entries import create_entries -# from agents_api.queries.entry.delete_entries import delete_entries -# from agents_api.queries.entry.get_history import get_history -# from agents_api.queries.entry.list_entries import list_entries -# from agents_api.queries.session.get_session import get_session -# from tests.fixtures import cozo_client, test_developer_id, test_session -# MODEL = "gpt-4o-mini" +@test("query: create entry no session") +async def _(dsn=pg_dsn, developer=test_developer): + """Test the addition of a new entry to the database.""" + pool = await create_db_pool(dsn=dsn) + test_entry = CreateEntryRequest.from_model_input( + model=MODEL, + role="user", + source="internal", + content="test entry content", + ) -# @test("query: create entry") -# def _(client=cozo_client, developer_id=test_developer_id, session=test_session): -# """ -# Tests the addition of a new entry to the database. -# Verifies that the entry can be successfully added using the create_entries function. -# """ + with raises(HTTPException) as exc_info: + await create_entries( + developer_id=developer.id, + session_id=uuid7(), + data=[test_entry], + connection_pool=pool, + ) + assert exc_info.raised.status_code == 404 -# test_entry = CreateEntryRequest.from_model_input( -# model=MODEL, -# role="user", -# source="internal", -# content="test entry content", -# ) - -# create_entries( -# developer_id=developer_id, -# session_id=session.id, -# data=[test_entry], -# mark_session_as_updated=False, -# client=client, -# ) - - -# @test("query: create entry, update session") -# def _(client=cozo_client, developer_id=test_developer_id, session=test_session): -# """ -# Tests the addition of a new entry to the database. -# Verifies that the entry can be successfully added using the create_entries function. -# """ - -# test_entry = CreateEntryRequest.from_model_input( -# model=MODEL, -# role="user", -# source="internal", -# content="test entry content", -# ) - -# # TODO: We should make sessions.updated_at also a updated_at_ms field to avoid this sleep -# time.sleep(1) -# create_entries( -# developer_id=developer_id, -# session_id=session.id, -# data=[test_entry], -# mark_session_as_updated=True, -# client=client, -# ) +@test("query: list entries no session") +async def _(dsn=pg_dsn, developer=test_developer): + """Test the retrieval of entries from the database.""" -# updated_session = get_session( -# developer_id=developer_id, -# session_id=session.id, -# client=client, -# ) + pool = await create_db_pool(dsn=dsn) -# assert updated_session.updated_at > session.updated_at + with raises(HTTPException) as exc_info: + await list_entries( + developer_id=developer.id, + session_id=uuid7(), + connection_pool=pool, + ) + assert exc_info.raised.status_code == 404 # @test("query: get entries") -# def _(client=cozo_client, developer_id=test_developer_id, session=test_session): -# """ -# Tests the retrieval of entries from the database. -# Verifies that entries matching specific criteria can be successfully retrieved. -# """ +# async def _(dsn=pg_dsn, developer_id=test_developer_id): # , session=test_session +# """Test the retrieval of entries from the database.""" +# pool = await create_db_pool(dsn=dsn) # test_entry = CreateEntryRequest.from_model_input( # model=MODEL, # role="user", @@ -98,30 +71,31 @@ # source="internal", # ) -# create_entries( -# developer_id=developer_id, -# session_id=session.id, +# await create_entries( +# developer_id=TEST_DEVELOPER_ID, +# session_id=SESSION_ID, # data=[test_entry, internal_entry], -# client=client, +# connection_pool=pool, # ) -# result = list_entries( -# developer_id=developer_id, -# session_id=session.id, -# client=client, +# result = await list_entries( +# developer_id=TEST_DEVELOPER_ID, +# session_id=SESSION_ID, +# connection_pool=pool, # ) -# # Asserts that only one entry is retrieved, matching the session_id. + +# # Assert that only one entry is retrieved, matching the session_id. # assert len(result) == 1 +# assert isinstance(result[0], Entry) +# assert result is not None # @test("query: get history") -# def _(client=cozo_client, developer_id=test_developer_id, session=test_session): -# """ -# Tests the retrieval of entries from the database. -# Verifies that entries matching specific criteria can be successfully retrieved. -# """ +# async def _(dsn=pg_dsn, developer_id=test_developer_id): # , session=test_session +# """Test the retrieval of entry history from the database.""" +# pool = await create_db_pool(dsn=dsn) # test_entry = CreateEntryRequest.from_model_input( # model=MODEL, # role="user", @@ -136,31 +110,31 @@ # source="internal", # ) -# create_entries( +# await create_entries( # developer_id=developer_id, -# session_id=session.id, +# session_id=SESSION_ID, # data=[test_entry, internal_entry], -# client=client, +# connection_pool=pool, # ) -# result = get_history( +# result = await get_history( # developer_id=developer_id, -# session_id=session.id, -# client=client, +# session_id=SESSION_ID, +# connection_pool=pool, # ) -# # Asserts that only one entry is retrieved, matching the session_id. +# # Assert that entries are retrieved and have valid IDs. +# assert result is not None +# assert isinstance(result, History) # assert len(result.entries) > 0 # assert result.entries[0].id # @test("query: delete entries") -# def _(client=cozo_client, developer_id=test_developer_id, session=test_session): -# """ -# Tests the deletion of entries from the database. -# Verifies that entries can be successfully deleted using the delete_entries function. -# """ +# async def _(dsn=pg_dsn, developer_id=test_developer_id): # , session=test_session +# """Test the deletion of entries from the database.""" +# pool = await create_db_pool(dsn=dsn) # test_entry = CreateEntryRequest.from_model_input( # model=MODEL, # role="user", @@ -175,27 +149,29 @@ # source="internal", # ) -# created_entries = create_entries( +# created_entries = await create_entries( # developer_id=developer_id, -# session_id=session.id, +# session_id=SESSION_ID, # data=[test_entry, internal_entry], -# client=client, +# connection_pool=pool, # ) -# entry_ids = [entry.id for entry in created_entries] +# entry_ids = [entry.id for entry in created_entries] -# delete_entries( -# developer_id=developer_id, -# session_id=session.id, -# entry_ids=entry_ids, -# client=client, -# ) +# await delete_entries( +# developer_id=developer_id, +# session_id=SESSION_ID, +# entry_ids=[UUID("123e4567-e89b-12d3-a456-426614174002")], +# connection_pool=pool, +# ) -# result = list_entries( -# developer_id=developer_id, -# session_id=session.id, -# client=client, -# ) +# result = await list_entries( +# developer_id=developer_id, +# session_id=SESSION_ID, +# connection_pool=pool, +# ) -# # Asserts that no entries are retrieved after deletion. -# assert all(id not in [entry.id for entry in result] for id in entry_ids) +# Assert that no entries are retrieved after deletion. +# assert all(id not in [entry.id for entry in result] for id in entry_ids) +# assert len(result) == 0 +# assert result is not None diff --git a/agents-api/tests/test_messages_truncation.py b/agents-api/tests/test_messages_truncation.py index 39cc02c2c..1a6c344e6 100644 --- a/agents-api/tests/test_messages_truncation.py +++ b/agents-api/tests/test_messages_truncation.py @@ -1,4 +1,3 @@ -# from uuid import uuid4 # from uuid_extensions import uuid7 # from ward import raises, test diff --git a/agents-api/tests/test_session_queries.py b/agents-api/tests/test_session_queries.py index e8ec40367..7926a391f 100644 --- a/agents-api/tests/test_session_queries.py +++ b/agents-api/tests/test_session_queries.py @@ -1,160 +1,253 @@ -# # Tests for session queries - -# from uuid_extensions import uuid7 -# from ward import test - -# from agents_api.autogen.openapi_model import ( -# CreateOrUpdateSessionRequest, -# CreateSessionRequest, -# Session, -# ) -# from agents_api.queries.session.count_sessions import count_sessions -# from agents_api.queries.session.create_or_update_session import create_or_update_session -# from agents_api.queries.session.create_session import create_session -# from agents_api.queries.session.delete_session import delete_session -# from agents_api.queries.session.get_session import get_session -# from agents_api.queries.session.list_sessions import list_sessions -# from tests.fixtures import ( -# cozo_client, -# test_agent, -# test_developer_id, -# test_session, -# test_user, -# ) - -# MODEL = "gpt-4o-mini" - - -# @test("query: create session") -# def _( -# client=cozo_client, developer_id=test_developer_id, agent=test_agent, user=test_user -# ): -# create_session( -# developer_id=developer_id, -# data=CreateSessionRequest( -# users=[user.id], -# agents=[agent.id], -# situation="test session about", -# ), -# client=client, -# ) - - -# @test("query: create session no user") -# def _(client=cozo_client, developer_id=test_developer_id, agent=test_agent): -# create_session( -# developer_id=developer_id, -# data=CreateSessionRequest( -# agents=[agent.id], -# situation="test session about", -# ), -# client=client, -# ) - - -# @test("query: get session not exists") -# def _(client=cozo_client, developer_id=test_developer_id): -# session_id = uuid7() - -# try: -# get_session( -# session_id=session_id, -# developer_id=developer_id, -# client=client, -# ) -# except Exception: -# pass -# else: -# assert False, "Session should not exist" - - -# @test("query: get session exists") -# def _(client=cozo_client, developer_id=test_developer_id, session=test_session): -# result = get_session( -# session_id=session.id, -# developer_id=developer_id, -# client=client, -# ) - -# assert result is not None -# assert isinstance(result, Session) - - -# @test("query: delete session") -# def _(client=cozo_client, developer_id=test_developer_id, agent=test_agent): -# session = create_session( -# developer_id=developer_id, -# data=CreateSessionRequest( -# agent=agent.id, -# situation="test session about", -# ), -# client=client, -# ) - -# delete_session( -# session_id=session.id, -# developer_id=developer_id, -# client=client, -# ) - -# try: -# get_session( -# session_id=session.id, -# developer_id=developer_id, -# client=client, -# ) -# except Exception: -# pass - -# else: -# assert False, "Session should not exist" - - -# @test("query: list sessions") -# def _(client=cozo_client, developer_id=test_developer_id, session=test_session): -# result = list_sessions( -# developer_id=developer_id, -# client=client, -# ) - -# assert isinstance(result, list) -# assert len(result) > 0 - - -# @test("query: count sessions") -# def _(client=cozo_client, developer_id=test_developer_id, session=test_session): -# result = count_sessions( -# developer_id=developer_id, -# client=client, -# ) - -# assert isinstance(result, dict) -# assert result["count"] > 0 - - -# @test("query: create or update session") -# def _( -# client=cozo_client, developer_id=test_developer_id, agent=test_agent, user=test_user -# ): -# session_id = uuid7() - -# create_or_update_session( -# session_id=session_id, -# developer_id=developer_id, -# data=CreateOrUpdateSessionRequest( -# users=[user.id], -# agents=[agent.id], -# situation="test session about", -# ), -# client=client, -# ) - -# result = get_session( -# session_id=session_id, -# developer_id=developer_id, -# client=client, -# ) - -# assert result is not None -# assert isinstance(result, Session) -# assert result.id == session_id +""" +This module contains tests for SQL query generation functions in the sessions module. +Tests verify the SQL queries without actually executing them against a database. +""" + +from uuid_extensions import uuid7 +from ward import raises, test + +from agents_api.autogen.openapi_model import ( + CreateOrUpdateSessionRequest, + CreateSessionRequest, + PatchSessionRequest, + ResourceDeletedResponse, + ResourceUpdatedResponse, + Session, + UpdateSessionRequest, +) +from agents_api.clients.pg import create_db_pool +from agents_api.queries.sessions import ( + count_sessions, + create_or_update_session, + create_session, + delete_session, + get_session, + list_sessions, + patch_session, + update_session, +) +from tests.fixtures import ( + pg_dsn, + test_agent, + test_developer_id, + test_session, + test_user, +) + + +@test("query: create session sql") +async def _( + dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, user=test_user +): + """Test that a session can be successfully created.""" + + pool = await create_db_pool(dsn=dsn) + session_id = uuid7() + data = CreateSessionRequest( + users=[user.id], + agents=[agent.id], + system_template="test system template", + ) + result = await create_session( + developer_id=developer_id, + session_id=session_id, + data=data, + connection_pool=pool, + ) + + assert result is not None + assert isinstance(result, Session), f"Result is not a Session, {result}" + assert result.id == session_id + + +@test("query: create or update session sql") +async def _( + dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, user=test_user +): + """Test that a session can be successfully created or updated.""" + + pool = await create_db_pool(dsn=dsn) + session_id = uuid7() + data = CreateOrUpdateSessionRequest( + users=[user.id], + agents=[agent.id], + system_template="test system template", + ) + result = await create_or_update_session( + developer_id=developer_id, + session_id=session_id, + data=data, + connection_pool=pool, + ) + + assert result is not None + assert isinstance(result, ResourceUpdatedResponse) + assert result.id == session_id + assert result.updated_at is not None + + +@test("query: get session exists") +async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): + """Test retrieving an existing session.""" + + pool = await create_db_pool(dsn=dsn) + result = await get_session( + developer_id=developer_id, + session_id=session.id, + connection_pool=pool, + ) + + assert result is not None + assert isinstance(result, Session) + assert result.id == session.id + + +@test("query: get session does not exist") +async def _(dsn=pg_dsn, developer_id=test_developer_id): + """Test retrieving a non-existent session.""" + + session_id = uuid7() + pool = await create_db_pool(dsn=dsn) + with raises(Exception): + await get_session( + session_id=session_id, + developer_id=developer_id, + connection_pool=pool, + ) + + +@test("query: list sessions") +async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): + """Test listing sessions with default pagination.""" + + pool = await create_db_pool(dsn=dsn) + result = await list_sessions( + developer_id=developer_id, + limit=10, + offset=0, + connection_pool=pool, + ) + + assert isinstance(result, list) + assert len(result) >= 1 + assert any(s.id == session.id for s in result) + + +@test("query: list sessions with filters") +async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): + """Test listing sessions with specific filters.""" + + pool = await create_db_pool(dsn=dsn) + result = await list_sessions( + developer_id=developer_id, + limit=10, + offset=0, + connection_pool=pool, + ) + + assert isinstance(result, list) + assert len(result) >= 1 + assert all( + s.situation == session.situation for s in result + ), f"Result is not a list of sessions, {result}, {session.situation}" + + +@test("query: count sessions") +async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): + """Test counting the number of sessions for a developer.""" + + pool = await create_db_pool(dsn=dsn) + count = await count_sessions( + developer_id=developer_id, + connection_pool=pool, + ) + + assert isinstance(count, dict) + assert count["count"] >= 1 + + +@test("query: update session sql") +async def _( + dsn=pg_dsn, + developer_id=test_developer_id, + session=test_session, + agent=test_agent, + user=test_user, +): + """Test that an existing session's information can be successfully updated.""" + + pool = await create_db_pool(dsn=dsn) + data = UpdateSessionRequest( + token_budget=1000, + forward_tool_calls=True, + system_template="updated system template", + ) + result = await update_session( + session_id=session.id, + developer_id=developer_id, + data=data, + connection_pool=pool, + ) + + assert result is not None + assert isinstance(result, ResourceUpdatedResponse) + assert result.updated_at > session.created_at + + updated_session = await get_session( + developer_id=developer_id, + session_id=session.id, + connection_pool=pool, + ) + assert updated_session.forward_tool_calls is True + + +@test("query: patch session sql") +async def _( + dsn=pg_dsn, developer_id=test_developer_id, session=test_session, agent=test_agent +): + """Test that a session can be successfully patched.""" + + pool = await create_db_pool(dsn=dsn) + data = PatchSessionRequest( + metadata={"test": "metadata"}, + ) + result = await patch_session( + developer_id=developer_id, + session_id=session.id, + data=data, + connection_pool=pool, + ) + + assert result is not None + assert isinstance(result, ResourceUpdatedResponse) + assert result.updated_at > session.created_at + + patched_session = await get_session( + developer_id=developer_id, + session_id=session.id, + connection_pool=pool, + ) + assert patched_session.situation == session.situation + assert patched_session.metadata == {"test": "metadata"} + + +@test("query: delete session sql") +async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): + """Test that a session can be successfully deleted.""" + + pool = await create_db_pool(dsn=dsn) + delete_result = await delete_session( + developer_id=developer_id, + session_id=session.id, + connection_pool=pool, + ) + + assert delete_result is not None + assert isinstance(delete_result, ResourceDeletedResponse) + + with raises(Exception): + await get_session( + developer_id=developer_id, + session_id=session.id, + connection_pool=pool, + ) diff --git a/agents-api/tests/test_user_queries.py b/agents-api/tests/test_user_queries.py index cbe7e0353..002532816 100644 --- a/agents-api/tests/test_user_queries.py +++ b/agents-api/tests/test_user_queries.py @@ -5,7 +5,6 @@ from uuid import UUID -import asyncpg from uuid_extensions import uuid7 from ward import raises, test diff --git a/agents-api/tests/utils.py b/agents-api/tests/utils.py index 990a1015e..a4f98ac80 100644 --- a/agents-api/tests/utils.py +++ b/agents-api/tests/utils.py @@ -1,5 +1,4 @@ import asyncio -import json import logging import subprocess from contextlib import asynccontextmanager, contextmanager @@ -7,7 +6,6 @@ from typing import Any, Dict, Optional from unittest.mock import patch -import asyncpg from botocore import exceptions from fastapi.testclient import TestClient from litellm.types.utils import ModelResponse diff --git a/integrations-service/integrations/autogen/Entries.py b/integrations-service/integrations/autogen/Entries.py index de37e77d8..867b10192 100644 --- a/integrations-service/integrations/autogen/Entries.py +++ b/integrations-service/integrations/autogen/Entries.py @@ -52,6 +52,7 @@ class BaseEntry(BaseModel): ] tokenizer: str token_count: int + model: str = "gpt-4o-mini" tool_calls: ( list[ ChosenFunctionCall diff --git a/integrations-service/integrations/autogen/Sessions.py b/integrations-service/integrations/autogen/Sessions.py index 460fd25ce..e2a9ce164 100644 --- a/integrations-service/integrations/autogen/Sessions.py +++ b/integrations-service/integrations/autogen/Sessions.py @@ -31,6 +31,10 @@ class CreateSessionRequest(BaseModel): """ A specific situation that sets the background for this session """ + system_template: str | None = None + """ + System prompt for this session + """ render_templates: StrictBool = True """ Render system and assistant message content as jinja templates @@ -51,6 +55,10 @@ class CreateSessionRequest(BaseModel): If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. """ + forward_tool_calls: StrictBool = False + """ + Whether to forward tool calls to the model + """ recall_options: RecallOptions | None = None metadata: dict[str, Any] | None = None @@ -67,6 +75,10 @@ class PatchSessionRequest(BaseModel): """ A specific situation that sets the background for this session """ + system_template: str | None = None + """ + System prompt for this session + """ render_templates: StrictBool = True """ Render system and assistant message content as jinja templates @@ -87,6 +99,10 @@ class PatchSessionRequest(BaseModel): If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. """ + forward_tool_calls: StrictBool = False + """ + Whether to forward tool calls to the model + """ recall_options: RecallOptionsUpdate | None = None metadata: dict[str, Any] | None = None @@ -121,6 +137,10 @@ class Session(BaseModel): """ A specific situation that sets the background for this session """ + system_template: str | None = None + """ + System prompt for this session + """ summary: Annotated[str | None, Field(json_schema_extra={"readOnly": True})] = None """ Summary (null at the beginning) - generated automatically after every interaction @@ -145,6 +165,10 @@ class Session(BaseModel): If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. """ + forward_tool_calls: StrictBool = False + """ + Whether to forward tool calls to the model + """ recall_options: RecallOptions | None = None id: Annotated[UUID, Field(json_schema_extra={"readOnly": True})] metadata: dict[str, Any] | None = None @@ -197,6 +221,10 @@ class UpdateSessionRequest(BaseModel): """ A specific situation that sets the background for this session """ + system_template: str | None = None + """ + System prompt for this session + """ render_templates: StrictBool = True """ Render system and assistant message content as jinja templates @@ -217,6 +245,10 @@ class UpdateSessionRequest(BaseModel): If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. """ + forward_tool_calls: StrictBool = False + """ + Whether to forward tool calls to the model + """ recall_options: RecallOptions | None = None metadata: dict[str, Any] | None = None @@ -240,6 +272,10 @@ class CreateOrUpdateSessionRequest(CreateSessionRequest): """ A specific situation that sets the background for this session """ + system_template: str | None = None + """ + System prompt for this session + """ render_templates: StrictBool = True """ Render system and assistant message content as jinja templates @@ -260,6 +296,10 @@ class CreateOrUpdateSessionRequest(CreateSessionRequest): If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. """ + forward_tool_calls: StrictBool = False + """ + Whether to forward tool calls to the model + """ recall_options: RecallOptions | None = None metadata: dict[str, Any] | None = None diff --git a/memory-store/migrations/000009_sessions.up.sql b/memory-store/migrations/000009_sessions.up.sql index 082f3823c..75b5fde9a 100644 --- a/memory-store/migrations/000009_sessions.up.sql +++ b/memory-store/migrations/000009_sessions.up.sql @@ -7,8 +7,7 @@ CREATE TABLE IF NOT EXISTS sessions ( situation TEXT, system_template TEXT NOT NULL, created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - -- NOTE: Derived from entries - -- updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, metadata JSONB NOT NULL DEFAULT '{}'::JSONB, render_templates BOOLEAN NOT NULL DEFAULT TRUE, token_budget INTEGER, diff --git a/memory-store/migrations/000015_entries.up.sql b/memory-store/migrations/000015_entries.up.sql index 9985e4c41..c104091a2 100644 --- a/memory-store/migrations/000015_entries.up.sql +++ b/memory-store/migrations/000015_entries.up.sql @@ -1,7 +1,7 @@ BEGIN; -- Create chat_role enum -CREATE TYPE chat_role AS ENUM('user', 'assistant', 'tool', 'system'); +CREATE TYPE chat_role AS ENUM('user', 'assistant', 'tool', 'system', 'developer'); -- Create entries table CREATE TABLE IF NOT EXISTS entries ( @@ -85,4 +85,20 @@ OR UPDATE ON entries FOR EACH ROW EXECUTE FUNCTION optimized_update_token_count_after (); -COMMIT; \ No newline at end of file +-- Add trigger to update parent session's updated_at +CREATE OR REPLACE FUNCTION update_session_updated_at() +RETURNS TRIGGER AS $$ +BEGIN + UPDATE sessions + SET updated_at = CURRENT_TIMESTAMP + WHERE session_id = NEW.session_id; + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +CREATE TRIGGER trg_update_session_updated_at +AFTER INSERT OR UPDATE ON entries +FOR EACH ROW +EXECUTE FUNCTION update_session_updated_at(); + +COMMIT; diff --git a/memory-store/migrations/000016_entry_relations.up.sql b/memory-store/migrations/000016_entry_relations.up.sql index c61c7cd24..bcdb7fb72 100644 --- a/memory-store/migrations/000016_entry_relations.up.sql +++ b/memory-store/migrations/000016_entry_relations.up.sql @@ -31,25 +31,29 @@ CREATE INDEX idx_entry_relations_components ON entry_relations (session_id, head CREATE INDEX idx_entry_relations_leaf ON entry_relations (session_id, relation, is_leaf); -CREATE -OR REPLACE FUNCTION enforce_leaf_nodes () RETURNS TRIGGER AS $$ +CREATE OR REPLACE FUNCTION auto_update_leaf_status() RETURNS TRIGGER AS $$ BEGIN - IF NEW.is_leaf THEN - -- Ensure no other relations point to this leaf node as a head - IF EXISTS ( - SELECT 1 FROM entry_relations - WHERE tail = NEW.head AND session_id = NEW.session_id - ) THEN - RAISE EXCEPTION 'Cannot assign relations to a leaf node.'; - END IF; - END IF; + -- Set is_leaf = false for any existing rows that will now have this new relation as a child + UPDATE entry_relations + SET is_leaf = false + WHERE session_id = NEW.session_id + AND tail = NEW.head; + + -- Set is_leaf for the new row based on whether it has any children + NEW.is_leaf := NOT EXISTS ( + SELECT 1 + FROM entry_relations + WHERE session_id = NEW.session_id + AND head = NEW.tail + ); + RETURN NEW; END; $$ LANGUAGE plpgsql; -CREATE TRIGGER trg_enforce_leaf_nodes BEFORE INSERT -OR -UPDATE ON entry_relations FOR EACH ROW -EXECUTE FUNCTION enforce_leaf_nodes (); +CREATE TRIGGER trg_auto_update_leaf_status +BEFORE INSERT OR UPDATE ON entry_relations +FOR EACH ROW +EXECUTE FUNCTION auto_update_leaf_status(); COMMIT; \ No newline at end of file diff --git a/typespec/entries/models.tsp b/typespec/entries/models.tsp index 7f8c8b9fa..d7eae55e7 100644 --- a/typespec/entries/models.tsp +++ b/typespec/entries/models.tsp @@ -107,6 +107,7 @@ model BaseEntry { tokenizer: string; token_count: uint16; + "model": string = "gpt-4o-mini"; /** Tool calls generated by the model. */ tool_calls?: ChosenToolCall[] | null = null; diff --git a/typespec/sessions/models.tsp b/typespec/sessions/models.tsp index f15453a5f..720625f3b 100644 --- a/typespec/sessions/models.tsp +++ b/typespec/sessions/models.tsp @@ -63,6 +63,9 @@ model Session { /** A specific situation that sets the background for this session */ situation: string = defaultSessionSystemMessage; + /** System prompt for this session */ + system_template: string | null = null; + /** Summary (null at the beginning) - generated automatically after every interaction */ @visibility("read") summary: string | null = null; @@ -83,6 +86,9 @@ model Session { * If a tool call is not made, the model's output will be returned as is. */ auto_run_tools: boolean = false; + /** Whether to forward tool calls to the model */ + forward_tool_calls: boolean = false; + recall_options?: RecallOptions | null = null; ...HasId; diff --git a/typespec/tsp-output/@typespec/openapi3/openapi-1.0.0.yaml b/typespec/tsp-output/@typespec/openapi3/openapi-1.0.0.yaml index 0a12aac74..d4835a695 100644 --- a/typespec/tsp-output/@typespec/openapi3/openapi-1.0.0.yaml +++ b/typespec/tsp-output/@typespec/openapi3/openapi-1.0.0.yaml @@ -3064,6 +3064,7 @@ components: - source - tokenizer - token_count + - model - timestamp properties: role: @@ -3307,6 +3308,9 @@ components: token_count: type: integer format: uint16 + model: + type: string + default: gpt-4o-mini tool_calls: type: array items: @@ -3757,10 +3761,12 @@ components: required: - id - situation + - system_template - render_templates - token_budget - context_overflow - auto_run_tools + - forward_tool_calls properties: id: $ref: '#/components/schemas/Common.uuid' @@ -3836,6 +3842,11 @@ components: {{"---"}} {%- endfor -%} {%- endif -%} + system_template: + type: string + nullable: true + description: System prompt for this session + default: null render_templates: type: boolean description: Render system and assistant message content as jinja templates @@ -3861,6 +3872,10 @@ components: If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. default: false + forward_tool_calls: + type: boolean + description: Whether to forward tool calls to the model + default: false recall_options: type: object allOf: @@ -3876,10 +3891,12 @@ components: type: object required: - situation + - system_template - render_templates - token_budget - context_overflow - auto_run_tools + - forward_tool_calls properties: user: allOf: @@ -3953,6 +3970,11 @@ components: {{"---"}} {%- endfor -%} {%- endif -%} + system_template: + type: string + nullable: true + description: System prompt for this session + default: null render_templates: type: boolean description: Render system and assistant message content as jinja templates @@ -3978,6 +4000,10 @@ components: If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. default: false + forward_tool_calls: + type: boolean + description: Whether to forward tool calls to the model + default: false recall_options: type: object allOf: @@ -4092,6 +4118,11 @@ components: {{"---"}} {%- endfor -%} {%- endif -%} + system_template: + type: string + nullable: true + description: System prompt for this session + default: null render_templates: type: boolean description: Render system and assistant message content as jinja templates @@ -4117,6 +4148,10 @@ components: If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. default: false + forward_tool_calls: + type: boolean + description: Whether to forward tool calls to the model + default: false recall_options: type: object allOf: @@ -4185,11 +4220,13 @@ components: type: object required: - situation + - system_template - summary - render_templates - token_budget - context_overflow - auto_run_tools + - forward_tool_calls - id - created_at - updated_at @@ -4250,6 +4287,11 @@ components: {{"---"}} {%- endfor -%} {%- endif -%} + system_template: + type: string + nullable: true + description: System prompt for this session + default: null summary: type: string nullable: true @@ -4281,6 +4323,10 @@ components: If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. default: false + forward_tool_calls: + type: boolean + description: Whether to forward tool calls to the model + default: false recall_options: type: object allOf: @@ -4356,10 +4402,12 @@ components: type: object required: - situation + - system_template - render_templates - token_budget - context_overflow - auto_run_tools + - forward_tool_calls properties: situation: type: string @@ -4417,6 +4465,11 @@ components: {{"---"}} {%- endfor -%} {%- endif -%} + system_template: + type: string + nullable: true + description: System prompt for this session + default: null render_templates: type: boolean description: Render system and assistant message content as jinja templates @@ -4442,6 +4495,10 @@ components: If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. default: false + forward_tool_calls: + type: boolean + description: Whether to forward tool calls to the model + default: false recall_options: type: object allOf: