From ec4e6f858a4795b9eb16e06088e2193d599ae22c Mon Sep 17 00:00:00 2001 From: Riley Murray Date: Thu, 27 Jun 2024 18:45:38 -0400 Subject: [PATCH] Restructure files (#100) --- RandBLAS/DevNotes.md | 7 +- RandBLAS/dense_skops.hh | 375 ++++++++------- RandBLAS/skge.hh | 625 ++++++++++++++++++++++++- RandBLAS/skge3_to_gemm.hh | 368 --------------- RandBLAS/skges_to_spmm.hh | 330 ------------- RandBLAS/sparse_data/DevNotes.md | 5 +- RandBLAS/sparse_data/coo_spmm_impl.hh | 3 + RandBLAS/sparse_data/csc_spmm_impl.hh | 3 + RandBLAS/sparse_data/csr_spmm_impl.hh | 3 + RandBLAS/sparse_data/sksp.hh | 327 ++++++++++++- RandBLAS/sparse_data/sksp3_to_spmm.hh | 368 --------------- RandBLAS/sparse_skops.hh | 3 - RandBLAS/util.hh | 13 - test/test_matmul_cores/linop_common.hh | 3 +- 14 files changed, 1137 insertions(+), 1296 deletions(-) delete mode 100644 RandBLAS/skge3_to_gemm.hh delete mode 100644 RandBLAS/skges_to_spmm.hh delete mode 100644 RandBLAS/sparse_data/sksp3_to_spmm.hh diff --git a/RandBLAS/DevNotes.md b/RandBLAS/DevNotes.md index d4e6cf61..2db03d44 100644 --- a/RandBLAS/DevNotes.md +++ b/RandBLAS/DevNotes.md @@ -23,11 +23,8 @@ for our user guide. * The ``sketch_general`` functions in ``RandBLAS/skge.hh`` are the main entry point for sketching dense data. These functions are small wrappers around functions with more BLAS-like names: - * ``lskge3`` and ``rskge3`` in ``RandBLAS/skge3_to_gemm.hh``. - * ``lskges`` and ``rskges`` in ``RandBLAS/skges_to_spmm.hh``. - The former pair of functions are just fancy wrappers around GEMM. - The latter pair of functions trigger a far more opaque call sequence, since they rely on sparse - matrix operations. + * ``lskge3`` and ``rskge3`` are basically wrappers around GEMM. + * ``lskges`` and ``rskges`` trigger an opaque call sequence that uses sparse matrix operations. * There is no widely accepted standard for sparse BLAS operations. This is a bummer because sparse matrices are super important in data science and scientific computing. In view of this, diff --git a/RandBLAS/dense_skops.hh b/RandBLAS/dense_skops.hh index 08b6222e..c7e60e90 100644 --- a/RandBLAS/dense_skops.hh +++ b/RandBLAS/dense_skops.hh @@ -46,12 +46,192 @@ #include +namespace RandBLAS::dense { + +template +bool compare_ctr(typename RNG::ctr_type c1, typename RNG::ctr_type c2) { + int len = c1.size(); + + for (int ind = len - 1; ind >= 0; ind--) { + if (c1[ind] > c2[ind]) { + return true; + } else if (c1[ind] < c2[ind]) { + return false; + } + } + return false; +} + +/** + * Fill buff with random values so it gives a row-major representation of an n_srows \math{\times} n_scols + * submatrix of some implicitly defined parent matrix. + * + * The implicit parent matrix is **imagined** as a buffer in row-major order with "n_cols" columns. + * "ptr" is the pointer offset for the desired submatrix in the imagined buffer of the parent matrix. + * + * @tparam T the data type of the matrix + * @tparam RNG a random123 CBRNG type + * @tparam OP an operator that transforms raw random values into matrix + * elements. See r123ext::uneg11 and r123ext::boxmul. + * + * @param[in] n_cols + * The number of columns in the implicitly defined parent matrix. + * @param[in] smat + * A pointer to a region of memory with space for n_rows \math{\times} lda elements of type T. + * This memory will be filled with random values by wrting rows of length "n_scols" + * with an inter-row stride of length "lda". + * @param[in] n_srows + * The number of rows in the submatrix. + * @param[in] n_scols + * The number of columns in the submatrix. + * @param[in] ptr + * The starting locaiton within the random matrix, for which + * the submatrix is to be generated + * @param[in] seed + * A CBRNG state + * @param[in] lda + * If positive then must be >= n_scols. + * Otherwise, we automatically set it to n_scols. + * + * @returns the updated CBRNG state + * + * Notes + * ----- + * If RandBLAS is compiled with OpenMP threading support enabled, the operation is parallelized + * using OMP_NUM_THREADS. The sequence of values generated does not depend on the number of threads. + * + */ +template +static RNGState fill_dense_submat_impl( + int64_t n_cols, + T* smat, + int64_t n_srows, + int64_t n_scols, + int64_t ptr, + const RNGState & seed, + int64_t lda = 0 +) { + if (lda <= 0) { + lda = n_scols; + } else { + randblas_require(lda >= n_scols); + } + randblas_require(n_cols >= n_scols); + RNG rng; + using CTR_t = typename RNG::ctr_type; + using KEY_t = typename RNG::key_type; + CTR_t c = seed.counter; + KEY_t k = seed.key; + + int64_t pad = 0; + // ^ computed such that n_cols+pad is divisible by RNG::static_size + if (n_cols % CTR_t::static_size != 0) { + pad = CTR_t::static_size - n_cols % CTR_t::static_size; + } + + int64_t n_cols_padded = n_cols + pad; + // ^ smallest number of columns, greater than or equal to n_cols, that would be divisible by CTR_t::static_size + int64_t ptr_padded = ptr + ptr / n_cols * pad; + // ^ ptr corresponding to the padded matrix + int64_t r0_padded = ptr_padded / CTR_t::static_size; + // ^ starting counter corresponding to ptr_padded + int64_t r1_padded = (ptr_padded + n_scols - 1) / CTR_t::static_size; + // ^ ending counter corresponding to ptr of the last element of the row + int64_t ctr_gap = n_cols_padded / CTR_t::static_size; + // ^ number of counters between the first counter of the row to the first counter of the next row; + int64_t s0 = ptr_padded % CTR_t::static_size; + int64_t e1 = (ptr_padded + n_scols - 1) % CTR_t::static_size; + + int64_t num_thrds = 1; + #if defined(RandBLAS_HAS_OpenMP) + #pragma omp parallel + { + num_thrds = omp_get_num_threads(); + } + #endif + + //Instead of using thrd_arr just initialize ctr_arr to be zero counters; + CTR_t *ctr_arr = new CTR_t[num_thrds]; + for (int i = 0; i < num_thrds; i++) { + ctr_arr[i] = c; + } + + #pragma omp parallel firstprivate(c, k) + { + + auto cc = c; + int64_t prev = 0; + int64_t i; + int64_t r0, r1; + int64_t ind; + int64_t thrd = 0; + + #pragma omp for + for (int row = 0; row < n_srows; row++) { + + #if defined(RandBLAS_HAS_OpenMP) + thrd = omp_get_thread_num(); + #endif + + ind = 0; + r0 = r0_padded + ctr_gap*row; + r1 = r1_padded + ctr_gap*row; + + cc.incr(r0 - prev); + prev = r0; + auto rv = OP::generate(rng, cc, k); + int64_t range = (r1 > r0)? CTR_t::static_size - 1 : e1; + for (i = s0; i <= range; i++) { + smat[ind + row * lda] = rv[i]; + ind++; + } + // middle + int64_t tmp = r0; + while( tmp < r1 - 1) { + cc.incr(); + prev++; + rv = OP::generate(rng, cc, k); + for (i = 0; i < CTR_t::static_size; i++) { + smat[ind + row * lda] = rv[i]; + ind++; + } + tmp++; + } + + // end + if ( r1 > r0 ) { + cc.incr(); + prev++; + rv = OP::generate(rng, cc, k); + for (i = 0; i <= e1; i++) { + smat[ind + row * lda] = rv[i]; + ind++; + } + } + ctr_arr[thrd] = cc; + } + + } + + //finds the largest counter in the counter array + CTR_t max_c = ctr_arr[0]; + for (int i = 1; i < num_thrds; i++) { + if (compare_ctr(ctr_arr[i], max_c)) { + max_c = ctr_arr[i]; + } + } + delete [] ctr_arr; + + max_c.incr(); + return RNGState {max_c, k}; +} + +} // end namespace RandBLAS::dense + + namespace RandBLAS { + // ============================================================================= -/// We call a sketching operator "dense" if (1) it is naturally represented with a -/// buffer and (2) the natural way to apply that operator to a matrix is -/// to use the operator's buffer in GEMM. -/// /// We support two distributions for dense sketching operators: those whose /// entries are iid Gaussians or iid uniform over a symmetric interval. /// For implementation reasons, we also expose an option to indicate that an @@ -287,192 +467,6 @@ DenseSkOp::~DenseSkOp() { } } -} // end namespace RandBLAS (will continue later in this file) - -namespace RandBLAS::dense { - -template -bool compare_ctr(typename RNG::ctr_type c1, typename RNG::ctr_type c2) { - int len = c1.size(); - - for (int ind = len - 1; ind >= 0; ind--) { - if (c1[ind] > c2[ind]) { - return true; - } else if (c1[ind] < c2[ind]) { - return false; - } - } - return false; -} - -/** - * Fill buff with random values so it gives a row-major representation of an n_srows \math{\times} n_scols - * submatrix of some implicitly defined parent matrix. - * - * The implicit parent matrix is **imagined** as a buffer in row-major order with "n_cols" columns. - * "ptr" is the pointer offset for the desired submatrix in the imagined buffer of the parent matrix. - * - * @tparam T the data type of the matrix - * @tparam RNG a random123 CBRNG type - * @tparam OP an operator that transforms raw random values into matrix - * elements. See r123ext::uneg11 and r123ext::boxmul. - * - * @param[in] n_cols - * The number of columns in the implicitly defined parent matrix. - * @param[in] smat - * A pointer to a region of memory with space for n_rows \math{\times} lda elements of type T. - * This memory will be filled with random values by wrting rows of length "n_scols" - * with an inter-row stride of length "lda". - * @param[in] n_srows - * The number of rows in the submatrix. - * @param[in] n_scols - * The number of columns in the submatrix. - * @param[in] ptr - * The starting locaiton within the random matrix, for which - * the submatrix is to be generated - * @param[in] seed - * A CBRNG state - * @param[in] lda - * If positive then must be >= n_scols. - * Otherwise, we automatically set it to n_scols. - * - * @returns the updated CBRNG state - * - * Notes - * ----- - * If RandBLAS is compiled with OpenMP threading support enabled, the operation is parallelized - * using OMP_NUM_THREADS. The sequence of values generated does not depend on the number of threads. - * - */ -template -static RNGState fill_dense_submat_impl( - int64_t n_cols, - T* smat, - int64_t n_srows, - int64_t n_scols, - int64_t ptr, - const RNGState & seed, - int64_t lda = 0 -) { - if (lda <= 0) { - lda = n_scols; - } else { - randblas_require(lda >= n_scols); - } - randblas_require(n_cols >= n_scols); - RNG rng; - using CTR_t = typename RNG::ctr_type; - using KEY_t = typename RNG::key_type; - CTR_t c = seed.counter; - KEY_t k = seed.key; - - int64_t pad = 0; - // ^ computed such that n_cols+pad is divisible by RNG::static_size - if (n_cols % CTR_t::static_size != 0) { - pad = CTR_t::static_size - n_cols % CTR_t::static_size; - } - - int64_t n_cols_padded = n_cols + pad; - // ^ smallest number of columns, greater than or equal to n_cols, that would be divisible by CTR_t::static_size - int64_t ptr_padded = ptr + ptr / n_cols * pad; - // ^ ptr corresponding to the padded matrix - int64_t r0_padded = ptr_padded / CTR_t::static_size; - // ^ starting counter corresponding to ptr_padded - int64_t r1_padded = (ptr_padded + n_scols - 1) / CTR_t::static_size; - // ^ ending counter corresponding to ptr of the last element of the row - int64_t ctr_gap = n_cols_padded / CTR_t::static_size; - // ^ number of counters between the first counter of the row to the first counter of the next row; - int64_t s0 = ptr_padded % CTR_t::static_size; - int64_t e1 = (ptr_padded + n_scols - 1) % CTR_t::static_size; - - int64_t num_thrds = 1; - #if defined(RandBLAS_HAS_OpenMP) - #pragma omp parallel - { - num_thrds = omp_get_num_threads(); - } - #endif - - //Instead of using thrd_arr just initialize ctr_arr to be zero counters; - CTR_t *ctr_arr = new CTR_t[num_thrds]; - for (int i = 0; i < num_thrds; i++) { - ctr_arr[i] = c; - } - - #pragma omp parallel firstprivate(c, k) - { - - auto cc = c; - int64_t prev = 0; - int64_t i; - int64_t r0, r1; - int64_t ind; - int64_t thrd = 0; - - #pragma omp for - for (int row = 0; row < n_srows; row++) { - - #if defined(RandBLAS_HAS_OpenMP) - thrd = omp_get_thread_num(); - #endif - - ind = 0; - r0 = r0_padded + ctr_gap*row; - r1 = r1_padded + ctr_gap*row; - - cc.incr(r0 - prev); - prev = r0; - auto rv = OP::generate(rng, cc, k); - int64_t range = (r1 > r0)? CTR_t::static_size - 1 : e1; - for (i = s0; i <= range; i++) { - smat[ind + row * lda] = rv[i]; - ind++; - } - // middle - int64_t tmp = r0; - while( tmp < r1 - 1) { - cc.incr(); - prev++; - rv = OP::generate(rng, cc, k); - for (i = 0; i < CTR_t::static_size; i++) { - smat[ind + row * lda] = rv[i]; - ind++; - } - tmp++; - } - - // end - if ( r1 > r0 ) { - cc.incr(); - prev++; - rv = OP::generate(rng, cc, k); - for (i = 0; i <= e1; i++) { - smat[ind + row * lda] = rv[i]; - ind++; - } - } - ctr_arr[thrd] = cc; - } - - } - - //finds the largest counter in the counter array - CTR_t max_c = ctr_arr[0]; - for (int i = 1; i < num_thrds; i++) { - if (compare_ctr(ctr_arr[i], max_c)) { - max_c = ctr_arr[i]; - } - } - delete [] ctr_arr; - - max_c.incr(); - return RNGState {max_c, k}; -} - -} // end namespace RandBLAS::dense - -namespace RandBLAS { - // ============================================================================= /// @verbatim embed:rst:leading-slashes /// @@ -644,6 +638,7 @@ RNGState fill_dense( S.del_buff_on_destruct = true; return next_state; } + } // end namespace RandBLAS #endif \ No newline at end of file diff --git a/RandBLAS/skge.hh b/RandBLAS/skge.hh index 5ec0b8d4..647e02dd 100644 --- a/RandBLAS/skge.hh +++ b/RandBLAS/skge.hh @@ -35,8 +35,6 @@ #include "RandBLAS/random_gen.hh" #include "RandBLAS/dense_skops.hh" #include "RandBLAS/sparse_skops.hh" -#include "RandBLAS/skge3_to_gemm.hh" -#include "RandBLAS/skges_to_spmm.hh" #include #include @@ -46,12 +44,6 @@ #include #include -namespace RandBLAS { - -using namespace RandBLAS::dense; -using namespace RandBLAS::sparse; - - /* Intended macro definitions. .. |op| mathmacro:: \operatorname{op} @@ -63,8 +55,617 @@ using namespace RandBLAS::sparse; .. |opS| mathmacro:: \texttt{opS} */ +namespace RandBLAS::dense { + +using RandBLAS::DenseSkOp; +using RandBLAS::fill_dense; + +// MARK: LSKGE3 + +// ============================================================================= +/// @verbatim embed:rst:leading-slashes +/// +/// .. |op| mathmacro:: \operatorname{op} +/// .. |mat| mathmacro:: \operatorname{mat} +/// .. |submat| mathmacro:: \operatorname{submat} +/// .. |lda| mathmacro:: \mathrm{lda} +/// .. |ldb| mathmacro:: \mathrm{ldb} +/// .. |opA| mathmacro:: \mathrm{opA} +/// .. |opS| mathmacro:: \mathrm{opS} +/// +/// @endverbatim +/// LSKGE3: Perform a GEMM-like operation +/// @verbatim embed:rst:leading-slashes +/// .. math:: +/// \mat(B) = \alpha \cdot \underbrace{\op(\submat(S))}_{d \times m} \cdot \underbrace{\op(\mat(A))}_{m \times n} + \beta \cdot \underbrace{\mat(B)}_{d \times n}, \tag{$\star$} +/// @endverbatim +/// where \math{\alpha} and \math{\beta} are real scalars, \math{\op(X)} either returns a matrix \math{X} +/// or its transpose, and \math{S} is a sketching operator that takes Level 3 BLAS effort to apply. +/// +/// @verbatim embed:rst:leading-slashes +/// What are :math:`\mat(A)` and :math:`\mat(B)`? +/// Their shapes are defined implicitly by :math:`(d, m, n, \opA)`. +/// Their precise contents are determined by :math:`(A, \lda)`, :math:`(B, \ldb)`, +/// and "layout", following the same convention as BLAS. +/// +/// What is :math:`\submat(S)`? +/// Its shape is defined implicitly by :math:`(\opS, d, m)`. +/// If :math:`{\submat(S)}` is of shape :math:`r \times c`, +/// then it is the :math:`r \times c` submatrix of :math:`{S}` whose upper-left corner +/// appears at index :math:`(\texttt{ro_s}, \texttt{co_s})` of :math:`{S}`. +/// @endverbatim +/// @param[in] layout +/// Layout::ColMajor or Layout::RowMajor +/// - Matrix storage for \math{\mat(A)} and \math{\mat(B)}. +/// +/// @param[in] opS +/// - If \math{\opS} = NoTrans, then \math{ \op(\submat(S)) = \submat(S)}. +/// - If \math{\opS} = Trans, then \math{\op(\submat(S)) = \submat(S)^T }. +/// @param[in] opA +/// - If \math{\opA} == NoTrans, then \math{\op(\mat(A)) = \mat(A)}. +/// - If \math{\opA} == Trans, then \math{\op(\mat(A)) = \mat(A)^T}. +/// @param[in] d +/// A nonnegative integer. +/// - The number of rows in \math{\mat(B)} +/// - The number of rows in \math{\op(\mat(S))}. +/// +/// @param[in] n +/// A nonnegative integer. +/// - The number of columns in \math{\mat(B)} +/// - The number of columns in \math{\op(\mat(A))}. +/// +/// @param[in] m +/// A nonnegative integer. +/// - The number of columns in \math{\op(\submat(S))} +/// - The number of rows in \math{\op(\mat(A))}. +/// +/// @param[in] alpha +/// A real scalar. +/// - If zero, then \math{A} is not accessed. +/// +/// @param[in] S +/// A DenseSkOp object. +/// - Defines \math{\submat(S)}. +/// +/// @param[in] ro_s +/// A nonnegative integer. +/// - The rows of \math{\submat(S)} are a contiguous subset of rows of \math{S}. +/// - The rows of \math{\submat(S)} start at \math{S[\texttt{ro_s}, :]}. +/// +/// @param[in] co_s +/// A nonnnegative integer. +/// - The columns of \math{\submat(S)} are a contiguous subset of columns of \math{S}. +/// - The columns \math{\submat(S)} start at \math{S[:,\texttt{co_s}]}. +/// +/// @param[in] A +/// Pointer to a 1D array of real scalars. +/// - Defines \math{\mat(A)}. +/// +/// @param[in] lda +/// A nonnegative integer. +/// * Leading dimension of \math{\mat(A)} when reading from \math{A}. +/// * If layout == ColMajor, then +/// @verbatim embed:rst:leading-slashes +/// .. math:: +/// \mat(A)[i, j] = A[i + j \cdot \lda]. +/// @endverbatim +/// In this case, \math{\lda} must be \math{\geq} the length of a column in \math{\mat(A)}. +/// * If layout == RowMajor, then +/// @verbatim embed:rst:leading-slashes +/// .. math:: +/// \mat(A)[i, j] = A[i \cdot \lda + j]. +/// @endverbatim +/// In this case, \math{\lda} must be \math{\geq} the length of a row in \math{\mat(A)}. +/// +/// @param[in] beta +/// A real scalar. +/// - If zero, then \math{B} need not be set on input. +/// +/// @param[in, out] B +/// Pointer to 1D array of real scalars. +/// - On entry, defines \math{\mat(B)} +/// on the RIGHT-hand side of \math{(\star)}. +/// - On exit, defines \math{\mat(B)} +/// on the LEFT-hand side of \math{(\star)}. +/// +/// @param[in] ldb +/// - Leading dimension of \math{\mat(B)} when reading from \math{B}. +/// - Refer to documentation for \math{\lda} for details. +/// +template +void lskge3( + blas::Layout layout, + blas::Op opS, + blas::Op opA, + int64_t d, // B is d-by-n + int64_t n, // op(A) is m-by-n + int64_t m, // op(S) is d-by-m + T alpha, + DenseSkOp &S, + int64_t ro_s, + int64_t co_s, + const T *A, + int64_t lda, + T beta, + T *B, + int64_t ldb +){ + auto [rows_submat_S, cols_submat_S] = dims_before_op(d, m, opS); + if (!S.buff) { + // We'll make a shallow copy of the sketching operator, take responsibility for filling the memory + // of that sketching operator, and then call LSKGE3 with that new object. + T *buff = new T[rows_submat_S * cols_submat_S]; + fill_dense(S.dist, rows_submat_S, cols_submat_S, ro_s, co_s, buff, S.seed_state); + DenseDist D{rows_submat_S, cols_submat_S, DenseDistName::BlackBox, S.dist.major_axis}; + DenseSkOp S_(D, S.seed_state, buff); + lskge3(layout, opS, opA, d, n, m, alpha, S_, 0, 0, A, lda, beta, B, ldb); + delete [] buff; + return; + } + randblas_require( S.dist.n_rows >= rows_submat_S + ro_s ); + randblas_require( S.dist.n_cols >= cols_submat_S + co_s ); + auto [rows_A, cols_A] = dims_before_op(m, n, opA); + if (layout == blas::Layout::ColMajor) { + randblas_require(lda >= rows_A); + randblas_require(ldb >= d); + } else { + randblas_require(lda >= cols_A); + randblas_require(ldb >= n); + } + + auto [pos, lds] = offset_and_ldim(S.layout, S.dist.n_rows, S.dist.n_cols, ro_s, co_s); + T* S_ptr = &S.buff[pos]; + if (S.layout != layout) + opS = (opS == blas::Op::NoTrans) ? blas::Op::Trans : blas::Op::NoTrans; -// MARK: SUBMAT(S), LEFT + blas::gemm(layout, opS, opA, d, n, m, alpha, S_ptr, lds, A, lda, beta, B, ldb); + return; +} + +// MARK: RSKGE3 + +// ============================================================================= +/// RSKGE3: Perform a GEMM-like operation +/// @verbatim embed:rst:leading-slashes +/// .. math:: +/// \mat(B) = \alpha \cdot \underbrace{\op(\mat(A))}_{m \times n} \cdot \underbrace{\op(\submat(S))}_{n \times d} + \beta \cdot \underbrace{\mat(B)}_{m \times d}, \tag{$\star$} +/// @endverbatim +/// where \math{\alpha} and \math{\beta} are real scalars, \math{\op(X)} either returns a matrix \math{X} +/// or its transpose, and \math{S} is a sketching operator that takes Level 3 BLAS effort to apply. +/// +/// @verbatim embed:rst:leading-slashes +/// What are :math:`\mat(A)` and :math:`\mat(B)`? +/// Their shapes are defined implicitly by :math:`(m, d, n, \opA)`. +/// Their precise contents are determined by :math:`(A, \lda)`, :math:`(B, \ldb)`, +/// and "layout", following the same convention as BLAS. +/// +/// What is :math:`\submat(S)`? +/// Its shape is defined implicitly by :math:`(\opS, n, d)`. +/// If :math:`{\submat(S)}` is of shape :math:`r \times c`, +/// then it is the :math:`r \times c` submatrix of :math:`{S}` whose upper-left corner +/// appears at index :math:`(\texttt{ro_s}, \texttt{co_s})` of :math:`{S}`. +/// @endverbatim +/// @param[in] layout +/// Layout::ColMajor or Layout::RowMajor +/// - Matrix storage for \math{\mat(A)} and \math{\mat(B)}. +/// +/// @param[in] opA +/// - If \math{\opA} == NoTrans, then \math{\op(\mat(A)) = \mat(A)}. +/// - If \math{\opA} == Trans, then \math{\op(\mat(A)) = \mat(A)^T}. +/// +/// @param[in] opS +/// - If \math{\opS} = NoTrans, then \math{ \op(\submat(S)) = \submat(S)}. +/// - If \math{\opS} = Trans, then \math{\op(\submat(S)) = \submat(S)^T }. +/// +/// @param[in] m +/// A nonnegative integer. +/// - The number of rows in \math{\mat(B)}. +/// - The number of rows in \math{\op(\mat(A))}. +/// +/// @param[in] d +/// A nonnegative integer. +/// - The number of columns in \math{\mat(B)} +/// - The number of columns in \math{\op(\mat(S))}. +/// +/// @param[in] n +/// A nonnegative integer. +/// - The number of columns in \math{\op(\mat(A))} +/// - The number of rows in \math{\op(\submat(S))}. +/// +/// @param[in] alpha +/// A real scalar. +/// - If zero, then \math{A} is not accessed. +/// +/// @param[in] A +/// Pointer to a 1D array of real scalars. +/// - Defines \math{\mat(A)}. +/// +/// @param[in] lda +/// A nonnegative integer. +/// * Leading dimension of \math{\mat(A)} when reading from \math{A}. +/// * If layout == ColMajor, then +/// @verbatim embed:rst:leading-slashes +/// .. math:: +/// \mat(A)[i, j] = A[i + j \cdot \lda]. +/// @endverbatim +/// In this case, \math{\lda} must be \math{\geq} the length of a column in \math{\mat(A)}. +/// * If layout == RowMajor, then +/// @verbatim embed:rst:leading-slashes +/// .. math:: +/// \mat(A)[i, j] = A[i \cdot \lda + j]. +/// @endverbatim +/// In this case, \math{\lda} must be \math{\geq} the length of a row in \math{\mat(A)}. +/// +/// @param[in] S +/// A DenseSkOp object. +/// - Defines \math{\submat(S)}. +/// +/// @param[in] ro_s +/// A nonnegative integer. +/// - The rows of \math{\submat(S)} are a contiguous subset of rows of \math{S}. +/// - The rows of \math{\submat(S)} start at \math{S[\texttt{ro_s}, :]}. +/// +/// @param[in] co_s +/// A nonnnegative integer. +/// - The columns of \math{\submat(S)} are a contiguous subset of columns of \math{S}. +/// - The columns \math{\submat(S)} start at \math{S[:,\texttt{co_s}]}. +/// +/// @param[in] beta +/// A real scalar. +/// - If zero, then \math{B} need not be set on input. +/// +/// @param[in, out] B +/// Pointer to 1D array of real scalars. +/// - On entry, defines \math{\mat(B)} +/// on the RIGHT-hand side of \math{(\star)}. +/// - On exit, defines \math{\mat(B)} +/// on the LEFT-hand side of \math{(\star)}. +/// +/// @param[in] ldb +/// - Leading dimension of \math{\mat(B)} when reading from \math{B}. +/// - Refer to documentation for \math{\lda} for details. +/// +template +void rskge3( + blas::Layout layout, + blas::Op opA, + blas::Op opS, + int64_t m, // B is m-by-d + int64_t d, // op(S) is n-by-d + int64_t n, // op(A) is m-by-n + T alpha, + const T *A, + int64_t lda, + DenseSkOp &S, + int64_t ro_s, + int64_t co_s, + T beta, + T *B, + int64_t ldb +){ + auto [rows_submat_S, cols_submat_S] = dims_before_op(n, d, opS); + if (!S.buff) { + // We'll make a shallow copy of the sketching operator, take responsibility for filling the memory + // of that sketching operator, and then call RSKGE3 with that new object. + T *buff = new T[rows_submat_S * cols_submat_S]; + fill_dense(S.dist, rows_submat_S, cols_submat_S, ro_s, co_s, buff, S.seed_state); + DenseDist D{rows_submat_S, cols_submat_S, DenseDistName::BlackBox, S.dist.major_axis}; + DenseSkOp S_(D, S.seed_state, buff); + rskge3(layout, opA, opS, m, d, n, alpha, A, lda, S_, 0, 0, beta, B, ldb); + delete [] buff; + return; + } + randblas_require( S.dist.n_rows >= rows_submat_S + ro_s ); + randblas_require( S.dist.n_cols >= cols_submat_S + co_s ); + auto [rows_A, cols_A] = dims_before_op(m, n, opA); + if (layout == blas::Layout::ColMajor) { + randblas_require(lda >= rows_A); + randblas_require(ldb >= m); + } else { + randblas_require(lda >= cols_A); + randblas_require(ldb >= d); + } + + auto [pos, lds] = offset_and_ldim(S.layout, S.dist.n_rows, S.dist.n_cols, ro_s, co_s); + T* S_ptr = &S.buff[pos]; + if (S.layout != layout) + opS = (opS == blas::Op::NoTrans) ? blas::Op::Trans : blas::Op::NoTrans; + + blas::gemm(layout, opA, opS, m, d, n, alpha, A, lda, S_ptr, lds, beta, B, ldb); + return; +} + +} // end namespace RandBLAS::dense + + +namespace RandBLAS::sparse { + +// MARK: LSKGES + +// ============================================================================= +/// @verbatim embed:rst:leading-slashes +/// +/// .. |op| mathmacro:: \operatorname{op} +/// .. |mat| mathmacro:: \operatorname{mat} +/// .. |submat| mathmacro:: \operatorname{submat} +/// .. |lda| mathmacro:: \mathrm{lda} +/// .. |ldb| mathmacro:: \mathrm{ldb} +/// .. |opA| mathmacro:: \mathrm{opA} +/// .. |opS| mathmacro:: \mathrm{opS} +/// +/// @endverbatim +/// LSKGES: Perform a GEMM-like operation +/// @verbatim embed:rst:leading-slashes +/// .. math:: +/// \mat(B) = \alpha \cdot \underbrace{\op(\submat(S))}_{d \times m} \cdot \underbrace{\op(\mat(A))}_{m \times n} + \beta \cdot \underbrace{\mat(B)}_{d \times n}, \tag{$\star$} +/// @endverbatim +/// where \math{\alpha} and \math{\beta} are real scalars, \math{\op(X)} either returns a matrix \math{X} +/// or its transpose, and \math{S} is a sparse sketching operator. +/// +/// @verbatim embed:rst:leading-slashes +/// What are :math:`\mat(A)` and :math:`\mat(B)`? +/// Their shapes are defined implicitly by :math:`(d, m, n, \opA)`. +/// Their precise contents are determined by :math:`(A, \lda)`, :math:`(B, \ldb)`, +/// and "layout", following the same convention as BLAS. +/// +/// What is :math:`\submat(S)`? +/// Its shape is defined implicitly by :math:`(\opS, d, m)`. +/// If :math:`{\submat(S)}` is of shape :math:`r \times c`, +/// then it is the :math:`r \times c` submatrix of :math:`{S}` whose upper-left corner +/// appears at index :math:`(\texttt{ro_s}, \texttt{co_s})` of :math:`{S}`. +/// @endverbatim +/// @param[in] layout +/// Layout::ColMajor or Layout::RowMajor +/// - Matrix storage for \math{\mat(A)} and \math{\mat(B)}. +/// +/// @param[in] opS +/// - If \math{\opS} = NoTrans, then \math{ \op(\submat(S)) = \submat(S)}. +/// - If \math{\opS} = Trans, then \math{\op(\submat(S)) = \submat(S)^T }. +/// +/// @param[in] opA +/// - If \math{\opA} == NoTrans, then \math{\op(\mat(A)) = \mat(A)}. +/// - If \math{\opA} == Trans, then \math{\op(\mat(A)) = \mat(A)^T}. +/// +/// @param[in] d +/// A nonnegative integer. +/// - The number of rows in \math{\mat(B)} +/// - The number of rows in \math{\op(\mat(S))}. +/// +/// @param[in] n +/// A nonnegative integer. +/// - The number of columns in \math{\mat(B)} +/// - The number of columns in \math{\op(\mat(A))}. +/// +/// @param[in] m +/// A nonnegative integer. +/// - The number of columns in \math{\op(\submat(S))} +/// - The number of rows in \math{\op(\mat(A))}. +/// +/// @param[in] alpha +/// A real scalar. +/// - If zero, then \math{A} is not accessed. +/// +/// @param[in] S +/// A SparseSkOp object. +/// - Defines \math{\submat(S)}. +/// +/// @param[in] ro_s +/// A nonnegative integer. +/// - The rows of \math{\submat(S)} are a contiguous subset of rows of \math{S}. +/// - The rows of \math{\submat(S)} start at \math{S[\texttt{ro_s}, :]}. +/// +/// @param[in] co_s +/// A nonnnegative integer. +/// - The columns of \math{\submat(S)} are a contiguous subset of columns of \math{S}. +/// - The columns \math{\submat(S)} start at \math{S[:,\texttt{co_s}]}. +/// +/// @param[in] A +/// Pointer to a 1D array of real scalars. +/// - Defines \math{\mat(A)}. +/// +/// @param[in] lda +/// A nonnegative integer. +/// * Leading dimension of \math{\mat(A)} when reading from \math{A}. +/// * If layout == ColMajor, then +/// @verbatim embed:rst:leading-slashes +/// .. math:: +/// \mat(A)[i, j] = A[i + j \cdot \lda]. +/// @endverbatim +/// In this case, \math{\lda} must be \math{\geq} the length of a column in \math{\mat(A)}. +/// * If layout == RowMajor, then +/// @verbatim embed:rst:leading-slashes +/// .. math:: +/// \mat(A)[i, j] = A[i \cdot \lda + j]. +/// @endverbatim +/// In this case, \math{\lda} must be \math{\geq} the length of a row in \math{\mat(A)}. +/// +/// @param[in] beta +/// A real scalar. +/// - If zero, then \math{B} need not be set on input. +/// +/// @param[in, out] B +/// Pointer to 1D array of real scalars. +/// - On entry, defines \math{\mat(B)} +/// on the RIGHT-hand side of \math{(\star)}. +/// - On exit, defines \math{\mat(B)} +/// on the LEFT-hand side of \math{(\star)}. +/// +/// @param[in] ldb +/// - Leading dimension of \math{\mat(B)} when reading from \math{B}. +/// - Refer to documentation for \math{\lda} for details. +/// +template +inline void lskges( + blas::Layout layout, + blas::Op opS, + blas::Op opA, + int64_t d, // B is d-by-n + int64_t n, // \op(A) is m-by-n + int64_t m, // \op(S) is d-by-m + T alpha, + SKOP &S, + int64_t ro_s, + int64_t co_s, + const T *A, + int64_t lda, + T beta, + T *B, + int64_t ldb +) { + if (!S.known_filled) + fill_sparse(S); + using RNG = typename SKOP::RNG_t; + using sint_t = typename SKOP::index_t; + auto Scoo = coo_view_of_skop(S); + left_spmm( + layout, opS, opA, d, n, m, alpha, Scoo, ro_s, co_s, + A, lda, beta, B, ldb + ); + return; +} + + +// MARK: RSKGES + +// ============================================================================= +/// RSKGES: Perform a GEMM-like operation +/// @verbatim embed:rst:leading-slashes +/// .. math:: +/// \mat(B) = \alpha \cdot \underbrace{\op(\mat(A))}_{m \times n} \cdot \underbrace{\op(\submat(S))}_{n \times d} + \beta \cdot \underbrace{\mat(B)}_{m \times d}, \tag{$\star$} +/// @endverbatim +/// where \math{\alpha} and \math{\beta} are real scalars, \math{\op(X)} either returns a matrix \math{X} +/// or its transpose, and \math{S} is a sparse sketching operator. +/// +/// @verbatim embed:rst:leading-slashes +/// What are :math:`\mat(A)` and :math:`\mat(B)`? +/// Their shapes are defined implicitly by :math:`(m, d, n, \opA)`. +/// Their precise contents are determined by :math:`(A, \lda)`, :math:`(B, \ldb)`, +/// and "layout", following the same convention as BLAS. +/// +/// What is :math:`\submat(S)`? +/// Its shape is defined implicitly by :math:`(\opS, n, d)`. +/// If :math:`{\submat(S)}` is of shape :math:`r \times c`, +/// then it is the :math:`r \times c` submatrix of :math:`{S}` whose upper-left corner +/// appears at index :math:`(\texttt{ro_s}, \texttt{co_s})` of :math:`{S}`. +/// @endverbatim +/// @param[in] layout +/// Layout::ColMajor or Layout::RowMajor +/// - Matrix storage for \math{\mat(A)} and \math{\mat(B)}. +/// +/// @param[in] opA +/// - If \math{\opA} == NoTrans, then \math{\op(\mat(A)) = \mat(A)}. +/// - If \math{\opA} == Trans, then \math{\op(\mat(A)) = \mat(A)^T}. +/// +/// @param[in] opS +/// - If \math{\opS} = NoTrans, then \math{ \op(\submat(S)) = \submat(S)}. +/// - If \math{\opS} = Trans, then \math{\op(\submat(S)) = \submat(S)^T }. +/// +/// @param[in] m +/// A nonnegative integer. +/// - The number of rows in \math{\mat(B)}. +/// - The number of rows in \math{\op(\mat(A))}. +/// +/// @param[in] d +/// A nonnegative integer. +/// - The number of columns in \math{\mat(B)} +/// - The number of columns in \math{\op(\mat(S))}. +/// +/// @param[in] n +/// A nonnegative integer. +/// - The number of columns in \math{\op(\mat(A))} +/// - The number of rows in \math{\op(\submat(S))}. +/// +/// @param[in] alpha +/// A real scalar. +/// - If zero, then \math{A} is not accessed. +/// +/// @param[in] A +/// Pointer to a 1D array of real scalars. +/// - Defines \math{\mat(A)}. +/// +/// @param[in] lda +/// A nonnegative integer. +/// * Leading dimension of \math{\mat(A)} when reading from \math{A}. +/// * If layout == ColMajor, then +/// @verbatim embed:rst:leading-slashes +/// .. math:: +/// \mat(A)[i, j] = A[i + j \cdot \lda]. +/// @endverbatim +/// In this case, \math{\lda} must be \math{\geq} the length of a column in \math{\mat(A)}. +/// * If layout == RowMajor, then +/// @verbatim embed:rst:leading-slashes +/// .. math:: +/// \mat(A)[i, j] = A[i \cdot \lda + j]. +/// @endverbatim +/// In this case, \math{\lda} must be \math{\geq} the length of a row in \math{\mat(A)}. +/// +/// @param[in] S +/// A SparseSkOp object. +/// - Defines \math{\submat(S)}. +/// +/// @param[in] ro_s +/// A nonnegative integer. +/// - The rows of \math{\submat(S)} are a contiguous subset of rows of \math{S}. +/// - The rows of \math{\submat(S)} start at \math{S[\texttt{ro_s}, :]}. +/// +/// @param[in] co_s +/// A nonnnegative integer. +/// - The columns of \math{\submat(S)} are a contiguous subset of columns of \math{S}. +/// - The columns \math{\submat(S)} start at \math{S[:,\texttt{co_s}]}. +/// +/// @param[in] beta +/// A real scalar. +/// - If zero, then \math{B} need not be set on input. +/// +/// @param[in, out] B +/// Pointer to 1D array of real scalars. +/// - On entry, defines \math{\mat(B)} +/// on the RIGHT-hand side of \math{(\star)}. +/// - On exit, defines \math{\mat(B)} +/// on the LEFT-hand side of \math{(\star)}. +/// +/// @param[in] ldb +/// - Leading dimension of \math{\mat(B)} when reading from \math{B}. +/// - Refer to documentation for \math{\lda} for details. +/// +template +inline void rskges( + blas::Layout layout, + blas::Op opA, + blas::Op opS, + int64_t m, // B is m-by-d + int64_t d, // op(S) is n-by-d + int64_t n, // op(A) is m-by-n + T alpha, + const T *A, + int64_t lda, + SKOP &S, + int64_t ro_s, + int64_t co_s, + T beta, + T *B, + int64_t ldb +) { + if (!S.known_filled) + fill_sparse(S); + using RNG = typename SKOP::RNG_t; + using sint = typename SKOP::index_t; + auto Scoo = coo_view_of_skop(S); + right_spmm( + layout, opA, opS, m, d, n, alpha, A, lda, Scoo, ro_s, co_s, beta, B, ldb + ); + return; +} + +} // end namespace RandBLAS::sparse + + +namespace RandBLAS { + +using namespace RandBLAS::dense; +using namespace RandBLAS::sparse; + +// MARK: SKGE overloads, sub // ============================================================================= /// \fn sketch_general(blas::Layout layout, blas::Op opS, blas::Op opA, int64_t d, @@ -251,7 +852,6 @@ inline void sketch_general( ); } -// MARK: SUBMAT(S), RIGHT // ============================================================================= /// \fn sketch_general(blas::Layout layout, blas::Op opA, blas::Op opS, int64_t m, int64_t d, int64_t n, @@ -424,7 +1024,7 @@ inline void sketch_general( } -// MARK: FULL(S), LEFT +// MARK: SKGE overloads, full // ============================================================================= /// \fn sketch_general(blas::Layout layout, blas::Op opS, blas::Op opA, int64_t d, @@ -528,7 +1128,6 @@ inline void sketch_general( return sketch_general(layout, opS, opA, d, n, m, alpha, S, 0, 0, A, lda, beta, B, ldb); }; -// MARK: FULL(S), RIGHT // ============================================================================= /// \fn sketch_general(blas::Layout layout, blas::Op opA, blas::Op opS, int64_t m, int64_t d, int64_t n, @@ -632,4 +1231,6 @@ inline void sketch_general( }; } // end namespace RandBLAS + + #endif diff --git a/RandBLAS/skge3_to_gemm.hh b/RandBLAS/skge3_to_gemm.hh deleted file mode 100644 index 8303c947..00000000 --- a/RandBLAS/skge3_to_gemm.hh +++ /dev/null @@ -1,368 +0,0 @@ -// Copyright, 2024. See LICENSE for copyright holder information. -// -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are met: -// -// (1) Redistributions of source code must retain the above copyright notice, -// this list of conditions and the following disclaimer. -// -// (2) Redistributions in binary form must reproduce the above copyright -// notice, this list of conditions and the following disclaimer in the -// documentation and/or other materials provided with the distribution. -// -// (3) Neither the name of the copyright holder nor the names of its -// contributors may be used to endorse or promote products derived from -// this software without specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE -// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -// POSSIBILITY OF SUCH DAMAGE. -// - -#ifndef randblas_skge3_to_gemm_hh -#define randblas_skge3_to_gemm_hh - -#include "RandBLAS/base.hh" -#include "RandBLAS/exceptions.hh" -#include "RandBLAS/random_gen.hh" - -#include - -#include -#include -#include -#include -#include - -#include -#include -#include "dense_skops.hh" - - -namespace RandBLAS::dense { - -using RandBLAS::DenseSkOp; -using RandBLAS::fill_dense; - -// ============================================================================= -/// @verbatim embed:rst:leading-slashes -/// -/// .. |op| mathmacro:: \operatorname{op} -/// .. |mat| mathmacro:: \operatorname{mat} -/// .. |submat| mathmacro:: \operatorname{submat} -/// .. |lda| mathmacro:: \mathrm{lda} -/// .. |ldb| mathmacro:: \mathrm{ldb} -/// .. |opA| mathmacro:: \mathrm{opA} -/// .. |opS| mathmacro:: \mathrm{opS} -/// -/// @endverbatim -/// LSKGE3: Perform a GEMM-like operation -/// @verbatim embed:rst:leading-slashes -/// .. math:: -/// \mat(B) = \alpha \cdot \underbrace{\op(\submat(S))}_{d \times m} \cdot \underbrace{\op(\mat(A))}_{m \times n} + \beta \cdot \underbrace{\mat(B)}_{d \times n}, \tag{$\star$} -/// @endverbatim -/// where \math{\alpha} and \math{\beta} are real scalars, \math{\op(X)} either returns a matrix \math{X} -/// or its transpose, and \math{S} is a sketching operator that takes Level 3 BLAS effort to apply. -/// -/// @verbatim embed:rst:leading-slashes -/// What are :math:`\mat(A)` and :math:`\mat(B)`? -/// Their shapes are defined implicitly by :math:`(d, m, n, \opA)`. -/// Their precise contents are determined by :math:`(A, \lda)`, :math:`(B, \ldb)`, -/// and "layout", following the same convention as BLAS. -/// -/// What is :math:`\submat(S)`? -/// Its shape is defined implicitly by :math:`(\opS, d, m)`. -/// If :math:`{\submat(S)}` is of shape :math:`r \times c`, -/// then it is the :math:`r \times c` submatrix of :math:`{S}` whose upper-left corner -/// appears at index :math:`(\texttt{ro_s}, \texttt{co_s})` of :math:`{S}`. -/// @endverbatim -/// @param[in] layout -/// Layout::ColMajor or Layout::RowMajor -/// - Matrix storage for \math{\mat(A)} and \math{\mat(B)}. -/// -/// @param[in] opS -/// - If \math{\opS} = NoTrans, then \math{ \op(\submat(S)) = \submat(S)}. -/// - If \math{\opS} = Trans, then \math{\op(\submat(S)) = \submat(S)^T }. -/// @param[in] opA -/// - If \math{\opA} == NoTrans, then \math{\op(\mat(A)) = \mat(A)}. -/// - If \math{\opA} == Trans, then \math{\op(\mat(A)) = \mat(A)^T}. -/// @param[in] d -/// A nonnegative integer. -/// - The number of rows in \math{\mat(B)} -/// - The number of rows in \math{\op(\mat(S))}. -/// -/// @param[in] n -/// A nonnegative integer. -/// - The number of columns in \math{\mat(B)} -/// - The number of columns in \math{\op(\mat(A))}. -/// -/// @param[in] m -/// A nonnegative integer. -/// - The number of columns in \math{\op(\submat(S))} -/// - The number of rows in \math{\op(\mat(A))}. -/// -/// @param[in] alpha -/// A real scalar. -/// - If zero, then \math{A} is not accessed. -/// -/// @param[in] S -/// A DenseSkOp object. -/// - Defines \math{\submat(S)}. -/// -/// @param[in] ro_s -/// A nonnegative integer. -/// - The rows of \math{\submat(S)} are a contiguous subset of rows of \math{S}. -/// - The rows of \math{\submat(S)} start at \math{S[\texttt{ro_s}, :]}. -/// -/// @param[in] co_s -/// A nonnnegative integer. -/// - The columns of \math{\submat(S)} are a contiguous subset of columns of \math{S}. -/// - The columns \math{\submat(S)} start at \math{S[:,\texttt{co_s}]}. -/// -/// @param[in] A -/// Pointer to a 1D array of real scalars. -/// - Defines \math{\mat(A)}. -/// -/// @param[in] lda -/// A nonnegative integer. -/// * Leading dimension of \math{\mat(A)} when reading from \math{A}. -/// * If layout == ColMajor, then -/// @verbatim embed:rst:leading-slashes -/// .. math:: -/// \mat(A)[i, j] = A[i + j \cdot \lda]. -/// @endverbatim -/// In this case, \math{\lda} must be \math{\geq} the length of a column in \math{\mat(A)}. -/// * If layout == RowMajor, then -/// @verbatim embed:rst:leading-slashes -/// .. math:: -/// \mat(A)[i, j] = A[i \cdot \lda + j]. -/// @endverbatim -/// In this case, \math{\lda} must be \math{\geq} the length of a row in \math{\mat(A)}. -/// -/// @param[in] beta -/// A real scalar. -/// - If zero, then \math{B} need not be set on input. -/// -/// @param[in, out] B -/// Pointer to 1D array of real scalars. -/// - On entry, defines \math{\mat(B)} -/// on the RIGHT-hand side of \math{(\star)}. -/// - On exit, defines \math{\mat(B)} -/// on the LEFT-hand side of \math{(\star)}. -/// -/// @param[in] ldb -/// - Leading dimension of \math{\mat(B)} when reading from \math{B}. -/// - Refer to documentation for \math{\lda} for details. -/// -template -void lskge3( - blas::Layout layout, - blas::Op opS, - blas::Op opA, - int64_t d, // B is d-by-n - int64_t n, // op(A) is m-by-n - int64_t m, // op(S) is d-by-m - T alpha, - DenseSkOp &S, - int64_t ro_s, - int64_t co_s, - const T *A, - int64_t lda, - T beta, - T *B, - int64_t ldb -){ - auto [rows_submat_S, cols_submat_S] = dims_before_op(d, m, opS); - if (!S.buff) { - // We'll make a shallow copy of the sketching operator, take responsibility for filling the memory - // of that sketching operator, and then call LSKGE3 with that new object. - T *buff = new T[rows_submat_S * cols_submat_S]; - fill_dense(S.dist, rows_submat_S, cols_submat_S, ro_s, co_s, buff, S.seed_state); - DenseDist D{rows_submat_S, cols_submat_S, DenseDistName::BlackBox, S.dist.major_axis}; - DenseSkOp S_(D, S.seed_state, buff); - lskge3(layout, opS, opA, d, n, m, alpha, S_, 0, 0, A, lda, beta, B, ldb); - delete [] buff; - return; - } - randblas_require( S.dist.n_rows >= rows_submat_S + ro_s ); - randblas_require( S.dist.n_cols >= cols_submat_S + co_s ); - auto [rows_A, cols_A] = dims_before_op(m, n, opA); - if (layout == blas::Layout::ColMajor) { - randblas_require(lda >= rows_A); - randblas_require(ldb >= d); - } else { - randblas_require(lda >= cols_A); - randblas_require(ldb >= n); - } - - auto [pos, lds] = offset_and_ldim(S.layout, S.dist.n_rows, S.dist.n_cols, ro_s, co_s); - T* S_ptr = &S.buff[pos]; - if (S.layout != layout) - opS = (opS == blas::Op::NoTrans) ? blas::Op::Trans : blas::Op::NoTrans; - - blas::gemm(layout, opS, opA, d, n, m, alpha, S_ptr, lds, A, lda, beta, B, ldb); - return; -} - -// ============================================================================= -/// RSKGE3: Perform a GEMM-like operation -/// @verbatim embed:rst:leading-slashes -/// .. math:: -/// \mat(B) = \alpha \cdot \underbrace{\op(\mat(A))}_{m \times n} \cdot \underbrace{\op(\submat(S))}_{n \times d} + \beta \cdot \underbrace{\mat(B)}_{m \times d}, \tag{$\star$} -/// @endverbatim -/// where \math{\alpha} and \math{\beta} are real scalars, \math{\op(X)} either returns a matrix \math{X} -/// or its transpose, and \math{S} is a sketching operator that takes Level 3 BLAS effort to apply. -/// -/// @verbatim embed:rst:leading-slashes -/// What are :math:`\mat(A)` and :math:`\mat(B)`? -/// Their shapes are defined implicitly by :math:`(m, d, n, \opA)`. -/// Their precise contents are determined by :math:`(A, \lda)`, :math:`(B, \ldb)`, -/// and "layout", following the same convention as BLAS. -/// -/// What is :math:`\submat(S)`? -/// Its shape is defined implicitly by :math:`(\opS, n, d)`. -/// If :math:`{\submat(S)}` is of shape :math:`r \times c`, -/// then it is the :math:`r \times c` submatrix of :math:`{S}` whose upper-left corner -/// appears at index :math:`(\texttt{ro_s}, \texttt{co_s})` of :math:`{S}`. -/// @endverbatim -/// @param[in] layout -/// Layout::ColMajor or Layout::RowMajor -/// - Matrix storage for \math{\mat(A)} and \math{\mat(B)}. -/// -/// @param[in] opA -/// - If \math{\opA} == NoTrans, then \math{\op(\mat(A)) = \mat(A)}. -/// - If \math{\opA} == Trans, then \math{\op(\mat(A)) = \mat(A)^T}. -/// -/// @param[in] opS -/// - If \math{\opS} = NoTrans, then \math{ \op(\submat(S)) = \submat(S)}. -/// - If \math{\opS} = Trans, then \math{\op(\submat(S)) = \submat(S)^T }. -/// -/// @param[in] m -/// A nonnegative integer. -/// - The number of rows in \math{\mat(B)}. -/// - The number of rows in \math{\op(\mat(A))}. -/// -/// @param[in] d -/// A nonnegative integer. -/// - The number of columns in \math{\mat(B)} -/// - The number of columns in \math{\op(\mat(S))}. -/// -/// @param[in] n -/// A nonnegative integer. -/// - The number of columns in \math{\op(\mat(A))} -/// - The number of rows in \math{\op(\submat(S))}. -/// -/// @param[in] alpha -/// A real scalar. -/// - If zero, then \math{A} is not accessed. -/// -/// @param[in] A -/// Pointer to a 1D array of real scalars. -/// - Defines \math{\mat(A)}. -/// -/// @param[in] lda -/// A nonnegative integer. -/// * Leading dimension of \math{\mat(A)} when reading from \math{A}. -/// * If layout == ColMajor, then -/// @verbatim embed:rst:leading-slashes -/// .. math:: -/// \mat(A)[i, j] = A[i + j \cdot \lda]. -/// @endverbatim -/// In this case, \math{\lda} must be \math{\geq} the length of a column in \math{\mat(A)}. -/// * If layout == RowMajor, then -/// @verbatim embed:rst:leading-slashes -/// .. math:: -/// \mat(A)[i, j] = A[i \cdot \lda + j]. -/// @endverbatim -/// In this case, \math{\lda} must be \math{\geq} the length of a row in \math{\mat(A)}. -/// -/// @param[in] S -/// A DenseSkOp object. -/// - Defines \math{\submat(S)}. -/// -/// @param[in] ro_s -/// A nonnegative integer. -/// - The rows of \math{\submat(S)} are a contiguous subset of rows of \math{S}. -/// - The rows of \math{\submat(S)} start at \math{S[\texttt{ro_s}, :]}. -/// -/// @param[in] co_s -/// A nonnnegative integer. -/// - The columns of \math{\submat(S)} are a contiguous subset of columns of \math{S}. -/// - The columns \math{\submat(S)} start at \math{S[:,\texttt{co_s}]}. -/// -/// @param[in] beta -/// A real scalar. -/// - If zero, then \math{B} need not be set on input. -/// -/// @param[in, out] B -/// Pointer to 1D array of real scalars. -/// - On entry, defines \math{\mat(B)} -/// on the RIGHT-hand side of \math{(\star)}. -/// - On exit, defines \math{\mat(B)} -/// on the LEFT-hand side of \math{(\star)}. -/// -/// @param[in] ldb -/// - Leading dimension of \math{\mat(B)} when reading from \math{B}. -/// - Refer to documentation for \math{\lda} for details. -/// -template -void rskge3( - blas::Layout layout, - blas::Op opA, - blas::Op opS, - int64_t m, // B is m-by-d - int64_t d, // op(S) is n-by-d - int64_t n, // op(A) is m-by-n - T alpha, - const T *A, - int64_t lda, - DenseSkOp &S, - int64_t ro_s, - int64_t co_s, - T beta, - T *B, - int64_t ldb -){ - auto [rows_submat_S, cols_submat_S] = dims_before_op(n, d, opS); - if (!S.buff) { - // We'll make a shallow copy of the sketching operator, take responsibility for filling the memory - // of that sketching operator, and then call RSKGE3 with that new object. - T *buff = new T[rows_submat_S * cols_submat_S]; - fill_dense(S.dist, rows_submat_S, cols_submat_S, ro_s, co_s, buff, S.seed_state); - DenseDist D{rows_submat_S, cols_submat_S, DenseDistName::BlackBox, S.dist.major_axis}; - DenseSkOp S_(D, S.seed_state, buff); - rskge3(layout, opA, opS, m, d, n, alpha, A, lda, S_, 0, 0, beta, B, ldb); - delete [] buff; - return; - } - randblas_require( S.dist.n_rows >= rows_submat_S + ro_s ); - randblas_require( S.dist.n_cols >= cols_submat_S + co_s ); - auto [rows_A, cols_A] = dims_before_op(m, n, opA); - if (layout == blas::Layout::ColMajor) { - randblas_require(lda >= rows_A); - randblas_require(ldb >= m); - } else { - randblas_require(lda >= cols_A); - randblas_require(ldb >= d); - } - - auto [pos, lds] = offset_and_ldim(S.layout, S.dist.n_rows, S.dist.n_cols, ro_s, co_s); - T* S_ptr = &S.buff[pos]; - if (S.layout != layout) - opS = (opS == blas::Op::NoTrans) ? blas::Op::Trans : blas::Op::NoTrans; - - blas::gemm(layout, opA, opS, m, d, n, alpha, A, lda, S_ptr, lds, beta, B, ldb); - return; -} - -} // end namespace RandBLAS::dense - -#endif diff --git a/RandBLAS/skges_to_spmm.hh b/RandBLAS/skges_to_spmm.hh deleted file mode 100644 index 12218061..00000000 --- a/RandBLAS/skges_to_spmm.hh +++ /dev/null @@ -1,330 +0,0 @@ -// Copyright, 2024. See LICENSE for copyright holder information. -// -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are met: -// -// (1) Redistributions of source code must retain the above copyright notice, -// this list of conditions and the following disclaimer. -// -// (2) Redistributions in binary form must reproduce the above copyright -// notice, this list of conditions and the following disclaimer in the -// documentation and/or other materials provided with the distribution. -// -// (3) Neither the name of the copyright holder nor the names of its -// contributors may be used to endorse or promote products derived from -// this software without specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE -// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -// POSSIBILITY OF SUCH DAMAGE. -// - -#ifndef randblas_skges_to_spmm_hh -#define randblas_skges_to_spmm_hh - -#include "RandBLAS/config.h" -#include "RandBLAS/base.hh" -#include "RandBLAS/exceptions.hh" -#include "RandBLAS/random_gen.hh" -#include "RandBLAS/util.hh" -#include "RandBLAS/sparse_data/spmm_dispatch.hh" - -#include -#include -#include -#include -#include -#if defined(RandBLAS_HAS_OpenMP) -#include -#endif - -#define MAX(a, b) (((a) < (b)) ? (b) : (a)) -#define MIN(a, b) (((a) < (b)) ? (a) : (b)) - -namespace RandBLAS::sparse { - - -// ============================================================================= -/// @verbatim embed:rst:leading-slashes -/// -/// .. |op| mathmacro:: \operatorname{op} -/// .. |mat| mathmacro:: \operatorname{mat} -/// .. |submat| mathmacro:: \operatorname{submat} -/// .. |lda| mathmacro:: \mathrm{lda} -/// .. |ldb| mathmacro:: \mathrm{ldb} -/// .. |opA| mathmacro:: \mathrm{opA} -/// .. |opS| mathmacro:: \mathrm{opS} -/// -/// @endverbatim -/// LSKGES: Perform a GEMM-like operation -/// @verbatim embed:rst:leading-slashes -/// .. math:: -/// \mat(B) = \alpha \cdot \underbrace{\op(\submat(S))}_{d \times m} \cdot \underbrace{\op(\mat(A))}_{m \times n} + \beta \cdot \underbrace{\mat(B)}_{d \times n}, \tag{$\star$} -/// @endverbatim -/// where \math{\alpha} and \math{\beta} are real scalars, \math{\op(X)} either returns a matrix \math{X} -/// or its transpose, and \math{S} is a sparse sketching operator. -/// -/// @verbatim embed:rst:leading-slashes -/// What are :math:`\mat(A)` and :math:`\mat(B)`? -/// Their shapes are defined implicitly by :math:`(d, m, n, \opA)`. -/// Their precise contents are determined by :math:`(A, \lda)`, :math:`(B, \ldb)`, -/// and "layout", following the same convention as BLAS. -/// -/// What is :math:`\submat(S)`? -/// Its shape is defined implicitly by :math:`(\opS, d, m)`. -/// If :math:`{\submat(S)}` is of shape :math:`r \times c`, -/// then it is the :math:`r \times c` submatrix of :math:`{S}` whose upper-left corner -/// appears at index :math:`(\texttt{ro_s}, \texttt{co_s})` of :math:`{S}`. -/// @endverbatim -/// @param[in] layout -/// Layout::ColMajor or Layout::RowMajor -/// - Matrix storage for \math{\mat(A)} and \math{\mat(B)}. -/// -/// @param[in] opS -/// - If \math{\opS} = NoTrans, then \math{ \op(\submat(S)) = \submat(S)}. -/// - If \math{\opS} = Trans, then \math{\op(\submat(S)) = \submat(S)^T }. -/// -/// @param[in] opA -/// - If \math{\opA} == NoTrans, then \math{\op(\mat(A)) = \mat(A)}. -/// - If \math{\opA} == Trans, then \math{\op(\mat(A)) = \mat(A)^T}. -/// -/// @param[in] d -/// A nonnegative integer. -/// - The number of rows in \math{\mat(B)} -/// - The number of rows in \math{\op(\mat(S))}. -/// -/// @param[in] n -/// A nonnegative integer. -/// - The number of columns in \math{\mat(B)} -/// - The number of columns in \math{\op(\mat(A))}. -/// -/// @param[in] m -/// A nonnegative integer. -/// - The number of columns in \math{\op(\submat(S))} -/// - The number of rows in \math{\op(\mat(A))}. -/// -/// @param[in] alpha -/// A real scalar. -/// - If zero, then \math{A} is not accessed. -/// -/// @param[in] S -/// A SparseSkOp object. -/// - Defines \math{\submat(S)}. -/// -/// @param[in] ro_s -/// A nonnegative integer. -/// - The rows of \math{\submat(S)} are a contiguous subset of rows of \math{S}. -/// - The rows of \math{\submat(S)} start at \math{S[\texttt{ro_s}, :]}. -/// -/// @param[in] co_s -/// A nonnnegative integer. -/// - The columns of \math{\submat(S)} are a contiguous subset of columns of \math{S}. -/// - The columns \math{\submat(S)} start at \math{S[:,\texttt{co_s}]}. -/// -/// @param[in] A -/// Pointer to a 1D array of real scalars. -/// - Defines \math{\mat(A)}. -/// -/// @param[in] lda -/// A nonnegative integer. -/// * Leading dimension of \math{\mat(A)} when reading from \math{A}. -/// * If layout == ColMajor, then -/// @verbatim embed:rst:leading-slashes -/// .. math:: -/// \mat(A)[i, j] = A[i + j \cdot \lda]. -/// @endverbatim -/// In this case, \math{\lda} must be \math{\geq} the length of a column in \math{\mat(A)}. -/// * If layout == RowMajor, then -/// @verbatim embed:rst:leading-slashes -/// .. math:: -/// \mat(A)[i, j] = A[i \cdot \lda + j]. -/// @endverbatim -/// In this case, \math{\lda} must be \math{\geq} the length of a row in \math{\mat(A)}. -/// -/// @param[in] beta -/// A real scalar. -/// - If zero, then \math{B} need not be set on input. -/// -/// @param[in, out] B -/// Pointer to 1D array of real scalars. -/// - On entry, defines \math{\mat(B)} -/// on the RIGHT-hand side of \math{(\star)}. -/// - On exit, defines \math{\mat(B)} -/// on the LEFT-hand side of \math{(\star)}. -/// -/// @param[in] ldb -/// - Leading dimension of \math{\mat(B)} when reading from \math{B}. -/// - Refer to documentation for \math{\lda} for details. -/// -template -inline void lskges( - blas::Layout layout, - blas::Op opS, - blas::Op opA, - int64_t d, // B is d-by-n - int64_t n, // \op(A) is m-by-n - int64_t m, // \op(S) is d-by-m - T alpha, - SKOP &S, - int64_t ro_s, - int64_t co_s, - const T *A, - int64_t lda, - T beta, - T *B, - int64_t ldb -) { - if (!S.known_filled) - fill_sparse(S); - using RNG = typename SKOP::RNG_t; - using sint_t = typename SKOP::index_t; - auto Scoo = coo_view_of_skop(S); - left_spmm( - layout, opS, opA, d, n, m, alpha, Scoo, ro_s, co_s, - A, lda, beta, B, ldb - ); - return; -} - - -// ============================================================================= -/// RSKGES: Perform a GEMM-like operation -/// @verbatim embed:rst:leading-slashes -/// .. math:: -/// \mat(B) = \alpha \cdot \underbrace{\op(\mat(A))}_{m \times n} \cdot \underbrace{\op(\submat(S))}_{n \times d} + \beta \cdot \underbrace{\mat(B)}_{m \times d}, \tag{$\star$} -/// @endverbatim -/// where \math{\alpha} and \math{\beta} are real scalars, \math{\op(X)} either returns a matrix \math{X} -/// or its transpose, and \math{S} is a sparse sketching operator. -/// -/// @verbatim embed:rst:leading-slashes -/// What are :math:`\mat(A)` and :math:`\mat(B)`? -/// Their shapes are defined implicitly by :math:`(m, d, n, \opA)`. -/// Their precise contents are determined by :math:`(A, \lda)`, :math:`(B, \ldb)`, -/// and "layout", following the same convention as BLAS. -/// -/// What is :math:`\submat(S)`? -/// Its shape is defined implicitly by :math:`(\opS, n, d)`. -/// If :math:`{\submat(S)}` is of shape :math:`r \times c`, -/// then it is the :math:`r \times c` submatrix of :math:`{S}` whose upper-left corner -/// appears at index :math:`(\texttt{ro_s}, \texttt{co_s})` of :math:`{S}`. -/// @endverbatim -/// @param[in] layout -/// Layout::ColMajor or Layout::RowMajor -/// - Matrix storage for \math{\mat(A)} and \math{\mat(B)}. -/// -/// @param[in] opA -/// - If \math{\opA} == NoTrans, then \math{\op(\mat(A)) = \mat(A)}. -/// - If \math{\opA} == Trans, then \math{\op(\mat(A)) = \mat(A)^T}. -/// -/// @param[in] opS -/// - If \math{\opS} = NoTrans, then \math{ \op(\submat(S)) = \submat(S)}. -/// - If \math{\opS} = Trans, then \math{\op(\submat(S)) = \submat(S)^T }. -/// -/// @param[in] m -/// A nonnegative integer. -/// - The number of rows in \math{\mat(B)}. -/// - The number of rows in \math{\op(\mat(A))}. -/// -/// @param[in] d -/// A nonnegative integer. -/// - The number of columns in \math{\mat(B)} -/// - The number of columns in \math{\op(\mat(S))}. -/// -/// @param[in] n -/// A nonnegative integer. -/// - The number of columns in \math{\op(\mat(A))} -/// - The number of rows in \math{\op(\submat(S))}. -/// -/// @param[in] alpha -/// A real scalar. -/// - If zero, then \math{A} is not accessed. -/// -/// @param[in] A -/// Pointer to a 1D array of real scalars. -/// - Defines \math{\mat(A)}. -/// -/// @param[in] lda -/// A nonnegative integer. -/// * Leading dimension of \math{\mat(A)} when reading from \math{A}. -/// * If layout == ColMajor, then -/// @verbatim embed:rst:leading-slashes -/// .. math:: -/// \mat(A)[i, j] = A[i + j \cdot \lda]. -/// @endverbatim -/// In this case, \math{\lda} must be \math{\geq} the length of a column in \math{\mat(A)}. -/// * If layout == RowMajor, then -/// @verbatim embed:rst:leading-slashes -/// .. math:: -/// \mat(A)[i, j] = A[i \cdot \lda + j]. -/// @endverbatim -/// In this case, \math{\lda} must be \math{\geq} the length of a row in \math{\mat(A)}. -/// -/// @param[in] S -/// A SparseSkOp object. -/// - Defines \math{\submat(S)}. -/// -/// @param[in] ro_s -/// A nonnegative integer. -/// - The rows of \math{\submat(S)} are a contiguous subset of rows of \math{S}. -/// - The rows of \math{\submat(S)} start at \math{S[\texttt{ro_s}, :]}. -/// -/// @param[in] co_s -/// A nonnnegative integer. -/// - The columns of \math{\submat(S)} are a contiguous subset of columns of \math{S}. -/// - The columns \math{\submat(S)} start at \math{S[:,\texttt{co_s}]}. -/// -/// @param[in] beta -/// A real scalar. -/// - If zero, then \math{B} need not be set on input. -/// -/// @param[in, out] B -/// Pointer to 1D array of real scalars. -/// - On entry, defines \math{\mat(B)} -/// on the RIGHT-hand side of \math{(\star)}. -/// - On exit, defines \math{\mat(B)} -/// on the LEFT-hand side of \math{(\star)}. -/// -/// @param[in] ldb -/// - Leading dimension of \math{\mat(B)} when reading from \math{B}. -/// - Refer to documentation for \math{\lda} for details. -/// -template -inline void rskges( - blas::Layout layout, - blas::Op opA, - blas::Op opS, - int64_t m, // B is m-by-d - int64_t d, // op(S) is n-by-d - int64_t n, // op(A) is m-by-n - T alpha, - const T *A, - int64_t lda, - SKOP &S, - int64_t ro_s, - int64_t co_s, - T beta, - T *B, - int64_t ldb -) { - if (!S.known_filled) - fill_sparse(S); - using RNG = typename SKOP::RNG_t; - using sint = typename SKOP::index_t; - auto Scoo = coo_view_of_skop(S); - right_spmm( - layout, opA, opS, m, d, n, alpha, A, lda, Scoo, ro_s, co_s, beta, B, ldb - ); - return; -} - -} // end namespace RandBLAS::sparse - -#endif diff --git a/RandBLAS/sparse_data/DevNotes.md b/RandBLAS/sparse_data/DevNotes.md index 66ebe532..c5f138d3 100644 --- a/RandBLAS/sparse_data/DevNotes.md +++ b/RandBLAS/sparse_data/DevNotes.md @@ -46,8 +46,7 @@ Sketching dense data with a sparse operator is typically handled with ``sketch_g which is defined in ``skge.hh``. If we call this function with a SparseSkOp object, ``S``, we'd immediately get routed to -a function in ``skges_to_spmm.hh``: either ``lskges`` or ``rskges``. Here's what would happen -after we entered one of those functions: +either ``lskges`` or ``rskges``. Here's what would happen after we entered one of those functions: 1. If necessary, we'd sample the defining data of ``S`` with ``RandBLAS::fill_sparse(S)``. @@ -58,7 +57,7 @@ after we entered one of those functions: ## Sketching sparse data with dense operators If we call ``sketch_sparse`` with a DenseSkOp, ``S``, and a sparse matrix, ``A``, then we'll get routed to either -``lsksp3`` or ``rsksp3`` in ``sparse_data/sksp3_to_spmm.hh``. +``lsksp3`` or ``rsksp3``. From there, we'll do the following. diff --git a/RandBLAS/sparse_data/coo_spmm_impl.hh b/RandBLAS/sparse_data/coo_spmm_impl.hh index 11196153..60f1d397 100644 --- a/RandBLAS/sparse_data/coo_spmm_impl.hh +++ b/RandBLAS/sparse_data/coo_spmm_impl.hh @@ -36,6 +36,9 @@ #include "RandBLAS/sparse_data/csc_spmm_impl.hh" #include #include +#if defined(RandBLAS_HAS_OpenMP) +#include +#endif namespace RandBLAS::sparse_data::coo { diff --git a/RandBLAS/sparse_data/csc_spmm_impl.hh b/RandBLAS/sparse_data/csc_spmm_impl.hh index 18aabefd..8932e38a 100644 --- a/RandBLAS/sparse_data/csc_spmm_impl.hh +++ b/RandBLAS/sparse_data/csc_spmm_impl.hh @@ -35,6 +35,9 @@ #include "RandBLAS/sparse_data/csc_matrix.hh" #include #include +#if defined(RandBLAS_HAS_OpenMP) +#include +#endif namespace RandBLAS::sparse_data::csc { diff --git a/RandBLAS/sparse_data/csr_spmm_impl.hh b/RandBLAS/sparse_data/csr_spmm_impl.hh index edccc08d..d752064b 100644 --- a/RandBLAS/sparse_data/csr_spmm_impl.hh +++ b/RandBLAS/sparse_data/csr_spmm_impl.hh @@ -35,6 +35,9 @@ #include "RandBLAS/sparse_data/csr_matrix.hh" #include #include +#if defined(RandBLAS_HAS_OpenMP) +#include +#endif namespace RandBLAS::sparse_data::csr { diff --git a/RandBLAS/sparse_data/sksp.hh b/RandBLAS/sparse_data/sksp.hh index a9885f66..02085c7a 100644 --- a/RandBLAS/sparse_data/sksp.hh +++ b/RandBLAS/sparse_data/sksp.hh @@ -32,15 +32,336 @@ #include "RandBLAS/base.hh" #include "RandBLAS/dense_skops.hh" -#include "RandBLAS/sparse_data/sksp3_to_spmm.hh" - #include "RandBLAS/exceptions.hh" + +namespace RandBLAS::sparse_data { + +// MARK: LSKSP3 + +// ============================================================================= +/// \fn lsksp3(blas::Layout layout, blas::Op opS, blas::Op opA, int64_t d, +/// int64_t n, int64_t m, T alpha, DenseSkOp &S, int64_t ro_s, int64_t co_s, +/// SpMat &A, int64_t ro_a, int64_t co_a, T beta, T *B, int64_t ldb +/// ) +/// @verbatim embed:rst:leading-slashes +/// Sketch from the left in an SpMM-like operation +/// +/// .. math:: +/// \mat(B) = \alpha \cdot \underbrace{\op(\submat(S))}_{d \times m} \cdot \underbrace{\op(\submat(A))}_{m \times n} + \beta \cdot \underbrace{\mat(B)}_{d \times n}, \tag{$\star$} +/// +/// where :math:`\alpha` and :math:`\beta` are real scalars, :math:`\op(X)` either returns a matrix :math:`X` +/// or its transpose, :math:`A` is a sparse matrix, and :math:`S` is a dense sketching operator. +/// +/// .. dropdown:: FAQ +/// :animate: fade-in-slide-down +/// +/// **What's** :math:`\mat(B)` **?** +/// +/// It's matrix of shape :math:`d \times n`. Its contents are determined by :math:`(B, \ldb)` +/// and "layout", following the same convention as the Level 3 BLAS function "GEMM." +/// +/// **What are** :math:`\submat(S)` **and** :math:`\submat(A)` **?** +/// +/// Their shapes are determined implicitly by :math:`(\opS, d, m)` and :math:`(\opA, n, m)`. +/// If :math:`{\submat(X)}` is of shape :math:`r \times c`, +/// then it is the :math:`r \times c` submatrix of :math:`{X}` whose upper-left corner +/// appears at index :math:`(\texttt{ro_x}, \texttt{co_x})` of :math:`{X}`. +/// +/// .. dropdown:: Full parameter descriptions +/// :animate: fade-in-slide-down +/// +/// layout - [in] +/// * Layout::ColMajor or Layout::RowMajor. +/// * Matrix storage for :math:`\mat(B)`. +/// +/// opS - [in] +/// * If :math:`\opS` = NoTrans, then :math:`\op(\submat(S)) = \submat(S)`. +/// * If :math:`\opS` = Trans, then :math:`\op(\submat(S)) = \submat(S)^T`. +/// +/// opA - [in] +/// * If :math:`\opA` = NoTrans, then :math:`\op(\submat(A)) = \submat(A)`. +/// * If :math:`\opA` = Trans, then :math:`\op(\submat(A)) = \submat(A)^T`. +/// +/// d - [in] +/// * A nonnegative integer. +/// * The number of rows in :math:`\mat(B)`. +/// * The number of rows in :math:`\op(\submat(S))`. +/// +/// n - [in] +/// * A nonnegative integer. +/// * The number of columns in :math:`\mat(B)`. +/// * The number of columns in :math:`\op(\mat(A))`. +/// +/// m - [in] +/// * A nonnegative integer. +/// * The number of columns in :math:`\op(\submat(S))` +/// * The number of rows in :math:`\op(\mat(A))`. +/// +/// alpha - [in] +/// * A real scalar. +/// +/// S - [in] +/// * A DenseSkOp object. +/// * Defines :math:`\submat(S)`. +/// +/// ro_s - [in] +/// * A nonnegative integer. +/// * The rows of :math:`\submat(S)` are a contiguous subset of rows of :math:`S`. +/// * The rows of :math:`\submat(S)` start at :math:`S[\texttt{ro_s}, :]`. +/// +/// co_s - [in] +/// * A nonnegative integer. +/// * The columns of :math:`\submat(S)` are a contiguous subset of columns of :math:`S`. +/// * The columns :math:`\submat(S)` start at :math:`S[:,\texttt{co_s}]`. +/// +/// A - [in] +/// * A RandBLAS sparse matrix object. +/// * Defines :math:`\submat(A)`. +/// +/// ro_a - [in] +/// * A nonnegative integer. +/// * The rows of :math:`\submat(A)` are a contiguous subset of rows of :math:`A`. +/// * The rows of :math:`\submat(A)` start at :math:`A[\texttt{ro_a}, :]`. +/// +/// co_a - [in] +/// * A nonnegative integer. +/// * The columns of :math:`\submat(A)` are a contiguous subset of columns of :math:`A`. +/// * The columns :math:`\submat(A)` start at :math:`A[:,\texttt{co_a}]`. +/// +/// beta - [in] +/// * A real scalar. +/// * If zero, then :math:`B` need not be set on input. +/// +/// B - [in, out] +/// * Pointer to 1D array of real scalars. +/// * On entry, defines :math:`\mat(B)` +/// on the RIGHT-hand side of :math:`(\star)`. +/// * On exit, defines :math:`\mat(B)` +/// on the LEFT-hand side of :math:`(\star)`. +/// +/// ldb - [in] +/// * A nonnegative integer. +/// * Leading dimension of :math:`\mat(B)` when reading from :math:`B`. +/// +/// @endverbatim +template +void lsksp3( + blas::Layout layout, + blas::Op opS, + blas::Op opA, + int64_t d, // B is d-by-n + int64_t n, // op(submat(A)) is m-by-n + int64_t m, // op(submat(S)) is d-by-m + T alpha, + DenseSkOp &S, + int64_t ro_s, + int64_t co_s, + SpMat &A, + int64_t ro_a, + int64_t co_a, + T beta, + T *B, + int64_t ldb +) { + // B = op(submat(S)) @ op(submat(A)) + auto [rows_submat_S, cols_submat_S] = dims_before_op(d, m, opS); + if (!S.buff) { + T *buff = new T[rows_submat_S * cols_submat_S]; + fill_dense(S.dist, rows_submat_S, cols_submat_S, ro_s, co_s, buff, S.seed_state); + DenseDist D{rows_submat_S, cols_submat_S, DenseDistName::BlackBox, S.dist.major_axis}; + DenseSkOp S_(D, S.seed_state, buff); + lsksp3(layout, opS, opA, d, n, m, alpha, S_, 0, 0, A, ro_a, co_a, beta, B, ldb); + delete [] buff; + return; + } + + auto [rows_submat_A, cols_submat_A] = dims_before_op(m, n, opA); + randblas_require( A.n_rows >= rows_submat_A + ro_a ); + randblas_require( A.n_cols >= cols_submat_A + co_a ); + randblas_require( S.dist.n_rows >= rows_submat_S + ro_s ); + randblas_require( S.dist.n_cols >= cols_submat_S + co_s ); + if (layout == blas::Layout::ColMajor) { + randblas_require(ldb >= d); + } else { + randblas_require(ldb >= n); + } + + auto [pos, lds] = offset_and_ldim(S.layout, S.dist.n_rows, S.dist.n_cols, ro_s, co_s); + T* S_ptr = &S.buff[pos]; + if (S.layout != layout) + opS = (opS == blas::Op::NoTrans) ? blas::Op::Trans : blas::Op::NoTrans; + + right_spmm(layout, opS, opA, d, n, m, alpha, S_ptr, lds, A, ro_a, co_a, beta, B, ldb); + return; +} + +// MARK: RSKSP3 + +// ============================================================================= +/// \fn rsksp3(blas::Layout layout, blas::Op opA, blas::Op opS, int64_t m, +/// int64_t d, int64_t n, T alpha, SpMat &A, int64_t ro_a, int64_t co_a, +/// DenseSkOp &S, int64_t ro_s, int64_t co_s, T beta, T *B, int64_t ldb +/// ) +/// @verbatim embed:rst:leading-slashes +/// Sketch from the right in an SpMM-like operation +/// +/// .. math:: +/// \mat(B) = \alpha \cdot \underbrace{\op(\submat(A))}_{m \times n} \cdot \underbrace{\op(\submat(S))}_{n \times d} + \beta \cdot \underbrace{\mat(B)}_{m \times d}, \tag{$\star$} +/// +/// where :math:`\alpha` and :math:`\beta` are real scalars, :math:`\op(X)` either returns a matrix :math:`X` +/// or its transpose, :math:`A` is a sparse matrix, and :math:`S` is a dense sketching operator. +/// +/// .. dropdown:: FAQ +/// :animate: fade-in-slide-down +/// +/// **What's** :math:`\mat(B)` **?** +/// +/// It's matrix of shape :math:`m \times d`. Its contents are determined by :math:`(B, \ldb)` +/// and "layout", following the same convention as the Level 3 BLAS function "GEMM." +/// +/// **What are** :math:`\submat(S)` **and** :math:`\submat(A)` **?** +/// +/// Their shapes are determined implicitly by :math:`(\opS, n, d)` and :math:`(\opA, m, n)`. +/// If :math:`{\submat(X)}` is of shape :math:`r \times c`, +/// then it is the :math:`r \times c` submatrix of :math:`{X}` whose upper-left corner +/// appears at index :math:`(\texttt{ro_x}, \texttt{co_x})` of :math:`{X}`. +/// +/// .. dropdown:: Full parameter descriptions +/// :animate: fade-in-slide-down +/// +/// layout - [in] +/// * Layout::ColMajor or Layout::RowMajor. +/// * Matrix storage for :math:`\mat(B)`. +/// +/// opA - [in] +/// * If :math:`\opA` == NoTrans, then :math:`\op(\submat(A)) = \submat(A)`. +/// * If :math:`\opA` == Trans, then :math:`\op(\submat(A)) = \submat(A)^T`. +/// +/// opS - [in] +/// * If :math:`\opS` = NoTrans, then :math:`\op(\submat(S)) = \submat(S)`. +/// * If :math:`\opS` = Trans, then :math:`\op(\submat(S)) = \submat(S)^T`. +/// +/// m - [in] +/// * A nonnegative integer. +/// * The number of rows in :math:`\mat(B)`. +/// * The number of rows in :math:`\op(\submat(A))`. +/// +/// d - [in] +/// * A nonnegative integer. +/// * The number of columns in :math:`\mat(B)` +/// * The number of columns in :math:`\op(\submat(S))`. +/// +/// n - [in] +/// * A nonnegative integer. +/// * The number of columns in :math:`\op(\submat(A))` +/// * The number of rows in :math:`\op(\submat(S))`. +/// +/// alpha - [in] +/// * A real scalar. +/// +/// A - [in] +/// * A RandBLAS sparse matrix object. +/// * Defines :math:`\submat(A)`. +/// +/// ro_a - [in] +/// * A nonnegative integer. +/// * The rows of :math:`\submat(A)` are a contiguous subset of rows of :math:`A`. +/// * The rows of :math:`\submat(A)` start at :math:`A[\texttt{ro_a}, :]`. +/// +/// co_a - [in] +/// * A nonnegative integer. +/// * The columns of :math:`\submat(A)` are a contiguous subset of columns of :math:`A`. +/// * The columns :math:`\submat(A)` start at :math:`A[:,\texttt{co_a}]`. +/// +/// S - [in] +/// * A DenseSkOp object. +/// * Defines :math:`\submat(S)`. +/// +/// ro_s - [in] +/// * A nonnegative integer. +/// * The rows of :math:`\submat(S)` are a contiguous subset of rows of :math:`S`. +/// * The rows of :math:`\submat(S)` start at :math:`S[\texttt{ro_s}, :]`. +/// +/// co_s - [in] +/// * A nonnegative integer. +/// * The columns of :math:`\submat(S)` are a contiguous subset of columns of :math:`S`. +/// * The columns :math:`\submat(S)` start at :math:`S[:,\texttt{co_s}]`. +/// +/// beta - [in] +/// * A real scalar. +/// * If zero, then :math:`B` need not be set on input. +/// +/// B - [in, out] +/// * Pointer to 1D array of real scalars. +/// * On entry, defines :math:`\mat(B)` +/// on the RIGHT-hand side of :math:`(\star)`. +/// * On exit, defines :math:`\mat(B)` +/// on the LEFT-hand side of :math:`(\star)`. +/// +/// ldb - [in] +/// * A nonnegative integer. +/// * Leading dimension of :math:`\mat(B)` when reading from :math:`B`. +/// +/// @endverbatim +template +void rsksp3( + blas::Layout layout, + blas::Op opA, + blas::Op opS, + int64_t m, // B is m-by-d + int64_t d, // op(submat(A)) is m-by-n + int64_t n, // op(submat(S)) is n-by-d + T alpha, + SpMat &A, + int64_t ro_a, + int64_t co_a, + DenseSkOp &S, + int64_t ro_s, + int64_t co_s, + T beta, + T *B, + int64_t ldb +) { + auto [rows_submat_S, cols_submat_S] = dims_before_op(n, d, opS); + if (!S.buff) { + T *buff = new T[rows_submat_S * cols_submat_S]; + fill_dense(S.dist, rows_submat_S, cols_submat_S, ro_s, co_s, buff, S.seed_state); + DenseDist D{rows_submat_S, cols_submat_S, DenseDistName::BlackBox, S.dist.major_axis}; + DenseSkOp S_(D, S.seed_state, buff); + rsksp3(layout, opA, opS, m, d, n, alpha, A, ro_a, co_a, S_, 0, 0, beta, B, ldb); + delete [] buff; + return; + } + auto [rows_submat_A, cols_submat_A] = dims_before_op(m, n, opA); + randblas_require( A.n_rows >= rows_submat_A + ro_a ); + randblas_require( A.n_cols >= cols_submat_A + co_a ); + randblas_require( S.dist.n_rows >= rows_submat_S + ro_s ); + randblas_require( S.dist.n_cols >= cols_submat_S + co_s ); + if (layout == blas::Layout::ColMajor) { + randblas_require(ldb >= m); + } else { + randblas_require(ldb >= d); + } + + auto [pos, lds] = offset_and_ldim(S.layout, S.dist.n_rows, S.dist.n_cols, ro_s, co_s); + T* S_ptr = &S.buff[pos]; + if (S.layout != layout) + opS = (opS == blas::Op::NoTrans) ? blas::Op::Trans : blas::Op::NoTrans; + + left_spmm(layout, opA, opS, m, d, n, alpha, A, ro_a, co_a, S_ptr, lds, beta, B, ldb); + return; +} + +} // end namespace RandBLAS::sparse_data + + namespace RandBLAS { using namespace RandBLAS::dense; using namespace RandBLAS::sparse_data; +// MARK: SKSP overloads, sub // ============================================================================= /// \fn sketch_sparse(blas::Layout layout, blas::Op opS, blas::Op opA, int64_t d, @@ -304,4 +625,6 @@ inline void sketch_sparse( } } // end namespace RandBLAS + + #endif diff --git a/RandBLAS/sparse_data/sksp3_to_spmm.hh b/RandBLAS/sparse_data/sksp3_to_spmm.hh deleted file mode 100644 index 13b7150b..00000000 --- a/RandBLAS/sparse_data/sksp3_to_spmm.hh +++ /dev/null @@ -1,368 +0,0 @@ -// Copyright, 2024. See LICENSE for copyright holder information. -// -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are met: -// -// (1) Redistributions of source code must retain the above copyright notice, -// this list of conditions and the following disclaimer. -// -// (2) Redistributions in binary form must reproduce the above copyright -// notice, this list of conditions and the following disclaimer in the -// documentation and/or other materials provided with the distribution. -// -// (3) Neither the name of the copyright holder nor the names of its -// contributors may be used to endorse or promote products derived from -// this software without specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE -// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -// POSSIBILITY OF SUCH DAMAGE. -// - -#ifndef randblas_sksp3_to_spmm_hh -#define randblas_sksp3_to_spmm_hh - -#include "RandBLAS/base.hh" -#include "RandBLAS/dense_skops.hh" -#include "RandBLAS/sparse_data/spmm_dispatch.hh" - -#include "RandBLAS/exceptions.hh" - -namespace RandBLAS::sparse_data { - -using namespace RandBLAS::dense; - -/* Intended macro definitions. - - .. |op| mathmacro:: \operatorname{op} - .. |mat| mathmacro:: \operatorname{mat} - .. |submat| mathmacro:: \operatorname{submat} - .. |ldb| mathmacro:: \texttt{ldb} - .. |opA| mathmacro:: \texttt{opA} - .. |opS| mathmacro:: \texttt{opS} -*/ - - -// ============================================================================= -/// \fn lsksp3(blas::Layout layout, blas::Op opS, blas::Op opA, int64_t d, -/// int64_t n, int64_t m, T alpha, DenseSkOp &S, int64_t ro_s, int64_t co_s, -/// SpMat &A, int64_t ro_a, int64_t co_a, T beta, T *B, int64_t ldb -/// ) -/// @verbatim embed:rst:leading-slashes -/// Sketch from the left in an SpMM-like operation -/// -/// .. math:: -/// \mat(B) = \alpha \cdot \underbrace{\op(\submat(S))}_{d \times m} \cdot \underbrace{\op(\submat(A))}_{m \times n} + \beta \cdot \underbrace{\mat(B)}_{d \times n}, \tag{$\star$} -/// -/// where :math:`\alpha` and :math:`\beta` are real scalars, :math:`\op(X)` either returns a matrix :math:`X` -/// or its transpose, :math:`A` is a sparse matrix, and :math:`S` is a dense sketching operator. -/// -/// .. dropdown:: FAQ -/// :animate: fade-in-slide-down -/// -/// **What's** :math:`\mat(B)` **?** -/// -/// It's matrix of shape :math:`d \times n`. Its contents are determined by :math:`(B, \ldb)` -/// and "layout", following the same convention as the Level 3 BLAS function "GEMM." -/// -/// **What are** :math:`\submat(S)` **and** :math:`\submat(A)` **?** -/// -/// Their shapes are determined implicitly by :math:`(\opS, d, m)` and :math:`(\opA, n, m)`. -/// If :math:`{\submat(X)}` is of shape :math:`r \times c`, -/// then it is the :math:`r \times c` submatrix of :math:`{X}` whose upper-left corner -/// appears at index :math:`(\texttt{ro_x}, \texttt{co_x})` of :math:`{X}`. -/// -/// .. dropdown:: Full parameter descriptions -/// :animate: fade-in-slide-down -/// -/// layout - [in] -/// * Layout::ColMajor or Layout::RowMajor. -/// * Matrix storage for :math:`\mat(B)`. -/// -/// opS - [in] -/// * If :math:`\opS` = NoTrans, then :math:`\op(\submat(S)) = \submat(S)`. -/// * If :math:`\opS` = Trans, then :math:`\op(\submat(S)) = \submat(S)^T`. -/// -/// opA - [in] -/// * If :math:`\opA` = NoTrans, then :math:`\op(\submat(A)) = \submat(A)`. -/// * If :math:`\opA` = Trans, then :math:`\op(\submat(A)) = \submat(A)^T`. -/// -/// d - [in] -/// * A nonnegative integer. -/// * The number of rows in :math:`\mat(B)`. -/// * The number of rows in :math:`\op(\submat(S))`. -/// -/// n - [in] -/// * A nonnegative integer. -/// * The number of columns in :math:`\mat(B)`. -/// * The number of columns in :math:`\op(\mat(A))`. -/// -/// m - [in] -/// * A nonnegative integer. -/// * The number of columns in :math:`\op(\submat(S))` -/// * The number of rows in :math:`\op(\mat(A))`. -/// -/// alpha - [in] -/// * A real scalar. -/// -/// S - [in] -/// * A DenseSkOp object. -/// * Defines :math:`\submat(S)`. -/// -/// ro_s - [in] -/// * A nonnegative integer. -/// * The rows of :math:`\submat(S)` are a contiguous subset of rows of :math:`S`. -/// * The rows of :math:`\submat(S)` start at :math:`S[\texttt{ro_s}, :]`. -/// -/// co_s - [in] -/// * A nonnegative integer. -/// * The columns of :math:`\submat(S)` are a contiguous subset of columns of :math:`S`. -/// * The columns :math:`\submat(S)` start at :math:`S[:,\texttt{co_s}]`. -/// -/// A - [in] -/// * A RandBLAS sparse matrix object. -/// * Defines :math:`\submat(A)`. -/// -/// ro_a - [in] -/// * A nonnegative integer. -/// * The rows of :math:`\submat(A)` are a contiguous subset of rows of :math:`A`. -/// * The rows of :math:`\submat(A)` start at :math:`A[\texttt{ro_a}, :]`. -/// -/// co_a - [in] -/// * A nonnegative integer. -/// * The columns of :math:`\submat(A)` are a contiguous subset of columns of :math:`A`. -/// * The columns :math:`\submat(A)` start at :math:`A[:,\texttt{co_a}]`. -/// -/// beta - [in] -/// * A real scalar. -/// * If zero, then :math:`B` need not be set on input. -/// -/// B - [in, out] -/// * Pointer to 1D array of real scalars. -/// * On entry, defines :math:`\mat(B)` -/// on the RIGHT-hand side of :math:`(\star)`. -/// * On exit, defines :math:`\mat(B)` -/// on the LEFT-hand side of :math:`(\star)`. -/// -/// ldb - [in] -/// * A nonnegative integer. -/// * Leading dimension of :math:`\mat(B)` when reading from :math:`B`. -/// -/// @endverbatim -template -void lsksp3( - blas::Layout layout, - blas::Op opS, - blas::Op opA, - int64_t d, // B is d-by-n - int64_t n, // op(submat(A)) is m-by-n - int64_t m, // op(submat(S)) is d-by-m - T alpha, - DenseSkOp &S, - int64_t ro_s, - int64_t co_s, - SpMat &A, - int64_t ro_a, - int64_t co_a, - T beta, - T *B, - int64_t ldb -) { - // B = op(submat(S)) @ op(submat(A)) - auto [rows_submat_S, cols_submat_S] = dims_before_op(d, m, opS); - if (!S.buff) { - T *buff = new T[rows_submat_S * cols_submat_S]; - fill_dense(S.dist, rows_submat_S, cols_submat_S, ro_s, co_s, buff, S.seed_state); - DenseDist D{rows_submat_S, cols_submat_S, DenseDistName::BlackBox, S.dist.major_axis}; - DenseSkOp S_(D, S.seed_state, buff); - lsksp3(layout, opS, opA, d, n, m, alpha, S_, 0, 0, A, ro_a, co_a, beta, B, ldb); - delete [] buff; - return; - } - - auto [rows_submat_A, cols_submat_A] = dims_before_op(m, n, opA); - randblas_require( A.n_rows >= rows_submat_A + ro_a ); - randblas_require( A.n_cols >= cols_submat_A + co_a ); - randblas_require( S.dist.n_rows >= rows_submat_S + ro_s ); - randblas_require( S.dist.n_cols >= cols_submat_S + co_s ); - if (layout == blas::Layout::ColMajor) { - randblas_require(ldb >= d); - } else { - randblas_require(ldb >= n); - } - - auto [pos, lds] = offset_and_ldim(S.layout, S.dist.n_rows, S.dist.n_cols, ro_s, co_s); - T* S_ptr = &S.buff[pos]; - if (S.layout != layout) - opS = (opS == blas::Op::NoTrans) ? blas::Op::Trans : blas::Op::NoTrans; - - right_spmm(layout, opS, opA, d, n, m, alpha, S_ptr, lds, A, ro_a, co_a, beta, B, ldb); - return; -} - - -// ============================================================================= -/// \fn rsksp3(blas::Layout layout, blas::Op opA, blas::Op opS, int64_t m, -/// int64_t d, int64_t n, T alpha, SpMat &A, int64_t ro_a, int64_t co_a, -/// DenseSkOp &S, int64_t ro_s, int64_t co_s, T beta, T *B, int64_t ldb -/// ) -/// @verbatim embed:rst:leading-slashes -/// Sketch from the right in an SpMM-like operation -/// -/// .. math:: -/// \mat(B) = \alpha \cdot \underbrace{\op(\submat(A))}_{m \times n} \cdot \underbrace{\op(\submat(S))}_{n \times d} + \beta \cdot \underbrace{\mat(B)}_{m \times d}, \tag{$\star$} -/// -/// where :math:`\alpha` and :math:`\beta` are real scalars, :math:`\op(X)` either returns a matrix :math:`X` -/// or its transpose, :math:`A` is a sparse matrix, and :math:`S` is a dense sketching operator. -/// -/// .. dropdown:: FAQ -/// :animate: fade-in-slide-down -/// -/// **What's** :math:`\mat(B)` **?** -/// -/// It's matrix of shape :math:`m \times d`. Its contents are determined by :math:`(B, \ldb)` -/// and "layout", following the same convention as the Level 3 BLAS function "GEMM." -/// -/// **What are** :math:`\submat(S)` **and** :math:`\submat(A)` **?** -/// -/// Their shapes are determined implicitly by :math:`(\opS, n, d)` and :math:`(\opA, m, n)`. -/// If :math:`{\submat(X)}` is of shape :math:`r \times c`, -/// then it is the :math:`r \times c` submatrix of :math:`{X}` whose upper-left corner -/// appears at index :math:`(\texttt{ro_x}, \texttt{co_x})` of :math:`{X}`. -/// -/// .. dropdown:: Full parameter descriptions -/// :animate: fade-in-slide-down -/// -/// layout - [in] -/// * Layout::ColMajor or Layout::RowMajor. -/// * Matrix storage for :math:`\mat(B)`. -/// -/// opA - [in] -/// * If :math:`\opA` == NoTrans, then :math:`\op(\submat(A)) = \submat(A)`. -/// * If :math:`\opA` == Trans, then :math:`\op(\submat(A)) = \submat(A)^T`. -/// -/// opS - [in] -/// * If :math:`\opS` = NoTrans, then :math:`\op(\submat(S)) = \submat(S)`. -/// * If :math:`\opS` = Trans, then :math:`\op(\submat(S)) = \submat(S)^T`. -/// -/// m - [in] -/// * A nonnegative integer. -/// * The number of rows in :math:`\mat(B)`. -/// * The number of rows in :math:`\op(\submat(A))`. -/// -/// d - [in] -/// * A nonnegative integer. -/// * The number of columns in :math:`\mat(B)` -/// * The number of columns in :math:`\op(\submat(S))`. -/// -/// n - [in] -/// * A nonnegative integer. -/// * The number of columns in :math:`\op(\submat(A))` -/// * The number of rows in :math:`\op(\submat(S))`. -/// -/// alpha - [in] -/// * A real scalar. -/// -/// A - [in] -/// * A RandBLAS sparse matrix object. -/// * Defines :math:`\submat(A)`. -/// -/// ro_a - [in] -/// * A nonnegative integer. -/// * The rows of :math:`\submat(A)` are a contiguous subset of rows of :math:`A`. -/// * The rows of :math:`\submat(A)` start at :math:`A[\texttt{ro_a}, :]`. -/// -/// co_a - [in] -/// * A nonnegative integer. -/// * The columns of :math:`\submat(A)` are a contiguous subset of columns of :math:`A`. -/// * The columns :math:`\submat(A)` start at :math:`A[:,\texttt{co_a}]`. -/// -/// S - [in] -/// * A DenseSkOp object. -/// * Defines :math:`\submat(S)`. -/// -/// ro_s - [in] -/// * A nonnegative integer. -/// * The rows of :math:`\submat(S)` are a contiguous subset of rows of :math:`S`. -/// * The rows of :math:`\submat(S)` start at :math:`S[\texttt{ro_s}, :]`. -/// -/// co_s - [in] -/// * A nonnegative integer. -/// * The columns of :math:`\submat(S)` are a contiguous subset of columns of :math:`S`. -/// * The columns :math:`\submat(S)` start at :math:`S[:,\texttt{co_s}]`. -/// -/// beta - [in] -/// * A real scalar. -/// * If zero, then :math:`B` need not be set on input. -/// -/// B - [in, out] -/// * Pointer to 1D array of real scalars. -/// * On entry, defines :math:`\mat(B)` -/// on the RIGHT-hand side of :math:`(\star)`. -/// * On exit, defines :math:`\mat(B)` -/// on the LEFT-hand side of :math:`(\star)`. -/// -/// ldb - [in] -/// * A nonnegative integer. -/// * Leading dimension of :math:`\mat(B)` when reading from :math:`B`. -/// -/// @endverbatim -template -void rsksp3( - blas::Layout layout, - blas::Op opA, - blas::Op opS, - int64_t m, // B is m-by-d - int64_t d, // op(submat(A)) is m-by-n - int64_t n, // op(submat(S)) is n-by-d - T alpha, - SpMat &A, - int64_t ro_a, - int64_t co_a, - DenseSkOp &S, - int64_t ro_s, - int64_t co_s, - T beta, - T *B, - int64_t ldb -) { - auto [rows_submat_S, cols_submat_S] = dims_before_op(n, d, opS); - if (!S.buff) { - T *buff = new T[rows_submat_S * cols_submat_S]; - fill_dense(S.dist, rows_submat_S, cols_submat_S, ro_s, co_s, buff, S.seed_state); - DenseDist D{rows_submat_S, cols_submat_S, DenseDistName::BlackBox, S.dist.major_axis}; - DenseSkOp S_(D, S.seed_state, buff); - rsksp3(layout, opA, opS, m, d, n, alpha, A, ro_a, co_a, S_, 0, 0, beta, B, ldb); - delete [] buff; - return; - } - auto [rows_submat_A, cols_submat_A] = dims_before_op(m, n, opA); - randblas_require( A.n_rows >= rows_submat_A + ro_a ); - randblas_require( A.n_cols >= cols_submat_A + co_a ); - randblas_require( S.dist.n_rows >= rows_submat_S + ro_s ); - randblas_require( S.dist.n_cols >= cols_submat_S + co_s ); - if (layout == blas::Layout::ColMajor) { - randblas_require(ldb >= m); - } else { - randblas_require(ldb >= d); - } - - auto [pos, lds] = offset_and_ldim(S.layout, S.dist.n_rows, S.dist.n_cols, ro_s, co_s); - T* S_ptr = &S.buff[pos]; - if (S.layout != layout) - opS = (opS == blas::Op::NoTrans) ? blas::Op::Trans : blas::Op::NoTrans; - - left_spmm(layout, opA, opS, m, d, n, alpha, A, ro_a, co_a, S_ptr, lds, beta, B, ldb); - return; -} - -} // end namespace RandBLAS -#endif diff --git a/RandBLAS/sparse_skops.hh b/RandBLAS/sparse_skops.hh index 22d7005c..770d5076 100644 --- a/RandBLAS/sparse_skops.hh +++ b/RandBLAS/sparse_skops.hh @@ -42,9 +42,6 @@ #include #include #include -#if defined(RandBLAS_HAS_OpenMP) -#include -#endif #define MAX(a, b) (((a) < (b)) ? (b) : (a)) #define MIN(a, b) (((a) < (b)) ? (a) : (b)) diff --git a/RandBLAS/util.hh b/RandBLAS/util.hh index 579fd932..b05c51ed 100644 --- a/RandBLAS/util.hh +++ b/RandBLAS/util.hh @@ -111,19 +111,6 @@ void print_colmaj(int64_t n_rows, int64_t n_cols, T *a, char label[]) return; } -template -bool compare_ctr(typename RNG::ctr_type c1, typename RNG::ctr_type c2) { - int len = c1.size(); - - for (int ind = len - 1; ind >= 0; ind--) { - if (c1[ind] > c2[ind]) { - return true; - } else if (c1[ind] < c2[ind]) { - return false; - } - } - return false; -} template std::string type_name() { // call as type_name() diff --git a/test/test_matmul_cores/linop_common.hh b/test/test_matmul_cores/linop_common.hh index d04388c5..0383f63d 100644 --- a/test/test_matmul_cores/linop_common.hh +++ b/test/test_matmul_cores/linop_common.hh @@ -33,8 +33,7 @@ #include "RandBLAS/base.hh" #include "RandBLAS/dense_skops.hh" #include "RandBLAS/sparse_skops.hh" -#include "RandBLAS/skge3_to_gemm.hh" -#include "RandBLAS/skges_to_spmm.hh" +#include "RandBLAS/skge.hh" #include "RandBLAS/sparse_data/spmm_dispatch.hh" #include "RandBLAS/util.hh" #include "test/comparison.hh"