diff --git a/notebooks/eval_notebook.ipynb b/notebooks/eval_notebook.ipynb
new file mode 100644
index 00000000..864a1f1e
--- /dev/null
+++ b/notebooks/eval_notebook.ipynb
@@ -0,0 +1,480 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# colab cells (only run if on colab)\n",
+ "# TODO: experiment on colab to see how to set up the environment"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Important\n",
+ "\n",
+ "Run this cell by cell. The token selecter cell needs to be ran first so the later cells work."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/javascript": "(function(root) {\n function now() {\n return new Date();\n }\n\n var force = true;\n var py_version = '3.4.1'.replace('rc', '-rc.').replace('.dev', '-dev.');\n var reloading = false;\n var Bokeh = root.Bokeh;\n\n if (typeof (root._bokeh_timeout) === \"undefined\" || force) {\n root._bokeh_timeout = Date.now() + 5000;\n root._bokeh_failed_load = false;\n }\n\n function run_callbacks() {\n try {\n root._bokeh_onload_callbacks.forEach(function(callback) {\n if (callback != null)\n callback();\n });\n } finally {\n delete root._bokeh_onload_callbacks;\n }\n console.debug(\"Bokeh: all callbacks have finished\");\n }\n\n function load_libs(css_urls, js_urls, js_modules, js_exports, callback) {\n if (css_urls == null) css_urls = [];\n if (js_urls == null) js_urls = [];\n if (js_modules == null) js_modules = [];\n if (js_exports == null) js_exports = {};\n\n root._bokeh_onload_callbacks.push(callback);\n\n if (root._bokeh_is_loading > 0) {\n console.debug(\"Bokeh: BokehJS is being loaded, scheduling callback at\", now());\n return null;\n }\n if (js_urls.length === 0 && js_modules.length === 0 && Object.keys(js_exports).length === 0) {\n run_callbacks();\n return null;\n }\n if (!reloading) {\n console.debug(\"Bokeh: BokehJS not loaded, scheduling load and callback at\", now());\n }\n\n function on_load() {\n root._bokeh_is_loading--;\n if (root._bokeh_is_loading === 0) {\n console.debug(\"Bokeh: all BokehJS libraries/stylesheets loaded\");\n run_callbacks()\n }\n }\n window._bokeh_on_load = on_load\n\n function on_error() {\n console.error(\"failed to load \" + url);\n }\n\n var skip = [];\n if (window.requirejs) {\n window.requirejs.config({'packages': {}, 'paths': {}, 'shim': {}});\n root._bokeh_is_loading = css_urls.length + 0;\n } else {\n root._bokeh_is_loading = css_urls.length + js_urls.length + js_modules.length + Object.keys(js_exports).length;\n }\n\n var existing_stylesheets = []\n var links = document.getElementsByTagName('link')\n for (var i = 0; i < links.length; i++) {\n var link = links[i]\n if (link.href != null) {\n\texisting_stylesheets.push(link.href)\n }\n }\n for (var i = 0; i < css_urls.length; i++) {\n var url = css_urls[i];\n if (existing_stylesheets.indexOf(url) !== -1) {\n\ton_load()\n\tcontinue;\n }\n const element = document.createElement(\"link\");\n element.onload = on_load;\n element.onerror = on_error;\n element.rel = \"stylesheet\";\n element.type = \"text/css\";\n element.href = url;\n console.debug(\"Bokeh: injecting link tag for BokehJS stylesheet: \", url);\n document.body.appendChild(element);\n } var existing_scripts = []\n var scripts = document.getElementsByTagName('script')\n for (var i = 0; i < scripts.length; i++) {\n var script = scripts[i]\n if (script.src != null) {\n\texisting_scripts.push(script.src)\n }\n }\n for (var i = 0; i < js_urls.length; i++) {\n var url = js_urls[i];\n if (skip.indexOf(url) !== -1 || existing_scripts.indexOf(url) !== -1) {\n\tif (!window.requirejs) {\n\t on_load();\n\t}\n\tcontinue;\n }\n var element = document.createElement('script');\n element.onload = on_load;\n element.onerror = on_error;\n element.async = false;\n element.src = url;\n console.debug(\"Bokeh: injecting script tag for BokehJS library: \", url);\n document.head.appendChild(element);\n }\n for (var i = 0; i < js_modules.length; i++) {\n var url = js_modules[i];\n if (skip.indexOf(url) !== -1 || existing_scripts.indexOf(url) !== -1) {\n\tif (!window.requirejs) {\n\t on_load();\n\t}\n\tcontinue;\n }\n var element = document.createElement('script');\n element.onload = on_load;\n element.onerror = on_error;\n element.async = false;\n element.src = url;\n element.type = \"module\";\n console.debug(\"Bokeh: injecting script tag for BokehJS library: \", url);\n document.head.appendChild(element);\n }\n for (const name in js_exports) {\n var url = js_exports[name];\n if (skip.indexOf(url) >= 0 || root[name] != null) {\n\tif (!window.requirejs) {\n\t on_load();\n\t}\n\tcontinue;\n }\n var element = document.createElement('script');\n element.onerror = on_error;\n element.async = false;\n element.type = \"module\";\n console.debug(\"Bokeh: injecting script tag for BokehJS library: \", url);\n element.textContent = `\n import ${name} from \"${url}\"\n window.${name} = ${name}\n window._bokeh_on_load()\n `\n document.head.appendChild(element);\n }\n if (!js_urls.length && !js_modules.length) {\n on_load()\n }\n };\n\n function inject_raw_css(css) {\n const element = document.createElement(\"style\");\n element.appendChild(document.createTextNode(css));\n document.body.appendChild(element);\n }\n\n var js_urls = [\"https://cdn.bokeh.org/bokeh/release/bokeh-3.4.1.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-gl-3.4.1.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-widgets-3.4.1.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-tables-3.4.1.min.js\", \"https://cdn.holoviz.org/panel/1.4.0/dist/panel.min.js\"];\n var js_modules = [];\n var js_exports = {};\n var css_urls = [\"https://cdn.holoviz.org/panel/1.4.0/dist/bundled/font-awesome/css/all.min.css\"];\n var inline_js = [ function(Bokeh) {\n Bokeh.set_log_level(\"info\");\n },\nfunction(Bokeh) {} // ensure no trailing comma for IE\n ];\n\n function run_inline_js() {\n if ((root.Bokeh !== undefined) || (force === true)) {\n for (var i = 0; i < inline_js.length; i++) {\n\ttry {\n inline_js[i].call(root, root.Bokeh);\n\t} catch(e) {\n\t if (!reloading) {\n\t throw e;\n\t }\n\t}\n }\n // Cache old bokeh versions\n if (Bokeh != undefined && !reloading) {\n\tvar NewBokeh = root.Bokeh;\n\tif (Bokeh.versions === undefined) {\n\t Bokeh.versions = new Map();\n\t}\n\tif (NewBokeh.version !== Bokeh.version) {\n\t Bokeh.versions.set(NewBokeh.version, NewBokeh)\n\t}\n\troot.Bokeh = Bokeh;\n }} else if (Date.now() < root._bokeh_timeout) {\n setTimeout(run_inline_js, 100);\n } else if (!root._bokeh_failed_load) {\n console.log(\"Bokeh: BokehJS failed to load within specified timeout.\");\n root._bokeh_failed_load = true;\n }\n root._bokeh_is_initializing = false\n }\n\n function load_or_wait() {\n // Implement a backoff loop that tries to ensure we do not load multiple\n // versions of Bokeh and its dependencies at the same time.\n // In recent versions we use the root._bokeh_is_initializing flag\n // to determine whether there is an ongoing attempt to initialize\n // bokeh, however for backward compatibility we also try to ensure\n // that we do not start loading a newer (Panel>=1.0 and Bokeh>3) version\n // before older versions are fully initialized.\n if (root._bokeh_is_initializing && Date.now() > root._bokeh_timeout) {\n root._bokeh_is_initializing = false;\n root._bokeh_onload_callbacks = undefined;\n console.log(\"Bokeh: BokehJS was loaded multiple times but one version failed to initialize.\");\n load_or_wait();\n } else if (root._bokeh_is_initializing || (typeof root._bokeh_is_initializing === \"undefined\" && root._bokeh_onload_callbacks !== undefined)) {\n setTimeout(load_or_wait, 100);\n } else {\n root._bokeh_is_initializing = true\n root._bokeh_onload_callbacks = []\n var bokeh_loaded = Bokeh != null && (Bokeh.version === py_version || (Bokeh.versions !== undefined && Bokeh.versions.has(py_version)));\n if (!reloading && !bokeh_loaded) {\n\troot.Bokeh = undefined;\n }\n load_libs(css_urls, js_urls, js_modules, js_exports, function() {\n\tconsole.debug(\"Bokeh: BokehJS plotting callback run at\", now());\n\trun_inline_js();\n });\n }\n }\n // Give older versions of the autoload script a head-start to ensure\n // they initialize before we start loading newer version.\n setTimeout(load_or_wait, 100)\n}(window));",
+ "application/vnd.holoviews_load.v0+json": ""
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/javascript": "\nif ((window.PyViz === undefined) || (window.PyViz instanceof HTMLElement)) {\n window.PyViz = {comms: {}, comm_status:{}, kernels:{}, receivers: {}, plot_index: []}\n}\n\n\n function JupyterCommManager() {\n }\n\n JupyterCommManager.prototype.register_target = function(plot_id, comm_id, msg_handler) {\n if (window.comm_manager || ((window.Jupyter !== undefined) && (Jupyter.notebook.kernel != null))) {\n var comm_manager = window.comm_manager || Jupyter.notebook.kernel.comm_manager;\n comm_manager.register_target(comm_id, function(comm) {\n comm.on_msg(msg_handler);\n });\n } else if ((plot_id in window.PyViz.kernels) && (window.PyViz.kernels[plot_id])) {\n window.PyViz.kernels[plot_id].registerCommTarget(comm_id, function(comm) {\n comm.onMsg = msg_handler;\n });\n } else if (typeof google != 'undefined' && google.colab.kernel != null) {\n google.colab.kernel.comms.registerTarget(comm_id, (comm) => {\n var messages = comm.messages[Symbol.asyncIterator]();\n function processIteratorResult(result) {\n var message = result.value;\n console.log(message)\n var content = {data: message.data, comm_id};\n var buffers = []\n for (var buffer of message.buffers || []) {\n buffers.push(new DataView(buffer))\n }\n var metadata = message.metadata || {};\n var msg = {content, buffers, metadata}\n msg_handler(msg);\n return messages.next().then(processIteratorResult);\n }\n return messages.next().then(processIteratorResult);\n })\n }\n }\n\n JupyterCommManager.prototype.get_client_comm = function(plot_id, comm_id, msg_handler) {\n if (comm_id in window.PyViz.comms) {\n return window.PyViz.comms[comm_id];\n } else if (window.comm_manager || ((window.Jupyter !== undefined) && (Jupyter.notebook.kernel != null))) {\n var comm_manager = window.comm_manager || Jupyter.notebook.kernel.comm_manager;\n var comm = comm_manager.new_comm(comm_id, {}, {}, {}, comm_id);\n if (msg_handler) {\n comm.on_msg(msg_handler);\n }\n } else if ((plot_id in window.PyViz.kernels) && (window.PyViz.kernels[plot_id])) {\n var comm = window.PyViz.kernels[plot_id].connectToComm(comm_id);\n comm.open();\n if (msg_handler) {\n comm.onMsg = msg_handler;\n }\n } else if (typeof google != 'undefined' && google.colab.kernel != null) {\n var comm_promise = google.colab.kernel.comms.open(comm_id)\n comm_promise.then((comm) => {\n window.PyViz.comms[comm_id] = comm;\n if (msg_handler) {\n var messages = comm.messages[Symbol.asyncIterator]();\n function processIteratorResult(result) {\n var message = result.value;\n var content = {data: message.data};\n var metadata = message.metadata || {comm_id};\n var msg = {content, metadata}\n msg_handler(msg);\n return messages.next().then(processIteratorResult);\n }\n return messages.next().then(processIteratorResult);\n }\n }) \n var sendClosure = (data, metadata, buffers, disposeOnDone) => {\n return comm_promise.then((comm) => {\n comm.send(data, metadata, buffers, disposeOnDone);\n });\n };\n var comm = {\n send: sendClosure\n };\n }\n window.PyViz.comms[comm_id] = comm;\n return comm;\n }\n window.PyViz.comm_manager = new JupyterCommManager();\n \n\n\nvar JS_MIME_TYPE = 'application/javascript';\nvar HTML_MIME_TYPE = 'text/html';\nvar EXEC_MIME_TYPE = 'application/vnd.holoviews_exec.v0+json';\nvar CLASS_NAME = 'output';\n\n/**\n * Render data to the DOM node\n */\nfunction render(props, node) {\n var div = document.createElement(\"div\");\n var script = document.createElement(\"script\");\n node.appendChild(div);\n node.appendChild(script);\n}\n\n/**\n * Handle when a new output is added\n */\nfunction handle_add_output(event, handle) {\n var output_area = handle.output_area;\n var output = handle.output;\n if ((output.data == undefined) || (!output.data.hasOwnProperty(EXEC_MIME_TYPE))) {\n return\n }\n var id = output.metadata[EXEC_MIME_TYPE][\"id\"];\n var toinsert = output_area.element.find(\".\" + CLASS_NAME.split(' ')[0]);\n if (id !== undefined) {\n var nchildren = toinsert.length;\n var html_node = toinsert[nchildren-1].children[0];\n html_node.innerHTML = output.data[HTML_MIME_TYPE];\n var scripts = [];\n var nodelist = html_node.querySelectorAll(\"script\");\n for (var i in nodelist) {\n if (nodelist.hasOwnProperty(i)) {\n scripts.push(nodelist[i])\n }\n }\n\n scripts.forEach( function (oldScript) {\n var newScript = document.createElement(\"script\");\n var attrs = [];\n var nodemap = oldScript.attributes;\n for (var j in nodemap) {\n if (nodemap.hasOwnProperty(j)) {\n attrs.push(nodemap[j])\n }\n }\n attrs.forEach(function(attr) { newScript.setAttribute(attr.name, attr.value) });\n newScript.appendChild(document.createTextNode(oldScript.innerHTML));\n oldScript.parentNode.replaceChild(newScript, oldScript);\n });\n if (JS_MIME_TYPE in output.data) {\n toinsert[nchildren-1].children[1].textContent = output.data[JS_MIME_TYPE];\n }\n output_area._hv_plot_id = id;\n if ((window.Bokeh !== undefined) && (id in Bokeh.index)) {\n window.PyViz.plot_index[id] = Bokeh.index[id];\n } else {\n window.PyViz.plot_index[id] = null;\n }\n } else if (output.metadata[EXEC_MIME_TYPE][\"server_id\"] !== undefined) {\n var bk_div = document.createElement(\"div\");\n bk_div.innerHTML = output.data[HTML_MIME_TYPE];\n var script_attrs = bk_div.children[0].attributes;\n for (var i = 0; i < script_attrs.length; i++) {\n toinsert[toinsert.length - 1].childNodes[1].setAttribute(script_attrs[i].name, script_attrs[i].value);\n }\n // store reference to server id on output_area\n output_area._bokeh_server_id = output.metadata[EXEC_MIME_TYPE][\"server_id\"];\n }\n}\n\n/**\n * Handle when an output is cleared or removed\n */\nfunction handle_clear_output(event, handle) {\n var id = handle.cell.output_area._hv_plot_id;\n var server_id = handle.cell.output_area._bokeh_server_id;\n if (((id === undefined) || !(id in PyViz.plot_index)) && (server_id !== undefined)) { return; }\n var comm = window.PyViz.comm_manager.get_client_comm(\"hv-extension-comm\", \"hv-extension-comm\", function () {});\n if (server_id !== null) {\n comm.send({event_type: 'server_delete', 'id': server_id});\n return;\n } else if (comm !== null) {\n comm.send({event_type: 'delete', 'id': id});\n }\n delete PyViz.plot_index[id];\n if ((window.Bokeh !== undefined) & (id in window.Bokeh.index)) {\n var doc = window.Bokeh.index[id].model.document\n doc.clear();\n const i = window.Bokeh.documents.indexOf(doc);\n if (i > -1) {\n window.Bokeh.documents.splice(i, 1);\n }\n }\n}\n\n/**\n * Handle kernel restart event\n */\nfunction handle_kernel_cleanup(event, handle) {\n delete PyViz.comms[\"hv-extension-comm\"];\n window.PyViz.plot_index = {}\n}\n\n/**\n * Handle update_display_data messages\n */\nfunction handle_update_output(event, handle) {\n handle_clear_output(event, {cell: {output_area: handle.output_area}})\n handle_add_output(event, handle)\n}\n\nfunction register_renderer(events, OutputArea) {\n function append_mime(data, metadata, element) {\n // create a DOM node to render to\n var toinsert = this.create_output_subarea(\n metadata,\n CLASS_NAME,\n EXEC_MIME_TYPE\n );\n this.keyboard_manager.register_events(toinsert);\n // Render to node\n var props = {data: data, metadata: metadata[EXEC_MIME_TYPE]};\n render(props, toinsert[0]);\n element.append(toinsert);\n return toinsert\n }\n\n events.on('output_added.OutputArea', handle_add_output);\n events.on('output_updated.OutputArea', handle_update_output);\n events.on('clear_output.CodeCell', handle_clear_output);\n events.on('delete.Cell', handle_clear_output);\n events.on('kernel_ready.Kernel', handle_kernel_cleanup);\n\n OutputArea.prototype.register_mime_type(EXEC_MIME_TYPE, append_mime, {\n safe: true,\n index: 0\n });\n}\n\nif (window.Jupyter !== undefined) {\n try {\n var events = require('base/js/events');\n var OutputArea = require('notebook/js/outputarea').OutputArea;\n if (OutputArea.prototype.mime_types().indexOf(EXEC_MIME_TYPE) == -1) {\n register_renderer(events, OutputArea);\n }\n } catch(err) {\n }\n}\n",
+ "application/vnd.holoviews_load.v0+json": ""
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.holoviews_exec.v0+json": "",
+ "text/html": [
+ "
\n",
+ ""
+ ]
+ },
+ "metadata": {
+ "application/vnd.holoviews_exec.v0+json": {
+ "id": "cc773275-e3eb-4a12-93e1-3b0732805f97"
+ }
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "# imports\n",
+ "import torch\n",
+ "import panel as pn\n",
+ "from delphi.eval.vis import token_selector\n",
+ "from datasets import load_dataset, Dataset\n",
+ "from transformers import AutoTokenizer\n",
+ "from typing import cast\n",
+ "from delphi.eval.calc_model_group_stats import calc_model_group_stats\n",
+ "from delphi.eval.vis_per_token_model import visualize_selected_tokens\n",
+ "from ipywidgets import interact\n",
+ "from delphi.eval.token_positions import get_all_tok_metrics_in_label\n",
+ "from delphi.eval.vis import vis_pos_map\n",
+ "import ipywidgets as widgets\n",
+ "\n",
+ "# refer to https://panel.holoviz.org/reference/panes/IPyWidget.html to integrate ipywidgets with panel\n",
+ "pn.extension('ipywidgets')\n",
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# specify model names (or checkpoints)\n",
+ "prefix = \"delphi-suite/v0-next-logprobs-llama2-\"\n",
+ "suffixes = [\n",
+ " \"100k\",\n",
+ " \"200k\",\n",
+ " \"400k\",\n",
+ "] # , \"800k\", \"1.6m\", \"3.2m\", \"6.4m\", \"12.8m\", \"25.6m\"]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# load next logprobs data for all models\n",
+ "split = \"validation[:100]\"\n",
+ "next_logprobs = {\n",
+ " suffix: cast(\n",
+ " Dataset,\n",
+ " load_dataset(f\"{prefix}{suffix}\", split=split),\n",
+ " )\n",
+ " .with_format(\"torch\")\n",
+ " .map(lambda x: {\"logprobs\": x[\"logprobs\"].to(device)})\n",
+ " for suffix in suffixes\n",
+ "}\n",
+ "next_logprobs_plot = {k: d[\"logprobs\"] for k, d in next_logprobs.items()}\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "\n",
+ "# load the tokenized dataset\n",
+ "tokenized_corpus_dataset = (\n",
+ " cast(\n",
+ " Dataset,\n",
+ " load_dataset(\"delphi-suite/stories-tokenized\", split=split),\n",
+ " )\n",
+ " .with_format(\"torch\")\n",
+ " .map(lambda x: {\"tokens\": x[\"tokens\"].to(device)})\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Run this notebook until the following cell, then the rest should work."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 68,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "7a90e8b37a674fe4a82fa837c5258335",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "BokehModel(combine_events=True, render_bundle={'docs_json': {'fbcd4d70-83e4-4f3b-a073-702f920d0738': {'version…"
+ ]
+ },
+ "execution_count": 68,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# specific token specification\n",
+ "tokenizer = AutoTokenizer.from_pretrained(\"delphi-suite/stories-tokenizer\")\n",
+ "\n",
+ "# Count the frequency of each token using torch.bincount\n",
+ "token_counts = torch.bincount(tokenized_corpus_dataset[\"tokens\"].view(-1))\n",
+ "\n",
+ "# Get the indices that would sort the token counts in descending order\n",
+ "sorted_indices = torch.argsort(token_counts, descending=True)\n",
+ "\n",
+ "# Get the token IDs in descending order of frequency\n",
+ "valid_tok_ids = sorted_indices.tolist()\n",
+ "def format_fix(s):\n",
+ " if s.startswith(\" \"):\n",
+ " return \"_\" + s[1:]\n",
+ " return s\n",
+ "vocab = {format_fix(tokenizer.decode(t, clean_up_tokenization_spaces=True)): t for t in sorted_indices.tolist() if token_counts[t] > 0}\n",
+ "\n",
+ "\n",
+ "selector, selected_ids = token_selector(vocab) # use selected_ids as a dynamic variable\n",
+ "pn.Row(selector, height=500).servable()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Selected IDs: [40, 2, 14]\n"
+ ]
+ }
+ ],
+ "source": [
+ "if not selected_ids:\n",
+ " selected_ids = [40, 2, 14]\n",
+ "print(\"Selected IDs:\", selected_ids)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([100, 513])"
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "list(next_logprobs_plot.values())[0].shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 71,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Processing model 100k\n",
+ "Processing model 200k\n",
+ "Processing model 400k\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "133ee80dfa814961944d5ddf76f2fc6e",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "FigureWidget({\n",
+ " 'data': [{'line': {'width': 0},\n",
+ " 'marker': {'color': 'rgba(68, 68, 68, 0.3)'},\n",
+ " 'mode': 'lines',\n",
+ " 'name': 'Upper Bound',\n",
+ " 'showlegend': False,\n",
+ " 'type': 'scatter',\n",
+ " 'uid': '7b5d42bd-2dc2-4c7b-b9ef-500259ecc5e6',\n",
+ " 'x': [100k, 200k, 400k],\n",
+ " 'y': array([3.71775752, 3.02412134, 3.20290118])},\n",
+ " {'fill': 'tonexty',\n",
+ " 'fillcolor': 'rgba(68, 68, 68, 0.3)',\n",
+ " 'line': {'width': 0},\n",
+ " 'marker': {'color': 'rgba(68, 68, 68, 0.3)'},\n",
+ " 'mode': 'lines',\n",
+ " 'name': 'Lower Bound',\n",
+ " 'showlegend': False,\n",
+ " 'type': 'scatter',\n",
+ " 'uid': '716e9e70-2be1-4d42-ad85-8c815b6b66b2',\n",
+ " 'x': [100k, 200k, 400k],\n",
+ " 'y': array([0.68709016, 0.66691843, 0.41506469])},\n",
+ " {'marker': {'color': 'rgb(31, 119, 180)', 'line': {'color': 'rgb(31, 119, 180)', 'width': 1}, 'size': 0},\n",
+ " 'mode': 'lines',\n",
+ " 'name': 'Means',\n",
+ " 'type': 'scatter',\n",
+ " 'uid': 'f388f676-370a-45d7-8f75-4d51c27d1728',\n",
+ " 'x': [100k, 200k, 400k],\n",
+ " 'y': array([0.94124633, 0.75591046, 0.50620776])}],\n",
+ " 'layout': {'template': '...'}\n",
+ "})"
+ ]
+ },
+ "execution_count": 71,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "model_group_stats = calc_model_group_stats( # i'm not sure if tokenized_corpus_dataset.tolist() is the right input, it was list(tokenized_corpus_dataset) before\n",
+ " tokenized_corpus_dataset, next_logprobs_plot, selected_ids\n",
+ ")\n",
+ "performance_data = {}\n",
+ "for suffix in suffixes:\n",
+ " stats = model_group_stats[suffix]\n",
+ " performance_data[suffix] = (\n",
+ " -stats[\"median\"],\n",
+ " -stats[\"75th\"],\n",
+ " -stats[\"25th\"],\n",
+ " )\n",
+ "\n",
+ "visualize_selected_tokens(performance_data, log_scale=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from ipywidgets import interact_manual"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "6ff334e54715402eb32b3a77aef0a830",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "interactive(children=(FloatRangeSlider(value=(0.25, 0.75), description='Quantiles', max=1.0, step=0.05), Dropd…"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "def show_pos_map(\n",
+ " quantile: tuple[float, float],\n",
+ " model_name_1: str,\n",
+ " model_name_2: str,\n",
+ "):\n",
+ " logprobs_diff = next_logprobs[model_name_2][\"logprobs\"] - next_logprobs[model_name_1][\"logprobs\"] # type: ignore\n",
+ " pos_to_diff = get_all_tok_metrics_in_label(tokenized_corpus_dataset[\"tokens\"], selected_tokens=selected_ids, metrics=logprobs_diff, q_start=quantile[0], q_end=quantile[1]) # type: ignore\n",
+ " try:\n",
+ " _ = vis_pos_map(list(pos_to_diff.keys()), selected_ids, logprobs_diff, tokenized_corpus_dataset[\"tokens\"], tokenizer) # type: ignore\n",
+ " except ValueError:\n",
+ " if pos_to_diff == {}:\n",
+ " print(\"No tokens found in this label\")\n",
+ " return\n",
+ "\n",
+ "\n",
+ "interact_manual(\n",
+ " show_pos_map,\n",
+ " quantile=widgets.FloatRangeSlider(\n",
+ " min=0.0, max=1.0, step=0.05, description=\"Quantiles\"\n",
+ " ),\n",
+ " samples=widgets.IntSlider(min=1, max=5, description=\"Samples\", value=2),\n",
+ " model_name_1=widgets.Dropdown(\n",
+ " options=suffixes,\n",
+ " description=\"Model 1\",\n",
+ " value=\"100k\",\n",
+ " ),\n",
+ " model_name_2=widgets.Dropdown(\n",
+ " options=suffixes,\n",
+ " description=\"Model 2\",\n",
+ " value=\"200k\",\n",
+ " ),\n",
+ ")"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": ".venv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.14"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/src/delphi/eval/calc_model_group_stats.py b/src/delphi/eval/calc_model_group_stats.py
index d9c5d4c1..faab8a02 100644
--- a/src/delphi/eval/calc_model_group_stats.py
+++ b/src/delphi/eval/calc_model_group_stats.py
@@ -1,33 +1,30 @@
import numpy as np
+import torch
+from datasets import Dataset
+from jaxtyping import Float
def calc_model_group_stats(
- tokenized_corpus_dataset: list,
- logprobs_by_dataset: dict[str, list[list[float]]],
- token_labels_by_token: dict[int, dict[str, bool]],
- token_labels: list[str],
-) -> dict[tuple[str, str], dict[str, float]]:
+ tokenized_corpus_dataset: Dataset,
+ logprobs_by_dataset: dict[str, torch.Tensor],
+ selected_tokens: list[int],
+) -> dict[str, dict[str, float]]:
"""
For each (model, token group) pair, calculate useful stats (for visualization)
args:
- - tokenized_corpus_dataset: the tokenized corpus dataset, e.g. load_dataset(constants.tokenized_corpus_dataset))["validation"]
+ - tokenized_corpus_dataset: a list of the tokenized corpus datasets, e.g. load_dataset(constants.tokenized_corpus_dataset))["validation"]
- logprob_datasets: a dict of lists of logprobs, e.g. {"llama2": load_dataset("transcendingvictor/llama2-validation-logprobs")["validation"]["logprobs"]}
- - token_groups: a dict of token groups, e.g. {0: {"Is Noun": True, "Is Verb": False, ...}, 1: {...}, ...}
- - models: a list of model names, e.g. constants.LLAMA2_MODELS
- - token_labels: a list of token group descriptions, e.g. ["Is Noun", "Is Verb", ...]
+ - selected_tokens: a list of selected token IDs, e.g. [46, 402, ...]
- returns: a dict of (model, token group) pairs to a dict of stats,
- e.g. {("llama2", "Is Noun"): {"mean": -0.5, "median": -0.4, "min": -0.1, "max": -0.9, "25th": -0.3, "75th": -0.7}, ...}
+ returns: a dict of model names as keys and stats dict as values
+ e.g. {"100k": {"mean": -0.5, "median": -0.4, "min": -0.1, "max": -0.9, "25th": -0.3, "75th": -0.7}, ...}
- Technically `models` and `token_labels` are redundant, as they are also keys in `logprob_datasets` and `token_groups`,
- but it's better to be explicit
-
- stats calculated: mean, median, min, max, 25th percentile, 75th percentile
+ Stats calculated: mean, median, min, max, 25th percentile, 75th percentile
"""
model_group_stats = {}
for model in logprobs_by_dataset:
- group_logprobs = {}
+ model_logprobs = []
print(f"Processing model {model}")
dataset = logprobs_by_dataset[model]
for ix_doc_lp, document_lps in enumerate(dataset):
@@ -35,20 +32,17 @@ def calc_model_group_stats(
for ix_token, token in enumerate(tokens):
if ix_token == 0: # skip the first token, which isn't predicted
continue
- logprob = document_lps[ix_token]
- for token_group_desc in token_labels:
- if token_labels_by_token[token][token_group_desc]:
- if token_group_desc not in group_logprobs:
- group_logprobs[token_group_desc] = []
- group_logprobs[token_group_desc].append(logprob)
- for token_group_desc in token_labels:
- if token_group_desc in group_logprobs:
- model_group_stats[(model, token_group_desc)] = {
- "mean": np.mean(group_logprobs[token_group_desc]),
- "median": np.median(group_logprobs[token_group_desc]),
- "min": np.min(group_logprobs[token_group_desc]),
- "max": np.max(group_logprobs[token_group_desc]),
- "25th": np.percentile(group_logprobs[token_group_desc], 25),
- "75th": np.percentile(group_logprobs[token_group_desc], 75),
- }
+ logprob = document_lps[ix_token].item()
+ if token in selected_tokens:
+ model_logprobs.append(logprob)
+
+ if model_logprobs:
+ model_group_stats[model] = {
+ "mean": np.mean(model_logprobs),
+ "median": np.median(model_logprobs),
+ "min": np.min(model_logprobs),
+ "max": np.max(model_logprobs),
+ "25th": np.percentile(model_logprobs, 25),
+ "75th": np.percentile(model_logprobs, 75),
+ }
return model_group_stats
diff --git a/src/delphi/eval/constants.py b/src/delphi/eval/constants.py
index 5cd3daf1..3a586e00 100644
--- a/src/delphi/eval/constants.py
+++ b/src/delphi/eval/constants.py
@@ -2,15 +2,15 @@
tokenized_corpus_dataset = "delphi-suite/tinystories-v2-clean-tokenized-v0"
LLAMA2_MODELS = [
- "delphi-llama2-100k",
- "delphi-llama2-200k",
- "delphi-llama2-400k",
- "delphi-llama2-800k",
- "delphi-llama2-1.6m",
- "delphi-llama2-3.2m",
- "delphi-llama2-6.4m",
- "delphi-llama2-12.8m",
- "delphi-llama2-25.6m",
+ "llama2-100k",
+ "llama2-200k",
+ "llama2-400k",
+ "llama2-800k",
+ "llama2-1.6m",
+ "llama2-3.2m",
+ "llama2-6.4m",
+ "llama2-12.8m",
+ "llama2-25.6m",
]
LLAMA2_NEXT_LOGPROBS_DATASETS_MAP = {
diff --git a/src/delphi/eval/token_positions.py b/src/delphi/eval/token_positions.py
index 5239a53f..a98af761 100644
--- a/src/delphi/eval/token_positions.py
+++ b/src/delphi/eval/token_positions.py
@@ -1,8 +1,6 @@
-from numbers import Number
-from typing import Optional, cast
+from typing import Optional
import torch
-from datasets import Dataset
from jaxtyping import Int
from delphi.eval.utils import dict_filter_quantile
@@ -10,9 +8,8 @@
def get_all_tok_metrics_in_label(
token_ids: Int[torch.Tensor, "prompt pos"],
- token_labels: dict[int, dict[str, bool]],
+ selected_tokens: list[int],
metrics: torch.Tensor,
- label: str,
q_start: Optional[float] = None,
q_end: Optional[float] = None,
) -> dict[tuple[int, int], float]:
@@ -23,9 +20,8 @@ def get_all_tok_metrics_in_label(
Args:
- token_ids (Dataset): token_ids dataset e.g. token_ids[0] = {"tokens": [[1, 2, ...], [2, 5, ...], ...]}
- - token_labels (dict[int, dict[str, bool]]): dictionary of token labels e.g. { 0: {"Is Noun": True, "Is Verb": False}, ...}
+ - selected_tokens (list[int]): list of token IDs to search for e.g. [46, 402, ...]
- metrics (torch.Tensor): tensor of metrics to search through e.g. torch.tensor([[0.1, 0.2, ...], [0.3, 0.4, ...], ...])
- - label (str): the label to search for
- q_start (float): the start of the quantile range to filter the metrics e.g. 0.1
- q_end (float): the end of the quantile range to filter the metrics e.g. 0.9
@@ -42,7 +38,7 @@ def get_all_tok_metrics_in_label(
tok_positions = {}
for prompt_pos, prompt in enumerate(token_ids.numpy()):
for tok_pos, tok in enumerate(prompt):
- if token_labels[tok][label]:
+ if tok in selected_tokens:
tok_positions[(prompt_pos, tok_pos)] = metrics[
prompt_pos, tok_pos
].item()
diff --git a/src/delphi/eval/vis.py b/src/delphi/eval/vis.py
index 1a69eae2..f13924ad 100644
--- a/src/delphi/eval/vis.py
+++ b/src/delphi/eval/vis.py
@@ -3,6 +3,7 @@
import uuid
from typing import cast
+import numpy as np
import panel as pn
import torch
from IPython.core.display import HTML
@@ -54,6 +55,7 @@ def token_to_html(
tokenizer: PreTrainedTokenizerBase,
bg_color: str,
data: dict,
+ class_name: str = "token",
) -> str:
data = data or {} # equivalent to if not data: data = {}
# non-breakable space, w/o it leading spaces wouldn't be displayed
@@ -73,6 +75,7 @@ def token_to_html(
br += "
"
# this is so we can copy the prompt without "\n"s
specific_styles["user-select"] = "none"
+ str_token = str_token.replace("<", "<").replace(">", ">")
style_str = data_str = ""
# converting style dict into the style attribute
@@ -83,7 +86,7 @@ def token_to_html(
data_str = "".join(
f" data-{k}='{v.replace(' ', ' ')}'" for k, v in data.items()
)
- return f"{str_token}
{br}"
+ return f"{str_token}
{br}"
_token_style = {
@@ -97,7 +100,20 @@ def token_to_html(
"margin": "1px 0px 1px 1px",
"padding": "0px 1px 1px 1px",
}
+_token_emphasized_style = {
+ "border": "3px solid #888",
+ "display": "inline-block",
+ "font-family": "monospace",
+ "font-size": "14px",
+ "color": "black",
+ "background-color": "white",
+ "margin": "1px 0px 1px 1px",
+ "padding": "0px 1px 1px 1px",
+}
_token_style_str = " ".join([f"{k}: {v};" for k, v in _token_style.items()])
+_token_emphasized_style_str = " ".join(
+ [f"{k}: {v};" for k, v in _token_emphasized_style.items()]
+)
def vis_sample_prediction_probs(
@@ -130,8 +146,8 @@ def vis_sample_prediction_probs(
data[f"top{j}"] = to_tok_prob_str(top_tok, top_prob, tokenizer)
token_htmls.append(
- token_to_html(tok, tokenizer, bg_color=colors[i], data=data).replace(
- "class='token'", f"class='{token_class}'"
+ token_to_html(
+ tok, tokenizer, bg_color=colors[i], data=data, class_name=token_class
)
)
@@ -165,10 +181,11 @@ def vis_sample_prediction_probs(
def vis_pos_map(
- pos_map: dict[tuple[int, int], float | int],
+ pos_list: list[tuple[int, int]],
+ selected_tokens: list[int],
+ metrics: Float[torch.Tensor, "prompt pos"],
token_ids: Int[torch.Tensor, "prompt pos"],
tokenizer: PreTrainedTokenizerBase,
- sample: int = 3,
):
"""
Randomly sample from pos_map and visualize the loss diff at the corresponding position.
@@ -176,49 +193,43 @@ def vis_pos_map(
token_htmls = []
unique_id = str(uuid.uuid4())
- token_class = f"token_{unique_id}"
+ token_class = f"pretoken_{unique_id}"
+ selected_token_class = f"token_{unique_id}"
hover_div_id = f"hover_info_{unique_id}"
- # choose n random keys from pos_map
- keys = random.sample(list(pos_map.keys()), k=sample)
-
- for key in keys:
- prompt, pos = key
- pre_toks = token_ids[prompt][:pos]
- mask = torch.isin(pre_toks, torch.tensor([0, 1], dtype=torch.int8))
- pre_toks = pre_toks[
- ~mask
- ] # remove and tokens, cause strikethrough in html
-
- for i in range(pre_toks.shape[0]):
- pre_tok = cast(int, pre_toks[i].item())
- token_htmls.append(
- token_to_html(pre_tok, tokenizer, bg_color="white", data={}).replace(
- "class='token'", f"class='{token_class}'"
- )
- )
+ # choose a random keys from pos_map
+ key = random.choice(pos_list)
- tok = cast(int, token_ids[prompt][pos].item())
- value = cast(float, pos_map[key])
+ prompt, pos = key
+ all_toks = token_ids[prompt][: pos + 1]
+ for i in range(all_toks.shape[0]):
+ token_id = cast(int, all_toks[i].item())
+ value = metrics[prompt][i].item()
token_htmls.append(
token_to_html(
- tok,
+ token_id,
tokenizer,
- bg_color=single_loss_diff_to_color(value),
+ bg_color="white"
+ if np.isnan(value)
+ else single_loss_diff_to_color(value),
data={"loss-diff": f"{value:.2f}"},
- ).replace("class='token'", f"class='{token_class}'")
+ class_name=token_class
+ if token_id not in selected_tokens
+ else selected_token_class,
+ )
)
- # add break line
- token_htmls.append("
")
+ # add break line
+ token_htmls.append("
")
html_str = f"""
-
+
{"".join(token_htmls)}
"""
display(HTML(html_str))
- return html_str
def token_selector(
vocab_map: dict[str, int]
) -> tuple[pn.widgets.MultiChoice, list[int]]:
tokens = list(vocab_map.keys())
- token_selector = pn.widgets.MultiChoice(name="Tokens", options=tokens)
- token_ids = [vocab_map[token] for token in cast(list[str], token_selector.value)]
+ token_selector_ = pn.widgets.MultiChoice(name="Tokens", options=tokens)
+ token_ids = [vocab_map[token] for token in cast(list[str], token_selector_.value)]
def update_tokens(event):
token_ids.clear()
token_ids.extend([vocab_map[token] for token in event.new])
- token_selector.param.watch(update_tokens, "value")
- return token_selector, token_ids
+ token_selector_.param.watch(update_tokens, "value")
+ return token_selector_, token_ids
diff --git a/src/delphi/eval/vis_per_token_model.py b/src/delphi/eval/vis_per_token_model.py
index 8daaa96f..e5d735f4 100644
--- a/src/delphi/eval/vis_per_token_model.py
+++ b/src/delphi/eval/vis_per_token_model.py
@@ -1,12 +1,11 @@
from typing import Union
-import ipywidgets
import numpy as np
import plotly.graph_objects as go
-def visualize_per_token_category(
- input: dict[Union[str, int], dict[str, tuple]],
+def visualize_selected_tokens(
+ input: dict[Union[str, int], tuple[float, float, float]],
log_scale=False,
line_metric="Means",
checkpoint_mode=True,
@@ -17,18 +16,16 @@ def visualize_per_token_category(
background_color="AliceBlue",
) -> go.FigureWidget:
input_x = list(input.keys())
- categories = list(input[input_x[0]].keys())
- category = categories[0]
def get_hovertexts(mid: np.ndarray, lo: np.ndarray, hi: np.ndarray) -> list[str]:
return [f"Loss: {m:.3f} ({l:.3f}, {h:.3f})" for m, l, h in zip(mid, lo, hi)]
- def get_plot_values(category: str) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
- x = np.array([input[x][category] for x in input_x]).T
+ def get_plot_values() -> tuple[np.ndarray, np.ndarray, np.ndarray]:
+ x = np.array([input[x] for x in input_x]).T
means, err_lo, err_hi = x[0], x[1], x[2]
return means, err_lo, err_hi
- means, err_lo, err_hi = get_plot_values(category)
+ means, err_lo, err_hi = get_plot_values()
if checkpoint_mode:
scatter_plot = go.Figure(
diff --git a/tests/eval/test_token_positions.py b/tests/eval/test_token_positions.py
index 1adef6b7..c584b931 100644
--- a/tests/eval/test_token_positions.py
+++ b/tests/eval/test_token_positions.py
@@ -12,40 +12,43 @@ def mock_data():
token_ids = Dataset.from_dict(
{"tokens": [[1, 2, 3], [4, 5, 6], [7, 8, 9]]}
).with_format("torch")
- token_labels = {
- 1: {"Is Noun": False, "Is Verb": True},
- 2: {"Is Noun": True, "Is Verb": True},
- 3: {"Is Noun": False, "Is Verb": False},
- 4: {"Is Noun": True, "Is Verb": False},
- 5: {"Is Noun": False, "Is Verb": True},
- 6: {"Is Noun": True, "Is Verb": True},
- 7: {"Is Noun": False, "Is Verb": False},
- 8: {"Is Noun": True, "Is Verb": False},
- 9: {"Is Noun": False, "Is Verb": True},
- }
- metrics = torch.tensor([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]])
- return token_ids, token_labels, metrics
+ selected_tokens = [2, 4, 6, 8]
+ metrics = torch.tensor([[-1, 0.45, -0.33], [-1.31, 2.3, 0.6], [0.2, 0.8, 0.1]])
+ return token_ids, selected_tokens, metrics
def test_get_all_tok_metrics_in_label(mock_data):
- token_ids, token_labels, metrics = mock_data
+ token_ids, selected_tokens, metrics = mock_data
result = get_all_tok_metrics_in_label(
- token_ids["tokens"], token_labels, metrics, "Is Noun"
+ token_ids["tokens"],
+ selected_tokens,
+ metrics,
)
+ # key: (prompt_pos, tok_pos), value: logprob
expected = {
- (0, 1): 0.2,
- (1, 0): 0.4,
+ (0, 1): 0.45,
+ (1, 0): -1.31,
(1, 2): 0.6,
(2, 1): 0.8,
}
- # use isclose to compare floating point numbers
+
+ # compare keys
+ assert result.keys() == expected.keys()
+ # compare values
for k in result:
assert isclose(cast(float, result[k]), expected[k], rel_tol=1e-6) # type: ignore
# test with quantile filtering
result_q = get_all_tok_metrics_in_label(
- token_ids["tokens"], token_labels, metrics, "Is Noun", q_start=0.3, q_end=1.0
+ token_ids["tokens"], selected_tokens, metrics, q_start=0.6, q_end=1.0
)
- expected_q = {(1, 2): 0.6, (2, 1): 0.8, (1, 0): 0.4}
+ expected_q = {
+ (1, 2): 0.6,
+ (2, 1): 0.8,
+ }
+
+ # compare keys
+ assert result_q.keys() == expected_q.keys()
+ # compare values
for k in result_q:
assert isclose(cast(float, result_q[k]), expected_q[k], rel_tol=1e-6) # type: ignore
diff --git a/tests/eval/test_utils_eval.py b/tests/eval/test_utils_eval.py
index ad0f54b8..a259d16b 100644
--- a/tests/eval/test_utils_eval.py
+++ b/tests/eval/test_utils_eval.py
@@ -61,7 +61,22 @@ def test_load_validation_dataset():
def test_dict_filter_quantile():
d = {1: 0.1, 2: 0.2, 3: 0.3, 4: 0.4, 5: 0.5}
result = dict_filter_quantile(d, 0.2, 0.6)
- expected = {2: 0.2, 3: 0.3, 4: 0.4}
+ expected = {2: 0.2, 3: 0.3}
+
+ # compare keys
+ assert result.keys() == expected.keys()
+ # compare values
+ for k in result:
+ assert isclose(result[k], expected[k], rel_tol=1e-6)
+
+ # test with negative values
+ d = {1: -0.1, 2: -0.2, 3: -0.3, 4: -0.4, 5: -0.5}
+ result = dict_filter_quantile(d, 0.2, 0.6)
+ expected = {3: -0.3, 4: -0.4}
+
+ # compare keys
+ assert result.keys() == expected.keys()
+ # compare values
for k in result:
assert isclose(result[k], expected[k], rel_tol=1e-6)