Skip to content

Commit

Permalink
Fix for issue scikit-learn-contrib#29
Browse files Browse the repository at this point in the history
  • Loading branch information
goerch committed Nov 6, 2019
1 parent 9400866 commit 37ac9be
Showing 1 changed file with 14 additions and 1 deletion.
15 changes: 14 additions & 1 deletion stability_selection/stability_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,19 @@ def _return_estimator_from_pipeline(pipeline):
return pipeline


def _return_support_from_pipeline(pipeline, variable_selector):
"""Returns the support of a Pipeline after variable selection"""
support = variable_selector.get_support()
if isinstance(pipeline, Pipeline):
reverse_iter = reversed(list(pipeline._iter()))
for idx, name, transform in reverse_iter:
if hasattr(transform, 'get_support'):
temp = transform.get_support()
temp[transform.get_support(indices=True)] = support
support = temp
return support


def _bootstrap_generator(n_bootstrap_iterations, bootstrap_func, y,
n_subsamples, random_state=None):
for _ in range(n_bootstrap_iterations):
Expand Down Expand Up @@ -109,7 +122,7 @@ def _fit_bootstrap_sample(base_estimator, X, y, lambda_name, lambda_value,
variable_selector = SelectFromModel(estimator=selector_model,
threshold=threshold,
prefit=True)
return variable_selector.get_support()
return _return_support_from_pipeline(base_estimator, variable_selector)


def plot_stability_path(stability_selection, threshold_highlight=None,
Expand Down

0 comments on commit 37ac9be

Please sign in to comment.