Skip to content

Commit

Permalink
refactor try_merge_level, rename component input
Browse files Browse the repository at this point in the history
  • Loading branch information
julian-risch committed Jan 7, 2025
1 parent f71849a commit e332045
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -113,58 +113,57 @@ def _check_valid_documents(matched_leaf_documents: List[Document]):
raise ValueError("The matched leaf documents do not have the required meta field '__block_size'")

@component.output_types(documents=List[Document])
def run(self, matched_leaf_documents: List[Document]):
def run(self, documents: List[Document]):
"""
Run the AutoMergingRetriever.
Recursively groups documents by their parents and merges them if they meet the threshold,
continuing up the hierarchy until no more merges are possible.
:param matched_leaf_documents: List of leaf documents that were matched by a retriever
:param documents: List of leaf documents that were matched by a retriever
:returns:
List of documents (could be a mix of different hierarchy levels)
"""

AutoMergingRetriever._check_valid_documents(matched_leaf_documents)
AutoMergingRetriever._check_valid_documents(documents)

def try_merge_level(documents: List[Document], docs_to_return: List[Document]) -> List[Document]:
if not documents:
return []
def get_parent_doc(parent_id: str) -> Document:
parent_docs = self.document_store.filter_documents({"field": "id", "operator": "==", "value": parent_id})
if len(parent_docs) != 1:
raise ValueError(f"Expected 1 parent document with id {parent_id}, found {len(parent_docs)}")

parent_documents: Dict[str, List[Document]] = defaultdict(list) # to group the documents by their parent
parent_doc = parent_docs[0]
if not parent_doc.meta.get("__children_ids"):
raise ValueError(f"Parent document with id {parent_id} does not have any children.")

for doc in documents:
return parent_doc

def try_merge_level(docs_to_merge: List[Document], docs_to_return: List[Document]) -> List[Document]:
parent_doc_id_to_child_docs: Dict[str, List[Document]] = defaultdict(list) # to group documents by parent

for doc in docs_to_merge:
if doc.meta.get("__parent_id"): # only docs that have parents
parent_documents[doc.meta["__parent_id"]].append(doc)
parent_doc_id_to_child_docs[doc.meta["__parent_id"]].append(doc)
else:
docs_to_return.append(doc) # keep docs that have no parents

# Process each parent group
merged_docs = []
for doc_id, child_docs in parent_documents.items():
parent_doc = self.document_store.filter_documents({"field": "id", "operator": "==", "value": doc_id})
if len(parent_doc) != 1:
raise ValueError(f"Expected 1 parent document with id {doc_id}, found {len(parent_doc)}")

parent = parent_doc[0]
if not parent.meta.get("__children_ids"):
raise ValueError(f"Parent document with id {doc_id} does not have any children.")
for parent_doc_id, child_docs in parent_doc_id_to_child_docs.items():
parent_doc = get_parent_doc(parent_doc_id)

# Calculate merge score
score = len(child_docs) / len(parent.meta["__children_ids"])
score = len(child_docs) / len(parent_doc.meta["__children_ids"])
if score > self.threshold:
merged_docs.append(parent) # Merge into parent
merged_docs.append(parent_doc) # Merge into parent
else:
docs_to_return.extend(child_docs) # Keep children separate

# if no new merges were made, we're done
if len(merged_docs) == len(documents):
if merged_docs == docs_to_merge:
return merged_docs + docs_to_return

# Recursively try to merge the next level
return try_merge_level(merged_docs, docs_to_return)

# start the recursive merging process
docs_to_return: List[Document] = []
final_docs = try_merge_level(matched_leaf_documents, docs_to_return)
return {"documents": final_docs + docs_to_return}
return {"documents": try_merge_level(documents, [])}
12 changes: 6 additions & 6 deletions test/components/retrievers/test_auto_merging_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_run_missing_parent_id(self):
]
retriever = AutoMergingRetriever(InMemoryDocumentStore())
with pytest.raises(ValueError, match="The matched leaf documents do not have the required meta field '__parent_id'"):
retriever.run(matched_leaf_documents=docs)
retriever.run(documents=docs)

def test_run_missing_level(self):
docs = [
Expand All @@ -47,7 +47,7 @@ def test_run_missing_level(self):

retriever = AutoMergingRetriever(InMemoryDocumentStore())
with pytest.raises(ValueError, match="The matched leaf documents do not have the required meta field '__level'"):
retriever.run(matched_leaf_documents=docs)
retriever.run(documents=docs)

def test_run_missing_block_size(self):
docs = [
Expand All @@ -62,7 +62,7 @@ def test_run_missing_block_size(self):

retriever = AutoMergingRetriever(InMemoryDocumentStore())
with pytest.raises(ValueError, match="The matched leaf documents do not have the required meta field '__block_size'"):
retriever.run(matched_leaf_documents=docs)
retriever.run(documents=docs)

def test_run_mixed_valid_and_invalid_documents(self):
docs = [
Expand All @@ -84,7 +84,7 @@ def test_run_mixed_valid_and_invalid_documents(self):
]
retriever = AutoMergingRetriever(InMemoryDocumentStore())
with pytest.raises(ValueError, match="The matched leaf documents do not have the required meta field '__parent_id'"):
retriever.run(matched_leaf_documents=docs)
retriever.run(documents=docs)

def test_to_dict(self):
retriever = AutoMergingRetriever(InMemoryDocumentStore(), threshold=0.7)
Expand Down Expand Up @@ -117,7 +117,7 @@ def test_serialization_deserialization_pipeline(self):

pipeline.add_component(name="bm_25_retriever", instance=bm_25_retriever)
pipeline.add_component(name="auto_merging_retriever", instance=auto_merging_retriever)
pipeline.connect("bm_25_retriever.documents", "auto_merging_retriever.matched_leaf_documents")
pipeline.connect("bm_25_retriever.documents", "auto_merging_retriever.documents")
pipeline_dict = pipeline.to_dict()

new_pipeline = Pipeline.from_dict(pipeline_dict)
Expand Down Expand Up @@ -295,4 +295,4 @@ def test_run_go_up_hierarchy_multiple_levels_hit_root_document(self):
result = retriever.run(retrieved_leaf_docs)

assert len(result['documents']) == 1
assert result['documents'][0].meta["__level"] == 0 # hit root document
assert result['documents'][0].meta["__level"] == 0 # hit root document

0 comments on commit e332045

Please sign in to comment.