Skip to content

Commit

Permalink
feat: Make docs content be splited by users (#303)
Browse files Browse the repository at this point in the history
  • Loading branch information
whiterabbit1983 authored May 1, 2024
1 parent 5dbfa7b commit 9bcf5a9
Show file tree
Hide file tree
Showing 7 changed files with 178 additions and 76 deletions.
8 changes: 6 additions & 2 deletions agents-api/agents_api/autogen/openapi_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# generated by datamodel-codegen:
# filename: openapi.yaml
# timestamp: 2024-04-26T08:42:42+00:00
# timestamp: 2024-04-30T17:38:56+00:00

from __future__ import annotations

Expand Down Expand Up @@ -632,12 +632,16 @@ class Doc(BaseModel):
"""


class ContentItem(RootModel[str]):
root: Annotated[str, Field(min_length=1)]


class CreateDoc(BaseModel):
title: str
"""
Title describing what this bit of information contains
"""
content: str
content: List[ContentItem] | str
"""
Information content
"""
Expand Down
3 changes: 2 additions & 1 deletion agents-api/agents_api/clients/embed.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import httpx
from ..env import embedding_service_url, truncate_embed_text
from ..env import embedding_service_url, truncate_embed_text, embedding_model_id


async def embed(
Expand All @@ -17,6 +17,7 @@ async def embed(
"normalize": True,
# FIXME: We should control the truncation ourselves and truncate before sending
"truncate": truncate_embed_text,
"model_id": embedding_model_id,
},
)
resp.raise_for_status()
Expand Down
4 changes: 4 additions & 0 deletions agents-api/agents_api/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@
"EMBEDDING_SERVICE_URL", default="http://0.0.0.0:8082/embed"
)

embedding_model_id: str = env.str(
"EMBEDDING_MODEL_ID", default="BAAI/bge-large-en-v1.5"
)

truncate_embed_text: bool = env.bool("TRUNCATE_EMBED_TEXT", default=False)

# Temporal
Expand Down
10 changes: 3 additions & 7 deletions agents-api/agents_api/models/docs/create_docs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Literal
from typing import Literal
from uuid import UUID


Expand All @@ -13,8 +13,7 @@ def create_docs_query(
owner_id: UUID,
id: UUID,
title: str,
content: str,
split_fn: Callable[[str], list[str]] = lambda x: x.split("\n\n"),
content: list[str],
metadata: dict = {},
) -> tuple[str, dict]:
"""
Expand All @@ -26,19 +25,16 @@ def create_docs_query(
- id (UUID): The UUID of the document to be created.
- title (str): The title of the document.
- content (str): The content of the document, which will be split into snippets.
- split_fn (Callable[[str], list[str]]): A function to split the content into snippets. Defaults to splitting by double newlines.
- metadata (dict): Metadata associated with the document. Defaults to an empty dictionary.
Returns:
pd.DataFrame: A DataFrame containing the results of the query execution.
"""
created_at: float = utcnow().timestamp()

snippets = split_fn(content)
snippet_cols, snippet_rows = "", []

# Process each content snippet and prepare data for the datalog query.
for snippet_idx, snippet in enumerate(snippets):
for snippet_idx, snippet in enumerate(content):
snippet_cols, new_snippet_rows = cozo_process_mutate_data(
dict(
doc_id=str(id),
Expand Down
5 changes: 3 additions & 2 deletions agents-api/agents_api/routers/agents/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,12 +301,13 @@ async def list_agents(
@router.post("/agents/{agent_id}/docs", tags=["agents"])
async def create_docs(agent_id: UUID4, request: CreateDoc) -> ResourceCreatedResponse:
doc_id = uuid4()
content = [request.content] if isinstance(request.content, str) else request.content
resp: pd.DataFrame = create_docs_query(
owner_type="agent",
owner_id=agent_id,
id=doc_id,
title=request.title,
content=request.content,
content=content,
metadata=request.metadata or {},
)

Expand All @@ -316,7 +317,7 @@ async def create_docs(agent_id: UUID4, request: CreateDoc) -> ResourceCreatedRes
created_at=resp["created_at"][0],
)

indices, snippets = list(zip(*enumerate(request.content.split("\n\n"))))
indices, snippets = list(zip(*enumerate(content)))
embeddings = await embed(
[
snippet_embed_instruction + request.title + "\n\n" + snippet
Expand Down
5 changes: 3 additions & 2 deletions agents-api/agents_api/routers/users/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,12 +238,13 @@ async def list_users(
@router.post("/users/{user_id}/docs", tags=["users"])
async def create_docs(user_id: UUID4, request: CreateDoc) -> ResourceCreatedResponse:
doc_id = uuid4()
content = [request.content] if isinstance(request.content, str) else request.content
resp: pd.DataFrame = create_docs_query(
owner_type="user",
owner_id=user_id,
id=doc_id,
title=request.title,
content=request.content,
content=content,
metadata=request.metadata or {},
)

Expand All @@ -253,7 +254,7 @@ async def create_docs(user_id: UUID4, request: CreateDoc) -> ResourceCreatedResp
created_at=resp["created_at"][0],
)

indices, snippets = list(zip(*enumerate(request.content.split("\n\n"))))
indices, snippets = list(zip(*enumerate(content)))
embeddings = await embed(
[
snippet_embed_instruction + request.title + "\n\n" + snippet
Expand Down
Loading

0 comments on commit 9bcf5a9

Please sign in to comment.