Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement merge_annotations_from_documents() #428

Merged
merged 4 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/pytorch_ie/core/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,7 @@ def as_type(
)
return new_doc

def copy(self, with_annotations: bool = True) -> "Document":
def copy(self: D, with_annotations: bool = True) -> D:
doc_dict = self.asdict()
if not with_annotations:
for field in self.annotation_fields():
Expand Down
156 changes: 156 additions & 0 deletions src/pytorch_ie/utils/document.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
from collections import defaultdict
from typing import Dict, Hashable, List, Optional, TypeVar

from pytorch_ie.core.document import BaseAnnotationList, Document
from pytorch_ie.documents import WithMetadata


def deduplicate_annotation_dicts(
annotation_dicts: List[Dict[str, Hashable]]
) -> List[Dict[str, Hashable]]:
"""Remove duplicate annotation dictionaries from a list of annotation dictionaries.

Args:
annotation_dicts: The list of annotation dictionaries to remove duplicates from.

Returns:
The list of annotation dictionaries with duplicates removed.
"""
unique_annotation_dicts = []
seen_annotation_dicts = set()
for annotation_dict in annotation_dicts:
annotation_dict_tuple = tuple(sorted(annotation_dict.items()))
if annotation_dict_tuple not in seen_annotation_dicts:
unique_annotation_dicts.append(annotation_dict)
seen_annotation_dicts.add(annotation_dict_tuple)
return unique_annotation_dicts


D = TypeVar("D", bound=Document)


def deduplicate_annotations(document: D) -> D:
"""Remove duplicate annotations from a document.

Args:
document: The document to remove duplicate annotations from.

Returns:
The document with duplicate annotations removed.
"""
annotation_field_names = [field.name for field in document.annotation_fields()]
doc_dict = document.asdict()
for annotation_field_name in annotation_field_names:
doc_dict[annotation_field_name]["annotations"] = deduplicate_annotation_dicts(
doc_dict[annotation_field_name]["annotations"]
)
doc_dict[annotation_field_name]["predictions"] = deduplicate_annotation_dicts(
doc_dict[annotation_field_name]["predictions"]
)
return type(document).fromdict(doc_dict)


def save_annotation_sources_to_metadata(
document: D,
annotation_id2source: Dict[int, List[str]],
metadata_key: str,
use_predictions: bool,
) -> None:
"""Save the source names for the annotations or predictions in the metadata of the document.

Args:
document: The document to save the source names in the metadata for.
metadata_key: The key in the metadata where the source names should be stored.
annotation_id2source: A mapping from annotation IDs to the source names. Should contain
the ids of all annotations or predictions (depending on use_predictions) in the document.
use_predictions: Whether to store the source names for the predictions or the annotations.
"""

if not hasattr(document, "metadata"):
raise ValueError("Document does not have metadata, can not store source names.")
if metadata_key in document.metadata:
raise ValueError(f"Metadata key '{metadata_key}' already exists in the document.")
document.metadata[metadata_key] = defaultdict(dict)
for annotation_field in document.annotation_fields():
layer_name = annotation_field.name
document.metadata[metadata_key][layer_name] = []
layer: BaseAnnotationList
if use_predictions:
layer = document[layer_name].predictions
else:
layer = document[layer_name]
for ann in layer:
document.metadata[metadata_key][layer_name].append(annotation_id2source[ann._id])
document.metadata[metadata_key] = dict(document.metadata[metadata_key])


def merge_annotations_from_documents(
documents: Dict[str, D],
metadata_key_source_annotations: Optional[str] = None,
metadata_key_source_predictions: Optional[str] = None,
) -> D:
"""Merge annotations from multiple documents into a single document. Optionally, store the source
names for all annotations / predictions in the metadata at key metadata_key_source_annotations
/ metadata_key_source_predictions, respectively.

Note that this will remove any annotation duplicates.

Args:
documents: A dictionary mapping document source (e.g. dataset names) to documents.
metadata_key_source_annotations: If not None, the key in the metadata where the source names
for the (gold) annotations are stored.
metadata_key_source_predictions: If not None, the key in the metadata where the source names
for the predictions are stored.

Returns:
The document with merged annotations.
"""
if len(documents) == 0:
raise ValueError("No documents provided.")
source_names = sorted(documents)
first_source_name = source_names[0]
merged_document: D = documents[first_source_name].copy(with_annotations=False)

