From 3c83946fbb14e1fc502a17dd6badb045826e6408 Mon Sep 17 00:00:00 2001 From: James Braza Date: Tue, 17 Dec 2024 15:34:48 -0800 Subject: [PATCH] Moved the repo to use MultipleChoiceQuestion and MultipleChoiceEvaluation from aviary --- paperqa/agents/env.py | 4 +- paperqa/agents/models.py | 9 +- paperqa/agents/task.py | 57 ++++---- paperqa/litqa.py | 286 +++------------------------------------ paperqa/types.py | 14 +- tests/test_litqa.py | 283 ++++++-------------------------------- tests/test_task.py | 23 ++-- 7 files changed, 127 insertions(+), 549 deletions(-) diff --git a/paperqa/agents/env.py b/paperqa/agents/env.py index c748f19d..e7947d8a 100644 --- a/paperqa/agents/env.py +++ b/paperqa/agents/env.py @@ -11,6 +11,7 @@ ToolRequestMessage, ToolResponseMessage, ) +from aviary.utils import MultipleChoiceQuestion from llmclient import EmbeddingModel, LiteLLMModel from paperqa.docs import Docs @@ -127,10 +128,11 @@ def make_tools(self) -> list[Tool]: ) def make_initial_state(self) -> EnvironmentState: + query: str | MultipleChoiceQuestion = self._query.query return EnvironmentState( docs=self._docs, session=PQASession( - question=self._query.query, + question=query if isinstance(query, str) else query.question_prompt, config_md5=self._query.settings.md5, id=self._query.id, ), diff --git a/paperqa/agents/models.py b/paperqa/agents/models.py index eaf366d6..906e8058 100644 --- a/paperqa/agents/models.py +++ b/paperqa/agents/models.py @@ -8,6 +8,7 @@ from typing import Any, ClassVar, Protocol from uuid import UUID, uuid4 +from aviary.utils import MultipleChoiceQuestion from llmclient import LiteLLMModel, LLMModel from pydantic import ( BaseModel, @@ -55,7 +56,13 @@ class MismatchedModelsError(Exception): class QueryRequest(BaseModel): model_config = ConfigDict(extra="forbid") - query: str = "" + query: str | MultipleChoiceQuestion = Field( + default="", + description=( + "The query to be answered. Set to a multiple choice question when grading" + " (e.g. for training)." + ), + ) id: UUID = Field( default_factory=uuid4, description="Identifier which will be propagated to the Answer object.", diff --git a/paperqa/agents/task.py b/paperqa/agents/task.py index 96220146..9c99d1b0 100644 --- a/paperqa/agents/task.py +++ b/paperqa/agents/task.py @@ -17,25 +17,27 @@ from aviary.core import ( TASK_DATASET_REGISTRY, - Frame, Messages, TaskDataset, ToolRequestMessage, ToolResponseMessage, ) from aviary.env import ENV_REGISTRY +from aviary.utils import ( + DEFAULT_EVAL_MODEL_NAME, + MultipleChoiceEvaluation, + MultipleChoiceQuestion, +) from llmclient import EmbeddingModel, LiteLLMModel, LLMModel from paperqa._ldp_shims import ComputeTrajectoryMetricsMixin from paperqa.docs import Docs from paperqa.litqa import ( - DEFAULT_EVAL_MODEL_NAME, DEFAULT_LABBENCH_HF_HUB_NAME, DEFAULT_REWARD_MAPPING, - LitQAEvaluation, read_litqa_v2_from_hub, ) -from paperqa.types import DocDetails, PQASession +from paperqa.types import DocDetails from .env import POPULATE_FROM_SETTINGS, PaperQAEnvironment from .models import QueryRequest @@ -58,26 +60,22 @@ def __init__( llm_model: LiteLLMModel | None = POPULATE_FROM_SETTINGS, summary_llm_model: LiteLLMModel | None = POPULATE_FROM_SETTINGS, embedding_model: EmbeddingModel | None = POPULATE_FROM_SETTINGS, - evaluation_from_answer: ( - Callable[[PQASession | str], Awaitable[LitQAEvaluation]] | None - ) = None, sources: str | list[str] | None = None, rewards: Mapping[str, float] = DEFAULT_REWARD_MAPPING, - evaluation_callback: Callable[[LitQAEvaluation], Awaitable] | None = None, + evaluation_callback: ( + Callable[[MultipleChoiceEvaluation], Awaitable] | None + ) = None, **env_kwargs, ): super().__init__( query, docs, llm_model, summary_llm_model, embedding_model, **env_kwargs ) - self._evaluation_from_answer = evaluation_from_answer # Enables checking an Index has the right DOI(s) self.sources: list[str] | None = ( [sources] if isinstance(sources, str) else sources ) self._evaluation_callback = evaluation_callback self._rewards = rewards - self.answer = "" - self.ideal = "" async def validate_sources( self, manifest_or_index: dict[str, DocDetails] | SearchIndex | None = None @@ -120,7 +118,7 @@ async def step( self, action: ToolRequestMessage ) -> tuple[Messages, float, bool, bool]: messages, reward, done, truncated = await super().step(action) - if not done or not self._evaluation_from_answer: + if not done or not isinstance(self._query.query, MultipleChoiceQuestion): return messages, reward, done, truncated # If the ensuring evaluation fails (e.g. due to OpenAI being down), we can: # - Suppress the exception and declare the evaluation as incorrect, which can @@ -130,23 +128,13 @@ async def step( # incorrectly reward what otherwise was a good trajectory. # - Don't suppress the exception, which leads to the trajectory failing, and # removes it from the learnable pool. This is the only safe default behavior. - evaluation = await self._evaluation_from_answer(self.state.session.answer) + evaluation, self.state.session.graded_answer = await self._query.query.grade( + self.state.session.answer + ) if evaluation_callback := self._evaluation_callback: await evaluation_callback(evaluation) - self.answer = evaluation.answer or "" - self.ideal = evaluation.ideal or "" return messages, reward + self._rewards[evaluation.value], done, truncated - def export_frame(self) -> Frame: - return Frame( - state=self.state, - info={ - "query": self._query, - "answer": self.answer, - "ideal": self.ideal, - }, - ) - def __deepcopy__(self, memo) -> Self: copy_state = deepcopy(self.state, memo) # We don't know the side effects of deep copying a litellm.Router, @@ -162,7 +150,6 @@ def __deepcopy__(self, memo) -> Self: copy_self = type(self)( query=deepcopy(self._query, memo), # deepcopy for _docs_name docs=copy_state.docs, - evaluation_from_answer=self._evaluation_from_answer, sources=self.sources, rewards=self._rewards, evaluation_callback=self._evaluation_callback, @@ -218,24 +205,26 @@ def __init__( def _make_gradable_environment( self, - ideal: str, + ideal_answer: str, distractors: str | list[str], question: str, sources: str | list[str] | None = None, ) -> GradablePaperQAEnvironment: - qa_prompt, evaluation_from_answer = LitQAEvaluation.from_question( - ideal=ideal, - distractors=distractors, + mc_question = MultipleChoiceQuestion( question=question, - eval_model=self._eval_model, + options=( + distractors + if isinstance(distractors, list) + else MultipleChoiceQuestion.split_options(distractors) + ), + ideal_answer=ideal_answer, **(self._question_kwargs or {}), ) query = self._base_query.model_copy() - query.query = qa_prompt + query.query = mc_question return GradablePaperQAEnvironment( query=query, docs=self._base_docs.model_copy(), - evaluation_from_answer=evaluation_from_answer, sources=sources, rewards=self._rewards, **self._env_kwargs, @@ -338,7 +327,7 @@ def get_new_env_by_idx(self, idx: int) -> GradablePaperQAEnvironment: ) from exc sources.append(doi) return self._make_gradable_environment( - ideal=self.data.iloc[idx].ideal, + ideal_answer=self.data.iloc[idx].ideal, distractors=self.data.iloc[idx].distractors, question=self.data.iloc[idx].question, sources=sources, diff --git a/paperqa/litqa.py b/paperqa/litqa.py index 2f5607af..878fe971 100644 --- a/paperqa/litqa.py +++ b/paperqa/litqa.py @@ -2,284 +2,38 @@ from __future__ import annotations -import random -import re -import string -from ast import literal_eval -from collections.abc import Awaitable, Callable, Mapping, Sequence -from enum import StrEnum -from typing import TYPE_CHECKING, Literal, Self +from collections.abc import Mapping +from typing import TYPE_CHECKING -from aviary.core import Message -from llmclient import LiteLLMModel, LLMModel +from aviary.utils import MultipleChoiceEvaluation from paperqa._ldp_shims import discounted_returns -from paperqa.prompts import EVAL_PROMPT_TEMPLATE, QA_PROMPT_TEMPLATE -from paperqa.settings import make_default_litellm_model_list_settings -from paperqa.types import PQASession if TYPE_CHECKING: import pandas as pd -# Special case for LitQA, when ideal == "null" -UNSURE_OPTION = "Insufficient information to answer this question" -_CAPITAL_A_INDEX = ord("A") - -def make_mc_options( - ideal: str, - distractors: str | Sequence[str], - unsure_option: str | None = UNSURE_OPTION, - seed: int | None = None, -) -> tuple[str, str, str | None, list[str]]: - r""" - Return string of options (as letters) and correct answer. - - Examples: - >>> text, ideal_answer, unsure_answer, distractor_answers = make_mc_options( - ... ideal="1", distractors=["0", "2", "Dog"], seed=0 - ... ) - >>> text - 'A) Dog\nB) 2\nC) 0\nD) Insufficient information to answer this question\nE) 1' - >>> ideal_answer - 'E' - >>> unsure_answer - 'D' - >>> distractor_answers - ['C', 'B', 'A'] - """ - if isinstance(distractors, str): - try: - split_distractors = literal_eval(distractors) - if not isinstance(split_distractors, list): - raise TypeError("Need split_distractors to be a list.") # noqa: TRY301 - except (ValueError, SyntaxError, TypeError): - split_distractors = [d.strip("'[ ]\"") for d in distractors.split(",")] - distractors = split_distractors - # We are going to modify options in-place, so copy the distractors - options = [*distractors] - - if ideal == "null": - if not unsure_option: - raise ValueError( - 'Dataset configured for "unsure" options via ' - 'ideal="null", please specify "unsure_option".' - ) - correct_answer = unsure_option - else: - # add the answer to the options, only if not null - options.append(ideal) - correct_answer = ideal - - if unsure_option: - options.append(unsure_option) - - if len(options) > len(string.ascii_lowercase): - raise NotImplementedError( - "Didn't handle more multiple choice options than letters, options were" - f" {options}." - ) - random.Random(seed).shuffle(options) - return ( - "\n".join([f"{_CAPITAL_A_INDEX + i:c}) {o}" for i, o in enumerate(options)]), - chr(_CAPITAL_A_INDEX + options.index(correct_answer)), - chr(_CAPITAL_A_INDEX + options.index(unsure_option)) if unsure_option else None, - [chr(_CAPITAL_A_INDEX + options.index(dstr)) for dstr in distractors], - ) - - -DEFAULT_EVAL_MODEL_NAME = "gpt-4-turbo-2024-04-09" DEFAULT_REWARD_MAPPING = {"correct": 1.0, "unsure": 0.1, "incorrect": -1.0} -SEED_USING_QUESTION: Literal["SEED_USING_QUESTION"] = "SEED_USING_QUESTION" # Sentinel - - -class LitQAEvaluation(StrEnum): - """Possible evaluation results for a LitQA question and methods for working with answers.""" - - CORRECT = "correct" - INCORRECT = "incorrect" - UNSURE = "unsure" - - @property - def answer(self) -> str | None: - return getattr(self, "_answer", None) - - @answer.setter - def answer(self, value: str | None) -> None: - self._answer = value - - @property - def ideal(self) -> str | None: - return getattr(self, "_ideal", None) - @ideal.setter - def ideal(self, value: str) -> None: - self._ideal = value - def make_discounted_returns( - self, - num_steps: int, - rewards: Mapping[str, float] = DEFAULT_REWARD_MAPPING, - discount: float = 1.0, - ) -> list[float]: - try: - return discounted_returns( - # paper-qa has no intermediary rewards - [0] * (num_steps - 1) + [rewards[self.value]], - terminated=[False] * (num_steps - 1) + [True], - discount=discount, - ) - except TypeError as exc: - raise ImportError( - "Making discounted returns requires the 'ldp' extra for 'ldp'. Please:" - " `pip install paper-qa[ldp]`." - ) from exc - - @classmethod - def from_answer( - cls, - text: str, - ideal_mc_answer: str, - unsure_mc_answer: str | None = None, - total_options: int | None = None, - ) -> LitQAEvaluation: - """Compare text with a multiple choice answer or optionally an unsure answer.""" - - def extract_answer(answer: str) -> str: - # first capital letter, like A or A) - s = re.search(r"([A-Z])\)?", answer, re.DOTALL) - if s is not None: - return s.group(1) - return answer.split()[0][0].upper() - - result = extract_answer(text) - if ( - total_options is not None - and ord(result[0]) - _CAPITAL_A_INDEX + 1 > total_options - ): - # The result extracted was not in the options - evaluation = cls.INCORRECT - evaluation.answer = result - # From here, if we don't match either the ideal or the unsure multiple choice - # options then we declare the answer as incorrect. - elif unsure_mc_answer and result[0].lower() == unsure_mc_answer[0].lower(): - evaluation = cls.UNSURE - evaluation.answer = unsure_mc_answer - elif result[0].lower() == ideal_mc_answer[0].lower(): - evaluation = cls.CORRECT - evaluation.answer = ideal_mc_answer - else: - evaluation = cls.INCORRECT - evaluation.answer = result - evaluation.ideal = ideal_mc_answer - return evaluation - - @classmethod - def from_question( - cls, - ideal: str, - distractors: str | list[str], - question: str, - use_unsure: bool = True, - eval_model: LLMModel | str = DEFAULT_EVAL_MODEL_NAME, - seed: int | Literal["SEED_USING_QUESTION"] | None = None, - ) -> tuple[str, Callable[[PQASession | str], Awaitable[LitQAEvaluation]]]: - """ - Create a LitQA question and an answer-to-evaluation function. - - Args: - ideal: Ideal answer term's text (not a multiple choice letter). - distractors: Distractor terms' text (not multiple choice letters). - question: Question text. - use_unsure: Flag (default is enabled) to add an 'insufficient answer' term. - eval_model: Evaluation model to use for multiple choice letter extraction - from a text answer. - seed: Optional seed to use in randomization of multiple choice letters. - Optionally pass in the string literal "SEED_USING_QUESTION" to hash the - input question for the seed. - - Returns: - Two-tuple of created LitQA question, function (that can be thought of as - stateless) to use to extract an evaluation result from an answer. - """ - if seed == SEED_USING_QUESTION: - seed = hash(question) - text, ideal_answer, unsure_answer, distractor_answers = make_mc_options( - ideal=ideal, - distractors=distractors, - seed=seed, - **({} if use_unsure else {"unsure_option": None}), - ) - qa_prompt = QA_PROMPT_TEMPLATE.format(question=question, options=text) - - if isinstance(eval_model, str): - eval_model = LiteLLMModel( - name=eval_model, - config=make_default_litellm_model_list_settings(eval_model), - ) - - async def llm_from_answer(answer: PQASession | str) -> LitQAEvaluation: - if isinstance(answer, PQASession): - answer = answer.answer - eval_chunk = await eval_model.achat( - messages=[ - Message( - role="user", - content=EVAL_PROMPT_TEMPLATE.format( - qa_prompt=qa_prompt, qa_answer=answer - ), - ), - ] - ) - if not isinstance(eval_chunk.text, str): - raise NotImplementedError( - f"Expected evaluation chunk to be a string, not {eval_chunk.text}." - ) - evaluation = cls.from_answer( - text=eval_chunk.text, - ideal_mc_answer=ideal_answer, - unsure_mc_answer=unsure_answer, - total_options=len(distractor_answers) + (2 if use_unsure else 1), - ) - # convert MC answers back to full text option so that it - # is meaningful - evaluation.ideal = ideal - if evaluation == cls.CORRECT: - evaluation.answer = ideal - elif evaluation == cls.UNSURE: - evaluation.answer = UNSURE_OPTION - else: - try: - evaluation.answer = distractors[ - distractor_answers.index(evaluation.answer or "") - ] - except ValueError: - evaluation.answer = None - return evaluation - - return qa_prompt, llm_from_answer - - @classmethod - def calculate_accuracy_precision( - cls, evaluations: Sequence[Self | str] - ) -> tuple[float, float]: - """ - Calculate LitQA-specific accuracy and precision metrics upon evaluations. - - Raises: - ZeroDivisionError: if an empty input. - - Returns: - Two-tuple of accuracy = (num correct) / (num questions) and - precision = (num correct) / ((num questions) - (num unsure)). - """ - evaluations = [e if isinstance(e, cls) else cls(e) for e in evaluations] - num_correct = sum(e == cls.CORRECT for e in evaluations) - accuracy = num_correct / len(evaluations) - precision = num_correct / sum( - e in {cls.CORRECT, cls.INCORRECT} for e in evaluations +def make_discounted_returns( + evaluation: MultipleChoiceEvaluation, + num_steps: int, + rewards: Mapping[str, float] = DEFAULT_REWARD_MAPPING, + discount: float = 1.0, +) -> list[float]: + try: + return discounted_returns( + # paper-qa has no intermediary rewards + [0] * (num_steps - 1) + [rewards[evaluation.value]], + terminated=[False] * (num_steps - 1) + [True], + discount=discount, ) - return accuracy, precision + except TypeError as exc: + raise ImportError( + "Making discounted returns requires the 'ldp' extra for 'ldp'. Please:" + " `pip install paper-qa[ldp]`." + ) from exc DEFAULT_LABBENCH_HF_HUB_NAME = "futurehouse/lab-bench" diff --git a/paperqa/types.py b/paperqa/types.py index 007caf88..db932b08 100644 --- a/paperqa/types.py +++ b/paperqa/types.py @@ -132,7 +132,19 @@ class PQASession(BaseModel): context: str = "" contexts: list[Context] = Field(default_factory=list) references: str = "" - formatted_answer: str = "" + formatted_answer: str = Field( + default="", + description=( + "Optional prettified answer that includes information like question and" + " citations." + ), + ) + graded_answer: str | None = Field( + default=None, + description=( + "Optional graded answer, used for things like multiple choice questions." + ), + ) cost: float = 0.0 # Map model name to a two-item list of LLM prompt token counts # and LLM completion token counts diff --git a/tests/test_litqa.py b/tests/test_litqa.py index 18e03f48..0a36cefc 100644 --- a/tests/test_litqa.py +++ b/tests/test_litqa.py @@ -1,249 +1,56 @@ -from collections.abc import Sequence from typing import cast import pytest +from aviary.utils import MultipleChoiceEvaluation, MultipleChoiceQuestion -from paperqa.litqa import ( - SEED_USING_QUESTION, - UNSURE_OPTION, - LitQAEvaluation, - read_litqa_v2_from_hub, -) -from tests.conftest import VCR_DEFAULT_MATCH_ON - +from paperqa.litqa import make_discounted_returns, read_litqa_v2_from_hub -class TestLitQAEvaluation: - @staticmethod - def _assert_prompt_is_valid( - qa_prompt: str, question: str, ideal: str, distractors: Sequence[str] - ) -> None: - for substr in (question, "Insufficient information", ideal, *distractors): - assert qa_prompt.count(substr) == 1 - # Use for general purpose testing - ZIP_CODE_QUESTION_IDEAL_DISTRACTORS = ( - "What is my office's zip code?", - "94107", - ["-8", "94106", "cheesecake"], - ) - # The following two are used to check we don't leak on the LLM's innate knowledge - MEANING_OF_LIFE_QUESTION_IDEAL_DISTRACTORS = ( - "What is the meaning of life?", - "42", - ["-84", "11", "cheesecake"], +@pytest.mark.parametrize( + ("evaluation", "expected_dreturns"), + [ + (MultipleChoiceEvaluation.CORRECT, [0.25, 0.5, 1.0]), + (MultipleChoiceEvaluation.INCORRECT, [-0.25, -0.5, -1.0]), + (MultipleChoiceEvaluation.UNSURE, [0.025, 0.05, 0.1]), + ], +) +def test_make_discounted_returns( + evaluation: MultipleChoiceEvaluation, expected_dreturns: list[float] +) -> None: + assert ( + make_discounted_returns(evaluation, num_steps=3, discount=0.5) + == expected_dreturns ) - # Source: https://github.com/Future-House/LAB-Bench/blob/43b2045c67a2da12c233689cf538f1ed5c42f590/LitQA2/litqa-v2-public.jsonl#L130 - LITQA2_QUESTION_IDEAL_DISTRACTORS = ( + + +def test_creating_litqa_questions() -> None: + """Test making LitQA eval questions after downloading from Hugging Face Hub.""" + _, eval_split = read_litqa_v2_from_hub(seed=42) + assert len(eval_split) > 3 + assert [ + MultipleChoiceQuestion( + question=cast(str, row.question), + options=cast(list[str], row.distractors), + ideal_answer=cast(str, row.ideal), + shuffle_seed=42, + ).question_prompt + for row in eval_split[:3].itertuples() + ] == [ ( - "What method was used to demonstrate that the enzyme PafA is stable after" - " incubation with 4M urea for 14 days?" + "Q: Which of the following mutations in yeast Pbs2 increases its" + " interaction with SH3?\n\nOptions:\nA) P97A\nB) N92S\nC) Insufficient" + " information to answer this question\nD) K85W\nE) N92H\nF) I87W\nG) S83F" ), - "circular dichroism", - ["cryo EM", "x-ray crystallography", "NMR"], - ) - - @pytest.mark.asyncio - @pytest.mark.vcr(match_on=[*VCR_DEFAULT_MATCH_ON, "body"]) - @pytest.mark.parametrize( ( - "question", - "ideal", - "distractors", - "answer", - "expected_eval", - "expected_dreturns", - "extracted_answer", + "Q: What percentage of colorectal cancer-associated fibroblasts typically" + " survive at 2 weeks if cultured with the platinum-based chemotherapy" + " oxaliplatin?\n\nOptions:\nA) Insufficient information to answer this" + " question\nB) 0%\nC) 50-80%\nD) 20-50%\nE) 1-20%\nF) 80-99%" ), - [ - pytest.param( - *ZIP_CODE_QUESTION_IDEAL_DISTRACTORS, - "the answer is 94107", - LitQAEvaluation.CORRECT, - [0.25, 0.5, 1.0], - "94107", - id="matched-correct-option", - ), - pytest.param( - *ZIP_CODE_QUESTION_IDEAL_DISTRACTORS, - "the answer is 14004", - LitQAEvaluation.INCORRECT, - [-0.25, -0.5, -1.0], - None, - id="didnt-match-and-no-llm-innate-knowledge", - ), - pytest.param( - *ZIP_CODE_QUESTION_IDEAL_DISTRACTORS, - "the answer is 94106", - LitQAEvaluation.INCORRECT, - [-0.25, -0.5, -1.0], - "94106", - id="matched-incorrect-option", - ), - pytest.param( - *ZIP_CODE_QUESTION_IDEAL_DISTRACTORS, - "Insufficient information", - LitQAEvaluation.UNSURE, - [0.025, 0.05, 0.1], - UNSURE_OPTION, - id="matched-unsure-option", - ), - pytest.param( - *ZIP_CODE_QUESTION_IDEAL_DISTRACTORS, - "the answer is 94106 or 94107", - LitQAEvaluation.INCORRECT, - [-0.25, -0.5, -1.0], - None, - id="matched-several-options", - ), - pytest.param( - *ZIP_CODE_QUESTION_IDEAL_DISTRACTORS, - "", - LitQAEvaluation.INCORRECT, - [-0.25, -0.5, -1.0], - None, - id="empty-answer1", - ), - pytest.param( - *MEANING_OF_LIFE_QUESTION_IDEAL_DISTRACTORS, - "14", - LitQAEvaluation.INCORRECT, - [-0.25, -0.5, -1.0], - None, - id="didnt-match-and-llm-has-innate-knowledge", - ), - pytest.param( - *MEANING_OF_LIFE_QUESTION_IDEAL_DISTRACTORS, - "", - LitQAEvaluation.INCORRECT, - [-0.25, -0.5, -1.0], - None, - id="empty-answer2", - ), - pytest.param( - *LITQA2_QUESTION_IDEAL_DISTRACTORS, - "", - LitQAEvaluation.INCORRECT, - [-0.25, -0.5, -1.0], - None, - id="empty-answer3", - ), - ], - ) - async def test_from_question( - self, - question: str, - ideal: str, - distractors: str | list[str], - answer: str, - expected_eval: LitQAEvaluation, - expected_dreturns: list[float], - extracted_answer: str, - ) -> None: - """Tests that we can create a LitQA question and evaluate answers.""" - qa_prompt, eval_fn = LitQAEvaluation.from_question( - ideal=ideal, - distractors=distractors, - question=question, - seed=42, # Seed for VCR cassette - ) - self._assert_prompt_is_valid(qa_prompt, question, ideal, distractors) - - evaluation = await eval_fn(answer) - assert evaluation == expected_eval - if evaluation == LitQAEvaluation.CORRECT: - assert evaluation.answer == ideal - assert evaluation.answer == extracted_answer - assert evaluation.ideal == ideal - assert evaluation.make_discounted_returns(3, discount=0.5) == expected_dreturns - - def test_consistent_mc_options(self) -> None: - """Tests that creating multiple evaluations with the same seed results in the same prompt.""" - question, ideal, distractors = self.MEANING_OF_LIFE_QUESTION_IDEAL_DISTRACTORS - - qa_prompt_1a, _ = LitQAEvaluation.from_question( - ideal=ideal, distractors=distractors, question=question, seed=0 - ) - self._assert_prompt_is_valid(qa_prompt_1a, question, ideal, distractors) - - qa_prompt_1b, _ = LitQAEvaluation.from_question( - ideal=ideal, distractors=distractors, question=question, seed=0 - ) - self._assert_prompt_is_valid(qa_prompt_1b, question, ideal, distractors) - assert qa_prompt_1a == qa_prompt_1b, "Same seeding should lead to same prompts" - - qa_prompt_2a, _ = LitQAEvaluation.from_question( - ideal=ideal, - distractors=distractors, - question=question, - seed=SEED_USING_QUESTION, - ) - self._assert_prompt_is_valid(qa_prompt_2a, question, ideal, distractors) - - qa_prompt_2b, _ = LitQAEvaluation.from_question( - ideal=ideal, - distractors=distractors, - question=question, - seed=SEED_USING_QUESTION, - ) - self._assert_prompt_is_valid(qa_prompt_2b, question, ideal, distractors) - assert ( - qa_prompt_2a == qa_prompt_2b - ), "Same seeding strategy should lead to same prompts" - assert ( - qa_prompt_2a != qa_prompt_1a - ), "Different seeding strategies should lead to different prompts" - - def test_creating_litqa_questions(self) -> None: - """Test making LitQA eval questions after downloading from Hugging Face Hub.""" - _, eval_split = read_litqa_v2_from_hub(seed=42) - assert len(eval_split) > 3 - assert [ - LitQAEvaluation.from_question( - ideal=cast(str, row.ideal), - distractors=cast(list[str], row.distractors), - question=cast(str, row.question), - seed=42, - )[0] - for row in eval_split[:3].itertuples() - ] == [ - ( - "Q: Which of the following mutations in yeast Pbs2 increases its" - " interaction with SH3?\n\nOptions:\nA) S83F\nB) I87W\nC) N92H\nD) K85W\nE)" - " Insufficient information to answer this question\nF) N92S\nG) P97A" - ), - ( - "Q: What percentage of colorectal cancer-associated fibroblasts typically" - " survive at 2 weeks if cultured with the platinum-based chemotherapy" - " oxaliplatin?\n\nOptions:\nA) 80-99%\nB) 1-20%\nC) 20-50%\nD) 50-80%\nE)" - " 0%\nF) Insufficient information to answer this question" - ), - ( - "Q: Which of the following genes shows the greatest difference in gene" - " expression between homologous cell types in mouse and human" - " brain?\n\nOptions:\nA) Htr3a\nB) Htr5a\nC) Htr6\nD) Insufficient" - " information to answer this question\nE) Htr1d" - ), - ] - - @pytest.mark.parametrize( - ("evals", "accuracy_precision"), - [ - ( - [ - LitQAEvaluation.CORRECT, - LitQAEvaluation.CORRECT, - LitQAEvaluation.CORRECT, - ], - (1, 1), - ), - (["correct", "correct", "unsure"], (2 / 3, 1)), - ( - [LitQAEvaluation.CORRECT, LitQAEvaluation.UNSURE, "incorrect"], - (1 / 3, 1 / 2), - ), - ], - ) - def test_calculate_accuracy_precision( - self, evals: Sequence[LitQAEvaluation], accuracy_precision: tuple[float, float] - ) -> None: - assert LitQAEvaluation.calculate_accuracy_precision(evals) == accuracy_precision + ( + "Q: Which of the following genes shows the greatest difference in gene" + " expression between homologous cell types in mouse and human" + " brain?\n\nOptions:\nA) Htr1d\nB) Insufficient information to answer this" + " question\nC) Htr6\nD) Htr5a\nE) Htr3a" + ), + ] diff --git a/tests/test_task.py b/tests/test_task.py index e1f14d83..dfe7a452 100644 --- a/tests/test_task.py +++ b/tests/test_task.py @@ -5,6 +5,7 @@ import pytest from aviary.env import TASK_DATASET_REGISTRY, TaskConfig, TaskDataset +from aviary.utils import MultipleChoiceEvaluation, MultipleChoiceQuestion from ldp.agent import SimpleAgent from ldp.alg.callbacks import Callback, MeanMetricsCallback, StoreTrajectoriesCallback from ldp.alg.runners import Evaluator, EvaluatorConfig @@ -20,7 +21,7 @@ LitQAv2TaskSplit, ) from paperqa.agents.tools import GenerateAnswer -from paperqa.litqa import DEFAULT_REWARD_MAPPING, SEED_USING_QUESTION, LitQAEvaluation +from paperqa.litqa import DEFAULT_REWARD_MAPPING @pytest.fixture(name="base_query_request") @@ -60,7 +61,7 @@ def __init__(self, *args, **kwargs): def get_new_env_by_idx(self, idx: int) -> GradablePaperQAEnvironment: return self._make_gradable_environment( - ideal=self.data[idx][0], + ideal_answer=self.data[idx][0], distractors=self.data[idx][1], question=self.data[idx][2], sources=self.data[idx][3], @@ -87,7 +88,10 @@ def __init__(self): self.query_to_envs: dict[str, PaperQAEnvironment] = {} async def before_rollout(self, traj_id: str, env) -> None: # noqa: ARG002 - self.query_to_envs[env._query.query] = env + query: str | MultipleChoiceQuestion = env._query.query + self.query_to_envs[ + query if isinstance(query, str) else query.question_prompt + ] = env class TestTaskDataset: @@ -120,9 +124,9 @@ async def test___len__( obs, _ = await env.reset() assert ( "Q: SLC14A1 been identified as a specific marker for endothelial" - " cells in which organ?\n\nOptions:\nA) heart\nB) eye\nC)" - " prostate\nD) Insufficient information to answer this question\nE)" - " liver" in (obs[0].content or "") + " cells in which organ?\n\nOptions:\nA) liver\nB) eye\nC)" + " prostate\nD) heart\nE) Insufficient information to answer this" + " question" in (obs[0].content or "") ) assert env.sources, "Sources need to be accessible" assert isinstance( @@ -159,7 +163,7 @@ async def test_evaluation( "deleted_dockeys", } ), - "question_kwargs": {"seed": SEED_USING_QUESTION}, + "question_kwargs": {"seed": MultipleChoiceQuestion.SEED_USING_QUESTION}, }, ) # NOTE: set base_query after construction of the TaskConfig. because in @@ -193,7 +197,10 @@ async def test_evaluation( assert metrics_callback.eval_means["reward"] > 0, "Expected some wins" correct_reward, incorrect_reward = ( DEFAULT_REWARD_MAPPING[evaluation.value] - for evaluation in (LitQAEvaluation.CORRECT, LitQAEvaluation.INCORRECT) + for evaluation in ( + MultipleChoiceEvaluation.CORRECT, + MultipleChoiceEvaluation.INCORRECT, + ) ) worst_case_reward_given_correct = ( correct_reward * correct_percentage