Skip to content

Commit

Permalink
Early stopping support (#94)
Browse files Browse the repository at this point in the history
* Early stopping support

* Typo fix
  • Loading branch information
Yard1 authored Feb 14, 2021
1 parent cc56e2f commit f2f1e3c
Showing 1 changed file with 46 additions and 1 deletion.
47 changes: 46 additions & 1 deletion boruta/boruta_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,17 @@ class BorutaPy(BaseEstimator, TransformerMixin):
- 1: displays iteration number
- 2: which features have been selected already
early_stopping : bool, default = False
Whether to use early stopping to terminate the selection process
before reaching `max_iter` iterations if the algorithm cannot
confirm a tentative feature for `n_iter_no_change` iterations.
Will speed up the process at a cost of a possibility of a
worse result.
n_iter_no_change : int, default = 20
Ignored if `early_stopping` is False. The maximum amount of
iterations without confirming a tentative feature.
Attributes
----------
Expand Down Expand Up @@ -180,7 +191,8 @@ class BorutaPy(BaseEstimator, TransformerMixin):
"""

def __init__(self, estimator, n_estimators=1000, perc=100, alpha=0.05,
two_step=True, max_iter=100, random_state=None, verbose=0):
two_step=True, max_iter=100, random_state=None, verbose=0,
early_stopping=False, n_iter_no_change=20):
self.estimator = estimator
self.n_estimators = n_estimators
self.perc = perc
Expand All @@ -189,6 +201,8 @@ def __init__(self, estimator, n_estimators=1000, perc=100, alpha=0.05,
self.max_iter = max_iter
self.random_state = random_state
self.verbose = verbose
self.early_stopping = early_stopping
self.n_iter_no_change = n_iter_no_change
self.__version__ = '0.3'
self._is_lightgbm = 'lightgbm' in str(type(self.estimator))

Expand Down Expand Up @@ -279,9 +293,25 @@ def _fit(self, X, y):
y = self._validate_pandas_input(y)

self.random_state = check_random_state(self.random_state)

early_stopping = False
if self.early_stopping:
if self.n_iter_no_change >= self.max_iter:
if self.verbose > 0:
print(
f"n_iter_no_change is bigger or equal to max_iter"
f"({self.n_iter_no_change} >= {self.max_iter}), "
f"early stopping will not be used."
)
else:
early_stopping = True

# setup variables for Boruta
n_sample, n_feat = X.shape
_iter = 1
# early stopping vars
_same_iters = 1
_last_dec_reg = None
# holds the decision about each feature:
# 0 - default state = tentative in original code
# 1 - accepted in original code
Expand Down Expand Up @@ -335,6 +365,21 @@ def _fit(self, X, y):
self._print_results(dec_reg, _iter, 0)
if _iter < self.max_iter:
_iter += 1

# early stopping
if early_stopping:
if _last_dec_reg is not None and (_last_dec_reg == dec_reg).all():
_same_iters += 1
if self.verbose > 0:
print(
f"Early stopping: {_same_iters} out "
f"of {self.n_iter_no_change}"
)
else:
_same_iters = 1
_last_dec_reg = dec_reg.copy()
if _same_iters > self.n_iter_no_change:
break

# we automatically apply R package's rough fix for tentative ones
confirmed = np.where(dec_reg == 1)[0]
Expand Down

0 comments on commit f2f1e3c

Please sign in to comment.