-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* model specify + nextlogprobs load + tok specify * Draft version of plotting * get prompt examples * add meeting todos * grouped imports and cells * redid typing and fixed bugs * unique token check * Remove token category in calculation function * Remove token category from vis function * removed token_group in token diff + calculate all loss * vis_pos_map highlight + optimization * fixes: tokenization, mask, typing * use interact_manual for resampling * update quantile function tests * small update * beartype fix * rm comment * var rename * eval notebook updates --------- Co-authored-by: Siwei Li <[email protected]> Co-authored-by: VICTOR ABIA <[email protected]> Co-authored-by: Jett <[email protected]> Co-authored-by: JaiDhyani <[email protected]> Co-authored-by: Jai <[email protected]>
- Loading branch information
1 parent
426c964
commit 71f77fd
Showing
8 changed files
with
610 additions
and
115 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,54 +1,48 @@ | ||
import numpy as np | ||
import torch | ||
from datasets import Dataset | ||
from jaxtyping import Float | ||
|
||
|
||
def calc_model_group_stats( | ||
tokenized_corpus_dataset: list, | ||
logprobs_by_dataset: dict[str, list[list[float]]], | ||
token_labels_by_token: dict[int, dict[str, bool]], | ||
token_labels: list[str], | ||
) -> dict[tuple[str, str], dict[str, float]]: | ||
tokenized_corpus_dataset: Dataset, | ||
logprobs_by_dataset: dict[str, torch.Tensor], | ||
selected_tokens: list[int], | ||
) -> dict[str, dict[str, float]]: | ||
""" | ||
For each (model, token group) pair, calculate useful stats (for visualization) | ||
args: | ||
- tokenized_corpus_dataset: the tokenized corpus dataset, e.g. load_dataset(constants.tokenized_corpus_dataset))["validation"] | ||
- tokenized_corpus_dataset: a list of the tokenized corpus datasets, e.g. load_dataset(constants.tokenized_corpus_dataset))["validation"] | ||
- logprob_datasets: a dict of lists of logprobs, e.g. {"llama2": load_dataset("transcendingvictor/llama2-validation-logprobs")["validation"]["logprobs"]} | ||
- token_groups: a dict of token groups, e.g. {0: {"Is Noun": True, "Is Verb": False, ...}, 1: {...}, ...} | ||
- models: a list of model names, e.g. constants.LLAMA2_MODELS | ||
- token_labels: a list of token group descriptions, e.g. ["Is Noun", "Is Verb", ...] | ||
- selected_tokens: a list of selected token IDs, e.g. [46, 402, ...] | ||
returns: a dict of (model, token group) pairs to a dict of stats, | ||
e.g. {("llama2", "Is Noun"): {"mean": -0.5, "median": -0.4, "min": -0.1, "max": -0.9, "25th": -0.3, "75th": -0.7}, ...} | ||
returns: a dict of model names as keys and stats dict as values | ||
e.g. {"100k": {"mean": -0.5, "median": -0.4, "min": -0.1, "max": -0.9, "25th": -0.3, "75th": -0.7}, ...} | ||
Technically `models` and `token_labels` are redundant, as they are also keys in `logprob_datasets` and `token_groups`, | ||
but it's better to be explicit | ||
stats calculated: mean, median, min, max, 25th percentile, 75th percentile | ||
Stats calculated: mean, median, min, max, 25th percentile, 75th percentile | ||
""" | ||
model_group_stats = {} | ||
for model in logprobs_by_dataset: | ||
group_logprobs = {} | ||
model_logprobs = [] | ||
print(f"Processing model {model}") | ||
dataset = logprobs_by_dataset[model] | ||
for ix_doc_lp, document_lps in enumerate(dataset): | ||
tokens = tokenized_corpus_dataset[ix_doc_lp]["tokens"] | ||
for ix_token, token in enumerate(tokens): | ||
if ix_token == 0: # skip the first token, which isn't predicted | ||
continue | ||
logprob = document_lps[ix_token] | ||
for token_group_desc in token_labels: | ||
if token_labels_by_token[token][token_group_desc]: | ||
if token_group_desc not in group_logprobs: | ||
group_logprobs[token_group_desc] = [] | ||
group_logprobs[token_group_desc].append(logprob) | ||
for token_group_desc in token_labels: | ||
if token_group_desc in group_logprobs: | ||
model_group_stats[(model, token_group_desc)] = { | ||
"mean": np.mean(group_logprobs[token_group_desc]), | ||
"median": np.median(group_logprobs[token_group_desc]), | ||
"min": np.min(group_logprobs[token_group_desc]), | ||
"max": np.max(group_logprobs[token_group_desc]), | ||
"25th": np.percentile(group_logprobs[token_group_desc], 25), | ||
"75th": np.percentile(group_logprobs[token_group_desc], 75), | ||
} | ||
logprob = document_lps[ix_token].item() | ||
if token in selected_tokens: | ||
model_logprobs.append(logprob) | ||
|
||
if model_logprobs: | ||
model_group_stats[model] = { | ||
"mean": np.mean(model_logprobs), | ||
"median": np.median(model_logprobs), | ||
"min": np.min(model_logprobs), | ||
"max": np.max(model_logprobs), | ||
"25th": np.percentile(model_logprobs, 25), | ||
"75th": np.percentile(model_logprobs, 75), | ||
} | ||
return model_group_stats |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.