diff --git a/notebooks/eval_notebook.ipynb b/notebooks/eval_notebook.ipynb new file mode 100644 index 00000000..864a1f1e --- /dev/null +++ b/notebooks/eval_notebook.ipynb @@ -0,0 +1,480 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# colab cells (only run if on colab)\n", + "# TODO: experiment on colab to see how to set up the environment" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Important\n", + "\n", + "Run this cell by cell. The token selecter cell needs to be ran first so the later cells work." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "application/javascript": "(function(root) {\n function now() {\n return new Date();\n }\n\n var force = true;\n var py_version = '3.4.1'.replace('rc', '-rc.').replace('.dev', '-dev.');\n var reloading = false;\n var Bokeh = root.Bokeh;\n\n if (typeof (root._bokeh_timeout) === \"undefined\" || force) {\n root._bokeh_timeout = Date.now() + 5000;\n root._bokeh_failed_load = false;\n }\n\n function run_callbacks() {\n try {\n root._bokeh_onload_callbacks.forEach(function(callback) {\n if (callback != null)\n callback();\n });\n } finally {\n delete root._bokeh_onload_callbacks;\n }\n console.debug(\"Bokeh: all callbacks have finished\");\n }\n\n function load_libs(css_urls, js_urls, js_modules, js_exports, callback) {\n if (css_urls == null) css_urls = [];\n if (js_urls == null) js_urls = [];\n if (js_modules == null) js_modules = [];\n if (js_exports == null) js_exports = {};\n\n root._bokeh_onload_callbacks.push(callback);\n\n if (root._bokeh_is_loading > 0) {\n console.debug(\"Bokeh: BokehJS is being loaded, scheduling callback at\", now());\n return null;\n }\n if (js_urls.length === 0 && js_modules.length === 0 && Object.keys(js_exports).length === 0) {\n run_callbacks();\n return null;\n }\n if (!reloading) {\n console.debug(\"Bokeh: BokehJS not loaded, scheduling load and callback at\", now());\n }\n\n function on_load() {\n root._bokeh_is_loading--;\n if (root._bokeh_is_loading === 0) {\n console.debug(\"Bokeh: all BokehJS libraries/stylesheets loaded\");\n run_callbacks()\n }\n }\n window._bokeh_on_load = on_load\n\n function on_error() {\n console.error(\"failed to load \" + url);\n }\n\n var skip = [];\n if (window.requirejs) {\n window.requirejs.config({'packages': {}, 'paths': {}, 'shim': {}});\n root._bokeh_is_loading = css_urls.length + 0;\n } else {\n root._bokeh_is_loading = css_urls.length + js_urls.length + js_modules.length + Object.keys(js_exports).length;\n }\n\n var existing_stylesheets = []\n var links = document.getElementsByTagName('link')\n for (var i = 0; i < links.length; i++) {\n var link = links[i]\n if (link.href != null) {\n\texisting_stylesheets.push(link.href)\n }\n }\n for (var i = 0; i < css_urls.length; i++) {\n var url = css_urls[i];\n if (existing_stylesheets.indexOf(url) !== -1) {\n\ton_load()\n\tcontinue;\n }\n const element = document.createElement(\"link\");\n element.onload = on_load;\n element.onerror = on_error;\n element.rel = \"stylesheet\";\n element.type = \"text/css\";\n element.href = url;\n console.debug(\"Bokeh: injecting link tag for BokehJS stylesheet: \", url);\n document.body.appendChild(element);\n } var existing_scripts = []\n var scripts = document.getElementsByTagName('script')\n for (var i = 0; i < scripts.length; i++) {\n var script = scripts[i]\n if (script.src != null) {\n\texisting_scripts.push(script.src)\n }\n }\n for (var i = 0; i < js_urls.length; i++) {\n var url = js_urls[i];\n if (skip.indexOf(url) !== -1 || existing_scripts.indexOf(url) !== -1) {\n\tif (!window.requirejs) {\n\t on_load();\n\t}\n\tcontinue;\n }\n var element = document.createElement('script');\n element.onload = on_load;\n element.onerror = on_error;\n element.async = false;\n element.src = url;\n console.debug(\"Bokeh: injecting script tag for BokehJS library: \", url);\n document.head.appendChild(element);\n }\n for (var i = 0; i < js_modules.length; i++) {\n var url = js_modules[i];\n if (skip.indexOf(url) !== -1 || existing_scripts.indexOf(url) !== -1) {\n\tif (!window.requirejs) {\n\t on_load();\n\t}\n\tcontinue;\n }\n var element = document.createElement('script');\n element.onload = on_load;\n element.onerror = on_error;\n element.async = false;\n element.src = url;\n element.type = \"module\";\n console.debug(\"Bokeh: injecting script tag for BokehJS library: \", url);\n document.head.appendChild(element);\n }\n for (const name in js_exports) {\n var url = js_exports[name];\n if (skip.indexOf(url) >= 0 || root[name] != null) {\n\tif (!window.requirejs) {\n\t on_load();\n\t}\n\tcontinue;\n }\n var element = document.createElement('script');\n element.onerror = on_error;\n element.async = false;\n element.type = \"module\";\n console.debug(\"Bokeh: injecting script tag for BokehJS library: \", url);\n element.textContent = `\n import ${name} from \"${url}\"\n window.${name} = ${name}\n window._bokeh_on_load()\n `\n document.head.appendChild(element);\n }\n if (!js_urls.length && !js_modules.length) {\n on_load()\n }\n };\n\n function inject_raw_css(css) {\n const element = document.createElement(\"style\");\n element.appendChild(document.createTextNode(css));\n document.body.appendChild(element);\n }\n\n var js_urls = [\"https://cdn.bokeh.org/bokeh/release/bokeh-3.4.1.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-gl-3.4.1.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-widgets-3.4.1.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-tables-3.4.1.min.js\", \"https://cdn.holoviz.org/panel/1.4.0/dist/panel.min.js\"];\n var js_modules = [];\n var js_exports = {};\n var css_urls = [\"https://cdn.holoviz.org/panel/1.4.0/dist/bundled/font-awesome/css/all.min.css\"];\n var inline_js = [ function(Bokeh) {\n Bokeh.set_log_level(\"info\");\n },\nfunction(Bokeh) {} // ensure no trailing comma for IE\n ];\n\n function run_inline_js() {\n if ((root.Bokeh !== undefined) || (force === true)) {\n for (var i = 0; i < inline_js.length; i++) {\n\ttry {\n inline_js[i].call(root, root.Bokeh);\n\t} catch(e) {\n\t if (!reloading) {\n\t throw e;\n\t }\n\t}\n }\n // Cache old bokeh versions\n if (Bokeh != undefined && !reloading) {\n\tvar NewBokeh = root.Bokeh;\n\tif (Bokeh.versions === undefined) {\n\t Bokeh.versions = new Map();\n\t}\n\tif (NewBokeh.version !== Bokeh.version) {\n\t Bokeh.versions.set(NewBokeh.version, NewBokeh)\n\t}\n\troot.Bokeh = Bokeh;\n }} else if (Date.now() < root._bokeh_timeout) {\n setTimeout(run_inline_js, 100);\n } else if (!root._bokeh_failed_load) {\n console.log(\"Bokeh: BokehJS failed to load within specified timeout.\");\n root._bokeh_failed_load = true;\n }\n root._bokeh_is_initializing = false\n }\n\n function load_or_wait() {\n // Implement a backoff loop that tries to ensure we do not load multiple\n // versions of Bokeh and its dependencies at the same time.\n // In recent versions we use the root._bokeh_is_initializing flag\n // to determine whether there is an ongoing attempt to initialize\n // bokeh, however for backward compatibility we also try to ensure\n // that we do not start loading a newer (Panel>=1.0 and Bokeh>3) version\n // before older versions are fully initialized.\n if (root._bokeh_is_initializing && Date.now() > root._bokeh_timeout) {\n root._bokeh_is_initializing = false;\n root._bokeh_onload_callbacks = undefined;\n console.log(\"Bokeh: BokehJS was loaded multiple times but one version failed to initialize.\");\n load_or_wait();\n } else if (root._bokeh_is_initializing || (typeof root._bokeh_is_initializing === \"undefined\" && root._bokeh_onload_callbacks !== undefined)) {\n setTimeout(load_or_wait, 100);\n } else {\n root._bokeh_is_initializing = true\n root._bokeh_onload_callbacks = []\n var bokeh_loaded = Bokeh != null && (Bokeh.version === py_version || (Bokeh.versions !== undefined && Bokeh.versions.has(py_version)));\n if (!reloading && !bokeh_loaded) {\n\troot.Bokeh = undefined;\n }\n load_libs(css_urls, js_urls, js_modules, js_exports, function() {\n\tconsole.debug(\"Bokeh: BokehJS plotting callback run at\", now());\n\trun_inline_js();\n });\n }\n }\n // Give older versions of the autoload script a head-start to ensure\n // they initialize before we start loading newer version.\n setTimeout(load_or_wait, 100)\n}(window));", + "application/vnd.holoviews_load.v0+json": "" + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": "\nif ((window.PyViz === undefined) || (window.PyViz instanceof HTMLElement)) {\n window.PyViz = {comms: {}, comm_status:{}, kernels:{}, receivers: {}, plot_index: []}\n}\n\n\n function JupyterCommManager() {\n }\n\n JupyterCommManager.prototype.register_target = function(plot_id, comm_id, msg_handler) {\n if (window.comm_manager || ((window.Jupyter !== undefined) && (Jupyter.notebook.kernel != null))) {\n var comm_manager = window.comm_manager || Jupyter.notebook.kernel.comm_manager;\n comm_manager.register_target(comm_id, function(comm) {\n comm.on_msg(msg_handler);\n });\n } else if ((plot_id in window.PyViz.kernels) && (window.PyViz.kernels[plot_id])) {\n window.PyViz.kernels[plot_id].registerCommTarget(comm_id, function(comm) {\n comm.onMsg = msg_handler;\n });\n } else if (typeof google != 'undefined' && google.colab.kernel != null) {\n google.colab.kernel.comms.registerTarget(comm_id, (comm) => {\n var messages = comm.messages[Symbol.asyncIterator]();\n function processIteratorResult(result) {\n var message = result.value;\n console.log(message)\n var content = {data: message.data, comm_id};\n var buffers = []\n for (var buffer of message.buffers || []) {\n buffers.push(new DataView(buffer))\n }\n var metadata = message.metadata || {};\n var msg = {content, buffers, metadata}\n msg_handler(msg);\n return messages.next().then(processIteratorResult);\n }\n return messages.next().then(processIteratorResult);\n })\n }\n }\n\n JupyterCommManager.prototype.get_client_comm = function(plot_id, comm_id, msg_handler) {\n if (comm_id in window.PyViz.comms) {\n return window.PyViz.comms[comm_id];\n } else if (window.comm_manager || ((window.Jupyter !== undefined) && (Jupyter.notebook.kernel != null))) {\n var comm_manager = window.comm_manager || Jupyter.notebook.kernel.comm_manager;\n var comm = comm_manager.new_comm(comm_id, {}, {}, {}, comm_id);\n if (msg_handler) {\n comm.on_msg(msg_handler);\n }\n } else if ((plot_id in window.PyViz.kernels) && (window.PyViz.kernels[plot_id])) {\n var comm = window.PyViz.kernels[plot_id].connectToComm(comm_id);\n comm.open();\n if (msg_handler) {\n comm.onMsg = msg_handler;\n }\n } else if (typeof google != 'undefined' && google.colab.kernel != null) {\n var comm_promise = google.colab.kernel.comms.open(comm_id)\n comm_promise.then((comm) => {\n window.PyViz.comms[comm_id] = comm;\n if (msg_handler) {\n var messages = comm.messages[Symbol.asyncIterator]();\n function processIteratorResult(result) {\n var message = result.value;\n var content = {data: message.data};\n var metadata = message.metadata || {comm_id};\n var msg = {content, metadata}\n msg_handler(msg);\n return messages.next().then(processIteratorResult);\n }\n return messages.next().then(processIteratorResult);\n }\n }) \n var sendClosure = (data, metadata, buffers, disposeOnDone) => {\n return comm_promise.then((comm) => {\n comm.send(data, metadata, buffers, disposeOnDone);\n });\n };\n var comm = {\n send: sendClosure\n };\n }\n window.PyViz.comms[comm_id] = comm;\n return comm;\n }\n window.PyViz.comm_manager = new JupyterCommManager();\n \n\n\nvar JS_MIME_TYPE = 'application/javascript';\nvar HTML_MIME_TYPE = 'text/html';\nvar EXEC_MIME_TYPE = 'application/vnd.holoviews_exec.v0+json';\nvar CLASS_NAME = 'output';\n\n/**\n * Render data to the DOM node\n */\nfunction render(props, node) {\n var div = document.createElement(\"div\");\n var script = document.createElement(\"script\");\n node.appendChild(div);\n node.appendChild(script);\n}\n\n/**\n * Handle when a new output is added\n */\nfunction handle_add_output(event, handle) {\n var output_area = handle.output_area;\n var output = handle.output;\n if ((output.data == undefined) || (!output.data.hasOwnProperty(EXEC_MIME_TYPE))) {\n return\n }\n var id = output.metadata[EXEC_MIME_TYPE][\"id\"];\n var toinsert = output_area.element.find(\".\" + CLASS_NAME.split(' ')[0]);\n if (id !== undefined) {\n var nchildren = toinsert.length;\n var html_node = toinsert[nchildren-1].children[0];\n html_node.innerHTML = output.data[HTML_MIME_TYPE];\n var scripts = [];\n var nodelist = html_node.querySelectorAll(\"script\");\n for (var i in nodelist) {\n if (nodelist.hasOwnProperty(i)) {\n scripts.push(nodelist[i])\n }\n }\n\n scripts.forEach( function (oldScript) {\n var newScript = document.createElement(\"script\");\n var attrs = [];\n var nodemap = oldScript.attributes;\n for (var j in nodemap) {\n if (nodemap.hasOwnProperty(j)) {\n attrs.push(nodemap[j])\n }\n }\n attrs.forEach(function(attr) { newScript.setAttribute(attr.name, attr.value) });\n newScript.appendChild(document.createTextNode(oldScript.innerHTML));\n oldScript.parentNode.replaceChild(newScript, oldScript);\n });\n if (JS_MIME_TYPE in output.data) {\n toinsert[nchildren-1].children[1].textContent = output.data[JS_MIME_TYPE];\n }\n output_area._hv_plot_id = id;\n if ((window.Bokeh !== undefined) && (id in Bokeh.index)) {\n window.PyViz.plot_index[id] = Bokeh.index[id];\n } else {\n window.PyViz.plot_index[id] = null;\n }\n } else if (output.metadata[EXEC_MIME_TYPE][\"server_id\"] !== undefined) {\n var bk_div = document.createElement(\"div\");\n bk_div.innerHTML = output.data[HTML_MIME_TYPE];\n var script_attrs = bk_div.children[0].attributes;\n for (var i = 0; i < script_attrs.length; i++) {\n toinsert[toinsert.length - 1].childNodes[1].setAttribute(script_attrs[i].name, script_attrs[i].value);\n }\n // store reference to server id on output_area\n output_area._bokeh_server_id = output.metadata[EXEC_MIME_TYPE][\"server_id\"];\n }\n}\n\n/**\n * Handle when an output is cleared or removed\n */\nfunction handle_clear_output(event, handle) {\n var id = handle.cell.output_area._hv_plot_id;\n var server_id = handle.cell.output_area._bokeh_server_id;\n if (((id === undefined) || !(id in PyViz.plot_index)) && (server_id !== undefined)) { return; }\n var comm = window.PyViz.comm_manager.get_client_comm(\"hv-extension-comm\", \"hv-extension-comm\", function () {});\n if (server_id !== null) {\n comm.send({event_type: 'server_delete', 'id': server_id});\n return;\n } else if (comm !== null) {\n comm.send({event_type: 'delete', 'id': id});\n }\n delete PyViz.plot_index[id];\n if ((window.Bokeh !== undefined) & (id in window.Bokeh.index)) {\n var doc = window.Bokeh.index[id].model.document\n doc.clear();\n const i = window.Bokeh.documents.indexOf(doc);\n if (i > -1) {\n window.Bokeh.documents.splice(i, 1);\n }\n }\n}\n\n/**\n * Handle kernel restart event\n */\nfunction handle_kernel_cleanup(event, handle) {\n delete PyViz.comms[\"hv-extension-comm\"];\n window.PyViz.plot_index = {}\n}\n\n/**\n * Handle update_display_data messages\n */\nfunction handle_update_output(event, handle) {\n handle_clear_output(event, {cell: {output_area: handle.output_area}})\n handle_add_output(event, handle)\n}\n\nfunction register_renderer(events, OutputArea) {\n function append_mime(data, metadata, element) {\n // create a DOM node to render to\n var toinsert = this.create_output_subarea(\n metadata,\n CLASS_NAME,\n EXEC_MIME_TYPE\n );\n this.keyboard_manager.register_events(toinsert);\n // Render to node\n var props = {data: data, metadata: metadata[EXEC_MIME_TYPE]};\n render(props, toinsert[0]);\n element.append(toinsert);\n return toinsert\n }\n\n events.on('output_added.OutputArea', handle_add_output);\n events.on('output_updated.OutputArea', handle_update_output);\n events.on('clear_output.CodeCell', handle_clear_output);\n events.on('delete.Cell', handle_clear_output);\n events.on('kernel_ready.Kernel', handle_kernel_cleanup);\n\n OutputArea.prototype.register_mime_type(EXEC_MIME_TYPE, append_mime, {\n safe: true,\n index: 0\n });\n}\n\nif (window.Jupyter !== undefined) {\n try {\n var events = require('base/js/events');\n var OutputArea = require('notebook/js/outputarea').OutputArea;\n if (OutputArea.prototype.mime_types().indexOf(EXEC_MIME_TYPE) == -1) {\n register_renderer(events, OutputArea);\n }\n } catch(err) {\n }\n}\n", + "application/vnd.holoviews_load.v0+json": "" + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.holoviews_exec.v0+json": "", + "text/html": [ + "
\n", + "
\n", + "
\n", + "" + ] + }, + "metadata": { + "application/vnd.holoviews_exec.v0+json": { + "id": "cc773275-e3eb-4a12-93e1-3b0732805f97" + } + }, + "output_type": "display_data" + } + ], + "source": [ + "# imports\n", + "import torch\n", + "import panel as pn\n", + "from delphi.eval.vis import token_selector\n", + "from datasets import load_dataset, Dataset\n", + "from transformers import AutoTokenizer\n", + "from typing import cast\n", + "from delphi.eval.calc_model_group_stats import calc_model_group_stats\n", + "from delphi.eval.vis_per_token_model import visualize_selected_tokens\n", + "from ipywidgets import interact\n", + "from delphi.eval.token_positions import get_all_tok_metrics_in_label\n", + "from delphi.eval.vis import vis_pos_map\n", + "import ipywidgets as widgets\n", + "\n", + "# refer to https://panel.holoviz.org/reference/panes/IPyWidget.html to integrate ipywidgets with panel\n", + "pn.extension('ipywidgets')\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# specify model names (or checkpoints)\n", + "prefix = \"delphi-suite/v0-next-logprobs-llama2-\"\n", + "suffixes = [\n", + " \"100k\",\n", + " \"200k\",\n", + " \"400k\",\n", + "] # , \"800k\", \"1.6m\", \"3.2m\", \"6.4m\", \"12.8m\", \"25.6m\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# load next logprobs data for all models\n", + "split = \"validation[:100]\"\n", + "next_logprobs = {\n", + " suffix: cast(\n", + " Dataset,\n", + " load_dataset(f\"{prefix}{suffix}\", split=split),\n", + " )\n", + " .with_format(\"torch\")\n", + " .map(lambda x: {\"logprobs\": x[\"logprobs\"].to(device)})\n", + " for suffix in suffixes\n", + "}\n", + "next_logprobs_plot = {k: d[\"logprobs\"] for k, d in next_logprobs.items()}\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# load the tokenized dataset\n", + "tokenized_corpus_dataset = (\n", + " cast(\n", + " Dataset,\n", + " load_dataset(\"delphi-suite/stories-tokenized\", split=split),\n", + " )\n", + " .with_format(\"torch\")\n", + " .map(lambda x: {\"tokens\": x[\"tokens\"].to(device)})\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Run this notebook until the following cell, then the rest should work." + ] + }, + { + "cell_type": "code", + "execution_count": 68, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "7a90e8b37a674fe4a82fa837c5258335", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "BokehModel(combine_events=True, render_bundle={'docs_json': {'fbcd4d70-83e4-4f3b-a073-702f920d0738': {'version…" + ] + }, + "execution_count": 68, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# specific token specification\n", + "tokenizer = AutoTokenizer.from_pretrained(\"delphi-suite/stories-tokenizer\")\n", + "\n", + "# Count the frequency of each token using torch.bincount\n", + "token_counts = torch.bincount(tokenized_corpus_dataset[\"tokens\"].view(-1))\n", + "\n", + "# Get the indices that would sort the token counts in descending order\n", + "sorted_indices = torch.argsort(token_counts, descending=True)\n", + "\n", + "# Get the token IDs in descending order of frequency\n", + "valid_tok_ids = sorted_indices.tolist()\n", + "def format_fix(s):\n", + " if s.startswith(\" \"):\n", + " return \"_\" + s[1:]\n", + " return s\n", + "vocab = {format_fix(tokenizer.decode(t, clean_up_tokenization_spaces=True)): t for t in sorted_indices.tolist() if token_counts[t] > 0}\n", + "\n", + "\n", + "selector, selected_ids = token_selector(vocab) # use selected_ids as a dynamic variable\n", + "pn.Row(selector, height=500).servable()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Selected IDs: [40, 2, 14]\n" + ] + } + ], + "source": [ + "if not selected_ids:\n", + " selected_ids = [40, 2, 14]\n", + "print(\"Selected IDs:\", selected_ids)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([100, 513])" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "list(next_logprobs_plot.values())[0].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 71, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Processing model 100k\n", + "Processing model 200k\n", + "Processing model 400k\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "133ee80dfa814961944d5ddf76f2fc6e", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "FigureWidget({\n", + " 'data': [{'line': {'width': 0},\n", + " 'marker': {'color': 'rgba(68, 68, 68, 0.3)'},\n", + " 'mode': 'lines',\n", + " 'name': 'Upper Bound',\n", + " 'showlegend': False,\n", + " 'type': 'scatter',\n", + " 'uid': '7b5d42bd-2dc2-4c7b-b9ef-500259ecc5e6',\n", + " 'x': [100k, 200k, 400k],\n", + " 'y': array([3.71775752, 3.02412134, 3.20290118])},\n", + " {'fill': 'tonexty',\n", + " 'fillcolor': 'rgba(68, 68, 68, 0.3)',\n", + " 'line': {'width': 0},\n", + " 'marker': {'color': 'rgba(68, 68, 68, 0.3)'},\n", + " 'mode': 'lines',\n", + " 'name': 'Lower Bound',\n", + " 'showlegend': False,\n", + " 'type': 'scatter',\n", + " 'uid': '716e9e70-2be1-4d42-ad85-8c815b6b66b2',\n", + " 'x': [100k, 200k, 400k],\n", + " 'y': array([0.68709016, 0.66691843, 0.41506469])},\n", + " {'marker': {'color': 'rgb(31, 119, 180)', 'line': {'color': 'rgb(31, 119, 180)', 'width': 1}, 'size': 0},\n", + " 'mode': 'lines',\n", + " 'name': 'Means',\n", + " 'type': 'scatter',\n", + " 'uid': 'f388f676-370a-45d7-8f75-4d51c27d1728',\n", + " 'x': [100k, 200k, 400k],\n", + " 'y': array([0.94124633, 0.75591046, 0.50620776])}],\n", + " 'layout': {'template': '...'}\n", + "})" + ] + }, + "execution_count": 71, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model_group_stats = calc_model_group_stats( # i'm not sure if tokenized_corpus_dataset.tolist() is the right input, it was list(tokenized_corpus_dataset) before\n", + " tokenized_corpus_dataset, next_logprobs_plot, selected_ids\n", + ")\n", + "performance_data = {}\n", + "for suffix in suffixes:\n", + " stats = model_group_stats[suffix]\n", + " performance_data[suffix] = (\n", + " -stats[\"median\"],\n", + " -stats[\"75th\"],\n", + " -stats[\"25th\"],\n", + " )\n", + "\n", + "visualize_selected_tokens(performance_data, log_scale=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "from ipywidgets import interact_manual" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "6ff334e54715402eb32b3a77aef0a830", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "interactive(children=(FloatRangeSlider(value=(0.25, 0.75), description='Quantiles', max=1.0, step=0.05), Dropd…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def show_pos_map(\n", + " quantile: tuple[float, float],\n", + " model_name_1: str,\n", + " model_name_2: str,\n", + "):\n", + " logprobs_diff = next_logprobs[model_name_2][\"logprobs\"] - next_logprobs[model_name_1][\"logprobs\"] # type: ignore\n", + " pos_to_diff = get_all_tok_metrics_in_label(tokenized_corpus_dataset[\"tokens\"], selected_tokens=selected_ids, metrics=logprobs_diff, q_start=quantile[0], q_end=quantile[1]) # type: ignore\n", + " try:\n", + " _ = vis_pos_map(list(pos_to_diff.keys()), selected_ids, logprobs_diff, tokenized_corpus_dataset[\"tokens\"], tokenizer) # type: ignore\n", + " except ValueError:\n", + " if pos_to_diff == {}:\n", + " print(\"No tokens found in this label\")\n", + " return\n", + "\n", + "\n", + "interact_manual(\n", + " show_pos_map,\n", + " quantile=widgets.FloatRangeSlider(\n", + " min=0.0, max=1.0, step=0.05, description=\"Quantiles\"\n", + " ),\n", + " samples=widgets.IntSlider(min=1, max=5, description=\"Samples\", value=2),\n", + " model_name_1=widgets.Dropdown(\n", + " options=suffixes,\n", + " description=\"Model 1\",\n", + " value=\"100k\",\n", + " ),\n", + " model_name_2=widgets.Dropdown(\n", + " options=suffixes,\n", + " description=\"Model 2\",\n", + " value=\"200k\",\n", + " ),\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/delphi/eval/calc_model_group_stats.py b/src/delphi/eval/calc_model_group_stats.py index d9c5d4c1..faab8a02 100644 --- a/src/delphi/eval/calc_model_group_stats.py +++ b/src/delphi/eval/calc_model_group_stats.py @@ -1,33 +1,30 @@ import numpy as np +import torch +from datasets import Dataset +from jaxtyping import Float def calc_model_group_stats( - tokenized_corpus_dataset: list, - logprobs_by_dataset: dict[str, list[list[float]]], - token_labels_by_token: dict[int, dict[str, bool]], - token_labels: list[str], -) -> dict[tuple[str, str], dict[str, float]]: + tokenized_corpus_dataset: Dataset, + logprobs_by_dataset: dict[str, torch.Tensor], + selected_tokens: list[int], +) -> dict[str, dict[str, float]]: """ For each (model, token group) pair, calculate useful stats (for visualization) args: - - tokenized_corpus_dataset: the tokenized corpus dataset, e.g. load_dataset(constants.tokenized_corpus_dataset))["validation"] + - tokenized_corpus_dataset: a list of the tokenized corpus datasets, e.g. load_dataset(constants.tokenized_corpus_dataset))["validation"] - logprob_datasets: a dict of lists of logprobs, e.g. {"llama2": load_dataset("transcendingvictor/llama2-validation-logprobs")["validation"]["logprobs"]} - - token_groups: a dict of token groups, e.g. {0: {"Is Noun": True, "Is Verb": False, ...}, 1: {...}, ...} - - models: a list of model names, e.g. constants.LLAMA2_MODELS - - token_labels: a list of token group descriptions, e.g. ["Is Noun", "Is Verb", ...] + - selected_tokens: a list of selected token IDs, e.g. [46, 402, ...] - returns: a dict of (model, token group) pairs to a dict of stats, - e.g. {("llama2", "Is Noun"): {"mean": -0.5, "median": -0.4, "min": -0.1, "max": -0.9, "25th": -0.3, "75th": -0.7}, ...} + returns: a dict of model names as keys and stats dict as values + e.g. {"100k": {"mean": -0.5, "median": -0.4, "min": -0.1, "max": -0.9, "25th": -0.3, "75th": -0.7}, ...} - Technically `models` and `token_labels` are redundant, as they are also keys in `logprob_datasets` and `token_groups`, - but it's better to be explicit - - stats calculated: mean, median, min, max, 25th percentile, 75th percentile + Stats calculated: mean, median, min, max, 25th percentile, 75th percentile """ model_group_stats = {} for model in logprobs_by_dataset: - group_logprobs = {} + model_logprobs = [] print(f"Processing model {model}") dataset = logprobs_by_dataset[model] for ix_doc_lp, document_lps in enumerate(dataset): @@ -35,20 +32,17 @@ def calc_model_group_stats( for ix_token, token in enumerate(tokens): if ix_token == 0: # skip the first token, which isn't predicted continue - logprob = document_lps[ix_token] - for token_group_desc in token_labels: - if token_labels_by_token[token][token_group_desc]: - if token_group_desc not in group_logprobs: - group_logprobs[token_group_desc] = [] - group_logprobs[token_group_desc].append(logprob) - for token_group_desc in token_labels: - if token_group_desc in group_logprobs: - model_group_stats[(model, token_group_desc)] = { - "mean": np.mean(group_logprobs[token_group_desc]), - "median": np.median(group_logprobs[token_group_desc]), - "min": np.min(group_logprobs[token_group_desc]), - "max": np.max(group_logprobs[token_group_desc]), - "25th": np.percentile(group_logprobs[token_group_desc], 25), - "75th": np.percentile(group_logprobs[token_group_desc], 75), - } + logprob = document_lps[ix_token].item() + if token in selected_tokens: + model_logprobs.append(logprob) + + if model_logprobs: + model_group_stats[model] = { + "mean": np.mean(model_logprobs), + "median": np.median(model_logprobs), + "min": np.min(model_logprobs), + "max": np.max(model_logprobs), + "25th": np.percentile(model_logprobs, 25), + "75th": np.percentile(model_logprobs, 75), + } return model_group_stats diff --git a/src/delphi/eval/constants.py b/src/delphi/eval/constants.py index 5cd3daf1..3a586e00 100644 --- a/src/delphi/eval/constants.py +++ b/src/delphi/eval/constants.py @@ -2,15 +2,15 @@ tokenized_corpus_dataset = "delphi-suite/tinystories-v2-clean-tokenized-v0" LLAMA2_MODELS = [ - "delphi-llama2-100k", - "delphi-llama2-200k", - "delphi-llama2-400k", - "delphi-llama2-800k", - "delphi-llama2-1.6m", - "delphi-llama2-3.2m", - "delphi-llama2-6.4m", - "delphi-llama2-12.8m", - "delphi-llama2-25.6m", + "llama2-100k", + "llama2-200k", + "llama2-400k", + "llama2-800k", + "llama2-1.6m", + "llama2-3.2m", + "llama2-6.4m", + "llama2-12.8m", + "llama2-25.6m", ] LLAMA2_NEXT_LOGPROBS_DATASETS_MAP = { diff --git a/src/delphi/eval/token_positions.py b/src/delphi/eval/token_positions.py index 5239a53f..a98af761 100644 --- a/src/delphi/eval/token_positions.py +++ b/src/delphi/eval/token_positions.py @@ -1,8 +1,6 @@ -from numbers import Number -from typing import Optional, cast +from typing import Optional import torch -from datasets import Dataset from jaxtyping import Int from delphi.eval.utils import dict_filter_quantile @@ -10,9 +8,8 @@ def get_all_tok_metrics_in_label( token_ids: Int[torch.Tensor, "prompt pos"], - token_labels: dict[int, dict[str, bool]], + selected_tokens: list[int], metrics: torch.Tensor, - label: str, q_start: Optional[float] = None, q_end: Optional[float] = None, ) -> dict[tuple[int, int], float]: @@ -23,9 +20,8 @@ def get_all_tok_metrics_in_label( Args: - token_ids (Dataset): token_ids dataset e.g. token_ids[0] = {"tokens": [[1, 2, ...], [2, 5, ...], ...]} - - token_labels (dict[int, dict[str, bool]]): dictionary of token labels e.g. { 0: {"Is Noun": True, "Is Verb": False}, ...} + - selected_tokens (list[int]): list of token IDs to search for e.g. [46, 402, ...] - metrics (torch.Tensor): tensor of metrics to search through e.g. torch.tensor([[0.1, 0.2, ...], [0.3, 0.4, ...], ...]) - - label (str): the label to search for - q_start (float): the start of the quantile range to filter the metrics e.g. 0.1 - q_end (float): the end of the quantile range to filter the metrics e.g. 0.9 @@ -42,7 +38,7 @@ def get_all_tok_metrics_in_label( tok_positions = {} for prompt_pos, prompt in enumerate(token_ids.numpy()): for tok_pos, tok in enumerate(prompt): - if token_labels[tok][label]: + if tok in selected_tokens: tok_positions[(prompt_pos, tok_pos)] = metrics[ prompt_pos, tok_pos ].item() diff --git a/src/delphi/eval/vis.py b/src/delphi/eval/vis.py index 1a69eae2..f13924ad 100644 --- a/src/delphi/eval/vis.py +++ b/src/delphi/eval/vis.py @@ -3,6 +3,7 @@ import uuid from typing import cast +import numpy as np import panel as pn import torch from IPython.core.display import HTML @@ -54,6 +55,7 @@ def token_to_html( tokenizer: PreTrainedTokenizerBase, bg_color: str, data: dict, + class_name: str = "token", ) -> str: data = data or {} # equivalent to if not data: data = {} # non-breakable space, w/o it leading spaces wouldn't be displayed @@ -73,6 +75,7 @@ def token_to_html( br += "
" # this is so we can copy the prompt without "\n"s specific_styles["user-select"] = "none" + str_token = str_token.replace("<", "<").replace(">", ">") style_str = data_str = "" # converting style dict into the style attribute @@ -83,7 +86,7 @@ def token_to_html( data_str = "".join( f" data-{k}='{v.replace(' ', ' ')}'" for k, v in data.items() ) - return f"
{str_token}
{br}" + return f"
{str_token}
{br}" _token_style = { @@ -97,7 +100,20 @@ def token_to_html( "margin": "1px 0px 1px 1px", "padding": "0px 1px 1px 1px", } +_token_emphasized_style = { + "border": "3px solid #888", + "display": "inline-block", + "font-family": "monospace", + "font-size": "14px", + "color": "black", + "background-color": "white", + "margin": "1px 0px 1px 1px", + "padding": "0px 1px 1px 1px", +} _token_style_str = " ".join([f"{k}: {v};" for k, v in _token_style.items()]) +_token_emphasized_style_str = " ".join( + [f"{k}: {v};" for k, v in _token_emphasized_style.items()] +) def vis_sample_prediction_probs( @@ -130,8 +146,8 @@ def vis_sample_prediction_probs( data[f"top{j}"] = to_tok_prob_str(top_tok, top_prob, tokenizer) token_htmls.append( - token_to_html(tok, tokenizer, bg_color=colors[i], data=data).replace( - "class='token'", f"class='{token_class}'" + token_to_html( + tok, tokenizer, bg_color=colors[i], data=data, class_name=token_class ) ) @@ -165,10 +181,11 @@ def vis_sample_prediction_probs( def vis_pos_map( - pos_map: dict[tuple[int, int], float | int], + pos_list: list[tuple[int, int]], + selected_tokens: list[int], + metrics: Float[torch.Tensor, "prompt pos"], token_ids: Int[torch.Tensor, "prompt pos"], tokenizer: PreTrainedTokenizerBase, - sample: int = 3, ): """ Randomly sample from pos_map and visualize the loss diff at the corresponding position. @@ -176,49 +193,43 @@ def vis_pos_map( token_htmls = [] unique_id = str(uuid.uuid4()) - token_class = f"token_{unique_id}" + token_class = f"pretoken_{unique_id}" + selected_token_class = f"token_{unique_id}" hover_div_id = f"hover_info_{unique_id}" - # choose n random keys from pos_map - keys = random.sample(list(pos_map.keys()), k=sample) - - for key in keys: - prompt, pos = key - pre_toks = token_ids[prompt][:pos] - mask = torch.isin(pre_toks, torch.tensor([0, 1], dtype=torch.int8)) - pre_toks = pre_toks[ - ~mask - ] # remove and tokens, cause strikethrough in html - - for i in range(pre_toks.shape[0]): - pre_tok = cast(int, pre_toks[i].item()) - token_htmls.append( - token_to_html(pre_tok, tokenizer, bg_color="white", data={}).replace( - "class='token'", f"class='{token_class}'" - ) - ) + # choose a random keys from pos_map + key = random.choice(pos_list) - tok = cast(int, token_ids[prompt][pos].item()) - value = cast(float, pos_map[key]) + prompt, pos = key + all_toks = token_ids[prompt][: pos + 1] + for i in range(all_toks.shape[0]): + token_id = cast(int, all_toks[i].item()) + value = metrics[prompt][i].item() token_htmls.append( token_to_html( - tok, + token_id, tokenizer, - bg_color=single_loss_diff_to_color(value), + bg_color="white" + if np.isnan(value) + else single_loss_diff_to_color(value), data={"loss-diff": f"{value:.2f}"}, - ).replace("class='token'", f"class='{token_class}'") + class_name=token_class + if token_id not in selected_tokens + else selected_token_class, + ) ) - # add break line - token_htmls.append("

