diff --git a/.github/workflows/lint-agents-api-pr.yml b/.github/workflows/lint-agents-api-pr.yml index dc5767314..5850441ef 100644 --- a/.github/workflows/lint-agents-api-pr.yml +++ b/.github/workflows/lint-agents-api-pr.yml @@ -23,6 +23,11 @@ jobs: uses: astral-sh/setup-uv@v4 with: enable-cache: true + + - name: Install Go migrate + uses: jaxxstorm/action-install-gh-release@v1.10.0 + with: # Grab the latest version + repo: golang-migrate/migrate - name: Set up python and install dependencies run: | diff --git a/.github/workflows/test-agents-api-pr.yml b/.github/workflows/test-agents-api-pr.yml index 04016f034..80f736a87 100644 --- a/.github/workflows/test-agents-api-pr.yml +++ b/.github/workflows/test-agents-api-pr.yml @@ -23,6 +23,11 @@ jobs: uses: astral-sh/setup-uv@v4 with: enable-cache: true + + - name: Install Go migrate + uses: jaxxstorm/action-install-gh-release@v1.10.0 + with: # Grab the latest version + repo: golang-migrate/migrate - name: Set up python and install dependencies run: | diff --git a/.github/workflows/typecheck-agents-api-pr.yml b/.github/workflows/typecheck-agents-api-pr.yml index b9e543c34..3569d65b4 100644 --- a/.github/workflows/typecheck-agents-api-pr.yml +++ b/.github/workflows/typecheck-agents-api-pr.yml @@ -31,6 +31,11 @@ jobs: uses: astral-sh/setup-uv@v4 with: enable-cache: true + + - name: Install Go migrate + uses: jaxxstorm/action-install-gh-release@v1.10.0 + with: # Grab the latest version + repo: golang-migrate/migrate - name: Set up python and install dependencies run: | diff --git a/.gitignore b/.gitignore index 0adb06f10..591aabab1 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ ngrok* */node_modules/ .aider* .vscode/ +schema.sql diff --git a/agents-api/.gitignore b/agents-api/.gitignore index 33217a796..c2e19f143 100644 --- a/agents-api/.gitignore +++ b/agents-api/.gitignore @@ -1,5 +1,4 @@ # Local database files -cozo.db temporal.db *.bak *.dat diff --git a/agents-api/Dockerfile b/agents-api/Dockerfile index 54ae6b576..3408c38b5 100644 --- a/agents-api/Dockerfile +++ b/agents-api/Dockerfile @@ -30,4 +30,4 @@ COPY . ./ ENV PYTHONUNBUFFERED=1 ENV GUNICORN_CMD_ARGS="--capture-output --enable-stdio-inheritance" -ENTRYPOINT ["uv", "run", "gunicorn", "agents_api.web:app", "-c", "gunicorn_conf.py"] +ENTRYPOINT ["uv", "run", "--offline", "--no-sync", "gunicorn", "agents_api.web:app", "-c", "gunicorn_conf.py"] diff --git a/agents-api/Dockerfile.migration b/agents-api/Dockerfile.migration deleted file mode 100644 index 78f60c16b..000000000 --- a/agents-api/Dockerfile.migration +++ /dev/null @@ -1,22 +0,0 @@ -# syntax=docker/dockerfile:1 -# check=error=true - -FROM python:3.13-slim - -ENV PYTHONUNBUFFERED=1 -ENV POETRY_CACHE_DIR=/tmp/poetry_cache - -WORKDIR /app - -RUN pip install --no-cache-dir --upgrade cozo-migrate - -COPY . ./ -ENV COZO_HOST="http://cozo:9070" - -# Expected environment variables: -# COZO_AUTH_TOKEN="myauthkey" - -SHELL ["/bin/bash", "-c"] -ENTRYPOINT \ - cozo-migrate -e http -h $COZO_HOST --auth $COZO_AUTH_TOKEN init \ - ; cozo-migrate -e http -h $COZO_HOST --auth $COZO_AUTH_TOKEN -d ./migrations apply -ay diff --git a/agents-api/Dockerfile.worker b/agents-api/Dockerfile.worker index 88f30e2d2..34538a27d 100644 --- a/agents-api/Dockerfile.worker +++ b/agents-api/Dockerfile.worker @@ -30,4 +30,4 @@ COPY . ./ ENV PYTHONUNBUFFERED=1 ENV GUNICORN_CMD_ARGS="--capture-output --enable-stdio-inheritance" -ENTRYPOINT ["uv", "run", "python", "-m", "agents_api.worker"] +ENTRYPOINT ["uv", "run", "--offline", "--no-sync", "python", "-m", "agents_api.worker"] diff --git a/agents-api/agents_api/activities/container.py b/agents-api/agents_api/activities/container.py new file mode 100644 index 000000000..09bb14882 --- /dev/null +++ b/agents-api/agents_api/activities/container.py @@ -0,0 +1,12 @@ +class State: + pass + + +class Container: + state: State + + def __init__(self): + self.state = State() + + +container = Container() diff --git a/agents-api/agents_api/activities/demo.py b/agents-api/agents_api/activities/demo.py index f6d63f206..ba2babf43 100644 --- a/agents-api/agents_api/activities/demo.py +++ b/agents-api/agents_api/activities/demo.py @@ -1,5 +1,3 @@ -from typing import Callable - from temporalio import activity from ..env import testing @@ -7,13 +5,14 @@ async def demo_activity(a: int, b: int) -> int: # Should throw an error if testing is not enabled - raise Exception("This should not be called in production") + msg = "This should not be called in production" + raise Exception(msg) async def mock_demo_activity(a: int, b: int) -> int: return a + b -demo_activity: Callable[[int, int], int] = activity.defn(name="demo_activity")( +demo_activity = activity.defn(name="demo_activity")( demo_activity if not testing else mock_demo_activity ) diff --git a/agents-api/agents_api/activities/embed_docs.py b/agents-api/agents_api/activities/embed_docs.py deleted file mode 100644 index c6c7663c3..000000000 --- a/agents-api/agents_api/activities/embed_docs.py +++ /dev/null @@ -1,75 +0,0 @@ -import asyncio -import operator -from functools import reduce -from itertools import batched - -from beartype import beartype -from temporalio import activity - -from ..clients import cozo, litellm -from ..common.storage_handler import auto_blob_store -from ..env import testing -from ..models.docs.embed_snippets import embed_snippets as embed_snippets_query -from .types import EmbedDocsPayload - - -@auto_blob_store(deep=True) -@beartype -async def embed_docs( - payload: EmbedDocsPayload, cozo_client=None, max_batch_size: int = 100 -) -> None: - # Create batches of both indices and snippets together - indexed_snippets = list(enumerate(payload.content)) - # Batch snippets into groups of max_batch_size for parallel processing - batched_indexed_snippets = list(batched(indexed_snippets, max_batch_size)) - # Get embedding instruction and title from payload, defaulting to empty strings - embed_instruction: str = payload.embed_instruction or "" - title: str = payload.title or "" - - # Helper function to embed a batch of snippets - async def embed_batch(indexed_batch): - # Split indices and snippets for the batch - batch_indices, batch_snippets = zip(*indexed_batch) - embeddings = await litellm.aembedding( - inputs=[ - ((title + "\n\n" + snippet) if title else snippet).strip() - for snippet in batch_snippets - ], - embed_instruction=embed_instruction, - ) - return list(zip(batch_indices, embeddings)) - - # Gather embeddings with their corresponding indices - indexed_embeddings = reduce( - operator.add, - await asyncio.gather( - *[embed_batch(batch) for batch in batched_indexed_snippets] - ), - ) - - # Split indices and embeddings after all batches are processed - indices, embeddings = zip(*indexed_embeddings) - - # Convert to lists since embed_snippets_query expects list types - indices = list(indices) - embeddings = list(embeddings) - - embed_snippets_query( - developer_id=payload.developer_id, - doc_id=payload.doc_id, - snippet_indices=indices, - embeddings=embeddings, - client=cozo_client or cozo.get_cozo_client(), - ) - - -async def mock_embed_docs( - payload: EmbedDocsPayload, cozo_client=None, max_batch_size=100 -) -> None: - # Does nothing - return None - - -embed_docs = activity.defn(name="embed_docs")( - embed_docs if not testing else mock_embed_docs -) diff --git a/agents-api/agents_api/activities/excecute_api_call.py b/agents-api/agents_api/activities/excecute_api_call.py index 09a33aaa8..5ed6cddc1 100644 --- a/agents-api/agents_api/activities/excecute_api_call.py +++ b/agents-api/agents_api/activities/excecute_api_call.py @@ -1,26 +1,24 @@ import base64 -from typing import Any, Optional, TypedDict, Union +from typing import Any, TypedDict import httpx from beartype import beartype from temporalio import activity from ..autogen.openapi_model import ApiCallDef -from ..common.storage_handler import auto_blob_store from ..env import testing class RequestArgs(TypedDict): - content: Optional[str] - data: Optional[dict[str, Any]] - json_: Optional[dict[str, Any]] - cookies: Optional[dict[str, str]] - params: Optional[Union[str, dict[str, Any]]] - url: Optional[str] - headers: Optional[dict[str, str]] + content: str | None + data: dict[str, Any] | None + json_: dict[str, Any] | None + cookies: dict[str, str] | None + params: str | dict[str, Any] | None + url: str | None + headers: dict[str, str] | None -@auto_blob_store(deep=True) @beartype async def execute_api_call( api_call: ApiCallDef, diff --git a/agents-api/agents_api/activities/execute_integration.py b/agents-api/agents_api/activities/execute_integration.py index 3316ad6f5..7356916db 100644 --- a/agents-api/agents_api/activities/execute_integration.py +++ b/agents-api/agents_api/activities/execute_integration.py @@ -3,16 +3,17 @@ from beartype import beartype from temporalio import activity +from ..app import lifespan from ..autogen.openapi_model import BaseIntegrationDef from ..clients import integrations from ..common.exceptions.tools import IntegrationExecutionException from ..common.protocol.tasks import ExecutionInput, StepContext -from ..common.storage_handler import auto_blob_store from ..env import testing -from ..models.tools import get_tool_args_from_metadata +from ..queries.tools import get_tool_args_from_metadata +from .container import container -@auto_blob_store(deep=True) +@lifespan(container) @beartype async def execute_integration( context: StepContext, @@ -22,23 +23,30 @@ async def execute_integration( setup: dict[str, Any] = {}, ) -> Any: if not isinstance(context.execution_input, ExecutionInput): - raise TypeError("Expected ExecutionInput type for context.execution_input") + msg = "Expected ExecutionInput type for context.execution_input" + raise TypeError(msg) developer_id = context.execution_input.developer_id agent_id = context.execution_input.agent.id task_id = context.execution_input.task.id - merged_tool_args = get_tool_args_from_metadata( - developer_id=developer_id, agent_id=agent_id, task_id=task_id, arg_type="args" + merged_tool_args = await get_tool_args_from_metadata( + developer_id=developer_id, + agent_id=agent_id, + task_id=task_id, + arg_type="args", + connection_pool=container.state.postgres_pool, ) - merged_tool_setup = get_tool_args_from_metadata( - developer_id=developer_id, agent_id=agent_id, task_id=task_id, arg_type="setup" + merged_tool_setup = await get_tool_args_from_metadata( + developer_id=developer_id, + agent_id=agent_id, + task_id=task_id, + arg_type="setup", + connection_pool=container.state.postgres_pool, ) - arguments = ( - merged_tool_args.get(tool_name, {}) | (integration.arguments or {}) | arguments - ) + arguments = merged_tool_args.get(tool_name, {}) | (integration.arguments or {}) | arguments setup = merged_tool_setup.get(tool_name, {}) | (integration.setup or {}) | setup @@ -53,10 +61,7 @@ async def execute_integration( arguments=arguments, ) - if ( - "error" in integration_service_response - and integration_service_response["error"] - ): + if integration_service_response.get("error"): raise IntegrationExecutionException( integration=integration, error=integration_service_response["error"], @@ -69,9 +74,7 @@ async def execute_integration( integration_str = integration.provider + ( "." + integration.method if integration.method else "" ) - activity.logger.error( - f"Error in execute_integration {integration_str}: {e}" - ) + activity.logger.error(f"Error in execute_integration {integration_str}: {e}") raise diff --git a/agents-api/agents_api/activities/execute_system.py b/agents-api/agents_api/activities/execute_system.py index 8d85a2639..802b2900a 100644 --- a/agents-api/agents_api/activities/execute_system.py +++ b/agents-api/agents_api/activities/execute_system.py @@ -6,10 +6,10 @@ from beartype import beartype from box import Box, BoxList -from fastapi import HTTPException from fastapi.background import BackgroundTasks from temporalio import activity +from ..app import app, lifespan from ..autogen.openapi_model import ( ChatInput, CreateDocRequest, @@ -20,16 +20,16 @@ VectorDocSearchRequest, ) from ..common.protocol.tasks import ExecutionInput, StepContext -from ..common.storage_handler import auto_blob_store, load_from_blob_store_if_remote from ..env import testing -from ..models.developer import get_developer +from ..queries.developers import get_developer +from .container import container from .utils import get_handler # For running synchronous code in the background process_pool_executor = ProcessPoolExecutor() -@auto_blob_store(deep=True) +@lifespan(app, container) # Both are needed because we are using the routes @beartype async def execute_system( context: StepContext, @@ -38,11 +38,9 @@ async def execute_system( """Execute a system call with the appropriate handler and transformed arguments.""" arguments: dict[str, Any] = system.arguments or {} - if set(arguments.keys()) == {"bucket", "key"}: - arguments = await load_from_blob_store_if_remote(arguments) - if not isinstance(context.execution_input, ExecutionInput): - raise TypeError("Expected ExecutionInput type for context.execution_input") + msg = "Expected ExecutionInput type for context.execution_input" + raise TypeError(msg) arguments["developer_id"] = context.execution_input.developer_id @@ -95,7 +93,10 @@ async def execute_system( # Handle chat operations if system.operation == "chat" and system.resource == "session": - developer = get_developer(developer_id=arguments.get("developer_id")) + developer = await get_developer( + developer_id=arguments["developer_id"], + connection_pool=container.state.postgres_pool, + ) session_id = arguments.get("session_id") x_custom_api_key = arguments.get("x_custom_api_key", None) chat_input = ChatInput(**arguments) @@ -131,9 +132,7 @@ async def execute_system( # Run the synchronous function in another process loop = asyncio.get_running_loop() - return await loop.run_in_executor( - process_pool_executor, partial(handler, **arguments) - ) + return await loop.run_in_executor(process_pool_executor, partial(handler, **arguments)) except BaseException as e: if activity.in_activity(): activity.logger.error(f"Error in execute_system_call: {e}") @@ -151,19 +150,20 @@ def _create_search_request(arguments: dict) -> Any: confidence=arguments.pop("confidence", 0.5), limit=arguments.get("limit", 10), ) - elif "text" in arguments: + if "text" in arguments: return TextOnlyDocSearchRequest( text=arguments.pop("text"), mmr_strength=arguments.pop("mmr_strength", 0), limit=arguments.get("limit", 10), ) - elif "vector" in arguments: + if "vector" in arguments: return VectorDocSearchRequest( vector=arguments.pop("vector"), mmr_strength=arguments.pop("mmr_strength", 0), confidence=arguments.pop("confidence", 0.7), limit=arguments.get("limit", 10), ) + return None # Keep the existing mock and activity definition diff --git a/agents-api/agents_api/activities/mem_mgmt.py b/agents-api/agents_api/activities/mem_mgmt.py deleted file mode 100644 index 7cd4a7d6b..000000000 --- a/agents-api/agents_api/activities/mem_mgmt.py +++ /dev/null @@ -1,192 +0,0 @@ -from textwrap import dedent -from typing import Callable -from uuid import UUID - -from beartype import beartype -from temporalio import activity - -from ..autogen.openapi_model import InputChatMLMessage -from ..clients import litellm -from .types import MemoryManagementTaskArgs - -example_previous_memory = """ -Speaker 1: Composes and listens to music. Likes to buy basketball shoes but doesn't wear them often. -""".strip() - -example_dialog_context = """ -Speaker 1: Did you find a place to donate your shoes? -Speaker 2: I did! I was driving to the grocery store the other day, when I noticed a bin labeled "Donation for Shoes and Clothing." It was easier than I thought! How about you? Why do you have so many pairs of sandals? -Speaker 1: I don't understand myself! When I look them online I just have the urge to buy them, even when I know I don't need them. This addiction is getting worse and worse. -Speaker 2: I completely agree that buying shoes can become an addiction! Are there any ways you can make money from home while waiting for a job offer from a call center? -Speaker 1: Well I already got the job so I just need to learn using the software. When I was still searching for jobs, we actually do a yard sale to sell many of my random items that are never used and clearly aren't needed either. -Speaker 2: Congratulations on getting the job! I know it'll help you out so much. And of course, maybe I should turn to yard sales as well, for they can be a great way to make some extra cash! -Speaker 1: Do you have another job or do you compose music for a living? How does your shopping addiction go? -Speaker 2: As a matter of fact, I do have another job in addition to composing music. I'm actually a music teacher at a private school, and on the side, I compose music for friends and family. As far as my shopping addiction goes, it's getting better. I promised myself that I wouldn't buy myself any more shoes this year! -Speaker 1: Ah, I remember the time I promised myself the same thing on not buying random things anymore, never work so far. Good luck with yours! -Speaker 2: Thanks! I need the good luck wishes. I've been avoiding malls and shopping outlets. Maybe you can try the same! -Speaker 1: I can avoid them physically, but with my job enable me sitting in front of my computer for a long period of time, I already turn the shopping addiction into online-shopping addiction. lol. Wish me luck! -Speaker 2: Sure thing! You know, and speaking of spending time before a computer, I need to look up information about Precious Moments figurines. I'd still like to know what they are! -""".strip() - -example_updated_memory = """ -Speaker 1: -- Enjoys composing and listening to music. -- Recently got a job that requires the use of specialized software. -- Displays a shopping addiction, particularly for shoes, that has transitioned to online-shopping due to job nature. -- Previously attempted to mitigate shopping addiction without success. -- Had organized a yard sale to sell unused items when job searching. - -Speaker 2: -- Also enjoys buying shoes and admits to it being addictive. -- Works as a music teacher at a private school in addition to composing music. -- Takes active measures to control his shopping addiction, including avoiding malls. -- Is interested in Precious Moments figurines. -""".strip() - - -def make_prompt( - args: MemoryManagementTaskArgs, - max_turns: int = 10, - num_sentences: int = 10, -): - # Unpack - dialog = args.dialog - previous_memories = args.previous_memories - - # Template - template = dedent( - """\ - **Instructions** - You are an advanced AI language model with the ability to store and update a memory to keep track of key personality information for people. You will receive a memory and a dialogue between two people. - - Your goal is to update the memory by incorporating the new personality information for both participants while ensuring that the memory does not exceed {num_sentences} sentences. - - To successfully update the memory, follow these steps: - - 1. Carefully analyze the existing memory and extract the key personality information of the participants from it. - 2. Consider the dialogue provided to identify any new or changed personality traits of either participant that need to be incorporated into the memory. - 3. Combine the old and new personality information to create an updated representation of the participants' traits. - 4. Structure the updated memory in a clear and concise manner, ensuring that it does not exceed {num_sentences} sentences. - 5. Pay attention to the relevance and importance of the personality information, focusing on capturing the most significant aspects while maintaining the overall coherence of the memory. - - Remember, the memory should serve as a reference point to maintain continuity in the dialogue and help accurately set context in future conversations based on the personality traits of the participants. - - **Test Example** - [[Previous Memory]] - {example_previous_memory} - - [[Dialogue Context]] - {example_dialog_context} - - [[Updated Memory]] - {example_updated_memory} - - **Actual Run** - [[Previous Memory]] - {previous_memory} - - [[Dialogue Context]] - {dialog_context} - - [[Updated Memory]] - """ - ).strip() - - # Filter dialog (keep only user and assistant sections) - dialog = [entry for entry in dialog if entry.role != "system"] - - # Truncate to max_turns - dialog = dialog[-max_turns:] - - # Prepare dialog context - dialog_context = "\n".join( - [ - f'{e.name or ("User" if e.role == "user" else "Assistant")}: {e.content}' - for e in dialog - ] - ) - - prompt = template.format( - dialog_context=dialog_context, - previous_memory="\n".join(previous_memories), - num_sentences=num_sentences, - example_dialog_context=example_dialog_context, - example_previous_memory=example_previous_memory, - example_updated_memory=example_updated_memory, - ) - - return prompt - - -async def run_prompt( - dialog: list[InputChatMLMessage], - session_id: UUID, - previous_memories: list[str] = [], - model: str = "gpt-4o", - max_tokens: int = 400, - temperature: float = 0.4, - parser: Callable[[str], str] = lambda x: x, -) -> str: - prompt = make_prompt( - MemoryManagementTaskArgs( - session_id=session_id, - model=model, - dialog=dialog, - previous_memories=previous_memories, - ) - ) - - response = await litellm.acompletion( - model=model, - messages=[ - { - "content": prompt, - "role": "user", - } - ], - max_tokens=max_tokens, - temperature=temperature, - stop=["<", "<|"], - stream=False, - ) - - content = response.choices[0].message.content - - return parser(content.strip() if content is not None else "") - - -@activity.defn -@beartype -async def mem_mgmt( - dialog: list[InputChatMLMessage], - session_id: UUID, - previous_memories: list[str] = [], -) -> None: - # session_id = UUID(session_id) - # entries = [ - # Entry(**row) - # for _, row in client.run( - # get_toplevel_entries_query(session_id=session_id) - # ).iterrows() - # ] - - # assert len(entries) > 0, "no need to summarize on empty entries list" - - await run_prompt(dialog, session_id, previous_memories) - - # new_entry = Entry( - # session_id=session_id, - # source="summarizer", - # role="system", - # name="information", - # content=response, - # timestamp=entries[-1].timestamp + 0.01, - # ) - - # client.run( - # entries_summarization_query( - # session_id=session_id, - # new_entry=new_entry, - # old_entry_ids=[e.id for e in entries], - # ) - # ) diff --git a/agents-api/agents_api/activities/mem_rating.py b/agents-api/agents_api/activities/mem_rating.py deleted file mode 100644 index c681acbc3..000000000 --- a/agents-api/agents_api/activities/mem_rating.py +++ /dev/null @@ -1,100 +0,0 @@ -from textwrap import dedent -from typing import Callable - -from beartype import beartype -from temporalio import activity - -from ..clients import litellm -from .types import MemoryRatingTaskArgs - - -def make_prompt(args: MemoryRatingTaskArgs): - # Unpack - memory = args.memory - - # Template - template = dedent( - """\ - Importance distinguishes mundane from core memories, by assigning a higher score to those memory objects that the agent believes to be important. For instance, a mundane event such as eating breakfast in one’s room would yield a low importance score, whereas a breakup with one’s significant other would yield a high score. There are again many possible implementations of an importance score; we find that directly asking the language model to output an integer score is effective. - - On the scale of 1 to 10, where 1 is purely mundane (e.g., brushing teeth, making bed) and 10 is extremely poignant (e.g., a break up, college acceptance), rate the likely poignancy of the following piece of memory. - - [[Format to follow]] - Memory: - Thought: - Rating: - - [[Hypothetical Example]] - Memory: buying groceries at The Willows Market and Pharmacy - Thought: Grocery shopping is a routine task that most people engage in regularly. While there may be some significance attached to it—for instance, if it's part of a new diet plan or if you're buying groceries for a special occasion—in general, it is unlikely to be a memory that carries substantial emotional weight or has a long-lasting impact on one's life. However, there can be some variance; a mundane grocery trip could become more important if you bump into an old friend or make a particularly interesting discovery (e.g., a new favorite food). But in the absence of such circumstances, the poignancy would be quite low. - Rating: 2 - - [[Actual run]] - Memory: {memory} - """ - ).strip() - - prompt = template.format(memory=memory) - - return prompt - - -async def run_prompt( - memory: str, - model: str = "gpt-4o", - max_tokens: int = 400, - temperature: float = 0.1, - parser: Callable[[str], str] = lambda x: x, -) -> str: - prompt = make_prompt(MemoryRatingTaskArgs(memory=memory)) - - response = await litellm.acompletion( - model=model, - messages=[ - { - "content": prompt, - "role": "user", - } - ], - max_tokens=max_tokens, - temperature=temperature, - stop=["<", "<|"], - stream=False, - ) - - content = response.choices[0].message.content - - return parser(content.strip() if content is not None else "") - - -@activity.defn -@beartype -async def mem_rating(memory: str) -> None: - # session_id = UUID(session_id) - # entries = [ - # Entry(**row) - # for _, row in client.run( - # get_toplevel_entries_query(session_id=session_id) - # ).iterrows() - # ] - - # assert len(entries) > 0, "no need to summarize on empty entries list" - - await run_prompt(memory=memory) - - # new_entry = Entry( - # session_id=session_id, - # source="summarizer", - # role="system", - # name="information", - # content=response, - # timestamp=entries[-1].timestamp + 0.01, - # ) - - # client.run( - # entries_summarization_query( - # session_id=session_id, - # new_entry=new_entry, - # old_entry_ids=[e.id for e in entries], - # ) - # ) diff --git a/agents-api/agents_api/activities/summarization.py b/agents-api/agents_api/activities/summarization.py deleted file mode 100644 index aa7fa4740..000000000 --- a/agents-api/agents_api/activities/summarization.py +++ /dev/null @@ -1,86 +0,0 @@ -#!/usr/bin/env python3 - - -import pandas as pd -from beartype import beartype -from temporalio import activity - -# from agents_api.models.entry.entries_summarization import ( -# entries_summarization_query, -# get_toplevel_entries_query, -# ) - - -# TODO: Implement entry summarization queries -# SCRUM-3 -def entries_summarization_query(*args, **kwargs) -> pd.DataFrame: - return pd.DataFrame() - - -def get_toplevel_entries_query(*args, **kwargs) -> pd.DataFrame: - return pd.DataFrame() - - -# TODO: Implement entry summarization activities -# SCRUM-4 - - -@activity.defn -@beartype -async def summarization(session_id: str) -> None: - raise NotImplementedError() - - # session_id = UUID(session_id) - # entries = [] - # entities_entry_ids = [] - # for _, row in get_toplevel_entries_query(session_id=session_id).iterrows(): - # if row["role"] == "system" and row.get("name") == "entities": - # entities_entry_ids.append(UUID(row["entry_id"], version=4)) - # else: - # entries.append(row) - - # assert len(entries) > 0, "no need to summarize on empty entries list" - - # summarized, entities = await asyncio.gather( - # summarize_messages(entries, model=summarization_model_name), - # get_entities(entries, model=summarization_model_name), - # ) - # trimmed_messages = await trim_messages(summarized, model=summarization_model_name) - # ts_delta = (entries[1]["timestamp"] - entries[0]["timestamp"]) / 2 - # new_entities_entry = Entry( - # session_id=session_id, - # source="summarizer", - # role="system", - # name="entities", - # content=entities["content"], - # timestamp=entries[0]["timestamp"] + ts_delta, - # ) - - # entries_summarization_query( - # session_id=session_id, - # new_entry=new_entities_entry, - # old_entry_ids=entities_entry_ids, - # ) - - # trimmed_map = { - # m["index"]: m["content"] for m in trimmed_messages if m.get("index") is not None - # } - - # for idx, msg in enumerate(summarized): - # new_entry = Entry( - # session_id=session_id, - # source="summarizer", - # role="system", - # name="information", - # content=trimmed_map.get(idx, msg["content"]), - # timestamp=entries[-1]["timestamp"] + 0.01, - # ) - - # entries_summarization_query( - # session_id=session_id, - # new_entry=new_entry, - # old_entry_ids=[ - # UUID(entries[idx - 1]["entry_id"], version=4) - # for idx in msg["summarizes"] - # ], - # ) diff --git a/agents-api/agents_api/activities/sync_items_remote.py b/agents-api/agents_api/activities/sync_items_remote.py index d71a5c566..14751c2b6 100644 --- a/agents-api/agents_api/activities/sync_items_remote.py +++ b/agents-api/agents_api/activities/sync_items_remote.py @@ -9,20 +9,16 @@ @beartype async def save_inputs_remote_fn(inputs: list[Any]) -> list[Any | RemoteObject]: - from ..common.storage_handler import store_in_blob_store_if_large + from ..common.interceptors import offload_if_large - return await asyncio.gather( - *[store_in_blob_store_if_large(input) for input in inputs] - ) + return await asyncio.gather(*[offload_if_large(input) for input in inputs]) @beartype async def load_inputs_remote_fn(inputs: list[Any | RemoteObject]) -> list[Any]: - from ..common.storage_handler import load_from_blob_store_if_remote + from ..common.interceptors import load_if_remote - return await asyncio.gather( - *[load_from_blob_store_if_remote(input) for input in inputs] - ) + return await asyncio.gather(*[load_if_remote(input) for input in inputs]) save_inputs_remote = activity.defn(name="save_inputs_remote")(save_inputs_remote_fn) diff --git a/agents-api/agents_api/activities/task_steps/__init__.py b/agents-api/agents_api/activities/task_steps/__init__.py index 573884629..363a4d5d0 100644 --- a/agents-api/agents_api/activities/task_steps/__init__.py +++ b/agents-api/agents_api/activities/task_steps/__init__.py @@ -1,13 +1,13 @@ # ruff: noqa: F401, F403, F405 from .base_evaluate import base_evaluate -from .cozo_query_step import cozo_query_step from .evaluate_step import evaluate_step from .for_each_step import for_each_step from .get_value_step import get_value_step from .if_else_step import if_else_step from .log_step import log_step from .map_reduce_step import map_reduce_step +from .pg_query_step import pg_query_step from .prompt_step import prompt_step from .raise_complete_async import raise_complete_async from .return_step import return_step diff --git a/agents-api/agents_api/activities/task_steps/base_evaluate.py b/agents-api/agents_api/activities/task_steps/base_evaluate.py index d87b961d3..dcf43c0ee 100644 --- a/agents-api/agents_api/activities/task_steps/base_evaluate.py +++ b/agents-api/agents_api/activities/task_steps/base_evaluate.py @@ -9,13 +9,11 @@ # Increase the max string length to 2048000 simpleeval.MAX_STRING_LENGTH = 2048000 -from simpleeval import NameNotDefined, SimpleEval # noqa: E402 -from temporalio import activity # noqa: E402 -from thefuzz import fuzz # noqa: E402 +from simpleeval import NameNotDefined, SimpleEval +from temporalio import activity +from thefuzz import fuzz -from ...common.storage_handler import auto_blob_store # noqa: E402 -from ...env import testing # noqa: E402 -from ..utils import get_evaluator # noqa: E402 +from ..utils import get_evaluator class EvaluateError(Exception): @@ -30,7 +28,7 @@ def __init__(self, error, expression, values): # Catch a possible misspell in a variable name if isinstance(error, NameNotDefined): misspelledName = error_message.split("'")[1] - for variableName in values.keys(): + for variableName in values: if fuzz.ratio(variableName, misspelledName) >= 90.0: message += f"\nDid you mean '{variableName}' instead of '{misspelledName}'?" super().__init__(message) @@ -46,9 +44,7 @@ def _recursive_evaluate(expr, evaluator: SimpleEval): evaluate_error = EvaluateError(e, expr, evaluator.names) variables_accessed = { - name: value - for name, value in evaluator.names.items() - if name in expr + name: value for name, value in evaluator.names.items() if name in expr } activity.logger.error( @@ -60,10 +56,11 @@ def _recursive_evaluate(expr, evaluator: SimpleEval): elif isinstance(expr, dict): return {k: _recursive_evaluate(v, evaluator) for k, v in expr.items()} else: - raise ValueError(f"Invalid expression: {expr}") + msg = f"Invalid expression: {expr}" + raise ValueError(msg) -@auto_blob_store(deep=True) +@activity.defn @beartype async def base_evaluate( exprs: Any, @@ -84,15 +81,14 @@ async def base_evaluate( try: ast.parse(v) except Exception as e: - raise ValueError(f"Invalid lambda: {v}") from e + msg = f"Invalid lambda: {v}" + raise ValueError(msg) from e # Eval the lambda and add it to the extra lambdas extra_lambdas[k] = eval(v) # Turn the nested dict values from pydantic to dicts where possible - values = { - k: v.model_dump() if isinstance(v, BaseModel) else v for k, v in values.items() - } + values = {k: v.model_dump() if isinstance(v, BaseModel) else v for k, v in values.items()} # frozen_box doesn't work coz we need some mutability in the values values = Box(values, frozen_box=False, conversion_box=True) @@ -100,14 +96,4 @@ async def base_evaluate( evaluator: SimpleEval = get_evaluator(names=values, extra_functions=extra_lambdas) # Recursively evaluate the expression - result = _recursive_evaluate(exprs, evaluator) - return result - - -# Note: This is here just for clarity. We could have just imported base_evaluate directly -# They do the same thing, so we dont need to mock the base_evaluate function -mock_base_evaluate = base_evaluate - -base_evaluate = activity.defn(name="base_evaluate")( - base_evaluate if not testing else mock_base_evaluate -) + return _recursive_evaluate(exprs, evaluator) diff --git a/agents-api/agents_api/activities/task_steps/cozo_query_step.py b/agents-api/agents_api/activities/task_steps/cozo_query_step.py deleted file mode 100644 index 16e9a53d8..000000000 --- a/agents-api/agents_api/activities/task_steps/cozo_query_step.py +++ /dev/null @@ -1,30 +0,0 @@ -from typing import Any - -from beartype import beartype -from temporalio import activity - -from ... import models -from ...common.storage_handler import auto_blob_store -from ...env import testing - - -@auto_blob_store(deep=True) -@beartype -async def cozo_query_step( - query_name: str, - values: dict[str, Any], -) -> Any: - (module_name, name) = query_name.split(".") - - module = getattr(models, module_name) - query = getattr(module, name) - return query(**values) - - -# Note: This is here just for clarity. We could have just imported cozo_query_step directly -# They do the same thing, so we dont need to mock the cozo_query_step function -mock_cozo_query_step = cozo_query_step - -cozo_query_step = activity.defn(name="cozo_query_step")( - cozo_query_step if not testing else mock_cozo_query_step -) diff --git a/agents-api/agents_api/activities/task_steps/evaluate_step.py b/agents-api/agents_api/activities/task_steps/evaluate_step.py index 904ec3b9d..595a2e8ad 100644 --- a/agents-api/agents_api/activities/task_steps/evaluate_step.py +++ b/agents-api/agents_api/activities/task_steps/evaluate_step.py @@ -5,11 +5,9 @@ from ...activities.utils import simple_eval_dict from ...common.protocol.tasks import StepContext, StepOutcome -from ...common.storage_handler import auto_blob_store -from ...env import testing -@auto_blob_store(deep=True) +@activity.defn @beartype async def evaluate_step( context: StepContext, @@ -17,28 +15,13 @@ async def evaluate_step( override_expr: dict[str, str] | None = None, ) -> StepOutcome: try: - expr = ( - override_expr - if override_expr is not None - else context.current_step.evaluate - ) + expr = override_expr if override_expr is not None else context.current_step.evaluate - values = await context.prepare_for_step(include_remote=True) | additional_values + values = await context.prepare_for_step() | additional_values output = simple_eval_dict(expr, values) - result = StepOutcome(output=output) - - return result + return StepOutcome(output=output) except BaseException as e: activity.logger.error(f"Error in evaluate_step: {e}") return StepOutcome(error=str(e) or repr(e)) - - -# Note: This is here just for clarity. We could have just imported evaluate_step directly -# They do the same thing, so we dont need to mock the evaluate_step function -mock_evaluate_step = evaluate_step - -evaluate_step = activity.defn(name="evaluate_step")( - evaluate_step if not testing else mock_evaluate_step -) diff --git a/agents-api/agents_api/activities/task_steps/for_each_step.py b/agents-api/agents_api/activities/task_steps/for_each_step.py index f51c1ef76..4c8495ad3 100644 --- a/agents-api/agents_api/activities/task_steps/for_each_step.py +++ b/agents-api/agents_api/activities/task_steps/for_each_step.py @@ -6,12 +6,10 @@ StepContext, StepOutcome, ) -from ...common.storage_handler import auto_blob_store -from ...env import testing from .base_evaluate import base_evaluate -@auto_blob_store(deep=True) +@activity.defn @beartype async def for_each_step(context: StepContext) -> StepOutcome: try: @@ -25,12 +23,3 @@ async def for_each_step(context: StepContext) -> StepOutcome: except BaseException as e: activity.logger.error(f"Error in for_each_step: {e}") return StepOutcome(error=str(e)) - - -# Note: This is here just for clarity. We could have just imported if_else_step directly -# They do the same thing, so we dont need to mock the if_else_step function -mock_if_else_step = for_each_step - -for_each_step = activity.defn(name="for_each_step")( - for_each_step if not testing else mock_if_else_step -) diff --git a/agents-api/agents_api/activities/task_steps/get_value_step.py b/agents-api/agents_api/activities/task_steps/get_value_step.py index ca38bc4fe..47118833b 100644 --- a/agents-api/agents_api/activities/task_steps/get_value_step.py +++ b/agents-api/agents_api/activities/task_steps/get_value_step.py @@ -2,25 +2,16 @@ from temporalio import activity from ...common.protocol.tasks import StepContext, StepOutcome -from ...common.storage_handler import auto_blob_store -from ...env import testing - # TODO: We should use this step to query the parent workflow and get the value from the workflow context # SCRUM-1 -@auto_blob_store(deep=True) + + +@activity.defn @beartype async def get_value_step( context: StepContext, ) -> StepOutcome: key: str = context.current_step.get # noqa: F841 - raise NotImplementedError("Not implemented yet") - - -# Note: This is here just for clarity. We could have just imported get_value_step directly -# They do the same thing, so we dont need to mock the get_value_step function -mock_get_value_step = get_value_step - -get_value_step = activity.defn(name="get_value_step")( - get_value_step if not testing else mock_get_value_step -) + msg = "Not implemented yet" + raise NotImplementedError(msg) diff --git a/agents-api/agents_api/activities/task_steps/if_else_step.py b/agents-api/agents_api/activities/task_steps/if_else_step.py index cf3764199..b10ec843b 100644 --- a/agents-api/agents_api/activities/task_steps/if_else_step.py +++ b/agents-api/agents_api/activities/task_steps/if_else_step.py @@ -6,12 +6,10 @@ StepContext, StepOutcome, ) -from ...common.storage_handler import auto_blob_store -from ...env import testing from .base_evaluate import base_evaluate -@auto_blob_store(deep=True) +@activity.defn @beartype async def if_else_step(context: StepContext) -> StepOutcome: # NOTE: This activity is only for logging, so we just evaluate the expression @@ -23,18 +21,8 @@ async def if_else_step(context: StepContext) -> StepOutcome: output = await base_evaluate(expr, await context.prepare_for_step()) output: bool = bool(output) - result = StepOutcome(output=output) - return result + return StepOutcome(output=output) except BaseException as e: activity.logger.error(f"Error in if_else_step: {e}") return StepOutcome(error=str(e)) - - -# Note: This is here just for clarity. We could have just imported if_else_step directly -# They do the same thing, so we dont need to mock the if_else_step function -mock_if_else_step = if_else_step - -if_else_step = activity.defn(name="if_else_step")( - if_else_step if not testing else mock_if_else_step -) diff --git a/agents-api/agents_api/activities/task_steps/log_step.py b/agents-api/agents_api/activities/task_steps/log_step.py index 28fea2dae..a19e88ab3 100644 --- a/agents-api/agents_api/activities/task_steps/log_step.py +++ b/agents-api/agents_api/activities/task_steps/log_step.py @@ -6,12 +6,10 @@ StepContext, StepOutcome, ) -from ...common.storage_handler import auto_blob_store from ...common.utils.template import render_template -from ...env import testing -@auto_blob_store(deep=True) +@activity.defn @beartype async def log_step(context: StepContext) -> StepOutcome: # NOTE: This activity is only for logging, so we just evaluate the expression @@ -22,20 +20,12 @@ async def log_step(context: StepContext) -> StepOutcome: template: str = context.current_step.log output = await render_template( template, - await context.prepare_for_step(include_remote=True), + await context.prepare_for_step(), skip_vars=["developer_id"], ) - result = StepOutcome(output=output) - return result + return StepOutcome(output=output) except BaseException as e: activity.logger.error(f"Error in log_step: {e}") return StepOutcome(error=str(e)) - - -# Note: This is here just for clarity. We could have just imported log_step directly -# They do the same thing, so we dont need to mock the log_step function -mock_log_step = log_step - -log_step = activity.defn(name="log_step")(log_step if not testing else mock_log_step) diff --git a/agents-api/agents_api/activities/task_steps/map_reduce_step.py b/agents-api/agents_api/activities/task_steps/map_reduce_step.py index 872988bb4..600f98615 100644 --- a/agents-api/agents_api/activities/task_steps/map_reduce_step.py +++ b/agents-api/agents_api/activities/task_steps/map_reduce_step.py @@ -8,12 +8,10 @@ StepContext, StepOutcome, ) -from ...common.storage_handler import auto_blob_store -from ...env import testing from .base_evaluate import base_evaluate -@auto_blob_store(deep=True) +@activity.defn @beartype async def map_reduce_step(context: StepContext) -> StepOutcome: try: @@ -28,12 +26,3 @@ async def map_reduce_step(context: StepContext) -> StepOutcome: except BaseException as e: logging.error(f"Error in map_reduce_step: {e}") return StepOutcome(error=str(e)) - - -# Note: This is here just for clarity. We could have just imported if_else_step directly -# They do the same thing, so we dont need to mock the if_else_step function -mock_if_else_step = map_reduce_step - -map_reduce_step = activity.defn(name="map_reduce_step")( - map_reduce_step if not testing else mock_if_else_step -) diff --git a/agents-api/agents_api/activities/task_steps/pg_query_step.py b/agents-api/agents_api/activities/task_steps/pg_query_step.py new file mode 100644 index 000000000..2c081cb15 --- /dev/null +++ b/agents-api/agents_api/activities/task_steps/pg_query_step.py @@ -0,0 +1,24 @@ +from typing import Any + +from beartype import beartype +from temporalio import activity + +from ... import queries +from ...app import lifespan +from ...env import pg_dsn +from ..container import container + + +@activity.defn +@lifespan(container) +@beartype +async def pg_query_step( + query_name: str, + values: dict[str, Any], + dsn: str = pg_dsn, +) -> Any: + (module_name, name) = query_name.split(".") + + module = getattr(queries, module_name) + query = getattr(module, name) + return await query(**values, connection_pool=container.state.postgres_pool) diff --git a/agents-api/agents_api/activities/task_steps/prompt_step.py b/agents-api/agents_api/activities/task_steps/prompt_step.py index cf8b169d5..0824b9733 100644 --- a/agents-api/agents_api/activities/task_steps/prompt_step.py +++ b/agents-api/agents_api/activities/task_steps/prompt_step.py @@ -8,7 +8,6 @@ litellm, # We dont directly import `acompletion` so we can mock it ) from ...common.protocol.tasks import ExecutionInput, StepContext, StepOutcome -from ...common.storage_handler import auto_blob_store from ...common.utils.template import render_template from ...env import debug from .base_evaluate import base_evaluate @@ -28,7 +27,7 @@ def format_tool(tool: Tool) -> dict: } # For other tool types, we need to translate them to the OpenAI function tool format - formatted = { + return { "type": "function", "function": {"name": tool.name, "description": tool.description}, } @@ -55,39 +54,29 @@ def format_tool(tool: Tool) -> dict: # elif tool.type == "api_call": # raise NotImplementedError("API call tools are not supported") - return formatted - EVAL_PROMPT_PREFIX = "$_ " @activity.defn -@auto_blob_store(deep=True) @beartype async def prompt_step(context: StepContext) -> StepOutcome: # Get context data prompt: str | list[dict] = context.current_step.model_dump()["prompt"] - context_data: dict = await context.prepare_for_step(include_remote=True) + context_data: dict = await context.prepare_for_step() # If the prompt is a string and starts with $_ then we need to evaluate it - should_evaluate_prompt = isinstance(prompt, str) and prompt.startswith( - EVAL_PROMPT_PREFIX - ) + should_evaluate_prompt = isinstance(prompt, str) and prompt.startswith(EVAL_PROMPT_PREFIX) if should_evaluate_prompt: - prompt = await base_evaluate( - prompt[len(EVAL_PROMPT_PREFIX) :].strip(), context_data - ) + prompt = await base_evaluate(prompt[len(EVAL_PROMPT_PREFIX) :].strip(), context_data) - if not isinstance(prompt, (str, list)): - raise ApplicationError( - "Invalid prompt expression, expected a string or list" - ) + if not isinstance(prompt, str | list): + msg = "Invalid prompt expression, expected a string or list" + raise ApplicationError(msg) # Wrap the prompt in a list if it is not already - prompt = ( - prompt if isinstance(prompt, list) else [{"role": "user", "content": prompt}] - ) + prompt = prompt if isinstance(prompt, list) else [{"role": "user", "content": prompt}] # Render template messages if we didn't evaluate the prompt if not should_evaluate_prompt: @@ -99,7 +88,8 @@ async def prompt_step(context: StepContext) -> StepOutcome: ) if not isinstance(context.execution_input, ExecutionInput): - raise TypeError("Expected ExecutionInput type for context.execution_input") + msg = "Expected ExecutionInput type for context.execution_input" + raise TypeError(msg) # Get settings and run llm agent_default_settings: dict = ( @@ -109,9 +99,7 @@ async def prompt_step(context: StepContext) -> StepOutcome: ) agent_model: str = ( - context.execution_input.agent.model - if context.execution_input.agent.model - else "gpt-4o" + context.execution_input.agent.model if context.execution_input.agent.model else "gpt-4o" ) excluded_keys = [ @@ -204,11 +192,13 @@ async def prompt_step(context: StepContext) -> StepOutcome: if context.current_step.unwrap: if len(response.choices) != 1: - raise ApplicationError("Only one choice is supported") + msg = "Only one choice is supported" + raise ApplicationError(msg) choice = response.choices[0] if choice.finish_reason == "tool_calls": - raise ApplicationError("Tool calls cannot be unwrapped") + msg = "Tool calls cannot be unwrapped" + raise ApplicationError(msg) return StepOutcome( output=choice.message.content, @@ -228,7 +218,8 @@ async def prompt_step(context: StepContext) -> StepOutcome: original_tool = tools_mapping.get(call_name) if not original_tool: - raise ApplicationError(f"Tool {call_name} not found") + msg = f"Tool {call_name} not found" + raise ApplicationError(msg) if original_tool.type == "function": continue diff --git a/agents-api/agents_api/activities/task_steps/raise_complete_async.py b/agents-api/agents_api/activities/task_steps/raise_complete_async.py index 640d6ae4e..bbf27c500 100644 --- a/agents-api/agents_api/activities/task_steps/raise_complete_async.py +++ b/agents-api/agents_api/activities/task_steps/raise_complete_async.py @@ -6,12 +6,10 @@ from ...autogen.openapi_model import CreateTransitionRequest from ...common.protocol.tasks import StepContext -from ...common.storage_handler import auto_blob_store from .transition_step import original_transition_step @activity.defn -@auto_blob_store(deep=True) @beartype async def raise_complete_async(context: StepContext, output: Any) -> None: activity_info = activity.info() diff --git a/agents-api/agents_api/activities/task_steps/return_step.py b/agents-api/agents_api/activities/task_steps/return_step.py index 08ac20de4..05fe0ce16 100644 --- a/agents-api/agents_api/activities/task_steps/return_step.py +++ b/agents-api/agents_api/activities/task_steps/return_step.py @@ -6,12 +6,10 @@ StepContext, StepOutcome, ) -from ...common.storage_handler import auto_blob_store -from ...env import testing from .base_evaluate import base_evaluate -@auto_blob_store(deep=True) +@activity.defn @beartype async def return_step(context: StepContext) -> StepOutcome: try: @@ -20,18 +18,8 @@ async def return_step(context: StepContext) -> StepOutcome: exprs: dict[str, str] = context.current_step.return_ output = await base_evaluate(exprs, await context.prepare_for_step()) - result = StepOutcome(output=output) - return result + return StepOutcome(output=output) except BaseException as e: activity.logger.error(f"Error in log_step: {e}") return StepOutcome(error=str(e)) - - -# Note: This is here just for clarity. We could have just imported return_step directly -# They do the same thing, so we dont need to mock the return_step function -mock_return_step = return_step - -return_step = activity.defn(name="return_step")( - return_step if not testing else mock_return_step -) diff --git a/agents-api/agents_api/activities/task_steps/set_value_step.py b/agents-api/agents_api/activities/task_steps/set_value_step.py index 1c97b6551..031c6eb44 100644 --- a/agents-api/agents_api/activities/task_steps/set_value_step.py +++ b/agents-api/agents_api/activities/task_steps/set_value_step.py @@ -5,13 +5,12 @@ from ...activities.utils import simple_eval_dict from ...common.protocol.tasks import StepContext, StepOutcome -from ...common.storage_handler import auto_blob_store -from ...env import testing - # TODO: We should use this step to signal to the parent workflow and set the value on the workflow context # SCRUM-2 -@auto_blob_store(deep=True) + + +@activity.defn @beartype async def set_value_step( context: StepContext, @@ -23,19 +22,8 @@ async def set_value_step( values = await context.prepare_for_step() | additional_values output = simple_eval_dict(expr, values) - result = StepOutcome(output=output) - - return result + return StepOutcome(output=output) except BaseException as e: activity.logger.error(f"Error in set_value_step: {e}") return StepOutcome(error=str(e) or repr(e)) - - -# Note: This is here just for clarity. We could have just imported set_value_step directly -# They do the same thing, so we dont need to mock the set_value_step function -mock_set_value_step = set_value_step - -set_value_step = activity.defn(name="set_value_step")( - set_value_step if not testing else mock_set_value_step -) diff --git a/agents-api/agents_api/activities/task_steps/switch_step.py b/agents-api/agents_api/activities/task_steps/switch_step.py index 6a95e98d2..b39791b6b 100644 --- a/agents-api/agents_api/activities/task_steps/switch_step.py +++ b/agents-api/agents_api/activities/task_steps/switch_step.py @@ -6,12 +6,10 @@ StepContext, StepOutcome, ) -from ...common.storage_handler import auto_blob_store -from ...env import testing from ..utils import get_evaluator -@auto_blob_store(deep=True) +@activity.defn @beartype async def switch_step(context: StepContext) -> StepOutcome: try: @@ -30,16 +28,8 @@ async def switch_step(context: StepContext) -> StepOutcome: output = i break - result = StepOutcome(output=output) - return result + return StepOutcome(output=output) except BaseException as e: activity.logger.error(f"Error in switch_step: {e}") return StepOutcome(error=str(e)) - - -mock_switch_step = switch_step - -switch_step = activity.defn(name="switch_step")( - switch_step if not testing else mock_switch_step -) diff --git a/agents-api/agents_api/activities/task_steps/tool_call_step.py b/agents-api/agents_api/activities/task_steps/tool_call_step.py index 5725a75d1..2745414d8 100644 --- a/agents-api/agents_api/activities/task_steps/tool_call_step.py +++ b/agents-api/agents_api/activities/task_steps/tool_call_step.py @@ -11,7 +11,6 @@ StepContext, StepOutcome, ) -from ...common.storage_handler import auto_blob_store # FIXME: This shouldn't be here. @@ -25,9 +24,7 @@ def generate_call_id(): # FIXME: This shouldn't be here, and shouldn't be done this way. Should be refactored. -def construct_tool_call( - tool: CreateToolRequest | Tool, arguments: dict, call_id: str -) -> dict: +def construct_tool_call(tool: CreateToolRequest | Tool, arguments: dict, call_id: str) -> dict: return { tool.type: { "arguments": arguments, @@ -47,7 +44,6 @@ def construct_tool_call( @activity.defn -@auto_blob_store(deep=True) @beartype async def tool_call_step(context: StepContext) -> StepOutcome: assert isinstance(context.current_step, ToolCallStep) @@ -58,7 +54,8 @@ async def tool_call_step(context: StepContext) -> StepOutcome: tool = next((t for t in tools if t.name == tool_name), None) if tool is None: - raise ApplicationError(f"Tool {tool_name} not found in the toolset") + msg = f"Tool {tool_name} not found in the toolset" + raise ApplicationError(msg) arguments = await base_evaluate( context.current_step.arguments, await context.prepare_for_step() diff --git a/agents-api/agents_api/activities/task_steps/transition_step.py b/agents-api/agents_api/activities/task_steps/transition_step.py index 44046a5e7..f5eb60a10 100644 --- a/agents-api/agents_api/activities/task_steps/transition_step.py +++ b/agents-api/agents_api/activities/task_steps/transition_step.py @@ -5,66 +5,48 @@ from fastapi import HTTPException from temporalio import activity +from ...app import lifespan from ...autogen.openapi_model import CreateTransitionRequest, Transition from ...clients.temporal import get_workflow_handle from ...common.protocol.tasks import ExecutionInput, StepContext -from ...common.storage_handler import load_from_blob_store_if_remote -from ...env import ( - temporal_activity_after_retry_timeout, - testing, - transition_requests_per_minute, -) +from ...env import temporal_activity_after_retry_timeout from ...exceptions import LastErrorInput, TooManyRequestsError -from ...models.execution.create_execution_transition import ( - create_execution_transition_async, +from ...queries.executions.create_execution_transition import ( + create_execution_transition, ) -from ..utils import RateLimiter - -# Global rate limiter instance -rate_limiter = RateLimiter(max_requests=transition_requests_per_minute) +from ..container import container +@lifespan(container) @beartype async def transition_step( context: StepContext, transition_info: CreateTransitionRequest, last_error: BaseException | None = None, ) -> Transition: - # Check rate limit first - if not await rate_limiter.acquire(): - raise TooManyRequestsError( - f"Rate limit exceeded. Maximum {transition_requests_per_minute} requests per minute allowed." - ) - from ...workflows.task_execution import TaskExecutionWorkflow activity_info = activity.info() wf_handle = await get_workflow_handle(handle_id=activity_info.workflow_id) - # TODO: Filter by last_error type if last_error is not None: await asyncio.sleep(temporal_activity_after_retry_timeout) await wf_handle.signal( TaskExecutionWorkflow.set_last_error, LastErrorInput(last_error=None) ) - # Load output from blob store if it is a remote object - transition_info.output = await load_from_blob_store_if_remote( - transition_info.output - ) - if not isinstance(context.execution_input, ExecutionInput): - raise TypeError("Expected ExecutionInput type for context.execution_input") + msg = "Expected ExecutionInput type for context.execution_input" + raise TypeError(msg) # Create transition try: - transition = await create_execution_transition_async( + transition = await create_execution_transition( developer_id=context.execution_input.developer_id, execution_id=context.execution_input.execution.id, - task_id=context.execution_input.task.id, data=transition_info, task_token=transition_info.task_token, - update_execution_status=True, + connection_pool=container.state.postgres_pool, ) except Exception as e: @@ -78,9 +60,7 @@ async def transition_step( return transition +# NOTE: Here because needed by a different step original_transition_step = transition_step -mock_transition_step = transition_step -transition_step = activity.defn(name="transition_step")( - transition_step if not testing else mock_transition_step -) +transition_step = activity.defn(transition_step) diff --git a/agents-api/agents_api/activities/task_steps/wait_for_input_step.py b/agents-api/agents_api/activities/task_steps/wait_for_input_step.py index ad6eeb63e..267da3195 100644 --- a/agents-api/agents_api/activities/task_steps/wait_for_input_step.py +++ b/agents-api/agents_api/activities/task_steps/wait_for_input_step.py @@ -3,12 +3,10 @@ from ...autogen.openapi_model import WaitForInputStep from ...common.protocol.tasks import StepContext, StepOutcome -from ...common.storage_handler import auto_blob_store -from ...env import testing from .base_evaluate import base_evaluate -@auto_blob_store(deep=True) +@activity.defn @beartype async def wait_for_input_step(context: StepContext) -> StepOutcome: try: @@ -17,16 +15,8 @@ async def wait_for_input_step(context: StepContext) -> StepOutcome: exprs = context.current_step.wait_for_input.info output = await base_evaluate(exprs, await context.prepare_for_step()) - result = StepOutcome(output=output) - return result + return StepOutcome(output=output) except BaseException as e: activity.logger.error(f"Error in wait_for_input_step: {e}") return StepOutcome(error=str(e)) - - -mock_wait_for_input_step = wait_for_input_step - -wait_for_input_step = activity.defn(name="wait_for_input_step")( - wait_for_input_step if not testing else mock_wait_for_input_step -) diff --git a/agents-api/agents_api/activities/task_steps/yield_step.py b/agents-api/agents_api/activities/task_steps/yield_step.py index 199008703..d0b9e6f29 100644 --- a/agents-api/agents_api/activities/task_steps/yield_step.py +++ b/agents-api/agents_api/activities/task_steps/yield_step.py @@ -1,31 +1,28 @@ -from typing import Callable - from beartype import beartype from temporalio import activity from ...autogen.openapi_model import TransitionTarget, YieldStep from ...common.protocol.tasks import ExecutionInput, StepContext, StepOutcome -from ...common.storage_handler import auto_blob_store -from ...env import testing from .base_evaluate import base_evaluate -@auto_blob_store(deep=True) +@activity.defn @beartype async def yield_step(context: StepContext) -> StepOutcome: try: assert isinstance(context.current_step, YieldStep) if not isinstance(context.execution_input, ExecutionInput): - raise TypeError("Expected ExecutionInput type for context.execution_input") + msg = "Expected ExecutionInput type for context.execution_input" + raise TypeError(msg) all_workflows = context.execution_input.task.workflows workflow = context.current_step.workflow exprs = context.current_step.arguments - assert workflow in [ - wf.name for wf in all_workflows - ], f"Workflow {workflow} not found in task" + assert workflow in [wf.name for wf in all_workflows], ( + f"Workflow {workflow} not found in task" + ) # Evaluate the expressions in the arguments arguments = await base_evaluate(exprs, await context.prepare_for_step()) @@ -41,12 +38,3 @@ async def yield_step(context: StepContext) -> StepOutcome: except BaseException as e: activity.logger.error(f"Error in yield_step: {e}") return StepOutcome(error=str(e)) - - -# Note: This is here just for clarity. We could have just imported yield_step directly -# They do the same thing, so we dont need to mock the yield_step function -mock_yield_step: Callable[[StepContext], StepOutcome] = yield_step - -yield_step: Callable[[StepContext], StepOutcome] = activity.defn(name="yield_step")( - yield_step if not testing else mock_yield_step -) diff --git a/agents-api/agents_api/activities/truncation.py b/agents-api/agents_api/activities/truncation.py deleted file mode 100644 index afdb43da4..000000000 --- a/agents-api/agents_api/activities/truncation.py +++ /dev/null @@ -1,60 +0,0 @@ -from uuid import UUID - -from beartype import beartype -from temporalio import activity - -from ..autogen.openapi_model import Entry - -# from agents_api.models.entry.entries_summarization import get_toplevel_entries_query - -# TODO: Reimplement truncation queries -# SCRUM-5 - - -def get_extra_entries(messages: list[Entry], token_count_threshold: int) -> list[UUID]: - raise NotImplementedError() - - if not len(messages): - return messages - - _token_cnt, _offset = 0, 0 - # if messages[0].role == Role.system: - # token_cnt, offset = messages[0].token_count, 1 - - # for m in reversed(messages[offset:]): - # token_cnt += m.token_count - # if token_cnt < token_count_threshold: - # continue - # else: - # result.append(m.id) - - # return result - - -# TODO: Reimplement truncation activities -# SCRUM-6 -@activity.defn -@beartype -async def truncation(session_id: str, token_count_threshold: int) -> None: - session_id = UUID(session_id) - - # delete_entries( - # get_extra_entries( - # [ - # Entry( - # entry_id=row["entry_id"], - # session_id=session_id, - # source=row["source"], - # role=Role(row["role"]), - # name=row["name"], - # content=row["content"], - # created_at=row["created_at"], - # timestamp=row["timestamp"], - # ) - # for _, row in get_toplevel_entries_query( - # session_id=session_id - # ).iterrows() - # ], - # token_count_threshold, - # ), - # ) diff --git a/agents-api/agents_api/activities/types.py b/agents-api/agents_api/activities/types.py deleted file mode 100644 index c2af67936..000000000 --- a/agents-api/agents_api/activities/types.py +++ /dev/null @@ -1,36 +0,0 @@ -from typing import Literal -from uuid import UUID - -from pydantic import BaseModel - -from ..autogen.openapi_model import InputChatMLMessage - - -class MemoryManagementTaskArgs(BaseModel): - session_id: UUID - model: str - dialog: list[InputChatMLMessage] - previous_memories: list[str] = [] - - -class MemoryManagementTask(BaseModel): - name: Literal["memory_management.v1"] - args: MemoryManagementTaskArgs - - -class MemoryRatingTaskArgs(BaseModel): - memory: str - - -class MemoryRatingTask(BaseModel): - name: Literal["memory_rating.v1"] - args: MemoryRatingTaskArgs - - -class EmbedDocsPayload(BaseModel): - developer_id: UUID - doc_id: UUID - content: list[str] - embed_instruction: str | None - title: str | None = None - include_title: bool = False # Need to be a separate parameter for the activity diff --git a/agents-api/agents_api/activities/utils.py b/agents-api/agents_api/activities/utils.py index d9ad1840c..a9d4a11f2 100644 --- a/agents-api/agents_api/activities/utils.py +++ b/agents-api/agents_api/activities/utils.py @@ -1,8 +1,6 @@ import asyncio import base64 import datetime as dt -import functools -import itertools import json import math import random @@ -10,11 +8,11 @@ import string import time import urllib.parse -import zoneinfo from collections import deque +from collections.abc import Callable from dataclasses import dataclass from threading import Lock as ThreadLock -from typing import Any, Callable, ParamSpec, TypeVar +from typing import Any, ParamSpec, TypeVar import re2 from beartype import beartype @@ -24,21 +22,105 @@ from ..common.nlp import nlp from ..common.utils import yaml +# Security limits +MAX_STRING_LENGTH = 1_000_000 # 1MB +MAX_COLLECTION_SIZE = 10_000 +MAX_RANGE_SIZE = 1_000_000 + T = TypeVar("T") R = TypeVar("R") P = ParamSpec("P") +def safe_range(*args): + result = range(*args) + if len(result) > MAX_RANGE_SIZE: + msg = f"Range size exceeds maximum of {MAX_RANGE_SIZE}" + raise ValueError(msg) + return result + + +def safe_json_loads(s: str): + if len(s) > MAX_STRING_LENGTH: + msg = f"String exceeds maximum length of {MAX_STRING_LENGTH}" + raise ValueError(msg) + return json.loads(s) + + +def safe_yaml_load(s: str): + if len(s) > MAX_STRING_LENGTH: + msg = f"String exceeds maximum length of {MAX_STRING_LENGTH}" + raise ValueError(msg) + return yaml.load(s) + + +def safe_base64_decode(s: str) -> str: + if len(s) > MAX_STRING_LENGTH: + msg = f"String exceeds maximum length of {MAX_STRING_LENGTH}" + raise ValueError(msg) + try: + return base64.b64decode(s).decode("utf-8") + except Exception as e: + msg = f"Invalid base64 string: {e}" + raise ValueError(msg) + + +def safe_base64_encode(s: str) -> str: + if len(s) > MAX_STRING_LENGTH: + msg = f"String exceeds maximum length of {MAX_STRING_LENGTH}" + raise ValueError(msg) + return base64.b64encode(s.encode("utf-8")).decode("utf-8") + + +def safe_random_choice(seq): + if len(seq) > MAX_COLLECTION_SIZE: + msg = f"Sequence exceeds maximum size of {MAX_COLLECTION_SIZE}" + raise ValueError(msg) + return random.choice(seq) + + +def safe_random_sample(population, k): + if len(population) > MAX_COLLECTION_SIZE: + msg = f"Population exceeds maximum size of {MAX_COLLECTION_SIZE}" + raise ValueError(msg) + if k > MAX_COLLECTION_SIZE: + msg = f"Sample size exceeds maximum of {MAX_COLLECTION_SIZE}" + raise ValueError(msg) + if k > len(population): + msg = "Sample size cannot exceed population size" + raise ValueError(msg) + return random.sample(population, k) + + def chunk_doc(string: str) -> list[str]: """ Chunk a string into sentences. """ + if len(string) > MAX_STRING_LENGTH: + msg = f"String exceeds maximum length of {MAX_STRING_LENGTH}" + raise ValueError(msg) doc = nlp(string) return [" ".join([sent.text for sent in chunk]) for chunk in doc._.chunks] -# TODO: We need to make sure that we dont expose any security issues +def safe_extract_json(string: str): + if len(string) > MAX_STRING_LENGTH: + msg = f"String exceeds maximum length of {MAX_STRING_LENGTH}" + raise ValueError(msg) + # Check if the string contains JSON code block markers + if "```json" in string: + extracted_string = string[ + string.find("```json") + 7 : string.find("```", string.find("```json") + 7) + ] + else: + # If no markers, try to parse the whole string as JSON + extracted_string = string + return json.loads(extracted_string) + + +# Restricted set of allowed functions ALLOWED_FUNCTIONS = { + # Basic Python builtins "abs": abs, "all": all, "any": any, @@ -46,32 +128,34 @@ def chunk_doc(string: str) -> list[str]: "dict": dict, "enumerate": enumerate, "float": float, - "frozenset": frozenset, "int": int, "len": len, "list": list, "map": map, "max": max, "min": min, - "range": range, "round": round, "set": set, "str": str, "sum": sum, "tuple": tuple, - "reduce": functools.reduce, "zip": zip, - "search_regex": lambda pattern, string: re2.search(pattern, string), - "load_json": json.loads, - "load_yaml": yaml.load, + # Safe versions of potentially dangerous functions + "range": safe_range, + "load_json": safe_json_loads, + "load_yaml": safe_yaml_load, "dump_json": json.dumps, "dump_yaml": yaml.dump, + "extract_json": safe_extract_json, + # Regex and NLP functions (using re2 which is safe against ReDoS) + "search_regex": lambda pattern, string: re2.search(pattern, string), "match_regex": lambda pattern, string: bool(re2.fullmatch(pattern, string)), "nlp": nlp.__call__, "chunk_doc": chunk_doc, } +# Safe regex operations (using re2) class stdlib_re: fullmatch = re2.fullmatch search = re2.search @@ -84,59 +168,19 @@ class stdlib_re: subn = re2.subn +# Safe JSON operations class stdlib_json: - loads = json.loads + loads = safe_json_loads dumps = json.dumps +# Safe YAML operations class stdlib_yaml: - load = yaml.load + load = safe_yaml_load dump = yaml.dump -class stdlib_time: - strftime = time.strftime - strptime = time.strptime - time = time - - -class stdlib_random: - choice = random.choice - choices = random.choices - sample = random.sample - shuffle = random.shuffle - randrange = random.randrange - randint = random.randint - random = random.random - - -class stdlib_itertools: - accumulate = itertools.accumulate - - -class stdlib_functools: - partial = functools.partial - reduce = functools.reduce - - -class stdlib_base64: - b64encode = base64.b64encode - b64decode = base64.b64decode - - -class stdlib_urllib: - class parse: - urlparse = urllib.parse.urlparse - urlencode = urllib.parse.urlencode - unquote = urllib.parse.unquote - quote = urllib.parse.quote - parse_qs = urllib.parse.parse_qs - parse_qsl = urllib.parse.parse_qsl - urlsplit = urllib.parse.urlsplit - urljoin = urllib.parse.urljoin - unwrap = urllib.parse.unwrap - - +# Safe string constants class stdlib_string: ascii_letters = string.ascii_letters ascii_lowercase = string.ascii_lowercase @@ -149,14 +193,11 @@ class stdlib_string: printable = string.printable -class stdlib_zoneinfo: - ZoneInfo = zoneinfo.ZoneInfo - - +# Safe datetime operations class stdlib_datetime: class timezone: class utc: - utc = dt.timezone.utc + utc = dt.UTC class datetime: now = dt.datetime.now @@ -168,6 +209,7 @@ class datetime: timedelta = dt.timedelta +# Safe math operations class stdlib_math: sqrt = math.sqrt exp = math.exp @@ -191,6 +233,7 @@ class stdlib_math: e = math.e +# Safe statistics operations class stdlib_statistics: mean = statistics.mean stdev = statistics.stdev @@ -202,21 +245,57 @@ class stdlib_statistics: quantiles = statistics.quantiles +# Safe base64 operations +class stdlib_base64: + b64encode = safe_base64_encode + b64decode = safe_base64_decode + + +# Safe URL parsing operations +class stdlib_urllib: + class parse: + # Safe URL parsing operations that don't touch filesystem/network + urlparse = urllib.parse.urlparse + urlencode = urllib.parse.urlencode + unquote = urllib.parse.unquote + quote = urllib.parse.quote + parse_qs = urllib.parse.parse_qs + parse_qsl = urllib.parse.parse_qsl + urlsplit = urllib.parse.urlsplit + + +# Safe random operations +class stdlib_random: + # Limit to safe operations with bounded inputs + choice = safe_random_choice + sample = safe_random_sample + # Safe bounded random number generators + randint = random.randint # Already bounded by integer limits + random = random.random # Always returns 0.0 to 1.0 + + +# Safe time operations +class stdlib_time: + # Time formatting/parsing operations + strftime = time.strftime + strptime = time.strptime + # Current time (safe, no side effects) + time = time.time + + +# Restricted stdlib with only safe operations stdlib = { "re": stdlib_re, "json": stdlib_json, "yaml": stdlib_yaml, - "time": stdlib_time, - "random": stdlib_random, - "itertools": stdlib_itertools, - "functools": stdlib_functools, - "base64": stdlib_base64, - "urllib": stdlib_urllib, "string": stdlib_string, - "zoneinfo": stdlib_zoneinfo, "datetime": stdlib_datetime, "math": stdlib_math, "statistics": stdlib_statistics, + "base64": stdlib_base64, + "urllib": stdlib_urllib, + "random": stdlib_random, + "time": stdlib_time, } constants = { @@ -231,18 +310,33 @@ class stdlib_statistics: def get_evaluator( names: dict[str, Any], extra_functions: dict[str, Callable] | None = None ) -> SimpleEval: + if len(names) > MAX_COLLECTION_SIZE: + msg = f"Too many variables (max {MAX_COLLECTION_SIZE})" + raise ValueError(msg) + evaluator = EvalWithCompoundTypes( names=names | stdlib | constants, functions=ALLOWED_FUNCTIONS | (extra_functions or {}), ) + # Add maximum execution time + evaluator.TIMEOUT = 1.0 # 1 second timeout + return evaluator @beartype def simple_eval_dict(exprs: dict[str, str], values: dict[str, Any]) -> dict[str, Any]: - evaluator = get_evaluator(names=values) + if len(exprs) > MAX_COLLECTION_SIZE: + msg = f"Too many expressions (max {MAX_COLLECTION_SIZE})" + raise ValueError(msg) + + for v in exprs.values(): + if len(v) > MAX_STRING_LENGTH: + msg = f"Expression exceeds maximum length of {MAX_STRING_LENGTH}" + raise ValueError(msg) + evaluator = get_evaluator(names=values) return {k: evaluator.eval(v) for k, v in exprs.items()} @@ -277,9 +371,7 @@ def filtered_handler(*args, **kwargs): # Remove problematic parameters filtered_handler.__signature__ = sig.replace( - parameters=[ - p for p in sig.parameters.values() if p.name not in parameters_to_exclude - ] + parameters=[p for p in sig.parameters.values() if p.name not in parameters_to_exclude] ) return filtered_handler @@ -296,28 +388,28 @@ def get_handler(system: SystemDef) -> Callable: The base handler function. """ - from ..models.agent.create_agent import create_agent as create_agent_query - from ..models.agent.delete_agent import delete_agent as delete_agent_query - from ..models.agent.get_agent import get_agent as get_agent_query - from ..models.agent.list_agents import list_agents as list_agents_query - from ..models.agent.update_agent import update_agent as update_agent_query - from ..models.docs.delete_doc import delete_doc as delete_doc_query - from ..models.docs.list_docs import list_docs as list_docs_query - from ..models.session.create_session import create_session as create_session_query - from ..models.session.delete_session import delete_session as delete_session_query - from ..models.session.get_session import get_session as get_session_query - from ..models.session.list_sessions import list_sessions as list_sessions_query - from ..models.session.update_session import update_session as update_session_query - from ..models.task.create_task import create_task as create_task_query - from ..models.task.delete_task import delete_task as delete_task_query - from ..models.task.get_task import get_task as get_task_query - from ..models.task.list_tasks import list_tasks as list_tasks_query - from ..models.task.update_task import update_task as update_task_query - from ..models.user.create_user import create_user as create_user_query - from ..models.user.delete_user import delete_user as delete_user_query - from ..models.user.get_user import get_user as get_user_query - from ..models.user.list_users import list_users as list_users_query - from ..models.user.update_user import update_user as update_user_query + from ..queries.agents.create_agent import create_agent as create_agent_query + from ..queries.agents.delete_agent import delete_agent as delete_agent_query + from ..queries.agents.get_agent import get_agent as get_agent_query + from ..queries.agents.list_agents import list_agents as list_agents_query + from ..queries.agents.update_agent import update_agent as update_agent_query + from ..queries.docs.delete_doc import delete_doc as delete_doc_query + from ..queries.docs.list_docs import list_docs as list_docs_query + from ..queries.sessions.create_session import create_session as create_session_query + from ..queries.sessions.delete_session import delete_session as delete_session_query + from ..queries.sessions.get_session import get_session as get_session_query + from ..queries.sessions.list_sessions import list_sessions as list_sessions_query + from ..queries.sessions.update_session import update_session as update_session_query + from ..queries.tasks.create_task import create_task as create_task_query + from ..queries.tasks.delete_task import delete_task as delete_task_query + from ..queries.tasks.get_task import get_task as get_task_query + from ..queries.tasks.list_tasks import list_tasks as list_tasks_query + from ..queries.tasks.update_task import update_task as update_task_query + from ..queries.users.create_user import create_user as create_user_query + from ..queries.users.delete_user import delete_user as delete_user_query + from ..queries.users.get_user import get_user as get_user_query + from ..queries.users.list_users import list_users as list_users_query + from ..queries.users.update_user import update_user as update_user_query from ..routers.docs.create_doc import create_agent_doc, create_user_doc from ..routers.docs.search_docs import search_agent_docs, search_user_docs from ..routers.sessions.chat import chat @@ -390,9 +482,8 @@ def get_handler(system: SystemDef) -> Callable: return delete_task_query case _: - raise NotImplementedError( - f"System call not implemented for {system.resource}.{system.operation}" - ) + msg = f"System call not implemented for {system.resource}.{system.operation}" + raise NotImplementedError(msg) @dataclass diff --git a/agents-api/agents_api/app.py b/agents-api/agents_api/app.py new file mode 100644 index 000000000..c977491bc --- /dev/null +++ b/agents-api/agents_api/app.py @@ -0,0 +1,116 @@ +import os +from contextlib import asynccontextmanager +from typing import Any, Protocol + +from aiobotocore.session import get_session +from fastapi import APIRouter, FastAPI +from prometheus_fastapi_instrumentator import Instrumentator +from scalar_fastapi import get_scalar_api_reference + +from .clients.pg import create_db_pool +from .env import api_prefix, hostname, protocol, public_port + + +class Assignable(Protocol): + def __setattr__(self, name: str, value: Any) -> None: ... + + +class ObjectWithState(Protocol): + state: Assignable + + +# TODO: This currently doesn't use env.py, we should move to using them +@asynccontextmanager +async def lifespan(*containers: list[FastAPI | ObjectWithState]): + # INIT POSTGRES # + pg_dsn = os.environ.get("PG_DSN") + + for container in containers: + if not getattr(container.state, "postgres_pool", None): + container.state.postgres_pool = await create_db_pool(pg_dsn) + + # INIT S3 # + s3_access_key = os.environ.get("S3_ACCESS_KEY") + s3_secret_key = os.environ.get("S3_SECRET_KEY") + s3_endpoint = os.environ.get("S3_ENDPOINT") + + for container in containers: + if not getattr(container.state, "s3_client", None): + session = get_session() + container.state.s3_client = await session.create_client( + "s3", + aws_access_key_id=s3_access_key, + aws_secret_access_key=s3_secret_key, + endpoint_url=s3_endpoint, + ).__aenter__() + + try: + yield + finally: + # CLOSE POSTGRES # + for container in containers: + if getattr(container.state, "postgres_pool", None): + await container.state.postgres_pool.close() + container.state.postgres_pool = None + + # CLOSE S3 # + for container in containers: + if getattr(container.state, "s3_client", None): + await container.state.s3_client.close() + container.state.s3_client = None + + +app: FastAPI = FastAPI( + docs_url="/swagger", + openapi_prefix=api_prefix, + redoc_url=None, + title="Julep Agents API", + description="API for Julep Agents", + version="0.4.0", + terms_of_service="https://www.julep.ai/terms", + contact={ + "name": "Julep", + "url": "https://www.julep.ai", + "email": "developers@julep.ai", + }, + root_path=api_prefix, + lifespan=lifespan, +) + +# Enable metrics +Instrumentator().instrument(app).expose(app, include_in_schema=False) + + +# Create a new router for the docs +scalar_router = APIRouter() + + +@scalar_router.get("/docs", include_in_schema=False) +async def scalar_html(): + return get_scalar_api_reference( + openapi_url=app.openapi_url[1:], # Remove leading '/' + title=app.title, + servers=[{"url": f"{protocol}://{hostname}:{public_port}{api_prefix}"}], + ) + + +# Add the docs_router without dependencies +app.include_router(scalar_router) + + +# TODO: Implement correct content-length validation (using streaming and chunked transfer encoding) +# NOTE: This relies on client reporting the correct content-length header +# @app.middleware("http") +# async def validate_content_length( +# request: Request, +# call_next: Callable[[Request], Coroutine[Any, Any, Response]], +# ): +# content_length = request.headers.get("content-length") + +# if not content_length: +# return Response(status_code=411, content="Content-Length header is required") + +# if int(content_length) > max_payload_size: +# return Response(status_code=413, content="Payload too large") + +# return await call_next(request) diff --git a/agents-api/agents_api/autogen/Agents.py b/agents-api/agents_api/autogen/Agents.py index 5dab2c7b2..7390b6338 100644 --- a/agents-api/agents_api/autogen/Agents.py +++ b/agents-api/agents_api/autogen/Agents.py @@ -25,16 +25,17 @@ class Agent(BaseModel): """ When this resource was updated as UTC date-time """ - name: Annotated[ - str, - Field( - max_length=120, - pattern="^[\\p{L}\\p{Nl}\\p{Pattern_Syntax}\\p{Pattern_White_Space}]+[\\p{ID_Start}\\p{Mn}\\p{Mc}\\p{Nd}\\p{Pc}\\p{Pattern_Syntax}\\p{Pattern_White_Space}]*$", - ), - ] = "" + name: Annotated[str, Field(max_length=255, min_length=1)] """ Name of the agent """ + canonical_name: Annotated[ + str | None, + Field(max_length=255, min_length=1, pattern="^[a-zA-Z][a-zA-Z0-9_]*$"), + ] = None + """ + Canonical name of the agent + """ about: str = "" """ About the agent @@ -62,16 +63,17 @@ class CreateAgentRequest(BaseModel): populate_by_name=True, ) metadata: dict[str, Any] | None = None - name: Annotated[ - str, - Field( - max_length=120, - pattern="^[\\p{L}\\p{Nl}\\p{Pattern_Syntax}\\p{Pattern_White_Space}]+[\\p{ID_Start}\\p{Mn}\\p{Mc}\\p{Nd}\\p{Pc}\\p{Pattern_Syntax}\\p{Pattern_White_Space}]*$", - ), - ] = "" + name: Annotated[str, Field(max_length=255, min_length=1)] """ Name of the agent """ + canonical_name: Annotated[ + str | None, + Field(max_length=255, min_length=1, pattern="^[a-zA-Z][a-zA-Z0-9_]*$"), + ] = None + """ + Canonical name of the agent + """ about: str = "" """ About the agent @@ -96,16 +98,17 @@ class CreateOrUpdateAgentRequest(CreateAgentRequest): ) id: UUID metadata: dict[str, Any] | None = None - name: Annotated[ - str, - Field( - max_length=120, - pattern="^[\\p{L}\\p{Nl}\\p{Pattern_Syntax}\\p{Pattern_White_Space}]+[\\p{ID_Start}\\p{Mn}\\p{Mc}\\p{Nd}\\p{Pc}\\p{Pattern_Syntax}\\p{Pattern_White_Space}]*$", - ), - ] = "" + name: Annotated[str, Field(max_length=255, min_length=1)] """ Name of the agent """ + canonical_name: Annotated[ + str | None, + Field(max_length=255, min_length=1, pattern="^[a-zA-Z][a-zA-Z0-9_]*$"), + ] = None + """ + Canonical name of the agent + """ about: str = "" """ About the agent @@ -133,16 +136,17 @@ class PatchAgentRequest(BaseModel): populate_by_name=True, ) metadata: dict[str, Any] | None = None - name: Annotated[ - str, - Field( - max_length=120, - pattern="^[\\p{L}\\p{Nl}\\p{Pattern_Syntax}\\p{Pattern_White_Space}]+[\\p{ID_Start}\\p{Mn}\\p{Mc}\\p{Nd}\\p{Pc}\\p{Pattern_Syntax}\\p{Pattern_White_Space}]*$", - ), - ] = "" + name: Annotated[str | None, Field(max_length=255, min_length=1)] = None """ Name of the agent """ + canonical_name: Annotated[ + str | None, + Field(max_length=255, min_length=1, pattern="^[a-zA-Z][a-zA-Z0-9_]*$"), + ] = None + """ + Canonical name of the agent + """ about: str = "" """ About the agent @@ -170,16 +174,17 @@ class UpdateAgentRequest(BaseModel): populate_by_name=True, ) metadata: dict[str, Any] | None = None - name: Annotated[ - str, - Field( - max_length=120, - pattern="^[\\p{L}\\p{Nl}\\p{Pattern_Syntax}\\p{Pattern_White_Space}]+[\\p{ID_Start}\\p{Mn}\\p{Mc}\\p{Nd}\\p{Pc}\\p{Pattern_Syntax}\\p{Pattern_White_Space}]*$", - ), - ] = "" + name: Annotated[str, Field(max_length=255, min_length=1)] """ Name of the agent """ + canonical_name: Annotated[ + str | None, + Field(max_length=255, min_length=1, pattern="^[a-zA-Z][a-zA-Z0-9_]*$"), + ] = None + """ + Canonical name of the agent + """ about: str = "" """ About the agent diff --git a/agents-api/agents_api/autogen/Chat.py b/agents-api/agents_api/autogen/Chat.py index 042f9164d..13dcc9532 100644 --- a/agents-api/agents_api/autogen/Chat.py +++ b/agents-api/agents_api/autogen/Chat.py @@ -59,9 +59,7 @@ class BaseChatResponse(BaseModel): """ Background job IDs that may have been spawned from this interaction. """ - docs: Annotated[ - list[DocReference], Field(json_schema_extra={"readOnly": True}) - ] = [] + docs: Annotated[list[DocReference], Field(json_schema_extra={"readOnly": True})] = [] """ Documents referenced for this request (for citation purposes). """ @@ -134,21 +132,15 @@ class CompetionUsage(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - completion_tokens: Annotated[ - int | None, Field(json_schema_extra={"readOnly": True}) - ] = None + completion_tokens: Annotated[int | None, Field(json_schema_extra={"readOnly": True})] = None """ Number of tokens in the generated completion """ - prompt_tokens: Annotated[ - int | None, Field(json_schema_extra={"readOnly": True}) - ] = None + prompt_tokens: Annotated[int | None, Field(json_schema_extra={"readOnly": True})] = None """ Number of tokens in the prompt """ - total_tokens: Annotated[int | None, Field(json_schema_extra={"readOnly": True})] = ( - None - ) + total_tokens: Annotated[int | None, Field(json_schema_extra={"readOnly": True})] = None """ Total number of tokens used in the request (prompt + completion) """ @@ -429,9 +421,9 @@ class MessageModel(BaseModel): """ Tool calls generated by the model. """ - created_at: Annotated[ - AwareDatetime | None, Field(json_schema_extra={"readOnly": True}) - ] = None + created_at: Annotated[AwareDatetime | None, Field(json_schema_extra={"readOnly": True})] = ( + None + ) """ When this resource was created as UTC date-time """ @@ -576,9 +568,9 @@ class ChatInput(ChatInputData): """ Modify the likelihood of specified tokens appearing in the completion """ - response_format: ( - SimpleCompletionResponseFormat | SchemaCompletionResponseFormat | None - ) = None + response_format: SimpleCompletionResponseFormat | SchemaCompletionResponseFormat | None = ( + None + ) """ Response format (set to `json_object` to restrict output to JSON) """ @@ -672,9 +664,9 @@ class ChatSettings(DefaultChatSettings): """ Modify the likelihood of specified tokens appearing in the completion """ - response_format: ( - SimpleCompletionResponseFormat | SchemaCompletionResponseFormat | None - ) = None + response_format: SimpleCompletionResponseFormat | SchemaCompletionResponseFormat | None = ( + None + ) """ Response format (set to `json_object` to restrict output to JSON) """ diff --git a/agents-api/agents_api/autogen/Docs.py b/agents-api/agents_api/autogen/Docs.py index ffed27c1d..28a421ba5 100644 --- a/agents-api/agents_api/autogen/Docs.py +++ b/agents-api/agents_api/autogen/Docs.py @@ -73,6 +73,24 @@ class Doc(BaseModel): """ Embeddings for the document """ + modality: Annotated[str | None, Field(json_schema_extra={"readOnly": True})] = None + """ + Modality of the document + """ + language: Annotated[str | None, Field(json_schema_extra={"readOnly": True})] = None + """ + Language of the document + """ + embedding_model: Annotated[str | None, Field(json_schema_extra={"readOnly": True})] = None + """ + Embedding model used for the document + """ + embedding_dimensions: Annotated[int | None, Field(json_schema_extra={"readOnly": True})] = ( + None + ) + """ + Dimensions of the embedding model + """ class DocOwner(BaseModel): diff --git a/agents-api/agents_api/autogen/Entries.py b/agents-api/agents_api/autogen/Entries.py index de37e77d8..867b10192 100644 --- a/agents-api/agents_api/autogen/Entries.py +++ b/agents-api/agents_api/autogen/Entries.py @@ -52,6 +52,7 @@ class BaseEntry(BaseModel): ] tokenizer: str token_count: int + model: str = "gpt-4o-mini" tool_calls: ( list[ ChosenFunctionCall diff --git a/agents-api/agents_api/autogen/Executions.py b/agents-api/agents_api/autogen/Executions.py index 5ccc57e83..36a36b7a5 100644 --- a/agents-api/agents_api/autogen/Executions.py +++ b/agents-api/agents_api/autogen/Executions.py @@ -181,8 +181,6 @@ class Transition(TransitionEvent): ) execution_id: Annotated[UUID, Field(json_schema_extra={"readOnly": True})] current: Annotated[TransitionTarget, Field(json_schema_extra={"readOnly": True})] - next: Annotated[ - TransitionTarget | None, Field(json_schema_extra={"readOnly": True}) - ] + next: Annotated[TransitionTarget | None, Field(json_schema_extra={"readOnly": True})] id: Annotated[UUID, Field(json_schema_extra={"readOnly": True})] metadata: dict[str, Any] | None = None diff --git a/agents-api/agents_api/autogen/Sessions.py b/agents-api/agents_api/autogen/Sessions.py index 460fd25ce..20c9885b1 100644 --- a/agents-api/agents_api/autogen/Sessions.py +++ b/agents-api/agents_api/autogen/Sessions.py @@ -27,9 +27,13 @@ class CreateSessionRequest(BaseModel): Agent ID of agent associated with this session """ agents: list[UUID] | None = None - situation: str = '{%- if agent.name -%}\nYou are {{agent.name}}.{{" "}}\n{%- endif -%}\n\n{%- if agent.about -%}\nAbout you: {{agent.about}}.{{" "}}\n{%- endif -%}\n\n{%- if user -%}\nYou are talking to a user\n {%- if user.name -%}{{" "}} and their name is {{user.name}}\n {%- if user.about -%}. About the user: {{user.about}}.{%- else -%}.{%- endif -%}\n {%- endif -%}\n{%- endif -%}\n\n{{NEWLINE+NEWLINE}}\n\n{%- if agent.instructions -%}\nInstructions:{{NEWLINE}}\n {%- if agent.instructions is string -%}\n {{agent.instructions}}{{NEWLINE}}\n {%- else -%}\n {%- for instruction in agent.instructions -%}\n - {{instruction}}{{NEWLINE}}\n {%- endfor -%}\n {%- endif -%}\n {{NEWLINE}}\n{%- endif -%}\n\n{%- if tools -%}\nTools:{{NEWLINE}}\n {%- for tool in tools -%}\n - {{tool.name + NEWLINE}}\n {%- if tool.description -%}: {{tool.description + NEWLINE}}{%- endif -%}\n {%- endfor -%}\n{{NEWLINE+NEWLINE}}\n{%- endif -%}\n\n{%- if docs -%}\nRelevant documents:{{NEWLINE}}\n {%- for doc in docs -%}\n {{doc.title}}{{NEWLINE}}\n {%- if doc.content is string -%}\n {{doc.content}}{{NEWLINE}}\n {%- else -%}\n {%- for snippet in doc.content -%}\n {{snippet}}{{NEWLINE}}\n {%- endfor -%}\n {%- endif -%}\n {{"---"}}\n {%- endfor -%}\n{%- endif -%}' + situation: str | None = None """ - A specific situation that sets the background for this session + Session situation + """ + system_template: str = '{%- if agent.name -%}\nYou are {{agent.name}}.{{" "}}\n{%- endif -%}\n\n{%- if agent.about -%}\nAbout you: {{agent.about}}.{{" "}}\n{%- endif -%}\n\n{%- if user -%}\nYou are talking to a user\n {%- if user.name -%}{{" "}} and their name is {{user.name}}\n {%- if user.about -%}. About the user: {{user.about}}.{%- else -%}.{%- endif -%}\n {%- endif -%}\n{%- endif -%}\n\n{{NEWLINE}}\n\n{%- if session.situation -%}\nSituation: {{session.situation}}\n{%- endif -%}\n\n{{NEWLINE+NEWLINE}}\n\n{%- if agent.instructions -%}\nInstructions:{{NEWLINE}}\n {%- if agent.instructions is string -%}\n {{agent.instructions}}{{NEWLINE}}\n {%- else -%}\n {%- for instruction in agent.instructions -%}\n - {{instruction}}{{NEWLINE}}\n {%- endfor -%}\n {%- endif -%}\n {{NEWLINE}}\n{%- endif -%}\n\n{%- if docs -%}\nRelevant documents:{{NEWLINE}}\n {%- for doc in docs -%}\n {{doc.title}}{{NEWLINE}}\n {%- if doc.content is string -%}\n {{doc.content}}{{NEWLINE}}\n {%- else -%}\n {%- for snippet in doc.content -%}\n {{snippet}}{{NEWLINE}}\n {%- endfor -%}\n {%- endif -%}\n {{"---"}}\n {%- endfor -%}\n{%- endif -%}' + """ + A specific system prompt template that sets the background for this session """ render_templates: StrictBool = True """ @@ -51,6 +55,10 @@ class CreateSessionRequest(BaseModel): If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. """ + forward_tool_calls: StrictBool = False + """ + Whether to forward tool calls to the model + """ recall_options: RecallOptions | None = None metadata: dict[str, Any] | None = None @@ -63,9 +71,13 @@ class PatchSessionRequest(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - situation: str = '{%- if agent.name -%}\nYou are {{agent.name}}.{{" "}}\n{%- endif -%}\n\n{%- if agent.about -%}\nAbout you: {{agent.about}}.{{" "}}\n{%- endif -%}\n\n{%- if user -%}\nYou are talking to a user\n {%- if user.name -%}{{" "}} and their name is {{user.name}}\n {%- if user.about -%}. About the user: {{user.about}}.{%- else -%}.{%- endif -%}\n {%- endif -%}\n{%- endif -%}\n\n{{NEWLINE+NEWLINE}}\n\n{%- if agent.instructions -%}\nInstructions:{{NEWLINE}}\n {%- if agent.instructions is string -%}\n {{agent.instructions}}{{NEWLINE}}\n {%- else -%}\n {%- for instruction in agent.instructions -%}\n - {{instruction}}{{NEWLINE}}\n {%- endfor -%}\n {%- endif -%}\n {{NEWLINE}}\n{%- endif -%}\n\n{%- if tools -%}\nTools:{{NEWLINE}}\n {%- for tool in tools -%}\n - {{tool.name + NEWLINE}}\n {%- if tool.description -%}: {{tool.description + NEWLINE}}{%- endif -%}\n {%- endfor -%}\n{{NEWLINE+NEWLINE}}\n{%- endif -%}\n\n{%- if docs -%}\nRelevant documents:{{NEWLINE}}\n {%- for doc in docs -%}\n {{doc.title}}{{NEWLINE}}\n {%- if doc.content is string -%}\n {{doc.content}}{{NEWLINE}}\n {%- else -%}\n {%- for snippet in doc.content -%}\n {{snippet}}{{NEWLINE}}\n {%- endfor -%}\n {%- endif -%}\n {{"---"}}\n {%- endfor -%}\n{%- endif -%}' + situation: str | None = None + """ + Session situation """ - A specific situation that sets the background for this session + system_template: str = '{%- if agent.name -%}\nYou are {{agent.name}}.{{" "}}\n{%- endif -%}\n\n{%- if agent.about -%}\nAbout you: {{agent.about}}.{{" "}}\n{%- endif -%}\n\n{%- if user -%}\nYou are talking to a user\n {%- if user.name -%}{{" "}} and their name is {{user.name}}\n {%- if user.about -%}. About the user: {{user.about}}.{%- else -%}.{%- endif -%}\n {%- endif -%}\n{%- endif -%}\n\n{{NEWLINE}}\n\n{%- if session.situation -%}\nSituation: {{session.situation}}\n{%- endif -%}\n\n{{NEWLINE+NEWLINE}}\n\n{%- if agent.instructions -%}\nInstructions:{{NEWLINE}}\n {%- if agent.instructions is string -%}\n {{agent.instructions}}{{NEWLINE}}\n {%- else -%}\n {%- for instruction in agent.instructions -%}\n - {{instruction}}{{NEWLINE}}\n {%- endfor -%}\n {%- endif -%}\n {{NEWLINE}}\n{%- endif -%}\n\n{%- if docs -%}\nRelevant documents:{{NEWLINE}}\n {%- for doc in docs -%}\n {{doc.title}}{{NEWLINE}}\n {%- if doc.content is string -%}\n {{doc.content}}{{NEWLINE}}\n {%- else -%}\n {%- for snippet in doc.content -%}\n {{snippet}}{{NEWLINE}}\n {%- endfor -%}\n {%- endif -%}\n {{"---"}}\n {%- endfor -%}\n{%- endif -%}' + """ + A specific system prompt template that sets the background for this session """ render_templates: StrictBool = True """ @@ -87,6 +99,10 @@ class PatchSessionRequest(BaseModel): If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. """ + forward_tool_calls: StrictBool = False + """ + Whether to forward tool calls to the model + """ recall_options: RecallOptionsUpdate | None = None metadata: dict[str, Any] | None = None @@ -117,9 +133,13 @@ class Session(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - situation: str = '{%- if agent.name -%}\nYou are {{agent.name}}.{{" "}}\n{%- endif -%}\n\n{%- if agent.about -%}\nAbout you: {{agent.about}}.{{" "}}\n{%- endif -%}\n\n{%- if user -%}\nYou are talking to a user\n {%- if user.name -%}{{" "}} and their name is {{user.name}}\n {%- if user.about -%}. About the user: {{user.about}}.{%- else -%}.{%- endif -%}\n {%- endif -%}\n{%- endif -%}\n\n{{NEWLINE+NEWLINE}}\n\n{%- if agent.instructions -%}\nInstructions:{{NEWLINE}}\n {%- if agent.instructions is string -%}\n {{agent.instructions}}{{NEWLINE}}\n {%- else -%}\n {%- for instruction in agent.instructions -%}\n - {{instruction}}{{NEWLINE}}\n {%- endfor -%}\n {%- endif -%}\n {{NEWLINE}}\n{%- endif -%}\n\n{%- if tools -%}\nTools:{{NEWLINE}}\n {%- for tool in tools -%}\n - {{tool.name + NEWLINE}}\n {%- if tool.description -%}: {{tool.description + NEWLINE}}{%- endif -%}\n {%- endfor -%}\n{{NEWLINE+NEWLINE}}\n{%- endif -%}\n\n{%- if docs -%}\nRelevant documents:{{NEWLINE}}\n {%- for doc in docs -%}\n {{doc.title}}{{NEWLINE}}\n {%- if doc.content is string -%}\n {{doc.content}}{{NEWLINE}}\n {%- else -%}\n {%- for snippet in doc.content -%}\n {{snippet}}{{NEWLINE}}\n {%- endfor -%}\n {%- endif -%}\n {{"---"}}\n {%- endfor -%}\n{%- endif -%}' + situation: str | None = None + """ + Session situation + """ + system_template: str = '{%- if agent.name -%}\nYou are {{agent.name}}.{{" "}}\n{%- endif -%}\n\n{%- if agent.about -%}\nAbout you: {{agent.about}}.{{" "}}\n{%- endif -%}\n\n{%- if user -%}\nYou are talking to a user\n {%- if user.name -%}{{" "}} and their name is {{user.name}}\n {%- if user.about -%}. About the user: {{user.about}}.{%- else -%}.{%- endif -%}\n {%- endif -%}\n{%- endif -%}\n\n{{NEWLINE}}\n\n{%- if session.situation -%}\nSituation: {{session.situation}}\n{%- endif -%}\n\n{{NEWLINE+NEWLINE}}\n\n{%- if agent.instructions -%}\nInstructions:{{NEWLINE}}\n {%- if agent.instructions is string -%}\n {{agent.instructions}}{{NEWLINE}}\n {%- else -%}\n {%- for instruction in agent.instructions -%}\n - {{instruction}}{{NEWLINE}}\n {%- endfor -%}\n {%- endif -%}\n {{NEWLINE}}\n{%- endif -%}\n\n{%- if docs -%}\nRelevant documents:{{NEWLINE}}\n {%- for doc in docs -%}\n {{doc.title}}{{NEWLINE}}\n {%- if doc.content is string -%}\n {{doc.content}}{{NEWLINE}}\n {%- else -%}\n {%- for snippet in doc.content -%}\n {{snippet}}{{NEWLINE}}\n {%- endfor -%}\n {%- endif -%}\n {{"---"}}\n {%- endfor -%}\n{%- endif -%}' """ - A specific situation that sets the background for this session + A specific system prompt template that sets the background for this session """ summary: Annotated[str | None, Field(json_schema_extra={"readOnly": True})] = None """ @@ -145,6 +165,10 @@ class Session(BaseModel): If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. """ + forward_tool_calls: StrictBool = False + """ + Whether to forward tool calls to the model + """ recall_options: RecallOptions | None = None id: Annotated[UUID, Field(json_schema_extra={"readOnly": True})] metadata: dict[str, Any] | None = None @@ -193,9 +217,13 @@ class UpdateSessionRequest(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - situation: str = '{%- if agent.name -%}\nYou are {{agent.name}}.{{" "}}\n{%- endif -%}\n\n{%- if agent.about -%}\nAbout you: {{agent.about}}.{{" "}}\n{%- endif -%}\n\n{%- if user -%}\nYou are talking to a user\n {%- if user.name -%}{{" "}} and their name is {{user.name}}\n {%- if user.about -%}. About the user: {{user.about}}.{%- else -%}.{%- endif -%}\n {%- endif -%}\n{%- endif -%}\n\n{{NEWLINE+NEWLINE}}\n\n{%- if agent.instructions -%}\nInstructions:{{NEWLINE}}\n {%- if agent.instructions is string -%}\n {{agent.instructions}}{{NEWLINE}}\n {%- else -%}\n {%- for instruction in agent.instructions -%}\n - {{instruction}}{{NEWLINE}}\n {%- endfor -%}\n {%- endif -%}\n {{NEWLINE}}\n{%- endif -%}\n\n{%- if tools -%}\nTools:{{NEWLINE}}\n {%- for tool in tools -%}\n - {{tool.name + NEWLINE}}\n {%- if tool.description -%}: {{tool.description + NEWLINE}}{%- endif -%}\n {%- endfor -%}\n{{NEWLINE+NEWLINE}}\n{%- endif -%}\n\n{%- if docs -%}\nRelevant documents:{{NEWLINE}}\n {%- for doc in docs -%}\n {{doc.title}}{{NEWLINE}}\n {%- if doc.content is string -%}\n {{doc.content}}{{NEWLINE}}\n {%- else -%}\n {%- for snippet in doc.content -%}\n {{snippet}}{{NEWLINE}}\n {%- endfor -%}\n {%- endif -%}\n {{"---"}}\n {%- endfor -%}\n{%- endif -%}' + situation: str | None = None + """ + Session situation """ - A specific situation that sets the background for this session + system_template: str = '{%- if agent.name -%}\nYou are {{agent.name}}.{{" "}}\n{%- endif -%}\n\n{%- if agent.about -%}\nAbout you: {{agent.about}}.{{" "}}\n{%- endif -%}\n\n{%- if user -%}\nYou are talking to a user\n {%- if user.name -%}{{" "}} and their name is {{user.name}}\n {%- if user.about -%}. About the user: {{user.about}}.{%- else -%}.{%- endif -%}\n {%- endif -%}\n{%- endif -%}\n\n{{NEWLINE}}\n\n{%- if session.situation -%}\nSituation: {{session.situation}}\n{%- endif -%}\n\n{{NEWLINE+NEWLINE}}\n\n{%- if agent.instructions -%}\nInstructions:{{NEWLINE}}\n {%- if agent.instructions is string -%}\n {{agent.instructions}}{{NEWLINE}}\n {%- else -%}\n {%- for instruction in agent.instructions -%}\n - {{instruction}}{{NEWLINE}}\n {%- endfor -%}\n {%- endif -%}\n {{NEWLINE}}\n{%- endif -%}\n\n{%- if docs -%}\nRelevant documents:{{NEWLINE}}\n {%- for doc in docs -%}\n {{doc.title}}{{NEWLINE}}\n {%- if doc.content is string -%}\n {{doc.content}}{{NEWLINE}}\n {%- else -%}\n {%- for snippet in doc.content -%}\n {{snippet}}{{NEWLINE}}\n {%- endfor -%}\n {%- endif -%}\n {{"---"}}\n {%- endfor -%}\n{%- endif -%}' + """ + A specific system prompt template that sets the background for this session """ render_templates: StrictBool = True """ @@ -217,6 +245,10 @@ class UpdateSessionRequest(BaseModel): If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. """ + forward_tool_calls: StrictBool = False + """ + Whether to forward tool calls to the model + """ recall_options: RecallOptions | None = None metadata: dict[str, Any] | None = None @@ -236,9 +268,13 @@ class CreateOrUpdateSessionRequest(CreateSessionRequest): Agent ID of agent associated with this session """ agents: list[UUID] | None = None - situation: str = '{%- if agent.name -%}\nYou are {{agent.name}}.{{" "}}\n{%- endif -%}\n\n{%- if agent.about -%}\nAbout you: {{agent.about}}.{{" "}}\n{%- endif -%}\n\n{%- if user -%}\nYou are talking to a user\n {%- if user.name -%}{{" "}} and their name is {{user.name}}\n {%- if user.about -%}. About the user: {{user.about}}.{%- else -%}.{%- endif -%}\n {%- endif -%}\n{%- endif -%}\n\n{{NEWLINE+NEWLINE}}\n\n{%- if agent.instructions -%}\nInstructions:{{NEWLINE}}\n {%- if agent.instructions is string -%}\n {{agent.instructions}}{{NEWLINE}}\n {%- else -%}\n {%- for instruction in agent.instructions -%}\n - {{instruction}}{{NEWLINE}}\n {%- endfor -%}\n {%- endif -%}\n {{NEWLINE}}\n{%- endif -%}\n\n{%- if tools -%}\nTools:{{NEWLINE}}\n {%- for tool in tools -%}\n - {{tool.name + NEWLINE}}\n {%- if tool.description -%}: {{tool.description + NEWLINE}}{%- endif -%}\n {%- endfor -%}\n{{NEWLINE+NEWLINE}}\n{%- endif -%}\n\n{%- if docs -%}\nRelevant documents:{{NEWLINE}}\n {%- for doc in docs -%}\n {{doc.title}}{{NEWLINE}}\n {%- if doc.content is string -%}\n {{doc.content}}{{NEWLINE}}\n {%- else -%}\n {%- for snippet in doc.content -%}\n {{snippet}}{{NEWLINE}}\n {%- endfor -%}\n {%- endif -%}\n {{"---"}}\n {%- endfor -%}\n{%- endif -%}' + situation: str | None = None + """ + Session situation + """ + system_template: str = '{%- if agent.name -%}\nYou are {{agent.name}}.{{" "}}\n{%- endif -%}\n\n{%- if agent.about -%}\nAbout you: {{agent.about}}.{{" "}}\n{%- endif -%}\n\n{%- if user -%}\nYou are talking to a user\n {%- if user.name -%}{{" "}} and their name is {{user.name}}\n {%- if user.about -%}. About the user: {{user.about}}.{%- else -%}.{%- endif -%}\n {%- endif -%}\n{%- endif -%}\n\n{{NEWLINE}}\n\n{%- if session.situation -%}\nSituation: {{session.situation}}\n{%- endif -%}\n\n{{NEWLINE+NEWLINE}}\n\n{%- if agent.instructions -%}\nInstructions:{{NEWLINE}}\n {%- if agent.instructions is string -%}\n {{agent.instructions}}{{NEWLINE}}\n {%- else -%}\n {%- for instruction in agent.instructions -%}\n - {{instruction}}{{NEWLINE}}\n {%- endfor -%}\n {%- endif -%}\n {{NEWLINE}}\n{%- endif -%}\n\n{%- if docs -%}\nRelevant documents:{{NEWLINE}}\n {%- for doc in docs -%}\n {{doc.title}}{{NEWLINE}}\n {%- if doc.content is string -%}\n {{doc.content}}{{NEWLINE}}\n {%- else -%}\n {%- for snippet in doc.content -%}\n {{snippet}}{{NEWLINE}}\n {%- endfor -%}\n {%- endif -%}\n {{"---"}}\n {%- endfor -%}\n{%- endif -%}' """ - A specific situation that sets the background for this session + A specific system prompt template that sets the background for this session """ render_templates: StrictBool = True """ @@ -260,6 +296,10 @@ class CreateOrUpdateSessionRequest(CreateSessionRequest): If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. """ + forward_tool_calls: StrictBool = False + """ + Whether to forward tool calls to the model + """ recall_options: RecallOptions | None = None metadata: dict[str, Any] | None = None diff --git a/agents-api/agents_api/autogen/Tasks.py b/agents-api/agents_api/autogen/Tasks.py index b9212d8cb..ebc3a4b84 100644 --- a/agents-api/agents_api/autogen/Tasks.py +++ b/agents-api/agents_api/autogen/Tasks.py @@ -161,8 +161,21 @@ class CreateTaskRequest(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - name: str + name: Annotated[str, Field(max_length=255, min_length=1)] + """ + The name of the task. + """ + canonical_name: Annotated[ + str | None, + Field(max_length=255, min_length=1, pattern="^[a-zA-Z][a-zA-Z0-9_]*$"), + ] = None + """ + The canonical name of the task. + """ description: str = "" + """ + The description of the task. + """ main: Annotated[ list[ EvaluateStep @@ -206,9 +219,7 @@ class ErrorWorkflowStep(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - kind_: Annotated[Literal["error"], Field(json_schema_extra={"readOnly": True})] = ( - "error" - ) + kind_: Annotated[Literal["error"], Field(json_schema_extra={"readOnly": True})] = "error" """ The kind of step """ @@ -226,9 +237,9 @@ class EvaluateStep(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - kind_: Annotated[ - Literal["evaluate"], Field(json_schema_extra={"readOnly": True}) - ] = "evaluate" + kind_: Annotated[Literal["evaluate"], Field(json_schema_extra={"readOnly": True})] = ( + "evaluate" + ) """ The kind of step """ @@ -294,9 +305,9 @@ class ForeachStep(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - kind_: Annotated[ - Literal["foreach"], Field(json_schema_extra={"readOnly": True}) - ] = "foreach" + kind_: Annotated[Literal["foreach"], Field(json_schema_extra={"readOnly": True})] = ( + "foreach" + ) """ The kind of step """ @@ -332,9 +343,7 @@ class GetStep(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - kind_: Annotated[Literal["get"], Field(json_schema_extra={"readOnly": True})] = ( - "get" - ) + kind_: Annotated[Literal["get"], Field(json_schema_extra={"readOnly": True})] = "get" """ The kind of step """ @@ -352,9 +361,9 @@ class IfElseWorkflowStep(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - kind_: Annotated[ - Literal["if_else"], Field(json_schema_extra={"readOnly": True}) - ] = "if_else" + kind_: Annotated[Literal["if_else"], Field(json_schema_extra={"readOnly": True})] = ( + "if_else" + ) """ The kind of step """ @@ -476,9 +485,7 @@ class LogStep(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - kind_: Annotated[Literal["log"], Field(json_schema_extra={"readOnly": True})] = ( - "log" - ) + kind_: Annotated[Literal["log"], Field(json_schema_extra={"readOnly": True})] = "log" """ The kind of step """ @@ -496,9 +503,9 @@ class Main(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - kind_: Annotated[ - Literal["map_reduce"], Field(json_schema_extra={"readOnly": True}) - ] = "map_reduce" + kind_: Annotated[Literal["map_reduce"], Field(json_schema_extra={"readOnly": True})] = ( + "map_reduce" + ) """ The kind of step """ @@ -510,15 +517,7 @@ class Main(BaseModel): """ The variable to iterate over """ - map: ( - EvaluateStep - | ToolCallStep - | PromptStep - | GetStep - | SetStep - | LogStep - | YieldStep - ) + map: EvaluateStep | ToolCallStep | PromptStep | GetStep | SetStep | LogStep | YieldStep """ The steps to run for each iteration """ @@ -586,9 +585,9 @@ class ParallelStep(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - kind_: Annotated[ - Literal["parallel"], Field(json_schema_extra={"readOnly": True}) - ] = "parallel" + kind_: Annotated[Literal["parallel"], Field(json_schema_extra={"readOnly": True})] = ( + "parallel" + ) """ The kind of step """ @@ -598,13 +597,7 @@ class ParallelStep(BaseModel): """ parallel: Annotated[ list[ - EvaluateStep - | ToolCallStep - | PromptStep - | GetStep - | SetStep - | LogStep - | YieldStep + EvaluateStep | ToolCallStep | PromptStep | GetStep | SetStep | LogStep | YieldStep ], Field(max_length=100), ] @@ -650,7 +643,21 @@ class PatchTaskRequest(BaseModel): model_config = ConfigDict( populate_by_name=True, ) + name: Annotated[str | None, Field(max_length=255, min_length=1)] = None + """ + The name of the task. + """ + canonical_name: Annotated[ + str | None, + Field(max_length=255, min_length=1, pattern="^[a-zA-Z][a-zA-Z0-9_]*$"), + ] = None + """ + The canonical name of the task. + """ description: str = "" + """ + The description of the task. + """ main: Annotated[ list[ EvaluateStep @@ -733,9 +740,7 @@ class PromptStep(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - kind_: Annotated[Literal["prompt"], Field(json_schema_extra={"readOnly": True})] = ( - "prompt" - ) + kind_: Annotated[Literal["prompt"], Field(json_schema_extra={"readOnly": True})] = "prompt" """ The kind of step """ @@ -827,9 +832,7 @@ class ReturnStep(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - kind_: Annotated[Literal["return"], Field(json_schema_extra={"readOnly": True})] = ( - "return" - ) + kind_: Annotated[Literal["return"], Field(json_schema_extra={"readOnly": True})] = "return" """ The kind of step """ @@ -850,9 +853,7 @@ class SetStep(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - kind_: Annotated[Literal["set"], Field(json_schema_extra={"readOnly": True})] = ( - "set" - ) + kind_: Annotated[Literal["set"], Field(json_schema_extra={"readOnly": True})] = "set" """ The kind of step """ @@ -892,9 +893,7 @@ class SleepStep(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - kind_: Annotated[Literal["sleep"], Field(json_schema_extra={"readOnly": True})] = ( - "sleep" - ) + kind_: Annotated[Literal["sleep"], Field(json_schema_extra={"readOnly": True})] = "sleep" """ The kind of step """ @@ -924,9 +923,7 @@ class SwitchStep(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - kind_: Annotated[Literal["switch"], Field(json_schema_extra={"readOnly": True})] = ( - "switch" - ) + kind_: Annotated[Literal["switch"], Field(json_schema_extra={"readOnly": True})] = "switch" """ The kind of step """ @@ -966,8 +963,21 @@ class Task(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - name: str + name: Annotated[str, Field(max_length=255, min_length=1)] + """ + The name of the task. + """ + canonical_name: Annotated[ + str | None, + Field(max_length=255, min_length=1, pattern="^[a-zA-Z][a-zA-Z0-9_]*$"), + ] = None + """ + The canonical name of the task. + """ description: str = "" + """ + The description of the task. + """ main: Annotated[ list[ EvaluateStep @@ -1020,9 +1030,7 @@ class TaskTool(CreateToolRequest): model_config = ConfigDict( populate_by_name=True, ) - inherited: Annotated[StrictBool, Field(json_schema_extra={"readOnly": True})] = ( - False - ) + inherited: Annotated[StrictBool, Field(json_schema_extra={"readOnly": True})] = False """ Read-only: Whether the tool was inherited or not. Only applies within tasks. """ @@ -1032,9 +1040,9 @@ class ToolCallStep(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - kind_: Annotated[ - Literal["tool_call"], Field(json_schema_extra={"readOnly": True}) - ] = "tool_call" + kind_: Annotated[Literal["tool_call"], Field(json_schema_extra={"readOnly": True})] = ( + "tool_call" + ) """ The kind of step """ @@ -1057,9 +1065,7 @@ class ToolCallStep(BaseModel): dict[ str, dict[str, list[str] | dict[str, str] | list[dict[str, str]] | str] - | list[ - dict[str, list[str] | dict[str, str] | list[dict[str, str]] | str] - ] + | list[dict[str, list[str] | dict[str, str] | list[dict[str, str]] | str]] | str, ] ] @@ -1124,7 +1130,21 @@ class UpdateTaskRequest(BaseModel): model_config = ConfigDict( populate_by_name=True, ) + name: Annotated[str, Field(max_length=255, min_length=1)] + """ + The name of the task. + """ + canonical_name: Annotated[ + str | None, + Field(max_length=255, min_length=1, pattern="^[a-zA-Z][a-zA-Z0-9_]*$"), + ] = None + """ + The canonical name of the task. + """ description: str = "" + """ + The description of the task. + """ main: Annotated[ list[ EvaluateStep @@ -1178,9 +1198,9 @@ class WaitForInputStep(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - kind_: Annotated[ - Literal["wait_for_input"], Field(json_schema_extra={"readOnly": True}) - ] = "wait_for_input" + kind_: Annotated[Literal["wait_for_input"], Field(json_schema_extra={"readOnly": True})] = ( + "wait_for_input" + ) """ The kind of step """ @@ -1198,9 +1218,7 @@ class YieldStep(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - kind_: Annotated[Literal["yield"], Field(json_schema_extra={"readOnly": True})] = ( - "yield" - ) + kind_: Annotated[Literal["yield"], Field(json_schema_extra={"readOnly": True})] = "yield" """ The kind of step """ @@ -1214,8 +1232,7 @@ class YieldStep(BaseModel): VALIDATION: Should resolve to a defined subworkflow. """ arguments: ( - dict[str, list[str] | dict[str, str] | list[dict[str, str]] | str] - | Literal["_"] + dict[str, list[str] | dict[str, str] | list[dict[str, str]] | str] | Literal["_"] ) = "_" """ The input parameters for the subworkflow (defaults to last step output) diff --git a/agents-api/agents_api/autogen/Tools.py b/agents-api/agents_api/autogen/Tools.py index d872674af..229a866bb 100644 --- a/agents-api/agents_api/autogen/Tools.py +++ b/agents-api/agents_api/autogen/Tools.py @@ -561,9 +561,7 @@ class BrowserbaseGetSessionConnectUrlArguments(BrowserbaseGetSessionArguments): pass -class BrowserbaseGetSessionConnectUrlArgumentsUpdate( - BrowserbaseGetSessionArgumentsUpdate -): +class BrowserbaseGetSessionConnectUrlArgumentsUpdate(BrowserbaseGetSessionArgumentsUpdate): pass @@ -571,9 +569,7 @@ class BrowserbaseGetSessionLiveUrlsArguments(BrowserbaseGetSessionArguments): pass -class BrowserbaseGetSessionLiveUrlsArgumentsUpdate( - BrowserbaseGetSessionArgumentsUpdate -): +class BrowserbaseGetSessionLiveUrlsArgumentsUpdate(BrowserbaseGetSessionArgumentsUpdate): pass @@ -1806,9 +1802,9 @@ class SystemDefUpdate(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - resource: ( - Literal["agent", "user", "task", "execution", "doc", "session", "job"] | None - ) = None + resource: Literal["agent", "user", "task", "execution", "doc", "session", "job"] | None = ( + None + ) """ Resource is the name of the resource to use """ @@ -2366,9 +2362,7 @@ class BrowserbaseCompleteSessionIntegrationDef(BaseBrowserbaseIntegrationDef): arguments: BrowserbaseCompleteSessionArguments | None = None -class BrowserbaseCompleteSessionIntegrationDefUpdate( - BaseBrowserbaseIntegrationDefUpdate -): +class BrowserbaseCompleteSessionIntegrationDefUpdate(BaseBrowserbaseIntegrationDefUpdate): """ browserbase complete session integration definition """ @@ -2494,9 +2488,7 @@ class BrowserbaseGetSessionConnectUrlIntegrationDef(BaseBrowserbaseIntegrationDe arguments: BrowserbaseGetSessionConnectUrlArguments | None = None -class BrowserbaseGetSessionConnectUrlIntegrationDefUpdate( - BaseBrowserbaseIntegrationDefUpdate -): +class BrowserbaseGetSessionConnectUrlIntegrationDefUpdate(BaseBrowserbaseIntegrationDefUpdate): """ browserbase get session connect url integration definition """ @@ -2544,9 +2536,7 @@ class BrowserbaseGetSessionLiveUrlsIntegrationDef(BaseBrowserbaseIntegrationDef) arguments: BrowserbaseGetSessionLiveUrlsArguments | None = None -class BrowserbaseGetSessionLiveUrlsIntegrationDefUpdate( - BaseBrowserbaseIntegrationDefUpdate -): +class BrowserbaseGetSessionLiveUrlsIntegrationDefUpdate(BaseBrowserbaseIntegrationDefUpdate): """ browserbase get session live urls integration definition """ diff --git a/agents-api/agents_api/autogen/openapi_model.py b/agents-api/agents_api/autogen/openapi_model.py index d19684cee..ffcf9caf9 100644 --- a/agents-api/agents_api/autogen/openapi_model.py +++ b/agents-api/agents_api/autogen/openapi_model.py @@ -1,6 +1,6 @@ # ruff: noqa: F401, F403, F405 import ast -from typing import Annotated, Any, Generic, Literal, Self, Type, TypeVar, get_args +from typing import Annotated, Any, Generic, Self, TypeVar, get_args from uuid import UUID import jinja2 @@ -14,7 +14,6 @@ model_validator, ) -from ..common.storage_handler import RemoteObject from ..common.utils.datetime import utcnow from .Agents import * from .Chat import * @@ -126,7 +125,7 @@ def validate_python_expression(expr: str) -> tuple[bool, str]: ast.parse(expr) return True, "" except SyntaxError as e: - return False, f"SyntaxError in '{expr}': {str(e)}" + return False, f"SyntaxError in '{expr}': {e!s}" def validate_jinja_template(template: str) -> tuple[bool, str]: @@ -146,7 +145,7 @@ def validate_jinja_template(template: str) -> tuple[bool, str]: ) return True, "" except jinja2.exceptions.TemplateSyntaxError as e: - return False, f"TemplateSyntaxError in '{template}': {str(e)}" + return False, f"TemplateSyntaxError in '{template}': {e!s}" @field_validator("evaluate") @@ -154,7 +153,8 @@ def validate_evaluate_expressions(cls, v): for key, expr in v.items(): is_valid, error = validate_python_expression(expr) if not is_valid: - raise ValueError(f"Invalid Python expression in key '{key}': {error}") + msg = f"Invalid Python expression in key '{key}': {error}" + raise ValueError(msg) return v @@ -168,9 +168,8 @@ def validate_arguments(cls, v): if isinstance(expr, str): is_valid, error = validate_python_expression(expr) if not is_valid: - raise ValueError( - f"Invalid Python expression in arguments key '{key}': {error}" - ) + msg = f"Invalid Python expression in arguments key '{key}': {error}" + raise ValueError(msg) return v @@ -183,15 +182,15 @@ def validate_prompt(cls, v): if isinstance(v, str): is_valid, error = validate_jinja_template(v) if not is_valid: - raise ValueError(f"Invalid Jinja template in prompt: {error}") + msg = f"Invalid Jinja template in prompt: {error}" + raise ValueError(msg) elif isinstance(v, list): for item in v: if "content" in item: is_valid, error = validate_jinja_template(item["content"]) if not is_valid: - raise ValueError( - f"Invalid Jinja template in prompt content: {error}" - ) + msg = f"Invalid Jinja template in prompt content: {error}" + raise ValueError(msg) return v @@ -204,7 +203,8 @@ def validate_set_expressions(cls, v): for key, expr in v.items(): is_valid, error = validate_python_expression(expr) if not is_valid: - raise ValueError(f"Invalid Python expression in set key '{key}': {error}") + msg = f"Invalid Python expression in set key '{key}': {error}" + raise ValueError(msg) return v @@ -215,7 +215,8 @@ def validate_set_expressions(cls, v): def validate_log_template(cls, v): is_valid, error = validate_jinja_template(v) if not is_valid: - raise ValueError(f"Invalid Jinja template in log: {error}") + msg = f"Invalid Jinja template in log: {error}" + raise ValueError(msg) return v @@ -227,9 +228,8 @@ def validate_return_expressions(cls, v): for key, expr in v.items(): is_valid, error = validate_python_expression(expr) if not is_valid: - raise ValueError( - f"Invalid Python expression in return key '{key}': {error}" - ) + msg = f"Invalid Python expression in return key '{key}': {error}" + raise ValueError(msg) return v @@ -242,9 +242,8 @@ def validate_yield_arguments(cls, v): for key, expr in v.items(): is_valid, error = validate_python_expression(expr) if not is_valid: - raise ValueError( - f"Invalid Python expression in yield arguments key '{key}': {error}" - ) + msg = f"Invalid Python expression in yield arguments key '{key}': {error}" + raise ValueError(msg) return v @@ -255,7 +254,8 @@ def validate_yield_arguments(cls, v): def validate_if_expression(cls, v): is_valid, error = validate_python_expression(v) if not is_valid: - raise ValueError(f"Invalid Python expression in if condition: {error}") + msg = f"Invalid Python expression in if condition: {error}" + raise ValueError(msg) return v @@ -266,7 +266,8 @@ def validate_if_expression(cls, v): def validate_over_expression(cls, v): is_valid, error = validate_python_expression(v) if not is_valid: - raise ValueError(f"Invalid Python expression in over: {error}") + msg = f"Invalid Python expression in over: {error}" + raise ValueError(msg) return v @@ -275,7 +276,8 @@ def validate_reduce_expression(cls, v): if v is not None: is_valid, error = validate_python_expression(v) if not is_valid: - raise ValueError(f"Invalid Python expression in reduce: {error}") + msg = f"Invalid Python expression in reduce: {error}" + raise ValueError(msg) return v @@ -288,20 +290,16 @@ def validate_reduce_expression(cls, v): _CreateTaskRequest = CreateTaskRequest -CreateTaskRequest.model_config = ConfigDict( - **{ - **_CreateTaskRequest.model_config, - "extra": "allow", - } -) +CreateTaskRequest.model_config = ConfigDict(**{ + **_CreateTaskRequest.model_config, + "extra": "allow", +}) @model_validator(mode="after") def validate_subworkflows(self): subworkflows = { - k: v - for k, v in self.model_dump().items() - if k not in _CreateTaskRequest.model_fields + k: v for k, v in self.model_dump().items() if k not in _CreateTaskRequest.model_fields } for workflow_name, workflow_definition in subworkflows.items(): @@ -309,7 +307,8 @@ def validate_subworkflows(self): WorkflowType.model_validate(workflow_definition) setattr(self, workflow_name, WorkflowType(workflow_definition)) except Exception as e: - raise ValueError(f"Invalid subworkflow '{workflow_name}': {str(e)}") + msg = f"Invalid subworkflow '{workflow_name}': {e!s}" + raise ValueError(msg) return self @@ -358,7 +357,7 @@ def validate_subworkflows(self): class SystemDef(SystemDef): - arguments: dict[str, Any] | None | RemoteObject = None + arguments: dict[str, Any] | None = None class CreateTransitionRequest(Transition): @@ -373,13 +372,11 @@ class CreateTransitionRequest(Transition): class CreateEntryRequest(BaseEntry): - timestamp: Annotated[ - float, Field(ge=0.0, default_factory=lambda: utcnow().timestamp()) - ] + timestamp: Annotated[float, Field(ge=0.0, default_factory=lambda: utcnow().timestamp())] @classmethod def from_model_input( - cls: Type[Self], + cls: type[Self], model: str, *, role: ChatMLRole, @@ -400,6 +397,7 @@ def from_model_input( source=source, tokenizer=tokenizer["type"], token_count=token_count, + model=model, **kwargs, ) @@ -467,12 +465,10 @@ class PartialTaskSpecDef(TaskSpecDef): class Task(_Task): - model_config = ConfigDict( - **{ - **_Task.model_config, - "extra": "allow", - } - ) + model_config = ConfigDict(**{ + **_Task.model_config, + "extra": "allow", + }) # Patch some models to allow extra fields @@ -506,21 +502,17 @@ class Task(_Task): class PatchTaskRequest(_PatchTaskRequest): - model_config = ConfigDict( - **{ - **_PatchTaskRequest.model_config, - "extra": "allow", - } - ) + model_config = ConfigDict(**{ + **_PatchTaskRequest.model_config, + "extra": "allow", + }) _UpdateTaskRequest = UpdateTaskRequest class UpdateTaskRequest(_UpdateTaskRequest): - model_config = ConfigDict( - **{ - **_UpdateTaskRequest.model_config, - "extra": "allow", - } - ) + model_config = ConfigDict(**{ + **_UpdateTaskRequest.model_config, + "extra": "allow", + }) diff --git a/agents-api/agents_api/clients/__init__.py b/agents-api/agents_api/clients/__init__.py index 43a17ab08..1d2ac2cdb 100644 --- a/agents-api/agents_api/clients/__init__.py +++ b/agents-api/agents_api/clients/__init__.py @@ -1,9 +1,6 @@ """ The `clients` module contains client classes and functions for interacting with various external services and APIs, abstracting the complexity of HTTP requests and API interactions to provide a simplified interface for the rest of the application. -- `cozo.py`: Handles communication with the Cozo service, facilitating operations such as retrieving product information. -- `embed.py`: Manages requests to an Embedding Service for text embedding functionalities. -- `openai.py`: Facilitates interaction with OpenAI's API for natural language processing tasks. +- `pg.py`: Handles communication with the PostgreSQL service, facilitating operations such as retrieving product information. - `temporal.py`: Provides functionality for connecting to Temporal workflows, enabling asynchronous task execution and management. -- `worker/__init__.py` and related files: Describe the role of the worker service client in sending tasks to be processed by an external worker service, focusing on memory management and other computational tasks. """ diff --git a/agents-api/agents_api/clients/async_s3.py b/agents-api/agents_api/clients/async_s3.py index 0cd5235ee..d58f96140 100644 --- a/agents-api/agents_api/clients/async_s3.py +++ b/agents-api/agents_api/clients/async_s3.py @@ -3,114 +3,84 @@ with workflow.unsafe.imports_passed_through(): import botocore - from aiobotocore.session import get_session from async_lru import alru_cache from xxhash import xxh3_64_hexdigest as xxhash_key from ..env import ( blob_store_bucket, blob_store_cutoff_kb, - s3_access_key, - s3_endpoint, - s3_secret_key, ) -async def list_buckets() -> list[str]: - session = get_session() +@alru_cache(maxsize=1) +async def setup(): + from ..app import app - async with session.create_client( - "s3", - endpoint_url=s3_endpoint, - aws_access_key_id=s3_access_key, - aws_secret_access_key=s3_secret_key, - ) as client: - data = await client.list_buckets() - buckets = [bucket["Name"] for bucket in data["Buckets"]] - return buckets + if not app.state.s3_client: + msg = "S3 client not initialized" + raise RuntimeError(msg) + client = app.state.s3_client -@alru_cache(maxsize=1) -async def setup(): - session = get_session() - - async with session.create_client( - "s3", - aws_access_key_id=s3_access_key, - aws_secret_access_key=s3_secret_key, - endpoint_url=s3_endpoint, - ) as client: - if blob_store_bucket not in await list_buckets(): + try: + await client.head_bucket(Bucket=blob_store_bucket) + except botocore.exceptions.ClientError as e: + if e.response["Error"]["Code"] == "404": await client.create_bucket(Bucket=blob_store_bucket) + else: + raise e + + return client + + +@alru_cache(maxsize=1024) +async def list_buckets() -> list[str]: + client = await setup() + + data = await client.list_buckets() + return [bucket["Name"] for bucket in data["Buckets"]] @alru_cache(maxsize=10_000) async def exists(key: str) -> bool: - session = get_session() - - async with session.create_client( - "s3", - aws_access_key_id=s3_access_key, - aws_secret_access_key=s3_secret_key, - endpoint_url=s3_endpoint, - ) as client: - try: - await client.head_object(Bucket=blob_store_bucket, Key=key) - return True - except botocore.exceptions.ClientError as e: - if e.response["Error"]["Code"] == "404": - return False - else: - raise e + client = await setup() + + try: + await client.head_object(Bucket=blob_store_bucket, Key=key) + return True + except botocore.exceptions.ClientError as e: + if e.response["Error"]["Code"] == "404": + return False + raise e @beartype async def add_object(key: str, body: bytes, replace: bool = False) -> None: - session = get_session() + client = await setup() - async with session.create_client( - "s3", - aws_access_key_id=s3_access_key, - aws_secret_access_key=s3_secret_key, - endpoint_url=s3_endpoint, - ) as client: - if replace: - await client.put_object(Bucket=blob_store_bucket, Key=key, Body=body) - return + if replace: + await client.put_object(Bucket=blob_store_bucket, Key=key, Body=body) + return - if await exists(key): - return + if await exists(key): + return - await client.put_object(Bucket=blob_store_bucket, Key=key, Body=body) + await client.put_object(Bucket=blob_store_bucket, Key=key, Body=body) @alru_cache(maxsize=256 * 1024 // max(1, blob_store_cutoff_kb)) # 256mb in cache @beartype async def get_object(key: str) -> bytes: - session = get_session() + client = await setup() - async with session.create_client( - "s3", - aws_access_key_id=s3_access_key, - aws_secret_access_key=s3_secret_key, - endpoint_url=s3_endpoint, - ) as client: - response = await client.get_object(Bucket=blob_store_bucket, Key=key) - body = await response["Body"].read() - return body + response = await client.get_object(Bucket=blob_store_bucket, Key=key) + return await response["Body"].read() @beartype async def delete_object(key: str) -> None: - session = get_session() - - async with session.create_client( - "s3", - aws_access_key_id=s3_access_key, - aws_secret_access_key=s3_secret_key, - endpoint_url=s3_endpoint, - ) as client: - await client.delete_object(Bucket=blob_store_bucket, Key=key) + client = await setup() + await client.delete_object(Bucket=blob_store_bucket, Key=key) @beartype diff --git a/agents-api/agents_api/clients/cozo.py b/agents-api/agents_api/clients/cozo.py deleted file mode 100644 index 285bae8b2..000000000 --- a/agents-api/agents_api/clients/cozo.py +++ /dev/null @@ -1,29 +0,0 @@ -from typing import Dict - -from pycozo.client import Client -from pycozo_async import Client as AsyncClient - -from ..env import cozo_auth, cozo_host -from ..web import app - -options: Dict[str, str] = {"host": cozo_host} -if cozo_auth: - options.update({"auth": cozo_auth}) - - -def get_cozo_client() -> Client: - client = getattr(app.state, "cozo_client", Client("http", options=options)) - if not hasattr(app.state, "cozo_client"): - app.state.cozo_client = client - - return client - - -def get_async_cozo_client() -> AsyncClient: - client = getattr( - app.state, "async_cozo_client", AsyncClient("http", options=options) - ) - if not hasattr(app.state, "async_cozo_client"): - app.state.async_cozo_client = client - - return client diff --git a/agents-api/agents_api/clients/integrations.py b/agents-api/agents_api/clients/integrations.py index cb66c293a..aa33bd25f 100644 --- a/agents-api/agents_api/clients/integrations.py +++ b/agents-api/agents_api/clients/integrations.py @@ -1,11 +1,11 @@ -from typing import Any, List +from typing import Any from beartype import beartype from httpx import AsyncClient from ..env import integration_service_url -__all__: List[str] = ["run_integration_service"] +__all__: list[str] = ["run_integration_service"] @beartype diff --git a/agents-api/agents_api/clients/litellm.py b/agents-api/agents_api/clients/litellm.py index bbf743919..7a3dc8c77 100644 --- a/agents-api/agents_api/clients/litellm.py +++ b/agents-api/agents_api/clients/litellm.py @@ -1,17 +1,10 @@ from functools import wraps -from typing import List, Literal +from typing import Literal -import litellm from beartype import beartype -from litellm import ( - acompletion as _acompletion, -) -from litellm import ( - aembedding as _aembedding, -) -from litellm import ( - get_supported_openai_params, -) +from litellm import acompletion as _acompletion +from litellm import aembedding as _aembedding +from litellm import get_supported_openai_params from litellm.utils import CustomStreamWrapper, ModelResponse from ..env import ( @@ -21,10 +14,7 @@ litellm_url, ) -__all__: List[str] = ["acompletion"] - -# TODO: Should check if this is really needed -litellm.drop_params = True +__all__: list[str] = ["acompletion"] def patch_litellm_response( @@ -39,9 +29,11 @@ def patch_litellm_response( if choice.finish_reason == "eos": choice.finish_reason = "stop" - elif isinstance(model_response, CustomStreamWrapper): - if model_response.received_finish_reason == "eos": - model_response.received_finish_reason = "stop" + elif ( + isinstance(model_response, CustomStreamWrapper) + and model_response.received_finish_reason == "eos" + ): + model_response.received_finish_reason = "stop" return model_response @@ -49,19 +41,18 @@ def patch_litellm_response( @wraps(_acompletion) @beartype async def acompletion( - *, model: str, messages: list[dict], custom_api_key: None | str = None, **kwargs + *, model: str, messages: list[dict], custom_api_key: str | None = None, **kwargs ) -> ModelResponse | CustomStreamWrapper: if not custom_api_key: - model = f"openai/{model}" # FIXME: This is for litellm + model = f"openai/{model}" # This is needed for litellm supported_params = get_supported_openai_params(model) settings = {k: v for k, v in kwargs.items() if k in supported_params} - # FIXME: This is a hotfix for Mistral API, which expects a different message format + # NOTE: This is a fix for Mistral API, which expects a different message format if model[7:].startswith("mistral"): messages = [ - {"role": message["role"], "content": message["content"]} - for message in messages + {"role": message["role"], "content": message["content"]} for message in messages ] model_response = await _acompletion( @@ -72,9 +63,7 @@ async def acompletion( api_key=custom_api_key or litellm_master_key, ) - model_response = patch_litellm_response(model_response) - - return model_response + return patch_litellm_response(model_response) @wraps(_aembedding) @@ -86,25 +75,27 @@ async def aembedding( embed_instruction: str | None = None, dimensions: int = embedding_dimensions, join_inputs: bool = False, - custom_api_key: None | str = None, + custom_api_key: str | None = None, **settings, ) -> list[list[float]]: # Temporarily commented out (causes errors when using voyage/voyage-3) # if not custom_api_key: - # model = f"openai/{model}" # FIXME: This is for litellm - - if isinstance(inputs, str): - input = [inputs] - else: - input = ["\n\n".join(inputs)] if join_inputs else inputs + # model = f"openai/{model}" # FIXME: Is this still needed for litellm? + + input = ( + [inputs] + if isinstance(inputs, str) + else ["\n\n".join(inputs)] + if join_inputs + else inputs + ) if embed_instruction: - input = [embed_instruction] + input + input = [embed_instruction, *input] response = await _aembedding( model=model, input=input, - # dimensions=dimensions, # FIXME: litellm doesn't support dimensions correctly api_base=None if custom_api_key else litellm_url, api_key=custom_api_key or litellm_master_key, drop_params=True, @@ -113,7 +104,8 @@ async def aembedding( embedding_list: list[dict[Literal["embedding"], list[float]]] = response.data - # FIXME: Truncation should be handled by litellm - result = [embedding["embedding"][:dimensions] for embedding in embedding_list] - - return result + return [ + item["embedding"][:dimensions] + for item in embedding_list + if len(item["embedding"]) > dimensions + ] diff --git a/agents-api/agents_api/clients/pg.py b/agents-api/agents_api/clients/pg.py new file mode 100644 index 000000000..5fcce419c --- /dev/null +++ b/agents-api/agents_api/clients/pg.py @@ -0,0 +1,18 @@ +import json + +import asyncpg + +from ..env import pg_dsn + + +async def _init_conn(conn): + await conn.set_type_codec( + "jsonb", + encoder=json.dumps, + decoder=json.loads, + schema="pg_catalog", + ) + + +async def create_db_pool(dsn: str | None = None): + return await asyncpg.create_pool(dsn if dsn is not None else pg_dsn, init=_init_conn) diff --git a/agents-api/agents_api/clients/temporal.py b/agents-api/agents_api/clients/temporal.py index da2d7f6fa..325427c96 100644 --- a/agents-api/agents_api/clients/temporal.py +++ b/agents-api/agents_api/clients/temporal.py @@ -12,9 +12,9 @@ from temporalio.runtime import PrometheusConfig, Runtime, TelemetryConfig from ..autogen.openapi_model import TransitionTarget +from ..common.interceptors import offload_if_large from ..common.protocol.tasks import ExecutionInput from ..common.retry_policies import DEFAULT_RETRY_POLICY -from ..common.storage_handler import store_in_blob_store_if_large from ..env import ( temporal_client_cert, temporal_metrics_bind_host, @@ -91,14 +91,15 @@ async def run_task_execution_workflow( from ..workflows.task_execution import TaskExecutionWorkflow start: TransitionTarget = start or TransitionTarget(workflow="main", step=0) - previous_inputs: list[dict] = previous_inputs or [] client = client or (await get_client()) execution_id = execution_input.execution.id execution_id_key = SearchAttributeKey.for_keyword("CustomStringField") - execution_input.arguments = await store_in_blob_store_if_large( - execution_input.arguments - ) + + old_args = execution_input.arguments + execution_input.arguments = await offload_if_large(old_args) + + previous_inputs: list[dict] = previous_inputs or [execution_input.arguments] return await client.start_workflow( TaskExecutionWorkflow.run, @@ -107,11 +108,9 @@ async def run_task_execution_workflow( id=str(job_id), run_timeout=timedelta(days=31), retry_policy=DEFAULT_RETRY_POLICY, - search_attributes=TypedSearchAttributes( - [ - SearchAttributePair(execution_id_key, str(execution_id)), - ] - ), + search_attributes=TypedSearchAttributes([ + SearchAttributePair(execution_id_key, str(execution_id)), + ]), ) @@ -122,8 +121,6 @@ async def get_workflow_handle( ): client = client or (await get_client()) - handle = client.get_workflow_handle( + return client.get_workflow_handle( handle_id, ) - - return handle diff --git a/agents-api/agents_api/clients/worker/__init__.py b/agents-api/agents_api/clients/worker/__init__.py deleted file mode 100644 index 53f598ba2..000000000 --- a/agents-api/agents_api/clients/worker/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -""" -This module provides functionality for interacting with an external worker service. It includes utilities for creating and sending tasks, such as memory management tasks, to be processed by the service. The module leverages asynchronous HTTP requests via the `httpx` library to communicate with the worker service. Types for structuring task data are defined in `types.py`. -""" diff --git a/agents-api/agents_api/clients/worker/types.py b/agents-api/agents_api/clients/worker/types.py deleted file mode 100644 index 3bf063083..000000000 --- a/agents-api/agents_api/clients/worker/types.py +++ /dev/null @@ -1,41 +0,0 @@ -from typing import Literal -from uuid import UUID - -from pydantic import BaseModel - -from agents_api.autogen.openapi_model import ( - InputChatMLMessage, -) - - -class MemoryManagementTaskArgs(BaseModel): - session_id: UUID - model: str - dialog: list[InputChatMLMessage] - previous_memories: list[str] = [] - - -class MemoryManagementTask(BaseModel): - name: Literal["memory_management.v1"] - args: MemoryManagementTaskArgs - - -class MemoryDensityTaskArgs(BaseModel): - memory: str - - -class MemoryDensityTask(BaseModel): - name: Literal["memory_density.v1"] - args: MemoryDensityTaskArgs - - -class MemoryRatingTaskArgs(BaseModel): - memory: str - - -class MemoryRatingTask(BaseModel): - name: Literal["memory_rating.v1"] - args: MemoryRatingTaskArgs - - -CombinedTask = MemoryManagementTask | MemoryDensityTask | MemoryRatingTask diff --git a/agents-api/agents_api/clients/worker/worker.py b/agents-api/agents_api/clients/worker/worker.py deleted file mode 100644 index 8befa3080..000000000 --- a/agents-api/agents_api/clients/worker/worker.py +++ /dev/null @@ -1,21 +0,0 @@ -import httpx - -from ...env import temporal_worker_url -from .types import ( - MemoryManagementTask, - MemoryManagementTaskArgs, -) - - -async def add_summarization_task(data: MemoryManagementTaskArgs): - async with httpx.AsyncClient(timeout=30) as client: - data = MemoryManagementTask( - name="memory_management.v1", - args=data, - ) - - await client.post( - f"{temporal_worker_url}/task", - headers={"Content-Type": "json"}, - data=data.model_dump_json(), - ) diff --git a/agents-api/agents_api/common/exceptions/agents.py b/agents-api/agents_api/common/exceptions/agents.py index e58f25104..042b34ee0 100644 --- a/agents-api/agents_api/common/exceptions/agents.py +++ b/agents-api/agents_api/common/exceptions/agents.py @@ -8,8 +8,6 @@ class BaseAgentException(BaseCommonException): """Base exception class for all agent-related exceptions.""" - pass - class AgentNotFoundError(BaseAgentException): """ @@ -22,7 +20,7 @@ class AgentNotFoundError(BaseAgentException): def __init__(self, developer_id: UUID | str, agent_id: UUID | str): # Initialize the exception with a message indicating the missing agent and developer ID. super().__init__( - f"Agent {str(agent_id)} not found for developer {str(developer_id)}", + f"Agent {agent_id!s} not found for developer {developer_id!s}", http_code=404, ) @@ -37,9 +35,7 @@ class AgentToolNotFoundError(BaseAgentException): def __init__(self, agent_id: UUID | str, tool_id: UUID | str): # Initialize the exception with a message indicating the missing tool and agent ID. - super().__init__( - f"Tool {str(tool_id)} not found for agent {str(agent_id)}", http_code=404 - ) + super().__init__(f"Tool {tool_id!s} not found for agent {agent_id!s}", http_code=404) class AgentDocNotFoundError(BaseAgentException): @@ -52,9 +48,7 @@ class AgentDocNotFoundError(BaseAgentException): def __init__(self, agent_id: UUID | str, doc_id: UUID | str): # Initialize the exception with a message indicating the missing document and agent ID. - super().__init__( - f"Doc {str(doc_id)} not found for agent {str(agent_id)}", http_code=404 - ) + super().__init__(f"Doc {doc_id!s} not found for agent {agent_id!s}", http_code=404) class AgentModelNotValid(BaseAgentException): diff --git a/agents-api/agents_api/common/exceptions/sessions.py b/agents-api/agents_api/common/exceptions/sessions.py index 6e9941d43..6df811c77 100644 --- a/agents-api/agents_api/common/exceptions/sessions.py +++ b/agents-api/agents_api/common/exceptions/sessions.py @@ -16,8 +16,6 @@ class BaseSessionException(BaseCommonException): This class serves as a base for all session-related exceptions, allowing for a structured exception handling approach specific to session operations. """ - pass - class SessionNotFoundError(BaseSessionException): """ @@ -32,6 +30,6 @@ class SessionNotFoundError(BaseSessionException): def __init__(self, developer_id: UUID | str, session_id: UUID | str): super().__init__( - f"Session {str(session_id)} not found for developer {str(developer_id)}", + f"Session {session_id!s} not found for developer {developer_id!s}", http_code=404, ) diff --git a/agents-api/agents_api/common/exceptions/tools.py b/agents-api/agents_api/common/exceptions/tools.py index 2ea126505..118a4355c 100644 --- a/agents-api/agents_api/common/exceptions/tools.py +++ b/agents-api/agents_api/common/exceptions/tools.py @@ -9,8 +9,6 @@ class BaseToolsException(BaseCommonException): """Base exception for tools-related errors.""" - pass - class IntegrationExecutionException(BaseToolsException): """Exception raised when an error occurs during an integration execution.""" diff --git a/agents-api/agents_api/common/exceptions/users.py b/agents-api/agents_api/common/exceptions/users.py index cf4e995ad..2be87aea2 100644 --- a/agents-api/agents_api/common/exceptions/users.py +++ b/agents-api/agents_api/common/exceptions/users.py @@ -12,8 +12,6 @@ class BaseUserException(BaseCommonException): This class serves as a parent for all user-related exceptions to facilitate catching errors specific to user operations. """ - pass - class UserNotFoundError(BaseUserException): """ @@ -26,7 +24,7 @@ class UserNotFoundError(BaseUserException): def __init__(self, developer_id: UUID | str, user_id: UUID | str): # Construct an error message indicating the user and developer involved in the error. super().__init__( - f"User {str(user_id)} not found for developer {str(developer_id)}", + f"User {user_id!s} not found for developer {developer_id!s}", http_code=404, ) @@ -41,6 +39,4 @@ class UserDocNotFoundError(BaseUserException): def __init__(self, user_id: UUID | str, doc_id: UUID | str): # Construct an error message indicating the document and user involved in the error. - super().__init__( - f"Doc {str(doc_id)} not found for user {str(user_id)}", http_code=404 - ) + super().__init__(f"Doc {doc_id!s} not found for user {user_id!s}", http_code=404) diff --git a/agents-api/agents_api/common/interceptors.py b/agents-api/agents_api/common/interceptors.py index 40600a818..3a1ac9481 100644 --- a/agents-api/agents_api/common/interceptors.py +++ b/agents-api/agents_api/common/interceptors.py @@ -4,8 +4,13 @@ certain types of errors that are known to be non-retryable. """ -from typing import Optional, Type +import asyncio +import sys +from collections.abc import Awaitable, Callable, Sequence +from functools import wraps +from typing import Any +from temporalio import workflow from temporalio.activity import _CompleteAsyncError as CompleteAsyncError from temporalio.exceptions import ApplicationError, FailureError, TemporalError from temporalio.service import RPCError @@ -23,7 +28,95 @@ ReadOnlyContextError, ) -from .exceptions.tasks import is_retryable_error +with workflow.unsafe.imports_passed_through(): + from ..env import blob_store_cutoff_kb, use_blob_store_for_temporal + from .exceptions.tasks import is_retryable_error + from .protocol.remote import RemoteObject + +# Common exceptions that should be re-raised without modification +PASSTHROUGH_EXCEPTIONS = ( + ContinueAsNewError, + ReadOnlyContextError, + NondeterminismError, + RPCError, + CompleteAsyncError, + TemporalError, + FailureError, + ApplicationError, +) + + +def is_too_large(result: Any) -> bool: + return sys.getsizeof(result) > blob_store_cutoff_kb * 1024 + + +async def load_if_remote[T](arg: T | RemoteObject[T]) -> T: + if use_blob_store_for_temporal and isinstance(arg, RemoteObject): + return await arg.load() + + return arg + + +async def offload_if_large[T](result: T) -> T: + if use_blob_store_for_temporal and is_too_large(result): + return await RemoteObject.from_value(result) + + return result + + +def offload_to_blob_store[S, T]( + func: Callable[[S, ExecuteActivityInput | ExecuteWorkflowInput], Awaitable[T]], +) -> Callable[[S, ExecuteActivityInput | ExecuteWorkflowInput], Awaitable[T | RemoteObject[T]]]: + @wraps(func) + async def wrapper( + self, + input: ExecuteActivityInput | ExecuteWorkflowInput, + ) -> T | RemoteObject[T]: + # Load all remote arguments from the blob store + args: Sequence[Any] = input.args + + if use_blob_store_for_temporal: + input.args = await asyncio.gather(*[load_if_remote(arg) for arg in args]) + + # Execute the function + result = await func(self, input) + + # Save the result to the blob store if necessary + return await offload_if_large(result) + + return wrapper + + +async def handle_execution_with_errors[I, T]( + execution_fn: Callable[[I], Awaitable[T]], + input: I, +) -> T: + """ + Common error handling logic for both activities and workflows. + + Args: + execution_fn: Async function to execute with error handling + input: Input to the execution function + + Returns: + The result of the execution function + + Raises: + ApplicationError: For non-retryable errors + Any other exception: For retryable errors + """ + try: + return await execution_fn(input) + except PASSTHROUGH_EXCEPTIONS: + raise + except BaseException as e: + if not is_retryable_error(e): + raise ApplicationError( + str(e), + type=type(e).__name__, + non_retryable=True, + ) + raise class CustomActivityInterceptor(ActivityInboundInterceptor): @@ -35,105 +128,52 @@ class CustomActivityInterceptor(ActivityInboundInterceptor): as non-retryable errors. """ - async def execute_activity(self, input: ExecuteActivityInput): + @offload_to_blob_store + async def execute_activity(self, input: ExecuteActivityInput) -> Any: """ - 🎭 The Activity Whisperer: Handles activity execution with style and grace - - This is like a safety net for your activities - catching errors and deciding - their fate with the wisdom of a fortune cookie. + Handles activity execution by intercepting errors and determining their retry behavior. """ - try: - return await super().execute_activity(input) - except ( - ContinueAsNewError, # When you need a fresh start - ReadOnlyContextError, # When someone tries to write in a museum - NondeterminismError, # When chaos theory kicks in - RPCError, # When computers can't talk to each other - CompleteAsyncError, # When async goes wrong - TemporalError, # When time itself rebels - FailureError, # When failure is not an option, but happens anyway - ApplicationError, # When the app says "nope" - ): - raise - except BaseException as e: - if not is_retryable_error(e): - # If it's not retryable, we wrap it in a nice bow (ApplicationError) - # and mark it as non-retryable to prevent further attempts - raise ApplicationError( - str(e), - type=type(e).__name__, - non_retryable=True, - ) - # For retryable errors, we'll let Temporal retry with backoff - # Default retry policy ensures at least 2 retries - raise + return await handle_execution_with_errors( + super().execute_activity, + input, + ) class CustomWorkflowInterceptor(WorkflowInboundInterceptor): """ - 🎪 The Workflow Circus Ringmaster + Custom interceptor for Temporal workflows. - This interceptor is like a circus ringmaster - keeping all the workflow acts - running smoothly and catching any lions (errors) that escape their cages. + Handles workflow execution errors and determines their retry behavior. """ - async def execute_workflow(self, input: ExecuteWorkflowInput): + @offload_to_blob_store + async def execute_workflow(self, input: ExecuteWorkflowInput) -> Any: """ - 🎪 The Main Event: Workflow Execution Extravaganza! - - Watch as we gracefully handle errors like a trapeze artist catching their partner! + Executes workflows and handles error cases appropriately. """ - try: - return await super().execute_workflow(input) - except ( - ContinueAsNewError, # The show must go on! - ReadOnlyContextError, # No touching, please! - NondeterminismError, # When butterflies cause hurricanes - RPCError, # Lost in translation - CompleteAsyncError, # Async said "bye" too soon - TemporalError, # Time is relative, errors are absolute - FailureError, # Task failed successfully - ApplicationError, # App.exe has stopped working - ): - raise - except BaseException as e: - if not is_retryable_error(e): - # Pack the error in a nice box with a "do not retry" sticker - raise ApplicationError( - str(e), - type=type(e).__name__, - non_retryable=True, - ) - # Let it retry - everyone deserves a second (or third) chance! - raise + return await handle_execution_with_errors( + super().execute_workflow, + input, + ) class CustomInterceptor(Interceptor): """ - 🎭 The Grand Interceptor: Master of Ceremonies - - This is like the backstage manager of a theater - making sure both the - activity actors and workflow directors have their interceptor costumes on. + Main interceptor class that provides both activity and workflow interceptors. """ def intercept_activity( self, next: ActivityInboundInterceptor ) -> ActivityInboundInterceptor: """ - 🎬 Activity Interceptor Factory: Where the magic begins! - - Creating custom activity interceptors faster than a caffeinated barista - makes espresso shots. + Creates and returns a custom activity interceptor. """ return CustomActivityInterceptor(super().intercept_activity(next)) def workflow_interceptor_class( self, input: WorkflowInterceptorClassInput - ) -> Optional[Type[WorkflowInboundInterceptor]]: + ) -> type[WorkflowInboundInterceptor] | None: """ - 🎪 Workflow Interceptor Class Selector - - Like a matchmaker for workflows and their interceptors - a match made in - exception handling heaven! + Returns the custom workflow interceptor class. """ return CustomWorkflowInterceptor diff --git a/agents-api/agents_api/common/nlp.py b/agents-api/agents_api/common/nlp.py index 58b26c50b..00ba3d881 100644 --- a/agents-api/agents_api/common/nlp.py +++ b/agents-api/agents_api/common/nlp.py @@ -142,7 +142,7 @@ def find_proximity_groups( # Initialize Union-Find with path compression and union by rank parent = {kw: kw for kw in keywords} - rank = {kw: 0 for kw in keywords} + rank = dict.fromkeys(keywords, 0) def find(u: str) -> str: if parent[u] != u: @@ -277,9 +277,7 @@ def batch_paragraphs_to_custom_queries( list[list[str]]: A list where each element is a list of queries for a paragraph. """ results = [] - for doc in nlp.pipe( - paragraphs, disable=["lemmatizer", "textcat"], n_process=n_process - ): + for doc in nlp.pipe(paragraphs, disable=["lemmatizer", "textcat"], n_process=n_process): queries = [] for sent in doc.sents: sent_doc = sent.as_doc() diff --git a/agents-api/agents_api/common/protocol/remote.py b/agents-api/agents_api/common/protocol/remote.py index ce2a2a63a..0b6c7bf80 100644 --- a/agents-api/agents_api/common/protocol/remote.py +++ b/agents-api/agents_api/common/protocol/remote.py @@ -1,91 +1,34 @@ from dataclasses import dataclass -from typing import Any +from typing import Generic, Self, TypeVar, cast -from temporalio import activity, workflow +from temporalio import workflow with workflow.unsafe.imports_passed_through(): - from pydantic import BaseModel - + from ...clients import async_s3 from ...env import blob_store_bucket + from ...worker.codec import deserialize, serialize -@dataclass -class RemoteObject: - key: str - bucket: str = blob_store_bucket - - -class BaseRemoteModel(BaseModel): - _remote_cache: dict[str, Any] - - class Config: - arbitrary_types_allowed = True - - def __init__(self, **data: Any): - super().__init__(**data) - self._remote_cache = {} - - async def load_item(self, item: Any | RemoteObject) -> Any: - if not activity.in_activity(): - return item - - from ..storage_handler import load_from_blob_store_if_remote - - return await load_from_blob_store_if_remote(item) +T = TypeVar("T") - async def save_item(self, item: Any) -> Any: - if not activity.in_activity(): - return item - from ..storage_handler import store_in_blob_store_if_large - - return await store_in_blob_store_if_large(item) - - async def get_attribute(self, name: str) -> Any: - if name.startswith("_"): - return super().__getattribute__(name) - - try: - value = super().__getattribute__(name) - except AttributeError: - raise AttributeError( - f"'{type(self).__name__}' object has no attribute '{name}'" - ) - - if isinstance(value, RemoteObject): - cache = super().__getattribute__("_remote_cache") - if name in cache: - return cache[name] - - loaded_data = await self.load_item(value) - cache[name] = loaded_data - return loaded_data - - return value - - async def set_attribute(self, name: str, value: Any) -> None: - if name.startswith("_"): - super().__setattr__(name, value) - return +@dataclass +class RemoteObject(Generic[T]): + _type: type[T] + key: str + bucket: str - stored_value = await self.save_item(value) - super().__setattr__(name, stored_value) + @classmethod + async def from_value(cls, x: T) -> Self: + await async_s3.setup() - if isinstance(stored_value, RemoteObject): - cache = self.__dict__.get("_remote_cache", {}) - cache.pop(name, None) + serialized = serialize(x) - async def load_all(self) -> None: - for name in self.model_fields_set: - await self.get_attribute(name) + key = await async_s3.add_object_with_hash(serialized) + return RemoteObject[T](key=key, bucket=blob_store_bucket, _type=type(x)) - async def unload_attribute(self, name: str) -> None: - if name in self._remote_cache: - data = self._remote_cache.pop(name) - remote_obj = await self.save_item(data) - super().__setattr__(name, remote_obj) + async def load(self) -> T: + await async_s3.setup() - async def unload_all(self) -> "BaseRemoteModel": - for name in list(self._remote_cache.keys()): - await self.unload_attribute(name) - return self + fetched = await async_s3.get_object(self.key) + return cast(self._type, deserialize(fetched)) diff --git a/agents-api/agents_api/common/protocol/sessions.py b/agents-api/agents_api/common/protocol/sessions.py index 121afe702..3b0e9098c 100644 --- a/agents-api/agents_api/common/protocol/sessions.py +++ b/agents-api/agents_api/common/protocol/sessions.py @@ -28,8 +28,6 @@ class SessionSettings(AgentDefaultSettings): Currently, it does not extend the base class with additional properties. """ - pass - class SessionData(BaseModel): """ @@ -75,17 +73,13 @@ def merge_settings(self, chat_input: ChatInput) -> ChatSettings: active_agent = self.get_active_agent() default_settings: AgentDefaultSettings | None = active_agent.default_settings - default_settings: dict = ( - default_settings and default_settings.model_dump() or {} - ) + default_settings: dict = (default_settings and default_settings.model_dump()) or {} - self.settings = settings = ChatSettings( - **{ - "model": active_agent.model, - **default_settings, - **request_settings, - } - ) + self.settings = settings = ChatSettings(**{ + "model": active_agent.model, + **default_settings, + **request_settings, + }) return settings @@ -103,20 +97,21 @@ def get_active_tools(self) -> list[Tool]: return active_toolset.tools - def get_chat_environment(self) -> dict[str, dict | list[dict]]: + def get_chat_environment(self) -> dict[str, dict | list[dict] | None]: """ Get the chat environment from the session data. """ current_agent = self.get_active_agent() tools = self.get_active_tools() settings: ChatSettings | None = self.settings - settings: dict = settings and settings.model_dump() or {} + settings: dict = (settings and settings.model_dump()) or {} return { "session": self.session.model_dump(), "agents": [agent.model_dump() for agent in self.agents], "current_agent": current_agent.model_dump(), "agent": current_agent.model_dump(), + "user": self.users[0].model_dump() if len(self.users) > 0 else None, "users": [user.model_dump() for user in self.users], "settings": settings, "tools": [tool.model_dump() for tool in tools], @@ -136,7 +131,8 @@ def make_session( match (len(agents), len(users)): case (0, _): - raise ValueError("At least one agent must be provided.") + msg = "At least one agent must be provided." + raise ValueError(msg) case (1, 0): cls = SingleAgentNoUserSession participants = {"agent": agents[0]} diff --git a/agents-api/agents_api/common/protocol/state_machine.py b/agents-api/agents_api/common/protocol/state_machine.py new file mode 100644 index 000000000..ac3636456 --- /dev/null +++ b/agents-api/agents_api/common/protocol/state_machine.py @@ -0,0 +1,206 @@ +from collections.abc import Generator +from contextlib import contextmanager +from enum import StrEnum +from uuid import UUID + +from pydantic import BaseModel, Field + +from ...autogen.openapi_model import TransitionTarget + + +class TransitionType(StrEnum): + """Enum for transition types in the workflow.""" + + INIT = "init" + INIT_BRANCH = "init_branch" + WAIT = "wait" + RESUME = "resume" + STEP = "step" + FINISH = "finish" + FINISH_BRANCH = "finish_branch" + ERROR = "error" + CANCELLED = "cancelled" + + +class ExecutionStatus(StrEnum): + """Enum for execution statuses.""" + + QUEUED = "queued" + STARTING = "starting" + RUNNING = "running" + AWAITING_INPUT = "awaiting_input" + SUCCEEDED = "succeeded" + FAILED = "failed" + CANCELLED = "cancelled" + + +class StateTransitionError(Exception): + """Raised when an invalid state transition is attempted.""" + + +class ExecutionState(BaseModel): + """Model representing the current state of an execution.""" + + status: ExecutionStatus + transition_type: TransitionType | None = None + current_target: TransitionTarget | None = None + next_target: TransitionTarget | None = None + execution_id: UUID + metadata: dict = Field(default_factory=dict) + + +# Valid transitions from each state +_valid_transitions: dict[TransitionType | None, list[TransitionType]] = { + None: [ + TransitionType.INIT, + TransitionType.WAIT, + TransitionType.ERROR, + TransitionType.STEP, + TransitionType.CANCELLED, + TransitionType.INIT_BRANCH, + TransitionType.FINISH, + ], + TransitionType.INIT: [ + TransitionType.WAIT, + TransitionType.ERROR, + TransitionType.STEP, + TransitionType.CANCELLED, + TransitionType.INIT_BRANCH, + TransitionType.FINISH, + ], + TransitionType.INIT_BRANCH: [ + TransitionType.WAIT, + TransitionType.ERROR, + TransitionType.STEP, + TransitionType.CANCELLED, + TransitionType.INIT_BRANCH, + TransitionType.FINISH_BRANCH, + TransitionType.FINISH, + ], + TransitionType.WAIT: [ + TransitionType.RESUME, + TransitionType.STEP, + TransitionType.CANCELLED, + TransitionType.FINISH, + TransitionType.FINISH_BRANCH, + ], + TransitionType.RESUME: [ + TransitionType.WAIT, + TransitionType.ERROR, + TransitionType.CANCELLED, + TransitionType.STEP, + TransitionType.FINISH, + TransitionType.FINISH_BRANCH, + TransitionType.INIT_BRANCH, + ], + TransitionType.STEP: [ + TransitionType.WAIT, + TransitionType.ERROR, + TransitionType.CANCELLED, + TransitionType.STEP, + TransitionType.FINISH, + TransitionType.FINISH_BRANCH, + TransitionType.INIT_BRANCH, + ], + TransitionType.FINISH_BRANCH: [ + TransitionType.WAIT, + TransitionType.ERROR, + TransitionType.CANCELLED, + TransitionType.STEP, + TransitionType.FINISH, + TransitionType.INIT_BRANCH, + ], + # Terminal states + TransitionType.FINISH: [], + TransitionType.ERROR: [], + TransitionType.CANCELLED: [], +} + +# Mapping from transition types to execution statuses +_transition_to_status: dict[TransitionType | None, ExecutionStatus] = { + None: ExecutionStatus.QUEUED, + TransitionType.INIT: ExecutionStatus.STARTING, + TransitionType.INIT_BRANCH: ExecutionStatus.RUNNING, + TransitionType.WAIT: ExecutionStatus.AWAITING_INPUT, + TransitionType.RESUME: ExecutionStatus.RUNNING, + TransitionType.STEP: ExecutionStatus.RUNNING, + TransitionType.FINISH: ExecutionStatus.SUCCEEDED, + TransitionType.FINISH_BRANCH: ExecutionStatus.RUNNING, + TransitionType.ERROR: ExecutionStatus.FAILED, + TransitionType.CANCELLED: ExecutionStatus.CANCELLED, +} + + +class ExecutionStateMachine: + """ + A state machine for managing execution state transitions with validation. + Uses context managers for safe state transitions. + """ + + def __init__(self, execution_id: UUID): + """Initialize the state machine with QUEUED status.""" + self.state = ExecutionState( + status=ExecutionStatus.QUEUED, + execution_id=execution_id, + ) + + def _validate_transition(self, new_type: TransitionType) -> bool: + """Validate if a transition is allowed from the current state.""" + return new_type in _valid_transitions[self.state.transition_type] + + @contextmanager + def transition_to( + self, + transition_type: TransitionType, + current_target: TransitionTarget | None = None, + next_target: TransitionTarget | None = None, + metadata: dict | None = None, + ) -> Generator[ExecutionState, None, None]: + """ + Context manager for safely transitioning to a new state. + + Args: + transition_type: The type of transition to perform + current_target: The current workflow target + next_target: The next workflow target + metadata: Optional metadata for the transition + + Raises: + StateTransitionError: If the transition is invalid + """ + if not self._validate_transition(transition_type): + msg = f"Invalid transition from {self.state.transition_type} to {transition_type}" + raise StateTransitionError(msg) + + # Store previous state for rollback + previous_state = self.state.model_copy(deep=True) + + try: + # Update the state + self.state.transition_type = transition_type + self.state.status = _transition_to_status[transition_type] + self.state.current_target = current_target + self.state.next_target = next_target + if metadata: + self.state.metadata.update(metadata) + + yield self.state + + except Exception as e: + # Rollback on error + self.state = previous_state + msg = f"Transition failed: {e!s}" + raise StateTransitionError(msg) from e + + @property + def is_terminal(self) -> bool: + """Check if the current state is terminal.""" + return ( + self.state.transition_type is not None + and not _valid_transitions[self.state.transition_type] + ) + + @property + def current_status(self) -> ExecutionStatus: + """Get the current execution status.""" + return self.state.status diff --git a/agents-api/agents_api/common/protocol/tasks.py b/agents-api/agents_api/common/protocol/tasks.py index 430a62f36..85bf00cb6 100644 --- a/agents-api/agents_api/common/protocol/tasks.py +++ b/agents-api/agents_api/common/protocol/tasks.py @@ -1,9 +1,8 @@ -import asyncio from typing import Annotated, Any, Literal from uuid import UUID from beartype import beartype -from temporalio import activity, workflow +from temporalio import workflow from temporalio.exceptions import ApplicationError with workflow.unsafe.imports_passed_through(): @@ -33,8 +32,6 @@ Workflow, WorkflowStep, ) - from ...common.storage_handler import load_from_blob_store_if_remote - from .remote import BaseRemoteModel, RemoteObject # TODO: Maybe we should use a library for this @@ -48,7 +45,7 @@ # finish_branch -> wait | error | cancelled | step | finish | init_branch # error -> -## Mermaid Diagram +# Mermaid Diagram # ```mermaid # --- # title: Execution state machine @@ -142,20 +139,20 @@ class PartialTransition(create_partial_model(CreateTransitionRequest)): class ExecutionInput(BaseModel): developer_id: UUID - execution: Execution - task: TaskSpecDef + execution: Execution | None = None + task: TaskSpecDef | None = None agent: Agent agent_tools: list[Tool | CreateToolRequest] - arguments: dict[str, Any] | RemoteObject + arguments: dict[str, Any] # Not used at the moment user: User | None = None session: Session | None = None -class StepContext(BaseRemoteModel): - execution_input: ExecutionInput | RemoteObject - inputs: list[Any] | RemoteObject +class StepContext(BaseModel): + execution_input: ExecutionInput + inputs: list[Any] cursor: TransitionTarget @computed_field @@ -170,12 +167,9 @@ def tools(self) -> list[Tool | CreateToolRequest]: ) if step_tools != "all": - if not all( - tool and isinstance(tool, CreateToolRequest) for tool in step_tools - ): - raise ApplicationError( - "Invalid tools for step (ToolRef not supported yet)" - ) + if not all(tool and isinstance(tool, CreateToolRequest) for tool in step_tools): + msg = "Invalid tools for step (ToolRef not supported yet)" + raise ApplicationError(msg) return step_tools @@ -184,18 +178,14 @@ def tools(self) -> list[Tool | CreateToolRequest]: for tool in task.tools: tool_def = tool.model_dump() task_tools.append( - CreateToolRequest( - **{tool_def["type"]: tool_def.pop("spec"), **tool_def} - ) + CreateToolRequest(**{tool_def["type"]: tool_def.pop("spec"), **tool_def}) ) if not task.inherit_tools: return task_tools # Remove duplicates from agent_tools - filtered_tools = [ - t for t in agent_tools if t.name not in map(lambda x: x.name, task.tools) - ] + filtered_tools = [t for t in agent_tools if t.name not in (x.name for x in task.tools)] return filtered_tools + task_tools @@ -218,8 +208,7 @@ def current_workflow(self) -> Annotated[Workflow, Field(exclude=True)]: @computed_field @property def current_step(self) -> Annotated[WorkflowStep, Field(exclude=True)]: - step = self.current_workflow.steps[self.cursor.step] - return step + return self.current_workflow.steps[self.cursor.step] @computed_field @property @@ -242,17 +231,9 @@ def model_dump(self, *args, **kwargs) -> dict[str, Any]: return dump | execution_input - async def prepare_for_step( - self, *args, include_remote: bool = True, **kwargs - ) -> dict[str, Any]: + async def prepare_for_step(self, *args, **kwargs) -> dict[str, Any]: current_input = self.current_input inputs = self.inputs - if activity.in_activity() and include_remote: - await self.load_all() - inputs = await asyncio.gather( - *[load_from_blob_store_if_remote(input) for input in inputs] - ) - current_input = await load_from_blob_store_if_remote(current_input) # Merge execution inputs into the dump dict dump = self.model_dump(*args, **kwargs) @@ -277,7 +258,9 @@ class StepOutcome(BaseModel): def task_to_spec( task: Task | CreateTaskRequest | UpdateTaskRequest | PatchTaskRequest, **model_opts ) -> TaskSpecDef | PartialTaskSpecDef: - task_data = task.model_dump(**model_opts, exclude={"task_id", "id", "agent_id"}) + task_data = task.model_dump( + **model_opts, exclude={"version", "developer_id", "task_id", "id", "agent_id"} + ) if "tools" in task_data: del task_data["tools"] @@ -315,8 +298,8 @@ def spec_to_task_data(spec: dict) -> dict: workflows = spec.pop("workflows") workflows_dict = {workflow["name"]: workflow["steps"] for workflow in workflows} - tools = spec.pop("tools", []) - tools = [{tool["type"]: tool.pop("spec"), **tool} for tool in tools] + tools = spec.pop("tools", []) or [] + tools = [{tool["type"]: tool.pop("spec"), **tool} for tool in tools if tool] return { "id": task_id, diff --git a/agents-api/agents_api/common/storage_handler.py b/agents-api/agents_api/common/storage_handler.py deleted file mode 100644 index 42beef270..000000000 --- a/agents-api/agents_api/common/storage_handler.py +++ /dev/null @@ -1,226 +0,0 @@ -import asyncio -import sys -from datetime import timedelta -from functools import wraps -from typing import Any, Callable - -from pydantic import BaseModel -from temporalio import workflow - -from ..activities.sync_items_remote import load_inputs_remote -from ..clients import async_s3 -from ..common.protocol.remote import BaseRemoteModel, RemoteObject -from ..common.retry_policies import DEFAULT_RETRY_POLICY -from ..env import ( - blob_store_cutoff_kb, - debug, - temporal_heartbeat_timeout, - temporal_schedule_to_close_timeout, - testing, - use_blob_store_for_temporal, -) -from ..worker.codec import deserialize, serialize - - -async def store_in_blob_store_if_large(x: Any) -> RemoteObject | Any: - if not use_blob_store_for_temporal: - return x - - await async_s3.setup() - - serialized = serialize(x) - data_size = sys.getsizeof(serialized) - - if data_size > blob_store_cutoff_kb * 1024: - key = await async_s3.add_object_with_hash(serialized) - return RemoteObject(key=key) - - return x - - -async def load_from_blob_store_if_remote(x: Any | RemoteObject) -> Any: - if not use_blob_store_for_temporal: - return x - - await async_s3.setup() - - if isinstance(x, RemoteObject): - fetched = await async_s3.get_object(x.key) - return deserialize(fetched) - - elif isinstance(x, dict) and set(x.keys()) == {"bucket", "key"}: - fetched = await async_s3.get_object(x["key"]) - return deserialize(fetched) - - return x - - -# Decorator that automatically does two things: -# 1. store in blob store if the output of a function is large -# 2. load from blob store if the input is a RemoteObject - - -def auto_blob_store(f: Callable | None = None, *, deep: bool = False) -> Callable: - def auto_blob_store_decorator(f: Callable) -> Callable: - async def load_args( - args: list | tuple, kwargs: dict[str, Any] - ) -> tuple[list | tuple, dict[str, Any]]: - new_args = await asyncio.gather( - *[load_from_blob_store_if_remote(arg) for arg in args] - ) - kwargs_keys, kwargs_values = list(zip(*kwargs.items())) or ([], []) - new_kwargs = await asyncio.gather( - *[load_from_blob_store_if_remote(v) for v in kwargs_values] - ) - new_kwargs = dict(zip(kwargs_keys, new_kwargs)) - - if deep: - args = new_args - kwargs = new_kwargs - - new_args = [] - - for arg in args: - if isinstance(arg, list): - new_args.append( - await asyncio.gather( - *[load_from_blob_store_if_remote(item) for item in arg] - ) - ) - elif isinstance(arg, dict): - keys, values = list(zip(*arg.items())) or ([], []) - values = await asyncio.gather( - *[load_from_blob_store_if_remote(value) for value in values] - ) - new_args.append(dict(zip(keys, values))) - - elif isinstance(arg, BaseRemoteModel): - new_args.append(await arg.unload_all()) - - elif isinstance(arg, BaseModel): - for field in arg.model_fields.keys(): - if isinstance(getattr(arg, field), RemoteObject): - setattr( - arg, - field, - await load_from_blob_store_if_remote( - getattr(arg, field) - ), - ) - elif isinstance(getattr(arg, field), list): - setattr( - arg, - field, - await asyncio.gather( - *[ - load_from_blob_store_if_remote(item) - for item in getattr(arg, field) - ] - ), - ) - elif isinstance(getattr(arg, field), BaseRemoteModel): - setattr( - arg, - field, - await getattr(arg, field).unload_all(), - ) - - new_args.append(arg) - - else: - new_args.append(arg) - - new_kwargs = {} - - for k, v in kwargs.items(): - if isinstance(v, list): - new_kwargs[k] = await asyncio.gather( - *[load_from_blob_store_if_remote(item) for item in v] - ) - - elif isinstance(v, dict): - keys, values = list(zip(*v.items())) or ([], []) - values = await asyncio.gather( - *[load_from_blob_store_if_remote(value) for value in values] - ) - new_kwargs[k] = dict(zip(keys, values)) - - elif isinstance(v, BaseRemoteModel): - new_kwargs[k] = await v.unload_all() - - elif isinstance(v, BaseModel): - for field in v.model_fields.keys(): - if isinstance(getattr(v, field), RemoteObject): - setattr( - v, - field, - await load_from_blob_store_if_remote( - getattr(v, field) - ), - ) - elif isinstance(getattr(v, field), list): - setattr( - v, - field, - await asyncio.gather( - *[ - load_from_blob_store_if_remote(item) - for item in getattr(v, field) - ] - ), - ) - elif isinstance(getattr(v, field), BaseRemoteModel): - setattr( - v, - field, - await getattr(v, field).unload_all(), - ) - new_kwargs[k] = v - - else: - new_kwargs[k] = v - - return new_args, new_kwargs - - async def unload_return_value(x: Any | BaseRemoteModel) -> Any: - if isinstance(x, BaseRemoteModel): - await x.unload_all() - - return await store_in_blob_store_if_large(x) - - @wraps(f) - async def async_wrapper(*args, **kwargs) -> Any: - new_args, new_kwargs = await load_args(args, kwargs) - output = await f(*new_args, **new_kwargs) - - return await unload_return_value(output) - - return async_wrapper if use_blob_store_for_temporal else f - - return auto_blob_store_decorator(f) if f else auto_blob_store_decorator - - -def auto_blob_store_workflow(f: Callable) -> Callable: - @wraps(f) - async def wrapper(*args, **kwargs) -> Any: - keys = kwargs.keys() - values = [kwargs[k] for k in keys] - - loaded = await workflow.execute_activity( - load_inputs_remote, - args=[[*args, *values]], - schedule_to_close_timeout=timedelta( - seconds=60 if debug or testing else temporal_schedule_to_close_timeout - ), - retry_policy=DEFAULT_RETRY_POLICY, - heartbeat_timeout=timedelta(seconds=temporal_heartbeat_timeout), - ) - - loaded_args = loaded[: len(args)] - loaded_kwargs = dict(zip(keys, loaded[len(args) :])) - - result = await f(*loaded_args, **loaded_kwargs) - - return result - - return wrapper if use_blob_store_for_temporal else f diff --git a/agents-api/agents_api/common/utils/__init__.py b/agents-api/agents_api/common/utils/__init__.py index 891594c02..fbe7d490c 100644 --- a/agents-api/agents_api/common/utils/__init__.py +++ b/agents-api/agents_api/common/utils/__init__.py @@ -1,7 +1,7 @@ """ The `utils` module within the `agents-api` project offers a collection of utility functions designed to support various aspects of the application. This includes: -- `cozo.py`: Utilities for interacting with the Cozo API client, including data mutation processes. +- `pg.py`: Utilities for interacting with the PostgreSQL API client, including data mutation processes. - `datetime.py`: Functions for handling date and time operations, ensuring consistent use of time zones and formats across the application. - `json.py`: Custom JSON utilities, including a custom JSON encoder for handling specific object types like UUIDs, and a utility function for JSON serialization with support for default values for None objects. diff --git a/agents-api/agents_api/common/utils/cozo.py b/agents-api/agents_api/common/utils/cozo.py deleted file mode 100644 index f5567dc4a..000000000 --- a/agents-api/agents_api/common/utils/cozo.py +++ /dev/null @@ -1,26 +0,0 @@ -#!/usr/bin/env python3 - -"""This module provides utility functions for interacting with the Cozo API client, including data mutation processes.""" - -from types import SimpleNamespace -from uuid import UUID - -from beartype import beartype -from pycozo import Client - -# Define a mock client for testing purposes, simulating Cozo API client behavior. -_fake_client: SimpleNamespace = SimpleNamespace() -# Lambda function to process and mutate data dictionaries using the Cozo client's internal method. This is a workaround to access protected member functions for testing. -_fake_client._process_mutate_data_dict = lambda data: ( - Client._process_mutate_data_dict(_fake_client, data) -) - -# Public interface to process and mutate data using the Cozo client. It wraps the client's internal processing method for external use. -cozo_process_mutate_data = _fake_client._process_mutate_data = lambda data: ( - Client._process_mutate_data(_fake_client, data) -) - - -@beartype -def uuid_int_list_to_uuid4(data: list[int]) -> UUID: - return UUID(bytes=b"".join([i.to_bytes(1, "big") for i in data])) diff --git a/agents-api/agents_api/common/utils/datetime.py b/agents-api/agents_api/common/utils/datetime.py index bec5581c1..ce68bc801 100644 --- a/agents-api/agents_api/common/utils/datetime.py +++ b/agents-api/agents_api/common/utils/datetime.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 -from datetime import datetime, timezone +from datetime import UTC, datetime def utcnow(): - return datetime.now(timezone.utc) + return datetime.now(UTC) diff --git a/agents-api/agents_api/common/utils/db_exceptions.py b/agents-api/agents_api/common/utils/db_exceptions.py new file mode 100644 index 000000000..c40c72bba --- /dev/null +++ b/agents-api/agents_api/common/utils/db_exceptions.py @@ -0,0 +1,192 @@ +""" +Common database exception handling utilities. +""" + +import inspect +import socket +from collections.abc import Callable +from functools import partialmethod, wraps + +import asyncpg +import pydantic +from fastapi import HTTPException + + +def partialclass(cls, *args, **kwargs): + cls_signature = inspect.signature(cls) + bound = cls_signature.bind_partial(*args, **kwargs) + + # The `updated=()` argument is necessary to avoid a TypeError when using @wraps for a class + @wraps(cls, updated=()) + class NewCls(cls): + __init__ = partialmethod(cls.__init__, *bound.args, **bound.kwargs) + + return NewCls + + +def common_db_exceptions( + resource_name: str, + operations: list[str] | None = None, +) -> dict[ + type[BaseException] | Callable[[BaseException], bool], + type[BaseException] | Callable[[BaseException], BaseException], +]: + """ + Returns a mapping of common database exceptions to appropriate HTTP exceptions. + This is commonly used with the @rewrap_exceptions decorator. + + Args: + resource_name (str): The name of the resource being operated on (e.g. "agent", "task", "user") + operations (list[str] | None, optional): List of operations being performed. + Used to customize error messages. Defaults to None. + + Returns: + dict: A mapping of database exceptions to HTTP exceptions + """ + + # Helper to format operation-specific messages + def get_operation_message(base_msg: str) -> str: + if not operations: + return base_msg + op_str = " or ".join(operations) + return f"{base_msg} during {op_str}" + + exceptions = { + # Foreign key violations - usually means a referenced resource doesn't exist + asyncpg.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail=get_operation_message( + f"The specified {resource_name} or its dependencies do not exist" + ), + ), + # Unique constraint violations - usually means a resource with same unique key exists + asyncpg.UniqueViolationError: partialclass( + HTTPException, + status_code=409, + detail=get_operation_message( + f"A {resource_name} with these unique properties already exists" + ), + ), + # Check constraint violations - usually means invalid data that violates DB constraints + asyncpg.CheckViolationError: partialclass( + HTTPException, + status_code=400, + detail=get_operation_message( + f"The provided {resource_name} data violates one or more constraints" + ), + ), + # Data type/format errors + asyncpg.DataError: partialclass( + HTTPException, + status_code=400, + detail=get_operation_message(f"Invalid {resource_name} data provided"), + ), + # No rows found for update/delete operations + asyncpg.NoDataFoundError: partialclass( + HTTPException, + status_code=404, + detail=get_operation_message(f"{resource_name.title()} not found"), + ), + # Connection errors (timeouts, etc) + socket.gaierror: partialclass( + HTTPException, + status_code=429, + detail="Resource busy. Please try again later.", + ), + # Invalid text representation + asyncpg.InvalidTextRepresentationError: partialclass( + HTTPException, + status_code=400, + detail=get_operation_message(f"Invalid text format in {resource_name} data"), + ), + # Numeric value out of range + asyncpg.NumericValueOutOfRangeError: partialclass( + HTTPException, + status_code=400, + detail=get_operation_message( + f"Numeric value in {resource_name} data is out of allowed range" + ), + ), + # String data right truncation + asyncpg.StringDataRightTruncationError: partialclass( + HTTPException, + status_code=400, + detail=get_operation_message( + f"Text data in {resource_name} is too long for the field" + ), + ), + # Not null violation + asyncpg.NotNullViolationError: partialclass( + HTTPException, + status_code=400, + detail=get_operation_message(f"Required {resource_name} field cannot be null"), + ), + # Python standard exceptions + ValueError: partialclass( + HTTPException, + status_code=400, + detail=get_operation_message(f"Invalid value provided for {resource_name}"), + ), + TypeError: partialclass( + HTTPException, + status_code=400, + detail=get_operation_message(f"Invalid type for {resource_name}"), + ), + AttributeError: partialclass( + HTTPException, + status_code=404, + detail=get_operation_message(f"Required attribute not found for {resource_name}"), + ), + KeyError: partialclass( + HTTPException, + status_code=404, + detail=get_operation_message(f"Required key not found for {resource_name}"), + ), + AssertionError: partialclass( + HTTPException, + status_code=404, + detail=get_operation_message(f"No {resource_name} found"), + ), + # Pydantic validation errors + pydantic.ValidationError: lambda e: partialclass( + HTTPException, + status_code=422, + detail={ + "message": get_operation_message(f"Validation failed for {resource_name}"), + "errors": [ + { + "loc": list(error["loc"]), + "msg": error["msg"], + "type": error["type"], + } + for error in e.errors() + ], + }, + )(e), + } + + # Add operation-specific exceptions + if operations: + if "delete" in operations: + exceptions.update({ + # Handle cases where deletion is blocked by dependent records + lambda e: isinstance(e, asyncpg.ForeignKeyViolationError) + and "still referenced" in str(e): partialclass( + HTTPException, + status_code=409, + detail=f"Cannot delete {resource_name} because it is still referenced by other records", + ), + }) + + if "update" in operations: + exceptions.update({ + # Handle cases where update would affect multiple rows + asyncpg.CardinalityViolationError: partialclass( + HTTPException, + status_code=409, + detail=f"Update would affect multiple {resource_name} records", + ), + }) + + return exceptions diff --git a/agents-api/agents_api/common/utils/debug.py b/agents-api/agents_api/common/utils/debug.py index c250f7ad7..a7ba13664 100644 --- a/agents-api/agents_api/common/utils/debug.py +++ b/agents-api/agents_api/common/utils/debug.py @@ -17,7 +17,7 @@ def wrapper(*args, **kwargs): print("Traceback:") traceback.print_exc() - breakpoint() - raise + breakpoint() # noqa: T100 + raise exc return wrapper diff --git a/agents-api/agents_api/common/utils/template.py b/agents-api/agents_api/common/utils/template.py index 5bde8cab6..c6bd245e2 100644 --- a/agents-api/agents_api/common/utils/template.py +++ b/agents-api/agents_api/common/utils/template.py @@ -1,5 +1,4 @@ -import re -from typing import List, TypeVar +from typing import TypeVar from beartype import beartype from jinja2.sandbox import ImmutableSandboxedEnvironment @@ -8,7 +7,7 @@ from ...activities.utils import ALLOWED_FUNCTIONS, constants, stdlib -__all__: List[str] = [ +__all__: list[str] = [ "render_template", ] @@ -27,13 +26,6 @@ for k, v in (constants | stdlib | ALLOWED_FUNCTIONS).items(): jinja_env.globals[k] = v -simple_jinja_regex = re.compile(r"{{|{%.+}}|%}", re.DOTALL) - - -# TODO: This does not work for some reason -def is_simple_jinja(template_string: str) -> bool: - return simple_jinja_regex.search(template_string) is None - # Funcs @beartype @@ -43,7 +35,6 @@ async def render_template_string( check: bool = False, ) -> str: # Parse template - # TODO: Check that the string is indeed a jinjd template template = jinja_env.from_string(template_string) # If check is required, get required vars from template and validate variables @@ -52,8 +43,7 @@ async def render_template_string( validate(instance=variables, schema=schema) # Render - rendered = await template.render_async(**variables) - return rendered + return await template.render_async(**variables) # A render function that can render arbitrarily nested lists of dicts @@ -73,8 +63,7 @@ async def render_template_nested( return await render_template_string(input, variables, check) case dict(): return { - k: await render_template_nested(v, variables, check) - for k, v in input.items() + k: await render_template_nested(v, variables, check) for k, v in input.items() } case list(): return [await render_template_nested(v, variables, check) for v in input] diff --git a/agents-api/agents_api/common/utils/types.py b/agents-api/agents_api/common/utils/types.py index 6bf9cd502..6ec093b84 100644 --- a/agents-api/agents_api/common/utils/types.py +++ b/agents-api/agents_api/common/utils/types.py @@ -1,22 +1,14 @@ -from typing import Type - from beartype.vale import Is from beartype.vale._core._valecore import BeartypeValidator from pydantic import BaseModel -def dict_like(pydantic_model_class: Type[BaseModel]) -> BeartypeValidator: - required_fields_set: set[str] = set( - [ - field - for field, info in pydantic_model_class.model_fields.items() - if info.is_required() - ] - ) +def dict_like(pydantic_model_class: type[BaseModel]) -> BeartypeValidator: + required_fields_set: set[str] = { + field for field, info in pydantic_model_class.model_fields.items() if info.is_required() + } - validator = Is[ + return Is[ lambda x: isinstance(x, pydantic_model_class) or required_fields_set.issubset(set(x.keys())) ] - - return validator diff --git a/agents-api/agents_api/dependencies/auth.py b/agents-api/agents_api/dependencies/auth.py index e5e22995b..4da49c26e 100644 --- a/agents-api/agents_api/dependencies/auth.py +++ b/agents-api/agents_api/dependencies/auth.py @@ -16,8 +16,5 @@ async def get_api_key( user_api_key = (user_api_key or "").replace("Bearer ", "").strip() if user_api_key != api_key: - raise HTTPException( - status_code=HTTP_403_FORBIDDEN, detail="Could not validate API KEY" - ) - else: - return user_api_key + raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="Could not validate API KEY") + return user_api_key diff --git a/agents-api/agents_api/dependencies/content_length.py b/agents-api/agents_api/dependencies/content_length.py new file mode 100644 index 000000000..3fe8b6781 --- /dev/null +++ b/agents-api/agents_api/dependencies/content_length.py @@ -0,0 +1,7 @@ +from fastapi import Header + +from ..env import max_payload_size + + +async def valid_content_length(content_length: int = Header(..., lt=max_payload_size)): + return content_length diff --git a/agents-api/agents_api/dependencies/developer_id.py b/agents-api/agents_api/dependencies/developer_id.py index e71df35d7..efaec0e5a 100644 --- a/agents-api/agents_api/dependencies/developer_id.py +++ b/agents-api/agents_api/dependencies/developer_id.py @@ -5,7 +5,7 @@ from ..common.protocol.developers import Developer from ..env import multi_tenant_mode -from ..models.developer.get_developer import get_developer, verify_developer +from ..queries.developers.get_developer import get_developer from .exceptions import InvalidHeaderFormat @@ -16,15 +16,15 @@ async def get_developer_id( return UUID("00000000-0000-0000-0000-000000000000") if not x_developer_id: - raise InvalidHeaderFormat("X-Developer-Id header required") + msg = "X-Developer-Id header required" + raise InvalidHeaderFormat(msg) if isinstance(x_developer_id, str): try: x_developer_id = UUID(x_developer_id, version=4) except ValueError as e: - raise InvalidHeaderFormat("X-Developer-Id must be a valid UUID") from e - - verify_developer(developer_id=x_developer_id) + msg = "X-Developer-Id must be a valid UUID" + raise InvalidHeaderFormat(msg) from e return x_developer_id @@ -33,20 +33,18 @@ async def get_developer_data( x_developer_id: Annotated[UUID | None, Header(include_in_schema=False)] = None, ) -> Developer: if not multi_tenant_mode: - assert ( - not x_developer_id - ), "X-Developer-Id header not allowed in multi-tenant mode" - return get_developer(developer_id=UUID("00000000-0000-0000-0000-000000000000")) + assert not x_developer_id, "X-Developer-Id header not allowed in multi-tenant mode" + return await get_developer(developer_id=UUID("00000000-0000-0000-0000-000000000000")) if not x_developer_id: - raise InvalidHeaderFormat("X-Developer-Id header required") + msg = "X-Developer-Id header required" + raise InvalidHeaderFormat(msg) if isinstance(x_developer_id, str): try: x_developer_id = UUID(x_developer_id, version=4) except ValueError as e: - raise InvalidHeaderFormat("X-Developer-Id must be a valid UUID") from e - - developer = get_developer(developer_id=x_developer_id) + msg = "X-Developer-Id must be a valid UUID" + raise InvalidHeaderFormat(msg) from e - return developer + return await get_developer(developer_id=x_developer_id) diff --git a/agents-api/agents_api/dependencies/query_filter.py b/agents-api/agents_api/dependencies/query_filter.py index 73e099225..841274912 100644 --- a/agents-api/agents_api/dependencies/query_filter.py +++ b/agents-api/agents_api/dependencies/query_filter.py @@ -1,4 +1,5 @@ -from typing import Annotated, Any, Callable +from collections.abc import Callable +from typing import Annotated, Any from fastapi import Query, Request from pydantic import BaseModel, ConfigDict @@ -38,9 +39,7 @@ def create_filter_extractor( def extract_filters( request: Request, - metadata_filter: Annotated[ - MetadataFilter, Query(default_factory=MetadataFilter) - ], + metadata_filter: Annotated[MetadataFilter, Query(default_factory=MetadataFilter)], ) -> MetadataFilter: """ Extracts query parameters that start with the specified prefix and returns them as a dictionary. diff --git a/agents-api/agents_api/env.py b/agents-api/agents_api/env.py index 2e7173b17..a5b37aaae 100644 --- a/agents-api/agents_api/env.py +++ b/agents-api/agents_api/env.py @@ -5,7 +5,7 @@ import random from pprint import pprint -from typing import Any, Dict +from typing import Any from environs import Env @@ -25,6 +25,10 @@ hostname: str = env.str("AGENTS_API_HOSTNAME", default="localhost") public_port: int = env.int("AGENTS_API_PUBLIC_PORT", default=80) api_prefix: str = env.str("AGENTS_API_PREFIX", default="") +max_payload_size: int = env.int( + "AGENTS_API_MAX_PAYLOAD_SIZE", + default=50 * 1024 * 1024, # 50MB +) # Tasks # ----- @@ -36,8 +40,8 @@ # Blob Store # ---------- -use_blob_store_for_temporal: bool = ( - env.bool("USE_BLOB_STORE_FOR_TEMPORAL", default=False) if not testing else False +use_blob_store_for_temporal: bool = testing or env.bool( + "USE_BLOB_STORE_FOR_TEMPORAL", default=False ) blob_store_bucket: str = env.str("BLOB_STORE_BUCKET", default="agents-api") @@ -47,17 +51,15 @@ s3_secret_key: str | None = env.str("S3_SECRET_KEY", default=None) -# Cozo +# PostgreSQL # ---- -cozo_host: str = env.str("COZO_HOST", default="http://127.0.0.1:9070") -cozo_auth: str = env.str("COZO_AUTH_TOKEN", default=None) -summarization_model_name: str = env.str( - "SUMMARIZATION_MODEL_NAME", default="gpt-4-turbo" -) -do_verify_developer: bool = env.bool("DO_VERIFY_DEVELOPER", default=True) -do_verify_developer_owns_resource: bool = env.bool( - "DO_VERIFY_DEVELOPER_OWNS_RESOURCE", default=True +pg_dsn: str = env.str( + "PG_DSN", + default="postgres://postgres:postgres@0.0.0.0:5432/postgres?sslmode=disable", ) +summarization_model_name: str = env.str("SUMMARIZATION_MODEL_NAME", default="gpt-4-turbo") + +query_timeout: float = env.float("QUERY_TIMEOUT", default=90.0) # Auth @@ -81,18 +83,14 @@ # Embedding service # ----------------- -embedding_model_id: str = env.str( - "EMBEDDING_MODEL_ID", default="Alibaba-NLP/gte-large-en-v1.5" -) +embedding_model_id: str = env.str("EMBEDDING_MODEL_ID", default="Alibaba-NLP/gte-large-en-v1.5") embedding_dimensions: int = env.int("EMBEDDING_DIMENSIONS", default=1024) # Integration service # ------------------- -integration_service_url: str = env.str( - "INTEGRATION_SERVICE_URL", default="http://0.0.0.0:8000" -) +integration_service_url: str = env.str("INTEGRATION_SERVICE_URL", default="http://0.0.0.0:8000") # Temporal @@ -107,9 +105,7 @@ "TEMPORAL_SCHEDULE_TO_CLOSE_TIMEOUT", default=3600 ) temporal_heartbeat_timeout: int = env.int("TEMPORAL_HEARTBEAT_TIMEOUT", default=900) -temporal_metrics_bind_host: str = env.str( - "TEMPORAL_METRICS_BIND_HOST", default="0.0.0.0" -) +temporal_metrics_bind_host: str = env.str("TEMPORAL_METRICS_BIND_HOST", default="0.0.0.0") temporal_metrics_bind_port: int = env.int("TEMPORAL_METRICS_BIND_PORT", default=14000) temporal_activity_after_retry_timeout: int = env.int( "TEMPORAL_ACTIVITY_AFTER_RETRY_TIMEOUT", default=30 @@ -140,29 +136,27 @@ def _parse_optional_int(val: str | None) -> int | None: ) # Consolidate environment variables -environment: Dict[str, Any] = dict( - debug=debug, - multi_tenant_mode=multi_tenant_mode, - cozo_host=cozo_host, - cozo_auth=cozo_auth, - sentry_dsn=sentry_dsn, - temporal_endpoint=temporal_endpoint, - temporal_task_queue=temporal_task_queue, - api_key=api_key, - api_key_header_name=api_key_header_name, - hostname=hostname, - api_prefix=api_prefix, - temporal_worker_url=temporal_worker_url, - temporal_namespace=temporal_namespace, - embedding_model_id=embedding_model_id, - use_blob_store_for_temporal=use_blob_store_for_temporal, - blob_store_bucket=blob_store_bucket, - blob_store_cutoff_kb=blob_store_cutoff_kb, - s3_endpoint=s3_endpoint, - s3_access_key=s3_access_key, - s3_secret_key=s3_secret_key, - testing=testing, -) +environment: dict[str, Any] = { + "debug": debug, + "multi_tenant_mode": multi_tenant_mode, + "sentry_dsn": sentry_dsn, + "temporal_endpoint": temporal_endpoint, + "temporal_task_queue": temporal_task_queue, + "api_key": api_key, + "api_key_header_name": api_key_header_name, + "hostname": hostname, + "api_prefix": api_prefix, + "temporal_worker_url": temporal_worker_url, + "temporal_namespace": temporal_namespace, + "embedding_model_id": embedding_model_id, + "use_blob_store_for_temporal": use_blob_store_for_temporal, + "blob_store_bucket": blob_store_bucket, + "blob_store_cutoff_kb": blob_store_cutoff_kb, + "s3_endpoint": s3_endpoint, + "s3_access_key": s3_access_key, + "s3_secret_key": s3_secret_key, + "testing": testing, +} if debug or testing: # Print the loaded environment variables for debugging purposes. diff --git a/agents-api/agents_api/exceptions.py b/agents-api/agents_api/exceptions.py index 615958a87..f6fcc4741 100644 --- a/agents-api/agents_api/exceptions.py +++ b/agents-api/agents_api/exceptions.py @@ -49,3 +49,12 @@ class FailedEncodingSentinel: """Sentinel object returned when failed to encode payload.""" payload_data: bytes + + +class QueriesBaseException(AgentsBaseException): + pass + + +class InvalidSQLQuery(QueriesBaseException): + def __init__(self, query_name: str): + super().__init__(f"invalid query: {query_name}") diff --git a/agents-api/agents_api/metrics/counters.py b/agents-api/agents_api/metrics/counters.py index f80236bf7..f34662d91 100644 --- a/agents-api/agents_api/metrics/counters.py +++ b/agents-api/agents_api/metrics/counters.py @@ -1,6 +1,7 @@ import inspect +from collections.abc import Awaitable, Callable from functools import wraps -from typing import Awaitable, Callable, ParamSpec, TypeVar +from typing import ParamSpec, TypeVar from prometheus_client import Counter diff --git a/agents-api/agents_api/model_registry.py b/agents-api/agents_api/model_registry.py index 0120cc205..4c20f56ab 100644 --- a/agents-api/agents_api/model_registry.py +++ b/agents-api/agents_api/model_registry.py @@ -2,9 +2,7 @@ Model Registry maintains a list of supported models and their configs. """ -from typing import Dict - -GPT4_MODELS: Dict[str, int] = { +GPT4_MODELS: dict[str, int] = { # stable model names: # resolves to gpt-4-0314 before 2023-06-27, # resolves to gpt-4-0613 after @@ -27,7 +25,7 @@ "gpt-4-32k-0314": 32768, } -TURBO_MODELS: Dict[str, int] = { +TURBO_MODELS: dict[str, int] = { # stable model names: # resolves to gpt-3.5-turbo-0301 before 2023-06-27, # resolves to gpt-3.5-turbo-0613 until 2023-12-11, @@ -48,14 +46,14 @@ "gpt-3.5-turbo-0301": 4096, } -GPT3_5_MODELS: Dict[str, int] = { +GPT3_5_MODELS: dict[str, int] = { "text-davinci-003": 4097, "text-davinci-002": 4097, # instruct models "gpt-3.5-turbo-instruct": 4096, } -GPT3_MODELS: Dict[str, int] = { +GPT3_MODELS: dict[str, int] = { "text-ada-001": 2049, "text-babbage-001": 2040, "text-curie-001": 2049, @@ -66,14 +64,14 @@ } -DISCONTINUED_MODELS: Dict[str, int] = { +DISCONTINUED_MODELS: dict[str, int] = { "code-davinci-002": 8001, "code-davinci-001": 8001, "code-cushman-002": 2048, "code-cushman-001": 2048, } -CLAUDE_MODELS: Dict[str, int] = { +CLAUDE_MODELS: dict[str, int] = { "claude-instant-1": 100000, "claude-instant-1.2": 100000, "claude-2": 100000, @@ -84,14 +82,14 @@ "claude-3-haiku-20240307": 180000, } -OPENAI_MODELS: Dict[str, int] = { +OPENAI_MODELS: dict[str, int] = { **GPT4_MODELS, **TURBO_MODELS, **GPT3_5_MODELS, **GPT3_MODELS, } -LOCAL_MODELS: Dict[str, int] = { +LOCAL_MODELS: dict[str, int] = { "gpt-4o": 32768, "gpt-4o-awq": 32768, "TinyLlama/TinyLlama_v1.1": 2048, @@ -100,13 +98,13 @@ "OpenPipe/Hermes-2-Theta-Llama-3-8B-32k": 32768, } -LOCAL_MODELS_WITH_TOOL_CALLS: Dict[str, int] = { +LOCAL_MODELS_WITH_TOOL_CALLS: dict[str, int] = { "OpenPipe/Hermes-2-Theta-Llama-3-8B-32k": 32768, "julep-ai/Hermes-2-Theta-Llama-3-8B": 8192, } -OLLAMA_MODELS: Dict[str, int] = { +OLLAMA_MODELS: dict[str, int] = { "llama2": 4096, } -CHAT_MODELS: Dict[str, int] = {**GPT4_MODELS, **TURBO_MODELS, **CLAUDE_MODELS} +CHAT_MODELS: dict[str, int] = {**GPT4_MODELS, **TURBO_MODELS, **CLAUDE_MODELS} diff --git a/agents-api/agents_api/models/__init__.py b/agents-api/agents_api/models/__init__.py deleted file mode 100644 index e59b5b01c..000000000 --- a/agents-api/agents_api/models/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -""" -The `models` module of the agents API is designed to encapsulate all data interactions with the CozoDB database. It provides a structured way to perform CRUD (Create, Read, Update, Delete) operations and other specific data manipulations across various entities such as agents, documents, entries, sessions, tools, and users. - -Each sub-module within this module corresponds to a specific entity and contains functions and classes that implement datalog queries for interacting with the database. These interactions include creating new records, updating existing ones, retrieving data for specific conditions, and deleting records. The operations are crucial for the functionality of the agents API, enabling it to manage and process data effectively for each entity. - -This module also integrates with the `common` module for exception handling and utility functions, ensuring robust error management and providing reusable components for data processing and query construction. -""" - -# ruff: noqa: F401, F403, F405 - -from . import agent as agent -from . import developer as developer -from . import docs as docs -from . import entry as entry -from . import execution as execution -from . import files as files -from . import session as session -from . import task as task -from . import tools as tools -from . import user as user diff --git a/agents-api/agents_api/models/agent/__init__.py b/agents-api/agents_api/models/agent/__init__.py deleted file mode 100644 index 2beaf8166..000000000 --- a/agents-api/agents_api/models/agent/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -""" -The `agent` module within the `agents-api` package provides a comprehensive suite of functionalities for managing agents in the CozoDB database. This includes: - -- Creating new agents and their associated tools. -- Updating existing agents and their settings. -- Retrieving details about specific agents or a list of agents. -- Deleting agents from the database. - -Additionally, the module supports operations related to agent tools, including creating, updating, and patching tools associated with agents. - -This module serves as the backbone for agent management within the CozoDB ecosystem, facilitating a wide range of operations necessary for the effective handling of agent data. -""" - -# ruff: noqa: F401, F403, F405 - -from .create_agent import create_agent -from .create_or_update_agent import create_or_update_agent -from .delete_agent import delete_agent -from .get_agent import get_agent -from .list_agents import list_agents -from .patch_agent import patch_agent -from .update_agent import update_agent diff --git a/agents-api/agents_api/models/agent/create_agent.py b/agents-api/agents_api/models/agent/create_agent.py deleted file mode 100644 index a9f0bfb8f..000000000 --- a/agents-api/agents_api/models/agent/create_agent.py +++ /dev/null @@ -1,147 +0,0 @@ -""" -This module contains the functionality for creating agents in the CozoDB database. -It includes functions to construct and execute datalog queries for inserting new agent records. -""" - -from typing import Any, TypeVar -from uuid import UUID, uuid4 - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import Agent, CreateAgentRequest -from ...common.utils.cozo import cozo_process_mutate_data -from ...metrics.counters import increase_counter -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - lambda e: isinstance(e, QueryException) - and "asserted to return some results, but returned none" - in str(e): lambda *_: HTTPException( - detail="Developer not found. Please ensure the provided auth token (which refers to your developer_id) is valid and the developer has the necessary permissions to create an agent.", - status_code=403, - ), - QueryException: partialclass( - HTTPException, - status_code=400, - detail="A database query failed to return the expected results. This might occur if the requested resource doesn't exist or your query parameters are incorrect.", - ), - ValidationError: partialclass( - HTTPException, - status_code=400, - detail="Input validation failed. Please check the provided data for missing or incorrect fields, and ensure it matches the required format.", - ), - TypeError: partialclass( - HTTPException, - status_code=400, - detail="A type mismatch occurred. This likely means the data provided is of an incorrect type (e.g., string instead of integer). Please review the input and try again.", - ), - } -) -@wrap_in_class( - Agent, - one=True, - transform=lambda d: {"id": UUID(d.pop("agent_id")), **d}, - _kind="inserted", -) -@cozo_query -@increase_counter("create_agent") -@beartype -def create_agent( - *, - developer_id: UUID, - agent_id: UUID | None = None, - data: CreateAgentRequest, -) -> tuple[list[str], dict]: - """ - Constructs and executes a datalog query to create a new agent in the database. - - Parameters: - agent_id (UUID | None): The unique identifier for the agent. - developer_id (UUID): The unique identifier for the developer creating the agent. - data (CreateAgentRequest): The data for the new agent. - - Returns: - Agent: The newly created agent record. - """ - - agent_id = agent_id or uuid4() - - # Extract the agent data from the payload - data.metadata = data.metadata or {} - data.default_settings = data.default_settings or {} - - data.instructions = ( - data.instructions - if isinstance(data.instructions, list) - else [data.instructions] - ) - - agent_data = data.model_dump() - default_settings = agent_data.pop("default_settings") - - settings_cols, settings_vals = cozo_process_mutate_data( - { - **default_settings, - "agent_id": str(agent_id), - } - ) - - # Create default agent settings - # Construct a query to insert default settings for the new agent - default_settings_query = f""" - ?[{settings_cols}] <- $settings_vals - - :insert agent_default_settings {{ - {settings_cols} - }} - """ - # create the agent - # Construct a query to insert the new agent record into the agents table - agent_query = """ - ?[agent_id, developer_id, model, name, about, metadata, instructions, created_at, updated_at] <- [ - [$agent_id, $developer_id, $model, $name, $about, $metadata, $instructions, now(), now()] - ] - - :insert agents { - developer_id, - agent_id => - model, - name, - about, - metadata, - instructions, - created_at, - updated_at, - } - :returning - """ - - queries = [ - verify_developer_id_query(developer_id), - default_settings_query, - agent_query, - ] - - return ( - queries, - { - "settings_vals": settings_vals, - "agent_id": str(agent_id), - "developer_id": str(developer_id), - **agent_data, - }, - ) diff --git a/agents-api/agents_api/models/agent/create_or_update_agent.py b/agents-api/agents_api/models/agent/create_or_update_agent.py deleted file mode 100644 index 9a1feb717..000000000 --- a/agents-api/agents_api/models/agent/create_or_update_agent.py +++ /dev/null @@ -1,186 +0,0 @@ -""" -This module contains the functionality for creating agents in the CozoDB database. -It includes functions to construct and execute datalog queries for inserting new agent records. -""" - -from typing import Any, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import Agent, CreateOrUpdateAgentRequest -from ...common.utils.cozo import cozo_process_mutate_data -from ...metrics.counters import increase_counter -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass( - HTTPException, - status_code=400, - detail="A database query failed to return the expected results. This might occur if the requested resource doesn't exist or your query parameters are incorrect.", - ), - ValidationError: partialclass( - HTTPException, - status_code=400, - detail="Input validation failed. Please check the provided data for missing or incorrect fields, and ensure it matches the required format.", - ), - TypeError: partialclass( - HTTPException, - status_code=400, - detail="A type mismatch occurred. This likely means the data provided is of an incorrect type (e.g., string instead of integer). Please review the input and try again.", - ), - } -) -@wrap_in_class( - Agent, one=True, transform=lambda d: {"id": UUID(d.pop("agent_id")), **d} -) -@cozo_query -@increase_counter("create_or_update_agent") -@beartype -def create_or_update_agent( - *, - developer_id: UUID, - agent_id: UUID, - data: CreateOrUpdateAgentRequest, -) -> tuple[list[str | None], dict]: - """ - Constructs and executes a datalog query to create a new agent in the database. - - Parameters: - agent_id (UUID): The unique identifier for the agent. - developer_id (UUID): The unique identifier for the developer creating the agent. - name (str): The name of the agent. - about (str): A description of the agent. - instructions (list[str], optional): A list of instructions for using the agent. Defaults to an empty list. - model (str, optional): The model identifier for the agent. Defaults to "gpt-4o". - metadata (dict, optional): A dictionary of metadata for the agent. Defaults to an empty dict. - default_settings (dict, optional): A dictionary of default settings for the agent. Defaults to an empty dict. - client (CozoClient, optional): The CozoDB client instance to use for the query. Defaults to a preconfigured client instance. - - Returns: - Agent: The newly created agent record. - """ - - # Extract the agent data from the payload - data.metadata = data.metadata or {} - data.instructions = ( - data.instructions - if isinstance(data.instructions, list) - else [data.instructions] - ) - data.default_settings = data.default_settings or {} - - agent_data = data.model_dump() - default_settings = ( - data.default_settings.model_dump(exclude_none=True) - if data.default_settings - else {} - ) - - settings_cols, settings_vals = cozo_process_mutate_data( - { - **default_settings, - "agent_id": str(agent_id), - } - ) - - # TODO: remove this - ### # Create default agent settings - ### # Construct a query to insert default settings for the new agent - ### default_settings_query = f""" - ### %if {{ - ### len[count(agent_id)] := - ### *agent_default_settings{{agent_id}}, - ### agent_id = to_uuid($agent_id) - - ### ?[should_create] := len[count], count > 0 - ### }} - ### %then {{ - ### ?[{settings_cols}] <- $settings_vals - - ### :put agent_default_settings {{ - ### {settings_cols} - ### }} - ### }} - ### """ - - # FIXME: This create or update query will overwrite the settings - # Need to find a way to only run the insert query if the agent_default_settings - - # Create default agent settings - # Construct a query to insert default settings for the new agent - default_settings_query = f""" - ?[{settings_cols}] <- $settings_vals - - :put agent_default_settings {{ - {settings_cols} - }} - """ - - # create the agent - # Construct a query to insert the new agent record into the agents table - agent_query = """ - input[agent_id, developer_id, model, name, about, metadata, instructions, updated_at] <- [ - [$agent_id, $developer_id, $model, $name, $about, $metadata, $instructions, now()] - ] - - ?[agent_id, developer_id, model, name, about, metadata, instructions, created_at, updated_at] := - input[_agent_id, developer_id, model, name, about, metadata, instructions, updated_at], - *agents{ - agent_id, - developer_id, - created_at, - }, - agent_id = to_uuid(_agent_id), - - ?[agent_id, developer_id, model, name, about, metadata, instructions, created_at, updated_at] := - input[_agent_id, developer_id, model, name, about, metadata, instructions, updated_at], - not *agents{ - agent_id, - developer_id, - }, created_at = now(), - agent_id = to_uuid(_agent_id), - - :put agents { - developer_id, - agent_id => - model, - name, - about, - metadata, - instructions, - created_at, - updated_at, - } - :returning - """ - - queries = [ - verify_developer_id_query(developer_id), - default_settings_query, - agent_query, - ] - - return ( - queries, - { - "settings_vals": settings_vals, - "agent_id": str(agent_id), - "developer_id": str(developer_id), - **agent_data, - }, - ) diff --git a/agents-api/agents_api/models/agent/delete_agent.py b/agents-api/agents_api/models/agent/delete_agent.py deleted file mode 100644 index 60de66292..000000000 --- a/agents-api/agents_api/models/agent/delete_agent.py +++ /dev/null @@ -1,134 +0,0 @@ -""" -This module contains the implementation of the delete_agent_query function, which is responsible for deleting an agent and its related default settings from the CozoDB database. -""" - -from typing import Any, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import ResourceDeletedResponse -from ...common.utils.datetime import utcnow -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - lambda e: isinstance(e, QueryException) - and "Developer does not own resource" - in e.resp["display"]: lambda *_: HTTPException( - detail="The specified developer does not own the requested resource. Please verify the ownership or check if the developer ID is correct.", - status_code=404, - ), - QueryException: partialclass( - HTTPException, - status_code=400, - detail="A database query failed to return the expected results. This might occur if the requested resource doesn't exist or your query parameters are incorrect.", - ), - ValidationError: partialclass( - HTTPException, - status_code=400, - detail="Input validation failed. Please check the provided data for missing or incorrect fields, and ensure it matches the required format.", - ), - TypeError: partialclass( - HTTPException, - status_code=400, - detail="A type mismatch occurred. This likely means the data provided is of an incorrect type (e.g., string instead of integer). Please review the input and try again.", - ), - } -) -@wrap_in_class( - ResourceDeletedResponse, - one=True, - transform=lambda d: { - "id": UUID(d.pop("agent_id")), - "deleted_at": utcnow(), - "jobs": [], - }, - _kind="deleted", -) -@cozo_query -@beartype -def delete_agent(*, developer_id: UUID, agent_id: UUID) -> tuple[list[str], dict]: - """ - Constructs and returns a datalog query for deleting an agent and its default settings from the database. - - Parameters: - developer_id (UUID): The UUID of the developer owning the agent. - agent_id (UUID): The UUID of the agent to be deleted. - client (CozoClient, optional): An instance of the CozoClient to execute the query. - - Returns: - ResourceDeletedResponse: The response indicating the deletion of the agent. - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query(developer_id, "agents", agent_id=agent_id), - """ - # Delete docs - ?[owner_id, owner_type, doc_id] := - *docs{ - owner_type, - owner_id, - doc_id, - }, - owner_id = to_uuid($agent_id), - owner_type = "agent" - - :delete docs { - owner_type, - owner_id, - doc_id - } - :returning - """, - """ - # Delete tools - ?[agent_id, tool_id] := - *tools{ - agent_id, - tool_id, - }, agent_id = to_uuid($agent_id) - - :delete tools { - agent_id, - tool_id - } - :returning - """, - """ - # Delete default agent settings - ?[agent_id] <- [[$agent_id]] - - :delete agent_default_settings { - agent_id - } - :returning - """, - """ - # Delete the agent - ?[agent_id, developer_id] <- [[$agent_id, $developer_id]] - - :delete agents { - developer_id, - agent_id - } - :returning - """, - ] - - return (queries, {"agent_id": str(agent_id), "developer_id": str(developer_id)}) diff --git a/agents-api/agents_api/models/agent/get_agent.py b/agents-api/agents_api/models/agent/get_agent.py deleted file mode 100644 index 008e39454..000000000 --- a/agents-api/agents_api/models/agent/get_agent.py +++ /dev/null @@ -1,117 +0,0 @@ -from typing import Any, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import Agent -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - lambda e: isinstance(e, QueryException) - and "Developer not found" in str(e): lambda *_: HTTPException( - detail="Developer does not exist", status_code=403 - ), - lambda e: isinstance(e, QueryException) - and "Developer does not own resource" - in e.resp["display"]: lambda *_: HTTPException( - detail="Developer does not own resource", status_code=404 - ), - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class(Agent, one=True) -@cozo_query -@beartype -def get_agent(*, developer_id: UUID, agent_id: UUID) -> tuple[list[str], dict]: - """ - Fetches agent details and default settings from the database. - - This function constructs and executes a datalog query to retrieve information about a specific agent, including its default settings, based on the provided agent_id and developer_id. - - Parameters: - developer_id (UUID): The unique identifier for the developer. - agent_id (UUID): The unique identifier for the agent. - client (CozoClient, optional): The database client used to execute the query. - - Returns: - Agent - """ - # Constructing a datalog query to retrieve agent details and default settings. - # The query uses input parameters for agent_id and developer_id to filter the results. - # It joins the 'agents' and 'agent_default_settings' relations to fetch comprehensive details. - get_query = """ - input[agent_id] <- [[to_uuid($agent_id)]] - - ?[ - id, - model, - name, - about, - created_at, - updated_at, - metadata, - default_settings, - instructions, - ] := input[id], - *agents { - developer_id: to_uuid($developer_id), - agent_id: id, - model, - name, - about, - created_at, - updated_at, - metadata, - instructions, - }, - *agent_default_settings { - agent_id: id, - frequency_penalty, - presence_penalty, - length_penalty, - repetition_penalty, - top_p, - temperature, - min_p, - preset, - }, - default_settings = { - "frequency_penalty": frequency_penalty, - "presence_penalty": presence_penalty, - "length_penalty": length_penalty, - "repetition_penalty": repetition_penalty, - "top_p": top_p, - "temperature": temperature, - "min_p": min_p, - "preset": preset, - } - - :limit 1 - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query(developer_id, "agents", agent_id=agent_id), - get_query, - ] - - # Execute the constructed datalog query using the provided CozoClient. - # The result is returned as a pandas DataFrame. - return (queries, {"agent_id": str(agent_id), "developer_id": str(developer_id)}) diff --git a/agents-api/agents_api/models/agent/list_agents.py b/agents-api/agents_api/models/agent/list_agents.py deleted file mode 100644 index 882b6c8c6..000000000 --- a/agents-api/agents_api/models/agent/list_agents.py +++ /dev/null @@ -1,122 +0,0 @@ -from typing import Any, Literal, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import Agent -from ...common.utils import json -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class(Agent) -@cozo_query -@beartype -def list_agents( - *, - developer_id: UUID, - limit: int = 100, - offset: int = 0, - sort_by: Literal["created_at", "updated_at"] = "created_at", - direction: Literal["asc", "desc"] = "desc", - metadata_filter: dict[str, Any] = {}, -) -> tuple[list[str], dict]: - """ - Constructs and executes a datalog query to list agents from the 'cozodb' database. - - Parameters: - developer_id: UUID of the developer. - limit: Maximum number of agents to return. - offset: Number of agents to skip before starting to collect the result set. - metadata_filter: Dictionary to filter agents based on metadata. - client: Instance of CozoClient to execute the query. - """ - # Transforms the metadata_filter dictionary into a string representation for the datalog query. - metadata_filter_str = ", ".join( - [ - f"metadata->{json.dumps(k)} == {json.dumps(v)}" - for k, v in metadata_filter.items() - ] - ) - - sort = f"{'-' if direction == 'desc' else ''}{sort_by}" - - # Datalog query to retrieve agent information based on filters, sorted by creation date in descending order. - queries = [ - verify_developer_id_query(developer_id), - f""" - input[developer_id] <- [[to_uuid($developer_id)]] - - ?[ - id, - model, - name, - about, - created_at, - updated_at, - metadata, - default_settings, - instructions, - ] := input[developer_id], - *agents {{ - developer_id, - agent_id: id, - model, - name, - about, - created_at, - updated_at, - metadata, - instructions, - }}, - *agent_default_settings {{ - agent_id: id, - frequency_penalty, - presence_penalty, - length_penalty, - repetition_penalty, - top_p, - temperature, - min_p, - preset, - }}, - default_settings = {{ - "frequency_penalty": frequency_penalty, - "presence_penalty": presence_penalty, - "length_penalty": length_penalty, - "repetition_penalty": repetition_penalty, - "top_p": top_p, - "temperature": temperature, - "min_p": min_p, - "preset": preset, - }}, - {metadata_filter_str} - - :limit $limit - :offset $offset - :sort {sort} - """, - ] - - return ( - queries, - {"developer_id": str(developer_id), "limit": limit, "offset": offset}, - ) diff --git a/agents-api/agents_api/models/agent/patch_agent.py b/agents-api/agents_api/models/agent/patch_agent.py deleted file mode 100644 index 99d4e3553..000000000 --- a/agents-api/agents_api/models/agent/patch_agent.py +++ /dev/null @@ -1,132 +0,0 @@ -from typing import Any, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import PatchAgentRequest, ResourceUpdatedResponse -from ...common.utils.cozo import cozo_process_mutate_data -from ...common.utils.datetime import utcnow -from ...metrics.counters import increase_counter -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class( - ResourceUpdatedResponse, - one=True, - transform=lambda d: {"id": d["agent_id"], "jobs": [], **d}, - _kind="inserted", -) -@cozo_query -@increase_counter("patch_agent") -@beartype -def patch_agent( - *, - agent_id: UUID, - developer_id: UUID, - data: PatchAgentRequest, -) -> tuple[list[str], dict]: - """Patches agent data based on provided updates. - - Parameters: - agent_id (UUID): The unique identifier for the agent. - developer_id (UUID): The unique identifier for the developer. - default_settings (dict, optional): Default settings to apply to the agent. - **update_data: Arbitrary keyword arguments representing data to update. - - Returns: - ResourceUpdatedResponse: The updated agent data. - """ - update_data = data.model_dump(exclude_unset=True) - - # Construct the query for updating agent information in the database. - # Agent update query - metadata = update_data.pop("metadata", {}) or {} - default_settings = update_data.pop("default_settings", {}) or {} - agent_update_cols, agent_update_vals = cozo_process_mutate_data( - { - **{k: v for k, v in update_data.items() if v is not None}, - "agent_id": str(agent_id), - "developer_id": str(developer_id), - "updated_at": utcnow().timestamp(), - } - ) - - update_query = f""" - # update the agent - input[{agent_update_cols}] <- $agent_update_vals - - ?[{agent_update_cols}, metadata] := - input[{agent_update_cols}], - *agents {{ - agent_id: to_uuid($agent_id), - metadata: md, - }}, - metadata = concat(md, $metadata) - - :update agents {{ - {agent_update_cols}, - metadata, - }} - :returning - """ - - # Construct the query for updating agent's default settings in the database. - # Settings update query - settings_cols, settings_vals = cozo_process_mutate_data( - { - **default_settings, - "agent_id": str(agent_id), - } - ) - - settings_update_query = f""" - # update the agent settings - ?[{settings_cols}] <- $settings_vals - - :update agent_default_settings {{ - {settings_cols} - }} - """ - - # Combine agent and settings update queries if default settings are provided. - # Combine the queries - queries = [update_query] - - if len(default_settings) != 0: - queries.insert(0, settings_update_query) - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query(developer_id, "agents", agent_id=agent_id), - *queries, - ] - - return ( - queries, - { - "agent_update_vals": agent_update_vals, - "settings_vals": settings_vals, - "metadata": metadata, - "agent_id": str(agent_id), - }, - ) diff --git a/agents-api/agents_api/models/agent/update_agent.py b/agents-api/agents_api/models/agent/update_agent.py deleted file mode 100644 index b36f687eb..000000000 --- a/agents-api/agents_api/models/agent/update_agent.py +++ /dev/null @@ -1,149 +0,0 @@ -from typing import Any, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import ResourceUpdatedResponse, UpdateAgentRequest -from ...common.utils.cozo import cozo_process_mutate_data -from ...metrics.counters import increase_counter -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class( - ResourceUpdatedResponse, - one=True, - transform=lambda d: {"id": d["agent_id"], "jobs": [], **d}, - _kind="inserted", -) -@cozo_query -@increase_counter("update_agent") -@beartype -def update_agent( - *, - agent_id: UUID, - developer_id: UUID, - data: UpdateAgentRequest, -) -> tuple[list[str], dict]: - """ - Constructs and executes a datalog query to update an agent and its default settings in the 'cozodb' database. - - Parameters: - agent_id (UUID): The unique identifier of the agent to be updated. - developer_id (UUID): The unique identifier of the developer associated with the agent. - data (UpdateAgentRequest): The request payload containing the updated agent data. - client (CozoClient, optional): The database client used to execute the query. Defaults to a pre-configured client instance. - - Returns: - ResourceUpdatedResponse: The updated agent data. - """ - default_settings = ( - data.default_settings.model_dump(exclude_none=True) - if data.default_settings - else {} - ) - update_data = data.model_dump() - - # Remove default settings from the agent update data - update_data.pop("default_settings", None) - - agent_id = str(agent_id) - developer_id = str(developer_id) - update_data["instructions"] = update_data.get("instructions", []) - update_data["instructions"] = ( - update_data["instructions"] - if isinstance(update_data["instructions"], list) - else [update_data["instructions"]] - ) - - # Construct the agent update part of the query with dynamic columns and values based on `update_data`. - # Agent update query - agent_update_cols, agent_update_vals = cozo_process_mutate_data( - { - **{k: v for k, v in update_data.items() if v is not None}, - "agent_id": agent_id, - "developer_id": developer_id, - } - ) - - update_query = f""" - # update the agent - input[{agent_update_cols}] <- $agent_update_vals - original[created_at] := *agents{{ - developer_id: to_uuid($developer_id), - agent_id: to_uuid($agent_id), - created_at, - }}, - - ?[created_at, updated_at, {agent_update_cols}] := - input[{agent_update_cols}], - original[created_at], - updated_at = now(), - - :put agents {{ - created_at, - updated_at, - {agent_update_cols} - }} - :returning - """ - - # Construct the settings update part of the query if `default_settings` are provided. - # Settings update query - settings_cols, settings_vals = cozo_process_mutate_data( - { - **default_settings, - "agent_id": agent_id, - } - ) - - settings_update_query = f""" - # update the agent settings - ?[{settings_cols}] <- $settings_vals - - :put agent_default_settings {{ - {settings_cols} - }} - """ - - # Combine agent and settings update queries into a single query string. - # Combine the queries - queries = [update_query] - - if len(default_settings) != 0: - queries.insert(0, settings_update_query) - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query(developer_id, "agents", agent_id=agent_id), - *queries, - ] - - return ( - queries, - { - "agent_update_vals": agent_update_vals, - "settings_vals": settings_vals, - "agent_id": agent_id, - "developer_id": developer_id, - }, - ) diff --git a/agents-api/agents_api/models/chat/get_cached_response.py b/agents-api/agents_api/models/chat/get_cached_response.py deleted file mode 100644 index 368c88567..000000000 --- a/agents-api/agents_api/models/chat/get_cached_response.py +++ /dev/null @@ -1,15 +0,0 @@ -from beartype import beartype - -from ..utils import cozo_query - - -@cozo_query -@beartype -def get_cached_response(key: str) -> tuple[str, dict]: - query = """ - input[key] <- [[$key]] - ?[key, value] := input[key], *session_cache{key, value} - :limit 1 - """ - - return (query, {"key": key}) diff --git a/agents-api/agents_api/models/chat/prepare_chat_context.py b/agents-api/agents_api/models/chat/prepare_chat_context.py deleted file mode 100644 index f77686d7a..000000000 --- a/agents-api/agents_api/models/chat/prepare_chat_context.py +++ /dev/null @@ -1,143 +0,0 @@ -from typing import Any, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...common.protocol.sessions import ChatContext, make_session -from ..session.prepare_session_data import prepare_session_data -from ..utils import ( - cozo_query, - fix_uuid_if_present, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class( - ChatContext, - one=True, - transform=lambda d: { - **d, - "session": make_session( - agents=[a["id"] for a in d["agents"]], - users=[u["id"] for u in d["users"]], - **d["session"], - ), - "toolsets": [ - { - **ts, - "tools": [ - { - tool["type"]: tool.pop("spec"), - **tool, - } - for tool in map(fix_uuid_if_present, ts["tools"]) - ], - } - for ts in d["toolsets"] - ], - }, -) -@cozo_query -@beartype -def prepare_chat_context( - *, - developer_id: UUID, - session_id: UUID, -) -> tuple[list[str], dict]: - """ - Executes a complex query to retrieve memory context based on session ID. - """ - - [*_, session_data_query], sd_vars = prepare_session_data.__wrapped__( - developer_id=developer_id, session_id=session_id - ) - - session_data_fields = ("session", "agents", "users") - - session_data_query += """ - :create _session_data_json { - agents: [Json], - users: [Json], - session: Json, - } - """ - - toolsets_query = """ - input[session_id] <- [[to_uuid($session_id)]] - - tools_by_agent[agent_id, collect(tool)] := - input[session_id], - *session_lookup{ - session_id, - participant_id: agent_id, - participant_type: "agent", - }, - - *tools { agent_id, tool_id, name, type, spec, description, updated_at, created_at }, - tool = { - "id": tool_id, - "name": name, - "type": type, - "spec": spec, - "description": description, - "updated_at": updated_at, - "created_at": created_at, - } - - agent_toolsets[collect(toolset)] := - tools_by_agent[agent_id, tools], - toolset = { - "agent_id": agent_id, - "tools": tools, - } - - ?[toolsets] := - agent_toolsets[toolsets] - - :create _toolsets_json { - toolsets: [Json], - } - """ - - combine_query = f""" - ?[{', '.join(session_data_fields)}, toolsets] := - *_session_data_json {{ {', '.join(session_data_fields)} }}, - *_toolsets_json {{ toolsets }} - - :limit 1 - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query( - developer_id, "sessions", session_id=session_id - ), - session_data_query, - toolsets_query, - combine_query, - ] - - return ( - queries, - { - "session_id": str(session_id), - **sd_vars, - }, - ) diff --git a/agents-api/agents_api/models/chat/set_cached_response.py b/agents-api/agents_api/models/chat/set_cached_response.py deleted file mode 100644 index 8625f3f1b..000000000 --- a/agents-api/agents_api/models/chat/set_cached_response.py +++ /dev/null @@ -1,19 +0,0 @@ -from beartype import beartype - -from ..utils import cozo_query - - -@cozo_query -@beartype -def set_cached_response(key: str, value: dict) -> tuple[str, dict]: - query = """ - ?[key, value] <- [[$key, $value]] - - :insert session_cache { - key => value - } - - :returning - """ - - return (query, {"key": key, "value": value}) diff --git a/agents-api/agents_api/models/developer/get_developer.py b/agents-api/agents_api/models/developer/get_developer.py deleted file mode 100644 index 0ae5421aa..000000000 --- a/agents-api/agents_api/models/developer/get_developer.py +++ /dev/null @@ -1,74 +0,0 @@ -"""Module for retrieving document snippets from the CozoDB based on document IDs.""" - -from typing import Any, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...common.protocol.developers import Developer -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions({QueryException: partialclass(HTTPException, status_code=401)}) -@cozo_query -@beartype -def verify_developer( - *, - developer_id: UUID, -) -> tuple[str, dict]: - return (verify_developer_id_query(developer_id), {}) - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=403), - ValidationError: partialclass(HTTPException, status_code=500), - } -) -@wrap_in_class(Developer, one=True, transform=lambda d: {**d, "id": d["developer_id"]}) -@cozo_query -@beartype -def get_developer( - *, - developer_id: UUID, -) -> tuple[str, dict]: - developer_id = str(developer_id) - - query = """ - input[developer_id] <- [[to_uuid($developer_id)]] - ?[ - developer_id, - email, - active, - tags, - settings, - created_at, - updated_at, - ] := - input[developer_id], - *developers { - developer_id, - email, - active, - tags, - settings, - created_at, - updated_at, - } - - :limit 1 - """ - - return (query, {"developer_id": developer_id}) diff --git a/agents-api/agents_api/models/docs/create_doc.py b/agents-api/agents_api/models/docs/create_doc.py deleted file mode 100644 index 3b9c8c9f7..000000000 --- a/agents-api/agents_api/models/docs/create_doc.py +++ /dev/null @@ -1,140 +0,0 @@ -from typing import Any, Literal, TypeVar -from uuid import UUID, uuid4 - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import CreateDocRequest, Doc -from ...common.utils.cozo import cozo_process_mutate_data -from ...metrics.counters import increase_counter -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class( - Doc, - one=True, - transform=lambda d: { - "id": UUID(d["doc_id"]), - **d, - }, -) -@cozo_query -@increase_counter("create_doc") -@beartype -def create_doc( - *, - developer_id: UUID, - owner_type: Literal["user", "agent"], - owner_id: UUID, - doc_id: UUID | None = None, - data: CreateDocRequest, -) -> tuple[list[str], dict]: - """ - Constructs and executes a datalog query to create a new document and its associated snippets in the 'cozodb' database. - - Parameters: - owner_type (Literal["user", "agent"]): The type of the owner of the document. - owner_id (UUID): The UUID of the document owner. - doc_id (UUID): The UUID of the document to be created. - data (CreateDocRequest): The content of the document. - """ - - doc_id = str(doc_id or uuid4()) - owner_id = str(owner_id) - - if isinstance(data.content, str): - data.content = [data.content] - - data.metadata = data.metadata or {} - - doc_data = data.model_dump() - doc_data.pop("embed_instruction", None) - content = doc_data.pop("content") - - doc_data["owner_type"] = owner_type - doc_data["owner_id"] = owner_id - doc_data["doc_id"] = doc_id - - doc_cols, doc_rows = cozo_process_mutate_data(doc_data) - - snippet_cols, snippet_rows = "", [] - - # Process each content snippet and prepare data for the datalog query. - for snippet_idx, snippet in enumerate(content): - snippet_cols, new_snippet_rows = cozo_process_mutate_data( - dict( - doc_id=doc_id, - index=snippet_idx, - content=snippet, - ) - ) - - snippet_rows += new_snippet_rows - - create_snippets_query = f""" - ?[{snippet_cols}] <- $snippet_rows - - :create _snippets {{ {snippet_cols} }} - }} {{ - ?[{snippet_cols}] <- $snippet_rows - :insert snippets {{ {snippet_cols} }} - :returning - """ - - # Construct the datalog query for creating the document and its snippets. - create_doc_query = f""" - ?[{doc_cols}] <- $doc_rows - - :create _docs {{ {doc_cols} }} - }} {{ - ?[{doc_cols}] <- $doc_rows - :insert docs {{ {doc_cols} }} - :returning - }} {{ - snippet_rows[collect(content)] := - *_snippets {{ - content - }} - - ?[{doc_cols}, content, created_at] := - *_docs {{ {doc_cols} }}, - snippet_rows[content], - created_at = now() - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query( - developer_id, f"{owner_type}s", **{f"{owner_type}_id": owner_id} - ), - create_snippets_query, - create_doc_query, - ] - - # Execute the constructed datalog query and return the results as a DataFrame. - return ( - queries, - { - "doc_rows": doc_rows, - "snippet_rows": snippet_rows, - }, - ) diff --git a/agents-api/agents_api/models/docs/delete_doc.py b/agents-api/agents_api/models/docs/delete_doc.py deleted file mode 100644 index c02705756..000000000 --- a/agents-api/agents_api/models/docs/delete_doc.py +++ /dev/null @@ -1,102 +0,0 @@ -from typing import Any, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import ResourceDeletedResponse -from ...common.utils.datetime import utcnow -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class( - ResourceDeletedResponse, - one=True, - transform=lambda d: { - "id": UUID(d.pop("doc_id")), - "deleted_at": utcnow(), - "jobs": [], - }, - _kind="deleted", -) -@cozo_query -@beartype -def delete_doc( - *, - developer_id: UUID, - owner_id: UUID, - owner_type: str, - doc_id: UUID, -) -> tuple[list[str], dict]: - """Constructs and returns a datalog query for deleting documents and associated information snippets. - - This function targets the 'cozodb' database, allowing for the removal of documents and their related information snippets based on the provided document ID and owner (user or agent). - - Parameters: - doc_id (UUID): The UUID of the document to be deleted. - client (CozoClient): An instance of the CozoClient to execute the query. - - Returns: - pd.DataFrame: The result of the executed datalog query. - """ - # Convert UUID parameters to string format for use in the datalog query - doc_id = str(doc_id) - owner_id = str(owner_id) - - # The following query is divided into two main parts: - # 1. Deleting information snippets associated with the document - # 2. Deleting the document itself - delete_snippets_query = """ - # This section constructs the subquery for identifying and deleting all information snippets associated with the given document ID. - # Delete snippets - input[doc_id] <- [[to_uuid($doc_id)]] - ?[doc_id, index] := - input[doc_id], - *snippets { - doc_id, - index, - } - - :delete snippets { - doc_id, - index - } - """ - - delete_doc_query = """ - # Delete the docs - ?[doc_id, owner_type, owner_id] <- [[ to_uuid($doc_id), $owner_type, to_uuid($owner_id) ]] - - :delete docs { doc_id, owner_type, owner_id } - :returning - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query( - developer_id, f"{owner_type}s", **{f"{owner_type}_id": owner_id} - ), - delete_snippets_query, - delete_doc_query, - ] - - return (queries, {"doc_id": doc_id, "owner_type": owner_type, "owner_id": owner_id}) diff --git a/agents-api/agents_api/models/docs/embed_snippets.py b/agents-api/agents_api/models/docs/embed_snippets.py deleted file mode 100644 index 8d8ae1e62..000000000 --- a/agents-api/agents_api/models/docs/embed_snippets.py +++ /dev/null @@ -1,102 +0,0 @@ -"""Module for embedding documents in the cozodb database. Contains functions to update document embeddings.""" - -from typing import Any, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import ResourceUpdatedResponse -from ...common.utils.cozo import cozo_process_mutate_data -from ...common.utils.datetime import utcnow -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class( - ResourceUpdatedResponse, - one=True, - transform=lambda d: {"id": d["doc_id"], "updated_at": utcnow(), "jobs": []}, - _kind="inserted", -) -@cozo_query -@beartype -def embed_snippets( - *, - developer_id: UUID, - doc_id: UUID, - snippet_indices: list[int] | tuple[int, ...], - embeddings: list[list[float]], - embedding_size: int = 1024, -) -> tuple[list[str], dict]: - """Embeds document snippets in the cozodb database. - - Parameters: - doc_id (UUID): The unique identifier for the document. - snippet_indices (list[int]): Indices of the snippets in the document. - embeddings (list[list[float]]): Embedding vectors for the snippets. - """ - - doc_id = str(doc_id) - - # Ensure the number of snippet indices matches the number of embeddings. - assert len(snippet_indices) == len(embeddings) - assert all(len(embedding) == embedding_size for embedding in embeddings) - assert min(snippet_indices) >= 0 - - # Ensure all embeddings are non-zero. - assert all(sum(embedding) for embedding in embeddings) - - # Create a list of records to update the document snippet embeddings in the database. - records = [ - {"doc_id": doc_id, "index": snippet_idx, "embedding": embedding} - for snippet_idx, embedding in zip(snippet_indices, embeddings) - ] - - cols, vals = cozo_process_mutate_data(records) - - # Ensure that index is present in the records. - check_indices_query = f""" - ?[index] := - *snippets {{ - doc_id: $doc_id, - index, - }}, - index > {max(snippet_indices)} - - :limit 1 - :assert none - """ - - # Define the datalog query for updating document snippet embeddings in the database. - embed_query = f""" - ?[{cols}] <- $vals - - :update snippets {{ {cols} }} - :returning - """ - - queries = [ - verify_developer_id_query(developer_id), - check_indices_query, - embed_query, - ] - - return (queries, {"vals": vals, "doc_id": doc_id}) diff --git a/agents-api/agents_api/models/docs/get_doc.py b/agents-api/agents_api/models/docs/get_doc.py deleted file mode 100644 index d47cc80a8..000000000 --- a/agents-api/agents_api/models/docs/get_doc.py +++ /dev/null @@ -1,103 +0,0 @@ -"""Module for retrieving document snippets from the CozoDB based on document IDs.""" - -from typing import Any, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import Doc -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - lambda e: isinstance(e, AssertionError) - and "Expected one result" in repr(e): partialclass( - HTTPException, status_code=404 - ), - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class( - Doc, - one=True, - transform=lambda d: { - "content": [s[1] for s in sorted(d["snippet_data"], key=lambda x: x[0])], - "embeddings": [ - s[2] - for s in sorted(d["snippet_data"], key=lambda x: x[0]) - if s[2] is not None - ], - **d, - }, -) -@cozo_query -@beartype -def get_doc( - *, - developer_id: UUID, - doc_id: UUID, -) -> tuple[list[str], dict]: - """ - Retrieves snippets of documents by their ID from the CozoDB. - - Parameters: - doc_id (UUID): The unique identifier of the document. - client (CozoClient, optional): The CozoDB client instance. Defaults to a pre-configured client. - - Returns: - pd.DataFrame: A DataFrame containing the document snippets and related metadata. - """ - - doc_id = str(doc_id) - - get_query = """ - input[doc_id] <- [[to_uuid($doc_id)]] - snippets[collect(snippet_data)] := - input[doc_id], - *snippets { - doc_id, - index, - content, - embedding, - }, - snippet_data = [index, content, embedding] - - ?[ - id, - title, - snippet_data, - created_at, - metadata, - ] := input[id], - *docs { - doc_id: id, - title, - created_at, - metadata, - }, - snippets[snippet_data] - - :limit 1 - """ - - queries = [ - verify_developer_id_query(developer_id), - get_query, - ] - - return (queries, {"doc_id": doc_id}) diff --git a/agents-api/agents_api/models/docs/list_docs.py b/agents-api/agents_api/models/docs/list_docs.py deleted file mode 100644 index dd389d58c..000000000 --- a/agents-api/agents_api/models/docs/list_docs.py +++ /dev/null @@ -1,141 +0,0 @@ -"""This module contains functions for querying document-related data from the 'cozodb' database using datalog queries.""" - -import json -from typing import Any, Literal, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import Doc -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class( - Doc, - transform=lambda d: { - "content": [s[1] for s in sorted(d["snippet_data"], key=lambda x: x[0])], - "embeddings": [ - s[2] - for s in sorted(d["snippet_data"], key=lambda x: x[0]) - if s[2] is not None - ], - **d, - }, -) -@cozo_query -@beartype -def list_docs( - *, - developer_id: UUID, - owner_type: Literal["user", "agent"], - owner_id: UUID, - limit: int = 100, - offset: int = 0, - sort_by: Literal["created_at"] = "created_at", - direction: Literal["asc", "desc"] = "desc", - metadata_filter: dict[str, Any] = {}, - include_without_embeddings: bool = False, -) -> tuple[list[str], dict]: - """ - Constructs and returns a datalog query for listing documents and their associated information snippets. - - Parameters: - developer_id (UUID): The unique identifier of the developer associated with the documents. - owner_id (UUID): The unique identifier of the owner (user or agent) associated with the documents. - owner_type (Literal["user", "agent"]): The type of owner associated with the documents. - limit (int): The maximum number of documents to return. - offset (int): The number of documents to skip before returning the results. - sort_by (Literal["created_at"]): The field to sort the documents by. - direction (Literal["asc", "desc"]): The direction to sort the documents in. - metadata_filter (dict): A dictionary of metadata filters to apply to the documents. - include_without_embeddings (bool): Whether to include documents without embeddings in the results. - - Returns: - Doc[] - """ - - # Transforms the metadata_filter dictionary into a string representation for the datalog query. - metadata_filter_str = ", ".join( - [ - f"metadata->{json.dumps(k)} == {json.dumps(v)}" - for k, v in metadata_filter.items() - ] - ) - - owner_id = str(owner_id) - sort = f"{'-' if direction == 'desc' else ''}{sort_by}" - - get_query = f""" - snippets[id, collect(snippet_data)] := - *snippets {{ - doc_id: id, - index, - content, - embedding, - }}, - {"" if include_without_embeddings else "not is_null(embedding),"} - snippet_data = [index, content, embedding] - - ?[ - owner_type, - id, - title, - snippet_data, - created_at, - metadata, - ] := - owner_type = $owner_type, - owner_id = to_uuid($owner_id), - *docs {{ - owner_type, - owner_id, - doc_id: id, - title, - created_at, - metadata, - }}, - snippets[id, snippet_data], - {metadata_filter_str} - - :limit $limit - :offset $offset - :sort {sort} - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query( - developer_id, f"{owner_type}s", **{f"{owner_type}_id": owner_id} - ), - get_query, - ] - - return ( - queries, - { - "owner_id": owner_id, - "owner_type": owner_type, - "limit": limit, - "offset": offset, - }, - ) diff --git a/agents-api/agents_api/models/docs/search_docs_by_embedding.py b/agents-api/agents_api/models/docs/search_docs_by_embedding.py deleted file mode 100644 index 992e12f9d..000000000 --- a/agents-api/agents_api/models/docs/search_docs_by_embedding.py +++ /dev/null @@ -1,369 +0,0 @@ -"""This module contains functions for searching documents in the CozoDB based on embedding queries.""" - -import json -from typing import Any, Literal, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import DocReference -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class( - DocReference, - transform=lambda d: { - "owner": { - "id": d["owner_id"], - "role": d["owner_type"], - }, - "metadata": d.get("metadata", {}), - **d, - }, -) -@cozo_query -@beartype -def search_docs_by_embedding( - *, - developer_id: UUID, - owners: list[tuple[Literal["user", "agent"], UUID]], - query_embedding: list[float], - k: int = 3, - confidence: float = 0.5, - ef: int = 50, - embedding_size: int = 1024, - ann_threshold: int = 1_000_000, - metadata_filter: dict[str, Any] = {}, -) -> tuple[str, dict]: - """ - Searches for document snippets in CozoDB by embedding query. - - Parameters: - owner_type (Literal["user", "agent"]): The type of the owner of the documents. - owner_id (UUID): The unique identifier of the owner. - query_embedding (list[float]): The embedding vector of the query. - k (int, optional): The number of nearest neighbors to retrieve. Defaults to 3. - confidence (float, optional): The confidence threshold for filtering results. Defaults to 0.8. - mmr_lambda (float, optional): The lambda parameter for MMR. Defaults to 0.25. - embedding_size (int): Embedding vector length - metadata_filter (dict[str, Any]): Dictionary to filter agents based on metadata. - """ - - assert len(query_embedding) == embedding_size - assert sum(query_embedding) - - metadata_filter_str = ", ".join( - [ - f"metadata->{json.dumps(k)} == {json.dumps(v)}" - for k, v in metadata_filter.items() - ] - ) - - owners: list[list[str]] = [ - [owner_type, str(owner_id)] for owner_type, owner_id in owners - ] - - # Calculate the search radius based on confidence level - radius: float = 1.0 - confidence - - determine_knn_ann_query = f""" - owners[owner_type, owner_id] <- $owners - snippet_counter[count(item)] := - owners[owner_type, owner_id_str], - owner_id = to_uuid(owner_id_str), - *docs {{ - owner_type, - owner_id, - doc_id: item, - metadata, - }} - {', ' + metadata_filter_str if metadata_filter_str.strip() else ''} - - ?[use_ann] := - snippet_counter[count], - count > {ann_threshold}, - use_ann = true - - :limit 1 - :create _determine_knn_ann {{ - use_ann - }} - """ - - # Construct the datalog query for searching document snippets - search_query = f""" - # %debug _determine_knn_ann - %if {{ - ?[use_ann] := *_determine_knn_ann{{ use_ann }} - }} - - %then {{ - owners[owner_type, owner_id] <- $owners - input[ - owner_type, - owner_id, - query_embedding, - ] := - owners[owner_type, owner_id_str], - owner_id = to_uuid(owner_id_str), - query_embedding = vec($query_embedding) - - # Search for documents by owner - ?[ - doc_id, - index, - title, - content, - distance, - embedding, - metadata, - ] := - # Get input values - input[owner_type, owner_id, query], - - # Restrict the search to all documents that match the owner - *docs {{ - owner_type, - owner_id, - doc_id, - title, - metadata, - }}, - - # Search for snippets in the embedding space - ~snippets:embedding_space {{ - doc_id, - index, - content - | - query: query, - k: {k}, - ef: {ef}, - radius: {radius}, - bind_distance: distance, - bind_vector: embedding, - }} - - :sort distance - :limit {k} - - :create _search_result {{ - doc_id, - index, - title, - content, - distance, - embedding, - metadata, - }} - }} - - %else {{ - owners[owner_type, owner_id] <- $owners - input[ - owner_type, - owner_id, - query_embedding, - ] := - owners[owner_type, owner_id_str], - owner_id = to_uuid(owner_id_str), - query_embedding = vec($query_embedding) - - # Search for documents by owner - ?[ - doc_id, - index, - title, - content, - distance, - embedding, - metadata, - ] := - # Get input values - input[owner_type, owner_id, query], - - # Restrict the search to all documents that match the owner - *docs {{ - owner_type, - owner_id, - doc_id, - title, - metadata, - }}, - - # Search for snippets in the embedding space - *snippets {{ - doc_id, - index, - content, - embedding, - }}, - !is_null(embedding), - distance = cos_dist(query, embedding), - distance <= {radius} - - :sort distance - :limit {k} - - :create _search_result {{ - doc_id, - index, - title, - content, - distance, - embedding, - metadata, - }} - }} - %end - """ - - normal_interim_query = f""" - owners[owner_type, owner_id] <- $owners - - ?[ - owner_type, - owner_id, - doc_id, - snippet_data, - distance, - title, - embedding, - metadata, - ] := - owners[owner_type, owner_id_str], - owner_id = to_uuid(owner_id_str), - *_search_result{{ doc_id, index, title, content, distance, embedding, metadata }}, - snippet_data = [index, content] - - :sort distance - :limit {k} - - :create _interim {{ - owner_type, - owner_id, - doc_id, - snippet_data, - distance, - title, - embedding, - metadata, - }} - """ - - collect_query = """ - n[ - doc_id, - owner_type, - owner_id, - unique(snippet_data), - distance, - title, - embedding, - metadata, - ] := - *_interim { - owner_type, - owner_id, - doc_id, - snippet_data, - distance, - title, - embedding, - metadata, - } - - m[ - doc_id, - owner_type, - owner_id, - snippet, - distance, - title, - metadata, - ] := - n[ - doc_id, - owner_type, - owner_id, - snippet_data, - distance, - title, - embedding, - metadata, - ], - snippet = { - "index": snippet_datum->0, - "content": snippet_datum->1, - "embedding": embedding, - }, - snippet_datum in snippet_data - - ?[ - id, - owner_type, - owner_id, - snippet, - distance, - title, - metadata, - ] := m[ - id, - owner_type, - owner_id, - snippet, - distance, - title, - metadata, - ] - - :sort distance - """ - - verify_query = "}\n\n{".join( - [ - verify_developer_id_query(developer_id), - *[ - verify_developer_owns_resource_query( - developer_id, f"{owner_type}s", **{f"{owner_type}_id": owner_id} - ) - for owner_type, owner_id in owners - ], - ] - ) - - query = f""" - {{ {verify_query} }} - {{ {determine_knn_ann_query} }} - {search_query} - {{ {normal_interim_query} }} - {{ {collect_query} }} - """ - - return ( - query, - { - "owners": owners, - "query_embedding": query_embedding, - }, - ) diff --git a/agents-api/agents_api/models/docs/search_docs_by_text.py b/agents-api/agents_api/models/docs/search_docs_by_text.py deleted file mode 100644 index ac1a9f54f..000000000 --- a/agents-api/agents_api/models/docs/search_docs_by_text.py +++ /dev/null @@ -1,206 +0,0 @@ -"""This module contains functions for searching documents in the CozoDB based on embedding queries.""" - -import json -import re -from typing import Any, Literal, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import DocReference -from ...common.nlp import paragraph_to_custom_queries -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class( - DocReference, - transform=lambda d: { - "owner": { - "id": d["owner_id"], - "role": d["owner_type"], - }, - "metadata": d.get("metadata", {}), - **d, - }, -) -@cozo_query -@beartype -def search_docs_by_text( - *, - developer_id: UUID, - owners: list[tuple[Literal["user", "agent"], UUID]], - query: str, - k: int = 3, - metadata_filter: dict[str, Any] = {}, -) -> tuple[list[str], dict]: - """ - Searches for document snippets in CozoDB by embedding query. - - Parameters: - owners (list[tuple[Literal["user", "agent"], UUID]]): The type of the owner of the documents. - query (str): The query string. - k (int, optional): The number of nearest neighbors to retrieve. Defaults to 3. - metadata_filter (dict[str, Any]): Dictionary to filter agents based on metadata. - """ - metadata_filter_str = ", ".join( - [ - f"metadata->{json.dumps(k)} == {json.dumps(v)}" - for k, v in metadata_filter.items() - ] - ) - - owners: list[list[str]] = [ - [owner_type, str(owner_id)] for owner_type, owner_id in owners - ] - - # See: https://docs.cozodb.org/en/latest/vector.html#full-text-search-fts - fts_queries = paragraph_to_custom_queries(query) or [ - re.sub(r"[^\w\s\-_]+", "", query) - ] - - # Construct the datalog query for searching document snippets - search_query = f""" - owners[owner_type, owner_id] <- $owners - input[ - owner_type, - owner_id, - ] := - owners[owner_type, owner_id_str], - owner_id = to_uuid(owner_id_str) - - candidate[doc_id] := - input[owner_type, owner_id], - *docs {{ - owner_type, - owner_id, - doc_id, - metadata, - }} - {', ' + metadata_filter_str if metadata_filter_str.strip() else ''} - - # search_result[ - # doc_id, - # snippet_data, - # distance, - # ] := - # candidate[doc_id], - # ~snippets:lsh {{ - # doc_id, - # index, - # content - # | - # query: $query, - # k: {k}, - # }}, - # distance = 10000000, # Very large distance to depict no valid distance - # snippet_data = [index, content] - - search_result[ - doc_id, - snippet_data, - distance, - ] := - candidate[doc_id], - ~snippets:fts {{ - doc_id, - index, - content - | - query: query, - k: {k}, - score_kind: 'tf_idf', - bind_score: score, - }}, - query in $fts_queries, - distance = -score, - snippet_data = [index, content] - - m[ - doc_id, - snippet, - distance, - title, - owner_type, - owner_id, - metadata, - ] := - candidate[doc_id], - *docs {{ - owner_type, - owner_id, - doc_id, - title, - metadata, - }}, - search_result [ - doc_id, - snippet_data, - distance, - ], - snippet = {{ - "index": snippet_data->0, - "content": snippet_data->1, - }} - - - ?[ - id, - owner_type, - owner_id, - snippet, - distance, - title, - metadata, - ] := - candidate[id], - input[owner_type, owner_id], - m[ - id, - snippet, - distance, - title, - owner_type, - owner_id, - metadata, - ] - - # Sort the results by distance to find the closest matches - :sort distance - :limit {k} - """ - - queries = [ - verify_developer_id_query(developer_id), - *[ - verify_developer_owns_resource_query( - developer_id, f"{owner_type}s", **{f"{owner_type}_id": owner_id} - ) - for owner_type, owner_id in owners - ], - search_query, - ] - - return ( - queries, - {"owners": owners, "query": query, "fts_queries": fts_queries}, - ) diff --git a/agents-api/agents_api/models/docs/search_docs_hybrid.py b/agents-api/agents_api/models/docs/search_docs_hybrid.py deleted file mode 100644 index c43f8c97b..000000000 --- a/agents-api/agents_api/models/docs/search_docs_hybrid.py +++ /dev/null @@ -1,138 +0,0 @@ -"""This module contains functions for searching documents in the CozoDB based on embedding queries.""" - -from statistics import mean, stdev -from typing import Any, Literal -from uuid import UUID - -from beartype import beartype - -from ...autogen.openapi_model import DocReference -from ..utils import run_concurrently -from .search_docs_by_embedding import search_docs_by_embedding -from .search_docs_by_text import search_docs_by_text - - -# Distribution based score normalization -# https://medium.com/plain-simple-software/distribution-based-score-fusion-dbsf-a-new-approach-to-vector-search-ranking-f87c37488b18 -def dbsf_normalize(scores: list[float]) -> list[float]: - """ - Scores scaled using minmax scaler with our custom feature range - (extremes indicated as 3 standard deviations from the mean) - """ - if len(scores) < 2: - return scores - - sd = stdev(scores) - if sd == 0: - return scores - - m = mean(scores) - m3d = 3 * sd + m - m_3d = m - 3 * sd - - return [(s - m_3d) / (m3d - m_3d) for s in scores] - - -def dbsf_fuse( - text_results: list[DocReference], - embedding_results: list[DocReference], - alpha: float = 0.7, # Weight of the embedding search results (this is a good default) -) -> list[DocReference]: - """ - Weighted reciprocal-rank fusion of text and embedding search results - """ - all_docs = {doc.id: doc for doc in text_results + embedding_results} - - text_scores: dict[UUID, float] = { - doc.id: -(doc.distance or 0.0) for doc in text_results - } - - # Because these are cosine distances, we need to invert them - embedding_scores: dict[UUID, float] = { - doc.id: 1.0 - doc.distance for doc in embedding_results - } - - # normalize the scores - text_scores_normalized = dbsf_normalize(list(text_scores.values())) - text_scores = { - doc_id: score - for doc_id, score in zip(text_scores.keys(), text_scores_normalized) - } - - embedding_scores_normalized = dbsf_normalize(list(embedding_scores.values())) - embedding_scores = { - doc_id: score - for doc_id, score in zip(embedding_scores.keys(), embedding_scores_normalized) - } - - # Combine the scores - text_weight: float = 1 - alpha - embedding_weight: float = alpha - - combined_scores = [] - - for id in all_docs.keys(): - text_score = text_weight * text_scores.get(id, 0) - embedding_score = embedding_weight * embedding_scores.get(id, 0) - - combined_scores.append((id, text_score + embedding_score)) - - # Sort by the combined score - combined_scores = sorted(combined_scores, key=lambda x: x[1], reverse=True) - - # Rank the results - ranked_results = [] - for id, score in combined_scores: - doc = all_docs[id].model_copy() - doc.distance = 1.0 - score - ranked_results.append(doc) - - return ranked_results - - -@beartype -def search_docs_hybrid( - *, - developer_id: UUID, - owners: list[tuple[Literal["user", "agent"], UUID]], - query: str, - query_embedding: list[float], - k: int = 3, - alpha: float = 0.7, # Weight of the embedding search results (this is a good default) - embed_search_options: dict = {}, - text_search_options: dict = {}, - metadata_filter: dict[str, Any] = {}, -) -> list[DocReference]: - # Parallelize the text and embedding search queries - fns = [ - search_docs_by_text if bool(query.strip()) else lambda: [], - search_docs_by_embedding if bool(sum(query_embedding)) else lambda: [], - ] - - kwargs_list = [ - { - "developer_id": developer_id, - "owners": owners, - "query": query, - "k": k, - "metadata_filter": metadata_filter, - **text_search_options, - } - if bool(query.strip()) - else {}, - { - "developer_id": developer_id, - "owners": owners, - "query_embedding": query_embedding, - "k": k, - "metadata_filter": metadata_filter, - **embed_search_options, - } - if bool(sum(query_embedding)) - else {}, - ] - - results = run_concurrently(fns, kwargs_list=kwargs_list) - text_results, embedding_results = results - - return dbsf_fuse(text_results, embedding_results, alpha)[:k] diff --git a/agents-api/agents_api/models/entry/__init__.py b/agents-api/agents_api/models/entry/__init__.py deleted file mode 100644 index 32231c364..000000000 --- a/agents-api/agents_api/models/entry/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -""" -The `entry` module is responsible for managing entries related to agents' activities and interactions within the 'cozodb' database. It provides a comprehensive set of functionalities for adding, deleting, summarizing, and retrieving entries, as well as processing them to retrieve memory context based on embeddings. - -Key functionalities include: -- Adding entries to the database. -- Deleting entries from the database based on session IDs. -- Summarizing entries and managing their relationships. -- Retrieving entries from the database, including top-level entries and entries based on session IDs. -- Processing entries to retrieve memory context based on embeddings. - -The module utilizes pandas DataFrames for handling query results and integrates with the CozoClient for database operations, ensuring efficient and effective management of entries. -""" - -# ruff: noqa: F401, F403, F405 - -from .create_entries import create_entries -from .delete_entries import delete_entries -from .get_history import get_history -from .list_entries import list_entries diff --git a/agents-api/agents_api/models/entry/create_entries.py b/agents-api/agents_api/models/entry/create_entries.py deleted file mode 100644 index a8671a6dd..000000000 --- a/agents-api/agents_api/models/entry/create_entries.py +++ /dev/null @@ -1,127 +0,0 @@ -from typing import Any, TypeVar -from uuid import UUID, uuid4 - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import CreateEntryRequest, Entry, Relation -from ...common.utils.cozo import cozo_process_mutate_data -from ...common.utils.datetime import utcnow -from ...common.utils.messages import content_to_json -from ...metrics.counters import increase_counter -from ..utils import ( - cozo_query, - mark_session_updated_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class( - Entry, - transform=lambda d: { - "id": UUID(d.pop("entry_id")), - **d, - }, - _kind="inserted", -) -@cozo_query -@increase_counter("create_entries") -@beartype -def create_entries( - *, - developer_id: UUID, - session_id: UUID, - data: list[CreateEntryRequest], - mark_session_as_updated: bool = True, -) -> tuple[list[str], dict]: - developer_id = str(developer_id) - session_id = str(session_id) - - data_dicts = [item.model_dump(mode="json") for item in data] - - for item in data_dicts: - item["content"] = content_to_json(item["content"] or []) - item["session_id"] = session_id - item["entry_id"] = item.pop("id", None) or str(uuid4()) - item["created_at"] = (item.get("created_at") or utcnow()).timestamp() - - cols, rows = cozo_process_mutate_data(data_dicts) - - # Construct a datalog query to insert the processed entries into the 'cozodb' database. - # Refer to the schema for the 'entries' relation in the README.md for column names and types. - create_query = f""" - ?[{cols}] <- $rows - - :insert entries {{ - {cols} - }} - - :returning - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query( - developer_id, "sessions", session_id=session_id - ), - mark_session_updated_query(developer_id, session_id) - if mark_session_as_updated - else "", - create_query, - ] - - return (queries, {"rows": rows}) - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class(Relation, _kind="inserted") -@cozo_query -@beartype -def add_entry_relations( - *, - developer_id: UUID, - data: list[Relation], -) -> tuple[list[str], dict]: - developer_id = str(developer_id) - - data_dicts = [item.model_dump(mode="json") for item in data] - cols, rows = cozo_process_mutate_data(data_dicts) - - create_query = f""" - ?[{cols}] <- $rows - - :insert relations {{ - {cols} - }} - - :returning - """ - - queries = [ - verify_developer_id_query(developer_id), - create_query, - ] - - return (queries, {"rows": rows}) diff --git a/agents-api/agents_api/models/entry/delete_entries.py b/agents-api/agents_api/models/entry/delete_entries.py deleted file mode 100644 index c98b6c7d2..000000000 --- a/agents-api/agents_api/models/entry/delete_entries.py +++ /dev/null @@ -1,153 +0,0 @@ -from typing import Any, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import ResourceDeletedResponse -from ...common.utils.datetime import utcnow -from ..utils import ( - cozo_query, - mark_session_updated_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - IndexError: partialclass(HTTPException, status_code=404), - } -) -@wrap_in_class( - ResourceDeletedResponse, - one=True, - transform=lambda d: { - "id": UUID(d.pop("session_id")), # Only return session cleared - "deleted_at": utcnow(), - "jobs": [], - }, - _kind="deleted", -) -@cozo_query -@beartype -def delete_entries_for_session( - *, developer_id: UUID, session_id: UUID, mark_session_as_updated: bool = True -) -> tuple[list[str], dict]: - """ - Constructs and returns a datalog query for deleting entries associated with a given session ID from the 'cozodb' database. - - Parameters: - session_id (UUID): The unique identifier of the session whose entries are to be deleted. - """ - - delete_query = """ - input[session_id] <- [[ - to_uuid($session_id), - ]] - - ?[ - session_id, - entry_id, - source, - role, - ] := input[session_id], - *entries{ - session_id, - entry_id, - source, - role, - } - - :delete entries { - session_id, - entry_id, - source, - role, - } - - :returning - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query( - developer_id, "sessions", session_id=session_id - ), - mark_session_updated_query(developer_id, session_id) - if mark_session_as_updated - else "", - delete_query, - ] - - return (queries, {"session_id": str(session_id)}) - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class( - ResourceDeletedResponse, - transform=lambda d: { - "id": UUID(d.pop("entry_id")), - "deleted_at": utcnow(), - "jobs": [], - }, -) -@cozo_query -@beartype -def delete_entries( - *, developer_id: UUID, session_id: UUID, entry_ids: list[UUID] -) -> tuple[list[str], dict]: - delete_query = """ - input[entry_id_str] <- $entry_ids - - ?[ - entry_id, - session_id, - source, - role, - ] := - input[entry_id_str], - entry_id = to_uuid(entry_id_str), - *entries { - session_id, - entry_id, - source, - role, - } - - :delete entries { - session_id, - entry_id, - source, - role, - } - - :returning - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query( - developer_id, "sessions", session_id=session_id - ), - delete_query, - ] - - return (queries, {"entry_ids": [[str(id)] for id in entry_ids]}) diff --git a/agents-api/agents_api/models/entry/get_history.py b/agents-api/agents_api/models/entry/get_history.py deleted file mode 100644 index 4be23804e..000000000 --- a/agents-api/agents_api/models/entry/get_history.py +++ /dev/null @@ -1,150 +0,0 @@ -from typing import Any, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import History -from ...common.utils.cozo import uuid_int_list_to_uuid4 as fix_uuid -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class( - History, - one=True, - transform=lambda d: { - "relations": [ - { - # This is needed because cozo has a bug: - # https://github.com/cozodb/cozo/issues/269 - "head": fix_uuid(r["head"]), - "relation": r["relation"], - "tail": fix_uuid(r["tail"]), - } - for r in d.pop("relations") - ], - # TODO: Remove this once we sort the entries in the cozo query - # Sort entries by created_at - "entries": sorted(d.pop("entries"), key=lambda entry: entry["created_at"]), - **d, - }, -) -@cozo_query -@beartype -def get_history( - *, - developer_id: UUID, - session_id: UUID, - allowed_sources: list[str] = ["api_request", "api_response"], -) -> tuple[list[str], dict]: - developer_id = str(developer_id) - session_id = str(session_id) - - history_query = """ - session_entries[collect(entry)] := - *entries { - session_id, - entry_id, - role, - name, - content, - source, - token_count, - tokenizer, - created_at, - tool_calls, - timestamp, - tool_call_id, - }, - source in $allowed_sources, - session_id = to_uuid($session_id), - entry = { - "session_id": session_id, - "id": entry_id, - "role": role, - "name": name, - "content": content, - "source": source, - "token_count": token_count, - "tokenizer": tokenizer, - "created_at": created_at, - "timestamp": timestamp, - "tool_calls": tool_calls, - "tool_call_id": tool_call_id, - } - - session_relations[unique(item)] := - session_id = to_uuid($session_id), - *entries { - session_id, - entry_id: head - }, - - *relations { - head, - relation, - tail - }, - - item = { - "head": head, - "relation": relation, - "tail": tail - } - - session_relations[unique(item)] := - session_id = to_uuid($session_id), - *entries { - session_id, - entry_id: tail - }, - - *relations { - head, - relation, - tail - }, - - item = { - "head": head, - "relation": relation, - "tail": tail - } - - ?[entries, relations, session_id, created_at] := - session_entries[entries], - session_relations[relations], - session_id = to_uuid($session_id), - created_at = now() - - :limit 1 - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query( - developer_id, "sessions", session_id=session_id - ), - history_query, - ] - - return (queries, {"session_id": session_id, "allowed_sources": allowed_sources}) diff --git a/agents-api/agents_api/models/entry/list_entries.py b/agents-api/agents_api/models/entry/list_entries.py deleted file mode 100644 index d3081a9b0..000000000 --- a/agents-api/agents_api/models/entry/list_entries.py +++ /dev/null @@ -1,112 +0,0 @@ -from typing import Any, Literal, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import Entry -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class(Entry) -@cozo_query -@beartype -def list_entries( - *, - developer_id: UUID, - session_id: UUID, - allowed_sources: list[str] = ["api_request", "api_response"], - limit: int = -1, - offset: int = 0, - sort_by: Literal["created_at", "timestamp"] = "timestamp", - direction: Literal["asc", "desc"] = "asc", - exclude_relations: list[str] = [], -) -> tuple[list[str], dict]: - """ - Constructs and executes a query to retrieve entries from the 'cozodb' database. - """ - - developer_id = str(developer_id) - session_id = str(session_id) - - sort = f"{'-' if direction == 'desc' else ''}{sort_by}" - - exclude_relations_query = """ - not *relations { - relation, - tail: id, - }, - relation in $exclude_relations, - # !is_in(relation, $exclude_relations), - """ - - list_query = f""" - ?[ - session_id, - id, - role, - name, - content, - source, - token_count, - tokenizer, - created_at, - timestamp, - ] := *entries {{ - session_id, - entry_id: id, - role, - name, - content, - source, - token_count, - tokenizer, - created_at, - timestamp, - }}, - {exclude_relations_query if exclude_relations else ''} - source in $allowed_sources, - session_id = to_uuid($session_id), - - :sort {sort} - """ - - if limit > 0: - list_query += f"\n:limit {limit}" - list_query += f"\n:offset {offset}" - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query( - developer_id, "sessions", session_id=session_id - ), - list_query, - ] - - return ( - queries, - { - "session_id": session_id, - "allowed_sources": allowed_sources, - "exclude_relations": exclude_relations, - }, - ) diff --git a/agents-api/agents_api/models/execution/__init__.py b/agents-api/agents_api/models/execution/__init__.py deleted file mode 100644 index abd3c7e47..000000000 --- a/agents-api/agents_api/models/execution/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# ruff: noqa: F401, F403, F405 - -from .count_executions import count_executions -from .create_execution import create_execution -from .create_execution_transition import ( - create_execution_transition, - create_execution_transition_async, -) -from .get_execution import get_execution -from .get_execution_transition import get_execution_transition -from .list_execution_transitions import list_execution_transitions -from .list_executions import list_executions -from .lookup_temporal_data import lookup_temporal_data -from .prepare_execution_input import prepare_execution_input -from .update_execution import update_execution diff --git a/agents-api/agents_api/models/execution/count_executions.py b/agents-api/agents_api/models/execution/count_executions.py deleted file mode 100644 index d130f0359..000000000 --- a/agents-api/agents_api/models/execution/count_executions.py +++ /dev/null @@ -1,61 +0,0 @@ -from typing import Any, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class(dict, one=True) -@cozo_query -@beartype -def count_executions( - *, - developer_id: UUID, - task_id: UUID, -) -> tuple[list[str], dict]: - count_query = """ - input[task_id] <- [[to_uuid($task_id)]] - - counter[count(id)] := - input[task_id], - *executions:task_id_execution_id_idx { - task_id, - execution_id: id, - } - - ?[count] := counter[count] - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query( - developer_id, - "tasks", - task_id=task_id, - parents=[("agents", "agent_id")], - ), - count_query, - ] - - return (queries, {"task_id": str(task_id)}) diff --git a/agents-api/agents_api/models/execution/create_execution.py b/agents-api/agents_api/models/execution/create_execution.py deleted file mode 100644 index 832532d6d..000000000 --- a/agents-api/agents_api/models/execution/create_execution.py +++ /dev/null @@ -1,97 +0,0 @@ -from typing import Annotated, Any, TypeVar -from uuid import UUID, uuid4 - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import CreateExecutionRequest, Execution -from ...common.utils.cozo import cozo_process_mutate_data -from ...common.utils.types import dict_like -from ...metrics.counters import increase_counter -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) -from .constants import OUTPUT_UNNEST_KEY - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class( - Execution, - one=True, - transform=lambda d: {"id": d["execution_id"], **d}, - _kind="inserted", -) -@cozo_query -@increase_counter("create_execution") -@beartype -def create_execution( - *, - developer_id: UUID, - task_id: UUID, - execution_id: UUID | None = None, - data: Annotated[CreateExecutionRequest | dict, dict_like(CreateExecutionRequest)], -) -> tuple[list[str], dict]: - execution_id = execution_id or uuid4() - - developer_id = str(developer_id) - task_id = str(task_id) - execution_id = str(execution_id) - - if isinstance(data, CreateExecutionRequest): - data.metadata = data.metadata or {} - execution_data = data.model_dump() - else: - data["metadata"] = data.get("metadata", {}) - execution_data = data - - if execution_data["output"] is not None and not isinstance( - execution_data["output"], dict - ): - execution_data["output"] = {OUTPUT_UNNEST_KEY: execution_data["output"]} - - columns, values = cozo_process_mutate_data( - { - **execution_data, - "task_id": task_id, - "execution_id": execution_id, - } - ) - - insert_query = f""" - ?[{columns}] <- $values - - :insert executions {{ - {columns} - }} - - :returning - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query( - developer_id, - "tasks", - task_id=task_id, - parents=[("agents", "agent_id")], - ), - insert_query, - ] - - return (queries, {"values": values}) diff --git a/agents-api/agents_api/models/execution/create_execution_transition.py b/agents-api/agents_api/models/execution/create_execution_transition.py deleted file mode 100644 index 59a63ed09..000000000 --- a/agents-api/agents_api/models/execution/create_execution_transition.py +++ /dev/null @@ -1,258 +0,0 @@ -from uuid import UUID, uuid4 - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import ( - CreateTransitionRequest, - Transition, - UpdateExecutionRequest, -) -from ...common.protocol.tasks import transition_to_execution_status, valid_transitions -from ...common.utils.cozo import cozo_process_mutate_data -from ...metrics.counters import increase_counter -from ..utils import ( - cozo_query, - cozo_query_async, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) -from .update_execution import update_execution - - -@beartype -def _create_execution_transition( - *, - developer_id: UUID, - execution_id: UUID, - data: CreateTransitionRequest, - # Only one of these needed - transition_id: UUID | None = None, - task_token: str | None = None, - # Only required for updating the execution status as well - update_execution_status: bool = False, - task_id: UUID | None = None, -) -> tuple[list[str | None], dict]: - transition_id = transition_id or uuid4() - data.metadata = data.metadata or {} - data.execution_id = execution_id - - # Dump to json - if isinstance(data.output, list): - data.output = [ - item.model_dump(mode="json") if hasattr(item, "model_dump") else item - for item in data.output - ] - - elif hasattr(data.output, "model_dump"): - data.output = data.output.model_dump(mode="json") - - # TODO: This is a hack to make sure the transition is valid - # (parallel transitions are whack, we should do something better) - is_parallel = data.current.workflow.startswith("PAR:") - - # Prepare the transition data - transition_data = data.model_dump(exclude_unset=True, exclude={"id"}) - - # Parse the current and next targets - validate_transition_targets(data) - current_target = transition_data.pop("current") - next_target = transition_data.pop("next") - - transition_data["current"] = (current_target["workflow"], current_target["step"]) - transition_data["next"] = next_target and ( - next_target["workflow"], - next_target["step"], - ) - - columns, transition_values = cozo_process_mutate_data( - { - **transition_data, - "task_token": str(task_token), # Converting to str for JSON serialisation - "transition_id": str(transition_id), - "execution_id": str(execution_id), - } - ) - - # Make sure the transition is valid - check_last_transition_query = f""" - valid_transition[start, end] <- [ - {", ".join(f'["{start}", "{end}"]' for start, ends in valid_transitions.items() for end in ends)} - ] - - last_transition_type[min_cost(type_created_at)] := - *transitions:execution_id_type_created_at_idx {{ - execution_id: to_uuid("{str(execution_id)}"), - type, - created_at, - }}, - type_created_at = [type, -created_at] - - matched[collect(last_type)] := - last_transition_type[data], - last_type_data = first(data), - last_type = if(is_null(last_type_data), "init", last_type_data), - valid_transition[last_type, $next_type] - - ?[valid] := - matched[prev_transitions], - found = length(prev_transitions), - valid = if($next_type == "init", found == 0, found > 0), - assert(valid, "Invalid transition"), - - :limit 1 - """ - - # Prepare the insert query - insert_query = f""" - ?[{columns}] <- $transition_values - - :insert transitions {{ - {columns} - }} - - :returning - """ - - validate_status_query, update_execution_query, update_execution_params = ( - "", - "", - {}, - ) - - if update_execution_status: - assert ( - task_id is not None - ), "task_id is required for updating the execution status" - - # Prepare the execution update query - [*_, validate_status_query, update_execution_query], update_execution_params = ( - update_execution.__wrapped__( - developer_id=developer_id, - task_id=task_id, - execution_id=execution_id, - data=UpdateExecutionRequest( - status=transition_to_execution_status[data.type] - ), - output=data.output if data.type != "error" else None, - error=str(data.output) - if data.type == "error" and data.output - else None, - ) - ) - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query( - developer_id, - "executions", - execution_id=execution_id, - parents=[("agents", "agent_id"), ("tasks", "task_id")], - ), - validate_status_query if not is_parallel else None, - update_execution_query if not is_parallel else None, - check_last_transition_query if not is_parallel else None, - insert_query, - ] - - return ( - queries, - { - "transition_values": transition_values, - "next_type": data.type, - "valid_transitions": valid_transitions, - **update_execution_params, - }, - ) - - -def validate_transition_targets(data: CreateTransitionRequest) -> None: - # Make sure the current/next targets are valid - match data.type: - case "finish_branch": - pass # TODO: Implement - case "finish" | "error" | "cancelled": - pass - - ### FIXME: HACK: Fix this and uncomment - - ### assert ( - ### data.next is None - ### ), "Next target must be None for finish/finish_branch/error/cancelled" - - case "init_branch" | "init": - assert ( - data.next and data.current.step == data.next.step == 0 - ), "Next target must be same as current for init_branch/init and step 0" - - case "wait": - assert data.next is None, "Next target must be None for wait" - - case "resume" | "step": - assert data.next is not None, "Next target must be provided for resume/step" - - if data.next.workflow == data.current.workflow: - assert ( - data.next.step > data.current.step - ), "Next step must be greater than current" - - case _: - raise ValueError(f"Invalid transition type: {data.type}") - - -create_execution_transition = rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -)( - wrap_in_class( - Transition, - transform=lambda d: { - **d, - "id": d["transition_id"], - "current": {"workflow": d["current"][0], "step": d["current"][1]}, - "next": d["next"] and {"workflow": d["next"][0], "step": d["next"][1]}, - }, - one=True, - _kind="inserted", - )( - cozo_query( - increase_counter("create_execution_transition")( - _create_execution_transition - ) - ) - ) -) - -create_execution_transition_async = rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -)( - wrap_in_class( - Transition, - transform=lambda d: { - **d, - "id": d["transition_id"], - "current": {"workflow": d["current"][0], "step": d["current"][1]}, - "next": d["next"] and {"workflow": d["next"][0], "step": d["next"][1]}, - }, - one=True, - _kind="inserted", - )( - cozo_query_async( - increase_counter("create_execution_transition_async")( - _create_execution_transition - ) - ) - ) -) diff --git a/agents-api/agents_api/models/execution/create_temporal_lookup.py b/agents-api/agents_api/models/execution/create_temporal_lookup.py deleted file mode 100644 index e47a505db..000000000 --- a/agents-api/agents_api/models/execution/create_temporal_lookup.py +++ /dev/null @@ -1,72 +0,0 @@ -from typing import TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError -from temporalio.client import WorkflowHandle - -from ...common.utils.cozo import cozo_process_mutate_data -from ...metrics.counters import increase_counter -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, -) - -T = TypeVar("T") - - -@rewrap_exceptions( - { - AssertionError: partialclass(HTTPException, status_code=404), - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@cozo_query -@increase_counter("create_temporal_lookup") -@beartype -def create_temporal_lookup( - *, - developer_id: UUID, - execution_id: UUID, - workflow_handle: WorkflowHandle, -) -> tuple[list[str], dict]: - developer_id = str(developer_id) - execution_id = str(execution_id) - - temporal_columns, temporal_values = cozo_process_mutate_data( - { - "execution_id": execution_id, - "id": workflow_handle.id, - "run_id": workflow_handle.run_id, - "first_execution_run_id": workflow_handle.first_execution_run_id, - "result_run_id": workflow_handle.result_run_id, - } - ) - - temporal_executions_lookup_query = f""" - ?[{temporal_columns}] <- $temporal_values - - :insert temporal_executions_lookup {{ - {temporal_columns} - }} - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query( - developer_id, - "executions", - execution_id=execution_id, - parents=[("agents", "agent_id"), ("tasks", "task_id")], - ), - temporal_executions_lookup_query, - ] - - return (queries, {"temporal_values": temporal_values}) diff --git a/agents-api/agents_api/models/execution/get_execution.py b/agents-api/agents_api/models/execution/get_execution.py deleted file mode 100644 index db0279b1f..000000000 --- a/agents-api/agents_api/models/execution/get_execution.py +++ /dev/null @@ -1,78 +0,0 @@ -from typing import Any, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import Execution -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - wrap_in_class, -) -from .constants import OUTPUT_UNNEST_KEY - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - AssertionError: partialclass(HTTPException, status_code=404), - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class( - Execution, - one=True, - transform=lambda d: { - **d, - "output": d["output"][OUTPUT_UNNEST_KEY] - if isinstance(d["output"], dict) and OUTPUT_UNNEST_KEY in d["output"] - else d["output"], - }, -) -@cozo_query -@beartype -def get_execution( - *, - execution_id: UUID, -) -> tuple[str, dict]: - # Executions are allowed direct GET access if they have execution_id - - # NOTE: Do not remove outer curly braces - query = """ - { - input[execution_id] <- [[to_uuid($execution_id)]] - - ?[id, task_id, status, input, output, error, session_id, metadata, created_at, updated_at] := - input[execution_id], - *executions { - task_id, - execution_id, - status, - input, - output, - error, - session_id, - metadata, - created_at, - updated_at, - }, - id = execution_id - - :limit 1 - } - """ - - return ( - query, - { - "execution_id": str(execution_id), - }, - ) diff --git a/agents-api/agents_api/models/execution/get_execution_transition.py b/agents-api/agents_api/models/execution/get_execution_transition.py deleted file mode 100644 index e2b38789a..000000000 --- a/agents-api/agents_api/models/execution/get_execution_transition.py +++ /dev/null @@ -1,80 +0,0 @@ -from typing import Any, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import Transition -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - AssertionError: partialclass(HTTPException, status_code=500), - } -) -@wrap_in_class(Transition, one=True) -@cozo_query -@beartype -def get_execution_transition( - *, - developer_id: UUID, - transition_id: UUID | None = None, - task_token: str | None = None, -) -> tuple[list[str], dict]: - # At least one of `transition_id` or `task_token` must be provided - assert ( - transition_id or task_token - ), "At least one of `transition_id` or `task_token` must be provided." - - if transition_id: - transition_id = str(transition_id) - filter = "id = to_uuid($transition_id)" - - else: - filter = "task_token = $task_token" - - get_query = """ - ?[id, type, current, next, output, metadata, updated_at, created_at] := - *transitions { - transition_id: id, - type, - current: current_tuple, - next: next_tuple, - output, - metadata, - updated_at, - created_at, - }, - current = {"workflow": current_tuple->0, "step": current_tuple->1}, - next = if( - is_null(next_tuple), - null, - {"workflow": next_tuple->0, "step": next_tuple->1}, - ) - - :limit 1 - """ - - get_query += filter - - queries = [ - verify_developer_id_query(developer_id), - get_query, - ] - - return (queries, {"task_token": task_token, "transition_id": transition_id}) diff --git a/agents-api/agents_api/models/execution/get_paused_execution_token.py b/agents-api/agents_api/models/execution/get_paused_execution_token.py deleted file mode 100644 index 6c32c7692..000000000 --- a/agents-api/agents_api/models/execution/get_paused_execution_token.py +++ /dev/null @@ -1,77 +0,0 @@ -from typing import Any, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - AssertionError: partialclass(HTTPException, status_code=500), - } -) -@wrap_in_class(dict, one=True) -@cozo_query -@beartype -def get_paused_execution_token( - *, - developer_id: UUID, - execution_id: UUID, -) -> tuple[list[str], dict]: - execution_id = str(execution_id) - - check_status_query = """ - ?[execution_id, status] := - *executions:execution_id_status_idx { - execution_id, - status, - }, - execution_id = to_uuid($execution_id), - status = "awaiting_input" - - :limit 1 - :assert some - """ - - get_query = """ - ?[task_token, created_at, metadata] := - execution_id = to_uuid($execution_id), - *executions { - execution_id, - }, - *transitions { - execution_id, - created_at, - task_token, - type, - metadata, - }, - type = "wait" - - :sort -created_at - :limit 1 - """ - - queries = [ - verify_developer_id_query(developer_id), - check_status_query, - get_query, - ] - - return (queries, {"execution_id": execution_id}) diff --git a/agents-api/agents_api/models/execution/get_temporal_workflow_data.py b/agents-api/agents_api/models/execution/get_temporal_workflow_data.py deleted file mode 100644 index 8b1bf4604..000000000 --- a/agents-api/agents_api/models/execution/get_temporal_workflow_data.py +++ /dev/null @@ -1,57 +0,0 @@ -from typing import Any, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class(dict, one=True) -@cozo_query -@beartype -def get_temporal_workflow_data( - *, - execution_id: UUID, -) -> tuple[str, dict]: - # Executions are allowed direct GET access if they have execution_id - - query = """ - input[execution_id] <- [[to_uuid($execution_id)]] - - ?[id, run_id, result_run_id, first_execution_run_id] := - input[execution_id], - *temporal_executions_lookup { - execution_id, - id, - run_id, - result_run_id, - first_execution_run_id, - } - - :limit 1 - """ - - return ( - query, - { - "execution_id": str(execution_id), - }, - ) diff --git a/agents-api/agents_api/models/execution/list_execution_transitions.py b/agents-api/agents_api/models/execution/list_execution_transitions.py deleted file mode 100644 index 8931676f6..000000000 --- a/agents-api/agents_api/models/execution/list_execution_transitions.py +++ /dev/null @@ -1,69 +0,0 @@ -from typing import Any, Literal, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import Transition -from ..utils import cozo_query, partialclass, rewrap_exceptions, wrap_in_class - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class(Transition) -@cozo_query -@beartype -def list_execution_transitions( - *, - execution_id: UUID, - limit: int = 100, - offset: int = 0, - sort_by: Literal["created_at", "updated_at"] = "created_at", - direction: Literal["asc", "desc"] = "desc", -) -> tuple[str, dict]: - sort = f"{'-' if direction == 'desc' else ''}{sort_by}" - - query = f""" - ?[id, execution_id, type, current, next, output, metadata, updated_at, created_at] := - *transitions {{ - execution_id, - transition_id: id, - type, - current: current_tuple, - next: next_tuple, - output, - metadata, - updated_at, - created_at, - }}, - current = {{"workflow": current_tuple->0, "step": current_tuple->1}}, - next = if( - is_null(next_tuple), - null, - {{"workflow": next_tuple->0, "step": next_tuple->1}}, - ), - execution_id = to_uuid($execution_id) - - :limit $limit - :offset $offset - :sort {sort} - """ - - return ( - query, - { - "execution_id": str(execution_id), - "limit": limit, - "offset": offset, - }, - ) diff --git a/agents-api/agents_api/models/execution/list_executions.py b/agents-api/agents_api/models/execution/list_executions.py deleted file mode 100644 index 64add074f..000000000 --- a/agents-api/agents_api/models/execution/list_executions.py +++ /dev/null @@ -1,95 +0,0 @@ -from typing import Any, Literal, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import Execution -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) -from .constants import OUTPUT_UNNEST_KEY - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class( - Execution, - transform=lambda d: { - **d, - "output": d["output"][OUTPUT_UNNEST_KEY] - if isinstance(d.get("output"), dict) and OUTPUT_UNNEST_KEY in d["output"] - else d.get("output"), - }, -) -@cozo_query -@beartype -def list_executions( - *, - developer_id: UUID, - task_id: UUID, - limit: int = 100, - offset: int = 0, - sort_by: Literal["created_at", "updated_at"] = "created_at", - direction: Literal["asc", "desc"] = "desc", -) -> tuple[list[str], dict]: - sort = f"{'-' if direction == 'desc' else ''}{sort_by}" - - list_query = f""" - input[task_id] <- [[to_uuid($task_id)]] - - ?[ - id, - task_id, - status, - input, - output, - session_id, - metadata, - created_at, - updated_at, - ] := input[task_id], - *executions {{ - task_id, - execution_id: id, - status, - input, - output, - session_id, - metadata, - created_at, - updated_at, - }} - - :limit {limit} - :offset {offset} - :sort {sort} - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query( - developer_id, - "tasks", - task_id=task_id, - parents=[("agents", "agent_id")], - ), - list_query, - ] - - return (queries, {"task_id": str(task_id), "limit": limit, "offset": offset}) diff --git a/agents-api/agents_api/models/execution/lookup_temporal_data.py b/agents-api/agents_api/models/execution/lookup_temporal_data.py deleted file mode 100644 index 35f09129b..000000000 --- a/agents-api/agents_api/models/execution/lookup_temporal_data.py +++ /dev/null @@ -1,66 +0,0 @@ -from typing import Any, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class(dict, one=True) -@cozo_query -@beartype -def lookup_temporal_data( - *, - developer_id: UUID, - execution_id: UUID, -) -> tuple[list[str], dict]: - developer_id = str(developer_id) - execution_id = str(execution_id) - - temporal_query = """ - ?[id] := - execution_id = to_uuid($execution_id), - *temporal_executions_lookup { - id, execution_id, run_id, first_execution_run_id, result_run_id - } - - :limit 1 - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query( - developer_id, - "executions", - execution_id=execution_id, - parents=[("agents", "agent_id"), ("tasks", "task_id")], - ), - temporal_query, - ] - - return ( - queries, - { - "execution_id": str(execution_id), - }, - ) diff --git a/agents-api/agents_api/models/execution/prepare_execution_input.py b/agents-api/agents_api/models/execution/prepare_execution_input.py deleted file mode 100644 index 5e841b9f2..000000000 --- a/agents-api/agents_api/models/execution/prepare_execution_input.py +++ /dev/null @@ -1,223 +0,0 @@ -from typing import Any, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...common.protocol.tasks import ExecutionInput -from ..agent.get_agent import get_agent -from ..task.get_task import get_task -from ..tools.list_tools import list_tools -from ..utils import ( - cozo_query, - fix_uuid_if_present, - make_cozo_json_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) -from .get_execution import get_execution - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - AssertionError: lambda e: HTTPException( - status_code=429, - detail=str(e), - headers={"x-should-retry": "true"}, - ), - } -) -@wrap_in_class( - ExecutionInput, - one=True, - transform=lambda d: { - **d, - "task": { - "tools": [*map(fix_uuid_if_present, d["task"].pop("tools"))], - **d["task"], - }, - "agent_tools": [ - {tool["type"]: tool.pop("spec"), **tool} - for tool in map(fix_uuid_if_present, d["tools"]) - ], - }, -) -@cozo_query -@beartype -def prepare_execution_input( - *, - developer_id: UUID, - task_id: UUID, - execution_id: UUID, -) -> tuple[list[str], dict]: - execution_query, execution_params = get_execution.__wrapped__( - execution_id=execution_id - ) - - # Remove the outer curly braces - execution_query = execution_query.strip()[1:-1] - - execution_fields = ( - "id", - "task_id", - "status", - "input", - "session_id", - "metadata", - "created_at", - "updated_at", - ) - execution_query += f""" - :create _execution {{ - {", ".join(execution_fields)} - }} - """ - - task_query, task_params = get_task.__wrapped__( - developer_id=developer_id, task_id=task_id - ) - - # Remove the outer curly braces - task_query = task_query[-1].strip() - - task_fields = ( - "id", - "agent_id", - "name", - "description", - "input_schema", - "tools", - "inherit_tools", - "workflows", - "created_at", - "updated_at", - "metadata", - ) - task_query += f""" - :create _task {{ - {", ".join(task_fields)} - }} - """ - - dummy_agent_id = UUID(int=0) - - [*_, agent_query], agent_params = get_agent.__wrapped__( - developer_id=developer_id, - agent_id=dummy_agent_id, # We will replace this with value from the query - ) - agent_params.pop("agent_id") - agent_query = agent_query.replace( - "<- [[to_uuid($agent_id)]]", ":= *_task { agent_id }" - ) - - agent_fields = ( - "id", - "name", - "model", - "about", - "metadata", - "default_settings", - "instructions", - "created_at", - "updated_at", - ) - - agent_query += f""" - :create _agent {{ - {", ".join(agent_fields)} - }} - """ - - [*_, tools_query], tools_params = list_tools.__wrapped__( - developer_id=developer_id, - agent_id=dummy_agent_id, # We will replace this with value from the query - ) - tools_params.pop("agent_id") - tools_query = tools_query.replace( - "<- [[to_uuid($agent_id)]]", ":= *_task { agent_id }" - ) - - tools_fields = ( - "id", - "agent_id", - "name", - "type", - "spec", - "description", - "created_at", - "updated_at", - ) - tools_query += f""" - :create _tools {{ - {", ".join(tools_fields)} - }} - """ - - combine_query = f""" - collected_tools[collect(tool)] := - *_tools {{ {', '.join(tools_fields)} }}, - tool = {{ {make_cozo_json_query(tools_fields)} }} - - agent_json[agent] := - *_agent {{ {', '.join(agent_fields)} }}, - agent = {{ {make_cozo_json_query(agent_fields)} }} - - task_json[task] := - *_task {{ {', '.join(task_fields)} }}, - task = {{ {make_cozo_json_query(task_fields)} }} - - execution_json[execution] := - *_execution {{ {', '.join(execution_fields)} }}, - execution = {{ {make_cozo_json_query(execution_fields)} }} - - ?[developer_id, execution, task, agent, user, session, tools, arguments] := - developer_id = to_uuid($developer_id), - - agent_json[agent], - task_json[task], - execution_json[execution], - collected_tools[tools], - - # TODO: Enable these later - user = null, - session = null, - arguments = execution->"input" - - :limit 1 - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query( - developer_id, "tasks", task_id=task_id, parents=[("agents", "agent_id")] - ), - execution_query, - task_query, - agent_query, - tools_query, - combine_query, - ] - - return ( - queries, - { - "developer_id": str(developer_id), - "task_id": str(task_id), - "execution_id": str(execution_id), - **execution_params, - **task_params, - **agent_params, - **tools_params, - }, - ) diff --git a/agents-api/agents_api/models/execution/update_execution.py b/agents-api/agents_api/models/execution/update_execution.py deleted file mode 100644 index f33368412..000000000 --- a/agents-api/agents_api/models/execution/update_execution.py +++ /dev/null @@ -1,130 +0,0 @@ -from typing import Any, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import ( - ResourceUpdatedResponse, - UpdateExecutionRequest, -) -from ...common.protocol.tasks import ( - valid_previous_statuses as valid_previous_statuses_map, -) -from ...common.utils.cozo import cozo_process_mutate_data -from ...metrics.counters import increase_counter -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) -from .constants import OUTPUT_UNNEST_KEY - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class( - ResourceUpdatedResponse, - one=True, - transform=lambda d: {"id": d["execution_id"], **d}, - _kind="inserted", -) -@cozo_query -@increase_counter("update_execution") -@beartype -def update_execution( - *, - developer_id: UUID, - task_id: UUID, - execution_id: UUID, - data: UpdateExecutionRequest, - output: dict | Any | None = None, - error: str | None = None, -) -> tuple[list[str], dict]: - developer_id = str(developer_id) - task_id = str(task_id) - execution_id = str(execution_id) - - valid_previous_statuses: list[str] | None = valid_previous_statuses_map.get( - data.status, None - ) - - execution_data: dict = data.model_dump(exclude_none=True) - - if output is not None and not isinstance(output, dict): - output: dict = {OUTPUT_UNNEST_KEY: output} - - columns, values = cozo_process_mutate_data( - { - **execution_data, - "task_id": task_id, - "execution_id": execution_id, - "output": output, - "error": error, - } - ) - - validate_status_query = """ - valid_status[count(status)] := - *executions { - status, - execution_id: to_uuid($execution_id), - task_id: to_uuid($task_id), - }, - status in $valid_previous_statuses - - ?[num] := - valid_status[num], - assert(num > 0, 'Invalid status') - - :limit 1 - """ - - update_query = f""" - input[{columns}] <- $values - ?[{columns}, updated_at] := - input[{columns}], - updated_at = now() - - :update executions {{ - updated_at, - {columns} - }} - - :returning - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query( - developer_id, - "executions", - execution_id=execution_id, - parents=[("agents", "agent_id"), ("tasks", "task_id")], - ), - validate_status_query if valid_previous_statuses is not None else "", - update_query, - ] - - return ( - queries, - { - "values": values, - "valid_previous_statuses": valid_previous_statuses, - "execution_id": str(execution_id), - "task_id": task_id, - }, - ) diff --git a/agents-api/agents_api/models/files/__init__.py b/agents-api/agents_api/models/files/__init__.py deleted file mode 100644 index 444c0a6eb..000000000 --- a/agents-api/agents_api/models/files/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .create_file import create_file as create_file -from .delete_file import delete_file as delete_file -from .get_file import get_file as get_file diff --git a/agents-api/agents_api/models/files/create_file.py b/agents-api/agents_api/models/files/create_file.py deleted file mode 100644 index 224597180..000000000 --- a/agents-api/agents_api/models/files/create_file.py +++ /dev/null @@ -1,121 +0,0 @@ -""" -This module contains the functionality for creating a new user in the CozoDB database. -It defines a query for inserting user data into the 'users' relation. -""" - -import base64 -import hashlib -from typing import Any, TypeVar -from uuid import UUID, uuid4 - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import CreateFileRequest, File -from ...metrics.counters import increase_counter -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - lambda e: isinstance(e, QueryException) - and "asserted to return some results, but returned none" - in str(e): lambda *_: HTTPException( - detail="Developer not found. Please ensure the provided auth token (which refers to your developer_id) is valid and the developer has the necessary permissions to create an agent.", - status_code=403, - ), - QueryException: partialclass( - HTTPException, - status_code=400, - detail="A database query failed to return the expected results. This might occur if the requested resource doesn't exist or your query parameters are incorrect.", - ), - ValidationError: partialclass( - HTTPException, - status_code=400, - detail="Input validation failed. Please check the provided data for missing or incorrect fields, and ensure it matches the required format.", - ), - TypeError: partialclass( - HTTPException, - status_code=400, - detail="A type mismatch occurred. This likely means the data provided is of an incorrect type (e.g., string instead of integer). Please review the input and try again.", - ), - } -) -@wrap_in_class( - File, - one=True, - transform=lambda d: { - **d, - "id": d["file_id"], - "content": "DUMMY: NEED TO FETCH CONTENT FROM BLOB STORAGE", - }, - _kind="inserted", -) -@cozo_query -@increase_counter("create_file") -@beartype -def create_file( - *, - developer_id: UUID, - file_id: UUID | None = None, - data: CreateFileRequest, -) -> tuple[list[str], dict]: - """ - Constructs and executes a datalog query to create a new file in the CozoDB database. - - Parameters: - user_id (UUID): The unique identifier for the user. - developer_id (UUID): The unique identifier for the developer creating the file. - """ - - file_id = file_id or uuid4() - file_data = data.model_dump(exclude={"content"}) - - content_bytes = base64.b64decode(data.content) - size = len(content_bytes) - hash = hashlib.sha256(content_bytes).hexdigest() - - create_query = """ - # Then create the file - ?[file_id, developer_id, name, description, mime_type, size, hash] <- [ - [to_uuid($file_id), to_uuid($developer_id), $name, $description, $mime_type, $size, $hash] - ] - - :insert files { - developer_id, - file_id => - name, - description, - mime_type, - size, - hash, - } - :returning - """ - - queries = [ - verify_developer_id_query(developer_id), - create_query, - ] - - return ( - queries, - { - "file_id": str(file_id), - "developer_id": str(developer_id), - "size": size, - "hash": hash, - **file_data, - }, - ) diff --git a/agents-api/agents_api/models/files/delete_file.py b/agents-api/agents_api/models/files/delete_file.py deleted file mode 100644 index 053402e2f..000000000 --- a/agents-api/agents_api/models/files/delete_file.py +++ /dev/null @@ -1,97 +0,0 @@ -""" -This module contains the implementation of the delete_user_query function, which is responsible for deleting an user and its related default settings from the CozoDB database. -""" - -from typing import Any, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import ResourceDeletedResponse -from ...common.utils.datetime import utcnow -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - lambda e: isinstance(e, QueryException) - and "Developer does not exist" in str(e): lambda *_: HTTPException( - detail="The specified developer does not exist.", - status_code=403, - ), - lambda e: isinstance(e, QueryException) - and "Developer does not own resource" - in e.resp["display"]: lambda *_: HTTPException( - detail="The specified developer does not own the requested resource. Please verify the ownership or check if the developer ID is correct.", - status_code=404, - ), - QueryException: partialclass( - HTTPException, - status_code=400, - detail="A database query failed to return the expected results. This might occur if the requested resource doesn't exist or your query parameters are incorrect.", - ), - ValidationError: partialclass( - HTTPException, - status_code=400, - detail="Input validation failed. Please check the provided data for missing or incorrect fields, and ensure it matches the required format.", - ), - TypeError: partialclass( - HTTPException, - status_code=400, - detail="A type mismatch occurred. This likely means the data provided is of an incorrect type (e.g., string instead of integer). Please review the input and try again.", - ), - } -) -@wrap_in_class( - ResourceDeletedResponse, - one=True, - transform=lambda d: { - "id": UUID(d.pop("file_id")), - "deleted_at": utcnow(), - "jobs": [], - }, - _kind="deleted", -) -@cozo_query -@beartype -def delete_file(*, developer_id: UUID, file_id: UUID) -> tuple[list[str], dict]: - """ - Constructs and returns a datalog query for deleting an file from the database. - - Parameters: - developer_id (UUID): The UUID of the developer owning the file. - file_id (UUID): The UUID of the file to be deleted. - client (CozoClient, optional): An instance of the CozoClient to execute the query. - - Returns: - ResourceDeletedResponse: The response indicating the deletion of the user. - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query(developer_id, "files", file_id=file_id), - """ - ?[file_id, developer_id] <- [[$file_id, $developer_id]] - - :delete files { - developer_id, - file_id - } - :returning - """, - ] - - return (queries, {"file_id": str(file_id), "developer_id": str(developer_id)}) diff --git a/agents-api/agents_api/models/files/get_file.py b/agents-api/agents_api/models/files/get_file.py deleted file mode 100644 index f3b85c2f7..000000000 --- a/agents-api/agents_api/models/files/get_file.py +++ /dev/null @@ -1,116 +0,0 @@ -from typing import Any, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import File -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - lambda e: isinstance(e, QueryException) - and "Developer does not exist" in str(e): lambda *_: HTTPException( - detail="The specified developer does not exist.", - status_code=403, - ), - lambda e: isinstance(e, QueryException) - and "Developer does not own resource" - in e.resp["display"]: lambda *_: HTTPException( - detail="The specified developer does not own the requested resource. Please verify the ownership or check if the developer ID is correct.", - status_code=404, - ), - QueryException: partialclass( - HTTPException, - status_code=400, - detail="A database query failed to return the expected results. This might occur if the requested resource doesn't exist or your query parameters are incorrect.", - ), - ValidationError: partialclass( - HTTPException, - status_code=400, - detail="Input validation failed. Please check the provided data for missing or incorrect fields, and ensure it matches the required format.", - ), - TypeError: partialclass( - HTTPException, - status_code=400, - detail="A type mismatch occurred. This likely means the data provided is of an incorrect type (e.g., string instead of integer). Please review the input and try again.", - ), - } -) -@wrap_in_class( - File, - one=True, - transform=lambda d: { - **d, - "content": "DUMMY: NEED TO FETCH CONTENT FROM BLOB STORAGE", - }, -) -@cozo_query -@beartype -def get_file( - *, - developer_id: UUID, - file_id: UUID, -) -> tuple[list[str], dict]: - """ - Retrieves a file by their unique identifier. - - - Parameters: - developer_id (UUID): The unique identifier of the developer associated with the file. - file_id (UUID): The unique identifier of the file to retrieve. - - Returns: - File: The retrieved file. - """ - - # Convert UUIDs to strings for query compatibility. - file_id = str(file_id) - developer_id = str(developer_id) - - get_query = """ - input[developer_id, file_id] <- [[to_uuid($developer_id), to_uuid($file_id)]] - - ?[ - id, - name, - description, - mime_type, - size, - hash, - created_at, - ] := input[developer_id, id], - *files { - file_id: id, - developer_id, - name, - description, - mime_type, - size, - hash, - created_at, - } - - :limit 1 - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query(developer_id, "files", file_id=file_id), - get_query, - ] - - return (queries, {"developer_id": developer_id, "file_id": file_id}) diff --git a/agents-api/agents_api/models/session/__init__.py b/agents-api/agents_api/models/session/__init__.py deleted file mode 100644 index bf80c9f4b..000000000 --- a/agents-api/agents_api/models/session/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -"""The session module is responsible for managing session data in the 'cozodb' database. It provides functionalities to create, retrieve, list, update, and delete session information. This module utilizes the `CozoClient` for database interactions, ensuring that sessions are uniquely identified and managed through UUIDs. - -Key functionalities include: -- Creating new sessions with specific metadata. -- Retrieving session information based on developer and session IDs. -- Listing all sessions with optional filters for pagination and metadata. -- Updating session data, including situation, summary, and metadata. -- Deleting sessions and their associated data from the database. - -This module plays a crucial role in the application by facilitating the management of session data, which is essential for tracking and analyzing user interactions and behaviors within the system.""" - -# ruff: noqa: F401, F403, F405 - -from .count_sessions import count_sessions -from .create_or_update_session import create_or_update_session -from .create_session import create_session -from .delete_session import delete_session -from .get_session import get_session -from .list_sessions import list_sessions -from .patch_session import patch_session -from .prepare_session_data import prepare_session_data -from .update_session import update_session diff --git a/agents-api/agents_api/models/session/count_sessions.py b/agents-api/agents_api/models/session/count_sessions.py deleted file mode 100644 index 3599cc2fb..000000000 --- a/agents-api/agents_api/models/session/count_sessions.py +++ /dev/null @@ -1,64 +0,0 @@ -"""This module contains functions for querying session data from the 'cozodb' database.""" - -from typing import Any, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class(dict, one=True) -@cozo_query -@beartype -def count_sessions( - *, - developer_id: UUID, -) -> tuple[list[str], dict]: - """ - Counts sessions from the 'cozodb' database. - - Parameters: - developer_id (UUID): The developer's ID to filter sessions by. - """ - - count_query = """ - input[developer_id] <- [[ - to_uuid($developer_id), - ]] - - counter[count(id)] := - input[developer_id], - *sessions{ - developer_id, - session_id: id, - } - - ?[count] := counter[count] - """ - - queries = [ - verify_developer_id_query(developer_id), - count_query, - ] - - return (queries, {"developer_id": str(developer_id)}) diff --git a/agents-api/agents_api/models/session/create_or_update_session.py b/agents-api/agents_api/models/session/create_or_update_session.py deleted file mode 100644 index e34a63ca5..000000000 --- a/agents-api/agents_api/models/session/create_or_update_session.py +++ /dev/null @@ -1,158 +0,0 @@ -from typing import Any, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import ( - CreateOrUpdateSessionRequest, - ResourceUpdatedResponse, -) -from ...common.utils.cozo import cozo_process_mutate_data -from ...metrics.counters import increase_counter -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - AssertionError: partialclass(HTTPException, status_code=400), - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class( - ResourceUpdatedResponse, - one=True, - transform=lambda d: { - "id": d["session_id"], - "updated_at": d.pop("updated_at")[0], - "jobs": [], - **d, - }, -) -@cozo_query -@increase_counter("create_or_update_session") -@beartype -def create_or_update_session( - *, - session_id: UUID, - developer_id: UUID, - data: CreateOrUpdateSessionRequest, -) -> tuple[list[str], dict]: - data.metadata = data.metadata or {} - session_data = data.model_dump(exclude={"auto_run_tools", "disable_cache"}) - - user = session_data.pop("user") - agent = session_data.pop("agent") - users = session_data.pop("users") - agents = session_data.pop("agents") - - # Only one of agent or agents should be provided. - if agent and agents: - raise ValueError("Only one of 'agent' or 'agents' should be provided.") - - agents = agents or ([agent] if agent else []) - assert len(agents) > 0, "At least one agent must be provided." - - # Users are zero or more, so we default to an empty list if not provided. - if not (user or users): - users = [] - - else: - users = users or [user] - - participants = [ - *[("user", str(user)) for user in users], - *[("agent", str(agent)) for agent in agents], - ] - - # Construct the datalog query for creating a new session and its lookup. - clear_lookup_query = """ - input[session_id] <- [[$session_id]] - ?[session_id, participant_id, participant_type] := - input[session_id], - *session_lookup { - session_id, - participant_type, - participant_id, - }, - - :delete session_lookup { - session_id, - participant_type, - participant_id, - } - """ - - lookup_query = """ - # This section creates a new session lookup to ensure uniqueness and manage session metadata. - session[session_id] <- [[$session_id]] - participants[participant_type, participant_id] <- $participants - ?[session_id, participant_id, participant_type] := - session[session_id], - participants[participant_type, participant_id], - - :put session_lookup { - session_id, - participant_id, - participant_type, - } - """ - - session_update_cols, session_update_vals = cozo_process_mutate_data( - {k: v for k, v in session_data.items() if v is not None} - ) - - # Construct the datalog query for creating or updating session information. - update_query = f""" - input[{session_update_cols}] <- $session_update_vals - ids[session_id, developer_id] <- [[to_uuid($session_id), to_uuid($developer_id)]] - - ?[{session_update_cols}, session_id, developer_id] := - input[{session_update_cols}], - ids[session_id, developer_id], - - :put sessions {{ - {session_update_cols}, session_id, developer_id - }} - - :returning - """ - - queries = [ - verify_developer_id_query(developer_id), - *[ - verify_developer_owns_resource_query( - developer_id, - f"{participant_type}s", - **{f"{participant_type}_id": participant_id}, - ) - for participant_type, participant_id in participants - ], - clear_lookup_query, - lookup_query, - update_query, - ] - - return ( - queries, - { - "session_update_vals": session_update_vals, - "session_id": str(session_id), - "developer_id": str(developer_id), - "participants": participants, - }, - ) diff --git a/agents-api/agents_api/models/session/create_session.py b/agents-api/agents_api/models/session/create_session.py deleted file mode 100644 index ce804399d..000000000 --- a/agents-api/agents_api/models/session/create_session.py +++ /dev/null @@ -1,153 +0,0 @@ -""" -This module contains the functionality for creating a new session in the 'cozodb' database. -It constructs and executes a datalog query to insert session data. -""" - -from typing import Any, TypeVar -from uuid import UUID, uuid4 - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import CreateSessionRequest, Session -from ...metrics.counters import increase_counter -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - AssertionError: partialclass(HTTPException, status_code=400), - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class( - Session, - one=True, - transform=lambda d: { - "id": UUID(d.pop("session_id")), - "updated_at": (d.pop("updated_at")[0]), - **d, - }, - _kind="inserted", -) -@cozo_query -@increase_counter("create_session") -@beartype -def create_session( - *, - developer_id: UUID, - session_id: UUID | None = None, - data: CreateSessionRequest, -) -> tuple[list[str], dict]: - """ - Constructs and executes a datalog query to create a new session in the database. - """ - - session_id = session_id or uuid4() - - data.metadata = data.metadata or {} - session_data = data.model_dump(exclude={"auto_run_tools", "disable_cache"}) - - user = session_data.pop("user") - agent = session_data.pop("agent") - users = session_data.pop("users") - agents = session_data.pop("agents") - - # Only one of agent or agents should be provided. - if agent and agents: - raise ValueError("Only one of 'agent' or 'agents' should be provided.") - - agents = agents or ([agent] if agent else []) - assert len(agents) > 0, "At least one agent must be provided." - - # Users are zero or more, so we default to an empty list if not provided. - if not (user or users): - users = [] - - else: - users = users or [user] - - participants = [ - *[("user", str(user)) for user in users], - *[("agent", str(agent)) for agent in agents], - ] - - # Construct the datalog query for creating a new session and its lookup. - lookup_query = """ - # This section creates a new session lookup to ensure uniqueness and manage session metadata. - session[session_id] <- [[$session_id]] - participants[participant_type, participant_id] <- $participants - ?[session_id, participant_id, participant_type] := - session[session_id], - participants[participant_type, participant_id], - - :insert session_lookup { - session_id, - participant_id, - participant_type, - } - """ - - create_query = """ - # Insert the new session data into the 'session' table with the specified columns. - ?[session_id, developer_id, situation, metadata, render_templates, token_budget, context_overflow] <- [[ - $session_id, - $developer_id, - $situation, - $metadata, - $render_templates, - $token_budget, - $context_overflow, - ]] - - :insert sessions { - developer_id, - session_id, - situation, - metadata, - render_templates, - token_budget, - context_overflow, - } - # Specify the data to return after the query execution, typically the newly created session's ID. - :returning - """ - - queries = [ - verify_developer_id_query(developer_id), - *[ - verify_developer_owns_resource_query( - developer_id, - f"{participant_type}s", - **{f"{participant_type}_id": participant_id}, - ) - for participant_type, participant_id in participants - ], - lookup_query, - create_query, - ] - - # Execute the constructed query with the provided parameters and return the result. - return ( - queries, - { - "session_id": str(session_id), - "developer_id": str(developer_id), - "participants": participants, - **session_data, - }, - ) diff --git a/agents-api/agents_api/models/session/delete_session.py b/agents-api/agents_api/models/session/delete_session.py deleted file mode 100644 index 81f8e1f7c..000000000 --- a/agents-api/agents_api/models/session/delete_session.py +++ /dev/null @@ -1,125 +0,0 @@ -"""This module contains the implementation for deleting sessions from the 'cozodb' database using datalog queries.""" - -from typing import Any, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import ResourceDeletedResponse -from ...common.utils.datetime import utcnow -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class( - ResourceDeletedResponse, - one=True, - transform=lambda d: { - "id": UUID(d.pop("session_id")), - "deleted_at": utcnow(), - "jobs": [], - }, - _kind="deleted", -) -@cozo_query -@beartype -def delete_session( - *, - developer_id: UUID, - session_id: UUID, -) -> tuple[list[str], dict]: - """ - Deletes a session and its related data from the 'cozodb' database. - - Parameters: - developer_id (UUID): The unique identifier for the developer. - session_id (UUID): The unique identifier for the session to be deleted. - - Returns: - ResourceDeletedResponse: The response indicating the deletion of the session. - """ - session_id = str(session_id) - developer_id = str(developer_id) - - # Constructs and executes a datalog query to delete the specified session and its associated data based on the session_id and developer_id. - delete_lookup_query = """ - # Convert session_id to UUID format - input[session_id] <- [[ - to_uuid($session_id), - ]] - - # Select sessions based on the session_id provided - ?[ - session_id, - participant_id, - participant_type, - ] := - input[session_id], - *session_lookup{ - session_id, - participant_id, - participant_type, - } - - # Delete entries from session_lookup table matching the criteria - :delete session_lookup { - session_id, - participant_id, - participant_type, - } - """ - - delete_query = """ - # Convert developer_id and session_id to UUID format - input[developer_id, session_id] <- [[ - to_uuid($developer_id), - to_uuid($session_id), - ]] - - # Select sessions based on the developer_id and session_id provided - ?[developer_id, session_id, updated_at] := - input[developer_id, session_id], - *sessions { - developer_id, - session_id, - updated_at, - } - - # Delete entries from sessions table matching the criteria - :delete sessions { - developer_id, - session_id, - updated_at, - } - :returning - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query( - developer_id, "sessions", session_id=session_id - ), - delete_lookup_query, - delete_query, - ] - - return (queries, {"session_id": session_id, "developer_id": developer_id}) diff --git a/agents-api/agents_api/models/session/get_session.py b/agents-api/agents_api/models/session/get_session.py deleted file mode 100644 index f99f2524c..000000000 --- a/agents-api/agents_api/models/session/get_session.py +++ /dev/null @@ -1,116 +0,0 @@ -from typing import Any, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...common.protocol.sessions import make_session -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class(make_session, one=True) -@cozo_query -@beartype -def get_session( - *, - developer_id: UUID, - session_id: UUID, -) -> tuple[list[str], dict]: - """ - Constructs and executes a datalog query to retrieve session information from the 'cozodb' database. - - Parameters: - developer_id (UUID): The developer's unique identifier. - session_id (UUID): The session's unique identifier. - """ - session_id = str(session_id) - developer_id = str(developer_id) - - # This query retrieves session information by using `input` to pass parameters, - get_query = """ - input[developer_id, session_id] <- [[ - to_uuid($developer_id), - to_uuid($session_id), - ]] - - participants[collect(participant_id), participant_type] := - input[_, session_id], - *session_lookup{ - session_id, - participant_id, - participant_type, - } - - # We have to do this dance because users can be zero or more - users_p[users] := - participants[users, "user"] - - users_p[users] := - not participants[_, "user"], - users = [] - - ?[ - agents, - users, - id, - situation, - summary, - updated_at, - created_at, - metadata, - render_templates, - token_budget, - context_overflow, - recall_options, - forward_tool_calls, - ] := input[developer_id, id], - users_p[users], - participants[agents, "agent"], - *sessions{ - developer_id, - session_id: id, - situation, - summary, - created_at, - updated_at: validity, - metadata, - render_templates, - token_budget, - context_overflow, - recall_options, - forward_tool_calls, - @ "END" - }, - updated_at = to_int(validity) - - :limit 1 - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query( - developer_id, "sessions", session_id=session_id - ), - get_query, - ] - - return (queries, {"session_id": session_id, "developer_id": developer_id}) diff --git a/agents-api/agents_api/models/session/list_sessions.py b/agents-api/agents_api/models/session/list_sessions.py deleted file mode 100644 index 4adb84a6c..000000000 --- a/agents-api/agents_api/models/session/list_sessions.py +++ /dev/null @@ -1,131 +0,0 @@ -"""This module contains functions for querying session data from the 'cozodb' database.""" - -from typing import Any, Literal, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...common.protocol.sessions import make_session -from ...common.utils import json -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class(make_session) -@cozo_query -@beartype -def list_sessions( - *, - developer_id: UUID, - limit: int = 100, - offset: int = 0, - sort_by: Literal["created_at", "updated_at"] = "created_at", - direction: Literal["asc", "desc"] = "desc", - metadata_filter: dict[str, Any] = {}, -) -> tuple[list[str], dict]: - """ - Lists sessions from the 'cozodb' database based on the provided filters. - - Parameters: - developer_id (UUID): The developer's ID to filter sessions by. - limit (int): The maximum number of sessions to return. - offset (int): The offset from which to start listing sessions. - metadata_filter (dict[str, Any]): A dictionary of metadata fields to filter sessions by. - """ - metadata_filter_str = ", ".join( - [ - f"metadata->{json.dumps(k)} == {json.dumps(v)}" - for k, v in metadata_filter.items() - ] - ) - - sort = f"{'-' if direction == 'desc' else ''}{sort_by}" - - list_query = f""" - input[developer_id] <- [[ - to_uuid($developer_id), - ]] - - participants[collect(participant_id), participant_type, session_id] := - *session_lookup{{ - session_id, - participant_id, - participant_type, - }} - - # We have to do this dance because users can be zero or more - users_p[users, session_id] := - participants[users, "user", session_id] - - users_p[users, session_id] := - not participants[_, "user", session_id], - users = [] - - ?[ - agents, - users, - id, - situation, - summary, - updated_at, - created_at, - metadata, - token_budget, - context_overflow, - recall_options, - forward_tool_calls, - ] := - input[developer_id], - *sessions{{ - developer_id, - session_id: id, - situation, - summary, - created_at, - updated_at: validity, - metadata, - token_budget, - context_overflow, - recall_options, - forward_tool_calls, - @ "END" - }}, - users_p[users, id], - participants[agents, "agent", id], - updated_at = to_int(validity), - {metadata_filter_str} - - :limit $limit - :offset $offset - :sort {sort} - """ - - # Datalog query to retrieve agent information based on filters, sorted by creation date in descending order. - queries = [ - verify_developer_id_query(developer_id), - list_query, - ] - - # Execute the datalog query and return the results as a pandas DataFrame. - return ( - queries, - {"developer_id": str(developer_id), "limit": limit, "offset": offset}, - ) diff --git a/agents-api/agents_api/models/session/patch_session.py b/agents-api/agents_api/models/session/patch_session.py deleted file mode 100644 index 4a119a684..000000000 --- a/agents-api/agents_api/models/session/patch_session.py +++ /dev/null @@ -1,127 +0,0 @@ -"""This module contains functions for patching session data in the 'cozodb' database using datalog queries.""" - -from typing import Any, List, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import PatchSessionRequest, ResourceUpdatedResponse -from ...common.utils.cozo import cozo_process_mutate_data -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - -_fields: List[str] = [ - "situation", - "summary", - "created_at", - "session_id", - "developer_id", -] - - -# TODO: Add support for updating `render_templates` field - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class( - ResourceUpdatedResponse, - one=True, - transform=lambda d: { - "id": d["session_id"], - "updated_at": d.pop("updated_at")[0], - "jobs": [], - **d, - }, - _kind="inserted", -) -@cozo_query -@beartype -def patch_session( - *, - session_id: UUID, - developer_id: UUID, - data: PatchSessionRequest, -) -> tuple[list[str], dict]: - """ - Patch session data in the 'cozodb' database. - - Parameters: - session_id (UUID): The unique identifier for the session to be updated. - developer_id (UUID): The unique identifier for the developer making the update. - data (PatchSessionRequest): The request payload containing the updates to apply. - """ - - update_data = data.model_dump(exclude_unset=True) - metadata = update_data.pop("metadata", {}) or {} - - session_update_cols, session_update_vals = cozo_process_mutate_data( - {k: v for k, v in update_data.items() if v is not None} - ) - - # Prepare lists of columns for the query. - session_update_cols_lst = session_update_cols.split(",") - all_fields_lst = list(set(session_update_cols_lst).union(set(_fields))) - all_fields = ", ".join(all_fields_lst) - rest_fields = ", ".join( - list( - set(all_fields_lst) - - set([k for k, v in update_data.items() if v is not None]) - ) - ) - - # Construct the datalog query for updating session information. - update_query = f""" - input[{session_update_cols}] <- $session_update_vals - ids[session_id, developer_id] <- [[to_uuid($session_id), to_uuid($developer_id)]] - - ?[{all_fields}, metadata, updated_at] := - input[{session_update_cols}], - ids[session_id, developer_id], - *sessions{{ - {rest_fields}, metadata: md, @ "END" - }}, - updated_at = 'ASSERT', - metadata = concat(md, $metadata), - - :put sessions {{ - {all_fields}, metadata, updated_at - }} - - :returning - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query( - developer_id, "sessions", session_id=session_id - ), - update_query, - ] - - return ( - queries, - { - "session_update_vals": session_update_vals, - "session_id": str(session_id), - "developer_id": str(developer_id), - "metadata": metadata, - }, - ) diff --git a/agents-api/agents_api/models/session/prepare_session_data.py b/agents-api/agents_api/models/session/prepare_session_data.py deleted file mode 100644 index 83ee0c219..000000000 --- a/agents-api/agents_api/models/session/prepare_session_data.py +++ /dev/null @@ -1,235 +0,0 @@ -from typing import Any, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...common.protocol.sessions import SessionData, make_session -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class( - SessionData, - one=True, - transform=lambda d: { - "session": make_session( - **d["session"], - agents=[a["id"] for a in d["agents"]], - users=[u["id"] for u in d["users"]], - ), - }, -) -@cozo_query -@beartype -def prepare_session_data( - *, - developer_id: UUID, - session_id: UUID, -) -> tuple[list[str], dict]: - """Constructs and executes a datalog query to retrieve session data from the 'cozodb' database. - - Parameters: - developer_id (UUID): The developer's unique identifier. - session_id (UUID): The session's unique identifier. - """ - session_id = str(session_id) - developer_id = str(developer_id) - - # This query retrieves session information by using `input` to pass parameters, - get_query = """ - input[session_id, developer_id] <- [[ - to_uuid($session_id), - to_uuid($developer_id), - ]] - - participants[collect(participant_id), participant_type] := - input[session_id, developer_id], - *session_lookup{ - session_id, - participant_id, - participant_type, - } - - agents[agent_ids] := participants[agent_ids, "agent"] - - # We have to do this dance because users can be zero or more - users[user_ids] := - participants[user_ids, "user"] - - users[user_ids] := - not participants[_, "user"], - user_ids = [] - - settings_data[agent_id, settings] := - *agent_default_settings { - agent_id, - frequency_penalty, - presence_penalty, - length_penalty, - repetition_penalty, - top_p, - temperature, - min_p, - preset, - }, - settings = { - "frequency_penalty": frequency_penalty, - "presence_penalty": presence_penalty, - "length_penalty": length_penalty, - "repetition_penalty": repetition_penalty, - "top_p": top_p, - "temperature": temperature, - "min_p": min_p, - "preset": preset, - } - - agent_data[collect(record)] := - input[session_id, developer_id], - agents[agent_ids], - agent_id in agent_ids, - *agents{ - developer_id, - agent_id, - model, - name, - about, - created_at, - updated_at, - metadata, - instructions, - }, - settings_data[agent_id, default_settings], - record = { - "id": agent_id, - "name": name, - "model": model, - "about": about, - "created_at": created_at, - "updated_at": updated_at, - "metadata": metadata, - "default_settings": default_settings, - "instructions": instructions, - } - - # Version where we don't have default settings - agent_data[collect(record)] := - input[session_id, developer_id], - agents[agent_ids], - agent_id in agent_ids, - *agents{ - developer_id, - agent_id, - model, - name, - about, - created_at, - updated_at, - metadata, - instructions, - }, - not settings_data[agent_id, _], - record = { - "id": agent_id, - "name": name, - "model": model, - "about": about, - "created_at": created_at, - "updated_at": updated_at, - "metadata": metadata, - "default_settings": {}, - "instructions": instructions, - } - - user_data[collect(record)] := - input[session_id, developer_id], - users[user_ids], - user_id in user_ids, - *users{ - developer_id, - user_id, - name, - about, - created_at, - updated_at, - metadata, - }, - record = { - "id": user_id, - "name": name, - "about": about, - "created_at": created_at, - "updated_at": updated_at, - "metadata": metadata, - } - - session_data[record] := - input[session_id, developer_id], - *sessions{ - developer_id, - session_id, - situation, - summary, - created_at, - updated_at: validity, - metadata, - render_templates, - token_budget, - context_overflow, - @ "END" - }, - updated_at = to_int(validity), - record = { - "id": session_id, - "situation": situation, - "summary": summary, - "created_at": created_at, - "updated_at": updated_at, - "metadata": metadata, - "render_templates": render_templates, - "token_budget": token_budget, - "context_overflow": context_overflow, - } - - ?[ - agents, - users, - session, - ] := - session_data[session], - user_data[users], - agent_data[agents] - - :limit 1 - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query( - developer_id, "sessions", session_id=session_id - ), - get_query, - ] - - return ( - queries, - {"developer_id": developer_id, "session_id": session_id}, - ) diff --git a/agents-api/agents_api/models/session/update_session.py b/agents-api/agents_api/models/session/update_session.py deleted file mode 100644 index cc8b61f16..000000000 --- a/agents-api/agents_api/models/session/update_session.py +++ /dev/null @@ -1,127 +0,0 @@ -from typing import Any, List, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import ResourceUpdatedResponse, UpdateSessionRequest -from ...common.utils.cozo import cozo_process_mutate_data -from ...metrics.counters import increase_counter -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - -_fields: List[str] = [ - "situation", - "summary", - "metadata", - "created_at", - "session_id", - "developer_id", -] - -# TODO: Add support for updating `render_templates` field - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class( - ResourceUpdatedResponse, - one=True, - transform=lambda d: { - "id": d["session_id"], - "updated_at": d.pop("updated_at")[0], - "jobs": [], - **d, - }, - _kind="inserted", -) -@cozo_query -@increase_counter("update_session") -@beartype -def update_session( - *, - session_id: UUID, - developer_id: UUID, - data: UpdateSessionRequest, -) -> tuple[list[str], dict]: - """ - Updates a session with the provided data. - - Parameters: - session_id (UUID): The unique identifier of the session to update. - developer_id (UUID): The unique identifier of the developer associated with the session. - data (UpdateSessionRequest): The data to update the session with. - - Returns: - ResourceUpdatedResponse: The updated session. - """ - - update_data = data.model_dump(exclude_unset=True) - - session_update_cols, session_update_vals = cozo_process_mutate_data( - {k: v for k, v in update_data.items() if v is not None} - ) - - # Prepare lists of columns for the query. - session_update_cols_lst = session_update_cols.split(",") - all_fields_lst = list(set(session_update_cols_lst).union(set(_fields))) - all_fields = ", ".join(all_fields_lst) - rest_fields = ", ".join( - list( - set(all_fields_lst) - - set([k for k, v in update_data.items() if v is not None]) - ) - ) - - # Construct the datalog query for updating session information. - update_query = f""" - input[{session_update_cols}] <- $session_update_vals - ids[session_id, developer_id] <- [[to_uuid($session_id), to_uuid($developer_id)]] - - ?[{all_fields}, updated_at] := - input[{session_update_cols}], - ids[session_id, developer_id], - *sessions{{ - {rest_fields}, @ "END" - }}, - updated_at = 'ASSERT' - - :put sessions {{ - {all_fields}, updated_at - }} - - :returning - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query( - developer_id, "sessions", session_id=session_id - ), - update_query, - ] - - return ( - queries, - { - "session_update_vals": session_update_vals, - "session_id": str(session_id), - "developer_id": str(developer_id), - }, - ) diff --git a/agents-api/agents_api/models/task/__init__.py b/agents-api/agents_api/models/task/__init__.py deleted file mode 100644 index 2eaff3ab3..000000000 --- a/agents-api/agents_api/models/task/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# ruff: noqa: F401, F403, F405 - -from .create_or_update_task import create_or_update_task -from .create_task import create_task -from .delete_task import delete_task -from .get_task import get_task -from .list_tasks import list_tasks -from .patch_task import patch_task -from .update_task import update_task diff --git a/agents-api/agents_api/models/task/create_or_update_task.py b/agents-api/agents_api/models/task/create_or_update_task.py deleted file mode 100644 index 1f615a3ad..000000000 --- a/agents-api/agents_api/models/task/create_or_update_task.py +++ /dev/null @@ -1,109 +0,0 @@ -""" -This module contains the functionality for creating a new Task in the 'cozodb` database. -It constructs and executes a datalog query to insert Task data. -""" - -from typing import Any, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import ( - CreateOrUpdateTaskRequest, - ResourceUpdatedResponse, -) -from ...common.protocol.tasks import task_to_spec -from ...common.utils.cozo import cozo_process_mutate_data -from ...common.utils.datetime import utcnow -from ...metrics.counters import increase_counter -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class( - ResourceUpdatedResponse, - one=True, - transform=lambda d: { - "id": d["task_id"], - "jobs": [], - "updated_at": d["updated_at_ms"][0] / 1000, - **d, - }, -) -@cozo_query -@increase_counter("create_or_update_task") -@beartype -def create_or_update_task( - *, - developer_id: UUID, - agent_id: UUID, - task_id: UUID, - data: CreateOrUpdateTaskRequest, -) -> tuple[list[str], dict]: - developer_id = str(developer_id) - agent_id = str(agent_id) - task_id = str(task_id) - - data.metadata = data.metadata or {} - data.input_schema = data.input_schema or {} - - task_data = task_to_spec(data).model_dump(exclude_none=True, mode="json") - task_data.pop("task_id", None) - task_data["created_at"] = utcnow().timestamp() - - columns, values = cozo_process_mutate_data(task_data) - - update_query = f""" - input[{columns}] <- $values - ids[agent_id, task_id] := - agent_id = to_uuid($agent_id), - task_id = to_uuid($task_id) - - ?[updated_at_ms, agent_id, task_id, {columns}] := - ids[agent_id, task_id], - input[{columns}], - updated_at_ms = [floor(now() * 1000), true] - - :put tasks {{ - agent_id, - task_id, - updated_at_ms, - {columns}, - }} - - :returning - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query(developer_id, "agents", agent_id=agent_id), - update_query, - ] - - return ( - queries, - { - "values": values, - "agent_id": agent_id, - "task_id": task_id, - }, - ) diff --git a/agents-api/agents_api/models/task/create_task.py b/agents-api/agents_api/models/task/create_task.py deleted file mode 100644 index ab68a5b0c..000000000 --- a/agents-api/agents_api/models/task/create_task.py +++ /dev/null @@ -1,117 +0,0 @@ -""" -This module contains the functionality for creating a new Task in the 'cozodb` database. -It constructs and executes a datalog query to insert Task data. -""" - -from typing import Any, TypeVar -from uuid import UUID, uuid4 - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import ( - CreateTaskRequest, - ResourceCreatedResponse, -) -from ...common.protocol.tasks import task_to_spec -from ...common.utils.cozo import cozo_process_mutate_data -from ...metrics.counters import increase_counter -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class( - ResourceCreatedResponse, - one=True, - transform=lambda d: { - "id": d["task_id"], - "jobs": [], - "created_at": d["created_at"], - **d, - }, -) -@cozo_query -@increase_counter("create_task") -@beartype -def create_task( - *, - developer_id: UUID, - agent_id: UUID, - task_id: UUID | None = None, - data: CreateTaskRequest, -) -> tuple[list[str], dict]: - """ - Creates a new task. - - Parameters: - developer_id (UUID): The unique identifier of the developer associated with the task. - agent_id (UUID): The unique identifier of the agent associated with the task. - task_id (UUID | None): The unique identifier of the task. If not provided, a new UUID will be generated. - data (CreateTaskRequest): The data to create the task with. - - Returns: - ResourceCreatedResponse: The created task. - """ - - data.metadata = data.metadata or {} - data.input_schema = data.input_schema or {} - - task_id = task_id or uuid4() - task_spec = task_to_spec(data) - - # Prepares the update data by filtering out None values and adding user_id and developer_id. - columns, values = cozo_process_mutate_data( - { - **task_spec.model_dump(exclude_none=True, mode="json"), - "task_id": str(task_id), - "agent_id": str(agent_id), - } - ) - - create_query = f""" - input[{columns}] <- $values - ?[{columns}, updated_at_ms, created_at] := - input[{columns}], - updated_at_ms = [floor(now() * 1000), true], - created_at = now(), - - :insert tasks {{ - {columns}, - updated_at_ms, - created_at, - }} - - :returning - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query(developer_id, "agents", agent_id=agent_id), - create_query, - ] - - return ( - queries, - { - "agent_id": str(agent_id), - "values": values, - }, - ) diff --git a/agents-api/agents_api/models/task/delete_task.py b/agents-api/agents_api/models/task/delete_task.py deleted file mode 100644 index 10c377a25..000000000 --- a/agents-api/agents_api/models/task/delete_task.py +++ /dev/null @@ -1,91 +0,0 @@ -from typing import Any, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import ResourceDeletedResponse -from ...common.utils.datetime import utcnow -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class( - ResourceDeletedResponse, - one=True, - transform=lambda d: { - "id": UUID(d.pop("task_id")), - "jobs": [], - "deleted_at": utcnow(), - **d, - }, - _kind="deleted", -) -@cozo_query -@beartype -def delete_task( - *, - developer_id: UUID, - agent_id: UUID, - task_id: UUID, -) -> tuple[list[str], dict]: - """ - Deletes a task. - - Parameters: - developer_id (UUID): The unique identifier of the developer associated with the task. - agent_id (UUID): The unique identifier of the agent associated with the task. - task_id (UUID): The unique identifier of the task to delete. - - Returns: - ResourceDeletedResponse: The deleted task. - """ - - delete_query = """ - input[agent_id, task_id] <- [[ - to_uuid($agent_id), - to_uuid($task_id), - ]] - - ?[agent_id, task_id, updated_at_ms] := - input[agent_id, task_id], - *tasks{ - agent_id, - task_id, - updated_at_ms, - } - - :delete tasks { - agent_id, - task_id, - updated_at_ms, - } - - :returning - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query(developer_id, "agents", agent_id=agent_id), - delete_query, - ] - - return (queries, {"agent_id": str(agent_id), "task_id": str(task_id)}) diff --git a/agents-api/agents_api/models/task/get_task.py b/agents-api/agents_api/models/task/get_task.py deleted file mode 100644 index 460fdc38b..000000000 --- a/agents-api/agents_api/models/task/get_task.py +++ /dev/null @@ -1,120 +0,0 @@ -from typing import Any, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...common.protocol.tasks import spec_to_task -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class(spec_to_task, one=True) -@cozo_query -@beartype -def get_task( - *, - developer_id: UUID, - task_id: UUID, -) -> tuple[list[str], dict]: - """ - Retrieves a task by its unique identifier. - - Parameters: - developer_id (UUID): The unique identifier of the developer associated with the task. - task_id (UUID): The unique identifier of the task to retrieve. - - Returns: - Task | CreateTaskRequest: The retrieved task. - """ - - get_query = """ - input[task_id] <- [[to_uuid($task_id)]] - - task_data[ - task_id, - agent_id, - name, - description, - input_schema, - tools, - inherit_tools, - workflows, - created_at, - updated_at, - metadata, - ] := - input[task_id], - *tasks { - agent_id, - task_id, - updated_at_ms, - name, - description, - input_schema, - tools, - inherit_tools, - workflows, - created_at, - metadata, - @ 'END' - }, - updated_at = to_int(updated_at_ms) / 1000 - - ?[ - id, - agent_id, - name, - description, - input_schema, - tools, - inherit_tools, - workflows, - created_at, - updated_at, - metadata, - ] := - task_data[ - id, - agent_id, - name, - description, - input_schema, - tools, - inherit_tools, - workflows, - created_at, - updated_at, - metadata, - ] - - :limit 1 - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query( - developer_id, "tasks", task_id=task_id, parents=[("agents", "agent_id")] - ), - get_query, - ] - - return (queries, {"task_id": str(task_id)}) diff --git a/agents-api/agents_api/models/task/list_tasks.py b/agents-api/agents_api/models/task/list_tasks.py deleted file mode 100644 index d873e817e..000000000 --- a/agents-api/agents_api/models/task/list_tasks.py +++ /dev/null @@ -1,130 +0,0 @@ -from typing import Any, Literal, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...common.protocol.tasks import spec_to_task -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class(spec_to_task) -@cozo_query -@beartype -def list_tasks( - *, - developer_id: UUID, - agent_id: UUID, - limit: int = 100, - offset: int = 0, - sort_by: Literal["created_at", "updated_at"] = "created_at", - direction: Literal["asc", "desc"] = "desc", -) -> tuple[list[str], dict]: - """ - Lists tasks for a given agent. - - Parameters: - developer_id (UUID): The unique identifier of the developer associated with the tasks. - agent_id (UUID): The unique identifier of the agent associated with the tasks. - limit (int): The maximum number of tasks to return. - offset (int): The number of tasks to skip before returning the results. - sort_by (Literal["created_at", "updated_at"]): The field to sort the tasks by. - direction (Literal["asc", "desc"]): The direction to sort the tasks in. - - Returns: - Task[] | CreateTaskRequest[]: The list of tasks. - """ - - sort = f"{'-' if direction == 'desc' else ''}{sort_by}" - - list_query = f""" - input[agent_id] <- [[to_uuid($agent_id)]] - - task_data[ - task_id, - agent_id, - name, - description, - input_schema, - tools, - inherit_tools, - workflows, - created_at, - updated_at, - metadata, - ] := - input[agent_id], - *tasks {{ - agent_id, - task_id, - updated_at_ms, - name, - description, - input_schema, - tools, - inherit_tools, - workflows, - created_at, - metadata, - @ 'END' - }}, - updated_at = to_int(updated_at_ms) / 1000 - - ?[ - task_id, - agent_id, - name, - description, - input_schema, - tools, - inherit_tools, - workflows, - created_at, - updated_at, - metadata, - ] := - task_data[ - task_id, - agent_id, - name, - description, - input_schema, - tools, - inherit_tools, - workflows, - created_at, - updated_at, - metadata, - ] - - :limit $limit - :offset $offset - :sort {sort} - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query(developer_id, "agents", agent_id=agent_id), - list_query, - ] - - return (queries, {"agent_id": str(agent_id), "limit": limit, "offset": offset}) diff --git a/agents-api/agents_api/models/task/patch_task.py b/agents-api/agents_api/models/task/patch_task.py deleted file mode 100644 index 178b9daa3..000000000 --- a/agents-api/agents_api/models/task/patch_task.py +++ /dev/null @@ -1,133 +0,0 @@ -""" -This module contains the functionality for creating a new Task in the 'cozodb` database. -It constructs and executes a datalog query to insert Task data. -""" - -from typing import Any, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import PatchTaskRequest, ResourceUpdatedResponse, TaskSpec -from ...common.protocol.tasks import task_to_spec -from ...common.utils.cozo import cozo_process_mutate_data -from ...metrics.counters import increase_counter -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class( - ResourceUpdatedResponse, - one=True, - transform=lambda d: { - "id": d["task_id"], - "jobs": [], - "updated_at": d["updated_at_ms"][0] / 1000, - **d, - }, - _kind="inserted", -) -@cozo_query -@increase_counter("patch_task") -@beartype -def patch_task( - *, - developer_id: UUID, - agent_id: UUID, - task_id: UUID, - data: PatchTaskRequest, -) -> tuple[list[str], dict]: - developer_id = str(developer_id) - agent_id = str(agent_id) - task_id = str(task_id) - - data.input_schema = data.input_schema or {} - task_data = task_to_spec(data, exclude_none=True, exclude_unset=True).model_dump( - exclude_none=True, exclude_unset=True - ) - task_data.pop("task_id", None) - - assert len(task_data), "No data provided to update task" - metadata = task_data.pop("metadata", {}) - columns, values = cozo_process_mutate_data(task_data) - - all_columns = list(TaskSpec.model_fields.keys()) - all_columns.remove("id") - all_columns.remove("main") - - missing_columns = ( - set(all_columns) - - set(columns.split(",")) - - {"metadata", "created_at", "updated_at"} - ) - missing_columns_str = ",".join(missing_columns) - - patch_query = f""" - input[{columns}] <- $values - ids[agent_id, task_id] := - agent_id = to_uuid($agent_id), - task_id = to_uuid($task_id) - - original[created_at, metadata, {missing_columns_str}] := - ids[agent_id, task_id], - *tasks{{ - agent_id, - task_id, - created_at, - metadata, - {missing_columns_str}, - }} - - ?[created_at, updated_at_ms, agent_id, task_id, metadata, {columns}, {missing_columns_str}] := - ids[agent_id, task_id], - input[{columns}], - original[created_at, _metadata, {missing_columns_str}], - updated_at_ms = [floor(now() * 1000), true], - metadata = _metadata ++ $metadata - - :put tasks {{ - agent_id, - task_id, - created_at, - updated_at_ms, - metadata, - {columns}, {missing_columns_str} - }} - - :returning - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query(developer_id, "agents", agent_id=agent_id), - patch_query, - ] - - return ( - queries, - { - "values": values, - "agent_id": agent_id, - "task_id": task_id, - "metadata": metadata, - }, - ) diff --git a/agents-api/agents_api/models/task/update_task.py b/agents-api/agents_api/models/task/update_task.py deleted file mode 100644 index cd98d85d5..000000000 --- a/agents-api/agents_api/models/task/update_task.py +++ /dev/null @@ -1,129 +0,0 @@ -""" -This module contains the functionality for creating a new Task in the 'cozodb` database. -It constructs and executes a datalog query to insert Task data. -""" - -from typing import Any, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import ResourceUpdatedResponse, UpdateTaskRequest -from ...common.protocol.tasks import task_to_spec -from ...common.utils.cozo import cozo_process_mutate_data -from ...metrics.counters import increase_counter -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class( - ResourceUpdatedResponse, - one=True, - transform=lambda d: { - "id": d["task_id"], - "jobs": [], - "updated_at": d["updated_at_ms"][0] / 1000, - **d, - }, -) -@cozo_query -@increase_counter("update_task") -@beartype -def update_task( - *, - developer_id: UUID, - agent_id: UUID, - task_id: UUID, - data: UpdateTaskRequest, -) -> tuple[list[str], dict]: - """ - Updates a task. - - Parameters: - developer_id (UUID): The unique identifier of the developer associated with the task. - agent_id (UUID): The unique identifier of the agent associated with the task. - task_id (UUID): The unique identifier of the task to update. - data (UpdateTaskRequest): The data to update the task with. - - Returns: - ResourceUpdatedResponse: The updated task. - """ - - developer_id = str(developer_id) - agent_id = str(agent_id) - task_id = str(task_id) - - data.metadata = data.metadata or {} - data.input_schema = data.input_schema or {} - - task_data = task_to_spec(data, exclude_none=True, exclude_unset=True).model_dump( - exclude_none=True, exclude_unset=True - ) - task_data.pop("task_id", None) - - columns, values = cozo_process_mutate_data(task_data) - - update_query = f""" - input[{columns}] <- $values - ids[agent_id, task_id] := - agent_id = to_uuid($agent_id), - task_id = to_uuid($task_id) - - original[created_at] := - ids[agent_id, task_id], - *tasks{{ - agent_id, - task_id, - created_at, - }} - - ?[created_at, updated_at_ms, agent_id, task_id, {columns}] := - ids[agent_id, task_id], - input[{columns}], - original[created_at], - updated_at_ms = [floor(now() * 1000), true] - - :put tasks {{ - agent_id, - task_id, - created_at, - updated_at_ms, - {columns}, - }} - - :returning - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query(developer_id, "agents", agent_id=agent_id), - update_query, - ] - - return ( - queries, - { - "values": values, - "agent_id": agent_id, - "task_id": task_id, - }, - ) diff --git a/agents-api/agents_api/models/tools/create_tools.py b/agents-api/agents_api/models/tools/create_tools.py deleted file mode 100644 index 9b2be387a..000000000 --- a/agents-api/agents_api/models/tools/create_tools.py +++ /dev/null @@ -1,135 +0,0 @@ -"""This module contains functions for creating tools in the CozoDB database.""" - -from typing import Any, TypeVar -from uuid import UUID, uuid4 - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import CreateToolRequest, Tool -from ...metrics.counters import increase_counter -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - AssertionError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class( - Tool, - transform=lambda d: { - "id": UUID(d.pop("tool_id")), - d["type"]: d.pop("spec"), - **d, - }, - _kind="inserted", -) -@cozo_query -@increase_counter("create_tools") -@beartype -def create_tools( - *, - developer_id: UUID, - agent_id: UUID, - data: list[CreateToolRequest], - ignore_existing: bool = False, -) -> tuple[list[str], dict]: - """ - Constructs a datalog query for inserting tool records into the 'agent_functions' relation in the CozoDB. - - Parameters: - agent_id (UUID): The unique identifier for the agent. - data (list[CreateToolRequest]): A list of function definitions to be inserted. - - Returns: - list[Tool] - """ - - assert all( - getattr(tool, tool.type) is not None - for tool in data - if hasattr(tool, tool.type) - ), "Tool spec must be passed" - - tools_data = [ - [ - str(agent_id), - str(uuid4()), - tool.type, - tool.name, - getattr(tool, tool.type) and getattr(tool, tool.type).model_dump(), - tool.description if hasattr(tool, "description") else None, - ] - for tool in data - ] - - ensure_tool_name_unique_query = """ - input[agent_id, tool_id, type, name, spec, description] <- $records - ?[tool_id] := - input[agent_id, _, type, name, _, _], - *tools{ - agent_id: to_uuid(agent_id), - tool_id, - type, - name, - spec, - description, - } - - :limit 1 - :assert none - """ - - # Datalog query for inserting new tool records into the 'tools' relation - create_query = """ - input[agent_id, tool_id, type, name, spec, description] <- $records - - # Do not add duplicate - ?[agent_id, tool_id, type, name, spec, description] := - input[agent_id, tool_id, type, name, spec, description], - not *tools{ - agent_id: to_uuid(agent_id), - type, - name, - } - - :insert tools { - agent_id, - tool_id, - type, - name, - spec, - description, - } - :returning - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query(developer_id, "agents", agent_id=agent_id), - create_query, - ] - - if not ignore_existing: - queries.insert( - -1, - ensure_tool_name_unique_query, - ) - - return (queries, {"records": tools_data}) diff --git a/agents-api/agents_api/models/tools/delete_tool.py b/agents-api/agents_api/models/tools/delete_tool.py deleted file mode 100644 index c79cdfd29..000000000 --- a/agents-api/agents_api/models/tools/delete_tool.py +++ /dev/null @@ -1,68 +0,0 @@ -from typing import Any, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import ResourceDeletedResponse -from ...common.utils.datetime import utcnow -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class( - ResourceDeletedResponse, - one=True, - transform=lambda d: {"id": d["tool_id"], "deleted_at": utcnow(), "jobs": [], **d}, - _kind="deleted", -) -@cozo_query -@beartype -def delete_tool( - *, - developer_id: UUID, - agent_id: UUID, - tool_id: UUID, -) -> tuple[list[str], dict]: - agent_id = str(agent_id) - tool_id = str(tool_id) - - delete_query = """ - # Delete function - ?[tool_id, agent_id] <- [[ - to_uuid($tool_id), - to_uuid($agent_id), - ]] - - :delete tools { - tool_id, - agent_id, - } - :returning - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query(developer_id, "agents", agent_id=agent_id), - delete_query, - ] - - return (queries, {"tool_id": tool_id, "agent_id": agent_id}) diff --git a/agents-api/agents_api/models/tools/get_tool.py b/agents-api/agents_api/models/tools/get_tool.py deleted file mode 100644 index 465fd2efe..000000000 --- a/agents-api/agents_api/models/tools/get_tool.py +++ /dev/null @@ -1,81 +0,0 @@ -from typing import Any, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import Tool -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class( - Tool, - transform=lambda d: { - "id": UUID(d.pop("tool_id")), - d["type"]: d.pop("spec"), - **d, - }, - one=True, -) -@cozo_query -@beartype -def get_tool( - *, - developer_id: UUID, - agent_id: UUID, - tool_id: UUID, -) -> tuple[list[str], dict]: - agent_id = str(agent_id) - tool_id = str(tool_id) - - get_query = """ - input[agent_id, tool_id] <- [[to_uuid($agent_id), to_uuid($tool_id)]] - - ?[ - agent_id, - tool_id, - type, - name, - spec, - updated_at, - created_at, - ] := input[agent_id, tool_id], - *tools { - agent_id, - tool_id, - name, - type, - spec, - updated_at, - created_at, - } - - :limit 1 - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query(developer_id, "agents", agent_id=agent_id), - get_query, - ] - - return (queries, {"agent_id": agent_id, "tool_id": tool_id}) diff --git a/agents-api/agents_api/models/tools/get_tool_args_from_metadata.py b/agents-api/agents_api/models/tools/get_tool_args_from_metadata.py deleted file mode 100644 index 2cdb92cb9..000000000 --- a/agents-api/agents_api/models/tools/get_tool_args_from_metadata.py +++ /dev/null @@ -1,148 +0,0 @@ -from typing import Literal -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) - - -def tool_args_for_task( - *, - developer_id: UUID, - agent_id: UUID, - task_id: UUID, - tool_type: Literal["integration", "api_call"] = "integration", - arg_type: Literal["args", "setup"] = "args", -) -> tuple[list[str], dict]: - agent_id = str(agent_id) - task_id = str(task_id) - - get_query = f""" - input[agent_id, task_id] <- [[to_uuid($agent_id), to_uuid($task_id)]] - - ?[values] := - input[agent_id, task_id], - *tasks {{ - task_id, - metadata: task_metadata, - }}, - *agents {{ - agent_id, - metadata: agent_metadata, - }}, - task_{arg_type} = get(task_metadata, "x-{tool_type}-{arg_type}", {{}}), - agent_{arg_type} = get(agent_metadata, "x-{tool_type}-{arg_type}", {{}}), - - # Right values overwrite left values - # See: https://docs.cozodb.org/en/latest/functions.html#Func.Vector.concat - values = concat(agent_{arg_type}, task_{arg_type}), - - :limit 1 - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query( - developer_id, "tasks", task_id=task_id, parents=[("agents", "agent_id")] - ), - get_query, - ] - - return (queries, {"agent_id": agent_id, "task_id": task_id}) - - -def tool_args_for_session( - *, - developer_id: UUID, - session_id: UUID, - agent_id: UUID, - arg_type: Literal["args", "setup"] = "args", - tool_type: Literal["integration", "api_call"] = "integration", -) -> tuple[list[str], dict]: - session_id = str(session_id) - - get_query = f""" - input[session_id, agent_id] <- [[to_uuid($session_id), to_uuid($agent_id)]] - - ?[values] := - input[session_id, agent_id], - *sessions {{ - session_id, - metadata: session_metadata, - }}, - *agents {{ - agent_id, - metadata: agent_metadata, - }}, - session_{arg_type} = get(session_metadata, "x-{tool_type}-{arg_type}", {{}}), - agent_{arg_type} = get(agent_metadata, "x-{tool_type}-{arg_type}", {{}}), - - # Right values overwrite left values - # See: https://docs.cozodb.org/en/latest/functions.html#Func.Vector.concat - values = concat(agent_{arg_type}, session_{arg_type}), - - :limit 1 - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query( - developer_id, "sessions", session_id=session_id - ), - get_query, - ] - - return (queries, {"agent_id": agent_id, "session_id": session_id}) - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class(dict, transform=lambda x: x["values"], one=True) -@cozo_query -@beartype -def get_tool_args_from_metadata( - *, - developer_id: UUID, - agent_id: UUID, - session_id: UUID | None = None, - task_id: UUID | None = None, - tool_type: Literal["integration", "api_call"] = "integration", - arg_type: Literal["args", "setup", "headers"] = "args", -) -> tuple[list[str], dict]: - common: dict = dict( - developer_id=developer_id, - agent_id=agent_id, - tool_type=tool_type, - arg_type=arg_type, - ) - - match session_id, task_id: - case (None, task_id) if task_id is not None: - return tool_args_for_task( - **common, - task_id=task_id, - ) - - case (session_id, None) if session_id is not None: - return tool_args_for_session( - **common, - session_id=session_id, - ) - - case (_, _): - raise ValueError("Either session_id or task_id must be provided") diff --git a/agents-api/agents_api/models/tools/list_tools.py b/agents-api/agents_api/models/tools/list_tools.py deleted file mode 100644 index 727bf8028..000000000 --- a/agents-api/agents_api/models/tools/list_tools.py +++ /dev/null @@ -1,94 +0,0 @@ -from typing import Any, Literal, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import Tool -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class( - Tool, - transform=lambda d: { - d["type"]: { - **d.pop("spec"), - "name": d["name"], - "description": d["description"], - }, - **d, - }, -) -@cozo_query -@beartype -def list_tools( - *, - developer_id: UUID, - agent_id: UUID, - limit: int = 100, - offset: int = 0, - sort_by: Literal["created_at", "updated_at"] = "created_at", - direction: Literal["asc", "desc"] = "desc", -) -> tuple[list[str], dict]: - agent_id = str(agent_id) - - sort = f"{'-' if direction == 'desc' else ''}{sort_by}" - - list_query = f""" - input[agent_id] <- [[to_uuid($agent_id)]] - - ?[ - agent_id, - id, - name, - type, - spec, - description, - updated_at, - created_at, - ] := input[agent_id], - *tools {{ - agent_id, - tool_id: id, - name, - type, - spec, - description, - updated_at, - created_at, - }} - - :limit $limit - :offset $offset - :sort {sort} - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query(developer_id, "agents", agent_id=agent_id), - list_query, - ] - - return ( - queries, - {"agent_id": agent_id, "limit": limit, "offset": offset}, - ) diff --git a/agents-api/agents_api/models/tools/patch_tool.py b/agents-api/agents_api/models/tools/patch_tool.py deleted file mode 100644 index bc49b8121..000000000 --- a/agents-api/agents_api/models/tools/patch_tool.py +++ /dev/null @@ -1,116 +0,0 @@ -from typing import Any, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import PatchToolRequest, ResourceUpdatedResponse -from ...common.utils.cozo import cozo_process_mutate_data -from ...metrics.counters import increase_counter -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class( - ResourceUpdatedResponse, - one=True, - transform=lambda d: {"id": d["tool_id"], "jobs": [], **d}, - _kind="inserted", -) -@cozo_query -@increase_counter("patch_tool") -@beartype -def patch_tool( - *, developer_id: UUID, agent_id: UUID, tool_id: UUID, data: PatchToolRequest -) -> tuple[list[str], dict]: - """ - Execute the datalog query and return the results as a DataFrame - Updates the tool information for a given agent and tool ID in the 'cozodb' database. - - Parameters: - agent_id (UUID): The unique identifier of the agent. - tool_id (UUID): The unique identifier of the tool to be updated. - data (PatchToolRequest): The request payload containing the updated tool information. - - Returns: - ResourceUpdatedResponse: The updated tool data. - """ - - agent_id = str(agent_id) - tool_id = str(tool_id) - - # Extract the tool data from the payload - patch_data = data.model_dump(exclude_none=True) - - # Assert that only one of the tool type fields is present - tool_specs = [ - (tool_type, patch_data.get(tool_type)) - for tool_type in ["function", "integration", "system", "api_call"] - if patch_data.get(tool_type) is not None - ] - - assert len(tool_specs) <= 1, "Invalid tool update" - tool_type, tool_spec = tool_specs[0] if tool_specs else (None, None) - - if tool_type is not None: - patch_data["type"] = patch_data.get("type", tool_type) - assert patch_data["type"] == tool_type, "Invalid tool update" - - tool_spec = tool_spec or {} - if tool_spec: - del patch_data[tool_type] - - tool_cols, tool_vals = cozo_process_mutate_data( - { - **patch_data, - "agent_id": agent_id, - "tool_id": tool_id, - } - ) - - # Construct the datalog query for updating the tool information - patch_query = f""" - input[{tool_cols}] <- $input - - ?[{tool_cols}, spec, updated_at] := - *tools {{ - agent_id: to_uuid($agent_id), - tool_id: to_uuid($tool_id), - spec: old_spec, - }}, - input[{tool_cols}], - spec = concat(old_spec, $spec), - updated_at = now() - - :update tools {{ {tool_cols}, spec, updated_at }} - :returning - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query(developer_id, "agents", agent_id=agent_id), - patch_query, - ] - - return ( - queries, - dict(input=tool_vals, spec=tool_spec, agent_id=agent_id, tool_id=tool_id), - ) diff --git a/agents-api/agents_api/models/tools/update_tool.py b/agents-api/agents_api/models/tools/update_tool.py deleted file mode 100644 index ef700a5f6..000000000 --- a/agents-api/agents_api/models/tools/update_tool.py +++ /dev/null @@ -1,109 +0,0 @@ -from typing import Any, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import ( - ResourceUpdatedResponse, - UpdateToolRequest, -) -from ...common.utils.cozo import cozo_process_mutate_data -from ...metrics.counters import increase_counter -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class( - ResourceUpdatedResponse, - one=True, - transform=lambda d: {"id": d["tool_id"], "jobs": [], **d}, - _kind="inserted", -) -@cozo_query -@increase_counter("update_tool") -@beartype -def update_tool( - *, - developer_id: UUID, - agent_id: UUID, - tool_id: UUID, - data: UpdateToolRequest, - **kwargs, -) -> tuple[list[str], dict]: - agent_id = str(agent_id) - tool_id = str(tool_id) - - # Extract the tool data from the payload - update_data = data.model_dump(exclude_none=True) - - # Assert that only one of the tool type fields is present - tool_specs = [ - (tool_type, update_data.get(tool_type)) - for tool_type in ["function", "integration", "system", "api_call"] - if update_data.get(tool_type) is not None - ] - - assert len(tool_specs) <= 1, "Invalid tool update" - tool_type, tool_spec = tool_specs[0] if tool_specs else (None, None) - - if tool_type is not None: - update_data["type"] = update_data.get("type", tool_type) - assert update_data["type"] == tool_type, "Invalid tool update" - - update_data["spec"] = tool_spec - del update_data[tool_type] - - tool_cols, tool_vals = cozo_process_mutate_data( - { - **update_data, - "agent_id": agent_id, - "tool_id": tool_id, - } - ) - - # Construct the datalog query for updating the tool information - patch_query = f""" - input[{tool_cols}] <- $input - - ?[{tool_cols}, created_at, updated_at] := - *tools {{ - agent_id: to_uuid($agent_id), - tool_id: to_uuid($tool_id), - created_at - }}, - input[{tool_cols}], - updated_at = now() - - :put tools {{ {tool_cols}, created_at, updated_at }} - :returning - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query(developer_id, "agents", agent_id=agent_id), - patch_query, - ] - - return ( - queries, - dict(input=tool_vals, spec=tool_spec, agent_id=agent_id, tool_id=tool_id), - ) diff --git a/agents-api/agents_api/models/user/__init__.py b/agents-api/agents_api/models/user/__init__.py deleted file mode 100644 index 5ae76865f..000000000 --- a/agents-api/agents_api/models/user/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -""" -This module is responsible for managing user data in the CozoDB database. It provides functionalities to create, retrieve, list, and update user information. - -Functions: -- create_user_query: Creates a new user in the CozoDB database, accepting parameters such as user ID, developer ID, name, about, and optional metadata. -- get_user_query: Retrieves a user's information from the CozoDB database by their user ID and developer ID. -- list_users_query: Lists users associated with a specific developer, with support for pagination and metadata-based filtering. -- patch_user_query: Updates a user's information in the CozoDB database, allowing for changes to fields such as name, about, and metadata. -""" - -# ruff: noqa: F401, F403, F405 - -from .create_or_update_user import create_or_update_user -from .create_user import create_user -from .get_user import get_user -from .list_users import list_users -from .patch_user import patch_user -from .update_user import update_user diff --git a/agents-api/agents_api/models/user/create_or_update_user.py b/agents-api/agents_api/models/user/create_or_update_user.py deleted file mode 100644 index 3e9b1f3a6..000000000 --- a/agents-api/agents_api/models/user/create_or_update_user.py +++ /dev/null @@ -1,125 +0,0 @@ -""" -This module contains the functionality for creating users in the CozoDB database. -It includes functions to construct and execute datalog queries for inserting new user records. -""" - -from typing import Any, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import CreateOrUpdateUserRequest, User -from ...metrics.counters import increase_counter -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass( - HTTPException, - status_code=400, - detail="A database query failed to return the expected results. This might occur if the requested resource doesn't exist or your query parameters are incorrect.", - ), - ValidationError: partialclass( - HTTPException, - status_code=400, - detail="Input validation failed. Please check the provided data for missing or incorrect fields, and ensure it matches the required format.", - ), - TypeError: partialclass( - HTTPException, - status_code=400, - detail="A type mismatch occurred. This likely means the data provided is of an incorrect type (e.g., string instead of integer). Please review the input and try again.", - ), - } -) -@wrap_in_class(User, one=True, transform=lambda d: {"id": UUID(d.pop("user_id")), **d}) -@cozo_query -@increase_counter("create_or_update_user") -@beartype -def create_or_update_user( - *, - developer_id: UUID, - user_id: UUID, - data: CreateOrUpdateUserRequest, -) -> tuple[list[str], dict]: - """ - Constructs and executes a datalog query to create a new user in the database. - - Parameters: - user_id (UUID): The unique identifier for the user. - developer_id (UUID): The unique identifier for the developer creating the user. - name (str): The name of the user. - about (str): A description of the user. - metadata (dict, optional): A dictionary of metadata for the user. Defaults to an empty dict. - client (CozoClient, optional): The CozoDB client instance to use for the query. Defaults to a preconfigured client instance. - - Returns: - User: The newly created user record. - """ - - # Extract the user data from the payload - data.metadata = data.metadata or {} - - user_data = data.model_dump() - - # Create the user - # Construct a query to insert the new user record into the users table - user_query = """ - input[user_id, developer_id, name, about, metadata, updated_at] <- [ - [$user_id, $developer_id, $name, $about, $metadata, now()] - ] - - ?[user_id, developer_id, name, about, metadata, created_at, updated_at] := - input[_user_id, developer_id, name, about, metadata, updated_at], - *users{ - developer_id, - user_id, - created_at, - }, - user_id = to_uuid(_user_id), - - ?[user_id, developer_id, name, about, metadata, created_at, updated_at] := - input[_user_id, developer_id, name, about, metadata, updated_at], - not *users{ - developer_id, - user_id, - }, created_at = now(), - user_id = to_uuid(_user_id), - - :put users { - developer_id, - user_id => - name, - about, - metadata, - created_at, - updated_at, - } - :returning - """ - - queries = [ - verify_developer_id_query(developer_id), - user_query, - ] - - return ( - queries, - { - "user_id": str(user_id), - "developer_id": str(developer_id), - **user_data, - }, - ) diff --git a/agents-api/agents_api/models/user/create_user.py b/agents-api/agents_api/models/user/create_user.py deleted file mode 100644 index ba96bd2b5..000000000 --- a/agents-api/agents_api/models/user/create_user.py +++ /dev/null @@ -1,115 +0,0 @@ -""" -This module contains the functionality for creating a new user in the CozoDB database. -It defines a query for inserting user data into the 'users' relation. -""" - -from typing import Any, TypeVar -from uuid import UUID, uuid4 - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import CreateUserRequest, User -from ...metrics.counters import increase_counter -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - lambda e: isinstance(e, QueryException) - and "asserted to return some results, but returned none" - in str(e): lambda *_: HTTPException( - detail="Developer not found. Please ensure the provided auth token (which refers to your developer_id) is valid and the developer has the necessary permissions to create an agent.", - status_code=403, - ), - QueryException: partialclass( - HTTPException, - status_code=400, - detail="A database query failed to return the expected results. This might occur if the requested resource doesn't exist or your query parameters are incorrect.", - ), - ValidationError: partialclass( - HTTPException, - status_code=400, - detail="Input validation failed. Please check the provided data for missing or incorrect fields, and ensure it matches the required format.", - ), - TypeError: partialclass( - HTTPException, - status_code=400, - detail="A type mismatch occurred. This likely means the data provided is of an incorrect type (e.g., string instead of integer). Please review the input and try again.", - ), - } -) -@wrap_in_class( - User, - one=True, - transform=lambda d: {"id": UUID(d.pop("user_id")), **d}, - _kind="inserted", -) -@cozo_query -@increase_counter("create_user") -@beartype -def create_user( - *, - developer_id: UUID, - user_id: UUID | None = None, - data: CreateUserRequest, -) -> tuple[list[str], dict]: - """ - Constructs and executes a datalog query to create a new user in the CozoDB database. - - Parameters: - user_id (UUID): The unique identifier for the user. - developer_id (UUID): The unique identifier for the developer creating the user. - name (str): The name of the user. - about (str): A brief description about the user. - metadata (dict, optional): Additional metadata about the user. Defaults to an empty dict. - client (CozoClient, optional): The CozoDB client instance to run the query. Defaults to a pre-configured client instance. - - Returns: - pd.DataFrame: A DataFrame containing the result of the query execution. - """ - - user_id = user_id or uuid4() - data.metadata = data.metadata or {} - user_data = data.model_dump() - - create_query = """ - # Then create the user - ?[user_id, developer_id, name, about, metadata] <- [ - [to_uuid($user_id), to_uuid($developer_id), $name, $about, $metadata] - ] - - :insert users { - developer_id, - user_id => - name, - about, - metadata, - } - :returning - """ - - queries = [ - verify_developer_id_query(developer_id), - create_query, - ] - - return ( - queries, - { - "user_id": str(user_id), - "developer_id": str(developer_id), - **user_data, - }, - ) diff --git a/agents-api/agents_api/models/user/delete_user.py b/agents-api/agents_api/models/user/delete_user.py deleted file mode 100644 index 7f08316be..000000000 --- a/agents-api/agents_api/models/user/delete_user.py +++ /dev/null @@ -1,116 +0,0 @@ -""" -This module contains the implementation of the delete_user_query function, which is responsible for deleting an user and its related default settings from the CozoDB database. -""" - -from typing import Any, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import ResourceDeletedResponse -from ...common.utils.datetime import utcnow -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - lambda e: isinstance(e, QueryException) - and "Developer does not exist" in str(e): lambda *_: HTTPException( - detail="The specified developer does not exist.", - status_code=403, - ), - lambda e: isinstance(e, QueryException) - and "Developer does not own resource" - in e.resp["display"]: lambda *_: HTTPException( - detail="The specified developer does not own the requested resource. Please verify the ownership or check if the developer ID is correct.", - status_code=404, - ), - QueryException: partialclass( - HTTPException, - status_code=400, - detail="A database query failed to return the expected results. This might occur if the requested resource doesn't exist or your query parameters are incorrect.", - ), - ValidationError: partialclass( - HTTPException, - status_code=400, - detail="Input validation failed. Please check the provided data for missing or incorrect fields, and ensure it matches the required format.", - ), - TypeError: partialclass( - HTTPException, - status_code=400, - detail="A type mismatch occurred. This likely means the data provided is of an incorrect type (e.g., string instead of integer). Please review the input and try again.", - ), - } -) -@wrap_in_class( - ResourceDeletedResponse, - one=True, - transform=lambda d: { - "id": UUID(d.pop("user_id")), - "deleted_at": utcnow(), - "jobs": [], - }, - _kind="deleted", -) -@cozo_query -@beartype -def delete_user(*, developer_id: UUID, user_id: UUID) -> tuple[list[str], dict]: - """ - Constructs and returns a datalog query for deleting an user and its default settings from the database. - - Parameters: - developer_id (UUID): The UUID of the developer owning the user. - user_id (UUID): The UUID of the user to be deleted. - client (CozoClient, optional): An instance of the CozoClient to execute the query. - - Returns: - ResourceDeletedResponse: The response indicating the deletion of the user. - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query(developer_id, "users", user_id=user_id), - """ - # Delete docs - ?[owner_type, owner_id, doc_id] := - *docs{ - owner_id, - owner_type, - doc_id, - }, - owner_id = to_uuid($user_id), - owner_type = "user" - - :delete docs { - owner_type, - owner_id, - doc_id - } - :returning - """, - """ - # Delete the user - ?[user_id, developer_id] <- [[$user_id, $developer_id]] - - :delete users { - developer_id, - user_id - } - :returning - """, - ] - - return (queries, {"user_id": str(user_id), "developer_id": str(developer_id)}) diff --git a/agents-api/agents_api/models/user/get_user.py b/agents-api/agents_api/models/user/get_user.py deleted file mode 100644 index 69b3da883..000000000 --- a/agents-api/agents_api/models/user/get_user.py +++ /dev/null @@ -1,107 +0,0 @@ -from typing import Any, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import User -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - lambda e: isinstance(e, QueryException) - and "Developer does not exist" in str(e): lambda *_: HTTPException( - detail="The specified developer does not exist.", - status_code=403, - ), - lambda e: isinstance(e, QueryException) - and "Developer does not own resource" - in e.resp["display"]: lambda *_: HTTPException( - detail="The specified developer does not own the requested resource. Please verify the ownership or check if the developer ID is correct.", - status_code=404, - ), - QueryException: partialclass( - HTTPException, - status_code=400, - detail="A database query failed to return the expected results. This might occur if the requested resource doesn't exist or your query parameters are incorrect.", - ), - ValidationError: partialclass( - HTTPException, - status_code=400, - detail="Input validation failed. Please check the provided data for missing or incorrect fields, and ensure it matches the required format.", - ), - TypeError: partialclass( - HTTPException, - status_code=400, - detail="A type mismatch occurred. This likely means the data provided is of an incorrect type (e.g., string instead of integer). Please review the input and try again.", - ), - } -) -@wrap_in_class(User, one=True) -@cozo_query -@beartype -def get_user( - *, - developer_id: UUID, - user_id: UUID, -) -> tuple[list[str], dict]: - """ - Retrieves a user by their unique identifier. - - - Parameters: - developer_id (UUID): The unique identifier of the developer associated with the user. - user_id (UUID): The unique identifier of the user to retrieve. - - Returns: - User: The retrieved user. - """ - - # Convert UUIDs to strings for query compatibility. - user_id = str(user_id) - developer_id = str(developer_id) - - get_query = """ - input[developer_id, user_id] <- [[to_uuid($developer_id), to_uuid($user_id)]] - - ?[ - id, - name, - about, - created_at, - updated_at, - metadata, - ] := input[developer_id, id], - *users { - user_id: id, - developer_id, - name, - about, - created_at, - updated_at, - metadata, - } - - :limit 1 - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query(developer_id, "users", user_id=user_id), - get_query, - ] - - return (queries, {"developer_id": developer_id, "user_id": user_id}) diff --git a/agents-api/agents_api/models/user/list_users.py b/agents-api/agents_api/models/user/list_users.py deleted file mode 100644 index f1e06adf4..000000000 --- a/agents-api/agents_api/models/user/list_users.py +++ /dev/null @@ -1,116 +0,0 @@ -from typing import Any, Literal, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import User -from ...common.utils import json -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass( - HTTPException, - status_code=400, - detail="A database query failed to return the expected results. This might occur if the requested resource doesn't exist or your query parameters are incorrect.", - ), - ValidationError: partialclass( - HTTPException, - status_code=400, - detail="Input validation failed. Please check the provided data for missing or incorrect fields, and ensure it matches the required format.", - ), - TypeError: partialclass( - HTTPException, - status_code=400, - detail="A type mismatch occurred. This likely means the data provided is of an incorrect type (e.g., string instead of integer). Please review the input and try again.", - ), - } -) -@wrap_in_class(User) -@cozo_query -@beartype -def list_users( - *, - developer_id: UUID, - limit: int = 100, - offset: int = 0, - sort_by: Literal["created_at", "updated_at"] = "created_at", - direction: Literal["asc", "desc"] = "desc", - metadata_filter: dict[str, Any] = {}, -) -> tuple[list[str], dict]: - """ - Queries the 'cozodb' database to list users associated with a specific developer. - - Parameters: - developer_id (UUID): The unique identifier of the developer. - limit (int): The maximum number of users to return. Defaults to 100. - offset (int): The number of users to skip before starting to collect the result set. Defaults to 0. - sort_by (Literal["created_at", "updated_at"]): The field to sort the users by. Defaults to "created_at". - direction (Literal["asc", "desc"]): The direction to sort the users in. Defaults to "desc". - metadata_filter (dict[str, Any]): A dictionary representing filters to apply on user metadata. - - Returns: - pd.DataFrame: A DataFrame containing the queried user data. - """ - # Construct a filter string for the metadata based on the provided dictionary. - metadata_filter_str = ", ".join( - [ - f"metadata->{json.dumps(k)} == {json.dumps(v)}" - for k, v in metadata_filter.items() - ] - ) - - sort = f"{'-' if direction == 'desc' else ''}{sort_by}" - - # Define the datalog query for retrieving user information based on the specified filters and sorting them by creation date in descending order. - list_query = f""" - input[developer_id] <- [[to_uuid($developer_id)]] - - ?[ - id, - name, - about, - created_at, - updated_at, - metadata, - ] := - input[developer_id], - *users {{ - user_id: id, - developer_id, - name, - about, - created_at, - updated_at, - metadata, - }}, - {metadata_filter_str} - - :limit $limit - :offset $offset - :sort {sort} - """ - - queries = [ - verify_developer_id_query(developer_id), - list_query, - ] - - # Execute the datalog query with the specified parameters and return the results as a DataFrame. - return ( - queries, - {"developer_id": str(developer_id), "limit": limit, "offset": offset}, - ) diff --git a/agents-api/agents_api/models/user/patch_user.py b/agents-api/agents_api/models/user/patch_user.py deleted file mode 100644 index bd3fc0246..000000000 --- a/agents-api/agents_api/models/user/patch_user.py +++ /dev/null @@ -1,121 +0,0 @@ -"""Module for generating datalog queries to update user information in the 'cozodb' database.""" - -from typing import Any, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import PatchUserRequest, ResourceUpdatedResponse -from ...common.utils.cozo import cozo_process_mutate_data -from ...common.utils.datetime import utcnow -from ...metrics.counters import increase_counter -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass( - HTTPException, - status_code=400, - detail="A database query failed to return the expected results. This might occur if the requested resource doesn't exist or your query parameters are incorrect.", - ), - ValidationError: partialclass( - HTTPException, - status_code=400, - detail="Input validation failed. Please check the provided data for missing or incorrect fields, and ensure it matches the required format.", - ), - TypeError: partialclass( - HTTPException, - status_code=400, - detail="A type mismatch occurred. This likely means the data provided is of an incorrect type (e.g., string instead of integer). Please review the input and try again.", - ), - } -) -@wrap_in_class( - ResourceUpdatedResponse, - one=True, - transform=lambda d: {"id": d["user_id"], "jobs": [], **d}, - _kind="inserted", -) -@cozo_query -@increase_counter("patch_user") -@beartype -def patch_user( - *, - developer_id: UUID, - user_id: UUID, - data: PatchUserRequest, -) -> tuple[list[str], dict]: - """ - Generates a datalog query for updating a user's information. - - Parameters: - developer_id (UUID): The UUID of the developer. - user_id (UUID): The UUID of the user to be updated. - **update_data: Arbitrary keyword arguments representing the data to be updated. - - Returns: - tuple[str, dict]: A pandas DataFrame containing the results of the query execution. - """ - - update_data = data.model_dump(exclude_unset=True) - - # Prepare data for mutation by filtering out None values and adding system-generated fields. - metadata = update_data.pop("metadata", {}) or {} - user_update_cols, user_update_vals = cozo_process_mutate_data( - { - **{k: v for k, v in update_data.items() if v is not None}, - "user_id": str(user_id), - "developer_id": str(developer_id), - "updated_at": utcnow().timestamp(), - } - ) - - # Construct the datalog query for updating user information. - update_query = f""" - # update the user - input[{user_update_cols}] <- $user_update_vals - - ?[{user_update_cols}, metadata] := - input[{user_update_cols}], - *users:developer_id_metadata_user_id_idx {{ - developer_id: to_uuid($developer_id), - user_id: to_uuid($user_id), - metadata: md, - }}, - metadata = concat(md, $metadata) - - :update users {{ - {user_update_cols}, metadata - }} - :returning - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query(developer_id, "users", user_id=user_id), - update_query, - ] - - return ( - queries, - { - "user_update_vals": user_update_vals, - "metadata": metadata, - "user_id": str(user_id), - "developer_id": str(developer_id), - }, - ) diff --git a/agents-api/agents_api/models/user/update_user.py b/agents-api/agents_api/models/user/update_user.py deleted file mode 100644 index 68e6e6c25..000000000 --- a/agents-api/agents_api/models/user/update_user.py +++ /dev/null @@ -1,118 +0,0 @@ -from typing import Any, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import ResourceUpdatedResponse, UpdateUserRequest -from ...common.utils.cozo import cozo_process_mutate_data -from ...metrics.counters import increase_counter -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass( - HTTPException, - status_code=400, - detail="A database query failed to return the expected results. This might occur if the requested resource doesn't exist or your query parameters are incorrect.", - ), - ValidationError: partialclass( - HTTPException, - status_code=400, - detail="Input validation failed. Please check the provided data for missing or incorrect fields, and ensure it matches the required format.", - ), - TypeError: partialclass( - HTTPException, - status_code=400, - detail="A type mismatch occurred. This likely means the data provided is of an incorrect type (e.g., string instead of integer). Please review the input and try again.", - ), - } -) -@wrap_in_class( - ResourceUpdatedResponse, - one=True, - transform=lambda d: {"id": d["user_id"], "jobs": [], **d}, - _kind="inserted", -) -@cozo_query -@increase_counter("update_user") -@beartype -def update_user( - *, developer_id: UUID, user_id: UUID, data: UpdateUserRequest -) -> tuple[list[str], dict]: - """ - Updates user information in the 'cozodb' database. - - Parameters: - developer_id (UUID): The developer's unique identifier. - user_id (UUID): The user's unique identifier. - client (CozoClient): The Cozo database client instance. - **update_data: Arbitrary keyword arguments representing the data to update. - - Returns: - pd.DataFrame: A DataFrame containing the result of the update operation. - """ - user_id = str(user_id) - developer_id = str(developer_id) - update_data = data.model_dump() - - # Prepares the update data by filtering out None values and adding user_id and developer_id. - user_update_cols, user_update_vals = cozo_process_mutate_data( - { - **{k: v for k, v in update_data.items() if v is not None}, - "user_id": user_id, - "developer_id": developer_id, - } - ) - - # Constructs the update operation for the user, setting new values and updating 'updated_at'. - update_query = f""" - # update the user - # This line updates the user's information based on the provided columns and values. - input[{user_update_cols}] <- $user_update_vals - original[created_at] := *users{{ - developer_id: to_uuid($developer_id), - user_id: to_uuid($user_id), - created_at, - }}, - - ?[created_at, updated_at, {user_update_cols}] := - input[{user_update_cols}], - original[created_at], - updated_at = now(), - - :put users {{ - created_at, - updated_at, - {user_update_cols} - }} - :returning - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query(developer_id, "users", user_id=user_id), - update_query, - ] - - return ( - queries, - { - "user_update_vals": user_update_vals, - "developer_id": developer_id, - "user_id": user_id, - }, - ) diff --git a/agents-api/agents_api/models/utils.py b/agents-api/agents_api/models/utils.py deleted file mode 100644 index fc3f4e9b9..000000000 --- a/agents-api/agents_api/models/utils.py +++ /dev/null @@ -1,577 +0,0 @@ -import concurrent.futures -import inspect -import re -import time -from functools import partialmethod, wraps -from typing import Any, Awaitable, Callable, ParamSpec, Type, TypeVar -from uuid import UUID - -import pandas as pd -from fastapi import HTTPException -from httpcore import ConnectError, NetworkError, TimeoutException -from httpx import ConnectError as HttpxConnectError -from httpx import RequestError -from pydantic import BaseModel -from requests.exceptions import ConnectionError, Timeout - -from ..common.utils.cozo import uuid_int_list_to_uuid4 -from ..env import do_verify_developer, do_verify_developer_owns_resource - -P = ParamSpec("P") -T = TypeVar("T") -ModelT = TypeVar("ModelT", bound=BaseModel) - - -def fix_uuid( - item: dict[str, Any], attr_regex: str = r"^(?:id|.*_id)$" -) -> dict[str, Any]: - # find the attributes that are ids - id_attrs = [ - attr for attr in item.keys() if re.match(attr_regex, attr) and item[attr] - ] - - if not id_attrs: - return item - - fixed = { - **item, - **{ - attr: uuid_int_list_to_uuid4(item[attr]) - for attr in id_attrs - if isinstance(item[attr], list) - }, - } - - return fixed - - -def fix_uuid_list( - items: list[dict[str, Any]], attr_regex: str = r"^(?:id|.*_id)$" -) -> list[dict[str, Any]]: - fixed = list(map(lambda item: fix_uuid(item, attr_regex), items)) - return fixed - - -def fix_uuid_if_present(item: Any, attr_regex: str = r"^(?:id|.*_id)$") -> Any: - match item: - case [dict(), *_]: - return fix_uuid_list(item, attr_regex) - - case dict(): - return fix_uuid(item, attr_regex) - - case _: - return item - - -def partialclass(cls, *args, **kwargs): - cls_signature = inspect.signature(cls) - bound = cls_signature.bind_partial(*args, **kwargs) - - # The `updated=()` argument is necessary to avoid a TypeError when using @wraps for a class - @wraps(cls, updated=()) - class NewCls(cls): - __init__ = partialmethod(cls.__init__, *bound.args, **bound.kwargs) - - return NewCls - - -def mark_session_updated_query(developer_id: UUID | str, session_id: UUID | str) -> str: - return f""" - input[developer_id, session_id] <- [[ - to_uuid("{str(developer_id)}"), - to_uuid("{str(session_id)}"), - ]] - - ?[ - developer_id, - session_id, - situation, - summary, - created_at, - metadata, - render_templates, - token_budget, - context_overflow, - updated_at, - ] := - input[developer_id, session_id], - *sessions {{ - session_id, - situation, - summary, - created_at, - metadata, - render_templates, - token_budget, - context_overflow, - @ 'END' - }}, - updated_at = [floor(now()), true] - - :put sessions {{ - developer_id, - session_id, - situation, - summary, - created_at, - metadata, - render_templates, - token_budget, - context_overflow, - updated_at, - }} - """ - - -def verify_developer_id_query(developer_id: UUID | str) -> str: - if not do_verify_developer: - return "?[exists] := exists = true" - - return f""" - matched[count(developer_id)] := - *developers{{ - developer_id, - }}, developer_id = to_uuid("{str(developer_id)}") - - ?[exists] := - matched[num], - exists = num > 0, - assert(exists, "Developer does not exist") - - :limit 1 - """ - - -def verify_developer_owns_resource_query( - developer_id: UUID | str, - resource: str, - parents: list[tuple[str, str]] | None = None, - **resource_id, -) -> str: - if not do_verify_developer_owns_resource: - return "?[exists] := exists = true" - - parents = parents or [] - resource_id_key, resource_id_value = next(iter(resource_id.items())) - - parents.append((resource, resource_id_key)) - parent_keys = ["developer_id", *map(lambda x: x[1], parents)] - - rule_head = f""" - found[count({resource_id_key})] := - developer_id = to_uuid("{str(developer_id)}"), - {resource_id_key} = to_uuid("{str(resource_id_value)}"), - """ - - rule_body = "" - for parent_key, (relation, key) in zip(parent_keys, parents): - rule_body += f""" - *{relation}{{ - {parent_key}, - {key}, - }}, - """ - - assertion = f""" - ?[exists] := - found[num], - exists = num > 0, - assert(exists, "Developer does not own resource {resource} with {resource_id_key} {resource_id_value}") - - :limit 1 - """ - - rule = rule_head + rule_body + assertion - return rule - - -def make_cozo_json_query(fields): - return ", ".join(f'"{field}": {field}' for field in fields).strip() - - -def cozo_query( - func: Callable[P, tuple[str | list[str | None], dict]] | None = None, - debug: bool | None = None, - only_on_error: bool = False, - timeit: bool = False, -): - def cozo_query_dec(func: Callable[P, tuple[str | list[Any], dict]]): - """ - Decorator that wraps a function that takes arbitrary arguments, and - returns a (query string, variables) tuple. - - The wrapped function should additionally take a client keyword argument - and then run the query using the client, returning a DataFrame. - """ - - from pprint import pprint - - from tenacity import ( - retry, - retry_if_exception, - stop_after_attempt, - wait_exponential, - ) - - def is_resource_busy(e: Exception) -> bool: - return ( - isinstance(e, HTTPException) - and e.status_code == 429 - and not getattr(e, "cozo_offline", False) - ) - - @retry( - stop=stop_after_attempt(4), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception(is_resource_busy), - ) - @wraps(func) - def wrapper(*args: P.args, client=None, **kwargs: P.kwargs) -> pd.DataFrame: - queries, variables = func(*args, **kwargs) - - if isinstance(queries, str): - query = queries - else: - queries = [str(query) for query in queries if query] - query = "}\n\n{\n".join(queries) - query = f"{{ {query} }}" - - not only_on_error and debug and print(query) - not only_on_error and debug and pprint( - dict( - variables=variables, - ) - ) - - # Run the query - from ..clients import cozo - - try: - client = client or cozo.get_cozo_client() - - start = timeit and time.perf_counter() - result = client.run(query, variables) - end = timeit and time.perf_counter() - - timeit and print(f"Cozo query time: {end - start:.2f} seconds") - - except Exception as e: - if only_on_error and debug: - print(query) - pprint(variables) - - debug and print(repr(e)) - - pretty_error = repr(e).lower() - cozo_busy = ("busy" in pretty_error) or ( - "when executing against relation '_" in pretty_error - ) - cozo_offline = isinstance(e, ConnectionError) and ( - ("connection refused" in pretty_error) - or ("name or service not known" in pretty_error) - ) - connection_error = isinstance( - e, - ( - ConnectionError, - Timeout, - TimeoutException, - NetworkError, - RequestError, - ), - ) - - if cozo_busy or connection_error or cozo_offline: - exc = HTTPException( - status_code=429, detail="Resource busy. Please try again later." - ) - exc.cozo_offline = cozo_offline - raise exc from e - - raise - - # Need to fix the UUIDs in the result - result = result.map(fix_uuid_if_present) - - not only_on_error and debug and pprint( - dict( - result=result.to_dict(orient="records"), - ) - ) - - return result - - # Set the wrapped function as an attribute of the wrapper, - # forwards the __wrapped__ attribute if it exists. - setattr(wrapper, "__wrapped__", getattr(func, "__wrapped__", func)) - - return wrapper - - if func is not None and callable(func): - return cozo_query_dec(func) - - return cozo_query_dec - - -def cozo_query_async( - func: Callable[ - P, - tuple[str | list[str | None], dict] - | Awaitable[tuple[str | list[str | None], dict]], - ] - | None = None, - debug: bool | None = None, - only_on_error: bool = False, - timeit: bool = False, -): - def cozo_query_dec( - func: Callable[ - P, tuple[str | list[Any], dict] | Awaitable[tuple[str | list[Any], dict]] - ], - ): - """ - Decorator that wraps a function that takes arbitrary arguments, and - returns a (query string, variables) tuple. - - The wrapped function should additionally take a client keyword argument - and then run the query using the client, returning a DataFrame. - """ - - from pprint import pprint - - from tenacity import ( - retry, - retry_if_exception, - stop_after_attempt, - wait_exponential, - ) - - def is_resource_busy(e: Exception) -> bool: - return ( - isinstance(e, HTTPException) - and e.status_code == 429 - and not getattr(e, "cozo_offline", False) - ) - - @retry( - stop=stop_after_attempt(6), - wait=wait_exponential(multiplier=1.2, min=3, max=10), - retry=retry_if_exception(is_resource_busy), - reraise=True, - ) - @wraps(func) - async def wrapper( - *args: P.args, client=None, **kwargs: P.kwargs - ) -> pd.DataFrame: - if inspect.iscoroutinefunction(func): - queries, variables = await func(*args, **kwargs) - else: - queries, variables = func(*args, **kwargs) - - if isinstance(queries, str): - query = queries - else: - queries = [str(query) for query in queries if query] - query = "}\n\n{\n".join(queries) - query = f"{{ {query} }}" - - not only_on_error and debug and print(query) - not only_on_error and debug and pprint( - dict( - variables=variables, - ) - ) - - # Run the query - from ..clients import cozo - - try: - client = client or cozo.get_async_cozo_client() - - start = timeit and time.perf_counter() - result = await client.run(query, variables) - end = timeit and time.perf_counter() - - timeit and print(f"Cozo query time: {end - start:.2f} seconds") - - except Exception as e: - if only_on_error and debug: - print(query) - pprint(variables) - - debug and print(repr(e)) - - pretty_error = repr(e).lower() - cozo_busy = ("busy" in pretty_error) or ( - "when executing against relation '_" in pretty_error - ) - cozo_offline = ( - isinstance(e, ConnectError) - or isinstance(e, HttpxConnectError) - and ( - ("all connection attempts failed" in pretty_error) - or ("name or service not known" in pretty_error) - ) - ) - connection_error = isinstance( - e, - ( - ConnectError, - HttpxConnectError, - TimeoutException, - NetworkError, - RequestError, - ), - ) - - if cozo_busy or connection_error or cozo_offline: - exc = HTTPException( - status_code=429, detail="Resource busy. Please try again later." - ) - exc.cozo_offline = cozo_offline - raise exc from e - - raise - - # Need to fix the UUIDs in the result - result = result.map(fix_uuid_if_present) - - not only_on_error and debug and pprint( - dict( - result=result.to_dict(orient="records"), - ) - ) - - return result - - # Set the wrapped function as an attribute of the wrapper, - # forwards the __wrapped__ attribute if it exists. - setattr(wrapper, "__wrapped__", getattr(func, "__wrapped__", func)) - - return wrapper - - if func is not None and callable(func): - return cozo_query_dec(func) - - return cozo_query_dec - - -def wrap_in_class( - cls: Type[ModelT] | Callable[..., ModelT], - one: bool = False, - transform: Callable[[dict], dict] | None = None, - _kind: str | None = None, -): - def _return_data(df: pd.DataFrame): - # Convert df to list of dicts - if _kind: - df = df[df["_kind"] == _kind] - - data = df.to_dict(orient="records") - - nonlocal transform - transform = transform or (lambda x: x) - - if one: - assert len(data) >= 1, "Expected one result, got none" - obj: ModelT = cls(**transform(data[0])) - return obj - - objs: list[ModelT] = [cls(**item) for item in map(transform, data)] - return objs - - def decorator(func: Callable[P, pd.DataFrame | Awaitable[pd.DataFrame]]): - @wraps(func) - def wrapper(*args: P.args, **kwargs: P.kwargs) -> ModelT | list[ModelT]: - return _return_data(func(*args, **kwargs)) - - @wraps(func) - async def async_wrapper( - *args: P.args, **kwargs: P.kwargs - ) -> ModelT | list[ModelT]: - return _return_data(await func(*args, **kwargs)) - - # Set the wrapped function as an attribute of the wrapper, - # forwards the __wrapped__ attribute if it exists. - setattr(wrapper, "__wrapped__", getattr(func, "__wrapped__", func)) - setattr(async_wrapper, "__wrapped__", getattr(func, "__wrapped__", func)) - - return async_wrapper if inspect.iscoroutinefunction(func) else wrapper - - return decorator - - -def rewrap_exceptions( - mapping: dict[ - Type[BaseException] | Callable[[BaseException], bool], - Type[BaseException] | Callable[[BaseException], BaseException], - ], - /, -): - def _check_error(error): - nonlocal mapping - - for check, transform in mapping.items(): - should_catch = ( - isinstance(error, check) if isinstance(check, type) else check(error) - ) - - if should_catch: - new_error = ( - transform(str(error)) - if isinstance(transform, type) - else transform(error) - ) - - setattr(new_error, "__cause__", error) - - raise new_error from error - - def decorator(func: Callable[P, T | Awaitable[T]]): - @wraps(func) - async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: - try: - result: T = await func(*args, **kwargs) - except BaseException as error: - _check_error(error) - raise - - return result - - @wraps(func) - def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: - try: - result: T = func(*args, **kwargs) - except BaseException as error: - _check_error(error) - raise - - return result - - # Set the wrapped function as an attribute of the wrapper, - # forwards the __wrapped__ attribute if it exists. - setattr(wrapper, "__wrapped__", getattr(func, "__wrapped__", func)) - setattr(async_wrapper, "__wrapped__", getattr(func, "__wrapped__", func)) - - return async_wrapper if inspect.iscoroutinefunction(func) else wrapper - - return decorator - - -def run_concurrently( - fns: list[Callable[..., Any]], - *, - args_list: list[tuple] = [], - kwargs_list: list[dict] = [], -) -> list[Any]: - args_list = args_list or [tuple()] * len(fns) - kwargs_list = kwargs_list or [dict()] * len(fns) - - with concurrent.futures.ThreadPoolExecutor() as executor: - futures = [ - executor.submit(fn, *args, **kwargs) - for fn, args, kwargs in zip(fns, args_list, kwargs_list) - ] - - return [future.result() for future in concurrent.futures.as_completed(futures)] diff --git a/agents-api/agents_api/prompt_assets/sys_prompt.yml b/agents-api/agents_api/prompt_assets/sys_prompt.yml deleted file mode 100644 index 0aad05160..000000000 --- a/agents-api/agents_api/prompt_assets/sys_prompt.yml +++ /dev/null @@ -1,35 +0,0 @@ -Role: | - You are a function calling AI agent with self-recursion. - You can call only one function at a time and analyse data you get from function response. - You are provided with function signatures within XML tags. - The current date is: {date}. -Objective: | - You may use agentic frameworks for reasoning and planning to help with user query. - Please call a function and wait for function results to be provided to you in the next iteration. - Don't make assumptions about what values to plug into function arguments. - Once you have called a function, results will be fed back to you within XML tags. - Don't make assumptions about tool results if XML tags are not present since function hasn't been executed yet. - Analyze the data once you get the results and call another function. - At each iteration please continue adding the your analysis to previous summary. - Your final response should directly answer the user query with an anlysis or summary of the results of function calls. -Tools: | - Here are the available tools: - {{agent.tools}} - If the provided function signatures doesn't have the function you must call, you may write executable python code in markdown syntax and call code_interpreter() function as follows: - - {{"arguments": {{"code_markdown": , "name": "code_interpreter"}}}} - - Make sure that the json object above with code markdown block is parseable with json.loads() and the XML block with XML ElementTree. -Schema: | - Use the following pydantic model json schema for each tool call you will make: - {schema} -Instructions: | - At the very first turn you don't have so you shouldn't not make up the results. - Please keep a running summary with analysis of previous function results and summaries from previous iterations. - Do not stop calling functions until the task has been accomplished or you've reached max iteration of 10. - Calling multiple functions at once can overload the system and increase cost so call one function at a time please. - If you plan to continue with analysis, always call another function. - For each function call return a valid json object (using doulbe quotes) with function name and arguments within XML tags as follows: - - {{"arguments": , "name": }} - diff --git a/agents-api/agents_api/queries/__init__.py b/agents-api/agents_api/queries/__init__.py new file mode 100644 index 000000000..4b00a644d --- /dev/null +++ b/agents-api/agents_api/queries/__init__.py @@ -0,0 +1,20 @@ +""" +The `queries` module of the agents API is designed to encapsulate all data interactions with the PostgreSQL database. It provides a structured way to perform CRUD (Create, Read, Update, Delete) operations and other specific data manipulations across various entities such as agents, documents, entries, sessions, tools, and users. + +Each sub-module within this module corresponds to a specific entity and contains functions and classes that implement SQL queries for interacting with the database. These interactions include creating new records, updating existing ones, retrieving data for specific conditions, and deleting records. The operations are crucial for the functionality of the agents API, enabling it to manage and process data effectively for each entity. + +This module also integrates with the `common` module for exception handling and utility functions, ensuring robust error management and providing reusable components for data processing and query construction. +""" + +# ruff: noqa: F401, F403, F405 + +from . import agents as agents +from . import developers as developers +from . import docs as docs +from . import entries as entries +from . import executions as executions +from . import files as files +from . import sessions as sessions +from . import tasks as tasks +from . import tools as tools +from . import users as users diff --git a/agents-api/agents_api/queries/agents/__init__.py b/agents-api/agents_api/queries/agents/__init__.py new file mode 100644 index 000000000..c0712c47c --- /dev/null +++ b/agents-api/agents_api/queries/agents/__init__.py @@ -0,0 +1,31 @@ +""" +The `agent` module within the `queries` package provides a comprehensive suite of SQL query functions for managing agents in the PostgreSQL database. This includes: + +- Creating new agents +- Updating existing agents +- Retrieving details about specific agents +- Listing agents with filtering and pagination +- Deleting agents from the database + +Each function in this module constructs and returns SQL queries along with their parameters for database operations. +""" + +# ruff: noqa: F401, F403, F405 + +from .create_agent import create_agent +from .create_or_update_agent import create_or_update_agent +from .delete_agent import delete_agent +from .get_agent import get_agent +from .list_agents import list_agents +from .patch_agent import patch_agent +from .update_agent import update_agent + +__all__ = [ + "create_agent", + "create_or_update_agent", + "delete_agent", + "get_agent", + "list_agents", + "patch_agent", + "update_agent", +] diff --git a/agents-api/agents_api/queries/agents/create_agent.py b/agents-api/agents_api/queries/agents/create_agent.py new file mode 100644 index 000000000..380e20798 --- /dev/null +++ b/agents-api/agents_api/queries/agents/create_agent.py @@ -0,0 +1,99 @@ +""" +This module contains the functionality for creating agents in the PostgreSQL database. +It includes functions to construct and execute SQL queries for inserting new agent records. +""" + +from uuid import UUID + +from beartype import beartype +from uuid_extensions import uuid7 + +from ...autogen.openapi_model import CreateAgentRequest, ResourceCreatedResponse +from ...common.utils.db_exceptions import common_db_exceptions +from ...metrics.counters import increase_counter +from ..utils import generate_canonical_name, pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL query +agent_query = """ +INSERT INTO agents ( + developer_id, + agent_id, + canonical_name, + name, + about, + instructions, + model, + metadata, + default_settings +) +VALUES ( + $1, + $2, + $3, + $4, + $5, + $6, + $7, + $8, + $9 +) +RETURNING *; +""" + + +@rewrap_exceptions(common_db_exceptions("agent", ["create"])) +@wrap_in_class( + ResourceCreatedResponse, + one=True, + transform=lambda d: {"id": d["agent_id"], "created_at": d["created_at"]}, +) +@increase_counter("create_agent") +@pg_query +@beartype +async def create_agent( + *, + developer_id: UUID, + agent_id: UUID | None = None, + data: CreateAgentRequest, +) -> tuple[str, list]: + """ + Constructs and executes a SQL query to create a new agent in the database. + + Parameters: + agent_id (UUID | None): The unique identifier for the agent. + developer_id (UUID): The unique identifier for the developer creating the agent. + data (CreateAgentRequest): The data for the new agent. + + Returns: + tuple[str, dict]: SQL query and parameters for creating the agent. + """ + agent_id = agent_id or uuid7() + + # Ensure instructions is a list + data.instructions = ( + data.instructions if isinstance(data.instructions, list) else [data.instructions] + ) + + # Convert default_settings to dict if it exists + default_settings = data.default_settings.model_dump() if data.default_settings else {} + + # Set default values + data.metadata = data.metadata or {} + data.canonical_name = data.canonical_name or generate_canonical_name() + + params = [ + developer_id, + agent_id, + data.canonical_name, + data.name, + data.about, + data.instructions, + data.model, + data.metadata, + default_settings, + ] + + return ( + agent_query, + params, + ) diff --git a/agents-api/agents_api/queries/agents/create_or_update_agent.py b/agents-api/agents_api/queries/agents/create_or_update_agent.py new file mode 100644 index 000000000..d65e0e9fc --- /dev/null +++ b/agents-api/agents_api/queries/agents/create_or_update_agent.py @@ -0,0 +1,110 @@ +""" +This module contains the functionality for creating or updating agents in the PostgreSQL database. +It constructs and executes SQL queries to insert a new agent or update an existing agent's details based on agent ID and developer ID. +""" + +from uuid import UUID + +from beartype import beartype + +from ...autogen.openapi_model import Agent, CreateOrUpdateAgentRequest +from ...common.utils.db_exceptions import common_db_exceptions +from ...metrics.counters import increase_counter +from ..utils import generate_canonical_name, pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL query +agent_query = """ +WITH existing_agent AS ( + SELECT canonical_name + FROM agents + WHERE developer_id = $1 AND agent_id = $2 +) +INSERT INTO agents ( + developer_id, + agent_id, + canonical_name, + name, + about, + instructions, + model, + metadata, + default_settings +) +VALUES ( + $1, -- developer_id + $2, -- agent_id + COALESCE( -- canonical_name + (SELECT canonical_name FROM existing_agent), + $3 + ), + $4, -- name + $5, -- about + $6, -- instructions + $7, -- model + $8, -- metadata + $9 -- default_settings +) +ON CONFLICT (developer_id, agent_id) DO UPDATE SET + canonical_name = EXCLUDED.canonical_name, + name = EXCLUDED.name, + about = EXCLUDED.about, + instructions = EXCLUDED.instructions, + model = EXCLUDED.model, + metadata = EXCLUDED.metadata, + default_settings = EXCLUDED.default_settings +RETURNING *; +""" + + +@rewrap_exceptions(common_db_exceptions("agent", ["create", "update"])) +@wrap_in_class( + Agent, + one=True, + transform=lambda d: {**d, "id": d["agent_id"]}, +) +@increase_counter("create_or_update_agent") +@pg_query +@beartype +async def create_or_update_agent( + *, agent_id: UUID, developer_id: UUID, data: CreateOrUpdateAgentRequest +) -> tuple[str, list]: + """ + Constructs the SQL queries to create a new agent or update an existing agent's details. + + Args: + agent_id (UUID): The UUID of the agent to create or update. + developer_id (UUID): The UUID of the developer owning the agent. + agent_data (Dict[str, Any]): A dictionary containing agent fields to insert or update. + + Returns: + tuple[list[str], dict]: A tuple containing the list of SQL queries and their parameters. + """ + + # Ensure instructions is a list + data.instructions = ( + data.instructions if isinstance(data.instructions, list) else [data.instructions] + ) + + # Convert default_settings to dict if it exists + default_settings = data.default_settings.model_dump() if data.default_settings else {} + + # Set default values + data.metadata = data.metadata or {} + data.canonical_name = data.canonical_name or generate_canonical_name() + + params = [ + developer_id, + agent_id, + data.canonical_name, + data.name, + data.about, + data.instructions, + data.model, + data.metadata, + default_settings, + ] + + return ( + agent_query, + params, + ) diff --git a/agents-api/agents_api/queries/agents/delete_agent.py b/agents-api/agents_api/queries/agents/delete_agent.py new file mode 100644 index 000000000..6b3e85eb5 --- /dev/null +++ b/agents-api/agents_api/queries/agents/delete_agent.py @@ -0,0 +1,84 @@ +""" +This module contains the functionality for deleting agents from the PostgreSQL database. +It constructs and executes SQL queries to remove agent records and associated data. +""" + +from uuid import UUID + +from beartype import beartype + +from ...autogen.openapi_model import ResourceDeletedResponse +from ...common.utils.datetime import utcnow +from ...common.utils.db_exceptions import common_db_exceptions +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL query +agent_query = """ +WITH deleted_file_owners AS ( + DELETE FROM file_owners + WHERE developer_id = $1 + AND owner_type = 'agent' + AND owner_id = $2 +), +deleted_doc_owners AS ( + DELETE FROM doc_owners + WHERE developer_id = $1 + AND owner_type = 'agent' + AND owner_id = $2 +), +deleted_files AS ( + DELETE FROM files + WHERE developer_id = $1 + AND file_id IN ( + SELECT file_id FROM file_owners + WHERE developer_id = $1 + AND owner_type = 'agent' + AND owner_id = $2 + ) +), +deleted_docs AS ( + DELETE FROM docs + WHERE developer_id = $1 + AND doc_id IN ( + SELECT doc_id FROM doc_owners + WHERE developer_id = $1 + AND owner_type = 'agent' + AND owner_id = $2 + ) +), +deleted_tools AS ( + DELETE FROM tools + WHERE agent_id = $2 AND developer_id = $1 +) +DELETE FROM agents +WHERE agent_id = $2 AND developer_id = $1 +RETURNING developer_id, agent_id; +""" + + +@rewrap_exceptions(common_db_exceptions("agent", ["delete"])) +@wrap_in_class( + ResourceDeletedResponse, + one=True, + transform=lambda d: {**d, "id": d["agent_id"], "deleted_at": utcnow()}, +) +@pg_query +@beartype +async def delete_agent(*, agent_id: UUID, developer_id: UUID) -> tuple[str, list]: + """ + Constructs the SQL query to delete an agent and its related settings. + + Args: + agent_id (UUID): The UUID of the agent to be deleted. + developer_id (UUID): The UUID of the developer owning the agent. + + Returns: + tuple[str, list]: A tuple containing the SQL query and its parameters. + """ + # Note: We swap the parameter order because the queries use $1 for developer_id and $2 for agent_id + params = [developer_id, agent_id] + + return ( + agent_query, + params, + ) diff --git a/agents-api/agents_api/queries/agents/get_agent.py b/agents-api/agents_api/queries/agents/get_agent.py new file mode 100644 index 000000000..cdf33b7a2 --- /dev/null +++ b/agents-api/agents_api/queries/agents/get_agent.py @@ -0,0 +1,62 @@ +""" +This module contains the functionality for retrieving a single agent from the PostgreSQL database. +It constructs and executes SQL queries to fetch agent details based on agent ID and developer ID. +""" + +from typing import Literal +from uuid import UUID + +from beartype import beartype + +from ...autogen.openapi_model import Agent +from ...common.utils.db_exceptions import common_db_exceptions +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL query +agent_query = """ +SELECT + agent_id, + developer_id, + name, + canonical_name, + about, + instructions, + model, + metadata, + default_settings, + created_at, + updated_at +FROM + agents +WHERE + agent_id = $2 AND developer_id = $1; +""" + + +@rewrap_exceptions(common_db_exceptions("agent", ["get"])) +@wrap_in_class( + Agent, + one=True, + transform=lambda d: {**d, "id": d["agent_id"]}, +) +@pg_query +@beartype +async def get_agent( + *, agent_id: UUID, developer_id: UUID +) -> tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]: + """ + Constructs the SQL query to retrieve an agent's details. + + Args: + agent_id (UUID): The UUID of the agent to retrieve. + developer_id (UUID): The UUID of the developer owning the agent. + + Returns: + tuple[list[str], dict]: A tuple containing the SQL query and its parameters. + """ + + return ( + agent_query, + [developer_id, agent_id], + "fetchrow", + ) diff --git a/agents-api/agents_api/queries/agents/list_agents.py b/agents-api/agents_api/queries/agents/list_agents.py new file mode 100644 index 000000000..c3e780b04 --- /dev/null +++ b/agents-api/agents_api/queries/agents/list_agents.py @@ -0,0 +1,95 @@ +""" +This module contains the functionality for listing agents from the PostgreSQL database. +It constructs and executes SQL queries to fetch a list of agents based on developer ID with pagination. +""" + +from typing import Any, Literal +from uuid import UUID + +from beartype import beartype +from fastapi import HTTPException + +from ...autogen.openapi_model import Agent +from ...common.utils.db_exceptions import common_db_exceptions +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL query +raw_query = """ +SELECT + agent_id, + developer_id, + name, + canonical_name, + about, + instructions, + model, + metadata, + default_settings, + created_at, + updated_at +FROM agents +WHERE developer_id = $1 {metadata_filter_query} +ORDER BY + CASE WHEN $4 = 'created_at' AND $5 = 'asc' THEN created_at END ASC NULLS LAST, + CASE WHEN $4 = 'created_at' AND $5 = 'desc' THEN created_at END DESC NULLS LAST, + CASE WHEN $4 = 'updated_at' AND $5 = 'asc' THEN updated_at END ASC NULLS LAST, + CASE WHEN $4 = 'updated_at' AND $5 = 'desc' THEN updated_at END DESC NULLS LAST +LIMIT $2 OFFSET $3; +""" + + +@rewrap_exceptions(common_db_exceptions("agent", ["list"])) +@wrap_in_class( + Agent, + transform=lambda d: {**d, "id": d["agent_id"]}, +) +@pg_query +@beartype +async def list_agents( + *, + developer_id: UUID, + limit: int = 100, + offset: int = 0, + sort_by: Literal["created_at", "updated_at"] = "created_at", + direction: Literal["asc", "desc"] = "desc", + metadata_filter: dict[str, Any] = {}, +) -> tuple[str, list]: + """ + Constructs query to list agents for a developer with pagination. + + Args: + developer_id: UUID of the developer + limit: Maximum number of records to return + offset: Number of records to skip + sort_by: Field to sort by + direction: Sort direction ('asc' or 'desc') + metadata_filter: Optional metadata filters + + Returns: + Tuple of (query, params) + """ + # Validate sort direction + if direction.lower() not in ["asc", "desc"]: + raise HTTPException(status_code=400, detail="Invalid sort direction") + + # Build metadata filter clause if needed + + agent_query = raw_query.format( + metadata_filter_query="AND metadata @> $6::jsonb" if metadata_filter else "" + ) + + params = [ + developer_id, + limit, + offset, + sort_by, + direction, + ] + + if metadata_filter: + params.append(metadata_filter) + + return ( + agent_query, + params, + ) diff --git a/agents-api/agents_api/queries/agents/patch_agent.py b/agents-api/agents_api/queries/agents/patch_agent.py new file mode 100644 index 000000000..324ee2eee --- /dev/null +++ b/agents-api/agents_api/queries/agents/patch_agent.py @@ -0,0 +1,80 @@ +""" +This module contains the functionality for partially updating an agent in the PostgreSQL database. +It constructs and executes SQL queries to update specific fields of an agent based on agent ID and developer ID. +""" + +from uuid import UUID + +from beartype import beartype + +from ...autogen.openapi_model import PatchAgentRequest, ResourceUpdatedResponse +from ...common.utils.db_exceptions import common_db_exceptions +from ...metrics.counters import increase_counter +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL query +agent_query = """ +UPDATE agents +SET + name = CASE + WHEN $3::text IS NOT NULL THEN $3 + ELSE name + END, + about = CASE + WHEN $4::text IS NOT NULL THEN $4 + ELSE about + END, + metadata = CASE + WHEN $5::jsonb IS NOT NULL THEN metadata || $5 + ELSE metadata + END, + model = CASE + WHEN $6::text IS NOT NULL THEN $6 + ELSE model + END, + default_settings = CASE + WHEN $7::jsonb IS NOT NULL THEN $7 + ELSE default_settings + END +WHERE agent_id = $2 AND developer_id = $1 +RETURNING *; +""" + + +@rewrap_exceptions(common_db_exceptions("agent", ["patch"])) +@wrap_in_class( + ResourceUpdatedResponse, + one=True, + transform=lambda d: {"id": d["agent_id"], **d}, +) +@increase_counter("patch_agent") +@pg_query +@beartype +async def patch_agent( + *, agent_id: UUID, developer_id: UUID, data: PatchAgentRequest +) -> tuple[str, list]: + """ + Constructs the SQL query to partially update an agent's details. + + Args: + agent_id (UUID): The UUID of the agent to update. + developer_id (UUID): The UUID of the developer owning the agent. + data (PatchAgentRequest): A dictionary of fields to update. + + Returns: + tuple[str, list]: A tuple containing the SQL query and its parameters. + """ + params = [ + developer_id, + agent_id, + data.name, + data.about, + data.metadata, + data.model, + data.default_settings.model_dump() if data.default_settings else None, + ] + + return ( + agent_query, + params, + ) diff --git a/agents-api/agents_api/queries/agents/update_agent.py b/agents-api/agents_api/queries/agents/update_agent.py new file mode 100644 index 000000000..69c0fa9f0 --- /dev/null +++ b/agents-api/agents_api/queries/agents/update_agent.py @@ -0,0 +1,65 @@ +""" +This module contains the functionality for fully updating an agent in the PostgreSQL database. +It constructs and executes SQL queries to replace an agent's details based on agent ID and developer ID. +""" + +from uuid import UUID + +from beartype import beartype + +from ...autogen.openapi_model import ResourceUpdatedResponse, UpdateAgentRequest +from ...common.utils.db_exceptions import common_db_exceptions +from ...metrics.counters import increase_counter +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL query +agent_query = """ +UPDATE agents +SET + metadata = $3, + name = $4, + about = $5, + model = $6, + default_settings = $7::jsonb +WHERE agent_id = $2 AND developer_id = $1 +RETURNING *; +""" + + +@rewrap_exceptions(common_db_exceptions("agent", ["update"])) +@wrap_in_class( + ResourceUpdatedResponse, + one=True, + transform=lambda d: {"id": d["agent_id"], **d}, +) +@increase_counter("update_agent") +@pg_query +@beartype +async def update_agent( + *, agent_id: UUID, developer_id: UUID, data: UpdateAgentRequest +) -> tuple[str, list]: + """ + Constructs the SQL query to fully update an agent's details. + + Args: + agent_id (UUID): The UUID of the agent to update. + developer_id (UUID): The UUID of the developer owning the agent. + data (UpdateAgentRequest): A dictionary containing all agent fields to update. + + Returns: + tuple[str, list]: A tuple containing the SQL query and its parameters. + """ + params = [ + developer_id, + agent_id, + data.metadata or {}, + data.name, + data.about, + data.model, + data.default_settings.model_dump() if data.default_settings else {}, + ] + + return ( + agent_query, + params, + ) diff --git a/agents-api/agents_api/models/chat/__init__.py b/agents-api/agents_api/queries/chat/__init__.py similarity index 92% rename from agents-api/agents_api/models/chat/__init__.py rename to agents-api/agents_api/queries/chat/__init__.py index 428b72572..2c05b4f8b 100644 --- a/agents-api/agents_api/models/chat/__init__.py +++ b/agents-api/agents_api/queries/chat/__init__.py @@ -17,6 +17,4 @@ # ruff: noqa: F401, F403, F405 from .gather_messages import gather_messages -from .get_cached_response import get_cached_response from .prepare_chat_context import prepare_chat_context -from .set_cached_response import set_cached_response diff --git a/agents-api/agents_api/models/chat/gather_messages.py b/agents-api/agents_api/queries/chat/gather_messages.py similarity index 71% rename from agents-api/agents_api/models/chat/gather_messages.py rename to agents-api/agents_api/queries/chat/gather_messages.py index 28dc6607f..039624762 100644 --- a/agents-api/agents_api/models/chat/gather_messages.py +++ b/agents-api/agents_api/queries/chat/gather_messages.py @@ -3,33 +3,28 @@ from beartype import beartype from fastapi import HTTPException -from pycozo.client import QueryException from pydantic import ValidationError -from ...autogen.openapi_model import ChatInput, DocReference, History +from ...autogen.openapi_model import ChatInput, DocReference, History, Session from ...clients import litellm from ...common.protocol.developers import Developer from ...common.protocol.sessions import ChatContext +from ...common.utils.db_exceptions import common_db_exceptions, partialclass from ..docs.search_docs_by_embedding import search_docs_by_embedding from ..docs.search_docs_by_text import search_docs_by_text from ..docs.search_docs_hybrid import search_docs_hybrid -from ..entry.get_history import get_history -from ..session.get_session import get_session -from ..utils import ( - partialclass, - rewrap_exceptions, -) +from ..entries.get_history import get_history +from ..sessions.get_session import get_session +from ..utils import rewrap_exceptions T = TypeVar("T") -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) +@rewrap_exceptions({ + ValidationError: partialclass(HTTPException, status_code=400), + TypeError: partialclass(HTTPException, status_code=400), + **common_db_exceptions("history", ["get"]), +}) @beartype async def gather_messages( *, @@ -37,6 +32,7 @@ async def gather_messages( session_id: UUID, chat_context: ChatContext, chat_input: ChatInput, + connection_pool=None, ) -> tuple[list[dict], list[DocReference]]: new_raw_messages = [msg.model_dump(mode="json") for msg in chat_input.messages] recall = chat_input.recall @@ -44,10 +40,11 @@ async def gather_messages( assert len(new_raw_messages) > 0 # Get the session history - history: History = get_history( + history: History = await get_history( developer_id=developer.id, session_id=session_id, allowed_sources=["api_request", "api_response", "tool_response", "summarizer"], + connection_pool=connection_pool, ) # Keep leaf nodes only @@ -71,18 +68,17 @@ async def gather_messages( return past_messages, [] # Get recall options - session = get_session( + session: Session = await get_session( developer_id=developer.id, session_id=session_id, + connection_pool=connection_pool, ) recall_options = session.recall_options # search the last `search_threshold` messages search_messages = [ msg - for msg in (past_messages + new_raw_messages)[ - -(recall_options.num_search_messages) : - ] + for msg in (past_messages + new_raw_messages)[-(recall_options.num_search_messages) :] if isinstance(msg["content"], str) and msg["role"] in ["user", "assistant"] ] @@ -91,12 +87,9 @@ async def gather_messages( # FIXME: This should only search text messages and not embed if text is empty # Search matching docs - embed_text = "\n\n".join( - [ - f"{msg.get('name') or msg['role']}: {msg['content']}" - for msg in search_messages - ] - ).strip() + embed_text = "\n\n".join([ + f"{msg.get('name') or msg['role']}: {msg['content']}" for msg in search_messages + ]).strip() [query_embedding, *_] = await litellm.aembedding( # Truncate on the left to keep the last `search_query_chars` characters @@ -106,9 +99,7 @@ async def gather_messages( ) # Truncate on the right to take only the first `search_query_chars` characters - query_text = search_messages[-1]["content"].strip()[ - : recall_options.max_query_length - ] + query_text = search_messages[-1]["content"].strip()[: recall_options.max_query_length] # List all the applicable owners to search docs from active_agent_id = chat_context.get_active_agent().id @@ -119,23 +110,26 @@ async def gather_messages( doc_references: list[DocReference] = [] match recall_options.mode: case "vector": - doc_references: list[DocReference] = search_docs_by_embedding( + doc_references: list[DocReference] = await search_docs_by_embedding( developer_id=developer.id, owners=owners, query_embedding=query_embedding, + connection_pool=connection_pool, ) case "hybrid": - doc_references: list[DocReference] = search_docs_hybrid( + doc_references: list[DocReference] = await search_docs_hybrid( developer_id=developer.id, owners=owners, - query=query_text, - query_embedding=query_embedding, + text_query=query_text, + embedding=query_embedding, + connection_pool=connection_pool, ) case "text": - doc_references: list[DocReference] = search_docs_by_text( + doc_references: list[DocReference] = await search_docs_by_text( developer_id=developer.id, owners=owners, query=query_text, + connection_pool=connection_pool, ) return past_messages, doc_references diff --git a/agents-api/agents_api/queries/chat/prepare_chat_context.py b/agents-api/agents_api/queries/chat/prepare_chat_context.py new file mode 100644 index 000000000..4c964d1b3 --- /dev/null +++ b/agents-api/agents_api/queries/chat/prepare_chat_context.py @@ -0,0 +1,174 @@ +from typing import Any, TypeVar +from uuid import UUID + +from beartype import beartype +from fastapi import HTTPException +from pydantic import ValidationError + +from ...common.protocol.sessions import ChatContext, make_session +from ...common.utils.datetime import utcnow +from ...common.utils.db_exceptions import common_db_exceptions, partialclass +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +ModelT = TypeVar("ModelT", bound=Any) +T = TypeVar("T") + + +sql_query = """ +SELECT * FROM +( + SELECT jsonb_agg(u) AS users FROM ( + SELECT + session_lookup.participant_id, + users.user_id AS id, + users.developer_id, + users.name, + users.about, + users.created_at, + users.updated_at, + users.metadata + FROM session_lookup + INNER JOIN users ON session_lookup.participant_id = users.user_id + WHERE + session_lookup.developer_id = $1 AND + session_id = $2 AND + session_lookup.participant_type = 'user' + ) u +) AS users, +( + SELECT jsonb_agg(a) AS agents FROM ( + SELECT + session_lookup.participant_id, + agents.agent_id AS id, + agents.developer_id, + agents.canonical_name, + agents.name, + agents.about, + agents.instructions, + agents.model, + agents.created_at, + agents.updated_at, + agents.metadata, + agents.default_settings + FROM session_lookup + INNER JOIN agents ON session_lookup.participant_id = agents.agent_id + WHERE + session_lookup.developer_id = $1 AND + session_id = $2 AND + session_lookup.participant_type = 'agent' + ) a +) AS agents, +( + SELECT to_jsonb(s) AS session FROM ( + SELECT + sessions.session_id AS id, + sessions.developer_id, + sessions.situation, + sessions.system_template, + sessions.created_at, + sessions.updated_at, + sessions.metadata, + sessions.render_templates, + sessions.token_budget, + sessions.context_overflow, + sessions.forward_tool_calls, + sessions.recall_options + FROM sessions + WHERE + developer_id = $1 AND + session_id = $2 + LIMIT 1 + ) s +) AS session, +( + SELECT jsonb_agg(r) AS toolsets FROM ( + SELECT + session_lookup.participant_id, + tools.tool_id as id, + tools.developer_id, + tools.agent_id, + tools.task_id, + tools.type, + tools.name, + tools.description, + tools.spec, + tools.updated_at, + tools.created_at + FROM session_lookup + INNER JOIN tools ON session_lookup.participant_id = tools.agent_id + WHERE + session_lookup.developer_id = $1 AND + session_id = $2 AND + session_lookup.participant_type = 'agent' + ) r +) AS toolsets""" + + +def _transform(d): + toolsets = {} + + # Default to empty lists when users/agents are not present + d["users"] = d.get("users") or [] + d["agents"] = d.get("agents") or [] + + for tool in d.get("toolsets") or []: + if not tool: + continue + + agent_id = tool["agent_id"] + if agent_id in toolsets: + toolsets[agent_id].append(tool) + else: + toolsets[agent_id] = [tool] + + d["session"]["updated_at"] = utcnow() + d["users"] = d.get("users") or [] + + return { + **d, + "session": make_session( + agents=[a["id"] for a in d.get("agents") or []], + users=[u["id"] for u in d.get("users") or []], + **d["session"], + ), + "toolsets": [ + { + "agent_id": agent_id, + "tools": [ + { + tool["type"]: tool.pop("spec"), + **tool, + } + for tool in tools + ], + } + for agent_id, tools in toolsets.items() + ], + } + + +@rewrap_exceptions({ + ValidationError: partialclass(HTTPException, status_code=400), + TypeError: partialclass(HTTPException, status_code=400), + **common_db_exceptions("chat", ["get"]), +}) +@wrap_in_class( + ChatContext, + one=True, + transform=_transform, +) +@pg_query +@beartype +async def prepare_chat_context( + *, + developer_id: UUID, + session_id: UUID, +) -> tuple[str, list]: + """ + Executes a complex query to retrieve memory context based on session ID. + """ + + return ( + sql_query, + [developer_id, session_id], + ) diff --git a/agents-api/agents_api/models/developer/__init__.py b/agents-api/agents_api/queries/developers/__init__.py similarity index 80% rename from agents-api/agents_api/models/developer/__init__.py rename to agents-api/agents_api/queries/developers/__init__.py index a7117c06b..c3d1d4bbb 100644 --- a/agents-api/agents_api/models/developer/__init__.py +++ b/agents-api/agents_api/queries/developers/__init__.py @@ -16,4 +16,14 @@ # ruff: noqa: F401, F403, F405 -from .get_developer import get_developer, verify_developer +from .create_developer import create_developer +from .get_developer import get_developer +from .patch_developer import patch_developer +from .update_developer import update_developer + +__all__ = [ + "create_developer", + "get_developer", + "patch_developer", + "update_developer", +] diff --git a/agents-api/agents_api/queries/developers/create_developer.py b/agents-api/agents_api/queries/developers/create_developer.py new file mode 100644 index 000000000..6a581a136 --- /dev/null +++ b/agents-api/agents_api/queries/developers/create_developer.py @@ -0,0 +1,51 @@ +from uuid import UUID + +from beartype import beartype +from uuid_extensions import uuid7 + +from ...autogen.openapi_model import ResourceCreatedResponse +from ...common.utils.db_exceptions import common_db_exceptions +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL query +developer_query = """ +INSERT INTO developers ( + developer_id, + email, + active, + tags, + settings +) +VALUES ( + $1, -- developer_id + $2, -- email + $3, -- active + $4, -- tags + $5::jsonb -- settings +) +RETURNING *; +""" + + +@rewrap_exceptions(common_db_exceptions("developer", ["create"])) +@wrap_in_class( + ResourceCreatedResponse, + one=True, + transform=lambda d: {**d, "id": d["developer_id"], "created_at": d["created_at"]}, +) +@pg_query +@beartype +async def create_developer( + *, + email: str, + active: bool = True, + tags: list[str] | None = None, + settings: dict | None = None, + developer_id: UUID | None = None, +) -> tuple[str, list]: + developer_id = str(developer_id or uuid7()) + + return ( + developer_query, + [developer_id, email, active, tags or [], settings or {}], + ) diff --git a/agents-api/agents_api/queries/developers/get_developer.py b/agents-api/agents_api/queries/developers/get_developer.py new file mode 100644 index 000000000..95470d880 --- /dev/null +++ b/agents-api/agents_api/queries/developers/get_developer.py @@ -0,0 +1,35 @@ +""" +Module for retrieving developer information from the PostgreSQL database. +""" + +from uuid import UUID + +from beartype import beartype + +from ...common.protocol.developers import Developer +from ...common.utils.db_exceptions import common_db_exceptions +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +developer_query = """ +SELECT * FROM developers WHERE developer_id = $1 -- developer_id +""" + + +@rewrap_exceptions(common_db_exceptions("developer", ["get"])) +@wrap_in_class( + Developer, + one=True, + transform=lambda d: {**d, "id": d["developer_id"]}, +) +@pg_query +@beartype +async def get_developer( + *, + developer_id: UUID, +) -> tuple[str, list]: + developer_id = str(developer_id) + + return ( + developer_query, + [developer_id], + ) diff --git a/agents-api/agents_api/queries/developers/patch_developer.py b/agents-api/agents_api/queries/developers/patch_developer.py new file mode 100644 index 000000000..39f694377 --- /dev/null +++ b/agents-api/agents_api/queries/developers/patch_developer.py @@ -0,0 +1,35 @@ +from uuid import UUID + +from beartype import beartype + +from ...common.protocol.developers import Developer +from ...common.utils.db_exceptions import common_db_exceptions +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL query +developer_query = """ +UPDATE developers +SET email = $1, active = $2, tags = tags || $3, settings = settings || $4 -- settings +WHERE developer_id = $5 -- developer_id +RETURNING *; +""" + + +@rewrap_exceptions(common_db_exceptions("developer", ["patch"])) +@wrap_in_class(Developer, one=True, transform=lambda d: {**d, "id": d["developer_id"]}) +@pg_query +@beartype +async def patch_developer( + *, + developer_id: UUID, + email: str, + active: bool = True, + tags: list[str] | None = None, + settings: dict | None = None, +) -> tuple[str, list]: + developer_id = str(developer_id) + + return ( + developer_query, + [email, active, tags or [], settings or {}, developer_id], + ) diff --git a/agents-api/agents_api/queries/developers/update_developer.py b/agents-api/agents_api/queries/developers/update_developer.py new file mode 100644 index 000000000..e76ec9cca --- /dev/null +++ b/agents-api/agents_api/queries/developers/update_developer.py @@ -0,0 +1,35 @@ +from uuid import UUID + +from beartype import beartype + +from ...common.protocol.developers import Developer +from ...common.utils.db_exceptions import common_db_exceptions +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL query +developer_query = """ +UPDATE developers +SET email = $1, active = $2, tags = $3, settings = $4 +WHERE developer_id = $5 +RETURNING *; +""" + + +@rewrap_exceptions(common_db_exceptions("developer", ["update"])) +@wrap_in_class(Developer, one=True, transform=lambda d: {**d, "id": d["developer_id"]}) +@pg_query +@beartype +async def update_developer( + *, + developer_id: UUID, + email: str, + active: bool = True, + tags: list[str] | None = None, + settings: dict | None = None, +) -> tuple[str, list]: + developer_id = str(developer_id) + + return ( + developer_query, + [email, active, tags or [], settings or {}, developer_id], + ) diff --git a/agents-api/agents_api/models/docs/__init__.py b/agents-api/agents_api/queries/docs/__init__.py similarity index 80% rename from agents-api/agents_api/models/docs/__init__.py rename to agents-api/agents_api/queries/docs/__init__.py index 0ba3db0d4..3862131bb 100644 --- a/agents-api/agents_api/models/docs/__init__.py +++ b/agents-api/agents_api/queries/docs/__init__.py @@ -8,6 +8,9 @@ - Listing documents based on various criteria, including ownership and metadata filters. - Deleting documents by their unique identifiers. - Embedding document snippets for retrieval purposes. +- Searching documents by text. +- Searching documents by hybrid text and embedding. +- Searching documents by embedding. The module interacts with other parts of the application, such as the agents and users modules, to provide a comprehensive document management system. Its role is crucial in enabling document search, retrieval, and management features within the context of agents and users. @@ -18,8 +21,18 @@ from .create_doc import create_doc from .delete_doc import delete_doc -from .embed_snippets import embed_snippets from .get_doc import get_doc from .list_docs import list_docs from .search_docs_by_embedding import search_docs_by_embedding from .search_docs_by_text import search_docs_by_text +from .search_docs_hybrid import search_docs_hybrid + +__all__ = [ + "create_doc", + "delete_doc", + "get_doc", + "list_docs", + "search_docs_by_embedding", + "search_docs_by_text", + "search_docs_hybrid", +] diff --git a/agents-api/agents_api/queries/docs/create_doc.py b/agents-api/agents_api/queries/docs/create_doc.py new file mode 100644 index 000000000..16d8810d6 --- /dev/null +++ b/agents-api/agents_api/queries/docs/create_doc.py @@ -0,0 +1,162 @@ +from typing import Literal +from uuid import UUID + +from beartype import beartype +from uuid_extensions import uuid7 + +from ...autogen.openapi_model import CreateDocRequest, ResourceCreatedResponse +from ...common.utils.datetime import utcnow +from ...common.utils.db_exceptions import common_db_exceptions +from ...metrics.counters import increase_counter +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Base INSERT for docs +doc_query = """ +INSERT INTO docs ( + developer_id, + doc_id, + title, + content, + index, + modality, + embedding_model, + embedding_dimensions, + language, + metadata +) +VALUES ( + $1, -- developer_id + $2, -- doc_id + $3, -- title + $4, -- content + $5, -- index + $6, -- modality + $7, -- embedding_model + $8, -- embedding_dimensions + $9, -- language + $10 -- metadata (JSONB) +) +""" + +# Owner association query for doc_owners +doc_owner_query = """ +INSERT INTO doc_owners (developer_id, doc_id, owner_type, owner_id) +VALUES ($1, $2, $3, $4) +ON CONFLICT DO NOTHING +RETURNING *; +""" + + +@rewrap_exceptions(common_db_exceptions("doc", ["create"])) +@wrap_in_class( + ResourceCreatedResponse, + one=True, + transform=lambda d: { + "id": d["doc_id"], + "jobs": [], + "created_at": utcnow(), + **d, + }, +) +@increase_counter("create_doc") +@pg_query +@beartype +async def create_doc( + *, + developer_id: UUID, + doc_id: UUID | None = None, + data: CreateDocRequest, + owner_type: Literal["user", "agent"], + owner_id: UUID, + modality: Literal["text", "image", "mixed"] | None = "text", + embedding_model: str | None = "voyage-3", + embedding_dimensions: int | None = 1024, + language: str | None = "english", + index: int | None = 0, +) -> list[tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]]: + """ + Insert a new doc record into Timescale and associate it with an owner. + + Parameters: + developer_id (UUID): The ID of the developer. + doc_id (UUID | None): Optional custom UUID for the document. If not provided, one will be generated. + data (CreateDocRequest): The data for the document. + owner_type (Literal["user", "agent"]): The type of the owner (required). + owner_id (UUID): The ID of the owner (required). + modality (Literal["text", "image", "mixed"]): The modality of the documents. + embedding_model (str): The model used for embedding. + embedding_dimensions (int): The dimensions of the embedding. + language (str): The language of the documents. + index (int): The index of the documents. + + Returns: + list[tuple[str, list] | tuple[str, list, str]]: SQL query and parameters for creating the document. + """ + queries = [] + + # Generate a UUID if not provided + current_doc_id = uuid7() if doc_id is None else doc_id + + # Check if content is a list + if isinstance(data.content, list): + final_params_doc = [] + final_params_owner = [] + + for idx, content in enumerate(data.content): + doc_params = [ + developer_id, + current_doc_id, + data.title, + content, + idx, + modality, + embedding_model, + embedding_dimensions, + language, + data.metadata or {}, + ] + final_params_doc.append(doc_params) + + owner_params = [ + developer_id, + current_doc_id, + owner_type, + owner_id, + ] + final_params_owner.append(owner_params) + + # Add the doc query for each content + queries.append((doc_query, final_params_doc, "fetchmany")) + + # Add the owner query + queries.append((doc_owner_query, final_params_owner, "fetchmany")) + + else: + # Create the doc record + doc_params = [ + developer_id, + current_doc_id, + data.title, + data.content, + index, + modality, + embedding_model, + embedding_dimensions, + language, + data.metadata or {}, + ] + + owner_params = [ + developer_id, + current_doc_id, + owner_type, + owner_id, + ] + + # Add the doc query for single content + queries.append((doc_query, doc_params, "fetch")) + + # Add the owner query + queries.append((doc_owner_query, owner_params, "fetch")) + + return queries diff --git a/agents-api/agents_api/queries/docs/delete_doc.py b/agents-api/agents_api/queries/docs/delete_doc.py new file mode 100644 index 000000000..f29659013 --- /dev/null +++ b/agents-api/agents_api/queries/docs/delete_doc.py @@ -0,0 +1,71 @@ +from typing import Literal +from uuid import UUID + +from beartype import beartype + +from ...autogen.openapi_model import ResourceDeletedResponse +from ...common.utils.datetime import utcnow +from ...common.utils.db_exceptions import common_db_exceptions +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Delete doc query +delete_doc_query = """ +DELETE FROM docs +WHERE developer_id = $1 + AND doc_id = $2 + AND EXISTS ( + SELECT 1 FROM doc_owners + WHERE developer_id = $1 + AND doc_id = $2 + AND owner_type = $3 + AND owner_id = $4 + ) +RETURNING doc_id; +""" + +delete_doc_owners_query = """ +DELETE FROM doc_owners +WHERE developer_id = $1 + AND doc_id = $2 + AND owner_type = $3 + AND owner_id = $4 +RETURNING doc_id; +""" + + +@rewrap_exceptions(common_db_exceptions("doc", ["delete"])) +@wrap_in_class( + ResourceDeletedResponse, + one=True, + transform=lambda d: { + "id": d["doc_id"], + "deleted_at": utcnow(), + "jobs": [], + }, +) +@pg_query +@beartype +async def delete_doc( + *, + developer_id: UUID, + doc_id: UUID, + owner_type: Literal["user", "agent"], + owner_id: UUID, +) -> list[tuple[str, list]]: + """ + Deletes a doc (and associated doc_owners) for the given developer and doc_id. + If owner_type/owner_id is specified, only remove doc if that matches. + + Parameters: + developer_id (UUID): The ID of the developer. + doc_id (UUID): The ID of the document. + owner_type (Literal["user", "agent"]): The type of the owner of the documents. + owner_id (UUID): The ID of the owner of the documents. + + Returns: + tuple[str, list]: SQL query and parameters for deleting the document. + """ + return [ + (delete_doc_query, [developer_id, doc_id, owner_type, owner_id]), + (delete_doc_owners_query, [developer_id, doc_id, owner_type, owner_id]), + ] diff --git a/agents-api/agents_api/queries/docs/get_doc.py b/agents-api/agents_api/queries/docs/get_doc.py new file mode 100644 index 000000000..c742a3054 --- /dev/null +++ b/agents-api/agents_api/queries/docs/get_doc.py @@ -0,0 +1,86 @@ +from uuid import UUID + +from beartype import beartype + +from ...autogen.openapi_model import Doc +from ...common.utils.db_exceptions import common_db_exceptions +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Update the query to use DISTINCT ON to prevent duplicates +doc_with_embedding_query = """ +WITH doc_data AS ( + SELECT + d.doc_id, + d.developer_id, + d.title, + array_agg(d.content ORDER BY d.index) as content, + array_agg(d.index ORDER BY d.index) as indices, + array_agg(e.embedding ORDER BY d.index) as embeddings, + d.modality, + d.embedding_model, + d.embedding_dimensions, + d.language, + d.metadata, + d.created_at + FROM docs d + LEFT JOIN docs_embeddings e + ON d.doc_id = e.doc_id + WHERE d.developer_id = $1 + AND d.doc_id = $2 + GROUP BY + d.doc_id, + d.developer_id, + d.title, + d.modality, + d.embedding_model, + d.embedding_dimensions, + d.language, + d.metadata, + d.created_at +) +SELECT * FROM doc_data; +""" + + +def transform_get_doc(d: dict) -> dict: + content = d["content"][0] if len(d["content"]) == 1 else d["content"] + + embeddings = d["embeddings"][0] if len(d["embeddings"]) == 1 else d["embeddings"] + if embeddings and all((e is None) for e in embeddings): + embeddings = None + + return { + **d, + "id": d["doc_id"], + "content": content, + "embeddings": embeddings, + } + + +@rewrap_exceptions(common_db_exceptions("doc", ["get"])) +@wrap_in_class( + Doc, + one=True, + transform=transform_get_doc, +) +@pg_query +@beartype +async def get_doc( + *, + developer_id: UUID, + doc_id: UUID, +) -> tuple[str, list]: + """ + Fetch a single doc with its embedding, grouping all content chunks and embeddings. + + Parameters: + developer_id (UUID): The ID of the developer. + doc_id (UUID): The ID of the document. + + Returns: + tuple[str, list]: SQL query and parameters for fetching the document. + """ + return ( + doc_with_embedding_query, + [developer_id, doc_id], + ) diff --git a/agents-api/agents_api/queries/docs/list_docs.py b/agents-api/agents_api/queries/docs/list_docs.py new file mode 100644 index 000000000..60c0118a8 --- /dev/null +++ b/agents-api/agents_api/queries/docs/list_docs.py @@ -0,0 +1,147 @@ +""" +This module contains the functionality for listing documents from the PostgreSQL database. +It constructs and executes SQL queries to fetch document details based on various filters. +""" + +from typing import Any, Literal +from uuid import UUID + +from beartype import beartype +from fastapi import HTTPException + +from ...autogen.openapi_model import Doc +from ...common.utils.db_exceptions import common_db_exceptions +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Base query for listing docs with aggregated content and embeddings +base_docs_query = """ +WITH doc_data AS ( + SELECT + d.doc_id, + d.developer_id, + d.title, + array_agg(d.content ORDER BY d.index) as content, + array_agg(d.index ORDER BY d.index) as indices, + array_agg(CASE WHEN $2 THEN NULL ELSE e.embedding END ORDER BY d.index) as embeddings, + d.modality, + d.embedding_model, + d.embedding_dimensions, + d.language, + d.metadata, + d.created_at + FROM docs d + JOIN doc_owners doc_own + ON d.developer_id = doc_own.developer_id + AND d.doc_id = doc_own.doc_id + LEFT JOIN docs_embeddings e + ON d.doc_id = e.doc_id + WHERE d.developer_id = $1 + AND doc_own.owner_type = $3 + AND doc_own.owner_id = $4 + GROUP BY + d.doc_id, + d.developer_id, + d.title, + d.modality, + d.embedding_model, + d.embedding_dimensions, + d.language, + d.metadata, + d.created_at +) +SELECT * FROM doc_data +""" + + +def transform_list_docs(d: dict) -> dict: + content = d["content"][0] if len(d["content"]) == 1 else d["content"] + + embeddings = d["embeddings"][0] if len(d["embeddings"]) == 1 else d["embeddings"] + + # try: + # # Embeddings are retreived as a string, so we need to evaluate it + # embeddings = ast.literal_eval(embeddings) + # except Exception as e: + # msg = f"Error evaluating embeddings: {e}" + # raise ValueError(msg) + + if embeddings and all((e is None) for e in embeddings): + embeddings = None + + return { + **d, + "id": d["doc_id"], + "content": content, + "embeddings": embeddings, + } + + +@rewrap_exceptions(common_db_exceptions("doc", ["list"])) +@wrap_in_class( + Doc, + one=False, + transform=transform_list_docs, +) +@pg_query +@beartype +async def list_docs( + *, + developer_id: UUID, + owner_id: UUID, + owner_type: Literal["user", "agent"], + limit: int = 100, + offset: int = 0, + sort_by: Literal["created_at", "updated_at"] = "created_at", + direction: Literal["asc", "desc"] = "desc", + metadata_filter: dict[str, Any] = {}, + include_without_embeddings: bool = False, +) -> tuple[str, list]: + """ + Lists docs with pagination and sorting, aggregating content chunks and embeddings. + + Parameters: + developer_id (UUID): The ID of the developer. + owner_id (UUID): The ID of the owner of the documents (required). + owner_type (Literal["user", "agent"]): The type of the owner of the documents (required). + limit (int): The number of documents to return. + offset (int): The number of documents to skip. + sort_by (Literal["created_at", "updated_at"]): The field to sort by. + direction (Literal["asc", "desc"]): The direction to sort by. + metadata_filter (dict[str, Any]): The metadata filter to apply. + include_without_embeddings (bool): Whether to include documents without embeddings. + + Returns: + tuple[str, list]: SQL query and parameters for listing the documents. + + Raises: + HTTPException: If invalid parameters are provided. + """ + if direction.lower() not in ["asc", "desc"]: + raise HTTPException(status_code=400, detail="Invalid sort direction") + + if sort_by not in ["created_at", "updated_at"]: + raise HTTPException(status_code=400, detail="Invalid sort field") + + if limit > 100 or limit < 1: + raise HTTPException(status_code=400, detail="Limit must be between 1 and 100") + + if offset < 0: + raise HTTPException(status_code=400, detail="Offset must be >= 0") + + # Start with the base query + query = base_docs_query + params = [developer_id, include_without_embeddings, owner_type, owner_id] + + # Add metadata filtering + if metadata_filter: + for key, value in metadata_filter.items(): + query += f" AND metadata->>'{key}' = ${len(params) + 1}" + params.append(value) + + # Add sorting and pagination + query += ( + f" ORDER BY {sort_by} {direction} LIMIT ${len(params) + 1} OFFSET ${len(params) + 2}" + ) + params.extend([limit, offset]) + + return query, params diff --git a/agents-api/agents_api/models/docs/mmr.py b/agents-api/agents_api/queries/docs/mmr.py similarity index 87% rename from agents-api/agents_api/models/docs/mmr.py rename to agents-api/agents_api/queries/docs/mmr.py index d214e8c04..26f1f5aa1 100644 --- a/agents-api/agents_api/models/docs/mmr.py +++ b/agents-api/agents_api/queries/docs/mmr.py @@ -1,11 +1,10 @@ from __future__ import annotations import logging -from typing import Union import numpy as np -Matrix = Union[list[list[float]], list[np.ndarray], np.ndarray] +Matrix = list[list[float]] | list[np.ndarray] | np.ndarray logger = logging.getLogger(__name__) @@ -35,18 +34,14 @@ def _cosine_similarity(x: Matrix, y: Matrix) -> np.ndarray: x = np.array(x) y = np.array(y) if x.shape[1] != y.shape[1]: - msg = ( - f"Number of columns in X and Y must be the same. X has shape {x.shape} " - f"and Y has shape {y.shape}." - ) + msg = f"Number of columns in X and Y must be the same. X has shape {x.shape} and Y has shape {y.shape}." raise ValueError(msg) try: import simsimd as simd # type: ignore x = np.array(x, dtype=np.float32) y = np.array(y, dtype=np.float32) - z = 1 - np.array(simd.cdist(x, y, metric="cosine")) - return z + return 1 - np.array(simd.cdist(x, y, metric="cosine")) except ImportError: logger.debug( "Unable to import simsimd, defaulting to NumPy implementation. If you want " @@ -98,9 +93,7 @@ def maximal_marginal_relevance( if i in idxs: continue redundant_score = max(similarity_to_selected[i]) - equation_score = ( - lambda_mult * query_score - (1 - lambda_mult) * redundant_score - ) + equation_score = lambda_mult * query_score - (1 - lambda_mult) * redundant_score if equation_score > best_score: best_score = equation_score idx_to_add = i diff --git a/agents-api/agents_api/queries/docs/search_docs_by_embedding.py b/agents-api/agents_api/queries/docs/search_docs_by_embedding.py new file mode 100644 index 000000000..fb5110a56 --- /dev/null +++ b/agents-api/agents_api/queries/docs/search_docs_by_embedding.py @@ -0,0 +1,80 @@ +from typing import Any, Literal +from uuid import UUID + +from beartype import beartype +from fastapi import HTTPException + +from ...autogen.openapi_model import DocReference +from ...common.utils.db_exceptions import common_db_exceptions +from ..utils import pg_query, rewrap_exceptions, wrap_in_class +from .utils import transform_to_doc_reference + +# Raw query for vector search +search_docs_by_embedding_query = """ +SELECT * FROM search_by_vector( + $1, -- developer_id + $2::vector(1024), -- query_embedding + $3::text[], -- owner_types + $4::uuid[], -- owner_ids + $5, -- k + $6, -- confidence + $7 -- metadata_filter +) +""" + + +@rewrap_exceptions(common_db_exceptions("doc", ["search"])) +@wrap_in_class( + DocReference, + transform=transform_to_doc_reference, +) +@pg_query +@beartype +async def search_docs_by_embedding( + *, + developer_id: UUID, + query_embedding: list[float], + k: int = 10, + owners: list[tuple[Literal["user", "agent"], UUID]], + confidence: float = 0.5, + metadata_filter: dict[str, Any] = {}, +) -> tuple[str, list]: + """ + Vector-based doc search: + + Parameters: + developer_id (UUID): The ID of the developer. + query_embedding (List[float]): The vector to query. + k (int): The number of results to return. + owners (list[tuple[Literal["user", "agent"], UUID]]): List of (owner_type, owner_id) tuples. + confidence (float): The confidence threshold for the search. + metadata_filter (dict): Metadata filter criteria. + + Returns: + tuple[str, list]: SQL query and parameters for searching the documents. + """ + if k < 1: + raise HTTPException(status_code=400, detail="k must be >= 1") + + if not query_embedding: + raise HTTPException(status_code=400, detail="Empty embedding provided") + + # Convert query_embedding to a string + query_embedding_str = f"[{', '.join(map(str, query_embedding))}]" + + # Extract owner types and IDs + owner_types: list[str] = [owner[0] for owner in owners] + owner_ids: list[str] = [str(owner[1]) for owner in owners] + + return ( + search_docs_by_embedding_query, + [ + developer_id, + query_embedding_str, + owner_types, + owner_ids, + k, + confidence, + metadata_filter, + ], + ) diff --git a/agents-api/agents_api/queries/docs/search_docs_by_text.py b/agents-api/agents_api/queries/docs/search_docs_by_text.py new file mode 100644 index 000000000..77fb3a0e6 --- /dev/null +++ b/agents-api/agents_api/queries/docs/search_docs_by_text.py @@ -0,0 +1,75 @@ +from typing import Any, Literal +from uuid import UUID + +from beartype import beartype +from fastapi import HTTPException + +from ...autogen.openapi_model import DocReference +from ...common.utils.db_exceptions import common_db_exceptions +from ..utils import pg_query, rewrap_exceptions, wrap_in_class +from .utils import transform_to_doc_reference + +# Raw query for text search +search_docs_text_query = """ +SELECT * FROM search_by_text( + $1, -- developer_id + $2, -- query + $3, -- owner_types + $4, -- owner_ids + $5, -- search_language + $6, -- k + $7 -- metadata_filter +) +""" + + +@rewrap_exceptions(common_db_exceptions("doc", ["search"])) +@wrap_in_class( + DocReference, + transform=transform_to_doc_reference, +) +@pg_query +@beartype +async def search_docs_by_text( + *, + developer_id: UUID, + owners: list[tuple[Literal["user", "agent"], UUID]], + query: str, + k: int = 3, + metadata_filter: dict[str, Any] = {}, + search_language: str | None = "english", +) -> tuple[str, list]: + """ + Full-text search on docs using the search_tsv column. + + Parameters: + developer_id (UUID): The ID of the developer. + query (str): The text to search for. + owners (list[tuple[Literal["user", "agent"], UUID]]): List of (owner_type, owner_id) tuples. + k (int): Maximum number of results to return. + search_language (str): Language for text search (default: "english"). + metadata_filter (dict): Metadata filter criteria. + connection_pool (asyncpg.Pool): Database connection pool. + + Returns: + tuple[str, list]: SQL query and parameters for searching the documents. + """ + if k < 1: + raise HTTPException(status_code=400, detail="k must be >= 1") + + # Extract owner types and IDs + owner_types: list[str] = [owner[0] for owner in owners] + owner_ids: list[str] = [str(owner[1]) for owner in owners] + + return ( + search_docs_text_query, + [ + developer_id, + query, + owner_types, + owner_ids, + search_language, + k, + metadata_filter, + ], + ) diff --git a/agents-api/agents_api/queries/docs/search_docs_hybrid.py b/agents-api/agents_api/queries/docs/search_docs_hybrid.py new file mode 100644 index 000000000..5c09b802c --- /dev/null +++ b/agents-api/agents_api/queries/docs/search_docs_hybrid.py @@ -0,0 +1,98 @@ +from typing import Any, Literal +from uuid import UUID + +from beartype import beartype +from fastapi import HTTPException + +from ...autogen.openapi_model import DocReference +from ...common.utils.db_exceptions import common_db_exceptions +from ..utils import ( + pg_query, + rewrap_exceptions, + wrap_in_class, +) +from .utils import transform_to_doc_reference + +# Raw query for hybrid search +search_docs_hybrid_query = """ +SELECT * FROM search_hybrid( + $1, -- developer_id + $2, -- text_query + $3::vector(1024), -- embedding + $4::text[], -- owner_types + $5::uuid[], -- owner_ids + $6, -- k + $7, -- alpha + $8, -- confidence + $9, -- metadata_filter + $10 -- search_language +) +""" + + +@rewrap_exceptions(common_db_exceptions("doc", ["search"])) +@wrap_in_class( + DocReference, + transform=transform_to_doc_reference, +) +@pg_query +@beartype +async def search_docs_hybrid( + developer_id: UUID, + owners: list[tuple[Literal["user", "agent"], UUID]], + text_query: str = "", + embedding: list[float] | None = None, + k: int = 10, + alpha: float = 0.5, + metadata_filter: dict[str, Any] = {}, + search_language: str = "english", + confidence: float = 0.5, +) -> tuple[str, list]: + """ + Hybrid text-and-embedding doc search. We get top-K from each approach, + then fuse them client-side. Adjust concurrency or approach as you like. + + Parameters: + developer_id (UUID): The unique identifier for the developer. + text_query (str): The text query to search for. + embedding (List[float]): The embedding to search for. + k (int): The number of results to return. + alpha (float): The weight for the embedding results. + owner_type (Literal["user", "agent", "org"] | None): The type of the owner. + owner_id (UUID | None): The ID of the owner. + + Returns: + tuple[str, list]: The SQL query and parameters for the search. + """ + + if k < 1: + raise HTTPException(status_code=400, detail="k must be >= 1") + + if not text_query and not embedding: + raise HTTPException(status_code=400, detail="Empty query provided") + + if not embedding: + raise HTTPException(status_code=400, detail="Empty embedding provided") + + # Convert query_embedding to a string + embedding_str = f"[{', '.join(map(str, embedding))}]" + + # Extract owner types and IDs + owner_types: list[str] = [owner[0] for owner in owners] + owner_ids: list[str] = [str(owner[1]) for owner in owners] + + return ( + search_docs_hybrid_query, + [ + developer_id, + text_query, + embedding_str, + owner_types, + owner_ids, + k, + alpha, + confidence, + metadata_filter, + search_language, + ], + ) diff --git a/agents-api/agents_api/queries/docs/utils.py b/agents-api/agents_api/queries/docs/utils.py new file mode 100644 index 000000000..4d1cbaf45 --- /dev/null +++ b/agents-api/agents_api/queries/docs/utils.py @@ -0,0 +1,35 @@ +import ast + + +def transform_to_doc_reference(d: dict) -> dict: + id = d.pop("doc_id") + content = d.pop("content") + index = d.pop("index") + + embedding = d.pop("embedding") + + try: + # Embeddings are retreived as a string, so we need to evaluate it + embedding = ast.literal_eval(embedding) + except Exception as e: + msg = f"Error evaluating embeddings: {e}" + raise ValueError(msg) + + owner = { + "id": d.pop("owner_id"), + "role": d.pop("owner_type"), + } + snippet = { + "content": content, + "index": index, + "embedding": embedding, + } + metadata = d.pop("metadata") + + return { + "id": id, + "owner": owner, + "snippet": snippet, + "metadata": metadata, + **d, + } diff --git a/agents-api/agents_api/queries/entries/__init__.py b/agents-api/agents_api/queries/entries/__init__.py new file mode 100644 index 000000000..e6db0efed --- /dev/null +++ b/agents-api/agents_api/queries/entries/__init__.py @@ -0,0 +1,21 @@ +""" +The `entry` module provides SQL query functions for managing entries +in the TimescaleDB database. This includes operations for: + +- Creating new entries +- Deleting entries +- Retrieving entry history +- Listing entries with filtering and pagination +""" + +from .create_entries import create_entries +from .delete_entries import delete_entries +from .get_history import get_history +from .list_entries import list_entries + +__all__ = [ + "create_entries", + "delete_entries", + "get_history", + "list_entries", +] diff --git a/agents-api/agents_api/queries/entries/create_entries.py b/agents-api/agents_api/queries/entries/create_entries.py new file mode 100644 index 000000000..48e32dafd --- /dev/null +++ b/agents-api/agents_api/queries/entries/create_entries.py @@ -0,0 +1,177 @@ +from typing import Literal +from uuid import UUID + +from beartype import beartype +from litellm.utils import _select_tokenizer as select_tokenizer +from uuid_extensions import uuid7 + +from ...autogen.openapi_model import ( + CreateEntryRequest, + Relation, + ResourceCreatedResponse, +) +from ...common.utils.datetime import utcnow +from ...common.utils.db_exceptions import common_db_exceptions +from ...common.utils.messages import content_to_json +from ...metrics.counters import increase_counter +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Query for checking if the session exists +session_exists_query = """ +SELECT EXISTS ( + SELECT 1 FROM sessions + WHERE session_id = $1 AND developer_id = $2 +) AS exists; +""" + +# Define the raw SQL query for creating entries +entry_query = """ +INSERT INTO entries ( + session_id, + entry_id, + source, + role, + event_type, + name, + content, + tool_call_id, + tool_calls, + model, + token_count, + tokenizer, + created_at, + timestamp +) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) +RETURNING *; +""" + +# Define the raw SQL query for creating entry relations +entry_relation_query = """ +INSERT INTO entry_relations ( + session_id, + head, + relation, + tail +) VALUES ($1, $2, $3, $4) +RETURNING *; +""" + + +@rewrap_exceptions(common_db_exceptions("entry", ["create"])) +@wrap_in_class( + ResourceCreatedResponse, + transform=lambda d: { + "id": d.pop("entry_id"), + "created_at": d.pop("created_at"), + **d, + }, +) +@increase_counter("create_entries") +@pg_query +@beartype +async def create_entries( + *, + developer_id: UUID, + session_id: UUID, + data: list[CreateEntryRequest], +) -> list[tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]]: + """ + Create entries in a session. + + Parameters: + developer_id (UUID): The ID of the developer. + session_id (UUID): The ID of the session. + data (list[CreateEntryRequest]): The list of entries to create. + + Returns: + list[tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]]: SQL query and parameters for creating the entries. + """ + # Convert the data to a list of dictionaries + data_dicts = [item.model_dump(mode="json") for item in data] + + params = [ + [ + session_id, # $1 + item.pop("id", None) or uuid7(), # $2 + item.get("source"), # $3 + item.get("role"), # $4 + item.get("event_type") or "message.create", # $5 + item.get("name"), # $6 + content_to_json(item.get("content") or {}), # $7 + item.get("tool_call_id"), # $8 + content_to_json(item.get("tool_calls") or {}), # $9 + item.get("model"), # $10 + item.get("token_count"), # $11 + select_tokenizer(item.get("model"))["type"], # $12 + item.get("created_at") or utcnow(), # $13 + utcnow().timestamp(), # $14 + ] + for item in data_dicts + ] + + return [ + ( + session_exists_query, + [session_id, developer_id], + "fetchrow", + ), + ( + entry_query, + params, + "fetchmany", + ), + ] + + +@rewrap_exceptions(common_db_exceptions("entry_relation", ["create"])) +@wrap_in_class(Relation) +@increase_counter("add_entry_relations") +@pg_query +@beartype +async def add_entry_relations( + *, + developer_id: UUID, + session_id: UUID, + data: list[Relation], +) -> list[tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]]: + """ + Add relations between entries in a session. + + Parameters: + developer_id (UUID): The ID of the developer. + session_id (UUID): The ID of the session. + data (list[Relation]): The list of relations to add. + + Returns: + list[tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]]: SQL query and parameters for adding the relations. + """ + # Convert the data to a list of dictionaries + data_dicts = [item.model_dump(mode="json") for item in data] + + # Prepare the parameters for the query + # $1 + # $2 + # $3 + # $4 + params = [ + [ + session_id, # $1 + item.get("head"), # $2 + item.get("relation"), # $3 + item.get("tail"), # $4 + ] + for item in data_dicts + ] + + return [ + ( + session_exists_query, + [session_id, developer_id], + "fetchrow", + ), + ( + entry_relation_query, + params, + "fetchmany", + ), + ] diff --git a/agents-api/agents_api/queries/entries/delete_entries.py b/agents-api/agents_api/queries/entries/delete_entries.py new file mode 100644 index 000000000..c47e9e758 --- /dev/null +++ b/agents-api/agents_api/queries/entries/delete_entries.py @@ -0,0 +1,121 @@ +from typing import Literal +from uuid import UUID + +from beartype import beartype + +from ...autogen.openapi_model import ResourceDeletedResponse +from ...common.utils.datetime import utcnow +from ...common.utils.db_exceptions import common_db_exceptions +from ...metrics.counters import increase_counter +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL query for deleting entries with a developer check +delete_entry_query = """ +DELETE FROM entries +USING developers +WHERE entries.session_id = $1 -- session_id + AND developers.developer_id = $2 -- developer_id + +RETURNING entries.session_id as session_id; +""" + +# Define the raw SQL query for deleting entries with a developer check +delete_entry_relations_query = """ +DELETE FROM entry_relations +WHERE entry_relations.session_id = $1 -- session_id +""" + +# Define the raw SQL query for deleting entries with a developer check +delete_entry_relations_by_ids_query = """ +DELETE FROM entry_relations +WHERE entry_relations.session_id = $1 -- session_id + AND (entry_relations.head = ANY($2) -- entry_ids + OR entry_relations.tail = ANY($2)) -- entry_ids +""" + +# Define the raw SQL query for deleting entries by entry_ids with a developer check +delete_entry_by_ids_query = """ +DELETE FROM entries +USING developers +WHERE entries.entry_id = ANY($1) -- entry_ids + AND developers.developer_id = $2 -- developer_id + AND entries.session_id = $3 -- session_id + +RETURNING entries.entry_id as entry_id; +""" + +# Add a session_exists_query similar to create_entries.py +session_exists_query = """ +SELECT EXISTS ( + SELECT 1 + FROM sessions + WHERE session_id = $1 + AND developer_id = $2 +); +""" + + +@rewrap_exceptions(common_db_exceptions("entry", ["delete"])) +@wrap_in_class( + ResourceDeletedResponse, + one=True, + transform=lambda d: { + "id": d["session_id"], + "deleted_at": utcnow(), + "jobs": [], + }, +) +@increase_counter("delete_entries_for_session") +@pg_query +@beartype +async def delete_entries_for_session( + *, + developer_id: UUID, + session_id: UUID, +) -> list[tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]]: + """Delete all entries for a given session.""" + return [ + (session_exists_query, [session_id, developer_id], "fetchrow"), + (delete_entry_relations_query, [session_id], "fetchmany"), + (delete_entry_query, [session_id, developer_id], "fetchmany"), + ] + + +@rewrap_exceptions(common_db_exceptions("entry", ["delete"])) +@wrap_in_class( + ResourceDeletedResponse, + transform=lambda d: { + "id": d["entry_id"], + "deleted_at": utcnow(), + "jobs": [], + }, +) +@increase_counter("delete_entries") +@pg_query +@beartype +async def delete_entries( + *, developer_id: UUID, session_id: UUID, entry_ids: list[UUID] +) -> list[tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]]: + """Delete specific entries by their IDs. + + Parameters: + developer_id (UUID): The ID of the developer. + session_id (UUID): The ID of the session. + entry_ids (list[UUID]): The IDs of the entries to delete. + + Returns: + list[tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]]: SQL query and parameters for deleting the entries. + """ + return [ + ( + session_exists_query, + [session_id, developer_id], + "fetchrow", + ), + (delete_entry_relations_by_ids_query, [session_id, entry_ids], "fetch"), + ( + delete_entry_by_ids_query, + [entry_ids, developer_id, session_id], + "fetch", + ), + ] diff --git a/agents-api/agents_api/queries/entries/get_history.py b/agents-api/agents_api/queries/entries/get_history.py new file mode 100644 index 000000000..be4eebb5d --- /dev/null +++ b/agents-api/agents_api/queries/entries/get_history.py @@ -0,0 +1,95 @@ +import json +from uuid import UUID + +from beartype import beartype + +from ...autogen.openapi_model import History +from ...common.utils.datetime import utcnow +from ...common.utils.db_exceptions import common_db_exceptions +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL query for getting history with a developer check and relations +history_query = """ +WITH entries AS ( + SELECT + e.entry_id AS id, + e.session_id, + e.role, + e.name, + e.content, + e.source, + e.token_count, + e.created_at, + e.timestamp, + e.tool_calls, + e.tool_call_id, + e.tokenizer + FROM entries e + JOIN developers d ON d.developer_id = $3 + WHERE e.session_id = $1 + AND e.source = ANY($2) +), +relations AS ( + SELECT + er.head, + er.relation, + er.tail + FROM entry_relations er + WHERE er.session_id = $1 +) +SELECT + (SELECT json_agg(e) FROM entries e) AS entries, + (SELECT json_agg(r) FROM relations r) AS relations, + $1::uuid AS session_id +""" + + +def _transform(d): + return { + "entries": [ + { + **entry, + } + for entry in json.loads(d.get("entries") or "[]") + ], + "relations": [ + { + "head": r["head"], + "relation": r["relation"], + "tail": r["tail"], + } + for r in (d.get("relations") or []) + ], + "session_id": d.get("session_id"), + "created_at": utcnow(), + } + + +@rewrap_exceptions(common_db_exceptions("history", ["get"])) +@wrap_in_class( + History, + one=True, + transform=_transform, +) +@pg_query +@beartype +async def get_history( + *, + developer_id: UUID, + session_id: UUID, + allowed_sources: list[str] = ["api_request", "api_response"], +) -> tuple[str, list] | tuple[str, list, str]: + """Get the history of a session. + + Parameters: + developer_id (UUID): The ID of the developer. + session_id (UUID): The ID of the session. + allowed_sources (list[str]): The sources to include in the history. + + Returns: + tuple[str, list] | tuple[str, list, str]: SQL query and parameters for getting the history. + """ + return ( + history_query, + [session_id, allowed_sources, developer_id], + ) diff --git a/agents-api/agents_api/queries/entries/list_entries.py b/agents-api/agents_api/queries/entries/list_entries.py new file mode 100644 index 000000000..de4714ee0 --- /dev/null +++ b/agents-api/agents_api/queries/entries/list_entries.py @@ -0,0 +1,109 @@ +from typing import Literal +from uuid import UUID + +from beartype import beartype +from fastapi import HTTPException + +from ...autogen.openapi_model import Entry +from ...common.utils.db_exceptions import common_db_exceptions +from ...metrics.counters import increase_counter +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Query for checking if the session exists +session_exists_query = """ +SELECT EXISTS ( + SELECT 1 FROM sessions + WHERE session_id = $1 AND developer_id = $2 +) AS exists; +""" + +list_entries_query = """ +SELECT + e.entry_id as id, + e.session_id, + e.role, + e.name, + e.content, + e.source, + e.token_count, + e.created_at, + e.timestamp, + e.event_type, + e.tool_call_id, + e.tool_calls, + e.model, + e.tokenizer +FROM entries e +JOIN developers d ON d.developer_id = $5 +LEFT JOIN entry_relations er ON er.head = e.entry_id AND er.session_id = e.session_id +WHERE e.session_id = $1 +AND e.source = ANY($2) +AND (er.relation IS NULL OR er.relation != ALL($6)) +ORDER BY e.{sort_by} {direction} -- safe to interpolate +LIMIT $3 +OFFSET $4; +""" + + +@rewrap_exceptions(common_db_exceptions("entry", ["list"])) +@wrap_in_class(Entry) +@increase_counter("list_entries") +@pg_query +@beartype +async def list_entries( + *, + developer_id: UUID, + session_id: UUID, + allowed_sources: list[str] = ["api_request", "api_response"], + limit: int = 100, + offset: int = 0, + sort_by: Literal["created_at", "timestamp"] = "timestamp", + direction: Literal["asc", "desc"] = "asc", + exclude_relations: list[str] = [], +) -> list[tuple[str, list] | tuple[str, list, str]]: + """List entries in a session. + + Parameters: + developer_id (UUID): The ID of the developer. + session_id (UUID): The ID of the session. + allowed_sources (list[str]): The sources to include in the history. + limit (int): The number of entries to return. + offset (int): The number of entries to skip. + sort_by (Literal["created_at", "timestamp"]): The field to sort by. + direction (Literal["asc", "desc"]): The direction to sort by. + exclude_relations (list[str]): The relations to exclude. + + Returns: + tuple[str, list] | tuple[str, list, str]: SQL query and parameters for listing the entries. + """ + if limit < 1 or limit > 1000: + raise HTTPException(status_code=400, detail="Limit must be between 1 and 1000") + if offset < 0: + raise HTTPException(status_code=400, detail="Offset must be non-negative") + + query = list_entries_query.format( + sort_by=sort_by, + direction=direction, + ) + + # Parameters for the entry query + entry_params = [ + session_id, # $1 + allowed_sources, # $2 + limit, # $3 + offset, # $4 + developer_id, # $5 + exclude_relations, # $6 + ] + return [ + ( + session_exists_query, + [session_id, developer_id], + "fetchrow", + ), + ( + query, + entry_params, + "fetch", + ), + ] diff --git a/agents-api/agents_api/queries/executions/__init__.py b/agents-api/agents_api/queries/executions/__init__.py new file mode 100644 index 000000000..dd5efd23b --- /dev/null +++ b/agents-api/agents_api/queries/executions/__init__.py @@ -0,0 +1,33 @@ +# ruff: noqa: F401, F403, F405 + +""" +The `execution` module provides SQL query functions for managing executions +in the TimescaleDB database. This includes operations for: + +- Creating new executions +- Deleting executions +- Retrieving execution history +- Listing executions with filtering and pagination +""" + +from .count_executions import count_executions +from .create_execution import create_execution +from .create_execution_transition import create_execution_transition +from .get_execution import get_execution +from .get_execution_transition import get_execution_transition +from .list_execution_transitions import list_execution_transitions +from .list_executions import list_executions +from .lookup_temporal_data import lookup_temporal_data +from .prepare_execution_input import prepare_execution_input + +__all__ = [ + "count_executions", + "create_execution", + "create_execution_transition", + "get_execution", + "get_execution_transition", + "list_execution_transitions", + "list_executions", + "lookup_temporal_data", + "prepare_execution_input", +] diff --git a/agents-api/agents_api/models/execution/constants.py b/agents-api/agents_api/queries/executions/constants.py similarity index 100% rename from agents-api/agents_api/models/execution/constants.py rename to agents-api/agents_api/queries/executions/constants.py diff --git a/agents-api/agents_api/queries/executions/count_executions.py b/agents-api/agents_api/queries/executions/count_executions.py new file mode 100644 index 000000000..76d765450 --- /dev/null +++ b/agents-api/agents_api/queries/executions/count_executions.py @@ -0,0 +1,41 @@ +from typing import Literal +from uuid import UUID + +from beartype import beartype + +from ...common.utils.db_exceptions import common_db_exceptions +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Query to count executions for a given task +execution_count_query = """ +SELECT COUNT(*) FROM latest_executions +WHERE + developer_id = $1 + AND task_id = $2; +""" + + +@rewrap_exceptions(common_db_exceptions("execution", ["count"])) +@wrap_in_class(dict, one=True) +@pg_query +@beartype +async def count_executions( + *, + developer_id: UUID, + task_id: UUID, +) -> tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]: + """ + Count the number of executions for a given task. + + Parameters: + developer_id (UUID): The ID of the developer. + task_id (UUID): The ID of the task. + + Returns: + tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]: SQL query and parameters for counting executions. + """ + return ( + execution_count_query, + [developer_id, task_id], + "fetchrow", + ) diff --git a/agents-api/agents_api/queries/executions/create_execution.py b/agents-api/agents_api/queries/executions/create_execution.py new file mode 100644 index 000000000..49eacb8e6 --- /dev/null +++ b/agents-api/agents_api/queries/executions/create_execution.py @@ -0,0 +1,96 @@ +from typing import Annotated +from uuid import UUID + +from beartype import beartype +from uuid_extensions import uuid7 + +from ...autogen.openapi_model import CreateExecutionRequest, Execution +from ...common.utils.datetime import utcnow +from ...common.utils.db_exceptions import common_db_exceptions +from ...common.utils.types import dict_like +from ...metrics.counters import increase_counter +from ..utils import pg_query, rewrap_exceptions, wrap_in_class +from .constants import OUTPUT_UNNEST_KEY + +create_execution_query = """ +INSERT INTO executions +( + developer_id, + task_id, + execution_id, + input, + metadata, + task_version +) +VALUES +( + $1, + $2, + $3, + $4, + $5, + 1 +) +RETURNING *; +""" + + +@rewrap_exceptions(common_db_exceptions("execution", ["create"])) +@wrap_in_class( + Execution, + one=True, + transform=lambda d: { + "id": d["execution_id"], + "status": "queued", + "updated_at": utcnow(), + **d, + }, +) +@increase_counter("create_execution") +@pg_query +@beartype +async def create_execution( + *, + developer_id: UUID, + task_id: UUID, + execution_id: UUID | None = None, + data: Annotated[CreateExecutionRequest | dict, dict_like(CreateExecutionRequest)], +) -> tuple[str, list]: + """ + Create a new execution. + + Parameters: + developer_id (UUID): The ID of the developer. + task_id (UUID): The ID of the task. + execution_id (UUID | None): The ID of the execution. + data (CreateExecutionRequest | dict): The data for the execution. + + Returns: + tuple[str, list]: SQL query and parameters for creating the execution. + """ + execution_id = execution_id or uuid7() + + developer_id = str(developer_id) + task_id = str(task_id) + execution_id = str(execution_id) + + if isinstance(data, CreateExecutionRequest): + data.metadata = data.metadata or {} + execution_data = data.model_dump() + else: + data["metadata"] = data.get("metadata", {}) + execution_data = data + + if execution_data["output"] is not None and not isinstance(execution_data["output"], dict): + execution_data["output"] = {OUTPUT_UNNEST_KEY: execution_data["output"]} + + return ( + create_execution_query, + [ + developer_id, + task_id, + execution_id, + execution_data["input"], + execution_data["metadata"], + ], + ) diff --git a/agents-api/agents_api/queries/executions/create_execution_transition.py b/agents-api/agents_api/queries/executions/create_execution_transition.py new file mode 100644 index 000000000..e9f037f5f --- /dev/null +++ b/agents-api/agents_api/queries/executions/create_execution_transition.py @@ -0,0 +1,164 @@ +from typing import Literal +from uuid import UUID + +from beartype import beartype +from uuid_extensions import uuid7 + +from ...autogen.openapi_model import ( + CreateTransitionRequest, + Transition, +) +from ...common.utils.datetime import utcnow +from ...common.utils.db_exceptions import common_db_exceptions +from ...metrics.counters import increase_counter +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Query to create a transition +create_execution_transition_query = """ +INSERT INTO transitions +( + execution_id, + transition_id, + type, + step_definition, + step_label, + current_step, + next_step, + output, + task_token, + metadata +) +VALUES +( + $1, + $2, + $3, + $4, + $5, + $6, + $7, + $8, + $9, + $10 +) +RETURNING *; +""" + + +# FIXME: Remove this function +def validate_transition_targets(data: CreateTransitionRequest) -> None: + # Make sure the current/next targets are valid + match data.type: + case "finish_branch": + pass # TODO: Implement + case "finish" | "error" | "cancelled": + pass + + # FIXME: HACK: Fix this and uncomment + + # assert ( + # data.next is None + # ), "Next target must be None for finish/finish_branch/error/cancelled" + + case "init_branch" | "init": + assert data.next and data.current.step == data.next.step == 0, ( + "Next target must be same as current for init_branch/init and step 0" + ) + + case "wait": + assert data.next is None, "Next target must be None for wait" + + case "resume" | "step": + assert data.next is not None, "Next target must be provided for resume/step" + + if data.next.workflow == data.current.workflow: + assert data.next.step > data.current.step, ( + "Next step must be greater than current" + ) + + case _: + msg = f"Invalid transition type: {data.type}" + raise ValueError(msg) + + +@rewrap_exceptions(common_db_exceptions("transition", ["create"])) +@wrap_in_class( + Transition, + transform=lambda d: { + **d, + "id": d["transition_id"], + "current": {"workflow": d["current_step"][0], "step": d["current_step"][1]}, + "next": d["next_step"] and {"workflow": d["next_step"][0], "step": d["next_step"][1]}, + "updated_at": utcnow(), + }, + one=True, +) +@increase_counter("create_execution_transition") +@pg_query +@beartype +async def create_execution_transition( + *, + developer_id: UUID, + execution_id: UUID, + data: CreateTransitionRequest, + # Only one of these needed + transition_id: UUID | None = None, + task_token: str | None = None, +) -> tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]: + """ + Create a new execution transition. + + Parameters: + developer_id (UUID): The ID of the developer. + execution_id (UUID): The ID of the execution. + data (CreateTransitionRequest): The data for the transition. + transition_id (UUID | None): The ID of the transition. + task_token (str | None): The task token. + + Returns: + tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]: SQL query and parameters for creating the transition. + """ + transition_id = transition_id or uuid7() + data.metadata = data.metadata or {} + data.execution_id = execution_id + + # Dump to json + if isinstance(data.output, list): + data.output = [ + item.model_dump(mode="json") if hasattr(item, "model_dump") else item + for item in data.output + ] + + elif hasattr(data.output, "model_dump"): + data.output = data.output.model_dump(mode="json") + + # Prepare the transition data + transition_data = data.model_dump(exclude_unset=True, exclude={"id"}) + + # Parse the current and next targets + validate_transition_targets(data) + current_target = transition_data.pop("current") + next_target = transition_data.pop("next") + + transition_data["current"] = (current_target["workflow"], current_target["step"]) + transition_data["next"] = next_target and ( + next_target["workflow"], + next_target["step"], + ) + + return ( + create_execution_transition_query, + [ + execution_id, + transition_id, + data.type, + {}, + None, + transition_data["current"], + transition_data["next"], + data.output, + task_token, + data.metadata, + ], + "fetchrow", + ) diff --git a/agents-api/agents_api/queries/executions/create_temporal_lookup.py b/agents-api/agents_api/queries/executions/create_temporal_lookup.py new file mode 100644 index 000000000..be77e20c1 --- /dev/null +++ b/agents-api/agents_api/queries/executions/create_temporal_lookup.py @@ -0,0 +1,62 @@ +from uuid import UUID + +from beartype import beartype +from temporalio.client import WorkflowHandle + +from ...common.utils.db_exceptions import common_db_exceptions +from ...metrics.counters import increase_counter +from ..utils import pg_query, rewrap_exceptions + +# Query to create a temporal lookup +create_temporal_lookup_query = """ +INSERT INTO temporal_executions_lookup +( + execution_id, + id, + run_id, + first_execution_run_id, + result_run_id +) +VALUES +( + $1, + $2, + $3, + $4, + $5 +) +RETURNING *; +""" + + +@rewrap_exceptions(common_db_exceptions("temporal_execution", ["create"])) +@increase_counter("create_temporal_lookup") +@pg_query +@beartype +async def create_temporal_lookup( + *, + execution_id: UUID, + workflow_handle: WorkflowHandle, +) -> tuple[str, list]: + """ + Create a temporal lookup for a given execution. + + Parameters: + execution_id (UUID): The ID of the execution. + workflow_handle (WorkflowHandle): The workflow handle. + + Returns: + tuple[str, list]: SQL query and parameters for creating the temporal lookup. + """ + execution_id = str(execution_id) + + return ( + create_temporal_lookup_query, + [ + execution_id, + workflow_handle.id, + workflow_handle.run_id, + workflow_handle.first_execution_run_id, + workflow_handle.result_run_id, + ], + ) diff --git a/agents-api/agents_api/queries/executions/get_execution.py b/agents-api/agents_api/queries/executions/get_execution.py new file mode 100644 index 000000000..d4582358b --- /dev/null +++ b/agents-api/agents_api/queries/executions/get_execution.py @@ -0,0 +1,51 @@ +from typing import Literal +from uuid import UUID + +from beartype import beartype + +from ...autogen.openapi_model import Execution +from ...common.utils.db_exceptions import common_db_exceptions +from ..utils import pg_query, rewrap_exceptions, wrap_in_class +from .constants import OUTPUT_UNNEST_KEY + +# Query to get an execution +get_execution_query = """ +SELECT * FROM latest_executions +WHERE + execution_id = $1 +LIMIT 1; +""" + + +@rewrap_exceptions(common_db_exceptions("execution", ["get"])) +@wrap_in_class( + Execution, + one=True, + transform=lambda d: { + "id": d.pop("execution_id"), + **d, + "output": d["output"][OUTPUT_UNNEST_KEY] + if isinstance(d["output"], dict) and OUTPUT_UNNEST_KEY in d["output"] + else d["output"], + }, +) +@pg_query +@beartype +async def get_execution( + *, + execution_id: UUID, +) -> tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]: + """ + Get an execution by its ID. + + Parameters: + execution_id (UUID): The ID of the execution. + + Returns: + tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]: SQL query and parameters for getting the execution. + """ + return ( + get_execution_query, + [execution_id], + "fetchrow", + ) diff --git a/agents-api/agents_api/queries/executions/get_execution_transition.py b/agents-api/agents_api/queries/executions/get_execution_transition.py new file mode 100644 index 000000000..d8c23d3f0 --- /dev/null +++ b/agents-api/agents_api/queries/executions/get_execution_transition.py @@ -0,0 +1,67 @@ +from typing import Literal +from uuid import UUID + +from beartype import beartype + +from ...autogen.openapi_model import Transition +from ...common.utils.db_exceptions import common_db_exceptions +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Query to get an execution transition +get_execution_transition_query = """ +SELECT * FROM transitions +WHERE + transition_id = $1 + OR task_token = $2 +LIMIT 1; +""" + + +def _transform(d): + current_step = d.pop("current_step") + next_step = d.pop("next_step", None) + + return { + "current": { + "workflow": current_step[0], + "step": current_step[1], + }, + "next": {"workflow": next_step[0], "step": next_step[1]} + if next_step is not None + else None, + **d, + } + + +@rewrap_exceptions(common_db_exceptions("transition", ["get"])) +@wrap_in_class(Transition, one=True, transform=_transform) +@pg_query +@beartype +async def get_execution_transition( + *, + transition_id: UUID | None = None, + task_token: str | None = None, +) -> tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]: + """ + Get an execution transition by its ID or task token. + + Parameters: + transition_id (UUID | None): The ID of the transition. + task_token (str | None): The task token. + + Returns: + tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]: SQL query and parameters for getting the execution transition. + """ + # At least one of `transition_id` or `task_token` must be provided + assert transition_id or task_token, ( + "At least one of `transition_id` or `task_token` must be provided." + ) + + return ( + get_execution_transition_query, + [ + transition_id, + task_token, + ], + "fetchrow", + ) diff --git a/agents-api/agents_api/queries/executions/get_paused_execution_token.py b/agents-api/agents_api/queries/executions/get_paused_execution_token.py new file mode 100644 index 000000000..677fd91a3 --- /dev/null +++ b/agents-api/agents_api/queries/executions/get_paused_execution_token.py @@ -0,0 +1,43 @@ +from typing import Literal +from uuid import UUID + +from beartype import beartype + +from ...common.utils.db_exceptions import common_db_exceptions +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Query to get a paused execution token +get_paused_execution_token_query = """ +SELECT * FROM latest_transitions +WHERE + execution_id = $1 + AND type = 'wait' + ORDER BY created_at DESC + LIMIT 1; +""" + + +@rewrap_exceptions(common_db_exceptions("execution", ["get_paused_execution_token"])) +@wrap_in_class(dict, one=True) +@pg_query +@beartype +async def get_paused_execution_token( + *, + execution_id: UUID, +) -> tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]: + """ + Get a paused execution token for a given execution. + + Parameters: + execution_id (UUID): The ID of the execution. + + Returns: + tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]: SQL query and parameters for getting a paused execution token. + """ + execution_id = str(execution_id) + + return ( + get_paused_execution_token_query, + [execution_id], + "fetchrow", + ) diff --git a/agents-api/agents_api/queries/executions/get_temporal_workflow_data.py b/agents-api/agents_api/queries/executions/get_temporal_workflow_data.py new file mode 100644 index 000000000..00fa670ae --- /dev/null +++ b/agents-api/agents_api/queries/executions/get_temporal_workflow_data.py @@ -0,0 +1,44 @@ +from typing import Literal +from uuid import UUID + +from beartype import beartype + +from ...common.utils.db_exceptions import common_db_exceptions +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Query to get temporal workflow data +get_temporal_workflow_data_query = """ +SELECT id, run_id, result_run_id, first_execution_run_id FROM temporal_executions_lookup +WHERE + execution_id = $1 +LIMIT 1; +""" + + +@rewrap_exceptions(common_db_exceptions("temporal_execution", ["get"])) +@wrap_in_class(dict, one=True) +@pg_query +@beartype +async def get_temporal_workflow_data( + *, + execution_id: UUID, +) -> tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]: + """ + Get temporal workflow data for a given execution. + + Parameters: + execution_id (UUID): The ID of the execution. + + Returns: + tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]: SQL query and parameters for getting temporal workflow data. + """ + # Executions are allowed direct GET access if they have execution_id + execution_id = str(execution_id) + + return ( + get_temporal_workflow_data_query, + [ + execution_id, + ], + "fetchrow", + ) diff --git a/agents-api/agents_api/queries/executions/list_execution_transitions.py b/agents-api/agents_api/queries/executions/list_execution_transitions.py new file mode 100644 index 000000000..2440ffb29 --- /dev/null +++ b/agents-api/agents_api/queries/executions/list_execution_transitions.py @@ -0,0 +1,104 @@ +from typing import Literal +from uuid import UUID + +import asyncpg +from beartype import beartype +from fastapi import HTTPException + +from ...autogen.openapi_model import Transition +from ...common.utils.datetime import utcnow +from ...common.utils.db_exceptions import common_db_exceptions, partialclass +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Query to list execution transitions +list_execution_transitions_query = """ +SELECT * FROM transitions +WHERE + execution_id = $1 +ORDER BY + CASE WHEN $4 = 'created_at' AND $5 = 'asc' THEN created_at END ASC NULLS LAST, + CASE WHEN $4 = 'created_at' AND $5 = 'desc' THEN created_at END DESC NULLS LAST +LIMIT $2 OFFSET $3; +""" +# Query to get a single transition +get_execution_transition_query = """ +SELECT * FROM transitions +WHERE + execution_id = $1 + AND transition_id = $2; +""" + + +def _transform(d): + current_step = d.pop("current_step") + next_step = d.pop("next_step", None) + + return { + "id": d["transition_id"], + "updated_at": utcnow(), + "current": { + "workflow": current_step[0], + "step": current_step[1], + }, + "next": {"workflow": next_step[0], "step": next_step[1]} + if next_step is not None + else None, + **d, + } + + +@rewrap_exceptions({ + asyncpg.InvalidRowCountInLimitClauseError: partialclass( + HTTPException, status_code=400, detail="Invalid limit clause" + ), + asyncpg.InvalidRowCountInResultOffsetClauseError: partialclass( + HTTPException, status_code=400, detail="Invalid offset clause" + ), + **common_db_exceptions("transition", ["list"]), +}) +@wrap_in_class( + Transition, + transform=_transform, +) +@pg_query(debug=True) +@beartype +async def list_execution_transitions( + *, + execution_id: UUID, + transition_id: UUID | None = None, + limit: int = 100, + offset: int = 0, + sort_by: Literal["created_at"] = "created_at", + direction: Literal["asc", "desc"] = "desc", +) -> tuple[str, list]: + """ + List execution transitions for a given execution. + + Parameters: + execution_id (UUID): The ID of the execution. + limit (int): The number of transitions to return. + offset (int): The number of transitions to skip. + sort_by (Literal["created_at", "updated_at"]): The field to sort by. + direction (Literal["asc", "desc"]): The direction to sort by. + + Returns: + tuple[str, list]: SQL query and parameters for listing execution transitions. + """ + if transition_id is not None: + return ( + get_execution_transition_query, + [ + str(execution_id), + str(transition_id), + ], + ) + return ( + list_execution_transitions_query, + [ + str(execution_id), + limit, + offset, + sort_by, + direction, + ], + ) diff --git a/agents-api/agents_api/queries/executions/list_executions.py b/agents-api/agents_api/queries/executions/list_executions.py new file mode 100644 index 000000000..071aa1ac5 --- /dev/null +++ b/agents-api/agents_api/queries/executions/list_executions.py @@ -0,0 +1,92 @@ +from typing import Literal +from uuid import UUID + +import asyncpg +from beartype import beartype +from fastapi import HTTPException + +from ...autogen.openapi_model import Execution +from ...common.utils.db_exceptions import common_db_exceptions, partialclass +from ..utils import pg_query, rewrap_exceptions, wrap_in_class +from .constants import OUTPUT_UNNEST_KEY + +# Query to list executions +list_executions_query = """ +SELECT * FROM latest_executions +WHERE + developer_id = $1 AND + task_id = $2 +ORDER BY + CASE WHEN $3 = 'created_at' AND $4 = 'asc' THEN created_at END ASC NULLS LAST, + CASE WHEN $3 = 'created_at' AND $4 = 'desc' THEN created_at END DESC NULLS LAST, + CASE WHEN $3 = 'updated_at' AND $4 = 'asc' THEN updated_at END ASC NULLS LAST, + CASE WHEN $3 = 'updated_at' AND $4 = 'desc' THEN updated_at END DESC NULLS LAST +LIMIT $5 OFFSET $6; +""" + + +@rewrap_exceptions({ + asyncpg.InvalidRowCountInLimitClauseError: partialclass( + HTTPException, status_code=400, detail="Invalid limit clause" + ), + asyncpg.InvalidRowCountInResultOffsetClauseError: partialclass( + HTTPException, status_code=400, detail="Invalid offset clause" + ), + **common_db_exceptions("execution", ["list"]), +}) +@wrap_in_class( + Execution, + transform=lambda d: { + "id": d.pop("execution_id"), + **d, + "output": d["output"][OUTPUT_UNNEST_KEY] + if isinstance(d.get("output"), dict) and OUTPUT_UNNEST_KEY in d["output"] + else d.get("output"), + }, +) +@pg_query +@beartype +async def list_executions( + *, + developer_id: UUID, + task_id: UUID, + limit: int = 100, + offset: int = 0, + sort_by: Literal["created_at", "updated_at"] = "created_at", + direction: Literal["asc", "desc"] = "desc", +) -> tuple[str, list]: + """ + List executions for a given task. + + Parameters: + developer_id (UUID): The ID of the developer. + task_id (UUID): The ID of the task. + limit (int): The number of executions to return. + offset (int): The number of executions to skip. + sort_by (Literal["created_at", "updated_at"]): The field to sort by. + direction (Literal["asc", "desc"]): The direction to sort by. + + Returns: + tuple[str, list]: SQL query and parameters for listing executions. + """ + + if sort_by not in ["created_at", "updated_at"]: + raise HTTPException(status_code=400, detail="Invalid sort field") + + if limit > 100 or limit < 1: + raise HTTPException(status_code=400, detail="Limit must be between 1 and 100") + + if offset < 0: + raise HTTPException(status_code=400, detail="Offset must be >= 0") + + return ( + list_executions_query, + [ + developer_id, + task_id, + sort_by, + direction, + limit, + offset, + ], + ) diff --git a/agents-api/agents_api/queries/executions/lookup_temporal_data.py b/agents-api/agents_api/queries/executions/lookup_temporal_data.py new file mode 100644 index 000000000..b35ceb2a6 --- /dev/null +++ b/agents-api/agents_api/queries/executions/lookup_temporal_data.py @@ -0,0 +1,49 @@ +from typing import Literal +from uuid import UUID + +from beartype import beartype + +from ...common.utils.db_exceptions import common_db_exceptions +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Query to lookup temporal data +lookup_temporal_data_query = """ +SELECT t.* +FROM + temporal_executions_lookup t, + executions e +WHERE + t.execution_id = e.execution_id + AND e.execution_id = $1 + AND e.developer_id = $2 +LIMIT 1; +""" + + +@rewrap_exceptions(common_db_exceptions("temporal_execution", ["get"])) +@wrap_in_class(dict, one=True) +@pg_query +@beartype +async def lookup_temporal_data( + *, + developer_id: UUID, + execution_id: UUID, +) -> tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]: + """ + Lookup temporal data for a given execution. + + Parameters: + developer_id (UUID): The ID of the developer. + execution_id (UUID): The ID of the execution. + + Returns: + tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]: SQL query and parameters for looking up temporal data. + """ + developer_id = str(developer_id) + execution_id = str(execution_id) + + return ( + lookup_temporal_data_query, + [execution_id, developer_id], + "fetchrow", + ) diff --git a/agents-api/agents_api/queries/executions/prepare_execution_input.py b/agents-api/agents_api/queries/executions/prepare_execution_input.py new file mode 100644 index 000000000..ecbb7c319 --- /dev/null +++ b/agents-api/agents_api/queries/executions/prepare_execution_input.py @@ -0,0 +1,97 @@ +from uuid import UUID + +from beartype import beartype + +from ...common.protocol.tasks import ExecutionInput +from ...common.utils.db_exceptions import common_db_exceptions +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Query to prepare execution input +prepare_execution_input_query = """ +SELECT * FROM +( + SELECT to_jsonb(a) AS agent FROM ( + SELECT * FROM agents + WHERE + developer_id = $1 AND + agent_id = ( + SELECT agent_id FROM tasks + WHERE developer_id = $1 AND task_id = $2 + LIMIT 1 + ) + LIMIT 1 + ) a +) AS agent, +( + SELECT COALESCE(jsonb_agg(r), '[]'::jsonb) AS tools FROM ( + SELECT * FROM tools + WHERE + developer_id = $1 AND + task_id = $2 + ) r +) AS tools, +( + SELECT to_jsonb(e) AS execution FROM ( + SELECT * FROM latest_executions + WHERE + developer_id = $1 AND + task_id = $2 AND + execution_id = $3 + LIMIT 1 + ) e +) AS execution; +""" + + +@rewrap_exceptions(common_db_exceptions("execution_data", ["get"])) +@wrap_in_class( + ExecutionInput, + one=True, + transform=lambda d: { + **d, + # "task": { + # "tools": d["tools"], + # **d["task"], + # }, + "developer_id": d["agent"]["developer_id"], + "agent": { + "id": d["agent"]["agent_id"], + **d["agent"], + }, + "agent_tools": [ + {tool["type"]: tool.pop("spec"), **tool} for tool in d["tools"] if tool is not None + ], + "arguments": d["execution"]["input"], + "execution": { + "id": d["execution"]["execution_id"], + **d["execution"], + }, + }, +) +@pg_query +@beartype +async def prepare_execution_input( + *, + developer_id: UUID, + task_id: UUID, + execution_id: UUID, +) -> tuple[str, list]: + """ + Prepare the execution input for a given task. + + Parameters: + developer_id (UUID): The ID of the developer. + task_id (UUID): The ID of the task. + execution_id (UUID): The ID of the execution. + + Returns: + tuple[str, list]: SQL query and parameters for preparing the execution input. + """ + return ( + prepare_execution_input_query, + [ + str(developer_id), + str(task_id), + str(execution_id), + ], + ) diff --git a/agents-api/agents_api/queries/files/__init__.py b/agents-api/agents_api/queries/files/__init__.py new file mode 100644 index 000000000..99670a8fc --- /dev/null +++ b/agents-api/agents_api/queries/files/__init__.py @@ -0,0 +1,16 @@ +""" +The `files` module within the `queries` package provides SQL query functions for managing files +in the PostgreSQL database. This includes operations for: + +- Creating new files +- Retrieving file details +- Listing files with filtering and pagination +- Deleting files and their associations +""" + +from .create_file import create_file +from .delete_file import delete_file +from .get_file import get_file +from .list_files import list_files + +__all__ = ["create_file", "delete_file", "get_file", "list_files"] diff --git a/agents-api/agents_api/queries/files/create_file.py b/agents-api/agents_api/queries/files/create_file.py new file mode 100644 index 000000000..887493561 --- /dev/null +++ b/agents-api/agents_api/queries/files/create_file.py @@ -0,0 +1,124 @@ +""" +This module contains the functionality for creating files in the PostgreSQL database. +It includes functions to construct and execute SQL queries for inserting new file records. +""" + +import base64 +import hashlib +from typing import Literal +from uuid import UUID + +from beartype import beartype +from uuid_extensions import uuid7 + +from ...autogen.openapi_model import CreateFileRequest, File +from ...common.utils.db_exceptions import common_db_exceptions +from ...metrics.counters import increase_counter +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Create file +file_query = """ +INSERT INTO files ( + developer_id, + file_id, + name, + description, + mime_type, + size, + hash +) +VALUES ( + $1, -- developer_id + $2, -- file_id + $3, -- name + $4, -- description + $5, -- mime_type + $6, -- size + $7 -- hash +) +RETURNING *; +""" + +# Replace both user_file and agent_file queries with a single file_owner query +file_owner_query = """ +WITH inserted_owner AS ( + INSERT INTO file_owners ( + developer_id, + file_id, + owner_type, + owner_id + ) + VALUES ($1, $2, $3, $4) + RETURNING file_id +) +SELECT f.* +FROM inserted_owner io +JOIN files f ON f.file_id = io.file_id; +""" + + +# Add error handling decorator +@rewrap_exceptions(common_db_exceptions("file", ["create"])) +@wrap_in_class( + File, + one=True, + transform=lambda d: { + **d, + "id": d["file_id"], + "hash": d["hash"].hex(), + "content": "DUMMY: NEED TO FETCH CONTENT FROM BLOB STORAGE", + }, +) +@increase_counter("create_file") +@pg_query +@beartype +async def create_file( + *, + developer_id: UUID, + file_id: UUID | None = None, + data: CreateFileRequest, + owner_type: Literal["user", "agent"] | None = None, + owner_id: UUID | None = None, +) -> list[tuple[str, list] | tuple[str, list, str]]: + """ + Constructs and executes SQL queries to create a new file and optionally associate it with an owner. + + Parameters: + developer_id (UUID): The unique identifier for the developer. + file_id (UUID | None): Optional unique identifier for the file. + data (CreateFileRequest): The file data to insert. + owner_type (Literal["user", "agent"] | None): Optional type of owner + owner_id (UUID | None): Optional ID of the owner + + Returns: + list[tuple[str, list] | tuple[str, list, str]]: List of SQL queries, their parameters, and fetch type + """ + file_id = file_id or uuid7() + + # Calculate size and hash + content_bytes = base64.b64decode(data.content) + size = len(content_bytes) + hash_bytes = hashlib.sha256(content_bytes).digest() + + # Base file parameters + file_params = [ + developer_id, + file_id, + data.name, + data.description, + data.mime_type, + size, + hash_bytes, + ] + + queries = [] + + # Create the file first + queries.append((file_query, file_params)) + + # Then create the association if owner info provided + if owner_type and owner_id: + assoc_params = [developer_id, file_id, owner_type, owner_id] + queries.append((file_owner_query, assoc_params)) + + return queries diff --git a/agents-api/agents_api/queries/files/delete_file.py b/agents-api/agents_api/queries/files/delete_file.py new file mode 100644 index 000000000..2ab75944d --- /dev/null +++ b/agents-api/agents_api/queries/files/delete_file.py @@ -0,0 +1,77 @@ +""" +This module contains the functionality for deleting files from the PostgreSQL database. +It constructs and executes SQL queries to remove file records and associated data. +""" + +from typing import Literal +from uuid import UUID + +from beartype import beartype + +from ...autogen.openapi_model import ResourceDeletedResponse +from ...common.utils.datetime import utcnow +from ...common.utils.db_exceptions import common_db_exceptions +from ...metrics.counters import increase_counter +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Delete file query with ownership check +delete_file_query = """ +WITH deleted_owners AS ( + DELETE FROM file_owners + WHERE developer_id = $1 + AND file_id = $2 + AND ( + ($3::text IS NULL AND $4::uuid IS NULL) OR + (owner_type = $3 AND owner_id = $4) + ) +) +DELETE FROM files +WHERE developer_id = $1 +AND file_id = $2 +AND ($3::text IS NULL OR EXISTS ( + SELECT 1 FROM file_owners + WHERE developer_id = $1 + AND file_id = $2 + AND owner_type = $3 + AND owner_id = $4 +)) +RETURNING file_id; +""" + + +@rewrap_exceptions(common_db_exceptions("file", ["delete"])) +@wrap_in_class( + ResourceDeletedResponse, + one=True, + transform=lambda d: { + "id": d["file_id"], + "deleted_at": utcnow(), + "jobs": [], + }, +) +@increase_counter("delete_file") +@pg_query +@beartype +async def delete_file( + *, + developer_id: UUID, + file_id: UUID, + owner_type: Literal["user", "agent"] | None = None, + owner_id: UUID | None = None, +) -> tuple[str, list]: + """ + Deletes a file and its ownership records. + + Args: + developer_id: The developer's UUID + file_id: The file's UUID + owner_type: Optional type of owner ("user" or "agent") + owner_id: Optional UUID of the owner + + Returns: + tuple[str, list]: SQL query and parameters + """ + return ( + delete_file_query, + [developer_id, file_id, owner_type, owner_id], + ) diff --git a/agents-api/agents_api/queries/files/get_file.py b/agents-api/agents_api/queries/files/get_file.py new file mode 100644 index 000000000..a8474716c --- /dev/null +++ b/agents-api/agents_api/queries/files/get_file.py @@ -0,0 +1,67 @@ +""" +This module contains the functionality for retrieving a single file from the PostgreSQL database. +It constructs and executes SQL queries to fetch file details based on file ID and developer ID. +""" + +from typing import Literal +from uuid import UUID + +from beartype import beartype + +from ...autogen.openapi_model import File +from ...common.utils.db_exceptions import common_db_exceptions +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL query +file_query = """ +SELECT f.* +FROM files f +LEFT JOIN file_owners fo ON f.developer_id = fo.developer_id AND f.file_id = fo.file_id +WHERE f.developer_id = $1 +AND f.file_id = $2 +AND ( + ($3::text IS NULL AND $4::uuid IS NULL) OR + (fo.owner_type = $3 AND fo.owner_id = $4) +) +LIMIT 1; +""" + + +@rewrap_exceptions(common_db_exceptions("file", ["get"])) +@wrap_in_class( + File, + one=True, + transform=lambda d: { + **d, + "id": d["file_id"], + "hash": d["hash"].hex(), + "content": "DUMMY: NEED TO FETCH CONTENT FROM BLOB STORAGE", + }, +) +@pg_query +@beartype +async def get_file( + *, + file_id: UUID, + developer_id: UUID, + owner_type: Literal["user", "agent"] | None = None, + owner_id: UUID | None = None, +) -> tuple[str, list, Literal["fetchrow", "fetchmany", "fetch"]]: + """ + Constructs the SQL query to retrieve a file's details. + Uses composite index on (developer_id, file_id) for efficient lookup. + + Args: + file_id: The UUID of the file to retrieve + developer_id: The UUID of the developer owning the file + owner_type: Optional type of owner ("user" or "agent") + owner_id: Optional UUID of the owner + + Returns: + tuple[str, list]: SQL query and parameters + """ + return ( + file_query, + [developer_id, file_id, owner_type, owner_id], + "fetchrow", + ) diff --git a/agents-api/agents_api/queries/files/list_files.py b/agents-api/agents_api/queries/files/list_files.py new file mode 100644 index 000000000..954a62b04 --- /dev/null +++ b/agents-api/agents_api/queries/files/list_files.py @@ -0,0 +1,79 @@ +""" +This module contains the functionality for listing files from the PostgreSQL database. +It constructs and executes SQL queries to fetch a list of files based on developer ID with pagination. +""" + +from typing import Literal +from uuid import UUID + +from beartype import beartype +from fastapi import HTTPException + +from ...autogen.openapi_model import File +from ...common.utils.db_exceptions import common_db_exceptions +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Base query for listing files +base_files_query = """ +SELECT f.* +FROM files f +LEFT JOIN file_owners fo ON f.developer_id = fo.developer_id AND f.file_id = fo.file_id +WHERE f.developer_id = $1 +""" + + +@rewrap_exceptions(common_db_exceptions("file", ["list"])) +@wrap_in_class( + File, + one=False, + transform=lambda d: { + **d, + "id": d["file_id"], + "hash": d["hash"].hex(), + "content": "DUMMY: NEED TO FETCH CONTENT FROM BLOB STORAGE", + }, +) +@pg_query +@beartype +async def list_files( + *, + developer_id: UUID, + owner_id: UUID | None = None, + owner_type: Literal["user", "agent"] | None = None, + limit: int = 100, + offset: int = 0, + sort_by: Literal["created_at", "updated_at"] = "created_at", + direction: Literal["asc", "desc"] = "desc", +) -> tuple[str, list]: + """ + Lists files with optional owner filtering, pagination, and sorting. + """ + # Validate parameters + if direction.lower() not in ["asc", "desc"]: + raise HTTPException(status_code=400, detail="Invalid sort direction") + + if sort_by not in ["created_at", "updated_at"]: + raise HTTPException(status_code=400, detail="Invalid sort field") + + if limit > 100 or limit < 1: + raise HTTPException(status_code=400, detail="Limit must be between 1 and 100") + + if offset < 0: + raise HTTPException(status_code=400, detail="Offset must be non-negative") + + # Start with the base query + query = base_files_query + params = [developer_id] + + # Add owner filtering + if owner_type and owner_id: + query += " AND fo.owner_type = $2 AND fo.owner_id = $3" + params.extend([owner_type, owner_id]) + + # Add sorting and pagination + query += ( + f" ORDER BY {sort_by} {direction} LIMIT ${len(params) + 1} OFFSET ${len(params) + 2}" + ) + params.extend([limit, offset]) + + return query, params diff --git a/agents-api/agents_api/queries/sessions/__init__.py b/agents-api/agents_api/queries/sessions/__init__.py new file mode 100644 index 000000000..d0f64ea5e --- /dev/null +++ b/agents-api/agents_api/queries/sessions/__init__.py @@ -0,0 +1,30 @@ +""" +The `sessions` module within the `queries` package provides SQL query functions for managing sessions +in the PostgreSQL database. This includes operations for: + +- Creating new sessions +- Updating existing sessions +- Retrieving session details +- Listing sessions with filtering and pagination +- Deleting sessions +""" + +from .count_sessions import count_sessions +from .create_or_update_session import create_or_update_session +from .create_session import create_session +from .delete_session import delete_session +from .get_session import get_session +from .list_sessions import list_sessions +from .patch_session import patch_session +from .update_session import update_session + +__all__ = [ + "count_sessions", + "create_or_update_session", + "create_session", + "delete_session", + "get_session", + "list_sessions", + "patch_session", + "update_session", +] diff --git a/agents-api/agents_api/queries/sessions/count_sessions.py b/agents-api/agents_api/queries/sessions/count_sessions.py new file mode 100644 index 000000000..eff0d8d29 --- /dev/null +++ b/agents-api/agents_api/queries/sessions/count_sessions.py @@ -0,0 +1,42 @@ +"""This module contains functions for querying session data from the PostgreSQL database.""" + +from uuid import UUID + +from beartype import beartype + +from ...common.utils.db_exceptions import common_db_exceptions +from ...metrics.counters import increase_counter +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL query +session_query = """ +SELECT COUNT(*) +FROM sessions +WHERE developer_id = $1; +""" + + +@rewrap_exceptions(common_db_exceptions("session", ["count"])) +@wrap_in_class(dict, one=True) +@increase_counter("count_sessions") +@pg_query +@beartype +async def count_sessions( + *, + developer_id: UUID, +) -> tuple[str, list]: + """ + Counts sessions from the PostgreSQL database. + Uses the index on developer_id for efficient counting. + + Args: + developer_id (UUID): The developer's ID to filter sessions by. + + Returns: + tuple[str, list]: SQL query and parameters. + """ + + return ( + session_query, + [developer_id], + ) diff --git a/agents-api/agents_api/queries/sessions/create_or_update_session.py b/agents-api/agents_api/queries/sessions/create_or_update_session.py new file mode 100644 index 000000000..3da2126f6 --- /dev/null +++ b/agents-api/agents_api/queries/sessions/create_or_update_session.py @@ -0,0 +1,133 @@ +from uuid import UUID + +from beartype import beartype +from fastapi import HTTPException + +from ...autogen.openapi_model import ( + CreateOrUpdateSessionRequest, + ResourceUpdatedResponse, +) +from ...common.utils.db_exceptions import common_db_exceptions +from ...metrics.counters import increase_counter +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL queries +session_query = """ +INSERT INTO sessions ( + developer_id, + session_id, + situation, + system_template, + metadata, + render_templates, + token_budget, + context_overflow, + forward_tool_calls, + recall_options +) +VALUES ( + $1, -- developer_id + $2, -- session_id + $3, -- situation + $4, -- system_template + $5, -- metadata + $6, -- render_templates + $7, -- token_budget + $8, -- context_overflow + $9, -- forward_tool_calls + $10 -- recall_options +) +ON CONFLICT (developer_id, session_id) DO UPDATE +SET + situation = EXCLUDED.situation, + system_template = EXCLUDED.system_template, + metadata = EXCLUDED.metadata, + render_templates = EXCLUDED.render_templates, + token_budget = EXCLUDED.token_budget, + context_overflow = EXCLUDED.context_overflow, + forward_tool_calls = EXCLUDED.forward_tool_calls, + recall_options = EXCLUDED.recall_options +RETURNING *; +""" + +lookup_query = """ +INSERT INTO session_lookup ( + developer_id, + session_id, + participant_type, + participant_id +) +VALUES ($1, $2, $3, $4) +ON CONFLICT (developer_id, session_id, participant_type, participant_id) DO NOTHING; +""" + + +@rewrap_exceptions(common_db_exceptions("session", ["create", "update"])) +@wrap_in_class( + ResourceUpdatedResponse, + one=True, + transform=lambda d: {"id": d["session_id"], "updated_at": d["updated_at"]}, +) +@increase_counter("create_or_update_session") +@pg_query(return_index=0) +@beartype +async def create_or_update_session( + *, + developer_id: UUID, + session_id: UUID, + data: CreateOrUpdateSessionRequest, +) -> list[tuple[str, list] | tuple[str, list, str]]: + """ + Constructs SQL queries to create or update a session and its participant lookups. + + Args: + developer_id (UUID): The developer's UUID + session_id (UUID): The session's UUID + data (CreateOrUpdateSessionRequest): Session data to insert or update + + Returns: + list[tuple[str, list]]: List of SQL queries and their parameters + """ + # Handle participants + users = data.users or ([data.user] if data.user else []) + agents = data.agents or ([data.agent] if data.agent else []) + + if not agents: + raise HTTPException( + status_code=400, + detail="At least one agent must be provided", + ) + + if data.agent and data.agents: + raise HTTPException( + status_code=400, + detail="Only one of 'agent' or 'agents' should be provided", + ) + + # Prepare participant arrays for lookup query + participant_types = ["user"] * len(users) + ["agent"] * len(agents) + participant_ids = [str(u) for u in users] + [str(a) for a in agents] + + # Prepare session parameters + session_params = [ + developer_id, # $1 + session_id, # $2 + data.situation, # $3 + data.system_template, # $4 + data.metadata or {}, # $5 + data.render_templates, # $6 + data.token_budget, # $7 + data.context_overflow, # $8 + data.forward_tool_calls, # $9 + data.recall_options.model_dump() if data.recall_options else {}, # $10 + ] + + # Prepare lookup parameters + lookup_params = [] + for participant_type, participant_id in zip(participant_types, participant_ids): + lookup_params.append([developer_id, session_id, participant_type, participant_id]) + + return [ + (session_query, session_params, "fetch"), + (lookup_query, lookup_params, "fetchmany"), + ] diff --git a/agents-api/agents_api/queries/sessions/create_session.py b/agents-api/agents_api/queries/sessions/create_session.py new file mode 100644 index 000000000..fe243f252 --- /dev/null +++ b/agents-api/agents_api/queries/sessions/create_session.py @@ -0,0 +1,125 @@ +from uuid import UUID + +from beartype import beartype +from fastapi import HTTPException +from uuid_extensions import uuid7 + +from ...autogen.openapi_model import CreateSessionRequest, ResourceCreatedResponse +from ...common.utils.db_exceptions import common_db_exceptions +from ...metrics.counters import increase_counter +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL queries +session_query = """ +INSERT INTO sessions ( + developer_id, + session_id, + situation, + system_template, + metadata, + render_templates, + token_budget, + context_overflow, + forward_tool_calls, + recall_options +) +VALUES ( + $1, -- developer_id + $2, -- session_id + $3, -- situation + $4, -- system_template + $5, -- metadata + $6, -- render_templates + $7, -- token_budget + $8, -- context_overflow + $9, -- forward_tool_calls + $10 -- recall_options +) +RETURNING *; +""" + +lookup_query = """ +INSERT INTO session_lookup ( + developer_id, + session_id, + participant_type, + participant_id +) +VALUES ($1, $2, $3, $4); +""" + + +@rewrap_exceptions(common_db_exceptions("session", ["create"])) +@wrap_in_class( + ResourceCreatedResponse, + one=True, + transform=lambda d: { + **d, + "id": d["session_id"], + "created_at": d["created_at"], + }, +) +@increase_counter("create_session") +@pg_query(return_index=0) +@beartype +async def create_session( + *, + developer_id: UUID, + session_id: UUID | None = None, + data: CreateSessionRequest, +) -> list[tuple[str, list] | tuple[str, list, str]]: + """ + Constructs SQL queries to create a new session and its participant lookups. + + Args: + developer_id (UUID): The developer's UUID + session_id (UUID): The session's UUID + data (CreateSessionRequest): Session creation data + + Returns: + list[tuple[str, list] | tuple[str, list, str]]: SQL queries and their parameters + """ + # Handle participants + users = data.users or ([data.user] if data.user else []) + agents = data.agents or ([data.agent] if data.agent else []) + session_id = session_id or uuid7() + + if not agents: + raise HTTPException( + status_code=400, + detail="At least one agent must be provided", + ) + + if data.agent and data.agents: + raise HTTPException( + status_code=400, + detail="Only one of 'agent' or 'agents' should be provided", + ) + + # Prepare participant arrays for lookup query + participant_types = ["user"] * len(users) + ["agent"] * len(agents) + participant_ids = [str(u) for u in users] + [str(a) for a in agents] + + # Prepare session parameters + session_params = [ + developer_id, # $1 + session_id, # $2 + data.situation, # $3 + data.system_template, # $4 + data.metadata or {}, # $5 + data.render_templates, # $6 + data.token_budget, # $7 + data.context_overflow, # $8 + data.forward_tool_calls, # $9 + data.recall_options.model_dump() if data.recall_options else {}, # $10 + ] + + # Prepare lookup parameters as a list of parameter lists + lookup_params = [] + for ptype, pid in zip(participant_types, participant_ids): + lookup_params.append([developer_id, session_id, ptype, pid]) + + return [ + (session_query, session_params, "fetch"), + (lookup_query, lookup_params, "fetchmany"), + ] diff --git a/agents-api/agents_api/queries/sessions/delete_session.py b/agents-api/agents_api/queries/sessions/delete_session.py new file mode 100644 index 000000000..fe2e384f4 --- /dev/null +++ b/agents-api/agents_api/queries/sessions/delete_session.py @@ -0,0 +1,59 @@ +"""This module contains the implementation for deleting sessions from the PostgreSQL database.""" + +from uuid import UUID + +from beartype import beartype + +from ...autogen.openapi_model import ResourceDeletedResponse +from ...common.utils.datetime import utcnow +from ...common.utils.db_exceptions import common_db_exceptions +from ...metrics.counters import increase_counter +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL queries +lookup_query = """ +DELETE FROM session_lookup +WHERE developer_id = $1 AND session_id = $2; +""" + +session_query = """ +DELETE FROM sessions +WHERE developer_id = $1 AND session_id = $2 +RETURNING session_id AS id; +""" + + +@rewrap_exceptions(common_db_exceptions("session", ["delete"])) +@wrap_in_class( + ResourceDeletedResponse, + one=True, + transform=lambda d: { + **d, + "deleted_at": utcnow(), + "jobs": [], + }, +) +@increase_counter("delete_session") +@pg_query +@beartype +async def delete_session( + *, + developer_id: UUID, + session_id: UUID, +) -> list[tuple[str, list]]: + """ + Constructs SQL queries to delete a session and its participant lookups. + + Args: + developer_id (UUID): The developer's UUID + session_id (UUID): The session's UUID to delete + + Returns: + list[tuple[str, list]]: List of SQL queries and their parameters + """ + params = [developer_id, session_id] + + return [ + (lookup_query, params), # Delete from lookup table first due to FK constraint + (session_query, params), # Then delete from sessions table + ] diff --git a/agents-api/agents_api/queries/sessions/get_session.py b/agents-api/agents_api/queries/sessions/get_session.py new file mode 100644 index 000000000..d7b261534 --- /dev/null +++ b/agents-api/agents_api/queries/sessions/get_session.py @@ -0,0 +1,67 @@ +"""This module contains functions for retrieving session data from the PostgreSQL database.""" + +from uuid import UUID + +from beartype import beartype + +from ...autogen.openapi_model import Session +from ...common.utils.db_exceptions import common_db_exceptions +from ...metrics.counters import increase_counter +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL query +query = """ +WITH session_participants AS ( + SELECT + sl.session_id, + array_agg(sl.participant_id) FILTER (WHERE sl.participant_type = 'agent') as agents, + array_agg(sl.participant_id) FILTER (WHERE sl.participant_type = 'user') as users + FROM session_lookup sl + WHERE sl.developer_id = $1 AND sl.session_id = $2 + GROUP BY sl.session_id +) +SELECT + s.session_id as id, + s.developer_id, + s.situation, + s.system_template, + s.metadata, + s.render_templates, + s.token_budget, + s.context_overflow, + s.forward_tool_calls, + s.recall_options, + s.created_at, + s.updated_at, + sp.agents, + sp.users +FROM sessions s +LEFT JOIN session_participants sp ON s.session_id = sp.session_id +WHERE s.developer_id = $1 AND s.session_id = $2; +""" + + +@rewrap_exceptions(common_db_exceptions("session", ["get"])) +@wrap_in_class(Session, one=True) +@increase_counter("get_session") +@pg_query +@beartype +async def get_session( + *, + developer_id: UUID, + session_id: UUID, +) -> tuple[str, list]: + """ + Constructs SQL query to retrieve a session and its participants. + + Args: + developer_id (UUID): The developer's UUID + session_id (UUID): The session's UUID + + Returns: + tuple[str, list]: SQL query and parameters + """ + return ( + query, + [developer_id, session_id], + ) diff --git a/agents-api/agents_api/queries/sessions/list_sessions.py b/agents-api/agents_api/queries/sessions/list_sessions.py new file mode 100644 index 000000000..08d919ed3 --- /dev/null +++ b/agents-api/agents_api/queries/sessions/list_sessions.py @@ -0,0 +1,90 @@ +"""This module contains functions for querying session data from the PostgreSQL database.""" + +from typing import Any, Literal +from uuid import UUID + +from beartype import beartype + +from ...autogen.openapi_model import Session +from ...common.utils.db_exceptions import common_db_exceptions +from ...metrics.counters import increase_counter +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL query +session_query = """ +WITH session_participants AS ( + SELECT + sl.session_id, + array_agg(sl.participant_id) FILTER (WHERE sl.participant_type = 'agent') as agents, + array_agg(sl.participant_id) FILTER (WHERE sl.participant_type = 'user') as users + FROM session_lookup sl + WHERE sl.developer_id = $1 + GROUP BY sl.session_id +) +SELECT + s.session_id as id, + s.developer_id, + s.situation, + s.system_template, + s.metadata, + s.render_templates, + s.token_budget, + s.context_overflow, + s.forward_tool_calls, + s.recall_options, + s.created_at, + s.updated_at, + sp.agents, + sp.users +FROM sessions s +LEFT JOIN session_participants sp ON s.session_id = sp.session_id +WHERE s.developer_id = $1 + AND ($5::jsonb IS NULL OR s.metadata @> $5::jsonb) +ORDER BY + CASE WHEN $3 = 'created_at' AND $4 = 'desc' THEN s.created_at END DESC, + CASE WHEN $3 = 'created_at' AND $4 = 'asc' THEN s.created_at END ASC, + CASE WHEN $3 = 'updated_at' AND $4 = 'desc' THEN s.updated_at END DESC, + CASE WHEN $3 = 'updated_at' AND $4 = 'asc' THEN s.updated_at END ASC +LIMIT $2 OFFSET $6; +""" + + +@rewrap_exceptions(common_db_exceptions("session", ["list"])) +@wrap_in_class(Session) +@increase_counter("list_sessions") +@pg_query +@beartype +async def list_sessions( + *, + developer_id: UUID, + limit: int = 100, + offset: int = 0, + sort_by: Literal["created_at", "updated_at"] = "created_at", + direction: Literal["asc", "desc"] = "desc", + metadata_filter: dict[str, Any] = {}, +) -> tuple[str, list]: + """ + Lists sessions from the PostgreSQL database based on the provided filters. + + Args: + developer_id (UUID): The developer's UUID + limit (int): Maximum number of sessions to return + offset (int): Number of sessions to skip + sort_by (str): Field to sort by ('created_at' or 'updated_at') + direction (str): Sort direction ('asc' or 'desc') + metadata_filter (dict): Dictionary of metadata fields to filter by + + Returns: + tuple[str, list]: SQL query and parameters + """ + return ( + session_query, + [ + developer_id, # $1 + limit, # $2 + sort_by, # $3 + direction, # $4 + metadata_filter or None, # $5 + offset, # $6 + ], + ) diff --git a/agents-api/agents_api/queries/sessions/patch_session.py b/agents-api/agents_api/queries/sessions/patch_session.py new file mode 100644 index 000000000..fe6848959 --- /dev/null +++ b/agents-api/agents_api/queries/sessions/patch_session.py @@ -0,0 +1,71 @@ +from uuid import UUID + +from beartype import beartype + +from ...autogen.openapi_model import PatchSessionRequest, ResourceUpdatedResponse +from ...common.utils.db_exceptions import common_db_exceptions +from ...metrics.counters import increase_counter +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL queries +# Build dynamic SET clause based on provided fields +session_query = """ +UPDATE sessions +SET + situation = COALESCE($3, situation), + system_template = COALESCE($4, system_template), + metadata = sessions.metadata || $5, + render_templates = COALESCE($6, render_templates), + token_budget = COALESCE($7, token_budget), + context_overflow = COALESCE($8, context_overflow), + forward_tool_calls = COALESCE($9, forward_tool_calls), + recall_options = sessions.recall_options || $10 +WHERE + developer_id = $1 + AND session_id = $2 +RETURNING * +""" + + +@rewrap_exceptions(common_db_exceptions("session", ["patch"])) +@wrap_in_class( + ResourceUpdatedResponse, + one=True, + transform=lambda d: {"id": d["session_id"], "updated_at": d["updated_at"]}, +) +@increase_counter("patch_session") +@pg_query +@beartype +async def patch_session( + *, + developer_id: UUID, + session_id: UUID, + data: PatchSessionRequest, +) -> list[tuple[str, list]]: + """ + Constructs SQL queries to patch a session and its participant lookups. + + Args: + developer_id (UUID): The developer's UUID + session_id (UUID): The session's UUID + data (PatchSessionRequest): Session patch data + + Returns: + list[tuple[str, list]]: List of SQL queries and their parameters + """ + + # Extract fields from data, using None for unset fields + session_params = [ + developer_id, # $1 + session_id, # $2 + data.situation, # $3 + data.system_template, # $4 + data.metadata or {}, # $5 + data.render_templates, # $6 + data.token_budget, # $7 + data.context_overflow, # $8 + data.forward_tool_calls, # $9 + data.recall_options.model_dump() if data.recall_options else {}, # $10 + ] + + return [(session_query, session_params)] diff --git a/agents-api/agents_api/queries/sessions/update_session.py b/agents-api/agents_api/queries/sessions/update_session.py new file mode 100644 index 000000000..6ad90bef3 --- /dev/null +++ b/agents-api/agents_api/queries/sessions/update_session.py @@ -0,0 +1,72 @@ +from uuid import UUID + +from beartype import beartype + +from ...autogen.openapi_model import ResourceUpdatedResponse, UpdateSessionRequest +from ...common.utils.db_exceptions import common_db_exceptions +from ...metrics.counters import increase_counter +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL query +session_query = """ +UPDATE sessions +SET + situation = $3, + system_template = $4, + metadata = $5, + render_templates = $6, + token_budget = $7, + context_overflow = $8, + forward_tool_calls = $9, + recall_options = $10 +WHERE developer_id = $1 AND session_id = $2 +RETURNING *; +""" + + +@rewrap_exceptions(common_db_exceptions("session", ["update"])) +@wrap_in_class( + ResourceUpdatedResponse, + one=True, + transform=lambda d: { + "id": d["session_id"], + "updated_at": d["updated_at"], + }, +) +@increase_counter("update_session") +@pg_query +@beartype +async def update_session( + *, + developer_id: UUID, + session_id: UUID, + data: UpdateSessionRequest, +) -> list[tuple[str, list]]: + """ + Constructs SQL queries to update a session and its participant lookups. + + Args: + developer_id (UUID): The developer's UUID + session_id (UUID): The session's UUID + data (UpdateSessionRequest): Session update data + + Returns: + list[tuple[str, list]]: List of SQL queries and their parameters + """ + # Prepare session parameters + session_params = [ + developer_id, # $1 + session_id, # $2 + data.situation, # $3 + data.system_template, # $4 + data.metadata or {}, # $5 + data.render_templates, # $6 + data.token_budget, # $7 + data.context_overflow, # $8 + data.forward_tool_calls, # $9 + data.recall_options.model_dump() if data.recall_options else {}, # $10 + ] + + return [ + (session_query, session_params), + ] diff --git a/agents-api/agents_api/queries/tasks/__init__.py b/agents-api/agents_api/queries/tasks/__init__.py new file mode 100644 index 000000000..63b4bed22 --- /dev/null +++ b/agents-api/agents_api/queries/tasks/__init__.py @@ -0,0 +1,28 @@ +""" +The `task` module within the `queries` package provides SQL query functions for managing tasks +in the TimescaleDB database. This includes operations for: + +- Creating new tasks +- Updating existing tasks +- Retrieving task details +- Listing tasks with filtering and pagination +- Deleting tasks +""" + +from .create_or_update_task import create_or_update_task +from .create_task import create_task +from .delete_task import delete_task +from .get_task import get_task +from .list_tasks import list_tasks +from .patch_task import patch_task +from .update_task import update_task + +__all__ = [ + "create_or_update_task", + "create_task", + "delete_task", + "get_task", + "list_tasks", + "patch_task", + "update_task", +] diff --git a/agents-api/agents_api/queries/tasks/create_or_update_task.py b/agents-api/agents_api/queries/tasks/create_or_update_task.py new file mode 100644 index 000000000..11c1924c0 --- /dev/null +++ b/agents-api/agents_api/queries/tasks/create_or_update_task.py @@ -0,0 +1,225 @@ +from typing import Literal +from uuid import UUID + +from beartype import beartype +from uuid_extensions import uuid7 + +from ...autogen.openapi_model import CreateOrUpdateTaskRequest, ResourceUpdatedResponse +from ...common.protocol.tasks import task_to_spec +from ...common.utils.db_exceptions import common_db_exceptions +from ...metrics.counters import increase_counter +from ..utils import generate_canonical_name, pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL query for creating or updating a task +tools_query = """ +INSERT INTO tools ( + developer_id, + agent_id, + task_id, + tool_id, + type, + name, + description, + spec +) +VALUES ( + $1, -- developer_id + $2, -- agent_id + $3, -- task_id + $4, -- tool_id + $5, -- type + $6, -- name + $7, -- description + $8 -- spec +) +ON CONFLICT (agent_id, task_id, name) DO UPDATE SET + type = EXCLUDED.type, + description = EXCLUDED.description, + spec = EXCLUDED.spec +RETURNING *; +""" + +# Define the raw SQL query for creating or updating a task +task_query = """ +WITH current_version AS ( + SELECT COALESCE( + (SELECT MAX("version") + FROM tasks + WHERE developer_id = $1 + AND task_id = $4), + 0 + ) + 1 as next_version, + COALESCE( + (SELECT canonical_name + FROM tasks + WHERE developer_id = $1 AND task_id = $4 + ORDER BY version DESC + LIMIT 1), + $2 + ) as effective_canonical_name + FROM (SELECT 1) as dummy +) +INSERT INTO tasks ( + "version", + developer_id, + canonical_name, + agent_id, + task_id, + name, + description, + inherit_tools, + input_schema, + metadata +) +SELECT + next_version, -- version + $1, -- developer_id + effective_canonical_name, -- canonical_name + $3, -- agent_id + $4, -- task_id + $5, -- name + $6, -- description + $7, -- inherit_tools + $8::jsonb, -- input_schema + $9::jsonb -- metadata +FROM current_version +ON CONFLICT (developer_id, task_id, "version") DO UPDATE SET + version = tasks.version + 1, + name = EXCLUDED.name, + description = EXCLUDED.description, + inherit_tools = EXCLUDED.inherit_tools, + input_schema = EXCLUDED.input_schema, + metadata = EXCLUDED.metadata +RETURNING *, (SELECT next_version FROM current_version) as next_version; +""" + +# Define the raw SQL query for inserting workflows +workflows_query = """ +WITH version AS ( + SELECT COALESCE(MAX("version"), 0) as current_version + FROM tasks + WHERE developer_id = $1 + AND task_id = $2 +) +INSERT INTO workflows ( + developer_id, + task_id, + "version", + name, + step_idx, + step_type, + step_definition +) +SELECT + $1, -- developer_id + $2, -- task_id + current_version, -- version + $3, -- name + $4, -- step_idx + $5, -- step_type + $6 -- step_definition +FROM version; +""" + + +@rewrap_exceptions(common_db_exceptions("task", ["create_or_update"])) +@wrap_in_class( + ResourceUpdatedResponse, + one=True, + transform=lambda d: { + "id": d["task_id"], + "updated_at": d["updated_at"].timestamp(), + **d, + }, +) +@increase_counter("create_or_update_task") +@pg_query(return_index=0) +@beartype +async def create_or_update_task( + *, + developer_id: UUID, + agent_id: UUID, + task_id: UUID, + data: CreateOrUpdateTaskRequest, +) -> list[tuple[str, list, Literal["fetch", "fetchmany"]]]: + """ + Constructs an SQL query to create or update a task. + + Args: + developer_id (UUID): The UUID of the developer. + agent_id (UUID): The UUID of the agent. + task_id (UUID): The UUID of the task. + data (CreateOrUpdateTaskRequest): The task data to insert or update. + + Returns: + list[tuple[str, list, Literal["fetch", "fetchmany"]]]: List of SQL queries and parameters. + + Raises: + HTTPException: If developer/agent doesn't exist (404) or on unique constraint violation (409) + """ + + # Generate canonical name from task name if not provided + canonical_name = data.canonical_name or generate_canonical_name() + + # Version will be determined by the CTE + task_params = [ + developer_id, # $1 + canonical_name, # $2 + agent_id, # $3 + task_id, # $4 + data.name, # $5 + data.description, # $6 + data.inherit_tools, # $7 + data.input_schema or {}, # $8 + data.metadata or {}, # $9 + ] + + # Prepare tool parameters for the tools table + tool_params = [ + [ + developer_id, + agent_id, + task_id, + uuid7(), # tool_id + tool.type, + tool.name, + tool.description, + getattr(tool, tool.type) + and getattr(tool, tool.type).model_dump(mode="json"), # spec + ] + for tool in data.tools or [] + ] + + # Generate workflows from task data using task_to_spec + workflows_spec = task_to_spec(data).model_dump(mode="json") + workflow_params = [] + for workflow in workflows_spec.get("workflows", []): + workflow_name = workflow.get("name") + steps = workflow.get("steps", []) + for step_idx, step in enumerate(steps): + workflow_params.append([ + developer_id, # $1 + task_id, # $2 + workflow_name, # $3 + step_idx, # $4 + step["kind_"], # $5 + step, # $6 + ]) + + return [ + ( + task_query, + task_params, + "fetch", + ), + ( + tools_query, + tool_params, + "fetchmany", + ), + ( + workflows_query, + workflow_params, + "fetchmany", + ), + ] diff --git a/agents-api/agents_api/queries/tasks/create_task.py b/agents-api/agents_api/queries/tasks/create_task.py new file mode 100644 index 000000000..c96732c68 --- /dev/null +++ b/agents-api/agents_api/queries/tasks/create_task.py @@ -0,0 +1,193 @@ +from typing import Literal +from uuid import UUID + +from beartype import beartype +from uuid_extensions import uuid7 + +from ...autogen.openapi_model import CreateTaskRequest, ResourceCreatedResponse +from ...common.protocol.tasks import task_to_spec +from ...common.utils.db_exceptions import common_db_exceptions +from ...metrics.counters import increase_counter +from ..utils import ( + generate_canonical_name, + pg_query, + rewrap_exceptions, + wrap_in_class, +) + +# Define the raw SQL query for creating or updating a task +tools_query = """ +INSERT INTO tools ( + developer_id, + agent_id, + task_id, + tool_id, + type, + name, + description, + spec +) +VALUES ( + $1, -- developer_id + $2, -- agent_id + $3, -- task_id + $4, -- tool_id + $5, -- type + $6, -- name + $7, -- description + $8 -- spec +) +""" + +task_query = """ +INSERT INTO tasks ( + "version", + developer_id, + agent_id, + task_id, + name, + canonical_name, + description, + inherit_tools, + input_schema, + metadata +) +VALUES ( + 1, -- version + $1, -- developer_id + $2, -- agent_id + $3, -- task_id + $4, -- name + $5, -- canonical_name + $6, -- description + $7, -- inherit_tools + $8::jsonb, -- input_schema + $9::jsonb -- metadata +) +RETURNING * +""" + +# Define the raw SQL query for inserting workflows +workflows_query = """ +INSERT INTO workflows ( + developer_id, + task_id, + "version", + name, + step_idx, + step_type, + step_definition +) +VALUES ( + $1, -- developer_id + $2, -- task_id + $3, -- version + $4, -- name + $5, -- step_idx + $6, -- step_type + $7 -- step_definition +) +""" + + +@rewrap_exceptions(common_db_exceptions("task", ["create"])) +@wrap_in_class( + ResourceCreatedResponse, + one=True, + transform=lambda d: { + "id": d["task_id"], + "jobs": [], + # "updated_at": d["updated_at"].timestamp(), + **d, + }, +) +@increase_counter("create_task") +@pg_query(return_index=0) +@beartype +async def create_task( + *, + developer_id: UUID, + agent_id: UUID, + task_id: UUID | None = None, + data: CreateTaskRequest, +) -> list[tuple[str, list, Literal["fetch", "fetchmany"]]]: + """ + Constructs SQL queries to create or update a task along with its associated tools and workflows. + + Args: + developer_id (UUID): The UUID of the developer. + agent_id (UUID): The UUID of the agent. + task_id (UUID, optional): The UUID of the task. If not provided, a new UUID is generated. + data (CreateTaskRequest): The task data to insert or update. + + Returns: + tuple[str, list]: SQL query and parameters. + + Raises: + HTTPException: If developer/agent doesn't exist (404) or on unique constraint violation (409) + """ + task_id = task_id or uuid7() + + # Insert parameters for the tasks table + task_params = [ + developer_id, # $1 + agent_id, # $2 + task_id, # $3 + data.name, # $4 + data.canonical_name or generate_canonical_name(), # $5 + data.description, # $6 + data.inherit_tools, # $7 + data.input_schema or {}, # $8 + data.metadata or {}, # $9 + ] + + # Prepare tool parameters for the tools table + tool_params = [ + [ + developer_id, + agent_id, + task_id, + uuid7(), # tool_id + tool.type, + tool.name, + tool.description, + getattr(tool, tool.type) + and getattr(tool, tool.type).model_dump(mode="json"), # spec + ] + for tool in data.tools or [] + ] + + # Generate workflows from task data using task_to_spec + workflows_spec = task_to_spec(data).model_dump(mode="json") + workflow_params = [] + for workflow in workflows_spec.get("workflows", []): + workflow_name = workflow.get("name") + steps = workflow.get("steps", []) + for step_idx, step in enumerate(steps): + workflow_params.append([ + developer_id, # $1 + task_id, # $2 + 1, # $3 (version) + workflow_name, # $4 + step_idx, # $5 + step["kind_"], # $6 + step, # $7 + ]) + + return [ + ( + task_query, + task_params, + "fetch", + ), + ( + tools_query, + tool_params, + "fetchmany", + ), + ( + workflows_query, + workflow_params, + "fetchmany", + ), + ] diff --git a/agents-api/agents_api/queries/tasks/delete_task.py b/agents-api/agents_api/queries/tasks/delete_task.py new file mode 100644 index 000000000..bb2907618 --- /dev/null +++ b/agents-api/agents_api/queries/tasks/delete_task.py @@ -0,0 +1,58 @@ +from typing import Literal +from uuid import UUID + +from beartype import beartype + +from ...autogen.openapi_model import ResourceDeletedResponse +from ...common.utils.datetime import utcnow +from ...common.utils.db_exceptions import common_db_exceptions +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL query for deleting workflows +workflow_query = """ +DELETE FROM workflows +WHERE developer_id = $1 AND task_id = $2; +""" + +# Define the raw SQL query for deleting tasks +task_query = """ +DELETE FROM tasks +WHERE developer_id = $1 AND task_id = $2 +RETURNING task_id; +""" + + +@rewrap_exceptions(common_db_exceptions("task", ["delete"])) +@wrap_in_class( + ResourceDeletedResponse, + one=True, + transform=lambda d: { + "id": d["task_id"], + "deleted_at": utcnow(), + }, +) +@pg_query +@beartype +async def delete_task( + *, + developer_id: UUID, + task_id: UUID, +) -> list[tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]]: + """ + Deletes a task by its unique identifier along with its associated workflows. + + Parameters: + developer_id (UUID): The unique identifier of the developer associated with the task. + task_id (UUID): The unique identifier of the task to delete. + + Returns: + tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]: SQL query, parameters, and fetch method. + + Raises: + HTTPException: If developer/agent doesn't exist (404) or on unique constraint violation (409) + """ + + return [ + (workflow_query, [developer_id, task_id], "fetch"), + (task_query, [developer_id, task_id], "fetchrow"), + ] diff --git a/agents-api/agents_api/queries/tasks/get_task.py b/agents-api/agents_api/queries/tasks/get_task.py new file mode 100644 index 000000000..0089c6719 --- /dev/null +++ b/agents-api/agents_api/queries/tasks/get_task.py @@ -0,0 +1,78 @@ +from typing import Literal +from uuid import UUID + +from beartype import beartype + +from ...common.protocol.tasks import spec_to_task +from ...common.utils.db_exceptions import common_db_exceptions +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL query for getting a task +get_task_query = """ +SELECT + t.*, + COALESCE( + jsonb_agg( + DISTINCT jsonb_build_object( + 'name', w.name, + 'steps', ( + SELECT jsonb_agg(step_definition ORDER BY step_idx) + FROM workflows w2 + WHERE w2.developer_id = w.developer_id + AND w2.task_id = w.task_id + AND w2.version = w.version + AND w2.name = w.name + ) + ) + ) FILTER (WHERE w.name IS NOT NULL), + '[]'::jsonb + ) as workflows, + COALESCE( + jsonb_agg(tl) FILTER (WHERE tl IS NOT NULL), + '[]'::jsonb + ) as tools +FROM + tasks t +LEFT JOIN + workflows w ON t.developer_id = w.developer_id AND t.task_id = w.task_id AND t.version = w.version +LEFT JOIN + tools tl ON t.developer_id = tl.developer_id AND t.task_id = tl.task_id +WHERE + t.developer_id = $1 AND t.task_id = $2 + AND t.version = ( + SELECT MAX(version) + FROM tasks + WHERE developer_id = $1 AND task_id = $2 + ) +GROUP BY t.developer_id, t.task_id, t.canonical_name, t.agent_id, t.version; +""" + + +@rewrap_exceptions(common_db_exceptions("task", ["get"])) +@wrap_in_class(spec_to_task, one=True) +@pg_query +@beartype +async def get_task( + *, + developer_id: UUID, + task_id: UUID, +) -> tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]: + """ + Retrieves a task by its unique identifier along with its associated workflows. + + Parameters: + developer_id (UUID): The unique identifier of the developer associated with the task. + task_id (UUID): The unique identifier of the task to retrieve. + + Returns: + tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]: SQL query, parameters, and fetch method. + + Raises: + HTTPException: If developer/agent doesn't exist (404) or on unique constraint violation (409) + """ + + return ( + get_task_query, + [developer_id, task_id], + "fetchrow", + ) diff --git a/agents-api/agents_api/queries/tasks/list_tasks.py b/agents-api/agents_api/queries/tasks/list_tasks.py new file mode 100644 index 000000000..a1db13373 --- /dev/null +++ b/agents-api/agents_api/queries/tasks/list_tasks.py @@ -0,0 +1,103 @@ +from typing import Any, Literal +from uuid import UUID + +from beartype import beartype +from fastapi import HTTPException + +from ...common.protocol.tasks import spec_to_task +from ...common.utils.db_exceptions import common_db_exceptions +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL query for listing tasks +list_tasks_query = """ +SELECT + t.*, + COALESCE( + jsonb_agg( + CASE WHEN w.name IS NOT NULL THEN + jsonb_build_object( + 'name', w.name, + 'steps', jsonb_build_array(w.step_definition) + ) + END + ) FILTER (WHERE w.name IS NOT NULL), + '[]'::jsonb + ) as workflows +FROM + tasks t +LEFT JOIN + workflows w ON t.developer_id = w.developer_id AND t.task_id = w.task_id AND t.version = w.version +WHERE + t.developer_id = $1 + AND t.agent_id = $2 + {metadata_filter_query} +GROUP BY t.developer_id, t.task_id, t.canonical_name, t.agent_id, t.version +ORDER BY + CASE WHEN $5 = 'created_at' AND $6 = 'asc' THEN t.created_at END ASC NULLS LAST, + CASE WHEN $5 = 'created_at' AND $6 = 'desc' THEN t.created_at END DESC NULLS LAST, + CASE WHEN $5 = 'updated_at' AND $6 = 'asc' THEN t.updated_at END ASC NULLS LAST, + CASE WHEN $5 = 'updated_at' AND $6 = 'desc' THEN t.updated_at END DESC NULLS LAST +LIMIT $3 OFFSET $4; +""" + + +@rewrap_exceptions(common_db_exceptions("task", ["list"])) +@wrap_in_class(spec_to_task) +@pg_query +@beartype +async def list_tasks( + *, + developer_id: UUID, + agent_id: UUID, + limit: int = 100, + offset: int = 0, + sort_by: Literal["created_at", "updated_at"] = "created_at", + direction: Literal["asc", "desc"] = "desc", + metadata_filter: dict[str, Any] = {}, +) -> tuple[str, list]: + """ + Retrieves all tasks for a given developer with pagination and sorting. + + Parameters: + developer_id (UUID): The unique identifier of the developer. + agent_id (UUID): The unique identifier of the agent. + limit (int): Maximum number of records to return (default: 100) + offset (int): Number of records to skip (default: 0) + sort_by (str): Field to sort by ("created_at" or "updated_at") + direction (str): Sort direction ("asc" or "desc") + metadata_filter (dict): Optional metadata filters + + Returns: + tuple[str, list]: SQL query and parameters. + + Raises: + HTTPException: If parameters are invalid or developer/agent doesn't exist + """ + if direction.lower() not in ["asc", "desc"]: + raise HTTPException(status_code=400, detail="Invalid sort direction") + + if limit > 100 or limit < 1: + raise HTTPException(status_code=400, detail="Limit must be between 1 and 100") + + if offset < 0: + raise HTTPException(status_code=400, detail="Offset must be non-negative") + + # Format query with metadata filter if needed + query = list_tasks_query.format( + metadata_filter_query="AND metadata @> $7::jsonb" if metadata_filter else "" + ) + + # Build parameters list + params = [ + developer_id, + agent_id, + limit, + offset, + sort_by, + direction, + ] + + if metadata_filter: + params.append(metadata_filter) + + return (query, params) diff --git a/agents-api/agents_api/queries/tasks/patch_task.py b/agents-api/agents_api/queries/tasks/patch_task.py new file mode 100644 index 000000000..d9fe44aa7 --- /dev/null +++ b/agents-api/agents_api/queries/tasks/patch_task.py @@ -0,0 +1,187 @@ +from typing import Literal +from uuid import UUID + +from beartype import beartype + +from ...autogen.openapi_model import PatchTaskRequest, ResourceUpdatedResponse +from ...common.protocol.tasks import task_to_spec +from ...common.utils.datetime import utcnow +from ...common.utils.db_exceptions import common_db_exceptions +from ...metrics.counters import increase_counter +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Update task query using INSERT with version increment +patch_task_query = """ +WITH current_version AS ( + SELECT MAX("version") as current_version, + canonical_name as existing_canonical_name, + metadata as existing_metadata, + name as existing_name, + description as existing_description, + inherit_tools as existing_inherit_tools, + input_schema as existing_input_schema + FROM tasks + WHERE developer_id = $1 + AND task_id = $3 + GROUP BY canonical_name, metadata, name, description, inherit_tools, input_schema + HAVING MAX("version") IS NOT NULL -- This ensures we only proceed if a version exists +) +INSERT INTO tasks ( + "version", + developer_id, -- $1 + canonical_name, -- $2 + task_id, -- $3 + agent_id, -- $4 + metadata, -- $5 + name, -- $6 + description, -- $7 + inherit_tools, -- $8 + input_schema -- $9 +) +SELECT + current_version + 1, -- version + $1, -- developer_id + COALESCE($2, existing_canonical_name), -- canonical_name + $3, -- task_id + $4, -- agent_id + COALESCE($5::jsonb, existing_metadata), -- metadata + COALESCE($6, existing_name), -- name + COALESCE($7, existing_description), -- description + COALESCE($8, existing_inherit_tools), -- inherit_tools + COALESCE($9::jsonb, existing_input_schema) -- input_schema +FROM current_version +RETURNING *; +""" + +# When main is None - just copy existing workflows with new version +copy_workflows_query = """ +WITH current_version AS ( + SELECT MAX(version) - 1 as current_version + FROM tasks + WHERE developer_id = $1 AND task_id = $2 +) +INSERT INTO workflows ( + developer_id, + task_id, + version, + name, + step_idx, + step_type, + step_definition +) +SELECT + developer_id, + task_id, + (SELECT current_version + 1 FROM current_version), -- new version + name, + step_idx, + step_type, + step_definition +FROM workflows +WHERE developer_id = $1 +AND task_id = $2 +AND version = (SELECT current_version FROM current_version) +""" + +# When main is provided - create new workflows (existing query) +new_workflows_query = """ +WITH current_version AS ( + SELECT COALESCE(MAX(version), 0) - 1 as next_version + FROM tasks + WHERE developer_id = $1 AND task_id = $2 +) +INSERT INTO workflows ( + developer_id, + task_id, + version, + name, + step_idx, + step_type, + step_definition +) +SELECT + $1, -- developer_id + $2, -- task_id + next_version + 1, -- version + $3, -- name + $4, -- step_idx + $5, -- step_type + $6 -- step_definition +FROM current_version +""" + + +@rewrap_exceptions(common_db_exceptions("task", ["patch"])) +@wrap_in_class( + ResourceUpdatedResponse, + one=True, + transform=lambda d: {"id": d["task_id"], "updated_at": utcnow()}, +) +@increase_counter("patch_task") +@pg_query(return_index=0) +@beartype +async def patch_task( + *, + developer_id: UUID, + task_id: UUID, + agent_id: UUID, + data: PatchTaskRequest, +) -> list[tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]]: + """ + Updates a task and its associated workflows with version control. + Only updates the fields that are provided in the request. + + Parameters: + developer_id (UUID): The unique identifier of the developer. + task_id (UUID): The unique identifier of the task to update. + data (PatchTaskRequest): The partial update data. + agent_id (UUID): The unique identifier of the agent. + Returns: + list[tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]]: List of queries to execute. + """ + # Parameters for patching the task + + patch_task_params = [ + developer_id, # $1 + data.canonical_name, # $2 + task_id, # $3 + agent_id, # $4 + data.metadata or None, # $5 + data.name or None, # $6 + data.description or None, # $7 + data.inherit_tools, # $8 + data.input_schema, # $9 + ] + + if data.main is None: + workflow_query = copy_workflows_query + workflow_params = [[developer_id, task_id]] # Only need these params + else: + workflow_query = new_workflows_query + workflow_params = [] + workflows_spec = task_to_spec(data).model_dump(mode="json") + for workflow in workflows_spec.get("workflows", []): + workflow_name = workflow.get("name") + steps = workflow.get("steps", []) + for step_idx, step in enumerate(steps): + workflow_params.append([ + developer_id, # $1 + task_id, # $2 + workflow_name, # $3 + step_idx, # $4 + step["kind_"], # $5 + step, # $6 + ]) + + return [ + ( + patch_task_query, + patch_task_params, + "fetchrow", + ), + ( + workflow_query, + workflow_params, + "fetchmany", + ), + ] diff --git a/agents-api/agents_api/queries/tasks/update_task.py b/agents-api/agents_api/queries/tasks/update_task.py new file mode 100644 index 000000000..c905598e3 --- /dev/null +++ b/agents-api/agents_api/queries/tasks/update_task.py @@ -0,0 +1,150 @@ +from typing import Literal +from uuid import UUID + +from beartype import beartype + +from ...autogen.openapi_model import ResourceUpdatedResponse, UpdateTaskRequest +from ...common.protocol.tasks import task_to_spec +from ...common.utils.datetime import utcnow +from ...common.utils.db_exceptions import common_db_exceptions +from ...metrics.counters import increase_counter +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Update task query using INSERT with version increment +update_task_query = """ +WITH current_version AS ( + SELECT MAX("version") as current_version, + canonical_name as existing_canonical_name + FROM tasks + WHERE developer_id = $1 + AND task_id = $3 + GROUP BY task_id, canonical_name + HAVING MAX("version") IS NOT NULL -- This ensures we only proceed if a version exists +) +INSERT INTO tasks ( + "version", + developer_id, -- $1 + canonical_name, -- $2 + task_id, -- $3 + agent_id, -- $4 + metadata, -- $5 + name, -- $6 + description, -- $7 + inherit_tools, -- $8 + input_schema -- $9 +) +SELECT + current_version + 1, -- version + $1, -- developer_id + COALESCE($2, existing_canonical_name), -- canonical_name + $3, -- task_id + $4, -- agent_id + $5::jsonb, -- metadata + $6, -- name + $7, -- description + $8, -- inherit_tools + $9::jsonb -- input_schema +FROM current_version +RETURNING *; +""" + +# Update workflows query to use UPDATE instead of INSERT +workflows_query = """ +WITH version AS ( + SELECT COALESCE(MAX(version), 0) as current_version + FROM tasks + WHERE developer_id = $1 AND task_id = $2 +) +INSERT INTO workflows ( + developer_id, + task_id, + version, + name, + step_idx, + step_type, + step_definition +) +SELECT + $1, -- developer_id + $2, -- task_id + current_version, -- version (from CTE) + $3, -- name + $4, -- step_idx + $5, -- step_type + $6 -- step_definition +FROM version; +""" + + +@rewrap_exceptions(common_db_exceptions("task", ["update"])) +@wrap_in_class( + ResourceUpdatedResponse, + one=True, + transform=lambda d: { + "id": d["task_id"], + "updated_at": utcnow(), + "jobs": [], + }, +) +@increase_counter("update_task") +@pg_query(return_index=0) +@beartype +async def update_task( + *, + developer_id: UUID, + task_id: UUID, + agent_id: UUID, + data: UpdateTaskRequest, +) -> list[tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]]: + """ + Updates a task and its associated workflows with version control. + + Parameters: + developer_id (UUID): The unique identifier of the developer. + task_id (UUID): The unique identifier of the task to update. + data (UpdateTaskRequest): The update data. + agent_id (UUID): The unique identifier of the agent. + Returns: + list[tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]]: List of queries to execute. + """ + # Parameters for updating the task + update_task_params = [ + developer_id, # $1 + data.canonical_name, # $2 + task_id, # $3 + agent_id, # $4 + data.metadata or {}, # $5 + data.name, # $6 + data.description, # $7 + data.inherit_tools, # $8 + data.input_schema or {}, # $9 + ] + + # Generate workflows from task data + workflows_spec = task_to_spec(data).model_dump(mode="json") + workflow_params = [] + for workflow in workflows_spec.get("workflows", []): + workflow_name = workflow.get("name") + steps = workflow.get("steps", []) + for step_idx, step in enumerate(steps): + workflow_params.append([ + developer_id, # $1 + task_id, # $2 + workflow_name, # $3 + step_idx, # $4 + step["kind_"], # $5 + step, # $6 + ]) + + return [ + ( + update_task_query, + update_task_params, + "fetchrow", + ), + ( + workflows_query, + workflow_params, + "fetchmany", + ), + ] diff --git a/agents-api/agents_api/models/tools/__init__.py b/agents-api/agents_api/queries/tools/__init__.py similarity index 87% rename from agents-api/agents_api/models/tools/__init__.py rename to agents-api/agents_api/queries/tools/__init__.py index b1775f1a9..7afa6d64a 100644 --- a/agents-api/agents_api/models/tools/__init__.py +++ b/agents-api/agents_api/queries/tools/__init__.py @@ -18,3 +18,13 @@ from .list_tools import list_tools from .patch_tool import patch_tool from .update_tool import update_tool + +__all__ = [ + "create_tools", + "delete_tool", + "get_tool", + "get_tool_args_from_metadata", + "list_tools", + "patch_tool", + "update_tool", +] diff --git a/agents-api/agents_api/queries/tools/create_tools.py b/agents-api/agents_api/queries/tools/create_tools.py new file mode 100644 index 000000000..414d9ce6a --- /dev/null +++ b/agents-api/agents_api/queries/tools/create_tools.py @@ -0,0 +1,91 @@ +"""This module contains functions for creating tools in the PostgreSQL database.""" + +from uuid import UUID + +from beartype import beartype +from uuid_extensions import uuid7 + +from ...autogen.openapi_model import CreateToolRequest, Tool +from ...common.utils.db_exceptions import common_db_exceptions +from ...metrics.counters import increase_counter +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL query for creating tools +tools_query = """INSERT INTO tools +( + developer_id, + agent_id, + tool_id, + type, + name, + spec, + description +) +SELECT + $1, + $2, + $3, + $4, + $5, + $6, + $7 +WHERE NOT EXISTS ( + SELECT null FROM tools + WHERE (agent_id, name) = ($2, $5) +) +RETURNING * +""" + + +@rewrap_exceptions(common_db_exceptions("tool", ["create"])) +@wrap_in_class( + Tool, + transform=lambda d: { + "id": d.pop("tool_id"), + d["type"]: d.pop("spec"), + **d, + }, +) +@increase_counter("create_tools") +@pg_query +@beartype +async def create_tools( + *, + developer_id: UUID, + agent_id: UUID, + data: list[CreateToolRequest], +) -> tuple[str, list, str]: + """ + Constructs an SQL query for inserting tool records into the 'tools' relation in the PostgreSQL database. + + Parameters: + developer_id (UUID): The unique identifier for the developer. + agent_id (UUID): The unique identifier for the agent. + data (list[CreateToolRequest]): A list of function definitions to be inserted. + + Returns: + list[Tool] + """ + + assert all( + getattr(tool, tool.type) is not None for tool in data if hasattr(tool, tool.type) + ), "Tool spec must be passed" + + tools_data = [ + [ + developer_id, + str(agent_id), + str(uuid7()), + tool.type, + tool.name, + getattr(tool, tool.type) and getattr(tool, tool.type).model_dump(), + tool.description if hasattr(tool, "description") else None, + ] + for tool in data + ] + + return ( + tools_query, + tools_data, + "fetchmany", + ) diff --git a/agents-api/agents_api/queries/tools/delete_tool.py b/agents-api/agents_api/queries/tools/delete_tool.py new file mode 100644 index 000000000..d1e75b1be --- /dev/null +++ b/agents-api/agents_api/queries/tools/delete_tool.py @@ -0,0 +1,47 @@ +from uuid import UUID + +from beartype import beartype + +from ...autogen.openapi_model import ResourceDeletedResponse +from ...common.utils.datetime import utcnow +from ...common.utils.db_exceptions import common_db_exceptions +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL query for deleting a tool +tools_query = """ +DELETE FROM + tools +WHERE + developer_id = $1 AND + agent_id = $2 AND + tool_id = $3 +RETURNING * +""" + + +@rewrap_exceptions(common_db_exceptions("tool", ["delete"])) +@wrap_in_class( + ResourceDeletedResponse, + one=True, + transform=lambda d: {"id": d["tool_id"], "deleted_at": utcnow(), "jobs": [], **d}, +) +@pg_query +@beartype +async def delete_tool( + *, + developer_id: UUID, + agent_id: UUID, + tool_id: UUID, +) -> tuple[str, list]: + developer_id = str(developer_id) + agent_id = str(agent_id) + tool_id = str(tool_id) + + return ( + tools_query, + [ + developer_id, + agent_id, + tool_id, + ], + ) diff --git a/agents-api/agents_api/queries/tools/get_tool.py b/agents-api/agents_api/queries/tools/get_tool.py new file mode 100644 index 000000000..716e22fd2 --- /dev/null +++ b/agents-api/agents_api/queries/tools/get_tool.py @@ -0,0 +1,49 @@ +from uuid import UUID + +from beartype import beartype + +from ...autogen.openapi_model import Tool +from ...common.utils.db_exceptions import common_db_exceptions +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL query for getting a tool +tools_query = """ +SELECT * FROM tools +WHERE + developer_id = $1 AND + agent_id = $2 AND + tool_id = $3 +LIMIT 1 +""" + + +@rewrap_exceptions(common_db_exceptions("tool", ["get"])) +@wrap_in_class( + Tool, + transform=lambda d: { + "id": d.pop("tool_id"), + d["type"]: d.pop("spec"), + **d, + }, + one=True, +) +@pg_query +@beartype +async def get_tool( + *, + developer_id: UUID, + agent_id: UUID, + tool_id: UUID, +) -> tuple[str, list]: + developer_id = str(developer_id) + agent_id = str(agent_id) + tool_id = str(tool_id) + + return ( + tools_query, + [ + developer_id, + agent_id, + tool_id, + ], + ) diff --git a/agents-api/agents_api/queries/tools/get_tool_args_from_metadata.py b/agents-api/agents_api/queries/tools/get_tool_args_from_metadata.py new file mode 100644 index 000000000..635cd4164 --- /dev/null +++ b/agents-api/agents_api/queries/tools/get_tool_args_from_metadata.py @@ -0,0 +1,85 @@ +from typing import Literal +from uuid import UUID + +from beartype import beartype + +from ...common.utils.db_exceptions import common_db_exceptions +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL query for getting tool args from metadata +tools_args_for_task_query = """ +SELECT COALESCE(agents_md || tasks_md, agents_md, tasks_md, '{}') as values FROM ( + SELECT + CASE WHEN $3 = 'x-integrations-args' then metadata->'x-integrations-args' + WHEN $3 = 'x-api_call-args' then metadata->'x-api_call-args' + WHEN $3 = 'x-integrations-setup' then metadata->'x-integrations-setup' + WHEN $3 = 'x-api_call-setup' then metadata->'x-api_call-setup' END AS agents_md + FROM agents + WHERE agent_id = $1 AND developer_id = $4 LIMIT 1 +) AS agents_md, +( + SELECT + CASE WHEN $3 = 'x-integrations-args' then metadata->'x-integrations-args' + WHEN $3 = 'x-api_call-args' then metadata->'x-api_call-args' + WHEN $3 = 'x-integrations-setup' then metadata->'x-integrations-setup' + WHEN $3 = 'x-api_call-setup' then metadata->'x-api_call-setup' END AS tasks_md + FROM tasks + WHERE task_id = $2 AND developer_id = $4 LIMIT 1 +) AS tasks_md""" + +# Define the raw SQL query for getting tool args from metadata for a session +tool_args_for_session_query = """ +SELECT COALESCE(agents_md || sessions_md, agents_md, sessions_md, '{}') as values FROM ( + SELECT + CASE WHEN $3 = 'x-integrations-args' then metadata->'x-integrations-args' + WHEN $3 = 'x-api_call-args' then metadata->'x-api_call-args' + WHEN $3 = 'x-integrations-setup' then metadata->'x-integrations-setup' + WHEN $3 = 'x-api_call-setup' then metadata->'x-api_call-setup' END AS agents_md + FROM agents + WHERE agent_id = $1 AND developer_id = $4 LIMIT 1 +) AS agents_md, +( + SELECT + CASE WHEN $3 = 'x-integrations-args' then metadata->'x-integrations-args' + WHEN $3 = 'x-api_call-args' then metadata->'x-api_call-args' + WHEN $3 = 'x-integrations-setup' then metadata->'x-integrations-setup' + WHEN $3 = 'x-api_call-setup' then metadata->'x-api_call-setup' END AS tasks_md + FROM sessions + WHERE session_id = $2 AND developer_id = $4 LIMIT 1 +) AS sessions_md""" + + +@rewrap_exceptions(common_db_exceptions("tool_metadata", ["get"])) +@wrap_in_class(dict, transform=lambda x: x["values"], one=True) +@pg_query +@beartype +async def get_tool_args_from_metadata( + *, + developer_id: UUID, + agent_id: UUID, + session_id: UUID | None = None, + task_id: UUID | None = None, + tool_type: Literal["integration", "api_call"] = "integration", + arg_type: Literal["args", "setup", "headers"] = "args", +) -> tuple[str, list]: + match session_id, task_id: + case (None, task_id) if task_id is not None: + return ( + tools_args_for_task_query, + [ + agent_id, + task_id, + f"x-{tool_type}-{arg_type}", + developer_id, + ], + ) + + case (session_id, None) if session_id is not None: + return ( + tool_args_for_session_query, + [agent_id, session_id, f"x-{tool_type}-{arg_type}", developer_id], + ) + + case (_, _): + msg = "Either session_id or task_id must be provided" + raise ValueError(msg) diff --git a/agents-api/agents_api/queries/tools/list_tools.py b/agents-api/agents_api/queries/tools/list_tools.py new file mode 100644 index 000000000..543826462 --- /dev/null +++ b/agents-api/agents_api/queries/tools/list_tools.py @@ -0,0 +1,62 @@ +from typing import Literal +from uuid import UUID + +from beartype import beartype + +from ...autogen.openapi_model import Tool +from ...common.utils.db_exceptions import common_db_exceptions +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL query for listing tools +tools_query = """ +SELECT * FROM tools +WHERE + developer_id = $1 AND + agent_id = $2 +ORDER BY + CASE WHEN $5 = 'created_at' AND $6 = 'desc' THEN tools.created_at END DESC NULLS LAST, + CASE WHEN $5 = 'created_at' AND $6 = 'asc' THEN tools.created_at END ASC NULLS LAST, + CASE WHEN $5 = 'updated_at' AND $6 = 'desc' THEN tools.updated_at END DESC NULLS LAST, + CASE WHEN $5 = 'updated_at' AND $6 = 'asc' THEN tools.updated_at END ASC NULLS LAST +LIMIT $3 OFFSET $4; +""" + + +@rewrap_exceptions(common_db_exceptions("tool", ["list"])) +@wrap_in_class( + Tool, + transform=lambda d: { + d["type"]: { + **d.pop("spec"), + "name": d["name"], + "description": d["description"], + }, + "id": d.pop("tool_id"), + **d, + }, +) +@pg_query +@beartype +async def list_tools( + *, + developer_id: UUID, + agent_id: UUID, + limit: int = 100, + offset: int = 0, + sort_by: Literal["created_at", "updated_at"] = "created_at", + direction: Literal["asc", "desc"] = "desc", +) -> tuple[str, list]: + developer_id = str(developer_id) + agent_id = str(agent_id) + + return ( + tools_query, + [ + developer_id, + agent_id, + limit, + offset, + sort_by, + direction, + ], + ) diff --git a/agents-api/agents_api/queries/tools/patch_tool.py b/agents-api/agents_api/queries/tools/patch_tool.py new file mode 100644 index 000000000..aab80c42b --- /dev/null +++ b/agents-api/agents_api/queries/tools/patch_tool.py @@ -0,0 +1,88 @@ +from uuid import UUID + +from beartype import beartype + +from ...autogen.openapi_model import PatchToolRequest, ResourceUpdatedResponse +from ...common.utils.db_exceptions import common_db_exceptions +from ...metrics.counters import increase_counter +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL query for patching a tool +tools_query = """ +WITH updated_tools AS ( + UPDATE tools + SET + type = COALESCE($4, type), + name = COALESCE($5, name), + description = COALESCE($6, description), + spec = COALESCE($7, spec) + WHERE + developer_id = $1 AND + agent_id = $2 AND + tool_id = $3 + RETURNING * +) +SELECT * FROM updated_tools; +""" + + +@rewrap_exceptions(common_db_exceptions("tool", ["patch"])) +@wrap_in_class( + ResourceUpdatedResponse, + one=True, + transform=lambda d: {"id": d["tool_id"], "jobs": [], **d}, +) +@increase_counter("patch_tool") +@pg_query +@beartype +async def patch_tool( + *, developer_id: UUID, agent_id: UUID, tool_id: UUID, data: PatchToolRequest +) -> tuple[str, list]: + """ + Updates the tool information for a given agent and tool ID in the 'PostgreSQL' database. + + Parameters: + agent_id (UUID): The unique identifier of the agent. + tool_id (UUID): The unique identifier of the tool to be updated. + data (PatchToolRequest): The request payload containing the updated tool information. + Returns: + ResourceUpdatedResponse: The updated tool data. + """ + + developer_id = str(developer_id) + agent_id = str(agent_id) + tool_id = str(tool_id) + + # Extract the tool data from the payload + patch_data = data.model_dump(exclude_none=True) + + # Assert that only one of the tool type fields is present + tool_specs = [ + (tool_type, patch_data.get(tool_type)) + for tool_type in ["function", "integration", "system", "api_call"] + if patch_data.get(tool_type) is not None + ] + + assert len(tool_specs) <= 1, "Invalid tool update" + tool_type, tool_spec = tool_specs[0] if tool_specs else (None, None) + + if tool_type is not None: + patch_data["type"] = patch_data.get("type", tool_type) + assert patch_data["type"] == tool_type, "Invalid tool update" + + tool_spec = tool_spec or {} + if tool_spec: + del patch_data[tool_type] + + return ( + tools_query, + [ + developer_id, + agent_id, + tool_id, + tool_type, + data.name, + data.description, + tool_spec, + ], + ) diff --git a/agents-api/agents_api/queries/tools/update_tool.py b/agents-api/agents_api/queries/tools/update_tool.py new file mode 100644 index 000000000..8aa9c29a4 --- /dev/null +++ b/agents-api/agents_api/queries/tools/update_tool.py @@ -0,0 +1,81 @@ +from uuid import UUID + +from beartype import beartype + +from ...autogen.openapi_model import ( + ResourceUpdatedResponse, + UpdateToolRequest, +) +from ...common.utils.db_exceptions import common_db_exceptions +from ...metrics.counters import increase_counter +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL query for updating a tool +tools_query = """ +UPDATE tools +SET + type = $4, + name = $5, + description = $6, + spec = $7 +WHERE + developer_id = $1 AND + agent_id = $2 AND + tool_id = $3 +RETURNING *; +""" + + +@rewrap_exceptions(common_db_exceptions("tool", ["update"])) +@wrap_in_class( + ResourceUpdatedResponse, + one=True, + transform=lambda d: {"id": d["tool_id"], "jobs": [], **d}, +) +@increase_counter("update_tool") +@pg_query +@beartype +async def update_tool( + *, + developer_id: UUID, + agent_id: UUID, + tool_id: UUID, + data: UpdateToolRequest, + **kwargs, +) -> tuple[str, list]: + developer_id = str(developer_id) + agent_id = str(agent_id) + tool_id = str(tool_id) + + # Extract the tool data from the payload + update_data = data.model_dump(exclude_none=True) + + # Assert that only one of the tool type fields is present + tool_specs = [ + (tool_type, update_data.get(tool_type)) + for tool_type in ["function", "integration", "system", "api_call"] + if update_data.get(tool_type) is not None + ] + + assert len(tool_specs) <= 1, "Invalid tool update" + tool_type, tool_spec = tool_specs[0] if tool_specs else (None, None) + + if tool_type is not None: + update_data["type"] = update_data.get("type", tool_type) + assert update_data["type"] == tool_type, "Invalid tool update" + + update_data["spec"] = tool_spec + del update_data[tool_type] + + return ( + tools_query, + [ + developer_id, + agent_id, + tool_id, + tool_type, + data.name, + data.description, + tool_spec, + ], + ) diff --git a/agents-api/agents_api/queries/users/__init__.py b/agents-api/agents_api/queries/users/__init__.py new file mode 100644 index 000000000..8b2bdf36f --- /dev/null +++ b/agents-api/agents_api/queries/users/__init__.py @@ -0,0 +1,28 @@ +""" +The `user` module within the `queries` package provides SQL query functions for managing users +in the TimescaleDB database. This includes operations for: + +- Creating new users +- Updating existing users +- Retrieving user details +- Listing users with filtering and pagination +- Deleting users +""" + +from .create_or_update_user import create_or_update_user +from .create_user import create_user +from .delete_user import delete_user +from .get_user import get_user +from .list_users import list_users +from .patch_user import patch_user +from .update_user import update_user + +__all__ = [ + "create_or_update_user", + "create_user", + "delete_user", + "get_user", + "list_users", + "patch_user", + "update_user", +] diff --git a/agents-api/agents_api/queries/users/create_or_update_user.py b/agents-api/agents_api/queries/users/create_or_update_user.py new file mode 100644 index 000000000..02dc2ecb5 --- /dev/null +++ b/agents-api/agents_api/queries/users/create_or_update_user.py @@ -0,0 +1,74 @@ +from uuid import UUID + +from beartype import beartype + +from ...autogen.openapi_model import CreateOrUpdateUserRequest, User +from ...common.utils.db_exceptions import common_db_exceptions +from ...metrics.counters import increase_counter +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL query for creating or updating a user +user_query = """ +INSERT INTO users ( + developer_id, + user_id, + name, + about, + metadata +) +VALUES ( + $1, -- developer_id + $2, -- user_id + $3, -- name + $4, -- about + $5::jsonb -- metadata +) +ON CONFLICT (developer_id, user_id) DO UPDATE SET + name = EXCLUDED.name, + about = EXCLUDED.about, + metadata = EXCLUDED.metadata +RETURNING *; +""" + + +@rewrap_exceptions(common_db_exceptions("user", ["create_or_update"])) +@wrap_in_class( + User, + one=True, + transform=lambda d: { + **d, + "id": d["user_id"], + }, +) +@increase_counter("create_or_update_user") +@pg_query +@beartype +async def create_or_update_user( + *, developer_id: UUID, user_id: UUID, data: CreateOrUpdateUserRequest +) -> tuple[str, list]: + """ + Constructs an SQL query to create or update a user. + + Args: + developer_id (UUID): The UUID of the developer. + user_id (UUID): The UUID of the user. + data (CreateOrUpdateUserRequest): The user data to insert or update. + + Returns: + tuple[str, list]: SQL query and parameters. + + Raises: + HTTPException: If developer doesn't exist (404) or on unique constraint violation (409) + """ + params = [ + developer_id, # $1 + user_id, # $2 + data.name, # $3 + data.about, # $4 + data.metadata or {}, # $5 + ] + + return ( + user_query, + params, + ) diff --git a/agents-api/agents_api/queries/users/create_user.py b/agents-api/agents_api/queries/users/create_user.py new file mode 100644 index 000000000..98f7782c6 --- /dev/null +++ b/agents-api/agents_api/queries/users/create_user.py @@ -0,0 +1,74 @@ +from uuid import UUID + +from beartype import beartype +from uuid_extensions import uuid7 + +from ...autogen.openapi_model import CreateUserRequest, User +from ...common.utils.db_exceptions import common_db_exceptions +from ...metrics.counters import increase_counter +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL query outside the function +user_query = """ +INSERT INTO users ( + developer_id, + user_id, + name, + about, + metadata +) +VALUES ( + $1, -- developer_id + $2, -- user_id + $3, -- name + $4, -- about + $5::jsonb -- metadata +) +RETURNING *; +""" + + +@rewrap_exceptions(common_db_exceptions("user", ["create"])) +@wrap_in_class( + User, + one=True, + transform=lambda d: { + **d, + "id": d["user_id"], + }, +) +@increase_counter("create_user") +@pg_query +@beartype +async def create_user( + *, + developer_id: UUID, + user_id: UUID | None = None, + data: CreateUserRequest, +) -> tuple[str, list]: + """ + Constructs the SQL query to create a new user. + + Args: + developer_id (UUID): The UUID of the developer creating the user. + user_id (UUID, optional): The UUID for the new user. If None, one will be generated. + data (CreateUserRequest): The user data to insert. + + Returns: + tuple[str, list]: A tuple containing the SQL query and its parameters. + """ + user_id = user_id or uuid7() + metadata = data.metadata or {} + + params = [ + developer_id, # $1 + user_id, # $2 + data.name, # $3 + data.about, # $4 + metadata, # $5 + ] + + return ( + user_query, + params, + ) diff --git a/agents-api/agents_api/queries/users/delete_user.py b/agents-api/agents_api/queries/users/delete_user.py new file mode 100644 index 000000000..a2e95a2c4 --- /dev/null +++ b/agents-api/agents_api/queries/users/delete_user.py @@ -0,0 +1,79 @@ +from uuid import UUID + +from beartype import beartype + +from ...autogen.openapi_model import ResourceDeletedResponse +from ...common.utils.datetime import utcnow +from ...common.utils.db_exceptions import common_db_exceptions +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL query outside the function +delete_query = """ +WITH deleted_file_owners AS ( + DELETE FROM file_owners + WHERE developer_id = $1 + AND owner_type = 'user' + AND owner_id = $2 +), +deleted_doc_owners AS ( + DELETE FROM doc_owners + WHERE developer_id = $1 + AND owner_type = 'user' + AND owner_id = $2 +), +deleted_files AS ( + DELETE FROM files + WHERE developer_id = $1 + AND file_id IN ( + SELECT file_id FROM file_owners + WHERE developer_id = $1 + AND owner_type = 'user' + AND owner_id = $2 + ) +), +deleted_docs AS ( + DELETE FROM docs + WHERE developer_id = $1 + AND doc_id IN ( + SELECT doc_id FROM doc_owners + WHERE developer_id = $1 + AND owner_type = 'user' + AND owner_id = $2 + ) +) +DELETE FROM users +WHERE developer_id = $1 AND user_id = $2 +RETURNING user_id, developer_id; +""" + + +@rewrap_exceptions(common_db_exceptions("user", ["delete"])) +@wrap_in_class( + ResourceDeletedResponse, + one=True, + transform=lambda d: { + **d, + "id": d["user_id"], + "deleted_at": utcnow(), + "jobs": [], + }, +) +@pg_query +@beartype +async def delete_user(*, developer_id: UUID, user_id: UUID) -> tuple[str, list]: + """ + Constructs optimized SQL query to delete a user and related data. + Uses primary key for efficient deletion. + + Args: + developer_id (UUID): The developer's UUID + user_id (UUID): The user's UUID + + Returns: + tuple[str, list]: SQL query and parameters + """ + + return ( + delete_query, + [developer_id, user_id], + ) diff --git a/agents-api/agents_api/queries/users/get_user.py b/agents-api/agents_api/queries/users/get_user.py new file mode 100644 index 000000000..1570f6476 --- /dev/null +++ b/agents-api/agents_api/queries/users/get_user.py @@ -0,0 +1,49 @@ +from typing import Literal +from uuid import UUID + +from beartype import beartype + +from ...autogen.openapi_model import User +from ...common.utils.db_exceptions import common_db_exceptions +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL query outside the function +user_query = """ +SELECT + user_id as id, -- user_id + developer_id, -- developer_id + name, -- name + about, -- about + metadata, -- metadata + created_at, -- created_at + updated_at -- updated_at +FROM users +WHERE developer_id = $1 +AND user_id = $2; +""" + + +@rewrap_exceptions(common_db_exceptions("user", ["get"])) +@wrap_in_class(User, one=True) +@pg_query +@beartype +async def get_user( + *, developer_id: UUID, user_id: UUID +) -> tuple[str, list, Literal["fetchrow", "fetchmany", "fetch"]]: + """ + Constructs an optimized SQL query to retrieve a user's details. + Uses the primary key index (developer_id, user_id) for efficient lookup. + + Args: + developer_id (UUID): The UUID of the developer. + user_id (UUID): The UUID of the user to retrieve. + + Returns: + tuple[str, list, str]: SQL query, parameters, and fetch mode. + """ + + return ( + user_query, + [developer_id, user_id], + "fetchrow", + ) diff --git a/agents-api/agents_api/queries/users/list_users.py b/agents-api/agents_api/queries/users/list_users.py new file mode 100644 index 000000000..6edba899d --- /dev/null +++ b/agents-api/agents_api/queries/users/list_users.py @@ -0,0 +1,83 @@ +from typing import Literal +from uuid import UUID + +from beartype import beartype +from fastapi import HTTPException + +from ...autogen.openapi_model import User +from ...common.utils.db_exceptions import common_db_exceptions +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL query outside the function +user_query = """ +WITH filtered_users AS ( + SELECT + user_id as id, -- user_id + developer_id, -- developer_id + name, -- name + about, -- about + metadata, -- metadata + created_at, -- created_at + updated_at -- updated_at + FROM users + WHERE developer_id = $1 + AND ($4::jsonb IS NULL OR metadata @> $4) +) +SELECT * +FROM filtered_users +ORDER BY + CASE WHEN $5 = 'created_at' AND $6 = 'asc' THEN created_at END ASC NULLS LAST, + CASE WHEN $5 = 'created_at' AND $6 = 'desc' THEN created_at END DESC NULLS LAST, + CASE WHEN $5 = 'updated_at' AND $6 = 'asc' THEN updated_at END ASC NULLS LAST, + CASE WHEN $5 = 'updated_at' AND $6 = 'desc' THEN updated_at END DESC NULLS LAST +LIMIT $2 +OFFSET $3; +""" + + +@rewrap_exceptions(common_db_exceptions("user", ["list"])) +@wrap_in_class(User) +@pg_query +@beartype +async def list_users( + *, + developer_id: UUID, + limit: int = 100, + offset: int = 0, + sort_by: Literal["created_at", "updated_at"] = "created_at", + direction: Literal["asc", "desc"] = "desc", + metadata_filter: dict | None = None, +) -> tuple[str, list]: + """ + Constructs an optimized SQL query for listing users with pagination and filtering. + Uses indexes on developer_id and metadata for efficient querying. + + Args: + developer_id (UUID): The developer's UUID + limit (int): Maximum number of records to return + offset (int): Number of records to skip + sort_by (str): Field to sort by + direction (str): Sort direction + metadata_filter (dict, optional): Metadata-based filters + + Returns: + tuple[str, list]: SQL query and parameters + """ + if limit < 1 or limit > 1000: + raise HTTPException(status_code=400, detail="Limit must be between 1 and 1000") + if offset < 0: + raise HTTPException(status_code=400, detail="Offset must be non-negative") + + params = [ + developer_id, # $1 + limit, # $2 + offset, # $3 + metadata_filter, # Will be NULL if not provided + sort_by, # $4 + direction, # $5 + ] + + return ( + user_query, + params, + ) diff --git a/agents-api/agents_api/queries/users/patch_user.py b/agents-api/agents_api/queries/users/patch_user.py new file mode 100644 index 000000000..b8dd6ad27 --- /dev/null +++ b/agents-api/agents_api/queries/users/patch_user.py @@ -0,0 +1,70 @@ +from uuid import UUID + +from beartype import beartype + +from ...autogen.openapi_model import PatchUserRequest, ResourceUpdatedResponse +from ...common.utils.db_exceptions import common_db_exceptions +from ...metrics.counters import increase_counter +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL query outside the function +user_query = """ +UPDATE users +SET + name = CASE + WHEN $3::text IS NOT NULL THEN $3 -- name + ELSE name + END, + about = CASE + WHEN $4::text IS NOT NULL THEN $4 -- about + ELSE about + END, + metadata = CASE + WHEN $5::jsonb IS NOT NULL THEN metadata || $5 -- metadata + ELSE metadata + END +WHERE developer_id = $1 +AND user_id = $2 +RETURNING + user_id as id, -- user_id + developer_id, -- developer_id + name, -- name + about, -- about + metadata, -- metadata + created_at, -- created_at + updated_at; -- updated_at +""" + + +@rewrap_exceptions(common_db_exceptions("user", ["patch"])) +@wrap_in_class(ResourceUpdatedResponse, one=True) +@increase_counter("patch_user") +@pg_query +@beartype +async def patch_user( + *, developer_id: UUID, user_id: UUID, data: PatchUserRequest +) -> tuple[str, list]: + """ + Constructs an optimized SQL query for partial user updates. + Uses primary key for efficient update and jsonb_merge for metadata. + + Args: + developer_id (UUID): The developer's UUID + user_id (UUID): The user's UUID + data (PatchUserRequest): Partial update data + + Returns: + tuple[str, list]: SQL query and parameters + """ + params = [ + developer_id, # $1 + user_id, # $2 + data.name, # $3. Will be NULL if not provided + data.about, # $4. Will be NULL if not provided + data.metadata, # $5. Will be NULL if not provided + ] + + return ( + user_query, + params, + ) diff --git a/agents-api/agents_api/queries/users/update_user.py b/agents-api/agents_api/queries/users/update_user.py new file mode 100644 index 000000000..89822a202 --- /dev/null +++ b/agents-api/agents_api/queries/users/update_user.py @@ -0,0 +1,58 @@ +from uuid import UUID + +from beartype import beartype + +from ...autogen.openapi_model import ResourceUpdatedResponse, UpdateUserRequest +from ...common.utils.db_exceptions import common_db_exceptions +from ...metrics.counters import increase_counter +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL query outside the function +user_query = """ +UPDATE users +SET + name = $3, -- name + about = $4, -- about + metadata = $5 -- metadata +WHERE developer_id = $1 -- developer_id +AND user_id = $2 -- user_id +RETURNING * +""" + + +@rewrap_exceptions(common_db_exceptions("user", ["update"])) +@wrap_in_class( + ResourceUpdatedResponse, + one=True, + transform=lambda d: {**d, "id": d["user_id"]}, +) +@increase_counter("update_user") +@pg_query +@beartype +async def update_user( + *, developer_id: UUID, user_id: UUID, data: UpdateUserRequest +) -> tuple[str, list]: + """ + Constructs an optimized SQL query to update a user's details. + Uses primary key for efficient update. + + Args: + developer_id (UUID): The developer's UUID + user_id (UUID): The user's UUID + data (UpdateUserRequest): Updated user data + + Returns: + tuple[str, list]: SQL query and parameters + """ + params = [ + developer_id, + user_id, + data.name, + data.about, + data.metadata or {}, + ] + + return ( + user_query, + params, + ) diff --git a/agents-api/agents_api/queries/utils.py b/agents-api/agents_api/queries/utils.py new file mode 100644 index 000000000..aa4b9ae20 --- /dev/null +++ b/agents-api/agents_api/queries/utils.py @@ -0,0 +1,300 @@ +import concurrent.futures +import inspect +import socket +import time +from collections.abc import Awaitable, Callable +from functools import wraps +from typing import ( + Any, + Literal, + NotRequired, + ParamSpec, + TypeVar, + cast, +) + +import asyncpg +import namer +from asyncpg import Record +from beartype import beartype +from fastapi import HTTPException +from pydantic import BaseModel +from typing_extensions import TypedDict + +from ..app import app +from ..env import query_timeout + +P = ParamSpec("P") +T = TypeVar("T") +ModelT = TypeVar("ModelT", bound=BaseModel) + + +def generate_canonical_name() -> str: + """Generate canonical name""" + + categories: list[str] = ["astronomy", "physics", "scientists", "math"] + return namer.generate(separator="_", suffix_length=3, category=categories) + + +class AsyncPGFetchArgs(TypedDict): + query: str + args: list[Any] + timeout: NotRequired[float | None] + + +type SQLQuery = str +type FetchMethod = Literal["fetch", "fetchmany", "fetchrow"] +type PGQueryArgs = tuple[SQLQuery, list[Any]] | tuple[SQLQuery, list[Any], FetchMethod] +type PreparedPGQueryArgs = tuple[FetchMethod, AsyncPGFetchArgs] +type BatchedPreparedPGQueryArgs = list[PreparedPGQueryArgs] + + +@beartype +def prepare_pg_query_args( + query_args: PGQueryArgs | list[PGQueryArgs], +) -> BatchedPreparedPGQueryArgs: + batch = [] + query_args = [query_args] if isinstance(query_args, tuple) else query_args + + for query_arg in query_args: + match query_arg: + case (query, variables) | (query, variables, "fetch"): + batch.append(( + "fetch", + AsyncPGFetchArgs(query=query, args=variables, timeout=query_timeout), + )) + case (query, variables, "fetchmany"): + batch.append(( + "fetchmany", + AsyncPGFetchArgs(query=query, args=[variables], timeout=query_timeout), + )) + case (query, variables, "fetchrow"): + batch.append(( + "fetchrow", + AsyncPGFetchArgs(query=query, args=variables, timeout=query_timeout), + )) + case _: + msg = "Invalid query arguments" + raise ValueError(msg) + + return batch + + +@beartype +def pg_query( + func: Callable[P, PGQueryArgs | list[PGQueryArgs]] | None = None, + debug: bool | None = None, + only_on_error: bool = False, + timeit: bool = False, + return_index: int = -1, +) -> Callable[..., Callable[P, list[Record]]] | Callable[P, list[Record]]: + def pg_query_dec( + func: Callable[P, PGQueryArgs | list[PGQueryArgs]], + ) -> Callable[..., Callable[P, list[Record]]]: + """ + Decorator that wraps a function that takes arbitrary arguments, and + returns a (query string, variables) tuple. + + The wrapped function should additionally take a client keyword argument + and then run the query using the client, returning a Record. + """ + + from pprint import pprint + + @wraps(func) + async def wrapper( + *args: P.args, + connection_pool: asyncpg.Pool | None = None, + **kwargs: P.kwargs, + ) -> list[Record]: + query_args = await func(*args, **kwargs) + batch = prepare_pg_query_args(query_args) + + not only_on_error and debug and pprint(batch) + + # Run the query + pool = ( + connection_pool + if connection_pool is not None + else cast(asyncpg.Pool, app.state.postgres_pool) + ) + + try: + async with pool.acquire() as conn, conn.transaction(): + start = timeit and time.perf_counter() + all_results = [] + + for method_name, payload in batch: + method = getattr(conn, method_name) + + query = payload["query"] + args = payload["args"] + timeout = payload.get("timeout") + + results: list[Record] = await method(query, *args, timeout=timeout) + if method_name == "fetchrow": + results = ( + [results] + if results is not None + and results.get("bool", False) is not None + and results.get("exists", True) is not False + else [] + ) + + if method_name == "fetchrow" and len(results) == 0: + msg = "No data found" + raise asyncpg.NoDataFoundError(msg) + + all_results.append(results) + + end = timeit and time.perf_counter() + + timeit and print(f"PostgreSQL query time: {end - start:.2f} seconds") + + except Exception as e: + if only_on_error and debug: + pprint(batch) + + debug and print(repr(e)) + connection_error = isinstance( + e, + (socket.gaierror), + ) + + if connection_error: + exc = HTTPException( + status_code=429, detail="Resource busy. Please try again later." + ) + raise exc from e + + raise + + # Return results from specified index + results_to_return = all_results[return_index] if all_results else [] + not only_on_error and debug and pprint(results_to_return) + + return results_to_return + + # Set the wrapped function as an attribute of the wrapper, + # forwards the __wrapped__ attribute if it exists. + setattr(wrapper, "__wrapped__", getattr(func, "__wrapped__", func)) + + return wrapper + + if func is not None and callable(func): + return pg_query_dec(func) + + return pg_query_dec + + +def wrap_in_class( + cls: type[ModelT] | Callable[..., ModelT], + one: bool = False, + transform: Callable[[dict], dict] | None = None, +) -> Callable[..., Callable[..., ModelT | list[ModelT]]]: + def _return_data(rec: list[Record]): + data = [dict(r.items()) for r in rec] + + nonlocal transform + transform = transform or (lambda x: x) + + if one: + assert len(data) == 1, f"Expected one result, got {len(data)}" + obj: ModelT = cls(**transform(data[0])) + return obj + + objs: list[ModelT] = [cls(**item) for item in map(transform, data)] + return objs + + def decorator( + func: Callable[P, list[Record] | Awaitable[list[Record]]], + ) -> Callable[P, ModelT | list[ModelT]]: + @wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> ModelT | list[ModelT]: + return _return_data(func(*args, **kwargs)) + + @wraps(func) + async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> ModelT | list[ModelT]: + return _return_data(await func(*args, **kwargs)) + + # Set the wrapped function as an attribute of the wrapper, + # forwards the __wrapped__ attribute if it exists. + setattr(wrapper, "__wrapped__", getattr(func, "__wrapped__", func)) + setattr(async_wrapper, "__wrapped__", getattr(func, "__wrapped__", func)) + + return async_wrapper if inspect.iscoroutinefunction(func) else wrapper + + return decorator + + +def rewrap_exceptions( + mapping: dict[ + type[BaseException] | Callable[[BaseException], bool], + type[BaseException] | Callable[[BaseException], BaseException], + ], + /, +) -> Callable[..., Callable[P, T | Awaitable[T]]]: + def _check_error(error): + nonlocal mapping + + for check, transform in mapping.items(): + should_catch = isinstance(error, check) if isinstance(check, type) else check(error) + + if should_catch: + new_error = ( + transform(str(error)) if isinstance(transform, type) else transform(error) + ) + + setattr(new_error, "__cause__", error) + + raise new_error from error + + def decorator( + func: Callable[P, T | Awaitable[T]], + ) -> Callable[..., Callable[P, T | Awaitable[T]]]: + @wraps(func) + async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + try: + result: T = await func(*args, **kwargs) + except BaseException as error: + _check_error(error) + raise error + + return result + + @wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + try: + result: T = func(*args, **kwargs) + except BaseException as error: + _check_error(error) + raise error + + return result + + # Set the wrapped function as an attribute of the wrapper, + # forwards the __wrapped__ attribute if it exists. + setattr(wrapper, "__wrapped__", getattr(func, "__wrapped__", func)) + setattr(async_wrapper, "__wrapped__", getattr(func, "__wrapped__", func)) + + return async_wrapper if inspect.iscoroutinefunction(func) else wrapper + + return decorator + + +def run_concurrently( + fns: list[Callable[..., Any]], + *, + args_list: list[tuple] = [], + kwargs_list: list[dict] = [], +) -> list[Any]: + args_list = args_list or [()] * len(fns) + kwargs_list = kwargs_list or [{}] * len(fns) + + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [ + executor.submit(fn, *args, **kwargs) + for fn, args, kwargs in zip(fns, args_list, kwargs_list) + ] + + return [future.result() for future in concurrent.futures.as_completed(futures)] diff --git a/agents-api/agents_api/rec_sum/data.py b/agents-api/agents_api/rec_sum/data.py index 23474c995..76bc9f966 100644 --- a/agents-api/agents_api/rec_sum/data.py +++ b/agents-api/agents_api/rec_sum/data.py @@ -5,21 +5,21 @@ module_directory: Path = Path(__file__).parent -with open(f"{module_directory}/entities_example_chat.json", "r") as _f: +with open(f"{module_directory}/entities_example_chat.json") as _f: entities_example_chat: Any = json.load(_f) -with open(f"{module_directory}/trim_example_chat.json", "r") as _f: +with open(f"{module_directory}/trim_example_chat.json") as _f: trim_example_chat: Any = json.load(_f) -with open(f"{module_directory}/trim_example_result.json", "r") as _f: +with open(f"{module_directory}/trim_example_result.json") as _f: trim_example_result: Any = json.load(_f) -with open(f"{module_directory}/summarize_example_chat.json", "r") as _f: +with open(f"{module_directory}/summarize_example_chat.json") as _f: summarize_example_chat: Any = json.load(_f) -with open(f"{module_directory}/summarize_example_result.json", "r") as _f: +with open(f"{module_directory}/summarize_example_result.json") as _f: summarize_example_result: Any = json.load(_f) diff --git a/agents-api/agents_api/rec_sum/entities.py b/agents-api/agents_api/rec_sum/entities.py index 01b29951b..c316173a1 100644 --- a/agents-api/agents_api/rec_sum/entities.py +++ b/agents-api/agents_api/rec_sum/entities.py @@ -9,7 +9,7 @@ from .utils import chatml, get_names_from_session ############## -## Entities ## +# Entities ## ############## entities_example_plan: str = """\ @@ -77,10 +77,7 @@ async def get_entities( assert "" in result["content"] result["content"] = ( - result["content"] - .split("")[-1] - .replace("", "") - .strip() + result["content"].split("")[-1].replace("", "").strip() ) result["role"] = "system" result["name"] = "entities" diff --git a/agents-api/agents_api/rec_sum/summarize.py b/agents-api/agents_api/rec_sum/summarize.py index 46a6662a3..700733a22 100644 --- a/agents-api/agents_api/rec_sum/summarize.py +++ b/agents-api/agents_api/rec_sum/summarize.py @@ -1,5 +1,4 @@ import json -from typing import List from tenacity import retry, stop_after_attempt @@ -8,7 +7,7 @@ from .utils import add_indices, chatml, get_names_from_session ########## -## summarize ## +# summarize ## ########## summarize_example_plan: str = """\ @@ -35,9 +34,7 @@ - VERY IMPORTANT: Add the indices of messages that are being summarized so that those messages can then be removed from the session otherwise, there'll be no way to identify which messages to remove. See example for more details.""" -def make_summarize_prompt( - session, user="a user", assistant="gpt-4-turbo", **_ -) -> List[str]: +def make_summarize_prompt(session, user="a user", assistant="gpt-4-turbo", **_) -> list[str]: return [ f"You are given a session history of a chat between {user or 'a user'} and {assistant or 'gpt-4-turbo'}. The session is formatted in the ChatML JSON format (from OpenAI).\n\n{summarize_instructions}\n\n\n{json.dumps(add_indices(summarize_example_chat), indent=2)}\n\n\n\n{summarize_example_plan}\n\n\n\n{json.dumps(summarize_example_result, indent=2)}\n", f"Begin! Write the summarized messages as a json list just like the example above. First write your plan inside and then your answer between . Don't forget to add the indices of the messages being summarized alongside each summary.\n\n\n{json.dumps(add_indices(session), indent=2)}\n\n", @@ -57,10 +54,7 @@ async def summarize_messages( offset = 0 # Remove the system prompt if present - if ( - chat_session[0]["role"] == "system" - and chat_session[0].get("name") != "entities" - ): + if chat_session[0]["role"] == "system" and chat_session[0].get("name") != "entities": chat_session = chat_session[1:] # The indices are not matched up correctly @@ -85,12 +79,10 @@ async def summarize_messages( .strip() ) - assert all((msg.get("summarizes") is not None for msg in summarized_messages)) + assert all(msg.get("summarizes") is not None for msg in summarized_messages) # Correct offset - summarized_messages = [ + return [ {**msg, "summarizes": [i + offset for i in msg["summarizes"]]} for msg in summarized_messages ] - - return summarized_messages diff --git a/agents-api/agents_api/rec_sum/trim.py b/agents-api/agents_api/rec_sum/trim.py index ee4025ea0..5ffadecfc 100644 --- a/agents-api/agents_api/rec_sum/trim.py +++ b/agents-api/agents_api/rec_sum/trim.py @@ -1,5 +1,4 @@ import json -from typing import List from tenacity import retry, stop_after_attempt @@ -8,7 +7,7 @@ from .utils import add_indices, chatml, get_names_from_session ########## -## Trim ## +# Trim ## ########## trim_example_plan: str = """\ @@ -33,7 +32,7 @@ # It is important to make keep the tone, setting and flow of the conversation consistent while trimming the messages. -def make_trim_prompt(session, user="a user", assistant="gpt-4-turbo", **_) -> List[str]: +def make_trim_prompt(session, user="a user", assistant="gpt-4-turbo", **_) -> list[str]: return [ f"You are given a session history of a chat between {user or 'a user'} and {assistant or 'gpt-4-turbo'}. The session is formatted in the ChatML JSON format (from OpenAI).\n\n{trim_instructions}\n\n\n{json.dumps(add_indices(trim_example_chat), indent=2)}\n\n\n\n{trim_example_plan}\n\n\n\n{json.dumps(trim_example_result, indent=2)}\n", f"Begin! Write the trimmed messages as a json list. First write your plan inside and then your answer between .\n\n\n{json.dumps(add_indices(session), indent=2)}\n\n", @@ -66,9 +65,7 @@ async def trim_messages( result["content"].split("")[-1].replace("", "").strip() ) - assert all((msg.get("index") is not None for msg in trimmed_messages)) + assert all(msg.get("index") is not None for msg in trimmed_messages) # Correct offset - trimmed_messages = [{**msg, "index": msg["index"]} for msg in trimmed_messages] - - return trimmed_messages + return [{**msg, "index": msg["index"]} for msg in trimmed_messages] diff --git a/agents-api/agents_api/rec_sum/utils.py b/agents-api/agents_api/rec_sum/utils.py index c674a4d44..4816b4308 100644 --- a/agents-api/agents_api/rec_sum/utils.py +++ b/agents-api/agents_api/rec_sum/utils.py @@ -1,19 +1,19 @@ ########### -## Utils ## +# Utils ## ########### -from typing import Any, Dict, List, TypeVar +from typing import Any, TypeVar _T2 = TypeVar("_T2") class chatml: @staticmethod - def make(content, role="system", name: _T2 = None, **_) -> Dict[str, _T2]: + def make(content, role="system", name: _T2 = None, **_) -> dict[str, _T2]: return { key: value - for key, value in dict(role=role, name=name, content=content).items() + for key, value in {"role": role, "name": name, "content": content}.items() if value is not None } @@ -46,14 +46,12 @@ def entities(content) -> Any: return chatml.system(content, name="entity") -def add_indices(list_of_dicts, idx_name="index") -> List[dict]: +def add_indices(list_of_dicts, idx_name="index") -> list[dict]: return [{idx_name: i, **msg} for i, msg in enumerate(list_of_dicts)] -def get_names_from_session(session) -> Dict[str, Any]: +def get_names_from_session(session) -> dict[str, Any]: return { - role: next( - (msg.get("name", None) for msg in session if msg["role"] == role), None - ) + role: next((msg.get("name", None) for msg in session if msg["role"] == role), None) for role in {"user", "assistant", "system"} } diff --git a/agents-api/agents_api/routers/agents/__init__.py b/agents-api/agents_api/routers/agents/__init__.py index 2eadecb3d..95354363c 100644 --- a/agents-api/agents_api/routers/agents/__init__.py +++ b/agents-api/agents_api/routers/agents/__init__.py @@ -1,15 +1,19 @@ # ruff: noqa: F401 from .create_agent import create_agent -from .create_agent_tool import create_agent_tool + +# from .create_agent_tool import create_agent_tool from .create_or_update_agent import create_or_update_agent from .delete_agent import delete_agent -from .delete_agent_tool import delete_agent_tool + +# from .delete_agent_tool import delete_agent_tool from .get_agent_details import get_agent_details -from .list_agent_tools import list_agent_tools + +# from .list_agent_tools import list_agent_tools from .list_agents import list_agents from .patch_agent import patch_agent -from .patch_agent_tool import patch_agent_tool + +# from .patch_agent_tool import patch_agent_tool from .router import router from .update_agent import update_agent -from .update_agent_tool import update_agent_tool +# from .update_agent_tool import update_agent_tool diff --git a/agents-api/agents_api/routers/agents/create_agent.py b/agents-api/agents_api/routers/agents/create_agent.py index 2e1c4df0a..f630d5251 100644 --- a/agents-api/agents_api/routers/agents/create_agent.py +++ b/agents-api/agents_api/routers/agents/create_agent.py @@ -9,7 +9,7 @@ ResourceCreatedResponse, ) from ...dependencies.developer_id import get_developer_id -from ...models.agent.create_agent import create_agent as create_agent_query +from ...queries.agents.create_agent import create_agent as create_agent_query from .router import router @@ -19,7 +19,7 @@ async def create_agent( data: CreateAgentRequest, ) -> ResourceCreatedResponse: # TODO: Validate model name - agent = create_agent_query( + agent = await create_agent_query( developer_id=x_developer_id, data=data, ) diff --git a/agents-api/agents_api/routers/agents/create_agent_tool.py b/agents-api/agents_api/routers/agents/create_agent_tool.py index 21b8e175a..74e98b3f9 100644 --- a/agents-api/agents_api/routers/agents/create_agent_tool.py +++ b/agents-api/agents_api/routers/agents/create_agent_tool.py @@ -4,13 +4,12 @@ from fastapi import Depends from starlette.status import HTTP_201_CREATED -import agents_api.models as models - from ...autogen.openapi_model import ( CreateToolRequest, ResourceCreatedResponse, ) from ...dependencies.developer_id import get_developer_id +from ...queries.tools.create_tools import create_tools as create_tools_query from .router import router @@ -20,7 +19,7 @@ async def create_agent_tool( x_developer_id: Annotated[UUID, Depends(get_developer_id)], data: CreateToolRequest, ) -> ResourceCreatedResponse: - tool = models.tools.create_tools( + tool = await create_tools_query( developer_id=x_developer_id, agent_id=agent_id, data=[data], diff --git a/agents-api/agents_api/routers/agents/create_or_update_agent.py b/agents-api/agents_api/routers/agents/create_or_update_agent.py index 2dcbcd599..fd2fc124c 100644 --- a/agents-api/agents_api/routers/agents/create_or_update_agent.py +++ b/agents-api/agents_api/routers/agents/create_or_update_agent.py @@ -4,13 +4,14 @@ from fastapi import Depends from starlette.status import HTTP_201_CREATED -import agents_api.models as models - from ...autogen.openapi_model import ( CreateOrUpdateAgentRequest, ResourceCreatedResponse, ) from ...dependencies.developer_id import get_developer_id +from ...queries.agents.create_or_update_agent import ( + create_or_update_agent as create_or_update_agent_query, +) from .router import router @@ -21,7 +22,7 @@ async def create_or_update_agent( x_developer_id: Annotated[UUID, Depends(get_developer_id)], ) -> ResourceCreatedResponse: # TODO: Validate model name - agent = models.agent.create_or_update_agent( + agent = await create_or_update_agent_query( developer_id=x_developer_id, agent_id=agent_id, data=data, diff --git a/agents-api/agents_api/routers/agents/delete_agent.py b/agents-api/agents_api/routers/agents/delete_agent.py index 03fcd56a0..3acb56aa2 100644 --- a/agents-api/agents_api/routers/agents/delete_agent.py +++ b/agents-api/agents_api/routers/agents/delete_agent.py @@ -6,7 +6,7 @@ from ...autogen.openapi_model import ResourceDeletedResponse from ...dependencies.developer_id import get_developer_id -from ...models.agent.delete_agent import delete_agent as delete_agent_query +from ...queries.agents.delete_agent import delete_agent as delete_agent_query from .router import router @@ -14,4 +14,4 @@ async def delete_agent( agent_id: UUID, x_developer_id: Annotated[UUID, Depends(get_developer_id)] ) -> ResourceDeletedResponse: - return delete_agent_query(developer_id=x_developer_id, agent_id=agent_id) + return await delete_agent_query(developer_id=x_developer_id, agent_id=agent_id) diff --git a/agents-api/agents_api/routers/agents/delete_agent_tool.py b/agents-api/agents_api/routers/agents/delete_agent_tool.py index 772116d64..6f82e0768 100644 --- a/agents-api/agents_api/routers/agents/delete_agent_tool.py +++ b/agents-api/agents_api/routers/agents/delete_agent_tool.py @@ -5,7 +5,7 @@ from ...autogen.openapi_model import ResourceDeletedResponse from ...dependencies.developer_id import get_developer_id -from ...models.tools.delete_tool import delete_tool +from ...queries.tools.delete_tool import delete_tool as delete_tool_query from .router import router @@ -15,7 +15,7 @@ async def delete_agent_tool( tool_id: UUID, x_developer_id: Annotated[UUID, Depends(get_developer_id)], ) -> ResourceDeletedResponse: - return delete_tool( + return await delete_tool_query( developer_id=x_developer_id, agent_id=agent_id, tool_id=tool_id, diff --git a/agents-api/agents_api/routers/agents/get_agent_details.py b/agents-api/agents_api/routers/agents/get_agent_details.py index 3d684368e..30f7d3a34 100644 --- a/agents-api/agents_api/routers/agents/get_agent_details.py +++ b/agents-api/agents_api/routers/agents/get_agent_details.py @@ -5,7 +5,7 @@ from ...autogen.openapi_model import Agent from ...dependencies.developer_id import get_developer_id -from ...models.agent.get_agent import get_agent as get_agent_query +from ...queries.agents.get_agent import get_agent as get_agent_query from .router import router @@ -14,4 +14,4 @@ async def get_agent_details( agent_id: UUID, x_developer_id: Annotated[UUID, Depends(get_developer_id)], ) -> Agent: - return get_agent_query(developer_id=x_developer_id, agent_id=agent_id) + return await get_agent_query(developer_id=x_developer_id, agent_id=agent_id) diff --git a/agents-api/agents_api/routers/agents/list_agent_tools.py b/agents-api/agents_api/routers/agents/list_agent_tools.py index 59d1a6ade..7712cbf26 100644 --- a/agents-api/agents_api/routers/agents/list_agent_tools.py +++ b/agents-api/agents_api/routers/agents/list_agent_tools.py @@ -5,7 +5,7 @@ from ...autogen.openapi_model import ListResponse, Tool from ...dependencies.developer_id import get_developer_id -from ...models.tools.list_tools import list_tools as list_tools_query +from ...queries.tools.list_tools import list_tools as list_tools_query from .router import router @@ -20,7 +20,7 @@ async def list_agent_tools( ) -> ListResponse[Tool]: # FIXME: list agent tools is returning an empty list # SCRUM-22 - tools = list_tools_query( + tools = await list_tools_query( agent_id=agent_id, developer_id=x_developer_id, limit=limit, diff --git a/agents-api/agents_api/routers/agents/list_agents.py b/agents-api/agents_api/routers/agents/list_agents.py index b96bec089..f3b74f7a4 100644 --- a/agents-api/agents_api/routers/agents/list_agents.py +++ b/agents-api/agents_api/routers/agents/list_agents.py @@ -6,7 +6,7 @@ from ...autogen.openapi_model import Agent, ListResponse from ...dependencies.developer_id import get_developer_id from ...dependencies.query_filter import MetadataFilter, create_filter_extractor -from ...models.agent.list_agents import list_agents as list_agents_query +from ...queries.agents.list_agents import list_agents as list_agents_query from .router import router @@ -24,7 +24,7 @@ async def list_agents( sort_by: Literal["created_at", "updated_at"] = "created_at", direction: Literal["asc", "desc"] = "desc", ) -> ListResponse[Agent]: - agents = list_agents_query( + agents = await list_agents_query( developer_id=x_developer_id, limit=limit, offset=offset, diff --git a/agents-api/agents_api/routers/agents/patch_agent.py b/agents-api/agents_api/routers/agents/patch_agent.py index f31f2c63e..bb7c16d5c 100644 --- a/agents-api/agents_api/routers/agents/patch_agent.py +++ b/agents-api/agents_api/routers/agents/patch_agent.py @@ -6,7 +6,7 @@ from ...autogen.openapi_model import PatchAgentRequest, ResourceUpdatedResponse from ...dependencies.developer_id import get_developer_id -from ...models.agent.patch_agent import patch_agent as patch_agent_query +from ...queries.agents.patch_agent import patch_agent as patch_agent_query from .router import router @@ -21,7 +21,7 @@ async def patch_agent( agent_id: UUID, data: PatchAgentRequest, ) -> ResourceUpdatedResponse: - return patch_agent_query( + return await patch_agent_query( agent_id=agent_id, developer_id=x_developer_id, data=data, diff --git a/agents-api/agents_api/routers/agents/patch_agent_tool.py b/agents-api/agents_api/routers/agents/patch_agent_tool.py index e4031810b..cef29dea2 100644 --- a/agents-api/agents_api/routers/agents/patch_agent_tool.py +++ b/agents-api/agents_api/routers/agents/patch_agent_tool.py @@ -8,7 +8,7 @@ ResourceUpdatedResponse, ) from ...dependencies.developer_id import get_developer_id -from ...models.tools.patch_tool import patch_tool as patch_tool_query +from ...queries.tools.patch_tool import patch_tool as patch_tool_query from .router import router @@ -19,7 +19,7 @@ async def patch_agent_tool( tool_id: UUID, data: PatchToolRequest, ) -> ResourceUpdatedResponse: - return patch_tool_query( + return await patch_tool_query( developer_id=x_developer_id, agent_id=agent_id, tool_id=tool_id, diff --git a/agents-api/agents_api/routers/agents/update_agent.py b/agents-api/agents_api/routers/agents/update_agent.py index d878b7d6b..608da0b20 100644 --- a/agents-api/agents_api/routers/agents/update_agent.py +++ b/agents-api/agents_api/routers/agents/update_agent.py @@ -6,7 +6,7 @@ from ...autogen.openapi_model import ResourceUpdatedResponse, UpdateAgentRequest from ...dependencies.developer_id import get_developer_id -from ...models.agent.update_agent import update_agent as update_agent_query +from ...queries.agents.update_agent import update_agent as update_agent_query from .router import router @@ -21,7 +21,7 @@ async def update_agent( agent_id: UUID, data: UpdateAgentRequest, ) -> ResourceUpdatedResponse: - return update_agent_query( + return await update_agent_query( developer_id=x_developer_id, agent_id=agent_id, data=data, diff --git a/agents-api/agents_api/routers/agents/update_agent_tool.py b/agents-api/agents_api/routers/agents/update_agent_tool.py index b736ea686..790cff39c 100644 --- a/agents-api/agents_api/routers/agents/update_agent_tool.py +++ b/agents-api/agents_api/routers/agents/update_agent_tool.py @@ -8,7 +8,7 @@ UpdateToolRequest, ) from ...dependencies.developer_id import get_developer_id -from ...models.tools.update_tool import update_tool as update_tool_query +from ...queries.tools.update_tool import update_tool as update_tool_query from .router import router @@ -19,7 +19,7 @@ async def update_agent_tool( tool_id: UUID, data: UpdateToolRequest, ) -> ResourceUpdatedResponse: - return update_tool_query( + return await update_tool_query( developer_id=x_developer_id, agent_id=agent_id, tool_id=tool_id, diff --git a/agents-api/agents_api/routers/docs/create_doc.py b/agents-api/agents_api/routers/docs/create_doc.py index b3cac1a87..d089f2802 100644 --- a/agents-api/agents_api/routers/docs/create_doc.py +++ b/agents-api/agents_api/routers/docs/create_doc.py @@ -1,66 +1,20 @@ from typing import Annotated -from uuid import UUID, uuid4 +from uuid import UUID -from fastapi import BackgroundTasks, Depends +from fastapi import Depends from starlette.status import HTTP_201_CREATED -from temporalio.client import Client as TemporalClient -from ...activities.types import EmbedDocsPayload from ...autogen.openapi_model import CreateDocRequest, Doc, ResourceCreatedResponse -from ...clients import temporal -from ...common.retry_policies import DEFAULT_RETRY_POLICY from ...dependencies.developer_id import get_developer_id -from ...env import temporal_task_queue, testing -from ...models.docs.create_doc import create_doc as create_doc_query +from ...queries.docs.create_doc import create_doc as create_doc_query from .router import router -async def run_embed_docs_task( - *, - developer_id: UUID, - doc_id: UUID, - title: str, - content: list[str], - embed_instruction: str | None = None, - job_id: UUID, - background_tasks: BackgroundTasks, - client: TemporalClient | None = None, -): - from ...workflows.embed_docs import EmbedDocsWorkflow - - client = client or (await temporal.get_client()) - - embed_payload = EmbedDocsPayload( - developer_id=developer_id, - doc_id=doc_id, - content=content, - title=title, - # Default embed instruction for docs. See https://docs.voyageai.com/docs/embeddings - embed_instruction=embed_instruction or "Represent the document for retrieval: ", - ) - - handle = await client.start_workflow( - EmbedDocsWorkflow.run, - embed_payload, - task_queue=temporal_task_queue, - id=str(job_id), - retry_policy=DEFAULT_RETRY_POLICY, - ) - - # TODO: Remove this conditional once we have a way to run workflows in - # a test environment. - if not testing: - background_tasks.add_task(handle.result) - - return handle - - @router.post("/users/{user_id}/docs", status_code=HTTP_201_CREATED, tags=["docs"]) async def create_user_doc( user_id: UUID, data: CreateDocRequest, x_developer_id: Annotated[UUID, Depends(get_developer_id)], - background_tasks: BackgroundTasks, ) -> ResourceCreatedResponse: """ Creates a new document for a user. @@ -69,34 +23,19 @@ async def create_user_doc( user_id (UUID): The unique identifier of the user associated with the document. data (CreateDocRequest): The data to create the document with. x_developer_id (UUID): The unique identifier of the developer associated with the document. - background_tasks (BackgroundTasks): The background tasks to run. Returns: ResourceCreatedResponse: The created document. """ - doc: Doc = create_doc_query( + doc: Doc = await create_doc_query( developer_id=x_developer_id, owner_type="user", owner_id=user_id, data=data, ) - embed_job_id = uuid4() - - await run_embed_docs_task( - developer_id=x_developer_id, - doc_id=doc.id, - title=doc.title, - content=doc.content, - embed_instruction=data.embed_instruction, - job_id=embed_job_id, - background_tasks=background_tasks, - ) - - return ResourceCreatedResponse( - id=doc.id, created_at=doc.created_at, jobs=[embed_job_id] - ) + return ResourceCreatedResponse(id=doc.id, created_at=doc.created_at, jobs=[]) @router.post("/agents/{agent_id}/docs", status_code=HTTP_201_CREATED, tags=["docs"]) @@ -104,27 +43,12 @@ async def create_agent_doc( agent_id: UUID, data: CreateDocRequest, x_developer_id: Annotated[UUID, Depends(get_developer_id)], - background_tasks: BackgroundTasks, ) -> ResourceCreatedResponse: - doc: Doc = create_doc_query( + doc: Doc = await create_doc_query( developer_id=x_developer_id, owner_type="agent", owner_id=agent_id, data=data, ) - embed_job_id = uuid4() - - await run_embed_docs_task( - developer_id=x_developer_id, - doc_id=doc.id, - title=doc.title, - content=doc.content, - embed_instruction=data.embed_instruction, - job_id=embed_job_id, - background_tasks=background_tasks, - ) - - return ResourceCreatedResponse( - id=doc.id, created_at=doc.created_at, jobs=[embed_job_id] - ) + return ResourceCreatedResponse(id=doc.id, created_at=doc.created_at, jobs=[]) diff --git a/agents-api/agents_api/routers/docs/delete_doc.py b/agents-api/agents_api/routers/docs/delete_doc.py index c67e46447..e1ec30a41 100644 --- a/agents-api/agents_api/routers/docs/delete_doc.py +++ b/agents-api/agents_api/routers/docs/delete_doc.py @@ -6,19 +6,17 @@ from ...autogen.openapi_model import ResourceDeletedResponse from ...dependencies.developer_id import get_developer_id -from ...models.docs.delete_doc import delete_doc as delete_doc_query +from ...queries.docs.delete_doc import delete_doc as delete_doc_query from .router import router -@router.delete( - "/agents/{agent_id}/docs/{doc_id}", status_code=HTTP_202_ACCEPTED, tags=["docs"] -) +@router.delete("/agents/{agent_id}/docs/{doc_id}", status_code=HTTP_202_ACCEPTED, tags=["docs"]) async def delete_agent_doc( doc_id: UUID, agent_id: UUID, x_developer_id: Annotated[UUID, Depends(get_developer_id)], ) -> ResourceDeletedResponse: - return delete_doc_query( + return await delete_doc_query( developer_id=x_developer_id, owner_id=agent_id, owner_type="agent", @@ -26,15 +24,13 @@ async def delete_agent_doc( ) -@router.delete( - "/users/{user_id}/docs/{doc_id}", status_code=HTTP_202_ACCEPTED, tags=["docs"] -) +@router.delete("/users/{user_id}/docs/{doc_id}", status_code=HTTP_202_ACCEPTED, tags=["docs"]) async def delete_user_doc( doc_id: UUID, user_id: UUID, x_developer_id: Annotated[UUID, Depends(get_developer_id)], ) -> ResourceDeletedResponse: - return delete_doc_query( + return await delete_doc_query( developer_id=x_developer_id, owner_id=user_id, owner_type="user", diff --git a/agents-api/agents_api/routers/docs/get_doc.py b/agents-api/agents_api/routers/docs/get_doc.py index b120bc867..498fb46e0 100644 --- a/agents-api/agents_api/routers/docs/get_doc.py +++ b/agents-api/agents_api/routers/docs/get_doc.py @@ -5,7 +5,7 @@ from ...autogen.openapi_model import Doc from ...dependencies.developer_id import get_developer_id -from ...models.docs.get_doc import get_doc as get_doc_query +from ...queries.docs.get_doc import get_doc as get_doc_query from .router import router @@ -14,4 +14,4 @@ async def get_doc( x_developer_id: Annotated[UUID, Depends(get_developer_id)], doc_id: UUID, ) -> Doc: - return get_doc_query(developer_id=x_developer_id, doc_id=doc_id) + return await get_doc_query(developer_id=x_developer_id, doc_id=doc_id) diff --git a/agents-api/agents_api/routers/docs/list_docs.py b/agents-api/agents_api/routers/docs/list_docs.py index 2f663a324..5f24e42cd 100644 --- a/agents-api/agents_api/routers/docs/list_docs.py +++ b/agents-api/agents_api/routers/docs/list_docs.py @@ -6,7 +6,7 @@ from ...autogen.openapi_model import Doc, ListResponse from ...dependencies.developer_id import get_developer_id from ...dependencies.query_filter import MetadataFilter, create_filter_extractor -from ...models.docs.list_docs import list_docs as list_docs_query +from ...queries.docs.list_docs import list_docs as list_docs_query from .router import router @@ -23,7 +23,7 @@ async def list_user_docs( sort_by: Literal["created_at", "updated_at"] = "created_at", direction: Literal["asc", "desc"] = "desc", ) -> ListResponse[Doc]: - docs = list_docs_query( + docs = await list_docs_query( developer_id=x_developer_id, owner_type="user", owner_id=user_id, @@ -49,7 +49,7 @@ async def list_agent_docs( sort_by: Literal["created_at", "updated_at"] = "created_at", direction: Literal["asc", "desc"] = "desc", ) -> ListResponse[Doc]: - docs = list_docs_query( + docs = await list_docs_query( developer_id=x_developer_id, owner_type="agent", owner_id=agent_id, diff --git a/agents-api/agents_api/routers/docs/search_docs.py b/agents-api/agents_api/routers/docs/search_docs.py index 22bba86a1..0c463a83a 100644 --- a/agents-api/agents_api/routers/docs/search_docs.py +++ b/agents-api/agents_api/routers/docs/search_docs.py @@ -1,5 +1,5 @@ import time -from typing import Annotated, Any, Dict, List, Optional, Tuple, Union +from typing import Annotated, Any from uuid import UUID import numpy as np @@ -13,30 +13,26 @@ VectorDocSearchRequest, ) from ...dependencies.developer_id import get_developer_id -from ...models.docs.mmr import maximal_marginal_relevance -from ...models.docs.search_docs_by_embedding import search_docs_by_embedding -from ...models.docs.search_docs_by_text import search_docs_by_text -from ...models.docs.search_docs_hybrid import search_docs_hybrid +from ...queries.docs.mmr import maximal_marginal_relevance +from ...queries.docs.search_docs_by_embedding import search_docs_by_embedding +from ...queries.docs.search_docs_by_text import search_docs_by_text +from ...queries.docs.search_docs_hybrid import search_docs_hybrid from .router import router def get_search_fn_and_params( search_params, -) -> Tuple[ - Any, Optional[Dict[str, Union[float, int, str, Dict[str, float], List[float]]]] -]: +) -> tuple[Any, dict[str, float | int | str | dict[str, float] | list[float]] | None]: search_fn, params = None, None match search_params: - case TextOnlyDocSearchRequest( - text=query, limit=k, metadata_filter=metadata_filter - ): + case TextOnlyDocSearchRequest(text=query, limit=k, metadata_filter=metadata_filter): search_fn = search_docs_by_text - params = dict( - query=query, - k=k, - metadata_filter=metadata_filter, - ) + params = { + "query": query, + "k": k, + "metadata_filter": metadata_filter, + } case VectorDocSearchRequest( vector=query_embedding, @@ -45,12 +41,12 @@ def get_search_fn_and_params( metadata_filter=metadata_filter, ): search_fn = search_docs_by_embedding - params = dict( - query_embedding=query_embedding, - k=k * 3 if search_params.mmr_strength > 0 else k, - confidence=confidence, - metadata_filter=metadata_filter, - ) + params = { + "query_embedding": query_embedding, + "k": k * 3 if search_params.mmr_strength > 0 else k, + "confidence": confidence, + "metadata_filter": metadata_filter, + } case HybridDocSearchRequest( text=query, @@ -61,14 +57,14 @@ def get_search_fn_and_params( metadata_filter=metadata_filter, ): search_fn = search_docs_hybrid - params = dict( - query=query, - query_embedding=query_embedding, - k=k * 3 if search_params.mmr_strength > 0 else k, - embed_search_options=dict(confidence=confidence), - alpha=alpha, - metadata_filter=metadata_filter, - ) + params = { + "query": query, + "query_embedding": query_embedding, + "k": k * 3 if search_params.mmr_strength > 0 else k, + "embed_search_options": {"confidence": confidence}, + "alpha": alpha, + "metadata_filter": metadata_filter, + } return search_fn, params @@ -76,9 +72,7 @@ def get_search_fn_and_params( @router.post("/users/{user_id}/search", tags=["docs"]) async def search_user_docs( x_developer_id: Annotated[UUID, Depends(get_developer_id)], - search_params: ( - TextOnlyDocSearchRequest | VectorDocSearchRequest | HybridDocSearchRequest - ), + search_params: (TextOnlyDocSearchRequest | VectorDocSearchRequest | HybridDocSearchRequest), user_id: UUID, ) -> DocSearchResponse: """ @@ -97,7 +91,7 @@ async def search_user_docs( search_fn, params = get_search_fn_and_params(search_params) start = time.time() - docs: list[DocReference] = search_fn( + docs: list[DocReference] = await search_fn( developer_id=x_developer_id, owners=[("user", user_id)], **params, @@ -128,9 +122,7 @@ async def search_user_docs( @router.post("/agents/{agent_id}/search", tags=["docs"]) async def search_agent_docs( x_developer_id: Annotated[UUID, Depends(get_developer_id)], - search_params: ( - TextOnlyDocSearchRequest | VectorDocSearchRequest | HybridDocSearchRequest - ), + search_params: (TextOnlyDocSearchRequest | VectorDocSearchRequest | HybridDocSearchRequest), agent_id: UUID, ) -> DocSearchResponse: """ @@ -148,7 +140,7 @@ async def search_agent_docs( search_fn, params = get_search_fn_and_params(search_params) start = time.time() - docs: list[DocReference] = search_fn( + docs: list[DocReference] = await search_fn( developer_id=x_developer_id, owners=[("agent", agent_id)], **params, diff --git a/agents-api/agents_api/routers/files/__init__.py b/agents-api/agents_api/routers/files/__init__.py index 5e3d5a62c..daddb2bf7 100644 --- a/agents-api/agents_api/routers/files/__init__.py +++ b/agents-api/agents_api/routers/files/__init__.py @@ -3,4 +3,5 @@ from .create_file import create_file from .delete_file import delete_file from .get_file import get_file +from .list_files import list_files from .router import router diff --git a/agents-api/agents_api/routers/files/create_file.py b/agents-api/agents_api/routers/files/create_file.py index 80d80e6f3..9736f65e8 100644 --- a/agents-api/agents_api/routers/files/create_file.py +++ b/agents-api/agents_api/routers/files/create_file.py @@ -12,24 +12,27 @@ ) from ...clients import async_s3 from ...dependencies.developer_id import get_developer_id -from ...models.files.create_file import create_file as create_file_query +from ...queries.files.create_file import create_file as create_file_query from .router import router async def upload_file_content(file_id: UUID, content: str) -> None: """Upload file content to blob storage using the file ID as the key""" - await async_s3.setup() key = str(file_id) content_bytes = base64.b64decode(content) - await async_s3.add_object(key, content_bytes) + client = await async_s3.setup() + await client.put_object(Bucket=async_s3.blob_store_bucket, Key=key, Body=content_bytes) + + +# TODO: Use streaming for large payloads @router.post("/files", status_code=HTTP_201_CREATED, tags=["files"]) async def create_file( x_developer_id: Annotated[UUID, Depends(get_developer_id)], data: CreateFileRequest, ) -> ResourceCreatedResponse: - file: File = create_file_query( + file: File = await create_file_query( developer_id=x_developer_id, data=data, ) diff --git a/agents-api/agents_api/routers/files/delete_file.py b/agents-api/agents_api/routers/files/delete_file.py index fbe10290e..4b949fcf9 100644 --- a/agents-api/agents_api/routers/files/delete_file.py +++ b/agents-api/agents_api/routers/files/delete_file.py @@ -7,22 +7,23 @@ from ...autogen.openapi_model import ResourceDeletedResponse from ...clients import async_s3 from ...dependencies.developer_id import get_developer_id -from ...models.files.delete_file import delete_file as delete_file_query +from ...queries.files.delete_file import delete_file as delete_file_query from .router import router async def delete_file_content(file_id: UUID) -> None: """Delete file content from blob storage using the file ID as the key""" - await async_s3.setup() + client = await async_s3.setup() key = str(file_id) - await async_s3.delete_object(key) + + await client.delete_object(Bucket=async_s3.blob_store_bucket, Key=key) @router.delete("/files/{file_id}", status_code=HTTP_202_ACCEPTED, tags=["files"]) async def delete_file( file_id: UUID, x_developer_id: Annotated[UUID, Depends(get_developer_id)] ) -> ResourceDeletedResponse: - resource_deleted = delete_file_query(developer_id=x_developer_id, file_id=file_id) + resource_deleted = await delete_file_query(developer_id=x_developer_id, file_id=file_id) # Delete the file content from blob storage await delete_file_content(file_id) diff --git a/agents-api/agents_api/routers/files/get_file.py b/agents-api/agents_api/routers/files/get_file.py index cc5dcdc35..44ca57656 100644 --- a/agents-api/agents_api/routers/files/get_file.py +++ b/agents-api/agents_api/routers/files/get_file.py @@ -7,23 +7,28 @@ from ...autogen.openapi_model import File from ...clients import async_s3 from ...dependencies.developer_id import get_developer_id -from ...models.files.get_file import get_file as get_file_query +from ...queries.files.get_file import get_file as get_file_query from .router import router +# TODO: Use streaming for large payloads and file ID formatting async def fetch_file_content(file_id: UUID) -> str: """Fetch file content from blob storage using the file ID as the key""" - await async_s3.setup() + client = await async_s3.setup() + key = str(file_id) - content = await async_s3.get_object(key) + result = await client.get_object(Bucket=async_s3.blob_store_bucket, Key=key) + content = await result["Body"].read() + return base64.b64encode(content).decode("utf-8") +# TODO: Use streaming for large payloads @router.get("/files/{file_id}", tags=["files"]) async def get_file( file_id: UUID, x_developer_id: Annotated[UUID, Depends(get_developer_id)] ) -> File: - file = get_file_query(developer_id=x_developer_id, file_id=file_id) + file = await get_file_query(developer_id=x_developer_id, file_id=file_id) # Fetch the file content from blob storage file.content = await fetch_file_content(file.id) diff --git a/agents-api/agents_api/routers/files/list_files.py b/agents-api/agents_api/routers/files/list_files.py new file mode 100644 index 000000000..abbfcb0e5 --- /dev/null +++ b/agents-api/agents_api/routers/files/list_files.py @@ -0,0 +1,24 @@ +from typing import Annotated +from uuid import UUID + +from fastapi import Depends + +from ...autogen.openapi_model import File +from ...dependencies.developer_id import get_developer_id +from ...queries.files.list_files import list_files as list_files_query +from .get_file import fetch_file_content +from .router import router + + +# TODO: Use streaming for large payloads +@router.get("/files", tags=["files"]) +async def list_files( + x_developer_id: Annotated[UUID, Depends(get_developer_id)], +) -> list[File]: + files = await list_files_query(developer_id=x_developer_id) + + # Fetch the file content from blob storage + for file in files: + file.content = await fetch_file_content(file.id) + + return files diff --git a/agents-api/agents_api/routers/healthz/__init__.py b/agents-api/agents_api/routers/healthz/__init__.py new file mode 100644 index 000000000..5859730f0 --- /dev/null +++ b/agents-api/agents_api/routers/healthz/__init__.py @@ -0,0 +1,2 @@ +from .check_health import check_health as check_health +from .router import router as router diff --git a/agents-api/agents_api/routers/healthz/check_health.py b/agents-api/agents_api/routers/healthz/check_health.py new file mode 100644 index 000000000..33fb19eff --- /dev/null +++ b/agents-api/agents_api/routers/healthz/check_health.py @@ -0,0 +1,19 @@ +import logging +from uuid import UUID + +from ...queries.agents.list_agents import list_agents as list_agents_query +from .router import router + + +@router.get("/healthz", tags=["healthz"]) +async def check_health() -> dict: + try: + # Check if the database is reachable + await list_agents_query( + developer_id=UUID("00000000-0000-0000-0000-000000000000"), + ) + except Exception as e: + logging.error("An error occurred while checking health: %s", str(e)) + return {"status": "error", "message": "An internal error has occurred."} + + return {"status": "ok"} diff --git a/agents-api/agents_api/routers/healthz/router.py b/agents-api/agents_api/routers/healthz/router.py new file mode 100644 index 000000000..5c3ec9311 --- /dev/null +++ b/agents-api/agents_api/routers/healthz/router.py @@ -0,0 +1,3 @@ +from fastapi import APIRouter + +router: APIRouter = APIRouter() diff --git a/agents-api/agents_api/routers/jobs/__init__.py b/agents-api/agents_api/routers/jobs/__init__.py index fa07d0740..9c5649244 100644 --- a/agents-api/agents_api/routers/jobs/__init__.py +++ b/agents-api/agents_api/routers/jobs/__init__.py @@ -1 +1 @@ -from .routers import router # noqa: F401 +from .routers import router as router diff --git a/agents-api/agents_api/routers/sessions/chat.py b/agents-api/agents_api/routers/sessions/chat.py index 85a1574ef..710e8481a 100644 --- a/agents-api/agents_api/routers/sessions/chat.py +++ b/agents-api/agents_api/routers/sessions/chat.py @@ -1,8 +1,9 @@ -from typing import Annotated, Optional -from uuid import UUID, uuid4 +from typing import Annotated +from uuid import UUID from fastapi import BackgroundTasks, Depends, Header, HTTPException, status from starlette.status import HTTP_201_CREATED +from uuid_extensions import uuid7 from ...autogen.openapi_model import ( ChatInput, @@ -18,10 +19,10 @@ from ...common.utils.template import render_template from ...dependencies.developer_id import get_developer_data from ...env import max_free_sessions -from ...models.chat.gather_messages import gather_messages -from ...models.chat.prepare_chat_context import prepare_chat_context -from ...models.entry.create_entries import create_entries -from ...models.session.count_sessions import count_sessions as count_sessions_query +from ...queries.chat.gather_messages import gather_messages +from ...queries.chat.prepare_chat_context import prepare_chat_context +from ...queries.entries.create_entries import create_entries +from ...queries.sessions.count_sessions import count_sessions as count_sessions_query from .metrics import total_tokens_per_user from .router import router @@ -38,7 +39,7 @@ async def chat( session_id: UUID, chat_input: ChatInput, background_tasks: BackgroundTasks, - x_custom_api_key: Optional[str] = Header(None, alias="X-Custom-Api-Key"), + x_custom_api_key: str | None = Header(None, alias="X-Custom-Api-Key"), ) -> ChatResponse: """ Initiates a chat session. @@ -56,7 +57,7 @@ async def chat( # check if the developer is paid if "paid" not in developer.tags: # get the session length - sessions = count_sessions_query(developer_id=developer.id) + sessions = await count_sessions_query(developer_id=developer.id) session_length = sessions["count"] if session_length > max_free_sessions: raise HTTPException( @@ -65,10 +66,11 @@ async def chat( ) if chat_input.stream: - raise NotImplementedError("Streaming is not yet implemented") + msg = "Streaming is not yet implemented" + raise NotImplementedError(msg) # First get the chat context - chat_context: ChatContext = prepare_chat_context( + chat_context: ChatContext = await prepare_chat_context( developer_id=developer.id, session_id=session_id, ) @@ -88,22 +90,20 @@ async def chat( # Prepare the environment env: dict = chat_context.get_chat_environment() env["docs"] = [ - dict( - title=ref.title, - content=[ref.snippet.content], - ) + { + "title": ref.title, + "content": [ref.snippet.content], + } for ref in doc_references ] # Render the system message - if situation := chat_context.session.situation: - system_message = dict( - role="system", - content=situation, - ) + if system_template := chat_context.session.system_template: + system_message = { + "role": "system", + "content": system_template, + } - system_messages: list[dict] = await render_template( - [system_message], variables=env - ) + system_messages: list[dict] = await render_template([system_message], variables=env) past_messages = system_messages + past_messages # Render the incoming messages @@ -132,7 +132,8 @@ async def chat( # SCRUM-7 if chat_context.session.context_overflow == "truncate": # messages = messages[-settings["max_tokens"] :] - raise NotImplementedError("Truncation is not yet implemented") + msg = "Truncation is not yet implemented" + raise NotImplementedError(msg) # FIXME: Hotfix for datetime not serializable. Needs investigation messages = [ @@ -218,7 +219,6 @@ async def chat( developer_id=developer.id, session_id=session_id, data=new_entries, - mark_session_as_updated=True, ) # Adaptive context handling @@ -228,15 +228,14 @@ async def chat( # SCRUM-8 # jobs = [await start_adaptive_context_workflow] - raise NotImplementedError("Adaptive context is not yet implemented") + msg = "Adaptive context is not yet implemented" + raise NotImplementedError(msg) # Return the response # FIXME: Implement streaming for chat - chat_response_class = ( - ChunkChatResponse if chat_input.stream else MessageChatResponse - ) + chat_response_class = ChunkChatResponse if chat_input.stream else MessageChatResponse chat_response: ChatResponse = chat_response_class( - id=uuid4(), + id=uuid7(), created_at=utcnow(), jobs=jobs, docs=doc_references, @@ -245,9 +244,7 @@ async def chat( ) total_tokens_per_user.labels(str(developer.id)).inc( - amount=chat_response.usage.total_tokens - if chat_response.usage is not None - else 0 + amount=chat_response.usage.total_tokens if chat_response.usage is not None else 0 ) return chat_response diff --git a/agents-api/agents_api/routers/sessions/create_or_update_session.py b/agents-api/agents_api/routers/sessions/create_or_update_session.py index a4efb0444..a15479891 100644 --- a/agents-api/agents_api/routers/sessions/create_or_update_session.py +++ b/agents-api/agents_api/routers/sessions/create_or_update_session.py @@ -9,8 +9,8 @@ ResourceUpdatedResponse, ) from ...dependencies.developer_id import get_developer_id -from ...models.session.create_or_update_session import ( - create_or_update_session as create_session_query, +from ...queries.sessions.create_or_update_session import ( + create_or_update_session as create_or_update_session_query, ) from .router import router @@ -21,10 +21,8 @@ async def create_or_update_session( session_id: UUID, data: CreateOrUpdateSessionRequest, ) -> ResourceUpdatedResponse: - session_updated = create_session_query( + return await create_or_update_session_query( developer_id=x_developer_id, session_id=session_id, data=data, ) - - return session_updated diff --git a/agents-api/agents_api/routers/sessions/create_session.py b/agents-api/agents_api/routers/sessions/create_session.py index a83b71d5a..8359f808b 100644 --- a/agents-api/agents_api/routers/sessions/create_session.py +++ b/agents-api/agents_api/routers/sessions/create_session.py @@ -9,7 +9,7 @@ ResourceCreatedResponse, ) from ...dependencies.developer_id import get_developer_id -from ...models.session.create_session import create_session as create_session_query +from ...queries.sessions.create_session import create_session as create_session_query from .router import router @@ -18,7 +18,7 @@ async def create_session( x_developer_id: Annotated[UUID, Depends(get_developer_id)], data: CreateSessionRequest, ) -> ResourceCreatedResponse: - session = create_session_query( + session = await create_session_query( developer_id=x_developer_id, data=data, ) diff --git a/agents-api/agents_api/routers/sessions/delete_session.py b/agents-api/agents_api/routers/sessions/delete_session.py index 1a664a871..f3d446d15 100644 --- a/agents-api/agents_api/routers/sessions/delete_session.py +++ b/agents-api/agents_api/routers/sessions/delete_session.py @@ -6,14 +6,12 @@ from ...autogen.openapi_model import ResourceDeletedResponse from ...dependencies.developer_id import get_developer_id -from ...models.session.delete_session import delete_session as delete_session_query +from ...queries.sessions.delete_session import delete_session as delete_session_query from .router import router -@router.delete( - "/sessions/{session_id}", status_code=HTTP_202_ACCEPTED, tags=["sessions"] -) +@router.delete("/sessions/{session_id}", status_code=HTTP_202_ACCEPTED, tags=["sessions"]) async def delete_session( session_id: UUID, x_developer_id: Annotated[UUID, Depends(get_developer_id)] ) -> ResourceDeletedResponse: - return delete_session_query(developer_id=x_developer_id, session_id=session_id) + return await delete_session_query(developer_id=x_developer_id, session_id=session_id) diff --git a/agents-api/agents_api/routers/sessions/get_session.py b/agents-api/agents_api/routers/sessions/get_session.py index df70a8f72..b77a01176 100644 --- a/agents-api/agents_api/routers/sessions/get_session.py +++ b/agents-api/agents_api/routers/sessions/get_session.py @@ -5,7 +5,7 @@ from ...autogen.openapi_model import Session from ...dependencies.developer_id import get_developer_id -from ...models.session.get_session import get_session as get_session_query +from ...queries.sessions.get_session import get_session as get_session_query from .router import router @@ -13,4 +13,4 @@ async def get_session( session_id: UUID, x_developer_id: Annotated[UUID, Depends(get_developer_id)] ) -> Session: - return get_session_query(developer_id=x_developer_id, session_id=session_id) + return await get_session_query(developer_id=x_developer_id, session_id=session_id) diff --git a/agents-api/agents_api/routers/sessions/get_session_history.py b/agents-api/agents_api/routers/sessions/get_session_history.py index fa993975b..e62aa9d2c 100644 --- a/agents-api/agents_api/routers/sessions/get_session_history.py +++ b/agents-api/agents_api/routers/sessions/get_session_history.py @@ -5,7 +5,7 @@ from ...autogen.openapi_model import History from ...dependencies.developer_id import get_developer_id -from ...models.entry.get_history import get_history as get_history_query +from ...queries.entries.get_history import get_history as get_history_query from .router import router @@ -13,4 +13,4 @@ async def get_session_history( session_id: UUID, x_developer_id: Annotated[UUID, Depends(get_developer_id)] ) -> History: - return get_history_query(developer_id=x_developer_id, session_id=session_id) + return await get_history_query(developer_id=x_developer_id, session_id=session_id) diff --git a/agents-api/agents_api/routers/sessions/list_sessions.py b/agents-api/agents_api/routers/sessions/list_sessions.py index fc9cd2e99..108f1528f 100644 --- a/agents-api/agents_api/routers/sessions/list_sessions.py +++ b/agents-api/agents_api/routers/sessions/list_sessions.py @@ -6,7 +6,7 @@ from ...autogen.openapi_model import ListResponse, Session from ...dependencies.developer_id import get_developer_id from ...dependencies.query_filter import MetadataFilter, create_filter_extractor -from ...models.session.list_sessions import list_sessions as list_sessions_query +from ...queries.sessions.list_sessions import list_sessions as list_sessions_query from .router import router @@ -21,7 +21,7 @@ async def list_sessions( sort_by: Literal["created_at", "updated_at"] = "created_at", direction: Literal["asc", "desc"] = "desc", ) -> ListResponse[Session]: - sessions = list_sessions_query( + sessions = await list_sessions_query( developer_id=x_developer_id, limit=limit, offset=offset, diff --git a/agents-api/agents_api/routers/sessions/patch_session.py b/agents-api/agents_api/routers/sessions/patch_session.py index 8eefab4dc..87acd3c0d 100644 --- a/agents-api/agents_api/routers/sessions/patch_session.py +++ b/agents-api/agents_api/routers/sessions/patch_session.py @@ -8,7 +8,7 @@ ResourceUpdatedResponse, ) from ...dependencies.developer_id import get_developer_id -from ...models.session.patch_session import patch_session as patch_session_query +from ...queries.sessions.patch_session import patch_session as patch_session_query from .router import router @@ -18,7 +18,7 @@ async def patch_session( session_id: UUID, data: PatchSessionRequest, ) -> ResourceUpdatedResponse: - return patch_session_query( + return await patch_session_query( developer_id=x_developer_id, session_id=session_id, data=data, diff --git a/agents-api/agents_api/routers/sessions/update_session.py b/agents-api/agents_api/routers/sessions/update_session.py index f35368d84..0c25e0652 100644 --- a/agents-api/agents_api/routers/sessions/update_session.py +++ b/agents-api/agents_api/routers/sessions/update_session.py @@ -8,7 +8,7 @@ UpdateSessionRequest, ) from ...dependencies.developer_id import get_developer_id -from ...models.session.update_session import update_session as update_session_query +from ...queries.sessions.update_session import update_session as update_session_query from .router import router @@ -18,7 +18,7 @@ async def update_session( session_id: UUID, data: UpdateSessionRequest, ) -> ResourceUpdatedResponse: - return update_session_query( + return await update_session_query( developer_id=x_developer_id, session_id=session_id, data=data, diff --git a/agents-api/agents_api/routers/tasks/__init__.py b/agents-api/agents_api/routers/tasks/__init__.py index 5ada6a04e..0c7180cd2 100644 --- a/agents-api/agents_api/routers/tasks/__init__.py +++ b/agents-api/agents_api/routers/tasks/__init__.py @@ -7,7 +7,6 @@ from .list_execution_transitions import list_execution_transitions from .list_task_executions import list_task_executions from .list_tasks import list_tasks -from .patch_execution import patch_execution from .router import router from .stream_transitions_events import stream_transitions_events from .update_execution import update_execution diff --git a/agents-api/agents_api/routers/tasks/create_or_update_task.py b/agents-api/agents_api/routers/tasks/create_or_update_task.py index f40530dfc..946ff7e6b 100644 --- a/agents-api/agents_api/routers/tasks/create_or_update_task.py +++ b/agents-api/agents_api/routers/tasks/create_or_update_task.py @@ -11,15 +11,13 @@ ResourceUpdatedResponse, ) from ...dependencies.developer_id import get_developer_id -from ...models.task.create_or_update_task import ( +from ...queries.tasks.create_or_update_task import ( create_or_update_task as create_or_update_task_query, ) from .router import router -@router.post( - "/agents/{agent_id}/tasks/{task_id}", status_code=HTTP_201_CREATED, tags=["tasks"] -) +@router.post("/agents/{agent_id}/tasks/{task_id}", status_code=HTTP_201_CREATED, tags=["tasks"]) async def create_or_update_task( data: CreateOrUpdateTaskRequest, agent_id: UUID, @@ -40,7 +38,7 @@ async def create_or_update_task( except ValidationError: pass - return create_or_update_task_query( + return await create_or_update_task_query( developer_id=x_developer_id, agent_id=agent_id, task_id=task_id, diff --git a/agents-api/agents_api/routers/tasks/create_task.py b/agents-api/agents_api/routers/tasks/create_task.py index 0e233ac97..0e8813102 100644 --- a/agents-api/agents_api/routers/tasks/create_task.py +++ b/agents-api/agents_api/routers/tasks/create_task.py @@ -11,7 +11,7 @@ ResourceCreatedResponse, ) from ...dependencies.developer_id import get_developer_id -from ...models.task.create_task import create_task as create_task_query +from ...queries.tasks.create_task import create_task as create_task_query from .router import router @@ -35,7 +35,7 @@ async def create_task( except ValidationError: pass - return create_task_query( + return await create_task_query( developer_id=x_developer_id, agent_id=agent_id, data=data, diff --git a/agents-api/agents_api/routers/tasks/create_task_execution.py b/agents-api/agents_api/routers/tasks/create_task_execution.py index 09342bf84..185825091 100644 --- a/agents-api/agents_api/routers/tasks/create_task_execution.py +++ b/agents-api/agents_api/routers/tasks/create_task_execution.py @@ -1,38 +1,40 @@ import logging from typing import Annotated -from uuid import UUID, uuid4 +from uuid import UUID from beartype import beartype from fastapi import BackgroundTasks, Depends, HTTPException, status from jsonschema import validate from jsonschema.exceptions import ValidationError -from pycozo.client import QueryException from starlette.status import HTTP_201_CREATED from temporalio.client import WorkflowHandle +from uuid_extensions import uuid7 from ...autogen.openapi_model import ( CreateExecutionRequest, + CreateTransitionRequest, Execution, ResourceCreatedResponse, - UpdateExecutionRequest, + TransitionTarget, ) from ...clients.temporal import run_task_execution_workflow from ...common.protocol.developers import Developer +from ...common.protocol.tasks import task_to_spec from ...dependencies.developer_id import get_developer_id from ...env import max_free_executions -from ...models.developer.get_developer import get_developer -from ...models.execution.count_executions import ( +from ...queries.developers.get_developer import get_developer +from ...queries.executions.count_executions import ( count_executions as count_executions_query, ) -from ...models.execution.create_execution import ( +from ...queries.executions.create_execution import ( create_execution as create_execution_query, ) -from ...models.execution.create_temporal_lookup import create_temporal_lookup -from ...models.execution.prepare_execution_input import prepare_execution_input -from ...models.execution.update_execution import ( - update_execution as update_execution_query, +from ...queries.executions.create_execution_transition import ( + create_execution_transition, ) -from ...models.task.get_task import get_task as get_task_query +from ...queries.executions.create_temporal_lookup import create_temporal_lookup +from ...queries.executions.prepare_execution_input import prepare_execution_input +from ...queries.tasks.get_task import get_task as get_task_query from .router import router logger: logging.Logger = logging.getLogger(__name__) @@ -45,26 +47,34 @@ async def start_execution( developer_id: UUID, task_id: UUID, data: CreateExecutionRequest, - client=None, + connection_pool=None, ) -> tuple[Execution, WorkflowHandle]: - execution_id = uuid4() + execution_id = uuid7() - execution = create_execution_query( + execution = await create_execution_query( developer_id=developer_id, task_id=task_id, execution_id=execution_id, data=data, - client=client, + connection_pool=connection_pool, ) - execution_input = prepare_execution_input( + execution_input = await prepare_execution_input( developer_id=developer_id, task_id=task_id, execution_id=execution_id, - client=client, + connection_pool=connection_pool, ) - job_id = uuid4() + task = await get_task_query( + developer_id=developer_id, + task_id=task_id, + connection_pool=connection_pool, + ) + + execution_input.task = task_to_spec(task) + + job_id = uuid7() try: handle = await run_task_execution_workflow( @@ -75,12 +85,19 @@ async def start_execution( except Exception as e: logger.exception(e) - update_execution_query( + await create_execution_transition( developer_id=developer_id, - task_id=task_id, execution_id=execution_id, - data=UpdateExecutionRequest(status="failed"), - client=client, + data=CreateTransitionRequest( + type="error", + output={"error": str(e)}, + current=TransitionTarget( + workflow="main", + step=0, + ), + next=None, + ), + connection_pool=connection_pool, ) raise HTTPException( @@ -103,7 +120,7 @@ async def create_task_execution( background_tasks: BackgroundTasks, ) -> ResourceCreatedResponse: try: - task = get_task_query(task_id=task_id, developer_id=x_developer_id) + task = await get_task_query(task_id=task_id, developer_id=x_developer_id) validate(data.input, task.input_schema) except ValidationError: @@ -111,22 +128,13 @@ async def create_task_execution( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request arguments schema", ) - except QueryException as e: - if e.code == "transact::assertion_failure": - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail="Task not found" - ) - - raise # get developer data - developer: Developer = get_developer(developer_id=x_developer_id) + developer: Developer = await get_developer(developer_id=x_developer_id) # # check if the developer is paid if "paid" not in developer.tags: - executions = count_executions_query( - developer_id=x_developer_id, task_id=task_id - ) + executions = await count_executions_query(developer_id=x_developer_id, task_id=task_id) execution_count = executions["count"] if execution_count > max_free_executions: @@ -143,8 +151,6 @@ async def create_task_execution( background_tasks.add_task( create_temporal_lookup, - # - developer_id=x_developer_id, execution_id=execution.id, workflow_handle=handle, ) diff --git a/agents-api/agents_api/routers/tasks/get_execution_details.py b/agents-api/agents_api/routers/tasks/get_execution_details.py index 95bccbc07..53b6ad6d5 100644 --- a/agents-api/agents_api/routers/tasks/get_execution_details.py +++ b/agents-api/agents_api/routers/tasks/get_execution_details.py @@ -3,7 +3,7 @@ from ...autogen.openapi_model import ( Execution, ) -from ...models.execution.get_execution import ( +from ...queries.executions.get_execution import ( get_execution as get_execution_query, ) from .router import router @@ -11,4 +11,4 @@ @router.get("/executions/{execution_id}", tags=["executions"]) async def get_execution_details(execution_id: UUID) -> Execution: - return get_execution_query(execution_id=execution_id) + return await get_execution_query(execution_id=execution_id) diff --git a/agents-api/agents_api/routers/tasks/get_task_details.py b/agents-api/agents_api/routers/tasks/get_task_details.py index 9f8008118..c6a70207e 100644 --- a/agents-api/agents_api/routers/tasks/get_task_details.py +++ b/agents-api/agents_api/routers/tasks/get_task_details.py @@ -1,14 +1,13 @@ from typing import Annotated from uuid import UUID -from fastapi import Depends, HTTPException, status -from pycozo.client import QueryException +from fastapi import Depends from ...autogen.openapi_model import ( Task, ) from ...dependencies.developer_id import get_developer_id -from ...models.task.get_task import get_task as get_task_query +from ...queries.tasks.get_task import get_task as get_task_query from .router import router @@ -17,20 +16,8 @@ async def get_task_details( task_id: UUID, x_developer_id: Annotated[UUID, Depends(get_developer_id)], ) -> Task: - not_found = HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail="Task not found" - ) - - try: - task = get_task_query(developer_id=x_developer_id, task_id=task_id) - task_data = task.model_dump() - except AssertionError: - raise not_found - except QueryException as e: - if e.code == "transact::assertion_failure": - raise not_found - - raise + task = await get_task_query(developer_id=x_developer_id, task_id=task_id) + task_data = task.model_dump() for workflow in task_data.get("workflows", []): if workflow["name"] == "main": diff --git a/agents-api/agents_api/routers/tasks/list_execution_transitions.py b/agents-api/agents_api/routers/tasks/list_execution_transitions.py index 9ce169509..c4e075184 100644 --- a/agents-api/agents_api/routers/tasks/list_execution_transitions.py +++ b/agents-api/agents_api/routers/tasks/list_execution_transitions.py @@ -1,11 +1,13 @@ from typing import Literal from uuid import UUID +from fastapi import HTTPException, status + from ...autogen.openapi_model import ( ListResponse, Transition, ) -from ...models.execution.list_execution_transitions import ( +from ...queries.executions.list_execution_transitions import ( list_execution_transitions as list_execution_transitions_query, ) from .router import router @@ -19,7 +21,7 @@ async def list_execution_transitions( sort_by: Literal["created_at", "updated_at"] = "created_at", direction: Literal["asc", "desc"] = "desc", ) -> ListResponse[Transition]: - transitions = list_execution_transitions_query( + transitions = await list_execution_transitions_query( execution_id=execution_id, limit=limit, offset=offset, @@ -30,22 +32,21 @@ async def list_execution_transitions( return ListResponse[Transition](items=transitions) -# TODO: Do we need this? -# @router.get("/executions/{execution_id}/transitions/{transition_id}", tags=["tasks"]) -# async def get_execution_transition( -# execution_id: UUID, -# transition_id: UUID, -# ) -> Transition: -# try: -# res = [ -# row.to_dict() -# for _, row in get_execution_transition_query( -# execution_id, transition_id -# ).iterrows() -# ][0] -# return Transition(**res) -# except (IndexError, KeyError): -# raise HTTPException( -# status_code=status.HTTP_404_NOT_FOUND, -# detail="Transition not found", -# ) +@router.get("/executions/{execution_id}/transitions/{transition_id}", tags=["tasks"]) +async def get_execution_transition( + execution_id: UUID, + transition_id: UUID, +) -> Transition: + try: + transitions = await list_execution_transitions_query( + execution_id=execution_id, + transition_id=transition_id, + ) + if not transitions: + raise IndexError + return transitions[0] + except (IndexError, KeyError): + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Transition not found", + ) diff --git a/agents-api/agents_api/routers/tasks/list_task_executions.py b/agents-api/agents_api/routers/tasks/list_task_executions.py index 72cbd9b40..17256f038 100644 --- a/agents-api/agents_api/routers/tasks/list_task_executions.py +++ b/agents-api/agents_api/routers/tasks/list_task_executions.py @@ -8,7 +8,7 @@ ListResponse, ) from ...dependencies.developer_id import get_developer_id -from ...models.execution.list_executions import ( +from ...queries.executions.list_executions import ( list_executions as list_task_executions_query, ) from .router import router @@ -23,7 +23,7 @@ async def list_task_executions( sort_by: Literal["created_at", "updated_at"] = "created_at", direction: Literal["asc", "desc"] = "desc", ) -> ListResponse[Execution]: - executions = list_task_executions_query( + executions = await list_task_executions_query( task_id=task_id, developer_id=x_developer_id, limit=limit, diff --git a/agents-api/agents_api/routers/tasks/list_tasks.py b/agents-api/agents_api/routers/tasks/list_tasks.py index a53983006..529700c09 100644 --- a/agents-api/agents_api/routers/tasks/list_tasks.py +++ b/agents-api/agents_api/routers/tasks/list_tasks.py @@ -8,7 +8,7 @@ Task, ) from ...dependencies.developer_id import get_developer_id -from ...models.task.list_tasks import list_tasks as list_tasks_query +from ...queries.tasks.list_tasks import list_tasks as list_tasks_query from .router import router @@ -21,7 +21,7 @@ async def list_tasks( sort_by: Literal["created_at", "updated_at"] = "created_at", direction: Literal["asc", "desc"] = "desc", ) -> ListResponse[Task]: - query_results = list_tasks_query( + query_results = await list_tasks_query( agent_id=agent_id, developer_id=x_developer_id, limit=limit, diff --git a/agents-api/agents_api/routers/tasks/patch_execution.py b/agents-api/agents_api/routers/tasks/patch_execution.py deleted file mode 100644 index 3cc45ee37..000000000 --- a/agents-api/agents_api/routers/tasks/patch_execution.py +++ /dev/null @@ -1,29 +0,0 @@ -from typing import Annotated -from uuid import UUID - -from fastapi import Depends - -from ...autogen.openapi_model import ( - ResourceUpdatedResponse, - UpdateExecutionRequest, -) -from ...dependencies.developer_id import get_developer_id -from ...models.execution.update_execution import ( - update_execution as update_execution_query, -) -from .router import router - - -@router.patch("/tasks/{task_id}/executions/{execution_id}", tags=["tasks"]) -async def patch_execution( - x_developer_id: Annotated[UUID, Depends(get_developer_id)], - task_id: UUID, - execution_id: UUID, - data: UpdateExecutionRequest, -) -> ResourceUpdatedResponse: - return update_execution_query( - developer_id=x_developer_id, - task_id=task_id, - execution_id=execution_id, - data=data, - ) diff --git a/agents-api/agents_api/routers/tasks/router.py b/agents-api/agents_api/routers/tasks/router.py index 101dcb228..0cecf572e 100644 --- a/agents-api/agents_api/routers/tasks/router.py +++ b/agents-api/agents_api/routers/tasks/router.py @@ -1,4 +1,4 @@ -from typing import Callable +from collections.abc import Callable from fastapi import APIRouter, Request, Response from fastapi.routing import APIRoute diff --git a/agents-api/agents_api/routers/tasks/stream_transitions_events.py b/agents-api/agents_api/routers/tasks/stream_transitions_events.py index 37500b0d6..92633bf08 100644 --- a/agents-api/agents_api/routers/tasks/stream_transitions_events.py +++ b/agents-api/agents_api/routers/tasks/stream_transitions_events.py @@ -18,7 +18,7 @@ from ...autogen.openapi_model import TransitionEvent from ...clients.temporal import get_workflow_handle from ...dependencies.developer_id import get_developer_id -from ...models.execution.lookup_temporal_data import lookup_temporal_data +from ...queries.executions.lookup_temporal_data import lookup_temporal_data from ...worker.codec import from_payload_data from .router import router @@ -36,9 +36,7 @@ async def event_publisher( async for event in history_events: # TODO: We should get the workflow-completed event as well and use that to close the stream if event.event_type == EventType.EVENT_TYPE_ACTIVITY_TASK_COMPLETED: - payloads = ( - event.activity_task_completed_event_attributes.result.payloads - ) + payloads = event.activity_task_completed_event_attributes.result.payloads for payload in payloads: try: @@ -52,11 +50,11 @@ async def event_publisher( continue # FIXME: This does NOT return the last event (and maybe other events) - transition_event_dict = dict( - type=data_item.type, - output=data_item.output, - created_at=data_item.created_at.isoformat(), - ) + transition_event_dict = { + "type": data_item.type, + "output": data_item.output, + "created_at": data_item.created_at.isoformat(), + } next_page_token = ( b64encode(history_events.next_page_token).decode("ascii") @@ -64,18 +62,16 @@ async def event_publisher( else None ) - await inner_send_chan.send( - dict( - data=dict( - transition=transition_event_dict, - next_page_token=next_page_token, - ), - ) - ) + await inner_send_chan.send({ + "data": { + "transition": transition_event_dict, + "next_page_token": next_page_token, + }, + }) except anyio.get_cancelled_exc_class() as e: with anyio.move_on_after(STREAM_TIMEOUT, shield=True): - await inner_send_chan.send(dict(closing=True)) + await inner_send_chan.send({"closing": True}) raise e @@ -87,7 +83,7 @@ async def stream_transitions_events( next_page_token: Annotated[str | None, Query()] = None, ): # Get temporal id - temporal_data = lookup_temporal_data( + temporal_data = await lookup_temporal_data( developer_id=x_developer_id, execution_id=execution_id, ) @@ -98,9 +94,7 @@ async def stream_transitions_events( handle_id=temporal_data["id"], ) - next_page_token: bytes | None = ( - b64decode(next_page_token) if next_page_token else None - ) + next_page_token: bytes | None = b64decode(next_page_token) if next_page_token else None history_events = workflow_handle.fetch_history_events( page_size=1, diff --git a/agents-api/agents_api/routers/tasks/update_execution.py b/agents-api/agents_api/routers/tasks/update_execution.py index e88c36ed9..b363c06ce 100644 --- a/agents-api/agents_api/routers/tasks/update_execution.py +++ b/agents-api/agents_api/routers/tasks/update_execution.py @@ -10,10 +10,10 @@ ) from ...clients.temporal import get_client from ...dependencies.developer_id import get_developer_id -from ...models.execution.get_paused_execution_token import ( +from ...queries.executions.get_paused_execution_token import ( get_paused_execution_token, ) -from ...models.execution.get_temporal_workflow_data import ( +from ...queries.executions.get_temporal_workflow_data import ( get_temporal_workflow_data, ) from .router import router @@ -31,24 +31,20 @@ async def update_execution( case StopExecutionRequest(): try: wf_handle = temporal_client.get_workflow_handle_for( - *get_temporal_workflow_data(execution_id=execution_id) + *await get_temporal_workflow_data(execution_id=execution_id) ) await wf_handle.cancel() except Exception: raise HTTPException(status_code=500, detail="Failed to stop execution") case ResumeExecutionRequest(): - token_data = get_paused_execution_token( - developer_id=x_developer_id, execution_id=execution_id - ) + token_data = await get_paused_execution_token(execution_id=execution_id) activity_id = token_data["metadata"].get("x-activity-id", None) run_id = token_data["metadata"].get("x-run-id", None) workflow_id = token_data["metadata"].get("x-workflow-id", None) if activity_id is None or run_id is None or workflow_id is None: act_handle = temporal_client.get_async_activity_handle( - task_token=base64.b64decode( - token_data["task_token"].encode("ascii") - ), + task_token=base64.b64decode(token_data["task_token"].encode("ascii")), ) else: @@ -60,8 +56,6 @@ async def update_execution( try: await act_handle.complete(data.input) except Exception: - raise HTTPException( - status_code=500, detail="Failed to resume execution" - ) + raise HTTPException(status_code=500, detail="Failed to resume execution") case _: raise HTTPException(status_code=400, detail="Invalid request data") diff --git a/agents-api/agents_api/routers/users/create_or_update_user.py b/agents-api/agents_api/routers/users/create_or_update_user.py index 0141983c9..0a1f9db37 100644 --- a/agents-api/agents_api/routers/users/create_or_update_user.py +++ b/agents-api/agents_api/routers/users/create_or_update_user.py @@ -6,7 +6,7 @@ from ...autogen.openapi_model import CreateOrUpdateUserRequest, ResourceCreatedResponse from ...dependencies.developer_id import get_developer_id -from ...models.user.create_or_update_user import ( +from ...queries.users.create_or_update_user import ( create_or_update_user as create_or_update_user_query, ) from .router import router @@ -18,7 +18,7 @@ async def create_or_update_user( user_id: UUID, data: CreateOrUpdateUserRequest, ) -> ResourceCreatedResponse: - user = create_or_update_user_query( + user = await create_or_update_user_query( developer_id=x_developer_id, user_id=user_id, data=data, diff --git a/agents-api/agents_api/routers/users/create_user.py b/agents-api/agents_api/routers/users/create_user.py index 4724a77b4..1ac42bc36 100644 --- a/agents-api/agents_api/routers/users/create_user.py +++ b/agents-api/agents_api/routers/users/create_user.py @@ -6,7 +6,7 @@ from ...autogen.openapi_model import CreateUserRequest, ResourceCreatedResponse from ...dependencies.developer_id import get_developer_id -from ...models.user.create_user import create_user as create_user_query +from ...queries.users.create_user import create_user as create_user_query from .router import router @@ -15,7 +15,7 @@ async def create_user( data: CreateUserRequest, x_developer_id: Annotated[UUID, Depends(get_developer_id)], ) -> ResourceCreatedResponse: - user = create_user_query( + user = await create_user_query( developer_id=x_developer_id, data=data, ) diff --git a/agents-api/agents_api/routers/users/delete_user.py b/agents-api/agents_api/routers/users/delete_user.py index d9d8032e7..bbc7f8736 100644 --- a/agents-api/agents_api/routers/users/delete_user.py +++ b/agents-api/agents_api/routers/users/delete_user.py @@ -6,7 +6,7 @@ from ...autogen.openapi_model import ResourceDeletedResponse from ...dependencies.developer_id import get_developer_id -from ...models.user.delete_user import delete_user as delete_user_query +from ...queries.users.delete_user import delete_user as delete_user_query from .router import router @@ -14,4 +14,4 @@ async def delete_user( user_id: UUID, x_developer_id: Annotated[UUID, Depends(get_developer_id)] ) -> ResourceDeletedResponse: - return delete_user_query(developer_id=x_developer_id, user_id=user_id) + return await delete_user_query(developer_id=x_developer_id, user_id=user_id) diff --git a/agents-api/agents_api/routers/users/get_user_details.py b/agents-api/agents_api/routers/users/get_user_details.py index 71a26c2dc..4a219869c 100644 --- a/agents-api/agents_api/routers/users/get_user_details.py +++ b/agents-api/agents_api/routers/users/get_user_details.py @@ -5,7 +5,7 @@ from ...autogen.openapi_model import User from ...dependencies.developer_id import get_developer_id -from ...models.user.get_user import get_user as get_user_query +from ...queries.users.get_user import get_user as get_user_query from .router import router @@ -14,4 +14,4 @@ async def get_user_details( x_developer_id: Annotated[UUID, Depends(get_developer_id)], user_id: UUID, ) -> User: - return get_user_query(developer_id=x_developer_id, user_id=user_id) + return await get_user_query(developer_id=x_developer_id, user_id=user_id) diff --git a/agents-api/agents_api/routers/users/list_users.py b/agents-api/agents_api/routers/users/list_users.py index 926699d40..4c027bbd3 100644 --- a/agents-api/agents_api/routers/users/list_users.py +++ b/agents-api/agents_api/routers/users/list_users.py @@ -6,7 +6,7 @@ from ...autogen.openapi_model import ListResponse, User from ...dependencies.developer_id import get_developer_id from ...dependencies.query_filter import MetadataFilter, create_filter_extractor -from ...models.user.list_users import list_users as list_users_query +from ...queries.users.list_users import list_users as list_users_query from .router import router @@ -21,7 +21,7 @@ async def list_users( sort_by: Literal["created_at", "updated_at"] = "created_at", direction: Literal["asc", "desc"] = "desc", ) -> ListResponse[User]: - users = list_users_query( + users = await list_users_query( developer_id=x_developer_id, limit=limit, offset=offset, diff --git a/agents-api/agents_api/routers/users/patch_user.py b/agents-api/agents_api/routers/users/patch_user.py index 8a49aaf93..03cd9bcfe 100644 --- a/agents-api/agents_api/routers/users/patch_user.py +++ b/agents-api/agents_api/routers/users/patch_user.py @@ -5,7 +5,7 @@ from ...autogen.openapi_model import PatchUserRequest, ResourceUpdatedResponse from ...dependencies.developer_id import get_developer_id -from ...models.user.patch_user import patch_user as patch_user_query +from ...queries.users.patch_user import patch_user as patch_user_query from .router import router @@ -15,7 +15,7 @@ async def patch_user( data: PatchUserRequest, x_developer_id: Annotated[UUID, Depends(get_developer_id)], ) -> ResourceUpdatedResponse: - return patch_user_query( + return await patch_user_query( developer_id=x_developer_id, user_id=user_id, data=data, diff --git a/agents-api/agents_api/routers/users/update_user.py b/agents-api/agents_api/routers/users/update_user.py index d9104da73..8071657d7 100644 --- a/agents-api/agents_api/routers/users/update_user.py +++ b/agents-api/agents_api/routers/users/update_user.py @@ -5,7 +5,7 @@ from ...autogen.openapi_model import ResourceUpdatedResponse, UpdateUserRequest from ...dependencies.developer_id import get_developer_id -from ...models.user.update_user import update_user as update_user_query +from ...queries.users.update_user import update_user as update_user_query from .router import router @@ -15,7 +15,7 @@ async def update_user( data: UpdateUserRequest, x_developer_id: Annotated[UUID, Depends(get_developer_id)], ) -> ResourceUpdatedResponse: - return update_user_query( + return await update_user_query( developer_id=x_developer_id, user_id=user_id, data=data, diff --git a/agents-api/agents_api/web.py b/agents-api/agents_api/web.py index 8e2e7da54..ae27cdaf8 100644 --- a/agents-api/agents_api/web.py +++ b/agents-api/agents_api/web.py @@ -4,30 +4,30 @@ import asyncio import logging -from typing import Any, Callable, Union, cast +from collections.abc import Callable +from typing import Any, cast import sentry_sdk import uvicorn import uvloop -from fastapi import APIRouter, Depends, FastAPI, Request, status +from fastapi import Depends, FastAPI, Request, status from fastapi.exceptions import HTTPException, RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from litellm.exceptions import APIError -from prometheus_fastapi_instrumentator import Instrumentator -from pycozo.client import QueryException from pydantic import ValidationError -from scalar_fastapi import get_scalar_api_reference from temporalio.service import RPCError +from .app import app from .common.exceptions import BaseCommonException from .dependencies.auth import get_api_key -from .env import api_prefix, hostname, protocol, public_port, sentry_dsn +from .env import sentry_dsn from .exceptions import PromptTooBigError from .routers import ( agents, docs, files, + healthz, internal, jobs, sessions, @@ -64,8 +64,8 @@ async def _handler(request: Request, exc: Exception): offending_input = None # Return the deepest matching possibility - if isinstance(exc, (ValidationError, RequestValidationError)): - exc = cast(Union[ValidationError, RequestValidationError], exc) + if isinstance(exc, ValidationError | RequestValidationError): + exc = cast(ValidationError | RequestValidationError, exc) errors = exc.errors() # Get the deepest matching errors @@ -91,9 +91,7 @@ async def _handler(request: Request, exc: Exception): if loc not in offending_input: break case list(): - if not ( - isinstance(loc, int) and 0 <= loc < len(offending_input) - ): + if not (isinstance(loc, int) and 0 <= loc < len(offending_input)): break case _: break @@ -134,50 +132,16 @@ def register_exceptions(app: FastAPI) -> None: RequestValidationError, make_exception_handler(status.HTTP_422_UNPROCESSABLE_ENTITY), ) - app.add_exception_handler( - QueryException, - make_exception_handler(status.HTTP_500_INTERNAL_SERVER_ERROR), - ) + # app.add_exception_handler( + # QueryException, + # make_exception_handler(status.HTTP_500_INTERNAL_SERVER_ERROR), + # ) # TODO: Auth logic should be moved into global middleware _per router_ # Because some routes don't require auth # See: https://fastapi.tiangolo.com/tutorial/bigger-applications/ # -app: FastAPI = FastAPI( - docs_url="/swagger", - openapi_prefix=api_prefix, - redoc_url=None, - title="Julep Agents API", - description="API for Julep Agents", - version="0.4.0", - terms_of_service="https://www.julep.ai/terms", - contact={ - "name": "Julep", - "url": "https://www.julep.ai", - "email": "team@julep.ai", - }, - root_path=api_prefix, -) - -# Enable metrics -Instrumentator().instrument(app).expose(app, include_in_schema=False) - -# Create a new router for the docs -scalar_router = APIRouter() - - -@scalar_router.get("/docs", include_in_schema=False) -async def scalar_html(): - return get_scalar_api_reference( - openapi_url=app.openapi_url[1:], # Remove leading '/' - title=app.title, - servers=[{"url": f"{protocol}://{hostname}:{public_port}{api_prefix}"}], - ) - - -# Add the docs_router without dependencies -app.include_router(scalar_router) # Add other routers with the get_api_key dependency app.include_router(agents.router, dependencies=[Depends(get_api_key)]) @@ -188,6 +152,7 @@ async def scalar_html(): app.include_router(docs.router, dependencies=[Depends(get_api_key)]) app.include_router(tasks.router, dependencies=[Depends(get_api_key)]) app.include_router(internal.router) +app.include_router(healthz.router) # TODO: CORS should be enabled only for JWT auth # @@ -218,9 +183,7 @@ async def http_exception_handler(request, exc: HTTPException): # pylint: disabl async def validation_error_handler(request: Request, exc: RPCError): return JSONResponse( status_code=status.HTTP_400_BAD_REQUEST, - content={ - "error": {"message": "job not found or invalid", "code": exc.status.name} - }, + content={"error": {"message": "job not found or invalid", "code": exc.status.name}}, ) diff --git a/agents-api/agents_api/worker/__main__.py b/agents-api/agents_api/worker/__main__.py index 0c419a0d0..07df1d4bf 100644 --- a/agents-api/agents_api/worker/__main__.py +++ b/agents-api/agents_api/worker/__main__.py @@ -3,7 +3,7 @@ It supports various workflows and activities related to agents' operations. """ -#!/usr/bin/env python3 +# !/usr/bin/env python3 import asyncio import logging diff --git a/agents-api/agents_api/worker/codec.py b/agents-api/agents_api/worker/codec.py index 8f213bc34..94a2d89b7 100644 --- a/agents-api/agents_api/worker/codec.py +++ b/agents-api/agents_api/worker/codec.py @@ -1,20 +1,17 @@ ### -### NOTE: Working with temporal's codec is really really weird -### This is a workaround to use pydantic models with temporal -### The codec is used to serialize/deserialize the data -### But this code is quite brittle. Be careful when changing it +# NOTE: Working with temporal's codec is really really weird +# This is a workaround to use pydantic models with temporal +# The codec is used to serialize/deserialize the data +# But this code is quite brittle. Be careful when changing it import dataclasses import logging import sys import time -from typing import Any, Optional, Type +from typing import Any import larch.pickle as pickle import temporalio.converter - -# from beartype import BeartypeConf -# from beartype.door import is_bearable, is_subhint from lz4.frame import compress, decompress from temporalio import workflow from temporalio.api.common.v1 import Payload @@ -57,7 +54,7 @@ def deserialize(b: bytes) -> Any: return object -def from_payload_data(data: bytes, type_hint: Optional[Type] = None) -> Any: +def from_payload_data(data: bytes, type_hint: type | None = None) -> Any: decoded = deserialize(data) if type_hint is None: @@ -65,54 +62,24 @@ def from_payload_data(data: bytes, type_hint: Optional[Type] = None) -> Any: decoded_type = type(decoded) - # TODO: Enable this check when temporal's codec stuff is fixed - # - # # Otherwise, check if the decoded value is bearable to the type hint - # if not is_bearable( - # decoded, - # type_hint, - # conf=BeartypeConf( - # is_pep484_tower=True - # ), # Check PEP 484 type hints. (be more lax on numeric types) - # ): - # logging.warning( - # f"WARNING: Decoded value {decoded_type} is not bearable to {type_hint}" - # ) - - # TODO: Enable this check when temporal's codec stuff is fixed - # - # If the decoded value is a BaseModel and the type hint is a subclass of BaseModel - # and the decoded value's class is a subclass of the type hint, then promote the decoded value - # to the type hint. if ( type_hint != decoded_type and hasattr(type_hint, "model_construct") and hasattr(decoded, "model_dump") - # - # TODO: Enable this check when temporal's codec stuff is fixed - # - # and is_subhint(type_hint, decoded_type) ): try: decoded = type_hint(**decoded.model_dump()) except Exception as e: - logging.warning( - f"WARNING: Could not promote {decoded_type} to {type_hint}: {e}" - ) + logging.warning(f"WARNING: Could not promote {decoded_type} to {type_hint}: {e}") return decoded -# TODO: Create a codec server for temporal to use for debugging -# SCRUM-12 -# This will allow us to see the data in the workflow history -# See: https://github.com/temporalio/samples-python/blob/main/encryption/codec_server.py -# https://docs.temporal.io/production-deployment/data-encryption#web-ui class PydanticEncodingPayloadConverter(EncodingPayloadConverter): encoding = "text/pickle+lz4" b_encoding = encoding.encode() - def to_payload(self, value: Any) -> Optional[Payload]: + def to_payload(self, value: Any) -> Payload | None: python_version = f"{sys.version_info.major}.{sys.version_info.minor}".encode() try: @@ -137,10 +104,8 @@ def to_payload(self, value: Any) -> Optional[Payload]: error_bytes = str(value).encode("utf-8") return FailedEncodingSentinel(payload_data=error_bytes) - def from_payload(self, payload: Payload, type_hint: Optional[Type] = None) -> Any: - current_python_version = ( - f"{sys.version_info.major}.{sys.version_info.minor}".encode() - ) + def from_payload(self, payload: Payload, type_hint: type | None = None) -> Any: + current_python_version = f"{sys.version_info.major}.{sys.version_info.minor}".encode() # Check if this is a payload we can handle if ( diff --git a/agents-api/agents_api/worker/worker.py b/agents-api/agents_api/worker/worker.py index 39eff2b54..e9c5fe78c 100644 --- a/agents-api/agents_api/worker/worker.py +++ b/agents-api/agents_api/worker/worker.py @@ -21,54 +21,34 @@ def create_worker(client: Client) -> Any: from ..activities import task_steps from ..activities.demo import demo_activity - from ..activities.embed_docs import embed_docs from ..activities.excecute_api_call import execute_api_call from ..activities.execute_integration import execute_integration from ..activities.execute_system import execute_system - from ..activities.mem_mgmt import mem_mgmt - from ..activities.mem_rating import mem_rating - from ..activities.summarization import summarization from ..activities.sync_items_remote import load_inputs_remote, save_inputs_remote - from ..activities.truncation import truncation from ..common.interceptors import CustomInterceptor from ..env import ( temporal_task_queue, ) from ..workflows.demo import DemoWorkflow - from ..workflows.embed_docs import EmbedDocsWorkflow - from ..workflows.mem_mgmt import MemMgmtWorkflow - from ..workflows.mem_rating import MemRatingWorkflow - from ..workflows.summarization import SummarizationWorkflow from ..workflows.task_execution import TaskExecutionWorkflow - from ..workflows.truncation import TruncationWorkflow - task_activity_names, task_activities = zip(*getmembers(task_steps, isfunction)) + _task_activity_names, task_activities = zip(*getmembers(task_steps, isfunction)) # Initialize the worker with the specified task queue, workflows, and activities - worker = Worker( + return Worker( client, graceful_shutdown_timeout=timedelta(seconds=30), task_queue=temporal_task_queue, workflows=[ DemoWorkflow, - SummarizationWorkflow, - MemMgmtWorkflow, - MemRatingWorkflow, - EmbedDocsWorkflow, TaskExecutionWorkflow, - TruncationWorkflow, ], activities=[ *task_activities, demo_activity, - embed_docs, execute_integration, execute_system, execute_api_call, - mem_mgmt, - mem_rating, - summarization, - truncation, save_inputs_remote, load_inputs_remote, ], @@ -78,5 +58,3 @@ def create_worker(client: Client) -> Any: max_activities_per_second=temporal_max_activities_per_second, max_task_queue_activities_per_second=temporal_max_task_queue_activities_per_second, ) - - return worker diff --git a/agents-api/agents_api/workflows/embed_docs.py b/agents-api/agents_api/workflows/embed_docs.py deleted file mode 100644 index 9e7b43d79..000000000 --- a/agents-api/agents_api/workflows/embed_docs.py +++ /dev/null @@ -1,27 +0,0 @@ -#!/usr/bin/env python3 - - -from datetime import timedelta - -from temporalio import workflow - -with workflow.unsafe.imports_passed_through(): - from ..activities.embed_docs import embed_docs - from ..activities.types import EmbedDocsPayload - from ..common.retry_policies import DEFAULT_RETRY_POLICY - from ..env import temporal_heartbeat_timeout, temporal_schedule_to_close_timeout - - -@workflow.defn -class EmbedDocsWorkflow: - @workflow.run - async def run(self, embed_payload: EmbedDocsPayload) -> None: - await workflow.execute_activity( - embed_docs, - embed_payload, - schedule_to_close_timeout=timedelta( - seconds=temporal_schedule_to_close_timeout - ), - retry_policy=DEFAULT_RETRY_POLICY, - heartbeat_timeout=timedelta(seconds=temporal_heartbeat_timeout), - ) diff --git a/agents-api/agents_api/workflows/mem_mgmt.py b/agents-api/agents_api/workflows/mem_mgmt.py deleted file mode 100644 index 1e945a7c4..000000000 --- a/agents-api/agents_api/workflows/mem_mgmt.py +++ /dev/null @@ -1,30 +0,0 @@ -#!/usr/bin/env python3 - - -from datetime import timedelta - -from temporalio import workflow - -with workflow.unsafe.imports_passed_through(): - from ..activities.mem_mgmt import mem_mgmt - from ..autogen.openapi_model import InputChatMLMessage - from ..env import temporal_heartbeat_timeout, temporal_schedule_to_close_timeout - - -@workflow.defn -class MemMgmtWorkflow: - @workflow.run - async def run( - self, - dialog: list[InputChatMLMessage], - session_id: str, - previous_memories: list[str], - ) -> None: - return await workflow.execute_activity( - mem_mgmt, - [dialog, session_id, previous_memories], - schedule_to_close_timeout=timedelta( - seconds=temporal_schedule_to_close_timeout - ), - heartbeat_timeout=timedelta(seconds=temporal_heartbeat_timeout), - ) diff --git a/agents-api/agents_api/workflows/mem_rating.py b/agents-api/agents_api/workflows/mem_rating.py deleted file mode 100644 index 2846c0b97..000000000 --- a/agents-api/agents_api/workflows/mem_rating.py +++ /dev/null @@ -1,26 +0,0 @@ -#!/usr/bin/env python3 - - -from datetime import timedelta - -from temporalio import workflow - -with workflow.unsafe.imports_passed_through(): - from ..activities.mem_rating import mem_rating - from ..common.retry_policies import DEFAULT_RETRY_POLICY - from ..env import temporal_heartbeat_timeout, temporal_schedule_to_close_timeout - - -@workflow.defn -class MemRatingWorkflow: - @workflow.run - async def run(self, memory: str) -> None: - return await workflow.execute_activity( - mem_rating, - memory, - schedule_to_close_timeout=timedelta( - seconds=temporal_schedule_to_close_timeout - ), - retry_policy=DEFAULT_RETRY_POLICY, - heartbeat_timeout=timedelta(seconds=temporal_heartbeat_timeout), - ) diff --git a/agents-api/agents_api/workflows/summarization.py b/agents-api/agents_api/workflows/summarization.py deleted file mode 100644 index 9338763da..000000000 --- a/agents-api/agents_api/workflows/summarization.py +++ /dev/null @@ -1,26 +0,0 @@ -#!/usr/bin/env python3 - - -from datetime import timedelta - -from temporalio import workflow - -with workflow.unsafe.imports_passed_through(): - from ..activities.summarization import summarization - from ..common.retry_policies import DEFAULT_RETRY_POLICY - from ..env import temporal_heartbeat_timeout, temporal_schedule_to_close_timeout - - -@workflow.defn -class SummarizationWorkflow: - @workflow.run - async def run(self, session_id: str) -> None: - return await workflow.execute_activity( - summarization, - session_id, - schedule_to_close_timeout=timedelta( - seconds=temporal_schedule_to_close_timeout - ), - retry_policy=DEFAULT_RETRY_POLICY, - heartbeat_timeout=timedelta(seconds=temporal_heartbeat_timeout), - ) diff --git a/agents-api/agents_api/workflows/task_execution/__init__.py b/agents-api/agents_api/workflows/task_execution/__init__.py index 6ea9239df..245f7c9a7 100644 --- a/agents-api/agents_api/workflows/task_execution/__init__.py +++ b/agents-api/agents_api/workflows/task_execution/__init__.py @@ -15,7 +15,7 @@ from ...activities.excecute_api_call import execute_api_call from ...activities.execute_integration import execute_integration from ...activities.execute_system import execute_system - from ...activities.sync_items_remote import load_inputs_remote, save_inputs_remote + from ...activities.sync_items_remote import save_inputs_remote from ...autogen.openapi_model import ( ApiCallDef, BaseIntegrationDef, @@ -140,18 +140,15 @@ async def set_last_error(self, value: LastErrorInput): async def run( self, execution_input: ExecutionInput, - start: TransitionTarget = TransitionTarget(workflow="main", step=0), - previous_inputs: list | None = None, + start: TransitionTarget, + previous_inputs: list, ) -> Any: workflow.logger.info( f"TaskExecutionWorkflow for task {execution_input.task.id}" f" [LOC {start.workflow}.{start.step}]" ) - # FIXME: Look into saving arguments to the blob store if necessary # 0. Prepare context - previous_inputs = previous_inputs or [execution_input.arguments] - context = StepContext( execution_input=execution_input, inputs=previous_inputs, @@ -193,37 +190,24 @@ async def run( context, # schedule_to_close_timeout=timedelta( - seconds=30 - if debug or testing - else temporal_schedule_to_close_timeout + seconds=30 if debug or testing else temporal_schedule_to_close_timeout ), retry_policy=DEFAULT_RETRY_POLICY, heartbeat_timeout=timedelta(seconds=temporal_heartbeat_timeout), ) - workflow.logger.debug( - f"Step {context.cursor.step} completed successfully" - ) + workflow.logger.debug(f"Step {context.cursor.step} completed successfully") except Exception as e: - workflow.logger.error(f"Error in step {context.cursor.step}: {str(e)}") + workflow.logger.error(f"Error in step {context.cursor.step}: {e!s}") await transition(context, type="error", output=str(e)) - raise ApplicationError(f"Activity {activity} threw error: {e}") from e + msg = f"Activity {activity} threw error: {e}" + raise ApplicationError(msg) from e # --- # 3. Then, based on the outcome and step type, decide what to do next workflow.logger.info(f"Processing outcome for step {context.cursor.step}") - [outcome] = await workflow.execute_activity( - load_inputs_remote, - args=[[outcome]], - schedule_to_close_timeout=timedelta( - seconds=60 if debug or testing else temporal_schedule_to_close_timeout - ), - retry_policy=DEFAULT_RETRY_POLICY, - heartbeat_timeout=timedelta(seconds=temporal_heartbeat_timeout), - ) - # Init state state = None @@ -232,9 +216,8 @@ async def run( case step, StepOutcome(error=error) if error is not None: workflow.logger.error(f"Error in step {context.cursor.step}: {error}") await transition(context, type="error", output=error) - raise ApplicationError( - f"Step {type(step).__name__} threw error: {error}" - ) + msg = f"Step {type(step).__name__} threw error: {error}" + raise ApplicationError(msg) case LogStep(), StepOutcome(output=log): workflow.logger.info(f"Log step: {log}") @@ -273,7 +256,8 @@ async def run( case SwitchStep(), StepOutcome(output=index) if index < 0: workflow.logger.error("Switch step: Invalid negative index") - raise ApplicationError("Negative indices not allowed") + msg = "Negative indices not allowed" + raise ApplicationError(msg) case IfElseWorkflowStep(then=then_branch, else_=else_branch), StepOutcome( output=condition @@ -336,17 +320,11 @@ async def run( days=days, ) ), _: - total_seconds = ( - seconds + minutes * 60 + hours * 60 * 60 + days * 24 * 60 * 60 - ) - workflow.logger.info( - f"Sleep step: Sleeping for {total_seconds} seconds" - ) + total_seconds = seconds + minutes * 60 + hours * 60 * 60 + days * 24 * 60 * 60 + workflow.logger.info(f"Sleep step: Sleeping for {total_seconds} seconds") assert total_seconds > 0, "Sleep duration must be greater than 0" - result = await asyncio.sleep( - total_seconds, result=context.current_input - ) + result = await asyncio.sleep(total_seconds, result=context.current_input) state = PartialTransition(output=result) @@ -366,14 +344,13 @@ async def run( last_error=self.last_error, ) - raise ApplicationError(f"Error raised by ErrorWorkflowStep: {error}") + msg = f"Error raised by ErrorWorkflowStep: {error}" + raise ApplicationError(msg) case YieldStep(), StepOutcome( output=output, transition_to=(yield_transition_type, yield_next_target) ): - workflow.logger.info( - f"Yield step: Transitioning to {yield_transition_type}" - ) + workflow.logger.info(f"Yield step: Transitioning to {yield_transition_type}") await transition( context, output=output, @@ -407,19 +384,17 @@ async def run( workflow.logger.debug(f"Prompt step: Received response: {message}") state = PartialTransition(output=message) - case PromptStep(auto_run_tools=False, unwrap=False), StepOutcome( - output=response - ): + case PromptStep(auto_run_tools=False, unwrap=False), StepOutcome(output=response): workflow.logger.debug(f"Prompt step: Received response: {response}") state = PartialTransition(output=response) - case PromptStep(unwrap=False), StepOutcome(output=response) if response[ - "choices" - ][0]["finish_reason"] != "tool_calls": + case PromptStep(unwrap=False), StepOutcome(output=response) if ( + response["choices"][0]["finish_reason"] != "tool_calls" + ): workflow.logger.debug(f"Prompt step: Received response: {response}") state = PartialTransition(output=response) - ## TODO: Handle multiple tool calls and multiple choices + # TODO: Handle multiple tool calls and multiple choices # case PromptStep(unwrap=False), StepOutcome(output=response) if response[ # "choices" # ][0]["finish_reason"] == "tool_calls": @@ -429,11 +404,9 @@ async def run( case PromptStep(auto_run_tools=True, unwrap=False), StepOutcome( output=response - ) if (choice := response["choices"][0])[ - "finish_reason" - ] == "tool_calls" and (tool_calls_input := choice["message"]["tool_calls"])[ - 0 - ]["type"] not in ["integration", "api_call", "system"]: + ) if (choice := response["choices"][0])["finish_reason"] == "tool_calls" and ( + tool_calls_input := choice["message"]["tool_calls"] + )[0]["type"] not in ["integration", "api_call", "system"]: workflow.logger.debug("Prompt step: Received FUNCTION tool call") # Enter a wait-for-input step to ask the developer to run the tool calls @@ -452,9 +425,7 @@ async def run( task_steps.prompt_step, context, schedule_to_close_timeout=timedelta( - seconds=30 - if debug or testing - else temporal_schedule_to_close_timeout + seconds=30 if debug or testing else temporal_schedule_to_close_timeout ), retry_policy=DEFAULT_RETRY_POLICY, heartbeat_timeout=timedelta(seconds=temporal_heartbeat_timeout), @@ -463,46 +434,43 @@ async def run( case PromptStep(auto_run_tools=True, unwrap=False), StepOutcome( output=response - ) if (choice := response["choices"][0])[ - "finish_reason" - ] == "tool_calls" and (tool_calls_input := choice["message"]["tool_calls"])[ - 0 - ]["type"] == "integration": + ) if (choice := response["choices"][0])["finish_reason"] == "tool_calls" and ( + tool_calls_input := choice["message"]["tool_calls"] + )[0]["type"] == "integration": workflow.logger.debug("Prompt step: Received INTEGRATION tool call") # FIXME: Implement integration tool calls # See: MANUAL TOOL CALL INTEGRATION (below) - raise NotImplementedError("Integration tool calls not yet supported") + msg = "Integration tool calls not yet supported" + raise NotImplementedError(msg) # TODO: Feed the tool call results back to the model (see above) case PromptStep(auto_run_tools=True, unwrap=False), StepOutcome( output=response - ) if (choice := response["choices"][0])[ - "finish_reason" - ] == "tool_calls" and (tool_calls_input := choice["message"]["tool_calls"])[ - 0 - ]["type"] == "api_call": + ) if (choice := response["choices"][0])["finish_reason"] == "tool_calls" and ( + tool_calls_input := choice["message"]["tool_calls"] + )[0]["type"] == "api_call": workflow.logger.debug("Prompt step: Received API_CALL tool call") # FIXME: Implement API_CALL tool calls # See: MANUAL TOOL CALL API_CALL (below) - raise NotImplementedError("API_CALL tool calls not yet supported") + msg = "API_CALL tool calls not yet supported" + raise NotImplementedError(msg) # TODO: Feed the tool call results back to the model (see above) case PromptStep(auto_run_tools=True, unwrap=False), StepOutcome( output=response - ) if (choice := response["choices"][0])[ - "finish_reason" - ] == "tool_calls" and (tool_calls_input := choice["message"]["tool_calls"])[ - 0 - ]["type"] == "system": + ) if (choice := response["choices"][0])["finish_reason"] == "tool_calls" and ( + tool_calls_input := choice["message"]["tool_calls"] + )[0]["type"] == "system": workflow.logger.debug("Prompt step: Received SYSTEM tool call") # FIXME: Implement SYSTEM tool calls # See: MANUAL TOOL CALL SYSTEM (below) - raise NotImplementedError("SYSTEM tool calls not yet supported") + msg = "SYSTEM tool calls not yet supported" + raise NotImplementedError(msg) # TODO: Feed the tool call results back to the model (see above) @@ -535,11 +503,12 @@ async def run( # FIXME: Implement ParallelStep # SCRUM-17 workflow.logger.error("ParallelStep not yet implemented") - raise ApplicationError("Not implemented") + msg = "Not implemented" + raise ApplicationError(msg) - case ToolCallStep(), StepOutcome(output=tool_call) if tool_call[ - "type" - ] == "function": + case ToolCallStep(), StepOutcome(output=tool_call) if ( + tool_call["type"] == "function" + ): # Enter a wait-for-input step to ask the developer to run the tool calls tool_call_response = await workflow.execute_activity( task_steps.raise_complete_async, @@ -551,20 +520,19 @@ async def run( state = PartialTransition(output=tool_call_response, type="resume") - case ToolCallStep(), StepOutcome(output=tool_call) if tool_call[ - "type" - ] == "integration": + case ToolCallStep(), StepOutcome(output=tool_call) if ( + tool_call["type"] == "integration" + ): # MANUAL TOOL CALL INTEGRATION workflow.logger.debug("ToolCallStep: Received INTEGRATION tool call") call = tool_call["integration"] tool_name = call["name"] arguments = call["arguments"] - integration_tool = next( - (t for t in context.tools if t.name == tool_name), None - ) + integration_tool = next((t for t in context.tools if t.name == tool_name), None) if integration_tool is None: - raise ApplicationError(f"Integration {tool_name} not found") + msg = f"Integration {tool_name} not found" + raise ApplicationError(msg) provider = integration_tool.integration.provider setup = ( @@ -584,9 +552,7 @@ async def run( execute_integration, args=[context, tool_name, integration, arguments], schedule_to_close_timeout=timedelta( - seconds=30 - if debug or testing - else temporal_schedule_to_close_timeout + seconds=30 if debug or testing else temporal_schedule_to_close_timeout ), retry_policy=DEFAULT_RETRY_POLICY, heartbeat_timeout=timedelta(seconds=temporal_heartbeat_timeout), @@ -594,20 +560,19 @@ async def run( state = PartialTransition(output=tool_call_response) - case ToolCallStep(), StepOutcome(output=tool_call) if tool_call[ - "type" - ] == "api_call": + case ToolCallStep(), StepOutcome(output=tool_call) if ( + tool_call["type"] == "api_call" + ): # MANUAL TOOL CALL API_CALL workflow.logger.debug("ToolCallStep: Received API_CALL tool call") call = tool_call["api_call"] tool_name = call["name"] arguments = call["arguments"] - apicall_tool = next( - (t for t in context.tools if t.name == tool_name), None - ) + apicall_tool = next((t for t in context.tools if t.name == tool_name), None) if apicall_tool is None: - raise ApplicationError(f"Integration {tool_name} not found") + msg = f"Integration {tool_name} not found" + raise ApplicationError(msg) api_call = ApiCallDef( method=apicall_tool.api_call.method, @@ -628,18 +593,14 @@ async def run( arguments, ], schedule_to_close_timeout=timedelta( - seconds=30 - if debug or testing - else temporal_schedule_to_close_timeout + seconds=30 if debug or testing else temporal_schedule_to_close_timeout ), heartbeat_timeout=timedelta(seconds=temporal_heartbeat_timeout), ) state = PartialTransition(output=tool_call_response) - case ToolCallStep(), StepOutcome(output=tool_call) if tool_call[ - "type" - ] == "system": + case ToolCallStep(), StepOutcome(output=tool_call) if tool_call["type"] == "system": # MANUAL TOOL CALL SYSTEM workflow.logger.debug("ToolCallStep: Received SYSTEM tool call") call = tool_call.get("system") @@ -649,9 +610,7 @@ async def run( execute_system, args=[context, system_call], schedule_to_close_timeout=timedelta( - seconds=30 - if debug or testing - else temporal_schedule_to_close_timeout + seconds=30 if debug or testing else temporal_schedule_to_close_timeout ), heartbeat_timeout=timedelta(seconds=temporal_heartbeat_timeout), ) @@ -669,7 +628,8 @@ async def run( last_error=self.last_error, ) - raise ApplicationError("Not implemented") + msg = "Not implemented" + raise ApplicationError(msg) # 4. Transition to the next step workflow.logger.info(f"Transitioning after step {context.cursor.step}") @@ -693,7 +653,8 @@ async def run( # 5b. Recurse to the next step if not final_state.next: - raise ApplicationError("No next step") + msg = "No next step" + raise ApplicationError(msg) workflow.logger.info( f"Continuing to next step: {final_state.next.workflow}.{final_state.next.step}" diff --git a/agents-api/agents_api/workflows/task_execution/helpers.py b/agents-api/agents_api/workflows/task_execution/helpers.py index 1d68322f5..6e115be7b 100644 --- a/agents-api/agents_api/workflows/task_execution/helpers.py +++ b/agents-api/agents_api/workflows/task_execution/helpers.py @@ -19,11 +19,9 @@ ExecutionInput, StepContext, ) - from ...common.storage_handler import auto_blob_store_workflow from ...env import task_max_parallelism, temporal_heartbeat_timeout -@auto_blob_store_workflow async def continue_as_child( execution_input: ExecutionInput, start: TransitionTarget, @@ -50,7 +48,6 @@ async def continue_as_child( ) -@auto_blob_store_workflow async def execute_switch_branch( *, context: StepContext, @@ -84,7 +81,6 @@ async def execute_switch_branch( ) -@auto_blob_store_workflow async def execute_if_else_branch( *, context: StepContext, @@ -123,7 +119,6 @@ async def execute_if_else_branch( ) -@auto_blob_store_workflow async def execute_foreach_step( *, context: StepContext, @@ -137,9 +132,7 @@ async def execute_foreach_step( results = [] for i, item in enumerate(items): - foreach_wf_name = ( - f"`{context.cursor.workflow}`[{context.cursor.step}].foreach[{i}]" - ) + foreach_wf_name = f"`{context.cursor.workflow}`[{context.cursor.step}].foreach[{i}]" foreach_task = execution_input.task.model_copy() foreach_task.workflows = [ Workflow(name=foreach_wf_name, steps=[do_step]), @@ -153,7 +146,7 @@ async def execute_foreach_step( result = await continue_as_child( foreach_execution_input, foreach_next_target, - previous_inputs + [item], + [*previous_inputs, item], user_state=user_state, ) results.append(result) @@ -161,7 +154,6 @@ async def execute_foreach_step( return results -@auto_blob_store_workflow async def execute_map_reduce_step( *, context: StepContext, @@ -178,9 +170,7 @@ async def execute_map_reduce_step( reduce = "results + [_]" if reduce is None else reduce for i, item in enumerate(items): - workflow_name = ( - f"`{context.cursor.workflow}`[{context.cursor.step}].mapreduce[{i}]" - ) + workflow_name = f"`{context.cursor.workflow}`[{context.cursor.step}].mapreduce[{i}]" map_reduce_task = execution_input.task.model_copy() map_reduce_task.workflows = [ Workflow(name=workflow_name, steps=[map_defn]), @@ -194,7 +184,7 @@ async def execute_map_reduce_step( output = await continue_as_child( map_reduce_execution_input, map_reduce_next_target, - previous_inputs + [item], + [*previous_inputs, item], user_state=user_state, ) @@ -209,7 +199,6 @@ async def execute_map_reduce_step( return result -@auto_blob_store_workflow async def execute_map_reduce_step_parallel( *, context: StepContext, @@ -235,7 +224,7 @@ async def execute_map_reduce_step_parallel( # Explanation: # - reduce is the reduce expression # - reducer_lambda is the lambda function that will be used to reduce the results - extra_lambda_strs = dict(reducer_lambda=f"lambda _result, _item: ({reduce})") + extra_lambda_strs = {"reducer_lambda": f"lambda _result, _item: ({reduce})"} reduce = "reduce(reducer_lambda, _, results)" @@ -248,7 +237,9 @@ async def execute_map_reduce_step_parallel( for j, item in enumerate(batch): # Parallel batch workflow name # Note: Added PAR: prefix to easily identify parallel batches in logs - workflow_name = f"PAR:`{context.cursor.workflow}`[{context.cursor.step}].mapreduce[{i}][{j}]" + workflow_name = ( + f"PAR:`{context.cursor.workflow}`[{context.cursor.step}].mapreduce[{i}][{j}]" + ) map_reduce_task = execution_input.task.model_copy() map_reduce_task.workflows = [ Workflow(name=workflow_name, steps=[map_defn]), @@ -264,7 +255,7 @@ async def execute_map_reduce_step_parallel( continue_as_child( map_reduce_execution_input, map_reduce_next_target, - previous_inputs + [item], + [*previous_inputs, item], user_state=user_state, ) ) @@ -289,6 +280,7 @@ async def execute_map_reduce_step_parallel( except BaseException as e: workflow.logger.error(f"Error in batch {i}: {e}") - raise ApplicationError(f"Error in batch {i}: {e}") from e + msg = f"Error in batch {i}: {e}" + raise ApplicationError(msg) from e return results diff --git a/agents-api/agents_api/workflows/task_execution/transition.py b/agents-api/agents_api/workflows/task_execution/transition.py index a26ac1778..ca1e63cc1 100644 --- a/agents-api/agents_api/workflows/task_execution/transition.py +++ b/agents-api/agents_api/workflows/task_execution/transition.py @@ -14,7 +14,6 @@ from ...common.retry_policies import DEFAULT_RETRY_POLICY from ...env import ( debug, - temporal_activity_after_retry_timeout, temporal_heartbeat_timeout, temporal_schedule_to_close_timeout, testing, @@ -62,5 +61,6 @@ async def transition( ) except Exception as e: - workflow.logger.error(f"Error in transition: {str(e)}") - raise ApplicationError(f"Error in transition: {e}") from e + workflow.logger.error(f"Error in transition: {e!s}") + msg = f"Error in transition: {e}" + raise ApplicationError(msg) from e diff --git a/agents-api/agents_api/workflows/truncation.py b/agents-api/agents_api/workflows/truncation.py deleted file mode 100644 index 1e83aebe7..000000000 --- a/agents-api/agents_api/workflows/truncation.py +++ /dev/null @@ -1,26 +0,0 @@ -#!/usr/bin/env python3 - - -from datetime import timedelta - -from temporalio import workflow - -with workflow.unsafe.imports_passed_through(): - from ..activities.truncation import truncation - from ..common.retry_policies import DEFAULT_RETRY_POLICY - from ..env import temporal_heartbeat_timeout, temporal_schedule_to_close_timeout - - -@workflow.defn -class TruncationWorkflow: - @workflow.run - async def run(self, session_id: str, token_count_threshold: int) -> None: - return await workflow.execute_activity( - truncation, - args=[session_id, token_count_threshold], - schedule_to_close_timeout=timedelta( - seconds=temporal_schedule_to_close_timeout - ), - retry_policy=DEFAULT_RETRY_POLICY, - heartbeat_timeout=timedelta(seconds=temporal_heartbeat_timeout), - ) diff --git a/agents-api/docker-compose.yml b/agents-api/docker-compose.yml index 94129896c..2116eafbc 100644 --- a/agents-api/docker-compose.yml +++ b/agents-api/docker-compose.yml @@ -8,8 +8,7 @@ x--shared-environment: &shared-environment AGENTS_API_PUBLIC_PORT: ${AGENTS_API_PUBLIC_PORT:-80} AGENTS_API_PROTOCOL: ${AGENTS_API_PROTOCOL:-http} AGENTS_API_URL: ${AGENTS_API_URL:-http://agents-api:8080} - COZO_AUTH_TOKEN: ${COZO_AUTH_TOKEN} - COZO_HOST: ${COZO_HOST:-http://memory-store:9070} + PG_DSN: ${PG_DSN:-postgres://postgres:postgres@memory-store:5432/postgres} DEBUG: ${AGENTS_API_DEBUG:-False} EMBEDDING_MODEL_ID: ${EMBEDDING_MODEL_ID:-Alibaba-NLP/gte-large-en-v1.5} INTEGRATION_SERVICE_URL: ${INTEGRATION_SERVICE_URL:-http://integrations:8000} @@ -29,6 +28,7 @@ x--shared-environment: &shared-environment TEMPORAL_MAX_CONCURRENT_ACTIVITIES: ${TEMPORAL_MAX_CONCURRENT_ACTIVITIES:-100} TEMPORAL_MAX_ACTIVITIES_PER_SECOND: ${TEMPORAL_MAX_ACTIVITIES_PER_SECOND} TEMPORAL_MAX_TASK_QUEUE_ACTIVITIES_PER_SECOND: ${TEMPORAL_MAX_TASK_QUEUE_ACTIVITIES_PER_SECOND} + AGENTS_API_TRANSITION_REQUESTS_PER_MINUTE: ${AGENTS_API_TRANSITION_REQUESTS_PER_MINUTE:-500} TRUNCATE_EMBED_TEXT: ${TRUNCATE_EMBED_TEXT:-True} WORKER_URL: ${WORKER_URL:-temporal:7233} USE_BLOB_STORE_FOR_TEMPORAL: ${USE_BLOB_STORE_FOR_TEMPORAL:-false} @@ -111,21 +111,3 @@ services: path: uv.lock - action: rebuild path: Dockerfile.worker - - cozo-migrate: - image: julepai/cozo-migrate:${TAG:-dev} - container_name: cozo-migrate - build: - context: . - dockerfile: Dockerfile.migration - restart: "no" # Make sure to double quote this - environment: - <<: *shared-environment - - develop: - watch: - - action: sync+restart - path: ./migrations - target: /app/migrations - - action: rebuild - path: Dockerfile.migration diff --git a/agents-api/migrations/migrate_1704699172_init.py b/agents-api/migrations/migrate_1704699172_init.py deleted file mode 100644 index 3a427ad48..000000000 --- a/agents-api/migrations/migrate_1704699172_init.py +++ /dev/null @@ -1,130 +0,0 @@ -# /usr/bin/env python3 - -MIGRATION_ID = "init" -CREATED_AT = 1704699172.673636 - - -def run(client, *queries): - joiner = "}\n\n{" - - query = joiner.join(queries) - query = f"{{\n{query}\n}}" - client.run(query) - - -def up(client): - create_agents_relation_query = """ - :create agents { - agent_id: Uuid, - => - name: String, - about: String, - model: String default 'gpt-4o', - created_at: Float default now(), - updated_at: Float default now(), - } - """ - - create_model_settings_relation_query = """ - :create agent_default_settings { - agent_id: Uuid, - => - frequency_penalty: Float default 0.0, - presence_penalty: Float default 0.0, - length_penalty: Float default 1.0, - repetition_penalty: Float default 1.0, - top_p: Float default 0.95, - temperature: Float default 0.7, - } - """ - - create_entries_relation_query = """ - :create entries { - session_id: Uuid, - entry_id: Uuid default random_uuid_v4(), - source: String, - role: String, - name: String? default null, - => - content: String, - token_count: Int, - tokenizer: String, - created_at: Float default now(), - } - """ - - create_sessions_relation_query = """ - :create sessions { - session_id: Uuid, - updated_at: Validity default [floor(now()), true], - => - situation: String, - summary: String? default null, - created_at: Float default now(), - } - """ - - create_session_lookup_relation_query = """ - :create session_lookup { - agent_id: Uuid, - user_id: Uuid? default null, - session_id: Uuid, - } - """ - - create_users_relation_query = """ - :create users { - user_id: Uuid, - => - name: String, - about: String, - created_at: Float default now(), - updated_at: Float default now(), - } - """ - - run( - client, - create_agents_relation_query, - create_model_settings_relation_query, - create_entries_relation_query, - create_sessions_relation_query, - create_session_lookup_relation_query, - create_users_relation_query, - ) - - -def down(client): - remove_agents_relation_query = """ - ::remove agents - """ - - remove_model_settings_relation_query = """ - ::remove agent_default_settings - """ - - remove_entries_relation_query = """ - ::remove entries - """ - - remove_sessions_relation_query = """ - ::remove sessions - """ - - remove_session_lookup_relation_query = """ - ::remove session_lookup - """ - - remove_users_relation_query = """ - ::remove users - """ - - run( - client, - remove_users_relation_query, - remove_session_lookup_relation_query, - remove_sessions_relation_query, - remove_entries_relation_query, - remove_model_settings_relation_query, - remove_agents_relation_query, - ) diff --git a/agents-api/migrations/migrate_1704699595_developers.py b/agents-api/migrations/migrate_1704699595_developers.py deleted file mode 100644 index d22edb393..000000000 --- a/agents-api/migrations/migrate_1704699595_developers.py +++ /dev/null @@ -1,151 +0,0 @@ -# /usr/bin/env python3 - -MIGRATION_ID = "developers" -CREATED_AT = 1704699595.546072 - - -def run(client, *queries): - joiner = "}\n\n{" - - query = joiner.join(queries) - query = f"{{\n{query}\n}}" - client.run(query) - - -def up(client): - update_agents_relation_query = """ - ?[agent_id, name, about, model, created_at, updated_at, developer_id] := *agents{ - agent_id, - name, - about, - model, - created_at, - updated_at, - }, developer_id = rand_uuid_v4() - - :replace agents { - developer_id: Uuid, - agent_id: Uuid, - => - name: String, - about: String, - model: String default 'gpt-4o', - created_at: Float default now(), - updated_at: Float default now(), - } - """ - - update_sessions_relation_query = """ - ?[developer_id, session_id, updated_at, situation, summary, created_at] := *sessions{ - session_id, - updated_at, - situation, - summary, - created_at, - }, developer_id = rand_uuid_v4() - - :replace sessions { - developer_id: Uuid, - session_id: Uuid, - updated_at: Validity default [floor(now()), true], - => - situation: String, - summary: String? default null, - created_at: Float default now(), - } - """ - - update_users_relation_query = """ - ?[user_id, name, about, created_at, updated_at, developer_id] := *users{ - user_id, - name, - about, - created_at, - updated_at, - }, developer_id = rand_uuid_v4() - - :replace users { - developer_id: Uuid, - user_id: Uuid, - => - name: String, - about: String, - created_at: Float default now(), - updated_at: Float default now(), - } - """ - - run( - client, - update_agents_relation_query, - update_sessions_relation_query, - update_users_relation_query, - ) - - -def down(client): - update_agents_relation_query = """ - ?[agent_id, name, about, model, created_at, updated_at] := *agents{ - agent_id, - name, - about, - model, - created_at, - updated_at, - } - - :replace agents { - agent_id: Uuid, - => - name: String, - about: String, - model: String default 'gpt-4o', - created_at: Float default now(), - updated_at: Float default now(), - } - """ - - update_sessions_relation_query = """ - ?[session_id, updated_at, situation, summary, created_at] := *sessions{ - session_id, - updated_at, - situation, - summary, - created_at, - } - - :replace sessions { - session_id: Uuid, - updated_at: Validity default [floor(now()), true], - => - situation: String, - summary: String? default null, - created_at: Float default now(), - } - """ - - update_users_relation_query = """ - ?[user_id, name, about, created_at, updated_at] := *users{ - user_id, - name, - about, - created_at, - updated_at, - } - - :replace users { - user_id: Uuid, - => - name: String, - about: String, - created_at: Float default now(), - updated_at: Float default now(), - } - """ - - run( - client, - update_users_relation_query, - update_sessions_relation_query, - update_agents_relation_query, - ) diff --git a/agents-api/migrations/migrate_1704728076_additional_info.py b/agents-api/migrations/migrate_1704728076_additional_info.py deleted file mode 100644 index c20f021f4..000000000 --- a/agents-api/migrations/migrate_1704728076_additional_info.py +++ /dev/null @@ -1,107 +0,0 @@ -# /usr/bin/env python3 - -MIGRATION_ID = "additional_info" -CREATED_AT = 1704728076.129496 - - -def run(client, *queries): - joiner = "}\n\n{" - - query = joiner.join(queries) - query = f"{{\n{query}\n}}" - client.run(query) - - -agent_additional_info_table = dict( - up=""" - :create agent_additional_info { - agent_id: Uuid, - additional_info_id: Uuid - => - created_at: Float default now(), - } - """, - down=""" - ::remove agent_additional_info - """, -) - -user_additional_info_table = dict( - up=""" - :create user_additional_info { - user_id: Uuid, - additional_info_id: Uuid - => - created_at: Float default now(), - } - """, - down=""" - ::remove user_additional_info - """, -) - -information_snippets_table = dict( - up=""" - :create information_snippets { - additional_info_id: Uuid, - snippet_idx: Int, - => - title: String, - snippet: String, - embed_instruction: String default 'Encode this passage for retrieval: ', - embedding: ? default null, - } - """, - down=""" - ::remove information_snippets - """, -) - -# See: https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md -information_snippets_hnsw_index = dict( - up=""" - ::hnsw create information_snippets:embedding_space { - fields: [embedding], - filter: !is_null(embedding), - dim: 768, - distance: Cosine, - m: 64, - ef_construction: 256, - extend_candidates: false, - keep_pruned_connections: false, - } - """, - down=""" - ::hnsw drop information_snippets:embedding_space - """, -) - -# See: https://docs.cozodb.org/en/latest/vector.html#full-text-search-fts -information_snippets_fts_index = dict( - up=""" - ::fts create information_snippets:fts { - extractor: concat(title, ' ', snippet), - tokenizer: Simple, - filters: [Lowercase, Stemmer('english'), Stopwords('en')], - } - """, - down=""" - ::fts drop information_snippets:fts - """, -) - -queries_to_run = [ - agent_additional_info_table, - user_additional_info_table, - information_snippets_table, - information_snippets_hnsw_index, - information_snippets_fts_index, -] - - -def up(client): - run(client, *[q["up"] for q in queries_to_run]) - - -def down(client): - run(client, *[q["down"] for q in reversed(queries_to_run)]) diff --git a/agents-api/migrations/migrate_1704892503_tools.py b/agents-api/migrations/migrate_1704892503_tools.py deleted file mode 100644 index 38fefaa08..000000000 --- a/agents-api/migrations/migrate_1704892503_tools.py +++ /dev/null @@ -1,106 +0,0 @@ -# /usr/bin/env python3 - -MIGRATION_ID = "tools" -CREATED_AT = 1704892503.302678 - - -def run(client, *queries): - joiner = "}\n\n{" - - query = joiner.join(queries) - query = f"{{\n{query}\n}}" - client.run(query) - - -agent_instructions_table = dict( - up=""" - :create agent_instructions { - agent_id: Uuid, - instruction_idx: Int, - => - content: String, - important: Bool default false, - embed_instruction: String default 'Embed this historical text chunk for retrieval: ', - embedding: ? default null, - created_at: Float default now(), - } - """, - down=""" - ::remove agent_instructions - """, -) - -# See: https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md -agent_instructions_hnsw_index = dict( - up=""" - ::hnsw create agent_instructions:embedding_space { - fields: [embedding], - filter: !is_null(embedding), - dim: 768, - distance: Cosine, - m: 64, - ef_construction: 256, - extend_candidates: false, - keep_pruned_connections: false, - } - """, - down=""" - ::hnsw drop agent_instructions:embedding_space - """, -) - -agent_functions_table = dict( - up=""" - :create agent_functions { - agent_id: Uuid, - tool_id: Uuid, - => - name: String, - description: String, - parameters: Json, - embed_instruction: String default 'Transform this tool description for retrieval: ', - embedding: ? default null, - updated_at: Float default now(), - created_at: Float default now(), - } - """, - down=""" - ::remove agent_functions - """, -) - - -# See: https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md -agent_functions_hnsw_index = dict( - up=""" - ::hnsw create agent_functions:embedding_space { - fields: [embedding], - filter: !is_null(embedding), - dim: 768, - distance: Cosine, - m: 64, - ef_construction: 256, - extend_candidates: false, - keep_pruned_connections: false, - } - """, - down=""" - ::hnsw drop agent_functions:embedding_space - """, -) - - -queries_to_run = [ - agent_instructions_table, - agent_instructions_hnsw_index, - agent_functions_table, - agent_functions_hnsw_index, -] - - -def up(client): - run(client, *[q["up"] for q in queries_to_run]) - - -def down(client): - run(client, *[q["down"] for q in reversed(queries_to_run)]) diff --git a/agents-api/migrations/migrate_1706090164_entries_timestamp.py b/agents-api/migrations/migrate_1706090164_entries_timestamp.py deleted file mode 100644 index d85a7170e..000000000 --- a/agents-api/migrations/migrate_1706090164_entries_timestamp.py +++ /dev/null @@ -1,102 +0,0 @@ -# /usr/bin/env python3 - -MIGRATION_ID = "entries_timestamp" -CREATED_AT = 1706090164.80913 - - -def run(client, *queries): - joiner = "}\n\n{" - - query = joiner.join(queries) - query = f"{{\n{query}\n}}" - client.run(query) - - -update_entries = { - "up": """ - ?[ - session_id, - entry_id, - source, - role, - name, - content, - token_count, - tokenizer, - created_at, - timestamp, - ] := *entries{ - session_id, - entry_id, - source, - role, - name, - content, - token_count, - tokenizer, - created_at, - }, timestamp = created_at - - :replace entries { - session_id: Uuid, - entry_id: Uuid default random_uuid_v4(), - source: String, - role: String, - name: String? default null, - => - content: String, - token_count: Int, - tokenizer: String, - created_at: Float default now(), - timestamp: Float default now(), - } - """, - "down": """ - ?[ - session_id, - entry_id, - source, - role, - name, - content, - token_count, - tokenizer, - created_at, - ] := *entries{ - session_id, - entry_id, - source, - role, - name, - content, - token_count, - tokenizer, - created_at, - } - - :replace entries { - session_id: Uuid, - entry_id: Uuid default random_uuid_v4(), - source: String, - role: String, - name: String? default null, - => - content: String, - token_count: Int, - tokenizer: String, - created_at: Float default now(), - } - """, -} - -queries_to_run = [ - update_entries, -] - - -def up(client): - run(client, *[q["up"] for q in queries_to_run]) - - -def down(client): - run(client, *[q["down"] for q in queries_to_run]) diff --git a/agents-api/migrations/migrate_1706092435_entry_relations.py b/agents-api/migrations/migrate_1706092435_entry_relations.py deleted file mode 100644 index e031b27d1..000000000 --- a/agents-api/migrations/migrate_1706092435_entry_relations.py +++ /dev/null @@ -1,38 +0,0 @@ -# /usr/bin/env python3 - -MIGRATION_ID = "entry_relations" -CREATED_AT = 1706092435.462968 - - -def run(client, *queries): - joiner = "}\n\n{" - - query = joiner.join(queries) - query = f"{{\n{query}\n}}" - client.run(query) - - -entry_relations = { - "up": """ - :create entry_relations { - head: Uuid, - relation: String, - tail: Uuid, - } - """, - "down": """ - ::remove entry_relations - """, -} - -queries_to_run = [ - entry_relations, -] - - -def up(client): - run(client, *[q["up"] for q in queries_to_run]) - - -def down(client): - run(client, *[q["down"] for q in queries_to_run]) diff --git a/agents-api/migrations/migrate_1707537826_rename_additional_info.py b/agents-api/migrations/migrate_1707537826_rename_additional_info.py deleted file mode 100644 index d71576f05..000000000 --- a/agents-api/migrations/migrate_1707537826_rename_additional_info.py +++ /dev/null @@ -1,217 +0,0 @@ -# /usr/bin/env python3 - -MIGRATION_ID = "rename_additional_info" -CREATED_AT = 1707537826.539182 - -rename_agent_doc_id = dict( - up=""" - ?[agent_id, doc_id, created_at] := - *agent_additional_info{ - agent_id, - additional_info_id: doc_id, - created_at, - } - - :replace agent_additional_info { - agent_id: Uuid, - doc_id: Uuid - => - created_at: Float default now(), - } - """, - down=""" - ?[agent_id, additional_info_id, created_at] := - *agent_additional_info{ - agent_id, - doc_id: additional_info_id, - created_at, - } - - :replace agent_additional_info { - agent_id: Uuid, - additional_info_id: Uuid - => - created_at: Float default now(), - } - """, -) - - -rename_user_doc_id = dict( - up=""" - ?[user_id, doc_id, created_at] := - *user_additional_info{ - user_id, - additional_info_id: doc_id, - created_at, - } - - :replace user_additional_info { - user_id: Uuid, - doc_id: Uuid - => - created_at: Float default now(), - } - """, - down=""" - ?[user_id, additional_info_id, created_at] := - *user_additional_info{ - user_id, - doc_id: additional_info_id, - created_at, - } - - :replace user_additional_info { - user_id: Uuid, - additional_info_id: Uuid - => - created_at: Float default now(), - } - """, -) - -# See: https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md -information_snippets_hnsw_index = dict( - up=""" - ::hnsw create information_snippets:embedding_space { - fields: [embedding], - filter: !is_null(embedding), - dim: 768, - distance: Cosine, - m: 64, - ef_construction: 256, - extend_candidates: false, - keep_pruned_connections: false, - } - """, - down=""" - ::hnsw drop information_snippets:embedding_space - """, -) - -# See: https://docs.cozodb.org/en/latest/vector.html#full-text-search-fts -information_snippets_fts_index = dict( - up=""" - ::fts create information_snippets:fts { - extractor: concat(title, ' ', snippet), - tokenizer: Simple, - filters: [Lowercase, Stemmer('english'), Stopwords('en')], - } - """, - down=""" - ::fts drop information_snippets:fts - """, -) - -drop_information_snippets_hnsw_index = { - "up": information_snippets_hnsw_index["down"], - "down": information_snippets_hnsw_index["up"], -} - - -drop_information_snippets_fts_index = { - "up": information_snippets_fts_index["down"], - "down": information_snippets_fts_index["up"], -} - - -rename_information_snippets_doc_id = dict( - up=""" - ?[ - doc_id, - snippet_idx, - title, - snippet, - embed_instruction, - embedding, - ] := - *information_snippets{ - snippet_idx, - title, - snippet, - embed_instruction, - embedding, - additional_info_id: doc_id, - } - - :replace information_snippets { - doc_id: Uuid, - snippet_idx: Int, - => - title: String, - snippet: String, - embed_instruction: String default 'Encode this passage for retrieval: ', - embedding: ? default null, - } - """, - down=""" - ?[ - additional_info_id, - snippet_idx, - title, - snippet, - embed_instruction, - embedding, - ] := - *information_snippets{ - snippet_idx, - title, - snippet, - embed_instruction, - embedding, - doc_id: additional_info_id, - } - - :replace information_snippets { - additional_info_id: Uuid, - snippet_idx: Int, - => - title: String, - snippet: String, - embed_instruction: String default 'Encode this passage for retrieval: ', - embedding: ? default null, - } - """, -) - -rename_relations = dict( - up=""" - ::rename - agent_additional_info -> agent_docs, - user_additional_info -> user_docs - """, - down=""" - ::rename - agent_docs -> agent_additional_info, - user_docs -> user_additional_info - """, -) - - -queries_to_run = [ - rename_agent_doc_id, - rename_user_doc_id, - drop_information_snippets_hnsw_index, - drop_information_snippets_fts_index, - rename_information_snippets_doc_id, - information_snippets_hnsw_index, - information_snippets_fts_index, - rename_relations, -] - - -def run(client, *queries): - joiner = "}\n\n{" - - query = joiner.join(queries) - query = f"{{\n{query}\n}}" - - client.run(query) - - -def up(client): - run(client, *[q["up"] for q in queries_to_run]) - - -def down(client): - run(client, *[q["down"] for q in reversed(queries_to_run)]) diff --git a/agents-api/migrations/migrate_1709200345_extend_agents_default_settings.py b/agents-api/migrations/migrate_1709200345_extend_agents_default_settings.py deleted file mode 100644 index 4a2be5921..000000000 --- a/agents-api/migrations/migrate_1709200345_extend_agents_default_settings.py +++ /dev/null @@ -1,83 +0,0 @@ -# /usr/bin/env python3 - -MIGRATION_ID = "extend_agents_default_settings" -CREATED_AT = 1709200345.052425 - - -extend_agents_default_settings = { - "up": """ - ?[ - agent_id, - frequency_penalty, - presence_penalty, - length_penalty, - repetition_penalty, - top_p, - temperature, - min_p, - ] := *agent_default_settings{ - agent_id, - frequency_penalty, - presence_penalty, - length_penalty, - repetition_penalty, - top_p, - temperature, - }, min_p = 0.01 - - :replace agent_default_settings { - agent_id: Uuid, - => - frequency_penalty: Float default 0.0, - presence_penalty: Float default 0.0, - length_penalty: Float default 1.0, - repetition_penalty: Float default 1.0, - top_p: Float default 0.95, - temperature: Float default 0.7, - min_p: Float default 0.01, - } - """, - "down": """ - ?[ - agent_id, - frequency_penalty, - presence_penalty, - length_penalty, - repetition_penalty, - top_p, - temperature, - ] := *agent_default_settings{ - agent_id, - frequency_penalty, - presence_penalty, - length_penalty, - repetition_penalty, - top_p, - temperature, - } - - :replace agent_default_settings { - agent_id: Uuid, - => - frequency_penalty: Float default 0.0, - presence_penalty: Float default 0.0, - length_penalty: Float default 1.0, - repetition_penalty: Float default 1.0, - top_p: Float default 0.95, - temperature: Float default 0.7, - } - """, -} - - -queries_to_run = [ - extend_agents_default_settings, -] - - -def up(client): - client.run(extend_agents_default_settings["up"]) - - -def down(client): - client.run(extend_agents_default_settings["down"]) diff --git a/agents-api/migrations/migrate_1709292828_presets.py b/agents-api/migrations/migrate_1709292828_presets.py deleted file mode 100644 index ee2c3885a..000000000 --- a/agents-api/migrations/migrate_1709292828_presets.py +++ /dev/null @@ -1,82 +0,0 @@ -# /usr/bin/env python3 - -MIGRATION_ID = "presets" -CREATED_AT = 1709292828.203209 - -extend_agents_default_settings = { - "up": """ - ?[ - agent_id, - frequency_penalty, - presence_penalty, - length_penalty, - repetition_penalty, - top_p, - temperature, - min_p, - ] := *agent_default_settings{ - agent_id, - frequency_penalty, - presence_penalty, - length_penalty, - repetition_penalty, - top_p, - temperature, - min_p, - }, preset = null - - :replace agent_default_settings { - agent_id: Uuid, - => - frequency_penalty: Float default 0.0, - presence_penalty: Float default 0.0, - length_penalty: Float default 1.0, - repetition_penalty: Float default 1.0, - top_p: Float default 0.95, - temperature: Float default 0.7, - min_p: Float default 0.01, - preset: String? default null, - } - """, - "down": """ - ?[ - agent_id, - frequency_penalty, - presence_penalty, - length_penalty, - repetition_penalty, - top_p, - temperature, - min_p, - ] := *agent_default_settings{ - agent_id, - frequency_penalty, - presence_penalty, - length_penalty, - repetition_penalty, - top_p, - temperature, - min_p, - } - - :replace agent_default_settings { - agent_id: Uuid, - => - frequency_penalty: Float default 0.0, - presence_penalty: Float default 0.0, - length_penalty: Float default 1.0, - repetition_penalty: Float default 1.0, - top_p: Float default 0.95, - temperature: Float default 0.7, - min_p: Float default 0.01, - } - """, -} - - -def up(client): - client.run(extend_agents_default_settings["up"]) - - -def down(client): - client.run(extend_agents_default_settings["down"]) diff --git a/agents-api/migrations/migrate_1709631202_metadata.py b/agents-api/migrations/migrate_1709631202_metadata.py deleted file mode 100644 index 36c1c8ec4..000000000 --- a/agents-api/migrations/migrate_1709631202_metadata.py +++ /dev/null @@ -1,232 +0,0 @@ -# /usr/bin/env python3 - -MIGRATION_ID = "metadata" -CREATED_AT = 1709631202.917773 - - -extend_agents = { - "up": """ - ?[agent_id, name, about, model, created_at, updated_at, developer_id, metadata] := *agents{ - agent_id, - name, - about, - model, - created_at, - updated_at, - developer_id, - }, metadata = {} - - :replace agents { - developer_id: Uuid, - agent_id: Uuid, - => - name: String, - about: String, - model: String default 'gpt-4o', - created_at: Float default now(), - updated_at: Float default now(), - metadata: Json default {}, - } - """, - "down": """ - ?[agent_id, name, about, model, created_at, updated_at, developer_id] := *agents{ - agent_id, - name, - about, - model, - created_at, - updated_at, - developer_id, - } - - :replace agents { - developer_id: Uuid, - agent_id: Uuid, - => - name: String, - about: String, - model: String default 'gpt-4o', - created_at: Float default now(), - updated_at: Float default now(), - } - """, -} - - -extend_users = { - "up": """ - ?[user_id, name, about, created_at, updated_at, developer_id, metadata] := *users{ - user_id, - name, - about, - created_at, - updated_at, - developer_id, - }, metadata = {} - - :replace users { - developer_id: Uuid, - user_id: Uuid, - => - name: String, - about: String, - created_at: Float default now(), - updated_at: Float default now(), - metadata: Json default {}, - } - """, - "down": """ - ?[user_id, name, about, created_at, updated_at, developer_id] := *users{ - user_id, - name, - about, - created_at, - updated_at, - developer_id, - } - - :replace users { - developer_id: Uuid, - user_id: Uuid, - => - name: String, - about: String, - created_at: Float default now(), - updated_at: Float default now(), - } - """, -} - - -extend_sessions = { - "up": """ - ?[developer_id, session_id, updated_at, situation, summary, created_at, developer_id, metadata] := *sessions{ - session_id, - updated_at, - situation, - summary, - created_at, - developer_id - }, metadata = {} - - :replace sessions { - developer_id: Uuid, - session_id: Uuid, - updated_at: Validity default [floor(now()), true], - => - situation: String, - summary: String? default null, - created_at: Float default now(), - metadata: Json default {}, - } - """, - "down": """ - ?[developer_id, session_id, updated_at, situation, summary, created_at, developer_id] := *sessions{ - session_id, - updated_at, - situation, - summary, - created_at, - developer_id - } - - :replace sessions { - developer_id: Uuid, - session_id: Uuid, - updated_at: Validity default [floor(now()), true], - => - situation: String, - summary: String? default null, - created_at: Float default now(), - } - """, -} - - -extend_agent_docs = { - "up": """ - ?[agent_id, doc_id, created_at, metadata] := - *agent_docs{ - agent_id, - doc_id, - created_at, - }, metadata = {} - - :replace agent_docs { - agent_id: Uuid, - doc_id: Uuid - => - created_at: Float default now(), - metadata: Json default {}, - } - """, - "down": """ - ?[agent_id, doc_id, created_at] := - *agent_docs{ - agent_id, - doc_id, - created_at, - } - - :replace agent_docs { - agent_id: Uuid, - doc_id: Uuid - => - created_at: Float default now(), - } - """, -} - - -extend_user_docs = { - "up": """ - ?[user_id, doc_id, created_at, metadata] := - *user_docs{ - user_id, - doc_id, - created_at, - }, metadata = {} - - :replace user_docs { - user_id: Uuid, - doc_id: Uuid - => - created_at: Float default now(), - metadata: Json default {}, - } - """, - "down": """ - ?[user_id, doc_id, created_at] := - *user_docs{ - user_id, - doc_id, - created_at, - } - - :replace user_docs { - user_id: Uuid, - doc_id: Uuid - => - created_at: Float default now(), - } - """, -} - - -queries_to_run = [ - extend_agents, - extend_users, - extend_sessions, - extend_agent_docs, - extend_user_docs, -] - - -def up(client): - for q in queries_to_run: - client.run(q["up"]) - - -def down(client): - for q in reversed(queries_to_run): - client.run(q["down"]) diff --git a/agents-api/migrations/migrate_1709806979_entry_relations_to_relations.py b/agents-api/migrations/migrate_1709806979_entry_relations_to_relations.py deleted file mode 100644 index e8c05be8f..000000000 --- a/agents-api/migrations/migrate_1709806979_entry_relations_to_relations.py +++ /dev/null @@ -1,30 +0,0 @@ -# /usr/bin/env python3 - -MIGRATION_ID = "entry_relations_to_relations" -CREATED_AT = 1709806979.250619 - - -entry_relations_to_relations = { - "up": """ - ::rename - entry_relations -> relations - """, - "down": """ - ::rename - relations -> entry_relations - """, -} - -queries_to_run = [ - entry_relations_to_relations, -] - - -def up(client): - for q in queries_to_run: - client.run(q["up"]) - - -def down(client): - for q in reversed(queries_to_run): - client.run(q["down"]) diff --git a/agents-api/migrations/migrate_1709810233_memories.py b/agents-api/migrations/migrate_1709810233_memories.py deleted file mode 100644 index 5036c1826..000000000 --- a/agents-api/migrations/migrate_1709810233_memories.py +++ /dev/null @@ -1,92 +0,0 @@ -# /usr/bin/env python3 - -MIGRATION_ID = "memories" -CREATED_AT = 1709810233.271039 - - -memories = { - "up": """ - :create memories { - memory_id: Uuid, - type: String, # enum: belief | episode - => - content: String, - weight: Int, # range: 0-100 - last_accessed_at: Float? default null, - timestamp: Float default now(), - sentiment: Int, - emotions: [String], - duration: Float? default null, - created_at: Float default now(), - embedding: ? default null, - } - """, - "down": """ - ::remove memories - """, -} - - -memory_lookup = { - "up": """ - :create memory_lookup { - agent_id: Uuid, - user_id: Uuid? default null, - memory_id: Uuid, - } - """, - "down": """ - ::remove memory_lookup - """, -} - - -memories_hnsw_index = { - "up": """ - ::hnsw create memories:embedding_space { - fields: [embedding], - filter: !is_null(embedding), - dim: 768, - distance: Cosine, - m: 64, - ef_construction: 256, - extend_candidates: false, - keep_pruned_connections: false, - } - """, - "down": """ - ::hnsw drop memories:embedding_space - """, -} - - -memories_fts_index = { - "up": """ - ::fts create memories:fts { - extractor: content, - tokenizer: Simple, - filters: [Lowercase, Stemmer('english'), Stopwords('en')], - } - """, - "down": """ - ::fts drop memories:fts - """, -} - - -queries_to_run = [ - memories, - memory_lookup, - memories_hnsw_index, - memories_fts_index, -] - - -def up(client): - for q in queries_to_run: - client.run(q["up"]) - - -def down(client): - for q in reversed(queries_to_run): - client.run(q["down"]) diff --git a/agents-api/migrations/migrate_1712309841_simplify_memories.py b/agents-api/migrations/migrate_1712309841_simplify_memories.py deleted file mode 100644 index 5a2656d83..000000000 --- a/agents-api/migrations/migrate_1712309841_simplify_memories.py +++ /dev/null @@ -1,144 +0,0 @@ -# /usr/bin/env python3 - -MIGRATION_ID = "simplify_memories" -CREATED_AT = 1712309841.289588 - -simplify_memories = { - "up": """ - ?[ - memory_id, - content, - last_accessed_at, - timestamp, - sentiment, - entities, - created_at, - embedding, - ] := - *memories { - memory_id, - content, - last_accessed_at, - timestamp, - sentiment, - created_at, - embedding, - }, - entities = [] - - :replace memories { - memory_id: Uuid, - => - content: String, - last_accessed_at: Float? default null, - timestamp: Float default now(), - sentiment: Int default 0.0, - entities: [Json] default [], - created_at: Float default now(), - embedding: ? default null, - } - """, - "down": """ - ?[ - memory_id, - type, - weight, - duration, - emotions, - content, - last_accessed_at, - timestamp, - sentiment, - created_at, - embedding, - ] := - *memories { - memory_id, - content, - last_accessed_at, - timestamp, - sentiment, - created_at, - embedding, - }, - type = 'episode', - weight = 1, - duration = null, - emotions = [] - - :replace memories { - memory_id: Uuid, - type: String, # enum: belief | episode - => - content: String, - weight: Int, # range: 0-100 - last_accessed_at: Float? default null, - timestamp: Float default now(), - sentiment: Int, - emotions: [String], - duration: Float? default null, - created_at: Float default now(), - embedding: ? default null, - } - """, -} - -memories_hnsw_index = { - "up": """ - ::hnsw create memories:embedding_space { - fields: [embedding], - filter: !is_null(embedding), - dim: 768, - distance: Cosine, - m: 64, - ef_construction: 256, - extend_candidates: false, - keep_pruned_connections: false, - } - """, - "down": """ - ::hnsw drop memories:embedding_space - """, -} - - -memories_fts_index = { - "up": """ - ::fts create memories:fts { - extractor: content, - tokenizer: Simple, - filters: [Lowercase, Stemmer('english'), Stopwords('en')], - } - """, - "down": """ - ::fts drop memories:fts - """, -} - -drop_memories_hnsw_index = { - "up": memories_hnsw_index["down"], - "down": memories_hnsw_index["up"], -} - -drop_memories_fts_index = { - "up": memories_fts_index["down"], - "down": memories_fts_index["up"], -} - -queries_to_run = [ - drop_memories_hnsw_index, - drop_memories_fts_index, - simplify_memories, - memories_hnsw_index, - memories_fts_index, -] - - -def up(client): - for query in queries_to_run: - client.run(query["up"]) - - -def down(client): - for query in reversed(queries_to_run): - client.run(query["down"]) diff --git a/agents-api/migrations/migrate_1712405369_simplify_instructions.py b/agents-api/migrations/migrate_1712405369_simplify_instructions.py deleted file mode 100644 index b3f8a289a..000000000 --- a/agents-api/migrations/migrate_1712405369_simplify_instructions.py +++ /dev/null @@ -1,109 +0,0 @@ -# /usr/bin/env python3 - -MIGRATION_ID = "simplify_instructions" -CREATED_AT = 1712405369.263776 - -update_agents_relation_query = dict( - up=""" - ?[agent_id, name, about, model, created_at, updated_at, developer_id, instructions, metadata] := *agents{ - agent_id, - name, - about, - model, - created_at, - updated_at, - metadata, - }, - developer_id = rand_uuid_v4(), - instructions = [] - - :replace agents { - developer_id: Uuid, - agent_id: Uuid, - => - name: String, - about: String, - instructions: [String] default [], - model: String default 'gpt-4o', - created_at: Float default now(), - updated_at: Float default now(), - metadata: Json default {}, - } - """, - down=""" - ?[agent_id, name, about, model, created_at, updated_at, developer_id, metadata] := *agents{ - agent_id, - name, - about, - model, - created_at, - updated_at, - metadata, - }, developer_id = rand_uuid_v4() - - :replace agents { - developer_id: Uuid, - agent_id: Uuid, - => - name: String, - about: String, - model: String default 'gpt-4o', - created_at: Float default now(), - updated_at: Float default now(), - metadata: Json default {}, - } - """, -) - -drop_instructions_table = dict( - down=""" - :create agent_instructions { - agent_id: Uuid, - instruction_idx: Int, - => - content: String, - important: Bool default false, - embed_instruction: String default 'Embed this historical text chunk for retrieval: ', - embedding: ? default null, - created_at: Float default now(), - } - """, - up=""" - ::remove agent_instructions - """, -) - -# See: https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md -drop_agent_instructions_hnsw_index = dict( - down=""" - ::hnsw create agent_instructions:embedding_space { - fields: [embedding], - filter: !is_null(embedding), - dim: 768, - distance: Cosine, - m: 64, - ef_construction: 256, - extend_candidates: false, - keep_pruned_connections: false, - } - """, - up=""" - ::hnsw drop agent_instructions:embedding_space - """, -) - -queries_to_run = [ - drop_agent_instructions_hnsw_index, - drop_instructions_table, - update_agents_relation_query, -] - - -def up(client): - for query in queries_to_run: - client.run(query["up"]) - - -def down(client): - for query in reversed(queries_to_run): - client.run(query["down"]) diff --git a/agents-api/migrations/migrate_1714119679_session_render_templates.py b/agents-api/migrations/migrate_1714119679_session_render_templates.py deleted file mode 100644 index 93d7dba14..000000000 --- a/agents-api/migrations/migrate_1714119679_session_render_templates.py +++ /dev/null @@ -1,67 +0,0 @@ -# /usr/bin/env python3 - -MIGRATION_ID = "session_render_templates" -CREATED_AT = 1714119679.493182 - -extend_sessions = { - "up": """ - ?[render_templates, developer_id, session_id, updated_at, situation, summary, created_at, developer_id, metadata] := *sessions{ - session_id, - updated_at, - situation, - summary, - created_at, - developer_id - }, - metadata = {}, - render_templates = false - - :replace sessions { - developer_id: Uuid, - session_id: Uuid, - updated_at: Validity default [floor(now()), true], - => - situation: String, - summary: String? default null, - created_at: Float default now(), - metadata: Json default {}, - render_templates: Bool default false, - } - """, - "down": """ - ?[developer_id, session_id, updated_at, situation, summary, created_at, developer_id, metadata] := *sessions{ - session_id, - updated_at, - situation, - summary, - created_at, - developer_id - }, metadata = {} - - :replace sessions { - developer_id: Uuid, - session_id: Uuid, - updated_at: Validity default [floor(now()), true], - => - situation: String, - summary: String? default null, - created_at: Float default now(), - metadata: Json default {}, - } - """, -} - - -queries_to_run = [ - extend_sessions, -] - - -def up(client): - for q in queries_to_run: - client.run(q["up"]) - - -def down(client): - for q in reversed(queries_to_run): - client.run(q["down"]) diff --git a/agents-api/migrations/migrate_1714566760_change_embeddings_dimensions.py b/agents-api/migrations/migrate_1714566760_change_embeddings_dimensions.py deleted file mode 100644 index dba657345..000000000 --- a/agents-api/migrations/migrate_1714566760_change_embeddings_dimensions.py +++ /dev/null @@ -1,149 +0,0 @@ -# /usr/bin/env python3 - -MIGRATION_ID = "change_embeddings_dimensions" -CREATED_AT = 1714566760.731964 - - -change_dimensions = { - "up": """ - ?[ - doc_id, - snippet_idx, - title, - snippet, - embed_instruction, - embedding, - ] := - *information_snippets{ - snippet_idx, - title, - snippet, - embed_instruction, - embedding, - doc_id, - } - - :replace information_snippets { - doc_id: Uuid, - snippet_idx: Int, - => - title: String, - snippet: String, - embed_instruction: String default 'Encode this passage for retrieval: ', - embedding: ? default null, - } - """, - "down": """ - ?[ - doc_id, - snippet_idx, - title, - snippet, - embed_instruction, - embedding, - ] := - *information_snippets{ - snippet_idx, - title, - snippet, - embed_instruction, - embedding, - doc_id, - } - - :replace information_snippets { - doc_id: Uuid, - snippet_idx: Int, - => - title: String, - snippet: String, - embed_instruction: String default 'Encode this passage for retrieval: ', - embedding: ? default null, - } - """, -} - -snippets_hnsw_768_index = dict( - up=""" - ::hnsw create information_snippets:embedding_space { - fields: [embedding], - filter: !is_null(embedding), - dim: 768, - distance: Cosine, - m: 64, - ef_construction: 256, - extend_candidates: true, - keep_pruned_connections: false, - } - """, - down=""" - ::hnsw drop information_snippets:embedding_space - """, -) - -drop_snippets_hnsw_768_index = { - "up": snippets_hnsw_768_index["down"], - "down": snippets_hnsw_768_index["up"], -} - -snippets_hnsw_1024_index = dict( - up=""" - ::hnsw create information_snippets:embedding_space { - fields: [embedding], - filter: !is_null(embedding), - dim: 1024, - distance: Cosine, - m: 64, - ef_construction: 256, - extend_candidates: true, - keep_pruned_connections: false, - } - """, - down=""" - ::hnsw drop information_snippets:embedding_space - """, -) - -drop_snippets_hnsw_1024_index = { - "up": snippets_hnsw_1024_index["down"], - "down": snippets_hnsw_1024_index["up"], -} - - -# See: https://docs.cozodb.org/en/latest/vector.html#full-text-search-fts -information_snippets_fts_index = dict( - up=""" - ::fts create information_snippets:fts { - extractor: concat(title, ' ', snippet), - tokenizer: Simple, - filters: [Lowercase, Stemmer('english'), Stopwords('en')], - } - """, - down=""" - ::fts drop information_snippets:fts - """, -) - -drop_information_snippets_fts_index = { - "up": information_snippets_fts_index["down"], - "down": information_snippets_fts_index["up"], -} - - -queries_to_run = [ - drop_information_snippets_fts_index, - drop_snippets_hnsw_768_index, - change_dimensions, - snippets_hnsw_1024_index, - information_snippets_fts_index, -] - - -def up(client): - for q in queries_to_run: - client.run(q["up"]) - - -def down(client): - for q in reversed(queries_to_run): - client.run(q["down"]) diff --git a/agents-api/migrations/migrate_1716013793_session_cache.py b/agents-api/migrations/migrate_1716013793_session_cache.py deleted file mode 100644 index c29f670b3..000000000 --- a/agents-api/migrations/migrate_1716013793_session_cache.py +++ /dev/null @@ -1,33 +0,0 @@ -# /usr/bin/env python3 - -MIGRATION_ID = "session_cache" -CREATED_AT = 1716013793.746602 - - -session_cache = dict( - up=""" - :create session_cache { - key: String, - => - value: Json, - } - """, - down=""" - ::remove session_cache - """, -) - - -queries_to_run = [ - session_cache, -] - - -def up(client): - for q in queries_to_run: - client.run(q["up"]) - - -def down(client): - for q in reversed(queries_to_run): - client.run(q["down"]) diff --git a/agents-api/migrations/migrate_1716847597_support_multimodal_chatml.py b/agents-api/migrations/migrate_1716847597_support_multimodal_chatml.py deleted file mode 100644 index 8b54b6b06..000000000 --- a/agents-api/migrations/migrate_1716847597_support_multimodal_chatml.py +++ /dev/null @@ -1,93 +0,0 @@ -# /usr/bin/env python3 - -MIGRATION_ID = "support_multimodal_chatml" -CREATED_AT = 1716847597.155657 - -update_entries = { - "up": """ - ?[ - session_id, - entry_id, - source, - role, - name, - content, - token_count, - tokenizer, - created_at, - timestamp, - ] := *entries{ - session_id, - entry_id, - source, - role, - name, - content: content_string, - token_count, - tokenizer, - created_at, - timestamp, - }, content = [{"type": "text", "content": content_string}] - - :replace entries { - session_id: Uuid, - entry_id: Uuid default random_uuid_v4(), - source: String, - role: String, - name: String? default null, - => - content: [Json], - token_count: Int, - tokenizer: String, - created_at: Float default now(), - timestamp: Float default now(), - } - """, - "down": """ - ?[ - session_id, - entry_id, - source, - role, - name, - content, - token_count, - tokenizer, - created_at, - timestamp, - ] := *entries{ - session_id, - entry_id, - source, - role, - name, - content: content_array, - token_count, - tokenizer, - created_at, - timestamp, - }, content = json_to_scalar(get(content_array, 0, "")) - - :replace entries { - session_id: Uuid, - entry_id: Uuid default random_uuid_v4(), - source: String, - role: String, - name: String? default null, - => - content: String, - token_count: Int, - tokenizer: String, - created_at: Float default now(), - timestamp: Float default now(), - } - """, -} - - -def up(client): - client.run(update_entries["up"]) - - -def down(client): - client.run(update_entries["down"]) diff --git a/agents-api/migrations/migrate_1716939839_task_relations.py b/agents-api/migrations/migrate_1716939839_task_relations.py deleted file mode 100644 index 14a6037a1..000000000 --- a/agents-api/migrations/migrate_1716939839_task_relations.py +++ /dev/null @@ -1,87 +0,0 @@ -# /usr/bin/env python3 - -MIGRATION_ID = "task_relations" -CREATED_AT = 1716939839.690704 - - -def run(client, queries): - joiner = "}\n\n{" - - query = joiner.join(queries) - query = f"{{\n{query}\n}}" - client.run(query) - - -create_task_relation_query = dict( - up=""" - :create tasks { - agent_id: Uuid, - task_id: Uuid, - updated_at_ms: Validity default [floor(now() * 1000), true], - => - name: String, - description: String? default null, - input_schema: Json, - tools_available: [Uuid] default [], - workflows: [Json], - created_at: Float default now(), - } - """, - down="::remove tasks", -) - -create_execution_relation_query = dict( - up=""" - :create executions { - task_id: Uuid, - execution_id: Uuid, - => - status: String default 'queued', - # one of: "queued", "starting", "running", "awaiting_input", "succeeded", "failed" - - arguments: Json, - session_id: Uuid? default null, - created_at: Float default now(), - updated_at: Float default now(), - } - """, - down="::remove executions", -) - -create_transition_relation_query = dict( - up=""" - :create transitions { - execution_id: Uuid, - transition_id: Uuid, - => - type: String, - # one of: "finish", "wait", "error", "step" - - from: (String, Int), - to: (String, Int)?, - output: Json, - - task_token: String? default null, - - # should store: an Activity Id, a Workflow Id, and optionally a Run Id. - metadata: Json default {}, - created_at: Float default now(), - updated_at: Float default now(), - } - """, - down="::remove transitions", -) - -queries = [ - create_task_relation_query, - create_execution_relation_query, - create_transition_relation_query, -] - - -def up(client): - run(client, [q["up"] for q in queries]) - - -def down(client): - run(client, [q["down"] for q in reversed(queries)]) diff --git a/agents-api/migrations/migrate_1717239610_token_budget.py b/agents-api/migrations/migrate_1717239610_token_budget.py deleted file mode 100644 index c042c56e5..000000000 --- a/agents-api/migrations/migrate_1717239610_token_budget.py +++ /dev/null @@ -1,67 +0,0 @@ -# /usr/bin/env python3 - -MIGRATION_ID = "token_budget" -CREATED_AT = 1717239610.622555 - -update_sessions = { - "up": """ - ?[developer_id, session_id, updated_at, situation, summary, created_at, metadata, render_templates, token_budget, context_overflow] := *sessions{ - developer_id, - session_id, - updated_at, - situation, - summary, - created_at, - metadata, - render_templates, - }, - token_budget = null, - context_overflow = null, - - :replace sessions { - developer_id: Uuid, - session_id: Uuid, - updated_at: Validity default [floor(now()), true], - => - situation: String, - summary: String? default null, - created_at: Float default now(), - metadata: Json default {}, - render_templates: Bool default false, - token_budget: Int? default null, - context_overflow: String? default null, - } - """, - "down": """ - ?[developer_id, session_id, updated_at, situation, summary, created_at, metadata, render_templates] := *sessions{ - developer_id, - session_id, - updated_at, - situation, - summary, - created_at, - metadata, - render_templates, - } - - :replace sessions { - developer_id: Uuid, - session_id: Uuid, - updated_at: Validity default [floor(now()), true], - => - situation: String, - summary: String? default null, - created_at: Float default now(), - metadata: Json default {}, - render_templates: Bool default false, - } - """, -} - - -def up(client): - client.run(update_sessions["up"]) - - -def down(client): - client.run(update_sessions["down"]) diff --git a/agents-api/migrations/migrate_1721576813_extended_tool_relations.py b/agents-api/migrations/migrate_1721576813_extended_tool_relations.py deleted file mode 100644 index 2e4583a18..000000000 --- a/agents-api/migrations/migrate_1721576813_extended_tool_relations.py +++ /dev/null @@ -1,90 +0,0 @@ -# /usr/bin/env python3 - -MIGRATION_ID = "extended_tool_relations" -CREATED_AT = 1721576813.383905 - - -drop_agent_functions_hnsw_index = dict( - up=""" - ::hnsw drop agent_functions:embedding_space - """, - down=""" - ::hnsw create agent_functions:embedding_space { - fields: [embedding], - filter: !is_null(embedding), - dim: 768, - distance: Cosine, - m: 64, - ef_construction: 256, - extend_candidates: false, - keep_pruned_connections: false, - } - """, -) - -create_tools_relation = dict( - up=""" - ?[agent_id, tool_id, type, name, spec, updated_at, created_at] := *agent_functions{ - agent_id, tool_id, name, description, parameters, updated_at, created_at - }, type = "function", - spec = {"description": description, "parameters": parameters} - - :create tools { - agent_id: Uuid, - tool_id: Uuid, - => - type: String, - name: String, - spec: Json, - - updated_at: Float default now(), - created_at: Float default now(), - } - """, - down=""" - ::remove tools - """, -) - -drop_agent_functions_table = dict( - up=""" - ::remove agent_functions - """, - down=""" - :create agent_functions { - agent_id: Uuid, - tool_id: Uuid, - => - name: String, - description: String, - parameters: Json, - embed_instruction: String default 'Transform this tool description for retrieval: ', - embedding: ? default null, - updated_at: Float default now(), - created_at: Float default now(), - } - """, -) - - -queries_to_run = [ - drop_agent_functions_hnsw_index, - create_tools_relation, - drop_agent_functions_table, -] - - -def run(client, *queries): - joiner = "}\n\n{" - - query = joiner.join(queries) - query = f"{{\n{query}\n}}" - client.run(query) - - -def up(client): - run(client, *[q["up"] for q in queries_to_run]) - - -def down(client): - run(client, *[q["down"] for q in reversed(queries_to_run)]) diff --git a/agents-api/migrations/migrate_1721609661_task_tool_ref_by_name.py b/agents-api/migrations/migrate_1721609661_task_tool_ref_by_name.py deleted file mode 100644 index 902ec396d..000000000 --- a/agents-api/migrations/migrate_1721609661_task_tool_ref_by_name.py +++ /dev/null @@ -1,105 +0,0 @@ -# /usr/bin/env python3 - -MIGRATION_ID = "task_tool_ref_by_name" -CREATED_AT = 1721609661.768934 - - -# - add metadata -# - add inherit_tools bool -# - rename tools_available to tools -update_tasks_relation = dict( - up=""" - ?[ - agent_id, - task_id, - updated_at_ms, - name, - description, - input_schema, - inherit_tools, - workflows, - created_at, - metadata, - ] := *tasks { - agent_id, - task_id, - updated_at_ms, - name, - description, - input_schema, - workflows, - created_at, - }, - metadata = {}, - inherit_tools = true - - :replace tasks { - agent_id: Uuid, - task_id: Uuid, - updated_at_ms: Validity default [floor(now() * 1000), true], - => - name: String, - description: String? default null, - input_schema: Json, - tools: [Json] default [], - inherit_tools: Bool default true, - workflows: [Json], - created_at: Float default now(), - metadata: Json default {}, - } - """, - down=""" - ?[ - agent_id, - task_id, - updated_at_ms, - name, - description, - input_schema, - workflows, - created_at, - ] := *tasks { - agent_id, - task_id, - updated_at_ms, - name, - description, - input_schema, - workflows, - created_at, - } - - :replace tasks { - agent_id: Uuid, - task_id: Uuid, - updated_at_ms: Validity default [floor(now() * 1000), true], - => - name: String, - description: String? default null, - input_schema: Json, - tools_available: [Uuid] default [], - workflows: [Json], - created_at: Float default now(), - } - """, -) - -queries_to_run = [ - update_tasks_relation, -] - - -def run(client, *queries): - joiner = "}\n\n{" - - query = joiner.join(queries) - query = f"{{\n{query}\n}}" - client.run(query) - - -def up(client): - run(client, *[q["up"] for q in queries_to_run]) - - -def down(client): - run(client, *[q["down"] for q in reversed(queries_to_run)]) diff --git a/agents-api/migrations/migrate_1721609675_multi_agent_multi_user_session.py b/agents-api/migrations/migrate_1721609675_multi_agent_multi_user_session.py deleted file mode 100644 index 6b144fca3..000000000 --- a/agents-api/migrations/migrate_1721609675_multi_agent_multi_user_session.py +++ /dev/null @@ -1,79 +0,0 @@ -# /usr/bin/env python3 - -MIGRATION_ID = "multi_agent_multi_user_session" -CREATED_AT = 1721609675.213755 - -add_multiple_participants_in_session = dict( - up=""" - ?[session_id, participant_id, participant_type] := - *session_lookup { - agent_id: participant_id, - user_id: null, - session_id, - }, participant_type = 'agent' - - ?[session_id, participant_id, participant_type] := - *session_lookup { - agent_id, - user_id: participant_id, - session_id, - }, participant_type = 'user', - participant_id != null - - :replace session_lookup { - session_id: Uuid, - participant_type: String, - participant_id: Uuid, - } - """, - down=""" - users[user_id, session_id] := - *session_lookup { - session_id, - participant_type: "user", - participant_id: user_id, - } - - agents[agent_id, session_id] := - *session_lookup { - session_id, - participant_type: "agent", - participant_id: agent_id, - } - - ?[agent_id, user_id, session_id] := - agents[agent_id, session_id], - users[user_id, session_id] - - ?[agent_id, user_id, session_id] := - agents[agent_id, session_id], - not users[_, session_id], - user_id = null - - :replace session_lookup { - agent_id: Uuid, - user_id: Uuid? default null, - session_id: Uuid, - } - """, -) - -queries_to_run = [ - add_multiple_participants_in_session, -] - - -def run(client, *queries): - joiner = "}\n\n{" - - query = joiner.join(queries) - query = f"{{\n{query}\n}}" - client.run(query) - - -def up(client): - run(client, *[q["up"] for q in queries_to_run]) - - -def down(client): - run(client, *[q["down"] for q in reversed(queries_to_run)]) diff --git a/agents-api/migrations/migrate_1721666295_developers_relation.py b/agents-api/migrations/migrate_1721666295_developers_relation.py deleted file mode 100644 index 560b056da..000000000 --- a/agents-api/migrations/migrate_1721666295_developers_relation.py +++ /dev/null @@ -1,32 +0,0 @@ -# /usr/bin/env python3 - -MIGRATION_ID = "developers_relation" -CREATED_AT = 1721666295.486804 - - -def up(client): - client.run( - """ - # Create developers table and insert default developer - ?[developer_id, email] <- [ - ["00000000-0000-0000-0000-000000000000", "developers@example.com"] - ] - - :create developers { - developer_id: Uuid, - => - email: String, - active: Bool default true, - created_at: Float default now(), - updated_at: Float default now(), - } - """ - ) - - -def down(client): - client.run( - """ - ::remove developers - """ - ) diff --git a/agents-api/migrations/migrate_1721678846_rename_information_snippets.py b/agents-api/migrations/migrate_1721678846_rename_information_snippets.py deleted file mode 100644 index a3fdd4f94..000000000 --- a/agents-api/migrations/migrate_1721678846_rename_information_snippets.py +++ /dev/null @@ -1,33 +0,0 @@ -# /usr/bin/env python3 - -MIGRATION_ID = "rename_information_snippets" -CREATED_AT = 1721678846.468865 - -rename_information_snippets = dict( - up=""" - ::rename information_snippets -> snippets - """, - down=""" - ::rename snippets -> information_snippets - """, -) - -queries_to_run = [ - rename_information_snippets, -] - - -def run(client, *queries): - joiner = "}\n\n{" - - query = joiner.join(queries) - query = f"{{\n{query}\n}}" - client.run(query) - - -def up(client): - run(client, *[q["up"] for q in queries_to_run]) - - -def down(client): - run(client, *[q["down"] for q in reversed(queries_to_run)]) diff --git a/agents-api/migrations/migrate_1722107354_rename_executions_arguments_col.py b/agents-api/migrations/migrate_1722107354_rename_executions_arguments_col.py deleted file mode 100644 index 9fcb3dac9..000000000 --- a/agents-api/migrations/migrate_1722107354_rename_executions_arguments_col.py +++ /dev/null @@ -1,83 +0,0 @@ -# /usr/bin/env python3 - -MIGRATION_ID = "rename_executions_arguments_col" -CREATED_AT = 1722107354.988836 - -rename_arguments_add_metadata_query = dict( - up=""" - ?[ - task_id, - execution_id, - status, - input, - session_id, - created_at, - updated_at, - metadata, - ] := - *executions{ - task_id, - execution_id, - arguments: input, - status, - session_id, - created_at, - updated_at, - }, metadata = {} - - :replace executions { - task_id: Uuid, - execution_id: Uuid, - => - status: String default 'queued', - # one of: "queued", "starting", "running", "awaiting_input", "succeeded", "failed" - - input: Json, - session_id: Uuid? default null, - metadata: Json default {}, - created_at: Float default now(), - updated_at: Float default now(), - } - """, - down=""" - ?[ - task_id, - execution_id, - status, - arguments, - session_id, - created_at, - updated_at, - ] := - *executions{ - task_id, - execution_id, - input: arguments, - status, - session_id, - created_at, - updated_at, - } - - :replace executions { - task_id: Uuid, - execution_id: Uuid, - => - status: String default 'queued', - # one of: "queued", "starting", "running", "awaiting_input", "succeeded", "failed" - - arguments: Json, - session_id: Uuid? default null, - created_at: Float default now(), - updated_at: Float default now(), - } - """, -) - - -def up(client): - client.run(rename_arguments_add_metadata_query["up"]) - - -def down(client): - client.run(rename_arguments_add_metadata_query["down"]) diff --git a/agents-api/migrations/migrate_1722115427_rename_transitions_from.py b/agents-api/migrations/migrate_1722115427_rename_transitions_from.py deleted file mode 100644 index 63f2660e8..000000000 --- a/agents-api/migrations/migrate_1722115427_rename_transitions_from.py +++ /dev/null @@ -1,103 +0,0 @@ -# /usr/bin/env python3 - -MIGRATION_ID = "rename_transitions_from" -CREATED_AT = 1722115427.685346 - -rename_transitions_from_to_query = dict( - up=""" - ?[ - execution_id, - transition_id, - type, - current, - next, - output, - task_token, - metadata, - created_at, - updated_at, - ] := *transitions { - execution_id, - transition_id, - type, - from: current, - to: next, - output, - task_token, - metadata, - created_at, - updated_at, - } - - :replace transitions { - execution_id: Uuid, - transition_id: Uuid, - => - type: String, - # one of: "finish", "wait", "error", "step" - - current: (String, Int), - next: (String, Int)?, - output: Json, - - task_token: String? default null, - - # should store: an Activity Id, a Workflow Id, and optionally a Run Id. - metadata: Json default {}, - created_at: Float default now(), - updated_at: Float default now(), - } - """, - down=""" - ?[ - execution_id, - transition_id, - type, - from, - to, - output, - task_token, - metadata, - created_at, - updated_at, - ] := *transitions { - execution_id, - transition_id, - type, - current: from, - next: to, - output, - task_token, - metadata, - created_at, - updated_at, - } - - :replace transitions { - execution_id: Uuid, - transition_id: Uuid, - => - type: String, - # one of: "finish", "wait", "error", "step" - - from: (String, Int), - to: (String, Int)?, - output: Json, - - task_token: String? default null, - - # should store: an Activity Id, a Workflow Id, and optionally a Run Id. - metadata: Json default {}, - created_at: Float default now(), - updated_at: Float default now(), - } - """, -) - - -def up(client): - client.run(rename_transitions_from_to_query["up"]) - - -def down(client): - client.run(rename_transitions_from_to_query["down"]) diff --git a/agents-api/migrations/migrate_1722710530_unify_owner_doc_relations.py b/agents-api/migrations/migrate_1722710530_unify_owner_doc_relations.py deleted file mode 100644 index a56bce674..000000000 --- a/agents-api/migrations/migrate_1722710530_unify_owner_doc_relations.py +++ /dev/null @@ -1,204 +0,0 @@ -# /usr/bin/env python3 - -MIGRATION_ID = "unify_owner_doc_relations" -CREATED_AT = 1722710530.126563 - -create_docs_relations_query = dict( - up=""" - :create docs { - owner_type: String, - owner_id: Uuid, - doc_id: Uuid, - => - title: String, - created_at: Float default now(), - metadata: Json default {}, - } - """, - down="::remove docs", -) - -remove_user_docs_table = dict( - up=""" - doc_title[doc_id, unique(title)] := - *snippets { - doc_id, - title, - } - - ?[owner_type, owner_id, doc_id, title, created_at, metadata] := - owner_type = "user", - *user_docs { - user_id: owner_id, - doc_id, - created_at, - metadata, - }, - doc_title[doc_id, title] - - :insert docs { - owner_type, - owner_id, - doc_id, - title, - created_at, - metadata, - } - - } { # <-- this is just a separator between the two queries - ::remove user_docs - """, - down=""" - :create user_docs { - user_id: Uuid, - doc_id: Uuid - => - created_at: Float default now(), - metadata: Json default {}, - } - """, -) - -remove_agent_docs_table = dict( - up=remove_user_docs_table["up"].replace("user", "agent"), - down=remove_user_docs_table["down"].replace("user", "agent"), -) - -# See: https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md -snippets_hnsw_index = dict( - up=""" - ::hnsw create snippets:embedding_space { - fields: [embedding], - filter: !is_null(embedding), - dim: 1024, - distance: Cosine, - m: 64, - ef_construction: 256, - extend_candidates: true, - keep_pruned_connections: false, - } - """, - down=""" - ::hnsw drop snippets:embedding_space - """, -) - -# See: https://docs.cozodb.org/en/latest/vector.html#full-text-search-fts -snippets_fts_index = dict( - up=""" - ::fts create snippets:fts { - extractor: content, - tokenizer: Simple, - filters: [Lowercase, Stemmer('english'), Stopwords('en')], - } - """, - down=""" - ::fts drop snippets:fts - """, -) - -temp_rename_snippets_table = dict( - up=""" - ::rename snippets -> information_snippets - """, - down=""" - ::rename information_snippets -> snippets - """, -) - -temp_rename_snippets_table_back = dict( - up=temp_rename_snippets_table["down"], - down=temp_rename_snippets_table["up"], -) - -drop_snippets_hnsw_index = { - "up": snippets_hnsw_index["down"].replace("snippets:", "information_snippets:"), - "down": snippets_hnsw_index["up"].replace("snippets:", "information_snippets:"), -} - -drop_snippets_fts_index = dict( - up=""" - ::fts drop information_snippets:fts - """, - down=""" - ::fts create information_snippets:fts { - extractor: concat(title, ' ', snippet), - tokenizer: Simple, - filters: [Lowercase, Stemmer('english'), Stopwords('en')], - } - """, -) - - -remove_title_from_snippets_table = dict( - up=""" - ?[doc_id, index, content, embedding] := - *snippets { - doc_id, - snippet_idx: index, - snippet: content, - embedding, - } - - :replace snippets { - doc_id: Uuid, - index: Int, - => - content: String, - embedding: ? default null, - } - """, - down=""" - ?[doc_id, snippet_idx, title, snippet, embedding] := - *snippets { - doc_id, - index: snippet_idx, - content: snippet, - embedding, - }, - *docs { - doc_id, - title, - } - - :replace snippets { - doc_id: Uuid, - snippet_idx: Int, - => - title: String, - snippet: String, - embed_instruction: String default 'Encode this passage for retrieval: ', - embedding: ? default null, - } - """, -) - -queries = [ - create_docs_relations_query, - remove_user_docs_table, - remove_agent_docs_table, - temp_rename_snippets_table, # Because of a bug in Cozo - drop_snippets_hnsw_index, - drop_snippets_fts_index, - temp_rename_snippets_table_back, # Because of a bug in Cozo - remove_title_from_snippets_table, - snippets_fts_index, - snippets_hnsw_index, -] - - -def run(client, queries): - joiner = "}\n\n{" - - query = joiner.join(queries) - query = f"{{\n{query}\n}}" - - client.run(query) - - -def up(client): - run(client, [q["up"] for q in queries]) - - -def down(client): - run(client, [q["down"] for q in reversed(queries)]) diff --git a/agents-api/migrations/migrate_1722875101_add_temporal_mapping.py b/agents-api/migrations/migrate_1722875101_add_temporal_mapping.py deleted file mode 100644 index b38a3717c..000000000 --- a/agents-api/migrations/migrate_1722875101_add_temporal_mapping.py +++ /dev/null @@ -1,40 +0,0 @@ -# /usr/bin/env python3 - -MIGRATION_ID = "add_temporal_mapping" -CREATED_AT = 1722875101.262791 - - -def run(client, queries): - joiner = "}\n\n{" - - query = joiner.join(queries) - query = f"{{\n{query}\n}}" - client.run(query) - - -create_temporal_executions_lookup = dict( - up=""" - :create temporal_executions_lookup { - execution_id: Uuid, - id: String, - => - run_id: String?, - first_execution_run_id: String?, - result_run_id: String?, - created_at: Float default now(), - } - """, - down="::remove temporal_executions_lookup", -) - -queries = [ - create_temporal_executions_lookup, -] - - -def up(client): - run(client, [q["up"] for q in queries]) - - -def down(client): - run(client, [q["down"] for q in reversed(queries)]) diff --git a/agents-api/migrations/migrate_1723307805_add_lsh_index_to_docs.py b/agents-api/migrations/migrate_1723307805_add_lsh_index_to_docs.py deleted file mode 100644 index 01eaa8a60..000000000 --- a/agents-api/migrations/migrate_1723307805_add_lsh_index_to_docs.py +++ /dev/null @@ -1,44 +0,0 @@ -# /usr/bin/env python3 - -MIGRATION_ID = "add_lsh_index_to_docs" -CREATED_AT = 1723307805.007054 - -# See: https://docs.cozodb.org/en/latest/vector.html#full-text-search-fts -snippets_lsh_index = dict( - up=""" - ::lsh create snippets:lsh { - extractor: content, - tokenizer: Simple, - filters: [Stopwords('en')], - n_perm: 200, - target_threshold: 0.9, - n_gram: 3, - false_positive_weight: 1.0, - false_negative_weight: 1.0, - } - """, - down=""" - ::lsh drop snippets:lsh - """, -) - -queries = [ - snippets_lsh_index, -] - - -def run(client, queries): - joiner = "}\n\n{" - - query = joiner.join(queries) - query = f"{{\n{query}\n}}" - - client.run(query) - - -def up(client): - run(client, [q["up"] for q in queries]) - - -def down(client): - run(client, [q["down"] for q in reversed(queries)]) diff --git a/agents-api/migrations/migrate_1723400730_add_settings_to_developers.py b/agents-api/migrations/migrate_1723400730_add_settings_to_developers.py deleted file mode 100644 index e10e71510..000000000 --- a/agents-api/migrations/migrate_1723400730_add_settings_to_developers.py +++ /dev/null @@ -1,68 +0,0 @@ -# /usr/bin/env python3 - -MIGRATION_ID = "add_settings_to_developers" -CREATED_AT = 1723400730.539554 - - -def up(client): - client.run( - """ - ?[ - developer_id, - email, - active, - tags, - settings, - created_at, - updated_at, - ] := *developers { - developer_id, - email, - active, - created_at, - updated_at, - }, - tags = [], - settings = {} - - :replace developers { - developer_id: Uuid, - => - email: String, - active: Bool default true, - tags: [String] default [], - settings: Json, - created_at: Float default now(), - updated_at: Float default now(), - } - """ - ) - - -def down(client): - client.run( - """ - ?[ - developer_id, - email, - active, - created_at, - updated_at, - ] := *developers { - developer_id, - email, - active, - created_at, - updated_at, - } - - :replace developers { - developer_id: Uuid, - => - email: String, - active: Bool default true, - created_at: Float default now(), - updated_at: Float default now(), - } - """ - ) diff --git a/agents-api/migrations/migrate_1725153437_add_output_to_executions.py b/agents-api/migrations/migrate_1725153437_add_output_to_executions.py deleted file mode 100644 index 8118e4f89..000000000 --- a/agents-api/migrations/migrate_1725153437_add_output_to_executions.py +++ /dev/null @@ -1,104 +0,0 @@ -# /usr/bin/env python3 - -MIGRATION_ID = "add_output_to_executions" -CREATED_AT = 1725153437.489542 - - -def run(client, queries): - joiner = "}\n\n{" - - query = joiner.join(queries) - query = f"{{\n{query}\n}}" - client.run(query) - - -add_output_to_executions_query = dict( - up=""" - ?[ - task_id, - execution_id, - status, - input, - session_id, - created_at, - updated_at, - output, - error, - metadata, - ] := - *executions { - task_id, - execution_id, - status, - input, - session_id, - created_at, - updated_at, - }, - output = null, - error = null, - metadata = {} - - :replace executions { - task_id: Uuid, - execution_id: Uuid, - => - status: String default 'queued', - # one of: "queued", "starting", "running", "awaiting_input", "succeeded", "failed" - - input: Json, - output: Json? default null, - error: String? default null, - session_id: Uuid? default null, - metadata: Json default {}, - created_at: Float default now(), - updated_at: Float default now(), - } - """, - down=""" - ?[ - task_id, - execution_id, - status, - input, - session_id, - created_at, - updated_at, - ] := - *executions { - task_id, - execution_id, - status, - input, - session_id, - created_at, - updated_at, - } - - :replace executions { - task_id: Uuid, - execution_id: Uuid, - => - status: String default 'queued', - # one of: "queued", "starting", "running", "awaiting_input", "succeeded", "failed" - - input: Json, - session_id: Uuid? default null, - created_at: Float default now(), - updated_at: Float default now(), - } - """, -) - - -queries = [ - add_output_to_executions_query, -] - - -def up(client): - run(client, [q["up"] for q in queries]) - - -def down(client): - run(client, [q["down"] for q in reversed(queries)]) diff --git a/agents-api/migrations/migrate_1725323734_make_transition_output_optional.py b/agents-api/migrations/migrate_1725323734_make_transition_output_optional.py deleted file mode 100644 index dd13c3132..000000000 --- a/agents-api/migrations/migrate_1725323734_make_transition_output_optional.py +++ /dev/null @@ -1,109 +0,0 @@ -# /usr/bin/env python3 - -MIGRATION_ID = "make_transition_output_optional" -CREATED_AT = 1725323734.591567 - - -def run(client, queries): - joiner = "}\n\n{" - - query = joiner.join(queries) - query = f"{{\n{query}\n}}" - client.run(query) - - -make_transition_output_optional_query = dict( - up=""" - ?[ - execution_id, - transition_id, - output, - type, - current, - next, - task_token, - metadata, - created_at, - updated_at, - ] := - *transitions { - execution_id, - transition_id, - output, - type, - current, - next, - task_token, - metadata, - created_at, - updated_at, - } - - :replace transitions { - execution_id: Uuid, - transition_id: Uuid, - => - type: String, - current: (String, Int), - next: (String, Int)?, - output: Json?, # <--- this is the only change; output is now optional - task_token: String? default null, - metadata: Json default {}, - created_at: Float default now(), - updated_at: Float default now(), - } - """, - down=""" - ?[ - execution_id, - transition_id, - output, - type, - current, - next, - task_token, - metadata, - created_at, - updated_at, - ] := - *transitions { - execution_id, - transition_id, - output, - type, - current, - next, - task_token, - metadata, - created_at, - updated_at, - } - - :replace transitions { - execution_id: Uuid, - transition_id: Uuid, - => - type: String, - current: (String, Int), - next: (String, Int)?, - output: Json, - task_token: String? default null, - metadata: Json default {}, - created_at: Float default now(), - updated_at: Float default now(), - } - """, -) - - -queries = [ - make_transition_output_optional_query, -] - - -def up(client): - run(client, [q["up"] for q in queries]) - - -def down(client): - run(client, [q["down"] for q in reversed(queries)]) diff --git a/agents-api/migrations/migrate_1727235852_add_forward_tool_calls_option.py b/agents-api/migrations/migrate_1727235852_add_forward_tool_calls_option.py deleted file mode 100644 index aa1b8441a..000000000 --- a/agents-api/migrations/migrate_1727235852_add_forward_tool_calls_option.py +++ /dev/null @@ -1,87 +0,0 @@ -# /usr/bin/env python3 - -MIGRATION_ID = "add_forward_tool_calls_option" -CREATED_AT = 1727235852.744035 - - -def run(client, queries): - joiner = "}\n\n{" - - query = joiner.join(queries) - query = f"{{\n{query}\n}}" - client.run(query) - - -add_forward_tool_calls_option_to_session_query = dict( - up=""" - ?[forward_tool_calls, token_budget, context_overflow, developer_id, session_id, updated_at, situation, summary, created_at, metadata, render_templates, token_budget, context_overflow] := *sessions{ - developer_id, - session_id, - updated_at, - situation, - summary, - created_at, - metadata, - render_templates, - token_budget, - context_overflow, - }, - forward_tool_calls = null - - :replace sessions { - developer_id: Uuid, - session_id: Uuid, - updated_at: Validity default [floor(now()), true], - => - situation: String, - summary: String? default null, - created_at: Float default now(), - metadata: Json default {}, - render_templates: Bool default false, - token_budget: Int? default null, - context_overflow: String? default null, - forward_tool_calls: Bool? default null, - } - """, - down=""" - ?[token_budget, context_overflow, developer_id, session_id, updated_at, situation, summary, created_at, metadata, render_templates, token_budget, context_overflow] := *sessions{ - developer_id, - session_id, - updated_at, - situation, - summary, - created_at, - metadata, - render_templates, - token_budget, - context_overflow, - } - - :replace sessions { - developer_id: Uuid, - session_id: Uuid, - updated_at: Validity default [floor(now()), true], - => - situation: String, - summary: String? default null, - created_at: Float default now(), - metadata: Json default {}, - render_templates: Bool default false, - token_budget: Int? default null, - context_overflow: String? default null, - } - """, -) - - -queries = [ - add_forward_tool_calls_option_to_session_query, -] - - -def up(client): - run(client, [q["up"] for q in queries]) - - -def down(client): - run(client, [q["down"] for q in reversed(queries)]) diff --git a/agents-api/migrations/migrate_1727922523_add_description_to_tools.py b/agents-api/migrations/migrate_1727922523_add_description_to_tools.py deleted file mode 100644 index 1d6724090..000000000 --- a/agents-api/migrations/migrate_1727922523_add_description_to_tools.py +++ /dev/null @@ -1,64 +0,0 @@ -# /usr/bin/env python3 - -MIGRATION_ID = "add_description_to_tools" -CREATED_AT = 1727922523.283493 - - -add_description_to_tools = dict( - up=""" - ?[agent_id, tool_id, type, name, description, spec, updated_at, created_at] := *tools { - agent_id, tool_id, type, name, spec, updated_at, created_at - }, description = null - - :replace tools { - agent_id: Uuid, - tool_id: Uuid, - => - type: String, - name: String, - description: String?, - spec: Json, - - updated_at: Float default now(), - created_at: Float default now(), - } - """, - down=""" - ?[agent_id, tool_id, type, name, spec, updated_at, created_at] := *tools { - agent_id, tool_id, type, name, spec, updated_at, created_at - } - - :replace tools { - agent_id: Uuid, - tool_id: Uuid, - => - type: String, - name: String, - spec: Json, - - updated_at: Float default now(), - created_at: Float default now(), - } - """, -) - - -queries_to_run = [ - add_description_to_tools, -] - - -def run(client, *queries): - joiner = "}\n\n{" - - query = joiner.join(queries) - query = f"{{\n{query}\n}}" - client.run(query) - - -def up(client): - run(client, *[q["up"] for q in queries_to_run]) - - -def down(client): - run(client, *[q["down"] for q in reversed(queries_to_run)]) diff --git a/agents-api/migrations/migrate_1729114011_tweak_proximity_indices.py b/agents-api/migrations/migrate_1729114011_tweak_proximity_indices.py deleted file mode 100644 index 4852f3603..000000000 --- a/agents-api/migrations/migrate_1729114011_tweak_proximity_indices.py +++ /dev/null @@ -1,133 +0,0 @@ -# /usr/bin/env python3 - -MIGRATION_ID = "tweak_proximity_indices" -CREATED_AT = 1729114011.022733 - - -def run(client, *queries): - joiner = "}\n\n{" - - query = joiner.join(queries) - query = f"{{\n{query}\n}}" - client.run(query) - - -# See: https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md -drop_snippets_hnsw_index = dict( - down=""" - ::hnsw create snippets:embedding_space { - fields: [embedding], - filter: !is_null(embedding), - dim: 1024, - distance: Cosine, - m: 64, - ef_construction: 256, - extend_candidates: true, - keep_pruned_connections: false, - } - """, - up=""" - ::hnsw drop snippets:embedding_space - """, -) - - -# See: https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md -snippets_hnsw_index = dict( - up=""" - ::hnsw create snippets:embedding_space { - fields: [embedding], - filter: !is_null(embedding), - dim: 1024, - distance: Cosine, - m: 64, - ef_construction: 800, - extend_candidates: false, - keep_pruned_connections: false, - } - """, - down=""" - ::hnsw drop snippets:embedding_space - """, -) - -drop_snippets_lsh_index = dict( - up=""" - ::lsh drop snippets:lsh - """, - down=""" - ::lsh create snippets:lsh { - extractor: content, - tokenizer: Simple, - filters: [Stopwords('en')], - n_perm: 200, - target_threshold: 0.9, - n_gram: 3, - false_positive_weight: 1.0, - false_negative_weight: 1.0, - } - """, -) - -snippets_lsh_index = dict( - up=""" - ::lsh create snippets:lsh { - extractor: content, - tokenizer: Simple, - filters: [Lowercase, AsciiFolding, Stemmer('english'), Stopwords('en')], - n_perm: 200, - target_threshold: 0.5, - n_gram: 2, - false_positive_weight: 1.0, - false_negative_weight: 1.0, - } - """, - down=""" - ::lsh drop snippets:lsh - """, -) - -# See: https://docs.cozodb.org/en/latest/vector.html#full-text-search-fts -drop_snippets_fts_index = dict( - down=""" - ::fts create snippets:fts { - extractor: content, - tokenizer: Simple, - filters: [Lowercase, Stemmer('english'), Stopwords('en')], - } - """, - up=""" - ::fts drop snippets:fts - """, -) - -# See: https://docs.cozodb.org/en/latest/vector.html#full-text-search-fts -snippets_fts_index = dict( - up=""" - ::fts create snippets:fts { - extractor: content, - tokenizer: Simple, - filters: [Lowercase, AsciiFolding, Stemmer('english'), Stopwords('en')], - } - """, - down=""" - ::fts drop snippets:fts - """, -) - -queries_to_run = [ - drop_snippets_hnsw_index, - drop_snippets_lsh_index, - drop_snippets_fts_index, - snippets_hnsw_index, - snippets_lsh_index, - snippets_fts_index, -] - - -def up(client): - run(client, *[q["up"] for q in queries_to_run]) - - -def down(client): - run(client, *[q["down"] for q in reversed(queries_to_run)]) diff --git a/agents-api/migrations/migrate_1731143165_support_tool_call_id.py b/agents-api/migrations/migrate_1731143165_support_tool_call_id.py deleted file mode 100644 index 9faf4d577..000000000 --- a/agents-api/migrations/migrate_1731143165_support_tool_call_id.py +++ /dev/null @@ -1,100 +0,0 @@ -# /usr/bin/env python3 - -MIGRATION_ID = "support_tool_call_id" -CREATED_AT = 1731143165.95882 - -update_entries = { - "down": """ - ?[ - session_id, - entry_id, - source, - role, - name, - content, - token_count, - tokenizer, - created_at, - timestamp, - ] := *entries{ - session_id, - entry_id, - source, - role, - name, - content: content_string, - token_count, - tokenizer, - created_at, - timestamp, - }, content = [{"type": "text", "content": content_string}] - - :replace entries { - session_id: Uuid, - entry_id: Uuid default random_uuid_v4(), - source: String, - role: String, - name: String? default null, - => - content: [Json], - token_count: Int, - tokenizer: String, - created_at: Float default now(), - timestamp: Float default now(), - } - """, - "up": """ - ?[ - session_id, - entry_id, - source, - role, - name, - content, - token_count, - tokenizer, - created_at, - timestamp, - tool_call_id, - tool_calls, - ] := *entries{ - session_id, - entry_id, - source, - role, - name, - content: content_string, - token_count, - tokenizer, - created_at, - timestamp, - }, - content = [{"type": "text", "content": content_string}], - tool_call_id = null, - tool_calls = null - - :replace entries { - session_id: Uuid, - entry_id: Uuid default random_uuid_v4(), - source: String, - role: String, - name: String? default null, - => - content: [Json], - tool_call_id: String? default null, - tool_calls: [Json]? default null, - token_count: Int, - tokenizer: String, - created_at: Float default now(), - timestamp: Float default now(), - } - """, -} - - -def up(client): - client.run(update_entries["up"]) - - -def down(client): - client.run(update_entries["down"]) diff --git a/agents-api/migrations/migrate_1731953383_create_files_relation.py b/agents-api/migrations/migrate_1731953383_create_files_relation.py deleted file mode 100644 index 9cdc4f8fe..000000000 --- a/agents-api/migrations/migrate_1731953383_create_files_relation.py +++ /dev/null @@ -1,29 +0,0 @@ -# /usr/bin/env python3 - -MIGRATION_ID = "create_files_relation" -CREATED_AT = 1731953383.258172 - -create_files_query = dict( - up=""" - :create files { - developer_id: Uuid, - file_id: Uuid, - => - name: String, - description: String default "", - mime_type: String? default null, - size: Int, - hash: String, - created_at: Float default now(), - } - """, - down="::remove files", -) - - -def up(client): - client.run(create_files_query["up"]) - - -def down(client): - client.run(create_files_query["down"]) diff --git a/agents-api/migrations/migrate_1733493650_add_recall_options_to_sessions.py b/agents-api/migrations/migrate_1733493650_add_recall_options_to_sessions.py deleted file mode 100644 index ba0be5d2b..000000000 --- a/agents-api/migrations/migrate_1733493650_add_recall_options_to_sessions.py +++ /dev/null @@ -1,91 +0,0 @@ -# /usr/bin/env python3 - -MIGRATION_ID = "add_recall_options_to_sessions" -CREATED_AT = 1733493650.922383 - - -def run(client, queries): - joiner = "}\n\n{" - - query = joiner.join(queries) - query = f"{{\n{query}\n}}" - client.run(query) - - -add_recall_options_to_sessions_query = dict( - up=""" - ?[recall_options, forward_tool_calls, token_budget, context_overflow, developer_id, session_id, updated_at, situation, summary, created_at, metadata, render_templates, token_budget, context_overflow] := *sessions{ - developer_id, - session_id, - updated_at, - situation, - summary, - created_at, - metadata, - render_templates, - token_budget, - context_overflow, - forward_tool_calls, - }, - recall_options = {}, - - :replace sessions { - developer_id: Uuid, - session_id: Uuid, - updated_at: Validity default [floor(now()), true], - => - situation: String, - summary: String? default null, - created_at: Float default now(), - metadata: Json default {}, - render_templates: Bool default false, - token_budget: Int? default null, - context_overflow: String? default null, - forward_tool_calls: Bool? default null, - recall_options: Json default {}, - } - """, - down=""" - ?[forward_tool_calls, token_budget, context_overflow, developer_id, session_id, updated_at, situation, summary, created_at, metadata, render_templates, token_budget, context_overflow] := *sessions{ - developer_id, - session_id, - updated_at, - situation, - summary, - created_at, - metadata, - render_templates, - token_budget, - context_overflow, - }, - forward_tool_calls = null - - :replace sessions { - developer_id: Uuid, - session_id: Uuid, - updated_at: Validity default [floor(now()), true], - => - situation: String, - summary: String? default null, - created_at: Float default now(), - metadata: Json default {}, - render_templates: Bool default false, - token_budget: Int? default null, - context_overflow: String? default null, - forward_tool_calls: Bool? default null, - } - """, -) - - -queries = [ - add_recall_options_to_sessions_query, -] - - -def up(client): - run(client, [q["up"] for q in queries]) - - -def down(client): - run(client, [q["down"] for q in reversed(queries)]) diff --git a/agents-api/migrations/migrate_1733755642_transition_indices.py b/agents-api/migrations/migrate_1733755642_transition_indices.py deleted file mode 100644 index 1b33f4646..000000000 --- a/agents-api/migrations/migrate_1733755642_transition_indices.py +++ /dev/null @@ -1,42 +0,0 @@ -# /usr/bin/env python3 - -MIGRATION_ID = "transition_indices" -CREATED_AT = 1733755642.881131 - - -create_transition_indices = dict( - up=[ - "::index create executions:execution_id_status_idx { execution_id, status }", - "::index create executions:execution_id_task_id_idx { execution_id, task_id }", - "::index create executions:task_id_execution_id_idx { task_id, execution_id }", - "::index create tasks:task_id_agent_id_idx { task_id, agent_id }", - "::index create agents:agent_id_developer_id_idx { agent_id, developer_id }", - "::index create sessions:session_id_developer_id_idx { session_id, developer_id }", - "::index create docs:owner_id_metadata_doc_id_idx { owner_id, metadata, doc_id }", - "::index create agents:developer_id_metadata_agent_id_idx { developer_id, metadata, agent_id }", - "::index create users:developer_id_metadata_user_id_idx { developer_id, metadata, user_id }", - "::index create transitions:execution_id_type_created_at_idx { execution_id, type, created_at }", - ], - down=[ - "::index drop executions:execution_id_status_idx", - "::index drop executions:execution_id_task_id_idx", - "::index drop executions:task_id_execution_id_idx", - "::index drop tasks:task_id_agent_id_idx", - "::index drop agents:agent_id_developer_id_idx", - "::index drop sessions:session_id_developer_id_idx", - "::index drop docs:owner_id_metadata_doc_id_idx", - "::index drop agents:developer_id_metadata_agent_id_idx", - "::index drop users:developer_id_metadata_user_id_idx", - "::index drop transitions:execution_id_type_created_at_idx", - ], -) - - -def up(client): - for q in create_transition_indices["up"]: - client.run(q) - - -def down(client): - for q in create_transition_indices["down"]: - client.run(q) diff --git a/agents-api/poe_tasks.toml b/agents-api/poe_tasks.toml index 60fa533f7..beeb234c1 100644 --- a/agents-api/poe_tasks.toml +++ b/agents-api/poe_tasks.toml @@ -1,10 +1,12 @@ [tasks] format = "ruff format" -lint = "ruff check --select I --fix --unsafe-fixes agents_api/**/*.py migrations/**/*.py tests/**/*.py" +lint = "ruff check" typecheck = "pytype --config pytype.toml" +validate-sql = "sqlvalidator --verbose-validate agents_api/" check = [ "lint", "format", + "validate-sql", "typecheck", ] codegen = """ diff --git a/agents-api/pyproject.toml b/agents-api/pyproject.toml index 677abd678..54028c9a1 100644 --- a/agents-api/pyproject.toml +++ b/agents-api/pyproject.toml @@ -31,8 +31,6 @@ dependencies = [ "pandas~=2.2.2", "prometheus-client~=0.21.0", "prometheus-fastapi-instrumentator~=7.0.0", - "pycozo-async~=0.7.7", - "pycozo[embedded]~=0.7.6", "pydantic-partial~=0.5.5", "pydantic[email]~=2.10.2", "python-box~=7.2.0", @@ -50,24 +48,34 @@ dependencies = [ "uvloop~=0.21.0", "xxhash~=3.5.0", "spacy-chunks>=0.0.2", + "uuid7>=0.1.0", + "asyncpg>=0.30.0", + "unique-namer>=1.6.1", ] [dependency-groups] dev = [ - "cozo-migrate>=0.2.4", "datamodel-code-generator>=0.26.3", "ipython>=8.30.0", "ipywidgets>=8.1.5", "julep>=1.43.1", "jupyterlab>=4.3.1", + "pip>=24.3.1", "poethepoet>=0.31.1", "pyjwt>=2.10.1", - "pyright>=1.1.389", + "pyright>=1.1.391", "pytype>=2024.10.11", - "ruff>=0.8.1", + "ruff>=0.8.4", + "sqlvalidator>=0.0.20", + "testcontainers[postgres,localstack]>=4.9.0", "ward>=0.68.0b0", ] +[tool.setuptools] +py-modules = [ + "agents_api", +] + [tool.uv.sources] litellm = { url = "https://github.com/julep-ai/litellm/archive/fix_anthropic_tool_image_content.zip" } diff --git a/agents-api/scripts/agents_api.py b/agents-api/scripts/agents_api.py index 5bacef0c8..8ab7d2e0c 100644 --- a/agents-api/scripts/agents_api.py +++ b/agents-api/scripts/agents_api.py @@ -1,5 +1,4 @@ import fire - from agents_api.web import main diff --git a/agents-api/tests/fixtures.py b/agents-api/tests/fixtures.py index 2ed346892..72e8f4d7e 100644 --- a/agents-api/tests/fixtures.py +++ b/agents-api/tests/fixtures.py @@ -1,12 +1,8 @@ -import time -from uuid import UUID, uuid4 - -from cozo_migrate.api import apply, init -from fastapi.testclient import TestClient -from pycozo import Client as CozoClient -from pycozo_async import Client as AsyncCozoClient -from temporalio.client import WorkflowHandle -from ward import fixture +import os +import random +import string +import sys +from uuid import UUID from agents_api.autogen.openapi_model import ( CreateAgentRequest, @@ -19,351 +15,293 @@ CreateTransitionRequest, CreateUserRequest, ) +from agents_api.clients.pg import create_db_pool from agents_api.env import api_key, api_key_header_name, multi_tenant_mode -from agents_api.models.agent.create_agent import create_agent -from agents_api.models.agent.delete_agent import delete_agent -from agents_api.models.developer.get_developer import get_developer -from agents_api.models.docs.create_doc import create_doc -from agents_api.models.docs.delete_doc import delete_doc -from agents_api.models.execution.create_execution import create_execution -from agents_api.models.execution.create_execution_transition import ( +from agents_api.queries.agents.create_agent import create_agent +from agents_api.queries.developers.create_developer import create_developer +from agents_api.queries.developers.get_developer import get_developer +from agents_api.queries.docs.create_doc import create_doc +from agents_api.queries.docs.get_doc import get_doc +from agents_api.queries.executions.create_execution import create_execution +from agents_api.queries.executions.create_execution_transition import ( create_execution_transition, ) -from agents_api.models.execution.create_temporal_lookup import create_temporal_lookup -from agents_api.models.files.create_file import create_file -from agents_api.models.files.delete_file import delete_file -from agents_api.models.session.create_session import create_session -from agents_api.models.session.delete_session import delete_session -from agents_api.models.task.create_task import create_task -from agents_api.models.task.delete_task import delete_task -from agents_api.models.tools.create_tools import create_tools -from agents_api.models.tools.delete_tool import delete_tool -from agents_api.models.user.create_user import create_user -from agents_api.models.user.delete_user import delete_user +from agents_api.queries.executions.create_temporal_lookup import create_temporal_lookup +from agents_api.queries.files.create_file import create_file +from agents_api.queries.sessions.create_session import create_session +from agents_api.queries.tasks.create_task import create_task +from agents_api.queries.tools.create_tools import create_tools +from agents_api.queries.users.create_user import create_user from agents_api.web import app -from tests.utils import ( - patch_embed_acompletion as patch_embed_acompletion_ctx, +from aiobotocore.session import get_session +from fastapi.testclient import TestClient +from temporalio.client import WorkflowHandle +from uuid_extensions import uuid7 +from ward import fixture + +from .utils import ( + get_localstack, + get_pg_dsn, ) -from tests.utils import ( - patch_s3_client, +from .utils import ( + patch_embed_acompletion as patch_embed_acompletion_ctx, ) -EMBEDDING_SIZE: int = 1024 - @fixture(scope="global") -def cozo_client(migrations_dir: str = "./migrations"): - # Create a new client for each test - # and initialize the schema. - client = CozoClient() - - setattr(app.state, "cozo_client", client) - - init(client) - apply(client, migrations_dir=migrations_dir, all_=True) +def pg_dsn(): + with get_pg_dsn() as pg_dsn: + os.environ["PG_DSN"] = pg_dsn - return client + try: + yield pg_dsn + finally: + del os.environ["PG_DSN"] @fixture(scope="global") -def cozo_clients_with_migrations(sync_client=cozo_client): - async_client = AsyncCozoClient() - async_client.embedded = sync_client.embedded - setattr(app.state, "async_cozo_client", async_client) - - return sync_client, async_client - - -@fixture(scope="global") -def async_cozo_client(migrations_dir: str = "./migrations"): - # Create a new client for each test - # and initialize the schema. - client = AsyncCozoClient() - migrations_client = CozoClient() - setattr(migrations_client, "embedded", client.embedded) - - setattr(app.state, "async_cozo_client", client) - - init(migrations_client) - apply(migrations_client, migrations_dir=migrations_dir, all_=True) - - return client - - -@fixture(scope="global") -def test_developer_id(cozo_client=cozo_client): +def test_developer_id(): if not multi_tenant_mode: - yield UUID(int=0) - return - - developer_id = uuid4() + return UUID(int=0) - cozo_client.run( - f""" - ?[developer_id, email, settings] <- [["{str(developer_id)}", "developers@julep.ai", {{}}]] - :insert developers {{ developer_id, email, settings }} - """ - ) - - yield developer_id - - cozo_client.run( - f""" - ?[developer_id, email] <- [["{str(developer_id)}", "developers@julep.ai"]] - :delete developers {{ developer_id, email }} - """ - ) + return uuid7() @fixture(scope="global") -def test_file(client=cozo_client, developer_id=test_developer_id): - file = create_file( +async def test_developer(dsn=pg_dsn, developer_id=test_developer_id): + pool = await create_db_pool(dsn=dsn) + return await get_developer( developer_id=developer_id, - data=CreateFileRequest( - name="Hello", - description="World", - mime_type="text/plain", - content="eyJzYW1wbGUiOiAidGVzdCJ9", - ), - client=client, - ) - - yield file - - delete_file( - developer_id=developer_id, - file_id=file.id, - client=client, - ) - - -@fixture(scope="global") -def test_developer(cozo_client=cozo_client, developer_id=test_developer_id): - return get_developer( - developer_id=developer_id, - client=cozo_client, + connection_pool=pool, ) @fixture(scope="test") def patch_embed_acompletion(): output = {"role": "assistant", "content": "Hello, world!"} - with patch_embed_acompletion_ctx(output) as (embed, acompletion): yield embed, acompletion -@fixture(scope="global") -def test_agent(cozo_client=cozo_client, developer_id=test_developer_id): - agent = create_agent( - developer_id=developer_id, +@fixture(scope="test") +async def test_agent(dsn=pg_dsn, developer=test_developer): + pool = await create_db_pool(dsn=dsn) + + return await create_agent( + developer_id=developer.id, data=CreateAgentRequest( model="gpt-4o-mini", name="test agent", about="test agent about", metadata={"test": "test"}, ), - client=cozo_client, + connection_pool=pool, ) - yield agent - - delete_agent( - developer_id=developer_id, - agent_id=agent.id, - client=cozo_client, - ) +@fixture(scope="test") +async def test_user(dsn=pg_dsn, developer=test_developer): + pool = await create_db_pool(dsn=dsn) -@fixture(scope="global") -def test_user(cozo_client=cozo_client, developer_id=test_developer_id): - user = create_user( - developer_id=developer_id, + return await create_user( + developer_id=developer.id, data=CreateUserRequest( name="test user", about="test user about", ), - client=cozo_client, + connection_pool=pool, ) - yield user - delete_user( - developer_id=developer_id, - user_id=user.id, - client=cozo_client, +@fixture(scope="test") +async def test_file(dsn=pg_dsn, developer=test_developer, user=test_user): + pool = await create_db_pool(dsn=dsn) + return await create_file( + developer_id=developer.id, + data=CreateFileRequest( + name="Hello", + description="World", + mime_type="text/plain", + content="eyJzYW1wbGUiOiAidGVzdCJ9", + ), + connection_pool=pool, ) -@fixture(scope="global") -def test_session( - cozo_client=cozo_client, - developer_id=test_developer_id, - test_user=test_user, - test_agent=test_agent, -): - session = create_session( - developer_id=developer_id, - data=CreateSessionRequest( - agent=test_agent.id, user=test_user.id, metadata={"test": "test"} +@fixture(scope="test") +async def test_doc(dsn=pg_dsn, developer=test_developer, agent=test_agent): + pool = await create_db_pool(dsn=dsn) + resp = await create_doc( + developer_id=developer.id, + data=CreateDocRequest( + title="Hello", + content=["World", "World2", "World3"], + metadata={"test": "test"}, + embed_instruction="Embed the document", ), - client=cozo_client, + owner_type="agent", + owner_id=agent.id, + connection_pool=pool, ) + return await get_doc(developer_id=developer.id, doc_id=resp.id, connection_pool=pool) - yield session - delete_session( - developer_id=developer_id, - session_id=session.id, - client=cozo_client, +@fixture(scope="test") +async def test_task(dsn=pg_dsn, developer=test_developer, agent=test_agent): + pool = await create_db_pool(dsn=dsn) + return await create_task( + developer_id=developer.id, + agent_id=agent.id, + task_id=uuid7(), + data=CreateTaskRequest( + name="test task", + description="test task about", + input_schema={"type": "object", "additionalProperties": True}, + main=[{"evaluate": {"hi": "_"}}], + metadata={"test": True}, + ), + connection_pool=pool, ) -@fixture(scope="global") -def test_doc( - client=cozo_client, - developer_id=test_developer_id, - agent=test_agent, -): - doc = create_doc( - developer_id=developer_id, - owner_type="agent", - owner_id=agent.id, - data=CreateDocRequest(title="Hello", content=["World"]), - client=client, - ) +@fixture(scope="test") +async def random_email(): + return f"{''.join([random.choice(string.ascii_lowercase) for _ in range(10)])}@mail.com" - time.sleep(0.5) - yield doc +@fixture(scope="test") +async def test_new_developer(dsn=pg_dsn, email=random_email): + pool = await create_db_pool(dsn=dsn) + dev_id = uuid7() + await create_developer( + email=email, + active=True, + tags=["tag1"], + settings={"key1": "val1"}, + developer_id=dev_id, + connection_pool=pool, + ) - delete_doc( - developer_id=developer_id, - doc_id=doc.id, - owner_type="agent", - owner_id=agent.id, - client=client, + return await get_developer( + developer_id=dev_id, + connection_pool=pool, ) -@fixture(scope="global") -def test_user_doc( - client=cozo_client, +@fixture(scope="test") +async def test_session( + dsn=pg_dsn, developer_id=test_developer_id, - user=test_user, + test_user=test_user, + test_agent=test_agent, ): - doc = create_doc( - developer_id=developer_id, - owner_type="user", - owner_id=user.id, - data=CreateDocRequest(title="Hello", content=["World"]), - client=client, - ) - - time.sleep(0.5) + pool = await create_db_pool(dsn=dsn) - yield doc - - delete_doc( + return await create_session( developer_id=developer_id, - doc_id=doc.id, - owner_type="user", - owner_id=user.id, - client=client, + data=CreateSessionRequest( + agent=test_agent.id, + user=test_user.id, + metadata={"test": "test"}, + system_template="test system template", + ), + connection_pool=pool, ) @fixture(scope="global") -def test_task( - client=cozo_client, +async def test_user_doc( + dsn=pg_dsn, developer_id=test_developer_id, - agent=test_agent, + user=test_user, ): - task = create_task( + pool = await create_db_pool(dsn=dsn) + doc = await create_doc( developer_id=developer_id, - agent_id=agent.id, - data=CreateTaskRequest( - **{ - "name": "test task", - "description": "test task about", - "input_schema": {"type": "object", "additionalProperties": True}, - "main": [{"evaluate": {"hello": '"world"'}}], - } - ), - client=client, + owner_type="user", + owner_id=user.id, + data=CreateDocRequest(title="Hello", content=["World"]), + connection_pool=pool, ) + yield doc - yield task - delete_task( - developer_id=developer_id, - task_id=task.id, - client=client, - ) +# @fixture(scope="global") +# async def test_task( +# dsn=pg_dsn, +# developer_id=test_developer_id, +# agent=test_agent, +# ): +# async with get_pg_client(dsn=dsn) as client: +# task = await create_task( +# developer_id=developer_id, +# agent_id=agent.id, +# data=CreateTaskRequest( +# **{ +# "name": "test task", +# "description": "test task about", +# "input_schema": {"type": "object", "additionalProperties": True}, +# "main": [{"evaluate": {"hello": '"world"'}}], +# } +# ), +# client=client, +# ) +# yield task @fixture(scope="global") -def test_execution( - client=cozo_client, +async def test_execution( + dsn=pg_dsn, developer_id=test_developer_id, task=test_task, ): + pool = await create_db_pool(dsn=dsn) workflow_handle = WorkflowHandle( client=None, id="blah", ) - execution = create_execution( + execution = await create_execution( developer_id=developer_id, task_id=task.id, data=CreateExecutionRequest(input={"test": "test"}), - client=client, + connection_pool=pool, ) - create_temporal_lookup( - developer_id=developer_id, + await create_temporal_lookup( execution_id=execution.id, workflow_handle=workflow_handle, - client=client, + connection_pool=pool, ) - yield execution - client.run( - f""" - ?[execution_id] <- ["{str(execution.id)}"] - :delete executions {{ execution_id }} - """ - ) - @fixture(scope="test") -def test_execution_started( - client=cozo_client, +async def test_execution_started( + dsn=pg_dsn, developer_id=test_developer_id, task=test_task, ): + pool = await create_db_pool(dsn=dsn) workflow_handle = WorkflowHandle( client=None, id="blah", ) - execution = create_execution( + execution = await create_execution( developer_id=developer_id, task_id=task.id, data=CreateExecutionRequest(input={"test": "test"}), - client=client, + connection_pool=pool, ) - create_temporal_lookup( - developer_id=developer_id, + await create_temporal_lookup( execution_id=execution.id, workflow_handle=workflow_handle, - client=client, + connection_pool=pool, ) # Start the execution - create_execution_transition( + await create_execution_transition( developer_id=developer_id, - task_id=task.id, execution_id=execution.id, data=CreateTransitionRequest( type="init", @@ -371,27 +309,19 @@ def test_execution_started( current={"workflow": "main", "step": 0}, next={"workflow": "main", "step": 0}, ), - update_execution_status=True, - client=client, + connection_pool=pool, ) - yield execution - client.run( - f""" - ?[execution_id, task_id] <- [[to_uuid("{str(execution.id)}"), to_uuid("{str(task.id)}")]] - :delete executions {{ execution_id, task_id }} - """ - ) - @fixture(scope="global") -def test_transition( - client=cozo_client, +async def test_transition( + dsn=pg_dsn, developer_id=test_developer_id, - execution=test_execution, + execution=test_execution_started, ): - transition = create_execution_transition( + pool = await create_db_pool(dsn=dsn) + transition = await create_execution_transition( developer_id=developer_id, execution_id=execution.id, data=CreateTransitionRequest( @@ -400,63 +330,46 @@ def test_transition( current={"workflow": "main", "step": 0}, next={"workflow": "wf1", "step": 1}, ), - client=client, + connection_pool=pool, ) - yield transition - client.run( - f""" - ?[transition_id] <- ["{str(transition.id)}"] - :delete transitions {{ transition_id }} - """ - ) - -@fixture(scope="global") -def test_tool( - client=cozo_client, +@fixture(scope="test") +async def test_tool( + dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, ): + pool = await create_db_pool(dsn=dsn) function = { "description": "A function that prints hello world", "parameters": {"type": "object", "properties": {}}, } - tool = { + tool_spec = { "function": function, "name": "hello_world1", "type": "function", } - [tool, *_] = create_tools( + [tool, *_] = await create_tools( developer_id=developer_id, agent_id=agent.id, - data=[CreateToolRequest(**tool)], - client=client, - ) - - yield tool - - delete_tool( - developer_id=developer_id, - agent_id=agent.id, - tool_id=tool.id, - client=client, + data=[CreateToolRequest(**tool_spec)], + connection_pool=pool, ) + return tool @fixture(scope="global") -def client(cozo_client=cozo_client): - client = TestClient(app=app) - app.state.cozo_client = cozo_client - - return client +def client(_dsn=pg_dsn): + with TestClient(app=app) as client: + yield client @fixture(scope="global") -def make_request(client=client, developer_id=test_developer_id): +async def make_request(client=client, developer_id=test_developer_id): def _make_request(method, url, **kwargs): headers = kwargs.pop("headers", {}) headers = { @@ -467,12 +380,30 @@ def _make_request(method, url, **kwargs): if multi_tenant_mode: headers["X-Developer-Id"] = str(developer_id) + headers["Content-Length"] = str(sys.getsizeof(kwargs.get("json", {}))) + return client.request(method, url, headers=headers, **kwargs) return _make_request @fixture(scope="global") -def s3_client(): - with patch_s3_client() as s3_client: - yield s3_client +async def s3_client(): + with get_localstack() as localstack: + s3_endpoint = localstack.get_url() + + session = get_session() + s3_client = await session.create_client( + "s3", + endpoint_url=s3_endpoint, + aws_access_key_id=localstack.env["AWS_ACCESS_KEY_ID"], + aws_secret_access_key=localstack.env["AWS_SECRET_ACCESS_KEY"], + ).__aenter__() + + app.state.s3_client = s3_client + + try: + yield s3_client + finally: + await s3_client.close() + app.state.s3_client = None diff --git a/agents-api/tests/sample_tasks/test_find_selector.py b/agents-api/tests/sample_tasks/test_find_selector.py index 5af7aac54..beaa18613 100644 --- a/agents-api/tests/sample_tasks/test_find_selector.py +++ b/agents-api/tests/sample_tasks/test_find_selector.py @@ -1,126 +1,125 @@ -# Tests for task queries - -import os -from uuid import uuid4 - -from ward import raises, test - -from ..fixtures import cozo_client, test_agent, test_developer_id -from ..utils import patch_embed_acompletion, patch_http_client_with_temporal - -this_dir = os.path.dirname(__file__) - - -@test("workflow sample: find-selector create task") -async def _( - cozo_client=cozo_client, - developer_id=test_developer_id, - agent=test_agent, -): - agent_id = str(agent.id) - task_id = str(uuid4()) - - with ( - patch_embed_acompletion(), - open(f"{this_dir}/find_selector.yaml", "r") as sample_file, - ): - task_def = sample_file.read() - - async with patch_http_client_with_temporal( - cozo_client=cozo_client, developer_id=developer_id - ) as ( - make_request, - _, - ): - make_request( - method="POST", - url=f"/agents/{agent_id}/tasks/{task_id}", - headers={"Content-Type": "application/x-yaml"}, - data=task_def, - ).raise_for_status() - - -@test("workflow sample: find-selector start with bad input should fail") -async def _( - cozo_client=cozo_client, - developer_id=test_developer_id, - agent=test_agent, -): - agent_id = str(agent.id) - task_id = str(uuid4()) - - with ( - patch_embed_acompletion(), - open(f"{this_dir}/find_selector.yaml", "r") as sample_file, - ): - task_def = sample_file.read() - - async with patch_http_client_with_temporal( - cozo_client=cozo_client, developer_id=developer_id - ) as ( - make_request, - temporal_client, - ): - make_request( - method="POST", - url=f"/agents/{agent_id}/tasks/{task_id}", - headers={"Content-Type": "application/x-yaml"}, - data=task_def, - ).raise_for_status() - - execution_data = dict(input={"test": "input"}) - - with raises(BaseException): - make_request( - method="POST", - url=f"/tasks/{task_id}/executions", - json=execution_data, - ).raise_for_status() - - -@test("workflow sample: find-selector start with correct input") -async def _( - cozo_client=cozo_client, - developer_id=test_developer_id, - agent=test_agent, -): - agent_id = str(agent.id) - task_id = str(uuid4()) - - with ( - patch_embed_acompletion( - output={"role": "assistant", "content": "found: true\nvalue: 'Gaga'"} - ), - open(f"{this_dir}/find_selector.yaml", "r") as sample_file, - ): - task_def = sample_file.read() - - async with patch_http_client_with_temporal( - cozo_client=cozo_client, developer_id=developer_id - ) as ( - make_request, - temporal_client, - ): - make_request( - method="POST", - url=f"/agents/{agent_id}/tasks/{task_id}", - headers={"Content-Type": "application/x-yaml"}, - data=task_def, - ).raise_for_status() - - input = dict( - screenshot_base64="iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAACXBIWXMAAAsTAAALEwEAmpwYAAAA", - network_requests=[{"request": {}, "response": {"body": "Lady Gaga"}}], - parameters=["name"], - ) - execution_data = dict(input=input) - - execution_created = make_request( - method="POST", - url=f"/tasks/{task_id}/executions", - json=execution_data, - ).json() - - handle = temporal_client.get_workflow_handle(execution_created["jobs"][0]) - - await handle.result() +# # Tests for task queries +# import os + +# from uuid_extensions import uuid7 +# from ward import raises, test + +# from ..fixtures import cozo_client, test_agent, test_developer_id +# from ..utils import patch_embed_acompletion, patch_http_client_with_temporal + +# this_dir = os.path.dirname(__file__) + + +# @test("workflow sample: find-selector create task") +# async def _( +# cozo_client=cozo_client, +# developer_id=test_developer_id, +# agent=test_agent, +# ): +# agent_id = str(agent.id) +# task_id = str(uuid7()) + +# with ( +# patch_embed_acompletion(), +# open(f"{this_dir}/find_selector.yaml", "r") as sample_file, +# ): +# task_def = sample_file.read() + +# async with patch_http_client_with_temporal( +# cozo_client=cozo_client, developer_id=developer_id +# ) as ( +# make_request, +# _, +# ): +# make_request( +# method="POST", +# url=f"/agents/{agent_id}/tasks/{task_id}", +# headers={"Content-Type": "application/x-yaml"}, +# data=task_def, +# ).raise_for_status() + + +# @test("workflow sample: find-selector start with bad input should fail") +# async def _( +# cozo_client=cozo_client, +# developer_id=test_developer_id, +# agent=test_agent, +# ): +# agent_id = str(agent.id) +# task_id = str(uuid7()) + +# with ( +# patch_embed_acompletion(), +# open(f"{this_dir}/find_selector.yaml", "r") as sample_file, +# ): +# task_def = sample_file.read() + +# async with patch_http_client_with_temporal( +# cozo_client=cozo_client, developer_id=developer_id +# ) as ( +# make_request, +# temporal_client, +# ): +# make_request( +# method="POST", +# url=f"/agents/{agent_id}/tasks/{task_id}", +# headers={"Content-Type": "application/x-yaml"}, +# data=task_def, +# ).raise_for_status() + +# execution_data = dict(input={"test": "input"}) + +# with raises(BaseException): +# make_request( +# method="POST", +# url=f"/tasks/{task_id}/executions", +# json=execution_data, +# ).raise_for_status() + + +# @test("workflow sample: find-selector start with correct input") +# async def _( +# cozo_client=cozo_client, +# developer_id=test_developer_id, +# agent=test_agent, +# ): +# agent_id = str(agent.id) +# task_id = str(uuid7()) + +# with ( +# patch_embed_acompletion( +# output={"role": "assistant", "content": "found: true\nvalue: 'Gaga'"} +# ), +# open(f"{this_dir}/find_selector.yaml", "r") as sample_file, +# ): +# task_def = sample_file.read() + +# async with patch_http_client_with_temporal( +# cozo_client=cozo_client, developer_id=developer_id +# ) as ( +# make_request, +# temporal_client, +# ): +# make_request( +# method="POST", +# url=f"/agents/{agent_id}/tasks/{task_id}", +# headers={"Content-Type": "application/x-yaml"}, +# data=task_def, +# ).raise_for_status() + +# input = dict( +# screenshot_base64="iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAACXBIWXMAAAsTAAALEwEAmpwYAAAA", +# network_requests=[{"request": {}, "response": {"body": "Lady Gaga"}}], +# parameters=["name"], +# ) +# execution_data = dict(input=input) + +# execution_created = make_request( +# method="POST", +# url=f"/tasks/{task_id}/executions", +# json=execution_data, +# ).json() + +# handle = temporal_client.get_workflow_handle(execution_created["jobs"][0]) + +# await handle.result() diff --git a/agents-api/tests/test_activities.py b/agents-api/tests/test_activities.py index 6f65cd034..83c6970ee 100644 --- a/agents-api/tests/test_activities.py +++ b/agents-api/tests/test_activities.py @@ -1,45 +1,13 @@ -from uuid import uuid4 - -from ward import test - -from agents_api.activities.embed_docs import embed_docs -from agents_api.activities.types import EmbedDocsPayload from agents_api.clients import temporal from agents_api.env import temporal_task_queue from agents_api.workflows.demo import DemoWorkflow from agents_api.workflows.task_execution.helpers import DEFAULT_RETRY_POLICY +from uuid_extensions import uuid7 +from ward import test -from .fixtures import ( - cozo_client, - test_developer_id, - test_doc, -) from .utils import patch_testing_temporal -@test("activity: call direct embed_docs") -async def _( - cozo_client=cozo_client, - developer_id=test_developer_id, - doc=test_doc, -): - title = "title" - content = ["content 1"] - include_title = True - - await embed_docs( - EmbedDocsPayload( - developer_id=developer_id, - doc_id=doc.id, - title=title, - content=content, - include_title=include_title, - embed_instruction=None, - ), - cozo_client, - ) - - @test("activity: call demo workflow via temporal client") async def _(): async with patch_testing_temporal() as (_, mock_get_client): @@ -48,7 +16,7 @@ async def _(): result = await client.execute_workflow( DemoWorkflow.run, args=[1, 2], - id=str(uuid4()), + id=str(uuid7()), task_queue=temporal_task_queue, retry_policy=DEFAULT_RETRY_POLICY, ) diff --git a/agents-api/tests/test_agent_queries.py b/agents-api/tests/test_agent_queries.py index 8c0099419..d9c012e8e 100644 --- a/agents-api/tests/test_agent_queries.py +++ b/agents-api/tests/test_agent_queries.py @@ -1,108 +1,71 @@ # Tests for agent queries -from uuid import uuid4 - -from ward import raises, test from agents_api.autogen.openapi_model import ( Agent, CreateAgentRequest, CreateOrUpdateAgentRequest, PatchAgentRequest, + ResourceDeletedResponse, ResourceUpdatedResponse, UpdateAgentRequest, ) -from agents_api.models.agent.create_agent import create_agent -from agents_api.models.agent.create_or_update_agent import create_or_update_agent -from agents_api.models.agent.delete_agent import delete_agent -from agents_api.models.agent.get_agent import get_agent -from agents_api.models.agent.list_agents import list_agents -from agents_api.models.agent.patch_agent import patch_agent -from agents_api.models.agent.update_agent import update_agent -from tests.fixtures import cozo_client, test_agent, test_developer_id - - -@test("model: create agent") -def _(client=cozo_client, developer_id=test_developer_id): - create_agent( - developer_id=developer_id, - data=CreateAgentRequest( - name="test agent", - about="test agent about", - model="gpt-4o-mini", - ), - client=client, - ) +from agents_api.clients.pg import create_db_pool +from agents_api.queries.agents import ( + create_agent, + create_or_update_agent, + delete_agent, + get_agent, + list_agents, + patch_agent, + update_agent, +) +from uuid_extensions import uuid7 +from ward import raises, test +from tests.fixtures import pg_dsn, test_agent, test_developer_id -@test("model: create agent with instructions") -def _(client=cozo_client, developer_id=test_developer_id): - create_agent( - developer_id=developer_id, - data=CreateAgentRequest( - name="test agent", - about="test agent about", - model="gpt-4o-mini", - instructions=["test instruction"], - ), - client=client, - ) +@test("query: create agent sql") +async def _(dsn=pg_dsn, developer_id=test_developer_id): + """Test that an agent can be successfully created.""" -@test("model: create or update agent") -def _(client=cozo_client, developer_id=test_developer_id): - create_or_update_agent( + pool = await create_db_pool(dsn=dsn) + await create_agent( developer_id=developer_id, - agent_id=uuid4(), - data=CreateOrUpdateAgentRequest( + data=CreateAgentRequest( name="test agent", about="test agent about", model="gpt-4o-mini", - instructions=["test instruction"], ), - client=client, + connection_pool=pool, ) -@test("model: get agent not exists") -def _(client=cozo_client, developer_id=test_developer_id): - agent_id = uuid4() - - with raises(Exception): - get_agent(agent_id=agent_id, developer_id=developer_id, client=client) - - -@test("model: get agent exists") -def _(client=cozo_client, developer_id=test_developer_id, agent=test_agent): - result = get_agent(agent_id=agent.id, developer_id=developer_id, client=client) - - assert result is not None - assert isinstance(result, Agent) - +@test("query: create or update agent sql") +async def _(dsn=pg_dsn, developer_id=test_developer_id): + """Test that an agent can be successfully created or updated.""" -@test("model: delete agent") -def _(client=cozo_client, developer_id=test_developer_id): - temp_agent = create_agent( + pool = await create_db_pool(dsn=dsn) + await create_or_update_agent( developer_id=developer_id, - data=CreateAgentRequest( + agent_id=uuid7(), + data=CreateOrUpdateAgentRequest( name="test agent", + canonical_name="test_agent2", about="test agent about", model="gpt-4o-mini", instructions=["test instruction"], ), - client=client, + connection_pool=pool, ) - # Delete the agent - delete_agent(agent_id=temp_agent.id, developer_id=developer_id, client=client) - # Check that the agent is deleted - with raises(Exception): - get_agent(agent_id=temp_agent.id, developer_id=developer_id, client=client) +@test("query: update agent sql") +async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent): + """Test that an existing agent's information can be successfully updated.""" - -@test("model: update agent") -def _(client=cozo_client, developer_id=test_developer_id, agent=test_agent): - result = update_agent( + pool = await create_db_pool(dsn=dsn) + result = await update_agent( agent_id=agent.id, developer_id=developer_id, data=UpdateAgentRequest( @@ -112,24 +75,56 @@ def _(client=cozo_client, developer_id=test_developer_id, agent=test_agent): default_settings={"temperature": 1.0}, metadata={"hello": "world"}, ), - client=client, + connection_pool=pool, ) assert result is not None assert isinstance(result, ResourceUpdatedResponse) - agent = get_agent( + +@test("query: get agent not exists sql") +async def _(dsn=pg_dsn, developer_id=test_developer_id): + """Test that retrieving a non-existent agent raises an exception.""" + + agent_id = uuid7() + pool = await create_db_pool(dsn=dsn) + + with raises(Exception): + await get_agent(agent_id=agent_id, developer_id=developer_id, connection_pool=pool) + + +@test("query: get agent exists sql") +async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent): + """Test that retrieving an existing agent returns the correct agent information.""" + + pool = await create_db_pool(dsn=dsn) + result = await get_agent( agent_id=agent.id, developer_id=developer_id, - client=client, + connection_pool=pool, ) - assert "test" not in agent.metadata + assert result is not None + assert isinstance(result, Agent) + + +@test("query: list agents sql") +async def _(dsn=pg_dsn, developer_id=test_developer_id): + """Test that listing agents returns a collection of agent information.""" + + pool = await create_db_pool(dsn=dsn) + result = await list_agents(developer_id=developer_id, connection_pool=pool) + + assert isinstance(result, list) + assert all(isinstance(agent, Agent) for agent in result) + +@test("query: patch agent sql") +async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent): + """Test that an agent can be successfully patched.""" -@test("model: patch agent") -def _(client=cozo_client, developer_id=test_developer_id, agent=test_agent): - result = patch_agent( + pool = await create_db_pool(dsn=dsn) + result = await patch_agent( agent_id=agent.id, developer_id=developer_id, data=PatchAgentRequest( @@ -138,26 +133,37 @@ def _(client=cozo_client, developer_id=test_developer_id, agent=test_agent): default_settings={"temperature": 1.0}, metadata={"something": "else"}, ), - client=client, + connection_pool=pool, ) assert result is not None assert isinstance(result, ResourceUpdatedResponse) - agent = get_agent( - agent_id=agent.id, - developer_id=developer_id, - client=client, - ) - - assert "hello" in agent.metadata +@test("query: delete agent sql") +async def _(dsn=pg_dsn, developer_id=test_developer_id): + """Test that an agent can be successfully deleted.""" -@test("model: list agents") -def _(client=cozo_client, developer_id=test_developer_id, agent=test_agent): - """Tests listing all agents associated with a developer in the database. Verifies that the correct list of agents is retrieved.""" + pool = await create_db_pool(dsn=dsn) + create_result = await create_agent( + developer_id=developer_id, + data=CreateAgentRequest( + name="test agent", + about="test agent about", + model="gpt-4o-mini", + ), + connection_pool=pool, + ) + delete_result = await delete_agent( + agent_id=create_result.id, developer_id=developer_id, connection_pool=pool + ) - result = list_agents(developer_id=developer_id, client=client) + assert delete_result is not None + assert isinstance(delete_result, ResourceDeletedResponse) - assert isinstance(result, list) - assert all(isinstance(agent, Agent) for agent in result) + with raises(Exception): + await get_agent( + developer_id=developer_id, + agent_id=create_result.id, + connection_pool=pool, + ) diff --git a/agents-api/tests/test_agent_routes.py b/agents-api/tests/test_agent_routes.py index 91ddf9f1a..2da1fec1b 100644 --- a/agents-api/tests/test_agent_routes.py +++ b/agents-api/tests/test_agent_routes.py @@ -1,6 +1,6 @@ -# Tests for agent queries -from uuid import uuid4 +# # Tests for agent queries +from uuid_extensions import uuid7 from ward import test from tests.fixtures import client, make_request, test_agent @@ -8,11 +8,11 @@ @test("route: unauthorized should fail") def _(client=client): - data = dict( - name="test agent", - about="test agent about", - model="gpt-4o-mini", - ) + data = { + "name": "test agent", + "about": "test agent about", + "model": "gpt-4o-mini", + } response = client.request( method="POST", @@ -25,11 +25,11 @@ def _(client=client): @test("route: create agent") def _(make_request=make_request): - data = dict( - name="test agent", - about="test agent about", - model="gpt-4o-mini", - ) + data = { + "name": "test agent", + "about": "test agent about", + "model": "gpt-4o-mini", + } response = make_request( method="POST", @@ -42,12 +42,12 @@ def _(make_request=make_request): @test("route: create agent with instructions") def _(make_request=make_request): - data = dict( - name="test agent", - about="test agent about", - model="gpt-4o-mini", - instructions=["test instruction"], - ) + data = { + "name": "test agent", + "about": "test agent about", + "model": "gpt-4o-mini", + "instructions": ["test instruction"], + } response = make_request( method="POST", @@ -60,14 +60,14 @@ def _(make_request=make_request): @test("route: create or update agent") def _(make_request=make_request): - agent_id = str(uuid4()) + agent_id = str(uuid7()) - data = dict( - name="test agent", - about="test agent about", - model="gpt-4o-mini", - instructions=["test instruction"], - ) + data = { + "name": "test agent", + "about": "test agent about", + "model": "gpt-4o-mini", + "instructions": ["test instruction"], + } response = make_request( method="POST", @@ -80,7 +80,7 @@ def _(make_request=make_request): @test("route: get agent not exists") def _(make_request=make_request): - agent_id = str(uuid4()) + agent_id = str(uuid7()) response = make_request( method="GET", @@ -104,12 +104,12 @@ def _(make_request=make_request, agent=test_agent): @test("route: delete agent") def _(make_request=make_request): - data = dict( - name="test agent", - about="test agent about", - model="gpt-4o-mini", - instructions=["test instruction"], - ) + data = { + "name": "test agent", + "about": "test agent about", + "model": "gpt-4o-mini", + "instructions": ["test instruction"], + } response = make_request( method="POST", @@ -135,13 +135,13 @@ def _(make_request=make_request): @test("route: update agent") def _(make_request=make_request, agent=test_agent): - data = dict( - name="updated agent", - about="updated agent about", - default_settings={"temperature": 1.0}, - model="gpt-4o-mini", - metadata={"hello": "world"}, - ) + data = { + "name": "updated agent", + "about": "updated agent about", + "default_settings": {"temperature": 1.0}, + "model": "gpt-4o-mini", + "metadata": {"hello": "world"}, + } agent_id = str(agent.id) response = make_request( @@ -169,12 +169,12 @@ def _(make_request=make_request, agent=test_agent): def _(make_request=make_request, agent=test_agent): agent_id = str(agent.id) - data = dict( - name="patched agent", - about="patched agent about", - default_settings={"temperature": 1.0}, - metadata={"something": "else"}, - ) + data = { + "name": "patched agent", + "about": "patched agent about", + "default_settings": {"temperature": 1.0}, + "metadata": {"hello": "world"}, + } response = make_request( method="PATCH", diff --git a/agents-api/tests/test_chat_routes.py b/agents-api/tests/test_chat_routes.py index 4838efcd5..949a712f1 100644 --- a/agents-api/tests/test_chat_routes.py +++ b/agents-api/tests/test_chat_routes.py @@ -1,17 +1,18 @@ # Tests for session queries -from ward import test - from agents_api.autogen.openapi_model import ChatInput, CreateSessionRequest from agents_api.clients import litellm +from agents_api.clients.pg import create_db_pool from agents_api.common.protocol.sessions import ChatContext -from agents_api.models.chat.gather_messages import gather_messages -from agents_api.models.chat.prepare_chat_context import prepare_chat_context -from agents_api.models.session.create_session import create_session +from agents_api.queries.chat.gather_messages import gather_messages +from agents_api.queries.chat.prepare_chat_context import prepare_chat_context +from agents_api.queries.sessions.create_session import create_session +from ward import test + from tests.fixtures import ( - cozo_client, make_request, patch_embed_acompletion, + pg_dsn, test_agent, test_developer, test_developer_id, @@ -26,15 +27,13 @@ async def _( _=patch_embed_acompletion, ): assert (await litellm.acompletion(model="gpt-4o-mini", messages=[])).id == "fake_id" - assert (await litellm.aembedding())[0][ - 0 - ] == 1.0 # pytype: disable=missing-parameter + assert (await litellm.aembedding())[0][0] == 1.0 # pytype: disable=missing-parameter @test("chat: check that non-recall gather_messages works") async def _( developer=test_developer, - client=cozo_client, + dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, session=test_session, @@ -44,10 +43,11 @@ async def _( ): (embed, _) = mocks - chat_context = prepare_chat_context( + pool = await create_db_pool(dsn=dsn) + chat_context = await prepare_chat_context( developer_id=developer_id, session_id=session.id, - client=client, + connection_pool=pool, ) session_id = session.id @@ -59,6 +59,7 @@ async def _( session_id=session_id, chat_context=chat_context, chat_input=ChatInput(messages=messages, recall=False), + connection_pool=pool, ) assert isinstance(past_messages, list) @@ -73,7 +74,7 @@ async def _( @test("chat: check that gather_messages works") async def _( developer=test_developer, - client=cozo_client, + dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, # session=test_session, @@ -81,7 +82,8 @@ async def _( user=test_user, mocks=patch_embed_acompletion, ): - session = create_session( + pool = await create_db_pool(dsn=dsn) + session = await create_session( developer_id=developer_id, data=CreateSessionRequest( agent=agent.id, @@ -92,15 +94,15 @@ async def _( "max_query_length": 1001, }, ), - client=client, + connection_pool=pool, ) (embed, _) = mocks - chat_context = prepare_chat_context( + chat_context = await prepare_chat_context( developer_id=developer_id, session_id=session.id, - client=client, + connection_pool=pool, ) session_id = session.id @@ -112,6 +114,7 @@ async def _( session_id=session_id, chat_context=chat_context, chat_input=ChatInput(messages=messages, recall=True), + connection_pool=pool, ) assert isinstance(past_messages, list) @@ -127,9 +130,10 @@ async def _( developer_id=test_developer_id, agent=test_agent, mocks=patch_embed_acompletion, - client=cozo_client, + dsn=pg_dsn, ): - session = create_session( + pool = await create_db_pool(dsn=dsn) + session = await create_session( developer_id=developer_id, data=CreateSessionRequest( agent=agent.id, @@ -140,7 +144,7 @@ async def _( "max_query_length": 1001, }, ), - client=client, + connection_pool=pool, ) (embed, acompletion) = mocks @@ -158,19 +162,20 @@ async def _( acompletion.assert_called() -@test("model: prepare chat context") -def _( - client=cozo_client, +@test("query: prepare chat context") +async def _( + dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, session=test_session, tool=test_tool, user=test_user, ): - context = prepare_chat_context( + pool = await create_db_pool(dsn=dsn) + context = await prepare_chat_context( developer_id=developer_id, session_id=session.id, - client=client, + connection_pool=pool, ) assert isinstance(context, ChatContext) diff --git a/agents-api/tests/test_developer_queries.py b/agents-api/tests/test_developer_queries.py index 569733fa5..77d33f32a 100644 --- a/agents-api/tests/test_developer_queries.py +++ b/agents-api/tests/test_developer_queries.py @@ -1,36 +1,92 @@ # Tests for agent queries -from uuid import uuid4 +from agents_api.autogen.openapi_model import ResourceCreatedResponse +from agents_api.clients.pg import create_db_pool +from agents_api.common.protocol.developers import Developer +from agents_api.queries.developers.create_developer import create_developer +from agents_api.queries.developers.get_developer import ( + get_developer, +) +from agents_api.queries.developers.patch_developer import patch_developer +from agents_api.queries.developers.update_developer import update_developer +from uuid_extensions import uuid7 from ward import raises, test -from agents_api.common.protocol.developers import Developer -from agents_api.models.developer.get_developer import get_developer, verify_developer -from tests.fixtures import cozo_client, test_developer_id +from .fixtures import pg_dsn, random_email, test_new_developer + + +@test("query: get developer not exists") +async def _(dsn=pg_dsn): + pool = await create_db_pool(dsn=dsn) + with raises(Exception): + await get_developer( + developer_id=uuid7(), + connection_pool=pool, + ) -@test("model: get developer") -def _(client=cozo_client, developer_id=test_developer_id): - developer = get_developer( - developer_id=developer_id, - client=client, +@test("query: get developer exists") +async def _(dsn=pg_dsn, dev=test_new_developer): + pool = await create_db_pool(dsn=dsn) + developer = await get_developer( + developer_id=dev.id, + connection_pool=pool, ) - assert isinstance(developer, Developer) - assert developer.id + assert type(developer) is Developer + assert developer.id == dev.id + assert developer.email == dev.email + assert developer.active + assert developer.tags == dev.tags + assert developer.settings == dev.settings -@test("model: verify developer exists") -def _(client=cozo_client, developer_id=test_developer_id): - verify_developer( - developer_id=developer_id, - client=client, +@test("query: create developer") +async def _(dsn=pg_dsn): + pool = await create_db_pool(dsn=dsn) + dev_id = uuid7() + developer = await create_developer( + email="m@mail.com", + active=True, + tags=["tag1"], + settings={"key1": "val1"}, + developer_id=dev_id, + connection_pool=pool, ) + assert type(developer) is ResourceCreatedResponse + assert developer.id == dev_id + assert developer.created_at is not None -@test("model: verify developer not exists") -def _(client=cozo_client): - with raises(Exception): - verify_developer( - developer_id=uuid4(), - client=client, - ) + +@test("query: update developer") +async def _(dsn=pg_dsn, dev=test_new_developer, email=random_email): + pool = await create_db_pool(dsn=dsn) + developer = await update_developer( + email=email, + tags=["tag2"], + settings={"key2": "val2"}, + developer_id=dev.id, + connection_pool=pool, + ) + + assert developer.id == dev.id + + +@test("query: patch developer") +async def _(dsn=pg_dsn, dev=test_new_developer, email=random_email): + pool = await create_db_pool(dsn=dsn) + developer = await patch_developer( + email=email, + active=True, + tags=["tag2"], + settings={"key2": "val2"}, + developer_id=dev.id, + connection_pool=pool, + ) + + assert developer.id == dev.id + assert developer.email == email + assert developer.active + assert developer.tags == [*dev.tags, "tag2"] + assert developer.settings == {**dev.settings, "key2": "val2"} diff --git a/agents-api/tests/test_docs_queries.py b/agents-api/tests/test_docs_queries.py index a7fa7868a..2c49de891 100644 --- a/agents-api/tests/test_docs_queries.py +++ b/agents-api/tests/test_docs_queries.py @@ -1,163 +1,323 @@ -# Tests for entry queries - import asyncio -from ward import test - from agents_api.autogen.openapi_model import CreateDocRequest -from agents_api.models.docs.create_doc import create_doc -from agents_api.models.docs.delete_doc import delete_doc -from agents_api.models.docs.embed_snippets import embed_snippets -from agents_api.models.docs.get_doc import get_doc -from agents_api.models.docs.list_docs import list_docs -from agents_api.models.docs.search_docs_by_embedding import search_docs_by_embedding -from agents_api.models.docs.search_docs_by_text import search_docs_by_text -from tests.fixtures import ( - EMBEDDING_SIZE, - cozo_client, - test_agent, - test_developer_id, - test_doc, - test_user, -) - - -@test("model: create docs") -def _( - client=cozo_client, developer_id=test_developer_id, agent=test_agent, user=test_user -): - create_doc( - developer_id=developer_id, +from agents_api.clients.pg import create_db_pool +from agents_api.queries.docs.create_doc import create_doc +from agents_api.queries.docs.delete_doc import delete_doc +from agents_api.queries.docs.get_doc import get_doc +from agents_api.queries.docs.list_docs import list_docs +from agents_api.queries.docs.search_docs_by_embedding import search_docs_by_embedding +from agents_api.queries.docs.search_docs_by_text import search_docs_by_text +from agents_api.queries.docs.search_docs_hybrid import search_docs_hybrid +from ward import skip, test + +from tests.fixtures import pg_dsn, test_agent, test_developer, test_doc, test_user + +EMBEDDING_SIZE: int = 1024 + + +@test("query: create user doc") +async def _(dsn=pg_dsn, developer=test_developer, user=test_user): + pool = await create_db_pool(dsn=dsn) + doc_created = await create_doc( + developer_id=developer.id, + data=CreateDocRequest( + title="User Doc", + content=["Docs for user testing", "Docs for user testing 2"], + metadata={"test": "test"}, + embed_instruction="Embed the document", + ), + owner_type="user", + owner_id=user.id, + connection_pool=pool, + ) + + assert doc_created.id is not None + + # Verify doc appears in user's docs + found = await get_doc( + developer_id=developer.id, + doc_id=doc_created.id, + connection_pool=pool, + ) + assert found.id == doc_created.id + + +@test("query: create agent doc") +async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent): + pool = await create_db_pool(dsn=dsn) + doc = await create_doc( + developer_id=developer.id, + data=CreateDocRequest( + title="Agent Doc", + content="Docs for agent testing", + metadata={"test": "test"}, + embed_instruction="Embed the document", + ), owner_type="agent", owner_id=agent.id, - data=CreateDocRequest(title="Hello", content=["World"]), - client=client, + connection_pool=pool, ) + assert doc.id is not None - create_doc( - developer_id=developer_id, - owner_type="user", - owner_id=user.id, - data=CreateDocRequest(title="Hello", content=["World"]), - client=client, + # Verify doc appears in agent's docs + docs_list = await list_docs( + developer_id=developer.id, + owner_type="agent", + owner_id=agent.id, + connection_pool=pool, ) + assert any(d.id == doc.id for d in docs_list) -@test("model: get docs") -def _(client=cozo_client, doc=test_doc, developer_id=test_developer_id): - get_doc( - developer_id=developer_id, +@test("query: get doc") +async def _(dsn=pg_dsn, developer=test_developer, doc=test_doc): + pool = await create_db_pool(dsn=dsn) + doc_test = await get_doc( + developer_id=developer.id, doc_id=doc.id, - client=client, + connection_pool=pool, ) + assert doc_test.id == doc.id + assert doc_test.title is not None + assert doc_test.content is not None -@test("model: delete doc") -def _(client=cozo_client, developer_id=test_developer_id, agent=test_agent): - doc = create_doc( - developer_id=developer_id, +@test("query: list user docs") +async def _(dsn=pg_dsn, developer=test_developer, user=test_user): + pool = await create_db_pool(dsn=dsn) + + # Create a doc owned by the user + doc_user = await create_doc( + developer_id=developer.id, + data=CreateDocRequest( + title="User List Test", + content="Some user doc content", + metadata={"test": "test"}, + embed_instruction="Embed the document", + ), + owner_type="user", + owner_id=user.id, + connection_pool=pool, + ) + + # List user's docs + docs_list = await list_docs( + developer_id=developer.id, + owner_type="user", + owner_id=user.id, + connection_pool=pool, + ) + assert len(docs_list) >= 1 + assert any(d.id == doc_user.id for d in docs_list) + + +@test("query: list agent docs") +async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent): + pool = await create_db_pool(dsn=dsn) + + # Create a doc owned by the agent + doc_agent = await create_doc( + developer_id=developer.id, + data=CreateDocRequest( + title="Agent List Test", + content="Some agent doc content", + metadata={"test": "test"}, + embed_instruction="Embed the document", + ), owner_type="agent", owner_id=agent.id, - data=CreateDocRequest(title="Hello", content=["World"]), - client=client, + connection_pool=pool, ) - delete_doc( - developer_id=developer_id, - doc_id=doc.id, + # List agent's docs + docs_list = await list_docs( + developer_id=developer.id, owner_type="agent", owner_id=agent.id, - client=client, + connection_pool=pool, ) + assert len(docs_list) >= 1 + assert any(d.id == doc_agent.id for d in docs_list) + +@test("query: delete user doc") +async def _(dsn=pg_dsn, developer=test_developer, user=test_user): + pool = await create_db_pool(dsn=dsn) -@test("model: list docs") -def _( - client=cozo_client, developer_id=test_developer_id, doc=test_doc, agent=test_agent -): - result = list_docs( - developer_id=developer_id, + # Create a doc owned by the user + doc_user = await create_doc( + developer_id=developer.id, + data=CreateDocRequest( + title="User Delete Test", + content="Doc for user deletion test", + metadata={"test": "test"}, + embed_instruction="Embed the document", + ), + owner_type="user", + owner_id=user.id, + connection_pool=pool, + ) + + # Delete the doc + await delete_doc( + developer_id=developer.id, + doc_id=doc_user.id, + owner_type="user", + owner_id=user.id, + connection_pool=pool, + ) + + # Verify doc is no longer in user's docs + docs_list = await list_docs( + developer_id=developer.id, + owner_type="user", + owner_id=user.id, + connection_pool=pool, + ) + assert not any(d.id == doc_user.id for d in docs_list) + + +@test("query: delete agent doc") +async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent): + pool = await create_db_pool(dsn=dsn) + + # Create a doc owned by the agent + doc_agent = await create_doc( + developer_id=developer.id, + data=CreateDocRequest( + title="Agent Delete Test", + content="Doc for agent deletion test", + metadata={"test": "test"}, + embed_instruction="Embed the document", + ), owner_type="agent", owner_id=agent.id, - client=client, - include_without_embeddings=True, + connection_pool=pool, ) - assert len(result) >= 1 + # Delete the doc + await delete_doc( + developer_id=developer.id, + doc_id=doc_agent.id, + owner_type="agent", + owner_id=agent.id, + connection_pool=pool, + ) + + # Verify doc is no longer in agent's docs + docs_list = await list_docs( + developer_id=developer.id, + owner_type="agent", + owner_id=agent.id, + connection_pool=pool, + ) + assert not any(d.id == doc_agent.id for d in docs_list) + +@skip("text search: test container not vectorizing") +@test("query: search docs by text") +async def _(dsn=pg_dsn, agent=test_agent, developer=test_developer): + pool = await create_db_pool(dsn=dsn) -@test("model: search docs by text") -async def _(client=cozo_client, agent=test_agent, developer_id=test_developer_id): - create_doc( - developer_id=developer_id, + # Create a test document + doc = await create_doc( + developer_id=developer.id, owner_type="agent", owner_id=agent.id, data=CreateDocRequest( - title="Hello", content=["The world is a funny little thing"] + title="Hello", + content="The world is a funny little thing", + metadata={"test": "test"}, + embed_instruction="Embed the document", ), - client=client, + connection_pool=pool, ) - await asyncio.sleep(1) + # Add a longer delay to ensure the search index is updated + await asyncio.sleep(3) - result = search_docs_by_text( - developer_id=developer_id, + # Search using simpler terms first + result = await search_docs_by_text( + developer_id=developer.id, owners=[("agent", agent.id)], - query="funny", - client=client, + query="world", + k=3, + search_language="english", + metadata_filter={"test": "test"}, + connection_pool=pool, ) - assert len(result) >= 1 - assert result[0].metadata is not None + print("\nSearch results:", result) + + # More specific assertions + assert len(result) >= 1, "Should find at least one document" + assert any(d.id == doc.id for d in result), f"Should find document {doc.id}" + assert result[0].metadata == {"test": "test"}, "Metadata should match" -@test("model: search docs by embedding") -async def _(client=cozo_client, agent=test_agent, developer_id=test_developer_id): - doc = create_doc( - developer_id=developer_id, +@skip("embedding search: test container not vectorizing") +@test("query: search docs by embedding") +async def _(dsn=pg_dsn, agent=test_agent, developer=test_developer): + pool = await create_db_pool(dsn=dsn) + + # Create a test document + await create_doc( + developer_id=developer.id, owner_type="agent", owner_id=agent.id, - data=CreateDocRequest(title="Hello", content=["World"]), - client=client, - ) - - ### Add embedding to the snippet - embed_snippets( - developer_id=developer_id, - doc_id=doc.id, - snippet_indices=[0], - embeddings=[[1.0] * EMBEDDING_SIZE], - client=client, + data=CreateDocRequest( + title="Hello", + content="The world is a funny little thing", + metadata={"test": "test"}, + embed_instruction="Embed the document", + ), + connection_pool=pool, ) - await asyncio.sleep(1) - - ### Search - query_embedding = [0.99] * EMBEDDING_SIZE - - result = search_docs_by_embedding( - developer_id=developer_id, + # Search using the correct parameter types + result = await search_docs_by_embedding( + developer_id=developer.id, owners=[("agent", agent.id)], - query_embedding=query_embedding, - client=client, + query_embedding=[1.0] * 1024, + k=3, # Add k parameter + metadata_filter={"test": "test"}, # Add metadata filter + connection_pool=pool, ) assert len(result) >= 1 assert result[0].metadata is not None -@test("model: embed snippets") -def _(client=cozo_client, developer_id=test_developer_id, doc=test_doc): - snippet_indices = [0] - embeddings = [[1.0] * EMBEDDING_SIZE] +@skip("hybrid search: test container not vectorizing") +@test("query: search docs by hybrid") +async def _(dsn=pg_dsn, agent=test_agent, developer=test_developer): + pool = await create_db_pool(dsn=dsn) - result = embed_snippets( - developer_id=developer_id, - doc_id=doc.id, - snippet_indices=snippet_indices, - embeddings=embeddings, - client=client, + # Create a test document + await create_doc( + developer_id=developer.id, + owner_type="agent", + owner_id=agent.id, + data=CreateDocRequest( + title="Hello", + content="The world is a funny little thing", + metadata={"test": "test"}, + embed_instruction="Embed the document", + ), + connection_pool=pool, ) - assert result is not None - assert result.id == doc.id + # Search using the correct parameter types + result = await search_docs_hybrid( + developer_id=developer.id, + owners=[("agent", agent.id)], + text_query="funny thing", + embedding=[1.0] * 1024, + k=3, # Add k parameter + metadata_filter={"test": "test"}, # Add metadata filter + connection_pool=pool, + ) + + assert len(result) >= 1 + assert result[0].metadata is not None diff --git a/agents-api/tests/test_docs_routes.py b/agents-api/tests/test_docs_routes.py index 89a14a41c..e62da6c42 100644 --- a/agents-api/tests/test_docs_routes.py +++ b/agents-api/tests/test_docs_routes.py @@ -1,4 +1,4 @@ -import time +import asyncio from ward import skip, test @@ -10,16 +10,17 @@ test_user, test_user_doc, ) -from tests.utils import patch_testing_temporal + +from .utils import patch_testing_temporal @test("route: create user doc") async def _(make_request=make_request, user=test_user): async with patch_testing_temporal(): - data = dict( - title="Test User Doc", - content=["This is a test user document."], - ) + data = { + "title": "Test User Doc", + "content": ["This is a test user document."], + } response = make_request( method="POST", @@ -29,17 +30,14 @@ async def _(make_request=make_request, user=test_user): assert response.status_code == 201 - result = response.json() - assert len(result["jobs"]) > 0 - @test("route: create agent doc") async def _(make_request=make_request, agent=test_agent): async with patch_testing_temporal(): - data = dict( - title="Test Agent Doc", - content=["This is a test agent document."], - ) + data = { + "title": "Test Agent Doc", + "content": ["This is a test agent document."], + } response = make_request( method="POST", @@ -49,17 +47,14 @@ async def _(make_request=make_request, agent=test_agent): assert response.status_code == 201 - result = response.json() - assert len(result["jobs"]) > 0 - @test("route: delete doc") async def _(make_request=make_request, agent=test_agent): async with patch_testing_temporal(): - data = dict( - title="Test Agent Doc", - content=["This is a test agent document."], - ) + data = { + "title": "Test Agent Doc", + "content": "This is a test agent document.", + } response = make_request( method="POST", @@ -68,6 +63,16 @@ async def _(make_request=make_request, agent=test_agent): ) doc_id = response.json()["id"] + response = make_request( + method="GET", + url=f"/docs/{doc_id}", + ) + + assert response.status_code == 200 + assert response.json()["id"] == doc_id + assert response.json()["title"] == "Test Agent Doc" + assert response.json()["content"] == "This is a test agent document." + response = make_request( method="DELETE", url=f"/agents/{agent.id}/docs/{doc_id}", @@ -86,10 +91,10 @@ async def _(make_request=make_request, agent=test_agent): @test("route: get doc") async def _(make_request=make_request, agent=test_agent): async with patch_testing_temporal(): - data = dict( - title="Test Agent Doc", - content=["This is a test agent document."], - ) + data = { + "title": "Test Agent Doc", + "content": ["This is a test agent document."], + } response = make_request( method="POST", @@ -168,14 +173,14 @@ def _(make_request=make_request, agent=test_agent): assert isinstance(docs, list) -# TODO: Fix this test. It fails sometimes and sometimes not. +@skip("Fails due to FTS not working in Test Container") @test("route: search agent docs") async def _(make_request=make_request, agent=test_agent, doc=test_doc): - time.sleep(0.5) - search_params = dict( - text=doc.content[0], - limit=1, - ) + await asyncio.sleep(0.5) + search_params = { + "text": doc.content[0], + "limit": 1, + } response = make_request( method="POST", @@ -191,15 +196,14 @@ async def _(make_request=make_request, agent=test_agent, doc=test_doc): assert len(docs) >= 1 -# FIXME: This test is failing because the search is not returning the expected results -@skip("Fails randomly on CI") +@skip("Fails due to FTS not working in Test Container") @test("route: search user docs") async def _(make_request=make_request, user=test_user, doc=test_user_doc): - time.sleep(0.5) - search_params = dict( - text=doc.content[0], - limit=1, - ) + await asyncio.sleep(0.5) + search_params = { + "text": doc.content[0], + "limit": 1, + } response = make_request( method="POST", @@ -216,17 +220,18 @@ async def _(make_request=make_request, user=test_user, doc=test_user_doc): assert len(docs) >= 1 +@skip("Fails due to Vectorizer and FTS not working in Test Container") @test("route: search agent docs hybrid with mmr") async def _(make_request=make_request, agent=test_agent, doc=test_doc): - time.sleep(0.5) + await asyncio.sleep(0.5) EMBEDDING_SIZE = 1024 - search_params = dict( - text=doc.content[0], - vector=[1.0] * EMBEDDING_SIZE, - mmr_strength=0.5, - limit=1, - ) + search_params = { + "text": doc.content[0], + "vector": [1.0] * EMBEDDING_SIZE, + "mmr_strength": 0.5, + "limit": 1, + } response = make_request( method="POST", diff --git a/agents-api/tests/test_entry_queries.py b/agents-api/tests/test_entry_queries.py index a3c93f465..fe514c31a 100644 --- a/agents-api/tests/test_entry_queries.py +++ b/agents-api/tests/test_entry_queries.py @@ -3,30 +3,32 @@ It verifies the functionality of adding, retrieving, and processing entries as defined in the schema. """ -# Tests for entry queries - -import time - -from ward import test - -from agents_api.autogen.openapi_model import CreateEntryRequest -from agents_api.models.entry.create_entries import create_entries -from agents_api.models.entry.delete_entries import delete_entries -from agents_api.models.entry.get_history import get_history -from agents_api.models.entry.list_entries import list_entries -from agents_api.models.session.get_session import get_session -from tests.fixtures import cozo_client, test_developer_id, test_session +from agents_api.autogen.openapi_model import ( + CreateEntryRequest, + Entry, + History, +) +from agents_api.clients.pg import create_db_pool +from agents_api.queries.entries import ( + create_entries, + delete_entries, + get_history, + list_entries, +) +from fastapi import HTTPException +from uuid_extensions import uuid7 +from ward import raises, test + +from tests.fixtures import pg_dsn, test_developer, test_developer_id, test_session MODEL = "gpt-4o-mini" -@test("model: create entry") -def _(client=cozo_client, developer_id=test_developer_id, session=test_session): - """ - Tests the addition of a new entry to the database. - Verifies that the entry can be successfully added using the create_entries function. - """ +@test("query: create entry no session") +async def _(dsn=pg_dsn, developer=test_developer): + """Test the addition of a new entry to the database.""" + pool = await create_db_pool(dsn=dsn) test_entry = CreateEntryRequest.from_model_input( model=MODEL, role="user", @@ -34,56 +36,36 @@ def _(client=cozo_client, developer_id=test_developer_id, session=test_session): content="test entry content", ) - create_entries( - developer_id=developer_id, - session_id=session.id, - data=[test_entry], - mark_session_as_updated=False, - client=client, - ) + with raises(HTTPException) as exc_info: + await create_entries( + developer_id=developer.id, + session_id=uuid7(), + data=[test_entry], + connection_pool=pool, + ) + assert exc_info.raised.status_code == 404 -@test("model: create entry, update session") -def _(client=cozo_client, developer_id=test_developer_id, session=test_session): - """ - Tests the addition of a new entry to the database. - Verifies that the entry can be successfully added using the create_entries function. - """ +@test("query: list entries sql - no session") +async def _(dsn=pg_dsn, developer=test_developer): + """Test the retrieval of entries from the database.""" - test_entry = CreateEntryRequest.from_model_input( - model=MODEL, - role="user", - source="internal", - content="test entry content", - ) - - # TODO: We should make sessions.updated_at also a updated_at_ms field to avoid this sleep - time.sleep(1) - - create_entries( - developer_id=developer_id, - session_id=session.id, - data=[test_entry], - mark_session_as_updated=True, - client=client, - ) - - updated_session = get_session( - developer_id=developer_id, - session_id=session.id, - client=client, - ) + pool = await create_db_pool(dsn=dsn) - assert updated_session.updated_at > session.updated_at + with raises(HTTPException) as exc_info: + await list_entries( + developer_id=developer.id, + session_id=uuid7(), + connection_pool=pool, + ) + assert exc_info.raised.status_code == 404 -@test("model: get entries") -def _(client=cozo_client, developer_id=test_developer_id, session=test_session): - """ - Tests the retrieval of entries from the database. - Verifies that entries matching specific criteria can be successfully retrieved. - """ +@test("query: list entries sql - session exists") +async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): + """Test the retrieval of entries from the database.""" + pool = await create_db_pool(dsn=dsn) test_entry = CreateEntryRequest.from_model_input( model=MODEL, role="user", @@ -98,30 +80,30 @@ def _(client=cozo_client, developer_id=test_developer_id, session=test_session): source="internal", ) - create_entries( + await create_entries( developer_id=developer_id, session_id=session.id, data=[test_entry, internal_entry], - client=client, + connection_pool=pool, ) - result = list_entries( + result = await list_entries( developer_id=developer_id, session_id=session.id, - client=client, + connection_pool=pool, ) - # Asserts that only one entry is retrieved, matching the session_id. + # Assert that only one entry is retrieved, matching the session_id. assert len(result) == 1 + assert isinstance(result[0], Entry) + assert result is not None -@test("model: get history") -def _(client=cozo_client, developer_id=test_developer_id, session=test_session): - """ - Tests the retrieval of entries from the database. - Verifies that entries matching specific criteria can be successfully retrieved. - """ +@test("query: get history sql - session exists") +async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): + """Test the retrieval of entry history from the database.""" + pool = await create_db_pool(dsn=dsn) test_entry = CreateEntryRequest.from_model_input( model=MODEL, role="user", @@ -136,31 +118,31 @@ def _(client=cozo_client, developer_id=test_developer_id, session=test_session): source="internal", ) - create_entries( + await create_entries( developer_id=developer_id, session_id=session.id, data=[test_entry, internal_entry], - client=client, + connection_pool=pool, ) - result = get_history( + result = await get_history( developer_id=developer_id, session_id=session.id, - client=client, + connection_pool=pool, ) - # Asserts that only one entry is retrieved, matching the session_id. + # Assert that entries are retrieved and have valid IDs. + assert result is not None + assert isinstance(result, History) assert len(result.entries) > 0 assert result.entries[0].id -@test("model: delete entries") -def _(client=cozo_client, developer_id=test_developer_id, session=test_session): - """ - Tests the deletion of entries from the database. - Verifies that entries can be successfully deleted using the delete_entries function. - """ +@test("query: delete entries sql - session exists") +async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): + """Test the deletion of entries from the database.""" + pool = await create_db_pool(dsn=dsn) test_entry = CreateEntryRequest.from_model_input( model=MODEL, role="user", @@ -175,27 +157,29 @@ def _(client=cozo_client, developer_id=test_developer_id, session=test_session): source="internal", ) - created_entries = create_entries( + created_entries = await create_entries( developer_id=developer_id, session_id=session.id, data=[test_entry, internal_entry], - client=client, + connection_pool=pool, ) entry_ids = [entry.id for entry in created_entries] - delete_entries( + await delete_entries( developer_id=developer_id, session_id=session.id, entry_ids=entry_ids, - client=client, + connection_pool=pool, ) - result = list_entries( + result = await list_entries( developer_id=developer_id, session_id=session.id, - client=client, + connection_pool=pool, ) - # Asserts that no entries are retrieved after deletion. + # Assert that no entries are retrieved after deletion. assert all(id not in [entry.id for entry in result] for id in entry_ids) + assert len(result) == 0 + assert result is not None diff --git a/agents-api/tests/test_execution_queries.py b/agents-api/tests/test_execution_queries.py index 9e75b3cda..c9acffc3c 100644 --- a/agents-api/tests/test_execution_queries.py +++ b/agents-api/tests/test_execution_queries.py @@ -1,60 +1,65 @@ -# Tests for execution queries - -from temporalio.client import WorkflowHandle -from ward import test +# # Tests for execution queries from agents_api.autogen.openapi_model import ( CreateExecutionRequest, CreateTransitionRequest, Execution, ) -from agents_api.models.execution.count_executions import count_executions -from agents_api.models.execution.create_execution import create_execution -from agents_api.models.execution.create_execution_transition import ( +from agents_api.clients.pg import create_db_pool +from agents_api.queries.executions.count_executions import count_executions +from agents_api.queries.executions.create_execution import create_execution +from agents_api.queries.executions.create_execution_transition import ( create_execution_transition, ) -from agents_api.models.execution.create_temporal_lookup import create_temporal_lookup -from agents_api.models.execution.get_execution import get_execution -from agents_api.models.execution.list_executions import list_executions -from agents_api.models.execution.lookup_temporal_data import lookup_temporal_data +from agents_api.queries.executions.create_temporal_lookup import create_temporal_lookup +from agents_api.queries.executions.get_execution import get_execution +from agents_api.queries.executions.list_executions import list_executions +from agents_api.queries.executions.lookup_temporal_data import lookup_temporal_data +from temporalio.client import WorkflowHandle +from ward import test + from tests.fixtures import ( - cozo_client, + pg_dsn, test_developer_id, test_execution, test_execution_started, test_task, ) -MODEL = "gpt-4o-mini-mini" +MODEL = "gpt-4o-mini" -@test("model: create execution") -def _(client=cozo_client, developer_id=test_developer_id, task=test_task): +@test("query: create execution") +async def _(dsn=pg_dsn, developer_id=test_developer_id, task=test_task): + pool = await create_db_pool(dsn=dsn) workflow_handle = WorkflowHandle( client=None, id="blah", ) - execution = create_execution( + execution = await create_execution( developer_id=developer_id, task_id=task.id, data=CreateExecutionRequest(input={"test": "test"}), - client=client, + connection_pool=pool, ) - create_temporal_lookup( - developer_id=developer_id, + await create_temporal_lookup( execution_id=execution.id, workflow_handle=workflow_handle, - client=client, + connection_pool=pool, ) + assert execution.status == "queued" + assert execution.input == {"test": "test"} + -@test("model: get execution") -def _(client=cozo_client, developer_id=test_developer_id, execution=test_execution): - result = get_execution( +@test("query: get execution") +async def _(dsn=pg_dsn, developer_id=test_developer_id, execution=test_execution): + pool = await create_db_pool(dsn=dsn) + result = await get_execution( execution_id=execution.id, - client=client, + connection_pool=pool, ) assert result is not None @@ -62,80 +67,84 @@ def _(client=cozo_client, developer_id=test_developer_id, execution=test_executi assert result.status == "queued" -@test("model: lookup temporal id") -def _(client=cozo_client, developer_id=test_developer_id, execution=test_execution): - result = lookup_temporal_data( +@test("query: lookup temporal id") +async def _(dsn=pg_dsn, developer_id=test_developer_id, execution=test_execution): + pool = await create_db_pool(dsn=dsn) + result = await lookup_temporal_data( execution_id=execution.id, developer_id=developer_id, - client=client, + connection_pool=pool, ) assert result is not None assert result["id"] -@test("model: list executions") -def _( - client=cozo_client, +@test("query: list executions") +async def _( + dsn=pg_dsn, developer_id=test_developer_id, - execution=test_execution, + execution=test_execution_started, task=test_task, ): - result = list_executions( + pool = await create_db_pool(dsn=dsn) + result = await list_executions( developer_id=developer_id, task_id=task.id, - client=client, + connection_pool=pool, ) assert isinstance(result, list) assert len(result) >= 1 - assert result[0].status == "queued" + assert result[0].status == "starting" -@test("model: count executions") -def _( - client=cozo_client, +@test("query: count executions") +async def _( + dsn=pg_dsn, developer_id=test_developer_id, - execution=test_execution, + execution=test_execution_started, task=test_task, ): - result = count_executions( + pool = await create_db_pool(dsn=dsn) + result = await count_executions( developer_id=developer_id, task_id=task.id, - client=client, + connection_pool=pool, ) assert isinstance(result, dict) assert result["count"] > 0 -@test("model: create execution transition") -def _(client=cozo_client, developer_id=test_developer_id, execution=test_execution): - result = create_execution_transition( +@test("query: create execution transition") +async def _(dsn=pg_dsn, developer_id=test_developer_id, execution=test_execution): + pool = await create_db_pool(dsn=dsn) + result = await create_execution_transition( developer_id=developer_id, execution_id=execution.id, data=CreateTransitionRequest( - type="step", + type="init_branch", output={"result": "test"}, current={"workflow": "main", "step": 0}, - next={"workflow": "main", "step": 1}, + next={"workflow": "main", "step": 0}, ), - client=client, + connection_pool=pool, ) assert result is not None - assert result.type == "step" + assert result.type == "init_branch" assert result.output == {"result": "test"} -@test("model: create execution transition with execution update") -def _( - client=cozo_client, +@test("query: create execution transition with execution update") +async def _( + dsn=pg_dsn, developer_id=test_developer_id, - task=test_task, execution=test_execution_started, ): - result = create_execution_transition( + pool = await create_db_pool(dsn=dsn) + result = await create_execution_transition( developer_id=developer_id, execution_id=execution.id, data=CreateTransitionRequest( @@ -144,9 +153,9 @@ def _( current={"workflow": "main", "step": 0}, next=None, ), - task_id=task.id, - update_execution_status=True, - client=client, + # task_id=task.id, + # update_execution_status=True, + connection_pool=pool, ) assert result is not None diff --git a/agents-api/tests/test_execution_workflow.py b/agents-api/tests/test_execution_workflow.py index e733f81c0..e953fe138 100644 --- a/agents-api/tests/test_execution_workflow.py +++ b/agents-api/tests/test_execution_workflow.py @@ -1,53 +1,51 @@ # Tests for task queries + import asyncio import json from unittest.mock import patch import yaml -from google.protobuf.json_format import MessageToDict -from litellm.types.utils import Choices, ModelResponse -from ward import raises, skip, test - from agents_api.autogen.openapi_model import ( CreateExecutionRequest, CreateTaskRequest, ) -from agents_api.models.task.create_task import create_task +from agents_api.clients.pg import create_db_pool +from agents_api.queries.tasks.create_task import create_task from agents_api.routers.tasks.create_task_execution import start_execution -from tests.fixtures import ( - async_cozo_client, - cozo_client, - cozo_clients_with_migrations, +from google.protobuf.json_format import MessageToDict +from litellm import Choices, ModelResponse +from ward import raises, skip, test + +from .fixtures import ( + pg_dsn, + s3_client, test_agent, test_developer_id, ) -from tests.utils import patch_integration_service, patch_testing_temporal - -EMBEDDING_SIZE: int = 1024 +from .utils import patch_integration_service, patch_testing_temporal @test("workflow: evaluate step single") async def _( - clients=cozo_clients_with_migrations, + dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, + _s3_client=s3_client, # Adding coz blob store might be used ): - client, _ = clients + pool = await create_db_pool(dsn=dsn) data = CreateExecutionRequest(input={"test": "input"}) - task = create_task( + task = await create_task( developer_id=developer_id, agent_id=agent.id, data=CreateTaskRequest( - **{ - "name": "test task", - "description": "test task about", - "input_schema": {"type": "object", "additionalProperties": True}, - "main": [{"evaluate": {"hello": '"world"'}}], - } + name="test task", + description="test task about", + input_schema={"type": "object", "additionalProperties": True}, + main=[{"evaluate": {"hello": '"world"'}}], ), - client=client, + connection_pool=pool, ) async with patch_testing_temporal() as (_, mock_run_task_execution_workflow): @@ -55,7 +53,7 @@ async def _( developer_id=developer_id, task_id=task.id, data=data, - client=client, + connection_pool=pool, ) assert handle is not None @@ -69,28 +67,27 @@ async def _( @test("workflow: evaluate step multiple") async def _( - clients=cozo_clients_with_migrations, + dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, + _s3_client=s3_client, # Adding coz blob store might be used ): - client, _ = clients + pool = await create_db_pool(dsn=dsn) data = CreateExecutionRequest(input={"test": "input"}) - task = create_task( + task = await create_task( developer_id=developer_id, agent_id=agent.id, data=CreateTaskRequest( - **{ - "name": "test task", - "description": "test task about", - "input_schema": {"type": "object", "additionalProperties": True}, - "main": [ - {"evaluate": {"hello": '"nope"'}}, - {"evaluate": {"hello": '"world"'}}, - ], - } + name="test task", + description="test task about", + input_schema={"type": "object", "additionalProperties": True}, + main=[ + {"evaluate": {"hello": '"nope"'}}, + {"evaluate": {"hello": '"world"'}}, + ], ), - client=client, + connection_pool=pool, ) async with patch_testing_temporal() as (_, mock_run_task_execution_workflow): @@ -98,7 +95,7 @@ async def _( developer_id=developer_id, task_id=task.id, data=data, - client=client, + connection_pool=pool, ) assert handle is not None @@ -112,28 +109,27 @@ async def _( @test("workflow: variable access in expressions") async def _( - clients=cozo_clients_with_migrations, + dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, + _s3_client=s3_client, # Adding coz blob store might be used ): - client, _ = clients + pool = await create_db_pool(dsn=dsn) data = CreateExecutionRequest(input={"test": "input"}) - task = create_task( + task = await create_task( developer_id=developer_id, agent_id=agent.id, data=CreateTaskRequest( - **{ - "name": "test task", - "description": "test task about", - "input_schema": {"type": "object", "additionalProperties": True}, - "main": [ - # Testing that we can access the input - {"evaluate": {"hello": '_["test"]'}}, - ], - } + name="test task", + description="test task about", + input_schema={"type": "object", "additionalProperties": True}, + main=[ + # Testing that we can access the input + {"evaluate": {"hello": '_["test"]'}}, + ], ), - client=client, + connection_pool=pool, ) async with patch_testing_temporal() as (_, mock_run_task_execution_workflow): @@ -141,7 +137,7 @@ async def _( developer_id=developer_id, task_id=task.id, data=data, - client=client, + connection_pool=pool, ) assert handle is not None @@ -155,35 +151,34 @@ async def _( @test("workflow: yield step") async def _( - clients=cozo_clients_with_migrations, + dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, + _s3_client=s3_client, # Adding coz blob store might be used ): - client, _ = clients + pool = await create_db_pool(dsn=dsn) data = CreateExecutionRequest(input={"test": "input"}) - task = create_task( + task = await create_task( developer_id=developer_id, agent_id=agent.id, data=CreateTaskRequest( - **{ - "name": "test task", - "description": "test task about", - "input_schema": {"type": "object", "additionalProperties": True}, - "other_workflow": [ - # Testing that we can access the input - {"evaluate": {"hello": '_["test"]'}}, - ], - "main": [ - # Testing that we can access the input - { - "workflow": "other_workflow", - "arguments": {"test": '_["test"]'}, - }, - ], - } + name="test task", + description="test task about", + input_schema={"type": "object", "additionalProperties": True}, + other_workflow=[ + # Testing that we can access the input + {"evaluate": {"hello": '_["test"]'}}, + ], + main=[ + # Testing that we can access the input + { + "workflow": "other_workflow", + "arguments": {"test": '_["test"]'}, + }, + ], ), - client=client, + connection_pool=pool, ) async with patch_testing_temporal() as (_, mock_run_task_execution_workflow): @@ -191,7 +186,7 @@ async def _( developer_id=developer_id, task_id=task.id, data=data, - client=client, + connection_pool=pool, ) assert handle is not None @@ -205,36 +200,35 @@ async def _( @test("workflow: sleep step") async def _( - clients=cozo_clients_with_migrations, + dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, + _s3_client=s3_client, # Adding coz blob store might be used ): - client, _ = clients + pool = await create_db_pool(dsn=dsn) data = CreateExecutionRequest(input={"test": "input"}) - task = create_task( + task = await create_task( developer_id=developer_id, agent_id=agent.id, data=CreateTaskRequest( - **{ - "name": "test task", - "description": "test task about", - "input_schema": {"type": "object", "additionalProperties": True}, - "other_workflow": [ - # Testing that we can access the input - {"evaluate": {"hello": '_["test"]'}}, - {"sleep": {"days": 5}}, - ], - "main": [ - # Testing that we can access the input - { - "workflow": "other_workflow", - "arguments": {"test": '_["test"]'}, - }, - ], - } + name="test task", + description="test task about", + input_schema={"type": "object", "additionalProperties": True}, + other_workflow=[ + # Testing that we can access the input + {"evaluate": {"hello": '_["test"]'}}, + {"sleep": {"days": 5}}, + ], + main=[ + # Testing that we can access the input + { + "workflow": "other_workflow", + "arguments": {"test": '_["test"]'}, + }, + ], ), - client=client, + connection_pool=pool, ) async with patch_testing_temporal() as (_, mock_run_task_execution_workflow): @@ -242,7 +236,7 @@ async def _( developer_id=developer_id, task_id=task.id, data=data, - client=client, + connection_pool=pool, ) assert handle is not None @@ -256,30 +250,29 @@ async def _( @test("workflow: return step direct") async def _( - clients=cozo_clients_with_migrations, + dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, + _s3_client=s3_client, # Adding coz blob store might be used ): - client, _ = clients + pool = await create_db_pool(dsn=dsn) data = CreateExecutionRequest(input={"test": "input"}) - task = create_task( + task = await create_task( developer_id=developer_id, agent_id=agent.id, data=CreateTaskRequest( - **{ - "name": "test task", - "description": "test task about", - "input_schema": {"type": "object", "additionalProperties": True}, - "main": [ - # Testing that we can access the input - {"evaluate": {"hello": '_["test"]'}}, - {"return": {"value": '_["hello"]'}}, - {"return": {"value": '"banana"'}}, - ], - } + name="test task", + description="test task about", + input_schema={"type": "object", "additionalProperties": True}, + main=[ + # Testing that we can access the input + {"evaluate": {"hello": '_["test"]'}}, + {"return": {"value": '_["hello"]'}}, + {"return": {"value": '"banana"'}}, + ], ), - client=client, + connection_pool=pool, ) async with patch_testing_temporal() as (_, mock_run_task_execution_workflow): @@ -287,7 +280,7 @@ async def _( developer_id=developer_id, task_id=task.id, data=data, - client=client, + connection_pool=pool, ) assert handle is not None @@ -301,37 +294,36 @@ async def _( @test("workflow: return step nested") async def _( - clients=cozo_clients_with_migrations, + dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, + _s3_client=s3_client, # Adding coz blob store might be used ): - client, _ = clients + pool = await create_db_pool(dsn=dsn) data = CreateExecutionRequest(input={"test": "input"}) - task = create_task( + task = await create_task( developer_id=developer_id, agent_id=agent.id, data=CreateTaskRequest( - **{ - "name": "test task", - "description": "test task about", - "input_schema": {"type": "object", "additionalProperties": True}, - "other_workflow": [ - # Testing that we can access the input - {"evaluate": {"hello": '_["test"]'}}, - {"return": {"value": '_["hello"]'}}, - {"return": {"value": '"banana"'}}, - ], - "main": [ - # Testing that we can access the input - { - "workflow": "other_workflow", - "arguments": {"test": '_["test"]'}, - }, - ], - } + name="test task", + description="test task about", + input_schema={"type": "object", "additionalProperties": True}, + other_workflow=[ + # Testing that we can access the input + {"evaluate": {"hello": '_["test"]'}}, + {"return": {"value": '_["hello"]'}}, + {"return": {"value": '"banana"'}}, + ], + main=[ + # Testing that we can access the input + { + "workflow": "other_workflow", + "arguments": {"test": '_["test"]'}, + }, + ], ), - client=client, + connection_pool=pool, ) async with patch_testing_temporal() as (_, mock_run_task_execution_workflow): @@ -339,7 +331,7 @@ async def _( developer_id=developer_id, task_id=task.id, data=data, - client=client, + connection_pool=pool, ) assert handle is not None @@ -353,36 +345,35 @@ async def _( @test("workflow: log step") async def _( - clients=cozo_clients_with_migrations, + dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, + _s3_client=s3_client, # Adding coz blob store might be used ): - client, _ = clients + pool = await create_db_pool(dsn=dsn) data = CreateExecutionRequest(input={"test": "input"}) - task = create_task( + task = await create_task( developer_id=developer_id, agent_id=agent.id, data=CreateTaskRequest( - **{ - "name": "test task", - "description": "test task about", - "input_schema": {"type": "object", "additionalProperties": True}, - "other_workflow": [ - # Testing that we can access the input - {"evaluate": {"hello": '_["test"]'}}, - {"log": "{{_.hello}}"}, - ], - "main": [ - # Testing that we can access the input - { - "workflow": "other_workflow", - "arguments": {"test": '_["test"]'}, - }, - ], - } + name="test task", + description="test task about", + input_schema={"type": "object", "additionalProperties": True}, + other_workflow=[ + # Testing that we can access the input + {"evaluate": {"hello": '_["test"]'}}, + {"log": "{{_.hello}}"}, + ], + main=[ + # Testing that we can access the input + { + "workflow": "other_workflow", + "arguments": {"test": '_["test"]'}, + }, + ], ), - client=client, + connection_pool=pool, ) async with patch_testing_temporal() as (_, mock_run_task_execution_workflow): @@ -390,7 +381,7 @@ async def _( developer_id=developer_id, task_id=task.id, data=data, - client=client, + connection_pool=pool, ) assert handle is not None @@ -404,38 +395,35 @@ async def _( @test("workflow: log step expression fail") async def _( - clients=cozo_clients_with_migrations, + dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, + _s3_client=s3_client, # Adding coz blob store might be used ): - client, _ = clients + pool = await create_db_pool(dsn=dsn) data = CreateExecutionRequest(input={"test": "input"}) - task = create_task( + task = await create_task( developer_id=developer_id, agent_id=agent.id, data=CreateTaskRequest( - **{ - "name": "test task", - "description": "test task about", - "input_schema": {"type": "object", "additionalProperties": True}, - "other_workflow": [ - # Testing that we can access the input - {"evaluate": {"hello": '_["test"]'}}, - { - "log": '{{_["hell"].strip()}}' - }, # <--- The "hell" key does not exist - ], - "main": [ - # Testing that we can access the input - { - "workflow": "other_workflow", - "arguments": {"test": '_["test"]'}, - }, - ], - } + name="test task", + description="test task about", + input_schema={"type": "object", "additionalProperties": True}, + other_workflow=[ + # Testing that we can access the input + {"evaluate": {"hello": '_["test"]'}}, + {"log": '{{_["hell"].strip()}}'}, # <--- The "hell" key does not exist + ], + main=[ + # Testing that we can access the input + { + "workflow": "other_workflow", + "arguments": {"test": '_["test"]'}, + }, + ], ), - client=client, + connection_pool=pool, ) async with patch_testing_temporal() as (_, mock_run_task_execution_workflow): @@ -444,7 +432,7 @@ async def _( developer_id=developer_id, task_id=task.id, data=data, - client=client, + connection_pool=pool, ) assert handle is not None @@ -458,40 +446,39 @@ async def _( @test("workflow: system call - list agents") async def _( - clients=cozo_clients_with_migrations, + dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, + _s3_client=s3_client, # Adding coz blob store might be used ): - client, _ = clients + pool = await create_db_pool(dsn=dsn) data = CreateExecutionRequest(input={}) - task = create_task( + task = await create_task( developer_id=developer_id, agent_id=agent.id, data=CreateTaskRequest( - **{ - "name": "Test system tool task", - "description": "List agents using system call", - "input_schema": {"type": "object"}, - "tools": [ - { - "name": "list_agents", - "description": "List all agents", - "type": "system", - "system": {"resource": "agent", "operation": "list"}, - }, - ], - "main": [ - { - "tool": "list_agents", - "arguments": { - "limit": "10", - }, + name="Test system tool task", + description="List agents using system call", + input_schema={"type": "object"}, + tools=[ + { + "name": "list_agents", + "description": "List all agents", + "type": "system", + "system": {"resource": "agent", "operation": "list"}, + }, + ], + main=[ + { + "tool": "list_agents", + "arguments": { + "limit": "10", }, - ], - } + }, + ], ), - client=client, + connection_pool=pool, ) async with patch_testing_temporal() as (_, mock_run_task_execution_workflow): @@ -499,7 +486,7 @@ async def _( developer_id=developer_id, task_id=task.id, data=data, - client=client, + connection_pool=pool, ) assert handle is not None @@ -519,45 +506,44 @@ async def _( @test("workflow: tool call api_call") async def _( - clients=cozo_clients_with_migrations, + dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, + _s3_client=s3_client, # Adding coz blob store might be used ): - client, _ = clients + pool = await create_db_pool(dsn=dsn) data = CreateExecutionRequest(input={"test": "input"}) - task = create_task( + task = await create_task( developer_id=developer_id, agent_id=agent.id, data=CreateTaskRequest( - **{ - "name": "test task", - "description": "test task about", - "input_schema": {"type": "object", "additionalProperties": True}, - "tools": [ - { - "type": "api_call", - "name": "hello", - "api_call": { - "method": "GET", - "url": "https://httpbin.org/get", - }, - } - ], - "main": [ - { - "tool": "hello", - "arguments": { - "params": {"test": "_.test"}, - }, + name="test task", + description="test task about", + input_schema={"type": "object", "additionalProperties": True}, + tools=[ + { + "type": "api_call", + "name": "hello", + "api_call": { + "method": "GET", + "url": "https://httpbin.org/get", }, - { - "evaluate": {"hello": "_.json.args.test"}, + } + ], + main=[ + { + "tool": "hello", + "arguments": { + "params": {"test": "_.test"}, }, - ], - } + }, + { + "evaluate": {"hello": "_.json.args.test"}, + }, + ], ), - client=client, + connection_pool=pool, ) async with patch_testing_temporal() as (_, mock_run_task_execution_workflow): @@ -565,7 +551,7 @@ async def _( developer_id=developer_id, task_id=task.id, data=data, - client=client, + connection_pool=pool, ) assert handle is not None @@ -579,51 +565,50 @@ async def _( @test("workflow: tool call api_call test retry") async def _( - clients=cozo_clients_with_migrations, + dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, + _s3_client=s3_client, # Adding coz blob store might be used ): - client, _ = clients + pool = await create_db_pool(dsn=dsn) data = CreateExecutionRequest(input={"test": "input"}) status_codes_to_retry = ",".join(str(code) for code in (408, 429, 503, 504)) - task = create_task( + task = await create_task( developer_id=developer_id, agent_id=agent.id, data=CreateTaskRequest( - **{ - "name": "test task", - "description": "test task about", - "input_schema": {"type": "object", "additionalProperties": True}, - "tools": [ - { - "type": "api_call", - "name": "hello", - "api_call": { - "method": "GET", - "url": f"https://httpbin.org/status/{status_codes_to_retry}", - }, - } - ], - "main": [ - { - "tool": "hello", - "arguments": { - "params": {"test": "_.test"}, - }, + name="test task", + description="test task about", + input_schema={"type": "object", "additionalProperties": True}, + tools=[ + { + "type": "api_call", + "name": "hello", + "api_call": { + "method": "GET", + "url": f"https://httpbin.org/status/{status_codes_to_retry}", }, - ], - } + } + ], + main=[ + { + "tool": "hello", + "arguments": { + "params": {"test": "_.test"}, + }, + }, + ], ), - client=client, + connection_pool=pool, ) async with patch_testing_temporal() as (_, mock_run_task_execution_workflow): - execution, handle = await start_execution( + _execution, handle = await start_execution( developer_id=developer_id, task_id=task.id, data=data, - client=client, + connection_pool=pool, ) assert handle is not None @@ -644,48 +629,45 @@ async def _( # NOTE: super janky but works events_strings = [json.dumps(event) for event in events] - num_retries = len( - [event for event in events_strings if "execute_api_call" in event] - ) + num_retries = len([event for event in events_strings if "execute_api_call" in event]) assert num_retries >= 2 @test("workflow: tool call integration dummy") async def _( - clients=cozo_clients_with_migrations, + dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, + _s3_client=s3_client, # Adding coz blob store might be used ): - client, _ = clients + pool = await create_db_pool(dsn=dsn) data = CreateExecutionRequest(input={"test": "input"}) - task = create_task( + task = await create_task( developer_id=developer_id, agent_id=agent.id, data=CreateTaskRequest( - **{ - "name": "test task", - "description": "test task about", - "input_schema": {"type": "object", "additionalProperties": True}, - "tools": [ - { - "type": "integration", - "name": "hello", - "integration": { - "provider": "dummy", - }, - } - ], - "main": [ - { - "tool": "hello", - "arguments": {"test": "_.test"}, + name="test task", + description="test task about", + input_schema={"type": "object", "additionalProperties": True}, + tools=[ + { + "type": "integration", + "name": "hello", + "integration": { + "provider": "dummy", }, - ], - } + } + ], + main=[ + { + "tool": "hello", + "arguments": {"test": "_.test"}, + }, + ], ), - client=client, + connection_pool=pool, ) async with patch_testing_temporal() as (_, mock_run_task_execution_workflow): @@ -693,7 +675,7 @@ async def _( developer_id=developer_id, task_id=task.id, data=data, - client=client, + connection_pool=pool, ) assert handle is not None @@ -708,41 +690,40 @@ async def _( @skip("integration service patch not working") @test("workflow: tool call integration mocked weather") async def _( - clients=cozo_clients_with_migrations, + dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, + _s3_client=s3_client, # Adding coz blob store might be used ): - client, _ = clients + pool = await create_db_pool(dsn=dsn) data = CreateExecutionRequest(input={"test": "input"}) - task = create_task( + task = await create_task( developer_id=developer_id, agent_id=agent.id, data=CreateTaskRequest( - **{ - "name": "test task", - "description": "test task about", - "input_schema": {"type": "object", "additionalProperties": True}, - "tools": [ - { - "type": "integration", - "name": "get_weather", - "integration": { - "provider": "weather", - "setup": {"openweathermap_api_key": "test"}, - "arguments": {"test": "fake"}, - }, - } - ], - "main": [ - { - "tool": "get_weather", - "arguments": {"location": "_.test"}, + name="test task", + description="test task about", + input_schema={"type": "object", "additionalProperties": True}, + tools=[ + { + "type": "integration", + "name": "get_weather", + "integration": { + "provider": "weather", + "setup": {"openweathermap_api_key": "test"}, + "arguments": {"test": "fake"}, }, - ], - } + } + ], + main=[ + { + "tool": "get_weather", + "arguments": {"location": "_.test"}, + }, + ], ), - client=client, + connection_pool=pool, ) expected_output = {"temperature": 20, "humidity": 60} @@ -753,7 +734,7 @@ async def _( developer_id=developer_id, task_id=task.id, data=data, - client=client, + connection_pool=pool, ) assert handle is not None @@ -768,27 +749,26 @@ async def _( @test("workflow: wait for input step start") async def _( - clients=cozo_clients_with_migrations, + dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, + _s3_client=s3_client, # Adding coz blob store might be used ): - client, _ = clients + pool = await create_db_pool(dsn=dsn) data = CreateExecutionRequest(input={"test": "input"}) - task = create_task( + task = await create_task( developer_id=developer_id, agent_id=agent.id, data=CreateTaskRequest( - **{ - "name": "test task", - "description": "test task about", - "input_schema": {"type": "object", "additionalProperties": True}, - "main": [ - {"wait_for_input": {"info": {"hi": '"bye"'}}}, - ], - } + name="test task", + description="test task about", + input_schema={"type": "object", "additionalProperties": True}, + main=[ + {"wait_for_input": {"info": {"hi": '"bye"'}}}, + ], ), - client=client, + connection_pool=pool, ) async with patch_testing_temporal() as (_, mock_run_task_execution_workflow): @@ -796,7 +776,7 @@ async def _( developer_id=developer_id, task_id=task.id, data=data, - client=client, + connection_pool=pool, ) assert handle is not None @@ -809,7 +789,7 @@ async def _( task = asyncio.create_task(result_coroutine) try: await asyncio.wait_for(task, timeout=3) - except asyncio.TimeoutError: + except TimeoutError: task.cancel() # Get the history @@ -824,41 +804,38 @@ async def _( for event in events if "ACTIVITY_TASK_SCHEDULED" in event["eventType"] ] - activities_scheduled = [ - activity for activity in activities_scheduled if activity - ] + activities_scheduled = [activity for activity in activities_scheduled if activity] assert "wait_for_input_step" in activities_scheduled @test("workflow: foreach wait for input step start") async def _( - clients=cozo_clients_with_migrations, + dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, + _s3_client=s3_client, # Adding coz blob store might be used ): - client, _ = clients + pool = await create_db_pool(dsn=dsn) data = CreateExecutionRequest(input={"test": "input"}) - task = create_task( + task = await create_task( developer_id=developer_id, agent_id=agent.id, data=CreateTaskRequest( - **{ - "name": "test task", - "description": "test task about", - "input_schema": {"type": "object", "additionalProperties": True}, - "main": [ - { - "foreach": { - "in": "'a b c'.split()", - "do": {"wait_for_input": {"info": {"hi": '"bye"'}}}, - }, + name="test task", + description="test task about", + input_schema={"type": "object", "additionalProperties": True}, + main=[ + { + "foreach": { + "in": "'a b c'.split()", + "do": {"wait_for_input": {"info": {"hi": '"bye"'}}}, }, - ], - } + }, + ], ), - client=client, + connection_pool=pool, ) async with patch_testing_temporal() as (_, mock_run_task_execution_workflow): @@ -866,7 +843,7 @@ async def _( developer_id=developer_id, task_id=task.id, data=data, - client=client, + connection_pool=pool, ) assert handle is not None @@ -879,7 +856,7 @@ async def _( task = asyncio.create_task(result_coroutine) try: await asyncio.wait_for(task, timeout=3) - except asyncio.TimeoutError: + except TimeoutError: task.cancel() # Get the history @@ -894,42 +871,39 @@ async def _( for event in events if "ACTIVITY_TASK_SCHEDULED" in event["eventType"] ] - activities_scheduled = [ - activity for activity in activities_scheduled if activity - ] + activities_scheduled = [activity for activity in activities_scheduled if activity] assert "for_each_step" in activities_scheduled @test("workflow: if-else step") async def _( - clients=cozo_clients_with_migrations, + dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, + _s3_client=s3_client, # Adding coz blob store might be used ): - client, _ = clients + pool = await create_db_pool(dsn=dsn) data = CreateExecutionRequest(input={"test": "input"}) task_def = CreateTaskRequest( - **{ - "name": "test task", - "description": "test task about", - "input_schema": {"type": "object", "additionalProperties": True}, - "main": [ - { - "if": "False", - "then": {"evaluate": {"hello": '"world"'}}, - "else": {"evaluate": {"hello": "random.randint(0, 10)"}}, - }, - ], - } + name="test task", + description="test task about", + input_schema={"type": "object", "additionalProperties": True}, + main=[ + { + "if": "False", + "then": {"evaluate": {"hello": '"world"'}}, + "else": {"evaluate": {"hello": "random.randint(0, 10)"}}, + }, + ], ) - task = create_task( + task = await create_task( developer_id=developer_id, agent_id=agent.id, data=task_def, - client=client, + connection_pool=pool, ) async with patch_testing_temporal() as (_, mock_run_task_execution_workflow): @@ -937,7 +911,7 @@ async def _( developer_id=developer_id, task_id=task.id, data=data, - client=client, + connection_pool=pool, ) assert handle is not None @@ -952,42 +926,41 @@ async def _( @test("workflow: switch step") async def _( - clients=cozo_clients_with_migrations, + dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, + _s3_client=s3_client, # Adding coz blob store might be used ): - client, _ = clients + pool = await create_db_pool(dsn=dsn) data = CreateExecutionRequest(input={"test": "input"}) - task = create_task( + task = await create_task( developer_id=developer_id, agent_id=agent.id, data=CreateTaskRequest( - **{ - "name": "test task", - "description": "test task about", - "input_schema": {"type": "object", "additionalProperties": True}, - "main": [ - { - "switch": [ - { - "case": "False", - "then": {"evaluate": {"hello": '"bubbles"'}}, - }, - { - "case": "True", - "then": {"evaluate": {"hello": '"world"'}}, - }, - { - "case": "True", - "then": {"evaluate": {"hello": '"bye"'}}, - }, - ] - }, - ], - } + name="test task", + description="test task about", + input_schema={"type": "object", "additionalProperties": True}, + main=[ + { + "switch": [ + { + "case": "False", + "then": {"evaluate": {"hello": '"bubbles"'}}, + }, + { + "case": "True", + "then": {"evaluate": {"hello": '"world"'}}, + }, + { + "case": "True", + "then": {"evaluate": {"hello": '"bye"'}}, + }, + ] + }, + ], ), - client=client, + connection_pool=pool, ) async with patch_testing_temporal() as (_, mock_run_task_execution_workflow): @@ -995,7 +968,7 @@ async def _( developer_id=developer_id, task_id=task.id, data=data, - client=client, + connection_pool=pool, ) assert handle is not None @@ -1010,32 +983,31 @@ async def _( @test("workflow: for each step") async def _( - clients=cozo_clients_with_migrations, + dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, + _s3_client=s3_client, # Adding coz blob store might be used ): - client, _ = clients + pool = await create_db_pool(dsn=dsn) data = CreateExecutionRequest(input={"test": "input"}) - task = create_task( + task = await create_task( developer_id=developer_id, agent_id=agent.id, data=CreateTaskRequest( - **{ - "name": "test task", - "description": "test task about", - "input_schema": {"type": "object", "additionalProperties": True}, - "main": [ - { - "foreach": { - "in": "'a b c'.split()", - "do": {"evaluate": {"hello": '"world"'}}, - }, + name="test task", + description="test task about", + input_schema={"type": "object", "additionalProperties": True}, + main=[ + { + "foreach": { + "in": "'a b c'.split()", + "do": {"evaluate": {"hello": '"world"'}}, }, - ], - } + }, + ], ), - client=client, + connection_pool=pool, ) async with patch_testing_temporal() as (_, mock_run_task_execution_workflow): @@ -1043,7 +1015,7 @@ async def _( developer_id=developer_id, task_id=task.id, data=data, - client=client, + connection_pool=pool, ) assert handle is not None @@ -1058,11 +1030,12 @@ async def _( @test("workflow: map reduce step") async def _( - clients=cozo_clients_with_migrations, + dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, + _s3_client=s3_client, # Adding coz blob store might be used ): - client, _ = clients + pool = await create_db_pool(dsn=dsn) data = CreateExecutionRequest(input={"test": "input"}) map_step = { @@ -1079,11 +1052,11 @@ async def _( "main": [map_step], } - task = create_task( + task = await create_task( developer_id=developer_id, agent_id=agent.id, data=CreateTaskRequest(**task_def), - client=client, + connection_pool=pool, ) async with patch_testing_temporal() as (_, mock_run_task_execution_workflow): @@ -1091,7 +1064,7 @@ async def _( developer_id=developer_id, task_id=task.id, data=data, - client=client, + connection_pool=pool, ) assert handle is not None @@ -1108,10 +1081,12 @@ async def _( @test(f"workflow: map reduce step parallel (parallelism={p})") async def _( - client=cozo_client, + dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, + _s3_client=s3_client, # Adding coz blob store might be used ): + pool = await create_db_pool(dsn=dsn) data = CreateExecutionRequest(input={"test": "input"}) map_step = { @@ -1129,11 +1104,11 @@ async def _( "main": [map_step], } - task = create_task( + task = await create_task( developer_id=developer_id, agent_id=agent.id, data=CreateTaskRequest(**task_def), - client=client, + connection_pool=pool, ) async with patch_testing_temporal() as (_, mock_run_task_execution_workflow): @@ -1141,7 +1116,7 @@ async def _( developer_id=developer_id, task_id=task.id, data=data, - client=client, + connection_pool=pool, ) assert handle is not None @@ -1161,11 +1136,12 @@ async def _( @test("workflow: prompt step (python expression)") async def _( - clients=cozo_clients_with_migrations, + dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, + _s3_client=s3_client, # Adding coz blob store might be used ): - client, _ = clients + pool = await create_db_pool(dsn=dsn) mock_model_response = ModelResponse( id="fake_id", choices=[Choices(message={"role": "assistant", "content": "Hello, world!"})], @@ -1177,23 +1153,21 @@ async def _( acompletion.return_value = mock_model_response data = CreateExecutionRequest(input={"test": "input"}) - task = create_task( + task = await create_task( developer_id=developer_id, agent_id=agent.id, data=CreateTaskRequest( - **{ - "name": "test task", - "description": "test task about", - "input_schema": {"type": "object", "additionalProperties": True}, - "main": [ - { - "prompt": "$_ [{'role': 'user', 'content': _.test}]", - "settings": {}, - }, - ], - } + name="test task", + description="test task about", + input_schema={"type": "object", "additionalProperties": True}, + main=[ + { + "prompt": "$_ [{'role': 'user', 'content': _.test}]", + "settings": {}, + }, + ], ), - client=client, + connection_pool=pool, ) async with patch_testing_temporal() as (_, mock_run_task_execution_workflow): @@ -1201,7 +1175,7 @@ async def _( developer_id=developer_id, task_id=task.id, data=data, - client=client, + connection_pool=pool, ) assert handle is not None @@ -1218,11 +1192,12 @@ async def _( @test("workflow: prompt step") async def _( - clients=cozo_clients_with_migrations, + dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, + _s3_client=s3_client, # Adding coz blob store might be used ): - client, _ = clients + pool = await create_db_pool(dsn=dsn) mock_model_response = ModelResponse( id="fake_id", choices=[Choices(message={"role": "assistant", "content": "Hello, world!"})], @@ -1234,28 +1209,26 @@ async def _( acompletion.return_value = mock_model_response data = CreateExecutionRequest(input={"test": "input"}) - task = create_task( + task = await create_task( developer_id=developer_id, agent_id=agent.id, data=CreateTaskRequest( - **{ - "name": "test task", - "description": "test task about", - "input_schema": {"type": "object", "additionalProperties": True}, - "main": [ - { - "prompt": [ - { - "role": "user", - "content": "message", - }, - ], - "settings": {}, - }, - ], - } + name="test task", + description="test task about", + input_schema={"type": "object", "additionalProperties": True}, + main=[ + { + "prompt": [ + { + "role": "user", + "content": "message", + }, + ], + "settings": {}, + }, + ], ), - client=client, + connection_pool=pool, ) async with patch_testing_temporal() as (_, mock_run_task_execution_workflow): @@ -1263,7 +1236,7 @@ async def _( developer_id=developer_id, task_id=task.id, data=data, - client=client, + connection_pool=pool, ) assert handle is not None @@ -1280,11 +1253,12 @@ async def _( @test("workflow: prompt step unwrap") async def _( - clients=cozo_clients_with_migrations, + dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, + _s3_client=s3_client, # Adding coz blob store might be used ): - client, _ = clients + pool = await create_db_pool(dsn=dsn) mock_model_response = ModelResponse( id="fake_id", choices=[Choices(message={"role": "assistant", "content": "Hello, world!"})], @@ -1296,29 +1270,27 @@ async def _( acompletion.return_value = mock_model_response data = CreateExecutionRequest(input={"test": "input"}) - task = create_task( + task = await create_task( developer_id=developer_id, agent_id=agent.id, data=CreateTaskRequest( - **{ - "name": "test task", - "description": "test task about", - "input_schema": {"type": "object", "additionalProperties": True}, - "main": [ - { - "prompt": [ - { - "role": "user", - "content": "message", - }, - ], - "unwrap": True, - "settings": {}, - }, - ], - } + name="test task", + description="test task about", + input_schema={"type": "object", "additionalProperties": True}, + main=[ + { + "prompt": [ + { + "role": "user", + "content": "message", + }, + ], + "unwrap": True, + "settings": {}, + }, + ], ), - client=client, + connection_pool=pool, ) async with patch_testing_temporal() as (_, mock_run_task_execution_workflow): @@ -1326,7 +1298,7 @@ async def _( developer_id=developer_id, task_id=task.id, data=data, - client=client, + connection_pool=pool, ) assert handle is not None @@ -1341,28 +1313,26 @@ async def _( @test("workflow: set and get steps") async def _( - clients=cozo_clients_with_migrations, + dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, ): - client, _ = clients + pool = await create_db_pool(dsn=dsn) data = CreateExecutionRequest(input={"test": "input"}) - task = create_task( + task = await create_task( developer_id=developer_id, agent_id=agent.id, data=CreateTaskRequest( - **{ - "name": "test task", - "description": "test task about", - "input_schema": {"type": "object", "additionalProperties": True}, - "main": [ - {"set": {"test_key": '"test_value"'}}, - {"get": "test_key"}, - ], - } + name="test task", + description="test task about", + input_schema={"type": "object", "additionalProperties": True}, + main=[ + {"set": {"test_key": '"test_value"'}}, + {"get": "test_key"}, + ], ), - client=client, + connection_pool=pool, ) async with patch_testing_temporal() as (_, mock_run_task_execution_workflow): @@ -1370,7 +1340,7 @@ async def _( developer_id=developer_id, task_id=task.id, data=data, - client=client, + connection_pool=pool, ) assert handle is not None @@ -1385,17 +1355,15 @@ async def _( @test("workflow: execute yaml task") async def _( - clients=cozo_clients_with_migrations, + dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, ): - client, _ = clients + pool = await create_db_pool(dsn=dsn) mock_model_response = ModelResponse( id="fake_id", choices=[ - Choices( - message={"role": "assistant", "content": "found: true\nvalue: 'Gaga'"} - ) + Choices(message={"role": "assistant", "content": "found: true\nvalue: 'Gaga'"}) ], created=0, object="text_completion", @@ -1403,22 +1371,22 @@ async def _( with ( patch("agents_api.clients.litellm.acompletion") as acompletion, - open("./tests/sample_tasks/find_selector.yaml", "r") as task_file, + open("./tests/sample_tasks/find_selector.yaml") as task_file, ): - input = dict( - screenshot_base64="iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAACXBIWXMAAAsTAAALEwEAmpwYAAAA", - network_requests=[{"request": {}, "response": {"body": "Lady Gaga"}}], - parameters=["name"], - ) + input = { + "screenshot_base64": "iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAACXBIWXMAAAsTAAALEwEAmpwYAAAA", + "network_requests": [{"request": {}, "response": {"body": "Lady Gaga"}}], + "parameters": ["name"], + } task_definition = yaml.safe_load(task_file) acompletion.return_value = mock_model_response data = CreateExecutionRequest(input=input) - task = create_task( + task = await create_task( developer_id=developer_id, agent_id=agent.id, data=CreateTaskRequest(**task_definition), - client=client, + connection_pool=pool, ) async with patch_testing_temporal() as (_, mock_run_task_execution_workflow): @@ -1426,7 +1394,7 @@ async def _( developer_id=developer_id, task_id=task.id, data=data, - client=client, + connection_pool=pool, ) assert handle is not None diff --git a/agents-api/tests/test_files_routes.py b/agents-api/tests/test_file_routes.py similarity index 66% rename from agents-api/tests/test_files_routes.py rename to agents-api/tests/test_file_routes.py index 662612ff5..3eb3dc82d 100644 --- a/agents-api/tests/test_files_routes.py +++ b/agents-api/tests/test_file_routes.py @@ -8,12 +8,12 @@ @test("route: create file") async def _(make_request=make_request, s3_client=s3_client): - data = dict( - name="Test File", - description="This is a test file.", - mime_type="text/plain", - content="eyJzYW1wbGUiOiAidGVzdCJ9", - ) + data = { + "name": "Test File", + "description": "This is a test file.", + "mime_type": "text/plain", + "content": "eyJzYW1wbGUiOiAidGVzdCJ9", + } response = make_request( method="POST", @@ -26,12 +26,12 @@ async def _(make_request=make_request, s3_client=s3_client): @test("route: delete file") async def _(make_request=make_request, s3_client=s3_client): - data = dict( - name="Test File", - description="This is a test file.", - mime_type="text/plain", - content="eyJzYW1wbGUiOiAidGVzdCJ9", - ) + data = { + "name": "Test File", + "description": "This is a test file.", + "mime_type": "text/plain", + "content": "eyJzYW1wbGUiOiAidGVzdCJ9", + } response = make_request( method="POST", @@ -58,12 +58,12 @@ async def _(make_request=make_request, s3_client=s3_client): @test("route: get file") async def _(make_request=make_request, s3_client=s3_client): - data = dict( - name="Test File", - description="This is a test file.", - mime_type="text/plain", - content="eyJzYW1wbGUiOiAidGVzdCJ9", - ) + data = { + "name": "Test File", + "description": "This is a test file.", + "mime_type": "text/plain", + "content": "eyJzYW1wbGUiOiAidGVzdCJ9", + } response = make_request( method="POST", @@ -86,3 +86,13 @@ async def _(make_request=make_request, s3_client=s3_client): # Decode base64 content and compute its SHA-256 hash assert result["hash"] == expected_hash + + +@test("route: list files") +async def _(make_request=make_request, s3_client=s3_client): + response = make_request( + method="GET", + url="/files", + ) + + assert response.status_code == 200 diff --git a/agents-api/tests/test_files_queries.py b/agents-api/tests/test_files_queries.py index 712a083ca..a67c68bae 100644 --- a/agents-api/tests/test_files_queries.py +++ b/agents-api/tests/test_files_queries.py @@ -1,57 +1,251 @@ -# Tests for entry queries +# # Tests for entry queries +from agents_api.autogen.openapi_model import CreateFileRequest +from agents_api.clients.pg import create_db_pool +from agents_api.queries.files.create_file import create_file +from agents_api.queries.files.delete_file import delete_file +from agents_api.queries.files.get_file import get_file +from agents_api.queries.files.list_files import list_files from ward import test -from agents_api.autogen.openapi_model import CreateFileRequest -from agents_api.models.files.create_file import create_file -from agents_api.models.files.delete_file import delete_file -from agents_api.models.files.get_file import get_file -from tests.fixtures import ( - cozo_client, - test_developer_id, - test_file, -) - - -@test("model: create file") -def _(client=cozo_client, developer_id=test_developer_id): - create_file( - developer_id=developer_id, +from tests.fixtures import pg_dsn, test_agent, test_developer, test_file, test_user + + +@test("query: create file") +async def _(dsn=pg_dsn, developer=test_developer): + pool = await create_db_pool(dsn=dsn) + await create_file( + developer_id=developer.id, data=CreateFileRequest( name="Hello", description="World", mime_type="text/plain", content="eyJzYW1wbGUiOiAidGVzdCJ9", ), - client=client, + connection_pool=pool, + ) + + +@test("query: create user file") +async def _(dsn=pg_dsn, developer=test_developer, user=test_user): + pool = await create_db_pool(dsn=dsn) + file = await create_file( + developer_id=developer.id, + data=CreateFileRequest( + name="User File", + description="Test user file", + mime_type="text/plain", + content="eyJzYW1wbGUiOiAidGVzdCJ9", + ), + owner_type="user", + owner_id=user.id, + connection_pool=pool, ) + assert file.name == "User File" + + # Verify file appears in user's files + files = await list_files( + developer_id=developer.id, + owner_type="user", + owner_id=user.id, + connection_pool=pool, + ) + assert any(f.id == file.id for f in files) + +@test("query: create agent file") +async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent): + pool = await create_db_pool(dsn=dsn) -@test("model: get file") -def _(client=cozo_client, file=test_file, developer_id=test_developer_id): - get_file( - developer_id=developer_id, + file = await create_file( + developer_id=developer.id, + data=CreateFileRequest( + name="Agent File", + description="Test agent file", + mime_type="text/plain", + content="eyJzYW1wbGUiOiAidGVzdCJ9", + ), + owner_type="agent", + owner_id=agent.id, + connection_pool=pool, + ) + assert file.name == "Agent File" + + # Verify file appears in agent's files + files = await list_files( + developer_id=developer.id, + owner_type="agent", + owner_id=agent.id, + connection_pool=pool, + ) + assert any(f.id == file.id for f in files) + + +@test("query: get file") +async def _(dsn=pg_dsn, file=test_file, developer=test_developer): + pool = await create_db_pool(dsn=dsn) + file_test = await get_file( + developer_id=developer.id, file_id=file.id, - client=client, + connection_pool=pool, ) + assert file_test.id == file.id + assert file_test.name == "Hello" + assert file_test.description == "World" + assert file_test.mime_type == "text/plain" + assert file_test.hash == file.hash + + +@test("query: list files") +async def _(dsn=pg_dsn, developer=test_developer, file=test_file): + pool = await create_db_pool(dsn=dsn) + files = await list_files( + developer_id=developer.id, + connection_pool=pool, + ) + assert len(files) >= 1 + assert any(f.id == file.id for f in files) + +@test("query: list user files") +async def _(dsn=pg_dsn, developer=test_developer, user=test_user): + pool = await create_db_pool(dsn=dsn) -@test("model: delete file") -def _(client=cozo_client, developer_id=test_developer_id): - file = create_file( - developer_id=developer_id, + # Create a file owned by the user + file = await create_file( + developer_id=developer.id, data=CreateFileRequest( - name="Hello", - description="World", + name="User List Test", + description="Test file for user listing", mime_type="text/plain", content="eyJzYW1wbGUiOiAidGVzdCJ9", ), - client=client, + owner_type="user", + owner_id=user.id, + connection_pool=pool, + ) + + # List user's files + files = await list_files( + developer_id=developer.id, + owner_type="user", + owner_id=user.id, + connection_pool=pool, ) + assert len(files) >= 1 + assert any(f.id == file.id for f in files) + + +@test("query: list agent files") +async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent): + pool = await create_db_pool(dsn=dsn) + + # Create a file owned by the agent + file = await create_file( + developer_id=developer.id, + data=CreateFileRequest( + name="Agent List Test", + description="Test file for agent listing", + mime_type="text/plain", + content="eyJzYW1wbGUiOiAidGVzdCJ9", + ), + owner_type="agent", + owner_id=agent.id, + connection_pool=pool, + ) + + # List agent's files + files = await list_files( + developer_id=developer.id, + owner_type="agent", + owner_id=agent.id, + connection_pool=pool, + ) + assert len(files) >= 1 + assert any(f.id == file.id for f in files) + + +@test("query: delete user file") +async def _(dsn=pg_dsn, developer=test_developer, user=test_user): + pool = await create_db_pool(dsn=dsn) + + # Create a file owned by the user + file = await create_file( + developer_id=developer.id, + data=CreateFileRequest( + name="User Delete Test", + description="Test file for user deletion", + mime_type="text/plain", + content="eyJzYW1wbGUiOiAidGVzdCJ9", + ), + owner_type="user", + owner_id=user.id, + connection_pool=pool, + ) + + # Delete the file + await delete_file( + developer_id=developer.id, + file_id=file.id, + owner_type="user", + owner_id=user.id, + connection_pool=pool, + ) + + # Verify file is no longer in user's files + files = await list_files( + developer_id=developer.id, + owner_type="user", + owner_id=user.id, + connection_pool=pool, + ) + assert not any(f.id == file.id for f in files) + + +@test("query: delete agent file") +async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent): + pool = await create_db_pool(dsn=dsn) + + # Create a file owned by the agent + file = await create_file( + developer_id=developer.id, + data=CreateFileRequest( + name="Agent Delete Test", + description="Test file for agent deletion", + mime_type="text/plain", + content="eyJzYW1wbGUiOiAidGVzdCJ9", + ), + owner_type="agent", + owner_id=agent.id, + connection_pool=pool, + ) + + # Delete the file + await delete_file( + developer_id=developer.id, + file_id=file.id, + owner_type="agent", + owner_id=agent.id, + connection_pool=pool, + ) + + # Verify file is no longer in agent's files + files = await list_files( + developer_id=developer.id, + owner_type="agent", + owner_id=agent.id, + connection_pool=pool, + ) + assert not any(f.id == file.id for f in files) + + +@test("query: delete file") +async def _(dsn=pg_dsn, developer=test_developer, file=test_file): + pool = await create_db_pool(dsn=dsn) - delete_file( - developer_id=developer_id, + await delete_file( + developer_id=developer.id, file_id=file.id, - client=client, + connection_pool=pool, ) diff --git a/agents-api/tests/test_messages_truncation.py b/agents-api/tests/test_messages_truncation.py index 97516617a..1a6c344e6 100644 --- a/agents-api/tests/test_messages_truncation.py +++ b/agents-api/tests/test_messages_truncation.py @@ -1,4 +1,4 @@ -# from uuid import uuid4 +# from uuid_extensions import uuid7 # from ward import raises, test @@ -26,9 +26,9 @@ # threshold = sum([len(c) // 3.5 for c in contents]) # messages: list[Entry] = [ -# Entry(session_id=uuid4(), role=Role.user, content=contents[0][0]), -# Entry(session_id=uuid4(), role=Role.assistant, content=contents[1][0]), -# Entry(session_id=uuid4(), role=Role.user, content=contents[2][0]), +# Entry(session_id=uuid7(), role=Role.user, content=contents[0][0]), +# Entry(session_id=uuid7(), role=Role.assistant, content=contents[1][0]), +# Entry(session_id=uuid7(), role=Role.user, content=contents[2][0]), # ] # result = session.truncate(messages, threshold) @@ -45,7 +45,7 @@ # ("content5", True), # ("content6", True), # ] -# session_ids = [uuid4()] * len(contents) +# session_ids = [uuid7()] * len(contents) # threshold = sum([len(c) // 3.5 for c, i in contents if i]) # messages: list[Entry] = [ @@ -99,7 +99,7 @@ # ("content5", True), # ("content6", True), # ] -# session_ids = [uuid4()] * len(contents) +# session_ids = [uuid7()] * len(contents) # threshold = sum([len(c) // 3.5 for c, i in contents if i]) # messages: list[Entry] = [ @@ -146,7 +146,7 @@ # ("content6", True), # ("content7", False), # ] -# session_ids = [uuid4()] * len(contents) +# session_ids = [uuid7()] * len(contents) # threshold = sum([len(c) // 3.5 for c, i in contents if i]) # messages: list[Entry] = [ @@ -204,7 +204,7 @@ # ("content12", True), # ("content13", False), # ] -# session_ids = [uuid4()] * len(contents) +# session_ids = [uuid7()] * len(contents) # threshold = sum([len(c) // 3.5 for c, i in contents if i]) # messages: list[Entry] = [ @@ -271,7 +271,7 @@ # ("content9", True), # ("content10", False), # ] -# session_ids = [uuid4()] * len(contents) +# session_ids = [uuid7()] * len(contents) # threshold = sum([len(c) // 3.5 for c, i in contents if i]) # all_tokens = sum([len(c) // 3.5 for c, _ in contents]) diff --git a/agents-api/tests/test_session_queries.py b/agents-api/tests/test_session_queries.py index 01fea1375..e2d1dba17 100644 --- a/agents-api/tests/test_session_queries.py +++ b/agents-api/tests/test_session_queries.py @@ -1,160 +1,247 @@ -# Tests for session queries -from uuid import uuid4 - -from ward import test +# """ +# This module contains tests for SQL query generation functions in the sessions module. +# Tests verify the SQL queries without actually executing them against a database. +# """ from agents_api.autogen.openapi_model import ( CreateOrUpdateSessionRequest, CreateSessionRequest, + PatchSessionRequest, + ResourceCreatedResponse, + ResourceDeletedResponse, + ResourceUpdatedResponse, Session, + UpdateSessionRequest, +) +from agents_api.clients.pg import create_db_pool +from agents_api.queries.sessions import ( + count_sessions, + create_or_update_session, + create_session, + delete_session, + get_session, + list_sessions, + patch_session, + update_session, ) -from agents_api.models.session.count_sessions import count_sessions -from agents_api.models.session.create_or_update_session import create_or_update_session -from agents_api.models.session.create_session import create_session -from agents_api.models.session.delete_session import delete_session -from agents_api.models.session.get_session import get_session -from agents_api.models.session.list_sessions import list_sessions +from uuid_extensions import uuid7 +from ward import raises, test + from tests.fixtures import ( - cozo_client, + pg_dsn, test_agent, test_developer_id, test_session, test_user, ) -MODEL = "gpt-4o-mini" +@test("query: create session sql") +async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, user=test_user): + """Test that a session can be successfully created.""" -@test("model: create session") -def _( - client=cozo_client, developer_id=test_developer_id, agent=test_agent, user=test_user -): - create_session( + pool = await create_db_pool(dsn=dsn) + session_id = uuid7() + data = CreateSessionRequest( + users=[user.id], + agents=[agent.id], + system_template="test system template", + ) + result = await create_session( developer_id=developer_id, - data=CreateSessionRequest( - users=[user.id], - agents=[agent.id], - situation="test session about", - ), - client=client, + session_id=session_id, + data=data, + connection_pool=pool, ) + assert result is not None + assert isinstance(result, ResourceCreatedResponse), f"Result is not a Session, {result}" + assert result.id == session_id + -@test("model: create session no user") -def _(client=cozo_client, developer_id=test_developer_id, agent=test_agent): - create_session( +@test("query: create or update session sql") +async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, user=test_user): + """Test that a session can be successfully created or updated.""" + + pool = await create_db_pool(dsn=dsn) + session_id = uuid7() + data = CreateOrUpdateSessionRequest( + users=[user.id], + agents=[agent.id], + system_template="test system template", + ) + result = await create_or_update_session( developer_id=developer_id, - data=CreateSessionRequest( - agents=[agent.id], - situation="test session about", - ), - client=client, + session_id=session_id, + data=data, + connection_pool=pool, ) + assert result is not None + assert isinstance(result, ResourceUpdatedResponse) + assert result.id == session_id + assert result.updated_at is not None -@test("model: get session not exists") -def _(client=cozo_client, developer_id=test_developer_id): - session_id = uuid4() - try: - get_session( - session_id=session_id, - developer_id=developer_id, - client=client, - ) - except Exception: - pass - else: - assert False, "Session should not exist" +@test("query: get session exists") +async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): + """Test retrieving an existing session.""" - -@test("model: get session exists") -def _(client=cozo_client, developer_id=test_developer_id, session=test_session): - result = get_session( - session_id=session.id, + pool = await create_db_pool(dsn=dsn) + result = await get_session( developer_id=developer_id, - client=client, + session_id=session.id, + connection_pool=pool, ) assert result is not None assert isinstance(result, Session) + assert result.id == session.id -@test("model: delete session") -def _(client=cozo_client, developer_id=test_developer_id, agent=test_agent): - session = create_session( - developer_id=developer_id, - data=CreateSessionRequest( - agent=agent.id, - situation="test session about", - ), - client=client, - ) +@test("query: get session does not exist") +async def _(dsn=pg_dsn, developer_id=test_developer_id): + """Test retrieving a non-existent session.""" - delete_session( - session_id=session.id, + session_id = uuid7() + pool = await create_db_pool(dsn=dsn) + with raises(Exception): + await get_session( + session_id=session_id, + developer_id=developer_id, + connection_pool=pool, + ) + + +@test("query: list sessions") +async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): + """Test listing sessions with default pagination.""" + + pool = await create_db_pool(dsn=dsn) + result = await list_sessions( developer_id=developer_id, - client=client, + limit=10, + offset=0, + connection_pool=pool, ) - try: - get_session( - session_id=session.id, - developer_id=developer_id, - client=client, - ) - except Exception: - pass + assert isinstance(result, list) + assert len(result) >= 1 + assert any(s.id == session.id for s in result) - else: - assert False, "Session should not exist" +@test("query: list sessions with filters") +async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): + """Test listing sessions with specific filters.""" -@test("model: list sessions") -def _(client=cozo_client, developer_id=test_developer_id, session=test_session): - result = list_sessions( + pool = await create_db_pool(dsn=dsn) + result = await list_sessions( developer_id=developer_id, - client=client, + limit=10, + offset=0, + connection_pool=pool, ) assert isinstance(result, list) - assert len(result) > 0 + assert len(result) >= 1 + assert all(isinstance(s, Session) for s in result), ( + f"Result is not a list of sessions, {result}" + ) + +@test("query: count sessions") +async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): + """Test counting the number of sessions for a developer.""" -@test("model: count sessions") -def _(client=cozo_client, developer_id=test_developer_id, session=test_session): - result = count_sessions( + pool = await create_db_pool(dsn=dsn) + count = await count_sessions( developer_id=developer_id, - client=client, + connection_pool=pool, ) - assert isinstance(result, dict) - assert result["count"] > 0 + assert isinstance(count, dict) + assert count["count"] >= 1 -@test("model: create or update session") -def _( - client=cozo_client, developer_id=test_developer_id, agent=test_agent, user=test_user +@test("query: update session sql") +async def _( + dsn=pg_dsn, + developer_id=test_developer_id, + session=test_session, + agent=test_agent, + user=test_user, ): - session_id = uuid4() + """Test that an existing session's information can be successfully updated.""" - create_or_update_session( - session_id=session_id, + pool = await create_db_pool(dsn=dsn) + data = UpdateSessionRequest( + token_budget=1000, + forward_tool_calls=True, + system_template="updated system template", + ) + result = await update_session( + session_id=session.id, developer_id=developer_id, - data=CreateOrUpdateSessionRequest( - users=[user.id], - agents=[agent.id], - situation="test session about", - ), - client=client, + data=data, + connection_pool=pool, ) - result = get_session( - session_id=session_id, + assert result is not None + assert isinstance(result, ResourceUpdatedResponse) + assert result.updated_at > session.created_at + + updated_session = await get_session( developer_id=developer_id, - client=client, + session_id=session.id, + connection_pool=pool, + ) + assert updated_session.forward_tool_calls is True + + +@test("query: patch session sql") +async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session, agent=test_agent): + """Test that a session can be successfully patched.""" + + pool = await create_db_pool(dsn=dsn) + data = PatchSessionRequest( + metadata={"test": "metadata"}, + ) + result = await patch_session( + developer_id=developer_id, + session_id=session.id, + data=data, + connection_pool=pool, ) assert result is not None - assert isinstance(result, Session) - assert result.id == session_id + assert isinstance(result, ResourceUpdatedResponse) + assert result.updated_at > session.created_at + + patched_session = await get_session( + developer_id=developer_id, + session_id=session.id, + connection_pool=pool, + ) + assert patched_session.metadata == {"test": "metadata"} + + +@test("query: delete session sql") +async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): + """Test that a session can be successfully deleted.""" + + pool = await create_db_pool(dsn=dsn) + delete_result = await delete_session( + developer_id=developer_id, + session_id=session.id, + connection_pool=pool, + ) + + assert delete_result is not None + assert isinstance(delete_result, ResourceDeletedResponse) + + with raises(Exception): + await get_session( + developer_id=developer_id, + session_id=session.id, + connection_pool=pool, + ) diff --git a/agents-api/tests/test_session_routes.py b/agents-api/tests/test_session_routes.py new file mode 100644 index 000000000..aa1380f11 --- /dev/null +++ b/agents-api/tests/test_session_routes.py @@ -0,0 +1,177 @@ +from uuid_extensions import uuid7 +from ward import test + +from tests.fixtures import client, make_request, test_agent, test_session + + +@test("route: unauthorized should fail") +def _(client=client): + response = client.request( + method="GET", + url="/sessions", + ) + + assert response.status_code == 403 + + +@test("route: create session") +def _(make_request=make_request, agent=test_agent): + data = { + "agent": str(agent.id), + "situation": "test session about", + "metadata": {"test": "test"}, + "system_template": "test system template", + } + + response = make_request( + method="POST", + url="/sessions", + json=data, + ) + + assert response.status_code == 201 + + +@test("route: create or update session - create") +def _(make_request=make_request, agent=test_agent): + session_id = uuid7() + + data = { + "agent": str(agent.id), + "situation": "test session about", + "metadata": {"test": "test"}, + "system_template": "test system template", + } + + response = make_request( + method="POST", + url=f"/sessions/{session_id}", + json=data, + ) + + assert response.status_code == 201 + + +@test("route: create or update session - update") +def _(make_request=make_request, session=test_session, agent=test_agent): + data = { + "agent": str(agent.id), + "situation": "test session about", + "metadata": {"test": "test"}, + "system_template": "test system template", + } + + response = make_request( + method="POST", + url=f"/sessions/{session.id}", + json=data, + ) + + assert response.status_code == 201, f"{response.json()}" + + +@test("route: get session - exists") +def _(make_request=make_request, session=test_session): + response = make_request( + method="GET", + url=f"/sessions/{session.id}", + ) + + assert response.status_code == 200 + + +@test("route: get session - does not exist") +def _(make_request=make_request): + session_id = uuid7() + response = make_request( + method="GET", + url=f"/sessions/{session_id}", + ) + + assert response.status_code == 404 + + +@test("route: list sessions") +def _(make_request=make_request, session=test_session): + response = make_request( + method="GET", + url="/sessions", + ) + + assert response.status_code == 200 + response = response.json() + sessions = response["items"] + + assert isinstance(sessions, list) + assert len(sessions) > 0 + + +@test("route: list sessions with metadata filter") +def _(make_request=make_request, session=test_session): + response = make_request( + method="GET", + url="/sessions", + params={ + "metadata_filter": {"test": "test"}, + }, + ) + + assert response.status_code == 200 + response = response.json() + sessions = response["items"] + + assert isinstance(sessions, list) + assert len(sessions) > 0 + + +@test("route: get session history") +def _(make_request=make_request, session=test_session): + response = make_request( + method="GET", + url=f"/sessions/{session.id}/history", + ) + + assert response.status_code == 200 + + history = response.json() + assert history["session_id"] == str(session.id) + + +@test("route: patch session") +def _(make_request=make_request, session=test_session): + data = { + "situation": "test session about", + } + + response = make_request( + method="PATCH", + url=f"/sessions/{session.id}", + json=data, + ) + + assert response.status_code == 200 + + +@test("route: update session") +def _(make_request=make_request, session=test_session): + data = { + "situation": "test session about", + } + + response = make_request( + method="PUT", + url=f"/sessions/{session.id}", + json=data, + ) + + assert response.status_code == 200 + + +@test("route: delete session") +def _(make_request=make_request, session=test_session): + response = make_request( + method="DELETE", + url=f"/sessions/{session.id}", + ) + + assert response.status_code == 202 diff --git a/agents-api/tests/test_sessions.py b/agents-api/tests/test_sessions.py deleted file mode 100644 index b25a8a706..000000000 --- a/agents-api/tests/test_sessions.py +++ /dev/null @@ -1,36 +0,0 @@ -from ward import test - -from tests.fixtures import make_request - - -@test("model: list sessions") -def _(make_request=make_request): - response = make_request( - method="GET", - url="/sessions", - ) - - assert response.status_code == 200 - response = response.json() - sessions = response["items"] - - assert isinstance(sessions, list) - assert len(sessions) > 0 - - -@test("model: list sessions with metadata filter") -def _(make_request=make_request): - response = make_request( - method="GET", - url="/sessions", - params={ - "metadata_filter": {"test": "test"}, - }, - ) - - assert response.status_code == 200 - response = response.json() - sessions = response["items"] - - assert isinstance(sessions, list) - assert len(sessions) > 0 diff --git a/agents-api/tests/test_task_queries.py b/agents-api/tests/test_task_queries.py index e61489df8..84b18cad8 100644 --- a/agents-api/tests/test_task_queries.py +++ b/agents-api/tests/test_task_queries.py @@ -1,160 +1,321 @@ # Tests for task queries -from uuid import uuid4 - -from ward import test from agents_api.autogen.openapi_model import ( CreateTaskRequest, + PatchTaskRequest, ResourceUpdatedResponse, Task, UpdateTaskRequest, ) -from agents_api.models.task.create_or_update_task import create_or_update_task -from agents_api.models.task.create_task import create_task -from agents_api.models.task.delete_task import delete_task -from agents_api.models.task.get_task import get_task -from agents_api.models.task.list_tasks import list_tasks -from agents_api.models.task.update_task import update_task -from tests.fixtures import cozo_client, test_agent, test_developer_id, test_task +from agents_api.clients.pg import create_db_pool +from agents_api.queries.tasks.create_or_update_task import create_or_update_task +from agents_api.queries.tasks.create_task import create_task +from agents_api.queries.tasks.delete_task import delete_task +from agents_api.queries.tasks.get_task import get_task +from agents_api.queries.tasks.list_tasks import list_tasks +from agents_api.queries.tasks.patch_task import patch_task +from agents_api.queries.tasks.update_task import update_task +from fastapi import HTTPException +from uuid_extensions import uuid7 +from ward import raises, test + +from tests.fixtures import pg_dsn, test_agent, test_developer_id, test_task -@test("model: create task") -def _(client=cozo_client, developer_id=test_developer_id, agent=test_agent): - task_id = uuid4() +@test("query: create task sql") +async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent): + """Test that a task can be successfully created.""" - create_task( + pool = await create_db_pool(dsn=dsn) + await create_task( developer_id=developer_id, agent_id=agent.id, - task_id=task_id, + task_id=uuid7(), data=CreateTaskRequest( - **{ - "name": "test task", - "description": "test task about", - "input_schema": {"type": "object", "additionalProperties": True}, - "main": [{"evaluate": {"hi": "_"}}], - } + name="test task", + description="test task about", + input_schema={"type": "object", "additionalProperties": True}, + main=[{"evaluate": {"hi": "_"}}], ), - client=client, + connection_pool=pool, ) -@test("model: create or update task") -def _(client=cozo_client, developer_id=test_developer_id, agent=test_agent): - task_id = uuid4() +@test("query: create or update task sql") +async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent): + """Test that a task can be successfully created or updated.""" - create_or_update_task( + pool = await create_db_pool(dsn=dsn) + await create_or_update_task( developer_id=developer_id, agent_id=agent.id, - task_id=task_id, + task_id=uuid7(), data=CreateTaskRequest( - **{ - "name": "test task", - "description": "test task about", - "input_schema": {"type": "object", "additionalProperties": True}, - "main": [{"evaluate": {"hi": "_"}}], - } + name="test task", + description="test task about", + input_schema={"type": "object", "additionalProperties": True}, + main=[{"evaluate": {"hi": "_"}}], ), - client=client, + connection_pool=pool, + ) + + +@test("query: get task sql - exists") +async def _(dsn=pg_dsn, developer_id=test_developer_id, task=test_task): + """Test that an existing task can be successfully retrieved.""" + + pool = await create_db_pool(dsn=dsn) + + # Then retrieve it + result = await get_task( + developer_id=developer_id, + task_id=task.id, + connection_pool=pool, ) + assert result is not None + assert isinstance(result, Task), f"Result is not a Task, got {type(result)}" + assert result.id == task.id + assert result.name == "test task" + assert result.description == "test task about" -@test("model: get task not exists") -def _(client=cozo_client, developer_id=test_developer_id): - task_id = uuid4() +@test("query: get task sql - not exists") +async def _(dsn=pg_dsn, developer_id=test_developer_id): + """Test that attempting to retrieve a non-existent task raises an error.""" - try: - get_task( + pool = await create_db_pool(dsn=dsn) + task_id = uuid7() + + with raises(HTTPException) as exc: + await get_task( developer_id=developer_id, task_id=task_id, - client=client, + connection_pool=pool, ) - except Exception: - pass - else: - assert False, "Task should not exist" + assert exc.raised.status_code == 404 + assert "Task not found" in str(exc.raised.detail) + + +@test("query: delete task sql - exists") +async def _(dsn=pg_dsn, developer_id=test_developer_id, task=test_task): + """Test that a task can be successfully deleted.""" + + pool = await create_db_pool(dsn=dsn) + + # First verify task exists + result = await get_task( + developer_id=developer_id, + task_id=task.id, + connection_pool=pool, + ) + assert result is not None + assert result.id == task.id -@test("model: get task exists") -def _(client=cozo_client, developer_id=test_developer_id, task=test_task): - result = get_task( + # Delete the task + deleted = await delete_task( developer_id=developer_id, task_id=task.id, - client=client, + connection_pool=pool, ) + assert deleted is not None + assert deleted.id == task.id + # Verify task no longer exists + with raises(HTTPException) as exc: + await get_task( + developer_id=developer_id, + task_id=task.id, + connection_pool=pool, + ) + + assert exc.raised.status_code == 404 + assert "Task not found" in str(exc.raised.detail) + + +@test("query: delete task sql - not exists") +async def _(dsn=pg_dsn, developer_id=test_developer_id): + """Test that attempting to delete a non-existent task raises an error.""" + + pool = await create_db_pool(dsn=dsn) + task_id = uuid7() + + with raises(HTTPException) as exc: + await delete_task( + developer_id=developer_id, + task_id=task_id, + connection_pool=pool, + ) + + assert exc.raised.status_code == 404 + assert "Task not found" in str(exc.raised.detail) + + +# Add tests for list tasks +@test("query: list tasks sql - with filters") +async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent): + """Test that tasks can be successfully filtered and retrieved.""" + + pool = await create_db_pool(dsn=dsn) + result = await list_tasks( + developer_id=developer_id, + agent_id=agent.id, + limit=10, + offset=0, + sort_by="updated_at", + direction="asc", + metadata_filter={"test": True}, + connection_pool=pool, + ) assert result is not None - assert isinstance(result, Task) + assert isinstance(result, list) + assert all(isinstance(task, Task) for task in result) + assert all(task.metadata.get("test") is True for task in result) + +@test("query: list tasks sql - no filters") +async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, task=test_task): + """Test that a list of tasks can be successfully retrieved.""" -@test("model: delete task") -def _(client=cozo_client, developer_id=test_developer_id, agent=test_agent): - task = create_task( + pool = await create_db_pool(dsn=dsn) + result = await list_tasks( developer_id=developer_id, agent_id=agent.id, - data=CreateTaskRequest( - **{ - "name": "test task", - "description": "test task about", - "input_schema": {"type": "object", "additionalProperties": True}, - "main": [{"evaluate": {"hi": "_"}}], - } - ), - client=client, + connection_pool=pool, + ) + assert result is not None, "Result is None" + assert isinstance(result, list), f"Result is not a list, got {type(result)}" + assert len(result) > 0, "Result is empty" + assert all(isinstance(task, Task) for task in result), ( + "Not all listed tasks are of type Task" ) - delete_task( + +@test("query: update task sql - exists") +async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, task=test_task): + """Test that a task can be successfully updated.""" + + pool = await create_db_pool(dsn=dsn) + updated = await update_task( developer_id=developer_id, + task_id=task.id, agent_id=agent.id, + data=UpdateTaskRequest( + name="updated task", + canonical_name="updated_task", + description="updated task description", + input_schema={"type": "object", "additionalProperties": True}, + main=[{"evaluate": {"hi": "_"}}], + inherit_tools=False, + metadata={"updated": True}, + ), + connection_pool=pool, + ) + + assert updated is not None + assert isinstance(updated, ResourceUpdatedResponse) + assert updated.id == task.id + + # Verify task was updated + updated_task = await get_task( + developer_id=developer_id, task_id=task.id, - client=client, + connection_pool=pool, ) + assert updated_task.name == "updated task" + assert updated_task.description == "updated task description" + assert updated_task.metadata == {"updated": True} + - try: - get_task( +@test("query: update task sql - not exists") +async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent): + """Test that attempting to update a non-existent task raises an error.""" + + pool = await create_db_pool(dsn=dsn) + task_id = uuid7() + + with raises(HTTPException) as exc: + await update_task( developer_id=developer_id, - task_id=task.id, - client=client, + task_id=task_id, + agent_id=agent.id, + data=UpdateTaskRequest( + canonical_name="updated_task", + name="updated task", + description="updated task description", + input_schema={"type": "object", "additionalProperties": True}, + main=[{"evaluate": {"hi": "_"}}], + inherit_tools=False, + ), + connection_pool=pool, ) - except Exception: - pass - else: - assert False, "Task should not exist" + assert exc.raised.status_code == 404 + assert "Task not found" in str(exc.raised.detail) -@test("model: update task") -def _( - client=cozo_client, developer_id=test_developer_id, agent=test_agent, task=test_task -): - result = update_task( +@test("query: patch task sql - exists") +async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent): + """Test that patching an existing task works correctly.""" + pool = await create_db_pool(dsn=dsn) + + # Create initial task + task = await create_task( developer_id=developer_id, - task_id=task.id, agent_id=agent.id, - data=UpdateTaskRequest( - **{ - "name": "updated task", - "description": "updated task about", - "input_schema": {"type": "object", "additionalProperties": True}, - "main": [{"evaluate": {"hi": "_"}}], - } + data=CreateTaskRequest( + canonical_name="test_task", + name="test task", + description="test task description", + input_schema={"type": "object", "additionalProperties": True}, + main=[{"evaluate": {"hi": "_"}}], + inherit_tools=False, + metadata={"initial": True}, ), - client=client, + connection_pool=pool, ) - assert result is not None - assert isinstance(result, ResourceUpdatedResponse) + # Patch the task + updated = await patch_task( + developer_id=developer_id, + task_id=task.id, + agent_id=agent.id, + data=PatchTaskRequest(name="patched task", metadata={"patched": True}), + connection_pool=pool, + ) + assert updated is not None + assert isinstance(updated, ResourceUpdatedResponse) + assert updated.id == task.id -@test("model: list tasks") -def _( - client=cozo_client, developer_id=test_developer_id, task=test_task, agent=test_agent -): - result = list_tasks( + # Verify task was patched correctly + patched_task = await get_task( developer_id=developer_id, - agent_id=agent.id, - client=client, + task_id=task.id, + connection_pool=pool, ) + # Check that patched fields were updated + assert patched_task.name == "patched task" + assert patched_task.metadata == {"patched": True} + # Check that non-patched fields remain unchanged + assert patched_task.canonical_name == "test_task" + assert patched_task.description == "test task description" - assert isinstance(result, list) - assert len(result) > 0 - assert all(isinstance(task, Task) for task in result) + +@test("query: patch task sql - not exists") +async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent): + """Test that attempting to patch a non-existent task raises an error.""" + pool = await create_db_pool(dsn=dsn) + task_id = uuid7() + + with raises(HTTPException) as exc: + await patch_task( + developer_id=developer_id, + task_id=task_id, + agent_id=agent.id, + data=PatchTaskRequest(name="patched task", metadata={"patched": True}), + connection_pool=pool, + ) + + assert exc.raised.status_code == 404 + assert "Task not found" in str(exc.raised.detail) diff --git a/agents-api/tests/test_task_routes.py b/agents-api/tests/test_task_routes.py index 5d3c2f998..1d27d26d7 100644 --- a/agents-api/tests/test_task_routes.py +++ b/agents-api/tests/test_task_routes.py @@ -1,24 +1,34 @@ # Tests for task routes -from uuid import uuid4 - -from ward import test +from agents_api.autogen.openapi_model import ( + Transition, +) +from agents_api.queries.executions.create_execution_transition import ( + create_execution_transition, +) +from uuid_extensions import uuid7 +from ward import skip, test -from tests.fixtures import ( +from .fixtures import ( + CreateTransitionRequest, client, + create_db_pool, make_request, + pg_dsn, test_agent, + test_developer_id, test_execution, + test_execution_started, test_task, ) -from tests.utils import patch_testing_temporal +from .utils import patch_testing_temporal @test("route: unauthorized should fail") def _(client=client, agent=test_agent): - data = dict( - name="test user", - main=[ + data = { + "name": "test user", + "main": [ { "kind_": "evaluate", "evaluate": { @@ -26,12 +36,12 @@ def _(client=client, agent=test_agent): }, } ], - ) + } response = client.request( method="POST", - url=f"/agents/{str(agent.id)}/tasks", - data=data, + url=f"/agents/{agent.id!s}/tasks", + json=data, ) assert response.status_code == 403 @@ -39,9 +49,9 @@ def _(client=client, agent=test_agent): @test("route: create task") def _(make_request=make_request, agent=test_agent): - data = dict( - name="test user", - main=[ + data = { + "name": "test user", + "main": [ { "kind_": "evaluate", "evaluate": { @@ -49,11 +59,11 @@ def _(make_request=make_request, agent=test_agent): }, } ], - ) + } response = make_request( method="POST", - url=f"/agents/{str(agent.id)}/tasks", + url=f"/agents/{agent.id!s}/tasks", json=data, ) @@ -62,15 +72,15 @@ def _(make_request=make_request, agent=test_agent): @test("route: create task execution") async def _(make_request=make_request, task=test_task): - data = dict( - input={}, - metadata={}, - ) + data = { + "input": {}, + "metadata": {}, + } async with patch_testing_temporal(): response = make_request( method="POST", - url=f"/tasks/{str(task.id)}/executions", + url=f"/tasks/{task.id!s}/executions", json=data, ) @@ -79,7 +89,7 @@ async def _(make_request=make_request, task=test_task): @test("route: get execution not exists") def _(make_request=make_request): - execution_id = str(uuid4()) + execution_id = str(uuid7()) response = make_request( method="GET", @@ -93,7 +103,7 @@ def _(make_request=make_request): def _(make_request=make_request, execution=test_execution): response = make_request( method="GET", - url=f"/executions/{str(execution.id)}", + url=f"/executions/{execution.id!s}", ) assert response.status_code == 200 @@ -101,47 +111,86 @@ def _(make_request=make_request, execution=test_execution): @test("route: get task not exists") def _(make_request=make_request): - task_id = str(uuid4()) + task_id = str(uuid7()) response = make_request( method="GET", url=f"/tasks/{task_id}", ) - assert response.status_code == 400 + assert response.status_code == 404 @test("route: get task exists") def _(make_request=make_request, task=test_task): response = make_request( method="GET", - url=f"/tasks/{str(task.id)}", + url=f"/tasks/{task.id!s}", ) assert response.status_code == 200 -# FIXME: This test is failing -# @test("route: list execution transitions") -# def _(make_request=make_request, execution=test_execution, transition=test_transition): -# response = make_request( -# method="GET", -# url=f"/executions/{str(execution.id)}/transitions", -# ) +@test("route: list all execution transition") +async def _(make_request=make_request, execution=test_execution_started): + response = make_request( + method="GET", + url=f"/executions/{execution.id!s}/transitions", + ) + + assert response.status_code == 200 + response = response.json() + transitions = response["items"] + + assert isinstance(transitions, list) + assert len(transitions) > 0 + + +@test("route: list a single execution transition") +async def _( + dsn=pg_dsn, + make_request=make_request, + execution=test_execution_started, + developer_id=test_developer_id, +): + pool = await create_db_pool(dsn=dsn) + + # Create a transition + transition = await create_execution_transition( + developer_id=developer_id, + execution_id=execution.id, + data=CreateTransitionRequest( + type="step", + output={}, + current={"workflow": "main", "step": 0}, + next={"workflow": "wf1", "step": 1}, + ), + connection_pool=pool, + ) + + response = make_request( + method="GET", + url=f"/executions/{execution.id!s}/transitions/{transition.id!s}", + ) -# assert response.status_code == 200 -# response = response.json() -# transitions = response["items"] + assert response.status_code == 200 + response = response.json() -# assert isinstance(transitions, list) -# assert len(transitions) > 0 + assert isinstance(transition, Transition) + assert str(transition.id) == response["id"] + assert transition.type == response["type"] + assert transition.output == response["output"] + assert transition.current.workflow == response["current"]["workflow"] + assert transition.current.step == response["current"]["step"] + assert transition.next.workflow == response["next"]["workflow"] + assert transition.next.step == response["next"]["step"] @test("route: list task executions") def _(make_request=make_request, execution=test_execution): response = make_request( method="GET", - url=f"/tasks/{str(execution.task_id)}/executions", + url=f"/tasks/{execution.task_id!s}/executions", ) assert response.status_code == 200 @@ -156,7 +205,32 @@ def _(make_request=make_request, execution=test_execution): def _(make_request=make_request, agent=test_agent): response = make_request( method="GET", - url=f"/agents/{str(agent.id)}/tasks", + url=f"/agents/{agent.id!s}/tasks", + ) + + data = { + "name": "test user", + "main": [ + { + "kind_": "evaluate", + "evaluate": { + "additionalProp1": "value1", + }, + } + ], + } + + response = make_request( + method="POST", + url=f"/agents/{agent.id!s}/tasks", + json=data, + ) + + assert response.status_code == 201 + + response = make_request( + method="GET", + url=f"/agents/{agent.id!s}/tasks", ) assert response.status_code == 200 @@ -167,44 +241,45 @@ def _(make_request=make_request, agent=test_agent): assert len(tasks) > 0 -# FIXME: This test is failing +@skip("Temporal connextion issue") +@test("route: update execution") +async def _(make_request=make_request, task=test_task): + data = { + "input": {}, + "metadata": {}, + } -# @test("route: patch execution") -# async def _(make_request=make_request, task=test_task): -# data = dict( -# input={}, -# metadata={}, -# ) + async with patch_testing_temporal(): + response = make_request( + method="POST", + url=f"/tasks/{task.id!s}/executions", + json=data, + ) -# async with patch_testing_temporal(): -# response = make_request( -# method="POST", -# url=f"/tasks/{str(task.id)}/executions", -# json=data, -# ) + execution = response.json() -# execution = response.json() + data = { + "status": "running", + } -# data = dict( -# status="running", -# ) + execution_id = execution["id"] -# response = make_request( -# method="PATCH", -# url=f"/tasks/{str(task.id)}/executions/{str(execution['id'])}", -# json=data, -# ) + response = make_request( + method="PUT", + url=f"/executions/{execution_id}", + json=data, + ) -# assert response.status_code == 200 + assert response.status_code == 200 -# execution_id = response.json()["id"] + execution_id = response.json()["id"] -# response = make_request( -# method="GET", -# url=f"/executions/{execution_id}", -# ) + response = make_request( + method="GET", + url=f"/executions/{execution_id}", + ) -# assert response.status_code == 200 -# execution = response.json() + assert response.status_code == 200 + execution = response.json() -# assert execution["status"] == "running" + assert execution["status"] == "running" diff --git a/agents-api/tests/test_tool_queries.py b/agents-api/tests/test_tool_queries.py index b41125aaf..218136c79 100644 --- a/agents-api/tests/test_tool_queries.py +++ b/agents-api/tests/test_tool_queries.py @@ -1,6 +1,4 @@ -# Tests for tool queries - -from ward import test +# # Tests for tool queries from agents_api.autogen.openapi_model import ( CreateToolRequest, @@ -8,17 +6,21 @@ Tool, UpdateToolRequest, ) -from agents_api.models.tools.create_tools import create_tools -from agents_api.models.tools.delete_tool import delete_tool -from agents_api.models.tools.get_tool import get_tool -from agents_api.models.tools.list_tools import list_tools -from agents_api.models.tools.patch_tool import patch_tool -from agents_api.models.tools.update_tool import update_tool -from tests.fixtures import cozo_client, test_agent, test_developer_id, test_tool +from agents_api.clients.pg import create_db_pool +from agents_api.queries.tools.create_tools import create_tools +from agents_api.queries.tools.delete_tool import delete_tool +from agents_api.queries.tools.get_tool import get_tool +from agents_api.queries.tools.list_tools import list_tools +from agents_api.queries.tools.patch_tool import patch_tool +from agents_api.queries.tools.update_tool import update_tool +from ward import test +from tests.fixtures import pg_dsn, test_agent, test_developer_id, test_tool -@test("model: create tool") -def _(client=cozo_client, developer_id=test_developer_id, agent=test_agent): + +@test("query: create tool") +async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent): + pool = await create_db_pool(dsn=dsn) function = { "name": "hello_world", "description": "A function that prints hello world", @@ -31,19 +33,20 @@ def _(client=cozo_client, developer_id=test_developer_id, agent=test_agent): "type": "function", } - result = create_tools( + result = await create_tools( developer_id=developer_id, agent_id=agent.id, data=[CreateToolRequest(**tool)], - client=client, + connection_pool=pool, ) assert result is not None assert isinstance(result[0], Tool) -@test("model: delete tool") -def _(client=cozo_client, developer_id=test_developer_id, agent=test_agent): +@test("query: delete tool") +async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent): + pool = await create_db_pool(dsn=dsn) function = { "name": "temp_temp", "description": "A function that prints hello world", @@ -56,79 +59,78 @@ def _(client=cozo_client, developer_id=test_developer_id, agent=test_agent): "type": "function", } - [tool, *_] = create_tools( + [tool, *_] = await create_tools( developer_id=developer_id, agent_id=agent.id, data=[CreateToolRequest(**tool)], - client=client, + connection_pool=pool, ) - result = delete_tool( + result = await delete_tool( developer_id=developer_id, agent_id=agent.id, tool_id=tool.id, - client=client, + connection_pool=pool, ) assert result is not None -@test("model: get tool") -def _( - client=cozo_client, developer_id=test_developer_id, tool=test_tool, agent=test_agent -): - result = get_tool( +@test("query: get tool") +async def _(dsn=pg_dsn, developer_id=test_developer_id, tool=test_tool, agent=test_agent): + pool = await create_db_pool(dsn=dsn) + result = await get_tool( developer_id=developer_id, agent_id=agent.id, tool_id=tool.id, - client=client, + connection_pool=pool, ) - assert result is not None + assert result is not None, "Result is None" -@test("model: list tools") -def _( - client=cozo_client, developer_id=test_developer_id, agent=test_agent, tool=test_tool -): - result = list_tools( +@test("query: list tools") +async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, tool=test_tool): + pool = await create_db_pool(dsn=dsn) + result = await list_tools( developer_id=developer_id, agent_id=agent.id, - client=client, + connection_pool=pool, ) - assert result is not None - assert all(isinstance(tool, Tool) for tool in result) + assert result is not None, "Result is None" + assert len(result) > 0, "Result is empty" + assert all(isinstance(tool, Tool) for tool in result), ( + "Not all listed tools are of type Tool" + ) -@test("model: patch tool") -def _( - client=cozo_client, developer_id=test_developer_id, agent=test_agent, tool=test_tool -): +@test("query: patch tool") +async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, tool=test_tool): + pool = await create_db_pool(dsn=dsn) patch_data = PatchToolRequest( - **{ - "name": "patched_tool", - "function": { - "description": "A patched function that prints hello world", - }, - } + name="patched_tool", + function={ + "description": "A patched function that prints hello world", + "parameters": {"param1": "value1"}, + }, ) - result = patch_tool( + result = await patch_tool( developer_id=developer_id, agent_id=agent.id, tool_id=tool.id, data=patch_data, - client=client, + connection_pool=pool, ) assert result is not None - tool = get_tool( + tool = await get_tool( developer_id=developer_id, agent_id=agent.id, tool_id=tool.id, - client=client, + connection_pool=pool, ) assert tool.name == "patched_tool" @@ -136,10 +138,9 @@ def _( assert tool.function.parameters -@test("model: update tool") -def _( - client=cozo_client, developer_id=test_developer_id, agent=test_agent, tool=test_tool -): +@test("query: update tool") +async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, tool=test_tool): + pool = await create_db_pool(dsn=dsn) update_data = UpdateToolRequest( name="updated_tool", description="An updated description", @@ -149,21 +150,21 @@ def _( }, ) - result = update_tool( + result = await update_tool( developer_id=developer_id, agent_id=agent.id, tool_id=tool.id, data=update_data, - client=client, + connection_pool=pool, ) assert result is not None - tool = get_tool( + tool = await get_tool( developer_id=developer_id, agent_id=agent.id, tool_id=tool.id, - client=client, + connection_pool=pool, ) assert tool.name == "updated_tool" diff --git a/agents-api/tests/test_user_queries.py b/agents-api/tests/test_user_queries.py index ab5c62ed0..b0a259805 100644 --- a/agents-api/tests/test_user_queries.py +++ b/agents-api/tests/test_user_queries.py @@ -1,67 +1,83 @@ -# This module contains tests for user-related queries against the 'cozodb' database. It includes tests for creating, updating, and retrieving user information. +""" +This module contains tests for SQL query generation functions in the users module. +Tests verify the SQL queries without actually executing them against a database. +""" -# Tests for user queries -from uuid import uuid4 - -from ward import test +from uuid import UUID from agents_api.autogen.openapi_model import ( CreateOrUpdateUserRequest, CreateUserRequest, + PatchUserRequest, + ResourceDeletedResponse, ResourceUpdatedResponse, UpdateUserRequest, User, ) -from agents_api.models.user.create_or_update_user import create_or_update_user -from agents_api.models.user.create_user import create_user -from agents_api.models.user.get_user import get_user -from agents_api.models.user.list_users import list_users -from agents_api.models.user.update_user import update_user -from tests.fixtures import cozo_client, test_developer_id, test_user +from agents_api.clients.pg import create_db_pool +from agents_api.queries.users import ( + create_or_update_user, + create_user, + delete_user, + get_user, + list_users, + patch_user, + update_user, +) +from uuid_extensions import uuid7 +from ward import raises, test + +from tests.fixtures import pg_dsn, test_developer_id, test_user + +# Test UUIDs for consistent testing +TEST_DEVELOPER_ID = UUID("123e4567-e89b-12d3-a456-426614174000") +TEST_USER_ID = UUID("987e6543-e21b-12d3-a456-426614174000") -@test("model: create user") -def _(client=cozo_client, developer_id=test_developer_id): +@test("query: create user sql") +async def _(dsn=pg_dsn, developer_id=test_developer_id): """Test that a user can be successfully created.""" - create_user( + pool = await create_db_pool(dsn=dsn) + await create_user( developer_id=developer_id, data=CreateUserRequest( name="test user", about="test user about", ), - client=client, + connection_pool=pool, ) -@test("model: create or update user") -def _(client=cozo_client, developer_id=test_developer_id): +@test("query: create or update user sql") +async def _(dsn=pg_dsn, developer_id=test_developer_id): """Test that a user can be successfully created or updated.""" - create_or_update_user( + pool = await create_db_pool(dsn=dsn) + await create_or_update_user( developer_id=developer_id, - user_id=uuid4(), + user_id=uuid7(), data=CreateOrUpdateUserRequest( name="test user", about="test user about", ), - client=client, + connection_pool=pool, ) -@test("model: update user") -def _(client=cozo_client, developer_id=test_developer_id, user=test_user): +@test("query: update user sql") +async def _(dsn=pg_dsn, developer_id=test_developer_id, user=test_user): """Test that an existing user's information can be successfully updated.""" - # Verify that the 'updated_at' timestamp is greater than the 'created_at' timestamp, indicating a successful update. - update_result = update_user( + pool = await create_db_pool(dsn=dsn) + update_result = await update_user( user_id=user.id, developer_id=developer_id, data=UpdateUserRequest( name="updated user", about="updated user about", ), - client=client, + connection_pool=pool, ) assert update_result is not None @@ -69,50 +85,95 @@ def _(client=cozo_client, developer_id=test_developer_id, user=test_user): assert update_result.updated_at > user.created_at -@test("model: get user not exists") -def _(client=cozo_client, developer_id=test_developer_id): +@test("query: get user not exists sql") +async def _(dsn=pg_dsn, developer_id=test_developer_id): """Test that retrieving a non-existent user returns an empty result.""" - user_id = uuid4() + user_id = uuid7() - # Ensure that the query for an existing user returns exactly one result. - try: - get_user( + pool = await create_db_pool(dsn=dsn) + + with raises(Exception): + await get_user( user_id=user_id, developer_id=developer_id, - client=client, + connection_pool=pool, ) - except Exception: - pass - else: - assert ( - False - ), "Expected an exception to be raised when retrieving a non-existent user." -@test("model: get user exists") -def _(client=cozo_client, developer_id=test_developer_id, user=test_user): +@test("query: get user exists sql") +async def _(dsn=pg_dsn, developer_id=test_developer_id, user=test_user): """Test that retrieving an existing user returns the correct user information.""" - result = get_user( + pool = await create_db_pool(dsn=dsn) + result = await get_user( user_id=user.id, developer_id=developer_id, - client=client, + connection_pool=pool, ) assert result is not None assert isinstance(result, User) -@test("model: list users") -def _(client=cozo_client, developer_id=test_developer_id, user=test_user): +@test("query: list users sql") +async def _(dsn=pg_dsn, developer_id=test_developer_id): """Test that listing users returns a collection of user information.""" - result = list_users( + pool = await create_db_pool(dsn=dsn) + result = await list_users( developer_id=developer_id, - client=client, + connection_pool=pool, ) assert isinstance(result, list) assert len(result) >= 1 assert all(isinstance(user, User) for user in result) + + +@test("query: patch user sql") +async def _(dsn=pg_dsn, developer_id=test_developer_id, user=test_user): + """Test that a user can be successfully patched.""" + + pool = await create_db_pool(dsn=dsn) + patch_result = await patch_user( + developer_id=developer_id, + user_id=user.id, + data=PatchUserRequest( + name="patched user", + about="patched user about", + metadata={"test": "metadata"}, + ), + connection_pool=pool, + ) + + assert patch_result is not None + assert isinstance(patch_result, ResourceUpdatedResponse) + assert patch_result.updated_at > user.created_at + + +@test("query: delete user sql") +async def _(dsn=pg_dsn, developer_id=test_developer_id, user=test_user): + """Test that a user can be successfully deleted.""" + + pool = await create_db_pool(dsn=dsn) + delete_result = await delete_user( + developer_id=developer_id, + user_id=user.id, + connection_pool=pool, + ) + + assert delete_result is not None + assert isinstance(delete_result, ResourceDeletedResponse) + + # Verify the user no longer exists + try: + await get_user( + developer_id=developer_id, + user_id=user.id, + connection_pool=pool, + ) + except Exception: + pass + else: + assert False, "Expected an exception to be raised when retrieving a deleted user." diff --git a/agents-api/tests/test_user_routes.py b/agents-api/tests/test_user_routes.py index 229d85619..b158bea00 100644 --- a/agents-api/tests/test_user_routes.py +++ b/agents-api/tests/test_user_routes.py @@ -1,6 +1,6 @@ # Tests for user routes -from uuid import uuid4 +from uuid_extensions import uuid7 from ward import test from tests.fixtures import client, make_request, test_user @@ -8,15 +8,15 @@ @test("route: unauthorized should fail") def _(client=client): - data = dict( - name="test user", - about="test user about", - ) + data = { + "name": "test user", + "about": "test user about", + } response = client.request( method="POST", url="/users", - data=data, + json=data, ) assert response.status_code == 403 @@ -24,10 +24,10 @@ def _(client=client): @test("route: create user") def _(make_request=make_request): - data = dict( - name="test user", - about="test user about", - ) + data = { + "name": "test user", + "about": "test user about", + } response = make_request( method="POST", @@ -40,7 +40,7 @@ def _(make_request=make_request): @test("route: get user not exists") def _(make_request=make_request): - user_id = str(uuid4()) + user_id = str(uuid7()) response = make_request( method="GET", @@ -64,10 +64,10 @@ def _(make_request=make_request, user=test_user): @test("route: delete user") def _(make_request=make_request): - data = dict( - name="test user", - about="test user about", - ) + data = { + "name": "test user", + "about": "test user about", + } response = make_request( method="POST", @@ -93,10 +93,10 @@ def _(make_request=make_request): @test("route: update user") def _(make_request=make_request, user=test_user): - data = dict( - name="updated user", - about="updated user about", - ) + data = { + "name": "updated user", + "about": "updated user about", + } user_id = str(user.id) response = make_request( @@ -121,14 +121,14 @@ def _(make_request=make_request, user=test_user): assert user["about"] == "updated user about" -@test("model: patch user") +@test("query: patch user") def _(make_request=make_request, user=test_user): user_id = str(user.id) - data = dict( - name="patched user", - about="patched user about", - ) + data = { + "name": "patched user", + "about": "patched user about", + } response = make_request( method="PATCH", @@ -152,7 +152,7 @@ def _(make_request=make_request, user=test_user): assert user["about"] == "patched user about" -@test("model: list users") +@test("query: list users") def _(make_request=make_request): response = make_request( method="GET", @@ -167,7 +167,7 @@ def _(make_request=make_request): assert len(users) > 0 -@test("model: list users with right metadata filter") +@test("query: list users with right metadata filter") def _(make_request=make_request, user=test_user): response = make_request( method="GET", diff --git a/agents-api/tests/test_workflow_routes.py b/agents-api/tests/test_workflow_routes.py index 2ffc73173..220bcc820 100644 --- a/agents-api/tests/test_workflow_routes.py +++ b/agents-api/tests/test_workflow_routes.py @@ -1,27 +1,28 @@ # Tests for task queries -from uuid import uuid4 - +from agents_api.clients.pg import create_db_pool +from uuid_extensions import uuid7 from ward import test -from tests.fixtures import cozo_client, test_agent, test_developer_id +from tests.fixtures import pg_dsn, test_agent, test_developer_id from tests.utils import patch_http_client_with_temporal @test("workflow route: evaluate step single") async def _( - cozo_client=cozo_client, + dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, ): + pool = await create_db_pool(dsn=dsn) agent_id = str(agent.id) - task_id = str(uuid4()) + task_id = str(uuid7()) async with patch_http_client_with_temporal( - cozo_client=cozo_client, developer_id=developer_id + postgres_pool=pool, developer_id=developer_id ) as ( make_request, - client, + _postgres_pool, ): task_data = { "name": "test task", @@ -36,7 +37,7 @@ async def _( json=task_data, ).raise_for_status() - execution_data = dict(input={"test": "input"}) + execution_data = {"input": {"test": "input"}} make_request( method="POST", @@ -47,17 +48,18 @@ async def _( @test("workflow route: evaluate step single with yaml") async def _( - cozo_client=cozo_client, + dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, ): + pool = await create_db_pool(dsn=dsn) agent_id = str(agent.id) async with patch_http_client_with_temporal( - cozo_client=cozo_client, developer_id=developer_id + postgres_pool=pool, developer_id=developer_id ) as ( make_request, - client, + _postgres_pool, ): task_data = """ name: test task @@ -84,7 +86,7 @@ async def _( task_id = result["id"] - execution_data = dict(input={"test": "input"}) + execution_data = {"input": {"test": "input"}} make_request( method="POST", @@ -95,18 +97,19 @@ async def _( @test("workflow route: create or update: evaluate step single with yaml") async def _( - cozo_client=cozo_client, + dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, ): + pool = await create_db_pool(dsn=dsn) agent_id = str(agent.id) - task_id = str(uuid4()) + task_id = str(uuid7()) async with patch_http_client_with_temporal( - cozo_client=cozo_client, developer_id=developer_id + postgres_pool=pool, developer_id=developer_id ) as ( make_request, - client, + _postgres_pool, ): task_data = """ name: test task @@ -127,7 +130,7 @@ async def _( headers={"Content-Type": "text/yaml"}, ).raise_for_status() - execution_data = dict(input={"test": "input"}) + execution_data = {"input": {"test": "input"}} make_request( method="POST", diff --git a/agents-api/tests/utils.py b/agents-api/tests/utils.py index 130518419..2049b4689 100644 --- a/agents-api/tests/utils.py +++ b/agents-api/tests/utils.py @@ -1,17 +1,16 @@ import asyncio import logging +import subprocess from contextlib import asynccontextmanager, contextmanager -from dataclasses import dataclass -from typing import Any, Dict, Optional from unittest.mock import patch -from botocore import exceptions +from agents_api.worker.codec import pydantic_data_converter +from agents_api.worker.worker import create_worker from fastapi.testclient import TestClient from litellm.types.utils import ModelResponse from temporalio.testing import WorkflowEnvironment - -from agents_api.worker.codec import pydantic_data_converter -from agents_api.worker.worker import create_worker +from testcontainers.localstack import LocalStackContainer +from testcontainers.postgres import PostgresContainer # Replicated here to prevent circular import EMBEDDING_SIZE: int = 1024 @@ -31,7 +30,7 @@ async def patch_testing_temporal(): ) as env: # Create a worker with our workflows and start it worker = create_worker(client=env.client) - asyncio.create_task(worker.run()) + env.worker_task = asyncio.create_task(worker.run()) # Mock the Temporal client mock_client = worker.client @@ -50,13 +49,13 @@ async def patch_testing_temporal(): @asynccontextmanager -async def patch_http_client_with_temporal(*, cozo_client, developer_id): - async with patch_testing_temporal() as (worker, mock_get_client): +async def patch_http_client_with_temporal(*, postgres_pool, developer_id): + async with patch_testing_temporal() as (_worker, mock_get_client): from agents_api.env import api_key, api_key_header_name from agents_api.web import app client = TestClient(app=app) - app.state.cozo_client = cozo_client + app.state.postgres_pool = postgres_pool def make_request(method, url, **kwargs): headers = kwargs.pop("headers", {}) @@ -77,12 +76,12 @@ def patch_embed_acompletion(output={"role": "assistant", "content": "Hello, worl mock_model_response = ModelResponse( id="fake_id", choices=[ - dict( - message=output, - tool_calls=[], - created_at=1, + { + "message": output, + "tool_calls": [], + "created_at": 1, # finish_reason="stop", - ) + } ], created=0, object="text_completion", @@ -109,64 +108,20 @@ def patch_integration_service(output: dict = {"result": "ok"}): @contextmanager -def patch_s3_client(): - @dataclass - class AsyncBytesIO: - content: bytes - - async def read(self) -> bytes: - return self.content - - @dataclass - class InMemoryS3Client: - store: Optional[Dict[str, Dict[str, Any]]] = None - - def __post_init__(self): - self.store = {} - - def _get_object_or_raise(self, bucket: str, key: str, operation: str): - obj = self.store.get(bucket, {}).get(key) - if obj is None: - raise exceptions.ClientError( - {"Error": {"Code": "404", "Message": "Not Found"}}, operation - ) - return obj +def get_pg_dsn(): + with PostgresContainer("timescale/timescaledb-ha:pg17") as postgres: + test_psql_url = postgres.get_connection_url() + pg_dsn = f"postgres://{test_psql_url[22:]}?sslmode=disable" + command = f"migrate -database '{pg_dsn}' -path ../memory-store/migrations/ up" + process = subprocess.Popen(command, shell=True) + process.wait() - async def list_buckets(self): - return {"Buckets": [{"Name": bucket} for bucket in self.store]} + yield pg_dsn - async def create_bucket(self, Bucket): - self.store.setdefault(Bucket, {}) - async def head_object(self, Bucket, Key): - return self._get_object_or_raise(Bucket, Key, "HeadObject") - - async def put_object(self, Bucket, Key, Body): - self.store.setdefault(Bucket, {})[Key] = Body - - async def get_object(self, Bucket, Key): - obj = self._get_object_or_raise(Bucket, Key, "GetObject") - return {"Body": AsyncBytesIO(obj)} - - async def delete_object(self, Bucket, Key): - if Bucket in self.store: - self.store[Bucket].pop(Key, None) - - class MockSession: - s3_client = InMemoryS3Client() - - async def __aenter__(self): - return self.s3_client - - async def __aexit__(self, *_): - pass - - mock_session = type( - "MockSessionFactory", - (), - {"create_client": lambda self, service_name, **kwargs: MockSession()}, - )() - - with patch("agents_api.clients.async_s3.get_session") as get_session: - get_session.return_value = mock_session - yield mock_session +@contextmanager +def get_localstack(): + with LocalStackContainer(image="localstack/localstack:s3-latest").with_services( + "s3" + ) as localstack: + yield localstack diff --git a/agents-api/uv.lock b/agents-api/uv.lock index 9517c86f3..440b3bb6c 100644 --- a/agents-api/uv.lock +++ b/agents-api/uv.lock @@ -15,6 +15,7 @@ dependencies = [ { name = "anyio" }, { name = "arrow" }, { name = "async-lru" }, + { name = "asyncpg" }, { name = "beartype" }, { name = "en-core-web-sm" }, { name = "environs" }, @@ -36,8 +37,6 @@ dependencies = [ { name = "pandas" }, { name = "prometheus-client" }, { name = "prometheus-fastapi-instrumentator" }, - { name = "pycozo", extra = ["embedded"] }, - { name = "pycozo-async" }, { name = "pydantic", extra = ["email"] }, { name = "pydantic-partial" }, { name = "python-box" }, @@ -52,6 +51,8 @@ dependencies = [ { name = "tenacity" }, { name = "thefuzz" }, { name = "tiktoken" }, + { name = "unique-namer" }, + { name = "uuid7" }, { name = "uvicorn" }, { name = "uvloop" }, { name = "xxhash" }, @@ -59,17 +60,19 @@ dependencies = [ [package.dev-dependencies] dev = [ - { name = "cozo-migrate" }, { name = "datamodel-code-generator" }, { name = "ipython" }, { name = "ipywidgets" }, { name = "julep" }, { name = "jupyterlab" }, + { name = "pip" }, { name = "poethepoet" }, { name = "pyjwt" }, { name = "pyright" }, { name = "pytype" }, { name = "ruff" }, + { name = "sqlvalidator" }, + { name = "testcontainers", extra = ["localstack"] }, { name = "ward" }, ] @@ -80,6 +83,7 @@ requires-dist = [ { name = "anyio", specifier = "~=4.4.0" }, { name = "arrow", specifier = "~=1.3.0" }, { name = "async-lru", specifier = "~=2.0.4" }, + { name = "asyncpg", specifier = ">=0.30.0" }, { name = "beartype", specifier = "~=0.18.5" }, { name = "en-core-web-sm", url = "https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl" }, { name = "environs", specifier = "~=10.3.0" }, @@ -101,8 +105,6 @@ requires-dist = [ { name = "pandas", specifier = "~=2.2.2" }, { name = "prometheus-client", specifier = "~=0.21.0" }, { name = "prometheus-fastapi-instrumentator", specifier = "~=7.0.0" }, - { name = "pycozo", extras = ["embedded"], specifier = "~=0.7.6" }, - { name = "pycozo-async", specifier = "~=0.7.7" }, { name = "pydantic", extras = ["email"], specifier = "~=2.10.2" }, { name = "pydantic-partial", specifier = "~=0.5.5" }, { name = "python-box", specifier = "~=7.2.0" }, @@ -117,6 +119,8 @@ requires-dist = [ { name = "tenacity", specifier = "~=9.0.0" }, { name = "thefuzz", specifier = "~=0.22.1" }, { name = "tiktoken", specifier = "~=0.7.0" }, + { name = "unique-namer", specifier = ">=1.6.1" }, + { name = "uuid7", specifier = ">=0.1.0" }, { name = "uvicorn", specifier = "~=0.30.6" }, { name = "uvloop", specifier = "~=0.21.0" }, { name = "xxhash", specifier = "~=3.5.0" }, @@ -124,17 +128,19 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ - { name = "cozo-migrate", specifier = ">=0.2.4" }, { name = "datamodel-code-generator", specifier = ">=0.26.3" }, { name = "ipython", specifier = ">=8.30.0" }, { name = "ipywidgets", specifier = ">=8.1.5" }, { name = "julep", specifier = ">=1.43.1" }, { name = "jupyterlab", specifier = ">=4.3.1" }, + { name = "pip", specifier = ">=24.3.1" }, { name = "poethepoet", specifier = ">=0.31.1" }, { name = "pyjwt", specifier = ">=2.10.1" }, - { name = "pyright", specifier = ">=1.1.389" }, + { name = "pyright", specifier = ">=1.1.391" }, { name = "pytype", specifier = ">=2024.10.11" }, - { name = "ruff", specifier = ">=0.8.1" }, + { name = "ruff", specifier = ">=0.8.4" }, + { name = "sqlvalidator", specifier = ">=0.0.20" }, + { name = "testcontainers", extras = ["postgres", "localstack"], specifier = ">=4.9.0" }, { name = "ward", specifier = ">=0.68.0b0" }, ] @@ -338,6 +344,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fa/9f/3c3503693386c4b0f245eaf5ca6198e3b28879ca0a40bde6b0e319793453/async_lru-2.0.4-py3-none-any.whl", hash = "sha256:ff02944ce3c288c5be660c42dbcca0742b32c3b279d6dceda655190240b99224", size = 6111 }, ] +[[package]] +name = "asyncpg" +version = "0.30.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/2f/4c/7c991e080e106d854809030d8584e15b2e996e26f16aee6d757e387bc17d/asyncpg-0.30.0.tar.gz", hash = "sha256:c551e9928ab6707602f44811817f82ba3c446e018bfe1d3abecc8ba5f3eac851", size = 957746 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4b/64/9d3e887bb7b01535fdbc45fbd5f0a8447539833b97ee69ecdbb7a79d0cb4/asyncpg-0.30.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c902a60b52e506d38d7e80e0dd5399f657220f24635fee368117b8b5fce1142e", size = 673162 }, + { url = "https://files.pythonhosted.org/packages/6e/eb/8b236663f06984f212a087b3e849731f917ab80f84450e943900e8ca4052/asyncpg-0.30.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:aca1548e43bbb9f0f627a04666fedaca23db0a31a84136ad1f868cb15deb6e3a", size = 637025 }, + { url = "https://files.pythonhosted.org/packages/cc/57/2dc240bb263d58786cfaa60920779af6e8d32da63ab9ffc09f8312bd7a14/asyncpg-0.30.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c2a2ef565400234a633da0eafdce27e843836256d40705d83ab7ec42074efb3", size = 3496243 }, + { url = "https://files.pythonhosted.org/packages/f4/40/0ae9d061d278b10713ea9021ef6b703ec44698fe32178715a501ac696c6b/asyncpg-0.30.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1292b84ee06ac8a2ad8e51c7475aa309245874b61333d97411aab835c4a2f737", size = 3575059 }, + { url = "https://files.pythonhosted.org/packages/c3/75/d6b895a35a2c6506952247640178e5f768eeb28b2e20299b6a6f1d743ba0/asyncpg-0.30.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:0f5712350388d0cd0615caec629ad53c81e506b1abaaf8d14c93f54b35e3595a", size = 3473596 }, + { url = "https://files.pythonhosted.org/packages/c8/e7/3693392d3e168ab0aebb2d361431375bd22ffc7b4a586a0fc060d519fae7/asyncpg-0.30.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:db9891e2d76e6f425746c5d2da01921e9a16b5a71a1c905b13f30e12a257c4af", size = 3641632 }, + { url = "https://files.pythonhosted.org/packages/32/ea/15670cea95745bba3f0352341db55f506a820b21c619ee66b7d12ea7867d/asyncpg-0.30.0-cp312-cp312-win32.whl", hash = "sha256:68d71a1be3d83d0570049cd1654a9bdfe506e794ecc98ad0873304a9f35e411e", size = 560186 }, + { url = "https://files.pythonhosted.org/packages/7e/6b/fe1fad5cee79ca5f5c27aed7bd95baee529c1bf8a387435c8ba4fe53d5c1/asyncpg-0.30.0-cp312-cp312-win_amd64.whl", hash = "sha256:9a0292c6af5c500523949155ec17b7fe01a00ace33b68a476d6b5059f9630305", size = 621064 }, +] + [[package]] name = "attrs" version = "24.2.0" @@ -427,6 +449,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5a/5d/81aa3ddf94626806eb898b6d481a90a5e82bf55b10087556464ac05c120b/blis-1.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:376188493f590c4310ca534b687ef96c21c8224eb1ef4a0420703eebe175d6fa", size = 6370847 }, ] +[[package]] +name = "boto3" +version = "1.35.36" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "botocore" }, + { name = "jmespath" }, + { name = "s3transfer" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/33/9f/17536f9a1ab4c6ee454c782f27c9f0160558f70502fc55da62e456c47229/boto3-1.35.36.tar.gz", hash = "sha256:586524b623e4fbbebe28b604c6205eb12f263cc4746bccb011562d07e217a4cb", size = 110987 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/52/6b/8b126c2e1c07fae33185544ea974de67027afc905bd072feef9fbbd38d3d/boto3-1.35.36-py3-none-any.whl", hash = "sha256:33735b9449cd2ef176531ba2cb2265c904a91244440b0e161a17da9d24a1e6d1", size = 139143 }, +] + [[package]] name = "botocore" version = "1.35.36" @@ -510,7 +546,7 @@ name = "click" version = "8.1.7" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/96/d3/f04c7bfcf5c1862a2a5b845c6b2b360488cf47af55dfa79c98f6a6bf98b5/click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de", size = 336121 } wheels = [ @@ -584,37 +620,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0c/00/3106b1854b45bd0474ced037dfe6b73b90fe68a68968cef47c23de3d43d2/confection-0.1.5-py3-none-any.whl", hash = "sha256:e29d3c3f8eac06b3f77eb9dfb4bf2fc6bcc9622a98ca00a698e3d019c6430b14", size = 35451 }, ] -[[package]] -name = "cozo-embedded" -version = "0.7.6" -source = { registry = "https://pypi.org/simple" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/d3/17/e4a139cad601150303095532c51ab981b7b1ee9f6278188bedfe551c46e2/cozo_embedded-0.7.6-cp37-abi3-macosx_10_14_x86_64.whl", hash = "sha256:d146e76736beb5e14e0cf73dc8babefadfbbc358b325c94c64a51b6d5b0031e9", size = 9542067 }, - { url = "https://files.pythonhosted.org/packages/65/3b/92fe8c7c7b2b83974ae051c92697d92e860625326cfc06cb4c54222c2fc0/cozo_embedded-0.7.6-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:7341fa266369181bbc19ad9e68820b51900b0fe1c947318a3d860b570dca6e09", size = 8325766 }, - { url = "https://files.pythonhosted.org/packages/15/bf/19020af2645d8ea398e719bce8fcf7a91c341467aed9804c6d5f6ac878c2/cozo_embedded-0.7.6-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:80de79554138628967d4fd2636fc0a0a8dcca1c0c3bb527e638f1ee6cb763d7d", size = 10515504 }, - { url = "https://files.pythonhosted.org/packages/db/a7/3c96a4077520ee3179b5eaeba350132a854b3aca34d1168f335bfcd0038d/cozo_embedded-0.7.6-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7571f6521041c13b7e9ca8ab8809cf9c8eaad929726ed6190ffc25a5a3ab57a7", size = 11135792 }, - { url = "https://files.pythonhosted.org/packages/58/f7/5c6ec98d3983968df1d6709f1faa88a44b8c0fa7cd80994bc7f7d6b10293/cozo_embedded-0.7.6-cp37-abi3-win_amd64.whl", hash = "sha256:c945ab7b350d0b79d3e643b68ebc8343fc02d223a02ab929eb0fb8e4e0df3542", size = 9532612 }, -] - -[[package]] -name = "cozo-migrate" -version = "0.2.4" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "colorama" }, - { name = "cozo-embedded" }, - { name = "pandas" }, - { name = "pycozo" }, - { name = "requests" }, - { name = "rich" }, - { name = "shellingham" }, - { name = "typer" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/b1/3a/f66a88c50c5dd7bb7cb98d84f4d3e45bb2cfe1dba524f775f88b065b563b/cozo_migrate-0.2.4.tar.gz", hash = "sha256:ccb852f00bb25ff7c431dc8fa8a81e8f9f10198ad76aa34d1239d67f1613b899", size = 14317 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/26/ce/2dc5dc2be88ab79ed24b1412b7745c690e7f684e1665eb4feeb6300056bd/cozo_migrate-0.2.4-py3-none-any.whl", hash = "sha256:518151d65c81968e42402470418f42c8580e972f0b949df6c5c499cc2b098c1b", size = 21466 }, -] - [[package]] name = "cucumber-tag-expressions" version = "4.1.0" @@ -715,6 +720,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/68/1b/e0a87d256e40e8c888847551b20a017a6b98139178505dc7ffb96f04e954/dnspython-2.7.0-py3-none-any.whl", hash = "sha256:b4c34b7d10b51bcc3a5071e7b8dee77939f1e878477eeecc965e9835f63c6c86", size = 313632 }, ] +[[package]] +name = "docker" +version = "7.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pywin32", marker = "sys_platform == 'win32'" }, + { name = "requests" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/91/9b/4a2ea29aeba62471211598dac5d96825bb49348fa07e906ea930394a83ce/docker-7.1.0.tar.gz", hash = "sha256:ad8c70e6e3f8926cb8a92619b832b4ea5299e2831c14284663184e200546fa6c", size = 117834 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e3/26/57c6fb270950d476074c087527a558ccb6f4436657314bfb6cdf484114c4/docker-7.1.0-py3-none-any.whl", hash = "sha256:c96b93b7f0a746f9e77d325bcfb87422a3d8bd4f03136ae8a85b37f1898d5fc0", size = 147774 }, +] + [[package]] name = "email-validator" version = "2.2.0" @@ -993,7 +1012,7 @@ name = "ipykernel" version = "6.29.5" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "appnope", marker = "platform_system == 'Darwin'" }, + { name = "appnope", marker = "sys_platform == 'darwin'" }, { name = "comm" }, { name = "debugpy" }, { name = "ipython" }, @@ -2014,6 +2033,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9e/c3/059298687310d527a58bb01f3b1965787ee3b40dce76752eda8b44e9a2c5/pexpect-4.9.0-py2.py3-none-any.whl", hash = "sha256:7236d1e080e4936be2dc3e326cec0af72acf9212a7e1d060210e70a47e253523", size = 63772 }, ] +[[package]] +name = "pip" +version = "24.3.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f4/b1/b422acd212ad7eedddaf7981eee6e5de085154ff726459cf2da7c5a184c1/pip-24.3.1.tar.gz", hash = "sha256:ebcb60557f2aefabc2e0f918751cd24ea0d56d8ec5445fe1807f1d2109660b99", size = 1931073 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ef/7d/500c9ad20238fcfcb4cb9243eede163594d7020ce87bd9610c9e02771876/pip-24.3.1-py3-none-any.whl", hash = "sha256:3790624780082365f47549d032f3770eeb2b1e8bd1f7b2e02dace1afa361b4ed", size = 1822182 }, +] + [[package]] name = "platformdirs" version = "4.3.6" @@ -2186,35 +2214,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/20/af/7ba371f966657f6e7b1c9876cae7e9f1c5d3635c3df1329636b99e615494/pycnite-2024.7.31-py3-none-any.whl", hash = "sha256:9ff9c09d35056435b867e14ebf79626ca94b6017923a0bf9935377fa90d4cbb3", size = 22939 }, ] -[[package]] -name = "pycozo" -version = "0.7.6" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/92/18/dc0dd2db0f1661e2cf17a653da59b6812f30ddc976a66b7972fd5d2809bc/pycozo-0.7.6.tar.gz", hash = "sha256:e4be9a091ba71e9d4465179bbf7557d47af84c8114d4889bd5fa13c731d57a95", size = 19091 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/a1/e9/47ccff69e94bc80388c67e12b3c25244198fcfb1d3fad96489ed436a8e3f/pycozo-0.7.6-py3-none-any.whl", hash = "sha256:8930de5f82277d6481998a585c79aa898991cfb0692e168bde8b0a4558d579cf", size = 18977 }, -] - -[package.optional-dependencies] -embedded = [ - { name = "cozo-embedded" }, -] - -[[package]] -name = "pycozo-async" -version = "0.7.7" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "cozo-embedded" }, - { name = "httpx" }, - { name = "ipython" }, - { name = "pandas" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/01/17/2fc41dd8311f366625fc6fb70fe2dc27c345da8db0a4de78f39ccf759977/pycozo_async-0.7.7.tar.gz", hash = "sha256:fae95d8e9e11448263a752983b12a5a05b7656fa1dda0eeeb6f213d6fc592e1d", size = 21559 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/22/64/63330e6bd9bc30abfc863bd392c20c81f8ad1d6b5d1b6511d477496a6fbe/pycozo_async-0.7.7-py3-none-any.whl", hash = "sha256:2c23b184f6295d4dc6178350425110467e512638b3f4def937ed0609df321dd1", size = 22714 }, -] - [[package]] name = "pycparser" version = "2.22" @@ -2321,15 +2320,15 @@ wheels = [ [[package]] name = "pyright" -version = "1.1.389" +version = "1.1.391" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "nodeenv" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/72/4e/9a5ab8745e7606b88c2c7ca223449ac9d82a71fd5e31df47b453f2cb39a1/pyright-1.1.389.tar.gz", hash = "sha256:716bf8cc174ab8b4dcf6828c3298cac05c5ed775dda9910106a5dcfe4c7fe220", size = 21940 } +sdist = { url = "https://files.pythonhosted.org/packages/11/05/4ea52a8a45cc28897edb485b4102d37cbfd5fce8445d679cdeb62bfad221/pyright-1.1.391.tar.gz", hash = "sha256:66b2d42cdf5c3cbab05f2f4b76e8bec8aa78e679bfa0b6ad7b923d9e027cadb2", size = 21965 } wheels = [ - { url = "https://files.pythonhosted.org/packages/1b/26/c288cabf8cfc5a27e1aa9e5029b7682c0f920b8074f45d22bf844314d66a/pyright-1.1.389-py3-none-any.whl", hash = "sha256:41e9620bba9254406dc1f621a88ceab5a88af4c826feb4f614d95691ed243a60", size = 18581 }, + { url = "https://files.pythonhosted.org/packages/ad/89/66f49552fbeb21944c8077d11834b2201514a56fd1b7747ffff9630f1bd9/pyright-1.1.391-py3-none-any.whl", hash = "sha256:54fa186f8b3e8a55a44ebfa842636635688670c6896dcf6cf4a7fc75062f4d15", size = 18579 }, ] [[package]] @@ -2633,33 +2632,46 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/52/a9/d39f3c5ada0a3bb2870d7db41901125dbe2434fa4f12ca8c5b83a42d7c53/ruamel.yaml.clib-0.2.12-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:749c16fcc4a2b09f28843cda5a193e0283e47454b63ec4b81eaa2242f50e4ccd", size = 706497 }, { url = "https://files.pythonhosted.org/packages/b0/fa/097e38135dadd9ac25aecf2a54be17ddf6e4c23e43d538492a90ab3d71c6/ruamel.yaml.clib-0.2.12-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:bf165fef1f223beae7333275156ab2022cffe255dcc51c27f066b4370da81e31", size = 698042 }, { url = "https://files.pythonhosted.org/packages/ec/d5/a659ca6f503b9379b930f13bc6b130c9f176469b73b9834296822a83a132/ruamel.yaml.clib-0.2.12-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:32621c177bbf782ca5a18ba4d7af0f1082a3f6e517ac2a18b3974d4edf349680", size = 745831 }, + { url = "https://files.pythonhosted.org/packages/db/5d/36619b61ffa2429eeaefaab4f3374666adf36ad8ac6330d855848d7d36fd/ruamel.yaml.clib-0.2.12-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:b82a7c94a498853aa0b272fd5bc67f29008da798d4f93a2f9f289feb8426a58d", size = 715692 }, { url = "https://files.pythonhosted.org/packages/b1/82/85cb92f15a4231c89b95dfe08b09eb6adca929ef7df7e17ab59902b6f589/ruamel.yaml.clib-0.2.12-cp312-cp312-win32.whl", hash = "sha256:e8c4ebfcfd57177b572e2040777b8abc537cdef58a2120e830124946aa9b42c5", size = 98777 }, { url = "https://files.pythonhosted.org/packages/d7/8f/c3654f6f1ddb75daf3922c3d8fc6005b1ab56671ad56ffb874d908bfa668/ruamel.yaml.clib-0.2.12-cp312-cp312-win_amd64.whl", hash = "sha256:0467c5965282c62203273b838ae77c0d29d7638c8a4e3a1c8bdd3602c10904e4", size = 115523 }, ] [[package]] name = "ruff" -version = "0.8.1" +version = "0.8.4" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/95/d0/8ff5b189d125f4260f2255d143bf2fa413b69c2610c405ace7a0a8ec81ec/ruff-0.8.1.tar.gz", hash = "sha256:3583db9a6450364ed5ca3f3b4225958b24f78178908d5c4bc0f46251ccca898f", size = 3313222 } +sdist = { url = "https://files.pythonhosted.org/packages/34/37/9c02181ef38d55b77d97c68b78e705fd14c0de0e5d085202bb2b52ce5be9/ruff-0.8.4.tar.gz", hash = "sha256:0d5f89f254836799af1615798caa5f80b7f935d7a670fad66c5007928e57ace8", size = 3402103 } wheels = [ - { url = "https://files.pythonhosted.org/packages/a2/d6/1a6314e568db88acdbb5121ed53e2c52cebf3720d3437a76f82f923bf171/ruff-0.8.1-py3-none-linux_armv6l.whl", hash = "sha256:fae0805bd514066f20309f6742f6ee7904a773eb9e6c17c45d6b1600ca65c9b5", size = 10532605 }, - { url = "https://files.pythonhosted.org/packages/89/a8/a957a8812e31facffb6a26a30be0b5b4af000a6e30c7d43a22a5232a3398/ruff-0.8.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:b8a4f7385c2285c30f34b200ca5511fcc865f17578383db154e098150ce0a087", size = 10278243 }, - { url = "https://files.pythonhosted.org/packages/a8/23/9db40fa19c453fabf94f7a35c61c58f20e8200b4734a20839515a19da790/ruff-0.8.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:cd054486da0c53e41e0086e1730eb77d1f698154f910e0cd9e0d64274979a209", size = 9917739 }, - { url = "https://files.pythonhosted.org/packages/e2/a0/6ee2d949835d5701d832fc5acd05c0bfdad5e89cfdd074a171411f5ccad5/ruff-0.8.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2029b8c22da147c50ae577e621a5bfbc5d1fed75d86af53643d7a7aee1d23871", size = 10779153 }, - { url = "https://files.pythonhosted.org/packages/7a/25/9c11dca9404ef1eb24833f780146236131a3c7941de394bc356912ef1041/ruff-0.8.1-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2666520828dee7dfc7e47ee4ea0d928f40de72056d929a7c5292d95071d881d1", size = 10304387 }, - { url = "https://files.pythonhosted.org/packages/c8/b9/84c323780db1b06feae603a707d82dbbd85955c8c917738571c65d7d5aff/ruff-0.8.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:333c57013ef8c97a53892aa56042831c372e0bb1785ab7026187b7abd0135ad5", size = 11360351 }, - { url = "https://files.pythonhosted.org/packages/6b/e1/9d4bbb2ace7aad14ded20e4674a48cda5b902aed7a1b14e6b028067060c4/ruff-0.8.1-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:288326162804f34088ac007139488dcb43de590a5ccfec3166396530b58fb89d", size = 12022879 }, - { url = "https://files.pythonhosted.org/packages/75/28/752ff6120c0e7f9981bc4bc275d540c7f36db1379ba9db9142f69c88db21/ruff-0.8.1-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b12c39b9448632284561cbf4191aa1b005882acbc81900ffa9f9f471c8ff7e26", size = 11610354 }, - { url = "https://files.pythonhosted.org/packages/ba/8c/967b61c2cc8ebd1df877607fbe462bc1e1220b4a30ae3352648aec8c24bd/ruff-0.8.1-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:364e6674450cbac8e998f7b30639040c99d81dfb5bbc6dfad69bc7a8f916b3d1", size = 12813976 }, - { url = "https://files.pythonhosted.org/packages/7f/29/e059f945d6bd2d90213387b8c360187f2fefc989ddcee6bbf3c241329b92/ruff-0.8.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b22346f845fec132aa39cd29acb94451d030c10874408dbf776af3aaeb53284c", size = 11154564 }, - { url = "https://files.pythonhosted.org/packages/55/47/cbd05e5a62f3fb4c072bc65c1e8fd709924cad1c7ec60a1000d1e4ee8307/ruff-0.8.1-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:b2f2f7a7e7648a2bfe6ead4e0a16745db956da0e3a231ad443d2a66a105c04fa", size = 10760604 }, - { url = "https://files.pythonhosted.org/packages/bb/ee/4c3981c47147c72647a198a94202633130cfda0fc95cd863a553b6f65c6a/ruff-0.8.1-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:adf314fc458374c25c5c4a4a9270c3e8a6a807b1bec018cfa2813d6546215540", size = 10391071 }, - { url = "https://files.pythonhosted.org/packages/6b/e6/083eb61300214590b188616a8ac6ae1ef5730a0974240fb4bec9c17de78b/ruff-0.8.1-py3-none-musllinux_1_2_i686.whl", hash = "sha256:a885d68342a231b5ba4d30b8c6e1b1ee3a65cf37e3d29b3c74069cdf1ee1e3c9", size = 10896657 }, - { url = "https://files.pythonhosted.org/packages/77/bd/aacdb8285d10f1b943dbeb818968efca35459afc29f66ae3bd4596fbf954/ruff-0.8.1-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:d2c16e3508c8cc73e96aa5127d0df8913d2290098f776416a4b157657bee44c5", size = 11228362 }, - { url = "https://files.pythonhosted.org/packages/39/72/fcb7ad41947f38b4eaa702aca0a361af0e9c2bf671d7fd964480670c297e/ruff-0.8.1-py3-none-win32.whl", hash = "sha256:93335cd7c0eaedb44882d75a7acb7df4b77cd7cd0d2255c93b28791716e81790", size = 8803476 }, - { url = "https://files.pythonhosted.org/packages/e4/ea/cae9aeb0f4822c44651c8407baacdb2e5b4dcd7b31a84e1c5df33aa2cc20/ruff-0.8.1-py3-none-win_amd64.whl", hash = "sha256:2954cdbe8dfd8ab359d4a30cd971b589d335a44d444b6ca2cb3d1da21b75e4b6", size = 9614463 }, - { url = "https://files.pythonhosted.org/packages/eb/76/fbb4bd23dfb48fa7758d35b744413b650a9fd2ddd93bca77e30376864414/ruff-0.8.1-py3-none-win_arm64.whl", hash = "sha256:55873cc1a473e5ac129d15eccb3c008c096b94809d693fc7053f588b67822737", size = 8959621 }, + { url = "https://files.pythonhosted.org/packages/05/67/f480bf2f2723b2e49af38ed2be75ccdb2798fca7d56279b585c8f553aaab/ruff-0.8.4-py3-none-linux_armv6l.whl", hash = "sha256:58072f0c06080276804c6a4e21a9045a706584a958e644353603d36ca1eb8a60", size = 10546415 }, + { url = "https://files.pythonhosted.org/packages/eb/7a/5aba20312c73f1ce61814e520d1920edf68ca3b9c507bd84d8546a8ecaa8/ruff-0.8.4-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:ffb60904651c00a1e0b8df594591770018a0f04587f7deeb3838344fe3adabac", size = 10346113 }, + { url = "https://files.pythonhosted.org/packages/76/f4/c41de22b3728486f0aa95383a44c42657b2db4062f3234ca36fc8cf52d8b/ruff-0.8.4-py3-none-macosx_11_0_arm64.whl", hash = "sha256:6ddf5d654ac0d44389f6bf05cee4caeefc3132a64b58ea46738111d687352296", size = 9943564 }, + { url = "https://files.pythonhosted.org/packages/0e/f0/afa0d2191af495ac82d4cbbfd7a94e3df6f62a04ca412033e073b871fc6d/ruff-0.8.4-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e248b1f0fa2749edd3350a2a342b67b43a2627434c059a063418e3d375cfe643", size = 10805522 }, + { url = "https://files.pythonhosted.org/packages/12/57/5d1e9a0fd0c228e663894e8e3a8e7063e5ee90f8e8e60cf2085f362bfa1a/ruff-0.8.4-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bf197b98ed86e417412ee3b6c893f44c8864f816451441483253d5ff22c0e81e", size = 10306763 }, + { url = "https://files.pythonhosted.org/packages/04/df/f069fdb02e408be8aac6853583572a2873f87f866fe8515de65873caf6b8/ruff-0.8.4-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c41319b85faa3aadd4d30cb1cffdd9ac6b89704ff79f7664b853785b48eccdf3", size = 11359574 }, + { url = "https://files.pythonhosted.org/packages/d3/04/37c27494cd02e4a8315680debfc6dfabcb97e597c07cce0044db1f9dfbe2/ruff-0.8.4-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:9f8402b7c4f96463f135e936d9ab77b65711fcd5d72e5d67597b543bbb43cf3f", size = 12094851 }, + { url = "https://files.pythonhosted.org/packages/81/b1/c5d7fb68506cab9832d208d03ea4668da9a9887a4a392f4f328b1bf734ad/ruff-0.8.4-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e4e56b3baa9c23d324ead112a4fdf20db9a3f8f29eeabff1355114dd96014604", size = 11655539 }, + { url = "https://files.pythonhosted.org/packages/ef/38/8f8f2c8898dc8a7a49bc340cf6f00226917f0f5cb489e37075bcb2ce3671/ruff-0.8.4-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:736272574e97157f7edbbb43b1d046125fce9e7d8d583d5d65d0c9bf2c15addf", size = 12912805 }, + { url = "https://files.pythonhosted.org/packages/06/dd/fa6660c279f4eb320788876d0cff4ea18d9af7d9ed7216d7bd66877468d0/ruff-0.8.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e5fe710ab6061592521f902fca7ebcb9fabd27bc7c57c764298b1c1f15fff720", size = 11205976 }, + { url = "https://files.pythonhosted.org/packages/a8/d7/de94cc89833b5de455750686c17c9e10f4e1ab7ccdc5521b8fe911d1477e/ruff-0.8.4-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:13e9ec6d6b55f6da412d59953d65d66e760d583dd3c1c72bf1f26435b5bfdbae", size = 10792039 }, + { url = "https://files.pythonhosted.org/packages/6d/15/3e4906559248bdbb74854af684314608297a05b996062c9d72e0ef7c7097/ruff-0.8.4-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:97d9aefef725348ad77d6db98b726cfdb075a40b936c7984088804dfd38268a7", size = 10400088 }, + { url = "https://files.pythonhosted.org/packages/a2/21/9ed4c0e8133cb4a87a18d470f534ad1a8a66d7bec493bcb8bda2d1a5d5be/ruff-0.8.4-py3-none-musllinux_1_2_i686.whl", hash = "sha256:ab78e33325a6f5374e04c2ab924a3367d69a0da36f8c9cb6b894a62017506111", size = 10900814 }, + { url = "https://files.pythonhosted.org/packages/0d/5d/122a65a18955bd9da2616b69bc839351f8baf23b2805b543aa2f0aed72b5/ruff-0.8.4-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:8ef06f66f4a05c3ddbc9121a8b0cecccd92c5bf3dd43b5472ffe40b8ca10f0f8", size = 11268828 }, + { url = "https://files.pythonhosted.org/packages/43/a9/1676ee9106995381e3d34bccac5bb28df70194167337ed4854c20f27c7ba/ruff-0.8.4-py3-none-win32.whl", hash = "sha256:552fb6d861320958ca5e15f28b20a3d071aa83b93caee33a87b471f99a6c0835", size = 8805621 }, + { url = "https://files.pythonhosted.org/packages/10/98/ed6b56a30ee76771c193ff7ceeaf1d2acc98d33a1a27b8479cbdb5c17a23/ruff-0.8.4-py3-none-win_amd64.whl", hash = "sha256:f21a1143776f8656d7f364bd264a9d60f01b7f52243fbe90e7670c0dfe0cf65d", size = 9660086 }, + { url = "https://files.pythonhosted.org/packages/13/9f/026e18ca7d7766783d779dae5e9c656746c6ede36ef73c6d934aaf4a6dec/ruff-0.8.4-py3-none-win_arm64.whl", hash = "sha256:9183dd615d8df50defa8b1d9a074053891ba39025cf5ae88e8bcb52edcc4bf08", size = 9074500 }, +] + +[[package]] +name = "s3transfer" +version = "0.10.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "botocore" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c0/0a/1cdbabf9edd0ea7747efdf6c9ab4e7061b085aa7f9bfc36bb1601563b069/s3transfer-0.10.4.tar.gz", hash = "sha256:29edc09801743c21eb5ecbc617a152df41d3c287f67b615f73e5f750583666a7", size = 145287 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/66/05/7957af15543b8c9799209506df4660cba7afc4cf94bfb60513827e96bed6/s3transfer-0.10.4-py3-none-any.whl", hash = "sha256:244a76a24355363a68164241438de1b72f8781664920260c48465896b712a41e", size = 83175 }, ] [[package]] @@ -2853,6 +2865,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/33/78/d1a1a026ef3af911159398c939b1509d5c36fe524c7b644f34a5146c4e16/spacy_loggers-1.0.5-py3-none-any.whl", hash = "sha256:196284c9c446cc0cdb944005384270d775fdeaf4f494d8e269466cfa497ef645", size = 22343 }, ] +[[package]] +name = "sqlvalidator" +version = "0.0.20" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/21/7f/bd1ba351693e60b4dcddd3a84dad89ea75cbc627f9631da17809761a3eb4/sqlvalidator-0.0.20.tar.gz", hash = "sha256:6f399be1bf0ba54a17ad16f6818836c169d17c16306f4cfa6fc883f13b1705fc", size = 24291 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5f/9d/5434c2b90dac2a8ab12d42027398e2012d1ce347a0bcc9500525d05ac1ee/sqlvalidator-0.0.20-py3-none-any.whl", hash = "sha256:8820752d9ec5ccb9cc977099edf991f0090acf4f1e4beb0f2fb35a6e1cc03c89", size = 24182 }, +] + [[package]] name = "srsly" version = "2.4.8" @@ -2974,6 +2995,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6a/9e/2064975477fdc887e47ad42157e214526dcad8f317a948dee17e1659a62f/terminado-0.18.1-py3-none-any.whl", hash = "sha256:a4468e1b37bb318f8a86514f65814e1afc977cf29b3992a4500d9dd305dcceb0", size = 14154 }, ] +[[package]] +name = "testcontainers" +version = "4.9.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "docker" }, + { name = "python-dotenv" }, + { name = "typing-extensions" }, + { name = "urllib3" }, + { name = "wrapt" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/2e/9a/e1ac5231231192b39302fcad7de2c0dbfc718c0636d7e28917c30ec57c41/testcontainers-4.9.0.tar.gz", hash = "sha256:2cd6af070109ff68c1ab5389dc89c86c2dc3ab30a21ca734b2cb8f0f80ad479e", size = 64612 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3e/f8/6425ff800894784160290bcb9737878d910b6da6a08633bfe7f2ed8c9ae3/testcontainers-4.9.0-py3-none-any.whl", hash = "sha256:c6fee929990972c40bf6b91b7072c94064ff3649b405a14fde0274c8b2479d32", size = 105324 }, +] + +[package.optional-dependencies] +localstack = [ + { name = "boto3" }, +] + [[package]] name = "thefuzz" version = "0.22.1" @@ -3120,7 +3162,7 @@ name = "tqdm" version = "4.67.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737 } wheels = [ @@ -3187,6 +3229,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a6/ab/7e5f53c3b9d14972843a647d8d7a853969a58aecc7559cb3267302c94774/tzdata-2024.2-py2.py3-none-any.whl", hash = "sha256:a48093786cdcde33cad18c2555e8532f34422074448fbc874186f0abd79565cd", size = 346586 }, ] +[[package]] +name = "unique-namer" +version = "1.6.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/04/47/26e9f45b64ad2d7c77eefb48a0e84ae0c0070fa812bf6ab95584559ce53c/unique_namer-1.6.1.tar.gz", hash = "sha256:7f4e3143f923c24baaed56bb93726e10669333271caa71ffd5d8f1a928a5befe", size = 73334 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fb/72/e06078006bbc3635490b872e8647294cf5921f378634de43520012b7c09e/unique_namer-1.6.1-py3-none-any.whl", hash = "sha256:6e76751c0886244625b43a8e5e7c18168a9205f5a944c0dbbbd9eb219c4812f2", size = 71111 }, +] + [[package]] name = "uri-template" version = "1.3.0" @@ -3205,6 +3256,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ce/d9/5f4c13cecde62396b0d3fe530a50ccea91e7dfc1ccf0e09c228841bb5ba8/urllib3-2.2.3-py3-none-any.whl", hash = "sha256:ca899ca043dcb1bafa3e262d73aa25c465bfb49e0bd9dd5d59f1d0acba2f8fac", size = 126338 }, ] +[[package]] +name = "uuid7" +version = "0.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/5c/19/7472bd526591e2192926247109dbf78692e709d3e56775792fec877a7720/uuid7-0.1.0.tar.gz", hash = "sha256:8c57aa32ee7456d3cc68c95c4530bc571646defac01895cfc73545449894a63c", size = 14052 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b5/77/8852f89a91453956582a85024d80ad96f30a41fed4c2b3dce0c9f12ecc7e/uuid7-0.1.0-py2.py3-none-any.whl", hash = "sha256:5e259bb63c8cb4aded5927ff41b444a80d0c7124e8a0ced7cf44efa1f5cccf61", size = 7477 }, +] + [[package]] name = "uvicorn" version = "0.30.6" diff --git a/blob-store/docker-compose.yml b/blob-store/docker-compose.yml index 089b31f39..64d238df4 100644 --- a/blob-store/docker-compose.yml +++ b/blob-store/docker-compose.yml @@ -12,7 +12,7 @@ services: environment: - S3_ACCESS_KEY=${S3_ACCESS_KEY} - S3_SECRET_KEY=${S3_SECRET_KEY} - - DEBUG=${DEBUG:-true} + - DEBUG=${DEBUG:-false} ports: - 9333:9333 # master port diff --git a/deploy/simple-docker-compose.yaml b/deploy/simple-docker-compose.yaml index 0b21af407..c87e78174 100644 --- a/deploy/simple-docker-compose.yaml +++ b/deploy/simple-docker-compose.yaml @@ -13,8 +13,6 @@ services: AGENTS_API_PROTOCOL: http AGENTS_API_PUBLIC_PORT: "80" AGENTS_API_URL: http://agents-api:8080 - COZO_AUTH_TOKEN: ${COZO_AUTH_TOKEN} - COZO_HOST: http://memory-store:9070 EMBEDDING_MODEL_ID: voyage/voyage-3 INTEGRATION_SERVICE_URL: http://integrations:8000 LITELLM_MASTER_KEY: ${LITELLM_MASTER_KEY} @@ -35,32 +33,6 @@ services: published: "8080" protocol: tcp - cozo-migrate: - environment: - AGENTS_API_HOSTNAME: localhost - AGENTS_API_KEY: ${AGENTS_API_KEY} - AGENTS_API_KEY_HEADER_NAME: Authorization - AGENTS_API_PROTOCOL: http - AGENTS_API_PUBLIC_PORT: "80" - AGENTS_API_URL: http://agents-api:8080 - COZO_AUTH_TOKEN: ${COZO_AUTH_TOKEN} - COZO_HOST: http://memory-store:9070 - EMBEDDING_MODEL_ID: voyage/voyage-3 - INTEGRATION_SERVICE_URL: http://integrations:8000 - LITELLM_MASTER_KEY: ${LITELLM_MASTER_KEY} - LITELLM_URL: http://litellm:4000 - SUMMARIZATION_MODEL_NAME: gpt-4o-mini - TEMPORAL_ENDPOINT: temporal:7233 - TEMPORAL_NAMESPACE: default - TEMPORAL_TASK_QUEUE: julep-task-queue - TEMPORAL_WORKER_URL: temporal:7233 - TRUNCATE_EMBED_TEXT: "True" - WORKER_URL: temporal:7233 - image: julepai/cozo-migrate:${TAG:-dev} - networks: - default: null - restart: "no" - integrations: image: julepai/integrations:${TAG:-dev} environment: @@ -156,56 +128,11 @@ services: target: /data volume: {} - memory-store: - environment: - COZO_AUTH_TOKEN: ${COZO_AUTH_TOKEN} - COZO_BACKUP_DIR: /backup - COZO_MNT_DIR: /data - COZO_PORT: "9070" - image: julepai/memory-store:${TAG:-dev} - labels: - ofelia.enabled: "true" - ofelia.job-exec.backupcron.command: bash /app/backup.sh - ofelia.job-exec.backupcron.environment: '["COZO_PORT=9070", "COZO_AUTH_TOKEN=${COZO_AUTH_TOKEN}", "COZO_BACKUP_DIR=/backup"]' - ofelia.job-exec.backupcron.schedule: '@every 3h' - networks: - default: null - ports: - - mode: ingress - target: 9070 - published: "9070" - protocol: tcp - volumes: - - type: volume - source: cozo_data - target: /data - volume: {} - - type: volume - source: cozo_backup - target: /backup - volume: {} + # TODO: Add memory-store with postgres + # memory-store: - memory-store-backup-cron: - command: - - daemon - - --docker - - -f - - label=com.docker.compose.project=julep - depends_on: - memory-store: - condition: service_started - required: true - image: mcuadros/ofelia:latest - networks: - default: null - restart: unless-stopped - volumes: - - type: bind - source: /var/run/docker.sock - target: /var/run/docker.sock - read_only: true - bind: - create_host_path: true + # TODO: Add memory-store-backup-cron + # memory-store-backup-cron: temporal: depends_on: @@ -295,8 +222,7 @@ services: AGENTS_API_PROTOCOL: http AGENTS_API_PUBLIC_PORT: "80" AGENTS_API_URL: http://agents-api:8080 - COZO_AUTH_TOKEN: ${COZO_AUTH_TOKEN} - COZO_HOST: http://memory-store:9070 + PG_DSN: ${PG_DSN:-postgres://postgres:postgres@memory-store:5432/postgres} EMBEDDING_MODEL_ID: voyage/voyage-3 INTEGRATION_SERVICE_URL: http://integrations:8000 LITELLM_MASTER_KEY: ${LITELLM_MASTER_KEY} @@ -317,10 +243,6 @@ networks: name: julep_default volumes: - cozo_backup: - name: cozo_backup - cozo_data: - name: cozo_data litellm-db-data: name: julep_litellm-db-data litellm-redis-data: diff --git a/drafts/cozo b/drafts/cozo new file mode 160000 index 000000000..faf89ef77 --- /dev/null +++ b/drafts/cozo @@ -0,0 +1 @@ +Subproject commit faf89ef77e6462460f873e9de618001d968a1a40 diff --git a/embedding-service/docker-compose.yml b/embedding-service/docker-compose.yml index 73df579be..a51a93e7f 100644 --- a/embedding-service/docker-compose.yml +++ b/embedding-service/docker-compose.yml @@ -17,8 +17,7 @@ x--shared-environment: &shared-environment AGENTS_API_KEY_HEADER_NAME: ${AGENTS_API_KEY_HEADER_NAME:-Authorization} AGENTS_API_HOSTNAME: ${AGENTS_API_HOSTNAME:-localhost} AGENTS_API_URL: ${AGENTS_API_URL:-http://agents-api:8080} - COZO_AUTH_TOKEN: ${COZO_AUTH_TOKEN} - COZO_HOST: ${COZO_HOST:-http://memory-store:9070} + PG_DSN: ${PG_DSN:-postgres://postgres:postgres@memory-store:5432/postgres} DEBUG: ${AGENTS_API_DEBUG:-False} EMBEDDING_MODEL_ID: ${EMBEDDING_MODEL_ID:-Alibaba-NLP/gte-large-en-v1.5} LITELLM_MASTER_KEY: ${LITELLM_MASTER_KEY} diff --git a/integrations-service/gunicorn_conf.py b/integrations-service/gunicorn_conf.py index e7fad22a5..77b9d3009 100644 --- a/integrations-service/gunicorn_conf.py +++ b/integrations-service/gunicorn_conf.py @@ -7,9 +7,7 @@ # Gunicorn config variables workers = ( - (multiprocessing.cpu_count() // 2) - if not (TESTING or AGENTS_API_DEBUG or DEBUG) - else 1 + (multiprocessing.cpu_count() // 2) if not (TESTING or AGENTS_API_DEBUG or DEBUG) else 1 ) worker_class = "uvicorn.workers.UvicornWorker" bind = "0.0.0.0:8000" diff --git a/integrations-service/integrations/autogen/Agents.py b/integrations-service/integrations/autogen/Agents.py index 5dab2c7b2..7390b6338 100644 --- a/integrations-service/integrations/autogen/Agents.py +++ b/integrations-service/integrations/autogen/Agents.py @@ -25,16 +25,17 @@ class Agent(BaseModel): """ When this resource was updated as UTC date-time """ - name: Annotated[ - str, - Field( - max_length=120, - pattern="^[\\p{L}\\p{Nl}\\p{Pattern_Syntax}\\p{Pattern_White_Space}]+[\\p{ID_Start}\\p{Mn}\\p{Mc}\\p{Nd}\\p{Pc}\\p{Pattern_Syntax}\\p{Pattern_White_Space}]*$", - ), - ] = "" + name: Annotated[str, Field(max_length=255, min_length=1)] """ Name of the agent """ + canonical_name: Annotated[ + str | None, + Field(max_length=255, min_length=1, pattern="^[a-zA-Z][a-zA-Z0-9_]*$"), + ] = None + """ + Canonical name of the agent + """ about: str = "" """ About the agent @@ -62,16 +63,17 @@ class CreateAgentRequest(BaseModel): populate_by_name=True, ) metadata: dict[str, Any] | None = None - name: Annotated[ - str, - Field( - max_length=120, - pattern="^[\\p{L}\\p{Nl}\\p{Pattern_Syntax}\\p{Pattern_White_Space}]+[\\p{ID_Start}\\p{Mn}\\p{Mc}\\p{Nd}\\p{Pc}\\p{Pattern_Syntax}\\p{Pattern_White_Space}]*$", - ), - ] = "" + name: Annotated[str, Field(max_length=255, min_length=1)] """ Name of the agent """ + canonical_name: Annotated[ + str | None, + Field(max_length=255, min_length=1, pattern="^[a-zA-Z][a-zA-Z0-9_]*$"), + ] = None + """ + Canonical name of the agent + """ about: str = "" """ About the agent @@ -96,16 +98,17 @@ class CreateOrUpdateAgentRequest(CreateAgentRequest): ) id: UUID metadata: dict[str, Any] | None = None - name: Annotated[ - str, - Field( - max_length=120, - pattern="^[\\p{L}\\p{Nl}\\p{Pattern_Syntax}\\p{Pattern_White_Space}]+[\\p{ID_Start}\\p{Mn}\\p{Mc}\\p{Nd}\\p{Pc}\\p{Pattern_Syntax}\\p{Pattern_White_Space}]*$", - ), - ] = "" + name: Annotated[str, Field(max_length=255, min_length=1)] """ Name of the agent """ + canonical_name: Annotated[ + str | None, + Field(max_length=255, min_length=1, pattern="^[a-zA-Z][a-zA-Z0-9_]*$"), + ] = None + """ + Canonical name of the agent + """ about: str = "" """ About the agent @@ -133,16 +136,17 @@ class PatchAgentRequest(BaseModel): populate_by_name=True, ) metadata: dict[str, Any] | None = None - name: Annotated[ - str, - Field( - max_length=120, - pattern="^[\\p{L}\\p{Nl}\\p{Pattern_Syntax}\\p{Pattern_White_Space}]+[\\p{ID_Start}\\p{Mn}\\p{Mc}\\p{Nd}\\p{Pc}\\p{Pattern_Syntax}\\p{Pattern_White_Space}]*$", - ), - ] = "" + name: Annotated[str | None, Field(max_length=255, min_length=1)] = None """ Name of the agent """ + canonical_name: Annotated[ + str | None, + Field(max_length=255, min_length=1, pattern="^[a-zA-Z][a-zA-Z0-9_]*$"), + ] = None + """ + Canonical name of the agent + """ about: str = "" """ About the agent @@ -170,16 +174,17 @@ class UpdateAgentRequest(BaseModel): populate_by_name=True, ) metadata: dict[str, Any] | None = None - name: Annotated[ - str, - Field( - max_length=120, - pattern="^[\\p{L}\\p{Nl}\\p{Pattern_Syntax}\\p{Pattern_White_Space}]+[\\p{ID_Start}\\p{Mn}\\p{Mc}\\p{Nd}\\p{Pc}\\p{Pattern_Syntax}\\p{Pattern_White_Space}]*$", - ), - ] = "" + name: Annotated[str, Field(max_length=255, min_length=1)] """ Name of the agent """ + canonical_name: Annotated[ + str | None, + Field(max_length=255, min_length=1, pattern="^[a-zA-Z][a-zA-Z0-9_]*$"), + ] = None + """ + Canonical name of the agent + """ about: str = "" """ About the agent diff --git a/integrations-service/integrations/autogen/Chat.py b/integrations-service/integrations/autogen/Chat.py index 042f9164d..13dcc9532 100644 --- a/integrations-service/integrations/autogen/Chat.py +++ b/integrations-service/integrations/autogen/Chat.py @@ -59,9 +59,7 @@ class BaseChatResponse(BaseModel): """ Background job IDs that may have been spawned from this interaction. """ - docs: Annotated[ - list[DocReference], Field(json_schema_extra={"readOnly": True}) - ] = [] + docs: Annotated[list[DocReference], Field(json_schema_extra={"readOnly": True})] = [] """ Documents referenced for this request (for citation purposes). """ @@ -134,21 +132,15 @@ class CompetionUsage(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - completion_tokens: Annotated[ - int | None, Field(json_schema_extra={"readOnly": True}) - ] = None + completion_tokens: Annotated[int | None, Field(json_schema_extra={"readOnly": True})] = None """ Number of tokens in the generated completion """ - prompt_tokens: Annotated[ - int | None, Field(json_schema_extra={"readOnly": True}) - ] = None + prompt_tokens: Annotated[int | None, Field(json_schema_extra={"readOnly": True})] = None """ Number of tokens in the prompt """ - total_tokens: Annotated[int | None, Field(json_schema_extra={"readOnly": True})] = ( - None - ) + total_tokens: Annotated[int | None, Field(json_schema_extra={"readOnly": True})] = None """ Total number of tokens used in the request (prompt + completion) """ @@ -429,9 +421,9 @@ class MessageModel(BaseModel): """ Tool calls generated by the model. """ - created_at: Annotated[ - AwareDatetime | None, Field(json_schema_extra={"readOnly": True}) - ] = None + created_at: Annotated[AwareDatetime | None, Field(json_schema_extra={"readOnly": True})] = ( + None + ) """ When this resource was created as UTC date-time """ @@ -576,9 +568,9 @@ class ChatInput(ChatInputData): """ Modify the likelihood of specified tokens appearing in the completion """ - response_format: ( - SimpleCompletionResponseFormat | SchemaCompletionResponseFormat | None - ) = None + response_format: SimpleCompletionResponseFormat | SchemaCompletionResponseFormat | None = ( + None + ) """ Response format (set to `json_object` to restrict output to JSON) """ @@ -672,9 +664,9 @@ class ChatSettings(DefaultChatSettings): """ Modify the likelihood of specified tokens appearing in the completion """ - response_format: ( - SimpleCompletionResponseFormat | SchemaCompletionResponseFormat | None - ) = None + response_format: SimpleCompletionResponseFormat | SchemaCompletionResponseFormat | None = ( + None + ) """ Response format (set to `json_object` to restrict output to JSON) """ diff --git a/integrations-service/integrations/autogen/Docs.py b/integrations-service/integrations/autogen/Docs.py index ffed27c1d..28a421ba5 100644 --- a/integrations-service/integrations/autogen/Docs.py +++ b/integrations-service/integrations/autogen/Docs.py @@ -73,6 +73,24 @@ class Doc(BaseModel): """ Embeddings for the document """ + modality: Annotated[str | None, Field(json_schema_extra={"readOnly": True})] = None + """ + Modality of the document + """ + language: Annotated[str | None, Field(json_schema_extra={"readOnly": True})] = None + """ + Language of the document + """ + embedding_model: Annotated[str | None, Field(json_schema_extra={"readOnly": True})] = None + """ + Embedding model used for the document + """ + embedding_dimensions: Annotated[int | None, Field(json_schema_extra={"readOnly": True})] = ( + None + ) + """ + Dimensions of the embedding model + """ class DocOwner(BaseModel): diff --git a/integrations-service/integrations/autogen/Entries.py b/integrations-service/integrations/autogen/Entries.py index de37e77d8..867b10192 100644 --- a/integrations-service/integrations/autogen/Entries.py +++ b/integrations-service/integrations/autogen/Entries.py @@ -52,6 +52,7 @@ class BaseEntry(BaseModel): ] tokenizer: str token_count: int + model: str = "gpt-4o-mini" tool_calls: ( list[ ChosenFunctionCall diff --git a/integrations-service/integrations/autogen/Executions.py b/integrations-service/integrations/autogen/Executions.py index 5ccc57e83..36a36b7a5 100644 --- a/integrations-service/integrations/autogen/Executions.py +++ b/integrations-service/integrations/autogen/Executions.py @@ -181,8 +181,6 @@ class Transition(TransitionEvent): ) execution_id: Annotated[UUID, Field(json_schema_extra={"readOnly": True})] current: Annotated[TransitionTarget, Field(json_schema_extra={"readOnly": True})] - next: Annotated[ - TransitionTarget | None, Field(json_schema_extra={"readOnly": True}) - ] + next: Annotated[TransitionTarget | None, Field(json_schema_extra={"readOnly": True})] id: Annotated[UUID, Field(json_schema_extra={"readOnly": True})] metadata: dict[str, Any] | None = None diff --git a/integrations-service/integrations/autogen/Sessions.py b/integrations-service/integrations/autogen/Sessions.py index 460fd25ce..20c9885b1 100644 --- a/integrations-service/integrations/autogen/Sessions.py +++ b/integrations-service/integrations/autogen/Sessions.py @@ -27,9 +27,13 @@ class CreateSessionRequest(BaseModel): Agent ID of agent associated with this session """ agents: list[UUID] | None = None - situation: str = '{%- if agent.name -%}\nYou are {{agent.name}}.{{" "}}\n{%- endif -%}\n\n{%- if agent.about -%}\nAbout you: {{agent.about}}.{{" "}}\n{%- endif -%}\n\n{%- if user -%}\nYou are talking to a user\n {%- if user.name -%}{{" "}} and their name is {{user.name}}\n {%- if user.about -%}. About the user: {{user.about}}.{%- else -%}.{%- endif -%}\n {%- endif -%}\n{%- endif -%}\n\n{{NEWLINE+NEWLINE}}\n\n{%- if agent.instructions -%}\nInstructions:{{NEWLINE}}\n {%- if agent.instructions is string -%}\n {{agent.instructions}}{{NEWLINE}}\n {%- else -%}\n {%- for instruction in agent.instructions -%}\n - {{instruction}}{{NEWLINE}}\n {%- endfor -%}\n {%- endif -%}\n {{NEWLINE}}\n{%- endif -%}\n\n{%- if tools -%}\nTools:{{NEWLINE}}\n {%- for tool in tools -%}\n - {{tool.name + NEWLINE}}\n {%- if tool.description -%}: {{tool.description + NEWLINE}}{%- endif -%}\n {%- endfor -%}\n{{NEWLINE+NEWLINE}}\n{%- endif -%}\n\n{%- if docs -%}\nRelevant documents:{{NEWLINE}}\n {%- for doc in docs -%}\n {{doc.title}}{{NEWLINE}}\n {%- if doc.content is string -%}\n {{doc.content}}{{NEWLINE}}\n {%- else -%}\n {%- for snippet in doc.content -%}\n {{snippet}}{{NEWLINE}}\n {%- endfor -%}\n {%- endif -%}\n {{"---"}}\n {%- endfor -%}\n{%- endif -%}' + situation: str | None = None """ - A specific situation that sets the background for this session + Session situation + """ + system_template: str = '{%- if agent.name -%}\nYou are {{agent.name}}.{{" "}}\n{%- endif -%}\n\n{%- if agent.about -%}\nAbout you: {{agent.about}}.{{" "}}\n{%- endif -%}\n\n{%- if user -%}\nYou are talking to a user\n {%- if user.name -%}{{" "}} and their name is {{user.name}}\n {%- if user.about -%}. About the user: {{user.about}}.{%- else -%}.{%- endif -%}\n {%- endif -%}\n{%- endif -%}\n\n{{NEWLINE}}\n\n{%- if session.situation -%}\nSituation: {{session.situation}}\n{%- endif -%}\n\n{{NEWLINE+NEWLINE}}\n\n{%- if agent.instructions -%}\nInstructions:{{NEWLINE}}\n {%- if agent.instructions is string -%}\n {{agent.instructions}}{{NEWLINE}}\n {%- else -%}\n {%- for instruction in agent.instructions -%}\n - {{instruction}}{{NEWLINE}}\n {%- endfor -%}\n {%- endif -%}\n {{NEWLINE}}\n{%- endif -%}\n\n{%- if docs -%}\nRelevant documents:{{NEWLINE}}\n {%- for doc in docs -%}\n {{doc.title}}{{NEWLINE}}\n {%- if doc.content is string -%}\n {{doc.content}}{{NEWLINE}}\n {%- else -%}\n {%- for snippet in doc.content -%}\n {{snippet}}{{NEWLINE}}\n {%- endfor -%}\n {%- endif -%}\n {{"---"}}\n {%- endfor -%}\n{%- endif -%}' + """ + A specific system prompt template that sets the background for this session """ render_templates: StrictBool = True """ @@ -51,6 +55,10 @@ class CreateSessionRequest(BaseModel): If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. """ + forward_tool_calls: StrictBool = False + """ + Whether to forward tool calls to the model + """ recall_options: RecallOptions | None = None metadata: dict[str, Any] | None = None @@ -63,9 +71,13 @@ class PatchSessionRequest(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - situation: str = '{%- if agent.name -%}\nYou are {{agent.name}}.{{" "}}\n{%- endif -%}\n\n{%- if agent.about -%}\nAbout you: {{agent.about}}.{{" "}}\n{%- endif -%}\n\n{%- if user -%}\nYou are talking to a user\n {%- if user.name -%}{{" "}} and their name is {{user.name}}\n {%- if user.about -%}. About the user: {{user.about}}.{%- else -%}.{%- endif -%}\n {%- endif -%}\n{%- endif -%}\n\n{{NEWLINE+NEWLINE}}\n\n{%- if agent.instructions -%}\nInstructions:{{NEWLINE}}\n {%- if agent.instructions is string -%}\n {{agent.instructions}}{{NEWLINE}}\n {%- else -%}\n {%- for instruction in agent.instructions -%}\n - {{instruction}}{{NEWLINE}}\n {%- endfor -%}\n {%- endif -%}\n {{NEWLINE}}\n{%- endif -%}\n\n{%- if tools -%}\nTools:{{NEWLINE}}\n {%- for tool in tools -%}\n - {{tool.name + NEWLINE}}\n {%- if tool.description -%}: {{tool.description + NEWLINE}}{%- endif -%}\n {%- endfor -%}\n{{NEWLINE+NEWLINE}}\n{%- endif -%}\n\n{%- if docs -%}\nRelevant documents:{{NEWLINE}}\n {%- for doc in docs -%}\n {{doc.title}}{{NEWLINE}}\n {%- if doc.content is string -%}\n {{doc.content}}{{NEWLINE}}\n {%- else -%}\n {%- for snippet in doc.content -%}\n {{snippet}}{{NEWLINE}}\n {%- endfor -%}\n {%- endif -%}\n {{"---"}}\n {%- endfor -%}\n{%- endif -%}' + situation: str | None = None + """ + Session situation """ - A specific situation that sets the background for this session + system_template: str = '{%- if agent.name -%}\nYou are {{agent.name}}.{{" "}}\n{%- endif -%}\n\n{%- if agent.about -%}\nAbout you: {{agent.about}}.{{" "}}\n{%- endif -%}\n\n{%- if user -%}\nYou are talking to a user\n {%- if user.name -%}{{" "}} and their name is {{user.name}}\n {%- if user.about -%}. About the user: {{user.about}}.{%- else -%}.{%- endif -%}\n {%- endif -%}\n{%- endif -%}\n\n{{NEWLINE}}\n\n{%- if session.situation -%}\nSituation: {{session.situation}}\n{%- endif -%}\n\n{{NEWLINE+NEWLINE}}\n\n{%- if agent.instructions -%}\nInstructions:{{NEWLINE}}\n {%- if agent.instructions is string -%}\n {{agent.instructions}}{{NEWLINE}}\n {%- else -%}\n {%- for instruction in agent.instructions -%}\n - {{instruction}}{{NEWLINE}}\n {%- endfor -%}\n {%- endif -%}\n {{NEWLINE}}\n{%- endif -%}\n\n{%- if docs -%}\nRelevant documents:{{NEWLINE}}\n {%- for doc in docs -%}\n {{doc.title}}{{NEWLINE}}\n {%- if doc.content is string -%}\n {{doc.content}}{{NEWLINE}}\n {%- else -%}\n {%- for snippet in doc.content -%}\n {{snippet}}{{NEWLINE}}\n {%- endfor -%}\n {%- endif -%}\n {{"---"}}\n {%- endfor -%}\n{%- endif -%}' + """ + A specific system prompt template that sets the background for this session """ render_templates: StrictBool = True """ @@ -87,6 +99,10 @@ class PatchSessionRequest(BaseModel): If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. """ + forward_tool_calls: StrictBool = False + """ + Whether to forward tool calls to the model + """ recall_options: RecallOptionsUpdate | None = None metadata: dict[str, Any] | None = None @@ -117,9 +133,13 @@ class Session(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - situation: str = '{%- if agent.name -%}\nYou are {{agent.name}}.{{" "}}\n{%- endif -%}\n\n{%- if agent.about -%}\nAbout you: {{agent.about}}.{{" "}}\n{%- endif -%}\n\n{%- if user -%}\nYou are talking to a user\n {%- if user.name -%}{{" "}} and their name is {{user.name}}\n {%- if user.about -%}. About the user: {{user.about}}.{%- else -%}.{%- endif -%}\n {%- endif -%}\n{%- endif -%}\n\n{{NEWLINE+NEWLINE}}\n\n{%- if agent.instructions -%}\nInstructions:{{NEWLINE}}\n {%- if agent.instructions is string -%}\n {{agent.instructions}}{{NEWLINE}}\n {%- else -%}\n {%- for instruction in agent.instructions -%}\n - {{instruction}}{{NEWLINE}}\n {%- endfor -%}\n {%- endif -%}\n {{NEWLINE}}\n{%- endif -%}\n\n{%- if tools -%}\nTools:{{NEWLINE}}\n {%- for tool in tools -%}\n - {{tool.name + NEWLINE}}\n {%- if tool.description -%}: {{tool.description + NEWLINE}}{%- endif -%}\n {%- endfor -%}\n{{NEWLINE+NEWLINE}}\n{%- endif -%}\n\n{%- if docs -%}\nRelevant documents:{{NEWLINE}}\n {%- for doc in docs -%}\n {{doc.title}}{{NEWLINE}}\n {%- if doc.content is string -%}\n {{doc.content}}{{NEWLINE}}\n {%- else -%}\n {%- for snippet in doc.content -%}\n {{snippet}}{{NEWLINE}}\n {%- endfor -%}\n {%- endif -%}\n {{"---"}}\n {%- endfor -%}\n{%- endif -%}' + situation: str | None = None + """ + Session situation + """ + system_template: str = '{%- if agent.name -%}\nYou are {{agent.name}}.{{" "}}\n{%- endif -%}\n\n{%- if agent.about -%}\nAbout you: {{agent.about}}.{{" "}}\n{%- endif -%}\n\n{%- if user -%}\nYou are talking to a user\n {%- if user.name -%}{{" "}} and their name is {{user.name}}\n {%- if user.about -%}. About the user: {{user.about}}.{%- else -%}.{%- endif -%}\n {%- endif -%}\n{%- endif -%}\n\n{{NEWLINE}}\n\n{%- if session.situation -%}\nSituation: {{session.situation}}\n{%- endif -%}\n\n{{NEWLINE+NEWLINE}}\n\n{%- if agent.instructions -%}\nInstructions:{{NEWLINE}}\n {%- if agent.instructions is string -%}\n {{agent.instructions}}{{NEWLINE}}\n {%- else -%}\n {%- for instruction in agent.instructions -%}\n - {{instruction}}{{NEWLINE}}\n {%- endfor -%}\n {%- endif -%}\n {{NEWLINE}}\n{%- endif -%}\n\n{%- if docs -%}\nRelevant documents:{{NEWLINE}}\n {%- for doc in docs -%}\n {{doc.title}}{{NEWLINE}}\n {%- if doc.content is string -%}\n {{doc.content}}{{NEWLINE}}\n {%- else -%}\n {%- for snippet in doc.content -%}\n {{snippet}}{{NEWLINE}}\n {%- endfor -%}\n {%- endif -%}\n {{"---"}}\n {%- endfor -%}\n{%- endif -%}' """ - A specific situation that sets the background for this session + A specific system prompt template that sets the background for this session """ summary: Annotated[str | None, Field(json_schema_extra={"readOnly": True})] = None """ @@ -145,6 +165,10 @@ class Session(BaseModel): If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. """ + forward_tool_calls: StrictBool = False + """ + Whether to forward tool calls to the model + """ recall_options: RecallOptions | None = None id: Annotated[UUID, Field(json_schema_extra={"readOnly": True})] metadata: dict[str, Any] | None = None @@ -193,9 +217,13 @@ class UpdateSessionRequest(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - situation: str = '{%- if agent.name -%}\nYou are {{agent.name}}.{{" "}}\n{%- endif -%}\n\n{%- if agent.about -%}\nAbout you: {{agent.about}}.{{" "}}\n{%- endif -%}\n\n{%- if user -%}\nYou are talking to a user\n {%- if user.name -%}{{" "}} and their name is {{user.name}}\n {%- if user.about -%}. About the user: {{user.about}}.{%- else -%}.{%- endif -%}\n {%- endif -%}\n{%- endif -%}\n\n{{NEWLINE+NEWLINE}}\n\n{%- if agent.instructions -%}\nInstructions:{{NEWLINE}}\n {%- if agent.instructions is string -%}\n {{agent.instructions}}{{NEWLINE}}\n {%- else -%}\n {%- for instruction in agent.instructions -%}\n - {{instruction}}{{NEWLINE}}\n {%- endfor -%}\n {%- endif -%}\n {{NEWLINE}}\n{%- endif -%}\n\n{%- if tools -%}\nTools:{{NEWLINE}}\n {%- for tool in tools -%}\n - {{tool.name + NEWLINE}}\n {%- if tool.description -%}: {{tool.description + NEWLINE}}{%- endif -%}\n {%- endfor -%}\n{{NEWLINE+NEWLINE}}\n{%- endif -%}\n\n{%- if docs -%}\nRelevant documents:{{NEWLINE}}\n {%- for doc in docs -%}\n {{doc.title}}{{NEWLINE}}\n {%- if doc.content is string -%}\n {{doc.content}}{{NEWLINE}}\n {%- else -%}\n {%- for snippet in doc.content -%}\n {{snippet}}{{NEWLINE}}\n {%- endfor -%}\n {%- endif -%}\n {{"---"}}\n {%- endfor -%}\n{%- endif -%}' + situation: str | None = None + """ + Session situation """ - A specific situation that sets the background for this session + system_template: str = '{%- if agent.name -%}\nYou are {{agent.name}}.{{" "}}\n{%- endif -%}\n\n{%- if agent.about -%}\nAbout you: {{agent.about}}.{{" "}}\n{%- endif -%}\n\n{%- if user -%}\nYou are talking to a user\n {%- if user.name -%}{{" "}} and their name is {{user.name}}\n {%- if user.about -%}. About the user: {{user.about}}.{%- else -%}.{%- endif -%}\n {%- endif -%}\n{%- endif -%}\n\n{{NEWLINE}}\n\n{%- if session.situation -%}\nSituation: {{session.situation}}\n{%- endif -%}\n\n{{NEWLINE+NEWLINE}}\n\n{%- if agent.instructions -%}\nInstructions:{{NEWLINE}}\n {%- if agent.instructions is string -%}\n {{agent.instructions}}{{NEWLINE}}\n {%- else -%}\n {%- for instruction in agent.instructions -%}\n - {{instruction}}{{NEWLINE}}\n {%- endfor -%}\n {%- endif -%}\n {{NEWLINE}}\n{%- endif -%}\n\n{%- if docs -%}\nRelevant documents:{{NEWLINE}}\n {%- for doc in docs -%}\n {{doc.title}}{{NEWLINE}}\n {%- if doc.content is string -%}\n {{doc.content}}{{NEWLINE}}\n {%- else -%}\n {%- for snippet in doc.content -%}\n {{snippet}}{{NEWLINE}}\n {%- endfor -%}\n {%- endif -%}\n {{"---"}}\n {%- endfor -%}\n{%- endif -%}' + """ + A specific system prompt template that sets the background for this session """ render_templates: StrictBool = True """ @@ -217,6 +245,10 @@ class UpdateSessionRequest(BaseModel): If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. """ + forward_tool_calls: StrictBool = False + """ + Whether to forward tool calls to the model + """ recall_options: RecallOptions | None = None metadata: dict[str, Any] | None = None @@ -236,9 +268,13 @@ class CreateOrUpdateSessionRequest(CreateSessionRequest): Agent ID of agent associated with this session """ agents: list[UUID] | None = None - situation: str = '{%- if agent.name -%}\nYou are {{agent.name}}.{{" "}}\n{%- endif -%}\n\n{%- if agent.about -%}\nAbout you: {{agent.about}}.{{" "}}\n{%- endif -%}\n\n{%- if user -%}\nYou are talking to a user\n {%- if user.name -%}{{" "}} and their name is {{user.name}}\n {%- if user.about -%}. About the user: {{user.about}}.{%- else -%}.{%- endif -%}\n {%- endif -%}\n{%- endif -%}\n\n{{NEWLINE+NEWLINE}}\n\n{%- if agent.instructions -%}\nInstructions:{{NEWLINE}}\n {%- if agent.instructions is string -%}\n {{agent.instructions}}{{NEWLINE}}\n {%- else -%}\n {%- for instruction in agent.instructions -%}\n - {{instruction}}{{NEWLINE}}\n {%- endfor -%}\n {%- endif -%}\n {{NEWLINE}}\n{%- endif -%}\n\n{%- if tools -%}\nTools:{{NEWLINE}}\n {%- for tool in tools -%}\n - {{tool.name + NEWLINE}}\n {%- if tool.description -%}: {{tool.description + NEWLINE}}{%- endif -%}\n {%- endfor -%}\n{{NEWLINE+NEWLINE}}\n{%- endif -%}\n\n{%- if docs -%}\nRelevant documents:{{NEWLINE}}\n {%- for doc in docs -%}\n {{doc.title}}{{NEWLINE}}\n {%- if doc.content is string -%}\n {{doc.content}}{{NEWLINE}}\n {%- else -%}\n {%- for snippet in doc.content -%}\n {{snippet}}{{NEWLINE}}\n {%- endfor -%}\n {%- endif -%}\n {{"---"}}\n {%- endfor -%}\n{%- endif -%}' + situation: str | None = None + """ + Session situation + """ + system_template: str = '{%- if agent.name -%}\nYou are {{agent.name}}.{{" "}}\n{%- endif -%}\n\n{%- if agent.about -%}\nAbout you: {{agent.about}}.{{" "}}\n{%- endif -%}\n\n{%- if user -%}\nYou are talking to a user\n {%- if user.name -%}{{" "}} and their name is {{user.name}}\n {%- if user.about -%}. About the user: {{user.about}}.{%- else -%}.{%- endif -%}\n {%- endif -%}\n{%- endif -%}\n\n{{NEWLINE}}\n\n{%- if session.situation -%}\nSituation: {{session.situation}}\n{%- endif -%}\n\n{{NEWLINE+NEWLINE}}\n\n{%- if agent.instructions -%}\nInstructions:{{NEWLINE}}\n {%- if agent.instructions is string -%}\n {{agent.instructions}}{{NEWLINE}}\n {%- else -%}\n {%- for instruction in agent.instructions -%}\n - {{instruction}}{{NEWLINE}}\n {%- endfor -%}\n {%- endif -%}\n {{NEWLINE}}\n{%- endif -%}\n\n{%- if docs -%}\nRelevant documents:{{NEWLINE}}\n {%- for doc in docs -%}\n {{doc.title}}{{NEWLINE}}\n {%- if doc.content is string -%}\n {{doc.content}}{{NEWLINE}}\n {%- else -%}\n {%- for snippet in doc.content -%}\n {{snippet}}{{NEWLINE}}\n {%- endfor -%}\n {%- endif -%}\n {{"---"}}\n {%- endfor -%}\n{%- endif -%}' """ - A specific situation that sets the background for this session + A specific system prompt template that sets the background for this session """ render_templates: StrictBool = True """ @@ -260,6 +296,10 @@ class CreateOrUpdateSessionRequest(CreateSessionRequest): If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. """ + forward_tool_calls: StrictBool = False + """ + Whether to forward tool calls to the model + """ recall_options: RecallOptions | None = None metadata: dict[str, Any] | None = None diff --git a/integrations-service/integrations/autogen/Tasks.py b/integrations-service/integrations/autogen/Tasks.py index b9212d8cb..ebc3a4b84 100644 --- a/integrations-service/integrations/autogen/Tasks.py +++ b/integrations-service/integrations/autogen/Tasks.py @@ -161,8 +161,21 @@ class CreateTaskRequest(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - name: str + name: Annotated[str, Field(max_length=255, min_length=1)] + """ + The name of the task. + """ + canonical_name: Annotated[ + str | None, + Field(max_length=255, min_length=1, pattern="^[a-zA-Z][a-zA-Z0-9_]*$"), + ] = None + """ + The canonical name of the task. + """ description: str = "" + """ + The description of the task. + """ main: Annotated[ list[ EvaluateStep @@ -206,9 +219,7 @@ class ErrorWorkflowStep(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - kind_: Annotated[Literal["error"], Field(json_schema_extra={"readOnly": True})] = ( - "error" - ) + kind_: Annotated[Literal["error"], Field(json_schema_extra={"readOnly": True})] = "error" """ The kind of step """ @@ -226,9 +237,9 @@ class EvaluateStep(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - kind_: Annotated[ - Literal["evaluate"], Field(json_schema_extra={"readOnly": True}) - ] = "evaluate" + kind_: Annotated[Literal["evaluate"], Field(json_schema_extra={"readOnly": True})] = ( + "evaluate" + ) """ The kind of step """ @@ -294,9 +305,9 @@ class ForeachStep(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - kind_: Annotated[ - Literal["foreach"], Field(json_schema_extra={"readOnly": True}) - ] = "foreach" + kind_: Annotated[Literal["foreach"], Field(json_schema_extra={"readOnly": True})] = ( + "foreach" + ) """ The kind of step """ @@ -332,9 +343,7 @@ class GetStep(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - kind_: Annotated[Literal["get"], Field(json_schema_extra={"readOnly": True})] = ( - "get" - ) + kind_: Annotated[Literal["get"], Field(json_schema_extra={"readOnly": True})] = "get" """ The kind of step """ @@ -352,9 +361,9 @@ class IfElseWorkflowStep(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - kind_: Annotated[ - Literal["if_else"], Field(json_schema_extra={"readOnly": True}) - ] = "if_else" + kind_: Annotated[Literal["if_else"], Field(json_schema_extra={"readOnly": True})] = ( + "if_else" + ) """ The kind of step """ @@ -476,9 +485,7 @@ class LogStep(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - kind_: Annotated[Literal["log"], Field(json_schema_extra={"readOnly": True})] = ( - "log" - ) + kind_: Annotated[Literal["log"], Field(json_schema_extra={"readOnly": True})] = "log" """ The kind of step """ @@ -496,9 +503,9 @@ class Main(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - kind_: Annotated[ - Literal["map_reduce"], Field(json_schema_extra={"readOnly": True}) - ] = "map_reduce" + kind_: Annotated[Literal["map_reduce"], Field(json_schema_extra={"readOnly": True})] = ( + "map_reduce" + ) """ The kind of step """ @@ -510,15 +517,7 @@ class Main(BaseModel): """ The variable to iterate over """ - map: ( - EvaluateStep - | ToolCallStep - | PromptStep - | GetStep - | SetStep - | LogStep - | YieldStep - ) + map: EvaluateStep | ToolCallStep | PromptStep | GetStep | SetStep | LogStep | YieldStep """ The steps to run for each iteration """ @@ -586,9 +585,9 @@ class ParallelStep(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - kind_: Annotated[ - Literal["parallel"], Field(json_schema_extra={"readOnly": True}) - ] = "parallel" + kind_: Annotated[Literal["parallel"], Field(json_schema_extra={"readOnly": True})] = ( + "parallel" + ) """ The kind of step """ @@ -598,13 +597,7 @@ class ParallelStep(BaseModel): """ parallel: Annotated[ list[ - EvaluateStep - | ToolCallStep - | PromptStep - | GetStep - | SetStep - | LogStep - | YieldStep + EvaluateStep | ToolCallStep | PromptStep | GetStep | SetStep | LogStep | YieldStep ], Field(max_length=100), ] @@ -650,7 +643,21 @@ class PatchTaskRequest(BaseModel): model_config = ConfigDict( populate_by_name=True, ) + name: Annotated[str | None, Field(max_length=255, min_length=1)] = None + """ + The name of the task. + """ + canonical_name: Annotated[ + str | None, + Field(max_length=255, min_length=1, pattern="^[a-zA-Z][a-zA-Z0-9_]*$"), + ] = None + """ + The canonical name of the task. + """ description: str = "" + """ + The description of the task. + """ main: Annotated[ list[ EvaluateStep @@ -733,9 +740,7 @@ class PromptStep(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - kind_: Annotated[Literal["prompt"], Field(json_schema_extra={"readOnly": True})] = ( - "prompt" - ) + kind_: Annotated[Literal["prompt"], Field(json_schema_extra={"readOnly": True})] = "prompt" """ The kind of step """ @@ -827,9 +832,7 @@ class ReturnStep(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - kind_: Annotated[Literal["return"], Field(json_schema_extra={"readOnly": True})] = ( - "return" - ) + kind_: Annotated[Literal["return"], Field(json_schema_extra={"readOnly": True})] = "return" """ The kind of step """ @@ -850,9 +853,7 @@ class SetStep(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - kind_: Annotated[Literal["set"], Field(json_schema_extra={"readOnly": True})] = ( - "set" - ) + kind_: Annotated[Literal["set"], Field(json_schema_extra={"readOnly": True})] = "set" """ The kind of step """ @@ -892,9 +893,7 @@ class SleepStep(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - kind_: Annotated[Literal["sleep"], Field(json_schema_extra={"readOnly": True})] = ( - "sleep" - ) + kind_: Annotated[Literal["sleep"], Field(json_schema_extra={"readOnly": True})] = "sleep" """ The kind of step """ @@ -924,9 +923,7 @@ class SwitchStep(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - kind_: Annotated[Literal["switch"], Field(json_schema_extra={"readOnly": True})] = ( - "switch" - ) + kind_: Annotated[Literal["switch"], Field(json_schema_extra={"readOnly": True})] = "switch" """ The kind of step """ @@ -966,8 +963,21 @@ class Task(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - name: str + name: Annotated[str, Field(max_length=255, min_length=1)] + """ + The name of the task. + """ + canonical_name: Annotated[ + str | None, + Field(max_length=255, min_length=1, pattern="^[a-zA-Z][a-zA-Z0-9_]*$"), + ] = None + """ + The canonical name of the task. + """ description: str = "" + """ + The description of the task. + """ main: Annotated[ list[ EvaluateStep @@ -1020,9 +1030,7 @@ class TaskTool(CreateToolRequest): model_config = ConfigDict( populate_by_name=True, ) - inherited: Annotated[StrictBool, Field(json_schema_extra={"readOnly": True})] = ( - False - ) + inherited: Annotated[StrictBool, Field(json_schema_extra={"readOnly": True})] = False """ Read-only: Whether the tool was inherited or not. Only applies within tasks. """ @@ -1032,9 +1040,9 @@ class ToolCallStep(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - kind_: Annotated[ - Literal["tool_call"], Field(json_schema_extra={"readOnly": True}) - ] = "tool_call" + kind_: Annotated[Literal["tool_call"], Field(json_schema_extra={"readOnly": True})] = ( + "tool_call" + ) """ The kind of step """ @@ -1057,9 +1065,7 @@ class ToolCallStep(BaseModel): dict[ str, dict[str, list[str] | dict[str, str] | list[dict[str, str]] | str] - | list[ - dict[str, list[str] | dict[str, str] | list[dict[str, str]] | str] - ] + | list[dict[str, list[str] | dict[str, str] | list[dict[str, str]] | str]] | str, ] ] @@ -1124,7 +1130,21 @@ class UpdateTaskRequest(BaseModel): model_config = ConfigDict( populate_by_name=True, ) + name: Annotated[str, Field(max_length=255, min_length=1)] + """ + The name of the task. + """ + canonical_name: Annotated[ + str | None, + Field(max_length=255, min_length=1, pattern="^[a-zA-Z][a-zA-Z0-9_]*$"), + ] = None + """ + The canonical name of the task. + """ description: str = "" + """ + The description of the task. + """ main: Annotated[ list[ EvaluateStep @@ -1178,9 +1198,9 @@ class WaitForInputStep(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - kind_: Annotated[ - Literal["wait_for_input"], Field(json_schema_extra={"readOnly": True}) - ] = "wait_for_input" + kind_: Annotated[Literal["wait_for_input"], Field(json_schema_extra={"readOnly": True})] = ( + "wait_for_input" + ) """ The kind of step """ @@ -1198,9 +1218,7 @@ class YieldStep(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - kind_: Annotated[Literal["yield"], Field(json_schema_extra={"readOnly": True})] = ( - "yield" - ) + kind_: Annotated[Literal["yield"], Field(json_schema_extra={"readOnly": True})] = "yield" """ The kind of step """ @@ -1214,8 +1232,7 @@ class YieldStep(BaseModel): VALIDATION: Should resolve to a defined subworkflow. """ arguments: ( - dict[str, list[str] | dict[str, str] | list[dict[str, str]] | str] - | Literal["_"] + dict[str, list[str] | dict[str, str] | list[dict[str, str]] | str] | Literal["_"] ) = "_" """ The input parameters for the subworkflow (defaults to last step output) diff --git a/integrations-service/integrations/autogen/Tools.py b/integrations-service/integrations/autogen/Tools.py index d872674af..229a866bb 100644 --- a/integrations-service/integrations/autogen/Tools.py +++ b/integrations-service/integrations/autogen/Tools.py @@ -561,9 +561,7 @@ class BrowserbaseGetSessionConnectUrlArguments(BrowserbaseGetSessionArguments): pass -class BrowserbaseGetSessionConnectUrlArgumentsUpdate( - BrowserbaseGetSessionArgumentsUpdate -): +class BrowserbaseGetSessionConnectUrlArgumentsUpdate(BrowserbaseGetSessionArgumentsUpdate): pass @@ -571,9 +569,7 @@ class BrowserbaseGetSessionLiveUrlsArguments(BrowserbaseGetSessionArguments): pass -class BrowserbaseGetSessionLiveUrlsArgumentsUpdate( - BrowserbaseGetSessionArgumentsUpdate -): +class BrowserbaseGetSessionLiveUrlsArgumentsUpdate(BrowserbaseGetSessionArgumentsUpdate): pass @@ -1806,9 +1802,9 @@ class SystemDefUpdate(BaseModel): model_config = ConfigDict( populate_by_name=True, ) - resource: ( - Literal["agent", "user", "task", "execution", "doc", "session", "job"] | None - ) = None + resource: Literal["agent", "user", "task", "execution", "doc", "session", "job"] | None = ( + None + ) """ Resource is the name of the resource to use """ @@ -2366,9 +2362,7 @@ class BrowserbaseCompleteSessionIntegrationDef(BaseBrowserbaseIntegrationDef): arguments: BrowserbaseCompleteSessionArguments | None = None -class BrowserbaseCompleteSessionIntegrationDefUpdate( - BaseBrowserbaseIntegrationDefUpdate -): +class BrowserbaseCompleteSessionIntegrationDefUpdate(BaseBrowserbaseIntegrationDefUpdate): """ browserbase complete session integration definition """ @@ -2494,9 +2488,7 @@ class BrowserbaseGetSessionConnectUrlIntegrationDef(BaseBrowserbaseIntegrationDe arguments: BrowserbaseGetSessionConnectUrlArguments | None = None -class BrowserbaseGetSessionConnectUrlIntegrationDefUpdate( - BaseBrowserbaseIntegrationDefUpdate -): +class BrowserbaseGetSessionConnectUrlIntegrationDefUpdate(BaseBrowserbaseIntegrationDefUpdate): """ browserbase get session connect url integration definition """ @@ -2544,9 +2536,7 @@ class BrowserbaseGetSessionLiveUrlsIntegrationDef(BaseBrowserbaseIntegrationDef) arguments: BrowserbaseGetSessionLiveUrlsArguments | None = None -class BrowserbaseGetSessionLiveUrlsIntegrationDefUpdate( - BaseBrowserbaseIntegrationDefUpdate -): +class BrowserbaseGetSessionLiveUrlsIntegrationDefUpdate(BaseBrowserbaseIntegrationDefUpdate): """ browserbase get session live urls integration definition """ diff --git a/integrations-service/integrations/models/arxiv.py b/integrations-service/integrations/models/arxiv.py index 31edf455a..7bbf1753c 100644 --- a/integrations-service/integrations/models/arxiv.py +++ b/integrations-service/integrations/models/arxiv.py @@ -1,26 +1,24 @@ -from typing import List, Optional - from pydantic import BaseModel, Field from .base_models import BaseOutput class ArxivSearchResult(BaseModel): - entry_id: Optional[str] = None - title: Optional[str] = None - updated: Optional[str] = None - published: Optional[str] = None - authors: Optional[List[str]] = None - summary: Optional[str] = None - comment: Optional[str] = None - journal_ref: Optional[str] = None - doi: Optional[str] = None - primary_category: Optional[str] = None - categories: Optional[List[str]] = None - links: Optional[List[str]] = None - pdf_url: Optional[str] = None - pdf_downloaded: Optional[dict] = None + entry_id: str | None = None + title: str | None = None + updated: str | None = None + published: str | None = None + authors: list[str] | None = None + summary: str | None = None + comment: str | None = None + journal_ref: str | None = None + doi: str | None = None + primary_category: str | None = None + categories: list[str] | None = None + links: list[str] | None = None + pdf_url: str | None = None + pdf_downloaded: dict | None = None class ArxivSearchOutput(BaseOutput): - result: List[ArxivSearchResult] = Field(..., description="A list of search results") + result: list[ArxivSearchResult] = Field(..., description="A list of search results") diff --git a/integrations-service/integrations/models/base_models.py b/integrations-service/integrations/models/base_models.py index 6d43f67b2..95b79da10 100644 --- a/integrations-service/integrations/models/base_models.py +++ b/integrations-service/integrations/models/base_models.py @@ -1,4 +1,4 @@ -from typing import Annotated, Optional +from typing import Annotated from pydantic import BaseModel, Field from pydantic_core import Url @@ -10,9 +10,9 @@ class BaseOutput(BaseModel): ... class ProviderInfo(BaseModel): - url: Optional[Url] = None - docs: Optional[Url] = None - icon: Optional[Url] = None + url: Url | None = None + docs: Url | None = None + icon: Url | None = None friendly_name: str diff --git a/integrations-service/integrations/models/brave.py b/integrations-service/integrations/models/brave.py index dd721d222..629d1feca 100644 --- a/integrations-service/integrations/models/brave.py +++ b/integrations-service/integrations/models/brave.py @@ -1,5 +1,3 @@ -from typing import List - from pydantic import BaseModel, Field from .base_models import BaseOutput @@ -12,4 +10,4 @@ class SearchResult(BaseModel): class BraveSearchOutput(BaseOutput): - result: List[SearchResult] = Field(..., description="A list of search results") + result: list[SearchResult] = Field(..., description="A list of search results") diff --git a/integrations-service/integrations/models/browserbase.py b/integrations-service/integrations/models/browserbase.py index 46f332e57..df683a4ec 100644 --- a/integrations-service/integrations/models/browserbase.py +++ b/integrations-service/integrations/models/browserbase.py @@ -1,4 +1,4 @@ -from typing import Literal, Optional +from typing import Literal from browserbase import DebugConnectionURLs, Session from pydantic import AnyUrl, Field @@ -15,17 +15,13 @@ class BrowserbaseCreateSessionOutput(BaseOutput): createdAt: str | None = Field( None, description="Timestamp indicating when the session was created" ) - projectId: str | None = Field( - None, description="The Project ID linked to the Session" - ) - startedAt: str | None = Field( - None, description="Timestamp when the session started" - ) + projectId: str | None = Field(None, description="The Project ID linked to the Session") + startedAt: str | None = Field(None, description="Timestamp when the session started") endedAt: str | None = Field(None, description="Timestamp when the session ended") expiresAt: str | None = Field( None, description="Timestamp when the session is set to expire" ) - status: None | Literal["RUNNING", "ERROR", "TIMED_OUT", "COMPLETED"] = Field( + status: Literal["RUNNING", "ERROR", "TIMED_OUT", "COMPLETED"] | None = Field( None, description="Current status of the session" ) proxyBytes: int | None = Field(None, description="Bytes used via the Proxy") @@ -45,17 +41,13 @@ class BrowserbaseGetSessionOutput(BaseOutput): createdAt: str | None = Field( None, description="Timestamp indicating when the session was created" ) - projectId: str | None = Field( - None, description="The Project ID linked to the Session" - ) - startedAt: str | None = Field( - None, description="Timestamp when the session started" - ) + projectId: str | None = Field(None, description="The Project ID linked to the Session") + startedAt: str | None = Field(None, description="Timestamp when the session started") endedAt: str | None = Field(None, description="Timestamp when the session ended") expiresAt: str | None = Field( None, description="Timestamp when the session is set to expire" ) - status: None | Literal["RUNNING", "ERROR", "TIMED_OUT", "COMPLETED"] = Field( + status: Literal["RUNNING", "ERROR", "TIMED_OUT", "COMPLETED"] | None = Field( None, description="Current status of the session" ) proxyBytes: int | None = Field(None, description="Bytes used via the Proxy") @@ -85,14 +77,14 @@ class BrowserbaseGetSessionConnectUrlOutput(BaseOutput): class PageInfo(BaseOutput): - id: Optional[str] = Field(None, description="Unique identifier for the page") - url: Optional[AnyUrl] = Field(None, description="URL of the page") - faviconUrl: Optional[AnyUrl] = Field(None, description="URL for the page's favicon") - title: Optional[str] = Field(None, description="Title of the page") - debuggerUrl: Optional[AnyUrl] = Field( + id: str | None = Field(None, description="Unique identifier for the page") + url: AnyUrl | None = Field(None, description="URL of the page") + faviconUrl: AnyUrl | None = Field(None, description="URL for the page's favicon") + title: str | None = Field(None, description="Title of the page") + debuggerUrl: AnyUrl | None = Field( None, description="URL to access the debugger for this page" ) - debuggerFullscreenUrl: Optional[AnyUrl] = Field( + debuggerFullscreenUrl: AnyUrl | None = Field( None, description="URL to access the debugger in fullscreen for this page" ) diff --git a/integrations-service/integrations/models/cloudinary.py b/integrations-service/integrations/models/cloudinary.py index 4ad59f4bf..7bd7c732b 100644 --- a/integrations-service/integrations/models/cloudinary.py +++ b/integrations-service/integrations/models/cloudinary.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from .base_models import BaseOutput @@ -8,16 +6,16 @@ class CloudinaryUploadOutput(BaseOutput): url: str = Field(..., description="The URL of the uploaded file") public_id: str = Field(..., description="The public ID of the uploaded file") - base64: Optional[str] = Field( + base64: str | None = Field( None, description="The base64 encoded file if return_base64 is true" ) - meta_data: Optional[dict] = Field( + meta_data: dict | None = Field( None, description="Additional metadata from the upload response" ) class CloudinaryEditOutput(BaseOutput): transformed_url: str = Field(..., description="The transformed URL") - base64: Optional[str] = Field( + base64: str | None = Field( None, description="The base64 encoded file if return_base64 is true" ) diff --git a/integrations-service/integrations/models/execution.py b/integrations-service/integrations/models/execution.py index 42cae6cbc..a618fc758 100644 --- a/integrations-service/integrations/models/execution.py +++ b/integrations-service/integrations/models/execution.py @@ -1,5 +1,3 @@ -from typing import Optional, Union - from pydantic import BaseModel from ..autogen.Tools import ( @@ -63,70 +61,70 @@ class ExecutionError(BaseModel): # Setup configurations -ExecutionSetup = Union[ - EmailSetup, - SpiderSetup, - WeatherSetup, - BraveSearchSetup, - BrowserbaseSetup, - RemoteBrowserSetup, - LlamaParseSetup, - CloudinarySetup, -] +ExecutionSetup = ( + EmailSetup + | SpiderSetup + | WeatherSetup + | BraveSearchSetup + | BrowserbaseSetup + | RemoteBrowserSetup + | LlamaParseSetup + | CloudinarySetup +) # Argument configurations -ExecutionArguments = Union[ - SpiderFetchArguments, - WeatherGetArguments, - EmailArguments, - WikipediaSearchArguments, - BraveSearchArguments, - BrowserbaseCreateSessionArguments, - BrowserbaseGetSessionArguments, - BrowserbaseGetSessionConnectUrlArguments, - BrowserbaseGetSessionLiveUrlsArguments, - BrowserbaseCompleteSessionArguments, - BrowserbaseContextArguments, - BrowserbaseExtensionArguments, - BrowserbaseListSessionsArguments, - RemoteBrowserArguments, - LlamaParseFetchArguments, - FfmpegSearchArguments, - CloudinaryUploadArguments, - CloudinaryEditArguments, - ArxivSearchArguments, -] +ExecutionArguments = ( + SpiderFetchArguments + | WeatherGetArguments + | EmailArguments + | WikipediaSearchArguments + | BraveSearchArguments + | BrowserbaseCreateSessionArguments + | BrowserbaseGetSessionArguments + | BrowserbaseGetSessionConnectUrlArguments + | BrowserbaseGetSessionLiveUrlsArguments + | BrowserbaseCompleteSessionArguments + | BrowserbaseContextArguments + | BrowserbaseExtensionArguments + | BrowserbaseListSessionsArguments + | RemoteBrowserArguments + | LlamaParseFetchArguments + | FfmpegSearchArguments + | CloudinaryUploadArguments + | CloudinaryEditArguments + | ArxivSearchArguments +) -ExecutionResponse = Union[ - WeatherGetOutput, - EmailOutput, - WikipediaSearchOutput, - BraveSearchOutput, - BrowserbaseCreateSessionOutput, - BrowserbaseGetSessionOutput, - BrowserbaseGetSessionConnectUrlOutput, - BrowserbaseGetSessionLiveUrlsOutput, - BrowserbaseCompleteSessionOutput, - BrowserbaseContextOutput, - BrowserbaseExtensionOutput, - BrowserbaseListSessionsOutput, - RemoteBrowserOutput, - LlamaParseFetchOutput, - FfmpegSearchOutput, - CloudinaryEditOutput, - CloudinaryUploadOutput, - ExecutionError, - ArxivSearchOutput, - SpiderOutput, -] +ExecutionResponse = ( + WeatherGetOutput + | EmailOutput + | WikipediaSearchOutput + | BraveSearchOutput + | BrowserbaseCreateSessionOutput + | BrowserbaseGetSessionOutput + | BrowserbaseGetSessionConnectUrlOutput + | BrowserbaseGetSessionLiveUrlsOutput + | BrowserbaseCompleteSessionOutput + | BrowserbaseContextOutput + | BrowserbaseExtensionOutput + | BrowserbaseListSessionsOutput + | RemoteBrowserOutput + | LlamaParseFetchOutput + | FfmpegSearchOutput + | CloudinaryEditOutput + | CloudinaryUploadOutput + | ExecutionError + | ArxivSearchOutput + | SpiderOutput +) class ExecutionRequest(BaseModel): - setup: Optional[ExecutionSetup] + setup: ExecutionSetup | None """ The setup parameters the integration accepts (such as API keys) """ - arguments: Optional[ExecutionArguments] + arguments: ExecutionArguments | None """ The arguments to pass to the integration """ diff --git a/integrations-service/integrations/models/ffmpeg.py b/integrations-service/integrations/models/ffmpeg.py index ad773228c..741f464f6 100644 --- a/integrations-service/integrations/models/ffmpeg.py +++ b/integrations-service/integrations/models/ffmpeg.py @@ -1,15 +1,9 @@ -from typing import Optional - from pydantic import Field from .base_models import BaseOutput class FfmpegSearchOutput(BaseOutput): - fileoutput: Optional[str] = Field( - None, description="The output file from the Ffmpeg command" - ) + fileoutput: str | None = Field(None, description="The output file from the Ffmpeg command") result: bool = Field(..., description="Whether the Ffmpeg command was successful") - mime_type: Optional[str] = Field( - None, description="The MIME type of the output file" - ) + mime_type: str | None = Field(None, description="The MIME type of the output file") diff --git a/integrations-service/integrations/models/llama_parse.py b/integrations-service/integrations/models/llama_parse.py index 759ec949c..6e874760f 100644 --- a/integrations-service/integrations/models/llama_parse.py +++ b/integrations-service/integrations/models/llama_parse.py @@ -5,6 +5,4 @@ class LlamaParseFetchOutput(BaseOutput): - documents: list[Document] = Field( - ..., description="The documents returned from the spider" - ) + documents: list[Document] = Field(..., description="The documents returned from the spider") diff --git a/integrations-service/integrations/models/remote_browser.py b/integrations-service/integrations/models/remote_browser.py index f1f585838..7aaf616de 100644 --- a/integrations-service/integrations/models/remote_browser.py +++ b/integrations-service/integrations/models/remote_browser.py @@ -6,7 +6,5 @@ class RemoteBrowserOutput(BaseOutput): output: str | None = Field(None, description="The output of the action") error: str | None = Field(None, description="The error of the action") - base64_image: str | None = Field( - None, description="The base64 encoded image of the action" - ) + base64_image: str | None = Field(None, description="The base64 encoded image of the action") system: str | None = Field(None, description="The system output of the action") diff --git a/integrations-service/integrations/models/spider.py b/integrations-service/integrations/models/spider.py index 4acfd8a66..2f74c39ba 100644 --- a/integrations-service/integrations/models/spider.py +++ b/integrations-service/integrations/models/spider.py @@ -1,4 +1,4 @@ -from typing import Any, List, Optional +from typing import Any from pydantic import BaseModel, Field @@ -6,14 +6,12 @@ class SpiderResponse(BaseModel): - content: Optional[str] = None - error: Optional[str] = None - status: Optional[int] = None - costs: Optional[dict[Any, Any]] = None - url: Optional[str] = None + content: str | None = None + error: str | None = None + status: int | None = None + costs: dict[Any, Any] | None = None + url: str | None = None class SpiderOutput(BaseOutput): - result: List[SpiderResponse] = Field( - ..., description="The responses from the spider" - ) + result: list[SpiderResponse] = Field(..., description="The responses from the spider") diff --git a/integrations-service/integrations/routers/integrations/get_integration_tool.py b/integrations-service/integrations/routers/integrations/get_integration_tool.py index c689be322..ea9e71ed7 100644 --- a/integrations-service/integrations/routers/integrations/get_integration_tool.py +++ b/integrations-service/integrations/routers/integrations/get_integration_tool.py @@ -1,5 +1,3 @@ -from typing import Optional - from fastapi import HTTPException from ...models.base_models import BaseProvider, BaseProviderMethod @@ -7,7 +5,7 @@ def convert_to_openai_tool( - provider: BaseProvider, method: Optional[BaseProviderMethod] = None + provider: BaseProvider, method: BaseProviderMethod | None = None ) -> dict: method = method or provider.methods[0] name = f"{provider.provider}_{method.method}" @@ -26,7 +24,7 @@ def convert_to_openai_tool( @router.get("/integrations/{provider}/tool", tags=["integration_tool"]) @router.get("/integrations/{provider}/{method}/tool", tags=["integration_tool"]) -async def get_integration_tool(provider: str, method: Optional[str] = None): +async def get_integration_tool(provider: str, method: str | None = None): from ...providers import available_providers provider_obj: BaseProvider | None = available_providers.get(provider, None) diff --git a/integrations-service/integrations/routers/integrations/get_integrations.py b/integrations-service/integrations/routers/integrations/get_integrations.py index 5a90ec69a..13ddd4a3b 100644 --- a/integrations-service/integrations/routers/integrations/get_integrations.py +++ b/integrations-service/integrations/routers/integrations/get_integrations.py @@ -1,12 +1,10 @@ -from typing import List - from ...providers import available_providers from .router import router @router.get("/integrations", tags=["integrations"]) -async def get_integrations() -> List[dict]: - integrations = [ +async def get_integrations() -> list[dict]: + return [ { "provider": p.provider, "setup": p.setup.model_json_schema() if p.setup else None, @@ -28,4 +26,3 @@ async def get_integrations() -> List[dict]: } for p in available_providers.values() ] - return integrations diff --git a/integrations-service/integrations/utils/__init__.py b/integrations-service/integrations/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/integrations-service/integrations/utils/execute_integration.py b/integrations-service/integrations/utils/execute_integration.py index 5fd298344..aa2fad392 100644 --- a/integrations-service/integrations/utils/execute_integration.py +++ b/integrations-service/integrations/utils/execute_integration.py @@ -37,11 +37,7 @@ async def execute_integration( package="integrations", ) - if ( - setup is not None - and provider_obj.setup - and not isinstance(setup, provider_obj.setup) - ): + if setup is not None and provider_obj.setup and not isinstance(setup, provider_obj.setup): setup = provider_obj.setup(**setup.model_dump()) arguments = ( diff --git a/integrations-service/integrations/utils/integrations/arxiv.py b/integrations-service/integrations/utils/integrations/arxiv.py index 70b3c14df..48bdcbb88 100644 --- a/integrations-service/integrations/utils/integrations/arxiv.py +++ b/integrations-service/integrations/utils/integrations/arxiv.py @@ -83,7 +83,6 @@ def create_arxiv_search_result(result, pdf_content=None): pdf_content = base64.b64encode(pdf_file.read()).decode("utf-8") results.append(create_arxiv_search_result(result, pdf_content)) else: - for result in search_results: - results.append(create_arxiv_search_result(result)) + results.extend(create_arxiv_search_result(result) for result in search_results) return ArxivSearchOutput(result=results) diff --git a/integrations-service/integrations/utils/integrations/brave.py b/integrations-service/integrations/utils/integrations/brave.py index 7414e081a..920f0b246 100644 --- a/integrations-service/integrations/utils/integrations/brave.py +++ b/integrations-service/integrations/utils/integrations/brave.py @@ -15,9 +15,7 @@ reraise=True, stop=stop_after_attempt(4), ) -async def search( - setup: BraveSearchSetup, arguments: BraveSearchArguments -) -> BraveSearchOutput: +async def search(setup: BraveSearchSetup, arguments: BraveSearchArguments) -> BraveSearchOutput: """ Searches Brave Search with the provided query. """ @@ -36,6 +34,7 @@ async def search( try: parsed_result = [SearchResult(**item) for item in json.loads(result)] except json.JSONDecodeError as e: - raise ValueError("Malformed JSON response from Brave Search") from e + msg = "Malformed JSON response from Brave Search" + raise ValueError(msg) from e return BraveSearchOutput(result=parsed_result) diff --git a/integrations-service/integrations/utils/integrations/browserbase.py b/integrations-service/integrations/utils/integrations/browserbase.py index efdf6b594..6022f40e2 100644 --- a/integrations-service/integrations/utils/integrations/browserbase.py +++ b/integrations-service/integrations/utils/integrations/browserbase.py @@ -1,3 +1,4 @@ +import contextlib import os import tempfile @@ -38,13 +39,9 @@ def get_browserbase_client(setup: BrowserbaseSetup) -> Browserbase: - setup.api_key = ( - browserbase_api_key if setup.api_key == "DEMO_API_KEY" else setup.api_key - ) + setup.api_key = browserbase_api_key if setup.api_key == "DEMO_API_KEY" else setup.api_key setup.project_id = ( - browserbase_project_id - if setup.project_id == "DEMO_PROJECT_ID" - else setup.project_id + browserbase_project_id if setup.project_id == "DEMO_PROJECT_ID" else setup.project_id ) return Browserbase( @@ -178,8 +175,9 @@ async def install_extension_from_github( ) -> BrowserbaseExtensionOutput: """Download and install an extension from GitHub to the user's Browserbase account.""" - github_url = f"https://github.com/{arguments.repository_name}/archive/refs/tags/{ - arguments.ref}.zip" + github_url = ( + f"https://github.com/{arguments.repository_name}/archive/refs/tags/{arguments.ref}.zip" + ) async with httpx.AsyncClient(timeout=600) as client: # Download the extension zip @@ -202,9 +200,7 @@ async def install_extension_from_github( with open(tmp_file_path, "rb") as f: files = {"file": f} - upload_response = await client.post( - upload_url, headers=headers, files=files - ) + upload_response = await client.post(upload_url, headers=headers, files=files) try: upload_response.raise_for_status() @@ -213,9 +209,7 @@ async def install_extension_from_github( raise # Delete the temporary file - try: + with contextlib.suppress(FileNotFoundError): os.remove(tmp_file_path) - except FileNotFoundError: - pass return BrowserbaseExtensionOutput(id=upload_response.json()["id"]) diff --git a/integrations-service/integrations/utils/integrations/cloudinary.py b/integrations-service/integrations/utils/integrations/cloudinary.py index ccfecc7cf..a48a3c77f 100644 --- a/integrations-service/integrations/utils/integrations/cloudinary.py +++ b/integrations-service/integrations/utils/integrations/cloudinary.py @@ -65,16 +65,18 @@ async def media_upload( } if arguments.return_base64: - async with aiohttp.ClientSession() as session: - async with session.get(result["secure_url"]) as response: - if response.status == 200: - content = await response.read() - base64_encoded = base64.b64encode(content).decode("utf-8") - result["base64"] = base64_encoded - else: - raise RuntimeError( - f"Failed to download file from URL: {result['secure_url']}" - ) + async with ( + aiohttp.ClientSession() as session, + session.get(result["secure_url"]) as response, + ): + if response.status == 200: + content = await response.read() + base64_encoded = base64.b64encode(content).decode("utf-8") + result["base64"] = base64_encoded + else: + msg = f"Failed to download file from URL: {result['secure_url']}" + raise RuntimeError(msg) + return CloudinaryUploadOutput( url=result["secure_url"], public_id=result["public_id"], @@ -83,9 +85,11 @@ async def media_upload( ) except cloudinary.exceptions.Error as e: - raise RuntimeError(f"Cloudinary error occurred: {e}") + msg = f"Cloudinary error occurred: {e}" + raise RuntimeError(msg) except Exception as e: - raise RuntimeError(f"An unexpected error occurred: {e}") + msg = f"An unexpected error occurred: {e}" + raise RuntimeError(msg) @beartype @@ -128,16 +132,17 @@ async def media_edit( base64=None, ) if arguments.return_base64: - async with aiohttp.ClientSession() as session: - async with session.get(transformed_url[0]) as response: - if response.status == 200: - content = await response.read() - base64_encoded = base64.b64encode(content).decode("utf-8") - transformed_url_base64 = base64_encoded - else: - raise RuntimeError( - f"Failed to download file from URL: {transformed_url[0]}" - ) + async with ( + aiohttp.ClientSession() as session, + session.get(transformed_url[0]) as response, + ): + if response.status == 200: + content = await response.read() + base64_encoded = base64.b64encode(content).decode("utf-8") + transformed_url_base64 = base64_encoded + else: + msg = f"Failed to download file from URL: {transformed_url[0]}" + raise RuntimeError(msg) return CloudinaryEditOutput( transformed_url=transformed_url[0], @@ -145,6 +150,8 @@ async def media_edit( ) except cloudinary.exceptions.Error as e: - raise RuntimeError(f"Cloudinary error occurred: {e}") + msg = f"Cloudinary error occurred: {e}" + raise RuntimeError(msg) except Exception as e: - raise RuntimeError(f"An unexpected error occurred: {e}") + msg = f"An unexpected error occurred: {e}" + raise RuntimeError(msg) diff --git a/integrations-service/integrations/utils/integrations/ffmpeg.py b/integrations-service/integrations/utils/integrations/ffmpeg.py index 456882c0d..040181d3c 100644 --- a/integrations-service/integrations/utils/integrations/ffmpeg.py +++ b/integrations-service/integrations/utils/integrations/ffmpeg.py @@ -4,7 +4,6 @@ import shutil import tempfile from functools import lru_cache -from typing import Tuple from beartype import beartype from tenacity import retry, stop_after_attempt, wait_exponential @@ -15,7 +14,7 @@ # Cache for format validation @lru_cache(maxsize=128) -def _sync_validate_format(binary_prefix: bytes) -> Tuple[bool, str]: +def _sync_validate_format(binary_prefix: bytes) -> tuple[bool, str]: """Cached synchronous implementation of format validation""" signatures = { # Video formats @@ -46,7 +45,7 @@ def _sync_validate_format(binary_prefix: bytes) -> Tuple[bool, str]: return False, "application/octet-stream" -async def validate_format(binary_data: bytes) -> Tuple[bool, str]: +async def validate_format(binary_data: bytes) -> tuple[bool, str]: """Validate file format using file signatures""" # Only check first 16 bytes for efficiency binary_prefix = binary_data[:16] @@ -140,6 +139,4 @@ async def bash_cmd(arguments: FfmpegSearchArguments) -> FfmpegSearchOutput: # Clean up in case of exception if "temp_dir" in locals(): shutil.rmtree(temp_dir) - return FfmpegSearchOutput( - fileoutput=f"Error: {str(e)}", result=False, mime_type=None - ) + return FfmpegSearchOutput(fileoutput=f"Error: {e!s}", result=False, mime_type=None) diff --git a/integrations-service/integrations/utils/integrations/llama_parse.py b/integrations-service/integrations/utils/integrations/llama_parse.py index bbdbb13b6..f8b1873bc 100644 --- a/integrations-service/integrations/utils/integrations/llama_parse.py +++ b/integrations-service/integrations/utils/integrations/llama_parse.py @@ -51,10 +51,7 @@ async def parse( base64.b64decode(arguments.file), extra_info=extra_info ) else: - if arguments.filename: - extra_info = {"file_name": arguments.filename} - else: - extra_info = None + extra_info = {"file_name": arguments.filename} if arguments.filename else None # Parse the document (decode inline) documents = await parser.aload_data(arguments.file, extra_info=extra_info) diff --git a/integrations-service/integrations/utils/integrations/remote_browser.py b/integrations-service/integrations/utils/integrations/remote_browser.py index 2b83c2be6..0325bea21 100644 --- a/integrations-service/integrations/utils/integrations/remote_browser.py +++ b/integrations-service/integrations/utils/integrations/remote_browser.py @@ -47,14 +47,12 @@ def __init__( async def _is_initialized(self) -> bool: """Check if the page is initialized""" - result = bool( + return bool( await self._execute_javascript(""" window.$$julep$$_initialized """) ) - return result - async def initialize(self, debug: bool = False) -> None: if debug: self.page.on("console", lambda msg: print(msg.text)) @@ -69,7 +67,7 @@ async def initialize(self, debug: bool = False) -> None: // Update mouse coordinates on mouse move // but only on the top document - if (window === window.parent) + if (window === window.parent) window.addEventListener( 'DOMContentLoaded', () => { @@ -137,11 +135,9 @@ async def _get_screen_size(self) -> tuple[int, int]: async def _set_screen_size(self, width: int, height: int) -> None: """Set the current browser viewport size""" - await self.page.set_viewport_size(dict(width=width, height=height)) + await self.page.set_viewport_size({"width": width, "height": height}) - async def _wait_for_load( - self, event: str = "domcontentloaded", timeout: int = 0 - ) -> None: + async def _wait_for_load(self, event: str = "domcontentloaded", timeout: int = 0) -> None: """Wait for document to be fully loaded""" await self.page.wait_for_load_state(event, timeout=timeout) @@ -174,7 +170,8 @@ async def _get_element_coordinates(self, selector: str) -> tuple[int, int]: if element: box = await element.bounding_box() return (box["x"], box["y"]) - raise Exception(f"Element not found: {selector}") + msg = f"Element not found: {selector}" + raise Exception(msg) def _overlay_cursor(self, screenshot_bytes: bytes, x: int, y: int) -> bytes: """Overlay the cursor image on the screenshot at the specified coordinates.""" @@ -363,12 +360,14 @@ async def perform_action( } if action not in actions: - raise ValueError(f"Invalid action: {action}") + msg = f"Invalid action: {action}" + raise ValueError(msg) return await actions[action]() except Exception as e: - raise Exception(f"Error performing action {action}: {str(e)}") + msg = f"Error performing action {action}: {e!s}" + raise Exception(msg) @beartype diff --git a/integrations-service/integrations/utils/integrations/spider.py b/integrations-service/integrations/utils/integrations/spider.py index ff31705a0..a858afaf8 100644 --- a/integrations-service/integrations/utils/integrations/spider.py +++ b/integrations-service/integrations/utils/integrations/spider.py @@ -18,11 +18,7 @@ def get_api_key(setup: SpiderSetup) -> str: """ Helper function to get the API key. """ - return ( - setup.spider_api_key - if setup.spider_api_key != "DEMO_API_KEY" - else spider_api_key - ) + return setup.spider_api_key if setup.spider_api_key != "DEMO_API_KEY" else spider_api_key def create_spider_response(pages: list[dict]) -> list[SpiderResponse]: @@ -56,12 +52,13 @@ async def execute_spider_method( results = result if results is None: - raise ValueError("No results found") - else: - final_result = create_spider_response(results) + msg = "No results found" + raise ValueError(msg) + final_result = create_spider_response(results) except Exception as e: # Log the exception or handle it as needed - raise RuntimeError(f"Error executing spider method '{method_name}': {e}") + msg = f"Error executing spider method '{method_name}': {e}" + raise RuntimeError(msg) return SpiderOutput(result=final_result) @@ -102,9 +99,7 @@ async def links(setup: SpiderSetup, arguments: SpiderFetchArguments) -> SpiderOu reraise=True, stop=stop_after_attempt(4), ) -async def screenshot( - setup: SpiderSetup, arguments: SpiderFetchArguments -) -> SpiderOutput: +async def screenshot(setup: SpiderSetup, arguments: SpiderFetchArguments) -> SpiderOutput: """ Take a screenshot of the webpage. """ diff --git a/integrations-service/integrations/utils/integrations/weather.py b/integrations-service/integrations/utils/integrations/weather.py index 19e6c659e..9bddeb9ee 100644 --- a/integrations-service/integrations/utils/integrations/weather.py +++ b/integrations-service/integrations/utils/integrations/weather.py @@ -28,7 +28,8 @@ async def get(setup: WeatherSetup, arguments: WeatherGetArguments) -> WeatherGet openweathermap_api_key = openweather_api_key if not location: - raise ValueError("Location parameter is required for weather data") + msg = "Location parameter is required for weather data" + raise ValueError(msg) weather = OpenWeatherMapAPIWrapper(openweathermap_api_key=openweathermap_api_key) result = weather.run(location) diff --git a/integrations-service/integrations/utils/integrations/wikipedia.py b/integrations-service/integrations/utils/integrations/wikipedia.py index 235d9512a..f3d1394a8 100644 --- a/integrations-service/integrations/utils/integrations/wikipedia.py +++ b/integrations-service/integrations/utils/integrations/wikipedia.py @@ -21,7 +21,8 @@ async def search( query = arguments.query if not query: - raise ValueError("Query parameter is required for Wikipedia search") + msg = "Query parameter is required for Wikipedia search" + raise ValueError(msg) load_max_docs = arguments.load_max_docs diff --git a/integrations-service/integrations/web.py b/integrations-service/integrations/web.py index 2445dadbb..62b49af48 100644 --- a/integrations-service/integrations/web.py +++ b/integrations-service/integrations/web.py @@ -1,7 +1,8 @@ import asyncio import logging import os -from typing import Any, Callable +from collections.abc import Callable +from typing import Any import sentry_sdk import uvicorn diff --git a/integrations-service/poe_tasks.toml b/integrations-service/poe_tasks.toml index a43646dd4..258b06679 100644 --- a/integrations-service/poe_tasks.toml +++ b/integrations-service/poe_tasks.toml @@ -1,6 +1,6 @@ [tasks] format = "ruff format" -lint = "ruff check --select I --fix --unsafe-fixes integrations/**/*.py" +lint = "ruff check" typecheck = "pytype --config pytype.toml" check = [ "lint", diff --git a/integrations-service/tests/conftest.py b/integrations-service/tests/conftest.py index be5b4ebf7..81d197d86 100644 --- a/integrations-service/tests/conftest.py +++ b/integrations-service/tests/conftest.py @@ -1,13 +1,14 @@ -import pytest from unittest.mock import patch +import pytest from integrations.providers import available_providers + from .mocks.brave import MockBraveSearchClient from .mocks.email import MockEmailClient +from .mocks.llama_parse import MockLlamaParseClient from .mocks.spider import MockSpiderClient from .mocks.weather import MockWeatherClient from .mocks.wikipedia import MockWikipediaClient -from .mocks.llama_parse import MockLlamaParseClient @pytest.fixture(autouse=True) @@ -17,12 +18,8 @@ def mock_external_services(): patch("langchain_community.tools.BraveSearch", MockBraveSearchClient), patch("smtplib.SMTP", MockEmailClient), patch("langchain_community.document_loaders.SpiderLoader", MockSpiderClient), - patch( - "langchain_community.utilities.OpenWeatherMapAPIWrapper", MockWeatherClient - ), - patch( - "langchain_community.document_loaders.WikipediaLoader", MockWikipediaClient - ), + patch("langchain_community.utilities.OpenWeatherMapAPIWrapper", MockWeatherClient), + patch("langchain_community.document_loaders.WikipediaLoader", MockWikipediaClient), patch("llama_parse.LlamaParse", MockLlamaParseClient), ): yield diff --git a/integrations-service/tests/mocks/brave.py b/integrations-service/tests/mocks/brave.py index 958925aed..d9ed4399b 100644 --- a/integrations-service/tests/mocks/brave.py +++ b/integrations-service/tests/mocks/brave.py @@ -12,5 +12,3 @@ def search(self, query: str) -> str: class MockBraveSearchException(Exception): """Mock exception for Brave Search errors""" - - pass diff --git a/integrations-service/tests/mocks/email.py b/integrations-service/tests/mocks/email.py index 5f747ddc3..ea302d4ae 100644 --- a/integrations-service/tests/mocks/email.py +++ b/integrations-service/tests/mocks/email.py @@ -15,5 +15,3 @@ def send(self, to: str, from_: str, subject: str, body: str) -> bool: class MockEmailException(Exception): """Mock exception for email errors""" - - pass diff --git a/integrations-service/tests/mocks/llama_parse.py b/integrations-service/tests/mocks/llama_parse.py index 4ca9bd28a..78467fbfd 100644 --- a/integrations-service/tests/mocks/llama_parse.py +++ b/integrations-service/tests/mocks/llama_parse.py @@ -1,6 +1,5 @@ """Mock implementation of llama parse client""" -from typing import List, Dict from llama_index.core.schema import Document @@ -11,7 +10,7 @@ def __init__(self, api_key: str, result_type: str, num_workers: int, language: s self.num_workers = num_workers self.language = language - async def aload_data(self, file_content: bytes, extra_info: dict) -> List[Dict]: + async def aload_data(self, file_content: bytes, extra_info: dict) -> list[dict]: """Mock loading data that returns fixed documents""" return [ Document(page_content="Mock document content 1", metadata=extra_info), @@ -21,5 +20,3 @@ async def aload_data(self, file_content: bytes, extra_info: dict) -> List[Dict]: class MockLlamaParseException(Exception): """Mock exception for llama parse errors""" - - pass diff --git a/integrations-service/tests/mocks/spider.py b/integrations-service/tests/mocks/spider.py index dc6f01c41..9963c7af0 100644 --- a/integrations-service/tests/mocks/spider.py +++ b/integrations-service/tests/mocks/spider.py @@ -1,6 +1,5 @@ """Mock implementation of web spider client""" -from typing import List from langchain_core.documents import Document from pydantic import AnyUrl @@ -9,19 +8,13 @@ class MockSpiderClient: def __init__(self, api_key: str): self.api_key = api_key - def crawl(self, url: AnyUrl, mode: str = "scrape") -> List[Document]: + def crawl(self, url: AnyUrl, mode: str = "scrape") -> list[Document]: """Mock crawl that returns fixed documents""" return [ - Document( - page_content="Mock crawled content 1", metadata={"source": str(url)} - ), - Document( - page_content="Mock crawled content 2", metadata={"source": str(url)} - ), + Document(page_content="Mock crawled content 1", metadata={"source": str(url)}), + Document(page_content="Mock crawled content 2", metadata={"source": str(url)}), ] class MockSpiderException(Exception): """Mock exception for spider errors""" - - pass diff --git a/integrations-service/tests/mocks/weather.py b/integrations-service/tests/mocks/weather.py index 4fa4c357d..6ef8a2666 100644 --- a/integrations-service/tests/mocks/weather.py +++ b/integrations-service/tests/mocks/weather.py @@ -12,5 +12,3 @@ def get_weather(self, location: str) -> str: class MockWeatherException(Exception): """Mock exception for weather API errors""" - - pass diff --git a/integrations-service/tests/mocks/wikipedia.py b/integrations-service/tests/mocks/wikipedia.py index 19b11d140..40d52b7b2 100644 --- a/integrations-service/tests/mocks/wikipedia.py +++ b/integrations-service/tests/mocks/wikipedia.py @@ -1,6 +1,5 @@ """Mock implementation of Wikipedia API client""" -from typing import List from langchain_core.documents import Document @@ -15,11 +14,9 @@ def __init__(self, query: str, load_max_docs: int = 2): for _ in range(load_max_docs) ] - def load(self, *args, **kwargs) -> List[Document]: + def load(self, *args, **kwargs) -> list[Document]: return self.result class MockWikipediaException(Exception): """Mock exception for Wikipedia API errors""" - - pass diff --git a/integrations-service/tests/test_provider_execution.py b/integrations-service/tests/test_provider_execution.py index 9b96ee51b..21c43f24e 100644 --- a/integrations-service/tests/test_provider_execution.py +++ b/integrations-service/tests/test_provider_execution.py @@ -1,7 +1,6 @@ """Tests for provider execution using mocks""" import pytest - from integrations.autogen.Tools import ( WikipediaSearchArguments, ) @@ -20,7 +19,7 @@ async def test_weather_get_mock(wikipedia_provider): ) assert len(result.documents) > 0 - assert any([(query in doc.page_content) for doc in result.documents]) + assert any((query in doc.page_content) for doc in result.documents) # @pytest.mark.asyncio diff --git a/integrations-service/tests/test_providers.py b/integrations-service/tests/test_providers.py index 181248944..c79d3ff3d 100644 --- a/integrations-service/tests/test_providers.py +++ b/integrations-service/tests/test_providers.py @@ -4,18 +4,16 @@ def test_available_providers(providers): """Test that the available providers dictionary is properly structured""" assert isinstance(providers, dict) - assert all(isinstance(key, str) for key in providers.keys()) + assert all(isinstance(key, str) for key in providers) assert all(isinstance(value, BaseProvider) for value in providers.values()) def test_provider_structure(providers): """Test that each provider has the required attributes""" - for provider_name, provider in providers.items(): + for provider in providers.values(): assert isinstance(provider.provider, str) assert isinstance(provider.methods, list) - assert all( - isinstance(method, BaseProviderMethod) for method in provider.methods - ) + assert all(isinstance(method, BaseProviderMethod) for method in provider.methods) assert isinstance(provider.info, ProviderInfo) diff --git a/memory-store/.gitignore b/memory-store/.gitignore index 9383f36da..c2563b460 100644 --- a/memory-store/.gitignore +++ b/memory-store/.gitignore @@ -1,4 +1,3 @@ -cozo.db/ tmp/ *.pyc \ No newline at end of file diff --git a/memory-store/Dockerfile b/memory-store/Dockerfile deleted file mode 100644 index fa384cb12..000000000 --- a/memory-store/Dockerfile +++ /dev/null @@ -1,64 +0,0 @@ -# syntax=docker/dockerfile:1 -# check=error=true -# We need to build the cozo binary first from the repo -# https://github.com/cozodb/cozo -# Then copy the binary to the ./bin directory -# Then copy the run.sh script to the ./run.sh file - -# First stage: Build the Rust project -FROM rust:1.83-bookworm AS builder - -# Install required dependencies -RUN apt-get update && apt-get install -y \ - liburing-dev \ - libclang-dev \ - clang - -# Build cozo-ce-bin from crates.io -WORKDIR /usr/src -# RUN cargo install cozo-ce-bin@0.7.13-alpha.3 --features "requests graph-algo storage-new-rocksdb storage-sqlite jemalloc io-uring malloc-usable-size" -RUN cargo install --git https://github.com/cozo-community/cozo.git --branch f/publish-crate --rev 592f49b --profile release -F graph-algo -F jemalloc -F io-uring -F storage-new-rocksdb -F malloc-usable-size --target x86_64-unknown-linux-gnu cozo-ce-bin - -# Copy the built binary to /usr/local/bin -RUN cp /usr/local/cargo/bin/cozo-ce-bin /usr/local/bin/cozo - -# ------------------------------------------------------------------------------------------------- - -# Second stage: Create the final image -FROM debian:bookworm-slim - -# Install dependencies -RUN \ - apt-get update -yqq && \ - apt-get install -y \ - ca-certificates tini nfs-common nfs-kernel-server procps netbase \ - liburing-dev curl && \ - apt-get clean && \ - rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* - -# Set fallback mount directory -ENV COZO_MNT_DIR=/data COZO_BACKUP_DIR=/backup APP_HOME=/app COZO_PORT=9070 -WORKDIR $APP_HOME - -# Copy the cozo binary -COPY --from=builder /usr/local/bin/cozo $APP_HOME/bin/cozo - -# Copy local code to the container image. -COPY ./run.sh ./run.sh -COPY ./backup.sh ./backup.sh - -# Ensure the script is executable -RUN \ - mkdir -p $COZO_MNT_DIR $COZO_BACKUP_DIR && \ - chmod +x $APP_HOME/bin/cozo && \ - chmod +x $APP_HOME/run.sh - -# Copy the options file into the image -COPY ./options ./options - -# Use tini to manage zombie processes and signal forwarding -# https://github.com/krallin/tini -ENTRYPOINT ["/usr/bin/tini", "--"] - -# Pass the startup script as arguments to tini -CMD ["/app/run.sh"] diff --git a/memory-store/README.md b/memory-store/README.md index a58ba79d1..3441d47a4 100644 --- a/memory-store/README.md +++ b/memory-store/README.md @@ -1,28 +1,7 @@ -Cozo Server +### prototyping flow: -The `memory-store` directory within the julep repository serves as a critical component for managing data persistence and availability. It encompasses functionalities for data backup, service deployment, and containerization, ensuring that the julep project's data management is efficient and scalable. - -## Backup Script - -The `backup.py` script within the `backup` subdirectory is designed to periodically back up data while also cleaning up old backups based on a specified retention period. This ensures that the system maintains only the necessary backups, optimizing storage use. For more details, see the `backup.py` file. - -## Dockerfile - -The Dockerfile is instrumental in creating a Docker image for the memory-store service. It outlines the steps for installing necessary dependencies and setting up the environment to run the service. This includes the installation of software packages and configuration of environment variables. For specifics, refer to the Dockerfile. - -## Docker Compose - -The `docker-compose.yml` file is used to define and run multi-container Docker applications, specifically tailored for the memory-store service. It specifies the service configurations, including environment variables, volumes, and ports, facilitating an organized deployment. For more details, see the `docker-compose.yml` file. - -## Deployment Script - -The `deploy.sh` script is aimed at deploying the memory-store service to a cloud provider, utilizing specific configurations to ensure seamless integration and operation. This script includes commands for setting environment variables and deploying the service. For specifics, refer to the `deploy.sh` script. - -## Usage - -To utilize the components of the memory-store directory, follow these general instructions: - -- To build and run the Docker containers, use the Docker and Docker Compose commands as specified in the `docker-compose.yml` file. -- To execute the backup script, run `python backup.py` with the appropriate arguments as detailed in the `backup.py` file. - -This README provides a comprehensive guide to understanding and using the memory-store components within the julep project. +1. Install `pgmigrate` (until I move to golang-migrate) +2. In a separate window, `docker compose up db vectorizer-worker` to start db instances +3. `cd memory-store` and `pgmigrate migrate --database "postgres://postgres:postgres@0.0.0.0:5432/postgres" --migrations ./migrations` to apply the migrations +4. `pip install --user -U pgcli` +5. `pgcli "postgres://postgres:postgres@localhost:5432/postgres"` diff --git a/memory-store/backup.sh b/memory-store/backup.sh deleted file mode 100644 index 0a4fff0dd..000000000 --- a/memory-store/backup.sh +++ /dev/null @@ -1,38 +0,0 @@ -#!/usr/bin/env bash - -set -eo pipefail # Exit on error -set -u # Exit on undefined variable - -# Ensure environment variables are set -if [ -z "$COZO_AUTH_TOKEN" ]; then - echo "COZO_AUTH_TOKEN is not set" - exit 1 -fi - -COZO_PORT=${COZO_PORT:-9070} -COZO_BACKUP_DIR=${COZO_BACKUP_DIR:-/backup} -TIMESTAMP=$(date +%Y-%m-%d_%H-%M-%S) -MAX_BACKUPS=${MAX_BACKUPS:-10} - -curl -X POST \ - http://0.0.0.0:$COZO_PORT/backup \ - -H 'Content-Type: application/json' \ - -H "X-Cozo-Auth: ${COZO_AUTH_TOKEN}" \ - -d "{\"path\": \"${COZO_BACKUP_DIR}/cozo-backup-${TIMESTAMP}.bk\"}" \ - -w "\nStatus: %{http_code}\nResponse:\n" \ - -o /dev/stdout - -# Print the number of backups -echo "Number of backups: $(ls -l ${COZO_BACKUP_DIR} | grep -c "cozo-backup-")" - -# If the backup is successful, remove the oldest backup if the number of backups exceeds MAX_BACKUPS -if [ $(ls -l ${COZO_BACKUP_DIR} | grep -c "cozo-backup-") -gt $MAX_BACKUPS ]; then - oldest_backup=$(ls -t ${COZO_BACKUP_DIR}/cozo-backup-*.bk | tail -n 1) - - if [ -n "$oldest_backup" ]; then - rm "$oldest_backup" - echo "Removed oldest backup: $oldest_backup" - else - echo "No backups found to remove" - fi -fi \ No newline at end of file diff --git a/memory-store/docker-compose.yml b/memory-store/docker-compose.yml index f00d003de..4371c30d5 100644 --- a/memory-store/docker-compose.yml +++ b/memory-store/docker-compose.yml @@ -1,46 +1,56 @@ -name: julep-memory-store - +name: pgai services: memory-store: - image: julepai/memory-store:${TAG:-dev} + image: timescale/timescaledb-ha:pg17 + + # For timescaledb specific options, + # See: https://github.com/timescale/timescaledb-docker?tab=readme-ov-file#notes-on-timescaledb-tune environment: - - COZO_AUTH_TOKEN=${COZO_AUTH_TOKEN} - - COZO_PORT=${COZO_PORT:-9070} - - COZO_MNT_DIR=${MNT_DIR:-/data} - - COZO_BACKUP_DIR=${COZO_BACKUP_DIR:-/backup} - volumes: - - cozo_data:/data - - cozo_backup:/backup - build: - context: . + - POSTGRES_PASSWORD=${MEMORY_STORE_PASSWORD:-postgres} + - VOYAGE_API_KEY=${VOYAGE_API_KEY} ports: - - "9070:9070" + - "5432:5432" + volumes: + - memory_store_data:/home/postgres/pgdata/data - develop: - watch: - - action: sync+restart - path: ./options - target: /data/cozo.db/OPTIONS-000007 - - action: rebuild - path: Dockerfile + # TODO: Fix this to install pgaudit + # entrypoint: [] + # command: >- + # sed -r -i "s/[#]*\s*(shared_preload_libraries)\s*=\s*'(.*)'/\1 = 'pgaudit,\2'/;s/,'/'/" /home/postgres/pgdata/data/postgresql.conf + # && exec /docker-entrypoint.sh - labels: - ofelia.enabled: "true" - ofelia.job-exec.backupcron.schedule: "@every 3h" - ofelia.job-exec.backupcron.environment: '["COZO_PORT=${COZO_PORT}", "COZO_AUTH_TOKEN=${COZO_AUTH_TOKEN}", "COZO_BACKUP_DIR=${COZO_BACKUP_DIR}"]' - ofelia.job-exec.backupcron.command: bash /app/backup.sh + healthcheck: + test: ["CMD-SHELL", "pg_isready -U postgres || exit 1"] + interval: 10s + timeout: 5s + retries: 5 - memory-store-backup-cron: - image: mcuadros/ofelia:latest - restart: unless-stopped + vectorizer-worker: + image: timescale/pgai-vectorizer-worker:v0.3.0 + environment: + - PGAI_VECTORIZER_WORKER_DB_URL=postgres://postgres:${MEMORY_STORE_PASSWORD:-postgres}@memory-store:5432/postgres + - VOYAGE_API_KEY=${VOYAGE_API_KEY} + command: [ "--poll-interval", "5s" ] depends_on: - - memory-store - command: daemon --docker -f label=com.docker.compose.project=${COMPOSE_PROJECT_NAME} + memory-store: + condition: service_healthy + + migration: + image: migrate/migrate:latest volumes: - - /var/run/docker.sock:/var/run/docker.sock:ro + - ./migrations:/migrations + command: [ "-path", "/migrations", "-database", "postgres://postgres:${MEMORY_STORE_PASSWORD:-postgres}@memory-store:5432/postgres?sslmode=disable" , "up"] + + restart: "no" + develop: + watch: + - path: ./migrations + target: ./migrations + action: sync+restart + depends_on: + memory-store: + condition: service_healthy volumes: - cozo_data: - external: true - cozo_backup: + memory_store_data: external: true diff --git a/memory-store/migrations/000001_initial.down.sql b/memory-store/migrations/000001_initial.down.sql new file mode 100644 index 000000000..6f5aa4b5c --- /dev/null +++ b/memory-store/migrations/000001_initial.down.sql @@ -0,0 +1,27 @@ +BEGIN; + +-- Drop the update_updated_at_column function +DROP FUNCTION IF EXISTS update_updated_at_column (); + +-- Drop misc extensions +DROP EXTENSION IF EXISTS "uuid-ossp" CASCADE; + +DROP EXTENSION IF EXISTS citext CASCADE; + +DROP EXTENSION IF EXISTS btree_gist CASCADE; + +DROP EXTENSION IF EXISTS btree_gin CASCADE; + +-- Drop timescale's pgai extensions +DROP EXTENSION IF EXISTS ai CASCADE; + +DROP EXTENSION IF EXISTS vectorscale CASCADE; + +DROP EXTENSION IF EXISTS vector CASCADE; + +-- Drop timescaledb extensions +DROP EXTENSION IF EXISTS timescaledb_toolkit CASCADE; + +DROP EXTENSION IF EXISTS timescaledb CASCADE; + +COMMIT; \ No newline at end of file diff --git a/memory-store/migrations/000001_initial.up.sql b/memory-store/migrations/000001_initial.up.sql new file mode 100644 index 000000000..6eba5ab6c --- /dev/null +++ b/memory-store/migrations/000001_initial.up.sql @@ -0,0 +1,35 @@ +BEGIN; + +-- init timescaledb +CREATE EXTENSION IF NOT EXISTS timescaledb CASCADE; + +CREATE EXTENSION IF NOT EXISTS timescaledb_toolkit CASCADE; + +-- add timescale's pgai extension +CREATE EXTENSION IF NOT EXISTS vector CASCADE; + +CREATE EXTENSION IF NOT EXISTS vectorscale CASCADE; + +CREATE EXTENSION IF NOT EXISTS ai CASCADE; + +-- add misc extensions (for indexing etc) +CREATE EXTENSION IF NOT EXISTS btree_gin CASCADE; + +CREATE EXTENSION IF NOT EXISTS btree_gist CASCADE; + +CREATE EXTENSION IF NOT EXISTS citext CASCADE; + +CREATE EXTENSION IF NOT EXISTS "uuid-ossp" CASCADE; + +-- Create function to update the updated_at timestamp +CREATE +OR REPLACE FUNCTION update_updated_at_column () RETURNS TRIGGER AS $$ +BEGIN + NEW.updated_at = CURRENT_TIMESTAMP; + RETURN NEW; +END; +$$ language 'plpgsql'; + +COMMENT ON FUNCTION update_updated_at_column () IS 'Trigger function to automatically update updated_at timestamp'; + +COMMIT; \ No newline at end of file diff --git a/memory-store/migrations/000002_developers.down.sql b/memory-store/migrations/000002_developers.down.sql new file mode 100644 index 000000000..ea6c58509 --- /dev/null +++ b/memory-store/migrations/000002_developers.down.sql @@ -0,0 +1,4 @@ +-- Drop the table (this will automatically drop associated indexes and triggers) +DROP TABLE IF EXISTS developers CASCADE; + +-- Note: The update_updated_at_column() function is not dropped as it might be used by other tables diff --git a/memory-store/migrations/000002_developers.up.sql b/memory-store/migrations/000002_developers.up.sql new file mode 100644 index 000000000..3ee83e777 --- /dev/null +++ b/memory-store/migrations/000002_developers.up.sql @@ -0,0 +1,52 @@ +BEGIN; + +-- Create developers table +CREATE TABLE IF NOT EXISTS developers ( + developer_id UUID NOT NULL, + email TEXT NOT NULL CONSTRAINT ct_developers_email_format CHECK ( + email ~* '^[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}$' + ), + active BOOLEAN NOT NULL DEFAULT TRUE, + tags TEXT[] DEFAULT ARRAY[]::TEXT[], + settings JSONB NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + CONSTRAINT pk_developers PRIMARY KEY (developer_id), + CONSTRAINT uq_developers_email UNIQUE (email), + CONSTRAINT ct_settings_is_object CHECK (jsonb_typeof(settings) = 'object') +); + +-- Create sorted index on developer_id (optimized for UUID v7) +CREATE INDEX IF NOT EXISTS idx_developers_id_sorted ON developers (developer_id DESC) INCLUDE ( + email, + active, + tags, + settings, + created_at, + updated_at +) +WHERE + active = TRUE; + +-- Create index on email +CREATE INDEX IF NOT EXISTS idx_developers_email ON developers (email); + +-- Create GIN index for tags array +CREATE INDEX IF NOT EXISTS idx_developers_tags ON developers USING GIN (tags); + +-- Create trigger to automatically update updated_at +DO $$ +BEGIN + IF NOT EXISTS (SELECT 1 FROM pg_trigger WHERE tgname = 'trg_developers_updated_at') THEN + CREATE TRIGGER trg_developers_updated_at + BEFORE UPDATE ON developers + FOR EACH ROW + EXECUTE FUNCTION update_updated_at_column(); + END IF; +END +$$; + +-- Add comment to table +COMMENT ON TABLE developers IS 'Stores developer information including their settings and tags'; + +COMMIT; diff --git a/memory-store/migrations/000003_users.down.sql b/memory-store/migrations/000003_users.down.sql new file mode 100644 index 000000000..6bae2529e --- /dev/null +++ b/memory-store/migrations/000003_users.down.sql @@ -0,0 +1,16 @@ +BEGIN; + +-- Drop trigger first +DROP TRIGGER IF EXISTS update_users_updated_at ON users; + +-- Drop indexes +DROP INDEX IF EXISTS users_metadata_gin_idx; + +-- Drop foreign key constraint +ALTER TABLE IF EXISTS users +DROP CONSTRAINT IF EXISTS users_developer_id_fkey; + +-- Finally drop the table +DROP TABLE IF EXISTS users; + +COMMIT; diff --git a/memory-store/migrations/000003_users.up.sql b/memory-store/migrations/000003_users.up.sql new file mode 100644 index 000000000..480d39b6c --- /dev/null +++ b/memory-store/migrations/000003_users.up.sql @@ -0,0 +1,45 @@ +BEGIN; + +-- Create users table if it doesn't exist +CREATE TABLE IF NOT EXISTS users ( + developer_id UUID NOT NULL, + user_id UUID NOT NULL, + name TEXT NOT NULL, + about TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + metadata JSONB NOT NULL DEFAULT '{}'::JSONB, + CONSTRAINT pk_users PRIMARY KEY (developer_id, user_id) +); + +-- Create foreign key constraint and index if they don't exist +DO $$ BEGIN + IF NOT EXISTS ( + SELECT 1 FROM pg_constraint WHERE conname = 'users_developer_id_fkey' + ) THEN + ALTER TABLE users + ADD CONSTRAINT users_developer_id_fkey + FOREIGN KEY (developer_id) + REFERENCES developers(developer_id); + END IF; +END $$; + +-- Create a GIN index on the entire metadata column if it doesn't exist +CREATE INDEX IF NOT EXISTS users_metadata_gin_idx ON users USING GIN (metadata); + +-- Create trigger if it doesn't exist +DO $$ BEGIN + IF NOT EXISTS ( + SELECT 1 FROM pg_trigger WHERE tgname = 'update_users_updated_at' + ) THEN + CREATE TRIGGER update_users_updated_at + BEFORE UPDATE ON users + FOR EACH ROW + EXECUTE FUNCTION update_updated_at_column(); + END IF; +END $$; + +-- Add comment to table (comments are idempotent by default) +COMMENT ON TABLE users IS 'Stores user information linked to developers'; + +COMMIT; diff --git a/memory-store/migrations/000004_agents.down.sql b/memory-store/migrations/000004_agents.down.sql new file mode 100644 index 000000000..98d75058d --- /dev/null +++ b/memory-store/migrations/000004_agents.down.sql @@ -0,0 +1,12 @@ +BEGIN; + +-- Drop trigger first +DROP TRIGGER IF EXISTS trg_agents_updated_at ON agents; + +-- Drop indexes +DROP INDEX IF EXISTS idx_agents_metadata; + +-- Drop table (this will automatically drop associated constraints) +DROP TABLE IF EXISTS agents CASCADE; + +COMMIT; diff --git a/memory-store/migrations/000004_agents.up.sql b/memory-store/migrations/000004_agents.up.sql new file mode 100644 index 000000000..23de2b68d --- /dev/null +++ b/memory-store/migrations/000004_agents.up.sql @@ -0,0 +1,49 @@ +BEGIN; + +-- Create agents table +CREATE TABLE IF NOT EXISTS agents ( + developer_id UUID NOT NULL, + agent_id UUID NOT NULL, + canonical_name citext NOT NULL CONSTRAINT ct_agents_canonical_name_length CHECK ( + length(canonical_name) >= 1 + AND length(canonical_name) <= 255 + ), + name TEXT NOT NULL CONSTRAINT ct_agents_name_length CHECK ( + length(name) >= 1 + AND length(name) <= 255 + ), + about TEXT CONSTRAINT ct_agents_about_length CHECK ( + about IS NULL + OR length(about) <= 1000 + ), + instructions TEXT[] DEFAULT ARRAY[]::TEXT[], + model TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + metadata JSONB NOT NULL DEFAULT '{}'::JSONB, + default_settings JSONB NOT NULL DEFAULT '{}'::JSONB, + CONSTRAINT pk_agents PRIMARY KEY (developer_id, agent_id), + CONSTRAINT uq_agents_canonical_name_unique UNIQUE (developer_id, canonical_name), -- per developer + CONSTRAINT ct_agents_canonical_name_valid_identifier CHECK (canonical_name ~ '^[a-zA-Z][a-zA-Z0-9_]*$'), + CONSTRAINT ct_agents_metadata_is_object CHECK (jsonb_typeof(metadata) = 'object'), + CONSTRAINT ct_agents_default_settings_is_object CHECK (jsonb_typeof(default_settings) = 'object') +); + +-- Create foreign key constraint and index on developer_id +ALTER TABLE agents +DROP CONSTRAINT IF EXISTS fk_agents_developer, +ADD CONSTRAINT fk_agents_developer FOREIGN KEY (developer_id) REFERENCES developers (developer_id); + +-- Create a GIN index on the entire metadata column +CREATE INDEX IF NOT EXISTS idx_agents_metadata ON agents USING GIN (metadata); + +-- Create trigger to automatically update updated_at +CREATE +OR REPLACE TRIGGER trg_agents_updated_at BEFORE +UPDATE ON agents FOR EACH ROW +EXECUTE FUNCTION update_updated_at_column (); + +-- Add comment to table +COMMENT ON TABLE agents IS 'Stores AI agent configurations and metadata for developers'; + +COMMIT; \ No newline at end of file diff --git a/memory-store/migrations/000005_files.down.sql b/memory-store/migrations/000005_files.down.sql new file mode 100644 index 000000000..c582f7b67 --- /dev/null +++ b/memory-store/migrations/000005_files.down.sql @@ -0,0 +1,12 @@ +BEGIN; + +-- Drop file_owners table and its dependencies +DROP TRIGGER IF EXISTS trg_validate_file_owner ON file_owners; +DROP FUNCTION IF EXISTS validate_file_owner(); +DROP TABLE IF EXISTS file_owners; + +-- Drop files table and its dependencies +DROP TRIGGER IF EXISTS trg_files_updated_at ON files; +DROP TABLE IF EXISTS files; + +COMMIT; diff --git a/memory-store/migrations/000005_files.up.sql b/memory-store/migrations/000005_files.up.sql new file mode 100644 index 000000000..e408e1db2 --- /dev/null +++ b/memory-store/migrations/000005_files.up.sql @@ -0,0 +1,89 @@ +BEGIN; + +-- Create files table +CREATE TABLE IF NOT EXISTS files ( + developer_id UUID NOT NULL, + file_id UUID NOT NULL, + name TEXT NOT NULL CONSTRAINT ct_files_name_length CHECK ( + length(name) >= 1 + AND length(name) <= 255 + ), + description TEXT DEFAULT NULL CONSTRAINT ct_files_description_length CHECK ( + description IS NULL + OR length(description) <= 1000 + ), + mime_type TEXT DEFAULT NULL CONSTRAINT ct_files_mime_type_length CHECK ( + mime_type IS NULL + OR length(mime_type) <= 127 + ), + size BIGINT NOT NULL CONSTRAINT ct_files_size_positive CHECK (size > 0), + hash BYTEA NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + CONSTRAINT pk_files PRIMARY KEY (developer_id, file_id) +); + +-- Create foreign key constraint and index if they don't exist +DO $$ BEGIN + IF NOT EXISTS (SELECT 1 FROM pg_constraint WHERE conname = 'fk_files_developer') THEN + ALTER TABLE files + ADD CONSTRAINT fk_files_developer + FOREIGN KEY (developer_id) + REFERENCES developers(developer_id); + END IF; +END $$; + +-- Create trigger if it doesn't exist +DO $$ BEGIN + IF NOT EXISTS (SELECT 1 FROM pg_trigger WHERE tgname = 'trg_files_updated_at') THEN + CREATE TRIGGER trg_files_updated_at + BEFORE UPDATE ON files + FOR EACH ROW + EXECUTE FUNCTION update_updated_at_column(); + END IF; +END $$; + +-- Create the file_owners table +CREATE TABLE IF NOT EXISTS file_owners ( + developer_id UUID NOT NULL, + file_id UUID NOT NULL, + owner_type TEXT NOT NULL, -- 'user' or 'agent' + owner_id UUID NOT NULL, + CONSTRAINT pk_file_owners PRIMARY KEY (developer_id, file_id), + CONSTRAINT fk_file_owners_file FOREIGN KEY (developer_id, file_id) REFERENCES files (developer_id, file_id), + CONSTRAINT ct_file_owners_owner_type CHECK (owner_type IN ('user', 'agent')) +); + +-- Create indexes +CREATE INDEX IF NOT EXISTS idx_file_owners_owner ON file_owners (developer_id, owner_type, owner_id); + +-- Create function to validate owner reference +CREATE OR REPLACE FUNCTION validate_file_owner() +RETURNS TRIGGER AS $$ +BEGIN + IF NEW.owner_type = 'user' THEN + IF NOT EXISTS ( + SELECT 1 FROM users + WHERE developer_id = NEW.developer_id AND user_id = NEW.owner_id + ) THEN + RAISE EXCEPTION 'Invalid user reference'; + END IF; + ELSIF NEW.owner_type = 'agent' THEN + IF NOT EXISTS ( + SELECT 1 FROM agents + WHERE developer_id = NEW.developer_id AND agent_id = NEW.owner_id + ) THEN + RAISE EXCEPTION 'Invalid agent reference'; + END IF; + END IF; + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +-- Create trigger for validation +CREATE TRIGGER trg_validate_file_owner +BEFORE INSERT OR UPDATE ON file_owners +FOR EACH ROW +EXECUTE FUNCTION validate_file_owner(); + +COMMIT; diff --git a/memory-store/migrations/000006_docs.down.sql b/memory-store/migrations/000006_docs.down.sql new file mode 100644 index 000000000..f0df5a8e4 --- /dev/null +++ b/memory-store/migrations/000006_docs.down.sql @@ -0,0 +1,26 @@ +BEGIN; + +-- Drop doc_owners table and its dependencies +DROP TRIGGER IF EXISTS trg_validate_doc_owner ON doc_owners; +DROP FUNCTION IF EXISTS validate_doc_owner(); +DROP INDEX IF EXISTS idx_doc_owners_owner; +DROP TABLE IF EXISTS doc_owners CASCADE; + +-- Drop docs table and its dependencies +DROP TRIGGER IF EXISTS trg_docs_search_tsv ON docs; +DROP TRIGGER IF EXISTS trg_docs_updated_at ON docs; +DROP FUNCTION IF EXISTS docs_update_search_tsv(); + +-- Drop indexes +DROP INDEX IF EXISTS idx_docs_content_trgm; +DROP INDEX IF EXISTS idx_docs_title_trgm; +DROP INDEX IF EXISTS idx_docs_search_tsv; +DROP INDEX IF EXISTS idx_docs_metadata; + +-- Drop docs table +DROP TABLE IF EXISTS docs CASCADE; + +-- Drop language validation function +DROP FUNCTION IF EXISTS is_valid_language(text); + +COMMIT; diff --git a/memory-store/migrations/000006_docs.up.sql b/memory-store/migrations/000006_docs.up.sql new file mode 100644 index 000000000..3ec2eab07 --- /dev/null +++ b/memory-store/migrations/000006_docs.up.sql @@ -0,0 +1,206 @@ +BEGIN; + +-- Create function to validate language (make it OR REPLACE) +CREATE +OR REPLACE FUNCTION is_valid_language (lang text) RETURNS boolean AS $$ +BEGIN + RETURN EXISTS ( + SELECT 1 FROM pg_ts_config WHERE cfgname::text = lang + ); +END; +$$ LANGUAGE plpgsql; + +-- Create docs table +CREATE TABLE IF NOT EXISTS docs ( + developer_id UUID NOT NULL, + doc_id UUID NOT NULL, + title TEXT NOT NULL, + content TEXT NOT NULL, + index INTEGER NOT NULL, + modality TEXT NOT NULL, + embedding_model TEXT NOT NULL, + embedding_dimensions INTEGER NOT NULL, + language TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + metadata JSONB NOT NULL DEFAULT '{}'::JSONB, + CONSTRAINT pk_docs PRIMARY KEY (developer_id, doc_id, index), + CONSTRAINT ct_docs_embedding_dimensions_positive CHECK (embedding_dimensions > 0), + CONSTRAINT ct_docs_valid_modality CHECK (modality IN ('text', 'image', 'mixed')), + CONSTRAINT ct_docs_index_positive CHECK (index >= 0), + CONSTRAINT ct_docs_valid_language CHECK (is_valid_language (language)), + CONSTRAINT ct_metadata_is_object CHECK (jsonb_typeof(metadata) = 'object') +); + +-- Create foreign key constraint if not exists (using DO block for safety) +DO $$ +BEGIN + IF NOT EXISTS ( + SELECT 1 FROM pg_constraint WHERE conname = 'fk_docs_developer' + ) THEN + ALTER TABLE docs + ADD CONSTRAINT fk_docs_developer + FOREIGN KEY (developer_id) + REFERENCES developers(developer_id); + END IF; +END $$; + +-- Create trigger if not exists +DO $$ +BEGIN + IF NOT EXISTS ( + SELECT 1 FROM pg_trigger WHERE tgname = 'trg_docs_updated_at' + ) THEN + CREATE TRIGGER trg_docs_updated_at + BEFORE UPDATE ON docs + FOR EACH ROW + EXECUTE FUNCTION update_updated_at_column(); + END IF; +END $$; + +-- Create the doc_owners table +CREATE TABLE IF NOT EXISTS doc_owners ( + developer_id UUID NOT NULL, + doc_id UUID NOT NULL, + owner_type TEXT NOT NULL, -- 'user' or 'agent' + owner_id UUID NOT NULL, + CONSTRAINT pk_doc_owners PRIMARY KEY (developer_id, doc_id), + -- TODO: Ensure that doc exists (this constraint is not working) + -- CONSTRAINT fk_doc_owners_doc FOREIGN KEY (developer_id, doc_id) REFERENCES docs (developer_id, doc_id), + CONSTRAINT ct_doc_owners_owner_type CHECK (owner_type IN ('user', 'agent')) +); + +-- Create indexes +CREATE INDEX IF NOT EXISTS idx_doc_owners_owner ON doc_owners (developer_id, owner_type, owner_id); + + +-- Create function to validate owner reference +CREATE +OR REPLACE FUNCTION validate_doc_owner () RETURNS TRIGGER AS $$ +BEGIN + IF NEW.owner_type = 'user' THEN + IF NOT EXISTS ( + SELECT 1 FROM users + WHERE developer_id = NEW.developer_id AND user_id = NEW.owner_id + ) THEN + RAISE EXCEPTION 'Invalid user reference'; + END IF; + ELSIF NEW.owner_type = 'agent' THEN + IF NOT EXISTS ( + SELECT 1 FROM agents + WHERE developer_id = NEW.developer_id AND agent_id = NEW.owner_id + ) THEN + RAISE EXCEPTION 'Invalid agent reference'; + END IF; + END IF; + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +-- Create trigger for validation +CREATE TRIGGER trg_validate_doc_owner BEFORE INSERT +OR +UPDATE ON doc_owners FOR EACH ROW +EXECUTE FUNCTION validate_doc_owner (); + +-- Create indexes if not exists +CREATE INDEX IF NOT EXISTS idx_docs_metadata ON docs USING GIN (metadata); + +-- Enable necessary PostgreSQL extensions +CREATE EXTENSION IF NOT EXISTS unaccent; + +CREATE EXTENSION IF NOT EXISTS pg_trgm; + +CREATE EXTENSION IF NOT EXISTS dict_int CASCADE; + +CREATE EXTENSION IF NOT EXISTS dict_xsyn CASCADE; + +CREATE EXTENSION IF NOT EXISTS fuzzystrmatch CASCADE; + +-- Configure text search for all supported languages +DO $$ +DECLARE + lang text; +BEGIN + FOR lang IN (SELECT cfgname FROM pg_ts_config WHERE cfgname IN ( + 'arabic', 'danish', 'dutch', 'english', 'finnish', 'french', + 'german', 'greek', 'hungarian', 'indonesian', 'irish', 'italian', + 'lithuanian', 'nepali', 'norwegian', 'portuguese', 'romanian', + 'russian', 'spanish', 'swedish', 'tamil', 'turkish' + )) + LOOP + -- Configure integer dictionary + EXECUTE format('ALTER TEXT SEARCH CONFIGURATION %I + ALTER MAPPING FOR int, uint WITH intdict', lang); + + -- Configure synonym and stemming + EXECUTE format('ALTER TEXT SEARCH CONFIGURATION %I + ALTER MAPPING FOR asciihword, hword_asciipart, hword, hword_part, word, asciiword + WITH xsyn, %I_stem', lang, lang); + END LOOP; +END +$$; + +-- Add the search_tsv column if it doesn't exist +DO $$ +BEGIN + IF NOT EXISTS ( + SELECT 1 FROM information_schema.columns + WHERE table_name = 'docs' AND column_name = 'search_tsv' + ) THEN + ALTER TABLE docs ADD COLUMN search_tsv tsvector; + END IF; +END $$; + +-- Create function to update tsvector +CREATE +OR REPLACE FUNCTION docs_update_search_tsv () RETURNS trigger AS $$ +BEGIN + NEW.search_tsv := + setweight(to_tsvector(NEW.language::regconfig, unaccent(coalesce(NEW.title, ''))), 'A') || + setweight(to_tsvector(NEW.language::regconfig, unaccent(coalesce(NEW.content, ''))), 'B'); + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +-- Create trigger if not exists +DO $$ +BEGIN + IF NOT EXISTS ( + SELECT 1 FROM pg_trigger WHERE tgname = 'trg_docs_search_tsv' + ) THEN + CREATE TRIGGER trg_docs_search_tsv + BEFORE INSERT OR UPDATE OF title, content, language + ON docs + FOR EACH ROW + EXECUTE FUNCTION docs_update_search_tsv(); + END IF; +END $$; + +-- Create indexes if not exists +CREATE INDEX IF NOT EXISTS idx_docs_search_tsv ON docs USING GIN (search_tsv); + +CREATE INDEX IF NOT EXISTS idx_docs_title_trgm ON docs USING GIN (title gin_trgm_ops); + +CREATE INDEX IF NOT EXISTS idx_docs_content_trgm ON docs USING GIN (content gin_trgm_ops); + +-- Update existing rows (if any) +UPDATE docs +SET + search_tsv = setweight( + to_tsvector( + language::regconfig, + unaccent (coalesce(title, '')) + ), + 'A' + ) || setweight( + to_tsvector( + language::regconfig, + unaccent (coalesce(content, '')) + ), + 'B' + ) +WHERE + search_tsv IS NULL; + +COMMIT; diff --git a/memory-store/migrations/000007_ann.down.sql b/memory-store/migrations/000007_ann.down.sql new file mode 100644 index 000000000..2458c3dbd --- /dev/null +++ b/memory-store/migrations/000007_ann.down.sql @@ -0,0 +1,17 @@ +BEGIN; + +DO $$ +DECLARE + vectorizer_id INTEGER; +BEGIN + SELECT id INTO vectorizer_id + FROM ai.vectorizer + WHERE source_table = 'docs'; + + -- Drop the vectorizer if it exists + IF vectorizer_id IS NOT NULL THEN + PERFORM ai.drop_vectorizer(vectorizer_id, drop_all => true); + END IF; +END $$; + +COMMIT; diff --git a/memory-store/migrations/000007_ann.up.sql b/memory-store/migrations/000007_ann.up.sql new file mode 100644 index 000000000..725a78786 --- /dev/null +++ b/memory-store/migrations/000007_ann.up.sql @@ -0,0 +1,48 @@ +/* + * VECTOR SIMILARITY SEARCH WITH DISKANN (Complexity: 8/10) + * Uses TimescaleDB's vectorizer to convert text into high-dimensional vectors for semantic search. + * Implements DiskANN (Disk-based Approximate Nearest Neighbor) for efficient similarity search at scale. + * Includes smart text chunking to handle large documents while preserving context and semantic meaning. + */ + +-- Create vector similarity search index using diskann and timescale vectorizer +SELECT + ai.create_vectorizer ( + source => 'docs', + destination => 'docs_embeddings', + embedding => ai.embedding_voyageai ('voyage-3', 1024, 'document'), -- need to parameterize this + -- actual chunking is managed by the docs table + -- this is to prevent running out of context window + chunking => ai.chunking_recursive_character_text_splitter ( + chunk_column => 'content', + chunk_size => 30000, -- 30k characters ~= 7.5k tokens + chunk_overlap => 600, -- 600 characters ~= 150 tokens + separators => ARRAY[ -- tries separators in order + -- markdown headers + E'\n#', + E'\n##', + E'\n###', + E'\n---', + E'\n***', + -- html tags + E'', -- Split on major document sections + E'', -- Split on div boundaries + E'', + E'

', -- Split on paragraphs + E'
', -- Split on line breaks + -- other separators + E'\n\n', -- paragraphs + '. ', + '? ', + '! ', + '; ', -- sentences (note space after punctuation) + E'\n', -- line breaks + ' ' -- words (last resort) + ] + ), + scheduling => ai.scheduling_timescaledb (), + indexing => ai.indexing_diskann (), + formatting => ai.formatting_python_template (E'Title: $title\n\n$chunk'), + processing => ai.processing_default (), + enqueue_existing => TRUE + ); diff --git a/memory-store/migrations/000008_tools.down.sql b/memory-store/migrations/000008_tools.down.sql new file mode 100644 index 000000000..2fa3077c0 --- /dev/null +++ b/memory-store/migrations/000008_tools.down.sql @@ -0,0 +1,6 @@ +BEGIN; + +-- Drop table and all its dependent objects (indexes, constraints, triggers) +DROP TABLE IF EXISTS tools CASCADE; + +COMMIT; diff --git a/memory-store/migrations/000008_tools.up.sql b/memory-store/migrations/000008_tools.up.sql new file mode 100644 index 000000000..ad5db146c --- /dev/null +++ b/memory-store/migrations/000008_tools.up.sql @@ -0,0 +1,56 @@ +BEGIN; + +-- Create tools table if it doesn't exist +CREATE TABLE IF NOT EXISTS tools ( + developer_id UUID NOT NULL, + agent_id UUID NOT NULL, + tool_id UUID NOT NULL, + task_id UUID DEFAULT NULL, + type TEXT NOT NULL CONSTRAINT ct_tools_type_length CHECK ( + length(type) >= 1 + AND length(type) <= 255 + ), + name TEXT NOT NULL CONSTRAINT ct_tools_name_length CHECK ( + length(name) >= 1 + AND length(name) <= 255 + ), + description TEXT CONSTRAINT ct_tools_description_length CHECK ( + description IS NULL + OR length(description) <= 1000 + ), + spec JSONB NOT NULL, + updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + CONSTRAINT pk_tools PRIMARY KEY (developer_id, agent_id, tool_id), + CONSTRAINT ct_unique_name_per_agent UNIQUE (agent_id, name, task_id), + CONSTRAINT ct_spec_is_object CHECK (jsonb_typeof(spec) = 'object') +); + +-- Create sorted index on task_id if it doesn't exist +CREATE INDEX IF NOT EXISTS idx_tools_task_id_sorted ON tools (task_id DESC) +WHERE + task_id IS NOT NULL; + +-- Create foreign key constraint and index if they don't exist +DO $$ BEGIN + IF NOT EXISTS ( + SELECT 1 FROM pg_constraint WHERE conname = 'fk_tools_agent' + ) THEN + ALTER TABLE tools + ADD CONSTRAINT fk_tools_agent + FOREIGN KEY (developer_id, agent_id) + REFERENCES agents(developer_id, agent_id) ON DELETE CASCADE; + END IF; +END $$; + +-- Drop trigger if exists and recreate +DROP TRIGGER IF EXISTS trg_tools_updated_at ON tools; + +CREATE TRIGGER trg_tools_updated_at BEFORE +UPDATE ON tools FOR EACH ROW +EXECUTE FUNCTION update_updated_at_column (); + +-- Add comment to table +COMMENT ON TABLE tools IS 'Stores tool configurations and specifications for AI agents'; + +COMMIT; diff --git a/memory-store/migrations/000009_sessions.down.sql b/memory-store/migrations/000009_sessions.down.sql new file mode 100644 index 000000000..33d535e53 --- /dev/null +++ b/memory-store/migrations/000009_sessions.down.sql @@ -0,0 +1,22 @@ +BEGIN; + +-- Drop triggers first +DROP TRIGGER IF EXISTS trg_validate_participant_before_update ON session_lookup; + +DROP TRIGGER IF EXISTS trg_validate_participant_before_insert ON session_lookup; + +-- Drop the validation function +DROP FUNCTION IF EXISTS validate_participant (); + +-- Drop session_lookup table and its indexes +DROP TABLE IF EXISTS session_lookup; + +-- Drop sessions table and its indexes +DROP TRIGGER IF EXISTS trg_sessions_updated_at ON sessions; + +DROP TABLE IF EXISTS sessions CASCADE; + +-- Drop the enum type +DROP TYPE IF EXISTS participant_type; + +COMMIT; \ No newline at end of file diff --git a/memory-store/migrations/000009_sessions.up.sql b/memory-store/migrations/000009_sessions.up.sql new file mode 100644 index 000000000..5c7a8717b --- /dev/null +++ b/memory-store/migrations/000009_sessions.up.sql @@ -0,0 +1,131 @@ +BEGIN; + +-- Create sessions table if it doesn't exist +CREATE TABLE IF NOT EXISTS sessions ( + developer_id UUID NOT NULL, + session_id UUID NOT NULL, + situation TEXT, + system_template TEXT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + metadata JSONB NOT NULL DEFAULT '{}'::JSONB, + render_templates BOOLEAN NOT NULL DEFAULT TRUE, + token_budget INTEGER, + context_overflow TEXT, + forward_tool_calls BOOLEAN, + recall_options JSONB NOT NULL DEFAULT '{}'::JSONB, + CONSTRAINT pk_sessions PRIMARY KEY (developer_id, session_id), + CONSTRAINT uq_sessions_session_id UNIQUE (session_id), + CONSTRAINT ct_sessions_token_budget_positive CHECK ( + token_budget IS NULL + OR token_budget > 0 + ), + CONSTRAINT ct_sessions_context_overflow_valid CHECK ( + context_overflow IS NULL + OR context_overflow IN ('truncate', 'adaptive') + ), + CONSTRAINT ct_sessions_system_template_not_empty CHECK (length(trim(system_template)) > 0), + CONSTRAINT ct_sessions_situation_not_empty CHECK ( + situation IS NULL + OR length(trim(situation)) > 0 + ), + CONSTRAINT ct_sessions_metadata_is_object CHECK (jsonb_typeof(metadata) = 'object'), + CONSTRAINT ct_sessions_recall_options_is_object CHECK (jsonb_typeof(recall_options) = 'object') +); + +CREATE INDEX IF NOT EXISTS idx_sessions_metadata ON sessions USING GIN (metadata); + +-- Create foreign key if it doesn't exist +DO $$ +BEGIN + IF NOT EXISTS ( + SELECT 1 FROM pg_constraint WHERE conname = 'fk_sessions_developer' + ) THEN + ALTER TABLE sessions + ADD CONSTRAINT fk_sessions_developer + FOREIGN KEY (developer_id) + REFERENCES developers(developer_id); + END IF; +END $$; + +-- Create trigger if it doesn't exist +DO $$ +BEGIN + IF NOT EXISTS ( + SELECT 1 FROM pg_trigger WHERE tgname = 'trg_sessions_updated_at' + ) THEN + CREATE TRIGGER trg_sessions_updated_at + BEFORE UPDATE ON sessions + FOR EACH ROW + EXECUTE FUNCTION update_updated_at_column(); + END IF; +END $$; + +-- Create participant_type enum if it doesn't exist +DO $$ +BEGIN + IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'participant_type') THEN + CREATE TYPE participant_type AS ENUM ('user', 'agent'); + END IF; +END $$; + +-- Create session_lookup table if it doesn't exist +CREATE TABLE IF NOT EXISTS session_lookup ( + developer_id UUID NOT NULL, + session_id UUID NOT NULL, + participant_type participant_type NOT NULL, + participant_id UUID NOT NULL, + PRIMARY KEY ( + developer_id, + session_id, + participant_type, + participant_id + ), + FOREIGN KEY (developer_id, session_id) REFERENCES sessions (developer_id, session_id) ON DELETE CASCADE +); + +CREATE INDEX IF NOT EXISTS idx_session_lookup_by_participant ON session_lookup (developer_id, participant_type, participant_id); + +-- Create or replace the validation function +CREATE +OR REPLACE FUNCTION validate_participant () RETURNS trigger AS $$ +BEGIN + IF NEW.participant_type = 'user' THEN + PERFORM 1 FROM users WHERE developer_id = NEW.developer_id AND user_id = NEW.participant_id; + IF NOT FOUND THEN + RAISE EXCEPTION 'Invalid participant_id: % for participant_type user', NEW.participant_id; + END IF; + ELSIF NEW.participant_type = 'agent' THEN + PERFORM 1 FROM agents WHERE developer_id = NEW.developer_id AND agent_id = NEW.participant_id; + IF NOT FOUND THEN + RAISE EXCEPTION 'Invalid participant_id: % for participant_type agent', NEW.participant_id; + END IF; + ELSE + RAISE EXCEPTION 'Unknown participant_type: %', NEW.participant_type; + END IF; + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +-- Create triggers if they don't exist +DO $$ +BEGIN + IF NOT EXISTS ( + SELECT 1 FROM pg_trigger WHERE tgname = 'trg_validate_participant_before_insert' + ) THEN + CREATE TRIGGER trg_validate_participant_before_insert + BEFORE INSERT ON session_lookup + FOR EACH ROW + EXECUTE FUNCTION validate_participant(); + END IF; + IF NOT EXISTS ( + SELECT 1 FROM pg_trigger WHERE tgname = 'trg_validate_participant_before_update' + ) THEN + CREATE TRIGGER trg_validate_participant_before_update + BEFORE UPDATE ON session_lookup + FOR EACH ROW + EXECUTE FUNCTION validate_participant(); + END IF; +END $$; + +COMMIT; diff --git a/memory-store/migrations/000010_tasks.down.sql b/memory-store/migrations/000010_tasks.down.sql new file mode 100644 index 000000000..3b9b05b8b --- /dev/null +++ b/memory-store/migrations/000010_tasks.down.sql @@ -0,0 +1,26 @@ +BEGIN; + +-- Drop the foreign key constraint from tools table if it exists +DO $$ +BEGIN + IF EXISTS ( + SELECT + 1 + FROM + information_schema.table_constraints + WHERE + constraint_name = 'fk_tools_task_id' + ) THEN + ALTER TABLE tools + DROP CONSTRAINT fk_tools_task_id; + + END IF; +END $$; + +-- Drop the workflows table first since it depends on tasks +DROP TABLE IF EXISTS workflows CASCADE; + +-- Drop the tasks table and all its dependent objects (CASCADE will handle indexes, triggers, and constraints) +DROP TABLE IF EXISTS tasks CASCADE; + +COMMIT; \ No newline at end of file diff --git a/memory-store/migrations/000010_tasks.up.sql b/memory-store/migrations/000010_tasks.up.sql new file mode 100644 index 000000000..ce711d079 --- /dev/null +++ b/memory-store/migrations/000010_tasks.up.sql @@ -0,0 +1,127 @@ +BEGIN; + +/* + * DEFERRED FOREIGN KEY CONSTRAINTS (Complexity: 6/10) + * Uses PostgreSQL's deferred constraints to handle complex relationships between tasks and tools tables. + * Constraints are checked at transaction commit rather than immediately, allowing circular references. + * This enables more flexible data loading patterns while maintaining referential integrity. + */ +-- Create tasks table if it doesn't exist +CREATE TABLE IF NOT EXISTS tasks ( + developer_id UUID NOT NULL, + canonical_name CITEXT NOT NULL CONSTRAINT ct_tasks_canonical_name_length CHECK ( + length(canonical_name) >= 1 + AND length(canonical_name) <= 255 + ), + agent_id UUID NOT NULL, + task_id UUID NOT NULL, + "version" INTEGER NOT NULL DEFAULT 1, + name TEXT NOT NULL CONSTRAINT ct_tasks_name_length CHECK ( + length(name) >= 1 + AND length(name) <= 255 + ), + description TEXT DEFAULT NULL CONSTRAINT ct_tasks_description_length CHECK ( + description IS NULL + OR length(description) <= 1000 + ), + input_schema JSONB NOT NULL, + inherit_tools BOOLEAN DEFAULT FALSE, + created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + metadata JSONB DEFAULT '{}'::JSONB, + CONSTRAINT pk_tasks PRIMARY KEY (developer_id, task_id, "version"), + CONSTRAINT uq_tasks_canonical_name_unique UNIQUE (developer_id, canonical_name, "version"), + CONSTRAINT fk_tasks_agent FOREIGN KEY (developer_id, agent_id) REFERENCES agents (developer_id, agent_id) ON DELETE CASCADE, + CONSTRAINT ct_tasks_canonical_name_valid_identifier CHECK (canonical_name ~ '^[a-zA-Z][a-zA-Z0-9_]*$'), + CONSTRAINT ct_tasks_metadata_is_object CHECK (jsonb_typeof(metadata) = 'object'), + CONSTRAINT ct_tasks_input_schema_is_object CHECK (jsonb_typeof(input_schema) = 'object'), + CONSTRAINT ct_tasks_version_positive CHECK ("version" > 0) +); + +-- Create sorted index on task_id if it doesn't exist +DO $$ +BEGIN + IF NOT EXISTS (SELECT 1 FROM pg_indexes WHERE indexname = 'idx_tasks_id_sorted') THEN + CREATE INDEX idx_tasks_id_sorted ON tasks (task_id DESC); + END IF; +END $$; + +-- Create index on canonical_name if it doesn't exist +DO $$ +BEGIN + IF NOT EXISTS (SELECT 1 FROM pg_indexes WHERE indexname = 'idx_tasks_canonical_name') THEN + CREATE INDEX idx_tasks_canonical_name ON tasks (developer_id DESC, canonical_name); + END IF; +END $$; + +-- Create a GIN index on metadata if it doesn't exist +DO $$ +BEGIN + IF NOT EXISTS (SELECT 1 FROM pg_indexes WHERE indexname = 'idx_tasks_metadata') THEN + CREATE INDEX idx_tasks_metadata ON tasks USING GIN (metadata); + END IF; +END $$; + +-- Create function to validate owner reference +CREATE OR REPLACE FUNCTION validate_tool_task() +RETURNS TRIGGER AS $$ +BEGIN + IF NEW.task_id IS NOT NULL THEN + IF NOT EXISTS ( + SELECT 1 FROM tasks + WHERE developer_id = NEW.developer_id AND task_id = NEW.task_id + ) THEN + RAISE EXCEPTION 'Invalid task reference'; + END IF; + END IF; + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +-- Create trigger for validation +CREATE TRIGGER trg_validate_tool_task +BEFORE INSERT OR UPDATE ON tools +FOR EACH ROW +EXECUTE FUNCTION validate_tool_task(); + +-- Create updated_at trigger if it doesn't exist +DO $$ +BEGIN + IF NOT EXISTS ( + SELECT 1 + FROM pg_trigger + WHERE tgname = 'trg_tasks_updated_at' + ) THEN + CREATE TRIGGER trg_tasks_updated_at + BEFORE UPDATE ON tasks + FOR EACH ROW + EXECUTE FUNCTION update_updated_at_column(); + END IF; +END $$; + +-- Add comment to table (comments are idempotent by default) +COMMENT ON TABLE tasks IS 'Stores tasks associated with AI agents for developers'; + +-- Create 'workflows' table +CREATE TABLE IF NOT EXISTS workflows ( + developer_id UUID NOT NULL, + task_id UUID NOT NULL, + "version" INTEGER NOT NULL, + name TEXT NOT NULL CONSTRAINT ct_workflows_name_length CHECK ( + length(name) >= 1 + AND length(name) <= 255 + ), + step_idx INTEGER NOT NULL CONSTRAINT ct_workflows_step_idx_positive CHECK (step_idx >= 0), + step_type TEXT NOT NULL CONSTRAINT ct_workflows_step_type_length CHECK ( + length(step_type) >= 1 + AND length(step_type) <= 255 + ), + step_definition JSONB NOT NULL CONSTRAINT ct_workflows_step_definition_valid CHECK (jsonb_typeof(step_definition) = 'object'), + CONSTRAINT pk_workflows PRIMARY KEY (developer_id, task_id, "version", name, step_idx), + CONSTRAINT fk_workflows_tasks FOREIGN KEY (developer_id, task_id, "version") REFERENCES tasks (developer_id, task_id, "version") ON DELETE CASCADE +); + +-- Add comment to 'workflows' table +COMMENT ON TABLE workflows IS 'Stores normalized workflows for tasks'; + +COMMIT; diff --git a/memory-store/migrations/000011_executions.down.sql b/memory-store/migrations/000011_executions.down.sql new file mode 100644 index 000000000..e6c362d0e --- /dev/null +++ b/memory-store/migrations/000011_executions.down.sql @@ -0,0 +1,5 @@ +BEGIN; + +DROP TABLE IF EXISTS executions CASCADE; + +COMMIT; diff --git a/memory-store/migrations/000011_executions.up.sql b/memory-store/migrations/000011_executions.up.sql new file mode 100644 index 000000000..f2811204e --- /dev/null +++ b/memory-store/migrations/000011_executions.up.sql @@ -0,0 +1,35 @@ +BEGIN; + +-- Create executions table if it doesn't exist +CREATE TABLE IF NOT EXISTS executions ( + developer_id UUID NOT NULL, + task_id UUID NOT NULL, + task_version INTEGER NOT NULL, + execution_id UUID NOT NULL, + input JSONB NOT NULL, + -- NOTE: These will be generated using continuous aggregates from transitions + -- status TEXT DEFAULT 'pending', + -- output JSONB DEFAULT NULL, + -- error TEXT DEFAULT NULL, + -- updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + metadata JSONB NOT NULL DEFAULT '{}'::JSONB, + created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + CONSTRAINT pk_executions PRIMARY KEY (execution_id), + CONSTRAINT fk_executions_developer FOREIGN KEY (developer_id) REFERENCES developers (developer_id), + CONSTRAINT fk_executions_task FOREIGN KEY (developer_id, task_id, task_version) REFERENCES tasks (developer_id, task_id, "version"), + CONSTRAINT ct_metadata_is_object CHECK (jsonb_typeof(metadata) = 'object') +); + +-- Create index on developer_id +CREATE INDEX IF NOT EXISTS idx_executions_developer_id ON executions (developer_id); + +-- Create index on task_id +CREATE INDEX IF NOT EXISTS idx_executions_task_id ON executions (task_id, task_version); + +-- Create a GIN index on the metadata column +CREATE INDEX IF NOT EXISTS idx_executions_metadata ON executions USING GIN (metadata); + +-- Add comment to table (comments are idempotent by default) +COMMENT ON TABLE executions IS 'Stores executions associated with AI agents for developers'; + +COMMIT; diff --git a/memory-store/migrations/000012_transitions.down.sql b/memory-store/migrations/000012_transitions.down.sql new file mode 100644 index 000000000..e6171b495 --- /dev/null +++ b/memory-store/migrations/000012_transitions.down.sql @@ -0,0 +1,29 @@ +BEGIN; + +-- Drop foreign key constraint if exists +ALTER TABLE IF EXISTS transitions +DROP CONSTRAINT IF EXISTS fk_transitions_execution; + +-- Drop indexes if they exist +DROP INDEX IF EXISTS idx_transitions_metadata; + +DROP INDEX IF EXISTS idx_transitions_label; + +DROP INDEX IF EXISTS idx_transitions_next; + +DROP INDEX IF EXISTS idx_transitions_current; + +-- Drop the transitions table (this will also remove it from hypertables) +DROP TABLE IF EXISTS transitions; + +-- Drop custom types if they exist +DROP TYPE IF EXISTS transition_cursor; + +DROP TYPE IF EXISTS transition_type; + +-- Drop the trigger and function for transition validation +DROP TRIGGER IF EXISTS validate_transition ON transitions; + +DROP FUNCTION IF EXISTS check_valid_transition (); + +COMMIT; diff --git a/memory-store/migrations/000012_transitions.up.sql b/memory-store/migrations/000012_transitions.up.sql new file mode 100644 index 000000000..93e08157c --- /dev/null +++ b/memory-store/migrations/000012_transitions.up.sql @@ -0,0 +1,166 @@ +BEGIN; + +/* + * CUSTOM TYPES AND ENUMS WITH COMPLEX CONSTRAINTS (Complexity: 7/10) + * Creates custom composite type transition_cursor to track workflow state and enum type for transition states. + * Uses compound primary key combining timestamps and UUIDs for efficient time-series operations. + * Implements complex indexing strategy optimized for various query patterns (current state, next state, labels). + */ + +-- Create transition type enum if it doesn't exist +DO $$ +BEGIN + IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'transition_type') THEN + CREATE TYPE transition_type AS ENUM ( + 'init', + 'finish', + 'init_branch', + 'finish_branch', + 'wait', + 'resume', + 'error', + 'step', + 'cancelled' + ); + END IF; +END $$; + +-- Create transition cursor type if it doesn't exist +DO $$ +BEGIN + IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'transition_cursor') THEN + CREATE TYPE transition_cursor AS ( + workflow_name TEXT, + step_index INT + ); + END IF; +END $$; + +-- Create transitions table if it doesn't exist +CREATE TABLE IF NOT EXISTS transitions ( + created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + execution_id UUID NOT NULL, + transition_id UUID NOT NULL, + type transition_type NOT NULL, + step_definition JSONB NOT NULL, + step_label TEXT DEFAULT NULL, + current_step transition_cursor NOT NULL, + next_step transition_cursor DEFAULT NULL, + output JSONB, + task_token TEXT DEFAULT NULL, + metadata JSONB DEFAULT '{}'::JSONB, + CONSTRAINT pk_transitions PRIMARY KEY (created_at, execution_id, transition_id), + CONSTRAINT ct_step_definition_is_object CHECK (jsonb_typeof(step_definition) = 'object'), + CONSTRAINT ct_metadata_is_object CHECK (jsonb_typeof(metadata) = 'object') +); + +-- Convert to hypertable if not already +SELECT + create_hypertable ( + 'transitions', + by_range ('created_at', INTERVAL '1 day'), + if_not_exists => TRUE + ); + +SELECT + add_dimension ( + 'transitions', + by_hash ('execution_id', 2), + if_not_exists => TRUE + ); + +-- Create indexes if they don't exist +DO $$ +BEGIN + IF NOT EXISTS (SELECT 1 FROM pg_indexes WHERE indexname = 'idx_transitions_current') THEN + CREATE UNIQUE INDEX idx_transitions_current ON transitions (execution_id, current_step, created_at DESC); + END IF; + + IF NOT EXISTS (SELECT 1 FROM pg_indexes WHERE indexname = 'idx_transitions_next') THEN + CREATE UNIQUE INDEX idx_transitions_next ON transitions (execution_id, next_step, created_at DESC) + WHERE next_step IS NOT NULL; + END IF; + + IF NOT EXISTS (SELECT 1 FROM pg_indexes WHERE indexname = 'idx_transitions_label') THEN + CREATE UNIQUE INDEX idx_transitions_label ON transitions (execution_id, step_label, created_at DESC) + WHERE step_label IS NOT NULL; + END IF; + + IF NOT EXISTS (SELECT 1 FROM pg_indexes WHERE indexname = 'idx_transitions_metadata') THEN + CREATE INDEX idx_transitions_metadata ON transitions USING GIN (metadata); + END IF; +END $$; + +-- Add foreign key constraint if it doesn't exist +DO $$ +BEGIN + IF NOT EXISTS (SELECT 1 FROM pg_constraint WHERE conname = 'fk_transitions_execution') THEN + ALTER TABLE transitions + ADD CONSTRAINT fk_transitions_execution + FOREIGN KEY (execution_id) + REFERENCES executions(execution_id) + ON DELETE CASCADE; + END IF; +END $$; + +-- Add comment to table +COMMENT ON TABLE transitions IS 'Stores transitions associated with AI agents for developers'; + +-- Create a trigger function that checks for valid transitions +CREATE +OR REPLACE FUNCTION check_valid_transition () RETURNS trigger AS $$ +DECLARE + previous_type transition_type; + valid_next_types transition_type[]; +BEGIN + -- Get the latest transition_type for this execution_id + SELECT t.type INTO previous_type + FROM transitions t + WHERE t.execution_id = NEW.execution_id + ORDER BY t.created_at DESC + LIMIT 1; + + IF previous_type IS NULL THEN + -- If there is no previous transition, allow only 'init' or 'init_branch' + IF NEW.type NOT IN ('init', 'init_branch', 'error', 'cancelled') THEN + RAISE EXCEPTION 'First transition must be init / init_branch / error / cancelled, got %', NEW.type; + END IF; + ELSE + -- Define the valid_next_types array based on previous_type + CASE previous_type + WHEN 'init' THEN + valid_next_types := ARRAY['wait', 'error', 'step', 'cancelled', 'init_branch', 'finish']; + WHEN 'init_branch' THEN + valid_next_types := ARRAY['wait', 'error', 'step', 'cancelled', 'init_branch', 'finish_branch', 'finish']; + WHEN 'wait' THEN + valid_next_types := ARRAY['resume', 'step', 'cancelled', 'finish', 'finish_branch']; + WHEN 'resume' THEN + valid_next_types := ARRAY['wait', 'error', 'cancelled', 'step', 'finish', 'finish_branch', 'init_branch']; + WHEN 'step' THEN + valid_next_types := ARRAY['wait', 'error', 'cancelled', 'step', 'finish', 'finish_branch', 'init_branch']; + WHEN 'finish_branch' THEN + valid_next_types := ARRAY['wait', 'error', 'cancelled', 'step', 'finish', 'init_branch', 'finish_branch']; + WHEN 'finish' THEN + valid_next_types := ARRAY[]::transition_type[]; -- No valid next transitions + WHEN 'error' THEN + valid_next_types := ARRAY[]::transition_type[]; -- No valid next transitions + WHEN 'cancelled' THEN + valid_next_types := ARRAY[]::transition_type[]; -- No valid next transitions + ELSE + RAISE EXCEPTION 'Unknown previous transition type: %', previous_type; + END CASE; + + IF NOT NEW.type = ANY(valid_next_types) THEN + RAISE EXCEPTION 'Invalid transition from % to %', previous_type, NEW.type; + END IF; + END IF; + + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +-- Create a trigger on the transitions table +CREATE TRIGGER validate_transition BEFORE INSERT ON transitions FOR EACH ROW +EXECUTE FUNCTION check_valid_transition (); + +COMMIT; diff --git a/memory-store/migrations/000013_executions_continuous_view.down.sql b/memory-store/migrations/000013_executions_continuous_view.down.sql new file mode 100644 index 000000000..fcab7b023 --- /dev/null +++ b/memory-store/migrations/000013_executions_continuous_view.down.sql @@ -0,0 +1,15 @@ +BEGIN; + +-- Drop the continuous aggregate policy +SELECT + remove_continuous_aggregate_policy ('latest_transitions'); + +-- Drop the views +DROP VIEW IF EXISTS latest_executions; + +DROP MATERIALIZED VIEW IF EXISTS latest_transitions; + +-- Drop the helper function +DROP FUNCTION IF EXISTS to_text (transition_type); + +COMMIT; diff --git a/memory-store/migrations/000013_executions_continuous_view.up.sql b/memory-store/migrations/000013_executions_continuous_view.up.sql new file mode 100644 index 000000000..34bcfdb69 --- /dev/null +++ b/memory-store/migrations/000013_executions_continuous_view.up.sql @@ -0,0 +1,93 @@ +BEGIN; + +/* + * CONTINUOUS AGGREGATES WITH STATE AGGREGATION (Complexity: 9/10) + * This is a TimescaleDB feature that automatically maintains a real-time summary of the transitions table. + * It uses special aggregation functions like state_agg() to track state changes and last() to get most recent values. + * The view updates every 10 minutes and can serve both historical and real-time data (materialized_only = FALSE). + */ +-- create a function to convert transition_type to text (needed coz ::text is stable not immutable) +CREATE +OR REPLACE function to_text (transition_type) RETURNS text AS $$ + select $1 +$$ STRICT IMMUTABLE LANGUAGE sql; + +-- create a continuous view that aggregates the transitions table +CREATE MATERIALIZED VIEW IF NOT EXISTS latest_transitions +WITH + ( + timescaledb.continuous, + timescaledb.materialized_only = FALSE + ) AS +SELECT + time_bucket ('1 day', created_at) AS bucket, + execution_id, + last(transition_id, created_at) AS transition_id, + count(*) AS total_transitions, + state_agg(created_at, to_text(type)) AS state, + max(created_at) AS created_at, + last(type, created_at) AS type, + last(step_definition, created_at) AS step_definition, + last(step_label, created_at) AS step_label, + last(current_step, created_at) AS current_step, + last(next_step, created_at) AS next_step, + last(output, created_at) AS output, + last(task_token, created_at) AS task_token, + last(metadata, created_at) AS metadata +FROM + transitions +GROUP BY + bucket, + execution_id +WITH + no data; + +SELECT + add_continuous_aggregate_policy ( + 'latest_transitions', + start_offset => NULL, + end_offset => INTERVAL '10 minutes', + schedule_interval => INTERVAL '10 minutes' + ); + +-- Create a view that combines executions with their latest transitions +CREATE OR REPLACE VIEW latest_executions AS +SELECT + e.developer_id, + e.task_id, + e.task_version, + e.execution_id, + e.input, + e.metadata, + e.created_at, + coalesce(lt.created_at, e.created_at) AS updated_at, + CASE + WHEN lt.type::text IS NULL THEN 'queued' + WHEN lt.type::text = 'init' THEN 'starting' + WHEN lt.type::text = 'init_branch' THEN 'running' + WHEN lt.type::text = 'wait' THEN 'awaiting_input' + WHEN lt.type::text = 'resume' THEN 'running' + WHEN lt.type::text = 'step' THEN 'running' + WHEN lt.type::text = 'finish' THEN 'succeeded' + WHEN lt.type::text = 'finish_branch' THEN 'running' + WHEN lt.type::text = 'error' THEN 'failed' + WHEN lt.type::text = 'cancelled' THEN 'cancelled' + ELSE 'queued' + END AS status, + CASE + WHEN lt.type::text = 'error' THEN lt.output ->> 'error' + ELSE NULL + END AS error, + coalesce(lt.total_transitions, 0) AS total_transitions, + coalesce(lt.output, '{}'::jsonb) AS output, + lt.current_step, + lt.next_step, + lt.step_definition, + lt.step_label, + lt.task_token, + lt.metadata AS transition_metadata +FROM + executions e + LEFT JOIN latest_transitions lt ON e.execution_id = lt.execution_id; + +COMMIT; \ No newline at end of file diff --git a/memory-store/migrations/000014_temporal_lookup.down.sql b/memory-store/migrations/000014_temporal_lookup.down.sql new file mode 100644 index 000000000..ff501819b --- /dev/null +++ b/memory-store/migrations/000014_temporal_lookup.down.sql @@ -0,0 +1,5 @@ +BEGIN; + +DROP TABLE IF EXISTS temporal_executions_lookup CASCADE; + +COMMIT; \ No newline at end of file diff --git a/memory-store/migrations/000014_temporal_lookup.up.sql b/memory-store/migrations/000014_temporal_lookup.up.sql new file mode 100644 index 000000000..a87ead3cf --- /dev/null +++ b/memory-store/migrations/000014_temporal_lookup.up.sql @@ -0,0 +1,18 @@ +BEGIN; + +-- Create temporal_executions_lookup table +CREATE TABLE IF NOT EXISTS temporal_executions_lookup ( + execution_id UUID NOT NULL, + id TEXT NOT NULL, + run_id TEXT, + first_execution_run_id TEXT, + result_run_id TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + CONSTRAINT pk_temporal_executions_lookup PRIMARY KEY (execution_id, id), + CONSTRAINT fk_temporal_executions_lookup_execution FOREIGN KEY (execution_id) REFERENCES executions (execution_id) ON DELETE CASCADE +); + +-- Add comment to table +COMMENT ON TABLE temporal_executions_lookup IS 'Stores temporal workflow execution lookup data for AI agent executions'; + +COMMIT; \ No newline at end of file diff --git a/memory-store/migrations/000015_entries.down.sql b/memory-store/migrations/000015_entries.down.sql new file mode 100644 index 000000000..fdfd6c8dd --- /dev/null +++ b/memory-store/migrations/000015_entries.down.sql @@ -0,0 +1,23 @@ +BEGIN; + +DROP TRIGGER IF EXISTS trg_optimized_update_token_count_after ON entries; + +DROP FUNCTION IF EXISTS optimized_update_token_count_after; + +-- Drop foreign key constraint if it exists +ALTER TABLE IF EXISTS entries +DROP CONSTRAINT IF EXISTS fk_entries_session; + +-- Drop indexes +DROP INDEX IF EXISTS idx_entries_by_session; + +-- Drop the hypertable (this will also drop the table) +DROP TABLE IF EXISTS entries; + +-- Drop the function +DROP FUNCTION IF EXISTS all_jsonb_elements_are_objects; + +-- Drop the enum type +DROP TYPE IF EXISTS chat_role; + +COMMIT; \ No newline at end of file diff --git a/memory-store/migrations/000015_entries.up.sql b/memory-store/migrations/000015_entries.up.sql new file mode 100644 index 000000000..5b9302f05 --- /dev/null +++ b/memory-store/migrations/000015_entries.up.sql @@ -0,0 +1,137 @@ +BEGIN; + +-- Create chat_role enum +CREATE TYPE chat_role AS ENUM( + 'user', + 'assistant', + 'tool', + 'system', + 'developer' +); + +-- Create a custom function that checks if `content` is non-empty +-- and that every JSONB element in the array is an 'object'. +CREATE +OR REPLACE FUNCTION all_jsonb_elements_are_objects (content jsonb[]) RETURNS boolean AS $$ +DECLARE + elem jsonb; +BEGIN + -- Check each element in the `content` array + FOREACH elem IN ARRAY content + LOOP + IF jsonb_typeof(elem) <> 'object' THEN + RETURN false; + END IF; + END LOOP; + + RETURN true; +END; +$$ LANGUAGE plpgsql IMMUTABLE; + +CREATE TABLE IF NOT EXISTS entries ( + session_id UUID NOT NULL, + entry_id UUID NOT NULL, + source TEXT NOT NULL, + role chat_role NOT NULL, + event_type TEXT NOT NULL DEFAULT 'message.create', + name TEXT, + content JSONB[] NOT NULL, + tool_call_id TEXT DEFAULT NULL, + tool_calls JSONB[] DEFAULT NULL, + model TEXT NOT NULL, + token_count INTEGER DEFAULT NULL, + tokenizer TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + timestamp DOUBLE PRECISION NOT NULL, + CONSTRAINT pk_entries PRIMARY KEY (session_id, entry_id, created_at), + CONSTRAINT ct_content_is_array_of_objects CHECK (all_jsonb_elements_are_objects (content)) +); + +-- Convert to hypertable if not already +SELECT + create_hypertable ( + 'entries', + by_range ('created_at', INTERVAL '1 day'), + if_not_exists => TRUE + ); + +SELECT + add_dimension ( + 'entries', + by_hash ('session_id', 2), + if_not_exists => TRUE + ); + +-- Create indexes for efficient querying +CREATE INDEX IF NOT EXISTS idx_entries_by_session ON entries (session_id DESC); + +-- Add foreign key constraint to sessions table +DO $$ +BEGIN + IF NOT EXISTS ( + SELECT 1 FROM pg_constraint WHERE conname = 'fk_entries_session' + ) THEN + ALTER TABLE entries + ADD CONSTRAINT fk_entries_session + FOREIGN KEY (session_id) + REFERENCES sessions(session_id) ON DELETE CASCADE; + END IF; +END $$; + +-- TODO: We should consider using a timescale background job to update the token count +-- instead of a trigger. +-- https://docs.timescale.com/use-timescale/latest/user-defined-actions/create-and-register/ +CREATE +OR REPLACE FUNCTION optimized_update_token_count_after () RETURNS TRIGGER AS $$ +DECLARE + calc_token_count INTEGER; +BEGIN + -- Compute token_count outside the UPDATE statement for clarity and potential optimization + calc_token_count := cardinality( + ai.openai_tokenize( + 'gpt-4o', -- FIXME: Use `NEW.model` + array_to_string(NEW.content::TEXT[], ' ') + ) + ); + + -- Perform the update only if token_count differs + IF calc_token_count <> NEW.token_count THEN + UPDATE entries + SET token_count = calc_token_count + WHERE entry_id = NEW.entry_id; + END IF; + + RETURN NULL; +END; +$$ LANGUAGE plpgsql; + +-- FIXME: This trigger is causing the slow performance of the create_entries query +-- +-- We should consider using a timescale background job to update the token count +-- instead of a trigger. +-- https://docs.timescale.com/use-timescale/latest/user-defined-actions/create-and-register/ +-- +-- CREATE TRIGGER trg_optimized_update_token_count_after +-- AFTER INSERT +-- OR +-- UPDATE ON entries FOR EACH ROW +-- EXECUTE FUNCTION optimized_update_token_count_after (); + +-- Add trigger to update parent session's updated_at +CREATE +OR REPLACE FUNCTION update_session_updated_at () RETURNS TRIGGER AS $$ +BEGIN + UPDATE sessions + SET updated_at = CURRENT_TIMESTAMP + WHERE session_id = NEW.session_id; + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +CREATE TRIGGER trg_update_session_updated_at +AFTER INSERT +OR +UPDATE ON entries FOR EACH ROW +EXECUTE FUNCTION update_session_updated_at (); + +COMMIT; diff --git a/memory-store/migrations/000016_entry_relations.down.sql b/memory-store/migrations/000016_entry_relations.down.sql new file mode 100644 index 000000000..6d54b0c08 --- /dev/null +++ b/memory-store/migrations/000016_entry_relations.down.sql @@ -0,0 +1,12 @@ +BEGIN; + +-- Drop trigger first +DROP TRIGGER IF EXISTS trg_enforce_leaf_nodes ON entry_relations; + +-- Drop function +DROP FUNCTION IF EXISTS enforce_leaf_nodes (); + +-- Drop the table and its constraints +DROP TABLE IF EXISTS entry_relations CASCADE; + +COMMIT; \ No newline at end of file diff --git a/memory-store/migrations/000016_entry_relations.up.sql b/memory-store/migrations/000016_entry_relations.up.sql new file mode 100644 index 000000000..a7b0317b9 --- /dev/null +++ b/memory-store/migrations/000016_entry_relations.up.sql @@ -0,0 +1,57 @@ +BEGIN; + +-- Create citext extension if not exists +CREATE EXTENSION IF NOT EXISTS citext; + +-- Create entry_relations table +CREATE TABLE IF NOT EXISTS entry_relations ( + session_id UUID NOT NULL, + head UUID NOT NULL, + relation CITEXT NOT NULL, + tail UUID NOT NULL, + is_leaf BOOLEAN NOT NULL DEFAULT FALSE, + CONSTRAINT pk_entry_relations PRIMARY KEY (session_id, head, relation, tail) +); + +-- Add foreign key constraint to sessions table +DO $$ +BEGIN + IF NOT EXISTS ( + SELECT 1 FROM pg_constraint WHERE conname = 'fk_entry_relations_session' + ) THEN + ALTER TABLE entry_relations + ADD CONSTRAINT fk_entry_relations_session + FOREIGN KEY (session_id) + REFERENCES sessions(session_id) ON DELETE CASCADE; + END IF; +END $$; + +-- Create indexes for efficient querying +CREATE INDEX idx_entry_relations_leaf ON entry_relations (session_id, is_leaf); + +CREATE OR REPLACE FUNCTION auto_update_leaf_status() RETURNS TRIGGER AS $$ +BEGIN + -- Set is_leaf = false for any existing rows that will now have this new relation as a child + UPDATE entry_relations + SET is_leaf = false + WHERE session_id = NEW.session_id + AND tail = NEW.head; + + -- Set is_leaf for the new row based on whether it has any children + NEW.is_leaf := NOT EXISTS ( + SELECT 1 + FROM entry_relations + WHERE session_id = NEW.session_id + AND head = NEW.tail + ); + + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +CREATE TRIGGER trg_auto_update_leaf_status +BEFORE INSERT OR UPDATE ON entry_relations +FOR EACH ROW +EXECUTE FUNCTION auto_update_leaf_status(); + +COMMIT; \ No newline at end of file diff --git a/memory-store/migrations/000017_compression.down.sql b/memory-store/migrations/000017_compression.down.sql new file mode 100644 index 000000000..8befeb465 --- /dev/null +++ b/memory-store/migrations/000017_compression.down.sql @@ -0,0 +1,17 @@ +BEGIN; + +SELECT + remove_compression_policy ('entries'); + +SELECT + remove_compression_policy ('transitions'); + +ALTER TABLE entries +SET + (timescaledb.compress = FALSE); + +ALTER TABLE transitions +SET + (timescaledb.compress = FALSE); + +COMMIT; diff --git a/memory-store/migrations/000017_compression.up.sql b/memory-store/migrations/000017_compression.up.sql new file mode 100644 index 000000000..06c7e6c77 --- /dev/null +++ b/memory-store/migrations/000017_compression.up.sql @@ -0,0 +1,32 @@ +/* + * MULTI-DIMENSIONAL HYPERTABLES WITH COMPRESSION (Complexity: 8/10) + * TimescaleDB's advanced feature that partitions data by both time (created_at) and space (session_id/execution_id). + * Automatically compresses data older than 7 days to save storage while maintaining query performance. + * Uses segment_by to group related rows and order_by to optimize decompression speed. + */ + +BEGIN; + +ALTER TABLE entries +SET + ( + timescaledb.compress = TRUE, + timescaledb.compress_segmentby = 'session_id', + timescaledb.compress_orderby = 'created_at DESC, entry_id DESC' + ); + +SELECT + add_compression_policy ('entries', INTERVAL '7 days'); + +ALTER TABLE transitions +SET + ( + timescaledb.compress = TRUE, + timescaledb.compress_segmentby = 'execution_id', + timescaledb.compress_orderby = 'created_at DESC, transition_id DESC' + ); + +SELECT + add_compression_policy ('transitions', INTERVAL '7 days'); + +COMMIT; diff --git a/memory-store/migrations/000018_doc_search.down.sql b/memory-store/migrations/000018_doc_search.down.sql new file mode 100644 index 000000000..1ccbc5af8 --- /dev/null +++ b/memory-store/migrations/000018_doc_search.down.sql @@ -0,0 +1,27 @@ +BEGIN; + +-- Drop the embed and search hybrid function +DROP FUNCTION IF EXISTS embed_and_search_hybrid; + +-- Drop the hybrid search function +DROP FUNCTION IF EXISTS search_hybrid; + +-- Drop the text search function +DROP FUNCTION IF EXISTS search_by_text; + +-- Drop the combined embed and search function +DROP FUNCTION IF EXISTS embed_and_search_by_vector; + +-- Drop the search function +DROP FUNCTION IF EXISTS search_by_vector; + +-- Drop the doc_search_result type +DROP TYPE IF EXISTS doc_search_result; + +-- Drop the embed_with_cache function +DROP FUNCTION IF EXISTS embed_with_cache; + +-- Drop the embeddings cache table +DROP TABLE IF EXISTS embeddings_cache CASCADE; + +COMMIT; diff --git a/memory-store/migrations/000018_doc_search.up.sql b/memory-store/migrations/000018_doc_search.up.sql new file mode 100644 index 000000000..8fde5e9bb --- /dev/null +++ b/memory-store/migrations/000018_doc_search.up.sql @@ -0,0 +1,499 @@ +BEGIN; + +-- Create unlogged table for caching embeddings +CREATE UNLOGGED TABLE IF NOT EXISTS embeddings_cache ( + model_input_md5 TEXT NOT NULL, + embedding vector (1024) NOT NULL, + CONSTRAINT pk_embeddings_cache PRIMARY KEY (model_input_md5) +); + +-- Add comment explaining table purpose +COMMENT ON TABLE embeddings_cache IS 'Unlogged table that caches embedding requests to avoid duplicate API calls'; + +CREATE +OR REPLACE function embed_with_cache ( + _provider text, + _model text, + _input_text text, + _input_type text DEFAULT NULL, + _api_key text DEFAULT NULL, + _api_key_name text DEFAULT NULL +) returns vector (1024) language plpgsql AS $$ + +-- Try to get cached embedding first +declare + cached_embedding vector(1024); + model_input_md5 text; +begin + if _provider != 'voyageai' then + raise exception 'Only voyageai provider is supported'; + end if; + + model_input_md5 := md5(_provider || '++' || _model || '++' || _input_text || '++' || _input_type); + + select embedding into cached_embedding + from embeddings_cache c + where c.model_input_md5 = model_input_md5; + + if found then + return cached_embedding; + end if; + + -- Not found in cache, call AI embedding function + cached_embedding := ai.voyageai_embed( + _model, + _input_text, + _input_type, + _api_key, + _api_key_name + ); + + -- Cache the result + insert into embeddings_cache ( + model_input_md5, + embedding + ) values ( + model_input_md5, + cached_embedding + ) on conflict (model_input_md5) do update set embedding = cached_embedding; + + return cached_embedding; +end; +$$; + +-- Create a type for the search results if it doesn't exist +DO $$ +BEGIN + IF NOT EXISTS ( + SELECT 1 FROM pg_type WHERE typname = 'doc_search_result' + ) THEN + CREATE TYPE doc_search_result AS ( + developer_id uuid, + doc_id uuid, + index integer, + title text, + content text, + distance float, + embedding vector(1024), + metadata jsonb, + owner_type text, + owner_id uuid + ); + END IF; +END $$; + +-- Create the search function +CREATE +OR REPLACE FUNCTION search_by_vector ( + developer_id UUID, + query_embedding vector (1024), + owner_types TEXT[], + owner_ids UUID [], + k integer DEFAULT 3, + confidence float DEFAULT 0.5, + metadata_filter jsonb DEFAULT NULL +) RETURNS SETOF doc_search_result LANGUAGE plpgsql AS $$ +DECLARE + search_threshold float; + owner_filter_sql text; + metadata_filter_sql text; +BEGIN + -- Input validation + IF k <= 0 THEN + RAISE EXCEPTION 'k must be greater than 0'; + END IF; + + IF confidence < 0 OR confidence > 1 THEN + RAISE EXCEPTION 'confidence must be between 0 and 1'; + END IF; + + IF owner_types IS NOT NULL AND owner_ids IS NOT NULL AND + array_length(owner_types, 1) != array_length(owner_ids, 1) AND + array_length(owner_types, 1) <= 0 THEN + RAISE EXCEPTION 'owner_types and owner_ids arrays must have the same length'; + END IF; + + -- Calculate search threshold from confidence + search_threshold := 1.0 - confidence; + + -- Build owner filter SQL + owner_filter_sql := ' + AND ( + doc_owners.owner_id = ANY($5::uuid[]) AND doc_owners.owner_type = ANY($4::text[]) + )'; + + -- Build metadata filter SQL if provided + IF metadata_filter IS NOT NULL THEN + metadata_filter_sql := 'AND d.metadata @> $6'; + ELSE + metadata_filter_sql := ''; + END IF; + + -- Return search results + RETURN QUERY EXECUTE format( + 'WITH ranked_docs AS ( + SELECT + d.developer_id, + d.doc_id, + d.index, + d.title, + d.content, + (1 - (d.embedding <=> $1)) as distance, + d.embedding, + d.metadata, + doc_owners.owner_type, + doc_owners.owner_id + FROM docs_embeddings d + LEFT JOIN doc_owners ON d.doc_id = doc_owners.doc_id + WHERE d.developer_id = $7 + AND 1 - (d.embedding <=> $1) >= $2 + %s + %s + ) + SELECT DISTINCT ON (doc_id) * + FROM ranked_docs + ORDER BY doc_id, distance DESC + LIMIT $3', + owner_filter_sql, + metadata_filter_sql + ) + USING + query_embedding, + search_threshold, + k, + owner_types, + owner_ids, + metadata_filter, + developer_id; + + +END; +$$; + +-- Add helpful comment +COMMENT ON FUNCTION search_by_vector IS 'Search documents by vector similarity with configurable confidence threshold and filtering options'; + +-- Create the combined embed and search function +CREATE +OR REPLACE FUNCTION embed_and_search_by_vector ( + developer_id UUID, + query_text text, + owner_types TEXT[], + owner_ids UUID [], + k integer DEFAULT 3, + confidence float DEFAULT 0.5, + metadata_filter jsonb DEFAULT NULL, + embedding_provider text DEFAULT 'voyageai', + embedding_model text DEFAULT 'voyage-3', + input_type text DEFAULT 'query', + api_key text DEFAULT NULL, + api_key_name text DEFAULT NULL +) RETURNS SETOF doc_search_result LANGUAGE plpgsql AS $$ +DECLARE + query_embedding vector(1024); +BEGIN + -- First generate embedding for the query text + query_embedding := embed_with_cache( + embedding_provider, + embedding_model, + query_text, + input_type, + api_key, + api_key_name + ); + + -- Then perform the search using the generated embedding + RETURN QUERY SELECT * FROM search_by_vector( + developer_id, + query_embedding, + owner_types, + owner_ids, + k, + confidence, + metadata_filter + ); +END; +$$; + +COMMENT ON FUNCTION embed_and_search_by_vector IS 'Convenience function that combines text embedding and vector search in one call'; + +-- Create the text search function +CREATE +OR REPLACE FUNCTION search_by_text ( + developer_id UUID, + query_text text, + owner_types TEXT[], + owner_ids UUID[], + search_language text DEFAULT 'english', + k integer DEFAULT 3, + metadata_filter jsonb DEFAULT NULL +) RETURNS SETOF doc_search_result LANGUAGE plpgsql AS $$ +DECLARE + owner_filter_sql text; + metadata_filter_sql text; + ts_query tsquery; +BEGIN + -- Input validation + IF k <= 0 THEN + RAISE EXCEPTION 'k must be greater than 0'; + END IF; + + IF owner_types IS NOT NULL AND owner_ids IS NOT NULL AND + array_length(owner_types, 1) != array_length(owner_ids, 1) AND + array_length(owner_types, 1) <= 0 THEN + RAISE EXCEPTION 'owner_types and owner_ids arrays must have the same length'; + END IF; + + -- Convert search query to tsquery + ts_query := websearch_to_tsquery(search_language::regconfig, query_text); + + -- Build owner filter SQL + owner_filter_sql := ' + AND ( + doc_owners.owner_id = ANY($4::uuid[]) AND doc_owners.owner_type = ANY($3::text[]) + )'; + + + -- Build metadata filter SQL if provided + IF metadata_filter IS NOT NULL THEN + metadata_filter_sql := 'AND d.metadata @> $5'; + ELSE + metadata_filter_sql := ''; + END IF; + + -- Return search results + RETURN QUERY EXECUTE format( + 'WITH ranked_docs AS ( + SELECT + d.developer_id, + d.doc_id, + d.index, + d.title, + d.content, + ts_rank_cd(d.search_tsv, $1, 32)::double precision as distance, + d.embedding, + d.metadata, + doc_owners.owner_type, + doc_owners.owner_id + FROM docs_embeddings d + LEFT JOIN doc_owners ON d.doc_id = doc_owners.doc_id + WHERE d.developer_id = $6 + AND d.search_tsv @@ $1 + %s + %s + ) + SELECT DISTINCT ON (doc_id) * + FROM ranked_docs + ORDER BY doc_id, distance DESC + LIMIT $2', + owner_filter_sql, + metadata_filter_sql + ) + USING + ts_query, + k, + owner_types, + owner_ids, + metadata_filter, + developer_id; + +END; +$$; + +COMMENT ON FUNCTION search_by_text IS 'Search documents using full-text search with configurable language and filtering options'; + +-- Function to calculate mean of an array +CREATE +OR REPLACE FUNCTION array_mean (arr FLOAT[]) RETURNS float AS $$ + SELECT avg(v) FROM unnest(arr) v; +$$ LANGUAGE SQL; + +-- Function to calculate standard deviation of an array +CREATE +OR REPLACE FUNCTION array_stddev (arr FLOAT[]) RETURNS float AS $$ + SELECT stddev(v) FROM unnest(arr) v; +$$ LANGUAGE SQL; + +-- DBSF normalization function +CREATE +OR REPLACE FUNCTION dbsf_normalize (scores FLOAT[]) RETURNS FLOAT[] AS $$ +DECLARE + m float; + sd float; + m3d float; + m_3d float; +BEGIN + -- Handle edge cases + IF array_length(scores, 1) < 2 THEN + RETURN scores; + END IF; + + -- Calculate statistics + sd := array_stddev(scores); + IF sd = 0 THEN + RETURN scores; + END IF; + + m := array_mean(scores); + m3d := 3 * sd + m; + m_3d := m - 3 * sd; + + -- Apply normalization + RETURN array( + SELECT (s - m_3d) / (m3d - m_3d) + FROM unnest(scores) s + ); +END; +$$ LANGUAGE plpgsql; + +-- Hybrid search function combining text and vector search +CREATE +OR REPLACE FUNCTION search_hybrid ( + developer_id UUID, + query_text text, + query_embedding vector (1024), + owner_types TEXT[], + owner_ids UUID [], + k integer DEFAULT 3, + alpha float DEFAULT 0.7, -- Weight for embedding results + confidence float DEFAULT 0.5, + metadata_filter jsonb DEFAULT NULL, + search_language text DEFAULT 'english' +) RETURNS SETOF doc_search_result AS $$ +DECLARE + text_weight float; + embedding_weight float; +BEGIN + -- Input validation + IF k <= 0 THEN + RAISE EXCEPTION 'k must be greater than 0'; + END IF; + + text_weight := 1.0 - alpha; + embedding_weight := alpha; + + RETURN QUERY + WITH text_results AS ( + SELECT * FROM search_by_text( + developer_id, + query_text, + owner_types, + owner_ids, + search_language, + k, + metadata_filter + ) + ), + embedding_results AS ( + SELECT * FROM search_by_vector( + developer_id, + query_embedding, + owner_types, + owner_ids, + k, + confidence, + metadata_filter + ) + ), + all_results AS ( + SELECT DISTINCT doc_id, title, content, metadata, embedding, + index, owner_type, owner_id + FROM ( + SELECT * FROM text_results + UNION + SELECT * FROM embedding_results + ) combined + ), + scores AS ( + SELECT + -- r.developer_id, + r.doc_id, + r.title, + r.content, + r.metadata, + r.embedding, + r.index, + r.owner_type, + r.owner_id, + COALESCE(t.distance, 0.0) as text_score, + COALESCE(e.distance, 0.0) as embedding_score + FROM all_results r + LEFT JOIN text_results t ON r.doc_id = t.doc_id + LEFT JOIN embedding_results e ON r.doc_id = e.doc_id + ), + normalized_scores AS ( + SELECT + *, + unnest(dbsf_normalize(array_agg(text_score) OVER ())) as norm_text_score, + unnest(dbsf_normalize(array_agg(embedding_score) OVER ())) as norm_embedding_score + FROM scores + ) + SELECT + developer_id, + doc_id, + index, + title, + content, + 1.0 - (text_weight * norm_text_score + embedding_weight * norm_embedding_score) as distance, + embedding, + metadata, + owner_type, + owner_id + FROM normalized_scores + ORDER BY distance ASC + LIMIT k; +END; +$$ LANGUAGE plpgsql; + +COMMENT ON FUNCTION search_hybrid IS 'Hybrid search combining text and vector search using Distribution-Based Score Fusion (DBSF)'; + +-- Convenience function that handles embedding generation +CREATE +OR REPLACE FUNCTION embed_and_search_hybrid ( + developer_id UUID, + query_text text, + owner_types TEXT[], + owner_ids UUID [], + k integer DEFAULT 3, + alpha float DEFAULT 0.7, + confidence float DEFAULT 0.5, + metadata_filter jsonb DEFAULT NULL, + search_language text DEFAULT 'english', + embedding_provider text DEFAULT 'voyageai', + embedding_model text DEFAULT 'voyage-3', + input_type text DEFAULT 'query', + api_key text DEFAULT NULL, + api_key_name text DEFAULT NULL +) RETURNS SETOF doc_search_result AS $$ +DECLARE + query_embedding vector(1024); +BEGIN + -- Generate embedding for query text + query_embedding := embed_with_cache( + embedding_provider, + embedding_model, + query_text, + input_type, + api_key, + api_key_name + ); + + -- Perform hybrid search + RETURN QUERY SELECT * FROM search_hybrid( + developer_id, + query_text, + query_embedding, + owner_types, + owner_ids, + k, + alpha, + confidence, + metadata_filter, + search_language + ); +END; +$$ LANGUAGE plpgsql; + +COMMENT ON FUNCTION embed_and_search_hybrid IS 'Convenience function that combines text embedding generation and hybrid search in one call'; + +COMMIT; diff --git a/memory-store/migrations/000019_system_developer.down.sql b/memory-store/migrations/000019_system_developer.down.sql new file mode 100644 index 000000000..96c6e1f37 --- /dev/null +++ b/memory-store/migrations/000019_system_developer.down.sql @@ -0,0 +1,27 @@ +BEGIN; + +-- Remove the system developer +DELETE FROM docs +WHERE developer_id = '00000000-0000-0000-0000-000000000000'::uuid; + +-- Remove the system developer +DELETE FROM executions +WHERE developer_id = '00000000-0000-0000-0000-000000000000'::uuid; + +-- Remove the system developer +DELETE FROM tasks +WHERE developer_id = '00000000-0000-0000-0000-000000000000'::uuid; + +-- Remove the system developer +DELETE FROM agents +WHERE developer_id = '00000000-0000-0000-0000-000000000000'::uuid; + +-- Remove the system developer +DELETE FROM users +WHERE developer_id = '00000000-0000-0000-0000-000000000000'::uuid; + +-- Remove the system developer +DELETE FROM developers +WHERE developer_id = '00000000-0000-0000-0000-000000000000'::uuid; + +COMMIT; diff --git a/memory-store/migrations/000019_system_developer.up.sql b/memory-store/migrations/000019_system_developer.up.sql new file mode 100644 index 000000000..34635b7ad --- /dev/null +++ b/memory-store/migrations/000019_system_developer.up.sql @@ -0,0 +1,18 @@ +BEGIN; + +-- Insert system developer with all zeros UUID +INSERT INTO developers ( + developer_id, + email, + active, + tags, + settings +) VALUES ( + '00000000-0000-0000-0000-000000000000', + 'system@internal.julep.ai', + true, + ARRAY['system', 'paid'], + '{}'::jsonb +) ON CONFLICT (developer_id) DO NOTHING; + +COMMIT; diff --git a/memory-store/options b/memory-store/options deleted file mode 100644 index 8a2a30378..000000000 --- a/memory-store/options +++ /dev/null @@ -1,213 +0,0 @@ -# This is a RocksDB option file. -# -# For detailed file format spec, please refer to the example file -# in examples/rocksdb_option_file_example.ini -# - -[Version] - rocksdb_version=9.8.4 - options_file_version=1.1 - -[DBOptions] - compaction_readahead_size=2097152 - strict_bytes_per_sync=false - bytes_per_sync=0 - max_background_jobs=8 - avoid_flush_during_shutdown=false - max_background_flushes=1 - delayed_write_rate=16777216 - max_open_files=-1 - max_subcompactions=1 - writable_file_max_buffer_size=1048576 - wal_bytes_per_sync=0 - max_background_compactions=6 - max_total_wal_size=0 - delete_obsolete_files_period_micros=21600000000 - stats_dump_period_sec=600 - stats_history_buffer_size=1048576 - stats_persist_period_sec=600 - follower_refresh_catchup_period_ms=10000 - enforce_single_del_contracts=true - lowest_used_cache_tier=kNonVolatileBlockTier - bgerror_resume_retry_interval=1000000 - metadata_write_temperature=kUnknown - best_efforts_recovery=false - log_readahead_size=0 - write_identity_file=true - write_dbid_to_manifest=true - prefix_seek_opt_in_only=false - wal_compression=kNoCompression - manual_wal_flush=false - db_host_id=__hostname__ - two_write_queues=false - random_access_max_buffer_size=1048576 - avoid_unnecessary_blocking_io=false - skip_checking_sst_file_sizes_on_db_open=false - flush_verify_memtable_count=true - fail_if_options_file_error=true - atomic_flush=false - verify_sst_unique_id_in_manifest=true - skip_stats_update_on_db_open=false - track_and_verify_wals_in_manifest=false - compaction_verify_record_count=true - paranoid_checks=true - create_if_missing=true - max_write_batch_group_size_bytes=1048576 - follower_catchup_retry_count=10 - avoid_flush_during_recovery=false - file_checksum_gen_factory=nullptr - enable_thread_tracking=false - allow_fallocate=true - allow_data_in_errors=false - error_if_exists=false - use_direct_io_for_flush_and_compaction=false - background_close_inactive_wals=false - create_missing_column_families=false - WAL_size_limit_MB=0 - use_direct_reads=false - persist_stats_to_disk=true - allow_2pc=false - is_fd_close_on_exec=true - max_log_file_size=0 - max_file_opening_threads=16 - wal_filter=nullptr - wal_write_temperature=kUnknown - follower_catchup_retry_wait_ms=100 - allow_mmap_reads=false - allow_mmap_writes=false - use_adaptive_mutex=false - use_fsync=false - table_cache_numshardbits=6 - dump_malloc_stats=true - db_write_buffer_size=17179869184 - allow_ingest_behind=false - keep_log_file_num=1000 - max_bgerror_resume_count=2147483647 - allow_concurrent_memtable_write=true - recycle_log_file_num=0 - log_file_time_to_roll=0 - manifest_preallocation_size=4194304 - enable_write_thread_adaptive_yield=true - WAL_ttl_seconds=0 - max_manifest_file_size=1073741824 - wal_recovery_mode=kPointInTimeRecovery - enable_pipelined_write=false - write_thread_slow_yield_usec=3 - unordered_write=false - write_thread_max_yield_usec=100 - advise_random_on_open=true - info_log_level=INFO_LEVEL - - -[CFOptions "default"] - memtable_max_range_deletions=0 - compression_opts={checksum=false;max_dict_buffer_bytes=0;enabled=false;max_dict_bytes=0;max_compressed_bytes_per_kb=896;parallel_threads=1;zstd_max_train_bytes=0;level=32767;use_zstd_dict_trainer=true;strategy=0;window_bits=-14;} - paranoid_memory_checks=false - block_protection_bytes_per_key=0 - uncache_aggressiveness=0 - bottommost_file_compaction_delay=0 - memtable_protection_bytes_per_key=0 - experimental_mempurge_threshold=0.000000 - bottommost_compression=kZSTD - sample_for_compression=0 - prepopulate_blob_cache=kDisable - table_factory=BlockBasedTable - max_successive_merges=0 - max_write_buffer_number=2 - prefix_extractor=nullptr - memtable_huge_page_size=0 - write_buffer_size=33554427 - strict_max_successive_merges=false - blob_compaction_readahead_size=0 - arena_block_size=1048576 - level0_file_num_compaction_trigger=4 - report_bg_io_stats=true - inplace_update_num_locks=10000 - memtable_prefix_bloom_size_ratio=0.000000 - level0_stop_writes_trigger=36 - blob_compression_type=kNoCompression - level0_slowdown_writes_trigger=20 - hard_pending_compaction_bytes_limit=274877906944 - target_file_size_multiplier=1 - bottommost_compression_opts={checksum=false;max_dict_buffer_bytes=0;enabled=false;max_dict_bytes=0;max_compressed_bytes_per_kb=896;parallel_threads=1;zstd_max_train_bytes=0;level=32767;use_zstd_dict_trainer=true;strategy=0;window_bits=-14;} - paranoid_file_checks=false - blob_garbage_collection_force_threshold=1.000000 - enable_blob_files=true - blob_file_starting_level=0 - soft_pending_compaction_bytes_limit=68719476736 - target_file_size_base=67108864 - max_compaction_bytes=1677721600 - disable_auto_compactions=false - min_blob_size=0 - memtable_whole_key_filtering=false - max_bytes_for_level_base=268435456 - last_level_temperature=kUnknown - compaction_options_fifo={file_temperature_age_thresholds=;allow_compaction=false;age_for_warm=0;max_table_files_size=1073741824;} - max_bytes_for_level_multiplier=10.000000 - max_bytes_for_level_multiplier_additional=1:1:1:1:1:1:1 - max_sequential_skip_in_iterations=8 - compression=kLZ4Compression - default_write_temperature=kUnknown - compaction_options_universal={incremental=false;compression_size_percent=-1;allow_trivial_move=false;max_size_amplification_percent=200;max_merge_width=4294967295;stop_style=kCompactionStopStyleTotalSize;min_merge_width=2;max_read_amp=-1;size_ratio=1;} - blob_garbage_collection_age_cutoff=0.250000 - ttl=2592000 - periodic_compaction_seconds=0 - blob_file_size=268435456 - enable_blob_garbage_collection=true - persist_user_defined_timestamps=true - preserve_internal_time_seconds=0 - preclude_last_level_data_seconds=0 - sst_partitioner_factory=nullptr - num_levels=7 - force_consistency_checks=true - memtable_insert_with_hint_prefix_extractor=nullptr - memtable_factory=SkipListFactory - max_write_buffer_number_to_maintain=0 - optimize_filters_for_hits=false - level_compaction_dynamic_level_bytes=true - default_temperature=kUnknown - inplace_update_support=false - merge_operator=nullptr - min_write_buffer_number_to_merge=1 - compaction_filter=nullptr - compaction_style=kCompactionStyleLevel - bloom_locality=0 - comparator=leveldb.BytewiseComparator - compaction_filter_factory=nullptr - max_write_buffer_size_to_maintain=134217728 - compaction_pri=kMinOverlappingRatio - -[TableOptions/BlockBasedTable "default"] - num_file_reads_for_auto_readahead=2 - initial_auto_readahead_size=8192 - metadata_cache_options={unpartitioned_pinning=kFallback;partition_pinning=kFallback;top_level_index_pinning=kFallback;} - enable_index_compression=true - verify_compression=false - prepopulate_block_cache=kDisable - format_version=6 - use_delta_encoding=true - pin_top_level_index_and_filter=true - read_amp_bytes_per_bit=0 - decouple_partitioned_filters=false - partition_filters=false - metadata_block_size=4096 - max_auto_readahead_size=262144 - index_block_restart_interval=1 - block_size_deviation=10 - block_size=4096 - detect_filter_construct_corruption=false - no_block_cache=false - checksum=kXXH3 - filter_policy=ribbonfilter:10 - data_block_hash_table_util_ratio=0.750000 - block_restart_interval=16 - index_type=kBinarySearch - pin_l0_filter_and_index_blocks_in_cache=false - data_block_index_type=kDataBlockBinarySearch - cache_index_and_filter_blocks_with_high_priority=true - whole_key_filtering=true - index_shortening=kShortenSeparators - cache_index_and_filter_blocks=true - block_align=false - optimize_filters_for_memory=true - flush_block_policy_factory=FlushBlockBySizePolicyFactory \ No newline at end of file diff --git a/memory-store/run.sh b/memory-store/run.sh deleted file mode 100755 index fa03f664d..000000000 --- a/memory-store/run.sh +++ /dev/null @@ -1,23 +0,0 @@ -#!/usr/bin/env bash - -set -eo pipefail - -# Create mount directory for service and RocksDB directory -mkdir -p ${COZO_MNT_DIR:=/data}/${COZO_ROCKSDB_DIR:-cozo.db} - -# Create auth token if not exists. -export COZO_AUTH_TOKEN=${COZO_AUTH_TOKEN:=`tr -dc A-Za-z0-9 $COZO_MNT_DIR/${COZO_ROCKSDB_DIR}.newrocksdb.cozo_auth - -# Copy options file to the RocksDB directory -cp /app/options $COZO_MNT_DIR/${COZO_ROCKSDB_DIR}/OPTIONS-000007 - -# Start server -${APP_HOME:=.}/bin/cozo server \ - --engine newrocksdb \ - --path $COZO_MNT_DIR/${COZO_ROCKSDB_DIR} \ - --bind 0.0.0.0 \ - --port ${COZO_PORT:=9070} \ - -c '{"enable_write_buffer_manager": true, "allow_stall": true, "lru_cache_mb": 4096, "write_buffer_mb": 4096}' diff --git a/monitoring/grafana/provisioning/dashboards/main.yaml b/monitoring/grafana/provisioning/dashboards/main.yaml new file mode 100755 index 000000000..e69de29bb diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 000000000..14a22b935 --- /dev/null +++ b/ruff.toml @@ -0,0 +1,92 @@ +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".git-rewrite", + ".hg", + ".ipynb_checkpoints", + ".mypy_cache", + ".nox", + ".pants.d", + ".pyenv", + ".pytest_cache", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + ".vscode", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "site-packages", + "venv", +] + +fix = true +unsafe-fixes = true + +# Enumerate all fixed violations. +show-fixes = true + +respect-gitignore = true + +# Enable preview features. +preview = true + +line-length = 96 +indent-width = 4 + +# Assume Python 3.12 +target-version = "py312" + +# Group violations by containing file. +output-format = "grouped" + +[lint] +# Enable preview features. +preview = true + +# TODO: Need to add , "**/autogen/*.py" +exclude = ["gunicorn_conf.py", "*.ipynb"] + +# TODO: Enable C09, S, B, ARG, PTH, ERA, PLW, FURB +select = ["F", "E1", "E2", "E3", "E4", "E5", "E7", "W", "FAST", "I", "UP", "ASYNC", "COM", "C4", "DTZ", "T10", "EM", "FA", "ISC", "ICN", "INP", "PIE", "Q", "RSE", "RET", "SLF", "SLOT", "SIM", "INT", "PD", "PLE", "FLY", "NPY", "PERF", "RUF"] +ignore = [ + "E402", # module-import-not-at-top-of-file + "E501", # line-too-long + "E722", # bare-except + "RUF001", # ambiguous-unicode-character-string + "RUF029", # unused-async + "ASYNC230", # blocking-open-in-async + "ASYNC109", # disallow-async-fns-with-timeout-param + "COM812", "ISC001", # conflict with each other +] + +fixable = ["ALL"] +unfixable = [] + +[format] +exclude = ["*.ipynb", "*.pyi", "*.pyc"] + +# Enable preview style formatting. +preview = true + +# Like Black, use double quotes for strings. +quote-style = "double" + +# Like Black, indent with spaces, rather than tabs. +indent-style = "space" + +# Like Black, respect magic trailing commas. +skip-magic-trailing-comma = false + +# Like Black, automatically detect the appropriate line ending. +line-ending = "auto" + +docstring-code-format = true +docstring-code-line-length = "dynamic" diff --git a/sdks/node-sdk b/sdks/node-sdk index 6cb742bba..ff0443945 160000 --- a/sdks/node-sdk +++ b/sdks/node-sdk @@ -1 +1 @@ -Subproject commit 6cb742bba2b408ef2ea070bfe284595bcdb974fe +Subproject commit ff0443945294c2638b7cc796456d271c1b669fa2 diff --git a/sdks/python-sdk b/sdks/python-sdk index 7f9bc0c59..96cd8e74b 160000 --- a/sdks/python-sdk +++ b/sdks/python-sdk @@ -1 +1 @@ -Subproject commit 7f9bc0c59d2e80f6e707f5dcc9e78fecd197c3ca +Subproject commit 96cd8e74bc95bc83158f07489916b7f4214aa002 diff --git a/typespec/agents/models.tsp b/typespec/agents/models.tsp index b2763e285..374383c16 100644 --- a/typespec/agents/models.tsp +++ b/typespec/agents/models.tsp @@ -20,7 +20,10 @@ model Agent { ...HasTimestamps; /** Name of the agent */ - name: identifierSafeUnicode = identifierSafeUnicode(""); + name: displayName; + + /** Canonical name of the agent */ + canonical_name?: canonicalName; /** About the agent */ about: string = ""; diff --git a/typespec/common/constants.tsp b/typespec/common/constants.tsp index bcd9e8bc1..da9ed226b 100644 --- a/typespec/common/constants.tsp +++ b/typespec/common/constants.tsp @@ -20,6 +20,12 @@ You are talking to a user {%- endif -%} {%- endif -%} +{{NEWLINE}} + +{%- if session.situation -%} +Situation: {{session.situation}} +{%- endif -%} + {{NEWLINE+NEWLINE}} {%- if agent.instructions -%} @@ -34,15 +40,6 @@ Instructions:{{NEWLINE}} {{NEWLINE}} {%- endif -%} -{%- if tools -%} -Tools:{{NEWLINE}} - {%- for tool in tools -%} - - {{tool.name + NEWLINE}} - {%- if tool.description -%}: {{tool.description + NEWLINE}}{%- endif -%} - {%- endfor -%} -{{NEWLINE+NEWLINE}} -{%- endif -%} - {%- if docs -%} Relevant documents:{{NEWLINE}} {%- for doc in docs -%} diff --git a/typespec/common/scalars.tsp b/typespec/common/scalars.tsp index c718f6289..4e8f7b186 100644 --- a/typespec/common/scalars.tsp +++ b/typespec/common/scalars.tsp @@ -66,3 +66,20 @@ scalar PyExpression extends string; /** A valid jinja template. */ scalar JinjaTemplate extends string; + +/** + * For canonical names (machine-friendly identifiers) + * Must start with a letter and can only contain letters, numbers, and underscores + */ +@minLength(1) +@maxLength(255) +@pattern("^[a-zA-Z][a-zA-Z0-9_]*$") +scalar canonicalName extends string; + +/** + * For display names + * Must be between 1 and 255 characters + */ +@minLength(1) +@maxLength(255) +scalar displayName extends string; diff --git a/typespec/docs/models.tsp b/typespec/docs/models.tsp index 055fc2003..afc3b36fd 100644 --- a/typespec/docs/models.tsp +++ b/typespec/docs/models.tsp @@ -26,7 +26,23 @@ model Doc { /** Embeddings for the document */ @visibility("read") - embeddings?: float32[] | float32[][]; + embeddings: float32[] | float32[][] | null = null; + + @visibility("read") + /** Modality of the document */ + modality?: string; + + @visibility("read") + /** Language of the document */ + language?: string; + + @visibility("read") + /** Embedding model used for the document */ + embedding_model?: string; + + @visibility("read") + /** Dimensions of the embedding model */ + embedding_dimensions?: uint16; } /** Payload for creating a doc */ @@ -152,4 +168,4 @@ model DocSearchResponse { /** The time taken to search in seconds */ @minValueExclusive(0) time: float; -} \ No newline at end of file +} diff --git a/typespec/entries/models.tsp b/typespec/entries/models.tsp index 7f8c8b9fa..d7eae55e7 100644 --- a/typespec/entries/models.tsp +++ b/typespec/entries/models.tsp @@ -107,6 +107,7 @@ model BaseEntry { tokenizer: string; token_count: uint16; + "model": string = "gpt-4o-mini"; /** Tool calls generated by the model. */ tool_calls?: ChosenToolCall[] | null = null; diff --git a/typespec/sessions/models.tsp b/typespec/sessions/models.tsp index f15453a5f..68b328af0 100644 --- a/typespec/sessions/models.tsp +++ b/typespec/sessions/models.tsp @@ -60,8 +60,11 @@ model Session { @visibility("create") agents?: uuid[]; - /** A specific situation that sets the background for this session */ - situation: string = defaultSessionSystemMessage; + /** Session situation */ + situation: string | null = null; + + /** A specific system prompt template that sets the background for this session */ + system_template: string = defaultSessionSystemMessage; /** Summary (null at the beginning) - generated automatically after every interaction */ @visibility("read") @@ -83,6 +86,9 @@ model Session { * If a tool call is not made, the model's output will be returned as is. */ auto_run_tools: boolean = false; + /** Whether to forward tool calls to the model */ + forward_tool_calls: boolean = false; + recall_options?: RecallOptions | null = null; ...HasId; diff --git a/typespec/tasks/models.tsp b/typespec/tasks/models.tsp index c3b301bd2..ca6b72e00 100644 --- a/typespec/tasks/models.tsp +++ b/typespec/tasks/models.tsp @@ -50,9 +50,14 @@ model ToolRef { /** Object describing a Task */ model Task { - @visibility("read", "create") - name: string; + /** The name of the task. */ + @visibility("read", "create", "update") + name: displayName; + + /** The canonical name of the task. */ + canonical_name?: canonicalName; + /** The description of the task. */ description: string = ""; /** The entrypoint of the task. */ diff --git a/typespec/tsp-output/@typespec/openapi3/openapi-1.0.0.yaml b/typespec/tsp-output/@typespec/openapi3/openapi-1.0.0.yaml index eb58eeef2..d9aab47ee 100644 --- a/typespec/tsp-output/@typespec/openapi3/openapi-1.0.0.yaml +++ b/typespec/tsp-output/@typespec/openapi3/openapi-1.0.0.yaml @@ -1449,9 +1449,12 @@ components: readOnly: true name: allOf: - - $ref: '#/components/schemas/Common.identifierSafeUnicode' + - $ref: '#/components/schemas/Common.displayName' description: Name of the agent - default: '' + canonical_name: + allOf: + - $ref: '#/components/schemas/Common.canonicalName' + description: Canonical name of the agent about: type: string description: About the agent @@ -1485,9 +1488,12 @@ components: additionalProperties: {} name: allOf: - - $ref: '#/components/schemas/Common.identifierSafeUnicode' + - $ref: '#/components/schemas/Common.displayName' description: Name of the agent - default: '' + canonical_name: + allOf: + - $ref: '#/components/schemas/Common.canonicalName' + description: Canonical name of the agent about: type: string description: About the agent @@ -1525,9 +1531,12 @@ components: additionalProperties: {} name: allOf: - - $ref: '#/components/schemas/Common.identifierSafeUnicode' + - $ref: '#/components/schemas/Common.displayName' description: Name of the agent - default: '' + canonical_name: + allOf: + - $ref: '#/components/schemas/Common.canonicalName' + description: Canonical name of the agent about: type: string description: About the agent @@ -1558,9 +1567,12 @@ components: additionalProperties: {} name: allOf: - - $ref: '#/components/schemas/Common.identifierSafeUnicode' + - $ref: '#/components/schemas/Common.displayName' description: Name of the agent - default: '' + canonical_name: + allOf: + - $ref: '#/components/schemas/Common.canonicalName' + description: Canonical name of the agent about: type: string description: About the agent @@ -1595,9 +1607,12 @@ components: additionalProperties: {} name: allOf: - - $ref: '#/components/schemas/Common.identifierSafeUnicode' + - $ref: '#/components/schemas/Common.displayName' description: Name of the agent - default: '' + canonical_name: + allOf: + - $ref: '#/components/schemas/Common.canonicalName' + description: Canonical name of the agent about: type: string description: About the agent @@ -2706,6 +2721,21 @@ components: description: IDs (if any) of jobs created as part of this request default: [] readOnly: true + Common.canonicalName: + type: string + minLength: 1 + maxLength: 255 + pattern: ^[a-zA-Z][a-zA-Z0-9_]*$ + description: |- + For canonical names (machine-friendly identifiers) + Must start with a letter and can only contain letters, numbers, and underscores + Common.displayName: + type: string + minLength: 1 + maxLength: 255 + description: |- + For display names + Must be between 1 and 255 characters Common.identifierSafeUnicode: type: string maxLength: 120 @@ -2808,6 +2838,7 @@ components: - created_at - title - content + - embeddings properties: id: allOf: @@ -2844,7 +2875,26 @@ components: items: type: number format: float + nullable: true description: Embeddings for the document + default: null + readOnly: true + modality: + type: string + description: Modality of the document + readOnly: true + language: + type: string + description: Language of the document + readOnly: true + embedding_model: + type: string + description: Embedding model used for the document + readOnly: true + embedding_dimensions: + type: integer + format: uint16 + description: Dimensions of the embedding model readOnly: true Docs.DocOwner: type: object @@ -3034,6 +3084,7 @@ components: - source - tokenizer - token_count + - model - timestamp properties: role: @@ -3277,6 +3328,9 @@ components: token_count: type: integer format: uint16 + model: + type: string + default: gpt-4o-mini tool_calls: type: array items: @@ -3727,10 +3781,12 @@ components: required: - id - situation + - system_template - render_templates - token_budget - context_overflow - auto_run_tools + - forward_tool_calls properties: id: $ref: '#/components/schemas/Common.uuid' @@ -3752,7 +3808,12 @@ components: $ref: '#/components/schemas/Common.uuid' situation: type: string - description: A specific situation that sets the background for this session + nullable: true + description: Session situation + default: null + system_template: + type: string + description: A specific system prompt template that sets the background for this session default: |- {%- if agent.name -%} You are {{agent.name}}.{{" "}} @@ -3769,6 +3830,12 @@ components: {%- endif -%} {%- endif -%} + {{NEWLINE}} + + {%- if session.situation -%} + Situation: {{session.situation}} + {%- endif -%} + {{NEWLINE+NEWLINE}} {%- if agent.instructions -%} @@ -3783,15 +3850,6 @@ components: {{NEWLINE}} {%- endif -%} - {%- if tools -%} - Tools:{{NEWLINE}} - {%- for tool in tools -%} - - {{tool.name + NEWLINE}} - {%- if tool.description -%}: {{tool.description + NEWLINE}}{%- endif -%} - {%- endfor -%} - {{NEWLINE+NEWLINE}} - {%- endif -%} - {%- if docs -%} Relevant documents:{{NEWLINE}} {%- for doc in docs -%} @@ -3831,6 +3889,10 @@ components: If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. default: false + forward_tool_calls: + type: boolean + description: Whether to forward tool calls to the model + default: false recall_options: type: object allOf: @@ -3846,10 +3908,12 @@ components: type: object required: - situation + - system_template - render_templates - token_budget - context_overflow - auto_run_tools + - forward_tool_calls properties: user: allOf: @@ -3869,7 +3933,12 @@ components: $ref: '#/components/schemas/Common.uuid' situation: type: string - description: A specific situation that sets the background for this session + nullable: true + description: Session situation + default: null + system_template: + type: string + description: A specific system prompt template that sets the background for this session default: |- {%- if agent.name -%} You are {{agent.name}}.{{" "}} @@ -3886,6 +3955,12 @@ components: {%- endif -%} {%- endif -%} + {{NEWLINE}} + + {%- if session.situation -%} + Situation: {{session.situation}} + {%- endif -%} + {{NEWLINE+NEWLINE}} {%- if agent.instructions -%} @@ -3900,15 +3975,6 @@ components: {{NEWLINE}} {%- endif -%} - {%- if tools -%} - Tools:{{NEWLINE}} - {%- for tool in tools -%} - - {{tool.name + NEWLINE}} - {%- if tool.description -%}: {{tool.description + NEWLINE}}{%- endif -%} - {%- endfor -%} - {{NEWLINE+NEWLINE}} - {%- endif -%} - {%- if docs -%} Relevant documents:{{NEWLINE}} {%- for doc in docs -%} @@ -3948,6 +4014,10 @@ components: If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. default: false + forward_tool_calls: + type: boolean + description: Whether to forward tool calls to the model + default: false recall_options: type: object allOf: @@ -4008,7 +4078,12 @@ components: properties: situation: type: string - description: A specific situation that sets the background for this session + nullable: true + description: Session situation + default: null + system_template: + type: string + description: A specific system prompt template that sets the background for this session default: |- {%- if agent.name -%} You are {{agent.name}}.{{" "}} @@ -4025,6 +4100,12 @@ components: {%- endif -%} {%- endif -%} + {{NEWLINE}} + + {%- if session.situation -%} + Situation: {{session.situation}} + {%- endif -%} + {{NEWLINE+NEWLINE}} {%- if agent.instructions -%} @@ -4039,15 +4120,6 @@ components: {{NEWLINE}} {%- endif -%} - {%- if tools -%} - Tools:{{NEWLINE}} - {%- for tool in tools -%} - - {{tool.name + NEWLINE}} - {%- if tool.description -%}: {{tool.description + NEWLINE}}{%- endif -%} - {%- endfor -%} - {{NEWLINE+NEWLINE}} - {%- endif -%} - {%- if docs -%} Relevant documents:{{NEWLINE}} {%- for doc in docs -%} @@ -4087,6 +4159,10 @@ components: If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. default: false + forward_tool_calls: + type: boolean + description: Whether to forward tool calls to the model + default: false recall_options: type: object allOf: @@ -4155,18 +4231,25 @@ components: type: object required: - situation + - system_template - summary - render_templates - token_budget - context_overflow - auto_run_tools + - forward_tool_calls - id - created_at - updated_at properties: situation: type: string - description: A specific situation that sets the background for this session + nullable: true + description: Session situation + default: null + system_template: + type: string + description: A specific system prompt template that sets the background for this session default: |- {%- if agent.name -%} You are {{agent.name}}.{{" "}} @@ -4183,6 +4266,12 @@ components: {%- endif -%} {%- endif -%} + {{NEWLINE}} + + {%- if session.situation -%} + Situation: {{session.situation}} + {%- endif -%} + {{NEWLINE+NEWLINE}} {%- if agent.instructions -%} @@ -4197,15 +4286,6 @@ components: {{NEWLINE}} {%- endif -%} - {%- if tools -%} - Tools:{{NEWLINE}} - {%- for tool in tools -%} - - {{tool.name + NEWLINE}} - {%- if tool.description -%}: {{tool.description + NEWLINE}}{%- endif -%} - {%- endfor -%} - {{NEWLINE+NEWLINE}} - {%- endif -%} - {%- if docs -%} Relevant documents:{{NEWLINE}} {%- for doc in docs -%} @@ -4251,6 +4331,10 @@ components: If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. default: false + forward_tool_calls: + type: boolean + description: Whether to forward tool calls to the model + default: false recall_options: type: object allOf: @@ -4326,14 +4410,21 @@ components: type: object required: - situation + - system_template - render_templates - token_budget - context_overflow - auto_run_tools + - forward_tool_calls properties: situation: type: string - description: A specific situation that sets the background for this session + nullable: true + description: Session situation + default: null + system_template: + type: string + description: A specific system prompt template that sets the background for this session default: |- {%- if agent.name -%} You are {{agent.name}}.{{" "}} @@ -4350,6 +4441,12 @@ components: {%- endif -%} {%- endif -%} + {{NEWLINE}} + + {%- if session.situation -%} + Situation: {{session.situation}} + {%- endif -%} + {{NEWLINE+NEWLINE}} {%- if agent.instructions -%} @@ -4364,15 +4461,6 @@ components: {{NEWLINE}} {%- endif -%} - {%- if tools -%} - Tools:{{NEWLINE}} - {%- for tool in tools -%} - - {{tool.name + NEWLINE}} - {%- if tool.description -%}: {{tool.description + NEWLINE}}{%- endif -%} - {%- endfor -%} - {{NEWLINE+NEWLINE}} - {%- endif -%} - {%- if docs -%} Relevant documents:{{NEWLINE}} {%- for doc in docs -%} @@ -4412,6 +4500,10 @@ components: If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. default: false + forward_tool_calls: + type: boolean + description: Whether to forward tool calls to the model + default: false recall_options: type: object allOf: @@ -4487,9 +4579,16 @@ components: - inherit_tools properties: name: - type: string + allOf: + - $ref: '#/components/schemas/Common.displayName' + description: The name of the task. + canonical_name: + allOf: + - $ref: '#/components/schemas/Common.canonicalName' + description: The canonical name of the task. description: type: string + description: The description of the task. default: '' main: type: array @@ -5103,8 +5202,17 @@ components: Tasks.PatchTaskRequest: type: object properties: + name: + allOf: + - $ref: '#/components/schemas/Common.displayName' + description: The name of the task. + canonical_name: + allOf: + - $ref: '#/components/schemas/Common.canonicalName' + description: The canonical name of the task. description: type: string + description: The description of the task. default: '' main: type: array @@ -5899,9 +6007,16 @@ components: - updated_at properties: name: - type: string + allOf: + - $ref: '#/components/schemas/Common.displayName' + description: The name of the task. + canonical_name: + allOf: + - $ref: '#/components/schemas/Common.canonicalName' + description: The canonical name of the task. description: type: string + description: The description of the task. default: '' main: type: array @@ -6246,14 +6361,24 @@ components: Tasks.UpdateTaskRequest: type: object required: + - name - description - main - input_schema - tools - inherit_tools properties: + name: + allOf: + - $ref: '#/components/schemas/Common.displayName' + description: The name of the task. + canonical_name: + allOf: + - $ref: '#/components/schemas/Common.canonicalName' + description: The canonical name of the task. description: type: string + description: The description of the task. default: '' main: type: array