diff --git a/agents-api/agents_api/activities/embed_docs.py b/agents-api/agents_api/activities/embed_docs.py index 924424881..4a2611892 100644 --- a/agents-api/agents_api/activities/embed_docs.py +++ b/agents-api/agents_api/activities/embed_docs.py @@ -1,3 +1,6 @@ +import asyncio +from itertools import batched + from beartype import beartype from temporalio import activity @@ -8,26 +11,41 @@ @beartype -async def embed_docs(payload: EmbedDocsPayload, cozo_client=None) -> None: +async def embed_docs( + payload: EmbedDocsPayload, cozo_client=None, max_batch_size: int = 100 +) -> None: indices, snippets = list(zip(*enumerate(payload.content))) - embed_instruction: str = payload.embed_instruction or "" - title: str = payload.title or "" - - embeddings = await litellm.aembedding( - inputs=[ - ( - embed_instruction + (title + "\n\n" + snippet) if title else snippet - ).strip() - for snippet in snippets - ] + batched_indices, batched_snippets = ( + batched(indices, max_batch_size), + batched(snippets, max_batch_size), ) - embed_snippets_query( - developer_id=payload.developer_id, - doc_id=payload.doc_id, - snippet_indices=indices, - embeddings=embeddings, - client=cozo_client or cozo.get_cozo_client(), + async def embed_batch(indices, snippets): + embed_instruction: str = payload.embed_instruction or "" + title: str = payload.title or "" + + embeddings = await litellm.aembedding( + inputs=[ + ( + embed_instruction + (title + "\n\n" + snippet) if title else snippet + ).strip() + for snippet in snippets + ] + ) + + embed_snippets_query( + developer_id=payload.developer_id, + doc_id=payload.doc_id, + snippet_indices=indices, + embeddings=embeddings, + client=cozo_client or cozo.get_cozo_client(), + ) + + await asyncio.wait( + [ + embed_batch(indices, snippets) + for indices, snippets in zip(batched_indices, batched_snippets) + ] ) diff --git a/agents-api/tests/test_activities.py b/agents-api/tests/test_activities.py index 6f65cd034..6a6594453 100644 --- a/agents-api/tests/test_activities.py +++ b/agents-api/tests/test_activities.py @@ -1,4 +1,5 @@ from uuid import uuid4 +from unittest.mock import patch from ward import test @@ -40,6 +41,34 @@ async def _( ) +@test("activity: call direct embed_docs with batching") +async def _( + cozo_client=cozo_client, + developer_id=test_developer_id, + doc=test_doc, +): + title = "title" + content = ["content 1", "content 2", "content 3", "content 4", "content 5"] + include_title = True + + with patch("agents_api.activities.embed_docs.embed_snippets_query") as embed_query: + embed_query.return_value = None + + await embed_docs( + EmbedDocsPayload( + developer_id=developer_id, + doc_id=doc.id, + title=title, + content=content, + include_title=include_title, + embed_instruction=None, + ), + cozo_client, + max_batch_size=2, + ) + + embed_query.call_count == 3 + @test("activity: call demo workflow via temporal client") async def _(): async with patch_testing_temporal() as (_, mock_get_client):