Skip to content

Commit

Permalink
API tweaks, plus docs
Browse files Browse the repository at this point in the history
  • Loading branch information
rileyjmurray committed Jul 4, 2024
1 parent 4bfe1e0 commit adab6c6
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 25 deletions.
13 changes: 11 additions & 2 deletions RandBLAS/base.hh
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,18 @@
/// code common across the project
namespace RandBLAS {

/**
* Stores stride information for a matrix represented as a buffer.
* The intended semantics for a buffer "A" and the conceptualized
* matrix "mat(A)" are
*
* mat(A)[i, j] == A[i * inter_row_stride + j * inter_col_stride].
*
* for all (i, j) within the bounds of mat(A).
*/
struct stride_64t {
int64_t inter_row_stride;
int64_t inter_col_stride;
int64_t inter_row_stride; // step down a column
int64_t inter_col_stride; // step along a row
};

inline stride_64t layout_to_strides(blas::Layout layout, int64_t ldim) {
Expand Down
49 changes: 30 additions & 19 deletions RandBLAS/util.hh
Original file line number Diff line number Diff line change
Expand Up @@ -244,39 +244,50 @@ void flip_layout(blas::Layout layout_in, int64_t m, int64_t n, std::vector<T> &A


template <typename T>
void weights_to_cdf(T* p, int64_t len_p) {
void weights_to_cdf(int64_t n, T* w, T error_if_below = -std::numeric_limits<T>::epsilon()) {
T sum = 0.0;
for (int64_t i = 0; i < len_p; ++i) {
sum += p[i];
p[i] = sum;
for (int64_t i = 0; i < n; ++i) {
T val = w[i];
randblas_require(val >= error_if_below);
val = std::max(val, (T) 0.0);
sum += val;
w[i] = sum;
}
blas::scal(len_p, ((T)1.0)/sum, p, 1);
randblas_require(sum >= ((T) std::sqrt(n)) * std::numeric_limits<T>::epsilon());
blas::scal(n, ((T)1.0) / sum, w, 1);
return;
}

template <typename TO, typename TI>
static inline TO uneg11_to_uneg01(TI in) {
return ((TO) in + (TO) 1.0)/ ((TO) 2.0);
}

/***
* Assume cdf is a buffer specifying a cumulative probability distribution function.
* cdf represents a cumulative distribution function over {0, ..., n - 1}.
*
* TF is a template parameter for a real floating point type.
*
* This function produces "num_samples" from the distribution specified by "cdf"
* and stores them in "samples".
* We overwrite the "samples" buffer with k (independent) samples from the
* distribution specified by cdf.
*/
template <typename TF, typename int64_t, typename RNG>
RNGState<RNG> sample_indices_iid(
TF* cdf, int64_t len_cdf, int64_t* samples , int64_t num_samples, RandBLAS::RNGState<RNG> state
int64_t n, TF* cdf, int64_t k, int64_t* samples, RandBLAS::RNGState<RNG> state
) {
auto [ctr, key] = state;
RNG gen;
auto rv_array = r123ext::uneg11::generate(gen, ctr, key);
int64_t len_c = (int64_t) state.len_c;
int64_t rv_index = 0;
for (int64_t i = 0; i < num_samples; ++i) {
for (int64_t i = 0; i < k; ++i) {
if ((i+1) % len_c == 1) {
ctr.incr(1);
rv_array = r123ext::uneg11::generate(gen, ctr, key);
rv_index = 0;
}
TF random_unif01 = ((TF) (rv_array[rv_index] + 1.0)) / ((TF) 2.0);
int64_t sample_index = std::lower_bound(cdf, cdf + len_cdf, random_unif01) - cdf;
auto random_unif01 = uneg11_to_uneg01<TF>(rv_array[rv_index]);
int64_t sample_index = std::lower_bound(cdf, cdf + n, random_unif01) - cdf;
// ^ uses binary search to set sample_index to the smallest value for which
// random_unif01 < cdf[sample_index].
samples[i] = sample_index;
Expand All @@ -286,27 +297,27 @@ RNGState<RNG> sample_indices_iid(
}

/***
* This function produces "num_samples" from the uniform distribution over
* {0, ..., max_index_exclusive - 1} and stores them in "samples".
* Overwrite the "samples" buffer with k (independent) samples from the
* uniform distribution over {0, ..., n - 1}.
*/
template <typename int64_t, typename RNG>
RNGState<RNG> sample_indices_iid_uniform(
int64_t max_index_exclusive, int64_t* samples , int64_t num_samples, RandBLAS::RNGState<RNG> state
int64_t n, int64_t* samples , int64_t k, RandBLAS::RNGState<RNG> state
) {
auto [ctr, key] = state;
RNG gen;
auto rv_array = r123ext::uneg11::generate(gen, ctr, key);
int64_t len_c = (int64_t) state.len_c;
int64_t rv_index = 0;
double dmie = (double) max_index_exclusive;
for (int64_t i = 0; i < num_samples; ++i) {
double dN = (double) n;
for (int64_t i = 0; i < k; ++i) {
if ((i+1) % len_c == 1) {
ctr.incr(1);
rv_array = r123ext::uneg11::generate(gen, ctr, key);
rv_index = 0;
}
double random_unif01 = (double) (rv_array[rv_index] + 1.0) / 2.0;
int64_t sample_index = (int64_t) dmie * random_unif01;
auto random_unif01 = uneg11_to_uneg01<double>(rv_array[rv_index]);
int64_t sample_index = (int64_t) dN * random_unif01;
samples[i] = sample_index;
rv_index += 1;
}
Expand Down
8 changes: 4 additions & 4 deletions test/test_basic_rng/test_sample_indices.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class TestSampleIndices : public ::testing::Test
std::vector<float> sample_cdf(N, 0.0);
for (int64_t s : samples)
sample_cdf[s] += 1;
RandBLAS::util::weights_to_cdf(sample_cdf.data(), N);
RandBLAS::util::weights_to_cdf(N, sample_cdf.data());

for (int i = 0; i < num_samples; ++i) {
auto diff = (double) std::abs(sample_cdf[i] - true_cdf[i]);
Expand All @@ -84,7 +84,7 @@ class TestSampleIndices : public ::testing::Test
auto critical_value = critical_value_rep_mutator(num_samples, significance);

std::vector<float> true_cdf(N, 1.0);
RandBLAS::util::weights_to_cdf(true_cdf.data(), N);
RandBLAS::util::weights_to_cdf(N, true_cdf.data());

RNGState state(seed);
std::vector<int64_t> samples(num_samples, -1);
Expand All @@ -102,11 +102,11 @@ class TestSampleIndices : public ::testing::Test
std::vector<float> true_cdf{};
for (int i = 0; i < N; ++i)
true_cdf.push_back(1.0/((float)i + 1.0));
RandBLAS::util::weights_to_cdf(true_cdf.data(), N);
RandBLAS::util::weights_to_cdf(N, true_cdf.data());

RNGState state(seed);
std::vector<int64_t> samples(num_samples, -1);
RandBLAS::util::sample_indices_iid(true_cdf.data(), N, samples.data(), num_samples, state);
RandBLAS::util::sample_indices_iid(N, true_cdf.data(), num_samples, samples.data(), state);

index_set_kolmogorov_smirnov_tester(samples, true_cdf, critical_value);
return;
Expand Down

0 comments on commit adab6c6

Please sign in to comment.