From 10e557bc3c43c627d27babcbd6aea2248b3bfb0a Mon Sep 17 00:00:00 2001 From: Siwei Li Date: Wed, 14 Feb 2024 17:03:25 -0800 Subject: [PATCH] Add visualization function for per token model comparison --- .../dataset/mock_per_token_performance.py | 0 src/delphi/eval/vis_per_token_model.py | 37 +++++++++++++++++++ 2 files changed, 37 insertions(+) rename data/mock-per-token-performance.py => src/delphi/dataset/mock_per_token_performance.py (100%) create mode 100644 src/delphi/eval/vis_per_token_model.py diff --git a/data/mock-per-token-performance.py b/src/delphi/dataset/mock_per_token_performance.py similarity index 100% rename from data/mock-per-token-performance.py rename to src/delphi/dataset/mock_per_token_performance.py diff --git a/src/delphi/eval/vis_per_token_model.py b/src/delphi/eval/vis_per_token_model.py new file mode 100644 index 00000000..81897078 --- /dev/null +++ b/src/delphi/eval/vis_per_token_model.py @@ -0,0 +1,37 @@ +import ipywidgets as widgets +import matplotlib.pyplot as plt +import numpy as np +from ipywidgets import interact + + +def visualize_per_token_category(input): + model_names = list(input.keys()) + categories = list(list(input.values())[0].keys()) + + def _f(category): + x = np.array([input[name][category] for name in model_names]).T + means = np.mean(x, axis=0) + median = np.median(x, axis=0) + q1 = np.quantile(x, 0.25, axis=0) + q3 = np.quantile(x, 0.75, axis=0) + + ax = plt.gca() + ax.set_ylim([-5, 5]) # TODO + + plt.plot(model_names, means) + plt.errorbar(model_names, median, yerr=[median - q1, q3 - median], fmt="o") + + interact( + _f, + category=widgets.Dropdown( + options=categories, + placeholder="", + description="Token Category:", + disabled=False, + ), + ) + + +# Usage: +# from dataset.mock_per_token_performance import performance_datas +# visualize_per_token_category(performance_data)