Skip to content

Commit

Permalink
get_next_logprobs revamp (#130)
Browse files Browse the repository at this point in the history
* get_next_logprobs revamp

* UNTESTED: support revisions

* revisions -> branches, tested
  • Loading branch information
jettjaniak authored Apr 27, 2024
1 parent 9feac6e commit 3281227
Show file tree
Hide file tree
Showing 7 changed files with 225 additions and 133 deletions.
28 changes: 0 additions & 28 deletions scripts/generate_logprobs.sh

This file was deleted.

147 changes: 147 additions & 0 deletions scripts/get_next_logprobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
#!/usr/bin/env python3
import argparse
from collections.abc import Iterable

import numpy as np
import torch
from datasets import Dataset
from tqdm.auto import trange
from transformers import AutoModelForCausalLM

from delphi import utils
from delphi.eval.utils import get_all_and_next_logprobs

torch.set_grad_enabled(False)


def main(
in_model_repo_id: str,
branches: Iterable[str],
in_dataset_repo_id: str,
split: str,
feature: str,
batch_size: int,
out_repo_id: str,
):
"""
Outputs the log probabilities of the next token for each token in the dataset.
And uploads the resulting dataset to huggingface.
"""
in_dataset_split = utils.load_dataset_split_sequence_int32_feature(
in_dataset_repo_id, split, feature
)
in_dataset_split.set_format("torch")
for branch in branches:
print(f"Loading model='{in_model_repo_id}', {branch=}")
model = AutoModelForCausalLM.from_pretrained(in_model_repo_id, revision=branch)
logprobs_dataset = get_logprobs_single_model(
model=model,
dataset=in_dataset_split,
feature=feature,
batch_size=batch_size,
)
logprobs_dataset.push_to_hub(
repo_id=out_repo_id,
split=utils.hf_split_to_split_name(split),
revision=branch,
)


def get_logprobs_single_model(
model: AutoModelForCausalLM,
dataset: Dataset,
feature: str,
batch_size: int,
) -> Dataset:
n_seq = len(dataset)
seq_len = len(dataset[0][feature])
logprobs = np.empty((n_seq, seq_len))
logprobs[:, 0] = float("nan")
print("Running inference...")
for i in trange(0, n_seq, batch_size):
batch_tokens = dataset[i : i + batch_size][feature]
logprobs[i : i + batch_size, 1:] = (
get_all_and_next_logprobs(model, batch_tokens)[1].cpu().numpy() # type: ignore
)
return Dataset.from_dict({"logprobs": [row for row in logprobs]})


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run inference and generate log probabilities."
)
parser.add_argument(
"--in-model-repo-id",
"--im",
type=str,
required=True,
help="The model",
)
parser.add_argument(
"--branches",
help="comma separated branches of the model to use or 'ALL' to use all branches",
type=str,
default="main",
required=False,
)

parser.add_argument(
"--in-dataset-repo-id",
"--id",
type=str,
required=True,
help="The tokenized dataset",
)
parser.add_argument(
"--feature",
"-f",
type=str,
required=True,
help="Name of the column containing token sequences in the input dataset",
)
parser.add_argument(
"--split",
"-s",
type=str,
required=True,
help="Split of the tokenized dataset, supports slicing like 'train[:10%%]'",
)
parser.add_argument(
"--out-repo-id",
"-o",
type=str,
required=True,
help="Where to upload the next logprobs",
)
parser.add_argument(
"--batch-size",
"-b",
type=int,
default=80,
help="How many sequences to evaluate at once",
)
# TODO
# parser.add_argument(
# "--chunk-size",
# "-c",
# type=int,
# default=200_000,
# help="Size of the parquet chunks uploaded to HuggingFace",
# )
args = parser.parse_args()

branches = (
args.branches.split(",")
if args.branches != "ALL"
else utils.get_all_hf_branch_names(args.in_model_repo_id)
)

main(
in_model_repo_id=args.in_model_repo_id,
branches=branches,
in_dataset_repo_id=args.in_dataset_repo_id,
split=args.split,
feature=args.feature,
batch_size=args.batch_size,
out_repo_id=args.out_repo_id,
)
105 changes: 0 additions & 105 deletions scripts/inference.py

This file was deleted.

58 changes: 58 additions & 0 deletions src/delphi/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import cast

from datasets import Dataset, Features, Sequence, Value, load_dataset


def hf_split_to_split_name(split: str) -> str:
return split.split("[")[0]


# TODO: test load_dataset functions
def load_dataset_split_features(
repo_id: str,
split: str,
features: Features,
) -> Dataset:
dataset = load_dataset(
repo_id,
split=split,
features=features,
)
dataset = cast(Dataset, dataset)
return dataset


def load_dataset_split_string_feature(
repo_id: str,
split: str,
feature_name: str,
) -> Dataset:
print("Loading string dataset")
print(f"{repo_id=}, {split=}, {feature_name=}")
return load_dataset_split_features(
repo_id,
split,
Features({feature_name: Value("string")}),
)


def load_dataset_split_sequence_int32_feature(
repo_id: str,
split: str,
feature_name: str,
) -> Dataset:
print("Loading sequence int32 dataset")
print(f"{repo_id=}, {split=}, {feature_name=}")
return load_dataset_split_features(
repo_id,
split,
Features({feature_name: Sequence(Value("int32"))}),
)


def get_all_hf_branch_names(repo_id: str) -> list[str]:
from huggingface_hub import HfApi

api = HfApi()
refs = api.list_repo_refs(repo_id)
return [branch.name for branch in refs.branches]
Empty file added tests/__init__.py
Empty file.
14 changes: 14 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from delphi.utils import hf_split_to_split_name

from .utils import random_string


def test_hf_split_to_split_name():
random_split_name = random_string(5)
assert hf_split_to_split_name(random_split_name) == random_split_name
assert hf_split_to_split_name(f"{random_split_name}[:10%]") == random_split_name
assert hf_split_to_split_name(f"{random_split_name}[10%:]") == random_split_name
assert hf_split_to_split_name(f"{random_split_name}[10%:20%]") == random_split_name
assert hf_split_to_split_name(f"{random_split_name}[:200]") == random_split_name
assert hf_split_to_split_name(f"{random_split_name}[200:]") == random_split_name
assert hf_split_to_split_name(f"{random_split_name}[200:400]") == random_split_name
6 changes: 6 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import random
import string


def random_string(length: int) -> str:
return "".join(random.choices(string.ascii_lowercase, k=length))

0 comments on commit 3281227

Please sign in to comment.