Skip to content

Commit

Permalink
Add "checkpoint_mode" kwarg to plotting (#99)
Browse files Browse the repository at this point in the history
* Add checkpoint_mode kwarg to plotting

* Remove dummy key of token_category

* Revert "Remove dummy key of token_category"

This reverts commit 2edb2c5.

* Update vis notebook

* Remove kwargs and have fixed args in place

---------

Co-authored-by: Siwei Li <[email protected]>
  • Loading branch information
siwei-li and Siwei Li authored Apr 3, 2024
1 parent bfce886 commit 90410c8
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 46 deletions.
89 changes: 77 additions & 12 deletions notebooks/per_token_plot.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,53 @@
"cells": [
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "696575431f65420e9dc22c3b3476bfbb",
"model_id": "fbda6a916fe84814be64a40423196d76",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"VBox(children=(Dropdown(description='Token Category:', options=('nouns', 'verbs', 'prepositions', 'adjectives'…"
"FigureWidget({\n",
" 'data': [{'line': {'width': 0},\n",
" 'marker': {'color': 'rgba(68, 68, 68, 0.3)'},\n",
" 'mode': 'lines',\n",
" 'name': 'Upper Bound',\n",
" 'showlegend': False,\n",
" 'type': 'scatter',\n",
" 'uid': 'a3590fcd-466d-4a73-b167-194ab728efcd',\n",
" 'x': [0, 1, 2, ..., 497, 498, 499],\n",
" 'y': array([2.34006592, 2.41241021, 2.57781922, ..., 2.56474203, 2.59573629,\n",
" 2.43304471])},\n",
" {'fill': 'tonexty',\n",
" 'fillcolor': 'rgba(68, 68, 68, 0.3)',\n",
" 'line': {'width': 0},\n",
" 'marker': {'color': 'rgba(68, 68, 68, 0.3)'},\n",
" 'mode': 'lines',\n",
" 'name': 'Lower Bound',\n",
" 'showlegend': False,\n",
" 'type': 'scatter',\n",
" 'uid': 'fda82808-c8ff-4b6c-878d-c76d66c8ce17',\n",
" 'x': [0, 1, 2, ..., 497, 498, 499],\n",
" 'y': array([0.93626447, 0.9302987 , 0.99836227, ..., 0.95607835, 0.76146911,\n",
" 0.81709211])},\n",
" {'marker': {'color': 'rgb(31, 119, 180)', 'line': {'color': 'rgb(31, 119, 180)', 'width': 1}, 'size': 0},\n",
" 'mode': 'lines',\n",
" 'name': 'Means',\n",
" 'type': 'scatter',\n",
" 'uid': 'b11dfbd0-c130-4a97-a8ff-b8c753b95035',\n",
" 'x': [0, 1, 2, ..., 497, 498, 499],\n",
" 'y': array([1.3701917 , 1.4372206 , 1.53251235, ..., 1.55583357, 1.50179179,\n",
" 1.45715223])}],\n",
" 'layout': {'template': '...'}\n",
"})"
]
},
"execution_count": 5,
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -33,9 +65,10 @@
"random.seed(0)\n",
"\n",
"# generate mock data\n",
"model_names = ['llama2-100k', 'llama2-200k', 'llama2-1m', 'llama2-10m']\n",
"# model_names = ['llama2-100k', 'llama2-200k', 'llama2-1m', 'llama2-10m', \"0\"]\n",
"model_names = list(range(500))\n",
"categories = ['nouns', 'verbs', 'prepositions', 'adjectives']\n",
"entries = [200, 100, 150, 300]\n",
"entries = [200, 100, 150, 300, 100]*100\n",
"performance_data = defaultdict()\n",
"for i, model in enumerate(model_names):\n",
" performance_data[model] = defaultdict()\n",
Expand All @@ -47,32 +80,64 @@
" performance_data[model][cat] = (-means, err_low, err_hi)\n",
"\n",
"\n",
"visualize_per_token_category(performance_data, log_scale=True, bg_color='LightGrey', line_color=\"Red\", marker_color='Orange', bar_color='Green')"
"visualize_per_token_category(performance_data, log_scale=True, checkpoint_mode=True)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cb3af5248a4a40118c36a527c927289d",
"model_id": "993e5d66ae56462a8eeec2c9ac6bd972",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"VBox(children=(Dropdown(description='Token Category:', options=('nouns', 'verbs', 'prepositions', 'adjectives'…"
"FigureWidget({\n",
" 'data': [{'line': {'width': 0},\n",
" 'marker': {'color': 'wheat'},\n",
" 'mode': 'lines',\n",
" 'name': 'Upper Bound',\n",
" 'showlegend': False,\n",
" 'type': 'scatter',\n",
" 'uid': '56999008-205c-4592-a3f7-ea61e3e09d8e',\n",
" 'x': [0, 1, 2, ..., 497, 498, 499],\n",
" 'y': array([2.34006592, 2.41241021, 2.57781922, ..., 2.56474203, 2.59573629,\n",
" 2.43304471])},\n",
" {'fill': 'tonexty',\n",
" 'fillcolor': 'wheat',\n",
" 'line': {'width': 0},\n",
" 'marker': {'color': 'wheat'},\n",
" 'mode': 'lines',\n",
" 'name': 'Lower Bound',\n",
" 'showlegend': False,\n",
" 'type': 'scatter',\n",
" 'uid': 'be8a04f1-b8c4-46af-bf5e-03c942eff19f',\n",
" 'x': [0, 1, 2, ..., 497, 498, 499],\n",
" 'y': array([0.93626447, 0.9302987 , 0.99836227, ..., 0.95607835, 0.76146911,\n",
" 0.81709211])},\n",
" {'marker': {'color': 'Orange', 'line': {'color': 'Orange', 'width': 1}, 'size': 0},\n",
" 'mode': 'lines',\n",
" 'name': 'Median',\n",
" 'type': 'scatter',\n",
" 'uid': '85fe5113-70fb-4aa7-9821-947287d84e1d',\n",
" 'x': [0, 1, 2, ..., 497, 498, 499],\n",
" 'y': array([1.3701917 , 1.4372206 , 1.53251235, ..., 1.55583357, 1.50179179,\n",
" 1.45715223])}],\n",
" 'layout': {'template': '...'}\n",
"})"
]
},
"execution_count": 3,
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"visualize_per_token_category(performance_data, log_scale=False)"
"visualize_per_token_category(performance_data, log_scale=True, checkpoint_mode=True, line_metric=\"Median\", line_color='Orange' , shade_color=\"wheat\")"
]
}
],
Expand Down
100 changes: 66 additions & 34 deletions src/delphi/eval/vis_per_token_model.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,99 @@
from typing import Union

