-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathplotting.py
28 lines (22 loc) · 1.01 KB
/
plotting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import plotly.graph_objects as go
import plotly.express as px
import pandas as pd
def plot_character_confusions(confusion_stats_df: pd.DataFrame, char: str, top_n: int = 5):
char_confusions = confusion_stats_df[confusion_stats_df['correct'] == char]
# If there are no confusions for this character, print a message and return
if char_confusions.empty:
print(f"No confusions found for character '{char}'.")
return go.Figure()
# Sort by count and select the top_n
char_confusions = char_confusions.sort_values('ratio', ascending=False).head(top_n)
# Generate a color map
cmap = px.colors.sequential.Plasma
# Plot a bar chart
fig = go.Figure(data=[
go.Bar(x=char_confusions['generated'], y=char_confusions['ratio'],
marker_color=cmap, name='Ratio')
])
fig.update_layout(title_text=f"Top {top_n} Confusions for Character '{char}'",
xaxis_title="Generated Character",
yaxis_title="Ratio")
return fig