From 2c9342f604c2d3acb0c4f7646507f853fbf709a6 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Sat, 15 Oct 2022 18:18:34 +0200 Subject: [PATCH 01/25] add group logreg --- skglm/datafits/group.py | 75 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/skglm/datafits/group.py b/skglm/datafits/group.py index 7ce9295d3..97e8a1821 100644 --- a/skglm/datafits/group.py +++ b/skglm/datafits/group.py @@ -71,3 +71,78 @@ def gradient_scalar(self, X, y, w, Xw, j): def intercept_update_step(self, y, Xw): return np.mean(Xw - y) + + +class LogisticGroup(BaseDatafit): + r"""Logistic datafit used with group penalties. + + The datafit reads:: + + (1 / n_samples) * \sum_i log(1 + exp(-y_i * Xw_i)) + + Attributes + ---------- + grp_indices : array, shape (n_features,) + The group indices stacked contiguously + ([grp1_indices, grp2_indices, ...]). + + grp_ptr : array, shape (n_groups + 1,) + The group pointers such that two consecutive elements delimit + the indices of a group in ``grp_indices``. + + lipschitz : array, shape (n_groups,) + The lipschitz constants for each group. + """ + + def __init__(self, grp_ptr, grp_indices): + self.grp_ptr, self.grp_indices = grp_ptr, grp_indices + + def get_spec(self): + spec = ( + ('grp_ptr', int32[:]), + ('grp_indices', int32[:]), + ('lipschitz', float64[:]) + ) + return spec + + def params_to_dict(self): + return dict(grp_ptr=self.grp_ptr, + grp_indices=self.grp_indices) + + def initialize(self, X, y): + grp_ptr, grp_indices = self.grp_ptr, self.grp_indices + n_groups = len(grp_ptr) - 1 + + lipschitz = np.zeros(n_groups) + for g in range(n_groups): + grp_g_indices = grp_indices[grp_ptr[g]: grp_ptr[g+1]] + X_g = X[:, grp_g_indices] + lipschitz[g] = norm(X_g, ord=2) ** 2 / (4 * len(y)) + + self.lipschitz = lipschitz + + def value(self, y, w, Xw): + return np.log(1 + np.exp(-y * Xw)).sum() / len(y) + + def raw_grad(self, y, Xw): + """Compute gradient of datafit w.r.t ``Xw``.""" + return -y / (1 + np.exp(y * Xw)) / len(y) + + def raw_hessian(self, y, Xw): + """Compute Hessian of datafit w.r.t ``Xw``.""" + exp_minus_yXw = np.exp(-y * Xw) + return exp_minus_yXw / (1 + exp_minus_yXw) ** 2 / len(y) + + def gradient_g(self, X, y, w, Xw, g): + grp_ptr, grp_indices = self.grp_ptr, self.grp_indices + grp_g_indices = grp_indices[grp_ptr[g]: grp_ptr[g+1]] + raw_grad_val = self.raw_grad(y, Xw) + + grad_g = np.zeros(len(grp_g_indices)) + for idx, j in enumerate(grp_g_indices): + grad_g[idx] = X[:, j] @ raw_grad_val + + return grad_g + + def intercept_update_step(self, y, Xw): + return np.mean(self.raw_grad(y, Xw)) / 4 From ab3eeca52fdb94f99516987a37349768d75153ae Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Sat, 15 Oct 2022 18:18:55 +0200 Subject: [PATCH 02/25] unittest group logreg --- skglm/tests/test_group.py | 64 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 63 insertions(+), 1 deletion(-) diff --git a/skglm/tests/test_group.py b/skglm/tests/test_group.py index a02b8dddf..6a9339894 100644 --- a/skglm/tests/test_group.py +++ b/skglm/tests/test_group.py @@ -5,12 +5,13 @@ from skglm.penalties import L1 from skglm.datafits import Quadratic from skglm.penalties.block_separable import WeightedGroupL2 -from skglm.datafits.group import QuadraticGroup +from skglm.datafits.group import QuadraticGroup, LogisticGroup from skglm.solvers import GroupBCD from skglm.utils import ( _alpha_max_group_lasso, grp_converter, make_correlated_data, compiled_clone, AndersonAcceleration) from celer import GroupLasso, Lasso +from sklearn.linear_model import LogisticRegression def _generate_random_grp(n_groups, n_features, shuffle=True): @@ -160,6 +161,66 @@ def test_intercept_grouplasso(): np.testing.assert_allclose(model.intercept_, w[-1], atol=1e-5) +@pytest.mark.parametrize("rho", [1e-1, 1e-2]) +def test_equivalence_logreg(rho): + n_samples, n_features = 30, 50 + rnd = np.random.RandomState(1123) + X, y, _ = make_correlated_data(n_samples, n_features, random_state=rnd) + y = np.sign(y) + + grp_indices, grp_ptr = grp_converter(1, n_features, shuffle=False) + weights = np.ones(n_features) + alpha_max = norm(X.T @ y, ord=np.inf) / (2 * n_samples) + alpha = rho * alpha_max / 10. + + log_group = LogisticGroup(grp_ptr=grp_ptr, grp_indices=grp_indices) + group_penalty = WeightedGroupL2( + alpha=alpha, grp_ptr=grp_ptr, + grp_indices=grp_indices, weights=weights) + + log_group = compiled_clone(log_group, to_float32=X.dtype == np.float32) + group_penalty = compiled_clone(group_penalty) + w = GroupBCD(tol=1e-12).solve(X, y, log_group, group_penalty)[0] + + sk_logreg = LogisticRegression(penalty='l1', C=1/(n_samples * alpha), + fit_intercept=False, tol=1e-12, solver='liblinear') + sk_logreg.fit(X, y) + + np.testing.assert_allclose(sk_logreg.coef_.flatten(), w) + + +@pytest.mark.parametrize("n_groups, rho", [[15, 1e-1], [25, 1e-2]]) +def test_group_logreg(n_groups, rho): + n_samples, n_features, shuffle = 30, 100, True + random_state = 123 + + X, y, _ = make_correlated_data(n_samples, n_features, random_state=random_state) + y = np.sign(y) + + np.random.seed(random_state) + weights = np.abs(np.random.randn(n_groups)) + grp_indices, grp_ptr, _ = _generate_random_grp(n_groups, n_features, shuffle) + + alpha_max = 0. + for g in range(n_groups): + grp_g_indices = grp_indices[grp_ptr[g]: grp_ptr[g+1]] + alpha_max = max( + alpha_max, + norm(X[:, grp_g_indices].T @ y) / n_samples / weights[g] + ) + alpha = rho * alpha_max + + # skglm + log_group = LogisticGroup(grp_ptr=grp_ptr, grp_indices=grp_indices) + group_penalty = WeightedGroupL2(alpha, weights, grp_ptr, grp_indices) + + log_group = compiled_clone(log_group, to_float32=X.dtype == np.float32) + group_penalty = compiled_clone(group_penalty) + stop_crit = GroupBCD(tol=1e-12).solve(X, y, log_group, group_penalty)[2] + + np.testing.assert_array_less(stop_crit, 1e-12) + + def test_anderson_acceleration(): # VAR: w = rho * w + 1 with |rho| < 1 # converges to w_star = 1 / (1 - rho) @@ -201,4 +262,5 @@ def test_anderson_acceleration(): if __name__ == "__main__": + test_group_logreg(20, 1e-1) pass From bfd49e3f102905a3da0217287d1d088cd987d191 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Sat, 15 Oct 2022 18:19:23 +0200 Subject: [PATCH 03/25] add group log to api --- doc/api.rst | 1 + skglm/datafits/__init__.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index 77a44670c..4d003891e 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -55,6 +55,7 @@ Datafits Huber Logistic + LogisticGroup Quadratic QuadraticGroup QuadraticSVC diff --git a/skglm/datafits/__init__.py b/skglm/datafits/__init__.py index 0c236a16b..fad030ac3 100644 --- a/skglm/datafits/__init__.py +++ b/skglm/datafits/__init__.py @@ -1,12 +1,12 @@ from .base import BaseDatafit, BaseMultitaskDatafit from .single_task import Quadratic, QuadraticSVC, Logistic, Huber, Poisson from .multi_task import QuadraticMultiTask -from .group import QuadraticGroup +from .group import QuadraticGroup, LogisticGroup __all__ = [ BaseDatafit, BaseMultitaskDatafit, Quadratic, QuadraticSVC, Logistic, Huber, Poisson, QuadraticMultiTask, - QuadraticGroup + QuadraticGroup, LogisticGroup ] From 326636fde9cb4989d8a00a972646512b980846ae Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Sun, 16 Oct 2022 00:22:42 +0200 Subject: [PATCH 04/25] fix unittest --- skglm/tests/test_group.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/skglm/tests/test_group.py b/skglm/tests/test_group.py index 6a9339894..c3b476a69 100644 --- a/skglm/tests/test_group.py +++ b/skglm/tests/test_group.py @@ -168,7 +168,7 @@ def test_equivalence_logreg(rho): X, y, _ = make_correlated_data(n_samples, n_features, random_state=rnd) y = np.sign(y) - grp_indices, grp_ptr = grp_converter(1, n_features, shuffle=False) + grp_indices, grp_ptr = grp_converter(1, n_features) weights = np.ones(n_features) alpha_max = norm(X.T @ y, ord=np.inf) / (2 * n_samples) alpha = rho * alpha_max / 10. @@ -186,12 +186,12 @@ def test_equivalence_logreg(rho): fit_intercept=False, tol=1e-12, solver='liblinear') sk_logreg.fit(X, y) - np.testing.assert_allclose(sk_logreg.coef_.flatten(), w) + np.testing.assert_allclose(sk_logreg.coef_.flatten(), w, atol=1e-6, rtol=1e-5) @pytest.mark.parametrize("n_groups, rho", [[15, 1e-1], [25, 1e-2]]) def test_group_logreg(n_groups, rho): - n_samples, n_features, shuffle = 30, 100, True + n_samples, n_features, shuffle = 30, 60, True random_state = 123 X, y, _ = make_correlated_data(n_samples, n_features, random_state=random_state) From b9dafc9a0f543edeb56fecd8e13e9788e3db4f0a Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Sun, 16 Oct 2022 00:39:23 +0200 Subject: [PATCH 05/25] cleanups --- skglm/tests/test_group.py | 1 - 1 file changed, 1 deletion(-) diff --git a/skglm/tests/test_group.py b/skglm/tests/test_group.py index c3b476a69..d072d2a77 100644 --- a/skglm/tests/test_group.py +++ b/skglm/tests/test_group.py @@ -262,5 +262,4 @@ def test_anderson_acceleration(): if __name__ == "__main__": - test_group_logreg(20, 1e-1) pass From 747493bf0b217dac530b0a2fb932157b5db85f28 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Sun, 16 Oct 2022 12:39:13 +0200 Subject: [PATCH 06/25] init group prox newton --- skglm/solvers/group_prox_newton.py | 297 +++++++++++++++++++++++++++++ 1 file changed, 297 insertions(+) create mode 100644 skglm/solvers/group_prox_newton.py diff --git a/skglm/solvers/group_prox_newton.py b/skglm/solvers/group_prox_newton.py new file mode 100644 index 000000000..071f86516 --- /dev/null +++ b/skglm/solvers/group_prox_newton.py @@ -0,0 +1,297 @@ +import numpy as np +from numba import njit +from scipy.sparse import issparse +from skglm.solvers.base import BaseSolver + + +EPS_TOL = 0.3 +MAX_CD_ITER = 20 +MAX_BACKTRACK_ITER = 20 + + +class GroupProxNewton(BaseSolver): + """Prox Newton solver combined with working sets. + + p0 : int, default 10 + Minimum number of features to be included in the working set. + + max_iter : int, default 20 + Maximum number of outer iterations. + + max_pn_iter : int, default 1000 + Maximum number of prox Newton iterations on each subproblem. + + tol : float, default 1e-4 + Tolerance for convergence. + + verbose : bool, default False + Amount of verbosity. 0/False is silent. + + References + ---------- + .. [1] Massias, M. and Vaiter, S. and Gramfort, A. and Salmon, J. + "Dual Extrapolation for Sparse Generalized Linear Models", JMLR, 2020, + https://arxiv.org/abs/1907.05830 + code: https://github.com/mathurinm/celer + + .. [2] Johnson, T. B. and Guestrin, C. + "Blitz: A principled meta-algorithm for scaling sparse optimization", + ICML, 2015. + https://proceedings.mlr.press/v37/johnson15.html + code: https://github.com/tbjohns/BlitzL1 + """ + + def __init__(self, p0=10, max_iter=20, max_pn_iter=1000, tol=1e-4, + fit_intercept=False, warm_start=False, verbose=0): + self.p0 = p0 + self.max_iter = max_iter + self.max_pn_iter = max_pn_iter + self.tol = tol + self.fit_intercept = fit_intercept + self.warm_start = warm_start + self.verbose = verbose + + def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): + n_samples, n_features = X.shape + w = np.zeros(n_features) if w_init is None else w_init + Xw = np.zeros(n_samples) if Xw_init is None else Xw_init + + for t in range(self.max_iter): + + # compute grad + + # check convergence + + # construct ws + + for pn_iter in range(self.max_pn_iter): + # find descent direction + + # find a suitable step size + pass + pass + return + + +@njit +def _descent_direction(X, y, w_epoch, Xw_epoch, grad_ws, datafit, + penalty, ws, tol): + # Given: + # 1) b = \nabla F(X w_epoch) + # 2) D = \nabla^2 F(X w_epoch) <------> raw_hess + # Minimize quadratic approximation for delta_w = w - w_epoch: + # b.T @ X @ delta_w + \ + # 1/2 * delta_w.T @ (X.T @ D @ X) @ delta_w + penalty(w) + raw_hess = datafit.raw_hessian(y, Xw_epoch) + + lipschitz = np.zeros(len(ws)) + for idx, j in enumerate(ws): + lipschitz[idx] = raw_hess @ X[:, j] ** 2 + + # for a less costly stopping criterion, we do no compute the exact gradient, + # but store each coordinate-wise gradient every time we upate one coordinate: + past_grads = np.zeros(len(ws)) + X_delta_w_ws = np.zeros(X.shape[0]) + w_ws = w_epoch[ws] + + for cd_iter in range(MAX_CD_ITER): + for idx, j in enumerate(ws): + # skip when X[:, j] == 0 + if lipschitz[idx] == 0: + continue + + past_grads[idx] = grad_ws[idx] + X[:, j] @ (raw_hess * X_delta_w_ws) + old_w_idx = w_ws[idx] + stepsize = 1 / lipschitz[idx] + + w_ws[idx] = penalty.prox_1d( + old_w_idx - stepsize * past_grads[idx], stepsize, j) + + if w_ws[idx] != old_w_idx: + X_delta_w_ws += (w_ws[idx] - old_w_idx) * X[:, j] + + if cd_iter % 5 == 0: + # TODO: can be improved by passing in w_ws but breaks for WeightedL1 + current_w = w_epoch.copy() + current_w[ws] = w_ws + opt = penalty.subdiff_distance(current_w, past_grads, ws) + if np.max(opt) <= tol: + break + + # descent direction + return w_ws - w_epoch[ws], X_delta_w_ws + + +# sparse version of _compute_descent_direction +@njit +def _descent_direction_s(X_data, X_indptr, X_indices, y, w_epoch, + Xw_epoch, grad_ws, datafit, penalty, ws, tol): + raw_hess = datafit.raw_hessian(y, Xw_epoch) + + lipschitz = np.zeros(len(ws)) + for idx, j in enumerate(ws): + # equivalent to: lipschitz[idx] += raw_hess * X[:, j] ** 2 + lipschitz[idx] = _sparse_squared_weighted_norm( + X_data, X_indptr, X_indices, j, raw_hess) + + # see _descent_direction() comment + past_grads = np.zeros(len(ws)) + X_delta_w_ws = np.zeros(len(y)) + w_ws = w_epoch[ws] + + for cd_iter in range(MAX_CD_ITER): + for idx, j in enumerate(ws): + # skip when X[:, j] == 0 + if lipschitz[idx] == 0: + continue + + past_grads[idx] = grad_ws[idx] + # equivalent to cached_grads[idx] += X[:, j] @ (raw_hess * X_delta_w_ws) + past_grads[idx] += _sparse_weighted_dot( + X_data, X_indptr, X_indices, j, X_delta_w_ws, raw_hess) + + old_w_idx = w_ws[idx] + stepsize = 1 / lipschitz[idx] + + w_ws[idx] = penalty.prox_1d( + old_w_idx - stepsize * past_grads[idx], stepsize, j) + + if w_ws[idx] != old_w_idx: + _update_X_delta_w(X_data, X_indptr, X_indices, X_delta_w_ws, + w_ws[idx] - old_w_idx, j) + + if cd_iter % 5 == 0: + # TODO: could be improved by passing in w_ws + current_w = w_epoch.copy() + current_w[ws] = w_ws + opt = penalty.subdiff_distance(current_w, past_grads, ws) + if np.max(opt) <= tol: + break + + # descent direction + return w_ws - w_epoch[ws], X_delta_w_ws + + +@njit +def _backtrack_line_search(X, y, w, Xw, datafit, penalty, delta_w_ws, + X_delta_w_ws, ws): + # 1) find step in [0, 1] such that: + # penalty(w + step * delta_w) - penalty(w) + + # step * \nabla datafit(w + step * delta_w) @ delta_w < 0 + # ref: https://www.di.ens.fr/~aspremon/PDF/ENSAE/Newton.pdf + # 2) inplace update of w and Xw and return grad_ws of the last w and Xw + step, prev_step = 1., 0. + # TODO: could be improved by passing in w[ws] + old_penalty_val = penalty.value(w) + + # try step = 1, 1/2, 1/4, ... + for _ in range(MAX_BACKTRACK_ITER): + w[ws] += (step - prev_step) * delta_w_ws + Xw += (step - prev_step) * X_delta_w_ws + + grad_ws = _construct_grad(X, y, w, Xw, datafit, ws) + # TODO: could be improved by passing in w[ws] + stop_crit = penalty.value(w) - old_penalty_val + stop_crit += step * grad_ws @ delta_w_ws + + if stop_crit < 0: + break + else: + prev_step = step + step /= 2 + else: + pass + # TODO this case is not handled yet + + return grad_ws + + +# sparse version of _backtrack_line_search +@njit +def _backtrack_line_search_s(X_data, X_indptr, X_indices, y, w, Xw, datafit, + penalty, delta_w_ws, X_delta_w_ws, ws): + step, prev_step = 1., 0. + # TODO: could be improved by passing in w[ws] + old_penalty_val = penalty.value(w) + + for _ in range(MAX_BACKTRACK_ITER): + w[ws] += (step - prev_step) * delta_w_ws + Xw += (step - prev_step) * X_delta_w_ws + + grad_ws = _construct_grad_sparse(X_data, X_indptr, X_indices, + y, w, Xw, datafit, ws) + # TODO: could be improved by passing in w[ws] + stop_crit = penalty.value(w) - old_penalty_val + stop_crit += step * grad_ws.T @ delta_w_ws + + if stop_crit < 0: + break + else: + prev_step = step + step /= 2 + else: + pass # TODO + + return grad_ws + + +@njit +def _construct_grad(X, y, w, Xw, datafit, ws): + # Compute grad of datafit restricted to ws. This function avoids + # recomputing raw_grad for every j, which is costly for logreg + grp_ptr, grp_indices = datafit.grp_ptr, datafit.grp_indices + n_features_ws = sum([grp_ptr[g+1] - grp_ptr[g] for g in ws]) + + raw_grad = datafit.raw_grad(y, Xw) + grad = np.zeros(len(ws)) + + for idx, g in enumerate(ws): + # compute grad_g + grp_g_indices = grp_indices[grp_ptr[g]:grp_ptr[g+1]] + for j in grp_g_indices: + grad[idx] = X[:, j] @ raw_grad + return grad + + +@njit +def _construct_grad_sparse(X_data, X_indptr, X_indices, y, w, Xw, datafit, ws): + # Compute grad of datafit restricted to ws in case X sparse + raw_grad = datafit.raw_grad(y, Xw) + grad = np.zeros(len(ws)) + for idx, j in enumerate(ws): + grad[idx] = _sparse_xj_dot(X_data, X_indptr, X_indices, j, raw_grad) + return grad + + +@njit(fastmath=True) +def _sparse_xj_dot(X_data, X_indptr, X_indices, j, other): + # Compute X[:, j] @ other in case X sparse + res = 0. + for i in range(X_indptr[j], X_indptr[j+1]): + res += X_data[i] * other[X_indices[i]] + return res + + +@njit(fastmath=True) +def _sparse_weighted_dot(X_data, X_indptr, X_indices, j, other, weights): + # Compute X[:, j] @ (weights * other) in case X sparse + res = 0. + for i in range(X_indptr[j], X_indptr[j+1]): + res += X_data[i] * other[X_indices[i]] * weights[X_indices[i]] + return res + + +@njit(fastmath=True) +def _sparse_squared_weighted_norm(X_data, X_indptr, X_indices, j, weights): + # Compute weights @ X[:, j]**2 in case X sparse + res = 0. + for i in range(X_indptr[j], X_indptr[j+1]): + res += weights[X_indices[i]] * X_data[i]**2 + return res + + +@njit(fastmath=True) +def _update_X_delta_w(X_data, X_indptr, X_indices, X_delta_w, diff, j): + # Compute X_delta_w += diff * X[:, j] in case of X sparse + for i in range(X_indptr[j], X_indptr[j+1]): + X_delta_w[X_indices[i]] += diff * X_data[i] From e27a2e70b96861ebd9c1915e7588d3e783cbd84c Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Sun, 16 Oct 2022 17:06:34 +0200 Subject: [PATCH 07/25] implement group prox newton --- skglm/solvers/group_prox_newton.py | 266 +++++++++++++---------------- 1 file changed, 118 insertions(+), 148 deletions(-) diff --git a/skglm/solvers/group_prox_newton.py b/skglm/solvers/group_prox_newton.py index 071f86516..9f852fec8 100644 --- a/skglm/solvers/group_prox_newton.py +++ b/skglm/solvers/group_prox_newton.py @@ -1,5 +1,6 @@ import numpy as np from numba import njit +from numpy.linalg import norm from scipy.sparse import issparse from skglm.solvers.base import BaseSolver @@ -53,24 +54,71 @@ def __init__(self, p0=10, max_iter=20, max_pn_iter=1000, tol=1e-4, def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): n_samples, n_features = X.shape + n_groups = len(penalty.grp_ptr) - 1 + w = np.zeros(n_features) if w_init is None else w_init Xw = np.zeros(n_samples) if Xw_init is None else Xw_init + all_groups = np.arange(n_groups) + stop_crit = 0. + p_objs_out = [] for t in range(self.max_iter): - # compute grad + grad = -_construct_grad(X, y, w, Xw, datafit, all_groups) # check convergence + opt = penalty.subdiff_distance(w, -grad, all_groups) + stop_crit = np.max(opt) + + if self.verbose: + p_obj = datafit.value(y, w, Xw) + penalty.value(w) + print( + f"Iteration {t+1}: {p_obj:.10f}, " + f"stopping crit: {stop_crit:.2e}" + ) + + if stop_crit <= self.tol: + break # construct ws + gsupp_size = penalty.generalized_support(w).sum() + ws_size = max(min(self.p0, n_groups), + min(n_groups, 2 * gsupp_size)) + ws = np.argpartition(opt, -ws_size)[-ws_size:] # k-largest items (no sort) + n_features_ws = sum([penalty.grp_ptr[g+1] - penalty.grp_ptr[g] for g in ws]) + grad_ws = grad[:n_features_ws] + tol_in = EPS_TOL * stop_crit + + # solve subproblem for pn_iter in range(self.max_pn_iter): # find descent direction + delta_w_ws, X_delta_w_ws = _descent_direction( + X, y, w, Xw, grad_ws, datafit, penalty, ws, tol=EPS_TOL*tol_in) # find a suitable step size - pass - pass - return + grad_ws[:] = _backtrack_line_search( + X, y, w, Xw, datafit, penalty, delta_w_ws, X_delta_w_ws, ws) + + # check convergence + opt_in = penalty.subdiff_distance(w, -grad_ws, ws) + stop_crit_in = np.max(opt_in) + + if max(self.verbose-1, 0): + p_obj = datafit.value(y, w, Xw) + penalty.value(w) + print( + "PN iteration {}: {:.10f}, ".format(pn_iter+1, p_obj) + + "stopping crit in: {:.2e}".format(stop_crit_in) + ) + + if stop_crit_in <= tol_in: + if max(self.verbose-1, 0): + print("Early exit") + break + + p_obj = datafit.value(y, w, Xw) + penalty.value(w) + p_objs_out.append(p_obj) + return w, np.asarray(p_objs_out), stop_crit @njit @@ -82,94 +130,77 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, grad_ws, datafit, # Minimize quadratic approximation for delta_w = w - w_epoch: # b.T @ X @ delta_w + \ # 1/2 * delta_w.T @ (X.T @ D @ X) @ delta_w + penalty(w) - raw_hess = datafit.raw_hessian(y, Xw_epoch) + grp_ptr = penalty.grp_ptr + grp_indices = penalty.grp_indices + n_features_ws = sum([penalty.grp_ptr[g+1] - penalty.grp_ptr[g] for g in ws]) - lipschitz = np.zeros(len(ws)) - for idx, j in enumerate(ws): - lipschitz[idx] = raw_hess @ X[:, j] ** 2 + raw_hess = datafit.raw_hessian(y, Xw_epoch) + lipchitz = np.zeros(len(ws)) + for idx, g in enumerate(ws): + grp_g_indices = grp_indices[grp_ptr[g]:grp_ptr[g+1]] + lipchitz[idx] = norm(np.sqrt(raw_hess) * X[:, grp_g_indices], ord=2) ** 2 # for a less costly stopping criterion, we do no compute the exact gradient, # but store each coordinate-wise gradient every time we upate one coordinate: - past_grads = np.zeros(len(ws)) + past_grads = np.zeros(n_features_ws) X_delta_w_ws = np.zeros(X.shape[0]) - w_ws = w_epoch[ws] + + w_ws = np.zeros(n_features_ws) + w_ptr = 0 + for g in ws: + grp_g_indices = grp_indices[grp_ptr[g]:grp_ptr[g+1]] + w_ws[w_ptr:w_ptr+len(grp_g_indices)] = w_epoch[grp_g_indices] + w_ptr += len(grp_g_indices) for cd_iter in range(MAX_CD_ITER): - for idx, j in enumerate(ws): - # skip when X[:, j] == 0 - if lipschitz[idx] == 0: + ptr = 0 + for idx, g in enumerate(ws): + # skip when X[:, grp_g_indices] == 0 + if lipchitz[idx] == 0: continue - past_grads[idx] = grad_ws[idx] + X[:, j] @ (raw_hess * X_delta_w_ws) - old_w_idx = w_ws[idx] - stepsize = 1 / lipschitz[idx] - - w_ws[idx] = penalty.prox_1d( - old_w_idx - stepsize * past_grads[idx], stepsize, j) + grp_g_indices = grp_indices[grp_ptr[g]:grp_ptr[g+1]] + range_grp_g = slice(ptr, ptr + len(grp_g_indices)) - if w_ws[idx] != old_w_idx: - X_delta_w_ws += (w_ws[idx] - old_w_idx) * X[:, j] + past_grads[range_grp_g] = grad_ws[range_grp_g] + # TODO: compute without copying the cols of X + past_grads[range_grp_g] += X[:, grp_g_indices] @ (raw_hess * X_delta_w_ws) - if cd_iter % 5 == 0: - # TODO: can be improved by passing in w_ws but breaks for WeightedL1 - current_w = w_epoch.copy() - current_w[ws] = w_ws - opt = penalty.subdiff_distance(current_w, past_grads, ws) - if np.max(opt) <= tol: - break - - # descent direction - return w_ws - w_epoch[ws], X_delta_w_ws - - -# sparse version of _compute_descent_direction -@njit -def _descent_direction_s(X_data, X_indptr, X_indices, y, w_epoch, - Xw_epoch, grad_ws, datafit, penalty, ws, tol): - raw_hess = datafit.raw_hessian(y, Xw_epoch) - - lipschitz = np.zeros(len(ws)) - for idx, j in enumerate(ws): - # equivalent to: lipschitz[idx] += raw_hess * X[:, j] ** 2 - lipschitz[idx] = _sparse_squared_weighted_norm( - X_data, X_indptr, X_indices, j, raw_hess) - - # see _descent_direction() comment - past_grads = np.zeros(len(ws)) - X_delta_w_ws = np.zeros(len(y)) - w_ws = w_epoch[ws] - - for cd_iter in range(MAX_CD_ITER): - for idx, j in enumerate(ws): - # skip when X[:, j] == 0 - if lipschitz[idx] == 0: - continue + old_w_ws_g = w_ws[range_grp_g] + stepsize = 1 / lipchitz[idx] - past_grads[idx] = grad_ws[idx] - # equivalent to cached_grads[idx] += X[:, j] @ (raw_hess * X_delta_w_ws) - past_grads[idx] += _sparse_weighted_dot( - X_data, X_indptr, X_indices, j, X_delta_w_ws, raw_hess) + w_ws[range_grp_g] = penalty.prox_1group( + old_w_ws_g - stepsize * past_grads[range_grp_g], stepsize, g) - old_w_idx = w_ws[idx] - stepsize = 1 / lipschitz[idx] + for idx_j, j in enumerate(grp_g_indices): + if w_ws[ptr + idx_j] != old_w_ws_g[idx_j]: + X_delta_w_ws += (w_ws[ptr + idx_j] - old_w_ws_g[j]) * X[:, j] - w_ws[idx] = penalty.prox_1d( - old_w_idx - stepsize * past_grads[idx], stepsize, j) - - if w_ws[idx] != old_w_idx: - _update_X_delta_w(X_data, X_indptr, X_indices, X_delta_w_ws, - w_ws[idx] - old_w_idx, j) + ptr += len(grp_g_indices) if cd_iter % 5 == 0: - # TODO: could be improved by passing in w_ws + # TODO: can be improved by passing in w_ws current_w = w_epoch.copy() - current_w[ws] = w_ws + + ptr = 0 + for g in ws: + grp_g_indices = grp_indices[grp_ptr[g]:grp_ptr[g+1]] + current_w[grp_g_indices] = w_ws[ptr: ptr+len(grp_g_indices)] + ptr += len(grp_g_indices) + opt = penalty.subdiff_distance(current_w, past_grads, ws) if np.max(opt) <= tol: break # descent direction - return w_ws - w_epoch[ws], X_delta_w_ws + delta_w_ws = np.zeros(n_features_ws) + ptr = 0 + for g in ws: + grp_g_indices = grp_indices[grp_ptr[g]:grp_ptr[g+1]] + delta_w_ws[ptr: ptr+len(grp_g_indices)] = w_epoch[grp_g_indices] + ptr += len(grp_g_indices) + + return delta_w_ws, X_delta_w_ws @njit @@ -180,16 +211,25 @@ def _backtrack_line_search(X, y, w, Xw, datafit, penalty, delta_w_ws, # step * \nabla datafit(w + step * delta_w) @ delta_w < 0 # ref: https://www.di.ens.fr/~aspremon/PDF/ENSAE/Newton.pdf # 2) inplace update of w and Xw and return grad_ws of the last w and Xw + grp_ptr = penalty.grp_ptr + grp_indices = penalty.grp_indices step, prev_step = 1., 0. + # TODO: could be improved by passing in w[ws] old_penalty_val = penalty.value(w) # try step = 1, 1/2, 1/4, ... for _ in range(MAX_BACKTRACK_ITER): - w[ws] += (step - prev_step) * delta_w_ws + ptr = 0 + for g in ws: + grp_g_indices = grp_indices[grp_ptr[g]:grp_ptr[g+1]] + w[grp_g_indices] += ((step - prev_step) * + delta_w_ws[ptr: ptr + len(grp_g_indices)]) + ptr += len(grp_g_indices) + Xw += (step - prev_step) * X_delta_w_ws - grad_ws = _construct_grad(X, y, w, Xw, datafit, ws) + grad_ws = -_construct_grad(X, y, w, Xw, datafit, ws) # TODO: could be improved by passing in w[ws] stop_crit = penalty.value(w) - old_penalty_val stop_crit += step * grad_ws @ delta_w_ws @@ -206,35 +246,6 @@ def _backtrack_line_search(X, y, w, Xw, datafit, penalty, delta_w_ws, return grad_ws -# sparse version of _backtrack_line_search -@njit -def _backtrack_line_search_s(X_data, X_indptr, X_indices, y, w, Xw, datafit, - penalty, delta_w_ws, X_delta_w_ws, ws): - step, prev_step = 1., 0. - # TODO: could be improved by passing in w[ws] - old_penalty_val = penalty.value(w) - - for _ in range(MAX_BACKTRACK_ITER): - w[ws] += (step - prev_step) * delta_w_ws - Xw += (step - prev_step) * X_delta_w_ws - - grad_ws = _construct_grad_sparse(X_data, X_indptr, X_indices, - y, w, Xw, datafit, ws) - # TODO: could be improved by passing in w[ws] - stop_crit = penalty.value(w) - old_penalty_val - stop_crit += step * grad_ws.T @ delta_w_ws - - if stop_crit < 0: - break - else: - prev_step = step - step /= 2 - else: - pass # TODO - - return grad_ws - - @njit def _construct_grad(X, y, w, Xw, datafit, ws): # Compute grad of datafit restricted to ws. This function avoids @@ -243,55 +254,14 @@ def _construct_grad(X, y, w, Xw, datafit, ws): n_features_ws = sum([grp_ptr[g+1] - grp_ptr[g] for g in ws]) raw_grad = datafit.raw_grad(y, Xw) - grad = np.zeros(len(ws)) + minus_grad = np.zeros(n_features_ws) - for idx, g in enumerate(ws): + grad_ptr = 0 + for g in ws: # compute grad_g grp_g_indices = grp_indices[grp_ptr[g]:grp_ptr[g+1]] for j in grp_g_indices: - grad[idx] = X[:, j] @ raw_grad - return grad - + minus_grad[grad_ptr] = -X[:, j] @ raw_grad + grad_ptr += 1 -@njit -def _construct_grad_sparse(X_data, X_indptr, X_indices, y, w, Xw, datafit, ws): - # Compute grad of datafit restricted to ws in case X sparse - raw_grad = datafit.raw_grad(y, Xw) - grad = np.zeros(len(ws)) - for idx, j in enumerate(ws): - grad[idx] = _sparse_xj_dot(X_data, X_indptr, X_indices, j, raw_grad) - return grad - - -@njit(fastmath=True) -def _sparse_xj_dot(X_data, X_indptr, X_indices, j, other): - # Compute X[:, j] @ other in case X sparse - res = 0. - for i in range(X_indptr[j], X_indptr[j+1]): - res += X_data[i] * other[X_indices[i]] - return res - - -@njit(fastmath=True) -def _sparse_weighted_dot(X_data, X_indptr, X_indices, j, other, weights): - # Compute X[:, j] @ (weights * other) in case X sparse - res = 0. - for i in range(X_indptr[j], X_indptr[j+1]): - res += X_data[i] * other[X_indices[i]] * weights[X_indices[i]] - return res - - -@njit(fastmath=True) -def _sparse_squared_weighted_norm(X_data, X_indptr, X_indices, j, weights): - # Compute weights @ X[:, j]**2 in case X sparse - res = 0. - for i in range(X_indptr[j], X_indptr[j+1]): - res += weights[X_indices[i]] * X_data[i]**2 - return res - - -@njit(fastmath=True) -def _update_X_delta_w(X_data, X_indptr, X_indices, X_delta_w, diff, j): - # Compute X_delta_w += diff * X[:, j] in case of X sparse - for i in range(X_indptr[j], X_indptr[j+1]): - X_delta_w[X_indices[i]] += diff * X_data[i] + return minus_grad From cd730bbc91950f3cfc2f4662298446fce597eb6c Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Sun, 16 Oct 2022 18:48:55 +0200 Subject: [PATCH 08/25] unittest --- skglm/solvers/__init__.py | 4 +++- skglm/tests/test_group.py | 36 +++++++++++++++++++++++++++++++++++- 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/skglm/solvers/__init__.py b/skglm/solvers/__init__.py index 0f8016f40..a685ac58a 100644 --- a/skglm/solvers/__init__.py +++ b/skglm/solvers/__init__.py @@ -4,6 +4,8 @@ from .group_bcd import GroupBCD from .multitask_bcd import MultiTaskBCD from .prox_newton import ProxNewton +from .group_prox_newton import GroupProxNewton -__all__ = [AndersonCD, BaseSolver, GramCD, GroupBCD, MultiTaskBCD, ProxNewton] +__all__ = [AndersonCD, BaseSolver, GramCD, GroupBCD, MultiTaskBCD, ProxNewton, + GroupProxNewton] diff --git a/skglm/tests/test_group.py b/skglm/tests/test_group.py index d072d2a77..04fb17c7d 100644 --- a/skglm/tests/test_group.py +++ b/skglm/tests/test_group.py @@ -6,7 +6,7 @@ from skglm.datafits import Quadratic from skglm.penalties.block_separable import WeightedGroupL2 from skglm.datafits.group import QuadraticGroup, LogisticGroup -from skglm.solvers import GroupBCD +from skglm.solvers import GroupBCD, GroupProxNewton from skglm.utils import ( _alpha_max_group_lasso, grp_converter, make_correlated_data, compiled_clone, AndersonAcceleration) @@ -221,6 +221,39 @@ def test_group_logreg(n_groups, rho): np.testing.assert_array_less(stop_crit, 1e-12) +@pytest.mark.parametrize("n_groups, rho", [[15, 1e-1], [25, 1e-2]]) +def test_group_logreg(n_groups, rho): + n_samples, n_features, shuffle = 30, 60, True + random_state = 123 + + X, y, _ = make_correlated_data(n_samples, n_features, random_state=random_state) + y = np.sign(y) + + np.random.seed(random_state) + weights = np.abs(np.random.randn(n_groups)) + grp_indices, grp_ptr, _ = _generate_random_grp(n_groups, n_features, shuffle) + + alpha_max = 0. + for g in range(n_groups): + grp_g_indices = grp_indices[grp_ptr[g]: grp_ptr[g+1]] + alpha_max = max( + alpha_max, + norm(X[:, grp_g_indices].T @ y) / n_samples / weights[g] + ) + alpha = rho * alpha_max + + # skglm + log_group = LogisticGroup(grp_ptr=grp_ptr, grp_indices=grp_indices) + group_penalty = WeightedGroupL2(alpha, weights, grp_ptr, grp_indices) + + log_group = compiled_clone(log_group, to_float32=X.dtype == np.float32) + group_penalty = compiled_clone(group_penalty) + stop_crit = GroupProxNewton(tol=1e-12, verbose=1).solve(X, + y, log_group, group_penalty)[2] + + # np.testing.assert_array_less(stop_crit, 1e-12) + + def test_anderson_acceleration(): # VAR: w = rho * w + 1 with |rho| < 1 # converges to w_star = 1 / (1 - rho) @@ -262,4 +295,5 @@ def test_anderson_acceleration(): if __name__ == "__main__": + test_group_logreg(10, 1e-1) pass From 003c8af0b0fed489a8c190e8cf73c74923987d3f Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Sun, 16 Oct 2022 18:49:06 +0200 Subject: [PATCH 09/25] fix bug && concise code --- skglm/solvers/group_prox_newton.py | 78 ++++++++++++++++-------------- 1 file changed, 42 insertions(+), 36 deletions(-) diff --git a/skglm/solvers/group_prox_newton.py b/skglm/solvers/group_prox_newton.py index 9f852fec8..83b5fc552 100644 --- a/skglm/solvers/group_prox_newton.py +++ b/skglm/solvers/group_prox_newton.py @@ -1,7 +1,6 @@ import numpy as np from numba import njit from numpy.linalg import norm -from scipy.sparse import issparse from skglm.solvers.base import BaseSolver @@ -54,7 +53,8 @@ def __init__(self, p0=10, max_iter=20, max_pn_iter=1000, tol=1e-4, def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): n_samples, n_features = X.shape - n_groups = len(penalty.grp_ptr) - 1 + grp_ptr, grp_indices = penalty.grp_ptr, penalty.grp_indices + n_groups = len(grp_ptr) - 1 w = np.zeros(n_features) if w_init is None else w_init Xw = np.zeros(n_samples) if Xw_init is None else Xw_init @@ -86,8 +86,7 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): min(n_groups, 2 * gsupp_size)) ws = np.argpartition(opt, -ws_size)[-ws_size:] # k-largest items (no sort) - n_features_ws = sum([penalty.grp_ptr[g+1] - penalty.grp_ptr[g] for g in ws]) - grad_ws = grad[:n_features_ws] + grad_ws = _slice_array(grad, ws, grp_ptr, grp_indices) tol_in = EPS_TOL * stop_crit # solve subproblem @@ -121,7 +120,7 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): return w, np.asarray(p_objs_out), stop_crit -@njit +# @njit def _descent_direction(X, y, w_epoch, Xw_epoch, grad_ws, datafit, penalty, ws, tol): # Given: @@ -130,33 +129,28 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, grad_ws, datafit, # Minimize quadratic approximation for delta_w = w - w_epoch: # b.T @ X @ delta_w + \ # 1/2 * delta_w.T @ (X.T @ D @ X) @ delta_w + penalty(w) - grp_ptr = penalty.grp_ptr - grp_indices = penalty.grp_indices + grp_ptr, grp_indices = penalty.grp_ptr, penalty.grp_indices n_features_ws = sum([penalty.grp_ptr[g+1] - penalty.grp_ptr[g] for g in ws]) - raw_hess = datafit.raw_hessian(y, Xw_epoch) + lipchitz = np.zeros(len(ws)) for idx, g in enumerate(ws): grp_g_indices = grp_indices[grp_ptr[g]:grp_ptr[g+1]] - lipchitz[idx] = norm(np.sqrt(raw_hess) * X[:, grp_g_indices], ord=2) ** 2 + # TODO: compute without copying the cols of X + # equivalent to: norm(X[:, grp_g_indices].T @ D @ X[:, grp_g_indices], ord=2) + lipchitz[idx] = norm(np.sqrt(raw_hess) * X[:, grp_g_indices].T, ord=2) ** 2 # for a less costly stopping criterion, we do no compute the exact gradient, - # but store each coordinate-wise gradient every time we upate one coordinate: + # but store each coordinate-wise gradient every time we update one coordinate: past_grads = np.zeros(n_features_ws) X_delta_w_ws = np.zeros(X.shape[0]) - - w_ws = np.zeros(n_features_ws) - w_ptr = 0 - for g in ws: - grp_g_indices = grp_indices[grp_ptr[g]:grp_ptr[g+1]] - w_ws[w_ptr:w_ptr+len(grp_g_indices)] = w_epoch[grp_g_indices] - w_ptr += len(grp_g_indices) + w_ws = _slice_array(w_epoch, ws, grp_ptr, grp_indices) for cd_iter in range(MAX_CD_ITER): ptr = 0 for idx, g in enumerate(ws): # skip when X[:, grp_g_indices] == 0 - if lipchitz[idx] == 0: + if lipchitz[idx] == 0.: continue grp_g_indices = grp_indices[grp_ptr[g]:grp_ptr[g+1]] @@ -164,17 +158,19 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, grad_ws, datafit, past_grads[range_grp_g] = grad_ws[range_grp_g] # TODO: compute without copying the cols of X - past_grads[range_grp_g] += X[:, grp_g_indices] @ (raw_hess * X_delta_w_ws) + past_grads[range_grp_g] += X[:, grp_g_indices].T @ (raw_hess * X_delta_w_ws) - old_w_ws_g = w_ws[range_grp_g] + old_w_ws_g = w_ws[range_grp_g].copy() stepsize = 1 / lipchitz[idx] w_ws[range_grp_g] = penalty.prox_1group( old_w_ws_g - stepsize * past_grads[range_grp_g], stepsize, g) + # equivalent to: X_delta_w_ws += X[:, grp_g_indices] @ (w_ws_g - old_w_ws_g) + # but without making a copy of the cols of X for idx_j, j in enumerate(grp_g_indices): if w_ws[ptr + idx_j] != old_w_ws_g[idx_j]: - X_delta_w_ws += (w_ws[ptr + idx_j] - old_w_ws_g[j]) * X[:, j] + X_delta_w_ws += (w_ws[ptr + idx_j] - old_w_ws_g[idx_j]) * X[:, j] ptr += len(grp_g_indices) @@ -182,6 +178,7 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, grad_ws, datafit, # TODO: can be improved by passing in w_ws current_w = w_epoch.copy() + # equivalent to: current_w[ws] = w_ws ptr = 0 for g in ws: grp_g_indices = grp_indices[grp_ptr[g]:grp_ptr[g+1]] @@ -193,17 +190,11 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, grad_ws, datafit, break # descent direction - delta_w_ws = np.zeros(n_features_ws) - ptr = 0 - for g in ws: - grp_g_indices = grp_indices[grp_ptr[g]:grp_ptr[g+1]] - delta_w_ws[ptr: ptr+len(grp_g_indices)] = w_epoch[grp_g_indices] - ptr += len(grp_g_indices) - + delta_w_ws = w_ws - _slice_array(w_epoch, ws, grp_ptr, grp_indices) return delta_w_ws, X_delta_w_ws -@njit +# @njit def _backtrack_line_search(X, y, w, Xw, datafit, penalty, delta_w_ws, X_delta_w_ws, ws): # 1) find step in [0, 1] such that: @@ -211,8 +202,7 @@ def _backtrack_line_search(X, y, w, Xw, datafit, penalty, delta_w_ws, # step * \nabla datafit(w + step * delta_w) @ delta_w < 0 # ref: https://www.di.ens.fr/~aspremon/PDF/ENSAE/Newton.pdf # 2) inplace update of w and Xw and return grad_ws of the last w and Xw - grp_ptr = penalty.grp_ptr - grp_indices = penalty.grp_indices + grp_ptr, grp_indices = penalty.grp_ptr, penalty.grp_indices step, prev_step = 1., 0. # TODO: could be improved by passing in w[ws] @@ -220,6 +210,7 @@ def _backtrack_line_search(X, y, w, Xw, datafit, penalty, delta_w_ws, # try step = 1, 1/2, 1/4, ... for _ in range(MAX_BACKTRACK_ITER): + # equivalent to: w[ws] += (step - prev_step) * delta_w_ws ptr = 0 for g in ws: grp_g_indices = grp_indices[grp_ptr[g]:grp_ptr[g+1]] @@ -228,8 +219,8 @@ def _backtrack_line_search(X, y, w, Xw, datafit, penalty, delta_w_ws, ptr += len(grp_g_indices) Xw += (step - prev_step) * X_delta_w_ws - grad_ws = -_construct_grad(X, y, w, Xw, datafit, ws) + # TODO: could be improved by passing in w[ws] stop_crit = penalty.value(w) - old_penalty_val stop_crit += step * grad_ws @ delta_w_ws @@ -246,7 +237,7 @@ def _backtrack_line_search(X, y, w, Xw, datafit, penalty, delta_w_ws, return grad_ws -@njit +# @njit def _construct_grad(X, y, w, Xw, datafit, ws): # Compute grad of datafit restricted to ws. This function avoids # recomputing raw_grad for every j, which is costly for logreg @@ -256,12 +247,27 @@ def _construct_grad(X, y, w, Xw, datafit, ws): raw_grad = datafit.raw_grad(y, Xw) minus_grad = np.zeros(n_features_ws) - grad_ptr = 0 + ptr = 0 for g in ws: # compute grad_g grp_g_indices = grp_indices[grp_ptr[g]:grp_ptr[g+1]] for j in grp_g_indices: - minus_grad[grad_ptr] = -X[:, j] @ raw_grad - grad_ptr += 1 + minus_grad[ptr] = -X[:, j] @ raw_grad + ptr += 1 return minus_grad + + +# @njit +def _slice_array(arr, ws, grp_ptr, grp_indices): + # returns [arr[ws_1], arr[ws_2], ...] + n_features_ws = sum([grp_ptr[g+1] - grp_ptr[g] for g in ws]) + sliced_arr = np.zeros(n_features_ws) + + ptr = 0 + for g in ws: + grp_g_indices = grp_indices[grp_ptr[g]:grp_ptr[g+1]] + sliced_arr[ptr: ptr+len(grp_g_indices)] = arr[grp_g_indices] + ptr += len(grp_g_indices) + + return sliced_arr From 21387b420f520283c64155f797a1feff71b303e8 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Sun, 16 Oct 2022 19:11:39 +0200 Subject: [PATCH 10/25] uncomment ``njit`` --- skglm/solvers/group_prox_newton.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/skglm/solvers/group_prox_newton.py b/skglm/solvers/group_prox_newton.py index 83b5fc552..ab2f5b03d 100644 --- a/skglm/solvers/group_prox_newton.py +++ b/skglm/solvers/group_prox_newton.py @@ -10,7 +10,7 @@ class GroupProxNewton(BaseSolver): - """Prox Newton solver combined with working sets. + """Group Prox Newton solver combined with working sets. p0 : int, default 10 Minimum number of features to be included in the working set. @@ -120,15 +120,18 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): return w, np.asarray(p_objs_out), stop_crit -# @njit +@njit def _descent_direction(X, y, w_epoch, Xw_epoch, grad_ws, datafit, penalty, ws, tol): - # Given: + # given: # 1) b = \nabla F(X w_epoch) # 2) D = \nabla^2 F(X w_epoch) <------> raw_hess - # Minimize quadratic approximation for delta_w = w - w_epoch: + # minimize quadratic approximation for delta_w = w - w_epoch: # b.T @ X @ delta_w + \ # 1/2 * delta_w.T @ (X.T @ D @ X) @ delta_w + penalty(w) + # In CD leverage inequality: + # penalty_g(w_g) + 1/2 ||delta_w_g||_H <= \ + # penalty_g(w_g) + 1/2 * || H || * ||delta_w_g|| grp_ptr, grp_indices = penalty.grp_ptr, penalty.grp_indices n_features_ws = sum([penalty.grp_ptr[g+1] - penalty.grp_ptr[g] for g in ws]) raw_hess = datafit.raw_hessian(y, Xw_epoch) @@ -136,7 +139,6 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, grad_ws, datafit, lipchitz = np.zeros(len(ws)) for idx, g in enumerate(ws): grp_g_indices = grp_indices[grp_ptr[g]:grp_ptr[g+1]] - # TODO: compute without copying the cols of X # equivalent to: norm(X[:, grp_g_indices].T @ D @ X[:, grp_g_indices], ord=2) lipchitz[idx] = norm(np.sqrt(raw_hess) * X[:, grp_g_indices].T, ord=2) ** 2 @@ -157,7 +159,6 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, grad_ws, datafit, range_grp_g = slice(ptr, ptr + len(grp_g_indices)) past_grads[range_grp_g] = grad_ws[range_grp_g] - # TODO: compute without copying the cols of X past_grads[range_grp_g] += X[:, grp_g_indices].T @ (raw_hess * X_delta_w_ws) old_w_ws_g = w_ws[range_grp_g].copy() @@ -194,7 +195,7 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, grad_ws, datafit, return delta_w_ws, X_delta_w_ws -# @njit +@njit def _backtrack_line_search(X, y, w, Xw, datafit, penalty, delta_w_ws, X_delta_w_ws, ws): # 1) find step in [0, 1] such that: @@ -237,9 +238,9 @@ def _backtrack_line_search(X, y, w, Xw, datafit, penalty, delta_w_ws, return grad_ws -# @njit +@njit def _construct_grad(X, y, w, Xw, datafit, ws): - # Compute grad of datafit restricted to ws. This function avoids + # compute grad of datafit restricted to ws. This function avoids # recomputing raw_grad for every j, which is costly for logreg grp_ptr, grp_indices = datafit.grp_ptr, datafit.grp_indices n_features_ws = sum([grp_ptr[g+1] - grp_ptr[g] for g in ws]) @@ -258,7 +259,7 @@ def _construct_grad(X, y, w, Xw, datafit, ws): return minus_grad -# @njit +@njit def _slice_array(arr, ws, grp_ptr, grp_indices): # returns [arr[ws_1], arr[ws_2], ...] n_features_ws = sum([grp_ptr[g+1] - grp_ptr[g] for g in ws]) From 819b3812c1e9420a1aa9b8e6cd6f061167dae3e7 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Sun, 16 Oct 2022 19:17:54 +0200 Subject: [PATCH 11/25] separate unittest --- skglm/tests/test_group.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/skglm/tests/test_group.py b/skglm/tests/test_group.py index 04fb17c7d..7caf64e01 100644 --- a/skglm/tests/test_group.py +++ b/skglm/tests/test_group.py @@ -222,7 +222,7 @@ def test_group_logreg(n_groups, rho): @pytest.mark.parametrize("n_groups, rho", [[15, 1e-1], [25, 1e-2]]) -def test_group_logreg(n_groups, rho): +def test_group_prox_newton(n_groups, rho): n_samples, n_features, shuffle = 30, 60, True random_state = 123 @@ -248,10 +248,9 @@ def test_group_logreg(n_groups, rho): log_group = compiled_clone(log_group, to_float32=X.dtype == np.float32) group_penalty = compiled_clone(group_penalty) - stop_crit = GroupProxNewton(tol=1e-12, verbose=1).solve(X, - y, log_group, group_penalty)[2] + stop_crit = GroupProxNewton(tol=1e-12).solve(X, y, log_group, group_penalty)[2] - # np.testing.assert_array_less(stop_crit, 1e-12) + np.testing.assert_array_less(stop_crit, 1e-12) def test_anderson_acceleration(): @@ -295,5 +294,4 @@ def test_anderson_acceleration(): if __name__ == "__main__": - test_group_logreg(10, 1e-1) pass From 51c624ec86e392b053123ad2531b893d857b31d6 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Mon, 17 Oct 2022 18:12:28 +0200 Subject: [PATCH 12/25] bug p_objs && profile code --- profile_script.py | 71 ++++++++++++++++++++++++++++++ skglm/solvers/group_prox_newton.py | 47 +++++++++++++++----- 2 files changed, 108 insertions(+), 10 deletions(-) create mode 100644 profile_script.py diff --git a/profile_script.py b/profile_script.py new file mode 100644 index 000000000..38f694878 --- /dev/null +++ b/profile_script.py @@ -0,0 +1,71 @@ +import numpy as np +from numpy.linalg import norm +from skglm.utils import make_correlated_data, compiled_clone +from skglm.solvers import GroupProxNewton +from skglm.datafits import LogisticGroup +from skglm.penalties import WeightedGroupL2 + +from skglm.solvers.group_prox_newton import _descent_direction + +import line_profiler + + +def _generate_random_grp(n_groups, n_features, shuffle=True): + grp_indices = np.arange(n_features, dtype=np.int32) + np.random.seed(0) + if shuffle: + np.random.shuffle(grp_indices) + splits = np.random.choice( + n_features, size=n_groups+1, replace=False).astype(np.int32) + splits.sort() + splits[0], splits[-1] = 0, n_features + + groups = [list(grp_indices[splits[i]: splits[i+1]]) + for i in range(n_groups)] + + return grp_indices, splits, groups + + +###### +rho = 1e-1 +n_groups = 100 +n_samples, n_features, shuffle = 500, 5000, True +random_state = 123 + +X, y, _ = make_correlated_data(n_samples, n_features, rho=0.3, + random_state=random_state) +y = np.sign(y) + +np.random.seed(random_state) +weights = np.ones(n_groups) +grp_indices, grp_ptr, _ = _generate_random_grp(n_groups, n_features, shuffle) + +alpha_max = 0. +for g in range(n_groups): + grp_g_indices = grp_indices[grp_ptr[g]: grp_ptr[g+1]] + alpha_max = max( + alpha_max, + norm(X[:, grp_g_indices].T @ y) / n_samples / weights[g] + ) +alpha = rho * alpha_max + + +# skglm +log_group = LogisticGroup(grp_ptr=grp_ptr, grp_indices=grp_indices) +group_penalty = WeightedGroupL2(alpha, weights, grp_ptr, grp_indices) + +log_group = compiled_clone(log_group, to_float32=X.dtype == np.float32) +group_penalty = compiled_clone(group_penalty) + +# cache numba jit compilation +solver = GroupProxNewton(tol=1e-9, fit_intercept=False) +stop_crit = solver.solve(X, y, log_group, group_penalty)[2] +print(stop_crit) + + +# profile code +profiler = line_profiler.LineProfiler() +profiler.add_function(solver.solve) +profiler.enable_by_count() +solver.solve(X, y, log_group, group_penalty) +profiler.print_stats() diff --git a/skglm/solvers/group_prox_newton.py b/skglm/solvers/group_prox_newton.py index ab2f5b03d..ad45158f8 100644 --- a/skglm/solvers/group_prox_newton.py +++ b/skglm/solvers/group_prox_newton.py @@ -115,8 +115,8 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): print("Early exit") break - p_obj = datafit.value(y, w, Xw) + penalty.value(w) - p_objs_out.append(p_obj) + p_obj = datafit.value(y, w, Xw) + penalty.value(w) + p_objs_out.append(p_obj) return w, np.asarray(p_objs_out), stop_crit @@ -140,7 +140,8 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, grad_ws, datafit, for idx, g in enumerate(ws): grp_g_indices = grp_indices[grp_ptr[g]:grp_ptr[g+1]] # equivalent to: norm(X[:, grp_g_indices].T @ D @ X[:, grp_g_indices], ord=2) - lipchitz[idx] = norm(np.sqrt(raw_hess) * X[:, grp_g_indices].T, ord=2) ** 2 + lipchitz[idx] = norm( + _matrix_times_X_g(np.sqrt(raw_hess), X, grp_g_indices), ord=2)**2 # for a less costly stopping criterion, we do no compute the exact gradient, # but store each coordinate-wise gradient every time we update one coordinate: @@ -159,7 +160,8 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, grad_ws, datafit, range_grp_g = slice(ptr, ptr + len(grp_g_indices)) past_grads[range_grp_g] = grad_ws[range_grp_g] - past_grads[range_grp_g] += X[:, grp_g_indices].T @ (raw_hess * X_delta_w_ws) + past_grads[range_grp_g] += _X_g_dot_vec( + X, raw_hess * X_delta_w_ws, grp_g_indices) old_w_ws_g = w_ws[range_grp_g].copy() stepsize = 1 / lipchitz[idx] @@ -167,11 +169,9 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, grad_ws, datafit, w_ws[range_grp_g] = penalty.prox_1group( old_w_ws_g - stepsize * past_grads[range_grp_g], stepsize, g) - # equivalent to: X_delta_w_ws += X[:, grp_g_indices] @ (w_ws_g - old_w_ws_g) - # but without making a copy of the cols of X - for idx_j, j in enumerate(grp_g_indices): - if w_ws[ptr + idx_j] != old_w_ws_g[idx_j]: - X_delta_w_ws += (w_ws[ptr + idx_j] - old_w_ws_g[idx_j]) * X[:, j] + # X_delta_w_ws += X[:, grp_g_indices] @ (w_ws[range_grp_g] - old_w_ws_g) + _update_X_delta_w_ws(X, X_delta_w_ws, w_ws[range_grp_g], old_w_ws_g, + grp_g_indices) ptr += len(grp_g_indices) @@ -211,7 +211,7 @@ def _backtrack_line_search(X, y, w, Xw, datafit, penalty, delta_w_ws, # try step = 1, 1/2, 1/4, ... for _ in range(MAX_BACKTRACK_ITER): - # equivalent to: w[ws] += (step - prev_step) * delta_w_ws + # w[ws] += (step - prev_step) * delta_w_ws ptr = 0 for g in ws: grp_g_indices = grp_indices[grp_ptr[g]:grp_ptr[g+1]] @@ -272,3 +272,30 @@ def _slice_array(arr, ws, grp_ptr, grp_indices): ptr += len(grp_g_indices) return sliced_arr + + +@njit +def _update_X_delta_w_ws(X, X_delta_w_ws, w_ws_g, old_w_ws_g, grp_g_indices): + # + for idx, j in enumerate(grp_g_indices): + delta_w_j = w_ws_g[idx] - old_w_ws_g[idx] + if w_ws_g[idx] != old_w_ws_g[idx]: + X_delta_w_ws += delta_w_j * X[:, j] + + +@njit +def _X_g_dot_vec(X, vec, grp_g_indices): + # + result = np.zeros(len(grp_g_indices)) + for idx, j in enumerate(grp_g_indices): + result[idx] = X[:, j] @ vec + return result + + +@njit +def _matrix_times_X_g(matrix, X, grp_g_indices): + # + result = np.zeros((len(matrix), len(grp_g_indices))) + for idx, j in enumerate(grp_g_indices): + result[:, idx] = matrix * X[:, j] + return result From c5a50beefa692ec378b94ad1fa6a650ccc15f18b Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Tue, 18 Oct 2022 10:52:36 +0200 Subject: [PATCH 13/25] refactor test && better namings comments --- skglm/solvers/group_prox_newton.py | 39 ++++++++++--------- skglm/tests/test_group.py | 61 +++++++----------------------- 2 files changed, 35 insertions(+), 65 deletions(-) diff --git a/skglm/solvers/group_prox_newton.py b/skglm/solvers/group_prox_newton.py index ad45158f8..50733acc7 100644 --- a/skglm/solvers/group_prox_newton.py +++ b/skglm/solvers/group_prox_newton.py @@ -2,7 +2,7 @@ from numba import njit from numpy.linalg import norm from skglm.solvers.base import BaseSolver - +from skglm.utils import check_group_compatible EPS_TOL = 0.3 MAX_CD_ITER = 20 @@ -52,6 +52,9 @@ def __init__(self, p0=10, max_iter=20, max_pn_iter=1000, tol=1e-4, self.verbose = verbose def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): + check_group_compatible(datafit) + check_group_compatible(penalty) + n_samples, n_features = X.shape grp_ptr, grp_indices = penalty.grp_ptr, penalty.grp_indices n_groups = len(grp_ptr) - 1 @@ -62,8 +65,7 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): stop_crit = 0. p_objs_out = [] - for t in range(self.max_iter): - # compute grad + for iter in range(self.max_iter): grad = -_construct_grad(X, y, w, Xw, datafit, all_groups) # check convergence @@ -73,14 +75,14 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): if self.verbose: p_obj = datafit.value(y, w, Xw) + penalty.value(w) print( - f"Iteration {t+1}: {p_obj:.10f}, " + f"Iteration {iter+1}: {p_obj:.10f}, " f"stopping crit: {stop_crit:.2e}" ) if stop_crit <= self.tol: break - # construct ws + # build working set ws gsupp_size = penalty.generalized_support(w).sum() ws_size = max(min(self.p0, n_groups), min(n_groups, 2 * gsupp_size)) @@ -89,13 +91,13 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): grad_ws = _slice_array(grad, ws, grp_ptr, grp_indices) tol_in = EPS_TOL * stop_crit - # solve subproblem + # solve subproblem restricted to ws for pn_iter in range(self.max_pn_iter): # find descent direction delta_w_ws, X_delta_w_ws = _descent_direction( X, y, w, Xw, grad_ws, datafit, penalty, ws, tol=EPS_TOL*tol_in) - # find a suitable step size + # find a suitable step size and in-place update w, Xw grad_ws[:] = _backtrack_line_search( X, y, w, Xw, datafit, penalty, delta_w_ws, X_delta_w_ws, ws) @@ -106,8 +108,8 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): if max(self.verbose-1, 0): p_obj = datafit.value(y, w, Xw) + penalty.value(w) print( - "PN iteration {}: {:.10f}, ".format(pn_iter+1, p_obj) + - "stopping crit in: {:.2e}".format(stop_crit_in) + f"PN iteration {pn_iter+1,}: {p_obj:.10f}, " + f"stopping crit in: {stop_crit_in:.2e}" ) if stop_crit_in <= tol_in: @@ -129,7 +131,7 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, grad_ws, datafit, # minimize quadratic approximation for delta_w = w - w_epoch: # b.T @ X @ delta_w + \ # 1/2 * delta_w.T @ (X.T @ D @ X) @ delta_w + penalty(w) - # In CD leverage inequality: + # In BCD, we leverage inequality: # penalty_g(w_g) + 1/2 ||delta_w_g||_H <= \ # penalty_g(w_g) + 1/2 * || H || * ||delta_w_g|| grp_ptr, grp_indices = penalty.grp_ptr, penalty.grp_indices @@ -140,8 +142,8 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, grad_ws, datafit, for idx, g in enumerate(ws): grp_g_indices = grp_indices[grp_ptr[g]:grp_ptr[g+1]] # equivalent to: norm(X[:, grp_g_indices].T @ D @ X[:, grp_g_indices], ord=2) - lipchitz[idx] = norm( - _matrix_times_X_g(np.sqrt(raw_hess), X, grp_g_indices), ord=2)**2 + lipchitz[idx] = norm(_diag_times_X_g( + np.sqrt(raw_hess), X, grp_g_indices), ord=2)**2 # for a less costly stopping criterion, we do no compute the exact gradient, # but store each coordinate-wise gradient every time we update one coordinate: @@ -160,7 +162,8 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, grad_ws, datafit, range_grp_g = slice(ptr, ptr + len(grp_g_indices)) past_grads[range_grp_g] = grad_ws[range_grp_g] - past_grads[range_grp_g] += _X_g_dot_vec( + # += X[:, grp_g_indices].T @ (raw_hess * X_delta_w_ws) + past_grads[range_grp_g] += _X_g_T_dot_vec( X, raw_hess * X_delta_w_ws, grp_g_indices) old_w_ws_g = w_ws[range_grp_g].copy() @@ -179,7 +182,7 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, grad_ws, datafit, # TODO: can be improved by passing in w_ws current_w = w_epoch.copy() - # equivalent to: current_w[ws] = w_ws + # current_w[ws] = w_ws ptr = 0 for g in ws: grp_g_indices = grp_indices[grp_ptr[g]:grp_ptr[g+1]] @@ -284,7 +287,7 @@ def _update_X_delta_w_ws(X, X_delta_w_ws, w_ws_g, old_w_ws_g, grp_g_indices): @njit -def _X_g_dot_vec(X, vec, grp_g_indices): +def _X_g_T_dot_vec(X, vec, grp_g_indices): # result = np.zeros(len(grp_g_indices)) for idx, j in enumerate(grp_g_indices): @@ -293,9 +296,9 @@ def _X_g_dot_vec(X, vec, grp_g_indices): @njit -def _matrix_times_X_g(matrix, X, grp_g_indices): +def _diag_times_X_g(diag, X, grp_g_indices): # - result = np.zeros((len(matrix), len(grp_g_indices))) + result = np.zeros((len(diag), len(grp_g_indices))) for idx, j in enumerate(grp_g_indices): - result[:, idx] = matrix * X[:, j] + result[:, idx] = diag * X[:, j] return result diff --git a/skglm/tests/test_group.py b/skglm/tests/test_group.py index 7caf64e01..829ff68ae 100644 --- a/skglm/tests/test_group.py +++ b/skglm/tests/test_group.py @@ -1,4 +1,6 @@ import pytest +from itertools import product + import numpy as np from numpy.linalg import norm @@ -30,13 +32,14 @@ def _generate_random_grp(n_groups, n_features, shuffle=True): return grp_indices, splits, groups -def test_check_group_compatible(): +@pytest.mark.parametrize("solver", [GroupBCD, GroupProxNewton]) +def test_check_group_compatible(solver): l1_penalty = L1(1e-3) quad_datafit = Quadratic() X, y = np.random.randn(5, 5), np.random.randn(5) with np.testing.assert_raises(Exception): - GroupBCD().solve(X, y, quad_datafit, l1_penalty) + solver().solve(X, y, quad_datafit, l1_penalty) @pytest.mark.parametrize("n_groups, n_features, shuffle", @@ -161,8 +164,9 @@ def test_intercept_grouplasso(): np.testing.assert_allclose(model.intercept_, w[-1], atol=1e-5) -@pytest.mark.parametrize("rho", [1e-1, 1e-2]) -def test_equivalence_logreg(rho): +@pytest.mark.parametrize("solver, rho", + product([GroupBCD, GroupProxNewton], [1e-1, 1e-2])) +def test_equivalence_logreg(solver, rho): n_samples, n_features = 30, 50 rnd = np.random.RandomState(1123) X, y, _ = make_correlated_data(n_samples, n_features, random_state=rnd) @@ -180,7 +184,7 @@ def test_equivalence_logreg(rho): log_group = compiled_clone(log_group, to_float32=X.dtype == np.float32) group_penalty = compiled_clone(group_penalty) - w = GroupBCD(tol=1e-12).solve(X, y, log_group, group_penalty)[0] + w = solver(tol=1e-12).solve(X, y, log_group, group_penalty)[0] sk_logreg = LogisticRegression(penalty='l1', C=1/(n_samples * alpha), fit_intercept=False, tol=1e-12, solver='liblinear') @@ -189,40 +193,9 @@ def test_equivalence_logreg(rho): np.testing.assert_allclose(sk_logreg.coef_.flatten(), w, atol=1e-6, rtol=1e-5) -@pytest.mark.parametrize("n_groups, rho", [[15, 1e-1], [25, 1e-2]]) -def test_group_logreg(n_groups, rho): - n_samples, n_features, shuffle = 30, 60, True - random_state = 123 - - X, y, _ = make_correlated_data(n_samples, n_features, random_state=random_state) - y = np.sign(y) - - np.random.seed(random_state) - weights = np.abs(np.random.randn(n_groups)) - grp_indices, grp_ptr, _ = _generate_random_grp(n_groups, n_features, shuffle) - - alpha_max = 0. - for g in range(n_groups): - grp_g_indices = grp_indices[grp_ptr[g]: grp_ptr[g+1]] - alpha_max = max( - alpha_max, - norm(X[:, grp_g_indices].T @ y) / n_samples / weights[g] - ) - alpha = rho * alpha_max - - # skglm - log_group = LogisticGroup(grp_ptr=grp_ptr, grp_indices=grp_indices) - group_penalty = WeightedGroupL2(alpha, weights, grp_ptr, grp_indices) - - log_group = compiled_clone(log_group, to_float32=X.dtype == np.float32) - group_penalty = compiled_clone(group_penalty) - stop_crit = GroupBCD(tol=1e-12).solve(X, y, log_group, group_penalty)[2] - - np.testing.assert_array_less(stop_crit, 1e-12) - - -@pytest.mark.parametrize("n_groups, rho", [[15, 1e-1], [25, 1e-2]]) -def test_group_prox_newton(n_groups, rho): +@pytest.mark.parametrize("solver, n_groups, rho", + product([GroupBCD, GroupProxNewton], [15, 25], [1e-1, 1e-2])) +def test_group_logreg(solver, n_groups, rho): n_samples, n_features, shuffle = 30, 60, True random_state = 123 @@ -233,13 +206,7 @@ def test_group_prox_newton(n_groups, rho): weights = np.abs(np.random.randn(n_groups)) grp_indices, grp_ptr, _ = _generate_random_grp(n_groups, n_features, shuffle) - alpha_max = 0. - for g in range(n_groups): - grp_g_indices = grp_indices[grp_ptr[g]: grp_ptr[g+1]] - alpha_max = max( - alpha_max, - norm(X[:, grp_g_indices].T @ y) / n_samples / weights[g] - ) + alpha_max = _alpha_max_group_lasso(X, y, grp_indices, grp_ptr, weights) alpha = rho * alpha_max # skglm @@ -248,7 +215,7 @@ def test_group_prox_newton(n_groups, rho): log_group = compiled_clone(log_group, to_float32=X.dtype == np.float32) group_penalty = compiled_clone(group_penalty) - stop_crit = GroupProxNewton(tol=1e-12).solve(X, y, log_group, group_penalty)[2] + stop_crit = solver(tol=1e-12).solve(X, y, log_group, group_penalty)[2] np.testing.assert_array_less(stop_crit, 1e-12) From eada218433cf4372c04faec989371ca01e236060 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Tue, 18 Oct 2022 14:28:01 +0200 Subject: [PATCH 14/25] info comments --- profile_script.py | 2 +- skglm/solvers/group_prox_newton.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/profile_script.py b/profile_script.py index 38f694878..5900e8657 100644 --- a/profile_script.py +++ b/profile_script.py @@ -5,7 +5,7 @@ from skglm.datafits import LogisticGroup from skglm.penalties import WeightedGroupL2 -from skglm.solvers.group_prox_newton import _descent_direction +from skglm.solvers.group_prox_newton import _descent_direction, _backtrack_line_search import line_profiler diff --git a/skglm/solvers/group_prox_newton.py b/skglm/solvers/group_prox_newton.py index 50733acc7..917b32174 100644 --- a/skglm/solvers/group_prox_newton.py +++ b/skglm/solvers/group_prox_newton.py @@ -141,7 +141,7 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, grad_ws, datafit, lipchitz = np.zeros(len(ws)) for idx, g in enumerate(ws): grp_g_indices = grp_indices[grp_ptr[g]:grp_ptr[g+1]] - # equivalent to: norm(X[:, grp_g_indices].T @ D @ X[:, grp_g_indices], ord=2) + # norm(X[:, grp_g_indices].T @ np.diag(raw_hess) @ X[:, grp_g_indices], ord=2) lipchitz[idx] = norm(_diag_times_X_g( np.sqrt(raw_hess), X, grp_g_indices), ord=2)**2 @@ -182,7 +182,7 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, grad_ws, datafit, # TODO: can be improved by passing in w_ws current_w = w_epoch.copy() - # current_w[ws] = w_ws + # for g in ws: current_w[ws_g] = w_ws_g ptr = 0 for g in ws: grp_g_indices = grp_indices[grp_ptr[g]:grp_ptr[g+1]] @@ -214,7 +214,7 @@ def _backtrack_line_search(X, y, w, Xw, datafit, penalty, delta_w_ws, # try step = 1, 1/2, 1/4, ... for _ in range(MAX_BACKTRACK_ITER): - # w[ws] += (step - prev_step) * delta_w_ws + # for g in ws: w[ws_g] += (step - prev_step) * delta_w_ws_g ptr = 0 for g in ws: grp_g_indices = grp_indices[grp_ptr[g]:grp_ptr[g+1]] @@ -279,7 +279,7 @@ def _slice_array(arr, ws, grp_ptr, grp_indices): @njit def _update_X_delta_w_ws(X, X_delta_w_ws, w_ws_g, old_w_ws_g, grp_g_indices): - # + # X_delta_w_ws += X[:, grp_g_indices] @ (w_ws_g - old_w_ws_g) for idx, j in enumerate(grp_g_indices): delta_w_j = w_ws_g[idx] - old_w_ws_g[idx] if w_ws_g[idx] != old_w_ws_g[idx]: @@ -288,7 +288,7 @@ def _update_X_delta_w_ws(X, X_delta_w_ws, w_ws_g, old_w_ws_g, grp_g_indices): @njit def _X_g_T_dot_vec(X, vec, grp_g_indices): - # + # X[:, grp_g_indices].T @ vec result = np.zeros(len(grp_g_indices)) for idx, j in enumerate(grp_g_indices): result[idx] = X[:, j] @ vec @@ -297,7 +297,7 @@ def _X_g_T_dot_vec(X, vec, grp_g_indices): @njit def _diag_times_X_g(diag, X, grp_g_indices): - # + # np.diag(dig) @ X[:, grp_g_indices] result = np.zeros((len(diag), len(grp_g_indices))) for idx, j in enumerate(grp_g_indices): result[:, idx] = diag * X[:, j] From 2a9c580ed48e2e4f61349d4ffd78e6e83341d44d Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Tue, 18 Oct 2022 18:47:06 +0200 Subject: [PATCH 15/25] remove ``-grad`` convention --- skglm/solvers/group_prox_newton.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/skglm/solvers/group_prox_newton.py b/skglm/solvers/group_prox_newton.py index 917b32174..1d280715c 100644 --- a/skglm/solvers/group_prox_newton.py +++ b/skglm/solvers/group_prox_newton.py @@ -66,10 +66,10 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): p_objs_out = [] for iter in range(self.max_iter): - grad = -_construct_grad(X, y, w, Xw, datafit, all_groups) + grad = _construct_grad(X, y, w, Xw, datafit, all_groups) # check convergence - opt = penalty.subdiff_distance(w, -grad, all_groups) + opt = penalty.subdiff_distance(w, grad, all_groups) stop_crit = np.max(opt) if self.verbose: @@ -102,7 +102,7 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): X, y, w, Xw, datafit, penalty, delta_w_ws, X_delta_w_ws, ws) # check convergence - opt_in = penalty.subdiff_distance(w, -grad_ws, ws) + opt_in = penalty.subdiff_distance(w, grad_ws, ws) stop_crit_in = np.max(opt_in) if max(self.verbose-1, 0): @@ -223,7 +223,7 @@ def _backtrack_line_search(X, y, w, Xw, datafit, penalty, delta_w_ws, ptr += len(grp_g_indices) Xw += (step - prev_step) * X_delta_w_ws - grad_ws = -_construct_grad(X, y, w, Xw, datafit, ws) + grad_ws = _construct_grad(X, y, w, Xw, datafit, ws) # TODO: could be improved by passing in w[ws] stop_crit = penalty.value(w) - old_penalty_val @@ -249,17 +249,17 @@ def _construct_grad(X, y, w, Xw, datafit, ws): n_features_ws = sum([grp_ptr[g+1] - grp_ptr[g] for g in ws]) raw_grad = datafit.raw_grad(y, Xw) - minus_grad = np.zeros(n_features_ws) + grad = np.zeros(n_features_ws) ptr = 0 for g in ws: # compute grad_g grp_g_indices = grp_indices[grp_ptr[g]:grp_ptr[g+1]] for j in grp_g_indices: - minus_grad[ptr] = -X[:, j] @ raw_grad + grad[ptr] = X[:, j] @ raw_grad ptr += 1 - return minus_grad + return grad @njit From f98ba6490575026f699e66a0fb37882ce6f2dca7 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Wed, 19 Oct 2022 13:43:30 +0200 Subject: [PATCH 16/25] fix verbose && more info comments --- skglm/solvers/group_prox_newton.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/skglm/solvers/group_prox_newton.py b/skglm/solvers/group_prox_newton.py index 1d280715c..8d0383ebb 100644 --- a/skglm/solvers/group_prox_newton.py +++ b/skglm/solvers/group_prox_newton.py @@ -108,7 +108,7 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): if max(self.verbose-1, 0): p_obj = datafit.value(y, w, Xw) + penalty.value(w) print( - f"PN iteration {pn_iter+1,}: {p_obj:.10f}, " + f"PN iteration {pn_iter+1}: {p_obj:.10f}, " f"stopping crit in: {stop_crit_in:.2e}" ) @@ -186,7 +186,7 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, grad_ws, datafit, ptr = 0 for g in ws: grp_g_indices = grp_indices[grp_ptr[g]:grp_ptr[g+1]] - current_w[grp_g_indices] = w_ws[ptr: ptr+len(grp_g_indices)] + current_w[grp_g_indices] = w_ws[ptr:ptr+len(grp_g_indices)] ptr += len(grp_g_indices) opt = penalty.subdiff_distance(current_w, past_grads, ws) @@ -219,7 +219,7 @@ def _backtrack_line_search(X, y, w, Xw, datafit, penalty, delta_w_ws, for g in ws: grp_g_indices = grp_indices[grp_ptr[g]:grp_ptr[g+1]] w[grp_g_indices] += ((step - prev_step) * - delta_w_ws[ptr: ptr + len(grp_g_indices)]) + delta_w_ws[ptr:ptr+len(grp_g_indices)]) ptr += len(grp_g_indices) Xw += (step - prev_step) * X_delta_w_ws @@ -264,7 +264,7 @@ def _construct_grad(X, y, w, Xw, datafit, ws): @njit def _slice_array(arr, ws, grp_ptr, grp_indices): - # returns [arr[ws_1], arr[ws_2], ...] + # returns h stacked (arr[ws_1], arr[ws_2], ...) n_features_ws = sum([grp_ptr[g+1] - grp_ptr[g] for g in ws]) sliced_arr = np.zeros(n_features_ws) @@ -280,6 +280,7 @@ def _slice_array(arr, ws, grp_ptr, grp_indices): @njit def _update_X_delta_w_ws(X, X_delta_w_ws, w_ws_g, old_w_ws_g, grp_g_indices): # X_delta_w_ws += X[:, grp_g_indices] @ (w_ws_g - old_w_ws_g) + # but without copying the cols of X for idx, j in enumerate(grp_g_indices): delta_w_j = w_ws_g[idx] - old_w_ws_g[idx] if w_ws_g[idx] != old_w_ws_g[idx]: @@ -289,6 +290,7 @@ def _update_X_delta_w_ws(X, X_delta_w_ws, w_ws_g, old_w_ws_g, grp_g_indices): @njit def _X_g_T_dot_vec(X, vec, grp_g_indices): # X[:, grp_g_indices].T @ vec + # but without copying the cols os X result = np.zeros(len(grp_g_indices)) for idx, j in enumerate(grp_g_indices): result[idx] = X[:, j] @ vec @@ -298,6 +300,7 @@ def _X_g_T_dot_vec(X, vec, grp_g_indices): @njit def _diag_times_X_g(diag, X, grp_g_indices): # np.diag(dig) @ X[:, grp_g_indices] + # but without copying the cols of X result = np.zeros((len(diag), len(grp_g_indices))) for idx, j in enumerate(grp_g_indices): result[:, idx] = diag * X[:, j] From 06e549ca9d46604ebbd876297ffd43b908bb355e Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Wed, 19 Oct 2022 13:59:34 +0200 Subject: [PATCH 17/25] remove profile_script && revert rng in test --- profile_script.py | 71 --------------------------------------- skglm/tests/test_group.py | 11 +++--- 2 files changed, 6 insertions(+), 76 deletions(-) delete mode 100644 profile_script.py diff --git a/profile_script.py b/profile_script.py deleted file mode 100644 index 5900e8657..000000000 --- a/profile_script.py +++ /dev/null @@ -1,71 +0,0 @@ -import numpy as np -from numpy.linalg import norm -from skglm.utils import make_correlated_data, compiled_clone -from skglm.solvers import GroupProxNewton -from skglm.datafits import LogisticGroup -from skglm.penalties import WeightedGroupL2 - -from skglm.solvers.group_prox_newton import _descent_direction, _backtrack_line_search - -import line_profiler - - -def _generate_random_grp(n_groups, n_features, shuffle=True): - grp_indices = np.arange(n_features, dtype=np.int32) - np.random.seed(0) - if shuffle: - np.random.shuffle(grp_indices) - splits = np.random.choice( - n_features, size=n_groups+1, replace=False).astype(np.int32) - splits.sort() - splits[0], splits[-1] = 0, n_features - - groups = [list(grp_indices[splits[i]: splits[i+1]]) - for i in range(n_groups)] - - return grp_indices, splits, groups - - -###### -rho = 1e-1 -n_groups = 100 -n_samples, n_features, shuffle = 500, 5000, True -random_state = 123 - -X, y, _ = make_correlated_data(n_samples, n_features, rho=0.3, - random_state=random_state) -y = np.sign(y) - -np.random.seed(random_state) -weights = np.ones(n_groups) -grp_indices, grp_ptr, _ = _generate_random_grp(n_groups, n_features, shuffle) - -alpha_max = 0. -for g in range(n_groups): - grp_g_indices = grp_indices[grp_ptr[g]: grp_ptr[g+1]] - alpha_max = max( - alpha_max, - norm(X[:, grp_g_indices].T @ y) / n_samples / weights[g] - ) -alpha = rho * alpha_max - - -# skglm -log_group = LogisticGroup(grp_ptr=grp_ptr, grp_indices=grp_indices) -group_penalty = WeightedGroupL2(alpha, weights, grp_ptr, grp_indices) - -log_group = compiled_clone(log_group, to_float32=X.dtype == np.float32) -group_penalty = compiled_clone(group_penalty) - -# cache numba jit compilation -solver = GroupProxNewton(tol=1e-9, fit_intercept=False) -stop_crit = solver.solve(X, y, log_group, group_penalty)[2] -print(stop_crit) - - -# profile code -profiler = line_profiler.LineProfiler() -profiler.add_function(solver.solve) -profiler.enable_by_count() -solver.solve(X, y, log_group, group_penalty) -profiler.print_stats() diff --git a/skglm/tests/test_group.py b/skglm/tests/test_group.py index a42f5e9d1..b02f5f8a4 100644 --- a/skglm/tests/test_group.py +++ b/skglm/tests/test_group.py @@ -168,8 +168,8 @@ def test_intercept_grouplasso(): product([GroupBCD, GroupProxNewton], [1e-1, 1e-2])) def test_equivalence_logreg(solver, rho): n_samples, n_features = 30, 50 - rnd = np.random.RandomState(1123) - X, y, _ = make_correlated_data(n_samples, n_features, random_state=rnd) + rng = np.random.RandomState(1123) + X, y, _ = make_correlated_data(n_samples, n_features, random_state=rng) y = np.sign(y) grp_indices, grp_ptr = grp_converter(1, n_features) @@ -198,12 +198,13 @@ def test_equivalence_logreg(solver, rho): def test_group_logreg(solver, n_groups, rho): n_samples, n_features, shuffle = 30, 60, True random_state = 123 + rng = np.random.RandomState(random_state) - X, y, _ = make_correlated_data(n_samples, n_features, random_state=random_state) + X, y, _ = make_correlated_data(n_samples, n_features, random_state=rng) y = np.sign(y) - np.random.seed(random_state) - weights = np.abs(np.random.randn(n_groups)) + rng.seed(random_state) + weights = np.abs(rng.randn(n_groups)) grp_indices, grp_ptr, _ = _generate_random_grp(n_groups, n_features, shuffle) alpha_max = _alpha_max_group_lasso(X, y, grp_indices, grp_ptr, weights) From d0ae56c698fe22c78e297c4b243ec5403e0ccb50 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Mon, 24 Oct 2022 21:46:16 +0200 Subject: [PATCH 18/25] add fit ``fit_intercept`` --- skglm/solvers/group_prox_newton.py | 80 ++++++++++++++++++++++++------ 1 file changed, 65 insertions(+), 15 deletions(-) diff --git a/skglm/solvers/group_prox_newton.py b/skglm/solvers/group_prox_newton.py index 8d0383ebb..c882ac0ae 100644 --- a/skglm/solvers/group_prox_newton.py +++ b/skglm/solvers/group_prox_newton.py @@ -55,11 +55,12 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): check_group_compatible(datafit) check_group_compatible(penalty) + fit_intercept = self.fit_intercept n_samples, n_features = X.shape grp_ptr, grp_indices = penalty.grp_ptr, penalty.grp_indices n_groups = len(grp_ptr) - 1 - w = np.zeros(n_features) if w_init is None else w_init + w = np.zeros(n_features + fit_intercept) if w_init is None else w_init Xw = np.zeros(n_samples) if Xw_init is None else Xw_init all_groups = np.arange(n_groups) stop_crit = 0. @@ -72,6 +73,15 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): opt = penalty.subdiff_distance(w, grad, all_groups) stop_crit = np.max(opt) + # optimality of intercept + if fit_intercept: + # gradient w.r.t. intercept (constant features of ones) + intercept_opt = np.abs(np.sum(datafit.raw_grad(y, Xw))) + else: + intercept_opt = 0. + + stop_crit = max(stop_crit, intercept_opt) + if self.verbose: p_obj = datafit.value(y, w, Xw) + penalty.value(w) print( @@ -95,18 +105,29 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): for pn_iter in range(self.max_pn_iter): # find descent direction delta_w_ws, X_delta_w_ws = _descent_direction( - X, y, w, Xw, grad_ws, datafit, penalty, ws, tol=EPS_TOL*tol_in) + X, y, w, Xw, fit_intercept, grad_ws, datafit, penalty, + ws, tol=EPS_TOL*tol_in) # find a suitable step size and in-place update w, Xw grad_ws[:] = _backtrack_line_search( - X, y, w, Xw, datafit, penalty, delta_w_ws, X_delta_w_ws, ws) + X, y, w, Xw, fit_intercept, datafit, penalty, + delta_w_ws, X_delta_w_ws, ws) # check convergence opt_in = penalty.subdiff_distance(w, grad_ws, ws) stop_crit_in = np.max(opt_in) + # optimality of intercept + if fit_intercept: + # gradient w.r.t. intercept (constant features of ones) + intercept_opt_in = np.abs(np.sum(datafit.raw_grad(y, Xw))) + else: + intercept_opt_in = 0. + + stop_crit_in = max(stop_crit_in, intercept_opt_in) + if max(self.verbose-1, 0): - p_obj = datafit.value(y, w, Xw) + penalty.value(w) + p_obj = datafit.value(y, w, Xw) + penalty.value(w[:n_features]) print( f"PN iteration {pn_iter+1}: {p_obj:.10f}, " f"stopping crit in: {stop_crit_in:.2e}" @@ -117,13 +138,13 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): print("Early exit") break - p_obj = datafit.value(y, w, Xw) + penalty.value(w) + p_obj = datafit.value(y, w, Xw) + penalty.value(w[:n_features]) p_objs_out.append(p_obj) return w, np.asarray(p_objs_out), stop_crit @njit -def _descent_direction(X, y, w_epoch, Xw_epoch, grad_ws, datafit, +def _descent_direction(X, y, w_epoch, Xw_epoch, fit_intercept, grad_ws, datafit, penalty, ws, tol): # given: # 1) b = \nabla F(X w_epoch) @@ -145,11 +166,15 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, grad_ws, datafit, lipchitz[idx] = norm(_diag_times_X_g( np.sqrt(raw_hess), X, grp_g_indices), ord=2)**2 + if fit_intercept: + lipchitz_intercept = np.sum(raw_hess) + grad_intercept = np.sum(datafit.raw_grad(y, Xw_epoch)) + # for a less costly stopping criterion, we do no compute the exact gradient, # but store each coordinate-wise gradient every time we update one coordinate: past_grads = np.zeros(n_features_ws) X_delta_w_ws = np.zeros(X.shape[0]) - w_ws = _slice_array(w_epoch, ws, grp_ptr, grp_indices) + w_ws = _slice_array(w_epoch, ws, grp_ptr, grp_indices, fit_intercept) for cd_iter in range(MAX_CD_ITER): ptr = 0 @@ -178,6 +203,15 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, grad_ws, datafit, ptr += len(grp_g_indices) + # intercept update + if fit_intercept: + past_grads_intercept = grad_intercept + raw_hess @ X_delta_w_ws + old_intercept = w_ws[-1] + w_ws[-1] -= past_grads_intercept / lipchitz_intercept + + if w_ws[-1] != old_intercept: + X_delta_w_ws += w_ws[-1] - old_intercept + if cd_iter % 5 == 0: # TODO: can be improved by passing in w_ws current_w = w_epoch.copy() @@ -190,16 +224,20 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, grad_ws, datafit, ptr += len(grp_g_indices) opt = penalty.subdiff_distance(current_w, past_grads, ws) - if np.max(opt) <= tol: + stop_crit = np.max(opt) + if fit_intercept: + stop_crit = max(stop_crit, np.abs(past_grads_intercept)) + + if stop_crit <= tol: break # descent direction - delta_w_ws = w_ws - _slice_array(w_epoch, ws, grp_ptr, grp_indices) + delta_w_ws = w_ws - _slice_array(w_epoch, ws, grp_ptr, grp_indices, fit_intercept) return delta_w_ws, X_delta_w_ws @njit -def _backtrack_line_search(X, y, w, Xw, datafit, penalty, delta_w_ws, +def _backtrack_line_search(X, y, w, Xw, fit_intercept, datafit, penalty, delta_w_ws, X_delta_w_ws, ws): # 1) find step in [0, 1] such that: # penalty(w + step * delta_w) - penalty(w) + @@ -208,6 +246,8 @@ def _backtrack_line_search(X, y, w, Xw, datafit, penalty, delta_w_ws, # 2) inplace update of w and Xw and return grad_ws of the last w and Xw grp_ptr, grp_indices = penalty.grp_ptr, penalty.grp_indices step, prev_step = 1., 0. + n_features = X.shape[1] + n_features_ws = sum([grp_ptr[g+1] - grp_ptr[g] for g in ws]) # TODO: could be improved by passing in w[ws] old_penalty_val = penalty.value(w) @@ -222,12 +262,18 @@ def _backtrack_line_search(X, y, w, Xw, datafit, penalty, delta_w_ws, delta_w_ws[ptr:ptr+len(grp_g_indices)]) ptr += len(grp_g_indices) + if fit_intercept: + w[-1] += (step - prev_step) * delta_w_ws[-1] + Xw += (step - prev_step) * X_delta_w_ws - grad_ws = _construct_grad(X, y, w, Xw, datafit, ws) + grad_ws = _construct_grad(X, y, w[:n_features], Xw, datafit, ws) # TODO: could be improved by passing in w[ws] - stop_crit = penalty.value(w) - old_penalty_val - stop_crit += step * grad_ws @ delta_w_ws + stop_crit = penalty.value(w[:-1]) - old_penalty_val + stop_crit += step * grad_ws @ delta_w_ws[:n_features_ws] + + if fit_intercept: + stop_crit += step * delta_w_ws[-1] * np.sum(datafit.raw_grad(y, Xw)) if stop_crit < 0: break @@ -263,10 +309,11 @@ def _construct_grad(X, y, w, Xw, datafit, ws): @njit -def _slice_array(arr, ws, grp_ptr, grp_indices): +def _slice_array(arr, ws, grp_ptr, grp_indices, fit_intercept=False): # returns h stacked (arr[ws_1], arr[ws_2], ...) + # include last element when fit_intercept=True n_features_ws = sum([grp_ptr[g+1] - grp_ptr[g] for g in ws]) - sliced_arr = np.zeros(n_features_ws) + sliced_arr = np.zeros(n_features_ws + fit_intercept) ptr = 0 for g in ws: @@ -274,6 +321,9 @@ def _slice_array(arr, ws, grp_ptr, grp_indices): sliced_arr[ptr: ptr+len(grp_g_indices)] = arr[grp_g_indices] ptr += len(grp_g_indices) + if fit_intercept: + sliced_arr[-1] = arr[-1] + return sliced_arr From b1b00f51b80a2dd2d21b03e60f5ec54dfc4586af Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Mon, 24 Oct 2022 21:46:35 +0200 Subject: [PATCH 19/25] add unittest intercept --- skglm/tests/test_group.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/skglm/tests/test_group.py b/skglm/tests/test_group.py index b02f5f8a4..1ee3d8d00 100644 --- a/skglm/tests/test_group.py +++ b/skglm/tests/test_group.py @@ -193,9 +193,10 @@ def test_equivalence_logreg(solver, rho): np.testing.assert_allclose(sk_logreg.coef_.flatten(), w, atol=1e-6, rtol=1e-5) -@pytest.mark.parametrize("solver, n_groups, rho", - product([GroupBCD, GroupProxNewton], [15, 25], [1e-1, 1e-2])) -def test_group_logreg(solver, n_groups, rho): +@pytest.mark.parametrize("solver, n_groups, rho, fit_intercept", + product([GroupBCD, GroupProxNewton], [15, 25], [1e-1, 1e-2], + [False, True])) +def test_group_logreg(solver, n_groups, rho, fit_intercept): n_samples, n_features, shuffle = 30, 60, True random_state = 123 rng = np.random.RandomState(random_state) @@ -216,7 +217,8 @@ def test_group_logreg(solver, n_groups, rho): group_logistic = compiled_clone(group_logistic, to_float32=X.dtype == np.float32) group_penalty = compiled_clone(group_penalty) - stop_crit = solver(tol=1e-12).solve(X, y, group_logistic, group_penalty)[2] + stop_crit = solver(tol=1e-12, fit_intercept=fit_intercept).solve( + X, y, group_logistic, group_penalty)[2] np.testing.assert_array_less(stop_crit, 1e-12) From 09613fe75b616887a5dcdf59b5caea1813a0d26f Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Tue, 25 Oct 2022 10:08:11 +0200 Subject: [PATCH 20/25] CI trigger From 0567ffd70673973776467572ed383353e170b03a Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Tue, 25 Oct 2022 10:19:12 +0200 Subject: [PATCH 21/25] add group prox to doc --- doc/api.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/api.rst b/doc/api.rst index d05ed8a57..f909eb754 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -73,6 +73,7 @@ Solvers FISTA GramCD GroupBCD + GroupProxNewton MultiTaskBCD ProxNewton From a4dc29a6f6f1c1e7dff01c54dc97f22efd12b4a8 Mon Sep 17 00:00:00 2001 From: mathurinm Date: Tue, 25 Oct 2022 10:55:14 +0200 Subject: [PATCH 22/25] script to comapre against yngvem --- yngvem.py | 45 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 yngvem.py diff --git a/yngvem.py b/yngvem.py new file mode 100644 index 000000000..db53704c5 --- /dev/null +++ b/yngvem.py @@ -0,0 +1,45 @@ +import numpy as np +import matplotlib.pyplot as plt +from group_lasso import LogisticGroupLasso + +from skglm import SparseLogisticRegression + +np.random.seed(0) +X = np.random.randn(6, 10) +y = np.ones(X.shape[0]) +y[:len(y) // 2] = -1 + +X -= X.mean(keepdims=True) + + +alpha_max = np.max(np.abs(X.T @ y)) / (2 * len(y)) + +n_alphas = 75 + +me = np.zeros([n_alphas, X.shape[1]]) +them = me.copy() + +us = SparseLogisticRegression( + alpha=alpha_max, fit_intercept=False, verbose=1, warm_start=True, tol=1e-10) +alphas = alpha_max * np.geomspace(1, 0.01, num=n_alphas) + +for idx, alpha in enumerate(alphas): + clf = LogisticGroupLasso( + groups=np.arange(X.shape[1]), group_reg=alpha, l1_reg=0, fit_intercept=False, + old_regularisation=False, supress_warning=True, tol=1e-10) + + clf.fit(X, y) + them[idx] = clf.coef_[:, 1] + us.alpha = alpha + us.fit(X, y) + me[idx] = us.coef_.squeeze() + +fig, axarr = plt.subplots(1, 2, constrained_layout=True) +axarr[0].semilogx(alphas, me) +axarr[0].set_title("Regularization path skglm") +axarr[1].semilogx(alphas, them) +axarr[1].set_title("Regularization path yngvem") + +axarr[1].set_xlabel("alpha") +axarr[0].set_xlabel("alpha") +plt.show(block=False) From 99482daf44739bf9337992a278067386185049a5 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Tue, 25 Oct 2022 14:28:46 +0200 Subject: [PATCH 23/25] remove plot script --- yngvem.py | 45 --------------------------------------------- 1 file changed, 45 deletions(-) delete mode 100644 yngvem.py diff --git a/yngvem.py b/yngvem.py deleted file mode 100644 index db53704c5..000000000 --- a/yngvem.py +++ /dev/null @@ -1,45 +0,0 @@ -import numpy as np -import matplotlib.pyplot as plt -from group_lasso import LogisticGroupLasso - -from skglm import SparseLogisticRegression - -np.random.seed(0) -X = np.random.randn(6, 10) -y = np.ones(X.shape[0]) -y[:len(y) // 2] = -1 - -X -= X.mean(keepdims=True) - - -alpha_max = np.max(np.abs(X.T @ y)) / (2 * len(y)) - -n_alphas = 75 - -me = np.zeros([n_alphas, X.shape[1]]) -them = me.copy() - -us = SparseLogisticRegression( - alpha=alpha_max, fit_intercept=False, verbose=1, warm_start=True, tol=1e-10) -alphas = alpha_max * np.geomspace(1, 0.01, num=n_alphas) - -for idx, alpha in enumerate(alphas): - clf = LogisticGroupLasso( - groups=np.arange(X.shape[1]), group_reg=alpha, l1_reg=0, fit_intercept=False, - old_regularisation=False, supress_warning=True, tol=1e-10) - - clf.fit(X, y) - them[idx] = clf.coef_[:, 1] - us.alpha = alpha - us.fit(X, y) - me[idx] = us.coef_.squeeze() - -fig, axarr = plt.subplots(1, 2, constrained_layout=True) -axarr[0].semilogx(alphas, me) -axarr[0].set_title("Regularization path skglm") -axarr[1].semilogx(alphas, them) -axarr[1].set_title("Regularization path yngvem") - -axarr[1].set_xlabel("alpha") -axarr[0].set_xlabel("alpha") -plt.show(block=False) From c3ffae313963a1532d4247b8d3e629bd00ad5609 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Fri, 28 Oct 2022 10:06:13 +0200 Subject: [PATCH 24/25] remarks QB3 --- skglm/solvers/group_prox_newton.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/skglm/solvers/group_prox_newton.py b/skglm/solvers/group_prox_newton.py index c882ac0ae..c827bfd0e 100644 --- a/skglm/solvers/group_prox_newton.py +++ b/skglm/solvers/group_prox_newton.py @@ -153,8 +153,8 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, fit_intercept, grad_ws, datafit, # b.T @ X @ delta_w + \ # 1/2 * delta_w.T @ (X.T @ D @ X) @ delta_w + penalty(w) # In BCD, we leverage inequality: - # penalty_g(w_g) + 1/2 ||delta_w_g||_H <= \ - # penalty_g(w_g) + 1/2 * || H || * ||delta_w_g|| + # penalty_g(w_g) + 1/2 ||delta_w_g||^2_H <= \ + # penalty_g(w_g) + 1/2 * || H || * ||delta_w_g||^2 grp_ptr, grp_indices = penalty.grp_ptr, penalty.grp_indices n_features_ws = sum([penalty.grp_ptr[g+1] - penalty.grp_ptr[g] for g in ws]) raw_hess = datafit.raw_hessian(y, Xw_epoch) @@ -162,6 +162,7 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, fit_intercept, grad_ws, datafit, lipchitz = np.zeros(len(ws)) for idx, g in enumerate(ws): grp_g_indices = grp_indices[grp_ptr[g]:grp_ptr[g+1]] + # compute efficiently (few multiplications and avoid copying the cols of X) # norm(X[:, grp_g_indices].T @ np.diag(raw_hess) @ X[:, grp_g_indices], ord=2) lipchitz[idx] = norm(_diag_times_X_g( np.sqrt(raw_hess), X, grp_g_indices), ord=2)**2 From e55e9685e62fdeb005852a60053dd5367f809135 Mon Sep 17 00:00:00 2001 From: Badr-MOUFAD Date: Fri, 28 Oct 2022 17:41:07 +0200 Subject: [PATCH 25/25] comments && typo --- skglm/solvers/group_prox_newton.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/skglm/solvers/group_prox_newton.py b/skglm/solvers/group_prox_newton.py index c827bfd0e..3a92b5341 100644 --- a/skglm/solvers/group_prox_newton.py +++ b/skglm/solvers/group_prox_newton.py @@ -154,7 +154,7 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, fit_intercept, grad_ws, datafit, # 1/2 * delta_w.T @ (X.T @ D @ X) @ delta_w + penalty(w) # In BCD, we leverage inequality: # penalty_g(w_g) + 1/2 ||delta_w_g||^2_H <= \ - # penalty_g(w_g) + 1/2 * || H || * ||delta_w_g||^2 + # penalty_g(w_g) + 1/2 * || H ||^2 * ||delta_w_g||^2 grp_ptr, grp_indices = penalty.grp_ptr, penalty.grp_indices n_features_ws = sum([penalty.grp_ptr[g+1] - penalty.grp_ptr[g] for g in ws]) raw_hess = datafit.raw_hessian(y, Xw_epoch) @@ -198,6 +198,7 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, fit_intercept, grad_ws, datafit, w_ws[range_grp_g] = penalty.prox_1group( old_w_ws_g - stepsize * past_grads[range_grp_g], stepsize, g) + # update X_delta_w_ws without copying the cols of X # X_delta_w_ws += X[:, grp_g_indices] @ (w_ws[range_grp_g] - old_w_ws_g) _update_X_delta_w_ws(X, X_delta_w_ws, w_ws[range_grp_g], old_w_ws_g, grp_g_indices) @@ -341,7 +342,7 @@ def _update_X_delta_w_ws(X, X_delta_w_ws, w_ws_g, old_w_ws_g, grp_g_indices): @njit def _X_g_T_dot_vec(X, vec, grp_g_indices): # X[:, grp_g_indices].T @ vec - # but without copying the cols os X + # but without copying the cols of X result = np.zeros(len(grp_g_indices)) for idx, j in enumerate(grp_g_indices): result[idx] = X[:, j] @ vec