Skip to content

Commit

Permalink
start of implementation for sketch_symmetric functions
Browse files Browse the repository at this point in the history
  • Loading branch information
rileyjmurray committed Jun 11, 2024
1 parent 7dee45c commit a4b468e
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 0 deletions.
48 changes: 48 additions & 0 deletions RandBLAS/base.hh
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

#include "RandBLAS/config.h"
#include "RandBLAS/random_gen.hh"
#include "RandBLAS/exceptions.hh"

#include <blas.hh>
#include <tuple>
Expand Down Expand Up @@ -105,6 +106,53 @@ inline submat_spec_64t offset_and_ldim(
}
}

template <typename T>
void symmetrize(blas::Layout layout, blas::Uplo uplo, T* A, int64_t n, int64_t lda) {

auto [inter_row_stride, inter_col_stride] = layout_to_strides(layout, lda);
#define matA(_i, _j) A[(_i)*inter_row_stride + (_j)*inter_col_stride]
if (uplo == blas::Uplo::Upper) {
// copy to lower
for (int64_t i = 0; i < n; ++i) {
for (int64_t j = i+1; j < n; ++j) {
matA(j,i) = matA(i,j);
}
}
} else if (uplo == blas::Uplo::Lower) {
// copy to upper
for (int64_t i = 0; i < n; ++i) {
for (int64_t j = i+1; j < n; ++j) {
matA(i,j) = matA(j,i);
}
}
}
#undef matA
return;
}

template <typename T>
void require_symmetric(blas::Layout layout, T* A, int64_t n, int64_t lda, T tol) {

auto [inter_row_stride, inter_col_stride] = layout_to_strides(layout, lda);
#define matA(_i, _j) A[(_i)*inter_row_stride + (_j)*inter_col_stride]

for (int64_t i = 0; i < n; ++i) {
for (int64_t j = i+1; j < n; ++j) {
T Aij = matA(i,j);
T Aji = matA(j,i);
T viol = abs(Aij - Aji);
T rel_tol = (abs(Aij) + abs(Aji) + 1)*tol;
if (viol > rel_tol) {
randblas_error_if_msg(
viol > rel_tol
"Symmetry check failed. |A(%i,%i) - A(%i,%i)| was %d, which exceeds tolerance of %d", i, j, j, i, viol, rel_tol
);
}
}
}
#undef matA
return;
}

template<typename T>
concept SignedInteger = (std::numeric_limits<T>::is_signed && std::numeric_limits<T>::is_integer);
Expand Down
99 changes: 99 additions & 0 deletions RandBLAS/sksy.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
// Copyright, 2024. See LICENSE for copyright holder information.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// (1) Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
//
// (2) Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
//
// (3) Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
// POSSIBILITY OF SUCH DAMAGE.
//

#ifndef randblas_sksy_hh
#define randblas_sksy_hh

#include "RandBLAS/base.hh"
#include "RandBLAS/skge.hh"

namespace RandBLAS {

using namespace RandBLAS::dense;
using namespace RandBLAS::sparse;

/* Intended macro definitions.
.. |op| mathmacro:: \operatorname{op}
.. |mat| mathmacro:: \operatorname{mat}
.. |submat| mathmacro:: \operatorname{submat}
.. |lda| mathmacro:: \texttt{lda}
.. |ldb| mathmacro:: \texttt{ldb}
.. |opA| mathmacro:: \texttt{opA}
.. |opS| mathmacro:: \texttt{opS}
*/

template <typename T, typename SKOP>
inline void sketch_symmetric(
// B = alpha*A*S + beta*B, where A is a symmetric matrix stored in the format of a general matrix.
blas::Layout layout, // layout for (A,B)
int64_t n, // number of rows in B
int64_t d, // number of columns in B
T alpha,
const T* A,
int64_t lda,
SKOP &S,
int64_t ro_s,
int64_t co_s,
T beta,
T* B,
int64_t ldb,
T sym_check_tol = -1
) {
if (sym_check_tol >= 0) {
require_symmetric(layout, A, n, lda, sym_check_tol);
}
sketch_general(layout, blas::Op::NoTrans, blas::Op::NoTrans, n, d, n, alpha, A, lda, S, ro_s, co_s, beta, B, ldb);
}

template <typename T, typename SKOP>
inline void sketch_symmetric(
// B = alpha*S*A + beta*B
blas::Layout layout, // layout for (A,B)
int64_t d, // number of rows in B
int64_t n, // number of columns in B
T alpha,
SKOP &S,
int64_t ro_s,
int64_t co_s,
const T* A,
int64_t lda,
T beta,
T* B,
int64_t ldb,
T sym_check_tol = -1
) {
if (sym_check_tol >= 0) {
require_symmetric(layout, A, n, lda, sym_check_tol);
}
sketch_general(layout, blas::Op::NoTrans, blas::Op::NoTrans, d, n, n, alpha, S, ro_s, co_s, A, beta, B, ldb);
}

} // end namespace RandBLAS
#endif

0 comments on commit a4b468e

Please sign in to comment.