diff --git a/eyecite/models.py b/eyecite/models.py index 01d8cc61..4ab71f2c 100644 --- a/eyecite/models.py +++ b/eyecite/models.py @@ -1,6 +1,6 @@ import re from collections import UserString -from dataclasses import dataclass, field +from dataclasses import asdict, dataclass, field from datetime import datetime from typing import ( Any, @@ -15,7 +15,7 @@ cast, ) -from eyecite.utils import HashableDict +from eyecite.utils import hash_sha256 ResourceType = Hashable @@ -60,7 +60,7 @@ def includes_year( ) -@dataclass(eq=True, unsafe_hash=True) +@dataclass(eq=False, unsafe_hash=False) class CitationBase: """Base class for objects returned by `eyecite.find.get_citations`. We define several subclasses of this class below, representing the various @@ -79,7 +79,7 @@ class CitationBase: def __post_init__(self): """Set up groups and metadata.""" # Allow groups to be used in comparisons: - self.groups = HashableDict(self.token.groups) + self.groups = self.token.groups # Make metadata a self.Metadata object: self.metadata = ( self.Metadata(**self.metadata) @@ -101,21 +101,52 @@ def __repr__(self): + ")" ) + def __hash__(self) -> int: + """In general, citations are considered equivalent if they have the + same group values (i.e., the same regex group content that is extracted + from the matched text). Subclasses may override this method in order to + specify equivalence behavior that is more appropriate for certain + kinds of citations (e.g., see CaseCitation override). + + self.groups typically contains different keys for different objects: + + FullLawCitation (non-exhaustive and non-guaranteed): + - chapter + - reporter + - law_section + - issue + - page + - docket_number + - pamphlet + - title + + FullJournalCitation (non-exhaustive and non-guaranteed): + - volume + - reporter + - page + + FullCaseCitation (see CaseCitation.__hash__() notes) + """ + return hash( + hash_sha256( + {**dict(self.groups.items()), **{"class": type(self).__name__}} + ) + ) + + def __eq__(self, other): + """This method is inherited by all subclasses and should not be + overridden. It implements object equality in exactly the same way as + defined in an object's __hash__() function, which should be overridden + instead if desired. + """ + return self.__hash__() == other.__hash__() + @dataclass(eq=True, unsafe_hash=True) class Metadata: """Define fields on self.metadata.""" parenthetical: Optional[str] = None - def comparison_hash(self) -> int: - """Return hash that will be the same if two cites are semantically - equivalent, unless the citation is a CaseCitation missing a page. - """ - if isinstance(self, CaseCitation) and self.groups["page"] is None: - return id(self) - else: - return hash((type(self), tuple(self.groups.items()))) - def corrected_citation(self): """Return citation with any variations normalized.""" return self.matched_text() @@ -170,7 +201,7 @@ def full_span(self) -> Tuple[int, int]: return start, end -@dataclass(eq=True, unsafe_hash=True, repr=False) +@dataclass(eq=False, unsafe_hash=False, repr=False) class ResourceCitation(CitationBase): """Base class for a case, law, or journal citation. Could be short or long.""" @@ -194,6 +225,26 @@ def __post_init__(self): ) super().__post_init__() + def __hash__(self) -> int: + """ResourceCitation objects are hashed in the same way as their + parent class (CitationBase) objects, except that we also take into + consideration the all_editions field. + """ + return hash( + hash_sha256( + { + **dict(self.groups.items()), + **{ + "all_editions": sorted( + [asdict(e) for e in self.all_editions], + key=lambda d: d["short_name"], # type: ignore + ), + "class": type(self).__name__, + }, + } + ) + ) + @dataclass(eq=True, unsafe_hash=True) class Metadata(CitationBase.Metadata): """Define fields on self.metadata.""" @@ -201,11 +252,6 @@ class Metadata(CitationBase.Metadata): pin_cite: Optional[str] = None year: Optional[str] = None - def comparison_hash(self) -> int: - """Return hash that will be the same if two cites are semantically - equivalent.""" - return hash((super().comparison_hash(), self.all_editions)) - def add_metadata(self, words: "Tokens"): """Extract metadata from text before and after citation.""" self.guess_edition() @@ -248,13 +294,13 @@ def guess_edition(self): self.edition_guess = editions[0] -@dataclass(eq=True, unsafe_hash=True, repr=False) +@dataclass(eq=False, unsafe_hash=False, repr=False) class FullCitation(ResourceCitation): """Abstract base class indicating that a citation fully identifies a resource.""" -@dataclass(eq=True, unsafe_hash=True, repr=False) +@dataclass(eq=False, unsafe_hash=False, repr=False) class FullLawCitation(FullCitation): """Citation to a source from `reporters_db/laws.json`.""" @@ -291,7 +337,7 @@ def corrected_citation_full(self): return "".join(parts) -@dataclass(eq=True, unsafe_hash=True, repr=False) +@dataclass(eq=False, unsafe_hash=False, repr=False) class FullJournalCitation(FullCitation): """Citation to a source from `reporters_db/journals.json`.""" @@ -317,12 +363,43 @@ def corrected_citation_full(self): return "".join(parts) -@dataclass(eq=True, unsafe_hash=True, repr=False) +@dataclass(eq=False, unsafe_hash=False, repr=False) class CaseCitation(ResourceCitation): """Convenience class which represents a single citation found in a document. """ + def __hash__(self) -> int: + """CaseCitation objects that have the same volume, reporter, and page + are considered equivalent, unless the citation is missing a page, in + which case the object's hash will be unique for safety. + + self.groups for CaseCitation objects usually contains these keys: + - page (guaranteed here: https://github.com/freelawproject/reporters-db/blob/main/tests.py#L129) # noqa: E501 + - reporter (guaranteed here: https://github.com/freelawproject/reporters-db/blob/main/tests.py#L129) # noqa: E501 + - volume (almost always present, but some tax court citations don't have volumes) # noqa: E501 + - reporter_nominative (sometimes) + - volumes_nominative (sometimes) + """ + if self.groups["page"] is None: + return id(self) + else: + return hash( + hash_sha256( + { + **{ + k: self.groups[k] + for k in ["volume", "page"] + if k in self.groups + }, + **{ + "reporter": self.corrected_reporter(), + "class": type(self).__name__, + }, + } + ) + ) + @dataclass(eq=True, unsafe_hash=True) class Metadata(FullCitation.Metadata): """Define fields on self.metadata.""" @@ -339,7 +416,7 @@ def guess_court(self): self.metadata.court = "scotus" -@dataclass(eq=True, unsafe_hash=True, repr=False) +@dataclass(eq=False, unsafe_hash=False, repr=False) class FullCaseCitation(CaseCitation, FullCitation): """Convenience class which represents a standard, fully named citation, i.e., the kind of citation that marks the first time a document is cited. @@ -389,7 +466,7 @@ def corrected_citation_full(self): return "".join(parts) -@dataclass(eq=True, unsafe_hash=True, repr=False) +@dataclass(eq=False, unsafe_hash=False, repr=False) class ShortCaseCitation(CaseCitation): """Convenience class which represents a short form citation, i.e., the kind of citation made after a full citation has already appeared. This kind of @@ -419,7 +496,7 @@ def corrected_citation_full(self): return "".join(parts) -@dataclass(eq=True, unsafe_hash=True, repr=False) +@dataclass(eq=False, unsafe_hash=False, repr=False) class SupraCitation(CitationBase): """Convenience class which represents a 'supra' citation, i.e., a citation to something that is above in the document. Like a short form citation, @@ -458,7 +535,7 @@ def formatted(self): return "".join(parts) -@dataclass(eq=True, unsafe_hash=True, repr=False) +@dataclass(eq=False, unsafe_hash=False, repr=False) class IdCitation(CitationBase): """Convenience class which represents an 'id' or 'ibid' citation, i.e., a citation to the document referenced immediately prior. An 'id' citation is @@ -469,6 +546,10 @@ class IdCitation(CitationBase): Example: "... foo bar," id., at 240 """ + def __hash__(self) -> int: + """IdCitation objects are always considered unique for safety.""" + return id(self) + @dataclass(eq=True, unsafe_hash=True) class Metadata(CitationBase.Metadata): """Define fields on self.metadata.""" @@ -483,7 +564,7 @@ def formatted(self): return "".join(parts) -@dataclass(eq=True, unsafe_hash=True, repr=False) +@dataclass(eq=False, unsafe_hash=False, repr=False) class UnknownCitation(CitationBase): """Convenience class which represents an unknown citation. A recognized citation should theoretically be parsed as a CaseCitation, FullLawCitation, @@ -491,6 +572,10 @@ class UnknownCitation(CitationBase): a naive catch-all. """ + def __hash__(self) -> int: + """UnknownCitation objects are always considered unique for safety.""" + return id(self) + @dataclass(eq=True, unsafe_hash=True) class Token(UserString): @@ -636,13 +721,20 @@ class Resource(ResourceType): def __hash__(self): """Resources are the same if their citations are semantically - equivalent. + equivalent, as defined by their hash function. Note: Resources composed of citations with missing page numbers are NOT considered the same, even if their other attributes are identical. This is to avoid potential false positives. """ - return self.citation.comparison_hash() + return hash( + hash_sha256( + { + "citation": hash(self.citation), + "class": type(self).__name__, + } + ) + ) def __eq__(self, other): return self.__hash__() == other.__hash__() diff --git a/eyecite/test_factories.py b/eyecite/test_factories.py index ad7fd547..b200ebb2 100644 --- a/eyecite/test_factories.py +++ b/eyecite/test_factories.py @@ -67,8 +67,8 @@ def case_citation( def law_citation( - source_text, - reporter, + source_text=None, + reporter="Mass. Gen. Laws", **kwargs, ): """Convenience function for creating mock FullLawCitation objects.""" diff --git a/eyecite/utils.py b/eyecite/utils.py index a7df001f..b642d23e 100644 --- a/eyecite/utils.py +++ b/eyecite/utils.py @@ -1,3 +1,5 @@ +import hashlib +import json import re from lxml import etree @@ -72,13 +74,6 @@ def on_match(index, start, end, flags, context): return matches -class HashableDict(dict): - """Dict that works as an attribute of a hashable dataclass.""" - - def __hash__(self): - return hash(frozenset(self.items())) - - def dump_citations(citations, text, context_chars=30): """Dump citations extracted from text, for debugging. Example: >>> text = "blah. Foo v. Bar, 1 U.S. 1, 2 (1999). blah" @@ -117,3 +112,20 @@ def dump_citations(citations, text, context_chars=30): else: out.append(f" * {key}={repr(value)}") return "\n".join(out) + + +def hash_sha256(dictionary: dict) -> int: + """Hash dictionaries in a deterministic way. + + :param dictionary: The dictionary to hash + :return: An integer hash + """ + + # Convert the dictionary to a JSON string + json_str: str = json.dumps(dictionary, sort_keys=True) + + # Convert the JSON string to bytes + json_bytes: bytes = json_str.encode("utf-8") + + # Calculate the hash of the bytes, convert to an int, and return + return int.from_bytes(hashlib.sha256(json_bytes).digest(), byteorder="big") diff --git a/tests/test_ModelsTest.py b/tests/test_ModelsTest.py index 8e025af7..906aa29f 100644 --- a/tests/test_ModelsTest.py +++ b/tests/test_ModelsTest.py @@ -2,21 +2,28 @@ from eyecite import get_citations from eyecite.models import Resource -from eyecite.test_factories import case_citation +from eyecite.test_factories import ( + case_citation, + id_citation, + journal_citation, + law_citation, + unknown_citation, +) class ModelsTest(TestCase): def test_citation_comparison(self): """Are two citation objects equal when their attributes are the same?""" - citations = [ - case_citation(2, volume="2", reporter="U.S.", page="2"), - case_citation(2, volume="2", reporter="U.S.", page="2"), - ] - print("Testing citation comparison...", end=" ") - self.assertEqual(citations[0], citations[1]) - self.assertEqual(hash(citations[0]), hash(citations[1])) - print("✓") + for factory in [case_citation, journal_citation, law_citation]: + citations = [ + factory(), + factory(), + ] + print(f"Testing {factory.__name__} comparison...", end=" ") + self.assertEqual(citations[0], citations[1]) + self.assertEqual(hash(citations[0]), hash(citations[1])) + print("✓") def test_resource_comparison(self): """Are two Resource objects equal when their citations' attributes are @@ -42,6 +49,115 @@ def test_resource_comparison_with_missing_page_cites(self): self.assertNotEqual(hash(citations[0]), hash(citations[1])) print("✓") + def test_citation_comparison_with_missing_page_cites(self): + """Are two citation objects different when one of them is missing + a page, even if their other attributes are the same?""" + citations = [ + case_citation(2, volume="2", reporter="U.S.", page="__"), + case_citation(2, volume="2", reporter="U.S.", page="__"), + ] + print("Testing citation comparison with missing pages...", end=" ") + self.assertNotEqual(citations[0], citations[1]) + self.assertNotEqual(hash(citations[0]), hash(citations[1])) + print("✓") + + def test_citation_comparison_with_corrected_reporter(self): + """Are two citation objects equal when their attributes are + the same, even if the reporter has been normalized?""" + citations = [ + case_citation(2, volume="2", reporter="U.S.", page="4"), + case_citation(2, volume="2", reporter="U. S.", page="4"), + ] + print( + "Testing citation comparison with corrected reporter...", end=" " + ) + self.assertEqual(citations[0], citations[1]) + self.assertEqual(hash(citations[0]), hash(citations[1])) + print("✓") + + def test_citation_comparison_with_different_source_text(self): + """Are two citation objects equal when their attributes are + the same, even if they have different source text?""" + citations = [ + case_citation( + source_text="foobar", volume="2", reporter="U.S.", page="4" + ), + case_citation( + source_text="foo", volume="2", reporter="U.S.", page="4" + ), + ] + print( + "Testing citation comparison with different source text...", + end=" ", + ) + self.assertEqual(citations[0], citations[1]) + self.assertEqual(hash(citations[0]), hash(citations[1])) + print("✓") + + def test_citation_comparison_with_nominative_reporter(self): + """Are two citation objects equal when their attributes are + the same, even if one of them has a nominative reporter?""" + citations = [ + get_citations("5 U.S. 137")[0], + get_citations("5 U.S. (1 Cranch) 137")[0], + ] + print( + "Testing citation comparison with nominative reporter...", end=" " + ) + self.assertEqual(citations[0], citations[1]) + self.assertEqual(hash(citations[0]), hash(citations[1])) + print("✓") + + def test_citation_comparison_with_different_reporter(self): + """Are two citation objects different when they have different + reporters, even if their other attributes are the same? + (sanity check)""" + citations = [ + case_citation(2, volume="2", reporter="F. Supp.", page="4"), + case_citation(2, volume="2", reporter="U. S.", page="4"), + ] + print( + "Testing citation comparison with different reporters...", end=" " + ) + self.assertNotEqual(citations[0], citations[1]) + self.assertNotEqual(hash(citations[0]), hash(citations[1])) + print("✓") + + def test_tax_court_citation_comparison(self): + """Are two citation objects equal when their attributes are + the same, even if they are tax court citations and might not + have volumes?""" + citations = [ + get_citations("T.C.M. (RIA) ¶ 95,342")[0], + get_citations("T.C.M. (RIA) ¶ 95,342")[0], + ] + print("Testing tax court citation comparison...", end=" ") + self.assertEqual(citations[0], citations[1]) + self.assertEqual(hash(citations[0]), hash(citations[1])) + print("✓") + + def test_id_citation_comparison(self): + """Are two IdCitation objects always different?""" + citations = [ + id_citation("Id.,", metadata={"pin_cite": "at 123"}), + id_citation("Id.,", metadata={"pin_cite": "at 123"}), + ] + print("Testing id citation comparison...", end=" ") + self.assertNotEqual(citations[0], citations[1]) + self.assertNotEqual(hash(citations[0]), hash(citations[1])) + print("✓") + + def test_unknown_citation_comparison(self): + """Are two UnknownCitation objects always different?""" + citations = [ + unknown_citation("§99"), + unknown_citation("§99"), + ] + print("Testing unknown citation comparison...", end=" ") + self.assertNotEqual(citations[0], citations[1]) + self.assertNotEqual(hash(citations[0]), hash(citations[1])) + print("✓") + def test_missing_page_cite_conversion(self): """Do citations with missing page numbers get their groups['page'] attribute set to None?""" @@ -52,3 +168,37 @@ def test_missing_page_cite_conversion(self): self.assertIsNone(citation1.groups["page"]) self.assertIsNone(citation2.groups["page"]) print("✓") + + def test_persistent_hash(self): + """Are object hashes reproducible across runs?""" + print("Testing persistent citation hash...", end=" ") + objects = [ + ( + case_citation(), + 376794172219282606, + ), + ( + journal_citation(), + 1073308118601601409, + ), + ( + law_citation(), + 407008277458283218, + ), + ( + Resource(case_citation()), + 1986750081022884797, + ), + ] + for citation, citation_hash in objects: + self.assertEqual(hash(citation), citation_hash) + print("✓") + + def test_hash_function_identity(self): + """Do hash() and __hash__() output the same hash?""" + citation = case_citation() + resource = Resource(case_citation()) + print("Testing hash function identity...", end=" ") + self.assertEqual(hash(citation), citation.__hash__()) + self.assertEqual(hash(resource), resource.__hash__()) + print("✓")