From 7ce17f0b6a2beb5173b4ecf34855a2364a500321 Mon Sep 17 00:00:00 2001 From: Andrew White Date: Fri, 29 Mar 2024 17:04:14 -0700 Subject: [PATCH] Added batching to embed documents (#262) --- paperqa/llms.py | 16 +++++++++++----- pyproject.toml | 2 +- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/paperqa/llms.py b/paperqa/llms.py index a39bb1376..7944e10d1 100644 --- a/paperqa/llms.py +++ b/paperqa/llms.py @@ -87,17 +87,23 @@ def process_llm_config( async def embed_documents( - client: AsyncOpenAI, texts: list[str], embedding_model: str + client: AsyncOpenAI, texts: list[str], embedding_model: str, batch_size: int = 16 ) -> list[list[float]]: """Embed a list of documents with batching.""" if client is None: raise ValueError( "Your client is None - did you forget to set it after pickling?" ) - response = await client.embeddings.create( - model=embedding_model, input=texts, encoding_format="float" - ) - return [e.embedding for e in response.data] + N = len(texts) + embeddings = [] + for i in range(0, N, batch_size): + response = await client.embeddings.create( + model=embedding_model, + input=texts[i : i + batch_size], + encoding_format="float", + ) + embeddings.extend([e.embedding for e in response.data]) + return embeddings class EmbeddingModel(ABC, BaseModel): diff --git a/pyproject.toml b/pyproject.toml index 28cf09c03..ba9b05016 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ name = "paper-qa" readme = "README.md" requires-python = ">=3.8" urls = {repository = "https://github.com/whitead/paper-qa"} -version = "4.4.0" +version = "4.4.1" [tool.codespell] check-filenames = true