From 90410c86db375c1fa3057ad10f098e1d219432a3 Mon Sep 17 00:00:00 2001 From: Siwei Li <46750682+siwei-li@users.noreply.github.com> Date: Wed, 3 Apr 2024 11:42:10 -0700 Subject: [PATCH] Add "checkpoint_mode" kwarg to plotting (#99) * Add checkpoint_mode kwarg to plotting * Remove dummy key of token_category * Revert "Remove dummy key of token_category" This reverts commit 2edb2c5ae9406e47e6a66b9739e555ce81878a98. * Update vis notebook * Remove kwargs and have fixed args in place --------- Co-authored-by: Siwei Li --- notebooks/per_token_plot.ipynb | 89 +++++++++++++++++++--- src/delphi/eval/vis_per_token_model.py | 100 ++++++++++++++++--------- 2 files changed, 143 insertions(+), 46 deletions(-) diff --git a/notebooks/per_token_plot.ipynb b/notebooks/per_token_plot.ipynb index 198057c7..12e09926 100644 --- a/notebooks/per_token_plot.ipynb +++ b/notebooks/per_token_plot.ipynb @@ -2,21 +2,53 @@ "cells": [ { "cell_type": "code", - "execution_count": 5, + "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "696575431f65420e9dc22c3b3476bfbb", + "model_id": "fbda6a916fe84814be64a40423196d76", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "VBox(children=(Dropdown(description='Token Category:', options=('nouns', 'verbs', 'prepositions', 'adjectives'…" + "FigureWidget({\n", + " 'data': [{'line': {'width': 0},\n", + " 'marker': {'color': 'rgba(68, 68, 68, 0.3)'},\n", + " 'mode': 'lines',\n", + " 'name': 'Upper Bound',\n", + " 'showlegend': False,\n", + " 'type': 'scatter',\n", + " 'uid': 'a3590fcd-466d-4a73-b167-194ab728efcd',\n", + " 'x': [0, 1, 2, ..., 497, 498, 499],\n", + " 'y': array([2.34006592, 2.41241021, 2.57781922, ..., 2.56474203, 2.59573629,\n", + " 2.43304471])},\n", + " {'fill': 'tonexty',\n", + " 'fillcolor': 'rgba(68, 68, 68, 0.3)',\n", + " 'line': {'width': 0},\n", + " 'marker': {'color': 'rgba(68, 68, 68, 0.3)'},\n", + " 'mode': 'lines',\n", + " 'name': 'Lower Bound',\n", + " 'showlegend': False,\n", + " 'type': 'scatter',\n", + " 'uid': 'fda82808-c8ff-4b6c-878d-c76d66c8ce17',\n", + " 'x': [0, 1, 2, ..., 497, 498, 499],\n", + " 'y': array([0.93626447, 0.9302987 , 0.99836227, ..., 0.95607835, 0.76146911,\n", + " 0.81709211])},\n", + " {'marker': {'color': 'rgb(31, 119, 180)', 'line': {'color': 'rgb(31, 119, 180)', 'width': 1}, 'size': 0},\n", + " 'mode': 'lines',\n", + " 'name': 'Means',\n", + " 'type': 'scatter',\n", + " 'uid': 'b11dfbd0-c130-4a97-a8ff-b8c753b95035',\n", + " 'x': [0, 1, 2, ..., 497, 498, 499],\n", + " 'y': array([1.3701917 , 1.4372206 , 1.53251235, ..., 1.55583357, 1.50179179,\n", + " 1.45715223])}],\n", + " 'layout': {'template': '...'}\n", + "})" ] }, - "execution_count": 5, + "execution_count": 1, "metadata": {}, "output_type": "execute_result" } @@ -33,9 +65,10 @@ "random.seed(0)\n", "\n", "# generate mock data\n", - "model_names = ['llama2-100k', 'llama2-200k', 'llama2-1m', 'llama2-10m']\n", + "# model_names = ['llama2-100k', 'llama2-200k', 'llama2-1m', 'llama2-10m', \"0\"]\n", + "model_names = list(range(500))\n", "categories = ['nouns', 'verbs', 'prepositions', 'adjectives']\n", - "entries = [200, 100, 150, 300]\n", + "entries = [200, 100, 150, 300, 100]*100\n", "performance_data = defaultdict()\n", "for i, model in enumerate(model_names):\n", " performance_data[model] = defaultdict()\n", @@ -47,32 +80,64 @@ " performance_data[model][cat] = (-means, err_low, err_hi)\n", "\n", "\n", - "visualize_per_token_category(performance_data, log_scale=True, bg_color='LightGrey', line_color=\"Red\", marker_color='Orange', bar_color='Green')" + "visualize_per_token_category(performance_data, log_scale=True, checkpoint_mode=True)" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "cb3af5248a4a40118c36a527c927289d", + "model_id": "993e5d66ae56462a8eeec2c9ac6bd972", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "VBox(children=(Dropdown(description='Token Category:', options=('nouns', 'verbs', 'prepositions', 'adjectives'…" + "FigureWidget({\n", + " 'data': [{'line': {'width': 0},\n", + " 'marker': {'color': 'wheat'},\n", + " 'mode': 'lines',\n", + " 'name': 'Upper Bound',\n", + " 'showlegend': False,\n", + " 'type': 'scatter',\n", + " 'uid': '56999008-205c-4592-a3f7-ea61e3e09d8e',\n", + " 'x': [0, 1, 2, ..., 497, 498, 499],\n", + " 'y': array([2.34006592, 2.41241021, 2.57781922, ..., 2.56474203, 2.59573629,\n", + " 2.43304471])},\n", + " {'fill': 'tonexty',\n", + " 'fillcolor': 'wheat',\n", + " 'line': {'width': 0},\n", + " 'marker': {'color': 'wheat'},\n", + " 'mode': 'lines',\n", + " 'name': 'Lower Bound',\n", + " 'showlegend': False,\n", + " 'type': 'scatter',\n", + " 'uid': 'be8a04f1-b8c4-46af-bf5e-03c942eff19f',\n", + " 'x': [0, 1, 2, ..., 497, 498, 499],\n", + " 'y': array([0.93626447, 0.9302987 , 0.99836227, ..., 0.95607835, 0.76146911,\n", + " 0.81709211])},\n", + " {'marker': {'color': 'Orange', 'line': {'color': 'Orange', 'width': 1}, 'size': 0},\n", + " 'mode': 'lines',\n", + " 'name': 'Median',\n", + " 'type': 'scatter',\n", + " 'uid': '85fe5113-70fb-4aa7-9821-947287d84e1d',\n", + " 'x': [0, 1, 2, ..., 497, 498, 499],\n", + " 'y': array([1.3701917 , 1.4372206 , 1.53251235, ..., 1.55583357, 1.50179179,\n", + " 1.45715223])}],\n", + " 'layout': {'template': '...'}\n", + "})" ] }, - "execution_count": 3, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "visualize_per_token_category(performance_data, log_scale=False)" + "visualize_per_token_category(performance_data, log_scale=True, checkpoint_mode=True, line_metric=\"Median\", line_color='Orange' , shade_color=\"wheat\")" ] } ], diff --git a/src/delphi/eval/vis_per_token_model.py b/src/delphi/eval/vis_per_token_model.py index 618840b0..8daaa96f 100644 --- a/src/delphi/eval/vis_per_token_model.py +++ b/src/delphi/eval/vis_per_token_model.py @@ -1,67 +1,99 @@ +from typing import Union + import ipywidgets import numpy as np import plotly.graph_objects as go def visualize_per_token_category( - input: dict[str, dict[str, tuple]], log_scale=False, **kwargs: str -) -> ipywidgets.VBox: - model_names = list(input.keys()) - categories = list(input[model_names[0]].keys()) + input: dict[Union[str, int], dict[str, tuple]], + log_scale=False, + line_metric="Means", + checkpoint_mode=True, + shade_color="rgba(68, 68, 68, 0.3)", + line_color="rgb(31, 119, 180)", + bar_color="purple", + marker_color="SkyBlue", + background_color="AliceBlue", +) -> go.FigureWidget: + input_x = list(input.keys()) + categories = list(input[input_x[0]].keys()) category = categories[0] def get_hovertexts(mid: np.ndarray, lo: np.ndarray, hi: np.ndarray) -> list[str]: return [f"Loss: {m:.3f} ({l:.3f}, {h:.3f})" for m, l, h in zip(mid, lo, hi)] def get_plot_values(category: str) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - x = np.array([input[name][category] for name in model_names]).T + x = np.array([input[x][category] for x in input_x]).T means, err_lo, err_hi = x[0], x[1], x[2] return means, err_lo, err_hi - means, err_low, err_hi = get_plot_values(category) - g = go.FigureWidget( - data=go.Scatter( - x=model_names, + means, err_lo, err_hi = get_plot_values(category) + + if checkpoint_mode: + scatter_plot = go.Figure( + [ + go.Scatter( + name="Upper Bound", + x=input_x, + y=means + err_hi, + mode="lines", + marker=dict(color=shade_color), + line=dict(width=0), + showlegend=False, + ), + go.Scatter( + name="Lower Bound", + x=input_x, + y=means - err_lo, + marker=dict(color=shade_color), + line=dict(width=0), + mode="lines", + fillcolor=shade_color, + fill="tonexty", + showlegend=False, + ), + go.Scatter( + name=line_metric, + x=input_x, + y=means, + mode="lines", + marker=dict( + color=line_color, + size=0, + line=dict(color=line_color, width=1), + ), + ), + ] + ) + else: + scatter_plot = go.Scatter( + x=input_x, y=means, error_y=dict( type="data", symmetric=False, array=err_hi, - arrayminus=err_low, - color=kwargs.get("bar_color", "purple"), + arrayminus=err_lo, + color=bar_color, ), marker=dict( - color=kwargs.get("marker_color", "SkyBlue"), + color=marker_color, size=15, - line=dict(color=kwargs.get("line_color", "MediumPurple"), width=2), + line=dict(color=line_color, width=2), ), - hovertext=get_hovertexts(means, err_low, err_hi), + hovertext=get_hovertexts(means, err_lo, err_hi), hoverinfo="text+x", - ), + ) + g = go.FigureWidget( + data=scatter_plot, layout=go.Layout( yaxis=dict( title="Loss", type="log" if log_scale else "linear", ), - plot_bgcolor=kwargs.get("bg_color", "AliceBlue"), + plot_bgcolor=background_color, ), ) - selected_category = ipywidgets.Dropdown( - options=categories, - placeholder="", - description="Token Category:", - disabled=False, - ) - - def response(change): - means, err_lo, err_hi = get_plot_values(selected_category.value) - with g.batch_update(): - g.data[0].y = means - g.data[0].error_y["array"] = err_hi - g.data[0].error_y["arrayminus"] = err_lo - g.data[0].hovertext = get_hovertexts(means, err_lo, err_hi) - - selected_category.observe(response, names="value") - - return ipywidgets.VBox([selected_category, g]) + return g