") + # add break line + token_htmls.append("

") html_str = f""" - + {"".join(token_htmls)}
""" display(HTML(html_str)) - return html_str def token_selector( vocab_map: dict[str, int] ) -> tuple[pn.widgets.MultiChoice, list[int]]: tokens = list(vocab_map.keys()) - token_selector = pn.widgets.MultiChoice(name="Tokens", options=tokens) - token_ids = [vocab_map[token] for token in cast(list[str], token_selector.value)] + token_selector_ = pn.widgets.MultiChoice(name="Tokens", options=tokens) + token_ids = [vocab_map[token] for token in cast(list[str], token_selector_.value)] def update_tokens(event): token_ids.clear() token_ids.extend([vocab_map[token] for token in event.new]) - token_selector.param.watch(update_tokens, "value") - return token_selector, token_ids + token_selector_.param.watch(update_tokens, "value") + return token_selector_, token_ids diff --git a/src/delphi/eval/vis_per_token_model.py b/src/delphi/eval/vis_per_token_model.py index 8daaa96f..e5d735f4 100644 --- a/src/delphi/eval/vis_per_token_model.py +++ b/src/delphi/eval/vis_per_token_model.py @@ -1,12 +1,11 @@ from typing import Union -import ipywidgets import numpy as np import plotly.graph_objects as go -def visualize_per_token_category( - input: dict[Union[str, int], dict[str, tuple]], +def visualize_selected_tokens( + input: dict[Union[str, int], tuple[float, float, float]], log_scale=False, line_metric="Means", checkpoint_mode=True, @@ -17,18 +16,16 @@ def visualize_per_token_category( background_color="AliceBlue", ) -> go.FigureWidget: input_x = list(input.keys()) - categories = list(input[input_x[0]].keys()) - category = categories[0] def get_hovertexts(mid: np.ndarray, lo: np.ndarray, hi: np.ndarray) -> list[str]: return [f"Loss: {m:.3f} ({l:.3f}, {h:.3f})" for m, l, h in zip(mid, lo, hi)] - def get_plot_values(category: str) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - x = np.array([input[x][category] for x in input_x]).T + def get_plot_values() -> tuple[np.ndarray, np.ndarray, np.ndarray]: + x = np.array([input[x] for x in input_x]).T means, err_lo, err_hi = x[0], x[1], x[2] return means, err_lo, err_hi - means, err_lo, err_hi = get_plot_values(category) + means, err_lo, err_hi = get_plot_values() if checkpoint_mode: scatter_plot = go.Figure( diff --git a/tests/eval/test_token_positions.py b/tests/eval/test_token_positions.py index 1adef6b7..c584b931 100644 --- a/tests/eval/test_token_positions.py +++ b/tests/eval/test_token_positions.py @@ -12,40 +12,43 @@ def mock_data(): token_ids = Dataset.from_dict( {"tokens": [[1, 2, 3], [4, 5, 6], [7, 8, 9]]} ).with_format("torch") - token_labels = { - 1: {"Is Noun": False, "Is Verb": True}, - 2: {"Is Noun": True, "Is Verb": True}, - 3: {"Is Noun": False, "Is Verb": False}, - 4: {"Is Noun": True, "Is Verb": False}, - 5: {"Is Noun": False, "Is Verb": True}, - 6: {"Is Noun": True, "Is Verb": True}, - 7: {"Is Noun": False, "Is Verb": False}, - 8: {"Is Noun": True, "Is Verb": False}, - 9: {"Is Noun": False, "Is Verb": True}, - } - metrics = torch.tensor([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]) - return token_ids, token_labels, metrics + selected_tokens = [2, 4, 6, 8] + metrics = torch.tensor([[-1, 0.45, -0.33], [-1.31, 2.3, 0.6], [0.2, 0.8, 0.1]]) + return token_ids, selected_tokens, metrics def test_get_all_tok_metrics_in_label(mock_data): - token_ids, token_labels, metrics = mock_data + token_ids, selected_tokens, metrics = mock_data result = get_all_tok_metrics_in_label( - token_ids["tokens"], token_labels, metrics, "Is Noun" + token_ids["tokens"], + selected_tokens, + metrics, ) + # key: (prompt_pos, tok_pos), value: logprob expected = { - (0, 1): 0.2, - (1, 0): 0.4, + (0, 1): 0.45, + (1, 0): -1.31, (1, 2): 0.6, (2, 1): 0.8, } - # use isclose to compare floating point numbers + + # compare keys + assert result.keys() == expected.keys() + # compare values for k in result: assert isclose(cast(float, result[k]), expected[k], rel_tol=1e-6) # type: ignore # test with quantile filtering result_q = get_all_tok_metrics_in_label( - token_ids["tokens"], token_labels, metrics, "Is Noun", q_start=0.3, q_end=1.0 + token_ids["tokens"], selected_tokens, metrics, q_start=0.6, q_end=1.0 ) - expected_q = {(1, 2): 0.6, (2, 1): 0.8, (1, 0): 0.4} + expected_q = { + (1, 2): 0.6, + (2, 1): 0.8, + } + + # compare keys + assert result_q.keys() == expected_q.keys() + # compare values for k in result_q: assert isclose(cast(float, result_q[k]), expected_q[k], rel_tol=1e-6) # type: ignore diff --git a/tests/eval/test_utils_eval.py b/tests/eval/test_utils_eval.py index ad0f54b8..a259d16b 100644 --- a/tests/eval/test_utils_eval.py +++ b/tests/eval/test_utils_eval.py @@ -61,7 +61,22 @@ def test_load_validation_dataset(): def test_dict_filter_quantile(): d = {1: 0.1, 2: 0.2, 3: 0.3, 4: 0.4, 5: 0.5} result = dict_filter_quantile(d, 0.2, 0.6) - expected = {2: 0.2, 3: 0.3, 4: 0.4} + expected = {2: 0.2, 3: 0.3} + + # compare keys + assert result.keys() == expected.keys() + # compare values + for k in result: + assert isclose(result[k], expected[k], rel_tol=1e-6) + + # test with negative values + d = {1: -0.1, 2: -0.2, 3: -0.3, 4: -0.4, 5: -0.5} + result = dict_filter_quantile(d, 0.2, 0.6) + expected = {3: -0.3, 4: -0.4} + + # compare keys + assert result.keys() == expected.keys() + # compare values for k in result: assert isclose(result[k], expected[k], rel_tol=1e-6)