diff --git a/boruta/boruta_py.py b/boruta/boruta_py.py index 525b525..4eff5ac 100644 --- a/boruta/boruta_py.py +++ b/boruta/boruta_py.py @@ -114,6 +114,17 @@ class BorutaPy(BaseEstimator, TransformerMixin): - 1: displays iteration number - 2: which features have been selected already + early_stopping : bool, default = False + Whether to use early stopping to terminate the selection process + before reaching `max_iter` iterations if the algorithm cannot + confirm a tentative feature for `n_iter_no_change` iterations. + Will speed up the process at a cost of a possibility of a + worse result. + + n_iter_no_change : int, default = 20 + Ignored if `early_stopping` is False. The maximum amount of + iterations without confirming a tentative feature. + Attributes ---------- @@ -180,7 +191,8 @@ class BorutaPy(BaseEstimator, TransformerMixin): """ def __init__(self, estimator, n_estimators=1000, perc=100, alpha=0.05, - two_step=True, max_iter=100, random_state=None, verbose=0): + two_step=True, max_iter=100, random_state=None, verbose=0, + early_stopping=False, n_iter_no_change=20): self.estimator = estimator self.n_estimators = n_estimators self.perc = perc @@ -189,6 +201,8 @@ def __init__(self, estimator, n_estimators=1000, perc=100, alpha=0.05, self.max_iter = max_iter self.random_state = random_state self.verbose = verbose + self.early_stopping = early_stopping + self.n_iter_no_change = n_iter_no_change self.__version__ = '0.3' self._is_lightgbm = 'lightgbm' in str(type(self.estimator)) @@ -279,9 +293,25 @@ def _fit(self, X, y): y = self._validate_pandas_input(y) self.random_state = check_random_state(self.random_state) + + early_stopping = False + if self.early_stopping: + if self.n_iter_no_change >= self.max_iter: + if self.verbose > 0: + print( + f"n_iter_no_change is bigger or equal to max_iter" + f"({self.n_iter_no_change} >= {self.max_iter}), " + f"early stopping will not be used." + ) + else: + early_stopping = True + # setup variables for Boruta n_sample, n_feat = X.shape _iter = 1 + # early stopping vars + _same_iters = 1 + _last_dec_reg = None # holds the decision about each feature: # 0 - default state = tentative in original code # 1 - accepted in original code @@ -335,6 +365,21 @@ def _fit(self, X, y): self._print_results(dec_reg, _iter, 0) if _iter < self.max_iter: _iter += 1 + + # early stopping + if early_stopping: + if _last_dec_reg is not None and (_last_dec_reg == dec_reg).all(): + _same_iters += 1 + if self.verbose > 0: + print( + f"Early stopping: {_same_iters} out " + f"of {self.n_iter_no_change}" + ) + else: + _same_iters = 1 + _last_dec_reg = dec_reg.copy() + if _same_iters > self.n_iter_no_change: + break # we automatically apply R package's rough fix for tentative ones confirmed = np.where(dec_reg == 1)[0]