Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added markdown output with citations #209

Merged
merged 3 commits into from
Nov 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions paperqa/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
from .docs import Answer, Docs, PromptCollection, Doc, Text
from .docs import Answer, Docs, PromptCollection, Doc, Text, Context
from .version import __version__

__all__ = ["Docs", "Answer", "PromptCollection", "__version__", "Doc", "Text"]
__all__ = [
"Docs",
"Answer",
"PromptCollection",
"__version__",
"Doc",
"Text",
"Context",
]
3 changes: 1 addition & 2 deletions paperqa/contrib/zotero.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
except ImportError:
raise ImportError("Please install pyzotero: `pip install pyzotero`")
from ..paths import PAPERQA_DIR
from ..types import StrPath
from ..utils import count_pdf_pages
from ..utils import StrPath, count_pdf_pages


class ZoteroPaper(BaseModel):
Expand Down
88 changes: 84 additions & 4 deletions paperqa/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Set, Union
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union

from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.manager import (
Expand All @@ -13,16 +12,15 @@
except ImportError:
from pydantic import BaseModel, validator


from .prompts import (
citation_prompt,
default_system_prompt,
qa_prompt,
select_paper_prompt,
summary_prompt,
)
from .utils import extract_doi, iter_citations

StrPath = Union[str, Path]
DocKey = Any
CBManager = Union[AsyncCallbackManagerForChainRun, CallbackManagerForChainRun]
CallbackFactory = Callable[[str], Union[None, List[BaseCallbackHandler]]]
Expand Down Expand Up @@ -129,3 +127,85 @@ class Answer(BaseModel):
def __str__(self) -> str:
"""Return the answer as a string."""
return self.formatted_answer

def get_citation(self, name: str) -> str:
"""Return the formatted citation for the gien docname."""
try:
doc = next(filter(lambda x: x.text.name == name, self.contexts)).text.doc
except StopIteration:
raise ValueError(f"Could not find docname {name} in contexts")
return doc.citation

def markdown(self) -> Tuple[str, str]:
"""Return the answer with footnote style citations."""
# example: This is an answer.[^1]
# [^1]: This the citation.
output = self.answer
refs: Dict[str, int] = dict()
index = 1
for citation in iter_citations(self.answer):
compound = ""
for c in citation.split(","):
c = c.strip("() ")
if c == "Extra background information":
continue
if c in refs:
compound += f"[^{refs[c]}]"
continue
refs[c] = index
compound += f"[^{index}]"
index += 1
output = output.replace(citation, compound)
formatted_refs = "\n".join(
[
f"[^{i}]: [{self.get_citation(r)}]({extract_doi(self.get_citation(r))})"
for r, i in refs.items()
]
)
return output, formatted_refs

def combine_with(self, other: "Answer") -> "Answer":
"""
Combine this answer object with another, merging their context/answer.
"""
combined = Answer(
question=self.question + " / " + other.question,
answer=self.answer + " " + other.answer,
context=self.context + " " + other.context,
contexts=self.contexts + other.contexts,
references=self.references + " " + other.references,
formatted_answer=self.formatted_answer + " " + other.formatted_answer,
summary_length=self.summary_length, # Assuming the same summary_length for both
answer_length=self.answer_length, # Assuming the same answer_length for both
memory=self.memory if self.memory else other.memory,
cost=self.cost if self.cost else other.cost,
token_counts=self.merge_token_counts(self.token_counts, other.token_counts),
)
# Handling dockey_filter if present in either of the Answer objects
if self.dockey_filter or other.dockey_filter:
combined.dockey_filter = (
self.dockey_filter if self.dockey_filter else set()
) | (other.dockey_filter if other.dockey_filter else set())
return combined

@staticmethod
def merge_token_counts(
counts1: Optional[Dict[str, List[int]]], counts2: Optional[Dict[str, List[int]]]
) -> Optional[Dict[str, List[int]]]:
"""
Merge two dictionaries of token counts.
"""
if counts1 is None and counts2 is None:
return None
if counts1 is None:
return counts2
if counts2 is None:
return counts1
merged_counts = counts1.copy()
for key, values in counts2.items():
if key in merged_counts:
merged_counts[key][0] += values[0]
merged_counts[key][1] += values[1]
else:
merged_counts[key] = values
return merged_counts
30 changes: 28 additions & 2 deletions paperqa/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
import math
import re
import string
from typing import BinaryIO, List
from pathlib import Path
from typing import BinaryIO, List, Union

import pypdf
from langchain.base_language import BaseLanguageModel

from .types import StrPath
StrPath = Union[str, Path]


def name_in_text(name: str, text: str) -> bool:
Expand Down Expand Up @@ -105,3 +106,28 @@ def strip_citations(text: str) -> str:
# Remove the citations from the text
text = re.sub(citation_regex, "", text, flags=re.MULTILINE)
return text


def iter_citations(text: str) -> List[str]:
# Combined regex for identifying citations (see unit tests for examples)
citation_regex = r"\b[\w\-]+\set\sal\.\s\([0-9]{4}\)|\((?:[^\)]*?[a-zA-Z][^\)]*?[0-9]{4}[^\)]*?)\)"
result = re.findall(citation_regex, text, flags=re.MULTILINE)
return result


def extract_doi(reference: str) -> str:
"""
Extracts DOI from the reference string using regex.

:param reference: A string containing the reference.
:return: A string containing the DOI link or a message if DOI is not found.
"""
# DOI regex pattern
doi_pattern = r"10.\d{4,9}/[-._;()/:A-Z0-9]+"
doi_match = re.search(doi_pattern, reference, re.IGNORECASE)

# If DOI is found in the reference, return the DOI link
if doi_match:
return "https://doi.org/" + doi_match.group()
else:
return ""
2 changes: 1 addition & 1 deletion paperqa/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.12.0"
__version__ = "3.13.0"
67 changes: 64 additions & 3 deletions tests/test_paperqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
from langchain.llms.fake import FakeListLLM
from langchain.prompts import PromptTemplate

from paperqa import Answer, Docs, PromptCollection, Text
from paperqa import Answer, Context, Doc, Docs, PromptCollection, Text
from paperqa.chains import get_score
from paperqa.readers import read_doc
from paperqa.types import Doc
from paperqa.utils import (
iter_citations,
maybe_is_html,
maybe_is_text,
name_in_text,
Expand All @@ -29,7 +29,36 @@ async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
print(token)


# Assume strip_citations is imported or defined in this file.
def test_iter_citations():
text = (
"Yes, COVID-19 vaccines are effective. Various studies have documented the "
"effectiveness of COVID-19 vaccines in preventing severe disease, "
"hospitalization, and death. The BNT162b2 vaccine has shown effectiveness "
"ranging from 65% to -41% for the 5-11 years age group and 76% to 46% for the "
"12-17 years age group, after the emergence of the Omicron variant in New York "
"(Dorabawila2022EffectivenessOT). Against the Delta variant, the effectiveness "
"of the BNT162b2 vaccine was approximately 88% after two doses "
"(Bernal2021EffectivenessOC pg. 1-3).\n\n"
"Vaccine effectiveness was also found to be 89% against hospitalization and "
"91% against emergency department or urgent care clinic visits "
"(Thompson2021EffectivenessOC pg. 3-5, Goo2031Foo pg. 3-4). In the UK "
"vaccination program, vaccine effectiveness was approximately 56% in "
"individuals aged ≥70 years between 28-34 days post-vaccination, increasing to "
"approximately 58% from day 35 onwards (Marfé2021EffectivenessOC).\n\n"
"However, it is important to note that vaccine effectiveness can decrease over "
"time. For instance, the effectiveness of COVID-19 vaccines against severe "
"COVID-19 declined to 64% after 121 days, compared to around 90% initially "
"(Chemaitelly2022WaningEO, Foo2019Bar). Despite this, vaccines still provide "
"significant protection against severe outcomes."
)
ref = [
"(Dorabawila2022EffectivenessOT)",
"(Bernal2021EffectivenessOC pg. 1-3)",
"(Thompson2021EffectivenessOC pg. 3-5, Goo2031Foo pg. 3-4)",
"(Marfé2021EffectivenessOC)",
"(Chemaitelly2022WaningEO, Foo2019Bar)",
]
assert list(iter_citations(text)) == ref


def test_single_author():
Expand Down Expand Up @@ -97,6 +126,38 @@ def test_citations_with_nonstandard_chars():
)


def test_markdown():
answer = Answer(
question="What was Fredic's greatest accomplishment?",
answer="Frederick Bates's greatest accomplishment was his role in resolving land disputes "
"and his service as governor of Missouri (Wiki2023 chunk 1).",
contexts=[
Context(
context="",
text=Text(
text="Frederick Bates's greatest accomplishment was his role in resolving land disputes "
"and his service as governor of Missouri (Wiki2023 chunk 1).",
name="Wiki2023 chunk 1",
doc=Doc(
name="Wiki2023",
docname="Wiki2023",
citation="WikiMedia Foundation, 2023, Accessed now",
texts=[],
),
),
score=5,
)
],
)
m, r = answer.markdown()
print(r)
assert "[^1]" in m
answer = answer.combine_with(answer)
m2, r2 = answer.markdown()
assert m2.startswith(m)
assert r2 == r


def test_ablations():
tests_dir = os.path.dirname(os.path.abspath(__file__))
doc_path = os.path.join(tests_dir, "paper.pdf")
Expand Down
Loading