From e3320459c9383423b6ffaa49e32bf80eb0c47524 Mon Sep 17 00:00:00 2001 From: Julian Risch Date: Tue, 7 Jan 2025 16:17:14 +0100 Subject: [PATCH] refactor try_merge_level, rename component input --- .../retrievers/auto_merging_retriever.py | 47 +++++++++---------- .../retrievers/test_auto_merging_retriever.py | 12 ++--- 2 files changed, 29 insertions(+), 30 deletions(-) diff --git a/haystack_experimental/components/retrievers/auto_merging_retriever.py b/haystack_experimental/components/retrievers/auto_merging_retriever.py index 9c462929..c4ffce36 100644 --- a/haystack_experimental/components/retrievers/auto_merging_retriever.py +++ b/haystack_experimental/components/retrievers/auto_merging_retriever.py @@ -113,58 +113,57 @@ def _check_valid_documents(matched_leaf_documents: List[Document]): raise ValueError("The matched leaf documents do not have the required meta field '__block_size'") @component.output_types(documents=List[Document]) - def run(self, matched_leaf_documents: List[Document]): + def run(self, documents: List[Document]): """ Run the AutoMergingRetriever. Recursively groups documents by their parents and merges them if they meet the threshold, continuing up the hierarchy until no more merges are possible. - :param matched_leaf_documents: List of leaf documents that were matched by a retriever + :param documents: List of leaf documents that were matched by a retriever :returns: List of documents (could be a mix of different hierarchy levels) """ - AutoMergingRetriever._check_valid_documents(matched_leaf_documents) + AutoMergingRetriever._check_valid_documents(documents) - def try_merge_level(documents: List[Document], docs_to_return: List[Document]) -> List[Document]: - if not documents: - return [] + def get_parent_doc(parent_id: str) -> Document: + parent_docs = self.document_store.filter_documents({"field": "id", "operator": "==", "value": parent_id}) + if len(parent_docs) != 1: + raise ValueError(f"Expected 1 parent document with id {parent_id}, found {len(parent_docs)}") - parent_documents: Dict[str, List[Document]] = defaultdict(list) # to group the documents by their parent + parent_doc = parent_docs[0] + if not parent_doc.meta.get("__children_ids"): + raise ValueError(f"Parent document with id {parent_id} does not have any children.") - for doc in documents: + return parent_doc + + def try_merge_level(docs_to_merge: List[Document], docs_to_return: List[Document]) -> List[Document]: + parent_doc_id_to_child_docs: Dict[str, List[Document]] = defaultdict(list) # to group documents by parent + + for doc in docs_to_merge: if doc.meta.get("__parent_id"): # only docs that have parents - parent_documents[doc.meta["__parent_id"]].append(doc) + parent_doc_id_to_child_docs[doc.meta["__parent_id"]].append(doc) else: docs_to_return.append(doc) # keep docs that have no parents # Process each parent group merged_docs = [] - for doc_id, child_docs in parent_documents.items(): - parent_doc = self.document_store.filter_documents({"field": "id", "operator": "==", "value": doc_id}) - if len(parent_doc) != 1: - raise ValueError(f"Expected 1 parent document with id {doc_id}, found {len(parent_doc)}") - - parent = parent_doc[0] - if not parent.meta.get("__children_ids"): - raise ValueError(f"Parent document with id {doc_id} does not have any children.") + for parent_doc_id, child_docs in parent_doc_id_to_child_docs.items(): + parent_doc = get_parent_doc(parent_doc_id) # Calculate merge score - score = len(child_docs) / len(parent.meta["__children_ids"]) + score = len(child_docs) / len(parent_doc.meta["__children_ids"]) if score > self.threshold: - merged_docs.append(parent) # Merge into parent + merged_docs.append(parent_doc) # Merge into parent else: docs_to_return.extend(child_docs) # Keep children separate # if no new merges were made, we're done - if len(merged_docs) == len(documents): + if merged_docs == docs_to_merge: return merged_docs + docs_to_return # Recursively try to merge the next level return try_merge_level(merged_docs, docs_to_return) - # start the recursive merging process - docs_to_return: List[Document] = [] - final_docs = try_merge_level(matched_leaf_documents, docs_to_return) - return {"documents": final_docs + docs_to_return} + return {"documents": try_merge_level(documents, [])} diff --git a/test/components/retrievers/test_auto_merging_retriever.py b/test/components/retrievers/test_auto_merging_retriever.py index 0074ab6b..70228c1f 100644 --- a/test/components/retrievers/test_auto_merging_retriever.py +++ b/test/components/retrievers/test_auto_merging_retriever.py @@ -32,7 +32,7 @@ def test_run_missing_parent_id(self): ] retriever = AutoMergingRetriever(InMemoryDocumentStore()) with pytest.raises(ValueError, match="The matched leaf documents do not have the required meta field '__parent_id'"): - retriever.run(matched_leaf_documents=docs) + retriever.run(documents=docs) def test_run_missing_level(self): docs = [ @@ -47,7 +47,7 @@ def test_run_missing_level(self): retriever = AutoMergingRetriever(InMemoryDocumentStore()) with pytest.raises(ValueError, match="The matched leaf documents do not have the required meta field '__level'"): - retriever.run(matched_leaf_documents=docs) + retriever.run(documents=docs) def test_run_missing_block_size(self): docs = [ @@ -62,7 +62,7 @@ def test_run_missing_block_size(self): retriever = AutoMergingRetriever(InMemoryDocumentStore()) with pytest.raises(ValueError, match="The matched leaf documents do not have the required meta field '__block_size'"): - retriever.run(matched_leaf_documents=docs) + retriever.run(documents=docs) def test_run_mixed_valid_and_invalid_documents(self): docs = [ @@ -84,7 +84,7 @@ def test_run_mixed_valid_and_invalid_documents(self): ] retriever = AutoMergingRetriever(InMemoryDocumentStore()) with pytest.raises(ValueError, match="The matched leaf documents do not have the required meta field '__parent_id'"): - retriever.run(matched_leaf_documents=docs) + retriever.run(documents=docs) def test_to_dict(self): retriever = AutoMergingRetriever(InMemoryDocumentStore(), threshold=0.7) @@ -117,7 +117,7 @@ def test_serialization_deserialization_pipeline(self): pipeline.add_component(name="bm_25_retriever", instance=bm_25_retriever) pipeline.add_component(name="auto_merging_retriever", instance=auto_merging_retriever) - pipeline.connect("bm_25_retriever.documents", "auto_merging_retriever.matched_leaf_documents") + pipeline.connect("bm_25_retriever.documents", "auto_merging_retriever.documents") pipeline_dict = pipeline.to_dict() new_pipeline = Pipeline.from_dict(pipeline_dict) @@ -295,4 +295,4 @@ def test_run_go_up_hierarchy_multiple_levels_hit_root_document(self): result = retriever.run(retrieved_leaf_docs) assert len(result['documents']) == 1 - assert result['documents'][0].meta["__level"] == 0 # hit root document \ No newline at end of file + assert result['documents'][0].meta["__level"] == 0 # hit root document