diff --git a/paperqa/docs.py b/paperqa/docs.py index 7580bf09d..cdc064d25 100644 --- a/paperqa/docs.py +++ b/paperqa/docs.py @@ -35,6 +35,7 @@ maybe_is_text, md5sum, name_in_text, + strip_citations, ) @@ -59,6 +60,8 @@ class Docs(BaseModel, arbitrary_types_allowed=True, smart_union=True): memory: bool = False memory_model: Optional[BaseChatMemory] = None jit_texts_index: bool = False + # This is used to strip indirect citations that come up from the summary llm + strip_citations: bool = True # TODO: Not sure how to get this to work # while also passing mypy checks @@ -505,6 +508,9 @@ async def process(match): raise e if "not applicable" in context.lower() or "not relevant" in context.lower(): return None + if self.strip_citations: + # remove citations that collide with our grounded citations (for the answer LLM) + context = strip_citations(context) c = Context( context=context, text=Text( diff --git a/paperqa/utils.py b/paperqa/utils.py index d7da8ee75..dd28e01b1 100644 --- a/paperqa/utils.py +++ b/paperqa/utils.py @@ -97,3 +97,11 @@ def get_llm_name(llm: BaseLanguageModel) -> str: return llm.model_name # type: ignore except AttributeError: return llm.model # type: ignore + + +def strip_citations(text: str) -> str: + # Combined regex for identifying citations (see unit tests for examples) + citation_regex = r"\b[\w\-]+\set\sal\.\s\([0-9]{4}\)|\((?:[^\)]*?[a-zA-Z][^\)]*?[0-9]{4}[^\)]*?)\)" + # Remove the citations from the text + text = re.sub(citation_regex, "", text, flags=re.MULTILINE) + return text diff --git a/paperqa/version.py b/paperqa/version.py index e7e98ee6f..d1a7f1e0d 100644 --- a/paperqa/version.py +++ b/paperqa/version.py @@ -1 +1 @@ -__version__ = "3.11.2" +__version__ = "3.12.0" diff --git a/setup.py b/setup.py index c277abaa3..d2463365d 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ "pypdf", "pydantic<2", "langchain>=0.0.303", - "openai >= 0.27.8", + "openai <1", "faiss-cpu", "PyCryptodome", "html2text", diff --git a/tests/test_paperqa.py b/tests/test_paperqa.py index 6a69f8b57..ae7f4e073 100644 --- a/tests/test_paperqa.py +++ b/tests/test_paperqa.py @@ -15,7 +15,13 @@ from paperqa.chains import get_score from paperqa.readers import read_doc from paperqa.types import Doc -from paperqa.utils import maybe_is_html, maybe_is_text, name_in_text, strings_similarity +from paperqa.utils import ( + maybe_is_html, + maybe_is_text, + name_in_text, + strings_similarity, + strip_citations, +) class TestHandler(AsyncCallbackHandler): @@ -23,6 +29,74 @@ async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: print(token) +# Assume strip_citations is imported or defined in this file. + + +def test_single_author(): + text = "This was first proposed by (Smith 1999)." + assert strip_citations(text) == "This was first proposed by ." + + +def test_multiple_authors(): + text = "Recent studies (Smith et al. 1999) show that this is true." + assert strip_citations(text) == "Recent studies show that this is true." + + +def test_multiple_citations(): + text = "As discussed by several authors (Smith et al. 1999; Johnson 2001; Lee et al. 2003)." + assert strip_citations(text) == "As discussed by several authors ." + + +def test_citations_with_pages(): + text = "This is shown in (Smith et al. 1999, p. 150)." + assert strip_citations(text) == "This is shown in ." + + +def test_citations_without_space(): + text = "Findings by(Smith et al. 1999)were significant." + assert strip_citations(text) == "Findings bywere significant." + + +def test_citations_with_commas(): + text = "The method was adopted by (Smith, 1999, 2001; Johnson, 2002)." + assert strip_citations(text) == "The method was adopted by ." + + +def test_citations_with_text(): + text = "This was noted (see Smith, 1999, for a review)." + assert strip_citations(text) == "This was noted ." + + +def test_no_citations(): + text = "There are no references in this text." + assert strip_citations(text) == "There are no references in this text." + + +def test_malformed_citations(): + text = "This is a malformed citation (Smith 199)." + assert strip_citations(text) == "This is a malformed citation (Smith 199)." + + +def test_edge_case_citations(): + text = "Edge cases like (Smith et al.1999) should be handled." + assert strip_citations(text) == "Edge cases like should be handled." + + +def test_citations_with_special_characters(): + text = "Some names have dashes (O'Neil et al. 2000; Smith-Jones 1998)." + assert strip_citations(text) == "Some names have dashes ." + + +def test_citations_with_nonstandard_chars(): + text = ( + "In non-English languages, citations might look different (Müller et al. 1999)." + ) + assert ( + strip_citations(text) + == "In non-English languages, citations might look different ." + ) + + def test_ablations(): tests_dir = os.path.dirname(os.path.abspath(__file__)) doc_path = os.path.join(tests_dir, "paper.pdf")