Skip to content

Commit

Permalink
fix: Split payload content by smaller batches for embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
whiterabbit1983 committed Oct 15, 2024
1 parent 697d95e commit cb9eb91
Showing 1 changed file with 25 additions and 9 deletions.
34 changes: 25 additions & 9 deletions agents-api/agents_api/activities/embed_docs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
import asyncio
import operator
from functools import reduce
from itertools import batched

from beartype import beartype
from temporalio import activity

Expand All @@ -8,18 +13,27 @@


@beartype
async def embed_docs(payload: EmbedDocsPayload, cozo_client=None) -> None:
async def embed_docs(
payload: EmbedDocsPayload, cozo_client=None, max_batch_size: int = 100
) -> None:
indices, snippets = list(zip(*enumerate(payload.content)))
batched_snippets = batched(snippets, max_batch_size)
embed_instruction: str = payload.embed_instruction or ""
title: str = payload.title or ""

embeddings = await litellm.aembedding(
inputs=[
(
embed_instruction + (title + "\n\n" + snippet) if title else snippet
).strip()
for snippet in snippets
]
async def embed_batch(snippets):
return await litellm.aembedding(
inputs=[
(
embed_instruction + (title + "\n\n" + snippet) if title else snippet
).strip()
for snippet in snippets
]
)

embeddings = reduce(
operator.add,
await asyncio.gather(*[embed_batch(snippets) for snippets in batched_snippets]),
)

embed_snippets_query(
Expand All @@ -31,7 +45,9 @@ async def embed_docs(payload: EmbedDocsPayload, cozo_client=None) -> None:
)


async def mock_embed_docs(payload: EmbedDocsPayload, cozo_client=None) -> None:
async def mock_embed_docs(
payload: EmbedDocsPayload, cozo_client=None, max_batch_size=100
) -> None:
# Does nothing
return None

Expand Down

0 comments on commit cb9eb91

Please sign in to comment.