diff --git a/notebooks/per_token_plot.ipynb b/notebooks/per_token_plot.ipynb new file mode 100644 index 00000000..198057c7 --- /dev/null +++ b/notebooks/per_token_plot.ipynb @@ -0,0 +1,100 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "696575431f65420e9dc22c3b3476bfbb", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(Dropdown(description='Token Category:', options=('nouns', 'verbs', 'prepositions', 'adjectives'…" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from collections import defaultdict\n", + "import math\n", + "import random\n", + "import numpy as np\n", + "\n", + "from delphi.eval.vis_per_token_model import visualize_per_token_category\n", + "\n", + "\n", + "random.seed(0)\n", + "\n", + "# generate mock data\n", + "model_names = ['llama2-100k', 'llama2-200k', 'llama2-1m', 'llama2-10m']\n", + "categories = ['nouns', 'verbs', 'prepositions', 'adjectives']\n", + "entries = [200, 100, 150, 300]\n", + "performance_data = defaultdict()\n", + "for i, model in enumerate(model_names):\n", + " performance_data[model] = defaultdict()\n", + " for cat in categories:\n", + " x = [math.log2(random.random()) for _ in range(entries[i])]\n", + " means = np.mean(x)\n", + " err_low = means - np.percentile(x, 25)\n", + " err_hi = np.percentile(x, 75) - means\n", + " performance_data[model][cat] = (-means, err_low, err_hi)\n", + "\n", + "\n", + "visualize_per_token_category(performance_data, log_scale=True, bg_color='LightGrey', line_color=\"Red\", marker_color='Orange', bar_color='Green')" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "cb3af5248a4a40118c36a527c927289d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(Dropdown(description='Token Category:', options=('nouns', 'verbs', 'prepositions', 'adjectives'…" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "visualize_per_token_category(performance_data, log_scale=False)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/delphi/eval/vis_per_token_model.py b/src/delphi/eval/vis_per_token_model.py new file mode 100644 index 00000000..a8e269fe --- /dev/null +++ b/src/delphi/eval/vis_per_token_model.py @@ -0,0 +1,61 @@ +import ipywidgets +import numpy as np +import plotly.graph_objects as go + + +def visualize_per_token_category( + input: dict[str, dict[str, tuple]], log_scale=False, **kwargs: str +) -> ipywidgets.VBox: + model_names = list(input.keys()) + categories = list(input[model_names[0]].keys()) + category = categories[0] + + def get_plot_values(category: str) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + x = np.array([input[name][category] for name in model_names]).T + means, err_lo, err_hi = x[0], x[1], x[2] + return means, err_lo, err_hi + + means, err_low, err_hi = get_plot_values(category) + g = go.FigureWidget( + data=go.Scatter( + x=model_names, + y=means, + error_y=dict( + type="data", + symmetric=False, + array=err_hi, + arrayminus=err_low, + color=kwargs.get("bar_color", "purple"), + ), + marker=dict( + color=kwargs.get("marker_color", "SkyBlue"), + size=15, + line=dict(color=kwargs.get("line_color", "MediumPurple"), width=2), + ), + ), + layout=go.Layout( + yaxis=dict( + title="Loss", + type="log" if log_scale else "linear", + ), + plot_bgcolor=kwargs.get("bg_color", "AliceBlue"), + ), + ) + + selected_category = ipywidgets.Dropdown( + options=categories, + placeholder="", + description="Token Category:", + disabled=False, + ) + + def response(change): + means, err_lo, err_hi = get_plot_values(selected_category.value) + with g.batch_update(): + g.data[0].y = means + g.data[0].error_y["array"] = err_hi + g.data[0].error_y["arrayminus"] = err_lo + + selected_category.observe(response, names="value") + + return ipywidgets.VBox([selected_category, g])