Skip to content

Commit

Permalink
Moved the repo to use MultipleChoiceQuestion and MultipleChoiceEvalua…
Browse files Browse the repository at this point in the history
…tion from aviary
  • Loading branch information
jamesbraza committed Dec 19, 2024
1 parent ec989ab commit 3c83946
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 549 deletions.
4 changes: 3 additions & 1 deletion paperqa/agents/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
ToolRequestMessage,
ToolResponseMessage,
)
from aviary.utils import MultipleChoiceQuestion
from llmclient import EmbeddingModel, LiteLLMModel

from paperqa.docs import Docs
Expand Down Expand Up @@ -127,10 +128,11 @@ def make_tools(self) -> list[Tool]:
)

def make_initial_state(self) -> EnvironmentState:
query: str | MultipleChoiceQuestion = self._query.query
return EnvironmentState(
docs=self._docs,
session=PQASession(
question=self._query.query,
question=query if isinstance(query, str) else query.question_prompt,
config_md5=self._query.settings.md5,
id=self._query.id,
),
Expand Down
9 changes: 8 additions & 1 deletion paperqa/agents/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Any, ClassVar, Protocol
from uuid import UUID, uuid4

from aviary.utils import MultipleChoiceQuestion
from llmclient import LiteLLMModel, LLMModel
from pydantic import (
BaseModel,
Expand Down Expand Up @@ -55,7 +56,13 @@ class MismatchedModelsError(Exception):
class QueryRequest(BaseModel):
model_config = ConfigDict(extra="forbid")

query: str = ""
query: str | MultipleChoiceQuestion = Field(
default="",
description=(
"The query to be answered. Set to a multiple choice question when grading"
" (e.g. for training)."
),
)
id: UUID = Field(
default_factory=uuid4,
description="Identifier which will be propagated to the Answer object.",
Expand Down
57 changes: 23 additions & 34 deletions paperqa/agents/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,27 @@

from aviary.core import (
TASK_DATASET_REGISTRY,
Frame,
Messages,
TaskDataset,
ToolRequestMessage,
ToolResponseMessage,
)
from aviary.env import ENV_REGISTRY
from aviary.utils import (
DEFAULT_EVAL_MODEL_NAME,
MultipleChoiceEvaluation,
MultipleChoiceQuestion,
)
from llmclient import EmbeddingModel, LiteLLMModel, LLMModel

from paperqa._ldp_shims import ComputeTrajectoryMetricsMixin
from paperqa.docs import Docs
from paperqa.litqa import (
DEFAULT_EVAL_MODEL_NAME,
DEFAULT_LABBENCH_HF_HUB_NAME,
DEFAULT_REWARD_MAPPING,
LitQAEvaluation,
read_litqa_v2_from_hub,
)
from paperqa.types import DocDetails, PQASession
from paperqa.types import DocDetails

from .env import POPULATE_FROM_SETTINGS, PaperQAEnvironment
from .models import QueryRequest
Expand All @@ -58,26 +60,22 @@ def __init__(
llm_model: LiteLLMModel | None = POPULATE_FROM_SETTINGS,
summary_llm_model: LiteLLMModel | None = POPULATE_FROM_SETTINGS,
embedding_model: EmbeddingModel | None = POPULATE_FROM_SETTINGS,
evaluation_from_answer: (
Callable[[PQASession | str], Awaitable[LitQAEvaluation]] | None
) = None,
sources: str | list[str] | None = None,
rewards: Mapping[str, float] = DEFAULT_REWARD_MAPPING,
evaluation_callback: Callable[[LitQAEvaluation], Awaitable] | None = None,
evaluation_callback: (
Callable[[MultipleChoiceEvaluation], Awaitable] | None
) = None,
**env_kwargs,
):
super().__init__(
query, docs, llm_model, summary_llm_model, embedding_model, **env_kwargs
)
self._evaluation_from_answer = evaluation_from_answer
# Enables checking an Index has the right DOI(s)
self.sources: list[str] | None = (
[sources] if isinstance(sources, str) else sources
)
self._evaluation_callback = evaluation_callback
self._rewards = rewards
self.answer = ""
self.ideal = ""

async def validate_sources(
self, manifest_or_index: dict[str, DocDetails] | SearchIndex | None = None
Expand Down Expand Up @@ -120,7 +118,7 @@ async def step(
self, action: ToolRequestMessage
) -> tuple[Messages, float, bool, bool]:
messages, reward, done, truncated = await super().step(action)
if not done or not self._evaluation_from_answer:
if not done or not isinstance(self._query.query, MultipleChoiceQuestion):
return messages, reward, done, truncated
# If the ensuring evaluation fails (e.g. due to OpenAI being down), we can:
# - Suppress the exception and declare the evaluation as incorrect, which can
Expand All @@ -130,23 +128,13 @@ async def step(
# incorrectly reward what otherwise was a good trajectory.
# - Don't suppress the exception, which leads to the trajectory failing, and
# removes it from the learnable pool. This is the only safe default behavior.
evaluation = await self._evaluation_from_answer(self.state.session.answer)
evaluation, self.state.session.graded_answer = await self._query.query.grade(
self.state.session.answer
)
if evaluation_callback := self._evaluation_callback:
await evaluation_callback(evaluation)
self.answer = evaluation.answer or ""
self.ideal = evaluation.ideal or ""
return messages, reward + self._rewards[evaluation.value], done, truncated

def export_frame(self) -> Frame:
return Frame(
state=self.state,
info={
"query": self._query,
"answer": self.answer,
"ideal": self.ideal,
},
)

def __deepcopy__(self, memo) -> Self:
copy_state = deepcopy(self.state, memo)
# We don't know the side effects of deep copying a litellm.Router,
Expand All @@ -162,7 +150,6 @@ def __deepcopy__(self, memo) -> Self:
copy_self = type(self)(
query=deepcopy(self._query, memo), # deepcopy for _docs_name
docs=copy_state.docs,
evaluation_from_answer=self._evaluation_from_answer,
sources=self.sources,
rewards=self._rewards,
evaluation_callback=self._evaluation_callback,
Expand Down Expand Up @@ -218,24 +205,26 @@ def __init__(

def _make_gradable_environment(
self,
ideal: str,
ideal_answer: str,
distractors: str | list[str],
question: str,
sources: str | list[str] | None = None,
) -> GradablePaperQAEnvironment:
qa_prompt, evaluation_from_answer = LitQAEvaluation.from_question(
ideal=ideal,
distractors=distractors,
mc_question = MultipleChoiceQuestion(
question=question,
eval_model=self._eval_model,
options=(
distractors
if isinstance(distractors, list)
else MultipleChoiceQuestion.split_options(distractors)
),
ideal_answer=ideal_answer,
**(self._question_kwargs or {}),
)
query = self._base_query.model_copy()
query.query = qa_prompt
query.query = mc_question
return GradablePaperQAEnvironment(
query=query,
docs=self._base_docs.model_copy(),
evaluation_from_answer=evaluation_from_answer,
sources=sources,
rewards=self._rewards,
**self._env_kwargs,
Expand Down Expand Up @@ -338,7 +327,7 @@ def get_new_env_by_idx(self, idx: int) -> GradablePaperQAEnvironment:
) from exc
sources.append(doi)
return self._make_gradable_environment(
ideal=self.data.iloc[idx].ideal,
ideal_answer=self.data.iloc[idx].ideal,
distractors=self.data.iloc[idx].distractors,
question=self.data.iloc[idx].question,
sources=sources,
Expand Down
Loading

0 comments on commit 3c83946

Please sign in to comment.