Skip to content

Commit

Permalink
Remove token category from vis function
Browse files Browse the repository at this point in the history
  • Loading branch information
Siwei Li authored and menamerai committed Apr 11, 2024
1 parent 367d0d8 commit 4c9cb21
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 97 deletions.
106 changes: 16 additions & 90 deletions notebooks/eval_notebook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@
"from typing import cast\n",
"from delphi.eval.calc_model_group_stats import calc_model_group_stats\n",
"from collections import defaultdict\n",
"from delphi.eval.vis_per_token_model import visualize_per_token_category\n",
"from delphi.eval.vis_per_token_model import visualize_selected_tokens\n",
"from ipywidgets import interact\n",
"from delphi.eval.token_positions import get_all_tok_metrics_in_label\n",
"from delphi.eval.vis import vis_pos_map\n",
Expand Down Expand Up @@ -265,86 +265,13 @@
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'100k': {'mean': -1.070133,\n",
" 'median': -0.6912808,\n",
" 'min': -7.122874,\n",
" 'max': -0.017355476,\n",
" '25th': -1.3876090049743652,\n",
" '75th': -0.31879591941833496},\n",
" '200k': {'mean': -1.0078776,\n",
" 'median': -0.6108246,\n",
" 'min': -7.1288815,\n",
" 'max': -0.006140296,\n",
" '25th': -1.3651126325130463,\n",
" '75th': -0.21433717757463455},\n",
" '400k': {'mean': -0.8517932,\n",
" 'median': -0.5542941,\n",
" 'min': -6.2654996,\n",
" 'max': -0.0039506494,\n",
" '25th': -1.0751871466636658,\n",
" '75th': -0.13063477724790573},\n",
" '800k': {'mean': -0.78640485,\n",
" 'median': -0.31092834,\n",
" 'min': -6.6738915,\n",
" 'max': -0.0011469699,\n",
" '25th': -1.117132544517517,\n",
" '75th': -0.11057000048458576},\n",
" '1.6m': {'mean': -0.74975806,\n",
" 'median': -0.30155045,\n",
" 'min': -5.3355055,\n",
" 'max': -0.00043644916,\n",
" '25th': -1.0707703530788422,\n",
" '75th': -0.057139165699481964},\n",
" '3.2m': {'mean': -0.69542694,\n",
" 'median': -0.263493,\n",
" 'min': -4.481785,\n",
" 'max': -0.00014411364,\n",
" '25th': -1.0095961689949036,\n",
" '75th': -0.039097873494029045},\n",
" '6.4m': {'mean': -0.60625404,\n",
" 'median': -0.19129953,\n",
" 'min': -5.051317,\n",
" 'max': -7.00926e-05,\n",
" '25th': -0.804155021905899,\n",
" '75th': -0.028934753965586424},\n",
" '12.8m': {'mean': -0.56314814,\n",
" 'median': -0.13154678,\n",
" 'min': -4.793927,\n",
" 'max': -1.2159274e-05,\n",
" '25th': -0.8005392700433731,\n",
" '75th': -0.01866082102060318},\n",
" '25.6m': {'mean': -0.56998307,\n",
" 'median': -0.091308385,\n",
" 'min': -4.9958663,\n",
" 'max': -1.0967195e-05,\n",
" '25th': -0.577660083770752,\n",
" '75th': -0.006869094213470817}}"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model_group_stats"
]
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b55f17fb67f641f580d281b0e487aa99",
"model_id": "09e823365e0f436b8c500f6fac5b2095",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -356,10 +283,10 @@
" 'name': 'Upper Bound',\n",
" 'showlegend': False,\n",
" 'type': 'scatter',\n",
" 'uid': '59ab3f3e-a75e-4d8a-a7a9-89a2fc60e357',\n",
" 'uid': '389fd792-e2aa-412d-9941-19d0927c3693',\n",
" 'x': [100k, 200k, 400k, 800k, 1.6m, 3.2m, 6.4m, 12.8m, 25.6m],\n",
" 'y': array([1.96107459, 1.69477943, 1.47217423, 1.20097491, 1.20183697, 1.0590816 ,\n",
" 0.85512815, 0.7318224 , 0.67268856])},\n",
" 'y': array([2.07888979, 1.97593722, 1.62948126, 1.42806089, 1.3723208 , 1.27308917,\n",
" 0.99545455, 0.93208605, 0.66896847])},\n",
" {'fill': 'tonexty',\n",
" 'fillcolor': 'rgba(68, 68, 68, 0.3)',\n",
" 'line': {'width': 0},\n",
Expand All @@ -368,39 +295,38 @@
" 'name': 'Lower Bound',\n",
" 'showlegend': False,\n",
" 'type': 'scatter',\n",
" 'uid': '0522e51c-57f8-4180-b557-a8727855f2fa',\n",
" 'uid': 'b98bb36c-d650-43c3-9fa0-ca54ce533d1f',\n",
" 'x': [100k, 200k, 400k, 800k, 1.6m, 3.2m, 6.4m, 12.8m, 25.6m],\n",
" 'y': array([0.42223121, 0.33907843, 0.30785988, 0.22143046, 0.187295 , 0.16716134,\n",
" 0.12061603, 0.08660824, 0.07536004])},\n",
" 'y': array([0.37248486, 0.39648741, 0.42365933, 0.20035834, 0.24441128, 0.22439513,\n",
" 0.16236477, 0.11288596, 0.08443929])},\n",
" {'marker': {'color': 'rgb(31, 119, 180)', 'line': {'color': 'rgb(31, 119, 180)', 'width': 1}, 'size': 0},\n",
" 'mode': 'lines',\n",
" 'name': 'Means',\n",
" 'type': 'scatter',\n",
" 'uid': 'a58aa420-f1e7-4a7f-a5b4-b0a1eb9283fd',\n",
" 'uid': 'fbad2bf8-6845-4a3a-b0d8-8a0053769e23',\n",
" 'x': [100k, 200k, 400k, 800k, 1.6m, 3.2m, 6.4m, 12.8m, 25.6m],\n",
" 'y': array([0.53871393, 0.42690507, 0.36171991, 0.26546207, 0.21804789, 0.19020958,\n",
" 0.14060442, 0.10136253, 0.0816711 ])}],\n",
" 'y': array([0.69128078, 0.61082458, 0.55429411, 0.31092834, 0.30155045, 0.263493 ,\n",
" 0.19129953, 0.13154678, 0.09130839])}],\n",
" 'layout': {'template': '...'}\n",
"})"
]
},
"execution_count": 8,
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"performance_data = defaultdict(dict)\n",
"for suffix in suffixes:\n",
" stats = model_group_stats[(suffix, \"selected\")]\n",
" performance_data[suffix][\"selected\"] = (\n",
" stats = model_group_stats[suffix]\n",
" performance_data[suffix] = (\n",
" -stats[\"median\"],\n",
" -stats[\"75th\"],\n",
" -stats[\"25th\"],\n",
" )\n",
"\n",
"# TODO: this is using the older version of the plotting func\n",
"visualize_per_token_category(performance_data, log_scale=True)"
"visualize_selected_tokens(performance_data, log_scale=True)"
]
},
{
Expand Down
12 changes: 5 additions & 7 deletions src/delphi/eval/vis_per_token_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import plotly.graph_objects as go


def visualize_per_token_category(
input: dict[Union[str, int], dict[str, tuple]],
def visualize_selected_tokens(
input: dict[Union[str, int], dict[str, float]],
log_scale=False,
line_metric="Means",
checkpoint_mode=True,
Expand All @@ -17,18 +17,16 @@ def visualize_per_token_category(
background_color="AliceBlue",
) -> go.FigureWidget:
input_x = list(input.keys())
categories = list(input[input_x[0]].keys())
category = categories[0]

def get_hovertexts(mid: np.ndarray, lo: np.ndarray, hi: np.ndarray) -> list[str]:
return [f"Loss: {m:.3f} ({l:.3f}, {h:.3f})" for m, l, h in zip(mid, lo, hi)]

def get_plot_values(category: str) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
x = np.array([input[x][category] for x in input_x]).T
def get_plot_values() -> tuple[np.ndarray, np.ndarray, np.ndarray]:
x = np.array([input[x] for x in input_x]).T
means, err_lo, err_hi = x[0], x[1], x[2]
return means, err_lo, err_hi

means, err_lo, err_hi = get_plot_values(category)
means, err_lo, err_hi = get_plot_values()

if checkpoint_mode:
scatter_plot = go.Figure(
Expand Down

0 comments on commit 4c9cb21

Please sign in to comment.