added_annotation_id2source_names: Dict[int, List[str]] = defaultdict(list)
for source_name in source_names:
document = documents[source_name]
if type(document) is not type(merged_document):
raise ValueError(
f"Document types do not match: {type(document)} and {type(merged_document)}"
)
if isinstance(document, WithMetadata) and document.id is not None:
if document.id != merged_document.id:
raise ValueError(
f"Document IDs do not match: {document.id} and {merged_document.id}"
)

# TODO: add_all_annotations_from_other needs to be fixed! it should return a mapping from
# original annotation *IDs* to new annotations!
# Note: this does not check for duplicates!
added_annotations = merged_document.add_all_annotations_from_other(
other=document, strict=True
)

for layer_name, orig_id2new_annotation in added_annotations.items():
for orig_id, new_annotation in orig_id2new_annotation.items():
added_annotation_id2source_names[new_annotation._id].append(source_name)

merged_document = deduplicate_annotations(merged_document)

# save source names in metadata (at key metadata_key_source_annotations / metadata_key_source_predictions
# for each layer in the order of the annotations / predictions)
if metadata_key_source_annotations is not None:
save_annotation_sources_to_metadata(
document=merged_document,
annotation_id2source=added_annotation_id2source_names,
metadata_key=metadata_key_source_annotations,
use_predictions=False,
)
if metadata_key_source_predictions is not None:
save_annotation_sources_to_metadata(
document=merged_document,
annotation_id2source=added_annotation_id2source_names,
metadata_key=metadata_key_source_predictions,
use_predictions=True,
)
return merged_document
64 changes: 64 additions & 0 deletions tests/utils/test_document.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from pytorch_ie.annotations import LabeledSpan
from pytorch_ie.documents import TextDocumentWithLabeledSpans
from pytorch_ie.utils.document import merge_annotations_from_documents


def test_document_merge_annotations():
base_doc = TextDocumentWithLabeledSpans(id="doc1", text="This is a test.")
# add annotations
base_doc.labeled_spans.append(LabeledSpan(start=0, end=4, label="label1", score=1.0))
base_doc.labeled_spans.append(LabeledSpan(start=5, end=7, label="label2", score=1.0))

input1 = base_doc.copy()
# add predictions
input1.labeled_spans.predictions.append(LabeledSpan(start=0, end=4, label="label1", score=0.9))
input1.labeled_spans.predictions.append(LabeledSpan(start=5, end=7, label="label2", score=0.7))

input2 = base_doc.copy()
# add predictions
input2.labeled_spans.predictions.append(LabeledSpan(start=0, end=4, label="label1", score=0.8))
input2.labeled_spans.predictions.append(LabeledSpan(start=5, end=7, label="label2", score=0.7))
input2.labeled_spans.predictions.append(LabeledSpan(start=5, end=7, label="label3", score=0.6))

documents = {
"doc1": input1,
"doc2": input2,
}
result = merge_annotations_from_documents(
documents,
metadata_key_source_annotations="annotations_source",
metadata_key_source_predictions="predictions_source",
)
assert result.id == "doc1"
assert set(result.labeled_spans) == set(base_doc.labeled_spans)
assert len(result.labeled_spans) == len(base_doc.labeled_spans) == 2
assert len(result.labeled_spans.predictions) == 4
assert result.labeled_spans.predictions.resolve() == [
("label1", "This"),
("label2", "is"),
("label1", "This"),
("label3", "is"),
]
annotations_with_sources = [
(ann.copy(), sources)
for ann, sources in zip(
result.labeled_spans, result.metadata["annotations_source"]["labeled_spans"]
)
]
assert annotations_with_sources == [
(LabeledSpan(start=0, end=4, label="label1", score=1.0), ["doc1", "doc2"]),
(LabeledSpan(start=5, end=7, label="label2", score=1.0), ["doc1", "doc2"]),
]
predictions_with_scores = [
(ann.copy(), sources)
for ann, sources in zip(
result.labeled_spans.predictions,
result.metadata["predictions_source"]["labeled_spans"],
)
]
assert predictions_with_scores == [
(LabeledSpan(start=0, end=4, label="label1", score=0.9), ["doc1"]),
(LabeledSpan(start=5, end=7, label="label2", score=0.7), ["doc1", "doc2"]),
(LabeledSpan(start=0, end=4, label="label1", score=0.8), ["doc2"]),
(LabeledSpan(start=5, end=7, label="label3", score=0.6), ["doc2"]),
]
Loading