diff --git a/.mailmap b/.mailmap index 3c2ce6ea..49a287e0 100644 --- a/.mailmap +++ b/.mailmap @@ -8,3 +8,5 @@ Odhran O'Donoghue <39832722+odhran-o-d@users.nore Samantha Cox Anush008 Anush Mayk Caldas maykcaldas +Harry Vu +Harry Vu harryvu-futurehouse diff --git a/paperqa/types.py b/paperqa/types.py index 5f3f50ff..9d4107ce 100644 --- a/paperqa/types.py +++ b/paperqa/types.py @@ -686,6 +686,16 @@ def __add__(self, other: DocDetails | int) -> DocDetails: # noqa: PLR0912 merged_data[field] = {**self.other, **other.other} # handle the bibtex / sources as special fields for field_to_combine in ("bibtex_source", "client_source"): + # Ensure the fields are lists before combining + if self.other.get(field_to_combine) and not isinstance( + self.other[field_to_combine], list + ): + self.other[field_to_combine] = [self.other[field_to_combine]] + if other.other.get(field_to_combine) and not isinstance( + other.other[field_to_combine], list + ): + other.other[field_to_combine] = [other.other[field_to_combine]] + if self.other.get(field_to_combine) and other.other.get( field_to_combine ): diff --git a/tests/test_paperqa.py b/tests/test_paperqa.py index fd5bbba4..690d719a 100644 --- a/tests/test_paperqa.py +++ b/tests/test_paperqa.py @@ -6,6 +6,7 @@ import textwrap from collections.abc import AsyncIterable, Sequence from copy import deepcopy +from datetime import datetime, timedelta from io import BytesIO from pathlib import Path from typing import cast @@ -1239,6 +1240,74 @@ def test_dois_resolve_to_correct_journals(doi_journals): assert details.journal == doi_journals["journal"] +def test_docdetails_merge_with_non_list_fields() -> None: + """Check republication where the source metadata has different shapes.""" + initial_date = datetime(2023, 1, 1) + doc1 = DocDetails( + citation="Citation 1", + publication_date=initial_date, + docname="Document 1", + dockey="key1", + # NOTE: doc1 has non-list bibtex_source and list client_source + other={"bibtex_source": "source1", "client_source": ["client1"]}, + ) + + later_publication_date = initial_date + timedelta(weeks=13) + doc2 = DocDetails( + citation=doc1.citation, + publication_date=later_publication_date, + docname=doc1.docname, + dockey=doc1.dockey, + # NOTE: doc2 has list bibtex_source and non-list client_source + other={"bibtex_source": ["source2"], "client_source": "client2"}, + ) + + # Merge the two DocDetails instances + merged_doc = doc1 + doc2 + + assert {"source1", "source2"}.issubset( + merged_doc.other["bibtex_source"] + ), "Expected merge to keep both bibtex sources" + assert {"client1", "client2"}.issubset( + merged_doc.other["client_source"] + ), "Expected merge to keep both client sources" + assert isinstance(merged_doc, DocDetails), "Merged doc should also be DocDetails" + + +def test_docdetails_merge_with_list_fields() -> None: + """Check republication where the source metadata is the same shape.""" + initial_date = datetime(2023, 1, 1) + doc1 = DocDetails( + citation="Citation 1", + publication_date=initial_date, + docname="Document 1", + dockey="key1", + # NOTE: doc1 has list bibtex_source and list client_source + other={"bibtex_source": ["source1"], "client_source": ["client1"]}, + ) + + later_publication_date = initial_date + timedelta(weeks=13) + doc2 = DocDetails( + citation=doc1.citation, + publication_date=later_publication_date, + docname=doc1.docname, + dockey=doc1.dockey, + # NOTE: doc2 has list bibtex_source and list client_source + other={"bibtex_source": ["source2"], "client_source": ["client2"]}, + ) + + # Merge the two DocDetails instances + merged_doc = doc1 + doc2 + + assert {"source1", "source2"}.issubset( + merged_doc.other["bibtex_source"] + ), "Expected merge to keep both bibtex sources" + assert {"client1", "client2"}.issubset( + merged_doc.other["client_source"] + ), "Expected merge to keep both client sources" + assert isinstance(merged_doc, DocDetails), "Merged doc should also be DocDetails" + + @pytest.mark.vcr @pytest.mark.parametrize("use_partition", [True, False]) @pytest.mark.asyncio