diff --git a/paperqa/docs.py b/paperqa/docs.py index 06203ff0f..f20c936f5 100644 --- a/paperqa/docs.py +++ b/paperqa/docs.py @@ -403,36 +403,47 @@ def add_texts( loop = get_loop() return loop.run_until_complete(self.aadd_texts(texts, doc)) - async def aadd_texts( - self, - texts: list[Text], - doc: Doc, - ) -> bool: - """Add chunked texts to the collection. This is useful if you have already chunked the texts yourself. + async def aadd_texts(self, texts: list[Text], doc: Doc) -> bool: + """ + Add chunked texts to the collection. - Returns True if the document was added, False if it was already in the collection. + NOTE: this is useful if you have already chunked the texts yourself. + + Returns: + True if the doc was added, otherwise False if already in the collection. """ if doc.dockey in self.docs: return False - if len(texts) == 0: + if not texts: raise ValueError("No texts to add.") - if doc.docname in self.docnames: - new_docname = self._get_unique_name(doc.docname) - for t in texts: - t.name = t.name.replace(doc.docname, new_docname) - doc.docname = new_docname - if texts[0].embedding is None: - text_embeddings = await self.texts_index.embedding_model.embed_documents( - self._embedding_client, [t.text for t in texts] + # 1. Calculate text embeddings if not already present, but don't set them into + # the texts until we've set up the Doc's embedding, so callers can retry upon + # OpenAI rate limit errors + text_embeddings: list[list[float]] | None = ( + await self.texts_index.embedding_model.embed_documents( + self._embedding_client, texts=[t.text for t in texts] ) - for i, t in enumerate(texts): - t.embedding = text_embeddings[i] + if texts[0].embedding is None + else None + ) + # 2. Set the Doc's embedding to be the Doc's citation embedded if doc.embedding is None: doc.embedding = ( await self.docs_index.embedding_model.embed_documents( - self._embedding_client, [doc.citation] + self._embedding_client, texts=[doc.citation] ) )[0] + # 3. Now we can set the text embeddings + if text_embeddings is not None: + for t, t_embedding in zip(texts, text_embeddings, strict=True): + t.embedding = t_embedding + # 4. Update texts and the Doc's name + if doc.docname in self.docnames: + new_docname = self._get_unique_name(doc.docname) + for t in texts: + t.name = t.name.replace(doc.docname, new_docname) + doc.docname = new_docname + # 5. Index remaining updates if not self.jit_texts_index: self.texts_index.add_texts_and_embeddings(texts) self.docs_index.add_texts_and_embeddings([doc])