From adab6c667ea7c13cb12bc05164a8e566ce0fdc03 Mon Sep 17 00:00:00 2001 From: Riley Murray Date: Wed, 3 Jul 2024 21:16:09 -0400 Subject: [PATCH] API tweaks, plus docs --- RandBLAS/base.hh | 13 +++++- RandBLAS/util.hh | 49 +++++++++++++--------- test/test_basic_rng/test_sample_indices.cc | 8 ++-- 3 files changed, 45 insertions(+), 25 deletions(-) diff --git a/RandBLAS/base.hh b/RandBLAS/base.hh index 62aa456a..a2382e80 100644 --- a/RandBLAS/base.hh +++ b/RandBLAS/base.hh @@ -53,9 +53,18 @@ /// code common across the project namespace RandBLAS { +/** + * Stores stride information for a matrix represented as a buffer. + * The intended semantics for a buffer "A" and the conceptualized + * matrix "mat(A)" are + * + * mat(A)[i, j] == A[i * inter_row_stride + j * inter_col_stride]. + * + * for all (i, j) within the bounds of mat(A). + */ struct stride_64t { - int64_t inter_row_stride; - int64_t inter_col_stride; + int64_t inter_row_stride; // step down a column + int64_t inter_col_stride; // step along a row }; inline stride_64t layout_to_strides(blas::Layout layout, int64_t ldim) { diff --git a/RandBLAS/util.hh b/RandBLAS/util.hh index 81babb0b..4dadf795 100644 --- a/RandBLAS/util.hh +++ b/RandBLAS/util.hh @@ -244,39 +244,50 @@ void flip_layout(blas::Layout layout_in, int64_t m, int64_t n, std::vector &A template -void weights_to_cdf(T* p, int64_t len_p) { +void weights_to_cdf(int64_t n, T* w, T error_if_below = -std::numeric_limits::epsilon()) { T sum = 0.0; - for (int64_t i = 0; i < len_p; ++i) { - sum += p[i]; - p[i] = sum; + for (int64_t i = 0; i < n; ++i) { + T val = w[i]; + randblas_require(val >= error_if_below); + val = std::max(val, (T) 0.0); + sum += val; + w[i] = sum; } - blas::scal(len_p, ((T)1.0)/sum, p, 1); + randblas_require(sum >= ((T) std::sqrt(n)) * std::numeric_limits::epsilon()); + blas::scal(n, ((T)1.0) / sum, w, 1); + return; +} + +template +static inline TO uneg11_to_uneg01(TI in) { + return ((TO) in + (TO) 1.0)/ ((TO) 2.0); } /*** - * Assume cdf is a buffer specifying a cumulative probability distribution function. + * cdf represents a cumulative distribution function over {0, ..., n - 1}. + * * TF is a template parameter for a real floating point type. * - * This function produces "num_samples" from the distribution specified by "cdf" - * and stores them in "samples". + * We overwrite the "samples" buffer with k (independent) samples from the + * distribution specified by cdf. */ template RNGState sample_indices_iid( - TF* cdf, int64_t len_cdf, int64_t* samples , int64_t num_samples, RandBLAS::RNGState state + int64_t n, TF* cdf, int64_t k, int64_t* samples, RandBLAS::RNGState state ) { auto [ctr, key] = state; RNG gen; auto rv_array = r123ext::uneg11::generate(gen, ctr, key); int64_t len_c = (int64_t) state.len_c; int64_t rv_index = 0; - for (int64_t i = 0; i < num_samples; ++i) { + for (int64_t i = 0; i < k; ++i) { if ((i+1) % len_c == 1) { ctr.incr(1); rv_array = r123ext::uneg11::generate(gen, ctr, key); rv_index = 0; } - TF random_unif01 = ((TF) (rv_array[rv_index] + 1.0)) / ((TF) 2.0); - int64_t sample_index = std::lower_bound(cdf, cdf + len_cdf, random_unif01) - cdf; + auto random_unif01 = uneg11_to_uneg01(rv_array[rv_index]); + int64_t sample_index = std::lower_bound(cdf, cdf + n, random_unif01) - cdf; // ^ uses binary search to set sample_index to the smallest value for which // random_unif01 < cdf[sample_index]. samples[i] = sample_index; @@ -286,27 +297,27 @@ RNGState sample_indices_iid( } /*** - * This function produces "num_samples" from the uniform distribution over - * {0, ..., max_index_exclusive - 1} and stores them in "samples". + * Overwrite the "samples" buffer with k (independent) samples from the + * uniform distribution over {0, ..., n - 1}. */ template RNGState sample_indices_iid_uniform( - int64_t max_index_exclusive, int64_t* samples , int64_t num_samples, RandBLAS::RNGState state + int64_t n, int64_t* samples , int64_t k, RandBLAS::RNGState state ) { auto [ctr, key] = state; RNG gen; auto rv_array = r123ext::uneg11::generate(gen, ctr, key); int64_t len_c = (int64_t) state.len_c; int64_t rv_index = 0; - double dmie = (double) max_index_exclusive; - for (int64_t i = 0; i < num_samples; ++i) { + double dN = (double) n; + for (int64_t i = 0; i < k; ++i) { if ((i+1) % len_c == 1) { ctr.incr(1); rv_array = r123ext::uneg11::generate(gen, ctr, key); rv_index = 0; } - double random_unif01 = (double) (rv_array[rv_index] + 1.0) / 2.0; - int64_t sample_index = (int64_t) dmie * random_unif01; + auto random_unif01 = uneg11_to_uneg01(rv_array[rv_index]); + int64_t sample_index = (int64_t) dN * random_unif01; samples[i] = sample_index; rv_index += 1; } diff --git a/test/test_basic_rng/test_sample_indices.cc b/test/test_basic_rng/test_sample_indices.cc index aaea689a..08557a8a 100644 --- a/test/test_basic_rng/test_sample_indices.cc +++ b/test/test_basic_rng/test_sample_indices.cc @@ -70,7 +70,7 @@ class TestSampleIndices : public ::testing::Test std::vector sample_cdf(N, 0.0); for (int64_t s : samples) sample_cdf[s] += 1; - RandBLAS::util::weights_to_cdf(sample_cdf.data(), N); + RandBLAS::util::weights_to_cdf(N, sample_cdf.data()); for (int i = 0; i < num_samples; ++i) { auto diff = (double) std::abs(sample_cdf[i] - true_cdf[i]); @@ -84,7 +84,7 @@ class TestSampleIndices : public ::testing::Test auto critical_value = critical_value_rep_mutator(num_samples, significance); std::vector true_cdf(N, 1.0); - RandBLAS::util::weights_to_cdf(true_cdf.data(), N); + RandBLAS::util::weights_to_cdf(N, true_cdf.data()); RNGState state(seed); std::vector samples(num_samples, -1); @@ -102,11 +102,11 @@ class TestSampleIndices : public ::testing::Test std::vector true_cdf{}; for (int i = 0; i < N; ++i) true_cdf.push_back(1.0/((float)i + 1.0)); - RandBLAS::util::weights_to_cdf(true_cdf.data(), N); + RandBLAS::util::weights_to_cdf(N, true_cdf.data()); RNGState state(seed); std::vector samples(num_samples, -1); - RandBLAS::util::sample_indices_iid(true_cdf.data(), N, samples.data(), num_samples, state); + RandBLAS::util::sample_indices_iid(N, true_cdf.data(), num_samples, samples.data(), state); index_set_kolmogorov_smirnov_tester(samples, true_cdf, critical_value); return;