diff --git a/copt/__init__.py b/copt/__init__.py index c24d3c68..752bbc53 100644 --- a/copt/__init__.py +++ b/copt/__init__.py @@ -1,4 +1,4 @@ -__version__ = '0.1.dev0' +__version__ = '0.2.0' from .gradient_descent import minimize_PGD, minimize_DavisYin, minimize_APGD from .primal_dual import fmin_CondatVu diff --git a/copt/stochastic.py b/copt/stochastic.py index dd33acff..78bae0d9 100644 --- a/copt/stochastic.py +++ b/copt/stochastic.py @@ -299,13 +299,11 @@ def minimize_BCD( else: trace_x = np.zeros((0, 0)) stop_flag = np.zeros(1, dtype=np.bool) - trace_func = [] - trace_time = [] - @njit(nogil=True) def _bcd_algorithm( - x, Ax, A_csc_data, A_csc_indices, A_csc_indptr, b, trace_x): + x, Ax, A_csr_data, A_csr_indices, A_csr_indptr, A_csc_data, + A_csc_indices, A_csc_indptr, b, trace_x, job_id): feature_indices = np.arange(n_features) for it in range(1, max_iter): np.random.shuffle(feature_indices) @@ -316,15 +314,28 @@ def _bcd_algorithm( i_idx = A_csc_indices[i_indptr] grad_j += partial_gradient(Ax[i_idx], b[i_idx]) * A_csc_data[i_indptr] / n_samples x_new = prox(x[j] - step_size * (grad_j + f_alpha * x[j]), step_size) + + # .. update Ax .. for i_indptr in range(A_csc_indptr[j], A_csc_indptr[j+1]): i_idx = A_csc_indices[i_indptr] Ax[i_idx] += A_csc_data[i_indptr] * (x_new - x[j]) x[j] = x_new - if trace: - trace_x[it, :] = x + if job_id == 0: + if trace: + trace_x[it, :] = x + # .. recompute Ax .. + for i in range(n_samples): + p = 0. + for j in range(A_csr_indptr[i], A_csr_indptr[i+1]): + j_idx = A_csr_indices[j] + p += x[j_idx] * A_csr_data[j] + # .. copy back to shared memory .. + Ax[i] = p + return it, None X_csc = sparse.csc_matrix(f.A) + X_csr = sparse.csr_matrix(f.A) trace_func = [] start = datetime.now() @@ -335,8 +346,8 @@ def _bcd_algorithm( for job_id in range(n_jobs): futures.append(executor.submit( _bcd_algorithm, - xk, Ax, X_csc.data, X_csc.indices, X_csc.indptr, f.b, - trace_x)) + xk, Ax, X_csr.data, X_csr.indices, X_csr.indptr, X_csc.data, + X_csc.indices, X_csc.indptr, f.b, trace_x, job_id)) concurrent.futures.wait(futures) n_iter, certificate = futures[0].result()