Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

creatorrr/f add missing tests #450

Merged
merged 2 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 2 additions & 30 deletions agents-api/agents_api/autogen/Tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,35 +46,7 @@ class FunctionDef(BaseModel):
model_config = ConfigDict(
populate_by_name=True,
)
name: Annotated[str, Field("overriden", pattern="^[^\\W0-9]\\w*$")]
"""
DO NOT USE: This will be overriden by the tool name. Here only for compatibility reasons.
"""
description: Annotated[
str | None,
Field(
None,
pattern="^[\\p{L}\\p{Nl}\\p{Pattern_Syntax}\\p{Pattern_White_Space}]+[\\p{ID_Start}\\p{Mn}\\p{Mc}\\p{Nd}\\p{Pc}\\p{Pattern_Syntax}\\p{Pattern_White_Space}]*$",
),
]
"""
Description of the function
"""
parameters: dict[str, Any]
"""
The parameters the function accepts
"""


class FunctionDefUpdate(BaseModel):
"""
Function definition
"""

model_config = ConfigDict(
populate_by_name=True,
)
name: Annotated[str, Field("overriden", pattern="^[^\\W0-9]\\w*$")]
name: Any | None = None
"""
DO NOT USE: This will be overriden by the tool name. Here only for compatibility reasons.
"""
Expand Down Expand Up @@ -124,7 +96,7 @@ class PatchToolRequest(BaseModel):
"""
Name of the tool (must be unique for this agent and a valid python identifier string )
"""
function: FunctionDefUpdate | None = None
function: FunctionDef | None = None
integration: Any | None = None
system: Any | None = None
api_call: Any | None = None
Expand Down
4 changes: 2 additions & 2 deletions agents-api/agents_api/models/agent/create_or_update_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def create_or_update_agent(
developer_id: UUID,
agent_id: UUID,
data: CreateOrUpdateAgentRequest,
) -> tuple[list[str], dict]:
) -> tuple[list[str | None], dict]:
"""
Constructs and executes a datalog query to create a new agent in the database.

Expand Down Expand Up @@ -123,7 +123,7 @@ def create_or_update_agent(

queries = [
verify_developer_id_query(developer_id),
default_settings and default_settings_query,
default_settings_query if default_settings else None,
agent_query,
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@
TypeError: partialclass(HTTPException, status_code=400),
}
)
@wrap_in_class(Transition, transform=lambda d: {"id": d["transition_id"], **d})
@wrap_in_class(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The one=True parameter in the @wrap_in_class decorator is used incorrectly here. It should be set to False or removed because the function can potentially handle multiple transitions, not just one. This could lead to incorrect assumptions about the return type when the function is used elsewhere in the code.

Suggested change
@wrap_in_class(
Transition, transform=lambda d: {"id": d["transition_id"], **d}, one=False

Transition, transform=lambda d: {"id": d["transition_id"], **d}, one=True
)
@cozo_query
@beartype
def create_execution_transition(
Expand Down
33 changes: 22 additions & 11 deletions agents-api/agents_api/models/tools/patch_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
@cozo_query
@beartype
def patch_tool(
*, developer_id: UUID, agent_id: UUID, tool_id: UUID, patch_tool: PatchToolRequest
*, developer_id: UUID, agent_id: UUID, tool_id: UUID, data: PatchToolRequest
) -> tuple[list[str], dict]:
"""
# Execute the datalog query and return the results as a DataFrame
Expand All @@ -41,14 +41,17 @@ def patch_tool(
Parameters:
- agent_id (UUID): The unique identifier of the agent.
- tool_id (UUID): The unique identifier of the tool to be updated.
- patch_tool (PatchToolRequest): The request payload containing the updated tool information.
- data (PatchToolRequest): The request payload containing the updated tool information.

Returns:
- ResourceUpdatedResponse: The updated tool data.
"""

agent_id = str(agent_id)
tool_id = str(tool_id)

# Extract the tool data from the payload
patch_data = patch_tool.model_dump(exclude_none=True)
patch_data = data.model_dump(exclude_none=True)

# Assert that only one of the tool type fields is present
tool_specs = [
Expand All @@ -64,28 +67,33 @@ def patch_tool(
patch_data["type"] = patch_data.get("type", tool_type)
assert patch_data["type"] == tool_type, "Invalid tool update"

if tool_spec is not None:
# Rename the tool definition to 'spec'
patch_data["spec"] = tool_spec
tool_spec = tool_spec or {}
if tool_spec:
del patch_data[tool_type]

tool_cols, tool_vals = cozo_process_mutate_data(
{
**patch_data,
"agent_id": str(agent_id),
"tool_id": str(tool_id),
"agent_id": agent_id,
"tool_id": tool_id,
}
)

# Construct the datalog query for updating the tool information
patch_query = f"""
input[{tool_cols}] <- $input

?[{tool_cols}, updated_at] :=
?[{tool_cols}, spec, updated_at] :=
*tools {{
agent_id: to_uuid($agent_id),
tool_id: to_uuid($tool_id),
spec: old_spec,
}},
input[{tool_cols}],
spec = concat(old_spec, $spec),
updated_at = now()

:update tools {{ {tool_cols}, updated_at }}
:update tools {{ {tool_cols}, spec, updated_at }}
:returning
"""

Expand All @@ -95,4 +103,7 @@ def patch_tool(
patch_query,
]

return (queries, dict(input=tool_vals))
return (
queries,
dict(input=tool_vals, spec=tool_spec, agent_id=agent_id, tool_id=tool_id),
)
93 changes: 83 additions & 10 deletions agents-api/agents_api/models/tools/update_tool.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,38 @@
from uuid import UUID

from beartype import beartype
from fastapi import HTTPException
from pycozo.client import QueryException
from pydantic import ValidationError

from ...autogen.openapi_model import (
PatchToolRequest,
ResourceUpdatedResponse,
UpdateToolRequest,
)
from .patch_tool import patch_tool
from ...common.utils.cozo import cozo_process_mutate_data
from ..utils import (
cozo_query,
partialclass,
rewrap_exceptions,
verify_developer_id_query,
verify_developer_owns_resource_query,
wrap_in_class,
)


@rewrap_exceptions(
{
QueryException: partialclass(HTTPException, status_code=400),
ValidationError: partialclass(HTTPException, status_code=400),
TypeError: partialclass(HTTPException, status_code=400),
}
)
@wrap_in_class(
ResourceUpdatedResponse,
one=True,
transform=lambda d: {"id": d["tool_id"], "jobs": [], **d},
)
@cozo_query
@beartype
def update_tool(
*,
Expand All @@ -18,12 +41,62 @@ def update_tool(
tool_id: UUID,
data: UpdateToolRequest,
**kwargs,
) -> ResourceUpdatedResponse:
# Same as patch_tool_query, but with a different request payload
return patch_tool(
developer_id=developer_id,
agent_id=agent_id,
tool_id=tool_id,
patch_tool=PatchToolRequest(**data.model_dump()),
**kwargs,
) -> tuple[list[str], dict]:
agent_id = str(agent_id)
tool_id = str(tool_id)

# Extract the tool data from the payload
update_data = data.model_dump(exclude_none=True)

# Assert that only one of the tool type fields is present
tool_specs = [
(tool_type, update_data.get(tool_type))
for tool_type in ["function", "integration", "system", "api_call"]
if update_data.get(tool_type) is not None
]

assert len(tool_specs) <= 1, "Invalid tool update"
tool_type, tool_spec = tool_specs[0] if tool_specs else (None, None)

if tool_type is not None:
update_data["type"] = update_data.get("type", tool_type)
assert update_data["type"] == tool_type, "Invalid tool update"

update_data["spec"] = tool_spec
del update_data[tool_type]

tool_cols, tool_vals = cozo_process_mutate_data(
{
**update_data,
"agent_id": agent_id,
"tool_id": tool_id,
}
)

# Construct the datalog query for updating the tool information
patch_query = f"""
input[{tool_cols}] <- $input

?[{tool_cols}, created_at, updated_at] :=
*tools {{
agent_id: to_uuid($agent_id),
tool_id: to_uuid($tool_id),
created_at
}},
input[{tool_cols}],
updated_at = now()

:put tools {{ {tool_cols}, created_at, updated_at }}
:returning
"""

queries = [
verify_developer_id_query(developer_id),
verify_developer_owns_resource_query(developer_id, "agents", agent_id=agent_id),
patch_query,
]

return (
queries,
dict(input=tool_vals, spec=tool_spec, agent_id=agent_id, tool_id=tool_id),
)
4 changes: 2 additions & 2 deletions agents-api/agents_api/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def cozo_query(
func: Callable[P, tuple[str | list[str], dict]] | None = None,
debug: bool | None = None,
):
def cozo_query_dec(func: Callable[P, tuple[str | list[str], dict]]):
def cozo_query_dec(func: Callable[P, tuple[str | list[Any], dict]]):
"""
Decorator that wraps a function that takes arbitrary arguments, and
returns a (query string, variables) tuple.
Expand All @@ -135,7 +135,7 @@ def wrapper(
if isinstance(queries, str):
query = queries
else:
queries = [query for query in queries if query]
queries = [str(query) for query in queries if query]
query = "}\n\n{\n".join(queries)
query = f"{{ {query} }}"

Expand Down
Loading
Loading