Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(agents-api): Add doc search system tool #604

Merged
merged 6 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 83 additions & 4 deletions agents-api/agents_api/activities/execute_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
from uuid import UUID

from beartype import beartype
from fastapi.background import BackgroundTasks
from temporalio import activity

from ..autogen.Docs import CreateDocRequest, HybridDocSearchRequest, TextOnlyDocSearchRequest, VectorDocSearchRequest
from ..autogen.Tools import SystemDef
from ..common.protocol.tasks import StepContext
from ..env import testing
Expand Down Expand Up @@ -31,6 +33,8 @@
from ..models.user.get_user import get_user as get_user_query
from ..models.user.list_users import list_users as list_users_query
from ..models.user.update_user import update_user as update_user_query
from ..routers.docs.create_doc import create_agent_doc, create_user_doc
from ..routers.docs.search_docs import search_agent_docs, search_user_docs


@beartype
Expand Down Expand Up @@ -63,17 +67,54 @@ async def execute_system(
agent_doc_args = {
**{
"owner_type": "agent",
"owner_id": arguments.pop("agent_id"),
"owner_id": arguments["agent_id"],
},
**arguments,
}
agent_doc_args.pop("agent_id")

if system.operation == "list":
return list_docs_query(**agent_doc_args)

elif system.operation == "create":
return create_doc_query(**agent_doc_args)
# The `create_agent_doc` function requires `x_developer_id` instead of `developer_id`.
arguments["x_developer_id"] = arguments.pop("developer_id")
return await create_agent_doc(
data=CreateDocRequest(**arguments.pop("data")),
background_tasks=BackgroundTasks(),
**arguments,
)

elif system.operation == "delete":
return delete_doc_query(**agent_doc_args)

elif system.operation == "search":
# The `search_agent_docs` function requires `x_developer_id` instead of `developer_id`.
arguments["x_developer_id"] = arguments.pop("developer_id")

if "text" in arguments and "vector" in arguments:
search_params = HybridDocSearchRequest(
text=arguments.pop("text"),
vector=arguments.pop("vector"),
limit=arguments.get("limit", 10),
)

elif "text" in arguments:
search_params = TextOnlyDocSearchRequest(
text=arguments.pop("text"),
limit=arguments.get("limit", 10),
)
elif "vector" in arguments:
search_params = VectorDocSearchRequest(
vector=arguments.pop("vector"),
limit=arguments.get("limit", 10),
)

return await search_agent_docs(
search_params=search_params,
**arguments,
)

# NO SUBRESOURCE
elif system.subresource == None:
if system.operation == "list":
Expand All @@ -95,17 +136,55 @@ async def execute_system(
user_doc_args = {
**{
"owner_type": "user",
"owner_id": arguments.pop("user_id"),
"owner_id": arguments["user_id"],
},
**arguments,
}
user_doc_args.pop("user_id")

if system.operation == "list":
return list_docs_query(**user_doc_args)

elif system.operation == "create":
return create_doc_query(**user_doc_args)
# The `create_user_doc` function requires `x_developer_id` instead of `developer_id`.
arguments["x_developer_id"] = arguments.pop("developer_id")
return await create_user_doc(
data=CreateDocRequest(**arguments.pop("data")),
background_tasks=BackgroundTasks(),
**arguments,
)

elif system.operation == "delete":
return delete_doc_query(**user_doc_args)

elif system.operation == "search":
# The `search_user_docs` function requires `x_developer_id` instead of `developer_id`.
arguments["x_developer_id"] = arguments.pop("developer_id")


if "text" in arguments and "vector" in arguments:
search_params = HybridDocSearchRequest(
text=arguments.pop("text"),
vector=arguments.pop("vector"),
limit=arguments.get("limit", 10),
)

elif "text" in arguments:
search_params = TextOnlyDocSearchRequest(
text=arguments.pop("text"),
limit=arguments.get("limit", 10),
)
elif "vector" in arguments:
search_params = VectorDocSearchRequest(
vector=arguments.pop("vector"),
limit=arguments.get("limit", 10),
)

return await search_user_docs(
search_params=search_params,
**arguments,
)

# NO SUBRESOURCE
elif system.subresource == None:
if system.operation == "list":
Expand Down
22 changes: 11 additions & 11 deletions agents-api/agents_api/activities/task_steps/base_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,29 +46,29 @@ async def base_evaluate(
evaluator = get_evaluator(names=values, extra_functions=extra_lambdas)

try:
result = None
match exprs:
case str():
return evaluator.eval(exprs)

result = evaluator.eval(exprs)
case list():
return [evaluator.eval(expr) for expr in exprs]

case dict() as d if all(isinstance(v, dict) for v in d.values()):
return {
result = [evaluator.eval(expr) for expr in exprs]
case dict() as d if all(
isinstance(v, dict) or isinstance(v, str) for v in d.values()
):
result = {
k: {ik: evaluator.eval(iv) for ik, iv in v.items()}
if isinstance(v, dict)
else evaluator.eval(v)
for k, v in d.items()
}

case dict():
return {k: evaluator.eval(v) for k, v in exprs.items()}

case _:
raise ValueError(f"Invalid expression: {exprs}")

return result

except BaseException as e:
if activity.in_activity():
activity.logger.error(f"Error in base_evaluate: {e}")

raise


Expand Down
2 changes: 2 additions & 0 deletions agents-api/agents_api/models/docs/list_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def list_docs(
sort_by: Literal["created_at"] = "created_at",
direction: Literal["asc", "desc"] = "desc",
metadata_filter: dict[str, Any] = {},
include_without_embeddings: bool = False,
) -> tuple[list[str], dict]:
# Transforms the metadata_filter dictionary into a string representation for the datalog query.
metadata_filter_str = ", ".join(
Expand All @@ -70,6 +71,7 @@ def list_docs(
content,
embedding,
}},
{"" if include_without_embeddings else "not is_null(embedding),"}
snippet_data = [index, content, embedding]

?[
Expand Down
Loading