Skip to content

Commit

Permalink
ENH - Add fixed-point distance strategy to build working sets in ``Pr…
Browse files Browse the repository at this point in the history
…oxNewton`` (#138)

Co-authored-by: Badr-MOUFAD <[email protected]>
  • Loading branch information
PABannier and Badr-MOUFAD authored Nov 9, 2023
1 parent 4f1951c commit 261fee0
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 19 deletions.
3 changes: 2 additions & 1 deletion doc/changes/0.4.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
.. _changes_0_4:

Version 0.4 (in progress)
---------------------------
-------------------------
- Add support for weights and positive coefficients to :ref:`MCPRegression Estimator <skglm.MCPRegression>` (PR: :gh:`184`)
- Move solver specific computations from ``Datafit.initialize()`` to separate ``Datafit`` methods to ease ``Solver`` - ``Datafit`` compatibility check (PR: :gh:`192`)
- Add :ref:`LogSumPenalty <skglm.penalties.LogSumPenalty>` (PR: :gh:`#127`)
- Add fixed-point distance to build working sets in :ref:`ProxNewton <skglm.solvers.ProxNewton>` solver (:gh:`138`)
2 changes: 1 addition & 1 deletion skglm/solvers/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def dist_fix_point_cd(w, grad_ws, lipschitz, datafit, penalty, ws):
dist : array, shape (n_features,)
Violation score for every feature.
"""
dist = np.zeros(ws.shape[0])
dist = np.zeros(ws.shape[0], dtype=w.dtype)

for idx, j in enumerate(ws):
if lipschitz[j] == 0.:
Expand Down
57 changes: 44 additions & 13 deletions skglm/solvers/prox_newton.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from numba import njit
from scipy.sparse import issparse
from skglm.solvers.base import BaseSolver
from skglm.solvers.common import dist_fix_point_cd

from sklearn.exceptions import ConvergenceWarning
from skglm.utils.sparse_ops import _sparse_xj_dot
Expand All @@ -28,6 +29,9 @@ class ProxNewton(BaseSolver):
tol : float, default 1e-4
Tolerance for convergence.
ws_strategy : ('subdiff'|'fixpoint'), optional
The score used to build the working set.
fit_intercept : bool, default True
If ``True``, fits an unpenalized intercept.
Expand All @@ -49,11 +53,13 @@ class ProxNewton(BaseSolver):
"""

def __init__(self, p0=10, max_iter=20, max_pn_iter=1000, tol=1e-4,
fit_intercept=True, warm_start=False, verbose=0):
ws_strategy="subdiff", fit_intercept=True, warm_start=False,
verbose=0):
self.p0 = p0
self.max_iter = max_iter
self.max_pn_iter = max_pn_iter
self.tol = tol
self.ws_strategy = ws_strategy
self.fit_intercept = fit_intercept
self.warm_start = warm_start
self.verbose = verbose
Expand All @@ -73,6 +79,9 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
if is_sparse:
X_bundles = (X.data, X.indptr, X.indices)

if self.ws_strategy == "fixpoint":
X_square = X.multiply(X) if is_sparse else X ** 2

if len(w) != n_features + self.fit_intercept:
if self.fit_intercept:
val_error_message = (
Expand All @@ -92,7 +101,13 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
else:
grad = _construct_grad(X, y, w[:n_features], Xw, datafit, all_features)

opt = penalty.subdiff_distance(w[:n_features], grad, all_features)
if self.ws_strategy == "subdiff":
opt = penalty.subdiff_distance(w[:n_features], grad, all_features)
elif self.ws_strategy == "fixpoint":
lipschitz = datafit.raw_hessian(y, Xw) @ X_square
opt = dist_fix_point_cd(
w[:n_features], grad, lipschitz, datafit, penalty, all_features
)

# optimality of intercept
if fit_intercept:
Expand Down Expand Up @@ -128,13 +143,13 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
for pn_iter in range(self.max_pn_iter):
# find descent direction
if is_sparse:
delta_w_ws, X_delta_w_ws = _descent_direction_s(
delta_w_ws, X_delta_w_ws, lipschitz_ws = _descent_direction_s(
*X_bundles, y, w, Xw, fit_intercept, grad_ws, datafit,
penalty, ws, tol=EPS_TOL*tol_in)
penalty, ws, tol=EPS_TOL*tol_in, ws_strategy=self.ws_strategy)
else:
delta_w_ws, X_delta_w_ws = _descent_direction(
delta_w_ws, X_delta_w_ws, lipschitz_ws = _descent_direction(
X, y, w, Xw, fit_intercept, grad_ws, datafit,
penalty, ws, tol=EPS_TOL*tol_in)
penalty, ws, tol=EPS_TOL*tol_in, ws_strategy=self.ws_strategy)

# backtracking line search with inplace update of w, Xw
if is_sparse:
Expand All @@ -147,7 +162,12 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
delta_w_ws, X_delta_w_ws, ws)

