Skip to content

Commit

Permalink
set dtype in dist_fix_point_cd
Browse files Browse the repository at this point in the history
  • Loading branch information
Badr-MOUFAD committed Nov 8, 2023
1 parent 87fbfcd commit c336643
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion skglm/solvers/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def dist_fix_point_cd(w, grad_ws, lipschitz, datafit, penalty, ws):
dist : array, shape (n_features,)
Violation score for every feature.
"""
dist = np.zeros(ws.shape[0])
dist = np.zeros(ws.shape[0], dtype=w.dtype)

for idx, j in enumerate(ws):
if lipschitz[j] == 0.:
Expand Down
2 changes: 1 addition & 1 deletion skglm/solvers/prox_newton.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
return w, np.asarray(p_objs_out), stop_crit


# @njit
@njit
def _descent_direction(X, y, w_epoch, Xw_epoch, fit_intercept, grad_ws, datafit,
penalty, ws, tol, ws_strategy):
# Given:
Expand Down

0 comments on commit c336643

Please sign in to comment.