diff --git a/RandBLAS/DevNotes.md b/RandBLAS/DevNotes.md new file mode 100644 index 00000000..d9266221 --- /dev/null +++ b/RandBLAS/DevNotes.md @@ -0,0 +1,39 @@ +# Developer Notes for RandBLAS + +This file has discussions of RandBLAS' implementation that aren't (currently) suitable +for the RandBLAS User Guide. + +## What's where? + + * Our basic random number generation is handled by [Random123](https://github.com/DEShawResearch/random123). + We have small wrappers around Random123 code in ``RandBLAS/base.hh`` and ``RandBLAS/random_gen.hh``. + + * ``RandBLAS/dense_skops.hh`` has code for representing and sampling dense sketching operators. + The sampling code is complicated because it supports multi-threaded (yet threading invariant!) + random (sub)matrix generation. + + * ``RandBLAS/sparse_skops.hh`` has code for representing and sampling sparse sketching operators. + The sampling code has a customized method for repeatedly sampling from an index set without + replacement, which is needed to quickly generate the structures used in statistically reliable + sparse sketching operators. + + * [BLAS++ (aka blaspp)](https://github.com/icl-utk-edu/blaspp) is our portability layer for BLAS. + We actually use very few functions in BLAS at time of writing (GEMM, GEMV, SCAL, COPY, and + AXPY) but we use its enumerations _everywhere_. Fast GEMM is important for sketching dense + data with dense operators. + + * The ``sketch_general`` functions in ``RandBLAS/skge.hh`` are the main entry point for sketching dense data. + These functions are small wrappers around functions with more BLAS-like names: + * ``lskge3`` and ``rskge3`` in ``RandBLAS/skge3_to_gemm.hh``. + * ``lskges`` and ``rskges`` in ``RandBLAS/skges_to_spmm.hh``. + The former pair of functions are just fancy wrappers around GEMM. + The latter pair of functions trigger a far more opaque call sequence, since they rely on sparse + matrix operations. + + * There is no widely accepted standard for sparse BLAS operations. This is a bummer because + sparse matrices are super important in data science and scientific computing. In view of this, + RandBLAS provides its own abstractions for sparse matrices (CSC, CSR, and COO formats). + The abstractions can either own their associated data or just wrap existing data (say, data + attached to a sparse matrix in Eigen). RandBLAS has reasonably flexible and high-performance code + for multiplying a sparse matrix and a dense matrix. All code related to sparse matrices is in + ``RandBLAS/sparse_data``. See that folder's ``DevNotes.md`` file for details. diff --git a/RandBLAS/sparse_data/DevNotes.md b/RandBLAS/sparse_data/DevNotes.md new file mode 100644 index 00000000..66ebe532 --- /dev/null +++ b/RandBLAS/sparse_data/DevNotes.md @@ -0,0 +1,76 @@ +# Developer Notes for RandBLAS' sparse matrix functionality + +RandBLAS provides abstractions for CSC, CSR, and COO-format sparse matrices. +The following functions use these abstractions: + + * ``left_spmm``, which computes a product of a sparse matrix and a dense matrix when the sparse matrix + is the left operand. This function is GEMM-like, in that it allows offsets and transposition flags + for either argument. + * ``right_spmm``, which is analogous to ``left_spmm`` when the sparse matrix is the right operand. + * ``sketch_general``, when called with a SparseSkOp object. + * ``sketch_sparse``, when called with a DenseSkOp object. + +Each of those functions is merely a _dispatcher_ of other (lower level) functions. See below for details on +how the dispatching works. + +## Left_spmm and right_spmm + +These functions are implemented in ``RandBLAS/sparse_data/spmm_dispatch.hh``. + +``right_spmm`` is implemented by falling back on ``left_spmm`` with transformed +values for ``opA, opB`` and ``layout``. +Here's what happens if ``left_spmm`` is called with a sparse matrix ``A``, a dense input matrix ``B``, and a dense output matrix ``C``. + + 1. If needed, transposition of ``A`` is resolved by creating a lightweight object for the transpose + called ``At``. This object is just a tool for us to change how we intrepret the buffers that underlie ``A``. + * If ``A`` is COO, then ``At`` will also be COO. + * If ``A`` is CSR, then ``At`` will be CSC. + * If ``A`` is CSC, then ``At`` will be CSR. + + We make a recursive call to ``left_spmm`` once we have our hands on ``At``, so + the rest of ``left_spmm``'s logic only needs to handle un-transposed ``A``. + + 2. A memory layout is determined for how we'll read ``B`` in the low-level + sparse matrix multiplication kernels. + * If ``B`` is un-transposed then we'll use the same layout as ``C``. + * If ``B`` is transposed then we'll swap its declared dimensions + (i.e., we'll swap its reported numbers of rows and columns) and + and we'll tell the kernel to read it in the opposite layout as ``C``. + + 3. We dispatch a kernel from ``coo_spmm_impl.hh``, or ``csc_spmm_impl.hh``, + or ``csr_spmm_impl.h``. The precise kernel depends on the type of ``A``, and the inferred layout for ``B``, and the declared layout for ``C``. + +## Sketching dense data with sparse operators. + +Sketching dense data with a sparse operator is typically handled with ``sketch_general``, +which is defined in ``skge.hh``. + +If we call this function with a SparseSkOp object, ``S``, we'd immediately get routed to +a function in ``skges_to_spmm.hh``: either ``lskges`` or ``rskges``. Here's what would happen +after we entered one of those functions: + + 1. If necessary, we'd sample the defining data of ``S`` with ``RandBLAS::fill_sparse(S)``. + + 2. We'd obtain a lightweight view of ``S`` as a COOMatrix, and we'd pass that matrix to ``left_spmm`` + (if inside ``lskges``) or ``right_spmm`` (if inside ``rskges``). + + +## Sketching sparse data with dense operators + +If we call ``sketch_sparse`` with a DenseSkOp, ``S``, and a sparse matrix, ``A``, then we'll get routed to either +``lsksp3`` or ``rsksp3`` in ``sparse_data/sksp3_to_spmm.hh``. + +From there, we'll do the following. + + 1. If necessary, we sample the defining data of ``S``. The way that we do this is a + little more complicated than using ``RandBLAS::fill_dense(S)``, but it's similar + in spirit. + + 2. We get our hands on the simple buffer representation of ``S``. From there ... + * We call ``right_spmm`` if we're inside ``lsksp3``. + * We call ``left_spmm`` if we're inside ``rsksp3``. + + Note that the ``l`` and ``r`` in the ``[l/r]sksp3`` function names + get matched to opposite sides for ``[left/right]_spmm``! This is because all the fancy abstractions in ``S`` have been stripped away by this point in the call sequence, so the "side" that we emphasize in function names changes + from emphasizing ``S`` to emphasizing ``A``. + diff --git a/RandBLAS/sparse_data/base.hh b/RandBLAS/sparse_data/base.hh index e5ce1300..883f0cc0 100644 --- a/RandBLAS/sparse_data/base.hh +++ b/RandBLAS/sparse_data/base.hh @@ -63,6 +63,12 @@ static inline void sorted_nonzero_locations_to_pointer_array( return; } +// Idea: change all "const" attributes to for SpMatrix to return values from inlined functions. +// Looks like there'd be no collision with function/property names for sparse matrix +// types in Eigen, SuiteSparse, OneMKL, etc.. These inlined functions could return +// nomincally public members like A._n_rows and A._n_cols, which the user will only change +// at their own peril. + // ============================================================================= /// @verbatim embed:rst:leading-slashes /// @@ -145,9 +151,11 @@ concept SparseMatrix = requires(SpMat A) { { A.n_cols } -> std::convertible_to; { A.nnz } -> std::convertible_to; { *(A.vals) } -> std::convertible_to; + { SpMat(A.n_rows, A.n_cols) }; + // ^ Is there better way to require a two-argument constructor? { A.own_memory } -> std::convertible_to; - { SpMat(A.n_rows, A.n_cols) }; // Is there better way to require a two-argument constructor? - { A.reserve((int64_t) 10) }; // This will always compile, even though it might error at runtime. + // { A.reserve((int64_t) 10) }; + // ^ Problem: const SpMat objects fail that check. }; } // end namespace RandBLAS::sparse_data diff --git a/RandBLAS/sparse_data/csc_spmm_impl.hh b/RandBLAS/sparse_data/csc_spmm_impl.hh index 1c66e1c5..fbb5a8aa 100644 --- a/RandBLAS/sparse_data/csc_spmm_impl.hh +++ b/RandBLAS/sparse_data/csc_spmm_impl.hh @@ -121,5 +121,65 @@ static void apply_csc_left_jki_p11( return; } +template +static void apply_csc_left_kib_rowmajor_1p1( + T alpha, + int64_t d, + int64_t n, + int64_t m, + CSCMatrix &A, + const T *B, + int64_t ldb, + T *C, + int64_t ldc +) { + randblas_require(A.index_base == IndexBase::Zero); + + randblas_require(d == A.n_rows); + randblas_require(m == A.n_cols); + + + int num_threads = 1; + #if defined(RandBLAS_HAS_OpenMP) + #pragma omp parallel + { + num_threads = omp_get_num_threads(); + } + #endif + + int* block_bounds = new int[num_threads + 1]{}; + int block_size = d / num_threads; + if (block_size == 0) { block_size = 1;} + for (int t = 0; t < num_threads; ++t) + block_bounds[t+1] = block_bounds[t] + block_size; + block_bounds[num_threads] += d % num_threads; + + #pragma omp parallel default(shared) + { + #if defined(RandBLAS_HAS_OpenMP) + int t = omp_get_thread_num(); + #else + int t = 0; + #endif + int i_lower = block_bounds[t]; + int i_upper = block_bounds[t+1]; + for (int64_t k = 0; k < m; ++k) { + // Rank-1 update: C[:,:] += A[:,k] @ B[k,:] + const T* row_B = &B[k*ldb]; + for (int64_t ell = A.colptr[k]; ell < A.colptr[k+1]; ++ell) { + int64_t i = A.rowidxs[ell]; + if (i_lower <= i && i < i_upper) { + T* row_C = &C[i*ldc]; + T scale = alpha * A.vals[ell]; + blas::axpy(n, scale, row_B, 1, row_C, 1); + } + } + } + } + + delete [] block_bounds; + return; +} + } #endif diff --git a/RandBLAS/sparse_data/csr_spmm_impl.hh b/RandBLAS/sparse_data/csr_spmm_impl.hh index ae662b0e..2e27a85d 100644 --- a/RandBLAS/sparse_data/csr_spmm_impl.hh +++ b/RandBLAS/sparse_data/csr_spmm_impl.hh @@ -24,12 +24,13 @@ static void apply_csr_to_vector_from_left_ik( int64_t incAv // stride between elements of Av ) { for (int64_t i = 0; i < len_Av; ++i) { + T Av_i = Av[i*incAv]; for (int64_t ell = rowptr[i]; ell < rowptr[i+1]; ++ell) { int j = colidxs[ell]; T Aij = vals[ell]; - Av[i*incAv] += Aij * v[j*incv]; - // ^ if v were a matrix, this could be an axpy with the j-th row of v, accumulated into i-th row of Av. + Av_i += Aij * v[j*incv]; } + Av[i*incAv] = Av_i; } } @@ -86,6 +87,42 @@ static void apply_csr_left_jik_p11( return; } +template +static void apply_csr_left_ikb_rowmajor( + T alpha, + int64_t d, + int64_t n, + int64_t m, + CSRMatrix &A, + const T *B, + int64_t ldb, + T *C, + int64_t ldc +) { + randblas_require(A.index_base == IndexBase::Zero); + + randblas_require(d == A.n_rows); + randblas_require(m == A.n_cols); + + #pragma omp parallel default(shared) + { + #pragma omp for schedule(dynamic) + for (int64_t i = 0; i < d; ++i) { + // C[i, 0:n] += alpha * A[i, :] @ B[:, 0:n] + T* row_C = &C[i*ldc]; + for (int64_t ell = A.rowptr[i]; ell < A.rowptr[i+1]; ++ell) { + // we're working with A[i,k] for k = A.colidxs[ell] + // compute C[i, 0:n] += alpha * A[i,k] * B[k, 0:n] + T scale = alpha * A.vals[ell]; + int64_t k = A.colidxs[ell]; + const T* row_B = &B[k*ldb]; + blas::axpy(n, scale, row_B, 1, row_C, 1); + } + } + } + return; +} + } // end namespace RandBLAS::sparse_data::csr #endif diff --git a/RandBLAS/sparse_data/spmm_dispatch.hh b/RandBLAS/sparse_data/spmm_dispatch.hh index e2e75133..0256dbee 100644 --- a/RandBLAS/sparse_data/spmm_dispatch.hh +++ b/RandBLAS/sparse_data/spmm_dispatch.hh @@ -110,11 +110,22 @@ void left_spmm( using RandBLAS::sparse_data::coo::apply_coo_left_jki_p11; apply_coo_left_jki_p11(alpha, layout_opB, layout_C, d, n, m, A, ro_a, co_a, B, ldb, C, ldc); } else if constexpr (is_csc) { - using RandBLAS::sparse_data::csc::apply_csc_left_jki_p11; - apply_csc_left_jki_p11(alpha, layout_opB, layout_C, d, n, m, A, B, ldb, C, ldc); + if (layout_opB == Layout::RowMajor && layout_C == Layout::RowMajor) { + using RandBLAS::sparse_data::csc::apply_csc_left_kib_rowmajor_1p1; + apply_csc_left_kib_rowmajor_1p1(alpha, d, n, m, A, B, ldb, C, ldc); + } else { + using RandBLAS::sparse_data::csc::apply_csc_left_jki_p11; + apply_csc_left_jki_p11(alpha, layout_opB, layout_C, d, n, m, A, B, ldb, C, ldc); + } } else { - using RandBLAS::sparse_data::csr::apply_csr_left_jik_p11; - apply_csr_left_jik_p11(alpha, layout_opB, layout_C, d, n, m, A, B, ldb, C, ldc); + if (layout_opB == Layout::RowMajor && layout_C == Layout::RowMajor) { + using RandBLAS::sparse_data::csr::apply_csr_left_ikb_rowmajor; + apply_csr_left_ikb_rowmajor(alpha, d, n, m, A, B, ldb, C, ldc); + } else { + using RandBLAS::sparse_data::csr::apply_csr_left_jik_p11; + apply_csr_left_jik_p11(alpha, layout_opB, layout_C, d, n, m, A, B, ldb, C, ldc); + } + } return; } diff --git a/examples/.gitignore b/examples/.gitignore new file mode 100644 index 00000000..1a393d81 --- /dev/null +++ b/examples/.gitignore @@ -0,0 +1,3 @@ +sparse-data-matrices/* +sparse-low-rank-approx/data-matrices/* +sparse-low-rank-approx/fast-matrix-market/* diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 0f6ffa9b..a6d36735 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -6,46 +6,90 @@ list(APPEND CMAKE_MODULE_PATH "${CMAKE_SOURCE_DIR}/CMake") set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD_REQUIRED True) +set(CMAKE_CXX_FLAGS "-O3") +# ^ THAT'S SO IMPORTANT!!!! +# TODO: just set CMAKE_BUILD_TYPE to Release if it's not already given. + +# TODO: do a try-catch pattern for finding OpenMP. +# If we get an error, then check if (1) we're macOS +# and (2) if we're using default gcc/g++ linked to system clang, +# when we need to use clang from homebrew. +find_package(OpenMP REQUIRED) message(STATUS "Checking for RandBLAS ... ") find_package(RandBLAS REQUIRED) message(STATUS "Done checking for RandBLAS. ...") -message(STATUS "Looking for BLAS++ ... ") -find_package(blaspp REQUIRED) -message(STATUS "Done looking for BLAS++.") - - message(STATUS "Looking for LAPACK++ ... ") find_package(lapackpp REQUIRED) message(STATUS "Done looking for LAPACK++.") set( - TLS_DenseSkOp_cxx - TLS_DenseSkOp.cpp + tls_dense_skop_cxx total-least-squares/tls_dense_skop.cc ) add_executable( - TLS_DenseSkOp ${TLS_DenseSkOp_cxx} + tls_dense_skop ${tls_dense_skop_cxx} ) target_include_directories( - TLS_DenseSkOp PUBLIC ${Random123_DIR} + tls_dense_skop PUBLIC ${Random123_DIR} ) target_link_libraries( - TLS_DenseSkOp PUBLIC RandBLAS blaspp lapackpp + tls_dense_skop PUBLIC RandBLAS blaspp lapackpp ) set( - TLS_SparseSkOp_cxx - TLS_SparseSkOp.cpp + tls_sparse_skop_cxx total-least-squares/tls_sparse_skop.cc +) +add_executable( + tls_sparse_skop ${tls_sparse_skop_cxx} +) +target_include_directories( + tls_sparse_skop PUBLIC ${Random123_DIR} +) +target_link_libraries( + tls_sparse_skop PUBLIC RandBLAS blaspp lapackpp +) + +add_executable( + slra_svd_synthetic sparse-low-rank-approx/svd_rank1_plus_noise.cc +) +target_include_directories( + slra_svd_synthetic PUBLIC ${Random123_DIR} ) +target_link_libraries( + slra_svd_synthetic PUBLIC RandBLAS blaspp lapackpp +) + + +include(FetchContent) +FetchContent_Declare( + fast_matrix_market + GIT_REPOSITORY https://github.com/alugowski/fast_matrix_market + GIT_TAG main + GIT_SHALLOW TRUE +) +FetchContent_MakeAvailable( + fast_matrix_market +) + +add_executable( + slra_svd_fmm sparse-low-rank-approx/svd_matrixmarket.cc +) +target_include_directories( + slra_svd_fmm PUBLIC ${Random123_DIR} +) +target_link_libraries( + slra_svd_fmm PUBLIC RandBLAS blaspp lapackpp fast_matrix_market::fast_matrix_market +) + add_executable( - TLS_SparseSkOp ${TLS_SparseSkOp_cxx} + slra_qrcp sparse-low-rank-approx/qrcp_matrixmarket.cc ) target_include_directories( - TLS_SparseSkOp PUBLIC ${Random123_DIR} + slra_qrcp PUBLIC ${Random123_DIR} ) target_link_libraries( - TLS_SparseSkOp PUBLIC RandBLAS blaspp lapackpp + slra_qrcp PUBLIC RandBLAS blaspp lapackpp fast_matrix_market::fast_matrix_market ) diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 00000000..53cc3a6c --- /dev/null +++ b/examples/README.md @@ -0,0 +1,26 @@ +# RandBLAS examples + +Files in this directory show how RandBLAS can be used to implement high-level RandNLA algorithms. +Right now we have two types of examples. + 1. A sketch-and-solve approach to the total least squares problem. (Two executables.) + 2. Basic methods for low-rank approximation of sparse matrices. (Three executables.) + +There is a _lot_ of code duplication within the ``.cc`` files for a given type of example. +This is a necessary evil to ensure that each example file can stand on its own. + +## Building the examples +The examples are built with CMake. Before building the examples you have to build _and install_ +RandBLAS and lapackpp. + +We've given an example CMake configuration line below. The values we use in the configuration line +assume that ``cmake`` is invoked in a ``build`` folder that's one level beneath this file's directory. +The values also reflect specific install locations for RandBLAS and lapackpp relative to where the +command is invoked; your situation might require specifying different paths. + +```shell +cmake -DCMAKE_BINARY_DIR=`pwd` -DRandBLAS_DIR=`pwd`/../../../RandBLAS-install/lib/cmake -Dlapackpp_DIR=`pwd`/../../../lapackpp-install/lib/cmake/lapackpp .. +``` + +The curious are welcome to look at the examples' CMakeLists.txt file. That file also contains +a lot of code duplication that could be avoided with some CMake-foo, but we chose to keep things +verbose so others can easily copy-paste portions of the CMakeLists.txt into their own codebase. diff --git a/examples/TLS_DenseSkOp.cc b/examples/TLS_DenseSkOp.cc deleted file mode 100644 index 464a035d..00000000 --- a/examples/TLS_DenseSkOp.cc +++ /dev/null @@ -1,174 +0,0 @@ -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -using std::chrono::high_resolution_clock; -using std::chrono::duration_cast; -using std::chrono::duration; -using std::chrono::milliseconds; - - -void init_noisy_data(int64_t m, int64_t n, int64_t d, double* AB){ - double target_x[n*d]; - double eps[m*d]; - for (int i = 0; i < n; i++) { - target_x[i] = 1; // Target X is the vector of 1's - } - - RandBLAS::DenseDist Dist_A(m,n); - RandBLAS::DenseDist Dist_eps(m,d); - auto state = RandBLAS::RNGState(0); - auto state1 = RandBLAS::RNGState(1); - - RandBLAS::fill_dense(Dist_A, AB, state); //Fill A to be a random gaussian - RandBLAS::fill_dense(Dist_eps, eps, state1); //Fill A to be a random gaussian - - blas::gemm(blas::Layout::ColMajor, blas::Op::NoTrans, blas::Op::NoTrans, m, d, n, 1, AB, m, target_x, n, 0, &AB[m*n], m); - - for (int i = 0; i < m*d; i++){ - AB[m*n + i] += eps[i]; // Add Gaussian Noise to right hand side - } -} - -/* Let A be a tall data matrix of dimensions m by n where m > n and b be a vector of dimension m. - * In ordinary least squares it assumes that the error lies only in the right hand side vector b, - * and it aims to find a vector x that minimizes ||A*x - b||_2. - * On the other hand, total least squares assumes that the input data matrix A could also incur errors. - * Total least squares aims to find a solution where the error is orthogonal to the regression model. - */ - -// To call the executable run ./TLS_DenseSkOp where are the number of rows and columns -// of A respectively. We expect m > 2*n. -int main(int argc, char* argv[]){ - - // Initialize dimensions - int64_t m; // Number of rows of A, B - int64_t n; // Number of columns of A - - if (argc == 1) { - m = 10000; - n = 500; - } else if (argc == 3) { - m = atoi(argv[1]); - n = atoi(argv[2]); - if (n > m) { - std::cout << "Make sure number of rows are greater than number of cols" << '\n'; - exit(0); - } - } else { - std::cout << "Invalid arguments" << '\n'; - exit(1); - } - - int64_t sk_dim = 2*(n+1); - - // Initialize workspace - double *AB = new double[m*(n + 1)]; // Store [A B] in column major format - double *SAB = new double[sk_dim*(n+1)]; - double *X = new double[n]; - double *res = new double[n]; - - // Initialize workspace for the sketched svd - double *U = new double[sk_dim*sk_dim]; - double *svals = new double[n+1]; - double *VT = new double[(n+1)*(n+1)]; - - // Initialize noisy gaussian data - init_noisy_data(m, n, 1, AB); - - // Define properties of the sketching operator - - // Initialize seed for random number generation - uint32_t seed = 0; - - // Define the dense distribution that the sketching operator will sample from - /* Additional dense distributions: RandBLAS::DenseDistName::Uniform - entries are iid drawn uniform [-1,1] - * RandBLAS::DenseDistName::BlackBox - entires are user provided through a buffer - */ - auto time_constructsketch1 = high_resolution_clock::now(); - RandBLAS::DenseDistName dn = RandBLAS::DenseDistName::Gaussian; - - // Initialize dense distribution struct for the sketching operator - RandBLAS::DenseDist Dist(sk_dim, // Number of rows of the sketching operator - m, // Number of columns of the sketching operator - dn); // Distribution of the entires - - //Construct the dense sketching operator - RandBLAS::DenseSkOp S(Dist, seed); - RandBLAS::fill_dense(S); - auto time_constructsketch2 = high_resolution_clock::now(); - - // Sketch AB - // SAB = alpha * \op(S) * \op(AB) + beta * SAB - auto time_sketch1 = high_resolution_clock::now(); - RandBLAS::sketch_general( - blas::Layout::ColMajor, // Matrix storage layout of AB and SAB - blas::Op::NoTrans, // NoTrans => \op(S) = S, Trans => \op(S) = S^T - blas::Op::NoTrans, // NoTrans => \op(AB) = AB, Trans => \op(AB) = AB^T - sk_dim, // Number of rows of S and SAB - n+1, // Number of columns of AB and SAB - m, // Number of rows of AB and columns of S - 1, // Scalar alpha - if alpha is zero AB is not accessed - S, // A DenseSkOp or SparseSkOp sketching operator - AB, // Matrix to be sketched - m, // Leading dimension of AB - 0, // Scalar beta - if beta is zero SAB is not accessed - SAB, // Sketched matrix SAB - sk_dim // Leading dimension of SAB - ); - auto time_sketch2 = high_resolution_clock::now(); - - // Perform SVD operation on SAB - auto time_TLS1 = high_resolution_clock::now(); - lapack::gesdd(lapack::Job::AllVec, sk_dim, (n+1), SAB, sk_dim, svals, U, sk_dim, VT, n+1); - - for (int i = 0; i < n; i++) { - X[i] = VT[n + i*(n+1)]; // Take the right n by 1 block of V - } - - // Scale X by the inverse of the 1 by 1 bottom right block of V - blas::scal(n, -1/VT[(n+1)*(n+1)-1], X, 1); - auto time_TLS2 = high_resolution_clock::now(); - - //Check TLS solution. Expected to be close to a vector of 1's - double res_infnorm = 0; - double res_twonorm = 0; - - for (int i = 0; i < n; i++) { - res[i] = abs(X[i] - 1); - res_twonorm += res[i]*res[i]; - if (res_infnorm < res[i]) { - res_infnorm = res[i]; - } - } - - std::cout << "Matrix dimensions: " << m << " by " << n+1 << '\n'; - std::cout << "Sketch dimension: " << sk_dim << '\n'; - std::cout << "Time to create dense sketch: " << (double) duration_cast(time_constructsketch2 - time_constructsketch1).count()/1000 << " seconds" << '\n'; - std::cout << "Time to sketch AB: " << (double) duration_cast(time_sketch2 - time_sketch1).count()/1000 << " seconds" <<'\n'; - std::cout << "Time to perform TLS on sketched matrix: " << (double) duration_cast(time_TLS2 - time_TLS1).count()/1000 << " seconds" << '\n'; - std::cout << "Inf-norm distance from TLS solution to vector of all ones: " << res_infnorm << '\n'; - std::cout << "Two-norm distance from TLS solution to vector of all ones: " << sqrt(res_twonorm) << '\n'; - - delete[] AB; - delete[] SAB; - delete[] X; - delete[] res; - delete[] U; - delete[] svals; - delete[] VT; - return 0; -} - - - - diff --git a/examples/TLS_SparseSkOp.cc b/examples/TLS_SparseSkOp.cc deleted file mode 100644 index 4dbfcda2..00000000 --- a/examples/TLS_SparseSkOp.cc +++ /dev/null @@ -1,173 +0,0 @@ -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -using std::chrono::high_resolution_clock; -using std::chrono::duration_cast; -using std::chrono::duration; -using std::chrono::milliseconds; - - -//TODO: Read in matrix dimensions -//TODO: Have the user choose between dense and sketch sketching operator (4 nnz per col) - -void init_noisy_data(int64_t m, int64_t n, int64_t d, double* AB){ - double target_x[n*d]; - double eps[m*d]; - for (int i = 0; i < n; i++) { - target_x[i] = 1; // Target X is the vector of 1's - } - - RandBLAS::DenseDist Dist_A(m,n); - RandBLAS::DenseDist Dist_eps(m,d); - auto state = RandBLAS::RNGState(0); - auto state1 = RandBLAS::RNGState(1); - - RandBLAS::fill_dense(Dist_A, AB, state); //Fill A to be a random gaussian - RandBLAS::fill_dense(Dist_eps, eps, state1); //Fill A to be a random gaussian - - blas::gemm(blas::Layout::ColMajor, blas::Op::NoTrans, blas::Op::NoTrans, m, d, n, 1, AB, m, target_x, n, 0, &AB[m*n], m); - - for (int i = 0; i < m*d; i++){ - AB[m*n + i] += eps[i]; // Add Gaussian Noise to right hand side - } -} - -/* Let A be a tall data matrix of dimensions m by n where m > n and b be a vector of dimension m. - * In ordinary least squares it assumes that the error lies only in the right hand side vector b, - * and it aims to find a vector x that minimizes ||A*x - b||_2. - * On the other hand, total least squares assumes that the input data matrix A could also incur errors. - * Total least squares aims to find a solution where the error is orthogonal to the regression model. - */ - -// To call the executable run ./TLS_DenseSkOp where are the number of rows and columns -// of A respectively. We expect m > 2*n. -int main(int argc, char* argv[]){ - - // Initialize dimensions - int64_t m; // Number of rows of A, B - int64_t n; // Number of columns of A - - if (argc == 1) { - m = 10000; - n = 500; - } else if (argc == 3) { - m = atoi(argv[1]); - n = atoi(argv[2]); - if (n > m) { - std::cout << "Make sure number of rows are greater than number of cols" << '\n'; - exit(0); - } - } else { - std::cout << "Invalid arguments" << '\n'; - exit(1); - } - - // Define number or rows of the sketching operator - int64_t sk_dim = 2*(n+1); - - // Initialize workspace - double *AB = new double[m*(n+1)]; // Store [A B] in column major format - double *SAB = new double[sk_dim*(n+1)]; - double *X = new double[n]; - double *res = new double[n]; - - // Initialize workspace for the sketched svd - double *U = new double[sk_dim*sk_dim]; - double *svals = new double[n+1]; - double *VT = new double[(n+1)*(n+1)]; - - // Initialize noisy gaussian data - init_noisy_data(m, n, 1, AB); - - - // Define the parameters of the sparse distribution - auto time_constructsketch1 = high_resolution_clock::now(); - - // Initialize sparse distribution struct for the sketching operator - uint32_t seed = 0; // Initialize seed for random number generation - RandBLAS::SparseDist Dist = {.n_rows = sk_dim, // Number of rows of the sketching operator - .n_cols = m, // Number of columns of the sketching operator - .vec_nnz = 4, // Number of non-zero entires per major axis - .major_axis = RandBLAS::MajorAxis::Short // Defines the major axis of the sketching operator - }; - - //Construct the sparse sketching operator - RandBLAS::SparseSkOp S(Dist, seed); - RandBLAS::fill_sparse(S); - auto time_constructsketch2 = high_resolution_clock::now(); - - // Sketch AB - // SAB = alpha * \op(S) * \op(AB) + beta * SAB - auto time_sketch1 = high_resolution_clock::now(); - RandBLAS::sketch_general( - blas::Layout::ColMajor, // Matrix storage layout of AB and SAB - blas::Op::NoTrans, // NoTrans => \op(S) = S, Trans => \op(S) = S^T - blas::Op::NoTrans, // NoTrans => \op(AB) = AB, Trans => \op(AB) = AB^T - sk_dim, // Number of rows of S and SAB - n+1, // Number of columns of AB and SAB - m, // Number of rows of AB and columns of S - 1, // Scalar alpha - if alpha is zero AB is not accessed - S, // A DenseSkOp or SparseSkOp sketching operator - AB, // Matrix to be sketched - m, // Leading dimension of AB - 0, // Scalar beta - if beta is zero SAB is not accessed - SAB, // Sketched matrix SAB - sk_dim // Leading dimension of SAB - ); - auto time_sketch2 = high_resolution_clock::now(); - - // Perform SVD operation on SAB - auto time_TLS1 = high_resolution_clock::now(); - lapack::gesdd(lapack::Job::AllVec, sk_dim, (n+1), SAB, sk_dim, svals, U, sk_dim, VT, n+1); - - for (int i = 0; i < n; i++) { - X[i] = VT[n + i*(n+1)]; // Take the right n by 1 block of V - } - - // Scale X by the inverse of the 1 by 1 bottom right block of V - blas::scal(n, -1/VT[(n+1)*(n+1)-1], X, 1); - auto time_TLS2 = high_resolution_clock::now(); - - //Check TLS solution. Expected to be a vector of 1's - double res_infnorm = 0; - double res_twonorm = 0; - - for (int i = 0; i < n; i++) { - res[i] = abs(X[i] - 1); - res_twonorm += res[i]*res[i]; - if (res_infnorm < res[i]) { - res_infnorm = res[i]; - } - } - - std::cout << "Matrix dimensions: " << m << " by " << n+1 << '\n'; - std::cout << "Sketch dimension: " << sk_dim << '\n'; - std::cout << "Time to create dense sketch: " << (double) duration_cast(time_constructsketch2 - time_constructsketch1).count()/1000 << " seconds" << '\n'; - std::cout << "Time to sketch AB: " << (double) duration_cast(time_sketch2 - time_sketch1).count()/1000 << " seconds" <<'\n'; - std::cout << "Time to perform TLS on sketched matrix: " << (double) duration_cast(time_TLS2 - time_TLS1).count()/1000 << " seconds" << '\n'; - std::cout << "Inf-norm distance from TLS solution to vector of all ones: " << res_infnorm << '\n'; - std::cout << "Two-norm distance from TLS solution to vector of all ones: " << sqrt(res_twonorm) << '\n'; - - delete[] AB; - delete[] SAB; - delete[] X; - delete[] res; - delete[] U; - delete[] svals; - delete[] VT; - return 0; -} - - - - diff --git a/examples/sparse-data-matrices/.keep b/examples/sparse-data-matrices/.keep new file mode 100644 index 00000000..e69de29b diff --git a/examples/sparse-data-matrices/README.md b/examples/sparse-data-matrices/README.md new file mode 100755 index 00000000..3da2602b --- /dev/null +++ b/examples/sparse-data-matrices/README.md @@ -0,0 +1,28 @@ +## Getting sparse data matrices for the examples +This shell script shows how to get sparse matrices in from the SuiteSparse Matrix Collection: + + https://sparse.tamu.edu/. + +Consider a page for a specific matrix in that collection: + + https://sparse.tamu.edu/Schulthess/N_reactome. + +At time of writing, that page (and ones like it) have "download" section with three clickable buttons. +One of those buttons says "Matrix Market". If you right-click that button you can get the link that it +points to: + + https://suitesparse-collection-website.herokuapp.com/MM/Schulthess/N_reactome.tar.gz. + +You can then download that file with a shell command like ``wget``, and uncompress it with a command +like ``tar``. See below for commands that would be executed when running a bash terminal on macOS. + + Note: you'll need to run these commands (or some like them) in order to call some of the + examples in ``RandBLAS/examples/sparse-low-rank-approx`` without specifying a Matrix Market file. + +```shell +wget https://suitesparse-collection-website.herokuapp.com/MM/Schulthess/N_reactome.tar.gz +tar -xvzf N_reactome.tar.gz +wget https://suitesparse-collection-website.herokuapp.com/MM/HB/bcsstk17.tar.gz +tar -xvzf bcsstk17.tar.gz +``` + diff --git a/examples/sparse-low-rank-approx/qrcp_matrixmarket.cc b/examples/sparse-low-rank-approx/qrcp_matrixmarket.cc new file mode 100644 index 00000000..8f7e2626 --- /dev/null +++ b/examples/sparse-low-rank-approx/qrcp_matrixmarket.cc @@ -0,0 +1,428 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +using RandBLAS::sparse_data::COOMatrix; +using RandBLAS::sparse_data::CSCMatrix; +using std_clock = std::chrono::high_resolution_clock; +using timepoint_t = std::chrono::time_point; +using std::chrono::duration_cast; +using std::chrono::microseconds; + +#define DOUT(_d) std::setprecision(8) << _d + +auto parse_args(int argc, char** argv) { + std::string mat{"../sparse-data-matrices/N_reactome/N_reactome.mtx"}; + int k = 4; + if (argc > 1) + k = atoi(argv[1]); + if (argc > 2) + mat = argv[2]; + return std::make_tuple(mat, k); +} + +template +COOMatrix from_matrix_market(std::string fn) { + + int64_t n_rows, n_cols = 0; + std::vector rows{}; + std::vector cols{}; + std::vector vals{}; + + std::ifstream file_stream(fn); + fast_matrix_market::read_matrix_market_triplet( + file_stream, n_rows, n_cols, rows, cols, vals + ); + + COOMatrix out(n_rows, n_cols); + out.reserve(vals.size()); + for (int i = 0; i < out.nnz; ++i) { + out.rows[i] = rows[i]; + out.cols[i] = cols[i]; + out.vals[i] = vals[i]; + } + + return out; +} + +template +void col_swap(int64_t m, int64_t n, int64_t k, T* A, int64_t lda, const int64_t* idx) { + // Adapted from RandLAPACK code. + if(k > n) + throw std::runtime_error("Invalid rank parameter."); + int64_t *idx_copy = new int64_t[n]{}; + for (int i = 0; i < n; ++i) + idx_copy[i] = idx[i]; + + int64_t i, j; + for (i = 0, j = 0; i < k; ++i) { + j = idx_copy[i] - 1; + blas::swap(m, &A[i * lda], 1, &A[j * lda], 1); + auto it = std::find(idx_copy + i, idx_copy + k, i + 1); + idx_copy[it - idx_copy] = j + 1; + } + delete [] idx_copy; +} + +template +int qr_row_stabilize(int64_t m, int64_t n, T* mat, T* vec_work) { + if(lapack::gelqf(m, n, mat, m, vec_work)) + return 1; + randblas_require(m < n); + // The signature of UNGLQ is weird. LAPACK++ provides it as a wrapper to ORMLQ. See + // https://www.intel.com/content/www/us/en/docs/onemkl/developer-reference-fortran/2024-1/orglq.html + // for why we're using these particular arguments to get an orthonormal basis + // for the rowspace of "mat" (where the basis is given by a column-major representation + // of the transposed basis vectors). + lapack::unglq(m, n, m, mat, m, vec_work); + return 0; +} + +template +int sketch_orthogonalize_rows(int64_t m, int64_t n, T* A, T* work, int64_t d, int32_t key) { + RandBLAS::RNGState state(key); + // A is wide, m-by-n in column-major format. + // work has at least m*d space. + randblas_require(d >= 2); + randblas_require(d >= m); + std::vector tau(d, 0.0); + int64_t vec_nnz = std::min(d/2, (int64_t) 4); + RandBLAS::SparseDist D{n, d, vec_nnz}; + RandBLAS::SparseSkOp S(D, state); + // Simple option (shown here): + // Sketch A in column-major format, then do LQ on the sketch. + // If the sketch looks singular after we decompose it, then we bite the bullet and do LQ on A. + // + // Fancy option (for a later implementation): + // Compute the sketch in row-major format (this is lying about A's format, but we resolve that by setting transA=T). + // Look at the row-major sketch, interpret it as a transposed column-major sketch, factor the "implicitly de-transposed" version by GEQP3, + // then implicitly transpose back to get the factors for row-pivoted LQ of A_sk: + // A_sk = P R^* Q^* + // Apply M = inv(P R^*) = inv(R^*) P^* to the left of A by TRSM. + // If rank(A) = rank(A_sk) = m, then in exact arithmetic the conditioning of MA should be independent from + // that of A. However, there can be a positive correlation between cond(MA) and cond(A) in finite-precision. + // This happens when cond(A) is very large, and may warrant truncating rows of MA. + // There are many ways to select the size of that row-block. Two options jump out: + // 1. Compute (or estimate) the numerical rank of R by a reliable method of your choosing. + // 2. Proceed in a similar vein as CQRRPT: estimate the condition number of leading row blocks of MA + // by forming the Gram matrix (MA)(MA)^*, doing Cholesky on it, and computing or estimating the + // condition numbers of leading submatrices of the Cholesky factor. + // + // + RandBLAS::sketch_general(blas::Layout::ColMajor, blas::Op::NoTrans, blas::Op::NoTrans, m, d, n, 1.0, A, m, S, 0.0, work, m); + lapack::gelqf(m, d, work, m, tau.data()); + T tol = std::numeric_limits::epsilon()*100; + for (int i = 0; i < m; ++i) { + if (std::abs(work[i*m + i]) < tol) { + // We can't safely invert. Fall back on LQ of A. + qr_row_stabilize(m, n, A, tau.data()); + std::cout << "\n----> Could not safely sketch-orthogonalize. Falling back on GELQF instead.\n\n"; + return 1; + } + } + // L is in the lower triangle of work. + // Need to transform + // A <- inv(L)A + blas::trsm(blas::Layout::ColMajor, blas::Side::Left, blas::Uplo::Lower, blas::Op::NoTrans, blas::Diag::NonUnit, m, n, 1.0, work, m, A, m); + return 0; +} + +template +int lu_row_stabilize(int64_t m, int64_t n, T* mat, int64_t* piv_work) { + randblas_require(m < n); + for (int64_t i = 0; i < m; ++i) + piv_work[i] = 0; + lapack::getrf(m, n, mat, m, piv_work); + // above: the permutation applied to the rows of mat doesn't matter in our context. + // below: Need to zero-out the strict lower triangle of mat and scale each row. + T tol = std::numeric_limits::epsilon()*10; + bool nonzero_diag_U = true; + for (int64_t j = 0; (j < m-1) & nonzero_diag_U; ++j) { + nonzero_diag_U = abs(mat[j + j*m]) > tol; + for (int64_t i = j + 1; i < m; ++i) { + mat[i + j*m] = 0.0; + } + } + if (!nonzero_diag_U) { + throw std::runtime_error("LU stabilization failed. Matrix has been overwritten, so we cannot recover."); + } + for (int64_t i = 0; i < m; ++i) { + T scale = 1.0 / mat[i + i*m]; + blas::scal(n, scale, mat + i, m); + } + return 0; +} + +#ifdef FINE_GRAINED +#define TIMED_LINE(_op, _name) { \ + auto _tp0 = std_clock::now(); \ + _op; \ + auto _tp1 = std_clock::now(); \ + double dtime = (double) duration_cast(_tp1 - _tp0).count(); \ + std::cout << _name << DOUT(dtime / 1e6) << std::endl; \ + } +#else +#define TIMED_LINE(_op, _name) _op; +#endif + +enum class StabilizationMethod : char { + LU = 'L', + LQ = 'H', // householder + sketch = 'S', + None = 'N' +}; + +template +void power_iter_col_sketch(SpMat &A, int64_t k, T* Y, int64_t p_data_aware, STATE state, T* work, StabilizationMethod sm) { + int64_t m = A.n_rows; + int64_t n = A.n_cols; + using RandBLAS::sparse_data::right_spmm; + using blas::Op; + using blas::Layout; + // Want k-by-n matrix Y = SA, where S has p_data_aware passes over A to build up data-aware geometry. + // Run ... + // p_done = 0 + // if p_data_aware is even: + // S = oblivious k-by-m. + // if p_data_aware is odd: + // T = oblivious k-by-n. + // S = row_orth(T A') + // p_done += 1 + // while (p_data_aware - p_done > 0) + // T = row_orth(S A) + // S = row_orth(T A') + // p_done += 2 + // Y = S A + T* mat_work1 = Y; + T* mat_work2 = work; + + // Messy code to allow for different stabilization methods + T* tau_work = new T[std::max(n, m)]; + int64_t* piv_work = new int64_t[k]; + int64_t sketch_dim = (int64_t) (1.25*m + 1); + T* sketch_orth_work = new T[sketch_dim * m]{0.0}; + auto stab_func = [sm, k, piv_work, tau_work, sketch_orth_work, sketch_dim](T* mat_to_stab, int64_t num_mat_cols, int64_t key) { + if (sm == StabilizationMethod::LU) { + lu_row_stabilize(k, num_mat_cols, mat_to_stab, piv_work); + } else if (sm == StabilizationMethod::LQ) { + qr_row_stabilize(k, num_mat_cols, mat_to_stab, tau_work); + } else if (sm == StabilizationMethod::sketch) { + sketch_orthogonalize_rows(k, num_mat_cols, mat_to_stab, sketch_orth_work, sketch_dim, key); + } else if (sm == StabilizationMethod::None) { + // do nothing + } + return; + }; + + int64_t p_done = 0; + if (p_data_aware % 2 == 0) { + RandBLAS::DenseDist D(k, m, RandBLAS::DenseDistName::Gaussian); + TIMED_LINE( + RandBLAS::fill_dense(D, mat_work2, state), "sampling : ") + } else { + RandBLAS::DenseDist D(k, n, RandBLAS::DenseDistName::Gaussian); + TIMED_LINE( + RandBLAS::fill_dense(D, mat_work1, state), "sampling : ") + TIMED_LINE( + right_spmm(Layout::ColMajor, Op::NoTrans, Op::Trans, k, m, n, 1.0, mat_work1, k, A, 0, 0, 0.0, mat_work2, k), "spmm : ") + p_done += 1; + TIMED_LINE( + stab_func(mat_work2, m, p_done), "stabilization : ") + } + + while (p_data_aware - p_done > 0) { + TIMED_LINE( + right_spmm(Layout::ColMajor, Op::NoTrans, Op::NoTrans, k, n, m, 1.0, mat_work2, k, A, 0, 0, 0.0, mat_work1, k), "right_spmm : ") + p_done += 1; + TIMED_LINE( + stab_func(mat_work1, n, p_done), "stabilization : ") + TIMED_LINE( + right_spmm(Layout::ColMajor, Op::NoTrans, Op::Trans, k, m, n, 1.0, mat_work1, k, A, 0, 0, 0.0, mat_work2, k), "right_spmm : ") + p_done += 1; + TIMED_LINE( + stab_func(mat_work2, m, p_done), "stabilization : ") + } + TIMED_LINE( + right_spmm(Layout::ColMajor, Op::NoTrans, Op::NoTrans, k, n, m, 1.0, mat_work2, k, A, 0, 0, 0.0, Y, k), "spmm : ") + + delete [] tau_work; + delete [] piv_work; + delete [] sketch_orth_work; + return; +} + +template +void print_row_norms(T* mat, int64_t m, int64_t n, std::string s) { + std::cout << "Row norms for " << s << " : [ "; + int i; + for (i = 0; i < m-1; ++i) { + std::cout << DOUT(blas::nrm2(n, mat + i, m)) << ", "; + } + std::cout << DOUT(blas::nrm2(n, mat + i, m)) << " ] " << std::endl; + return; +} + +void print_pivots(int64_t *piv, int64_t k) { + std::cout << "Leading pivots : [ "; + int i; + for (i = 0; i < k-1; ++i) { + std::cout << piv[i]-1 << ", "; + } + std::cout << piv[i]-1 << " ]" << std::endl; +} + +template +void sketch_to_tqrcp(SpMat &A, int64_t k, T* Q, int64_t ldq, T* Y, int64_t ldy, int64_t *piv) { + // On input, Y is a left-sketch of A. + // On exit, Q, Y, piv are overwritten so that ... + // The columns of Q are an orthonormal basis for A(:, piv(:k)) + // Y = Q' A(:, piv) is upper-triangular. + using sint_t = typename SpMat::index_t; + constexpr bool valid_type = std::is_same_v>; + randblas_require(valid_type); + int64_t m = A.n_rows; + int64_t n = A.n_cols; + using blas::Layout; + using blas::Op; + using blas::Side; + using blas::Uplo; + for (int64_t i = 0; i < n; ++i) + piv[i] = 0; + T* tau = new T[n]{}; + T* precond = new T[k * k]{}; + + // ================================================================ + // Step 1: get the pivots + TIMED_LINE( + lapack::geqp3(k, n, Y, ldy, piv, tau), "GEQP3 : ") + + // ================================================================ + // Step 2: copy A(:, piv(0)-1), ..., A(:, piv(k)-1) into dense Q + for (int64_t j = 0; j < k; ++j) { + RandBLAS::util::safe_scal(m, 0.0, Q + j*ldq, 1); + for (int64_t ell = A.colptr[piv[j]-1]; ell < A.colptr[piv[j]]; ++ell) { + int64_t i = A.rowidxs[ell]; + Q[i + ldq*j] = A.vals[ell]; + } + } + + // ================================================================ + // Step 3: get explicit representation of orth(Q). + TIMED_LINE( + // Extract a preconditioner from the column-pivoted QR decomposition of Y. + for (int64_t j = 0; j < k; j++) { + for (int64_t i = 0; i < k; ++i) { + precond[i + k*j] = Y[i + k*j]; + } + } + // Apply the preconditioner: Q = Q / precond. + blas::trsm(Layout::ColMajor, Side::Right, Uplo::Upper, Op::NoTrans, blas::Diag::NonUnit, m, k, 1.0, precond, k, Q, ldq); + // Cholesky-orthogonalize the preconditioned matrix: + // precond = chol(Q' * Q, "upper") + // Q = Q / precond. + blas::syrk(Layout::ColMajor, Uplo::Upper, Op::Trans, k, m, 1.0, Q, ldq, 0.0, precond, k); + lapack::potrf(Uplo::Upper, k, precond, k); + blas::trsm(Layout::ColMajor, Side::Right, Uplo::Upper, Op::NoTrans, blas::Diag::NonUnit, m, k, 1.0, precond, k, Q, ldq), "getQ : ") + + // ================================================================ + // Step 4: multiply Y = Q'A and pivot Y = Y(:, piv) + TIMED_LINE( + RandBLAS::right_spmm(Layout::ColMajor, Op::Trans, Op::NoTrans, k, n, m, 1.0, Q, ldq, A, 0, 0, 0.0, Y, ldy); + col_swap(k, n, k, Y, ldy, piv), "getR : ") + + delete [] tau; + delete [] precond; + return; +} + +template +int run(SpMat &A, int64_t k, int64_t power_iteration_steps, StabilizationMethod sm, bool extra_verbose) { + auto m = A.n_rows; + auto n = A.n_cols; + + using T = typename SpMat::scalar_t; + T *Q = new T[m*k]{}; + T *R = new T[k*n]{}; + int64_t *piv = new int64_t[n]{}; + RandBLAS::RNGState state(0); + + auto start_timer = std_clock::now(); + TIMED_LINE( + power_iter_col_sketch(A, k, R, power_iteration_steps, state, Q, sm), "\n\tpower iter sketch : ") + if (extra_verbose) + print_row_norms(R, k, n, "Yf"); + TIMED_LINE( + sketch_to_tqrcp(A, k, Q, m, R, k, piv), "\n\tsketch to QRCP : ") + auto stop_timer = std_clock::now(); + if (extra_verbose) + print_row_norms(R, k, n, "R "); + print_pivots(piv, k); + + T runtime = (T) duration_cast(stop_timer - start_timer).count(); + std::cout << "Runtime in μs : " << DOUT(runtime) << std::endl << std::endl; + + delete [] Q; + delete [] R; + delete [] piv; + return 0; +} + +int main(int argc, char** argv) { + /* + This program should be called from a "build" folder that's one level below RandBLAS/examples. + + If called with two arguments, then the first argument will be the approximation rank, + and the second argument will be a path (relative or absolute) to a MatrixMarket file. + + If called with zero or one arguments, we'll assume that there's a file located at + ../sparse-data-matrices/N_reactome/N_reactome.mtx. + + If called with zero arguments, we'll automatically set the approximation rank to 4. + */ + auto [fn, _k] = parse_args(argc, argv); + auto mat_coo = from_matrix_market(fn); + auto m = mat_coo.n_rows; + auto n = mat_coo.n_cols; + int64_t k = (int64_t) _k; + + std::cout << "\nProcessing matrix in " << fn << std::endl; + std::cout << "n_rows : " << mat_coo.n_rows << std::endl; + std::cout << "n_cols : " << mat_coo.n_cols << std::endl; + double density = ((double) mat_coo.nnz) / ((double) (mat_coo.n_rows * mat_coo.n_cols)); + std::cout << "density : " << DOUT(density) << std::endl << std::endl; + + RandBLAS::CSCMatrix mat_csc(m, n); + RandBLAS::conversions::coo_to_csc(mat_coo, mat_csc); + int64_t power_iter_steps = 2; + bool extra_verbose = true; + + std::cout << "Computing rank-" << k << " truncated QRCP.\n"; + std::cout << "Internally use " << power_iter_steps << " steps of power iteration.\n"; + std::cout << "Consider four runs, each stabilizing power iteration in a different way.\n\n"; + std::cout << "Take Q from LQ\n"; + run(mat_csc, k, power_iter_steps, StabilizationMethod::LQ, extra_verbose); + std::cout << "Sketch-orthogonalize\n"; + run(mat_csc, k, power_iter_steps, StabilizationMethod::sketch, extra_verbose); + std::cout << "Do nothing. This is numerically dangerous unless power_iter_steps is extremely small.\n"; + run(mat_csc, k, power_iter_steps, StabilizationMethod::None, extra_verbose); + std::cout << "Take (scaled) U from row-pivoted LU. This may exit with an error!\n"; + run(mat_csc, k, power_iter_steps, StabilizationMethod::LU, extra_verbose); + return 0; +} diff --git a/examples/sparse-low-rank-approx/svd_matrixmarket.cc b/examples/sparse-low-rank-approx/svd_matrixmarket.cc new file mode 100644 index 00000000..61c459d0 --- /dev/null +++ b/examples/sparse-low-rank-approx/svd_matrixmarket.cc @@ -0,0 +1,223 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using RandBLAS::sparse_data::COOMatrix; +using std_clock = std::chrono::high_resolution_clock; +using timepoint_t = std::chrono::time_point; +using std::chrono::duration_cast; +using std::chrono::microseconds; + + +#define DOUT(_d) std::setprecision(std::numeric_limits::max_digits10) << _d + +std::string parse_args(int argc, char** argv) { + if (argc > 1) { + return std::string{argv[1]}; + } else { + return "../sparse-data-matrices/bcsstk17/bcsstk17.mtx"; + } +} + +template +COOMatrix from_matrix_market(std::string fn) { + + int64_t n_rows, n_cols = 0; + std::vector rows{}; + std::vector cols{}; + std::vector vals{}; + + std::ifstream file_stream(fn); + fast_matrix_market::read_matrix_market_triplet( + file_stream, n_rows, n_cols, rows, cols, vals + ); + + COOMatrix out(n_rows, n_cols); + out.reserve(vals.size()); + for (int i = 0; i < out.nnz; ++i) { + out.rows[i] = rows[i]; + out.cols[i] = cols[i]; + out.vals[i] = vals[i]; + } + + return out; +} + +template +int householder_orth(int64_t m, int64_t n, T* mat, T* work) { + if(lapack::geqrf(m, n, mat, m, work)) + return 1; + lapack::ungqr(m, n, n, mat, m, work); + return 0; +} + + +#define TIMED_LINE(_op, _name) { \ + auto _tp0 = std_clock::now(); \ + _op; \ + auto _tp1 = std_clock::now(); \ + double dtime = (double) duration_cast(_tp1 - _tp0).count(); \ + std::cout << _name << DOUT(dtime / 1e6) << std::endl; \ + } + + +template +void qb_decompose_sparse_matrix(SpMat &A, int64_t k, T* Q, T* B, int64_t p, STATE state, T* work, int64_t lwork) { + int64_t m = A.n_rows; + int64_t n = A.n_cols; + using RandBLAS::sparse_data::left_spmm; + using RandBLAS::sparse_data::right_spmm; + using blas::Op; + using blas::Layout; + + // We use Q and B as workspace and to store the final result. + // To distinguish the semantic use of workspace from the final result, + // we define some alias pointers to Q's and B's memory. + randblas_require(lwork >= std::max(m, n)); + T* mat_work1 = Q; + T* mat_work2 = B; + int64_t p_done = 0; + + std::string sample_log = "sample : "; + std::string lspmmN_log = "left_spmm (NoTrans) : "; + std::string orth_log = "orth : "; + std::string lspmmT_log = "left_spmm (Trans) : "; + + // Convert to CSC. + // CSR would also be okay, but it seems that CSC is faster in this case. + RandBLAS::sparse_data::CSCMatrix A_compressed(A.n_rows, A.n_cols); + TIMED_LINE( + RandBLAS::sparse_data::conversions::coo_to_csc(A, A_compressed), "COO to CSC : ") + + // Step 1: fill S := mat_work2 with the data needed to feed it into power iteration. + if (p % 2 == 0) { + RandBLAS::DenseDist D(n, k); + TIMED_LINE( + RandBLAS::fill_dense(D, mat_work2, state), sample_log) + } else { + RandBLAS::DenseDist D(m, k); + TIMED_LINE( + RandBLAS::fill_dense(D, mat_work1, state), sample_log) + TIMED_LINE( + left_spmm(Layout::ColMajor, Op::Trans, Op::NoTrans, n, k, m, 1.0, A_compressed, 0, 0, mat_work1, m, 0.0, mat_work2, n), lspmmT_log) + TIMED_LINE( + householder_orth(n, k, mat_work2, work), orth_log) + p_done += 1; + } + + // Step 2: fill S := mat_work2 with data needed to feed it into the rangefinder. + while (p - p_done > 0) { + // Update S = orth(A' * orth(A * S)) + TIMED_LINE( + left_spmm(Layout::ColMajor, Op::NoTrans, Op::NoTrans, m, k, n, 1.0, A_compressed, 0, 0, mat_work2, n, 0.0, mat_work1, m), lspmmN_log) + TIMED_LINE( + householder_orth(m, k, mat_work1, work), orth_log) + TIMED_LINE( + left_spmm(Layout::ColMajor, Op::Trans, Op::NoTrans, n, k, m, 1.0, A_compressed, 0, 0, mat_work1, m, 0.0, mat_work2, n), lspmmT_log) + TIMED_LINE( + householder_orth(n, k, mat_work2, work), orth_log) + p_done += 2; + } + + // Step 3: compute Q = orth(A * S) and B = Q'A. + TIMED_LINE( + left_spmm(Layout::ColMajor, Op::NoTrans, Op::NoTrans, m, k, n, 1.0, A_compressed, 0, 0, mat_work2, n, 0.0, Q, m), lspmmN_log) + TIMED_LINE( + householder_orth(m, k, Q, work), orth_log) + TIMED_LINE( + right_spmm(Layout::ColMajor, Op::Trans, Op::NoTrans, k, n, m, 1.0, Q, m, A_compressed, 0, 0, 0.0, B, k), "right_spmm : ") + return; +} + +template +void qb_to_svd(int64_t m, int64_t n, int64_t k, T* Q, T* svals, int64_t ldq, T* B, int64_t ldb, T* work, int64_t lwork) { + // Input: (Q, B) defining a matrix A = Q*B, where + // Q is m-by-k and column orthonormal + // and + // B is k-by-n and otherwise unstructured. + // + // Output: + // Q holds the top-k left singular vectors of A. + // B holds a matrix that can be described in two equivalent ways: + // 1. a column-major representation of the top-k transposed right singular vectors of A. + // 2. a row-major representation of the top-k right singular vectors of A. + // svals holds the top-k singular values of A. + // + using blas::Op; + using blas::Layout; + using lapack::Job; + using lapack::MatrixType; + + // Compute the SVD of B: B = U diag(svals) VT, where B is overwritten by VT. + int64_t extra_work_size = lwork - k*k; + randblas_require(extra_work_size >= 0); + T* U = work; // <-- just a semantic alias for the start of work. + lapack::gesdd(Job::OverwriteVec, k, n, B, ldb, svals, U, k, nullptr, k); + + // update Q = Q U. + bool allocate_more_work = extra_work_size < m*k; + T* more_work = (allocate_more_work) ? new T[m*k] : (work + k*(k+1)); + lapack::lacpy(MatrixType::General, m, k, Q, ldq, more_work, m); + blas::gemm(Layout::ColMajor, Op::NoTrans, Op::NoTrans, m, k, k, 1.0, more_work, m, U, k, 0.0, Q, ldq); + + if (allocate_more_work) + delete [] more_work; + + return; +} + +int main(int argc, char** argv) { + + auto fn = parse_args(argc, argv); + auto mat_sparse = from_matrix_market(fn); + auto m = mat_sparse.n_rows; + auto n = mat_sparse.n_cols; + + // Run the randomized algorithm! + int64_t k = 64; + double *U = new double[m*k]{}; + double *VT = new double[k*n]{}; + double *qb_work = new double[std::max(m, n)]; + RandBLAS::RNGState state(0); + /* + Effect of various parameters on performance: + It's EXTREMELY important to use -O3 if you want reasonably + fast sparse matrix conversion inside RandBLAS. We're talking + a more-than-10x speedup. + + */ + auto start_timer = std_clock::now(); + qb_decompose_sparse_matrix(mat_sparse, k, U, VT, 2, state, qb_work, std::max(m,n)); + double *svals = new double[std::min(m,n)]; + double *conversion_work = new double[m*k + k*k]; + qb_to_svd(m, n, k, U, svals, m, VT, k, conversion_work, m*k + k*k); + auto stop_timer = std_clock::now(); + double runtime = (double) std::chrono::duration_cast(stop_timer - start_timer).count(); + runtime = runtime / 1e6; + + std::cout << "n_rows : " << mat_sparse.n_rows << std::endl; + std::cout << "n_cols : " << mat_sparse.n_cols << std::endl; + double density = ((double) mat_sparse.nnz) / ((double) (mat_sparse.n_rows * mat_sparse.n_cols)); + std::cout << "density : " << DOUT(density) << std::endl; + std::cout << "runtime of low-rank approximation : " << DOUT(runtime) << std::endl; + + delete [] qb_work; + delete [] conversion_work; + delete [] svals; + return 0; +} diff --git a/examples/sparse-low-rank-approx/svd_rank1_plus_noise.cc b/examples/sparse-low-rank-approx/svd_rank1_plus_noise.cc new file mode 100644 index 00000000..75cc8f68 --- /dev/null +++ b/examples/sparse-low-rank-approx/svd_rank1_plus_noise.cc @@ -0,0 +1,334 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using RandBLAS::sparse_data::COOMatrix; + +#define DOUT(_d) std::setprecision(std::numeric_limits::max_digits10) << _d + +auto parse_dimension_args(int argc, char** argv) { + int64_t m; + int64_t n; + int64_t vec_nnz; + + if (argc == 1) { + m = 10000; + n = 500; + vec_nnz = 4; + } else if (argc == 4) { + m = atoi(argv[1]); + n = atoi(argv[2]); + vec_nnz = atoi(argv[3]); + } else { + std::cout << "Invalid parameters; must be called with no parameters, or with three positive integers." << '\n'; + exit(1); + } + return std::make_tuple(m, n, vec_nnz); +} + +template +void iid_sparsify_random_dense( + int64_t n_rows, int64_t n_cols, int64_t stride_row, int64_t stride_col, T* mat, T prob_of_zero, RandBLAS::RNGState state +) { + auto spar = new T[n_rows * n_cols]; + auto dist = RandBLAS::DenseDist(n_rows, n_cols, RandBLAS::DenseDistName::Uniform); + auto [unused, next_state] = RandBLAS::fill_dense(dist, spar, state); + + auto temp = new T[n_rows * n_cols]; + auto D_mat = RandBLAS::DenseDist(n_rows, n_cols, RandBLAS::DenseDistName::Uniform); + RandBLAS::fill_dense(D_mat, temp, next_state); + + #define SPAR(_i, _j) spar[(_i) + (_j) * n_rows] + #define TEMP(_i, _j) temp[(_i) + (_j) * n_rows] + #define MAT(_i, _j) mat[(_i) * stride_row + (_j) * stride_col] + for (int64_t i = 0; i < n_rows; ++i) { + for (int64_t j = 0; j < n_cols; ++j) { + T v = (SPAR(i, j) + 1.0) / 2.0; + if (v < prob_of_zero) { + MAT(i, j) = 0.0; + } else { + MAT(i, j) = TEMP(i, j); + } + } + } + + delete [] spar; + delete [] temp; +} + +template +SpMat sum_of_coo_matrices(SpMat &A, SpMat &B) { + randblas_require(A.n_rows == B.n_rows); + randblas_require(A.n_cols == B.n_cols); + + using T = typename SpMat::scalar_t; + using sint_t = typename SpMat::index_t; + constexpr bool valid_type = std::is_same_v>; + randblas_require(valid_type); + + using Tuple = std::pair; + struct TupleHasher { + size_t operator()(const Tuple &coordinate_pair) const { + // an implementation suggested by my robot friend. + size_t hash1 = std::hash{}(coordinate_pair.first); + size_t hash2 = std::hash{}(coordinate_pair.second); + size_t hash3 = hash1; + hash3 ^= hash2 + 0x9e3779b9 + (hash1 << 6) + (hash1 >> 2); + return hash3; + } + }; + std::unordered_map c_dict{}; + + for (int ell = 0; ell < A.nnz; ++ell) { + Tuple curr_idx(A.rows[ell], A.cols[ell]); + c_dict[curr_idx] = A.vals[ell]; + } + for (int ell = 0; ell < B.nnz; ++ell) { + Tuple curr_idx(B.rows[ell], B.cols[ell]); + c_dict[curr_idx] = B.vals[ell] + c_dict[curr_idx]; + } + + SpMat C(A.n_rows, A.n_cols); + C.reserve(c_dict.size()); + int64_t ell = 0; + for (auto iter : c_dict) { + Tuple t = iter.first; + auto [i, j] = t; + C.rows[ell] = i; + C.cols[ell] = j; + C.vals[ell] = iter.second; + ++ell; + } + return C; +} + + +template +void make_signal_matrix(double signal_scale, double* u, int64_t m, double* v, int64_t n, int64_t vec_nnz, double* signal_dense, SpMat &signal_sparse) { + using T = typename SpMat::scalar_t; + using sint_t = typename SpMat::index_t; + constexpr bool valid_type = std::is_same_v>; + randblas_require(valid_type); + signal_sparse.reserve(vec_nnz * vec_nnz); + + // populate signal_dense and signal_sparse. + RandBLAS::RNGState u_state(0); + double *work_vals = new double[2*vec_nnz]{}; + int64_t *work_idxs = new int64_t[2*vec_nnz]{}; + int64_t *trash = new int64_t[vec_nnz]{}; + + double uv_scale = 1.0 / std::sqrt((double) vec_nnz); + + + auto v_state = RandBLAS::repeated_fisher_yates(u_state, vec_nnz, m, 1, work_idxs, trash, work_vals); + auto next_state = RandBLAS::repeated_fisher_yates(v_state, vec_nnz, n, 1, work_idxs+vec_nnz, trash, work_vals+vec_nnz); + for (int j = 0; j < vec_nnz; ++j) { + for (int i = 0; i < vec_nnz; ++i) { + int temp = i + j*vec_nnz; + signal_sparse.rows[temp] = work_idxs[i]; + signal_sparse.cols[temp] = work_idxs[j+vec_nnz]; + signal_sparse.vals[temp] = work_vals[i] * work_vals[j+vec_nnz]; + } + u[work_idxs[j]] = uv_scale * work_vals[j]; + v[work_idxs[j + vec_nnz]] = uv_scale * work_vals[j + vec_nnz]; + } + blas::ger(blas::Layout::ColMajor, m, n, signal_scale, u, 1, v, 1, signal_dense, m); + + delete [] work_vals; + delete [] work_idxs; + delete [] trash; + return; +} + + +template +void make_noise_matrix(double noise_scale, int64_t m, int64_t n, double prob_of_nonzero, double* noise_dense, SpMat &noise_sparse) { + // populate noise_dense and noise_sparse. + // + // NOTE: it would be more efficient to sample vec_nnz*vec_nnz elements without replacement from the index set + // from 0 to m*n-1, then de-vectorize those indices (in either row-major or col-major interpretation) and + // only sample the values of the nonzeros for these pre-determined structural nonzeros. The current implementation + // has to generate to dense m-by-n matrices whose entries are iid uniform [-1, 1]. + // + using T = typename SpMat::scalar_t; + using sint_t = typename SpMat::index_t; + constexpr bool valid_type = std::is_same_v>; + randblas_require(valid_type); + + RandBLAS::RNGState noise_state(1); + double prob_of_zero = 1 - prob_of_nonzero; + iid_sparsify_random_dense(m, n, 1, m, noise_dense, prob_of_zero, noise_state); + blas::scal(m * n, noise_scale, noise_dense, 1); + RandBLAS::sparse_data::coo::dense_to_coo(blas::Layout::ColMajor, noise_dense, 0.0, noise_sparse); + return; +} + +template +int householder_orth(int64_t m, int64_t n, T* mat, T* work) { + if(lapack::geqrf(m, n, mat, m, work)) + return 1; + lapack::ungqr(m, n, n, mat, m, work); + return 0; +} + +template +void qb_decompose_sparse_matrix(SpMat &A, int64_t k, T* Q, T* B, int64_t p, STATE state, T* work, int64_t lwork) { + int64_t m = A.n_rows; + int64_t n = A.n_cols; + using RandBLAS::sparse_data::left_spmm; + using RandBLAS::sparse_data::right_spmm; + using blas::Op; + using blas::Layout; + + // We use Q and B as workspace and to store the final result. + // To distinguish the semantic use of workspace from the final result, + // we define some alias pointers to Q's and B's memory. + randblas_require(lwork >= std::max(m, n)); + T* mat_work1 = Q; + T* mat_work2 = B; + int64_t p_done = 0; + + // Step 1: fill S := mat_work2 with the data needed to feed it into power iteration. + if (p % 2 == 0) { + RandBLAS::DenseDist D(n, k); + RandBLAS::fill_dense(D, mat_work2, state); + } else { + RandBLAS::DenseDist D(m, k); + RandBLAS::fill_dense(D, mat_work1, state); + left_spmm(Layout::ColMajor, Op::Trans, Op::NoTrans, n, k, m, 1.0, A, 0, 0, mat_work1, m, 0.0, mat_work2, n); + p_done += 1; + householder_orth(n, k, mat_work2, work); + } + + // Step 2: fill S := mat_work2 with data needed to feed it into the rangefinder. + while (p - p_done > 0) { + // Update S = orth(A' * orth(A * S)) + left_spmm(Layout::ColMajor, Op::NoTrans, Op::NoTrans, m, k, n, 1.0, A, 0, 0, mat_work2, n, 0.0, mat_work1, m); + householder_orth(m, k, mat_work1, work); + left_spmm(Layout::ColMajor, Op::Trans, Op::NoTrans, n, k, m, 1.0, A, 0, 0, mat_work1, m, 0.0, mat_work2, n); + householder_orth(n, k, mat_work2, work); + p_done += 2; + } + + // Step 3: compute Q = orth(A * S) and B = Q'A. + left_spmm(Layout::ColMajor, Op::NoTrans, Op::NoTrans, m, k, n, 1.0, A, 0, 0, mat_work2, n, 0.0, Q, m); + householder_orth(m, k, Q, work); + right_spmm(Layout::ColMajor, Op::Trans, Op::NoTrans, k, n, m, 1.0, Q, m, A, 0, 0, 0.0, B, k); + return; +} + +template +void qb_to_svd(int64_t m, int64_t n, int64_t k, T* Q, T* svals, int64_t ldq, T* B, int64_t ldb, T* work, int64_t lwork) { + // Input: (Q, B) defining a matrix A = Q*B, where + // Q is m-by-k and column orthonormal + // and + // B is k-by-n and otherwise unstructured. + // + // Output: + // Q holds the top-k left singular vectors of A. + // B holds a matrix that can be described in two equivalent ways: + // 1. a column-major representation of the top-k transposed right singular vectors of A. + // 2. a row-major representation of the top-k right singular vectors of A. + // svals holds the top-k singular values of A. + // + using blas::Op; + using blas::Layout; + using lapack::Job; + using lapack::MatrixType; + + // Compute the SVD of B: B = U diag(svals) VT, where B is overwritten by VT. + int64_t extra_work_size = lwork - k*k; + randblas_require(extra_work_size >= 0); + T* U = work; // <-- just a semantic alias for the start of work. + lapack::gesdd(Job::OverwriteVec, k, n, B, ldb, svals, U, k, nullptr, k); + + // update Q = Q U. + T* more_work = work + k*(k+1); + bool allocate_more_work = extra_work_size < m*k; + if (allocate_more_work) + more_work = new T[m*k]; + lapack::lacpy(MatrixType::General, m, k, Q, ldq, more_work, m); + blas::gemm(Layout::ColMajor, Op::NoTrans, Op::NoTrans, m, k, k, 1.0, more_work, m, U, k, 0.0, Q, ldq); + + if (allocate_more_work) + delete [] more_work; + + return; +} + +int main(int argc, char** argv) { + auto [m, n, vec_nnz] = parse_dimension_args(argc, argv); + // First we set up problem data: a sparse matrix of low numerical rank + // given by a sum of a sparse "signal" matrix of rank 1 and a sparse + // "noise" matrix that has very small norm. + double signal_scale = 1e+2; + double noise_scale = 1e-6; + double prob_nonzero = 1e-4; + RandBLAS::sparse_data::COOMatrix signal_sparse(m, n); + RandBLAS::sparse_data::COOMatrix noise_sparse(m, n); + auto mn = m * n; + double *signal_dense = new double[mn]{}; + double *noise_dense = new double[mn]; + double *u_top = new double[m]{}; + double *v_top = new double[n]{}; + + make_signal_matrix(signal_scale, u_top, m, v_top, n, vec_nnz, signal_dense, signal_sparse); + make_noise_matrix(noise_scale, m, n, prob_nonzero, noise_dense, noise_sparse); + + // Add the two matrices together. + auto mat_sparse = sum_of_coo_matrices(noise_sparse, signal_sparse); + std::cout << signal_sparse.nnz << std::endl; + std::cout << noise_sparse.nnz << std::endl; + std::cout << mat_sparse.nnz << std::endl; + double *mat_dense = new double[mn]{}; + blas::copy(mn, noise_dense, 1, mat_dense, 1); + blas::axpy(mn, 1.0, signal_dense, 1, mat_dense, 1); + + // Run the randomized algorithm! + int64_t k = std::max((int64_t) 3, vec_nnz); // the matrix is really rank-1 plus noise + auto start_timer = std::chrono::high_resolution_clock::now(); + double *U = new double[m*k]{}; + double *VT = new double[k*n]{}; + double *qb_work = new double[std::max(m, n)]; + RandBLAS::RNGState state(0); + qb_decompose_sparse_matrix(mat_sparse, k, U, VT, 2, state, qb_work, std::max(m,n)); + double *svals = new double[std::min(m,n)]; + double *conversion_work = new double[m*k + k*k]; + qb_to_svd(m, n, k, U, svals, m, VT, k, conversion_work, m*k + k*k); + auto stop_timer = std::chrono::high_resolution_clock::now(); + double runtime = (double) std::chrono::duration_cast(stop_timer - start_timer).count(); + runtime = runtime / 1e6; + + // compute angles between (u_top, v_top) and the top singular vectors + double cos_utopu = blas::dot(m, u_top, 1, U, 1); + double cos_vtopv = blas::dot(n, v_top, 1, VT, k); + double theta_utopu = std::acos(cos_utopu) / (std::numbers::pi); + double theta_vtopv = std::acos(cos_vtopv) / (std::numbers::pi); + + std::cout << "runtime of low-rank approximation : " << DOUT(runtime) << std::endl; + std::cout << "Relative angle between top left singular vectors : " << DOUT(theta_utopu) << std::endl; + std::cout << "Relative angle between top right singular vectors : " << DOUT(theta_vtopv) << std::endl; + + delete [] u_top; + delete [] v_top; + delete [] qb_work; + delete [] conversion_work; + delete [] svals; + delete [] signal_dense; + delete [] noise_dense; + delete [] mat_dense; + return 0; +} diff --git a/examples/total-least-squares/tls_dense_skop.cc b/examples/total-least-squares/tls_dense_skop.cc new file mode 100644 index 00000000..41a9eb7f --- /dev/null +++ b/examples/total-least-squares/tls_dense_skop.cc @@ -0,0 +1,174 @@ +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +using std::chrono::high_resolution_clock; +using std::chrono::duration_cast; +using std::chrono::duration; +using std::chrono::milliseconds; + + +void init_noisy_data(int64_t m, int64_t n, int64_t d, double* AB){ + double target_x[n*d]; + double eps[m*d]; + for (int i = 0; i < n; i++) { + target_x[i] = 1; // Target X is the vector of 1's + } + + RandBLAS::DenseDist Dist_A(m,n); + RandBLAS::DenseDist Dist_eps(m,d); + RandBLAS::RNGState state(0); + RandBLAS::RNGState state1(1); + + RandBLAS::fill_dense(Dist_A, AB, state); //Fill A to be a random gaussian + RandBLAS::fill_dense(Dist_eps, eps, state1); //Fill A to be a random gaussian + + blas::gemm(blas::Layout::ColMajor, blas::Op::NoTrans, blas::Op::NoTrans, m, d, n, 1, AB, m, target_x, n, 0, &AB[m*n], m); + + for (int i = 0; i < m*d; i++){ + AB[m*n + i] += eps[i]; // Add Gaussian Noise to right hand side + } +} + +template +void total_least_squares(int64_t m, int64_t n, T* AB, int64_t ldab, T* x, T* work_s, T* work_vt) { + // AB is m-by-(n+1) and stored in column-major format with leading dimension "ldab". + // Its first n columns contain a matrix "A", and its last column contains a vector "B". + // + // This function overwrites x with the solution to + // (A+E)x = B+R + // where (E, R) solve + // solve min{ ||[E, R]||_F : B+R in range(A+E) }. + // + // On exit, AB will have been overwritten by its matrix of left singular vectors, + // its singular values will be stored in work_s, and its (transposed) right singular + // vectors will be stored in work_vt. + lapack::gesdd(lapack::Job::OverwriteVec, m, n+1, AB, ldab, work_s, nullptr, 1, work_vt, n+1); + T scale = work_vt[(n+1)*(n+1)-1]; + for (int i = 0; i < n; i++) { + x[i] = -work_vt[n + i*(n+1)] / scale; + } + return; +} + +/* Let A be a tall data matrix of dimensions m by n where m > n and b be a vector of dimension m. + * In ordinary least squares it assumes that the error lies only in the right hand side vector b, + * and it aims to find a vector x that minimizes ||A*x - b||_2. + * On the other hand, total least squares assumes that the input data matrix A could also incur errors. + * Total least squares aims to find a solution where the error is orthogonal to the regression model. + */ + +// To call the executable run ./TLS_DenseSkOp where are the number of rows and columns +// of A respectively. We expect m > 2*n. +int main(int argc, char* argv[]){ + + // Initialize dimensions + int64_t m; // Number of rows of A, B + int64_t n; // Number of columns of A + + if (argc == 1) { + m = 10000; + n = 500; + } else if (argc == 3) { + m = atoi(argv[1]); + n = atoi(argv[2]); + if (n > m) { + std::cout << "Make sure number of rows are greater than number of cols" << '\n'; + exit(0); + } + } else { + std::cout << "Invalid arguments" << '\n'; + exit(1); + } + + // Define number or rows of the sketching operator + int64_t sk_dim = 2*(n+1); + + // Initialize workspace + double *AB = new double[m*(n+1)]; + double *SAB = new double[sk_dim*(n+1)]; + double *sketch_x = new double[n]; + double *svals = new double[n+1]; + double *VT = new double[(n+1)*(n+1)]; + + // Initialize noisy gaussian data + init_noisy_data(m, n, 1, AB); + + std::cout << "\nDimensions of the augmented matrix [A|B] : " << m << " by " << n+1 << '\n'; + std::cout << "Embedding dimension : " << sk_dim << '\n'; + + // Sample the sketching operator + auto time_constructsketch1 = high_resolution_clock::now(); + RandBLAS::DenseDist Dist{ sk_dim, m }; + uint32_t seed = 1997; + RandBLAS::DenseSkOp S(Dist, seed); + RandBLAS::fill_dense(S); + auto time_constructsketch2 = high_resolution_clock::now(); + double sampling_time = (double) duration_cast(time_constructsketch2 - time_constructsketch1).count()/1000; + std::cout << "\nTime to sample S : " << sampling_time << " seconds" << '\n'; + + // Sketch AB + // SAB = 1.0 * S * AB + 0.0 * SAB + auto time_sketch1 = high_resolution_clock::now(); + RandBLAS::sketch_general( + blas::Layout::ColMajor, // Matrix storage layout of AB and SAB + blas::Op::NoTrans, // NoTrans => \op(S) = S, Trans => \op(S) = S^T + blas::Op::NoTrans, // NoTrans => \op(AB) = AB, Trans => \op(AB) = AB^T + sk_dim, // Number of rows of S and SAB + n + 1, // Number of columns of AB and SAB + m, // Number of rows of AB and columns of S + 1.0, // Scalar alpha - if alpha is zero AB is not accessed + S, // A DenseSkOp or SparseSkOp + AB, // Matrix to be sketched + m, // Leading dimension of AB + 0.0, // Scalar beta - if beta is zero the initial value of SAB is not accessed + SAB, // Sketched matrix SAB + sk_dim // Leading dimension of SAB + ); + auto time_sketch2 = high_resolution_clock::now(); + double sketching_time = (double) duration_cast(time_sketch2 - time_sketch1).count()/1000; + std::cout << "Time to compute SAB = S * AB : " << sketching_time << " seconds\n"; + + auto time_sketched_TLS1 = high_resolution_clock::now(); + total_least_squares(sk_dim, n, SAB, sk_dim, sketch_x, svals, VT); + auto time_sketched_TLS2 = high_resolution_clock::now(); + double sketched_solve_time = (double) duration_cast(time_sketched_TLS2 - time_sketched_TLS1).count()/1000; + std::cout << "Time to perform TLS on sketched data : " << sketched_solve_time << " seconds\n\n"; + + double total_randomized_time = sampling_time + sketching_time + sketched_solve_time; + std::cout << "Total time for the randomized TLS method : " << total_randomized_time << " seconds\n"; + + double* true_x = new double[n]; + auto time_true_TLS1 = high_resolution_clock::now(); + total_least_squares(m, n, AB, m, true_x, svals, VT); + auto time_true_TLS2 = high_resolution_clock::now(); + double true_solve_time = (double) duration_cast(time_true_TLS2 - time_true_TLS1).count()/1000; + std::cout << "Time for the classical TLS method : " << true_solve_time << " seconds" << "\n"; + + std::cout << "Speedup of sketched vs classical method : " << true_solve_time / total_randomized_time << "\n\n"; + + double* delta = new double[n]; + blas::copy(n, sketch_x, 1, delta, 1); + blas::axpy(n, -1, true_x, 1, delta, 1); + double distance = blas::nrm2(n, delta, 1); + double scale = blas::nrm2(n, true_x, 1); + std::cout << "||sketch_x - true_x|| / ||true_x|| : " << distance/scale << "\n\n"; + + delete[] delta; + delete[] true_x; + delete[] AB; + delete[] SAB; + delete[] sketch_x; + delete[] svals; + delete[] VT; + return 0; +} diff --git a/examples/total-least-squares/tls_sparse_skop.cc b/examples/total-least-squares/tls_sparse_skop.cc new file mode 100644 index 00000000..4c92bde6 --- /dev/null +++ b/examples/total-least-squares/tls_sparse_skop.cc @@ -0,0 +1,181 @@ +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +using std::chrono::high_resolution_clock; +using std::chrono::duration_cast; +using std::chrono::duration; +using std::chrono::milliseconds; + + +//TODO: Have the user choose between dense and sketch sketching operator (4 nnz per col) + +void init_noisy_data(int64_t m, int64_t n, int64_t d, double* AB){ + double target_x[n*d]; + double eps[m*d]; + for (int i = 0; i < n; i++) { + target_x[i] = 1; // Target X is the vector of 1's + } + + RandBLAS::DenseDist Dist_A(m,n); + RandBLAS::DenseDist Dist_eps(m,d); + RandBLAS::RNGState state(0); + RandBLAS::RNGState state1(1); + + RandBLAS::fill_dense(Dist_A, AB, state); //Fill A to be a random gaussian + RandBLAS::fill_dense(Dist_eps, eps, state1); //Fill A to be a random gaussian + + blas::gemm(blas::Layout::ColMajor, blas::Op::NoTrans, blas::Op::NoTrans, m, d, n, 1, AB, m, target_x, n, 0, &AB[m*n], m); + + for (int i = 0; i < m*d; i++){ + AB[m*n + i] += eps[i]; // Add Gaussian Noise to right hand side + } +} + +template +void total_least_squares(int64_t m, int64_t n, T* AB, int64_t ldab, T* x, T* work_s, T* work_vt) { + // AB is m-by-(n+1) and stored in column-major format with leading dimension "ldab". + // Its first n columns contain a matrix "A", and its last column contains a vector "B". + // + // This function overwrites x with the solution to + // (A+E)x = B+R + // where (E, R) solve + // solve min{ ||[E, R]||_F : B+R in range(A+E) }. + // + // On exit, AB will have been overwritten by its matrix of left singular vectors, + // its singular values will be stored in work_s, and its (transposed) right singular + // vectors will be stored in work_vt. + lapack::gesdd(lapack::Job::OverwriteVec, m, n+1, AB, ldab, work_s, nullptr, 1, work_vt, n+1); + T scale = work_vt[(n+1)*(n+1)-1]; + for (int i = 0; i < n; i++) { + x[i] = -work_vt[n + i*(n+1)] / scale; + } + return; +} + +/* Let A be a tall data matrix of dimensions m by n where m > n and b be a vector of dimension m. + * In ordinary least squares it assumes that the error lies only in the right hand side vector b, + * and it aims to find a vector x that minimizes ||A*x - b||_2. + * On the other hand, total least squares assumes that the input data matrix A could also incur errors. + * Total least squares aims to find a solution where the error is orthogonal to the regression model. + */ + +// To call the executable run ./TLS_DenseSkOp where are the number of rows and columns +// of A respectively. We expect m > 2*n. +int main(int argc, char* argv[]){ + + // Initialize dimensions + int64_t m; // Number of rows of A, B + int64_t n; // Number of columns of A + + if (argc == 1) { + m = 10000; + n = 500; + } else if (argc == 3) { + m = atoi(argv[1]); + n = atoi(argv[2]); + if (n > m) { + std::cout << "Make sure number of rows are greater than number of cols" << '\n'; + exit(0); + } + } else { + std::cout << "Invalid arguments" << '\n'; + exit(1); + } + + // Define number or rows of the sketching operator + int64_t sk_dim = 2*(n+1); + + // Initialize workspace + double *AB = new double[m*(n+1)]; + double *SAB = new double[sk_dim*(n+1)]; + double *sketch_x = new double[n]; + double *svals = new double[n+1]; + double *VT = new double[(n+1)*(n+1)]; + + // Initialize noisy gaussian data + init_noisy_data(m, n, 1, AB); + + std::cout << "\nDimensions of the augmented matrix [A|B] : " << m << " by " << n+1 << '\n'; + std::cout << "Embedding dimension : " << sk_dim << '\n'; + + // Sample the sketching operator + auto time_constructsketch1 = high_resolution_clock::now(); + RandBLAS::SparseDist Dist = { + .n_rows = sk_dim, // Number of rows of the sketching operator + .n_cols = m, // Number of columns of the sketching operator + .vec_nnz = 8, // Number of non-zero entires per major-axis vector + .major_axis = RandBLAS::MajorAxis::Short // A "SASO" (aka SJLT, aka OSNAP, aka generalized CountSketch) + }; + uint32_t seed = 1997; + RandBLAS::SparseSkOp S(Dist, seed); + RandBLAS::fill_sparse(S); + auto time_constructsketch2 = high_resolution_clock::now(); + double sampling_time = (double) duration_cast(time_constructsketch2 - time_constructsketch1).count()/1000; + std::cout << "\nTime to sample S : " << sampling_time << " seconds" << '\n'; + + // Sketch AB + // SAB = 1.0 * S * AB + 0.0 * SAB + auto time_sketch1 = high_resolution_clock::now(); + RandBLAS::sketch_general( + blas::Layout::ColMajor, // Matrix storage layout of AB and SAB + blas::Op::NoTrans, // NoTrans => \op(S) = S, Trans => \op(S) = S^T + blas::Op::NoTrans, // NoTrans => \op(AB) = AB, Trans => \op(AB) = AB^T + sk_dim, // Number of rows of S and SAB + n + 1, // Number of columns of AB and SAB + m, // Number of rows of AB and columns of S + 1.0, // Scalar alpha - if alpha is zero AB is not accessed + S, // A DenseSkOp or SparseSkOp + AB, // Matrix to be sketched + m, // Leading dimension of AB + 0.0, // Scalar beta - if beta is zero the initial value of SAB is not accessed + SAB, // Sketched matrix SAB + sk_dim // Leading dimension of SAB + ); + auto time_sketch2 = high_resolution_clock::now(); + double sketching_time = (double) duration_cast(time_sketch2 - time_sketch1).count()/1000; + std::cout << "Time to compute SAB = S * AB : " << sketching_time << " seconds\n"; + + auto time_sketched_TLS1 = high_resolution_clock::now(); + total_least_squares(sk_dim, n, SAB, sk_dim, sketch_x, svals, VT); + auto time_sketched_TLS2 = high_resolution_clock::now(); + double sketched_solve_time = (double) duration_cast(time_sketched_TLS2 - time_sketched_TLS1).count()/1000; + std::cout << "Time to perform TLS on sketched data : " << sketched_solve_time << " seconds\n\n"; + + double total_randomized_time = sampling_time + sketching_time + sketched_solve_time; + std::cout << "Total time for the randomized TLS method : " << total_randomized_time << " seconds\n"; + + double* true_x = new double[n]; + auto time_true_TLS1 = high_resolution_clock::now(); + total_least_squares(m, n, AB, m, true_x, svals, VT); + auto time_true_TLS2 = high_resolution_clock::now(); + double true_solve_time = (double) duration_cast(time_true_TLS2 - time_true_TLS1).count()/1000; + std::cout << "Time for the classical TLS method : " << true_solve_time << " seconds" << "\n"; + + std::cout << "Speedup of sketched vs classical method : " << true_solve_time / total_randomized_time << "\n\n"; + + double* delta = new double[n]; + blas::copy(n, sketch_x, 1, delta, 1); + blas::axpy(n, -1, true_x, 1, delta, 1); + double distance = blas::nrm2(n, delta, 1); + double scale = blas::nrm2(n, true_x, 1); + std::cout << "||sketch_x - true_x|| / ||true_x|| : " << distance/scale << "\n\n"; + + delete[] delta; + delete[] true_x; + delete[] AB; + delete[] SAB; + delete[] sketch_x; + delete[] svals; + delete[] VT; + return 0; +}