-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into simplify_run_training
- Loading branch information
Showing
7 changed files
with
433 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import torch\n", | ||
"import pickle\n", | ||
"\n", | ||
"\n", | ||
"from datasets import load_dataset, Dataset\n", | ||
"\n", | ||
"\n", | ||
"from typing import cast\n", | ||
"from ipywidgets import interact\n", | ||
"import ipywidgets as widgets\n", | ||
"\n", | ||
"\n", | ||
"from transformers import AutoTokenizer\n", | ||
"from delphi.constants import STATIC_ASSETS_DIR\n", | ||
"from delphi.eval.token_positions import get_all_tok_metrics_in_label\n", | ||
"from delphi.eval.vis import vis_pos_map\n", | ||
"from delphi.eval.constants import LLAMA2_NEXT_LOGPROBS_DATASETS_MAP\n", | ||
"\n", | ||
"# from delphi.train.utils import get_device\n", | ||
"\n", | ||
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"tokenizer = AutoTokenizer.from_pretrained(\"delphi-suite/stories-tokenizer\")\n", | ||
"token_ids = (\n", | ||
" cast(\n", | ||
" Dataset,\n", | ||
" load_dataset(\n", | ||
" \"delphi-suite/v0-tinystories-v2-clean-tokenized\", split=\"validation\"\n", | ||
" ),\n", | ||
" )\n", | ||
" .with_format(\"torch\")\n", | ||
" .map(lambda x: {\"tokens\": x[\"tokens\"].to(device)})\n", | ||
")\n", | ||
"\n", | ||
"next_logprobs = { # preloading all the logprobs datasets for interactive use\n", | ||
" model_name: (\n", | ||
" cast(\n", | ||
" Dataset,\n", | ||
" load_dataset(f\"{dataset_name}\", split=\"validation\"),\n", | ||
" )\n", | ||
" .with_format(\"torch\")\n", | ||
" .map(lambda x: {\"logprobs\": x[\"logprobs\"].to(device)})\n", | ||
" )\n", | ||
" for model_name, dataset_name in LLAMA2_NEXT_LOGPROBS_DATASETS_MAP.items()\n", | ||
"}\n", | ||
"\n", | ||
"token_labels_filename = \"labelled_token_ids_dict.pkl\"\n", | ||
"with open(f\"{STATIC_ASSETS_DIR.joinpath(token_labels_filename)}\", \"rb\") as f:\n", | ||
" token_labels = pickle.load(f)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "8e6f7079bf3b43bcb4b1afb904b36d11", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"interactive(children=(FloatRangeSlider(value=(0.25, 0.75), description='Start quantile', max=1.0, step=0.01), …" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"<function __main__.show_pos_map(quantile: tuple[float, float], model_name_1: str, model_name_2: str, label: str, samples: int)>" | ||
] | ||
}, | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"def show_pos_map(\n", | ||
" quantile: tuple[float, float],\n", | ||
" model_name_1: str,\n", | ||
" model_name_2: str,\n", | ||
" label: str,\n", | ||
" samples: int,\n", | ||
"):\n", | ||
" token_id_t = token_ids[\"tokens\"]\n", | ||
" logprobs_diff = next_logprobs[model_name_2][\"logprobs\"] - next_logprobs[model_name_1][\"logprobs\"] # type: ignore\n", | ||
" pos_to_diff = get_all_tok_metrics_in_label(token_id_t, token_labels=token_labels, metrics=logprobs_diff, label=label, q_start=quantile[0], q_end=quantile[1]) # type: ignore\n", | ||
" try:\n", | ||
" _ = vis_pos_map(pos_to_diff, token_id_t, tokenizer, sample=samples) # type: ignore\n", | ||
" except ValueError:\n", | ||
" print(\"No tokens found in this label\")\n", | ||
" return\n", | ||
"\n", | ||
"\n", | ||
"interact(\n", | ||
" show_pos_map,\n", | ||
" quantile=widgets.FloatRangeSlider(\n", | ||
" min=0.0, max=1.0, step=0.01, description=\"Start quantile\"\n", | ||
" ),\n", | ||
" samples=widgets.IntSlider(min=1, max=5, description=\"Samples\", value=2),\n", | ||
" model_name_1=widgets.Dropdown(\n", | ||
" options=LLAMA2_NEXT_LOGPROBS_DATASETS_MAP.keys(),\n", | ||
" description=\"Model 1\",\n", | ||
" value=\"llama2-100k\",\n", | ||
" ),\n", | ||
" model_name_2=widgets.Dropdown(\n", | ||
" options=LLAMA2_NEXT_LOGPROBS_DATASETS_MAP.keys(),\n", | ||
" description=\"Model 2\",\n", | ||
" value=\"llama2-200k\",\n", | ||
" ),\n", | ||
" label=widgets.Dropdown(\n", | ||
" options=token_labels[0].keys(), description=\"Label\", value=\"Is Noun\"\n", | ||
" ),\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": ".venv", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.6" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
from numbers import Number | ||
from typing import Optional, cast | ||
|
||
import torch | ||
from datasets import Dataset | ||
from jaxtyping import Int | ||
|
||
from delphi.eval.utils import dict_filter_quantile | ||
|
||
|
||
def get_all_tok_metrics_in_label( | ||
token_ids: Int[torch.Tensor, "prompt pos"], | ||
token_labels: dict[int, dict[str, bool]], | ||
metrics: torch.Tensor, | ||
label: str, | ||
q_start: Optional[float] = None, | ||
q_end: Optional[float] = None, | ||
) -> dict[tuple[int, int], float]: | ||
""" | ||
From the token_map, get all the positions of the tokens that have a certain label. | ||
We don't use the token_map because for sampling purposes, iterating through token_ids is more efficient. | ||
Optionally, filter the tokens based on the quantile range of the metrics. | ||
Args: | ||
- token_ids (Dataset): token_ids dataset e.g. token_ids[0] = {"tokens": [[1, 2, ...], [2, 5, ...], ...]} | ||
- token_labels (dict[int, dict[str, bool]]): dictionary of token labels e.g. { 0: {"Is Noun": True, "Is Verb": False}, ...} | ||
- metrics (torch.Tensor): tensor of metrics to search through e.g. torch.tensor([[0.1, 0.2, ...], [0.3, 0.4, ...], ...]) | ||
- label (str): the label to search for | ||
- q_start (float): the start of the quantile range to filter the metrics e.g. 0.1 | ||
- q_end (float): the end of the quantile range to filter the metrics e.g. 0.9 | ||
Returns: | ||
- tok_positions (dict[tuple[int, int], Number]): dictionary of token positions and their corresponding metrics | ||
""" | ||
|
||
# check if metrics have the same dimensions as token_ids | ||
if metrics.shape != token_ids.shape: | ||
raise ValueError( | ||
f"Expected metrics to have the same shape as token_ids, but got {metrics.shape} and {token_ids.shape} instead." | ||
) | ||
|
||
tok_positions = {} | ||
for prompt_pos, prompt in enumerate(token_ids.numpy()): | ||
for tok_pos, tok in enumerate(prompt): | ||
if token_labels[tok][label]: | ||
tok_positions[(prompt_pos, tok_pos)] = metrics[ | ||
prompt_pos, tok_pos | ||
].item() | ||
|
||
if q_start is not None and q_end is not None: | ||
tok_positions = dict_filter_quantile(tok_positions, q_start, q_end) | ||
|
||
return tok_positions |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.