Skip to content

Commit

Permalink
refactor: Lint agents-api (CI)
Browse files Browse the repository at this point in the history
  • Loading branch information
Vedantsahai18 authored and github-actions[bot] committed Dec 21, 2024
1 parent 249513d commit d7d9cd4
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 17 deletions.
3 changes: 1 addition & 2 deletions agents-api/agents_api/queries/docs/create_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ async def create_doc(
if isinstance(data.content, list):
final_params_doc = []
final_params_owner = []

for idx, content in enumerate(data.content):
doc_params = [
developer_id,
Expand Down Expand Up @@ -185,7 +185,6 @@ async def create_doc(
queries.append((doc_owner_query, final_params_owner, "fetchmany"))

else:

# Create the doc record
doc_params = [
developer_id,
Expand Down
6 changes: 4 additions & 2 deletions agents-api/agents_api/queries/docs/get_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@
"id": d["doc_id"],
"index": d["indices"][0],
"content": d["content"][0] if len(d["content"]) == 1 else d["content"],
"embeddings": d["embeddings"][0] if len(d["embeddings"]) == 1 else d["embeddings"],
"embeddings": d["embeddings"][0]
if len(d["embeddings"]) == 1
else d["embeddings"],
**d,
},
)
Expand All @@ -64,7 +66,7 @@ async def get_doc(
) -> tuple[str, list]:
"""
Fetch a single doc with its embedding, grouping all content chunks and embeddings.
Parameters:
developer_id (UUID): The ID of the developer.
doc_id (UUID): The ID of the document.
Expand Down
4 changes: 3 additions & 1 deletion agents-api/agents_api/queries/docs/list_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@
"id": d["doc_id"],
"index": d["indices"][0],
"content": d["content"][0] if len(d["content"]) == 1 else d["content"],
"embedding": d["embeddings"][0] if d.get("embeddings") and len(d["embeddings"]) == 1 else d.get("embeddings"),
"embedding": d["embeddings"][0]
if d.get("embeddings") and len(d["embeddings"]) == 1
else d.get("embeddings"),
**d,
},
)
Expand Down
16 changes: 7 additions & 9 deletions agents-api/agents_api/queries/docs/search_docs_by_text.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,23 @@
from typing import Any, Literal, List
import json
from typing import Any, List, Literal
from uuid import UUID

import asyncpg
from beartype import beartype
from fastapi import HTTPException
from sqlglot import parse_one
import asyncpg
import json

from ...autogen.openapi_model import DocReference
from ..utils import pg_query, wrap_in_class, rewrap_exceptions, partialclass
from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class

search_docs_text_query = (
"""
search_docs_text_query = """
SELECT * FROM search_by_text(
$1, -- developer_id
$2, -- query
$3, -- owner_types
( SELECT array_agg(*)::UUID[] FROM jsonb_array_elements($4) )
)
"""
)


@rewrap_exceptions(
Expand Down Expand Up @@ -74,10 +72,10 @@ async def search_docs_by_text(
# Extract owner types and IDs
owner_types = [owner[0] for owner in owners]
owner_ids = [owner[1] for owner in owners]

return (
search_docs_text_query,
[
[
developer_id,
query,
owner_types,
Expand Down
1 change: 1 addition & 0 deletions agents-api/tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def test_developer_id():
developer_id = uuid7()
return developer_id


@fixture(scope="global")
async def test_developer(dsn=pg_dsn, developer_id=test_developer_id):
pool = await create_db_pool(dsn=dsn)
Expand Down
9 changes: 6 additions & 3 deletions agents-api/tests/test_docs_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

# If you wish to test text/embedding/hybrid search, import them:
from agents_api.queries.docs.search_docs_by_text import search_docs_by_text

# from agents_api.queries.docs.search_docs_by_embedding import search_docs_by_embedding
# from agents_api.queries.docs.search_docs_hybrid import search_docs_hybrid
# You can rename or remove these imports to match your actual fixtures
Expand Down Expand Up @@ -81,6 +82,7 @@ async def _(dsn=pg_dsn, developer=test_developer, doc=test_doc):
assert doc_test.title == doc.title
assert doc_test.content == doc.content


@test("query: list user docs")
async def _(dsn=pg_dsn, developer=test_developer, user=test_user):
pool = await create_db_pool(dsn=dsn)
Expand Down Expand Up @@ -212,17 +214,18 @@ async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent):
)
assert not any(d.id == doc_agent.id for d in docs_list)


@test("query: search docs by text")
async def _(dsn=pg_dsn, agent=test_agent, developer=test_developer):
pool = await create_db_pool(dsn=dsn)

# Create a test document
await create_doc(
developer_id=developer.id,
owner_type="agent",
owner_id=agent.id,
data=CreateDocRequest(
title="Hello",
title="Hello",
content="The world is a funny little thing",
metadata={"test": "test"},
embed_instruction="Embed the document",
Expand All @@ -242,4 +245,4 @@ async def _(dsn=pg_dsn, agent=test_agent, developer=test_developer):
)

assert len(result) >= 1
assert result[0].metadata is not None
assert result[0].metadata is not None

0 comments on commit d7d9cd4

Please sign in to comment.