Skip to content

Commit

Permalink
Add checkpoint_mode kwarg to plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
Siwei Li committed Mar 31, 2024
1 parent a104d39 commit 20c6fce
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 39 deletions.
89 changes: 77 additions & 12 deletions notebooks/per_token_plot.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,53 @@
"cells": [
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "696575431f65420e9dc22c3b3476bfbb",
"model_id": "deb5a98615624e32b91fb3fc4d155c7a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"VBox(children=(Dropdown(description='Token Category:', options=('nouns', 'verbs', 'prepositions', 'adjectives'…"
"FigureWidget({\n",
" 'data': [{'line': {'width': 0},\n",
" 'marker': {'color': '#444'},\n",
" 'mode': 'lines',\n",
" 'name': 'Upper Bound',\n",
" 'showlegend': False,\n",
" 'type': 'scatter',\n",
" 'uid': '4dcf6b1c-9b26-425e-be14-ba73fde289fb',\n",
" 'x': [0, 1, 2, ..., 497, 498, 499],\n",
" 'y': array([2.34006592, 2.41241021, 2.57781922, ..., 2.56474203, 2.59573629,\n",
" 2.43304471])},\n",
" {'fill': 'tonexty',\n",
" 'fillcolor': 'rgba(68, 68, 68, 0.3)',\n",
" 'line': {'width': 0},\n",
" 'marker': {'color': '#444'},\n",
" 'mode': 'lines',\n",
" 'name': 'Lower Bound',\n",
" 'showlegend': False,\n",
" 'type': 'scatter',\n",
" 'uid': 'c0a55b83-d045-4faa-9285-a927058cad75',\n",
" 'x': [0, 1, 2, ..., 497, 498, 499],\n",
" 'y': array([0.93626447, 0.9302987 , 0.99836227, ..., 0.95607835, 0.76146911,\n",
" 0.81709211])},\n",
" {'marker': {'color': 'rgb(31, 119, 180)', 'line': {'color': 'rgb(31, 119, 180)', 'width': 1}, 'size': 0},\n",
" 'mode': 'lines',\n",
" 'name': 'Means',\n",
" 'type': 'scatter',\n",
" 'uid': 'c4c0a68b-efa3-4930-aa0c-3c17451e3d2e',\n",
" 'x': [0, 1, 2, ..., 497, 498, 499],\n",
" 'y': array([1.3701917 , 1.4372206 , 1.53251235, ..., 1.55583357, 1.50179179,\n",
" 1.45715223])}],\n",
" 'layout': {'template': '...'}\n",
"})"
]
},
"execution_count": 5,
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -33,9 +65,10 @@
"random.seed(0)\n",
"\n",
"# generate mock data\n",
"model_names = ['llama2-100k', 'llama2-200k', 'llama2-1m', 'llama2-10m']\n",
"# model_names = ['llama2-100k', 'llama2-200k', 'llama2-1m', 'llama2-10m', \"0\"]\n",
"model_names = list(range(500))\n",
"categories = ['nouns', 'verbs', 'prepositions', 'adjectives']\n",
"entries = [200, 100, 150, 300]\n",
"entries = [200, 100, 150, 300, 100]*100\n",
"performance_data = defaultdict()\n",
"for i, model in enumerate(model_names):\n",
" performance_data[model] = defaultdict()\n",
Expand All @@ -47,32 +80,64 @@
" performance_data[model][cat] = (-means, err_low, err_hi)\n",
"\n",
"\n",
"visualize_per_token_category(performance_data, log_scale=True, bg_color='LightGrey', line_color=\"Red\", marker_color='Orange', bar_color='Green')"
"visualize_per_token_category(performance_data, log_scale=True, checkpoint_mode=True)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cb3af5248a4a40118c36a527c927289d",
"model_id": "4550d6e8c4f74396b180fd1223c4c3b2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"VBox(children=(Dropdown(description='Token Category:', options=('nouns', 'verbs', 'prepositions', 'adjectives'…"
"FigureWidget({\n",
" 'data': [{'line': {'width': 0},\n",
" 'marker': {'color': 'wheat'},\n",
" 'mode': 'lines',\n",
" 'name': 'Upper Bound',\n",
" 'showlegend': False,\n",
" 'type': 'scatter',\n",
" 'uid': '453a6e83-3b4f-4090-b8ea-6c8945dba824',\n",
" 'x': [0, 1, 2, ..., 497, 498, 499],\n",
" 'y': array([2.34006592, 2.41241021, 2.57781922, ..., 2.56474203, 2.59573629,\n",
" 2.43304471])},\n",
" {'fill': 'tonexty',\n",
" 'fillcolor': 'wheat',\n",
" 'line': {'width': 0},\n",
" 'marker': {'color': 'wheat'},\n",
" 'mode': 'lines',\n",
" 'name': 'Lower Bound',\n",
" 'showlegend': False,\n",
" 'type': 'scatter',\n",
" 'uid': '4128f987-e0a9-457c-ab4a-c0547589c988',\n",
" 'x': [0, 1, 2, ..., 497, 498, 499],\n",
" 'y': array([0.93626447, 0.9302987 , 0.99836227, ..., 0.95607835, 0.76146911,\n",
" 0.81709211])},\n",
" {'marker': {'color': 'Orange', 'line': {'color': 'Orange', 'width': 1}, 'size': 0},\n",
" 'mode': 'lines',\n",
" 'name': 'Median',\n",
" 'type': 'scatter',\n",
" 'uid': '06de0035-f3a6-4b79-aff2-5ae62130a5db',\n",
" 'x': [0, 1, 2, ..., 497, 498, 499],\n",
" 'y': array([1.3701917 , 1.4372206 , 1.53251235, ..., 1.55583357, 1.50179179,\n",
" 1.45715223])}],\n",
" 'layout': {'template': '...'}\n",
"})"
]
},
"execution_count": 3,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"visualize_per_token_category(performance_data, log_scale=False)"
"visualize_per_token_category(performance_data, log_scale=True, checkpoint_mode=True, line_metric=\"Median\", line_color='Orange' , shade_color=\"wheat\")"
]
}
],
Expand Down
106 changes: 79 additions & 27 deletions src/delphi/eval/vis_per_token_model.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,89 @@
from typing import Union

import ipywidgets
import numpy as np
import plotly.graph_objects as go


def visualize_per_token_category(
input: dict[str, dict[str, tuple]], log_scale=False, **kwargs: str
) -> ipywidgets.VBox:
model_names = list(input.keys())
categories = list(input[model_names[0]].keys())
input: dict[Union[str, int], dict[str, tuple]],
log_scale=False,
**kwargs: Union[str, bool],
# ) -> ipywidgets.VBox:
) -> go.FigureWidget:
input_x = list(input.keys())
categories = list(input[input_x[0]].keys())
category = categories[0]

def get_hovertexts(mid: np.ndarray, lo: np.ndarray, hi: np.ndarray) -> list[str]:
return [f"Loss: {m:.3f} ({l:.3f}, {h:.3f})" for m, l, h in zip(mid, lo, hi)]

def get_plot_values(category: str) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
x = np.array([input[name][category] for name in model_names]).T
x = np.array([input[x][category] for x in input_x]).T
means, err_lo, err_hi = x[0], x[1], x[2]
return means, err_lo, err_hi

means, err_low, err_hi = get_plot_values(category)
g = go.FigureWidget(
data=go.Scatter(
x=model_names,
means, err_lo, err_hi = get_plot_values(category)

if kwargs.get("checkpoint_mode"):
scatter_plot = go.Figure(
[
go.Scatter(
name="Upper Bound",
x=input_x,
y=means + err_hi,
mode="lines",
marker=dict(color=kwargs.get("shade_color", "#444")),
line=dict(width=0),
showlegend=False,
),
go.Scatter(
name="Lower Bound",
x=input_x,
y=means - err_lo,
marker=dict(color=kwargs.get("shade_color", "#444")),
line=dict(width=0),
mode="lines",
fillcolor=kwargs.get("shade_color", "rgba(68, 68, 68, 0.3)"),
fill="tonexty",
showlegend=False,
),
go.Scatter(
name=kwargs.get("line_metric", "Means"),
x=input_x,
y=means,
mode="lines",
marker=dict(
color=kwargs.get("line_color", "rgb(31, 119, 180)"),
size=0,
line=dict(
color=kwargs.get("line_color", "rgb(31, 119, 180)"), width=1
),
),
),
]
)
else:
scatter_plot = go.Scatter(
x=input_x,
y=means,
error_y=dict(
type="data",
symmetric=False,
array=err_hi,
arrayminus=err_low,
arrayminus=err_lo,
color=kwargs.get("bar_color", "purple"),
),
marker=dict(
color=kwargs.get("marker_color", "SkyBlue"),
size=15,
line=dict(color=kwargs.get("line_color", "MediumPurple"), width=2),
),
hovertext=get_hovertexts(means, err_low, err_hi),
hovertext=get_hovertexts(means, err_lo, err_hi),
hoverinfo="text+x",
),
)
g = go.FigureWidget(
data=scatter_plot,
layout=go.Layout(
yaxis=dict(
title="Loss",
Expand All @@ -47,21 +93,27 @@ def get_plot_values(category: str) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
),
)

selected_category = ipywidgets.Dropdown(
options=categories,
placeholder="",
description="Token Category:",
disabled=False,
)
# selected_category = ipywidgets.Dropdown(
# options=categories,
# placeholder="",
# description="Token Category:",
# disabled=False,
# )

def response(change):
means, err_lo, err_hi = get_plot_values(selected_category.value)
with g.batch_update():
g.data[0].y = means
g.data[0].error_y["array"] = err_hi
g.data[0].error_y["arrayminus"] = err_lo
g.data[0].hovertext = get_hovertexts(means, err_lo, err_hi)
# def response(change):
# means, err_lo, err_hi = get_plot_values(selected_category.value)
# with g.batch_update():
# if kwargs.get("checkpoint_mode"):
# g.data[0].y = means
# g.data[1].y = means + err_hi
# g.data[2].y = means - err_lo
# else:
# g.data[0].y = means
# g.data[0].error_y["array"] = err_hi
# g.data[0].error_y["arrayminus"] = err_lo
# g.data[0].hovertext = get_hovertexts(means, err_lo, err_hi)

selected_category.observe(response, names="value")
# selected_category.observe(response, names="value")

return ipywidgets.VBox([selected_category, g])
# return ipywidgets.VBox([selected_category, g])
return g

0 comments on commit 20c6fce

Please sign in to comment.