Skip to content

Commit

Permalink
Refactored logic to validate Text.doc
Browse files Browse the repository at this point in the history
  • Loading branch information
maykcaldas committed Jan 14, 2025
1 parent c70f3de commit 5a81f71
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 14 deletions.
3 changes: 3 additions & 0 deletions paperqa/agents/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import warnings
import zlib
from collections.abc import Callable, Collection, Sequence
from datetime import datetime
from enum import StrEnum, auto
from typing import TYPE_CHECKING, Any, ClassVar
from uuid import UUID
Expand Down Expand Up @@ -70,6 +71,8 @@ def default(self, o):
return list(o)
if isinstance(o, os.PathLike):
return str(o)
if isinstance(o, datetime):
return o.isoformat()
return json.JSONEncoder.default(self, o)


Expand Down
50 changes: 37 additions & 13 deletions paperqa/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import re
import warnings
from collections.abc import Collection
from copy import deepcopy
from datetime import datetime
from typing import Any, ClassVar, cast
from uuid import UUID, uuid4
Expand All @@ -19,6 +20,7 @@
BaseModel,
ConfigDict,
Field,
ValidationError,
computed_field,
field_validator,
model_validator,
Expand All @@ -40,6 +42,8 @@


class Doc(Embeddable):
model_config = ConfigDict(extra="forbid")

docname: str
dockey: DocKey
citation: str
Expand All @@ -53,6 +57,15 @@ class Doc(Embeddable):
def __hash__(self) -> int:
return hash((self.docname, self.dockey))

@model_validator(mode="before")
@classmethod
def ensure_no_extra(cls, data: Any) -> Any:
if isinstance(data, dict):
# formatted_citation is serialized by model_dump
# but it's not an attribute
data.pop("formatted_citation", None)
return data

@computed_field # type: ignore[prop-decorator]
@property
def formatted_citation(self) -> str:
Expand Down Expand Up @@ -80,11 +93,7 @@ def matches_filter_criteria(self, filter_criteria: dict) -> bool:
class Text(Embeddable):
text: str
name: str
# TODO: doc is often used as `DocDetails`.
# However, typing it as `Doc | DocDetails` makes SearchIndex.query
# to fail to return the correct results.
# Will keep it as `Doc` and cast to `DocDetails` when needed for now.
doc: Doc | DocDetails
doc: DocDetails | Doc

@model_validator(mode="before")
@classmethod
Expand All @@ -95,18 +104,31 @@ def ensure_doc(cls, values: Any) -> Any:
return values

doc_data = values.get("doc")
if isinstance(doc_data, Doc | DocDetails):
if isinstance(doc_data, Doc):
# If not deserializing, we don't need to change anything
return values

if doc_data:
maybe_is_docdetails = all(
k in doc_data for k in list(DocDetails.model_fields.keys())
)
maybe_is_doc = all(k in doc_data for k in list(Doc.model_fields.keys()))
if maybe_is_doc and not maybe_is_docdetails:
doc = Doc(**doc_data)
copy_doc_data = deepcopy(doc_data)
# Formatted citation is a computed field
# not an attribute
copy_doc_data.pop("formatted_citation", None)
try:
doc = Doc(**copy_doc_data)
values["doc"] = doc
except ValidationError as exc:
logger.debug(
f"Failed to deserialize doc data {doc_data} due to {exc}."
f"Trying to deserialize it into a DocDetails."
)
try:
doc = DocDetails(**copy_doc_data)
values["doc"] = doc
except ValidationError as exc:
raise ValidationError(
f"Failed to deserialize doc data from {doc_data}."
) from exc

return values

def __hash__(self) -> int:
Expand Down Expand Up @@ -242,7 +264,7 @@ def filter_content_for_user(self) -> None:
text=Text(
text="",
**c.text.model_dump(exclude={"text", "embedding", "doc"}),
doc=Doc(**c.text.doc.model_dump(exclude={"embedding"})),
doc=c.text.doc.model_dump(exclude={"embedding"}),
),
)
for c in self.contexts
Expand Down Expand Up @@ -626,6 +648,8 @@ def populate_bibtex_key_citation( # noqa: PLR0912
@model_validator(mode="before")
@classmethod
def validate_all_fields(cls, data: dict[str, Any]) -> dict[str, Any]:
if isinstance(data, Doc): # type: ignore[unreachable]
data = data.model_dump() # type: ignore[unreachable]
data = cls.lowercase_doi_and_populate_doc_id(data)
data = cls.remove_invalid_authors(data)
data = cls.misc_string_cleaning(data)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,5 +80,5 @@ def test_cli_can_build_and_search_index(
result = search_query("XAI", index_name, settings)
assert len(result) == 1
assert isinstance(result[0][0], Docs)
assert result[0][0].docnames == {"Wellawatte2024"}
assert all(d.startswith("Wellawatte") for d in result[0][0].docnames)
assert result[0][1] == "paper.pdf"

0 comments on commit 5a81f71

Please sign in to comment.