Skip to content

Commit

Permalink
add hovertext
Browse files Browse the repository at this point in the history
  • Loading branch information
jaidhyani committed Feb 23, 2024
1 parent 8e6ab2e commit ae7a97e
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/delphi/eval/vis_per_token_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ def visualize_per_token_category(
categories = list(input[model_names[0]].keys())
category = categories[0]

def get_hovertexts(mid: np.ndarray, lo: np.ndarray, hi: np.ndarray) -> list[str]:
return [f"Loss: {m:.3f} ({l:.3f}, {h:.3f})" for m, l, h in zip(mid, lo, hi)]

def get_plot_values(category: str) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
x = np.array([input[name][category] for name in model_names]).T
means, err_lo, err_hi = x[0], x[1], x[2]
Expand All @@ -32,6 +35,8 @@ def get_plot_values(category: str) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
size=15,
line=dict(color=kwargs.get("line_color", "MediumPurple"), width=2),
),
hovertext=get_hovertexts(means, err_low, err_hi),
hoverinfo="text+x",
),
layout=go.Layout(
yaxis=dict(
Expand All @@ -55,6 +60,7 @@ def response(change):
g.data[0].y = means
g.data[0].error_y["array"] = err_hi
g.data[0].error_y["arrayminus"] = err_lo
g.data[0].hovertext = get_hovertexts(means, err_lo, err_hi)

selected_category.observe(response, names="value")

Expand Down

0 comments on commit ae7a97e

Please sign in to comment.