From ae7a97e0b3444f954666a6c4ca280720786ebd40 Mon Sep 17 00:00:00 2001 From: JaiDhyani Date: Fri, 23 Feb 2024 02:35:33 -0800 Subject: [PATCH] add hovertext --- src/delphi/eval/vis_per_token_model.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/delphi/eval/vis_per_token_model.py b/src/delphi/eval/vis_per_token_model.py index a8e269fe..618840b0 100644 --- a/src/delphi/eval/vis_per_token_model.py +++ b/src/delphi/eval/vis_per_token_model.py @@ -10,6 +10,9 @@ def visualize_per_token_category( categories = list(input[model_names[0]].keys()) category = categories[0] + def get_hovertexts(mid: np.ndarray, lo: np.ndarray, hi: np.ndarray) -> list[str]: + return [f"Loss: {m:.3f} ({l:.3f}, {h:.3f})" for m, l, h in zip(mid, lo, hi)] + def get_plot_values(category: str) -> tuple[np.ndarray, np.ndarray, np.ndarray]: x = np.array([input[name][category] for name in model_names]).T means, err_lo, err_hi = x[0], x[1], x[2] @@ -32,6 +35,8 @@ def get_plot_values(category: str) -> tuple[np.ndarray, np.ndarray, np.ndarray]: size=15, line=dict(color=kwargs.get("line_color", "MediumPurple"), width=2), ), + hovertext=get_hovertexts(means, err_low, err_hi), + hoverinfo="text+x", ), layout=go.Layout( yaxis=dict( @@ -55,6 +60,7 @@ def response(change): g.data[0].y = means g.data[0].error_y["array"] = err_hi g.data[0].error_y["arrayminus"] = err_lo + g.data[0].hovertext = get_hovertexts(means, err_lo, err_hi) selected_category.observe(response, names="value")