Skip to content

Commit

Permalink
add test for correct incrementing of RNGState in fill_sparse_unpacked…
Browse files Browse the repository at this point in the history
…_nosub. Fix bugs in counter incrementing in the process.
  • Loading branch information
rileyjmurray committed Sep 29, 2024
1 parent 264ce8d commit a2af88a
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 4 deletions.
3 changes: 2 additions & 1 deletion RandBLAS/sparse_skops.hh
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ RNGState<RNG> compute_next_state(SparseDist dist, RNGState<RNG> state) {
// See repeated_fisher_yates.
} else {
num_mavec = std::min(dist.n_rows, dist.n_cols);
incrs_per_mavec = dist.vec_nnz * ((int64_t) state.len_c/2);
incrs_per_mavec = (int64_t) std::ceil((double) dist.vec_nnz / ((double) state.len_c/2));
// ^ LASOs do try to be frugal with CBRNG increments.
// See sample_indices_iid_uniform.
}
Expand Down Expand Up @@ -571,6 +571,7 @@ void fill_sparse(SparseSkOp &S) {
randblas_require(S.cols != nullptr);
randblas_require(S.vals != nullptr);
fill_sparse_unpacked_nosub(S.dist, S.nnz, S.vals, S.rows, S.cols, S.seed_state);
// ^ We ignore the return value from that function call.
return;
}

Expand Down
8 changes: 5 additions & 3 deletions RandBLAS/util.hh
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ void weights_to_cdf(int64_t n, T* w, T error_if_below = -sqrt_epsilon<T>()) {
}

template <typename TO, typename TI>
static inline TO uneg11_to_uneg01(TI in) {
static inline TO uneg11_to_u01(TI in) {
return ((TO) in + (TO) 1.0)/ ((TO) 2.0);
}

Expand Down Expand Up @@ -458,11 +458,12 @@ state_t sample_indices_iid(int64_t n, const T* cdf, int64_t k, sint_t* samples,
rv_array = r123ext::uneg11::generate(gen, ctr, key);
rv_index = 0;
}
auto random_unif01 = uneg11_to_uneg01<T>(rv_array[rv_index]);
auto random_unif01 = uneg11_to_u01<T>(rv_array[rv_index]);
sint_t sample_index = std::lower_bound(cdf, cdf + n, random_unif01) - cdf;
samples[i] = sample_index;
rv_index += 1;
}
ctr.incr(1);
return state_t(ctr, key);
}

Expand All @@ -480,7 +481,7 @@ state_t sample_indices_iid_uniform(int64_t n, int64_t k, sint_t* samples, T* rad
int64_t rv_index = 0;
double dN = (double) n;
for (int64_t i = 0; i < k; ++i) {
auto random_unif01 = uneg11_to_uneg01<double>(rv_array[rv_index]);
auto random_unif01 = uneg11_to_u01<double>(rv_array[rv_index]);
sint_t sample_index = (sint_t) dN * random_unif01;
samples[i] = sample_index;
rv_index += 1;
Expand All @@ -494,6 +495,7 @@ state_t sample_indices_iid_uniform(int64_t n, int64_t k, sint_t* samples, T* rad
rv_index = 0;
}
}
if (rv_index < len_c) ctr.incr(1);
return state_t(ctr, key);
}

Expand Down
36 changes: 36 additions & 0 deletions test/test_datastructures/test_sparseskop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,14 @@
#include <gtest/gtest.h>
#include <cmath>

using std::vector;
using RandBLAS::RNGState;
using RandBLAS::SignedInteger;
using RandBLAS::SparseDist;
using RandBLAS::SparseSkOp;
using RandBLAS::Axis;
using RandBLAS::fill_sparse;
using RandBLAS::fill_sparse_unpacked_nosub;


class TestSparseSkOpConstruction : public ::testing::Test
Expand Down Expand Up @@ -143,6 +145,31 @@ class TestSparseSkOpConstruction : public ::testing::Test
test::comparison::buffs_approx_equal(vals.data(), vals_copy.data(), sd.full_nnz, __PRETTY_FUNCTION__, __FILE__, __LINE__, (T) 0, (T) 0);
return;
}

void unpacked_nosub(const SparseDist &D) {
RNGState<RandBLAS::DefaultRNG> s(1);
SparseSkOp<float> S(D, s);
auto expect_next = S.next_state;
fill_sparse(S);
vector<int64_t> rows(D.full_nnz);
vector<int64_t> cols(D.full_nnz);
vector<float> vals(D.full_nnz);
int64_t nnz = 0;
auto actual_next = fill_sparse_unpacked_nosub(
D, nnz, vals.data(), rows.data(), cols.data(), s
);
EXPECT_EQ(S.nnz, nnz);
EXPECT_TRUE(actual_next == expect_next);
test::comparison::buffs_approx_equal(
vals.data(), S.vals, nnz, __PRETTY_FUNCTION__, __FILE__, __LINE__
);
test::comparison::buffs_approx_equal(
rows.data(), S.rows, nnz, __PRETTY_FUNCTION__, __FILE__, __LINE__
);
test::comparison::buffs_approx_equal(
cols.data(), S.cols, nnz, __PRETTY_FUNCTION__, __FILE__, __LINE__
);
}
};

TEST_F(TestSparseSkOpConstruction, respect_ownership) {
Expand All @@ -155,6 +182,15 @@ TEST_F(TestSparseSkOpConstruction, respect_ownership) {
respect_ownership<int>(7, 20);
}

TEST_F(TestSparseSkOpConstruction, fill_unpacked_nosub_saso) {
unpacked_nosub({10,20,7,Axis::Short});
unpacked_nosub({20,10,7,Axis::Short});
}

TEST_F(TestSparseSkOpConstruction, fill_unpacked_nosub_laso) {
unpacked_nosub({10,20,7,Axis::Long});
unpacked_nosub({20,10,7,Axis::Long});
}

////////////////////////////////////////////////////////////////////////
//
Expand Down

0 comments on commit a2af88a

Please sign in to comment.