Skip to content

Commit

Permalink
Add color option for plot
Browse files Browse the repository at this point in the history
  • Loading branch information
ncaptier committed Dec 3, 2024
1 parent 8e4ddde commit 389dd11
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions multipit/result_analysis/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def plot_metrics(
ylim=None,
y_text=None,
ax=None,
colors=None,
):
"""
Plot the results of the repeated cross-validation experiments for different models with a barplot.
Expand Down Expand Up @@ -58,6 +59,10 @@ def plot_metrics(
ax : matplotlib.axes, None
The default is None.
colors: palette name, list, or dict
Colors to use for the different level of the hue variable (i.e., either for the different metrics when metrics
is a list or for the different models when metrics is a string). The default is None.
Returns
-------
matplotlib.pyplot.figure
Expand All @@ -80,6 +85,7 @@ def plot_metrics(
hue_order=metrics,
ax=ax,
errorbar=None,
palette=colors
)

for i, m in enumerate(metrics):
Expand All @@ -96,8 +102,9 @@ def plot_metrics(

elif isinstance(metrics, str):
df_plot = results[results["metric"] == metrics].melt(id_vars=["metric"])
pal = "tab20" if colors is None else colors
sns.barplot(
data=df_plot, x="variable", y="value", hue="variable", legend=False, ax=ax, errorbar=None, palette="tab20"
data=df_plot, x="variable", y="value", hue="variable", legend=False, ax=ax, errorbar=None, palette=pal
)
ax.errorbar(
x=np.arange(len(models)),
Expand Down Expand Up @@ -180,7 +187,7 @@ def plot_metrics(
# annotator.set_pvalues_and_annotate(pvalues)

plt.tight_layout()
plt.show()
#plt.show()
return fig


Expand Down

0 comments on commit 389dd11

Please sign in to comment.