Skip to content

Commit

Permalink
Merge branch 'main' into simplify_run_training
Browse files Browse the repository at this point in the history
  • Loading branch information
jaidhyani authored Apr 1, 2024
2 parents 53ec272 + 1b3ce22 commit 8ec1043
Show file tree
Hide file tree
Showing 7 changed files with 433 additions and 2 deletions.
171 changes: 171 additions & 0 deletions notebooks/model_diff.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import pickle\n",
"\n",
"\n",
"from datasets import load_dataset, Dataset\n",
"\n",
"\n",
"from typing import cast\n",
"from ipywidgets import interact\n",
"import ipywidgets as widgets\n",
"\n",
"\n",
"from transformers import AutoTokenizer\n",
"from delphi.constants import STATIC_ASSETS_DIR\n",
"from delphi.eval.token_positions import get_all_tok_metrics_in_label\n",
"from delphi.eval.vis import vis_pos_map\n",
"from delphi.eval.constants import LLAMA2_NEXT_LOGPROBS_DATASETS_MAP\n",
"\n",
"# from delphi.train.utils import get_device\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"tokenizer = AutoTokenizer.from_pretrained(\"delphi-suite/stories-tokenizer\")\n",
"token_ids = (\n",
" cast(\n",
" Dataset,\n",
" load_dataset(\n",
" \"delphi-suite/v0-tinystories-v2-clean-tokenized\", split=\"validation\"\n",
" ),\n",
" )\n",
" .with_format(\"torch\")\n",
" .map(lambda x: {\"tokens\": x[\"tokens\"].to(device)})\n",
")\n",
"\n",
"next_logprobs = { # preloading all the logprobs datasets for interactive use\n",
" model_name: (\n",
" cast(\n",
" Dataset,\n",
" load_dataset(f\"{dataset_name}\", split=\"validation\"),\n",
" )\n",
" .with_format(\"torch\")\n",
" .map(lambda x: {\"logprobs\": x[\"logprobs\"].to(device)})\n",
" )\n",
" for model_name, dataset_name in LLAMA2_NEXT_LOGPROBS_DATASETS_MAP.items()\n",
"}\n",
"\n",
"token_labels_filename = \"labelled_token_ids_dict.pkl\"\n",
"with open(f\"{STATIC_ASSETS_DIR.joinpath(token_labels_filename)}\", \"rb\") as f:\n",
" token_labels = pickle.load(f)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8e6f7079bf3b43bcb4b1afb904b36d11",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"interactive(children=(FloatRangeSlider(value=(0.25, 0.75), description='Start quantile', max=1.0, step=0.01), …"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"<function __main__.show_pos_map(quantile: tuple[float, float], model_name_1: str, model_name_2: str, label: str, samples: int)>"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def show_pos_map(\n",
" quantile: tuple[float, float],\n",
" model_name_1: str,\n",
" model_name_2: str,\n",
" label: str,\n",
" samples: int,\n",
"):\n",
" token_id_t = token_ids[\"tokens\"]\n",
" logprobs_diff = next_logprobs[model_name_2][\"logprobs\"] - next_logprobs[model_name_1][\"logprobs\"] # type: ignore\n",
" pos_to_diff = get_all_tok_metrics_in_label(token_id_t, token_labels=token_labels, metrics=logprobs_diff, label=label, q_start=quantile[0], q_end=quantile[1]) # type: ignore\n",
" try:\n",
" _ = vis_pos_map(pos_to_diff, token_id_t, tokenizer, sample=samples) # type: ignore\n",
" except ValueError:\n",
" print(\"No tokens found in this label\")\n",
" return\n",
"\n",
"\n",
"interact(\n",
" show_pos_map,\n",
" quantile=widgets.FloatRangeSlider(\n",
" min=0.0, max=1.0, step=0.01, description=\"Start quantile\"\n",
" ),\n",
" samples=widgets.IntSlider(min=1, max=5, description=\"Samples\", value=2),\n",
" model_name_1=widgets.Dropdown(\n",
" options=LLAMA2_NEXT_LOGPROBS_DATASETS_MAP.keys(),\n",
" description=\"Model 1\",\n",
" value=\"llama2-100k\",\n",
" ),\n",
" model_name_2=widgets.Dropdown(\n",
" options=LLAMA2_NEXT_LOGPROBS_DATASETS_MAP.keys(),\n",
" description=\"Model 2\",\n",
" value=\"llama2-200k\",\n",
" ),\n",
" label=widgets.Dropdown(\n",
" options=token_labels[0].keys(), description=\"Label\", value=\"Is Noun\"\n",
" ),\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
12 changes: 12 additions & 0 deletions src/delphi/eval/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,15 @@
"delphi-llama2-12.8m",
"delphi-llama2-25.6m",
]

LLAMA2_NEXT_LOGPROBS_DATASETS_MAP = {
"llama2-100k": "delphi-suite/v0-next-logprobs-llama2-100k",
"llama2-200k": "delphi-suite/v0-next-logprobs-llama2-200k",
"llama2-400k": "delphi-suite/v0-next-logprobs-llama2-400k",
"llama2-800k": "delphi-suite/v0-next-logprobs-llama2-800k",
"llama2-1.6m": "delphi-suite/v0-next-logprobs-llama2-1.6m",
"llama2-3.2m": "delphi-suite/v0-next-logprobs-llama2-3.2m",
"llama2-6.4m": "delphi-suite/v0-next-logprobs-llama2-6.4m",
"llama2-12.8m": "delphi-suite/v0-next-logprobs-llama2-12.8m",
"llama2-25.6m": "delphi-suite/v0-next-logprobs-llama2-25.6m",
}
53 changes: 53 additions & 0 deletions src/delphi/eval/token_positions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from numbers import Number
from typing import Optional, cast

import torch
from datasets import Dataset
from jaxtyping import Int

from delphi.eval.utils import dict_filter_quantile


def get_all_tok_metrics_in_label(
token_ids: Int[torch.Tensor, "prompt pos"],
token_labels: dict[int, dict[str, bool]],
metrics: torch.Tensor,
label: str,
q_start: Optional[float] = None,
q_end: Optional[float] = None,
) -> dict[tuple[int, int], float]:
"""
From the token_map, get all the positions of the tokens that have a certain label.
We don't use the token_map because for sampling purposes, iterating through token_ids is more efficient.
Optionally, filter the tokens based on the quantile range of the metrics.
Args:
- token_ids (Dataset): token_ids dataset e.g. token_ids[0] = {"tokens": [[1, 2, ...], [2, 5, ...], ...]}
- token_labels (dict[int, dict[str, bool]]): dictionary of token labels e.g. { 0: {"Is Noun": True, "Is Verb": False}, ...}
- metrics (torch.Tensor): tensor of metrics to search through e.g. torch.tensor([[0.1, 0.2, ...], [0.3, 0.4, ...], ...])
- label (str): the label to search for
- q_start (float): the start of the quantile range to filter the metrics e.g. 0.1
- q_end (float): the end of the quantile range to filter the metrics e.g. 0.9
Returns:
- tok_positions (dict[tuple[int, int], Number]): dictionary of token positions and their corresponding metrics
"""

# check if metrics have the same dimensions as token_ids
if metrics.shape != token_ids.shape:
raise ValueError(
f"Expected metrics to have the same shape as token_ids, but got {metrics.shape} and {token_ids.shape} instead."
)

tok_positions = {}
for prompt_pos, prompt in enumerate(token_ids.numpy()):
for tok_pos, tok in enumerate(prompt):
if token_labels[tok][label]:
tok_positions[(prompt_pos, tok_pos)] = metrics[
prompt_pos, tok_pos
].item()

if q_start is not None and q_end is not None:
tok_positions = dict_filter_quantile(tok_positions, q_start, q_end)

return tok_positions
15 changes: 14 additions & 1 deletion src/delphi/eval/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import logging
from collections.abc import Callable
from typing import cast
from typing import Any, cast

import numpy as np
import torch
from datasets import Dataset, load_dataset
from jaxtyping import Float, Int
Expand Down Expand Up @@ -118,3 +119,15 @@ def load_logprob_datasets(split: str = "validation") -> dict[str, list[list[floa
model: cast(dict, load_logprob_dataset(model)[split])["logprobs"] # type: ignore
for model in constants.LLAMA2_MODELS
}


def dict_filter_quantile(
d: dict[Any, float], q_start: float, q_end: float
) -> dict[Any, float]:
if not (0 <= q_start < q_end <= 1):
raise ValueError("Invalid quantile range")
q_start_val = np.nanquantile(list(d.values()), q_start)
q_end_val = np.nanquantile(list(d.values()), q_end)
return {
k: v for k, v in d.items() if q_start_val <= v <= q_end_val and not np.isnan(v)
}
101 changes: 101 additions & 0 deletions src/delphi/eval/vis.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import math
import random
import uuid
from typing import cast

Expand All @@ -20,6 +22,27 @@ def probs_to_colors(probs: Float[torch.Tensor, "next_pos"]) -> list[str]:
return colors


def single_loss_diff_to_color(loss_diff: float) -> str:
# if loss_diff is negative, we want the color to be red
# if loss_diff is positive, we want the color to be green
# if loss_diff is 0, we want the color to be white
# the color should be more intense the larger the absolute value of loss_diff

def sigmoid(x: float) -> float:
return 1 / (1 + math.exp(-x))

scaled_loss_diff = sigmoid(loss_diff) # scale to 0-1

if scaled_loss_diff < 0.5: # red
red_val = 255
green_blue_val = min(int(255 * 2 * scaled_loss_diff), 255)
return f"rgb({red_val}, {green_blue_val}, {green_blue_val})"
else: # green
green_val = 255
red_blue_val = min(int(255 * 2 * (1 - scaled_loss_diff)), 255)
return f"rgb({red_blue_val}, {green_val}, {red_blue_val})"


def to_tok_prob_str(tok: int, prob: float, tokenizer: PreTrainedTokenizerBase) -> str:
tok_str = tokenizer.decode(tok).replace(" ", "&nbsp;").replace("\n", r"\n")
prob_str = f"{prob:.2%}"
Expand Down Expand Up @@ -141,6 +164,84 @@ def vis_sample_prediction_probs(
return html_str


def vis_pos_map(
pos_map: dict[tuple[int, int], float | int],
token_ids: Int[torch.Tensor, "prompt pos"],
tokenizer: PreTrainedTokenizerBase,
sample: int = 3,
):
"""
Randomly sample from pos_map and visualize the loss diff at the corresponding position.
"""

token_htmls = []
unique_id = str(uuid.uuid4())
token_class = f"token_{unique_id}"
hover_div_id = f"hover_info_{unique_id}"

# choose n random keys from pos_map
keys = random.sample(list(pos_map.keys()), k=sample)

for key in keys:
prompt, pos = key
pre_toks = token_ids[prompt][:pos]
mask = torch.isin(pre_toks, torch.tensor([0, 1], dtype=torch.int8))
pre_toks = pre_toks[
~mask
] # remove <unk> and <s> tokens, <s> cause strikethrough in html

for i in range(pre_toks.shape[0]):
pre_tok = cast(int, pre_toks[i].item())
token_htmls.append(
token_to_html(pre_tok, tokenizer, bg_color="white", data={}).replace(
"class='token'", f"class='{token_class}'"
)
)

tok = cast(int, token_ids[prompt][pos].item())
value = cast(float, pos_map[key])

token_htmls.append(
token_to_html(
tok,
tokenizer,
bg_color=single_loss_diff_to_color(value),
data={"loss-diff": f"{value:.2f}"},
).replace("class='token'", f"class='{token_class}'")
)

# add break line
token_htmls.append("<br><br>")

html_str = f"""
<style>.{token_class} {{ {_token_style_str} }} #{hover_div_id} {{ height: 100px; font-family: monospace; }}</style>
{"".join(token_htmls)} <div id='{hover_div_id}'></div>
<script>
(function() {{
var token_divs = document.querySelectorAll('.{token_class}');
var hover_info = document.getElementById('{hover_div_id}');
token_divs.forEach(function(token_div) {{
token_div.addEventListener('mousemove', function(e) {{
hover_info.innerHTML = ""
for( var d in this.dataset) {{
hover_info.innerHTML += "<b>" + d + "</b> ";
hover_info.innerHTML += this.dataset[d] + "<br>";
}}
}});
token_div.addEventListener('mouseout', function(e) {{
hover_info.innerHTML = ""
}});
}});
}})();
</script>
"""
display(HTML(html_str))
return html_str


def token_selector(
vocab_map: dict[str, int]
) -> tuple[pn.widgets.MultiChoice, list[int]]:
Expand Down
Loading

0 comments on commit 8ec1043

Please sign in to comment.