Skip to content

Commit

Permalink
fix: Split payload content by smaller batches for embedding (#653)
Browse files Browse the repository at this point in the history
<!-- ELLIPSIS_HIDDEN -->


> [!IMPORTANT]
> `embed_docs` in `embed_docs.py` now processes payloads in smaller
batches asynchronously using `batched` and `asyncio`, with a new test
case added.
> 
>   - **Behavior**:
> - `embed_docs` in `embed_docs.py` now processes payload content in
smaller batches using `batched` from `itertools`.
> - Introduces `max_batch_size` parameter to control batch size,
defaulting to 100.
>     - Uses `asyncio.wait` for asynchronous embedding of batches.
>   - **Functions**:
> - Adds `embed_batch` inner function to process each batch of indices
and snippets.
>     - Modifies `embed_docs` to use `embed_batch` for batch processing.
>   - **Imports**:
> - Adds `asyncio` and `batched` imports to support new batching logic.
>   - **Tests**:
> - Adds test case in `test_activities.py` to verify `embed_docs` with
batching logic using `unittest.mock.patch`.
> 
> <sup>This description was created by </sup>[<img alt="Ellipsis"
src="https://img.shields.io/badge/Ellipsis-blue?color=175173">](https://www.ellipsis.dev?ref=julep-ai%2Fjulep&utm_source=github&utm_medium=referral)<sup>
for 30b26be. It will automatically
update as commits are pushed.</sup>


<!-- ELLIPSIS_HIDDEN -->

Co-authored-by: Diwank Singh Tomer <[email protected]>
  • Loading branch information
whiterabbit1983 and creatorrr authored Oct 15, 2024
1 parent 18412c3 commit 3dacacd
Showing 1 changed file with 25 additions and 9 deletions.
34 changes: 25 additions & 9 deletions agents-api/agents_api/activities/embed_docs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
import asyncio
import operator
from functools import reduce
from itertools import batched

from beartype import beartype
from temporalio import activity

Expand All @@ -8,18 +13,27 @@


@beartype
async def embed_docs(payload: EmbedDocsPayload, cozo_client=None) -> None:
async def embed_docs(
payload: EmbedDocsPayload, cozo_client=None, max_batch_size: int = 100
) -> None:
indices, snippets = list(zip(*enumerate(payload.content)))
batched_snippets = batched(snippets, max_batch_size)
embed_instruction: str = payload.embed_instruction or ""
title: str = payload.title or ""

embeddings = await litellm.aembedding(
inputs=[
(
embed_instruction + (title + "\n\n" + snippet) if title else snippet
).strip()
for snippet in snippets
]
async def embed_batch(snippets):
return await litellm.aembedding(
inputs=[
(
embed_instruction + (title + "\n\n" + snippet) if title else snippet
).strip()
for snippet in snippets
]
)

embeddings = reduce(
operator.add,
await asyncio.gather(*[embed_batch(snippets) for snippets in batched_snippets]),
)

embed_snippets_query(
Expand All @@ -31,7 +45,9 @@ async def embed_docs(payload: EmbedDocsPayload, cozo_client=None) -> None:
)


async def mock_embed_docs(payload: EmbedDocsPayload, cozo_client=None) -> None:
async def mock_embed_docs(
payload: EmbedDocsPayload, cozo_client=None, max_batch_size=100
) -> None:
# Does nothing
return None

Expand Down

0 comments on commit 3dacacd

Please sign in to comment.