Skip to content

Commit

Permalink
Add clinical trials search tool (#777)
Browse files Browse the repository at this point in the history
  • Loading branch information
mskarlin authored Jan 2, 2025
1 parent 525bb32 commit 919bf0c
Show file tree
Hide file tree
Showing 18 changed files with 991 additions and 36 deletions.
142 changes: 127 additions & 15 deletions paperqa/agents/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,18 @@

from paperqa.docs import Docs
from paperqa.settings import Settings
from paperqa.sources.clinical_trials import (
CLINICAL_TRIALS_BASE,
partition_clinical_trials_by_source,
)
from paperqa.types import PQASession
from paperqa.utils import get_year

from .models import QueryRequest
from .tools import (
AVAILABLE_TOOL_NAME_TO_CLASS,
DEFAULT_TOOL_NAMES,
ClinicalTrialsSearch,
Complete,
EnvironmentState,
GatherEvidence,
Expand All @@ -34,7 +40,7 @@
POPULATE_FROM_SETTINGS = None


def settings_to_tools(
def settings_to_tools( # noqa: PLR0912
settings: Settings,
llm_model: LiteLLMModel | None = POPULATE_FROM_SETTINGS,
summary_llm_model: LiteLLMModel | None = POPULATE_FROM_SETTINGS,
Expand All @@ -50,7 +56,7 @@ def settings_to_tools(
embedding_model = embedding_model or settings.get_embedding_model()
tools: list[Tool] = []
for tool_type in (
(PaperSearch, GatherEvidence, GenerateAnswer, Reset, Complete)
[AVAILABLE_TOOL_NAME_TO_CLASS[name] for name in DEFAULT_TOOL_NAMES]
if settings.agent.tool_names is None
else [
AVAILABLE_TOOL_NAME_TO_CLASS[name]
Expand All @@ -68,26 +74,56 @@ def settings_to_tools(
str, tool.info.parameters.properties[pname]["description"]
).format(current_year=get_year())
elif issubclass(tool_type, GatherEvidence):
tool = Tool.from_function(
GatherEvidence(
settings=settings,
summary_llm_model=summary_llm_model,
embedding_model=embedding_model,
).gather_evidence
gather_evidence_tool = GatherEvidence(
settings=settings,
summary_llm_model=summary_llm_model,
embedding_model=embedding_model,
)

# if we're using the SearchClinicalTrialsTool,
# we override this tool's docstring/prompt
# because the default prompt is unaware of the clinical trials tool

if ClinicalTrialsSearch.TOOL_FN_NAME in (
settings.agent.tool_names or DEFAULT_TOOL_NAMES
):
gather_evidence_tool.gather_evidence.__func__.__doc__ = ( # type: ignore[attr-defined]
ClinicalTrialsSearch.GATHER_EVIDENCE_TOOL_PROMPT_OVERRIDE
)
gather_evidence_tool.partitioning_fn = (
partition_clinical_trials_by_source
)

tool = Tool.from_function(gather_evidence_tool.gather_evidence)

elif issubclass(tool_type, GenerateAnswer):
tool = Tool.from_function(
GenerateAnswer(
settings=settings,
llm_model=llm_model,
summary_llm_model=summary_llm_model,
embedding_model=embedding_model,
).gen_answer
generate_answer_tool = GenerateAnswer(
settings=settings,
llm_model=llm_model,
summary_llm_model=summary_llm_model,
embedding_model=embedding_model,
)

if ClinicalTrialsSearch.TOOL_FN_NAME in (
settings.agent.tool_names or DEFAULT_TOOL_NAMES
):
generate_answer_tool.partitioning_fn = (
partition_clinical_trials_by_source
)

tool = Tool.from_function(generate_answer_tool.gen_answer)

elif issubclass(tool_type, Reset):
tool = Tool.from_function(Reset().reset)
elif issubclass(tool_type, Complete):
tool = Tool.from_function(Complete().complete)
elif issubclass(tool_type, ClinicalTrialsSearch):
tool = Tool.from_function(
ClinicalTrialsSearch(
search_count=settings.agent.search_count,
settings=settings,
).clinical_trials_search
)
else:
raise NotImplementedError(f"Didn't handle tool type {tool_type}.")
if tool.info.name == Complete.complete.__name__:
Expand All @@ -97,6 +133,74 @@ def settings_to_tools(
return tools


def make_clinical_trial_status(
total_paper_count: int,
relevant_paper_count: int,
total_clinical_trials: int,
relevant_clinical_trials: int,
evidence_count: int,
cost: float,
) -> str:
return (
f"Status: Paper Count={total_paper_count}"
f" | Relevant Papers={relevant_paper_count}"
f" | Clinical Trial Count={total_clinical_trials}"
f" | Relevant Clinical Trials={relevant_clinical_trials}"
f" | Current Evidence={evidence_count}"
f" | Current Cost=${cost:.4f}"
)


# SEE: https://regex101.com/r/L0L5MH/1
CLINICAL_STATUS_SEARCH_REGEX_PATTERN: str = (
r"Status: Paper Count=(\d+) \| Relevant Papers=(\d+)(?:\s\|\sClinical Trial Count=(\d+)\s"
r"\|\sRelevant Clinical Trials=(\d+))?\s\|\sCurrent Evidence=(\d+)"
)


def clinical_trial_status(state: "EnvironmentState") -> str:
return make_clinical_trial_status(
total_paper_count=len(
{
d.dockey
for d in state.docs.docs.values()
if CLINICAL_TRIALS_BASE
not in getattr(d, "other", {}).get("client_source", [])
}
),
relevant_paper_count=len(
{
c.text.doc.dockey
for c in state.session.contexts
if c.score > state.RELEVANT_SCORE_CUTOFF
and CLINICAL_TRIALS_BASE
not in getattr(c.text.doc, "other", {}).get("client_source", [])
}
),
total_clinical_trials=len(
{
d.dockey
for d in state.docs.docs.values()
if CLINICAL_TRIALS_BASE
in getattr(d, "other", {}).get("client_source", [])
}
),
relevant_clinical_trials=len(
{
c.text.doc.dockey
for c in state.session.contexts
if c.score > state.RELEVANT_SCORE_CUTOFF
and CLINICAL_TRIALS_BASE
in getattr(c.text.doc, "other", {}).get("client_source", [])
}
),
evidence_count=len(
[c for c in state.session.contexts if c.score > state.RELEVANT_SCORE_CUTOFF]
),
cost=state.session.cost,
)


class PaperQAEnvironment(Environment[EnvironmentState]):
"""Environment connecting paper-qa's tools with state."""

Expand Down Expand Up @@ -127,13 +231,21 @@ def make_tools(self) -> list[Tool]:
)

def make_initial_state(self) -> EnvironmentState:
status_fn = None

if ClinicalTrialsSearch.TOOL_FN_NAME in (
self._query.settings.agent.tool_names or DEFAULT_TOOL_NAMES
):
status_fn = clinical_trial_status

return EnvironmentState(
docs=self._docs,
session=PQASession(
question=self._query.query,
config_md5=self._query.settings.md5,
id=self._query.id,
),
status_fn=status_fn,
)

async def reset(self) -> tuple[list[Message], list[Tool]]:
Expand Down
5 changes: 3 additions & 2 deletions paperqa/agents/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from rich.table import Table

from paperqa.docs import Docs
from paperqa.types import DocDetails

from .models import AnswerResponse

Expand Down Expand Up @@ -92,10 +93,10 @@ def table_formatter(
table.add_column("File", style="magenta")
for obj, filename in objects:
try:
display_name = cast(Docs, obj).texts[0].doc.title
display_name = cast(DocDetails, cast(Docs, obj).texts[0].doc).title
except AttributeError:
display_name = cast(Docs, obj).texts[0].doc.formatted_citation
table.add_row(display_name[:max_chars_per_column], filename)
table.add_row(cast(str, display_name)[:max_chars_per_column], filename)
return table
raise NotImplementedError(
f"Object type {type(example_object)} can not be converted to table."
Expand Down
Loading

0 comments on commit 919bf0c

Please sign in to comment.