Skip to content

Commit

Permalink
Set alpha_min_ratio depending on n_samples/n_features ratio
Browse files Browse the repository at this point in the history
  • Loading branch information
sebp committed Jun 27, 2020
1 parent 83490e5 commit 0655c02
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 23 deletions.
46 changes: 28 additions & 18 deletions sksurv/linear_model/coxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,17 @@ class CoxnetSurvivalAnalysis(BaseEstimator, SurvivalAnalysisMixin):
List of alphas where to compute the models.
If ``None`` alphas are set automatically.
alpha_min_ratio : float, optional, default 0.0001
alpha_min_ratio : float or { "auto" }, optional, default: "auto"
Determines minimum alpha of the regularization path
if ``alphas`` is ``None``. The smallest value for alpha
is computed as the fraction of the data derived maximum
alpha (i.e. the smallest value for which all
coefficients are zero).
The default value of alpha_min_ratio will depend on the
sample size relative to the number of features in 0.13.
If `n_samples > n_features`, the current default value 0.0001
will be used. If `n_samples < n_features`, 0.01 will be used instead.
If set to "auto", the value will depend on the
sample size relative to the number of features.
If ``n_samples > n_features``, the default value is 0.0001
If ``n_samples <= n_features``, 0.01 is the default value.
l1_ratio : float, optional, default: 0.5
The ElasticNet mixing parameter, with ``0 < l1_ratio <= 1``.
Expand Down Expand Up @@ -99,6 +99,9 @@ class CoxnetSurvivalAnalysis(BaseEstimator, SurvivalAnalysisMixin):
alphas_ : ndarray, shape=(n_alphas,)
The actual sequence of alpha values used.
alpha_min_ratio_ : float
The inferred value of alpha_min_ratio.
penalty_factor_ : ndarray, shape=(n_features,)
The actual penalty factors used.
Expand All @@ -115,7 +118,7 @@ class CoxnetSurvivalAnalysis(BaseEstimator, SurvivalAnalysisMixin):
Journal of statistical software. 2011 Mar;39(5):1.
"""

def __init__(self, n_alphas=100, alphas=None, alpha_min_ratio="warn", l1_ratio=0.5,
def __init__(self, n_alphas=100, alphas=None, alpha_min_ratio="auto", l1_ratio=0.5,
penalty_factor=None, normalize=False, copy_X=True,
tol=1e-7, max_iter=100000, verbose=False, fit_baseline_model=False):
self.n_alphas = n_alphas
Expand Down Expand Up @@ -147,7 +150,7 @@ def _pre_fit(self, X, y):
time = time[o].astype(numpy.float64)
return X, event_num, time

def _check_params(self, n_features):
def _check_params(self, n_samples, n_features):
if not 0 < self.l1_ratio <= 1:
raise ValueError("l1_ratio must be in interval ]0;1], but was %f" % self.l1_ratio)

Expand Down Expand Up @@ -181,7 +184,21 @@ def _check_params(self, n_features):
if self.max_iter <= 0:
raise ValueError("max_iter must be a positive integer")

return create_path, alphas.astype(numpy.float64), penalty_factor.astype(numpy.float64)
if isinstance(self.alpha_min_ratio, str):
if self.alpha_min_ratio == "auto":
if n_samples > n_features:
alpha_min_ratio = 0.0001
else:
alpha_min_ratio = 0.01
else:
raise ValueError("Invalid value for alpha_min_ratio. "
"Allowed string values are 'auto'.")
else:
alpha_min_ratio = float(self.alpha_min_ratio)
if alpha_min_ratio <= 0 or not numpy.isfinite(alpha_min_ratio):
raise ValueError("alpha_min_ratio must be positive")

return create_path, alphas.astype(numpy.float64), penalty_factor.astype(numpy.float64), alpha_min_ratio

def fit(self, X, y):
"""Fit estimator.
Expand All @@ -201,19 +218,11 @@ def fit(self, X, y):
self
"""
X, event_num, time = self._pre_fit(X, y)
create_path, alphas, penalty = self._check_params(X.shape[1])

if self.alpha_min_ratio == 'warn':
warnings.warn("The default value of alpha_min_ratio will depend on the "
"sample size relative to the number of features in 0.13. "
"If n_samples > n_features, the current default value 0.0001 "
"will be used. If n_samples < n_features, 0.01 will be used instead.",
FutureWarning)
self.alpha_min_ratio = 0.0001
create_path, alphas, penalty, alpha_min_ratio = self._check_params(*X.shape)

coef, alphas, deviance_ratio, n_iter = call_fit_coxnet(
X, time, event_num, penalty, alphas, create_path,
self.alpha_min_ratio, self.l1_ratio, int(self.max_iter),
alpha_min_ratio, self.l1_ratio, int(self.max_iter),
self.tol, self.verbose)
assert numpy.isfinite(coef).all()

Expand All @@ -237,6 +246,7 @@ def fit(self, X, y):
self._baseline_models = None

self.alphas_ = alphas
self.alpha_min_ratio_ = alpha_min_ratio
self.penalty_factor_ = penalty
self.coef_ = coef
self.deviance_ratio_ = deviance_ratio
Expand Down
18 changes: 13 additions & 5 deletions tests/test_coxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,10 +520,16 @@ def test_invalid_alphas(self, infinite_float_array):
match="Input contains NaN, infinity or a value too large"):
self._fit_example(alpha_min_ratio=0.0001, alphas=infinite_float_array)

def test_alpha_min_ratio_future_warning(self):
with pytest.warns(FutureWarning,
match="The default value of alpha_min_ratio will depend "):
self._fit_example()
def test_invalid_alpha_min_ratio_string(self):
with pytest.raises(ValueError,
match="Invalid value for alpha_min_ratio"):
self._fit_example(alpha_min_ratio="max")

@pytest.mark.parametrize("value", [0.0, -1e-12, -1, -numpy.infty, numpy.nan])
def test_invalid_alpha_min_ratio_float(self, value):
with pytest.raises(ValueError,
match="alpha_min_ratio must be positive"):
self._fit_example(alpha_min_ratio=value)

@staticmethod
def test_alpha_too_small():
Expand Down Expand Up @@ -551,9 +557,11 @@ def test_breast_example():
x, y = load_breast_cancer()
x = column.encode_categorical(x)

coxnet = CoxnetSurvivalAnalysis(alpha_min_ratio=0.0001, l1_ratio=1.0)
coxnet = CoxnetSurvivalAnalysis(l1_ratio=1.0)
coxnet.fit(x.values, y)

assert coxnet.alpha_min_ratio_ == 0.0001

expected_alphas = numpy.array([
0.207764947265866, 0.189307681974955, 0.172490109262135, 0.157166563357949, 0.143204319038428,
0.130482442022696, 0.118890741498079, 0.108328815700004, 0.0987051822799425, 0.0899364859290742,
Expand Down

0 comments on commit 0655c02

Please sign in to comment.