diff --git a/plotsandgraphs/binary_classifier.py b/plotsandgraphs/binary_classifier.py index 168cca5..3f81c01 100644 --- a/plotsandgraphs/binary_classifier.py +++ b/plotsandgraphs/binary_classifier.py @@ -368,7 +368,7 @@ def plot_roc_curve( def plot_calibration_curve( y_prob: np.ndarray, y_true: np.ndarray, n_bins=10, save_fig_path=None -): +) -> Figure: """ Creates calibration plot for a binary classifier and calculates the ECE. @@ -462,7 +462,7 @@ def plot_calibration_curve( path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(save_fig_path, bbox_inches="tight") - return fig, ece + return fig def plot_y_prob_histogram(y_prob: np.ndarray, y_true: Optional[np.ndarray]=None, save_fig_path=None) -> Figure: