From 8582141cba7f7dc0c7ee7c313fb077dcea06bc0f Mon Sep 17 00:00:00 2001 From: Niels Nuyttens Date: Thu, 18 Jul 2024 23:50:24 +0200 Subject: [PATCH] Fix average precision calculation replace forgotten model_output_column_names() with class_probability_columns property --- .../metrics/multiclass_classification.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/nannyml/performance_calculation/metrics/multiclass_classification.py b/nannyml/performance_calculation/metrics/multiclass_classification.py index 0c3657df..a8f39f5b 100644 --- a/nannyml/performance_calculation/metrics/multiclass_classification.py +++ b/nannyml/performance_calculation/metrics/multiclass_classification.py @@ -16,7 +16,7 @@ precision_score, recall_score, roc_auc_score, - average_precision_score + average_precision_score, ) from sklearn.preprocessing import LabelBinarizer, label_binarize @@ -43,7 +43,7 @@ ap_sampling_error_components, ap_sampling_error, bv_sampling_error_components, - bv_sampling_error + bv_sampling_error, ) from nannyml.thresholds import Threshold, calculate_threshold_values @@ -106,7 +106,7 @@ def _fit(self, reference_data: pd.DataFrame): _list_missing([self.y_true] + self.class_probability_columns, list(reference_data.columns)) reference_data, empty = common_nan_removal( reference_data[[self.y_true] + self.class_probability_columns], - [self.y_true] + self.class_probability_columns + [self.y_true] + self.class_probability_columns, ) if empty: self._sampling_error_components = [(np.NaN, 0) for clasz in self.classes] @@ -120,7 +120,8 @@ def _fit(self, reference_data: pd.DataFrame): "targets." ) raise InvalidArgumentsException( - "y_pred_proba class and class probabilities dictionary does not match reference data.") + "y_pred_proba class and class probabilities dictionary does not match reference data." + ) # sampling error binarized_y_true = list(label_binarize(reference_data[self.y_true], classes=self.classes).T) @@ -978,7 +979,7 @@ def _fit(self, reference_data: pd.DataFrame): _list_missing([self.y_true] + self.class_probability_columns, list(reference_data.columns)) reference_data, empty = common_nan_removal( reference_data[[self.y_true] + self.class_probability_columns], - [self.y_true] + self.class_probability_columns + [self.y_true] + self.class_probability_columns, ) if empty: self._sampling_error_components = [(np.NaN, 0) for class_col in self.class_probability_columns] @@ -1022,10 +1023,9 @@ def _calculate(self, data: pd.DataFrame): return average_precision_score(y_true, y_pred_proba, average='macro') def _sampling_error(self, data: pd.DataFrame) -> float: - class_y_pred_proba_columns = model_output_column_names(self.y_pred_proba) - _list_missing([self.y_true] + class_y_pred_proba_columns, data) + _list_missing([self.y_true] + self.class_probability_columns, data) data, empty = common_nan_removal( - data[[self.y_true] + class_y_pred_proba_columns], [self.y_true] + class_y_pred_proba_columns + data[[self.y_true] + self.class_probability_columns], [self.y_true] + self.class_probability_columns ) if empty: warnings.warn(