From 58d44bc2c17d834ec182d459ba0566d95eb9a112 Mon Sep 17 00:00:00 2001 From: Nikolaos Perrakis <89025229+nikml@users.noreply.github.com> Date: Thu, 9 Nov 2023 22:19:44 +0200 Subject: [PATCH] fix specificity calculation for CBPE (#334) * fix specificity calculation for CBPE * Feature/multiclass confusion matrix (#287) * create tester * Updated tester * multiclass cm performance estimation * Multiclass confusion matrix calc. and estimation + docs and tests for both * Removed scratch testing files * updating MCM docs * Re-align docs with main version * [skip ci] Update CHANGELOG.md --------- Co-authored-by: Nikolaos Perrakis Co-authored-by: Niels Nuyttens * Small refactor to checks in realized performance calculations to make them consistent with the dedicated realized performance calculator. * Fix broken tests * Fix linting errors due to merges --------- Co-authored-by: Carter Blair Co-authored-by: Niels Nuyttens --- .../metrics/binary_classification.py | 2 +- .../metrics/multiclass_classification.py | 11 +- .../confidence_based/metrics.py | 164 +++++- .../CBPE/test_cbpe_metrics.py | 474 +++++++++++++++--- 4 files changed, 586 insertions(+), 65 deletions(-) diff --git a/nannyml/performance_calculation/metrics/binary_classification.py b/nannyml/performance_calculation/metrics/binary_classification.py index 2b4d97bc..44b04aa6 100644 --- a/nannyml/performance_calculation/metrics/binary_classification.py +++ b/nannyml/performance_calculation/metrics/binary_classification.py @@ -1,12 +1,12 @@ # Author: Niels Nuyttens # # License: Apache Software License 2.0 +import warnings from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import pandas as pd from sklearn.metrics import confusion_matrix, f1_score, precision_score, recall_score, roc_auc_score -import warnings from nannyml._typing import ProblemType from nannyml.base import _list_missing diff --git a/nannyml/performance_calculation/metrics/multiclass_classification.py b/nannyml/performance_calculation/metrics/multiclass_classification.py index 30b067c2..33ba3784 100644 --- a/nannyml/performance_calculation/metrics/multiclass_classification.py +++ b/nannyml/performance_calculation/metrics/multiclass_classification.py @@ -7,24 +7,25 @@ # License: Apache Software License 2.0 """Module containing metric utilities and implementations.""" +import warnings from typing import Dict, List, Optional, Tuple, Union # noqa: TYP001 import numpy as np import pandas as pd -import warnings from sklearn.metrics import ( accuracy_score, + confusion_matrix, f1_score, multilabel_confusion_matrix, precision_score, recall_score, roc_auc_score, - confusion_matrix, ) from sklearn.preprocessing import LabelBinarizer, label_binarize from nannyml._typing import ProblemType, class_labels, model_output_column_names from nannyml.base import _list_missing +from nannyml.chunk import Chunker from nannyml.exceptions import InvalidArgumentsException from nannyml.performance_calculation.metrics.base import Metric, MetricFactory, _common_data_cleaning from nannyml.sampling_error.multiclass_classification import ( @@ -44,7 +45,6 @@ multiclass_confusion_matrix_sampling_error_components, ) from nannyml.thresholds import Threshold, calculate_threshold_values -from nannyml.chunk import Chunker @MetricFactory.register(metric='roc_auc', use_case=ProblemType.CLASSIFICATION_MULTICLASS) @@ -674,7 +674,10 @@ def _get_components(self, classes: List[str]) -> List[Tuple[str, str]]: for true_class in classes: for pred_class in classes: components.append( - (f"true class: '{true_class}', predicted class: '{pred_class}'", f'true_{true_class}_pred_{pred_class}') + ( + f"true class: '{true_class}', predicted class: '{pred_class}'", + f'true_{true_class}_pred_{pred_class}', + ) ) return components diff --git a/nannyml/performance_estimation/confidence_based/metrics.py b/nannyml/performance_estimation/confidence_based/metrics.py index d52c9044..a7316f9e 100644 --- a/nannyml/performance_estimation/confidence_based/metrics.py +++ b/nannyml/performance_estimation/confidence_based/metrics.py @@ -11,6 +11,7 @@ import abc import logging +import warnings from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import numpy as np @@ -402,6 +403,11 @@ def _realized_performance(self, data: pd.DataFrame) -> float: y_pred_proba, _, y_true = self._common_cleaning(data, y_pred_proba_column_name=self.uncalibrated_y_pred_proba) if y_true is None: + warnings.warn("No 'y_true' values given for chunk, returning NaN as realized ROC-AUC.") + return np.NaN + + if y_true.nunique() <= 1: + warnings.warn("Too few unique values present in 'y_true', returning NaN as realized ROC-AUC.") return np.NaN return roc_auc_score(y_true, y_pred_proba) @@ -494,6 +500,15 @@ def _realized_performance(self, data: pd.DataFrame) -> float: _, y_pred, y_true = self._common_cleaning(data, y_pred_proba_column_name=self.uncalibrated_y_pred_proba) if y_true is None: + warnings.warn("No 'y_true' values given for chunk, returning NaN as realized F1 score.") + return np.NaN + + if y_true.nunique() <= 1: + warnings.warn("Too few unique values present in 'y_true', returning NaN as realized F1 score.") + return np.NaN + + if y_pred.nunique() <= 1: + warnings.warn("Too few unique values present in 'y_pred', returning NaN as realized F1 score.") return np.NaN return f1_score(y_true=y_true, y_pred=y_pred) @@ -570,6 +585,15 @@ def _realized_performance(self, data: pd.DataFrame) -> float: _, y_pred, y_true = self._common_cleaning(data, y_pred_proba_column_name=self.uncalibrated_y_pred_proba) if y_true is None: + warnings.warn("No 'y_true' values given for chunk, returning NaN as realized precision.") + return np.NaN + + if y_true.nunique() <= 1: + warnings.warn("Too few unique values present in 'y_true', returning NaN as realized precision.") + return np.NaN + + if y_pred.nunique() <= 1: + warnings.warn("Too few unique values present in 'y_pred', returning NaN as realized precision.") return np.NaN return precision_score(y_true=y_true, y_pred=y_pred) @@ -644,6 +668,15 @@ def _realized_performance(self, data: pd.DataFrame) -> float: _, y_pred, y_true = self._common_cleaning(data, y_pred_proba_column_name=self.uncalibrated_y_pred_proba) if y_true is None: + warnings.warn("No 'y_true' values given for chunk, returning NaN as realized recall.") + return np.NaN + + if y_true.nunique() <= 1: + warnings.warn("Too few unique values present in 'y_true', returning NaN as recall precision.") + return np.NaN + + if y_pred.nunique() <= 1: + warnings.warn("Too few unique values present in 'y_pred', returning NaN as recall precision.") return np.NaN return recall_score(y_true=y_true, y_pred=y_pred) @@ -718,10 +751,19 @@ def _realized_performance(self, data: pd.DataFrame) -> float: _, y_pred, y_true = self._common_cleaning(data, y_pred_proba_column_name=self.uncalibrated_y_pred_proba) if y_true is None: + warnings.warn("No 'y_true' values given for chunk, returning NaN as realized specificity.") + return np.NaN + + if y_true.nunique() <= 1: + warnings.warn("Too few unique values present in 'y_true', returning NaN as realized specificity.") + return np.NaN + + if y_pred.nunique() <= 1: + warnings.warn("Too few unique values present in 'y_pred', returning NaN as realized specificity.") return np.NaN - conf_matrix = confusion_matrix(y_true=y_true, y_pred=y_pred) - return conf_matrix[1, 1] / (conf_matrix[1, 0] + conf_matrix[1, 1]) + tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel() + return tn / (tn + fp) def estimate_specificity(y_pred: pd.DataFrame, y_pred_proba: pd.DataFrame) -> float: @@ -797,6 +839,15 @@ def _realized_performance(self, data: pd.DataFrame) -> float: _, y_pred, y_true = self._common_cleaning(data, y_pred_proba_column_name=self.uncalibrated_y_pred_proba) if y_true is None: + warnings.warn("No 'y_true' values given for chunk, returning NaN as realized accuracy.") + return np.NaN + + if y_true.nunique() <= 1: + warnings.warn("Too few unique values present in 'y_true', returning NaN as realized accuracy.") + return np.NaN + + if y_pred.nunique() <= 1: + warnings.warn("Too few unique values present in 'y_pred', returning NaN as realized accuracy.") return np.NaN return accuracy_score(y_true=y_true, y_pred=y_pred) @@ -961,6 +1012,15 @@ def _true_positive_realized_performance(self, data: pd.DataFrame) -> float: _, y_pred, y_true = self._common_cleaning(data, y_pred_proba_column_name=self.uncalibrated_y_pred_proba) if y_true is None: + warnings.warn("No 'y_true' values given for chunk, returning NaN as realized confusion matrix.") + return np.NaN + + if y_true.nunique() <= 1: + warnings.warn("Too few unique values present in 'y_true', returning NaN as realized confusion matrix.") + return np.NaN + + if y_pred.nunique() <= 1: + warnings.warn("Too few unique values present in 'y_pred', returning NaN as realized confusion matrix.") return np.NaN num_tp = np.sum(np.logical_and(y_pred, y_true)) @@ -980,6 +1040,7 @@ def _true_negative_realized_performance(self, data: pd.DataFrame) -> float: _, y_pred, y_true = self._common_cleaning(data, y_pred_proba_column_name=self.uncalibrated_y_pred_proba) if y_true is None: + warnings.warn("No 'y_true' values given for chunk, returning NaN as realized confusion matrix.") return np.NaN num_tn = np.sum(np.logical_and(np.logical_not(y_pred), np.logical_not(y_true))) @@ -999,6 +1060,15 @@ def _false_positive_realized_performance(self, data: pd.DataFrame) -> float: _, y_pred, y_true = self._common_cleaning(data, y_pred_proba_column_name=self.uncalibrated_y_pred_proba) if y_true is None: + warnings.warn("No 'y_true' values given for chunk, returning NaN as realized confusion matrix.") + return np.NaN + + if y_true.nunique() <= 1: + warnings.warn("Too few unique values present in 'y_true', returning NaN as realized confusion matrix.") + return np.NaN + + if y_pred.nunique() <= 1: + warnings.warn("Too few unique values present in 'y_pred', returning NaN as realized confusion matrix.") return np.NaN num_tp = np.sum(np.logical_and(y_pred, y_true)) @@ -1018,6 +1088,15 @@ def _false_negative_realized_performance(self, data: pd.DataFrame) -> float: _, y_pred, y_true = self._common_cleaning(data, y_pred_proba_column_name=self.uncalibrated_y_pred_proba) if y_true is None: + warnings.warn("No 'y_true' values given for chunk, returning NaN as realized confusion matrix.") + return np.NaN + + if y_true.nunique() <= 1: + warnings.warn("Too few unique values present in 'y_true', returning NaN as realized confusion matrix.") + return np.NaN + + if y_pred.nunique() <= 1: + warnings.warn("Too few unique values present in 'y_pred', returning NaN as realized confusion matrix.") return np.NaN num_tp = np.sum(np.logical_and(y_pred, y_true)) @@ -1500,6 +1579,15 @@ def _realized_performance(self, data: pd.DataFrame) -> float: _, y_pred, y_true = self._common_cleaning(data, y_pred_proba_column_name=self.uncalibrated_y_pred_proba) if y_true is None: + warnings.warn("No 'y_true' values given for chunk, returning NaN as realized business value.") + return np.NaN + + if y_true.nunique() <= 1: + warnings.warn("Too few unique values present in 'y_true', returning NaN as realized business value.") + return np.NaN + + if y_pred.nunique() <= 1: + warnings.warn("Too few unique values present in 'y_pred', returning NaN as realized business value.") return np.NaN tp_value = self.business_value_matrix[1, 1] @@ -1677,7 +1765,13 @@ def _sampling_error(self, data: pd.DataFrame) -> float: def _realized_performance(self, data: pd.DataFrame) -> float: data = self._ensure_targets(data) + if data is None: + warnings.warn("No 'y_true' values given for chunk, returning NaN as realized ROC-AUC.") + return np.NaN + + if data[self.y_true].nunique() <= 1: + warnings.warn("Too few unique values present in 'y_true', returning NaN as realized ROC-AUC.") return np.NaN _, y_pred_probas, labels = _get_multiclass_uncalibrated_predictions(data, self.y_pred, self.y_pred_proba) @@ -1734,7 +1828,17 @@ def _sampling_error(self, data: pd.DataFrame) -> float: def _realized_performance(self, data: pd.DataFrame) -> float: data = self._ensure_targets(data) + if data is None: + warnings.warn("No 'y_true' values given for chunk, returning NaN as realized F1 score.") + return np.NaN + + if data[self.y_true].nunique() <= 1: + warnings.warn("Too few unique values present in 'y_true', returning NaN as realized F1 score.") + return np.NaN + + if data[self.y_pred].nunique() <= 1: + warnings.warn("Too few unique values present in 'y_pred', returning NaN as realized F1 score.") return np.NaN y_pred, _, labels = _get_multiclass_uncalibrated_predictions(data, self.y_pred, self.y_pred_proba) @@ -1791,7 +1895,17 @@ def _sampling_error(self, data: pd.DataFrame) -> float: def _realized_performance(self, data: pd.DataFrame) -> float: data = self._ensure_targets(data) + if data is None: + warnings.warn("No 'y_true' values given for chunk, returning NaN as realized precision.") + return np.NaN + + if data[self.y_true].nunique() <= 1: + warnings.warn("Too few unique values present in 'y_true', returning NaN as realized precision.") + return np.NaN + + if data[self.y_pred].nunique() <= 1: + warnings.warn("Too few unique values present in 'y_pred', returning NaN as realized precision.") return np.NaN y_pred, _, labels = _get_multiclass_uncalibrated_predictions(data, self.y_pred, self.y_pred_proba) @@ -1848,7 +1962,17 @@ def _sampling_error(self, data: pd.DataFrame) -> float: def _realized_performance(self, data: pd.DataFrame) -> float: data = self._ensure_targets(data) + if data is None: + warnings.warn("No 'y_true' values given for chunk, returning NaN as realized recall.") + return np.NaN + + if data[self.y_true].nunique() <= 1: + warnings.warn("Too few unique values present in 'y_true', returning NaN as realized recall.") + return np.NaN + + if data[self.y_pred].nunique() <= 1: + warnings.warn("Too few unique values present in 'y_pred', returning NaN as realized recall.") return np.NaN y_pred, _, labels = _get_multiclass_uncalibrated_predictions(data, self.y_pred, self.y_pred_proba) @@ -1905,7 +2029,17 @@ def _sampling_error(self, data: pd.DataFrame) -> float: def _realized_performance(self, data: pd.DataFrame) -> float: data = self._ensure_targets(data) + if data is None: + warnings.warn("No 'y_true' values given for chunk, returning NaN as realized specificity.") + return np.NaN + + if data[self.y_true].nunique() <= 1: + warnings.warn("Too few unique values present in 'y_true', returning NaN as realized specificity.") + return np.NaN + + if data[self.y_pred].nunique() <= 1: + warnings.warn("Too few unique values present in 'y_pred', returning NaN as realized specificity.") return np.NaN y_pred, _, labels = _get_multiclass_uncalibrated_predictions(data, self.y_pred, self.y_pred_proba) @@ -1964,8 +2098,19 @@ def _sampling_error(self, data: pd.DataFrame) -> float: def _realized_performance(self, data: pd.DataFrame) -> float: data = self._ensure_targets(data) + if data is None: + warnings.warn("No 'y_true' values given for chunk, returning NaN as realized accuracy.") return np.NaN + + if data[self.y_true].nunique() <= 1: + warnings.warn("Too few unique values present in 'y_true', returning NaN as realized accuracy.") + return np.NaN + + if data[self.y_pred].nunique() <= 1: + warnings.warn("Too few unique values present in 'y_pred', returning NaN as realized accuracy.") + return np.NaN + y_pred, _, _ = _get_multiclass_uncalibrated_predictions(data, self.y_pred, self.y_pred_proba) return accuracy_score(data[self.y_true], y_pred) @@ -2011,7 +2156,10 @@ def _get_components(self, classes: List[str]) -> List[Tuple[str, str]]: for true_class in classes: for pred_class in classes: components.append( - (f"true class: '{true_class}', predicted class: '{pred_class}'", f'true_{true_class}_pred_{pred_class}') + ( + f"true class: '{true_class}', predicted class: '{pred_class}'", + f'true_{true_class}_pred_{pred_class}', + ) ) return components @@ -2074,8 +2222,16 @@ def _multiclass_confusion_matrix_alert_thresholds( return alert_thresholds def _multi_class_confusion_matrix_realized_performance(self, data: pd.DataFrame) -> Union[np.ndarray, float]: + if data is None or self.y_true not in data.columns: + warnings.warn("No 'y_true' values given for chunk, returning NaN as realized precision.") + return np.NaN + + if data[self.y_true].nunique() <= 1: + warnings.warn("Too few unique values present in 'y_true', returning NaN as realized precision.") + return np.NaN - if self.y_true not in data.columns or data[self.y_true].isna().all(): + if data[self.y_pred].nunique() <= 1: + warnings.warn("Too few unique values present in 'y_pred', returning NaN as realized precision.") return np.NaN cm = confusion_matrix( diff --git a/tests/performance_estimation/CBPE/test_cbpe_metrics.py b/tests/performance_estimation/CBPE/test_cbpe_metrics.py index b515417e..5c6210ab 100644 --- a/tests/performance_estimation/CBPE/test_cbpe_metrics.py +++ b/tests/performance_estimation/CBPE/test_cbpe_metrics.py @@ -2453,15 +2453,51 @@ def test_cbpe_for_binary_classification_with_timestamps(calculator_opts, expecte 'estimated_recall': [0.7564129287764665, 0.6934788458355289, 0.6319310599943714], 'estimated_specificity': [0.8782068281303994, 0.8469556750949159, 0.8172644220189141], 'estimated_accuracy': [0.7564451493123628, 0.6946947603445697, 0.6378557309960986], - 'estimated_true_highstreet_card_pred_highstreet_card': [4976.829215997277, 5148.649186425118, 5412.348045797111], - 'estimated_true_highstreet_card_pred_prepaid_card': [878.1877379091701, 1038.3533241561252, 1250.9260097761653], - 'estimated_true_highstreet_card_pred_upmarket_card': [831.7702766018707, 993.7691398029524, 1109.9706655490413], - 'estimated_true_prepaid_card_pred_highstreet_card': [806.1451187447954, 1140.1932616586546, 1451.431964364007], - 'estimated_true_prepaid_card_pred_prepaid_card': [5180.838942632071, 4134.524656135082, 3326.8467648553315], - 'estimated_true_prepaid_card_pred_upmarket_card': [755.9948957802203, 998.509495865855, 1200.1095251814281], - 'estimated_true_upmarket_card_pred_highstreet_card': [812.0256652579275, 1062.1575519162266, 1263.219989838882], - 'estimated_true_upmarket_card_pred_prepaid_card': [786.9733194587595, 873.1220197087925, 967.2272253685034], - 'estimated_true_upmarket_card_pred_upmarket_card': [4971.234827617909, 4610.7213643311925, 4017.9198092695306], + 'estimated_true_highstreet_card_pred_highstreet_card': [ + 4976.829215997277, + 5148.649186425118, + 5412.348045797111, + ], + 'estimated_true_highstreet_card_pred_prepaid_card': [ + 878.1877379091701, + 1038.3533241561252, + 1250.9260097761653, + ], + 'estimated_true_highstreet_card_pred_upmarket_card': [ + 831.7702766018707, + 993.7691398029524, + 1109.9706655490413, + ], + 'estimated_true_prepaid_card_pred_highstreet_card': [ + 806.1451187447954, + 1140.1932616586546, + 1451.431964364007, + ], + 'estimated_true_prepaid_card_pred_prepaid_card': [ + 5180.838942632071, + 4134.524656135082, + 3326.8467648553315, + ], + 'estimated_true_prepaid_card_pred_upmarket_card': [ + 755.9948957802203, + 998.509495865855, + 1200.1095251814281, + ], + 'estimated_true_upmarket_card_pred_highstreet_card': [ + 812.0256652579275, + 1062.1575519162266, + 1263.219989838882, + ], + 'estimated_true_upmarket_card_pred_prepaid_card': [ + 786.9733194587595, + 873.1220197087925, + 967.2272253685034, + ], + 'estimated_true_upmarket_card_pred_upmarket_card': [ + 4971.234827617909, + 4610.7213643311925, + 4017.9198092695306, + ], } ), ), @@ -2476,15 +2512,51 @@ def test_cbpe_for_binary_classification_with_timestamps(calculator_opts, expecte 'estimated_recall': [0.7564129287764665, 0.6934788458355289, 0.6319310599943714], 'estimated_specificity': [0.8782068281303994, 0.8469556750949159, 0.8172644220189141], 'estimated_accuracy': [0.7564451493123628, 0.6946947603445697, 0.6378557309960986], - 'estimated_true_highstreet_card_pred_highstreet_card': [0.7442780881812128, 0.7170050012869645, 0.6962791266676683], - 'estimated_true_highstreet_card_pred_prepaid_card': [0.1313317902358936, 0.14460191393226796, 0.16092713592008898], - 'estimated_true_highstreet_card_pred_upmarket_card': [0.12439012158289371, 0.1383930847807676, 0.1427937374122426], - 'estimated_true_prepaid_card_pred_highstreet_card': [0.11955326034187638, 0.18175544842770236, 0.24277980997563847], - 'estimated_true_prepaid_card_pred_prepaid_card': [0.7683308780213619, 0.6590745693568182, 0.5564788741190233], - 'estimated_true_prepaid_card_pred_upmarket_card': [0.1121158616367618, 0.15916998221547937, 0.20074131590533828], - 'estimated_true_upmarket_card_pred_highstreet_card': [0.1235915933057778, 0.16226052551901615, 0.20216802004274595], - 'estimated_true_upmarket_card_pred_prepaid_card': [0.1197785865673972, 0.13338250761817996, 0.15479680076083163], - 'estimated_true_upmarket_card_pred_upmarket_card': [0.756629820126825, 0.7043569668628038, 0.6430351791964225], + 'estimated_true_highstreet_card_pred_highstreet_card': [ + 0.7442780881812128, + 0.7170050012869645, + 0.6962791266676683, + ], + 'estimated_true_highstreet_card_pred_prepaid_card': [ + 0.1313317902358936, + 0.14460191393226796, + 0.16092713592008898, + ], + 'estimated_true_highstreet_card_pred_upmarket_card': [ + 0.12439012158289371, + 0.1383930847807676, + 0.1427937374122426, + ], + 'estimated_true_prepaid_card_pred_highstreet_card': [ + 0.11955326034187638, + 0.18175544842770236, + 0.24277980997563847, + ], + 'estimated_true_prepaid_card_pred_prepaid_card': [ + 0.7683308780213619, + 0.6590745693568182, + 0.5564788741190233, + ], + 'estimated_true_prepaid_card_pred_upmarket_card': [ + 0.1121158616367618, + 0.15916998221547937, + 0.20074131590533828, + ], + 'estimated_true_upmarket_card_pred_highstreet_card': [ + 0.1235915933057778, + 0.16226052551901615, + 0.20216802004274595, + ], + 'estimated_true_upmarket_card_pred_prepaid_card': [ + 0.1197785865673972, + 0.13338250761817996, + 0.15479680076083163, + ], + 'estimated_true_upmarket_card_pred_upmarket_card': [ + 0.756629820126825, + 0.7043569668628038, + 0.6430351791964225, + ], } ), ), @@ -2519,15 +2591,60 @@ def test_cbpe_for_binary_classification_with_timestamps(calculator_opts, expecte 0.6364205304514962, 0.6375753072973162, ], - 'estimated_true_highstreet_card_pred_highstreet_card': [0.7546260682147157, 0.7511343683695074, 0.6628383225865804, 0.6651814251770874], - 'estimated_true_highstreet_card_pred_prepaid_card': [0.12922483020709813, 0.12720280190168412, 0.22365956156664257, 0.22578913179209303], - 'estimated_true_highstreet_card_pred_upmarket_card': [0.12747696595643684, 0.12776612448252053, 0.17277613353669485, 0.17660735301820177], - 'estimated_true_prepaid_card_pred_highstreet_card': [0.12118073967907128, 0.1249170750987652, 0.18024418583692642, 0.17798857692081155], - 'estimated_true_prepaid_card_pred_prepaid_card': [0.7554502796336932, 0.7576402255283115, 0.5994574163887797, 0.5998622938235557], - 'estimated_true_prepaid_card_pred_upmarket_card': [0.11748464117810321, 0.11221429055241054, 0.1924554660281669, 0.18783942082621902], - 'estimated_true_upmarket_card_pred_highstreet_card': [0.12419319210621305, 0.12394855653172744, 0.15691749157649315, 0.15682999790210106], - 'estimated_true_upmarket_card_pred_prepaid_card': [0.11532489015920869, 0.1151569725700045, 0.17688302204457784, 0.17434857438435108], - 'estimated_true_upmarket_card_pred_upmarket_card': [0.7550383928654599, 0.7600195849650688, 0.6347684004351383, 0.6355532261555792], + 'estimated_true_highstreet_card_pred_highstreet_card': [ + 0.7546260682147157, + 0.7511343683695074, + 0.6628383225865804, + 0.6651814251770874, + ], + 'estimated_true_highstreet_card_pred_prepaid_card': [ + 0.12922483020709813, + 0.12720280190168412, + 0.22365956156664257, + 0.22578913179209303, + ], + 'estimated_true_highstreet_card_pred_upmarket_card': [ + 0.12747696595643684, + 0.12776612448252053, + 0.17277613353669485, + 0.17660735301820177, + ], + 'estimated_true_prepaid_card_pred_highstreet_card': [ + 0.12118073967907128, + 0.1249170750987652, + 0.18024418583692642, + 0.17798857692081155, + ], + 'estimated_true_prepaid_card_pred_prepaid_card': [ + 0.7554502796336932, + 0.7576402255283115, + 0.5994574163887797, + 0.5998622938235557, + ], + 'estimated_true_prepaid_card_pred_upmarket_card': [ + 0.11748464117810321, + 0.11221429055241054, + 0.1924554660281669, + 0.18783942082621902, + ], + 'estimated_true_upmarket_card_pred_highstreet_card': [ + 0.12419319210621305, + 0.12394855653172744, + 0.15691749157649315, + 0.15682999790210106, + ], + 'estimated_true_upmarket_card_pred_prepaid_card': [ + 0.11532489015920869, + 0.1151569725700045, + 0.17688302204457784, + 0.17434857438435108, + ], + 'estimated_true_upmarket_card_pred_upmarket_card': [ + 0.7550383928654599, + 0.7600195849650688, + 0.6347684004351383, + 0.6355532261555792, + ], } ), ), @@ -2562,15 +2679,60 @@ def test_cbpe_for_binary_classification_with_timestamps(calculator_opts, expecte 0.6364205304514962, 0.6375753072973162, ], - 'estimated_true_highstreet_card_pred_highstreet_card': [0.24922783612904678, 0.24847524905663304, 0.2702612787293017, 0.2678907326329857], - 'estimated_true_highstreet_card_pred_prepaid_card': [0.044125972021383776, 0.04231613209929359, 0.06202825174114887, 0.06269411559427118], - 'estimated_true_highstreet_card_pred_upmarket_card': [0.04184643869129968, 0.04299755975918424, 0.05441296365515643, 0.05644371002461729], - 'estimated_true_prepaid_card_pred_highstreet_card': [0.04002195895800795, 0.04132256844267153, 0.0734915627052428, 0.07168193287857484], - 'estimated_true_prepaid_card_pred_prepaid_card': [0.25796108881891844, 0.25204164835908494, 0.16624952347848823, 0.16656176358500735], - 'estimated_true_prepaid_card_pred_upmarket_card': [0.03856629154406535, 0.03776384924723789, 0.0606106414344707, 0.060033478896059596], - 'estimated_true_upmarket_card_pred_highstreet_card': [0.041016871579611966, 0.041002182500695435, 0.06398049189878882, 0.06316066782177283], - 'estimated_true_upmarket_card_pred_prepaid_card': [0.03937960582636446, 0.038308886208288165, 0.04905555811369625, 0.04841078748738816], - 'estimated_true_upmarket_card_pred_upmarket_card': [0.24785393643130169, 0.25577192432691115, 0.1999097282437062, 0.2031228110793231], + 'estimated_true_highstreet_card_pred_highstreet_card': [ + 0.24922783612904678, + 0.24847524905663304, + 0.2702612787293017, + 0.2678907326329857, + ], + 'estimated_true_highstreet_card_pred_prepaid_card': [ + 0.044125972021383776, + 0.04231613209929359, + 0.06202825174114887, + 0.06269411559427118, + ], + 'estimated_true_highstreet_card_pred_upmarket_card': [ + 0.04184643869129968, + 0.04299755975918424, + 0.05441296365515643, + 0.05644371002461729, + ], + 'estimated_true_prepaid_card_pred_highstreet_card': [ + 0.04002195895800795, + 0.04132256844267153, + 0.0734915627052428, + 0.07168193287857484, + ], + 'estimated_true_prepaid_card_pred_prepaid_card': [ + 0.25796108881891844, + 0.25204164835908494, + 0.16624952347848823, + 0.16656176358500735, + ], + 'estimated_true_prepaid_card_pred_upmarket_card': [ + 0.03856629154406535, + 0.03776384924723789, + 0.0606106414344707, + 0.060033478896059596, + ], + 'estimated_true_upmarket_card_pred_highstreet_card': [ + 0.041016871579611966, + 0.041002182500695435, + 0.06398049189878882, + 0.06316066782177283, + ], + 'estimated_true_upmarket_card_pred_prepaid_card': [ + 0.03937960582636446, + 0.038308886208288165, + 0.04905555811369625, + 0.04841078748738816, + ], + 'estimated_true_upmarket_card_pred_upmarket_card': [ + 0.24785393643130169, + 0.25577192432691115, + 0.1999097282437062, + 0.2031228110793231, + ], } ), ), @@ -2685,15 +2847,114 @@ def test_cbpe_for_binary_classification_with_timestamps(calculator_opts, expecte 0.6365172577468735, 0.6393273094601863, ], - 'estimated_true_highstreet_card_pred_highstreet_card': [1483.745037516118, 1536.2546154566053, 1486.1512390473335, 1455.1117469508827, 1504.2836388142573, 1581.5006283773678, 1619.4221852490452, 1653.916785154108, 1596.0668735461204, 1621.3736981076686], - 'estimated_true_highstreet_card_pred_prepaid_card': [271.9744616336458, 263.3288788018858, 255.36687592730394, 249.4490435511216, 256.51230189620316, 373.19858748671334, 376.5016769489258, 379.92027913064413, 383.2996831232405, 357.9152833417768], - 'estimated_true_highstreet_card_pred_upmarket_card': [249.77244098451774, 249.25341402994002, 256.5894445268087, 248.14404310496627, 268.90063411102597, 336.40636999479034, 321.1640326687466, 329.70510362377274, 338.73618885526093, 336.8384100540352], - 'estimated_true_prepaid_card_pred_highstreet_card': [249.18645665281267, 234.635939041771, 241.34496258349301, 243.99437206824282, 251.00618066387275, 426.16695555097976, 444.34831788068584, 437.82379051045245, 436.76196316220245, 432.5014066529442], - 'estimated_true_prepaid_card_pred_prepaid_card': [1570.046457210577, 1517.6502407889316, 1541.0740605507428, 1544.8307224347122, 1476.439576685086, 1000.6602945021018, 1015.9002964606179, 990.2540537898107, 1008.1898500978498, 977.1648111020529], - 'estimated_true_prepaid_card_pred_upmarket_card': [227.67692529471515, 228.16728611276778, 224.39810820574257, 236.66544157504953, 228.04435068127367, 373.80025334213235, 351.54340476695654, 359.16491614682934, 364.41698830746793, 360.7362423945684], - 'estimated_true_upmarket_card_pred_highstreet_card': [243.06850583106933, 254.1094455016238, 238.50379836917327, 249.89388098087457, 244.71018052186997, 378.3324160716525, 385.22949687026846, 390.2594243354395, 372.1711632916771, 381.12489523938723], - 'estimated_true_upmarket_card_pred_prepaid_card': [237.97908115577718, 232.0208804091826, 234.55906352195305, 231.72023401416573, 229.04812141871074, 297.14111801118486, 301.59802659045613, 282.82566707954516, 285.5104667789098, 294.91990555617025], - 'estimated_true_upmarket_card_pred_upmarket_card': [1466.550633720767, 1484.5792998572922, 1522.0124472674486, 1540.1905153199841, 1541.0550152077003, 1232.7933766630772, 1184.292562564297, 1176.129980229398, 1214.8468228372712, 1237.4253475513965], + 'estimated_true_highstreet_card_pred_highstreet_card': [ + 1483.745037516118, + 1536.2546154566053, + 1486.1512390473335, + 1455.1117469508827, + 1504.2836388142573, + 1581.5006283773678, + 1619.4221852490452, + 1653.916785154108, + 1596.0668735461204, + 1621.3736981076686, + ], + 'estimated_true_highstreet_card_pred_prepaid_card': [ + 271.9744616336458, + 263.3288788018858, + 255.36687592730394, + 249.4490435511216, + 256.51230189620316, + 373.19858748671334, + 376.5016769489258, + 379.92027913064413, + 383.2996831232405, + 357.9152833417768, + ], + 'estimated_true_highstreet_card_pred_upmarket_card': [ + 249.77244098451774, + 249.25341402994002, + 256.5894445268087, + 248.14404310496627, + 268.90063411102597, + 336.40636999479034, + 321.1640326687466, + 329.70510362377274, + 338.73618885526093, + 336.8384100540352, + ], + 'estimated_true_prepaid_card_pred_highstreet_card': [ + 249.18645665281267, + 234.635939041771, + 241.34496258349301, + 243.99437206824282, + 251.00618066387275, + 426.16695555097976, + 444.34831788068584, + 437.82379051045245, + 436.76196316220245, + 432.5014066529442, + ], + 'estimated_true_prepaid_card_pred_prepaid_card': [ + 1570.046457210577, + 1517.6502407889316, + 1541.0740605507428, + 1544.8307224347122, + 1476.439576685086, + 1000.6602945021018, + 1015.9002964606179, + 990.2540537898107, + 1008.1898500978498, + 977.1648111020529, + ], + 'estimated_true_prepaid_card_pred_upmarket_card': [ + 227.67692529471515, + 228.16728611276778, + 224.39810820574257, + 236.66544157504953, + 228.04435068127367, + 373.80025334213235, + 351.54340476695654, + 359.16491614682934, + 364.41698830746793, + 360.7362423945684, + ], + 'estimated_true_upmarket_card_pred_highstreet_card': [ + 243.06850583106933, + 254.1094455016238, + 238.50379836917327, + 249.89388098087457, + 244.71018052186997, + 378.3324160716525, + 385.22949687026846, + 390.2594243354395, + 372.1711632916771, + 381.12489523938723, + ], + 'estimated_true_upmarket_card_pred_prepaid_card': [ + 237.97908115577718, + 232.0208804091826, + 234.55906352195305, + 231.72023401416573, + 229.04812141871074, + 297.14111801118486, + 301.59802659045613, + 282.82566707954516, + 285.5104667789098, + 294.91990555617025, + ], + 'estimated_true_upmarket_card_pred_upmarket_card': [ + 1466.550633720767, + 1484.5792998572922, + 1522.0124472674486, + 1540.1905153199841, + 1541.0550152077003, + 1232.7933766630772, + 1184.292562564297, + 1176.129980229398, + 1214.8468228372712, + 1237.4253475513965, + ], } ), ), @@ -2785,15 +3046,114 @@ def test_cbpe_for_binary_classification_with_timestamps(calculator_opts, expecte 0.6365172577468735, 0.6393273094601863, ], - 'estimated_true_highstreet_card_pred_highstreet_card': [1483.745037516118, 1536.2546154566053, 1486.1512390473335, 1455.1117469508827, 1504.2836388142573, 1581.5006283773678, 1619.4221852490452, 1653.916785154108, 1596.0668735461204, 1621.3736981076686], - 'estimated_true_highstreet_card_pred_prepaid_card': [271.9744616336458, 263.3288788018858, 255.36687592730394, 249.4490435511216, 256.51230189620316, 373.19858748671334, 376.5016769489258, 379.92027913064413, 383.2996831232405, 357.9152833417768], - 'estimated_true_highstreet_card_pred_upmarket_card': [249.77244098451774, 249.25341402994002, 256.5894445268087, 248.14404310496627, 268.90063411102597, 336.40636999479034, 321.1640326687466, 329.70510362377274, 338.73618885526093, 336.8384100540352], - 'estimated_true_prepaid_card_pred_highstreet_card': [249.18645665281267, 234.635939041771, 241.34496258349301, 243.99437206824282, 251.00618066387275, 426.16695555097976, 444.34831788068584, 437.82379051045245, 436.76196316220245, 432.5014066529442], - 'estimated_true_prepaid_card_pred_prepaid_card': [1570.046457210577, 1517.6502407889316, 1541.0740605507428, 1544.8307224347122, 1476.439576685086, 1000.6602945021018, 1015.9002964606179, 990.2540537898107, 1008.1898500978498, 977.1648111020529], - 'estimated_true_prepaid_card_pred_upmarket_card': [227.67692529471515, 228.16728611276778, 224.39810820574257, 236.66544157504953, 228.04435068127367, 373.80025334213235, 351.54340476695654, 359.16491614682934, 364.41698830746793, 360.7362423945684], - 'estimated_true_upmarket_card_pred_highstreet_card': [243.06850583106933, 254.1094455016238, 238.50379836917327, 249.89388098087457, 244.71018052186997, 378.3324160716525, 385.22949687026846, 390.2594243354395, 372.1711632916771, 381.12489523938723], - 'estimated_true_upmarket_card_pred_prepaid_card': [237.97908115577718, 232.0208804091826, 234.55906352195305, 231.72023401416573, 229.04812141871074, 297.14111801118486, 301.59802659045613, 282.82566707954516, 285.5104667789098, 294.91990555617025], - 'estimated_true_upmarket_card_pred_upmarket_card': [1466.550633720767, 1484.5792998572922, 1522.0124472674486, 1540.1905153199841, 1541.0550152077003, 1232.7933766630772, 1184.292562564297, 1176.129980229398, 1214.8468228372712, 1237.4253475513965], + 'estimated_true_highstreet_card_pred_highstreet_card': [ + 1483.745037516118, + 1536.2546154566053, + 1486.1512390473335, + 1455.1117469508827, + 1504.2836388142573, + 1581.5006283773678, + 1619.4221852490452, + 1653.916785154108, + 1596.0668735461204, + 1621.3736981076686, + ], + 'estimated_true_highstreet_card_pred_prepaid_card': [ + 271.9744616336458, + 263.3288788018858, + 255.36687592730394, + 249.4490435511216, + 256.51230189620316, + 373.19858748671334, + 376.5016769489258, + 379.92027913064413, + 383.2996831232405, + 357.9152833417768, + ], + 'estimated_true_highstreet_card_pred_upmarket_card': [ + 249.77244098451774, + 249.25341402994002, + 256.5894445268087, + 248.14404310496627, + 268.90063411102597, + 336.40636999479034, + 321.1640326687466, + 329.70510362377274, + 338.73618885526093, + 336.8384100540352, + ], + 'estimated_true_prepaid_card_pred_highstreet_card': [ + 249.18645665281267, + 234.635939041771, + 241.34496258349301, + 243.99437206824282, + 251.00618066387275, + 426.16695555097976, + 444.34831788068584, + 437.82379051045245, + 436.76196316220245, + 432.5014066529442, + ], + 'estimated_true_prepaid_card_pred_prepaid_card': [ + 1570.046457210577, + 1517.6502407889316, + 1541.0740605507428, + 1544.8307224347122, + 1476.439576685086, + 1000.6602945021018, + 1015.9002964606179, + 990.2540537898107, + 1008.1898500978498, + 977.1648111020529, + ], + 'estimated_true_prepaid_card_pred_upmarket_card': [ + 227.67692529471515, + 228.16728611276778, + 224.39810820574257, + 236.66544157504953, + 228.04435068127367, + 373.80025334213235, + 351.54340476695654, + 359.16491614682934, + 364.41698830746793, + 360.7362423945684, + ], + 'estimated_true_upmarket_card_pred_highstreet_card': [ + 243.06850583106933, + 254.1094455016238, + 238.50379836917327, + 249.89388098087457, + 244.71018052186997, + 378.3324160716525, + 385.22949687026846, + 390.2594243354395, + 372.1711632916771, + 381.12489523938723, + ], + 'estimated_true_upmarket_card_pred_prepaid_card': [ + 237.97908115577718, + 232.0208804091826, + 234.55906352195305, + 231.72023401416573, + 229.04812141871074, + 297.14111801118486, + 301.59802659045613, + 282.82566707954516, + 285.5104667789098, + 294.91990555617025, + ], + 'estimated_true_upmarket_card_pred_upmarket_card': [ + 1466.550633720767, + 1484.5792998572922, + 1522.0124472674486, + 1540.1905153199841, + 1541.0550152077003, + 1232.7933766630772, + 1184.292562564297, + 1176.129980229398, + 1214.8468228372712, + 1237.4253475513965, + ], } ), ), @@ -2825,7 +3185,8 @@ def test_cbpe_for_multiclass_classification_with_timestamps(calculator_opts, exp result = cbpe.estimate(ana_df) column_names = [(m.name, 'value') for m in result.metrics] column_names = [c for c in column_names if c[0] != 'confusion_matrix'] - column_names += [('true_highstreet_card_pred_highstreet_card', 'value'), + column_names += [ + ('true_highstreet_card_pred_highstreet_card', 'value'), ('true_highstreet_card_pred_prepaid_card', 'value'), ('true_highstreet_card_pred_upmarket_card', 'value'), ('true_prepaid_card_pred_highstreet_card', 'value'), @@ -2833,7 +3194,8 @@ def test_cbpe_for_multiclass_classification_with_timestamps(calculator_opts, exp ('true_prepaid_card_pred_upmarket_card', 'value'), ('true_upmarket_card_pred_highstreet_card', 'value'), ('true_upmarket_card_pred_prepaid_card', 'value'), - ('true_upmarket_card_pred_upmarket_card', 'value')] + ('true_upmarket_card_pred_upmarket_card', 'value'), + ] sut = result.filter(period='analysis').to_df()[[('chunk', 'key')] + column_names] sut.columns = [ 'key',