Skip to content

Commit

Permalink
final evaluation notebook (#98)
Browse files Browse the repository at this point in the history
* model specify + nextlogprobs load + tok specify

* Draft version of plotting

* get prompt examples

* add meeting todos

* grouped imports and cells

* redid typing and fixed bugs

* unique token check

* Remove token category in calculation function

* Remove token category from vis function

* removed token_group in token diff + calculate all loss

* vis_pos_map highlight + optimization

* fixes: tokenization, mask, typing

* use interact_manual for resampling

* update quantile function tests

* small update

* beartype fix

* rm comment

* var rename

* eval notebook updates

---------

Co-authored-by: Siwei Li <[email protected]>
Co-authored-by: VICTOR ABIA <[email protected]>
Co-authored-by: Jett <[email protected]>
Co-authored-by: JaiDhyani <[email protected]>
Co-authored-by: Jai <[email protected]>
  • Loading branch information
6 people committed May 25, 2024
1 parent 426c964 commit 71f77fd
Show file tree
Hide file tree
Showing 8 changed files with 610 additions and 115 deletions.
480 changes: 480 additions & 0 deletions notebooks/eval_notebook.ipynb

Large diffs are not rendered by default.

58 changes: 26 additions & 32 deletions src/delphi/eval/calc_model_group_stats.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,48 @@
import numpy as np
import torch
from datasets import Dataset
from jaxtyping import Float


def calc_model_group_stats(
tokenized_corpus_dataset: list,
logprobs_by_dataset: dict[str, list[list[float]]],
token_labels_by_token: dict[int, dict[str, bool]],
token_labels: list[str],
) -> dict[tuple[str, str], dict[str, float]]:
tokenized_corpus_dataset: Dataset,
logprobs_by_dataset: dict[str, torch.Tensor],
selected_tokens: list[int],
) -> dict[str, dict[str, float]]:
"""
For each (model, token group) pair, calculate useful stats (for visualization)
args:
- tokenized_corpus_dataset: the tokenized corpus dataset, e.g. load_dataset(constants.tokenized_corpus_dataset))["validation"]
- tokenized_corpus_dataset: a list of the tokenized corpus datasets, e.g. load_dataset(constants.tokenized_corpus_dataset))["validation"]
- logprob_datasets: a dict of lists of logprobs, e.g. {"llama2": load_dataset("transcendingvictor/llama2-validation-logprobs")["validation"]["logprobs"]}
- token_groups: a dict of token groups, e.g. {0: {"Is Noun": True, "Is Verb": False, ...}, 1: {...}, ...}
- models: a list of model names, e.g. constants.LLAMA2_MODELS
- token_labels: a list of token group descriptions, e.g. ["Is Noun", "Is Verb", ...]
- selected_tokens: a list of selected token IDs, e.g. [46, 402, ...]
returns: a dict of (model, token group) pairs to a dict of stats,
e.g. {("llama2", "Is Noun"): {"mean": -0.5, "median": -0.4, "min": -0.1, "max": -0.9, "25th": -0.3, "75th": -0.7}, ...}
returns: a dict of model names as keys and stats dict as values
e.g. {"100k": {"mean": -0.5, "median": -0.4, "min": -0.1, "max": -0.9, "25th": -0.3, "75th": -0.7}, ...}
Technically `models` and `token_labels` are redundant, as they are also keys in `logprob_datasets` and `token_groups`,
but it's better to be explicit
stats calculated: mean, median, min, max, 25th percentile, 75th percentile
Stats calculated: mean, median, min, max, 25th percentile, 75th percentile
"""
model_group_stats = {}
for model in logprobs_by_dataset:
group_logprobs = {}
model_logprobs = []
print(f"Processing model {model}")
dataset = logprobs_by_dataset[model]
for ix_doc_lp, document_lps in enumerate(dataset):
tokens = tokenized_corpus_dataset[ix_doc_lp]["tokens"]
for ix_token, token in enumerate(tokens):
if ix_token == 0: # skip the first token, which isn't predicted
continue
logprob = document_lps[ix_token]
for token_group_desc in token_labels:
if token_labels_by_token[token][token_group_desc]:
if token_group_desc not in group_logprobs:
group_logprobs[token_group_desc] = []
group_logprobs[token_group_desc].append(logprob)
for token_group_desc in token_labels:
if token_group_desc in group_logprobs:
model_group_stats[(model, token_group_desc)] = {
"mean": np.mean(group_logprobs[token_group_desc]),
"median": np.median(group_logprobs[token_group_desc]),
"min": np.min(group_logprobs[token_group_desc]),
"max": np.max(group_logprobs[token_group_desc]),
"25th": np.percentile(group_logprobs[token_group_desc], 25),
"75th": np.percentile(group_logprobs[token_group_desc], 75),
}
logprob = document_lps[ix_token].item()
if token in selected_tokens:
model_logprobs.append(logprob)

