Skip to content

Commit

Permalink
Ensure fields are lists before combining in DocDetails class (#801)
Browse files Browse the repository at this point in the history
  • Loading branch information
harryvu-futurehouse authored Jan 13, 2025
1 parent 8b41c1a commit e5b6447
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .mailmap
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@ Odhran O'Donoghue <[email protected]> <[email protected]
Samantha Cox <[email protected]> <[email protected]>
Anush008 <[email protected]> Anush <[email protected]>
Mayk Caldas <[email protected]> maykcaldas <[email protected]>
Harry Vu <[email protected]> <[email protected]>
Harry Vu <[email protected]> harryvu-futurehouse
10 changes: 10 additions & 0 deletions paperqa/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,16 @@ def __add__(self, other: DocDetails | int) -> DocDetails: # noqa: PLR0912
merged_data[field] = {**self.other, **other.other}
# handle the bibtex / sources as special fields
for field_to_combine in ("bibtex_source", "client_source"):
# Ensure the fields are lists before combining
if self.other.get(field_to_combine) and not isinstance(
self.other[field_to_combine], list
):
self.other[field_to_combine] = [self.other[field_to_combine]]
if other.other.get(field_to_combine) and not isinstance(
other.other[field_to_combine], list
):
other.other[field_to_combine] = [other.other[field_to_combine]]

if self.other.get(field_to_combine) and other.other.get(
field_to_combine
):
Expand Down
69 changes: 69 additions & 0 deletions tests/test_paperqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import textwrap
from collections.abc import AsyncIterable, Sequence
from copy import deepcopy
from datetime import datetime, timedelta
from io import BytesIO
from pathlib import Path
from typing import cast
Expand Down Expand Up @@ -1239,6 +1240,74 @@ def test_dois_resolve_to_correct_journals(doi_journals):
assert details.journal == doi_journals["journal"]


def test_docdetails_merge_with_non_list_fields() -> None:
"""Check republication where the source metadata has different shapes."""
initial_date = datetime(2023, 1, 1)
doc1 = DocDetails(
citation="Citation 1",
publication_date=initial_date,
docname="Document 1",
dockey="key1",
# NOTE: doc1 has non-list bibtex_source and list client_source
other={"bibtex_source": "source1", "client_source": ["client1"]},
)

later_publication_date = initial_date + timedelta(weeks=13)
doc2 = DocDetails(
citation=doc1.citation,
publication_date=later_publication_date,
docname=doc1.docname,
dockey=doc1.dockey,
# NOTE: doc2 has list bibtex_source and non-list client_source
other={"bibtex_source": ["source2"], "client_source": "client2"},
)

# Merge the two DocDetails instances
merged_doc = doc1 + doc2

assert {"source1", "source2"}.issubset(
merged_doc.other["bibtex_source"]
), "Expected merge to keep both bibtex sources"
assert {"client1", "client2"}.issubset(
merged_doc.other["client_source"]
), "Expected merge to keep both client sources"
assert isinstance(merged_doc, DocDetails), "Merged doc should also be DocDetails"


def test_docdetails_merge_with_list_fields() -> None:
"""Check republication where the source metadata is the same shape."""
initial_date = datetime(2023, 1, 1)
doc1 = DocDetails(
citation="Citation 1",
publication_date=initial_date,
docname="Document 1",
dockey="key1",
# NOTE: doc1 has list bibtex_source and list client_source
other={"bibtex_source": ["source1"], "client_source": ["client1"]},
)

later_publication_date = initial_date + timedelta(weeks=13)
doc2 = DocDetails(
citation=doc1.citation,
publication_date=later_publication_date,
docname=doc1.docname,
dockey=doc1.dockey,
# NOTE: doc2 has list bibtex_source and list client_source
other={"bibtex_source": ["source2"], "client_source": ["client2"]},
)

# Merge the two DocDetails instances
merged_doc = doc1 + doc2

assert {"source1", "source2"}.issubset(
merged_doc.other["bibtex_source"]
), "Expected merge to keep both bibtex sources"
assert {"client1", "client2"}.issubset(
merged_doc.other["client_source"]
), "Expected merge to keep both client sources"
assert isinstance(merged_doc, DocDetails), "Merged doc should also be DocDetails"


@pytest.mark.vcr
@pytest.mark.parametrize("use_partition", [True, False])
@pytest.mark.asyncio
Expand Down

0 comments on commit e5b6447

Please sign in to comment.