diff --git a/test/handrolled_lapack.hh b/test/handrolled_lapack.hh index 8b6c3d52..a8e6cefd 100644 --- a/test/handrolled_lapack.hh +++ b/test/handrolled_lapack.hh @@ -7,7 +7,6 @@ #include - template void potrf_upper_colmajor_sequential(int64_t n, T* A, int64_t lda) { for (int64_t j = 0; j < n; ++j) { @@ -68,43 +67,45 @@ void chol_qr_colmajor(int64_t m, int64_t n, T* A, T* R, int64_t chol_block_size } template -void block_gram_schmidt(int64_t m, int64_t n, T* A, std::vector &work, int64_t b = 64) { +void qr_block_cgs(int64_t m, int64_t n, T* A, T* R, std::vector& work, int64_t b) { + if (n > m) + throw std::runtime_error("Invalid dimensions."); + b = std::min(b, n); - if (work.size() < n*b) { - work.resize(n*b); + if (work.size() < n * b) { + work.resize(n * b); } auto layout = blas::Layout::ColMajor; using blas::Op; chol_qr_colmajor(m, b, A, work.data(), b); - T one = 1.0; - T zero = 0.0; + T one = (T) 1.0; + T zero = (T) 0.0; + T* R1 = work.data(); + for (int64_t j = 0; j < b; ++j) + blas::copy(b, R1 + b*j, 1, R + n*j, 1); + if (b < n) { int64_t n_trail = n - b; T* A1 = A; // A[:, :b] - T* A2 = A + b * m; // A[:, b:] + T* A2 = A + m * b; // A[:, b:] + T* R2 = R + n * b; // R[:b, b:] // Compute A1tA2 := A1' * A2 and then update A2 -= A1 * A1tA2 T* A1tA2 = work.data(); blas::gemm(layout, Op::Trans, Op::NoTrans, b, n_trail, m, one, A1, m, A2, m, zero, A1tA2, b); blas::gemm(layout, Op::NoTrans, Op::NoTrans, m, n_trail, b, -one, A1, m, A1tA2, b, one, A2, m); - block_gram_schmidt(m, n - b, A + b * m, work, b); + // Copy A1tA2 to the appropriate place in R + for (int64_t j = 0; j < n_trail; ++j) { + blas::copy(b, A1tA2 + j*b, 1, R2 + j*n, 1); + } + qr_block_cgs(m, n_trail, A + b * m, R + b * n + b, work, b); } } template -void block_chol_qr(int64_t n, T* A, T* R, int64_t b = 64) { +void qr_block_cgs(int64_t n, T* A, T* R, int64_t b = 64) { b = std::min(b, n); - std::vector work_orth(n*b); - std::vector work_R(n*n); - T* A_copy = work_R.data(); - blas::copy(n*n, A, 1, A_copy, 1); - block_gram_schmidt(n, n, A, work_orth, b); - using blas::Layout; - using blas::Op; - blas::gemm(Layout::ColMajor, Op::Trans, Op::NoTrans, n, n, n, (T)1.0, A, n, A_copy, n, (T) 0.0, R, n); + std::vector work(n * b); + std::fill(R, R + n * n, (T) 0.0); + qr_block_cgs(n, n, A, R, work, b); } -// TODOs: -// 1) change block_gram_schmidt(...) to use modified Gram-Schmidt instead of classic Gram-Schmit. -// 2) merge block_gram_schmidt and block_chol_qr so that block_gram_schmidt builds R as it goes. -// --> This might require some attention to submatrix pointers for R. -///