Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

final evaluation notebook #98

Merged
merged 19 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
480 changes: 480 additions & 0 deletions notebooks/eval_notebook.ipynb

Large diffs are not rendered by default.

58 changes: 26 additions & 32 deletions src/delphi/eval/calc_model_group_stats.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,48 @@
import numpy as np
import torch
from datasets import Dataset
from jaxtyping import Float


def calc_model_group_stats(
tokenized_corpus_dataset: list,
logprobs_by_dataset: dict[str, list[list[float]]],
token_labels_by_token: dict[int, dict[str, bool]],
token_labels: list[str],
) -> dict[tuple[str, str], dict[str, float]]:
tokenized_corpus_dataset: Dataset,
logprobs_by_dataset: dict[str, torch.Tensor],
selected_tokens: list[int],
) -> dict[str, dict[str, float]]:
"""
For each (model, token group) pair, calculate useful stats (for visualization)

args:
- tokenized_corpus_dataset: the tokenized corpus dataset, e.g. load_dataset(constants.tokenized_corpus_dataset))["validation"]
- tokenized_corpus_dataset: a list of the tokenized corpus datasets, e.g. load_dataset(constants.tokenized_corpus_dataset))["validation"]
- logprob_datasets: a dict of lists of logprobs, e.g. {"llama2": load_dataset("transcendingvictor/llama2-validation-logprobs")["validation"]["logprobs"]}
- token_groups: a dict of token groups, e.g. {0: {"Is Noun": True, "Is Verb": False, ...}, 1: {...}, ...}
- models: a list of model names, e.g. constants.LLAMA2_MODELS
- token_labels: a list of token group descriptions, e.g. ["Is Noun", "Is Verb", ...]
- selected_tokens: a list of selected token IDs, e.g. [46, 402, ...]

returns: a dict of (model, token group) pairs to a dict of stats,
e.g. {("llama2", "Is Noun"): {"mean": -0.5, "median": -0.4, "min": -0.1, "max": -0.9, "25th": -0.3, "75th": -0.7}, ...}
returns: a dict of model names as keys and stats dict as values
e.g. {"100k": {"mean": -0.5, "median": -0.4, "min": -0.1, "max": -0.9, "25th": -0.3, "75th": -0.7}, ...}

Technically `models` and `token_labels` are redundant, as they are also keys in `logprob_datasets` and `token_groups`,
but it's better to be explicit

stats calculated: mean, median, min, max, 25th percentile, 75th percentile
Stats calculated: mean, median, min, max, 25th percentile, 75th percentile
"""
model_group_stats = {}
for model in logprobs_by_dataset:
group_logprobs = {}
model_logprobs = []
print(f"Processing model {model}")
dataset = logprobs_by_dataset[model]
for ix_doc_lp, document_lps in enumerate(dataset):
tokens = tokenized_corpus_dataset[ix_doc_lp]["tokens"]
for ix_token, token in enumerate(tokens):
if ix_token == 0: # skip the first token, which isn't predicted
continue
logprob = document_lps[ix_token]
for token_group_desc in token_labels:
if token_labels_by_token[token][token_group_desc]:
if token_group_desc not in group_logprobs:
group_logprobs[token_group_desc] = []
group_logprobs[token_group_desc].append(logprob)
for token_group_desc in token_labels:
if token_group_desc in group_logprobs:
model_group_stats[(model, token_group_desc)] = {
"mean": np.mean(group_logprobs[token_group_desc]),
"median": np.median(group_logprobs[token_group_desc]),
"min": np.min(group_logprobs[token_group_desc]),
"max": np.max(group_logprobs[token_group_desc]),
"25th": np.percentile(group_logprobs[token_group_desc], 25),
"75th": np.percentile(group_logprobs[token_group_desc], 75),
}
logprob = document_lps[ix_token].item()
if token in selected_tokens:
model_logprobs.append(logprob)

if model_logprobs:
model_group_stats[model] = {
"mean": np.mean(model_logprobs),
"median": np.median(model_logprobs),
"min": np.min(model_logprobs),
"max": np.max(model_logprobs),
"25th": np.percentile(model_logprobs, 25),
"75th": np.percentile(model_logprobs, 75),
}
return model_group_stats
18 changes: 9 additions & 9 deletions src/delphi/eval/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
tokenized_corpus_dataset = "delphi-suite/tinystories-v2-clean-tokenized-v0"

LLAMA2_MODELS = [
"delphi-llama2-100k",
"delphi-llama2-200k",
"delphi-llama2-400k",
"delphi-llama2-800k",
"delphi-llama2-1.6m",
"delphi-llama2-3.2m",
"delphi-llama2-6.4m",
"delphi-llama2-12.8m",
"delphi-llama2-25.6m",
"llama2-100k",
"llama2-200k",
"llama2-400k",
"llama2-800k",
"llama2-1.6m",
"llama2-3.2m",
"llama2-6.4m",
"llama2-12.8m",
"llama2-25.6m",
]