if model_logprobs:
model_group_stats[model] = {
"mean": np.mean(model_logprobs),
"median": np.median(model_logprobs),
"min": np.min(model_logprobs),
"max": np.max(model_logprobs),
"25th": np.percentile(model_logprobs, 25),
"75th": np.percentile(model_logprobs, 75),
}
return model_group_stats
18 changes: 9 additions & 9 deletions src/delphi/eval/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
tokenized_corpus_dataset = "delphi-suite/tinystories-v2-clean-tokenized-v0"

LLAMA2_MODELS = [
"delphi-llama2-100k",
"delphi-llama2-200k",
"delphi-llama2-400k",
"delphi-llama2-800k",
"delphi-llama2-1.6m",
"delphi-llama2-3.2m",
"delphi-llama2-6.4m",
"delphi-llama2-12.8m",
"delphi-llama2-25.6m",
"llama2-100k",
"llama2-200k",
"llama2-400k",
"llama2-800k",
"llama2-1.6m",
"llama2-3.2m",
"llama2-6.4m",
"llama2-12.8m",
"llama2-25.6m",
]

LLAMA2_NEXT_LOGPROBS_DATASETS_MAP = {
Expand Down
12 changes: 4 additions & 8 deletions src/delphi/eval/token_positions.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
from numbers import Number
from typing import Optional, cast
from typing import Optional

import torch
from datasets import Dataset
from jaxtyping import Int

from delphi.eval.utils import dict_filter_quantile


def get_all_tok_metrics_in_label(
token_ids: Int[torch.Tensor, "prompt pos"],
token_labels: dict[int, dict[str, bool]],
selected_tokens: list[int],
metrics: torch.Tensor,
label: str,
q_start: Optional[float] = None,
q_end: Optional[float] = None,
) -> dict[tuple[int, int], float]:
Expand All @@ -23,9 +20,8 @@ def get_all_tok_metrics_in_label(
Args:
- token_ids (Dataset): token_ids dataset e.g. token_ids[0] = {"tokens": [[1, 2, ...], [2, 5, ...], ...]}
- token_labels (dict[int, dict[str, bool]]): dictionary of token labels e.g. { 0: {"Is Noun": True, "Is Verb": False}, ...}
- selected_tokens (list[int]): list of token IDs to search for e.g. [46, 402, ...]
- metrics (torch.Tensor): tensor of metrics to search through e.g. torch.tensor([[0.1, 0.2, ...], [0.3, 0.4, ...], ...])
- label (str): the label to search for
- q_start (float): the start of the quantile range to filter the metrics e.g. 0.1
- q_end (float): the end of the quantile range to filter the metrics e.g. 0.9
Expand All @@ -42,7 +38,7 @@ def get_all_tok_metrics_in_label(
tok_positions = {}
for prompt_pos, prompt in enumerate(token_ids.numpy()):
for tok_pos, tok in enumerate(prompt):
if token_labels[tok][label]:
if tok in selected_tokens:
tok_positions[(prompt_pos, tok_pos)] = metrics[
prompt_pos, tok_pos
].item()
Expand Down
84 changes: 47 additions & 37 deletions src/delphi/eval/vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import uuid
from typing import cast

import numpy as np
import panel as pn
import torch
from IPython.core.display import HTML
Expand Down Expand Up @@ -54,6 +55,7 @@ def token_to_html(
tokenizer: PreTrainedTokenizerBase,
bg_color: str,
data: dict,
class_name: str = "token",
) -> str:
data = data or {} # equivalent to if not data: data = {}
# non-breakable space, w/o it leading spaces wouldn't be displayed
Expand All @@ -73,6 +75,7 @@ def token_to_html(
br += "<br>"
# this is so we can copy the prompt without "\n"s
specific_styles["user-select"] = "none"
str_token = str_token.replace("<", "&lt;").replace(">", "&gt;")

style_str = data_str = ""
# converting style dict into the style attribute
Expand All @@ -83,7 +86,7 @@ def token_to_html(
data_str = "".join(
f" data-{k}='{v.replace(' ', '&nbsp;')}'" for k, v in data.items()
)
return f"<div class='token'{style_str}{data_str}>{str_token}</div>{br}"
return f"<div class='{class_name}'{style_str}{data_str}>{str_token}</div>{br}"


_token_style = {
Expand All @@ -97,7 +100,20 @@ def token_to_html(
"margin": "1px 0px 1px 1px",
"padding": "0px 1px 1px 1px",
}
_token_emphasized_style = {
"border": "3px solid #888",
"display": "inline-block",
"font-family": "monospace",
"font-size": "14px",
"color": "black",
"background-color": "white",
"margin": "1px 0px 1px 1px",
"padding": "0px 1px 1px 1px",
}
_token_style_str = " ".join([f"{k}: {v};" for k, v in _token_style.items()])
_token_emphasized_style_str = " ".join(
[f"{k}: {v};" for k, v in _token_emphasized_style.items()]
)


def vis_sample_prediction_probs(
Expand Down Expand Up @@ -130,8 +146,8 @@ def vis_sample_prediction_probs(
data[f"top{j}"] = to_tok_prob_str(top_tok, top_prob, tokenizer)

token_htmls.append(
token_to_html(tok, tokenizer, bg_color=colors[i], data=data).replace(
"class='token'", f"class='{token_class}'"
token_to_html(
tok, tokenizer, bg_color=colors[i], data=data, class_name=token_class
)
)

Expand Down Expand Up @@ -165,60 +181,55 @@ def vis_sample_prediction_probs(


def vis_pos_map(
pos_map: dict[tuple[int, int], float | int],
pos_list: list[tuple[int, int]],
selected_tokens: list[int],
metrics: Float[torch.Tensor, "prompt pos"],
token_ids: Int[torch.Tensor, "prompt pos"],
tokenizer: PreTrainedTokenizerBase,
sample: int = 3,
):
"""
Randomly sample from pos_map and visualize the loss diff at the corresponding position.
"""

token_htmls = []
unique_id = str(uuid.uuid4())
token_class = f"token_{unique_id}"
token_class = f"pretoken_{unique_id}"
selected_token_class = f"token_{unique_id}"
hover_div_id = f"hover_info_{unique_id}"

# choose n random keys from pos_map
keys = random.sample(list(pos_map.keys()), k=sample)

for key in keys:
prompt, pos = key
pre_toks = token_ids[prompt][:pos]
mask = torch.isin(pre_toks, torch.tensor([0, 1], dtype=torch.int8))
pre_toks = pre_toks[
~mask
] # remove <unk> and <s> tokens, <s> cause strikethrough in html

for i in range(pre_toks.shape[0]):
pre_tok = cast(int, pre_toks[i].item())
token_htmls.append(
token_to_html(pre_tok, tokenizer, bg_color="white", data={}).replace(
"class='token'", f"class='{token_class}'"
)
)
# choose a random keys from pos_map
key = random.choice(pos_list)

tok = cast(int, token_ids[prompt][pos].item())
value = cast(float, pos_map[key])
prompt, pos = key
all_toks = token_ids[prompt][: pos + 1]

for i in range(all_toks.shape[0]):
token_id = cast(int, all_toks[i].item())
value = metrics[prompt][i].item()
token_htmls.append(
token_to_html(
tok,
token_id,
tokenizer,
bg_color=single_loss_diff_to_color(value),
bg_color="white"
if np.isnan(value)
else single_loss_diff_to_color(value),
data={"loss-diff": f"{value:.2f}"},
).replace("class='token'", f"class='{token_class}'")
class_name=token_class
if token_id not in selected_tokens
else selected_token_class,
)
)

# add break line
token_htmls.append("<br><br>")
# add break line
token_htmls.append("<br><br>")

html_str = f"""
<style>.{token_class} {{ {_token_style_str} }} #{hover_div_id} {{ height: 100px; font-family: monospace; }}</style>
<style>.{token_class} {{ {_token_style_str}}} .{selected_token_class} {{ {_token_emphasized_style_str} }} #{hover_div_id} {{ height: 100px; font-family: monospace; }}</style>
{"".join(token_htmls)} <div id='{hover_div_id}'></div>
<script>
(function() {{
var token_divs = document.querySelectorAll('.{token_class}');
token_divs = Array.from(token_divs).concat(Array.from(document.querySelectorAll('.{selected_token_class}')));
var hover_info = document.getElementById('{hover_div_id}');
Expand All @@ -239,19 +250,18 @@ def vis_pos_map(
</script>
"""
display(HTML(html_str))
return html_str


def token_selector(
vocab_map: dict[str, int]
) -> tuple[pn.widgets.MultiChoice, list[int]]:
tokens = list(vocab_map.keys())
token_selector = pn.widgets.MultiChoice(name="Tokens", options=tokens)
token_ids = [vocab_map[token] for token in cast(list[str], token_selector.value)]
token_selector_ = pn.widgets.MultiChoice(name="Tokens", options=tokens)
token_ids = [vocab_map[token] for token in cast(list[str], token_selector_.value)]

def update_tokens(event):
token_ids.clear()
token_ids.extend([vocab_map[token] for token in event.new])

token_selector.param.watch(update_tokens, "value")
return token_selector, token_ids
token_selector_.param.watch(update_tokens, "value")
return token_selector_, token_ids
13 changes: 5 additions & 8 deletions src/delphi/eval/vis_per_token_model.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from typing import Union

import ipywidgets
import numpy as np
import plotly.graph_objects as go


def visualize_per_token_category(
input: dict[Union[str, int], dict[str, tuple]],
def visualize_selected_tokens(
input: dict[Union[str, int], tuple[float, float, float]],
log_scale=False,
line_metric="Means",
checkpoint_mode=True,
Expand All @@ -17,18 +16,16 @@ def visualize_per_token_category(
background_color="AliceBlue",
) -> go.FigureWidget:
input_x = list(input.keys())
categories = list(input[input_x[0]].keys())
category = categories[0]

def get_hovertexts(mid: np.ndarray, lo: np.ndarray, hi: np.ndarray) -> list[str]:
return [f"Loss: {m:.3f} ({l:.3f}, {h:.3f})" for m, l, h in zip(mid, lo, hi)]

def get_plot_values(category: str) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
x = np.array([input[x][category] for x in input_x]).T
def get_plot_values() -> tuple[np.ndarray, np.ndarray, np.ndarray]:
x = np.array([input[x] for x in input_x]).T
means, err_lo, err_hi = x[0], x[1], x[2]
return means, err_lo, err_hi

means, err_lo, err_hi = get_plot_values(category)
means, err_lo, err_hi = get_plot_values()

if checkpoint_mode:
scatter_plot = go.Figure(
Expand Down
Loading

0 comments on commit 71f77fd

Please sign in to comment.