# check convergence
opt_in = penalty.subdiff_distance(w, grad_ws, ws)
if self.ws_strategy == "subdiff":
opt_in = penalty.subdiff_distance(w, grad_ws, ws)
elif self.ws_strategy == "fixpoint":
opt_in = dist_fix_point_cd(
w, grad_ws, lipschitz_ws, datafit, penalty, ws
)
stop_crit_in = np.max(opt_in)

if max(self.verbose-1, 0):
Expand Down Expand Up @@ -176,7 +196,7 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):

@njit
def _descent_direction(X, y, w_epoch, Xw_epoch, fit_intercept, grad_ws, datafit,
penalty, ws, tol):
penalty, ws, tol, ws_strategy):
# Given:
# 1) b = \nabla F(X w_epoch)
# 2) D = \nabla^2 F(X w_epoch) <------> raw_hess
Expand Down Expand Up @@ -229,7 +249,12 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, fit_intercept, grad_ws, datafit,
# TODO: can be improved by passing in w_ws but breaks for WeightedL1
current_w = w_epoch.copy()
current_w[ws_intercept] = w_ws
opt = penalty.subdiff_distance(current_w, past_grads, ws)
if ws_strategy == "subdiff":
opt = penalty.subdiff_distance(current_w, past_grads, ws)
elif ws_strategy == "fixpoint":
opt = dist_fix_point_cd(
current_w, past_grads, lipschitz, datafit, penalty, ws
)
stop_crit = np.max(opt)

if fit_intercept:
Expand All @@ -239,13 +264,14 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, fit_intercept, grad_ws, datafit,
break

# descent direction
return w_ws - w_epoch[ws_intercept], X_delta_w_ws
return w_ws - w_epoch[ws_intercept], X_delta_w_ws, lipschitz


# sparse version of _descent_direction
@njit
def _descent_direction_s(X_data, X_indptr, X_indices, y, w_epoch,
Xw_epoch, fit_intercept, grad_ws, datafit, penalty, ws, tol):
Xw_epoch, fit_intercept, grad_ws, datafit, penalty, ws, tol,
ws_strategy):
dtype = X_data.dtype
raw_hess = datafit.raw_hessian(y, Xw_epoch)

Expand Down Expand Up @@ -298,7 +324,12 @@ def _descent_direction_s(X_data, X_indptr, X_indices, y, w_epoch,
# TODO: could be improved by passing in w_ws
current_w = w_epoch.copy()
current_w[ws_intercept] = w_ws
opt = penalty.subdiff_distance(current_w, past_grads, ws)
if ws_strategy == "subdiff":
opt = penalty.subdiff_distance(current_w, past_grads, ws)
elif ws_strategy == "fixpoint":
opt = dist_fix_point_cd(
current_w, past_grads, lipschitz, datafit, penalty, ws
)
stop_crit = np.max(opt)

if fit_intercept:
Expand All @@ -308,7 +339,7 @@ def _descent_direction_s(X_data, X_indptr, X_indices, y, w_epoch,
break

# descent direction
return w_ws - w_epoch[ws_intercept], X_delta_w_ws
return w_ws - w_epoch[ws_intercept], X_delta_w_ws, lipschitz


@njit
Expand Down
10 changes: 6 additions & 4 deletions skglm/tests/test_prox_newton.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import pytest
import numpy as np
from itertools import product
from sklearn.linear_model import LogisticRegression

from skglm.penalties import L1
Expand All @@ -11,8 +10,10 @@
from skglm.utils.data import make_correlated_data


@pytest.mark.parametrize("X_density, fit_intercept", product([1., 0.5], [True, False]))
def test_pn_vs_sklearn(X_density, fit_intercept):
@pytest.mark.parametrize("X_density", [1., 0.5])
@pytest.mark.parametrize("fit_intercept", [True, False])
@pytest.mark.parametrize("ws_strategy", ["subdiff", "fixpoint"])
def test_pn_vs_sklearn(X_density, fit_intercept, ws_strategy):
n_samples, n_features = 12, 25
rho = 1e-1

Expand All @@ -30,7 +31,8 @@ def test_pn_vs_sklearn(X_density, fit_intercept):

log_datafit = compiled_clone(Logistic())
l1_penalty = compiled_clone(L1(alpha))
prox_solver = ProxNewton(fit_intercept=fit_intercept, tol=1e-12)
prox_solver = ProxNewton(
fit_intercept=fit_intercept, tol=1e-12, ws_strategy=ws_strategy)
w = prox_solver.solve(X, y, log_datafit, l1_penalty)[0]

np.testing.assert_allclose(w[:n_features], sk_log_reg.coef_.flatten())
Expand Down

0 comments on commit 261fee0

Please sign in to comment.