-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* get_next_logprobs revamp * UNTESTED: support revisions * revisions -> branches, tested
- Loading branch information
1 parent
9feac6e
commit 3281227
Showing
7 changed files
with
225 additions
and
133 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
#!/usr/bin/env python3 | ||
import argparse | ||
from collections.abc import Iterable | ||
|
||
import numpy as np | ||
import torch | ||
from datasets import Dataset | ||
from tqdm.auto import trange | ||
from transformers import AutoModelForCausalLM | ||
|
||
from delphi import utils | ||
from delphi.eval.utils import get_all_and_next_logprobs | ||
|
||
torch.set_grad_enabled(False) | ||
|
||
|
||
def main( | ||
in_model_repo_id: str, | ||
branches: Iterable[str], | ||
in_dataset_repo_id: str, | ||
split: str, | ||
feature: str, | ||
batch_size: int, | ||
out_repo_id: str, | ||
): | ||
""" | ||
Outputs the log probabilities of the next token for each token in the dataset. | ||
And uploads the resulting dataset to huggingface. | ||
""" | ||
in_dataset_split = utils.load_dataset_split_sequence_int32_feature( | ||
in_dataset_repo_id, split, feature | ||
) | ||
in_dataset_split.set_format("torch") | ||
for branch in branches: | ||
print(f"Loading model='{in_model_repo_id}', {branch=}") | ||
model = AutoModelForCausalLM.from_pretrained(in_model_repo_id, revision=branch) | ||
logprobs_dataset = get_logprobs_single_model( | ||
model=model, | ||
dataset=in_dataset_split, | ||
feature=feature, | ||
batch_size=batch_size, | ||
) | ||
logprobs_dataset.push_to_hub( | ||
repo_id=out_repo_id, | ||
split=utils.hf_split_to_split_name(split), | ||
revision=branch, | ||
) | ||
|
||
|
||
def get_logprobs_single_model( | ||
model: AutoModelForCausalLM, | ||
dataset: Dataset, | ||
feature: str, | ||
batch_size: int, | ||
) -> Dataset: | ||
n_seq = len(dataset) | ||
seq_len = len(dataset[0][feature]) | ||
logprobs = np.empty((n_seq, seq_len)) | ||
logprobs[:, 0] = float("nan") | ||
print("Running inference...") | ||
for i in trange(0, n_seq, batch_size): | ||
batch_tokens = dataset[i : i + batch_size][feature] | ||
logprobs[i : i + batch_size, 1:] = ( | ||
get_all_and_next_logprobs(model, batch_tokens)[1].cpu().numpy() # type: ignore | ||
) | ||
return Dataset.from_dict({"logprobs": [row for row in logprobs]}) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser( | ||
description="Run inference and generate log probabilities." | ||
) | ||
parser.add_argument( | ||
"--in-model-repo-id", | ||
"--im", | ||
type=str, | ||
required=True, | ||
help="The model", | ||
) | ||
parser.add_argument( | ||
"--branches", | ||
help="comma separated branches of the model to use or 'ALL' to use all branches", | ||
type=str, | ||
default="main", | ||
required=False, | ||
) | ||
|
||
parser.add_argument( | ||
"--in-dataset-repo-id", | ||
"--id", | ||
type=str, | ||
required=True, | ||
help="The tokenized dataset", | ||
) | ||
parser.add_argument( | ||
"--feature", | ||
"-f", | ||
type=str, | ||
required=True, | ||
help="Name of the column containing token sequences in the input dataset", | ||
) | ||
parser.add_argument( | ||
"--split", | ||
"-s", | ||
type=str, | ||
required=True, | ||
help="Split of the tokenized dataset, supports slicing like 'train[:10%%]'", | ||
) | ||
parser.add_argument( | ||
"--out-repo-id", | ||
"-o", | ||
type=str, | ||
required=True, | ||
help="Where to upload the next logprobs", | ||
) | ||
parser.add_argument( | ||
"--batch-size", | ||
"-b", | ||
type=int, | ||
default=80, | ||
help="How many sequences to evaluate at once", | ||
) | ||
# TODO | ||
# parser.add_argument( | ||
# "--chunk-size", | ||
# "-c", | ||
# type=int, | ||
# default=200_000, | ||
# help="Size of the parquet chunks uploaded to HuggingFace", | ||
# ) | ||
args = parser.parse_args() | ||
|
||
branches = ( | ||
args.branches.split(",") | ||
if args.branches != "ALL" | ||
else utils.get_all_hf_branch_names(args.in_model_repo_id) | ||
) | ||
|
||
main( | ||
in_model_repo_id=args.in_model_repo_id, | ||
branches=branches, | ||
in_dataset_repo_id=args.in_dataset_repo_id, | ||
split=args.split, | ||
feature=args.feature, | ||
batch_size=args.batch_size, | ||
out_repo_id=args.out_repo_id, | ||
) |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from typing import cast | ||
|
||
from datasets import Dataset, Features, Sequence, Value, load_dataset | ||
|
||
|
||
def hf_split_to_split_name(split: str) -> str: | ||
return split.split("[")[0] | ||
|
||
|
||
# TODO: test load_dataset functions | ||
def load_dataset_split_features( | ||
repo_id: str, | ||
split: str, | ||
features: Features, | ||
) -> Dataset: | ||
dataset = load_dataset( | ||
repo_id, | ||
split=split, | ||
features=features, | ||
) | ||
dataset = cast(Dataset, dataset) | ||
return dataset | ||
|
||
|
||
def load_dataset_split_string_feature( | ||
repo_id: str, | ||
split: str, | ||
feature_name: str, | ||
) -> Dataset: | ||
print("Loading string dataset") | ||
print(f"{repo_id=}, {split=}, {feature_name=}") | ||
return load_dataset_split_features( | ||
repo_id, | ||
split, | ||
Features({feature_name: Value("string")}), | ||
) | ||
|
||
|
||
def load_dataset_split_sequence_int32_feature( | ||
repo_id: str, | ||
split: str, | ||
feature_name: str, | ||
) -> Dataset: | ||
print("Loading sequence int32 dataset") | ||
print(f"{repo_id=}, {split=}, {feature_name=}") | ||
return load_dataset_split_features( | ||
repo_id, | ||
split, | ||
Features({feature_name: Sequence(Value("int32"))}), | ||
) | ||
|
||
|
||
def get_all_hf_branch_names(repo_id: str) -> list[str]: | ||
from huggingface_hub import HfApi | ||
|
||
api = HfApi() | ||
refs = api.list_repo_refs(repo_id) | ||
return [branch.name for branch in refs.branches] |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from delphi.utils import hf_split_to_split_name | ||
|
||
from .utils import random_string | ||
|
||
|
||
def test_hf_split_to_split_name(): | ||
random_split_name = random_string(5) | ||
assert hf_split_to_split_name(random_split_name) == random_split_name | ||
assert hf_split_to_split_name(f"{random_split_name}[:10%]") == random_split_name | ||
assert hf_split_to_split_name(f"{random_split_name}[10%:]") == random_split_name | ||
assert hf_split_to_split_name(f"{random_split_name}[10%:20%]") == random_split_name | ||
assert hf_split_to_split_name(f"{random_split_name}[:200]") == random_split_name | ||
assert hf_split_to_split_name(f"{random_split_name}[200:]") == random_split_name | ||
assert hf_split_to_split_name(f"{random_split_name}[200:400]") == random_split_name |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
import random | ||
import string | ||
|
||
|
||
def random_string(length: int) -> str: | ||
return "".join(random.choices(string.ascii_lowercase, k=length)) |