From 7ae5d160bf1b42c77222039945f70a23af45eb92 Mon Sep 17 00:00:00 2001 From: Jett Date: Wed, 7 Feb 2024 21:51:53 -0800 Subject: [PATCH] eval.utils.load_validation_dataset (#25) --- src/delphi/eval/utils.py | 17 +++++++++++++++++ tests/eval/test_utils.py | 7 ++++++- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/src/delphi/eval/utils.py b/src/delphi/eval/utils.py index 1ad7c256..ee9893d9 100644 --- a/src/delphi/eval/utils.py +++ b/src/delphi/eval/utils.py @@ -1,6 +1,8 @@ from collections.abc import Callable +from typing import cast import torch +from datasets import Dataset, load_dataset from jaxtyping import Float, Int @@ -25,3 +27,18 @@ def get_next_logprobs( logprobs = get_all_logprobs(model, input_ids[:, :-1]) next_tokens = input_ids[:, 1:] return gather_logprobs(logprobs, next_tokens) + + +def load_validation_dataset(dataset_name: str) -> Dataset: + if "/" not in dataset_name: + dataset_name = f"delphi-suite/{dataset_name}" + data_str = f"data/validation-*.parquet" + dataset = load_dataset( + dataset_name, + data_files=data_str, + verification_mode="no_checks", + # this seems to be the only split when using data_files + # regardless of the files we're actually loading + split="train", + ) + return cast(Dataset, dataset) diff --git a/tests/eval/test_utils.py b/tests/eval/test_utils.py index cefae455..f1d1c875 100644 --- a/tests/eval/test_utils.py +++ b/tests/eval/test_utils.py @@ -1,6 +1,6 @@ import torch -from delphi.eval.utils import gather_logprobs +from delphi.eval.utils import gather_logprobs, load_validation_dataset def test_gather_logprobs(): @@ -41,3 +41,8 @@ def test_gather_logprobs(): ) result = gather_logprobs(logprobs, tokens) assert torch.allclose(result, expected_output) + + +def test_load_validation_dataset(): + text = load_validation_dataset("tinystories-v2-clean") + tokenized = load_validation_dataset("tinystories-v2-clean-tokenized-v0")