From e57cc09252f91c64e654b363da1304f05540c62d Mon Sep 17 00:00:00 2001 From: Valentin Laurent Date: Sun, 5 Jan 2025 11:13:03 +0100 Subject: [PATCH 1/4] REFACTO: in split setting, remove checking NaNs to avoid inevitable warning, and remove useless aggregation to avoid dependency to agg_function --- mapie/estimator/regressor.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/mapie/estimator/regressor.py b/mapie/estimator/regressor.py index bad8988ca..9e62136c5 100644 --- a/mapie/estimator/regressor.py +++ b/mapie/estimator/regressor.py @@ -402,9 +402,12 @@ def predict_calib( predictions[i], dtype=float ) self.k_[ind, i] = 1 - check_nan_in_aposteriori_prediction(pred_matrix) - y_pred = aggregate_all(self.agg_function, pred_matrix) + if self.cv == "split": + y_pred = pred_matrix.flatten() + else: + check_nan_in_aposteriori_prediction(pred_matrix) + y_pred = aggregate_all(self.agg_function, pred_matrix) return y_pred From f1c60991c4bf8784837a87e1079b1c5ff32676c0 Mon Sep 17 00:00:00 2001 From: Valentin Laurent Date: Sun, 5 Jan 2025 20:02:11 +0100 Subject: [PATCH 2/4] #2 REFACTO: in split setting, remove checking NaNs to avoid inevitable warning, and remove useless aggregation to avoid dependency to agg_function --- mapie/estimator/regressor.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/mapie/estimator/regressor.py b/mapie/estimator/regressor.py index 9e62136c5..daa197f31 100644 --- a/mapie/estimator/regressor.py +++ b/mapie/estimator/regressor.py @@ -5,7 +5,7 @@ import numpy as np from joblib import Parallel, delayed from sklearn.base import RegressorMixin, clone -from sklearn.model_selection import BaseCrossValidator +from sklearn.model_selection import BaseCrossValidator, BaseShuffleSplit from sklearn.utils import _safe_indexing, deprecated from sklearn.utils.validation import _num_samples, check_is_fitted @@ -403,7 +403,10 @@ def predict_calib( ) self.k_[ind, i] = 1 - if self.cv == "split": + if ( + isinstance(self.cv, BaseShuffleSplit) and + self.cv.n_splits == 1 + ): y_pred = pred_matrix.flatten() else: check_nan_in_aposteriori_prediction(pred_matrix) From 3449a48ca4973142ed44e3a3051ae5732245015b Mon Sep 17 00:00:00 2001 From: Valentin Laurent Date: Sun, 5 Jan 2025 22:42:47 +0100 Subject: [PATCH 3/4] FIX: simplify condition, fix tests --- mapie/estimator/regressor.py | 5 +---- mapie/tests/test_regression.py | 2 +- mapie/tests/test_time_series_regression.py | 3 ++- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/mapie/estimator/regressor.py b/mapie/estimator/regressor.py index daa197f31..1703b5f1f 100644 --- a/mapie/estimator/regressor.py +++ b/mapie/estimator/regressor.py @@ -403,10 +403,7 @@ def predict_calib( ) self.k_[ind, i] = 1 - if ( - isinstance(self.cv, BaseShuffleSplit) and - self.cv.n_splits == 1 - ): + if self.use_split_method_: y_pred = pred_matrix.flatten() else: check_nan_in_aposteriori_prediction(pred_matrix) diff --git a/mapie/tests/test_regression.py b/mapie/tests/test_regression.py index e062a3704..f06fff2e3 100644 --- a/mapie/tests/test_regression.py +++ b/mapie/tests/test_regression.py @@ -701,7 +701,7 @@ def test_not_enough_resamplings() -> None: """ with pytest.warns(UserWarning, match=r"WARNING: at least one point of*"): mapie_reg = MapieRegressor( - cv=Subsample(n_resamplings=1), agg_function="mean" + cv=Subsample(n_resamplings=2, random_state=0), agg_function="mean" ) mapie_reg.fit(X, y) diff --git a/mapie/tests/test_time_series_regression.py b/mapie/tests/test_time_series_regression.py index 785cb9088..77e4607b4 100644 --- a/mapie/tests/test_time_series_regression.py +++ b/mapie/tests/test_time_series_regression.py @@ -318,7 +318,8 @@ def test_not_enough_resamplings() -> None: match=r"WARNING: at least one point of*" ): mapie_ts_reg = MapieTimeSeriesRegressor( - cv=BlockBootstrap(n_resamplings=1, n_blocks=1), agg_function="mean" + cv=BlockBootstrap(n_resamplings=2, n_blocks=1, random_state=0), + agg_function="mean" ) mapie_ts_reg.fit(X, y) From aebfe7d34493eb80ef83af461f16975812f976b1 Mon Sep 17 00:00:00 2001 From: Valentin Laurent Date: Sun, 5 Jan 2025 22:46:28 +0100 Subject: [PATCH 4/4] FIX linting --- mapie/estimator/regressor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mapie/estimator/regressor.py b/mapie/estimator/regressor.py index 1703b5f1f..ddf778e02 100644 --- a/mapie/estimator/regressor.py +++ b/mapie/estimator/regressor.py @@ -5,7 +5,7 @@ import numpy as np from joblib import Parallel, delayed from sklearn.base import RegressorMixin, clone -from sklearn.model_selection import BaseCrossValidator, BaseShuffleSplit +from sklearn.model_selection import BaseCrossValidator from sklearn.utils import _safe_indexing, deprecated from sklearn.utils.validation import _num_samples, check_is_fitted