-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
89b2b4b
commit f08396f
Showing
7 changed files
with
185 additions
and
133 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
#!/usr/bin/env python3 | ||
import argparse | ||
|
||
import numpy as np | ||
import torch | ||
from datasets import Dataset | ||
from tqdm.auto import trange | ||
from transformers import AutoModelForCausalLM | ||
|
||
from delphi import utils | ||
from delphi.eval.utils import get_all_and_next_logprobs | ||
|
||
torch.set_grad_enabled(False) | ||
|
||
|
||
def main( | ||
in_model_repo_id: str, | ||
in_dataset_repo_id: str, | ||
split: str, | ||
feature: str, | ||
batch_size: int, | ||
out_repo_id: str, | ||
): | ||
""" | ||
Outputs the log probabilities of the next token for each token in the dataset. | ||
And uploads the resulting dataset to huggingface. | ||
""" | ||
model = AutoModelForCausalLM.from_pretrained(in_model_repo_id) | ||
in_dataset_split = utils.load_dataset_split_sequence_int32_feature( | ||
in_dataset_repo_id, split, feature | ||
) | ||
in_dataset_split.set_format("torch") | ||
n_seq = len(in_dataset_split) | ||
seq_len = len(in_dataset_split[0][feature]) | ||
logprobs = np.empty((n_seq, seq_len)) | ||
logprobs[:, 0] = float("nan") | ||
print("Running inference...") | ||
for i in trange(0, n_seq, batch_size): | ||
batch_tokens = in_dataset_split[i : i + batch_size][feature] | ||
logprobs[i : i + batch_size, 1:] = ( | ||
get_all_and_next_logprobs(model, batch_tokens)[1].cpu().numpy() | ||
) | ||
|
||
hf_dataset = Dataset.from_dict({"logprobs": [row for row in logprobs]}) | ||
|
||
hf_dataset.push_to_hub( | ||
repo_id=out_repo_id, | ||
split=utils.hf_split_to_split_name(split), | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser( | ||
description="Run inference and generate log probabilities." | ||
) | ||
parser.add_argument( | ||
"--in-model-repo-id", | ||
"--im", | ||
type=str, | ||
required=True, | ||
help="The model", | ||
) | ||
parser.add_argument( | ||
"--in-dataset-repo-id", | ||
"--id", | ||
type=str, | ||
required=True, | ||
help="The tokenized dataset", | ||
) | ||
parser.add_argument( | ||
"--feature", | ||
"-f", | ||
type=str, | ||
required=True, | ||
help="Name of the column containing token sequences in the input dataset", | ||
) | ||
parser.add_argument( | ||
"--split", | ||
"-s", | ||
type=str, | ||
required=True, | ||
help="Split of the tokenized dataset, supports slicing like 'train[:10%%]'", | ||
) | ||
parser.add_argument( | ||
"--out-repo-id", | ||
"-o", | ||
type=str, | ||
required=True, | ||
help="Where to upload the next logprobs", | ||
) | ||
parser.add_argument( | ||
"--batch-size", | ||
"-b", | ||
type=int, | ||
default=80, | ||
help="How many sequences to evaluate at once", | ||
) | ||
# TODO | ||
# parser.add_argument( | ||
# "--chunk-size", | ||
# "-c", | ||
# type=int, | ||
# default=200_000, | ||
# help="Size of the parquet chunks uploaded to HuggingFace", | ||
# ) | ||
args = parser.parse_args() | ||
|
||
main( | ||
in_model_repo_id=args.in_model_repo_id, | ||
in_dataset_repo_id=args.in_dataset_repo_id, | ||
split=args.split, | ||
feature=args.feature, | ||
batch_size=args.batch_size, | ||
out_repo_id=args.out_repo_id, | ||
) |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from typing import cast | ||
|
||
from datasets import Dataset, Features, Sequence, Value, load_dataset | ||
|
||
|
||
def hf_split_to_split_name(split: str) -> str: | ||
return split.split("[")[0] | ||
|
||
|
||
# TODO: test load_dataset functions | ||
def load_dataset_split_features( | ||
repo_id: str, | ||
split: str, | ||
features: Features, | ||
) -> Dataset: | ||
dataset = load_dataset( | ||
repo_id, | ||
split=split, | ||
features=features, | ||
) | ||
dataset = cast(Dataset, dataset) | ||
return dataset | ||
|
||
|
||
def load_dataset_split_string_feature( | ||
repo_id: str, | ||
split: str, | ||
feature_name: str, | ||
) -> Dataset: | ||
print("Loading string dataset") | ||
print(f"{repo_id=}, {split=}, {feature_name=}") | ||
return load_dataset_split_features( | ||
repo_id, | ||
split, | ||
Features({feature_name: Value("string")}), | ||
) | ||
|
||
|
||
def load_dataset_split_sequence_int32_feature( | ||
repo_id: str, | ||
split: str, | ||
feature_name: str, | ||
) -> Dataset: | ||
print("Loading sequence int32 dataset") | ||
print(f"{repo_id=}, {split=}, {feature_name=}") | ||
return load_dataset_split_features( | ||
repo_id, | ||
split, | ||
Features({feature_name: Sequence(Value("int32"))}), | ||
) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from delphi.utils import hf_split_to_split_name | ||
|
||
from .utils import random_string | ||
|
||
|
||
def test_hf_split_to_split_name(): | ||
random_split_name = random_string(5) | ||
assert hf_split_to_split_name(random_split_name) == random_split_name | ||
assert hf_split_to_split_name(f"{random_split_name}[:10%]") == random_split_name | ||
assert hf_split_to_split_name(f"{random_split_name}[10%:]") == random_split_name | ||
assert hf_split_to_split_name(f"{random_split_name}[10%:20%]") == random_split_name | ||
assert hf_split_to_split_name(f"{random_split_name}[:200]") == random_split_name | ||
assert hf_split_to_split_name(f"{random_split_name}[200:]") == random_split_name | ||
assert hf_split_to_split_name(f"{random_split_name}[200:400]") == random_split_name |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
import random | ||
import string | ||
|
||
|
||
def random_string(length: int) -> str: | ||
return "".join(random.choices(string.ascii_lowercase, k=length)) |