Skip to content

Commit

Permalink
feat: Add update tool query
Browse files Browse the repository at this point in the history
  • Loading branch information
whiterabbit1983 committed Dec 20, 2024
1 parent 32dbbba commit 281e1a8
Showing 1 changed file with 41 additions and 52 deletions.
93 changes: 41 additions & 52 deletions agents-api/agents_api/queries/tools/update_tool.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,55 @@
from typing import Any, TypeVar
from uuid import UUID

import sqlvalidator
from beartype import beartype
from fastapi import HTTPException
from pycozo.client import QueryException
from pydantic import ValidationError

from ...autogen.openapi_model import (
ResourceUpdatedResponse,
UpdateToolRequest,
)
from ...common.utils.cozo import cozo_process_mutate_data
from ...exceptions import InvalidSQLQuery
from ...metrics.counters import increase_counter
from ..utils import (
cozo_query,
partialclass,
rewrap_exceptions,
verify_developer_id_query,
verify_developer_owns_resource_query,
pg_query,
wrap_in_class,
)

ModelT = TypeVar("ModelT", bound=Any)
T = TypeVar("T")


@rewrap_exceptions(
{
QueryException: partialclass(HTTPException, status_code=400),
ValidationError: partialclass(HTTPException, status_code=400),
TypeError: partialclass(HTTPException, status_code=400),
}
)
sql_query = sqlvalidator.parse("""
UPDATE tools
SET
type = $4,
name = $5,
description = $6,
spec = $7
WHERE
developer_id = $1 AND
agent_id = $2 AND
tool_id = $3
RETURNING *;
""")

if not sql_query.is_valid():
raise InvalidSQLQuery("update_tool")


# @rewrap_exceptions(
# {
# QueryException: partialclass(HTTPException, status_code=400),
# ValidationError: partialclass(HTTPException, status_code=400),
# TypeError: partialclass(HTTPException, status_code=400),
# }
# )
@wrap_in_class(
ResourceUpdatedResponse,
one=True,
transform=lambda d: {"id": d["tool_id"], "jobs": [], **d},
_kind="inserted",
)
@cozo_query
@pg_query
@increase_counter("update_tool")
@beartype
def update_tool(
Expand All @@ -48,7 +59,8 @@ def update_tool(
tool_id: UUID,
data: UpdateToolRequest,
**kwargs,
) -> tuple[list[str], dict]:
) -> tuple[list[str], list]:
developer_id = str(developer_id)
agent_id = str(agent_id)
tool_id = str(tool_id)

Expand All @@ -72,38 +84,15 @@ def update_tool(
update_data["spec"] = tool_spec
del update_data[tool_type]

tool_cols, tool_vals = cozo_process_mutate_data(
{
**update_data,
"agent_id": agent_id,
"tool_id": tool_id,
}
)

# Construct the datalog query for updating the tool information
patch_query = f"""
input[{tool_cols}] <- $input
?[{tool_cols}, created_at, updated_at] :=
*tools {{
agent_id: to_uuid($agent_id),
tool_id: to_uuid($tool_id),
created_at
}},
input[{tool_cols}],
updated_at = now()
:put tools {{ {tool_cols}, created_at, updated_at }}
:returning
"""

queries = [
verify_developer_id_query(developer_id),
verify_developer_owns_resource_query(developer_id, "agents", agent_id=agent_id),
patch_query,
]

return (
queries,
dict(input=tool_vals, spec=tool_spec, agent_id=agent_id, tool_id=tool_id),
sql_query.format(),
[
developer_id,
agent_id,
tool_id,
tool_type,
data.name,
data.description,
tool_spec,
],
)

0 comments on commit 281e1a8

Please sign in to comment.