From 76b20734a43e18e304153328e4b6cae580d06779 Mon Sep 17 00:00:00 2001 From: Laurent Sorber Date: Sun, 21 Apr 2024 15:48:15 +0200 Subject: [PATCH] feat: support pre-fitted estimators (#19) --- .../_conformal_coherent_quantile_regressor.py | 26 ++++++++++++------- tests/conftest.py | 6 ++--- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/src/conformal_tights/_conformal_coherent_quantile_regressor.py b/src/conformal_tights/_conformal_coherent_quantile_regressor.py index 0c945a7..8c61e15 100644 --- a/src/conformal_tights/_conformal_coherent_quantile_regressor.py +++ b/src/conformal_tights/_conformal_coherent_quantile_regressor.py @@ -5,6 +5,7 @@ import numpy as np import numpy.typing as npt from sklearn.base import BaseEstimator, MetaEstimatorMixin, RegressorMixin, clone +from sklearn.exceptions import NotFittedError from sklearn.model_selection import train_test_split from sklearn.utils.validation import ( check_array, @@ -134,15 +135,22 @@ def fit( self.sample_weight_calib_l1_, self.sample_weight_calib_l2_ = ( sample_weights_calib[:2] if sample_weight is not None else (None, None) # type: ignore[has-type] ) - # Fit the given estimator on the training data. - self.estimator_ = ( - clone(self.estimator) - if self.estimator != "auto" - else XGBRegressor(objective="reg:absoluteerror") - ) - if isinstance(self.estimator_, XGBRegressor): - self.estimator_.set_params(enable_categorical=True, random_state=self.random_state) - self.estimator_.fit(X_train, y_train, sample_weight=sample_weight_train) + # Check if the estimator was pre-fitted. + try: + check_is_fitted(self.estimator) + except (NotFittedError, TypeError): + # Fit the given estimator on the training data. + self.estimator_ = ( + clone(self.estimator) + if self.estimator != "auto" + else XGBRegressor(objective="reg:absoluteerror") + ) + if isinstance(self.estimator_, XGBRegressor): + self.estimator_.set_params(enable_categorical=True, random_state=self.random_state) + self.estimator_.fit(X_train, y_train, sample_weight=sample_weight_train) + else: + # Use the pre-fitted estimator. + self.estimator_ = self.estimator # Fit a nonconformity estimator on the training data with XGBRegressor's vector quantile # regression. We fit a minimal number of quantiles to reduce the computational cost, but # also to reduce the risk of overfitting in the coherent quantile regressor that is applied diff --git a/tests/conftest.py b/tests/conftest.py index 0bb132b..eb63dd7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,6 @@ """Test fixtures.""" -from typing import TypeAlias +from typing import Literal, TypeAlias import pandas as pd import pytest @@ -39,12 +39,12 @@ def dataset(request: SubRequest) -> Dataset: @pytest.fixture( params=[ - pytest.param(XGBRegressor(objective="reg:absoluteerror"), id="model:XGBRegressor-L1"), + pytest.param("auto", id="model:auto"), pytest.param(XGBRegressor(objective="reg:squarederror"), id="model:XGBRegressor-L2"), pytest.param(LGBMRegressor(objective="regression_l1"), id="model:LGBMRegressor-L1"), pytest.param(LGBMRegressor(objective="regression_l2"), id="model:LGBMRegressor-L2"), ] ) -def regressor(request: SubRequest) -> BaseEstimator: +def regressor(request: SubRequest) -> BaseEstimator | Literal["auto"]: """Return a scikit-learn regressor.""" return request.param