Skip to content

Commit

Permalink
Added delete function (#121)
Browse files Browse the repository at this point in the history
* Added delete function

* Fixed subscript

* Improved backwards compatible pickling
  • Loading branch information
whitead authored May 25, 2023
1 parent 6a51327 commit 6ec2273
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 16 deletions.
62 changes: 49 additions & 13 deletions paperqa/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,14 @@
from langchain.vectorstores import FAISS

from .paths import CACHE_PATH
from .qaprompts import (citation_prompt, make_chain, qa_prompt, search_prompt,
select_paper_prompt, summary_prompt)
from .qaprompts import (
citation_prompt,
make_chain,
qa_prompt,
search_prompt,
select_paper_prompt,
summary_prompt,
)
from .readers import read_doc
from .types import Answer, Context
from .utils import maybe_is_text, md5sum
Expand Down Expand Up @@ -65,14 +71,15 @@ def __init__(
if embeddings is None:
embeddings = OpenAIEmbeddings()
self.embeddings = embeddings
self._deleted_keys = set()

def update_llm(
self,
llm: Optional[Union[LLM, str]] = None,
summary_llm: Optional[Union[LLM, str]] = None,
) -> None:
"""Update the LLM for answering questions."""
if llm is None:
if llm is None and os.environ.get("OPENAI_API_KEY") is not None:
llm = "gpt-3.5-turbo"
if type(llm) is str:
llm = ChatOpenAI(temperature=0.1, model_name=llm)
Expand Down Expand Up @@ -112,11 +119,13 @@ def add(
raise ValueError(f"Document {path} already in collection.")

if citation is None:
cite_chain = make_chain(prompt=citation_prompt, llm=self.summary_llm)
cite_chain = make_chain(
prompt=citation_prompt, llm=self.summary_llm)
# peak first chunk
texts, _ = read_doc(path, "", "", chunk_chars=chunk_chars)
if len(texts) == 0:
raise ValueError(f"Could not read document {path}. Is it empty?")
raise ValueError(
f"Could not read document {path}. Is it empty?")
citation = cite_chain.run(texts[0])
if len(citation) < 3 or "Unknown" in citation or "insufficient" in citation:
citation = f"Unknown, {os.path.basename(path)}, {datetime.now().year}"
Expand All @@ -136,7 +145,8 @@ def add(
year = ""
key = f"{author}{year}"
key = self.get_unique_key(key)
texts, metadata = read_doc(path, citation, key, chunk_chars=chunk_chars)
texts, metadata = read_doc(
path, citation, key, chunk_chars=chunk_chars)
# loose check to see if document was loaded
#
if len("".join(texts)) < 10 or (
Expand Down Expand Up @@ -188,6 +198,14 @@ def add_texts(
)
self.keys.add(key)

def delete(self, key: str) -> None:
"""Delete a document from the collection."""
if key not in self.keys:
return
self.keys.remove(key)
self.docs = [doc for doc in self.docs if doc["key"] != key]
self._deleted_keys.add(key)

def clear(self) -> None:
"""Clear the collection of documents."""
self.docs = []
Expand Down Expand Up @@ -221,11 +239,16 @@ async def adoc_match(
return ""
if self._doc_index is None:
texts = [doc["metadata"][0]["citation"] for doc in self.docs]
metadatas = [{"key": doc["metadata"][0]["dockey"]} for doc in self.docs]
metadatas = [{"key": doc["metadata"][0]["dockey"]}
for doc in self.docs]
self._doc_index = FAISS.from_texts(
texts, metadatas=metadatas, embedding=self.embeddings
)
docs = self._doc_index.max_marginal_relevance_search(query, k=k)
docs = self._doc_index.max_marginal_relevance_search(
query, k=k + len(self._deleted_keys)
)
docs = [doc for doc in docs if doc.metadata["key"]
not in self._deleted_keys]
chain = make_chain(select_paper_prompt, self.summary_llm)
papers = [f"{d.metadata['key']}: {d.page_content}" for d in docs]
result = await chain.arun(
Expand All @@ -241,11 +264,16 @@ def doc_match(
return ""
if self._doc_index is None:
texts = [doc["metadata"][0]["citation"] for doc in self.docs]
metadatas = [{"key": doc["metadata"][0]["dockey"]} for doc in self.docs]
metadatas = [{"key": doc["metadata"][0]["dockey"]}
for doc in self.docs]
self._doc_index = FAISS.from_texts(
texts, metadatas=metadatas, embedding=self.embeddings
)
docs = self._doc_index.max_marginal_relevance_search(query, k=k)
docs = self._doc_index.max_marginal_relevance_search(
query, k=k + len(self._deleted_keys)
)
docs = [doc for doc in docs if doc.metadata["key"]
not in self._deleted_keys]
chain = make_chain(select_paper_prompt, self.summary_llm)
papers = [f"{d.metadata['key']}: {d.page_content}" for d in docs]
result = chain.run(
Expand All @@ -264,19 +292,25 @@ def __getstate__(self):
def __setstate__(self, state):
self.__dict__.update(state)
try:
self._faiss_index = FAISS.load_local(self.index_path, self.embeddings)
self._faiss_index = FAISS.load_local(
self.index_path, self.embeddings)
except:
# they use some special exception type, but I don't want to import it
self._faiss_index = None
if not hasattr(self, "_doc_index"):
self._doc_index = None
# must be a better way to have backwards compatibility
if not hasattr(self, "_deleted_keys"):
self._deleted_keys = set()
self.update_llm(None, None)

def _build_faiss_index(self):
if self._faiss_index is None:
texts = reduce(lambda x, y: x + y, [doc["texts"] for doc in self.docs], [])
texts = reduce(lambda x, y: x + y,
[doc["texts"] for doc in self.docs], [])
text_embeddings = reduce(
lambda x, y: x + y, [doc["text_embeddings"] for doc in self.docs], []
lambda x, y: x + y, [doc["text_embeddings"]
for doc in self.docs], []
)
metadatas = reduce(
lambda x, y: x + y, [doc["metadata"] for doc in self.docs], []
Expand Down Expand Up @@ -345,6 +379,8 @@ async def aget_evidence(
)

async def process(doc):
if doc.metadata["dockey"] in self._deleted_keys:
return None, None
if key_filter is not None and doc.metadata["dockey"] not in key_filter:
return None, None
# check if it is already in answer (possible in agent setting)
Expand Down
3 changes: 1 addition & 2 deletions paperqa/qaprompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from langchain.callbacks.manager import AsyncCallbackManagerForChainRun
from langchain.chains import LLMChain
from langchain.chat_models import ChatOpenAI
from langchain.prompts.chat import (ChatPromptTemplate,
HumanMessagePromptTemplate)
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain.schema import LLMResult, SystemMessage

summary_prompt = prompts.PromptTemplate(
Expand Down
2 changes: 1 addition & 1 deletion paperqa/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.8.0"
__version__ = "1.9.0"
30 changes: 30 additions & 0 deletions tests/test_paperqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,36 @@ def test_dockey_filter():
docs.get_evidence(answer, key_filter=["test"])


def test_dockey_delete():
"""Test that we can filter evidence with dockeys"""
doc_path = "example2.txt"
with open(doc_path, "w", encoding="utf-8") as f:
# get wiki page about politician
r = requests.get("https://en.wikipedia.org/wiki/Frederick_Bates_(politician)")
f.write(r.text)
docs = paperqa.Docs()
docs.add(doc_path, "WikiMedia Foundation, 2023, Accessed now")
# add with new dockey
with open("example.txt", "w", encoding="utf-8") as f:
f.write(r.text)
f.write("\n\nBates could be from Angola") # so we don't have same hash
docs.add("example.txt", "WikiMedia Foundation, 2023, Accessed now", key="test")
answer = paperqa.Answer("What country is Bates from?")
answer = docs.get_evidence(answer, marginal_relevance=False)
keys = set([c.key for c in answer.contexts])
assert len(keys) == 2
assert len(docs.docs) == 2
assert len(docs.keys) == 2

docs.delete("test")
assert len(docs.docs) == 1
assert len(docs.keys) == 1
answer = paperqa.Answer("What country is Bates from?")
answer = docs.get_evidence(answer, marginal_relevance=False)
keys = set([c.key for c in answer.contexts])
assert len(keys) == 1


def test_query_filter():
"""Test that we can filter evidence with in query"""
doc_path = "example2.txt"
Expand Down

0 comments on commit 6ec2273

Please sign in to comment.