Skip to content

Commit

Permalink
changed colours of plots
Browse files Browse the repository at this point in the history
changed colours of plots
  • Loading branch information
GretaVilla committed Aug 19, 2022
1 parent 291db62 commit 7be4853
Show file tree
Hide file tree
Showing 7 changed files with 54,205 additions and 54,272 deletions.
46 changes: 25 additions & 21 deletions bctools/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from .thresholds import get_optimized_thresholds_df

def predicted_proba_violin_plot(true_y, predicted_proba, threshold_step = 0.01, title = "Interactive Probabilities Violin Plot"):
def predicted_proba_violin_plot(true_y, predicted_proba, threshold_step = 0.01, marker_size = 3, title = "Interactive Probabilities Violin Plot"):

"""
Plots interactive and customized violin plots of predicted probabilties with plotly,
Expand All @@ -37,6 +37,8 @@ def predicted_proba_violin_plot(true_y, predicted_proba, threshold_step = 0.01,
threshold_step: float, default=0.01
step between each classification threshold (ranging from 0 to 1) below which prediction label is 0, 1 otherwise
each value will have a corresponding slider step
marker_size: int, default=3
Size of the points to be plotted
title: str, default='Interactive probabilities Violin Plot'
The main title of the plot.
"""
Expand All @@ -55,7 +57,7 @@ def predicted_proba_violin_plot(true_y, predicted_proba, threshold_step = 0.01,
main_title = f"<b>{title}</b><br>"

# VIOLIN PLOT
full_fig=go.Figure(data=go.Violin(y=data_df['pred'], x=data_df['class'], line_color='black',
full_fig=go.Figure(data=go.Violin(y=data_df['pred'], x=data_df['class'], line_color='#0D2A63',
meanline_visible=True, points=False, fillcolor=None, opacity=0.3, box=None,
scalemode='count', showlegend = False))

Expand Down Expand Up @@ -86,24 +88,26 @@ def predicted_proba_violin_plot(true_y, predicted_proba, threshold_step = 0.01,
titles[threshold] = titles[threshold][:-3] #removes last 3 char (2 spaces and comma)

# NOTE: px strip generates n plots, one for each color class (TN, FP, FN, TP) it finds)
strip_points_fig = px.strip(data_df, x='class', y='pred', color= threshold_string,
color_discrete_map = {'FN':'red', 'FP':'mediumpurple',
'TP':'green', 'TN':'blue'},
strip_points_fig = px.strip(data_df, x='class', y='pred', color=threshold_string,
color_discrete_map = {'FN':'#EF71D9', 'FP':'#EF553B',
'TP':'#00CC96', 'TN':'#636EFA'},
log_y=True, width=550, height=550, hover_data = [data_df.index])

strip_points_fig.update_traces(hovertemplate = 'Idx = %{customdata}<br>Class = %{x}<br>Pred = %{y}', jitter = 1, marker_size=3)
strip_points_fig.update_traces(hovertemplate = 'Idx = %{customdata}<br>Class = %{x}<br>Pred = %{y}', jitter = 1, marker_size=marker_size)

length_fig_list.append(len(strip_points_fig.data))

for i in range(len(strip_points_fig.data)):
strip_points_fig.data[i].visible=False
full_fig.add_trace(strip_points_fig.data[i])

full_fig.add_traces(list(strip_points_fig.select_traces()))

full_fig.update_layout(legend_font_size=9.5, legend_itemsizing='constant', legend_traceorder='grouped',
title=dict(text = main_title + '<span style="font-size: 13px;">' \
+ titles[threshold_values[0]] + '</span>',
y = 0.965, yanchor = 'bottom'),
width=550, height=550)

full_fig.update_layout(margin=dict(l=40, r=40, t=60, b=40))

# makes visible the first strip points figure
Expand Down Expand Up @@ -201,11 +205,10 @@ def curve_PR_plot(true_y, predicted_proba, beta = 1, title = "Precision Recall C
"Recall":recall.tolist(),
"Precision":precision.tolist()})

pr_fig=px.line(curve_df, x="Recall", y="Precision", hover_data=["Thresholds"], title=main_title)
full_fig=px.line(curve_df, x="Recall", y="Precision", hover_data=["Thresholds"], title=main_title)

pr_fig.update_traces(hovertemplate='Threshold: %{customdata:.4f} <br>Precision: %{y:.4f} <br>Recall: %{x:.4f}<extra></extra>')
pr_fig.update_traces(line_color='#222A2A', line_width=2, textposition="top center")
full_fig = pr_fig
full_fig.update_traces(hovertemplate='Threshold: %{customdata:.4f} <br>Precision: %{y:.4f} <br>Recall: %{x:.4f}<extra></extra>')
full_fig.update_traces(textposition="top center")

f_scores = np.linspace(0.2, 0.8, num=4)

Expand All @@ -222,18 +225,19 @@ def curve_PR_plot(true_y, predicted_proba, beta = 1, title = "Precision Recall C

iso_fig=px.line(recall_precision_df, x="recall", y="precision")
iso_fig.update_traces(hovertemplate=[]) # no hover info displayed but keeps dashed lines
iso_fig.update_traces(line_color='#778AAE', line=dict(dash='dot'), line_width=0.3)
iso_fig.update_traces(line_color='#4C78A8', line=dict(dash='dot'), line_width=0.8)

full_fig.add_annotation(x=0.90, y=y[45] + 0.01, text="f"+ str(beta) + "={0:0.1f}".format(f_score),
showarrow=False,yshift=10)
full_fig=go.Figure(data = full_fig.data + iso_fig.data, layout = full_fig.layout)

full_fig.add_traces(list(iso_fig.select_traces()))

area_under_pr_curve = auc(recall, precision)

full_fig.update_xaxes(range=[0.0, 1.0],title_text='Recall')
full_fig.update_yaxes(range=[0.0, 1.05],title_text='Precision')

full_fig.add_shape(type='line', line=dict(dash='dash'),x0=0, x1=1, y0=baseline, y1=baseline)
full_fig.add_shape(type='line', line=dict(dash='dash', color = '#20313e'),x0=0, x1=1, y0=baseline, y1=baseline)

full_fig['data'][0]['showlegend']= True
full_fig['data'][1]['showlegend']= True
Expand All @@ -244,8 +248,8 @@ def curve_PR_plot(true_y, predicted_proba, beta = 1, title = "Precision Recall C
legend_font_size=9.5,
width=550, height=550)

full_fig.update_xaxes(showspikes=True)
full_fig.update_yaxes(showspikes=True)
full_fig.update_xaxes(showspikes=True, spikedash = 'dot', spikethickness=2)
full_fig.update_yaxes(showspikes=True, spikedash = 'dot', spikethickness=2)
full_fig.update_layout(margin=dict(l=40, r=40, t=40, b=40))
full_fig.show()

Expand Down Expand Up @@ -292,22 +296,22 @@ def curve_ROC_plot(true_y, predicted_proba, title = "Receiver Operating Characte
hover_data=["Thresholds"],
width=550, height=550)

fig.update_traces(line_color="#222A2A", line_width=2, textposition="top center")
fig.update_traces(textposition="top center")
fig.update_traces(hovertemplate='Threshold: %{customdata:.4f} <br>False Positive Rate: %{x:.4f} <br>True Positive Rate: %{y:.4f}<extra></extra>')

fig.add_shape(type="line", line=dict(dash="dash"),
fig.add_shape(type="line", line=dict(dash="dash", color = '#20313e'),
x0=0, x1=1, y0=0, y1=1)

area_under_ROC_curve = auc(fpr, tpr)

fig["data"][0]["name"]= f"ROC Curve (AUC={area_under_ROC_curve:.3f})"

fig["data"][0]["showlegend"]= True

fig.update_layout(legend = dict(yanchor="top", y=0.20, xanchor="left", x=0.5),
legend_font_size=9.5)

fig.update_xaxes(showspikes=True)
fig.update_yaxes(showspikes=True)
fig.update_xaxes(showspikes=True, spikedash = 'dot', spikethickness=2)
fig.update_yaxes(showspikes=True, spikedash = 'dot', spikethickness=2)

fig.update_yaxes(scaleanchor="x", scaleratio=1)
fig.update_xaxes(range=[0,1], constrain="domain")
Expand Down
54,250 changes: 0 additions & 54,250 deletions example-notebook/example_classification_model.ipynb

This file was deleted.

Loading

0 comments on commit 7be4853

Please sign in to comment.