Skip to content

Commit

Permalink
remove Axis::Undefined. Removed ScalarDist::BlackBox. Create BLASFrie…
Browse files Browse the repository at this point in the history
…ndlyOperator struct to replace the old abuses of DenseSkOp that used Axis::Undefined and ScalarDist::BlackBox.
  • Loading branch information
rileyjmurray committed Sep 1, 2024
1 parent 0fef526 commit 2311ab3
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 115 deletions.
13 changes: 9 additions & 4 deletions RandBLAS/base.hh
Original file line number Diff line number Diff line change
Expand Up @@ -285,12 +285,17 @@ enum class Axis : char {
Short = 'S',

// ---------------------------------------------------------------------------
Long = 'L',

// ---------------------------------------------------------------------------
Undefined = 'U'
Long = 'L'
};

inline int64_t get_dim_major(Axis major_axis, int64_t n_rows, int64_t n_cols) {
if (major_axis == Axis::Long) {
return std::max(n_rows, n_cols);
} else {
return std::min(n_rows, n_cols);
}
}


#ifdef __cpp_concepts
// =============================================================================
Expand Down
110 changes: 51 additions & 59 deletions RandBLAS/dense_skops.hh
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,6 @@ static RNGState<RNG> fill_dense_submat_impl(int64_t n_cols, T* smat, int64_t n_s

template <typename RNG, typename DD>
RNGState<RNG> compute_next_state(DD dist, RNGState<RNG> state) {
if (dist.major_axis == Axis::Undefined) {
// implies dist.family = ScalarDist::BlackBox
throw std::invalid_argument("Cannot compute next_state when dist.family is BlackBox");
}
// ^ This is the only place where Axis is actually used to some
// productive end.
int64_t major_len = dist.dim_major;
int64_t minor_len = dist.dim_minor;
int64_t ctr_size = RNG::ctr_type::static_size;
Expand All @@ -191,8 +185,6 @@ RNGState<RNG> compute_next_state(DD dist, RNGState<RNG> state) {
}

inline blas::Layout natural_layout(Axis major_axis, int64_t n_rows, int64_t n_cols) {
if (major_axis == Axis::Undefined || n_rows == n_cols)
return blas::Layout::ColMajor;
bool is_wide = n_rows < n_cols;
bool fa_long = major_axis == Axis::Long;
if (is_wide && fa_long) {
Expand All @@ -214,9 +206,7 @@ namespace RandBLAS {
// =============================================================================
/// We support two distributions for dense sketching operators: those whose
/// entries are iid Gaussians or iid uniform over a symmetric interval.
/// For implementation reasons, we also expose an option to indicate that an
/// operator's distribution is unknown but it is still represented by a buffer
/// that can be used in GEMM.
///
enum class ScalarDist : char {
// ---------------------------------------------------------------------------
/// Indicates the Gaussian distribution with mean 0 and variance 1.
Expand All @@ -225,12 +215,7 @@ enum class ScalarDist : char {
// ---------------------------------------------------------------------------
/// Indicates the uniform distribution over [-r, r] where r := sqrt(3)
/// is the radius that provides for a variance of 1.
Uniform = 'U',

// ---------------------------------------------------------------------------
/// Indicates that the sketching operator's entries will only be specified by
/// a user-provided buffer.
BlackBox = 'B'
Uniform = 'U'
};

// =============================================================================
Expand Down Expand Up @@ -274,6 +259,7 @@ struct DenseDist {
/// Defined as \math{\ttt{n_rows} + \ttt{n_cols} - \ttt{dim_major}.} This is
/// just whichever of \math{(\ttt{n_rows}, \ttt{n_cols})} wasn't identified
/// as \math{\ttt{dim_major}.}
///
const int64_t dim_minor;

// ---------------------------------------------------------------------------
Expand All @@ -282,7 +268,7 @@ struct DenseDist {
const double isometry_scale;

// ---------------------------------------------------------------------------
/// The distribution used for the entries of operators sampled from this distribution.
/// The distribution on \math{\mathbb{R}} for entries of operators sampled from this distribution.
const ScalarDist family;

// ---------------------------------------------------------------------------
Expand Down Expand Up @@ -325,9 +311,9 @@ struct DenseDist {
) : // variable definitions
n_rows(n_rows), n_cols(n_cols),
major_axis(major_axis),
dim_major((major_axis == Axis::Short) ? std::min(n_rows, n_cols) : std::max(n_rows, n_cols)),
dim_minor(n_rows + n_cols - dim_major),
isometry_scale((family == ScalarDist::BlackBox) ? 1.0 : std::pow(dim_minor, -0.5)),
dim_major((major_axis == Axis::Long) ? std::max(n_rows, n_cols) : std::min(n_rows, n_cols)),
dim_minor((major_axis == Axis::Long) ? std::min(n_rows, n_cols) : std::max(n_rows, n_cols)),
isometry_scale(std::pow(dim_minor, -0.5)),
family(family),
natural_layout(dense::natural_layout(major_axis, n_rows, n_cols))
{ // argument validation
Expand Down Expand Up @@ -395,17 +381,22 @@ struct DenseSkOp {
/// as a dense matrix.
///
/// If non-null this must point to an array of length at least
/// \math{\ttt{dist.n_cols * dist.n_rows}.} Furthermore, we will presume that
/// this array contains the random samples from \math{\ttt{dist}} implied
/// by \math{\ttt{seed_state}.} See DenseSkOp::layout for more information.
/// \math{\ttt{dist.n_cols * dist.n_rows},} and this array must contain the
/// random samples from \math{\ttt{dist}} implied by \math{\ttt{seed_state}.} See DenseSkOp::layout for more information.
T *buff = nullptr;

// ---------------------------------------------------------------------------
/// The storage order that should be used for any read or write operations
/// with \math{\ttt{buff}.} The leading dimension when reading from the buffer is
/// \math{\ttt{S.dist.dim_major}.}
/// with \math{\ttt{buff}.} The leading dimension when reading from \math{\ttt{buff}}
/// is assumed to be
/// @verbatim embed:rst:leading-slashes
/// .. math::
///
/// \ttt{lds} = \begin{cases} \ttt{n_rows} & \text{ if } ~~ \ttt{layout == ColMajor} \\ \ttt{n_cols} & \text{ if } ~~ \ttt{layout == RowMajor}.
///
/// @endverbatim
///
blas::Layout layout;
const blas::Layout layout;


/////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -441,26 +432,6 @@ struct DenseSkOp {
buff(nullptr),
layout(dist.natural_layout) { }

/// ---------------------------------------------------------------------------
/// **Expert constructor**. Arguments passed to this function are
/// used to initialize members of the same name. own_memory is initialized to false.
///
DenseSkOp(
DenseDist dist,
const state_t &seed_state,
const state_t &next_state,
// ^ It would be nice to set next_state in an initializer list based on seed_state like we do with SparseSkOp.
// We can't do that since the possibility of dist.family == BlackBox means we might be allowed to handle
// random number generation. When this constructor is used it's the user's responsibility to set next_state
// correctly based on the value of buff (or the value of buff that they intend to use eventually). If a user
// is confident that they won't need next_state then they can just set it to state_t(0).
T *buff,
blas::Layout layout
) :
dist(dist), seed_state(seed_state), next_state(next_state),
n_rows(dist.n_rows), n_cols(dist.n_cols),
own_memory(false), buff(buff), layout(layout) { }

// Move constructor
DenseSkOp(
DenseSkOp<T,RNG> &&S
Expand Down Expand Up @@ -596,9 +567,6 @@ RNGState<RNG> fill_dense(blas::Layout layout, const DenseDist &D, int64_t n_rows
blas::scal(n_rows_ * n_cols_, (T)std::sqrt(3), buff, 1);
break;
}
case ScalarDist::BlackBox: {
throw std::invalid_argument(std::string("fill_dense cannot be called with the BlackBox distribution."));
}
default: {
throw std::runtime_error(std::string("Unrecognized distribution."));
}
Expand Down Expand Up @@ -650,31 +618,55 @@ RNGState<RNG> fill_dense(const DenseDist &D, T *buff, const RNGState<RNG> &seed)
/// On exit, one can encode a BLAS-style representation of \math{\ttt{S}} with the tuple
/// @verbatim embed:rst:leading-slashes
/// .. math::
/// (\ttt{S.layout},~\ttt{S.n_rows},~\ttt{S.n_cols},~\ttt{S.buff},~\ttt{S.dist.dim_major})
/// @endverbatim
///
/// (\ttt{S.layout},~\ttt{S.n_rows},~\ttt{S.n_cols},~\ttt{S.buff},~\ttt{lds})
///
/// where
///
/// .. math::
///
/// \ttt{lds} = \begin{cases} \ttt{n_rows} & \text{ if } ~~ \ttt{S.layout == ColMajor} \\ \ttt{n_cols} & \text{ if } ~~ \ttt{S.layout == RowMajor}.
///
/// @endverbatim
template <typename DenseSkOp>
void fill_dense(DenseSkOp &S) {
if (S.own_memory && S.buff == nullptr) {
using T = typename DenseSkOp::scalar_t;
S.buff = new T[S.n_rows * S.n_cols];
}
randblas_require(S.buff != nullptr);
fill_dense(S.dist, S.buff, S.seed_state);
fill_dense(S.layout, S.dist, S.n_rows, S.n_cols, 0, 0, S.buff, S.seed_state);
return;
}

template <typename DenseSkOp>
DenseSkOp submatrix_as_blackbox(const DenseSkOp &S, int64_t n_rows, int64_t n_cols, int64_t ro_s, int64_t co_s) {
template <typename T>
struct BLASFriendlyOperator {
using scalar_t = T;
const blas::Layout layout;
const int64_t n_rows;
const int64_t n_cols;
T* buff;
const int64_t ldim;
const bool own_memory;

~BLASFriendlyOperator() {
if (own_memory && buff != nullptr) {
delete [] buff;
}
}
};

// NOTE: the returned operator satisfies submatrix.layout == S.dist.natural_layout even if this differs from S.layout.
template <typename BFO, typename DenseSkOp>
BFO submatrix_as_blackbox(const DenseSkOp &S, int64_t n_rows, int64_t n_cols, int64_t ro_s, int64_t co_s) {
randblas_require(ro_s + n_rows <= S.n_rows);
randblas_require(co_s + n_cols <= S.n_cols);
using T = typename DenseSkOp::scalar_t;
T *buff = new T[n_rows * n_cols];
auto layout = S.dist.natural_layout;
auto layout = S.layout;
fill_dense(layout, S.dist, n_rows, n_cols, ro_s, co_s, buff, S.seed_state);
DenseDist submatrix_dist(n_rows, n_cols, ScalarDist::BlackBox, Axis::Undefined);
DenseSkOp submatrix(submatrix_dist, S.seed_state, S.next_state, buff, layout);
submatrix.own_memory = true;
int64_t dim_major = S.dist.dim_major;
BFO submatrix{layout, n_rows, n_cols, buff, dim_major, true};
return submatrix;
}

Expand Down
66 changes: 38 additions & 28 deletions RandBLAS/skge.hh
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ using RandBLAS::fill_dense;
/// - Leading dimension of \math{\mat(B)} when reading from \math{B}.
/// - Refer to documentation for \math{\lda} for details.
///
template <typename T, typename RNG>
template <typename T, typename DenseSkOp>
void lskge3(
blas::Layout layout,
blas::Op opS,
Expand All @@ -158,7 +158,7 @@ void lskge3(
int64_t n, // op(A) is m-by-n
int64_t m, // op(S) is d-by-m
T alpha,
DenseSkOp<T,RNG> &S,
DenseSkOp &S,
int64_t ro_s,
int64_t co_s,
const T *A,
Expand All @@ -168,13 +168,19 @@ void lskge3(
int64_t ldb
){
auto [rows_submat_S, cols_submat_S] = dims_before_op(d, m, opS);
if (!S.buff) {
auto submat_S = submatrix_as_blackbox(S, rows_submat_S, cols_submat_S, ro_s, co_s);
lskge3(layout, opS, opA, d, n, m, alpha, submat_S, 0, 0, A, lda, beta, B, ldb);
return;
constexpr bool maybe_denseskop = !std::is_same_v<std::remove_cv_t<DenseSkOp>, BLASFriendlyOperator<T>>;
if constexpr (maybe_denseskop) {
if (!S.buff) {
// DenseSkOp doesn't permit defining a "black box" distribution, so we have to pack the submatrix
// into an equivalent datastructure ourselves.
auto submat_S = submatrix_as_blackbox<BLASFriendlyOperator<T>>(S, rows_submat_S, cols_submat_S, ro_s, co_s);
lskge3(layout, opS, opA, d, n, m, alpha, submat_S, 0, 0, A, lda, beta, B, ldb);
return;
} // else, continue with the function as usual.
}
randblas_require( S.dist.n_rows >= rows_submat_S + ro_s );
randblas_require( S.dist.n_cols >= cols_submat_S + co_s );
randblas_require( S.buff != nullptr );
randblas_require( S.n_rows >= rows_submat_S + ro_s );
randblas_require( S.n_cols >= cols_submat_S + co_s );
auto [rows_A, cols_A] = dims_before_op(m, n, opA);
if (layout == blas::Layout::ColMajor) {
randblas_require(lda >= rows_A);
Expand All @@ -184,7 +190,7 @@ void lskge3(
randblas_require(ldb >= n);
}

auto [pos, lds] = offset_and_ldim(S.layout, S.dist.n_rows, S.dist.n_cols, ro_s, co_s);
auto [pos, lds] = offset_and_ldim(S.layout, S.n_rows, S.n_cols, ro_s, co_s);
T* S_ptr = &S.buff[pos];
if (S.layout != layout)
opS = (opS == blas::Op::NoTrans) ? blas::Op::Trans : blas::Op::NoTrans;
Expand Down Expand Up @@ -296,7 +302,7 @@ void lskge3(
/// - Leading dimension of \math{\mat(B)} when reading from \math{B}.
/// - Refer to documentation for \math{\lda} for details.
///
template <typename T, typename RNG>
template <typename T, typename DenseSkOp>
void rskge3(
blas::Layout layout,
blas::Op opA,
Expand All @@ -307,23 +313,27 @@ void rskge3(
T alpha,
const T *A,
int64_t lda,
DenseSkOp<T,RNG> &S,
DenseSkOp &S,
int64_t ro_s,
int64_t co_s,
T beta,
T *B,
int64_t ldb
){
auto [rows_submat_S, cols_submat_S] = dims_before_op(n, d, opS);
if (!S.buff) {
// We'll make a shallow copy of the sketching operator, take responsibility for filling the memory
// of that sketching operator, and then call RSKGE3 with that new object.
auto submat_S = submatrix_as_blackbox(S, rows_submat_S, cols_submat_S, ro_s, co_s);
rskge3(layout, opA, opS, m, d, n, alpha, A, lda, submat_S, 0, 0, beta, B, ldb);
return;
constexpr bool maybe_denseskop = !std::is_same_v<std::remove_cv_t<DenseSkOp>, BLASFriendlyOperator<T>>;
if constexpr (maybe_denseskop) {
if (!S.buff) {
// DenseSkOp doesn't permit defining a "black box" distribution, so we have to pack the submatrix
// into an equivalent datastructure ourselves.
auto submat_S = submatrix_as_blackbox<BLASFriendlyOperator<T>>(S, rows_submat_S, cols_submat_S, ro_s, co_s);
rskge3(layout, opA, opS, m, d, n, alpha, A, lda, submat_S, 0, 0, beta, B, ldb);
return;
}
}
randblas_require( S.dist.n_rows >= rows_submat_S + ro_s );
randblas_require( S.dist.n_cols >= cols_submat_S + co_s );
randblas_require( S.buff != nullptr );
randblas_require( S.n_rows >= rows_submat_S + ro_s );
randblas_require( S.n_cols >= cols_submat_S + co_s );
auto [rows_A, cols_A] = dims_before_op(m, n, opA);
if (layout == blas::Layout::ColMajor) {
randblas_require(lda >= rows_A);
Expand All @@ -333,7 +343,7 @@ void rskge3(
randblas_require(ldb >= d);
}

auto [pos, lds] = offset_and_ldim(S.layout, S.dist.n_rows, S.dist.n_cols, ro_s, co_s);
auto [pos, lds] = offset_and_ldim(S.layout, S.n_rows, S.n_cols, ro_s, co_s);
T* S_ptr = &S.buff[pos];
if (S.layout != layout)
opS = (opS == blas::Op::NoTrans) ? blas::Op::Trans : blas::Op::NoTrans;
Expand Down Expand Up @@ -1070,11 +1080,11 @@ inline void sketch_general(
int64_t ldb
) {
if (opS == blas::Op::NoTrans) {
randblas_require(S.dist.n_rows == d);
randblas_require(S.dist.n_cols == m);
randblas_require(S.n_rows == d);
randblas_require(S.n_cols == m);
} else {
randblas_require(S.dist.n_rows == m);
randblas_require(S.dist.n_cols == d);
randblas_require(S.n_rows == m);
randblas_require(S.n_cols == d);
}
return sketch_general(layout, opS, opA, d, n, m, alpha, S, 0, 0, A, lda, beta, B, ldb);
};
Expand Down Expand Up @@ -1172,11 +1182,11 @@ inline void sketch_general(
int64_t ldb
) {
if (opS == blas::Op::NoTrans) {
randblas_require(S.dist.n_rows == n);
randblas_require(S.dist.n_cols == d);
randblas_require(S.n_rows == n);
randblas_require(S.n_cols == d);
} else {
randblas_require(S.dist.n_rows == d);
randblas_require(S.dist.n_cols == n);
randblas_require(S.n_rows == d);
randblas_require(S.n_cols == n);
}
return sketch_general(layout, opA, opS, m, d, n, alpha, A, lda, S, 0, 0, beta, B, ldb);
};
Expand Down
Loading

0 comments on commit 2311ab3

Please sign in to comment.