Skip to content

Commit

Permalink
implement QR via block classical Gram-Schmidt
Browse files Browse the repository at this point in the history
  • Loading branch information
rileyjmurray committed Aug 5, 2024
1 parent 2a00447 commit 2d89a1f
Showing 1 changed file with 23 additions and 22 deletions.
45 changes: 23 additions & 22 deletions test/handrolled_lapack.hh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include <cmath>



template <typename T>
void potrf_upper_colmajor_sequential(int64_t n, T* A, int64_t lda) {
for (int64_t j = 0; j < n; ++j) {
Expand Down Expand Up @@ -68,43 +67,45 @@ void chol_qr_colmajor(int64_t m, int64_t n, T* A, T* R, int64_t chol_block_size
}

template <typename T>
void block_gram_schmidt(int64_t m, int64_t n, T* A, std::vector<T> &work, int64_t b = 64) {
void qr_block_cgs(int64_t m, int64_t n, T* A, T* R, std::vector<T>& work, int64_t b) {
if (n > m)
throw std::runtime_error("Invalid dimensions.");

b = std::min(b, n);
if (work.size() < n*b) {
work.resize(n*b);
if (work.size() < n * b) {
work.resize(n * b);
}
auto layout = blas::Layout::ColMajor;
using blas::Op;
chol_qr_colmajor(m, b, A, work.data(), b);
T one = 1.0;
T zero = 0.0;
T one = (T) 1.0;
T zero = (T) 0.0;
T* R1 = work.data();
for (int64_t j = 0; j < b; ++j)
blas::copy(b, R1 + b*j, 1, R + n*j, 1);

if (b < n) {
int64_t n_trail = n - b;
T* A1 = A; // A[:, :b]
T* A2 = A + b * m; // A[:, b:]
T* A2 = A + m * b; // A[:, b:]
T* R2 = R + n * b; // R[:b, b:]
// Compute A1tA2 := A1' * A2 and then update A2 -= A1 * A1tA2
T* A1tA2 = work.data();
blas::gemm(layout, Op::Trans, Op::NoTrans, b, n_trail, m, one, A1, m, A2, m, zero, A1tA2, b);
blas::gemm(layout, Op::NoTrans, Op::NoTrans, m, n_trail, b, -one, A1, m, A1tA2, b, one, A2, m);
block_gram_schmidt(m, n - b, A + b * m, work, b);
// Copy A1tA2 to the appropriate place in R
for (int64_t j = 0; j < n_trail; ++j) {
blas::copy(b, A1tA2 + j*b, 1, R2 + j*n, 1);
}
qr_block_cgs(m, n_trail, A + b * m, R + b * n + b, work, b);
}
}

template <typename T>
void block_chol_qr(int64_t n, T* A, T* R, int64_t b = 64) {
void qr_block_cgs(int64_t n, T* A, T* R, int64_t b = 64) {
b = std::min(b, n);
std::vector<T> work_orth(n*b);
std::vector<T> work_R(n*n);
T* A_copy = work_R.data();
blas::copy(n*n, A, 1, A_copy, 1);
block_gram_schmidt(n, n, A, work_orth, b);
using blas::Layout;
using blas::Op;
blas::gemm(Layout::ColMajor, Op::Trans, Op::NoTrans, n, n, n, (T)1.0, A, n, A_copy, n, (T) 0.0, R, n);
std::vector<T> work(n * b);
std::fill(R, R + n * n, (T) 0.0);
qr_block_cgs(n, n, A, R, work, b);
}

// TODOs:
// 1) change block_gram_schmidt(...) to use modified Gram-Schmidt instead of classic Gram-Schmit.
// 2) merge block_gram_schmidt and block_chol_qr so that block_gram_schmidt builds R as it goes.
// --> This might require some attention to submatrix pointers for R.
///

0 comments on commit 2d89a1f

Please sign in to comment.