From 35da0e2b800a4c7388b45cb7a9eb9eeefb605c06 Mon Sep 17 00:00:00 2001 From: Siwei Li Date: Sun, 31 Mar 2024 11:34:13 -0700 Subject: [PATCH 1/5] Add checkpoint_mode kwarg to plotting --- notebooks/per_token_plot.ipynb | 89 ++++++++++++++++++--- src/delphi/eval/vis_per_token_model.py | 106 ++++++++++++++++++------- 2 files changed, 156 insertions(+), 39 deletions(-) diff --git a/notebooks/per_token_plot.ipynb b/notebooks/per_token_plot.ipynb index 198057c7..12305b64 100644 --- a/notebooks/per_token_plot.ipynb +++ b/notebooks/per_token_plot.ipynb @@ -2,21 +2,53 @@ "cells": [ { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "696575431f65420e9dc22c3b3476bfbb", + "model_id": "deb5a98615624e32b91fb3fc4d155c7a", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "VBox(children=(Dropdown(description='Token Category:', options=('nouns', 'verbs', 'prepositions', 'adjectives'…" + "FigureWidget({\n", + " 'data': [{'line': {'width': 0},\n", + " 'marker': {'color': '#444'},\n", + " 'mode': 'lines',\n", + " 'name': 'Upper Bound',\n", + " 'showlegend': False,\n", + " 'type': 'scatter',\n", + " 'uid': '4dcf6b1c-9b26-425e-be14-ba73fde289fb',\n", + " 'x': [0, 1, 2, ..., 497, 498, 499],\n", + " 'y': array([2.34006592, 2.41241021, 2.57781922, ..., 2.56474203, 2.59573629,\n", + " 2.43304471])},\n", + " {'fill': 'tonexty',\n", + " 'fillcolor': 'rgba(68, 68, 68, 0.3)',\n", + " 'line': {'width': 0},\n", + " 'marker': {'color': '#444'},\n", + " 'mode': 'lines',\n", + " 'name': 'Lower Bound',\n", + " 'showlegend': False,\n", + " 'type': 'scatter',\n", + " 'uid': 'c0a55b83-d045-4faa-9285-a927058cad75',\n", + " 'x': [0, 1, 2, ..., 497, 498, 499],\n", + " 'y': array([0.93626447, 0.9302987 , 0.99836227, ..., 0.95607835, 0.76146911,\n", + " 0.81709211])},\n", + " {'marker': {'color': 'rgb(31, 119, 180)', 'line': {'color': 'rgb(31, 119, 180)', 'width': 1}, 'size': 0},\n", + " 'mode': 'lines',\n", + " 'name': 'Means',\n", + " 'type': 'scatter',\n", + " 'uid': 'c4c0a68b-efa3-4930-aa0c-3c17451e3d2e',\n", + " 'x': [0, 1, 2, ..., 497, 498, 499],\n", + " 'y': array([1.3701917 , 1.4372206 , 1.53251235, ..., 1.55583357, 1.50179179,\n", + " 1.45715223])}],\n", + " 'layout': {'template': '...'}\n", + "})" ] }, - "execution_count": 5, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -33,9 +65,10 @@ "random.seed(0)\n", "\n", "# generate mock data\n", - "model_names = ['llama2-100k', 'llama2-200k', 'llama2-1m', 'llama2-10m']\n", + "# model_names = ['llama2-100k', 'llama2-200k', 'llama2-1m', 'llama2-10m', \"0\"]\n", + "model_names = list(range(500))\n", "categories = ['nouns', 'verbs', 'prepositions', 'adjectives']\n", - "entries = [200, 100, 150, 300]\n", + "entries = [200, 100, 150, 300, 100]*100\n", "performance_data = defaultdict()\n", "for i, model in enumerate(model_names):\n", " performance_data[model] = defaultdict()\n", @@ -47,32 +80,64 @@ " performance_data[model][cat] = (-means, err_low, err_hi)\n", "\n", "\n", - "visualize_per_token_category(performance_data, log_scale=True, bg_color='LightGrey', line_color=\"Red\", marker_color='Orange', bar_color='Green')" + "visualize_per_token_category(performance_data, log_scale=True, checkpoint_mode=True)" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "cb3af5248a4a40118c36a527c927289d", + "model_id": "4550d6e8c4f74396b180fd1223c4c3b2", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "VBox(children=(Dropdown(description='Token Category:', options=('nouns', 'verbs', 'prepositions', 'adjectives'…" + "FigureWidget({\n", + " 'data': [{'line': {'width': 0},\n", + " 'marker': {'color': 'wheat'},\n", + " 'mode': 'lines',\n", + " 'name': 'Upper Bound',\n", + " 'showlegend': False,\n", + " 'type': 'scatter',\n", + " 'uid': '453a6e83-3b4f-4090-b8ea-6c8945dba824',\n", + " 'x': [0, 1, 2, ..., 497, 498, 499],\n", + " 'y': array([2.34006592, 2.41241021, 2.57781922, ..., 2.56474203, 2.59573629,\n", + " 2.43304471])},\n", + " {'fill': 'tonexty',\n", + " 'fillcolor': 'wheat',\n", + " 'line': {'width': 0},\n", + " 'marker': {'color': 'wheat'},\n", + " 'mode': 'lines',\n", + " 'name': 'Lower Bound',\n", + " 'showlegend': False,\n", + " 'type': 'scatter',\n", + " 'uid': '4128f987-e0a9-457c-ab4a-c0547589c988',\n", + " 'x': [0, 1, 2, ..., 497, 498, 499],\n", + " 'y': array([0.93626447, 0.9302987 , 0.99836227, ..., 0.95607835, 0.76146911,\n", + " 0.81709211])},\n", + " {'marker': {'color': 'Orange', 'line': {'color': 'Orange', 'width': 1}, 'size': 0},\n", + " 'mode': 'lines',\n", + " 'name': 'Median',\n", + " 'type': 'scatter',\n", + " 'uid': '06de0035-f3a6-4b79-aff2-5ae62130a5db',\n", + " 'x': [0, 1, 2, ..., 497, 498, 499],\n", + " 'y': array([1.3701917 , 1.4372206 , 1.53251235, ..., 1.55583357, 1.50179179,\n", + " 1.45715223])}],\n", + " 'layout': {'template': '...'}\n", + "})" ] }, - "execution_count": 3, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "visualize_per_token_category(performance_data, log_scale=False)" + "visualize_per_token_category(performance_data, log_scale=True, checkpoint_mode=True, line_metric=\"Median\", line_color='Orange' , shade_color=\"wheat\")" ] } ], diff --git a/src/delphi/eval/vis_per_token_model.py b/src/delphi/eval/vis_per_token_model.py index 618840b0..d1a1acde 100644 --- a/src/delphi/eval/vis_per_token_model.py +++ b/src/delphi/eval/vis_per_token_model.py @@ -1,33 +1,77 @@ +from typing import Union + import ipywidgets import numpy as np import plotly.graph_objects as go def visualize_per_token_category( - input: dict[str, dict[str, tuple]], log_scale=False, **kwargs: str -) -> ipywidgets.VBox: - model_names = list(input.keys()) - categories = list(input[model_names[0]].keys()) + input: dict[Union[str, int], dict[str, tuple]], + log_scale=False, + **kwargs: Union[str, bool], + # ) -> ipywidgets.VBox: +) -> go.FigureWidget: + input_x = list(input.keys()) + categories = list(input[input_x[0]].keys()) category = categories[0] def get_hovertexts(mid: np.ndarray, lo: np.ndarray, hi: np.ndarray) -> list[str]: return [f"Loss: {m:.3f} ({l:.3f}, {h:.3f})" for m, l, h in zip(mid, lo, hi)] def get_plot_values(category: str) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - x = np.array([input[name][category] for name in model_names]).T + x = np.array([input[x][category] for x in input_x]).T means, err_lo, err_hi = x[0], x[1], x[2] return means, err_lo, err_hi - means, err_low, err_hi = get_plot_values(category) - g = go.FigureWidget( - data=go.Scatter( - x=model_names, + means, err_lo, err_hi = get_plot_values(category) + + if kwargs.get("checkpoint_mode"): + scatter_plot = go.Figure( + [ + go.Scatter( + name="Upper Bound", + x=input_x, + y=means + err_hi, + mode="lines", + marker=dict(color=kwargs.get("shade_color", "#444")), + line=dict(width=0), + showlegend=False, + ), + go.Scatter( + name="Lower Bound", + x=input_x, + y=means - err_lo, + marker=dict(color=kwargs.get("shade_color", "#444")), + line=dict(width=0), + mode="lines", + fillcolor=kwargs.get("shade_color", "rgba(68, 68, 68, 0.3)"), + fill="tonexty", + showlegend=False, + ), + go.Scatter( + name=kwargs.get("line_metric", "Means"), + x=input_x, + y=means, + mode="lines", + marker=dict( + color=kwargs.get("line_color", "rgb(31, 119, 180)"), + size=0, + line=dict( + color=kwargs.get("line_color", "rgb(31, 119, 180)"), width=1 + ), + ), + ), + ] + ) + else: + scatter_plot = go.Scatter( + x=input_x, y=means, error_y=dict( type="data", symmetric=False, array=err_hi, - arrayminus=err_low, + arrayminus=err_lo, color=kwargs.get("bar_color", "purple"), ), marker=dict( @@ -35,9 +79,11 @@ def get_plot_values(category: str) -> tuple[np.ndarray, np.ndarray, np.ndarray]: size=15, line=dict(color=kwargs.get("line_color", "MediumPurple"), width=2), ), - hovertext=get_hovertexts(means, err_low, err_hi), + hovertext=get_hovertexts(means, err_lo, err_hi), hoverinfo="text+x", - ), + ) + g = go.FigureWidget( + data=scatter_plot, layout=go.Layout( yaxis=dict( title="Loss", @@ -47,21 +93,27 @@ def get_plot_values(category: str) -> tuple[np.ndarray, np.ndarray, np.ndarray]: ), ) - selected_category = ipywidgets.Dropdown( - options=categories, - placeholder="", - description="Token Category:", - disabled=False, - ) + # selected_category = ipywidgets.Dropdown( + # options=categories, + # placeholder="", + # description="Token Category:", + # disabled=False, + # ) - def response(change): - means, err_lo, err_hi = get_plot_values(selected_category.value) - with g.batch_update(): - g.data[0].y = means - g.data[0].error_y["array"] = err_hi - g.data[0].error_y["arrayminus"] = err_lo - g.data[0].hovertext = get_hovertexts(means, err_lo, err_hi) + # def response(change): + # means, err_lo, err_hi = get_plot_values(selected_category.value) + # with g.batch_update(): + # if kwargs.get("checkpoint_mode"): + # g.data[0].y = means + # g.data[1].y = means + err_hi + # g.data[2].y = means - err_lo + # else: + # g.data[0].y = means + # g.data[0].error_y["array"] = err_hi + # g.data[0].error_y["arrayminus"] = err_lo + # g.data[0].hovertext = get_hovertexts(means, err_lo, err_hi) - selected_category.observe(response, names="value") + # selected_category.observe(response, names="value") - return ipywidgets.VBox([selected_category, g]) + # return ipywidgets.VBox([selected_category, g]) + return g From 307c51d90fe2f63c13fdad0e037504d2b50f4e19 Mon Sep 17 00:00:00 2001 From: Siwei Li Date: Mon, 1 Apr 2024 09:03:06 -0700 Subject: [PATCH 2/5] Remove dummy key of token_category --- src/delphi/eval/vis_per_token_model.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/delphi/eval/vis_per_token_model.py b/src/delphi/eval/vis_per_token_model.py index d1a1acde..923c4224 100644 --- a/src/delphi/eval/vis_per_token_model.py +++ b/src/delphi/eval/vis_per_token_model.py @@ -6,24 +6,24 @@ def visualize_per_token_category( - input: dict[Union[str, int], dict[str, tuple]], + input: dict[Union[str, int], tuple], + # input: dict[Union[str, int], dict[str, tuple]], log_scale=False, **kwargs: Union[str, bool], - # ) -> ipywidgets.VBox: ) -> go.FigureWidget: input_x = list(input.keys()) - categories = list(input[input_x[0]].keys()) - category = categories[0] + # categories = list(input[input_x[0]].keys()) + # category = categories[0] def get_hovertexts(mid: np.ndarray, lo: np.ndarray, hi: np.ndarray) -> list[str]: return [f"Loss: {m:.3f} ({l:.3f}, {h:.3f})" for m, l, h in zip(mid, lo, hi)] - def get_plot_values(category: str) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - x = np.array([input[x][category] for x in input_x]).T + def get_plot_values() -> tuple[np.ndarray, np.ndarray, np.ndarray]: + x = np.array([input[x] for x in input_x]).T means, err_lo, err_hi = x[0], x[1], x[2] return means, err_lo, err_hi - means, err_lo, err_hi = get_plot_values(category) + means, err_lo, err_hi = get_plot_values() if kwargs.get("checkpoint_mode"): scatter_plot = go.Figure( From 28559a5b4f895dc0d82b02427ee30f2ad95f7e16 Mon Sep 17 00:00:00 2001 From: Siwei Li Date: Mon, 1 Apr 2024 10:51:17 -0700 Subject: [PATCH 3/5] Revert "Remove dummy key of token_category" This reverts commit 2edb2c5ae9406e47e6a66b9739e555ce81878a98. --- src/delphi/eval/vis_per_token_model.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/delphi/eval/vis_per_token_model.py b/src/delphi/eval/vis_per_token_model.py index 923c4224..d1a1acde 100644 --- a/src/delphi/eval/vis_per_token_model.py +++ b/src/delphi/eval/vis_per_token_model.py @@ -6,24 +6,24 @@ def visualize_per_token_category( - input: dict[Union[str, int], tuple], - # input: dict[Union[str, int], dict[str, tuple]], + input: dict[Union[str, int], dict[str, tuple]], log_scale=False, **kwargs: Union[str, bool], + # ) -> ipywidgets.VBox: ) -> go.FigureWidget: input_x = list(input.keys()) - # categories = list(input[input_x[0]].keys()) - # category = categories[0] + categories = list(input[input_x[0]].keys()) + category = categories[0] def get_hovertexts(mid: np.ndarray, lo: np.ndarray, hi: np.ndarray) -> list[str]: return [f"Loss: {m:.3f} ({l:.3f}, {h:.3f})" for m, l, h in zip(mid, lo, hi)] - def get_plot_values() -> tuple[np.ndarray, np.ndarray, np.ndarray]: - x = np.array([input[x] for x in input_x]).T + def get_plot_values(category: str) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + x = np.array([input[x][category] for x in input_x]).T means, err_lo, err_hi = x[0], x[1], x[2] return means, err_lo, err_hi - means, err_lo, err_hi = get_plot_values() + means, err_lo, err_hi = get_plot_values(category) if kwargs.get("checkpoint_mode"): scatter_plot = go.Figure( From 9ec2ec5e695c7cbeb261f05286aec1d7789a0349 Mon Sep 17 00:00:00 2001 From: Siwei Li Date: Mon, 1 Apr 2024 10:53:19 -0700 Subject: [PATCH 4/5] Update vis notebook --- notebooks/per_token_plot.ipynb | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/notebooks/per_token_plot.ipynb b/notebooks/per_token_plot.ipynb index 12305b64..33f7c354 100644 --- a/notebooks/per_token_plot.ipynb +++ b/notebooks/per_token_plot.ipynb @@ -2,13 +2,13 @@ "cells": [ { "cell_type": "code", - "execution_count": 4, + "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "deb5a98615624e32b91fb3fc4d155c7a", + "model_id": "9e971e03344f4b608cfc2b588a477238", "version_major": 2, "version_minor": 0 }, @@ -20,7 +20,7 @@ " 'name': 'Upper Bound',\n", " 'showlegend': False,\n", " 'type': 'scatter',\n", - " 'uid': '4dcf6b1c-9b26-425e-be14-ba73fde289fb',\n", + " 'uid': '4a40445f-0502-42f2-878b-9a9f03d66717',\n", " 'x': [0, 1, 2, ..., 497, 498, 499],\n", " 'y': array([2.34006592, 2.41241021, 2.57781922, ..., 2.56474203, 2.59573629,\n", " 2.43304471])},\n", @@ -32,7 +32,7 @@ " 'name': 'Lower Bound',\n", " 'showlegend': False,\n", " 'type': 'scatter',\n", - " 'uid': 'c0a55b83-d045-4faa-9285-a927058cad75',\n", + " 'uid': 'a08f2c43-1513-4042-bd1e-f1f01e28a0ef',\n", " 'x': [0, 1, 2, ..., 497, 498, 499],\n", " 'y': array([0.93626447, 0.9302987 , 0.99836227, ..., 0.95607835, 0.76146911,\n", " 0.81709211])},\n", @@ -40,7 +40,7 @@ " 'mode': 'lines',\n", " 'name': 'Means',\n", " 'type': 'scatter',\n", - " 'uid': 'c4c0a68b-efa3-4930-aa0c-3c17451e3d2e',\n", + " 'uid': 'da46ea5a-dfe8-4613-80fa-33805fabd2fb',\n", " 'x': [0, 1, 2, ..., 497, 498, 499],\n", " 'y': array([1.3701917 , 1.4372206 , 1.53251235, ..., 1.55583357, 1.50179179,\n", " 1.45715223])}],\n", @@ -48,7 +48,7 @@ "})" ] }, - "execution_count": 4, + "execution_count": 1, "metadata": {}, "output_type": "execute_result" } @@ -85,13 +85,13 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "4550d6e8c4f74396b180fd1223c4c3b2", + "model_id": "d951d67fc372475bab98e84db51c0cbc", "version_major": 2, "version_minor": 0 }, @@ -103,7 +103,7 @@ " 'name': 'Upper Bound',\n", " 'showlegend': False,\n", " 'type': 'scatter',\n", - " 'uid': '453a6e83-3b4f-4090-b8ea-6c8945dba824',\n", + " 'uid': '274f7f1b-21af-41bb-8c00-6fa385439bff',\n", " 'x': [0, 1, 2, ..., 497, 498, 499],\n", " 'y': array([2.34006592, 2.41241021, 2.57781922, ..., 2.56474203, 2.59573629,\n", " 2.43304471])},\n", @@ -115,7 +115,7 @@ " 'name': 'Lower Bound',\n", " 'showlegend': False,\n", " 'type': 'scatter',\n", - " 'uid': '4128f987-e0a9-457c-ab4a-c0547589c988',\n", + " 'uid': '050ad540-2443-452c-8f2c-e4d218640318',\n", " 'x': [0, 1, 2, ..., 497, 498, 499],\n", " 'y': array([0.93626447, 0.9302987 , 0.99836227, ..., 0.95607835, 0.76146911,\n", " 0.81709211])},\n", @@ -123,7 +123,7 @@ " 'mode': 'lines',\n", " 'name': 'Median',\n", " 'type': 'scatter',\n", - " 'uid': '06de0035-f3a6-4b79-aff2-5ae62130a5db',\n", + " 'uid': '27c3cec6-4e36-4e3b-9083-7b9bec371542',\n", " 'x': [0, 1, 2, ..., 497, 498, 499],\n", " 'y': array([1.3701917 , 1.4372206 , 1.53251235, ..., 1.55583357, 1.50179179,\n", " 1.45715223])}],\n", @@ -131,7 +131,7 @@ "})" ] }, - "execution_count": 5, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } From 2988ddb7499e835e177116c3f6e72ece5692a171 Mon Sep 17 00:00:00 2001 From: Siwei Li Date: Mon, 1 Apr 2024 11:21:46 -0700 Subject: [PATCH 5/5] Remove kwargs and have fixed args in place --- notebooks/per_token_plot.ipynb | 20 ++++----- src/delphi/eval/vis_per_token_model.py | 56 +++++++++----------------- 2 files changed, 28 insertions(+), 48 deletions(-) diff --git a/notebooks/per_token_plot.ipynb b/notebooks/per_token_plot.ipynb index 33f7c354..12e09926 100644 --- a/notebooks/per_token_plot.ipynb +++ b/notebooks/per_token_plot.ipynb @@ -8,31 +8,31 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "9e971e03344f4b608cfc2b588a477238", + "model_id": "fbda6a916fe84814be64a40423196d76", "version_major": 2, "version_minor": 0 }, "text/plain": [ "FigureWidget({\n", " 'data': [{'line': {'width': 0},\n", - " 'marker': {'color': '#444'},\n", + " 'marker': {'color': 'rgba(68, 68, 68, 0.3)'},\n", " 'mode': 'lines',\n", " 'name': 'Upper Bound',\n", " 'showlegend': False,\n", " 'type': 'scatter',\n", - " 'uid': '4a40445f-0502-42f2-878b-9a9f03d66717',\n", + " 'uid': 'a3590fcd-466d-4a73-b167-194ab728efcd',\n", " 'x': [0, 1, 2, ..., 497, 498, 499],\n", " 'y': array([2.34006592, 2.41241021, 2.57781922, ..., 2.56474203, 2.59573629,\n", " 2.43304471])},\n", " {'fill': 'tonexty',\n", " 'fillcolor': 'rgba(68, 68, 68, 0.3)',\n", " 'line': {'width': 0},\n", - " 'marker': {'color': '#444'},\n", + " 'marker': {'color': 'rgba(68, 68, 68, 0.3)'},\n", " 'mode': 'lines',\n", " 'name': 'Lower Bound',\n", " 'showlegend': False,\n", " 'type': 'scatter',\n", - " 'uid': 'a08f2c43-1513-4042-bd1e-f1f01e28a0ef',\n", + " 'uid': 'fda82808-c8ff-4b6c-878d-c76d66c8ce17',\n", " 'x': [0, 1, 2, ..., 497, 498, 499],\n", " 'y': array([0.93626447, 0.9302987 , 0.99836227, ..., 0.95607835, 0.76146911,\n", " 0.81709211])},\n", @@ -40,7 +40,7 @@ " 'mode': 'lines',\n", " 'name': 'Means',\n", " 'type': 'scatter',\n", - " 'uid': 'da46ea5a-dfe8-4613-80fa-33805fabd2fb',\n", + " 'uid': 'b11dfbd0-c130-4a97-a8ff-b8c753b95035',\n", " 'x': [0, 1, 2, ..., 497, 498, 499],\n", " 'y': array([1.3701917 , 1.4372206 , 1.53251235, ..., 1.55583357, 1.50179179,\n", " 1.45715223])}],\n", @@ -91,7 +91,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "d951d67fc372475bab98e84db51c0cbc", + "model_id": "993e5d66ae56462a8eeec2c9ac6bd972", "version_major": 2, "version_minor": 0 }, @@ -103,7 +103,7 @@ " 'name': 'Upper Bound',\n", " 'showlegend': False,\n", " 'type': 'scatter',\n", - " 'uid': '274f7f1b-21af-41bb-8c00-6fa385439bff',\n", + " 'uid': '56999008-205c-4592-a3f7-ea61e3e09d8e',\n", " 'x': [0, 1, 2, ..., 497, 498, 499],\n", " 'y': array([2.34006592, 2.41241021, 2.57781922, ..., 2.56474203, 2.59573629,\n", " 2.43304471])},\n", @@ -115,7 +115,7 @@ " 'name': 'Lower Bound',\n", " 'showlegend': False,\n", " 'type': 'scatter',\n", - " 'uid': '050ad540-2443-452c-8f2c-e4d218640318',\n", + " 'uid': 'be8a04f1-b8c4-46af-bf5e-03c942eff19f',\n", " 'x': [0, 1, 2, ..., 497, 498, 499],\n", " 'y': array([0.93626447, 0.9302987 , 0.99836227, ..., 0.95607835, 0.76146911,\n", " 0.81709211])},\n", @@ -123,7 +123,7 @@ " 'mode': 'lines',\n", " 'name': 'Median',\n", " 'type': 'scatter',\n", - " 'uid': '27c3cec6-4e36-4e3b-9083-7b9bec371542',\n", + " 'uid': '85fe5113-70fb-4aa7-9821-947287d84e1d',\n", " 'x': [0, 1, 2, ..., 497, 498, 499],\n", " 'y': array([1.3701917 , 1.4372206 , 1.53251235, ..., 1.55583357, 1.50179179,\n", " 1.45715223])}],\n", diff --git a/src/delphi/eval/vis_per_token_model.py b/src/delphi/eval/vis_per_token_model.py index d1a1acde..8daaa96f 100644 --- a/src/delphi/eval/vis_per_token_model.py +++ b/src/delphi/eval/vis_per_token_model.py @@ -8,8 +8,13 @@ def visualize_per_token_category( input: dict[Union[str, int], dict[str, tuple]], log_scale=False, - **kwargs: Union[str, bool], - # ) -> ipywidgets.VBox: + line_metric="Means", + checkpoint_mode=True, + shade_color="rgba(68, 68, 68, 0.3)", + line_color="rgb(31, 119, 180)", + bar_color="purple", + marker_color="SkyBlue", + background_color="AliceBlue", ) -> go.FigureWidget: input_x = list(input.keys()) categories = list(input[input_x[0]].keys()) @@ -25,7 +30,7 @@ def get_plot_values(category: str) -> tuple[np.ndarray, np.ndarray, np.ndarray]: means, err_lo, err_hi = get_plot_values(category) - if kwargs.get("checkpoint_mode"): + if checkpoint_mode: scatter_plot = go.Figure( [ go.Scatter( @@ -33,7 +38,7 @@ def get_plot_values(category: str) -> tuple[np.ndarray, np.ndarray, np.ndarray]: x=input_x, y=means + err_hi, mode="lines", - marker=dict(color=kwargs.get("shade_color", "#444")), + marker=dict(color=shade_color), line=dict(width=0), showlegend=False, ), @@ -41,24 +46,22 @@ def get_plot_values(category: str) -> tuple[np.ndarray, np.ndarray, np.ndarray]: name="Lower Bound", x=input_x, y=means - err_lo, - marker=dict(color=kwargs.get("shade_color", "#444")), + marker=dict(color=shade_color), line=dict(width=0), mode="lines", - fillcolor=kwargs.get("shade_color", "rgba(68, 68, 68, 0.3)"), + fillcolor=shade_color, fill="tonexty", showlegend=False, ), go.Scatter( - name=kwargs.get("line_metric", "Means"), + name=line_metric, x=input_x, y=means, mode="lines", marker=dict( - color=kwargs.get("line_color", "rgb(31, 119, 180)"), + color=line_color, size=0, - line=dict( - color=kwargs.get("line_color", "rgb(31, 119, 180)"), width=1 - ), + line=dict(color=line_color, width=1), ), ), ] @@ -72,12 +75,12 @@ def get_plot_values(category: str) -> tuple[np.ndarray, np.ndarray, np.ndarray]: symmetric=False, array=err_hi, arrayminus=err_lo, - color=kwargs.get("bar_color", "purple"), + color=bar_color, ), marker=dict( - color=kwargs.get("marker_color", "SkyBlue"), + color=marker_color, size=15, - line=dict(color=kwargs.get("line_color", "MediumPurple"), width=2), + line=dict(color=line_color, width=2), ), hovertext=get_hovertexts(means, err_lo, err_hi), hoverinfo="text+x", @@ -89,31 +92,8 @@ def get_plot_values(category: str) -> tuple[np.ndarray, np.ndarray, np.ndarray]: title="Loss", type="log" if log_scale else "linear", ), - plot_bgcolor=kwargs.get("bg_color", "AliceBlue"), + plot_bgcolor=background_color, ), ) - # selected_category = ipywidgets.Dropdown( - # options=categories, - # placeholder="", - # description="Token Category:", - # disabled=False, - # ) - - # def response(change): - # means, err_lo, err_hi = get_plot_values(selected_category.value) - # with g.batch_update(): - # if kwargs.get("checkpoint_mode"): - # g.data[0].y = means - # g.data[1].y = means + err_hi - # g.data[2].y = means - err_lo - # else: - # g.data[0].y = means - # g.data[0].error_y["array"] = err_hi - # g.data[0].error_y["arrayminus"] = err_lo - # g.data[0].hovertext = get_hovertexts(means, err_lo, err_hi) - - # selected_category.observe(response, names="value") - - # return ipywidgets.VBox([selected_category, g]) return g