Skip to content

Commit

Permalink
Fixed bug in context sorting for multiple calls
Browse files Browse the repository at this point in the history
  • Loading branch information
whitead committed Aug 8, 2023
1 parent fda9b97 commit 89ff929
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 17 deletions.
29 changes: 13 additions & 16 deletions paperqa/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,8 @@ def get_evidence(
marginal_relevance: bool = True,
get_callbacks: CallbackFactory = lambda x: None,
detailed_citations: bool = False,
ablate_vector_search: bool = False,
ablate_summarization: bool = False,
disable_vector_search: bool = False,
disable_summarization: bool = False,
) -> Answer:
# special case for jupyter notebooks
if "get_ipython" in globals() or "google.colab" in sys.modules:
Expand All @@ -389,8 +389,8 @@ def get_evidence(
marginal_relevance=marginal_relevance,
get_callbacks=get_callbacks,
detailed_citations=detailed_citations,
ablate_vector_search=ablate_vector_search,
ablate_summarization=ablate_summarization,
disable_vector_search=disable_vector_search,
disable_summarization=disable_summarization,
)
)

Expand All @@ -402,10 +402,10 @@ async def aget_evidence(
marginal_relevance: bool = True,
get_callbacks: CallbackFactory = lambda x: None,
detailed_citations: bool = False,
ablate_vector_search: bool = False,
ablate_summarization: bool = False,
disable_vector_search: bool = False,
disable_summarization: bool = False,
) -> Answer:
if ablate_vector_search:
if disable_vector_search:
k = k * 10000
if len(self.docs) == 0 and self.doc_index is None:
return answer
Expand Down Expand Up @@ -476,7 +476,7 @@ async def process(match):
if guess_is_4xx(str(e)):
return None
raise e
if "not applicable" in context.lower():
if "not applicable" in context.lower() or "not relevant" in context.lower():
return None
c = Context(
context=context,
Expand All @@ -489,7 +489,7 @@ async def process(match):
)
return c

if ablate_summarization:
if disable_summarization:
contexts = [
Context(
context=match.page_content,
Expand All @@ -502,21 +502,18 @@ async def process(match):
)
for match in matches
]
answer.contexts += contexts

else:
results = await gather_with_concurrency(
self.max_concurrent, *[process(m) for m in matches]
)
# filter out failures
contexts = [c for c in results if c is not None]
if len(contexts) == 0:
return answer
contexts = sorted(contexts, key=lambda x: x.score, reverse=True)
contexts = contexts[:max_sources]
# add to answer contexts
answer.contexts += contexts

answer.contexts = sorted(
contexts + answer.contexts, key=lambda x: x.score, reverse=True
)
answer.contexts = answer.contexts[:max_sources]
context_str = "\n\n".join(
[
f"{c.text.name}: {c.context}"
Expand Down
2 changes: 1 addition & 1 deletion paperqa/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.5.0"
__version__ = "3.6.0"

0 comments on commit 89ff929

Please sign in to comment.