Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(agents-api): Fix doc recall using search by text #506

Merged
merged 4 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion agents-api/agents_api/clients/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
@wraps(_acompletion)
@beartype
async def acompletion(
*, model: str, messages: list[dict], custom_api_key: str = None, **kwargs
*, model: str, messages: list[dict], custom_api_key: None | str = None, **kwargs
creatorrr marked this conversation as resolved.
Show resolved Hide resolved
) -> ModelResponse | CustomStreamWrapper:

supported_params = get_supported_openai_params(model)
Expand Down
29 changes: 24 additions & 5 deletions agents-api/agents_api/models/docs/search_docs_by_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,11 @@ def search_docs_by_embedding(
"""

collect_query = """
m[
n[
doc_id,
owner_type,
owner_id,
collect(snippet),
unique(snippet_data),
distance,
title,
] :=
Expand All @@ -209,11 +209,30 @@ def search_docs_by_embedding(
snippet_data,
distance,
title,
}, snippet = {
"index": snippet_data->0,
"content": snippet_data->1,
}

m[
doc_id,
owner_type,
owner_id,
collect(snippet),
distance,
title,
] :=
n[
doc_id,
owner_type,
owner_id,
snippet_data,
distance,
title,
],
snippet = {
"index": snippet_datum->0,
"content": snippet_datum->1
},
snippet_datum in snippet_data

?[
id,
owner_type,
Expand Down
5 changes: 5 additions & 0 deletions agents-api/agents_api/models/docs/search_docs_by_text.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""This module contains functions for searching documents in the CozoDB based on embedding queries."""

import json
from typing import Any, Literal, TypeVar
from uuid import UUID

Expand Down Expand Up @@ -61,6 +62,10 @@ def search_docs_by_text(
[owner_type, str(owner_id)] for owner_type, owner_id in owners
]

# Need to use NEAR/3($query) to search for arbitrary text within 3 words of each other
# See: https://docs.cozodb.org/en/latest/vector.html#full-text-search-fts
query = f"NEAR/3({json.dumps(query)})"

# Construct the datalog query for searching document snippets
search_query = f"""
owners[owner_type, owner_id] <- $owners
Expand Down
2 changes: 0 additions & 2 deletions agents-api/agents_api/models/docs/search_docs_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ def dbsf_fuse(
"""
all_docs = {doc.id: doc for doc in text_results + embedding_results}

assert all(doc.distance is not None in all_docs for doc in text_results)

text_scores: dict[UUID, float] = {
doc.id: -(doc.distance or 0.0) for doc in text_results
}
Expand Down
4 changes: 2 additions & 2 deletions agents-api/agents_api/routers/agents/create_agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Annotated

from fastapi import Depends
from pydantic import UUID4
from uuid import UUID
from starlette.status import HTTP_201_CREATED

import agents_api.models as models
Expand All @@ -16,7 +16,7 @@

@router.post("/agents", status_code=HTTP_201_CREATED, tags=["agents"])
async def create_agent(
x_developer_id: Annotated[UUID4, Depends(get_developer_id)],
x_developer_id: Annotated[UUID, Depends(get_developer_id)],
data: CreateAgentRequest,
) -> ResourceCreatedResponse:
# TODO: Validate model name
Expand Down
4 changes: 2 additions & 2 deletions agents-api/agents_api/routers/agents/create_agent_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from uuid import UUID

from fastapi import Depends
from pydantic import UUID4
from uuid import UUID
from starlette.status import HTTP_201_CREATED

import agents_api.models as models
Expand All @@ -18,7 +18,7 @@
@router.post("/agents/{agent_id}/tools", status_code=HTTP_201_CREATED, tags=["agents"])
async def create_agent_tool(
agent_id: UUID,
x_developer_id: Annotated[UUID4, Depends(get_developer_id)],
x_developer_id: Annotated[UUID, Depends(get_developer_id)],
data: CreateToolRequest,
) -> ResourceCreatedResponse:
tool = models.tools.create_tools(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from uuid import UUID

from fastapi import Depends
from pydantic import UUID4
from uuid import UUID
from starlette.status import HTTP_201_CREATED

import agents_api.models as models
Expand All @@ -19,7 +19,7 @@
async def create_or_update_agent(
agent_id: UUID,
data: CreateOrUpdateAgentRequest,
x_developer_id: Annotated[UUID4, Depends(get_developer_id)],
x_developer_id: Annotated[UUID, Depends(get_developer_id)],
) -> ResourceCreatedResponse:
# TODO: Validate model name
agent = models.agent.create_or_update_agent(
Expand Down
4 changes: 2 additions & 2 deletions agents-api/agents_api/routers/agents/delete_agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Annotated

from fastapi import Depends
from pydantic import UUID4
from uuid import UUID
from starlette.status import HTTP_202_ACCEPTED

from ...autogen.openapi_model import ResourceDeletedResponse
Expand All @@ -12,6 +12,6 @@

@router.delete("/agents/{agent_id}", status_code=HTTP_202_ACCEPTED, tags=["agents"])
async def delete_agent(
agent_id: UUID4, x_developer_id: Annotated[UUID4, Depends(get_developer_id)]
agent_id: UUID, x_developer_id: Annotated[UUID, Depends(get_developer_id)]
) -> ResourceDeletedResponse:
return delete_agent_query(developer_id=x_developer_id, agent_id=agent_id)
4 changes: 2 additions & 2 deletions agents-api/agents_api/routers/agents/delete_agent_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from uuid import UUID

from fastapi import Depends
from pydantic import UUID4
from uuid import UUID

from ...autogen.openapi_model import ResourceDeletedResponse
from ...dependencies.developer_id import get_developer_id
Expand All @@ -14,7 +14,7 @@
async def delete_agent_tool(
agent_id: UUID,
tool_id: UUID,
x_developer_id: Annotated[UUID4, Depends(get_developer_id)],
x_developer_id: Annotated[UUID, Depends(get_developer_id)],
) -> ResourceDeletedResponse:
return delete_tool(
developer_id=x_developer_id,
Expand Down
6 changes: 3 additions & 3 deletions agents-api/agents_api/routers/agents/get_agent_details.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Annotated

from fastapi import Depends
from pydantic import UUID4
from uuid import UUID

from ...autogen.openapi_model import Agent
from ...dependencies.developer_id import get_developer_id
Expand All @@ -11,7 +11,7 @@

@router.get("/agents/{agent_id}", tags=["agents"])
async def get_agent_details(
agent_id: UUID4,
x_developer_id: Annotated[UUID4, Depends(get_developer_id)],
agent_id: UUID,
x_developer_id: Annotated[UUID, Depends(get_developer_id)],
) -> Agent:
return get_agent_query(developer_id=x_developer_id, agent_id=agent_id)
4 changes: 2 additions & 2 deletions agents-api/agents_api/routers/agents/list_agent_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from uuid import UUID

from fastapi import Depends
from pydantic import UUID4
from uuid import UUID

from ...autogen.openapi_model import ListResponse, Tool
from ...dependencies.developer_id import get_developer_id
Expand All @@ -12,7 +12,7 @@

@router.get("/agents/{agent_id}/tools", tags=["agents"])
async def list_agent_tools(
x_developer_id: Annotated[UUID4, Depends(get_developer_id)],
x_developer_id: Annotated[UUID, Depends(get_developer_id)],
agent_id: UUID,
limit: int = 100,
offset: int = 0,
Expand Down
4 changes: 2 additions & 2 deletions agents-api/agents_api/routers/agents/list_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Annotated, Literal

from fastapi import Depends, HTTPException, status
from pydantic import UUID4
from uuid import UUID

from ...autogen.openapi_model import Agent, ListResponse
from ...dependencies.developer_id import get_developer_id
Expand All @@ -13,7 +13,7 @@

@router.get("/agents", tags=["agents"])
async def list_agents(
x_developer_id: Annotated[UUID4, Depends(get_developer_id)],
x_developer_id: Annotated[UUID, Depends(get_developer_id)],
limit: int = 100,
offset: int = 0,
sort_by: Literal["created_at", "updated_at"] = "created_at",
Expand Down
6 changes: 3 additions & 3 deletions agents-api/agents_api/routers/agents/patch_agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Annotated

from fastapi import Depends
from pydantic import UUID4
from uuid import UUID
from starlette.status import HTTP_200_OK

from ...autogen.openapi_model import PatchAgentRequest, ResourceUpdatedResponse
Expand All @@ -17,8 +17,8 @@
tags=["agents"],
)
async def patch_agent(
x_developer_id: Annotated[UUID4, Depends(get_developer_id)],
agent_id: UUID4,
x_developer_id: Annotated[UUID, Depends(get_developer_id)],
agent_id: UUID,
data: PatchAgentRequest,
) -> ResourceUpdatedResponse:
return patch_agent_query(
Expand Down
4 changes: 2 additions & 2 deletions agents-api/agents_api/routers/agents/patch_agent_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from uuid import UUID

from fastapi import Depends
from pydantic import UUID4
from uuid import UUID

from ...autogen.openapi_model import (
PatchToolRequest,
Expand All @@ -15,7 +15,7 @@

@router.patch("/agents/{agent_id}/tools/{tool_id}", tags=["agents"])
async def patch_agent_tool(
x_developer_id: Annotated[UUID4, Depends(get_developer_id)],
x_developer_id: Annotated[UUID, Depends(get_developer_id)],
agent_id: UUID,
tool_id: UUID,
data: PatchToolRequest,
Expand Down
6 changes: 3 additions & 3 deletions agents-api/agents_api/routers/agents/update_agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Annotated

from fastapi import Depends
from pydantic import UUID4
from uuid import UUID
from starlette.status import HTTP_200_OK

from ...autogen.openapi_model import ResourceUpdatedResponse, UpdateAgentRequest
Expand All @@ -17,8 +17,8 @@
tags=["agents"],
)
async def update_agent(
x_developer_id: Annotated[UUID4, Depends(get_developer_id)],
agent_id: UUID4,
x_developer_id: Annotated[UUID, Depends(get_developer_id)],
agent_id: UUID,
data: UpdateAgentRequest,
) -> ResourceUpdatedResponse:
return update_agent_query(
Expand Down
4 changes: 2 additions & 2 deletions agents-api/agents_api/routers/agents/update_agent_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from uuid import UUID

from fastapi import Depends
from pydantic import UUID4
from uuid import UUID

from ...autogen.openapi_model import (
ResourceUpdatedResponse,
Expand All @@ -15,7 +15,7 @@

@router.put("/agents/{agent_id}/tools/{tool_id}", tags=["agents"])
async def update_agent_tool(
x_developer_id: Annotated[UUID4, Depends(get_developer_id)],
x_developer_id: Annotated[UUID, Depends(get_developer_id)],
agent_id: UUID,
tool_id: UUID,
data: UpdateToolRequest,
Expand Down
10 changes: 5 additions & 5 deletions agents-api/agents_api/routers/docs/create_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from uuid import UUID, uuid4

from fastapi import BackgroundTasks, Depends
from pydantic import UUID4
from uuid import UUID
from starlette.status import HTTP_201_CREATED
from temporalio.client import Client as TemporalClient

Expand Down Expand Up @@ -54,9 +54,9 @@ async def run_embed_docs_task(

@router.post("/users/{user_id}/docs", status_code=HTTP_201_CREATED, tags=["docs"])
async def create_user_doc(
user_id: UUID4,
user_id: UUID,
data: CreateDocRequest,
x_developer_id: Annotated[UUID4, Depends(get_developer_id)],
x_developer_id: Annotated[UUID, Depends(get_developer_id)],
background_tasks: BackgroundTasks,
) -> ResourceCreatedResponse:
doc = create_doc_query(
Expand Down Expand Up @@ -84,9 +84,9 @@ async def create_user_doc(

@router.post("/agents/{agent_id}/docs", status_code=HTTP_201_CREATED, tags=["docs"])
async def create_agent_doc(
agent_id: UUID4,
agent_id: UUID,
data: CreateDocRequest,
x_developer_id: Annotated[UUID4, Depends(get_developer_id)],
x_developer_id: Annotated[UUID, Depends(get_developer_id)],
background_tasks: BackgroundTasks,
) -> ResourceCreatedResponse:
doc = create_doc_query(
Expand Down
14 changes: 7 additions & 7 deletions agents-api/agents_api/routers/docs/delete_doc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Annotated

from fastapi import Depends
from pydantic import UUID4
from uuid import UUID
from starlette.status import HTTP_202_ACCEPTED

from ...autogen.openapi_model import ResourceDeletedResponse
Expand All @@ -14,9 +14,9 @@
"/agents/{agent_id}/docs/{doc_id}", status_code=HTTP_202_ACCEPTED, tags=["docs"]
)
async def delete_agent_doc(
doc_id: UUID4,
agent_id: UUID4,
x_developer_id: Annotated[UUID4, Depends(get_developer_id)],
doc_id: UUID,
agent_id: UUID,
x_developer_id: Annotated[UUID, Depends(get_developer_id)],
) -> ResourceDeletedResponse:
return delete_doc_query(
developer_id=x_developer_id,
Expand All @@ -30,9 +30,9 @@ async def delete_agent_doc(
"/users/{user_id}/docs/{doc_id}", status_code=HTTP_202_ACCEPTED, tags=["docs"]
)
async def delete_user_doc(
doc_id: UUID4,
user_id: UUID4,
x_developer_id: Annotated[UUID4, Depends(get_developer_id)],
doc_id: UUID,
user_id: UUID,
x_developer_id: Annotated[UUID, Depends(get_developer_id)],
) -> ResourceDeletedResponse:
return delete_doc_query(
developer_id=x_developer_id,
Expand Down
4 changes: 2 additions & 2 deletions agents-api/agents_api/routers/docs/embed.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Annotated

from fastapi import Depends
from pydantic import UUID4
from uuid import UUID

import agents_api.clients.embed as embedder

Expand All @@ -15,7 +15,7 @@

@router.post("/embed", tags=["docs"])
async def embed(
x_developer_id: Annotated[UUID4, Depends(get_developer_id)],
x_developer_id: Annotated[UUID, Depends(get_developer_id)],
data: EmbedQueryRequest,
) -> EmbedQueryResponse:
text_to_embed: str | list[str] = data.text
Expand Down
Loading
Loading