Skip to content

Commit

Permalink
Update lu_row_stabilize so that it errors when necessary. Modulo read…
Browse files Browse the repository at this point in the history
…ability, this example is ready!
  • Loading branch information
rileyjmurray committed May 29, 2024
1 parent 8966842 commit 9a13962
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions examples/sparse-low-rank-approx/qrcp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <numbers>
#include <chrono>
#include <fstream>
#include <stdexcept>


using RandBLAS::sparse_data::COOMatrix;
Expand Down Expand Up @@ -151,11 +152,17 @@ int lu_row_stabilize(int64_t m, int64_t n, T* mat, int64_t* piv_work) {
lapack::getrf(m, n, mat, m, piv_work);
// above: the permutation applied to the rows of mat doesn't matter in our context.
// below: Need to zero-out the strict lower triangle of mat and scale each row.
for (int64_t j = 0; j < m-1; ++j) {
T tol = std::numeric_limits<T>::epsilon()*10;
bool nonzero_diag_U = true;
for (int64_t j = 0; (j < m-1) & nonzero_diag_U; ++j) {
nonzero_diag_U = abs(mat[j + j*m]) > tol;
for (int64_t i = j + 1; i < m; ++i) {
mat[i + j*m] = 0.0;
}
}
if (!nonzero_diag_U) {
throw std::runtime_error("LU stabilization failed. Matrix has been overwritten, so we cannot recover.");
}
for (int64_t i = 0; i < m; ++i) {
T scale = 1.0 / mat[i + i*m];
blas::scal(n, scale, mat + i, m);
Expand Down Expand Up @@ -367,7 +374,7 @@ int runall(int argc, char** argv, StabilizationMethod sm) {

auto start_timer = std_clock::now();
TIMED_LINE(
power_iter_col_sketch(mat_csc, k, R, 1, state, Q, sm), "\n\tpower iter sketch : ")
power_iter_col_sketch(mat_csc, k, R, 2, state, Q, sm), "\n\tpower iter sketch : ")
print_row_norms(R, k, n, "Yf");
TIMED_LINE(
sketch_to_tqrcp(mat_csc, k, Q, m, R, k, piv), "\n\tsketch to QRCP : ")
Expand All @@ -392,4 +399,6 @@ int main(int argc, char** argv) {
runall(argc, argv, StabilizationMethod::sketch);
std::cout << "Nothing:\n";
runall(argc, argv, StabilizationMethod::None);
std::cout << "LU:\n";
runall(argc, argv, StabilizationMethod::LU);
}

0 comments on commit 9a13962

Please sign in to comment.