From 9b40b19340b83b776fc92489898cd53336fd88f1 Mon Sep 17 00:00:00 2001 From: James Braza Date: Wed, 18 Dec 2024 00:04:16 -0800 Subject: [PATCH] Added consensus sampling helper function and storage callback --- paperqa/_ldp_shims.py | 9 +++- paperqa/agents/task.py | 106 +++++++++++++++++++++++++++++++++++++++-- 2 files changed, 109 insertions(+), 6 deletions(-) diff --git a/paperqa/_ldp_shims.py b/paperqa/_ldp_shims.py index 39c8ef4e..a977960b 100644 --- a/paperqa/_ldp_shims.py +++ b/paperqa/_ldp_shims.py @@ -15,6 +15,7 @@ "UIndexMemoryModel", "_Memories", "discounted_returns", + "evaluate_consensus", "set_training_mode", ] @@ -29,7 +30,12 @@ SimpleAgent, SimpleAgentState, ) - from ldp.alg import Callback, ComputeTrajectoryMetricsMixin, RolloutManager + from ldp.alg import ( + Callback, + ComputeTrajectoryMetricsMixin, + RolloutManager, + evaluate_consensus, + ) from ldp.graph.memory import Memory, UIndexMemoryModel from ldp.graph.op_utils import set_training_mode from ldp.utils import discounted_returns @@ -48,3 +54,4 @@ class Callback: # type: ignore[no-redef] RolloutManager = None # type: ignore[assignment,misc] discounted_returns = None # type: ignore[assignment] + evaluate_consensus = None # type: ignore[assignment] diff --git a/paperqa/agents/task.py b/paperqa/agents/task.py index 9c99d1b0..f3b977e8 100644 --- a/paperqa/agents/task.py +++ b/paperqa/agents/task.py @@ -10,13 +10,15 @@ import logging import re from abc import ABC -from collections.abc import Awaitable, Callable, Mapping, Sequence +from collections.abc import Awaitable, Callable, Iterable, Mapping, Sequence from copy import deepcopy from enum import StrEnum from typing import TYPE_CHECKING, Any, Self, assert_never from aviary.core import ( TASK_DATASET_REGISTRY, + Environment, + Frame, Messages, TaskDataset, ToolRequestMessage, @@ -30,22 +32,27 @@ ) from llmclient import EmbeddingModel, LiteLLMModel, LLMModel -from paperqa._ldp_shims import ComputeTrajectoryMetricsMixin +from paperqa._ldp_shims import ( + Callback, + ComputeTrajectoryMetricsMixin, + evaluate_consensus, +) from paperqa.docs import Docs from paperqa.litqa import ( DEFAULT_LABBENCH_HF_HUB_NAME, DEFAULT_REWARD_MAPPING, read_litqa_v2_from_hub, ) -from paperqa.types import DocDetails +from paperqa.types import DocDetails, PQASession from .env import POPULATE_FROM_SETTINGS, PaperQAEnvironment from .models import QueryRequest from .search import SearchIndex, maybe_get_manifest -from .tools import Complete +from .tools import Complete, EnvironmentState if TYPE_CHECKING: - from ldp.data_structures import Trajectory + from ldp.agent import Agent + from ldp.data_structures import Trajectory, Transition logger = logging.getLogger(__name__) @@ -169,6 +176,95 @@ def __deepcopy__(self, memo) -> Self: ) +async def evaluate_consensus_sampling( + data: Iterable[GradablePaperQAEnvironment | Frame], + num_samples: int = 1, + seed: int | None = None, +) -> tuple[dict[str, list[tuple[str, int]]], float]: + def get_question(x: GradablePaperQAEnvironment | Frame) -> str: + if isinstance(x, GradablePaperQAEnvironment): + query: str | MultipleChoiceQuestion | dict[str, Any] = x._query.query + else: + qr: QueryRequest | dict[str, Any] = x.info["query"] + query = qr.query if isinstance(qr, QueryRequest) else qr["query"] + if isinstance(query, str): + return query + if isinstance(query, MultipleChoiceQuestion): + return query.question_prompt + return query["question"] + + def extract_answer(x: GradablePaperQAEnvironment | Frame) -> str: + sess: PQASession | dict[str, Any] = ( + x.state.session + if isinstance(x.state, EnvironmentState) + else x.state["session"] + ) + return ( + sess.graded_answer + if isinstance(sess, PQASession) + else sess["graded_answer"] + ) or "" + + def extract_ideal(x: GradablePaperQAEnvironment | Frame) -> str: + if isinstance(x, GradablePaperQAEnvironment): + query: str | MultipleChoiceQuestion | dict[str, Any] = x._query.query + else: + qr: QueryRequest | dict[str, Any] = x.info["query"] + query = qr.query if isinstance(qr, QueryRequest) else qr["query"] + if isinstance(query, str): + raise ValueError( # noqa: TRY004 + "We require a {MultipleChoiceQuestion.__name__} variant to extract" + " ideal answer, not a string." + ) + if isinstance(query, MultipleChoiceQuestion): + return query.ideal_answer + return query["ideal_answer"] + + try: + return await evaluate_consensus( + data=data, + grouping_fn=get_question, + extract_answer_fn=extract_answer, + ideal_answer_fn=extract_ideal, + num_samples=num_samples, + seed=seed, + ) + except TypeError: + raise ImportError( + "Evaluating consensus requires the 'ldp' extra for 'ldp'. Please:" + " `pip install paper-qa[ldp]`." + ) from None + + +class StoreForConsensusSamplingCallback(Callback): + def __init__(self): + super().__init__() + self.stored: list[GradablePaperQAEnvironment | Frame] = [] + + async def after_transition( + self, + traj_id: str, # noqa: ARG002 + agent: "Agent", # noqa: ARG002 + env: Environment, + transition: "Transition", + ) -> None: + if not isinstance(env, GradablePaperQAEnvironment): + raise NotImplementedError( + f"So far only handled {GradablePaperQAEnvironment} in this callback," + f" not {type(env)}." + ) + if not transition.done: # Only store once + return + self.stored.append(env) + + async def evaluate_consensus_sampling( + self, num_samples: int = 1, seed: int | None = None + ) -> tuple[dict[str, list[tuple[str, int]]], float]: + return await evaluate_consensus_sampling( + data=self.stored, num_samples=num_samples, seed=seed + ) + + class LitQATaskDataset( TaskDataset[GradablePaperQAEnvironment], ComputeTrajectoryMetricsMixin, ABC ):