import ipywidgets
import numpy as np
import plotly.graph_objects as go


def visualize_per_token_category(
input: dict[str, dict[str, tuple]], log_scale=False, **kwargs: str
) -> ipywidgets.VBox:
model_names = list(input.keys())
categories = list(input[model_names[0]].keys())
input: dict[Union[str, int], dict[str, tuple]],
log_scale=False,
line_metric="Means",
checkpoint_mode=True,
shade_color="rgba(68, 68, 68, 0.3)",
line_color="rgb(31, 119, 180)",
bar_color="purple",
marker_color="SkyBlue",
background_color="AliceBlue",
) -> go.FigureWidget:
input_x = list(input.keys())
categories = list(input[input_x[0]].keys())
category = categories[0]

def get_hovertexts(mid: np.ndarray, lo: np.ndarray, hi: np.ndarray) -> list[str]:
return [f"Loss: {m:.3f} ({l:.3f}, {h:.3f})" for m, l, h in zip(mid, lo, hi)]

def get_plot_values(category: str) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
x = np.array([input[name][category] for name in model_names]).T
x = np.array([input[x][category] for x in input_x]).T
means, err_lo, err_hi = x[0], x[1], x[2]
return means, err_lo, err_hi

means, err_low, err_hi = get_plot_values(category)
g = go.FigureWidget(
data=go.Scatter(
x=model_names,
means, err_lo, err_hi = get_plot_values(category)

if checkpoint_mode:
scatter_plot = go.Figure(
[
go.Scatter(
name="Upper Bound",
x=input_x,
y=means + err_hi,
mode="lines",
marker=dict(color=shade_color),
line=dict(width=0),
showlegend=False,
),
go.Scatter(
name="Lower Bound",
x=input_x,
y=means - err_lo,
marker=dict(color=shade_color),
line=dict(width=0),
mode="lines",
fillcolor=shade_color,
fill="tonexty",
showlegend=False,
),
go.Scatter(
name=line_metric,
x=input_x,
y=means,
mode="lines",
marker=dict(
color=line_color,
size=0,
line=dict(color=line_color, width=1),
),
),
]
)
else:
scatter_plot = go.Scatter(
x=input_x,
y=means,
error_y=dict(
type="data",
symmetric=False,
array=err_hi,
arrayminus=err_low,
color=kwargs.get("bar_color", "purple"),
arrayminus=err_lo,
color=bar_color,
),
marker=dict(
color=kwargs.get("marker_color", "SkyBlue"),
color=marker_color,
size=15,
line=dict(color=kwargs.get("line_color", "MediumPurple"), width=2),
line=dict(color=line_color, width=2),
),
hovertext=get_hovertexts(means, err_low, err_hi),
hovertext=get_hovertexts(means, err_lo, err_hi),
hoverinfo="text+x",
),
)
g = go.FigureWidget(
data=scatter_plot,
layout=go.Layout(
yaxis=dict(
title="Loss",
type="log" if log_scale else "linear",
),
plot_bgcolor=kwargs.get("bg_color", "AliceBlue"),
plot_bgcolor=background_color,
),
)

selected_category = ipywidgets.Dropdown(
options=categories,
placeholder="",
description="Token Category:",
disabled=False,
)

def response(change):
means, err_lo, err_hi = get_plot_values(selected_category.value)
with g.batch_update():
g.data[0].y = means
g.data[0].error_y["array"] = err_hi
g.data[0].error_y["arrayminus"] = err_lo
g.data[0].hovertext = get_hovertexts(means, err_lo, err_hi)

selected_category.observe(response, names="value")

return ipywidgets.VBox([selected_category, g])
return g

0 comments on commit 90410c8

Please sign in to comment.