Skip to content

Commit

Permalink
rename DenseDistName to ScalarDist
Browse files Browse the repository at this point in the history
  • Loading branch information
rileyjmurray committed Aug 26, 2024
1 parent f08c4bf commit e2e05af
Show file tree
Hide file tree
Showing 11 changed files with 95 additions and 94 deletions.
35 changes: 18 additions & 17 deletions RandBLAS/dense_skops.hh
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ static RNGState<RNG> fill_dense_submat_impl(int64_t n_cols, T* smat, int64_t n_s
template <typename RNG, typename DD>
RNGState<RNG> compute_next_state(DD dist, RNGState<RNG> state) {
if (dist.major_axis == MajorAxis::Undefined) {
// implies dist.family = DenseDistName::BlackBox
// implies dist.family = ScalarDist::BlackBox
return state;
}
// ^ This is the only place where MajorAxis is actually used to some
Expand All @@ -190,9 +190,10 @@ RNGState<RNG> compute_next_state(DD dist, RNGState<RNG> state) {
return state;
}

template <typename DDN>
inline double isometry_scale(DDN dn, int64_t n_rows, int64_t n_cols) {
return (dn == DDN::BlackBox) ? 1.0 : std::pow(std::min(n_rows, n_cols), -0.5);
// We only template this function because ScalarDistribution has defined later.
template <typename ScalarDistribution>
inline double isometry_scale(ScalarDistribution sd, int64_t n_rows, int64_t n_cols) {
return (sd == ScalarDistribution::BlackBox) ? 1.0 : std::pow(std::min(n_rows, n_cols), -0.5);
}

inline blas::Layout natural_layout(MajorAxis major_axis, int64_t n_rows, int64_t n_cols) {
Expand Down Expand Up @@ -222,7 +223,7 @@ namespace RandBLAS {
/// For implementation reasons, we also expose an option to indicate that an
/// operator's distribution is unknown but it is still represented by a buffer
/// that can be used in GEMM.
enum class DenseDistName : char {
enum class ScalarDist : char {
// ---------------------------------------------------------------------------
/// Indicates the Gaussian distribution with mean 0 and variance 1.
Gaussian = 'G',
Expand Down Expand Up @@ -266,7 +267,7 @@ struct DenseDist {

// ---------------------------------------------------------------------------
/// The distribution used for the entries of the sketching operator.
const DenseDistName family;
const ScalarDist family;

// ---------------------------------------------------------------------------
/// The natural memory layout implied by major_axis, n_rows, and n_cols.
Expand All @@ -278,19 +279,19 @@ struct DenseDist {
/// A distribution over matrices of shape (n_rows, n_cols) with entries drawn
/// iid from a mean-zero variance-one distribution.
/// The default distribution is standard-normal. One can opt for the uniform
/// distribution over \math{[-\sqrt{3}, \sqrt{3}]} by setting dn = DenseDistName::Uniform.
/// distribution over \math{[-\sqrt{3}, \sqrt{3}]} by setting family = ScalarDist::Uniform.
DenseDist(
int64_t n_rows,
int64_t n_cols,
DenseDistName dn = DenseDistName::Gaussian,
ScalarDist family = ScalarDist::Gaussian,
MajorAxis ma = MajorAxis::Long
) : // variable definitions
n_rows(n_rows), n_cols(n_cols), major_axis(ma),
isometry_scale(dense::isometry_scale(dn, n_rows, n_cols)),
family(dn),
isometry_scale(dense::isometry_scale(family, n_rows, n_cols)),
family(family),
natural_layout(dense::natural_layout(ma, n_rows, n_cols))
{ // argument validation
if (dn == DenseDistName::BlackBox) {
if (family == ScalarDist::BlackBox) {
randblas_require(ma == MajorAxis::Undefined);
} else {
randblas_require(ma != MajorAxis::Undefined);
Expand Down Expand Up @@ -417,7 +418,7 @@ struct DenseSkOp {
{ // sanity checks
randblas_require(this->dist.n_rows > 0);
randblas_require(this->dist.n_cols > 0);
randblas_require(this->dist.family != DenseDistName::BlackBox);
randblas_require(this->dist.family != ScalarDist::BlackBox);
}

///---------------------------------------------------------------------------
Expand Down Expand Up @@ -536,16 +537,16 @@ RNGState<RNG> fill_dense(blas::Layout layout, const DenseDist &D, int64_t n_rows
}
RNGState<RNG> next_state{};
switch (D.family) {
case DenseDistName::Gaussian: {
case ScalarDist::Gaussian: {
next_state = fill_dense_submat_impl<T,RNG,r123ext::boxmul>(ma_len, buff, n_rows_, n_cols_, ptr, seed);
break;
}
case DenseDistName::Uniform: {
case ScalarDist::Uniform: {
next_state = fill_dense_submat_impl<T,RNG,r123ext::uneg11>(ma_len, buff, n_rows_, n_cols_, ptr, seed);
blas::scal(n_rows_ * n_cols_, (T)std::sqrt(3), buff, 1);
break;
}
case DenseDistName::BlackBox: {
case ScalarDist::BlackBox: {
throw std::invalid_argument(std::string("fill_dense cannot be called with the BlackBox distribution."));
}
default: {
Expand Down Expand Up @@ -603,7 +604,7 @@ void fill_dense(DenseSkOp &S) {
randblas_require(S.buff == nullptr);
// TODO: articulate why S.own_memory == true is important. (It's because it safeguards
// against the chance of introducing a memory leak.)
randblas_require(S.dist.family != DenseDistName::BlackBox);
randblas_require(S.dist.family != ScalarDist::BlackBox);
using T = typename DenseSkOp::scalar_t;
S.buff = new T[S.dist.n_rows * S.dist.n_cols];
fill_dense(S.dist, S.buff, S.seed_state);
Expand All @@ -618,7 +619,7 @@ DenseSkOp submatrix_as_blackbox(const DenseSkOp &S, int64_t n_rows, int64_t n_co
T *buff = new T[n_rows * n_cols];
auto layout = S.dist.natural_layout;
fill_dense(layout, S.dist, n_rows, n_cols, ro_s, co_s, buff, S.seed_state);
DenseDist submatrix_dist(n_rows, n_cols, DenseDistName::BlackBox, MajorAxis::Undefined);
DenseDist submatrix_dist(n_rows, n_cols, ScalarDist::BlackBox, MajorAxis::Undefined);
DenseSkOp submatrix(submatrix_dist, S.seed_state, S.next_state, buff, layout);
return submatrix;
}
Expand Down
4 changes: 2 additions & 2 deletions examples/sparse-low-rank-approx/qrcp_matrixmarket.cc
Original file line number Diff line number Diff line change
Expand Up @@ -261,11 +261,11 @@ void power_iter_col_sketch(SpMat &A, int64_t k, T* Y, int64_t p_data_aware, STAT

int64_t p_done = 0;
if (p_data_aware % 2 == 0) {
RandBLAS::DenseDist D(k, m, RandBLAS::DenseDistName::Gaussian);
RandBLAS::DenseDist D(k, m, RandBLAS::ScalarDist::Gaussian);
TIMED_LINE(
RandBLAS::fill_dense(D, mat_work2, state), "sampling : ")
} else {
RandBLAS::DenseDist D(k, n, RandBLAS::DenseDistName::Gaussian);
RandBLAS::DenseDist D(k, n, RandBLAS::ScalarDist::Gaussian);
TIMED_LINE(
RandBLAS::fill_dense(D, mat_work1, state), "sampling : ")
TIMED_LINE(
Expand Down
4 changes: 2 additions & 2 deletions examples/sparse-low-rank-approx/svd_rank1_plus_noise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@ void iid_sparsify_random_dense(
int64_t n_rows, int64_t n_cols, int64_t stride_row, int64_t stride_col, T* mat, T prob_of_zero, RandBLAS::RNGState<RNG> state
) {
auto spar = new T[n_rows * n_cols];
auto dist = RandBLAS::DenseDist(n_rows, n_cols, RandBLAS::DenseDistName::Uniform);
auto dist = RandBLAS::DenseDist(n_rows, n_cols, RandBLAS::ScalarDist::Uniform);
auto next_state = RandBLAS::fill_dense(dist, spar, state);

auto temp = new T[n_rows * n_cols];
auto D_mat = RandBLAS::DenseDist(n_rows, n_cols, RandBLAS::DenseDistName::Uniform);
auto D_mat = RandBLAS::DenseDist(n_rows, n_cols, RandBLAS::ScalarDist::Uniform);
RandBLAS::fill_dense(D_mat, temp, next_state);

#define SPAR(_i, _j) spar[(_i) + (_j) * n_rows]
Expand Down
2 changes: 1 addition & 1 deletion rtd/source/api_reference/skops_and_dists.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ Distributions
:project: RandBLAS
:members:

.. doxygenenum:: RandBLAS::DenseDistName
.. doxygenenum:: RandBLAS::ScalarDist
:project: RandBLAS

.. dropdown:: SparseDist : a distribution over structured sparse matrices
Expand Down
2 changes: 1 addition & 1 deletion test/test_basic_rng/benchmark_speed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ int main(int argc, char **argv)
int64_t m = atoi(argv[1]);
int64_t n = atoi(argv[2]);
int64_t d = m*n;
RandBLAS::DenseDist dist{m, n, RandBLAS::DenseDistName::Uniform};
RandBLAS::DenseDist dist{m, n, RandBLAS::ScalarDist::Uniform};

std::vector<T> mat(d);

Expand Down
52 changes: 26 additions & 26 deletions test/test_basic_rng/test_continuous.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
#include "RandBLAS/util.hh"
#include "RandBLAS/dense_skops.hh"
using RandBLAS::RNGState;
using RandBLAS::DenseDistName;
using RandBLAS::ScalarDist;
#include "rng_common.hh"

#include <algorithm>
Expand All @@ -54,12 +54,12 @@ class TestScalarDistributions : public ::testing::Test {

template <typename T>
static void kolmogorov_smirnov_tester(
std::vector<T> &samples, double critical_value, DenseDistName dn
std::vector<T> &samples, double critical_value, ScalarDist sd
) {
auto F_true = [dn](T x) {
if (dn == DenseDistName::Gaussian) {
auto F_true = [sd](T x) {
if (sd == ScalarDist::Gaussian) {
return RandBLAS_StatTests::standard_normal_cdf(x);
} else if (dn == DenseDistName::Uniform) {
} else if (sd == ScalarDist::Uniform) {
return RandBLAS_StatTests::uniform_syminterval_cdf(x, (T) std::sqrt(3));
} else {
std::string msg = "Unrecognized distributions name";
Expand Down Expand Up @@ -108,59 +108,59 @@ class TestScalarDistributions : public ::testing::Test {
}

template <typename T>
static void run(double significance, int64_t num_samples, DenseDistName dn, uint32_t seed) {
static void run(double significance, int64_t num_samples, ScalarDist sd, uint32_t seed) {
using RandBLAS_StatTests::KolmogorovSmirnovConstants::critical_value_rep_mutator;
auto critical_value = critical_value_rep_mutator(num_samples, significance);
RNGState state(seed);
std::vector<T> samples(num_samples, -1);
RandBLAS::fill_dense({num_samples, 1, dn, RandBLAS::MajorAxis::Long}, samples.data(), state);
kolmogorov_smirnov_tester(samples, critical_value, dn);
RandBLAS::fill_dense({num_samples, 1, sd, RandBLAS::MajorAxis::Long}, samples.data(), state);
kolmogorov_smirnov_tester(samples, critical_value, sd);
return;
}
};

TEST_F(TestScalarDistributions, uniform_ks_generous) {
double s = 1e-6;
for (uint32_t i = 999; i < 1011; ++i) {
run<double>(s, 100000, DenseDistName::Uniform, i);
run<double>(s, 10000, DenseDistName::Uniform, i*i);
run<double>(s, 1000, DenseDistName::Uniform, i*i*i);
run<double>(s, 100000, ScalarDist::Uniform, i);
run<double>(s, 10000, ScalarDist::Uniform, i*i);
run<double>(s, 1000, ScalarDist::Uniform, i*i*i);
}
}

TEST_F(TestScalarDistributions, uniform_ks_moderate) {
double s = 1e-4;
run<float>(s, 100000, DenseDistName::Uniform, 0);
run<float>(s, 10000, DenseDistName::Uniform, 0);
run<float>(s, 1000, DenseDistName::Uniform, 0);
run<float>(s, 100000, ScalarDist::Uniform, 0);
run<float>(s, 10000, ScalarDist::Uniform, 0);
run<float>(s, 1000, ScalarDist::Uniform, 0);
}

TEST_F(TestScalarDistributions, uniform_ks_skeptical) {
double s = 1e-2;
run<float>(s, 100000, DenseDistName::Uniform, 0);
run<float>(s, 10000, DenseDistName::Uniform, 0);
run<float>(s, 1000, DenseDistName::Uniform, 0);
run<float>(s, 100000, ScalarDist::Uniform, 0);
run<float>(s, 10000, ScalarDist::Uniform, 0);
run<float>(s, 1000, ScalarDist::Uniform, 0);
}

TEST_F(TestScalarDistributions, guassian_ks_generous) {
double s = 1e-6;
for (uint32_t i = 99; i < 103; ++i) {
run<double>(s, 100000, DenseDistName::Gaussian, i);
run<double>(s, 10000, DenseDistName::Gaussian, i*i);
run<double>(s, 1000, DenseDistName::Gaussian, i*i*i);
run<double>(s, 100000, ScalarDist::Gaussian, i);
run<double>(s, 10000, ScalarDist::Gaussian, i*i);
run<double>(s, 1000, ScalarDist::Gaussian, i*i*i);
}
}

TEST_F(TestScalarDistributions, guassian_ks_moderate) {
double s = 1e-4;
run<float>(s, 100000, DenseDistName::Gaussian, 0);
run<float>(s, 10000, DenseDistName::Gaussian, 0);
run<float>(s, 1000, DenseDistName::Gaussian, 0);
run<float>(s, 100000, ScalarDist::Gaussian, 0);
run<float>(s, 10000, ScalarDist::Gaussian, 0);
run<float>(s, 1000, ScalarDist::Gaussian, 0);
}

TEST_F(TestScalarDistributions, guassian_ks_skeptical) {
double s = 1e-2;
run<float>(s, 100000, DenseDistName::Gaussian, 0);
run<float>(s, 10000, DenseDistName::Gaussian, 0);
run<float>(s, 1000, DenseDistName::Gaussian, 0);
run<float>(s, 100000, ScalarDist::Gaussian, 0);
run<float>(s, 10000, ScalarDist::Gaussian, 0);
run<float>(s, 1000, ScalarDist::Gaussian, 0);
}
10 changes: 5 additions & 5 deletions test/test_basic_rng/test_distortion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
#include "RandBLAS/util.hh"
#include "RandBLAS/dense_skops.hh"
using RandBLAS::DenseDist;
using RandBLAS::DenseDistName;
using RandBLAS::ScalarDist;
using RandBLAS::RNGState;

#include "rng_common.hh"
Expand All @@ -47,14 +47,14 @@ class TestSubspaceDistortion : public ::testing::Test {
protected:

template <typename T>
void run_general(DenseDistName name, T distortion, int64_t d, int64_t N, uint32_t key) {
void run_general(ScalarDist name, T distortion, int64_t d, int64_t N, uint32_t key) {
auto layout = blas::Layout::ColMajor;
DenseDist D(d, N, name);
std::vector<T> S(d*N);
std::cout << "(d, N) = ( " << d << ", " << N << " )\n";
RandBLAS::RNGState<r123::Philox4x32> state(key);
auto next_state = RandBLAS::fill_dense(D, S.data(), state);
T inv_stddev = (name == DenseDistName::Gaussian) ? (T) 1.0 : (T) 1.0;
T inv_stddev = (name == ScalarDist::Gaussian) ? (T) 1.0 : (T) 1.0;
blas::scal(d*N, inv_stddev / std::sqrt(d), S.data(), 1);
std::vector<T> G(N*N, 0.0);
blas::syrk(layout, blas::Uplo::Upper, blas::Op::Trans, N, d, (T)1.0, S.data(), d, (T)0.0, G.data(), N);
Expand Down Expand Up @@ -100,7 +100,7 @@ class TestSubspaceDistortion : public ::testing::Test {
val *= val;
int64_t N = (int64_t) std::ceil(val);
int64_t d = std::ceil( std::pow((1 + tau) / distortion, 2) * N );
run_general<T>(DenseDistName::Gaussian, distortion, d, N, key);
run_general<T>(ScalarDist::Gaussian, distortion, d, N, key);
return;
}

Expand All @@ -111,7 +111,7 @@ class TestSubspaceDistortion : public ::testing::Test {
T epsnet_spectralnorm_factor = 1.0; // should be 4.0
T theta = epsnet_spectralnorm_factor * c6 * (rate + std::log(9));
int64_t d = std::ceil(N * theta * std::pow(distortion, -2));
run_general<T>(DenseDistName::Uniform, distortion, d, N, key);
run_general<T>(ScalarDist::Uniform, distortion, d, N, key);
return;
}
};
Expand Down
Loading

0 comments on commit e2e05af

Please sign in to comment.