Skip to content

Commit

Permalink
Added question ID to HotPotQAEnv (#181)
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesbraza authored Jan 27, 2025
1 parent 1c5cf12 commit f4ebd73
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 12 deletions.
47 changes: 36 additions & 11 deletions packages/hotpotqa/src/aviary/envs/hotpotqa/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
import string
from collections.abc import Callable
from enum import StrEnum
from typing import Any, ClassVar, cast
from typing import TYPE_CHECKING, Any, ClassVar, cast
from uuid import UUID

import httpx
from bs4 import BeautifulSoup
Expand All @@ -40,6 +41,9 @@
eval_answer,
)

if TYPE_CHECKING:
from datasets import Dataset

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -180,6 +184,7 @@ class HotPotQAEnv(Environment[HotPotQAEnvState]):

def __init__(
self,
question_id: UUID | None,
question: str,
correct_answer: str | float,
correct_reward: float = 1.0,
Expand All @@ -189,6 +194,7 @@ def __init__(
proxy: str | None = None,
):
super().__init__()
self.question_id = question_id
self.question = question
# Normalize the correct answer once here as a minor performance optimization
self.normalized_correct_answer = normalize_answer(correct_answer)
Expand All @@ -214,7 +220,7 @@ def __init__(

@classmethod
def from_task(cls, task: str) -> "HotPotQAEnv":
return cls(question=task, correct_answer=0.0)
return cls(question_id=None, question=task, correct_answer=0.0)

async def calculate_answer_reward(self, answer: str | None) -> float:
"""Calculate the reward based on the agent's answer.
Expand Down Expand Up @@ -549,32 +555,51 @@ class HotPotQADataset(TaskDataset[HotPotQAEnv]):
# SEE: https://huggingface.co/datasets/hotpotqa/hotpot_qa
HOTPOTQA_HUGGING_FACE_DATASET = "hotpotqa/hotpot_qa"

def get_data_from_hugging_face(
self, split: str, hf_dataset: str = HOTPOTQA_HUGGING_FACE_DATASET
) -> list[tuple[str, str]]:
"""Convert a local file and split to a list of (question, answer) tuples."""
@staticmethod
def load_raw(
split: str, hf_dataset: str = HOTPOTQA_HUGGING_FACE_DATASET
) -> "Dataset":
"""Load the HotPotQA dataset split from Hugging Face."""
if split in {"dev", "eval", "val"}: # Map common aliases
split = "validation"
all_datasets = load_dataset(hf_dataset, name="fullwiki", trust_remote_code=True)
try:
data = all_datasets[split].select_columns(
column_names=["question", "answer", "level"]
return all_datasets[split].select_columns(
column_names=["id", "question", "answer", "level"]
)
except KeyError as exc:
raise ValueError(
f"Split {split!r} was invalid for Hugging Face dataset {hf_dataset},"
f" please specify a split from {set(all_datasets.keys())}."
) from exc

def get_data_from_hugging_face(
self, split: str, hf_dataset: str = HOTPOTQA_HUGGING_FACE_DATASET
) -> list[tuple[UUID, str, str]]:
"""Get a list of (id, question, answer) tuples for the Hugging Face dataset."""
data = self.load_raw(split, hf_dataset)

if not all( # Making up for datasets not being typed: https://github.com/huggingface/datasets/issues/3841
isinstance(d["question"], str) and isinstance(d["answer"], str | float)
isinstance(d["id"], str)
and isinstance(d["question"], str)
and isinstance(d["answer"], str | float)
for d in data
):
raise ValueError(
f"Dataset {hf_dataset!r} and split {split!r} contains invalid"
" question(s) or answer(s)."
" ID(s), question(s), or answer(s)."
)
return [
(
# Sadly, using UUID v4 doesn't work with left padding (has collisions)
# or right padding (is a lossy conversion), so leave version unspecified
UUID(bytes=b"\x00" * 4 + bytes.fromhex(d["id"])), # Left zero-pad
d["question"],
d["answer"],
)
return [(d["question"], d["answer"]) for d in data if self._filter_task(d)]
for d in data
if self._filter_task(d)
]

def __init__(
self, split: str, config: HotPotQAEnvConfig | dict | None = None, **kwargs
Expand Down
38 changes: 37 additions & 1 deletion packages/hotpotqa/tests/test_hotpotqa_env.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
import re
from uuid import UUID

import pytest

Expand All @@ -11,6 +13,7 @@
def test_env_construction() -> None:
hotpotqa_env: HotPotQAEnv = Environment.from_name(
"hotpotqa",
question_id=None,
question=(
"What is the formula for the volume of Abraham Lincoln's favorite hat?"
),
Expand All @@ -19,9 +22,40 @@ def test_env_construction() -> None:
assert isinstance(hotpotqa_env, HotPotQAEnv)


IN_GITHUB_ACTIONS: bool = os.getenv("GITHUB_ACTIONS") == "true"


@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Slow test")
def test_question_id_uniqueness() -> None:
raw_dataset = HotPotQADataset.load_raw(split="dev")
raw_ds_ids: set[str] = set()
raw_ds_ids.update(row["id"] for row in raw_dataset)

dataset = TaskDataset.from_name("hotpotqa", split="dev")

question_ids: set[UUID] = set()
for i in range(len(dataset)):
env = dataset.get_new_env_by_idx(i)
assert isinstance(env, HotPotQAEnv)
assert isinstance(env.question_id, UUID)
question_ids.add(env.question_id)

assert len(raw_ds_ids) == len(dataset) == len(question_ids) == 7405, (
'Expected 7405 examples in "dev" split'
)
converted_back_question_ids = {
str(qid)[8:].replace("-", "") for qid in question_ids
}
assert converted_back_question_ids == raw_ds_ids, (
"Should be able to restore original HotPotQA question ID"
)


def test_dataset_from_name() -> None:
dataset = TaskDataset.from_name("hotpotqa", split="dev")
env_0 = dataset.get_new_env_by_idx(0)
assert isinstance(dataset.get_new_env_by_idx(0), HotPotQAEnv)
assert isinstance(env_0.question_id, UUID)

# double-check we can load with various options
dataset = TaskDataset.from_name(
Expand All @@ -44,7 +78,8 @@ def test_dataset_from_name() -> None:
async def test_tool_results() -> None:
hotpotqa_env: HotPotQAEnv = Environment.from_name(
"hotpotqa",
question=("Which country has a larger population: China or France?"),
question_id=None,
question="Which country has a larger population: China or France?",
correct_answer="China",
)
lookup_pattern = r"^\(Result \d+ / \d+\)\s*(.*)"
Expand Down Expand Up @@ -85,6 +120,7 @@ async def test_answer_evaluation_mode(evaluation_mode: EvalAnswerMode) -> None:
correct_answer = "Golden Gate Bridge"
incorrect_answer = "Bay Bridge"
env = HotPotQAEnv(
question_id=None,
question="What is the reddest bridge in San Francisco?",
correct_answer=correct_answer,
evaluation_mode=evaluation_mode,
Expand Down

0 comments on commit f4ebd73

Please sign in to comment.