diff --git a/haystack_experimental/components/retrievers/auto_merging_retriever.py b/haystack_experimental/components/retrievers/auto_merging_retriever.py index a8c102f1..989f7724 100644 --- a/haystack_experimental/components/retrievers/auto_merging_retriever.py +++ b/haystack_experimental/components/retrievers/auto_merging_retriever.py @@ -39,7 +39,7 @@ class AutoMergingRetriever: from haystack_experimental.components.retrievers.auto_merging_retriever import AutoMergingRetriever from haystack.document_stores.in_memory import InMemoryDocumentStore - # create a hierarchical document structure with 2 levels, where the parent document has 3 children + # create a hierarchical document structure with 3 levels, where the parent document has 3 children text = "The sun rose early in the morning. It cast a warm glow over the trees. Birds began to sing." original_document = Document(content=text) builder = HierarchicalDocumentSplitter(block_sizes=[10, 3], split_overlap=0, split_by="word") @@ -113,45 +113,57 @@ def _check_valid_documents(matched_leaf_documents: List[Document]): raise ValueError("The matched leaf documents do not have the required meta field '__block_size'") @component.output_types(documents=List[Document]) - def run(self, matched_leaf_documents: List[Document]): + def run(self, documents: List[Document]): """ Run the AutoMergingRetriever. - Groups the matched leaf documents by their parent documents and returns the parent documents if the number of - matched leaf documents below the same parent is higher than the defined threshold. Otherwise, returns the - matched leaf documents. + Recursively groups documents by their parents and merges them if they meet the threshold, + continuing up the hierarchy until no more merges are possible. - :param matched_leaf_documents: List of leaf documents that were matched by a retriever + :param documents: List of leaf documents that were matched by a retriever :returns: - List of parent documents or matched leaf documents based on the threshold value + List of documents (could be a mix of different hierarchy levels) """ - docs_to_return = [] - - # group the matched leaf documents by their parent documents - parent_documents: Dict[str, List[Document]] = defaultdict(list) - for doc in matched_leaf_documents: - parent_documents[doc.meta["__parent_id"]].append(doc) - - # find total number of children for each parent document - for doc_id, retrieved_child_docs in parent_documents.items(): - parent_doc = self.document_store.filter_documents({"field": "id", "operator": "==", "value": doc_id}) - if len(parent_doc) == 0: - raise ValueError(f"Parent document with id {doc_id} not found in the document store.") - if len(parent_doc) > 1: - raise ValueError(f"Multiple parent documents found with id {doc_id} in the document store.") - if not parent_doc[0].meta.get("__children_ids"): - raise ValueError(f"Parent document with id {doc_id} does not have any children.") - parent_children_count = len(parent_doc[0].meta["__children_ids"]) - - # return either the parent document or the matched leaf documents based on the threshold value - score = len(retrieved_child_docs) / parent_children_count - if score >= self.threshold: - # return the parent document - docs_to_return.append(parent_doc[0]) - else: - # return all the matched leaf documents which are child of this parent document - leafs_ids = {doc.id for doc in retrieved_child_docs} - docs_to_return.extend([doc for doc in matched_leaf_documents if doc.id in leafs_ids]) - - return {"documents": docs_to_return} + AutoMergingRetriever._check_valid_documents(documents) + + def _get_parent_doc(parent_id: str) -> Document: + parent_docs = self.document_store.filter_documents({"field": "id", "operator": "==", "value": parent_id}) + if len(parent_docs) != 1: + raise ValueError(f"Expected 1 parent document with id {parent_id}, found {len(parent_docs)}") + + parent_doc = parent_docs[0] + if not parent_doc.meta.get("__children_ids"): + raise ValueError(f"Parent document with id {parent_id} does not have any children.") + + return parent_doc + + def _try_merge_level(docs_to_merge: List[Document], docs_to_return: List[Document]) -> List[Document]: + parent_doc_id_to_child_docs: Dict[str, List[Document]] = defaultdict(list) # to group documents by parent + + for doc in docs_to_merge: + if doc.meta.get("__parent_id"): # only docs that have parents + parent_doc_id_to_child_docs[doc.meta["__parent_id"]].append(doc) + else: + docs_to_return.append(doc) # keep docs that have no parents + + # Process each parent group + merged_docs = [] + for parent_doc_id, child_docs in parent_doc_id_to_child_docs.items(): + parent_doc = _get_parent_doc(parent_doc_id) + + # Calculate merge score + score = len(child_docs) / len(parent_doc.meta["__children_ids"]) + if score > self.threshold: + merged_docs.append(parent_doc) # Merge into parent + else: + docs_to_return.extend(child_docs) # Keep children separate + + # if no new merges were made, we're done + if not merged_docs: + return merged_docs + docs_to_return + + # Recursively try to merge the next level + return _try_merge_level(merged_docs, docs_to_return) + + return {"documents": _try_merge_level(documents, [])} diff --git a/test/components/retrievers/test_auto_merging_retriever.py b/test/components/retrievers/test_auto_merging_retriever.py index 24e20c48..70228c1f 100644 --- a/test/components/retrievers/test_auto_merging_retriever.py +++ b/test/components/retrievers/test_auto_merging_retriever.py @@ -6,8 +6,8 @@ from haystack_experimental.components.retrievers.auto_merging_retriever import AutoMergingRetriever from haystack.document_stores.in_memory import InMemoryDocumentStore - class TestAutoMergingRetriever: + def test_init_default(self): retriever = AutoMergingRetriever(InMemoryDocumentStore()) assert retriever.threshold == 0.5 @@ -20,6 +20,72 @@ def test_init_with_invalid_threshold(self): with pytest.raises(ValueError): AutoMergingRetriever(InMemoryDocumentStore(), threshold=-2) + def test_run_missing_parent_id(self): + docs = [ + Document( + content="test", + meta={ + "__level": 1, + "__block_size": 10, + }, + ) + ] + retriever = AutoMergingRetriever(InMemoryDocumentStore()) + with pytest.raises(ValueError, match="The matched leaf documents do not have the required meta field '__parent_id'"): + retriever.run(documents=docs) + + def test_run_missing_level(self): + docs = [ + Document( + content="test", + meta={ + "__parent_id": "parent1", + "__block_size": 10, + }, + ) + ] + + retriever = AutoMergingRetriever(InMemoryDocumentStore()) + with pytest.raises(ValueError, match="The matched leaf documents do not have the required meta field '__level'"): + retriever.run(documents=docs) + + def test_run_missing_block_size(self): + docs = [ + Document( + content="test", + meta={ + "__parent_id": "parent1", + "__level": 1, + }, + ) + ] + + retriever = AutoMergingRetriever(InMemoryDocumentStore()) + with pytest.raises(ValueError, match="The matched leaf documents do not have the required meta field '__block_size'"): + retriever.run(documents=docs) + + def test_run_mixed_valid_and_invalid_documents(self): + docs = [ + Document( + content="valid", + meta={ + "__parent_id": "parent1", + "__level": 1, + "__block_size": 10, + }, + ), + Document( + content="invalid", + meta={ + "__level": 1, + "__block_size": 10, + }, + ), + ] + retriever = AutoMergingRetriever(InMemoryDocumentStore()) + with pytest.raises(ValueError, match="The matched leaf documents do not have the required meta field '__parent_id'"): + retriever.run(documents=docs) + def test_to_dict(self): retriever = AutoMergingRetriever(InMemoryDocumentStore(), threshold=0.7) expected = retriever.to_dict() @@ -43,6 +109,71 @@ def test_from_dict(self): retriever = AutoMergingRetriever.from_dict(data) assert retriever.threshold == 0.7 + def test_serialization_deserialization_pipeline(self): + pipeline = Pipeline() + doc_store_parents = InMemoryDocumentStore() + bm_25_retriever = InMemoryBM25Retriever(doc_store_parents) + auto_merging_retriever = AutoMergingRetriever(doc_store_parents, threshold=0.5) + + pipeline.add_component(name="bm_25_retriever", instance=bm_25_retriever) + pipeline.add_component(name="auto_merging_retriever", instance=auto_merging_retriever) + pipeline.connect("bm_25_retriever.documents", "auto_merging_retriever.documents") + pipeline_dict = pipeline.to_dict() + + new_pipeline = Pipeline.from_dict(pipeline_dict) + assert new_pipeline == pipeline + + def test_run_parent_not_found(self): + doc_store = InMemoryDocumentStore() + retriever = AutoMergingRetriever(doc_store, threshold=0.5) + + # a leaf document with a non-existent parent_id + leaf_doc = Document( + content="test", + meta={ + "__parent_id": "non_existent_parent", + "__level": 1, + "__block_size": 10, + } + ) + + with pytest.raises(ValueError, match="Expected 1 parent document with id non_existent_parent, found 0"): + retriever.run([leaf_doc]) + + def test_run_parent_without_children_metadata(self): + """Test case where a parent document exists but doesn't have the __children_ids metadata field""" + doc_store = InMemoryDocumentStore() + + # Create and store a parent document without __children_ids metadata + parent_doc = Document( + content="parent content", + id="parent1", + meta={ + "__level": 1, # Add other required metadata + "__block_size": 10 + } + ) + doc_store.write_documents([parent_doc]) + + retriever = AutoMergingRetriever(doc_store, threshold=0.5) + + # Create a leaf document that points to this parent + leaf_doc = Document( + content="leaf content", + meta={ + "__parent_id": "parent1", + "__level": 2, + "__block_size": 5 + } + ) + + with pytest.raises(ValueError, match="Parent document with id parent1 does not have any children"): + retriever.run([leaf_doc]) + + def test_run_empty_documents(self): + retriever = AutoMergingRetriever(InMemoryDocumentStore()) + assert retriever.run([]) == {"documents": []} + def test_run_return_parent_document(self): text = "The sun rose early in the morning. It cast a warm glow over the trees. Birds began to sing." @@ -50,10 +181,10 @@ def test_run_return_parent_document(self): builder = HierarchicalDocumentSplitter(block_sizes={10, 3}, split_overlap=0, split_by="word") docs = builder.run(docs) - # store level-1 parent documents and initialize the retriever + # store all non-leaf documents doc_store_parents = InMemoryDocumentStore() for doc in docs["documents"]: - if doc.meta["__children_ids"] and doc.meta["__level"] == 1: + if doc.meta["__children_ids"]: doc_store_parents.write_documents([doc]) retriever = AutoMergingRetriever(doc_store_parents, threshold=0.5) @@ -100,16 +231,68 @@ def test_run_return_leafs_document_different_parents(self): assert len(result['documents']) == 2 assert result['documents'][0].meta["__parent_id"] != result['documents'][1].meta["__parent_id"] - def test_serialization_deserialization_pipeline(self): - pipeline = Pipeline() + def test_run_go_up_hierarchy_multiple_levels(self): + """ + Test if the retriever can go up the hierarchy multiple levels to find the parent document. + + Simulate a scenario where we have 4 leaf-documents that matched some initial query. The leaf-documents + are continuously merged up the hierarchy until the threshold is no longer met. + In this case it goes from the 4th level in the hierarchy up the 1st level. + """ + text = "The sun rose early in the morning. It cast a warm glow over the trees. Birds began to sing." + + docs = [Document(content=text)] + builder = HierarchicalDocumentSplitter(block_sizes={6, 4, 2, 1}, split_overlap=0, split_by="word") + docs = builder.run(docs) + + # store all non-leaf documents doc_store_parents = InMemoryDocumentStore() - bm_25_retriever = InMemoryBM25Retriever(doc_store_parents) - auto_merging_retriever = AutoMergingRetriever(doc_store_parents, threshold=0.5) + for doc in docs["documents"]: + if doc.meta["__children_ids"]: + doc_store_parents.write_documents([doc]) + retriever = AutoMergingRetriever(doc_store_parents, threshold=0.4) - pipeline.add_component(name="bm_25_retriever", instance=bm_25_retriever) - pipeline.add_component(name="auto_merging_retriever", instance=auto_merging_retriever) - pipeline.connect("bm_25_retriever.documents", "auto_merging_retriever.matched_leaf_documents") - pipeline_dict = pipeline.to_dict() + retrieved_leaf_docs_id = [ + '8e65095a31fe5da857e4f939198217d961ea2d5052a4d0f587ec5fc78c743779', + '00409c91c6bb2a989565e963f563aa5a081f6054ab8b7a9307246b3cc0f0d352', + 'e88945a30bec3e084e6aa528bcc940b4a78b6a6353c4243632be3aae84a7f532', + '2d0cc69c40911586d51e3e9afbfed50a0b85475dcbd524c01b46ccf5bdc54d48' + ] - new_pipeline = Pipeline.from_dict(pipeline_dict) - assert new_pipeline == pipeline + retrieved_leaf_docs = [d for d in docs['documents'] if d.id in retrieved_leaf_docs_id] + result = retriever.run(retrieved_leaf_docs) + + assert len(result['documents']) == 1 + assert result['documents'][0].content == 'The sun rose early in the ' + + def test_run_go_up_hierarchy_multiple_levels_hit_root_document(self): + """ + Test case where we go up hierarchy until the root document, so the root document is returned. + + It's the only document in the hierarchy which has no parent. + """ + text = "The sun rose early in the morning. It cast a warm glow over the trees. Birds began to sing." + + docs = [Document(content=text)] + builder = HierarchicalDocumentSplitter(block_sizes={6, 4}, split_overlap=0, split_by="word") + docs = builder.run(docs) + + # store all non-leaf documents + doc_store_parents = InMemoryDocumentStore() + for doc in docs["documents"]: + if doc.meta["__children_ids"]: + doc_store_parents.write_documents([doc]) + retriever = AutoMergingRetriever(doc_store_parents, threshold=0.1) # set a low threshold to hit root document + + retrieved_leaf_docs_id = [ + '7e654d8ae21cc9807e4c377288a590efe7a6d86606676e51992cf719a03a3f42', + 'acb19c71330c1f7515046bbcbacfcdf8fe21d273c40485a6b3f6b8ea13d4adec', + '98480d4a5f97ebd330d2bc06640692d52a8af2265e2ea0e87abf09d6472c7af9', + 'a61b5a9ea9edfbd1572c02f7289c644128dd144a476f9e349bd35fdc93590610' + ] + + retrieved_leaf_docs = [d for d in docs['documents'] if d.id in retrieved_leaf_docs_id] + result = retriever.run(retrieved_leaf_docs) + + assert len(result['documents']) == 1 + assert result['documents'][0].meta["__level"] == 0 # hit root document