Skip to content

Commit

Permalink
fix: Split payload content by smaller batches for embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
whiterabbit1983 committed Oct 15, 2024
1 parent 697d95e commit 0476f38
Showing 1 changed file with 38 additions and 18 deletions.
56 changes: 38 additions & 18 deletions agents-api/agents_api/activities/embed_docs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import asyncio
from itertools import batched

from beartype import beartype
from temporalio import activity

Expand All @@ -8,30 +11,47 @@


@beartype
async def embed_docs(payload: EmbedDocsPayload, cozo_client=None) -> None:
async def embed_docs(
payload: EmbedDocsPayload, cozo_client=None, max_batch_size: int = 100
) -> None:
indices, snippets = list(zip(*enumerate(payload.content)))
embed_instruction: str = payload.embed_instruction or ""
title: str = payload.title or ""

embeddings = await litellm.aembedding(
inputs=[
(
embed_instruction + (title + "\n\n" + snippet) if title else snippet
).strip()
for snippet in snippets
]
batched_indices, batched_snippets = (
batched(indices, max_batch_size),
batched(snippets, max_batch_size),
)

embed_snippets_query(
developer_id=payload.developer_id,
doc_id=payload.doc_id,
snippet_indices=indices,
embeddings=embeddings,
client=cozo_client or cozo.get_cozo_client(),
async def embed_batch(indices, snippets):
embed_instruction: str = payload.embed_instruction or ""
title: str = payload.title or ""

embeddings = await litellm.aembedding(
inputs=[
(
embed_instruction + (title + "\n\n" + snippet) if title else snippet
).strip()
for snippet in snippets
]
)

embed_snippets_query(
developer_id=payload.developer_id,
doc_id=payload.doc_id,
snippet_indices=indices,
embeddings=embeddings,
client=cozo_client or cozo.get_cozo_client(),
)

await asyncio.wait(
[
embed_batch(indices, snippets)
for indices, snippets in zip(batched_indices, batched_snippets)
]
)


async def mock_embed_docs(payload: EmbedDocsPayload, cozo_client=None) -> None:
async def mock_embed_docs(
payload: EmbedDocsPayload, cozo_client=None, max_batch_size=100
) -> None:
# Does nothing
return None

Expand Down

0 comments on commit 0476f38

Please sign in to comment.