Skip to content

Commit

Permalink
Fix average precision calculation
Browse files Browse the repository at this point in the history
replace forgotten model_output_column_names() with class_probability_columns property
  • Loading branch information
nnansters committed Jul 18, 2024
1 parent 34e1cf3 commit 8582141
Showing 1 changed file with 8 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
precision_score,
recall_score,
roc_auc_score,
average_precision_score
average_precision_score,
)
from sklearn.preprocessing import LabelBinarizer, label_binarize

Expand All @@ -43,7 +43,7 @@
ap_sampling_error_components,
ap_sampling_error,
bv_sampling_error_components,
bv_sampling_error
bv_sampling_error,
)
from nannyml.thresholds import Threshold, calculate_threshold_values

Expand Down Expand Up @@ -106,7 +106,7 @@ def _fit(self, reference_data: pd.DataFrame):
_list_missing([self.y_true] + self.class_probability_columns, list(reference_data.columns))
reference_data, empty = common_nan_removal(
reference_data[[self.y_true] + self.class_probability_columns],
[self.y_true] + self.class_probability_columns
[self.y_true] + self.class_probability_columns,
)
if empty:
self._sampling_error_components = [(np.NaN, 0) for clasz in self.classes]
Expand All @@ -120,7 +120,8 @@ def _fit(self, reference_data: pd.DataFrame):
"targets."
)
raise InvalidArgumentsException(
"y_pred_proba class and class probabilities dictionary does not match reference data.")
"y_pred_proba class and class probabilities dictionary does not match reference data."
)

# sampling error
binarized_y_true = list(label_binarize(reference_data[self.y_true], classes=self.classes).T)
Expand Down Expand Up @@ -978,7 +979,7 @@ def _fit(self, reference_data: pd.DataFrame):
_list_missing([self.y_true] + self.class_probability_columns, list(reference_data.columns))
reference_data, empty = common_nan_removal(
reference_data[[self.y_true] + self.class_probability_columns],
[self.y_true] + self.class_probability_columns
[self.y_true] + self.class_probability_columns,
)
if empty:
self._sampling_error_components = [(np.NaN, 0) for class_col in self.class_probability_columns]
Expand Down Expand Up @@ -1022,10 +1023,9 @@ def _calculate(self, data: pd.DataFrame):
return average_precision_score(y_true, y_pred_proba, average='macro')

def _sampling_error(self, data: pd.DataFrame) -> float:
class_y_pred_proba_columns = model_output_column_names(self.y_pred_proba)
_list_missing([self.y_true] + class_y_pred_proba_columns, data)
_list_missing([self.y_true] + self.class_probability_columns, data)
data, empty = common_nan_removal(
data[[self.y_true] + class_y_pred_proba_columns], [self.y_true] + class_y_pred_proba_columns
data[[self.y_true] + self.class_probability_columns], [self.y_true] + self.class_probability_columns
)
if empty:
warnings.warn(
Expand Down

0 comments on commit 8582141

Please sign in to comment.