LLAMA2_NEXT_LOGPROBS_DATASETS_MAP = {
Expand Down
12 changes: 4 additions & 8 deletions src/delphi/eval/token_positions.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
from numbers import Number
from typing import Optional, cast
from typing import Optional

import torch
from datasets import Dataset
from jaxtyping import Int

from delphi.eval.utils import dict_filter_quantile


def get_all_tok_metrics_in_label(
token_ids: Int[torch.Tensor, "prompt pos"],
token_labels: dict[int, dict[str, bool]],
selected_tokens: list[int],
metrics: torch.Tensor,
label: str,
q_start: Optional[float] = None,
q_end: Optional[float] = None,
) -> dict[tuple[int, int], float]:
Expand All @@ -23,9 +20,8 @@ def get_all_tok_metrics_in_label(

Args:
- token_ids (Dataset): token_ids dataset e.g. token_ids[0] = {"tokens": [[1, 2, ...], [2, 5, ...], ...]}
- token_labels (dict[int, dict[str, bool]]): dictionary of token labels e.g. { 0: {"Is Noun": True, "Is Verb": False}, ...}
- selected_tokens (list[int]): list of token IDs to search for e.g. [46, 402, ...]
- metrics (torch.Tensor): tensor of metrics to search through e.g. torch.tensor([[0.1, 0.2, ...], [0.3, 0.4, ...], ...])
- label (str): the label to search for
- q_start (float): the start of the quantile range to filter the metrics e.g. 0.1
- q_end (float): the end of the quantile range to filter the metrics e.g. 0.9

Expand All @@ -42,7 +38,7 @@ def get_all_tok_metrics_in_label(
tok_positions = {}
for prompt_pos, prompt in enumerate(token_ids.numpy()):
for tok_pos, tok in enumerate(prompt):
if token_labels[tok][label]:
if tok in selected_tokens:
tok_positions[(prompt_pos, tok_pos)] = metrics[
prompt_pos, tok_pos
].item()
Expand Down
84 changes: 47 additions & 37 deletions src/delphi/eval/vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import uuid
from typing import cast

import numpy as np
import panel as pn
import torch
from IPython.core.display import HTML
Expand Down Expand Up @@ -54,6 +55,7 @@ def token_to_html(
tokenizer: PreTrainedTokenizerBase,
bg_color: str,
data: dict,
class_name: str = "token",
) -> str:
data = data or {} # equivalent to if not data: data = {}
# non-breakable space, w/o it leading spaces wouldn't be displayed
Expand All @@ -73,6 +75,7 @@ def token_to_html(
br += "<br>"
# this is so we can copy the prompt without "\n"s
specific_styles["user-select"] = "none"
str_token = str_token.replace("<", "&lt;").replace(">", "&gt;")

style_str = data_str = ""
# converting style dict into the style attribute
Expand All @@ -83,7 +86,7 @@ def token_to_html(
data_str = "".join(
f" data-{k}='{v.replace(' ', '&nbsp;')}'" for k, v in data.items()
)
return f"<div class='token'{style_str}{data_str}>{str_token}</div>{br}"
return f"<div class='{class_name}'{style_str}{data_str}>{str_token}</div>{br}"


_token_style = {
Expand All @@ -97,7 +100,20 @@ def token_to_html(
"margin": "1px 0px 1px 1px",
"padding": "0px 1px 1px 1px",
}
_token_emphasized_style = {
"border": "3px solid #888",
"display": "inline-block",
"font-family": "monospace",
"font-size": "14px",
"color": "black",
"background-color": "white",
"margin": "1px 0px 1px 1px",
"padding": "0px 1px 1px 1px",
}
_token_style_str = " ".join([f"{k}: {v};" for k, v in _token_style.items()])
_token_emphasized_style_str = " ".join(
[f"{k}: {v};" for k, v in _token_emphasized_style.items()]
)


def vis_sample_prediction_probs(
Expand Down Expand Up @@ -130,8 +146,8 @@ def vis_sample_prediction_probs(
data[f"top{j}"] = to_tok_prob_str(top_tok, top_prob, tokenizer)

token_htmls.append(
token_to_html(tok, tokenizer, bg_color=colors[i], data=data).replace(
"class='token'", f"class='{token_class}'"
token_to_html(
tok, tokenizer, bg_color=colors[i], data=data, class_name=token_class
)
)

Expand Down Expand Up @@ -165,60 +181,55 @@ def vis_sample_prediction_probs(


def vis_pos_map(
pos_map: dict[tuple[int, int], float | int],
pos_list: list[tuple[int, int]],
selected_tokens: list[int],
metrics: Float[torch.Tensor, "prompt pos"],
token_ids: Int[torch.Tensor, "prompt pos"],
tokenizer: PreTrainedTokenizerBase,
sample: int = 3,
):
"""
Randomly sample from pos_map and visualize the loss diff at the corresponding position.
"""

token_htmls = []
unique_id = str(uuid.uuid4())
token_class = f"token_{unique_id}"
token_class = f"pretoken_{unique_id}"
selected_token_class = f"token_{unique_id}"
hover_div_id = f"hover_info_{unique_id}"

# choose n random keys from pos_map
keys = random.sample(list(pos_map.keys()), k=sample)

for key in keys:
prompt, pos = key
pre_toks = token_ids[prompt][:pos]
mask = torch.isin(pre_toks, torch.tensor([0, 1], dtype=torch.int8))
pre_toks = pre_toks[
~mask
] # remove <unk> and <s> tokens, <s> cause strikethrough in html

for i in range(pre_toks.shape[0]):
pre_tok = cast(int, pre_toks[i].item())
token_htmls.append(
token_to_html(pre_tok, tokenizer, bg_color="white", data={}).replace(
"class='token'", f"class='{token_class}'"
)
)
# choose a random keys from pos_map
key = random.choice(pos_list)

tok = cast(int, token_ids[prompt][pos].item())
value = cast(float, pos_map[key])
prompt, pos = key
all_toks = token_ids[prompt][: pos + 1]

for i in range(all_toks.shape[0]):
token_id = cast(int, all_toks[i].item())
value = metrics[prompt][i].item()
token_htmls.append(
token_to_html(
tok,
token_id,
tokenizer,
bg_color=single_loss_diff_to_color(value),
bg_color="white"
if np.isnan(value)
else single_loss_diff_to_color(value),
data={"loss-diff": f"{value:.2f}"},
).replace("class='token'", f"class='{token_class}'")
class_name=token_class
if token_id not in selected_tokens
else selected_token_class,
)
)

# add break line
token_htmls.append("<br><br>")
# add break line
token_htmls.append("<br><br>")

html_str = f"""
<style>.{token_class} {{ {_token_style_str} }} #{hover_div_id} {{ height: 100px; font-family: monospace; }}</style>
<style>.{token_class} {{ {_token_style_str}}} .{selected_token_class} {{ {_token_emphasized_style_str} }} #{hover_div_id} {{ height: 100px; font-family: monospace; }}</style>
{"".join(token_htmls)} <div id='{hover_div_id}'></div>
<script>
(function() {{
var token_divs = document.querySelectorAll('.{token_class}');
token_divs = Array.from(token_divs).concat(Array.from(document.querySelectorAll('.{selected_token_class}')));
var hover_info = document.getElementById('{hover_div_id}');


Expand All @@ -239,19 +250,18 @@ def vis_pos_map(
</script>
"""
display(HTML(html_str))
return html_str


def token_selector(
vocab_map: dict[str, int]
) -> tuple[pn.widgets.MultiChoice, list[int]]:
tokens = list(vocab_map.keys())
token_selector = pn.widgets.MultiChoice(name="Tokens", options=tokens)
token_ids = [vocab_map[token] for token in cast(list[str], token_selector.value)]
token_selector_ = pn.widgets.MultiChoice(name="Tokens", options=tokens)
token_ids = [vocab_map[token] for token in cast(list[str], token_selector_.value)]

def update_tokens(event):
token_ids.clear()
token_ids.extend([vocab_map[token] for token in event.new])

token_selector.param.watch(update_tokens, "value")
return token_selector, token_ids
token_selector_.param.watch(update_tokens, "value")
return token_selector_, token_ids
13 changes: 5 additions & 8 deletions src/delphi/eval/vis_per_token_model.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from typing import Union

import ipywidgets
import numpy as np
import plotly.graph_objects as go


def visualize_per_token_category(
input: dict[Union[str, int], dict[str, tuple]],
def visualize_selected_tokens(
input: dict[Union[str, int], tuple[float, float, float]],
log_scale=False,
line_metric="Means",
checkpoint_mode=True,
Expand All @@ -17,18 +16,16 @@ def visualize_per_token_category(
background_color="AliceBlue",
) -> go.FigureWidget:
input_x = list(input.keys())
categories = list(input[input_x[0]].keys())
category = categories[0]

def get_hovertexts(mid: np.ndarray, lo: np.ndarray, hi: np.ndarray) -> list[str]:
return [f"Loss: {m:.3f} ({l:.3f}, {h:.3f})" for m, l, h in zip(mid, lo, hi)]

def get_plot_values(category: str) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
x = np.array([input[x][category] for x in input_x]).T
def get_plot_values() -> tuple[np.ndarray, np.ndarray, np.ndarray]:
x = np.array([input[x] for x in input_x]).T
means, err_lo, err_hi = x[0], x[1], x[2]
return means, err_lo, err_hi

means, err_lo, err_hi = get_plot_values(category)
means, err_lo, err_hi = get_plot_values()

if checkpoint_mode:
scatter_plot = go.Figure(
Expand Down
Loading
Loading