diff --git a/RandBLAS/dense_skops.hh b/RandBLAS/dense_skops.hh index 2bb18869..5d1ce130 100644 --- a/RandBLAS/dense_skops.hh +++ b/RandBLAS/dense_skops.hh @@ -48,12 +48,6 @@ namespace RandBLAS::dense { -template -static inline RNGState compute_next_state(DD dist, RNGState state) { - // Need logic that depends on DenseDistName. - return RNGState(0); -} - template inline void copy_promote(int n, const T_IN &a, T_OUT* b) { for (int i = 0; i < n; ++i) @@ -184,6 +178,25 @@ static RNGState fill_dense_submat_impl( return RNGState {max_c, k}; } +template +RNGState compute_next_state(DD dist, RNGState state) { + if (dist.major_axis == MajorAxis::Undefined) { + // implies dist.family = DenseDistName::BlackBox + return state; + } + int64_t major_len = major_axis_length(dist); + int64_t minor_len = dist.n_rows + (dist.n_cols - major_len); + int64_t ctr_size = RNG::ctr_type::static_size; + int64_t pad = 0; + if (major_len % ctr_size != 0) { + pad = ctr_size - major_len % ctr_size; + } + int64_t ctr_major_axis_stride = (major_len + pad) / ctr_size; + int64_t full_incr = safe_signed_int_product(ctr_major_axis_stride, minor_len); + state.counter.incr(full_incr); + return state; +} + } // end namespace RandBLAS::dense diff --git a/RandBLAS/sparse_skops.hh b/RandBLAS/sparse_skops.hh index cd2ad3cc..47e27919 100644 --- a/RandBLAS/sparse_skops.hh +++ b/RandBLAS/sparse_skops.hh @@ -46,10 +46,6 @@ namespace RandBLAS::sparse { -template -static RNGState compute_next_state(SD dist, RNGState seed_state) { - return RNGState(0); -} // ============================================================================= /// WARNING: this function is not part of the public API. @@ -73,11 +69,11 @@ static RNGState repeated_fisher_yates( auto [ctr, key] = state; for (sint_t i = 0; i < dim_minor; ++i) { sint_t offset = i * vec_nnz; - auto ctri = ctr; - ctri.incr(offset); + auto ctr_work = ctr; + ctr_work.incr(offset); for (sint_t j = 0; j < vec_nnz; ++j) { // one step of Fisher-Yates shuffling - auto rv = gen(ctri, key); + auto rv = gen(ctr_work, key); sint_t ell = j + rv[0] % (dim_major - j); pivots[j] = ell; sint_t swap = vec_work[ell]; @@ -88,7 +84,7 @@ static RNGState repeated_fisher_yates( vals[j + offset] = (rv[1] % 2 == 0) ? 1.0 : -1.0; idxs_minor[j + offset] = (sint_t) i; // increment counter - ctri.incr(); + ctr_work.incr(); } // Restore vec_work for next iteration of Fisher-Yates. // This isn't necessary from a statistical perspective, @@ -101,10 +97,23 @@ static RNGState repeated_fisher_yates( vec_work[jj] = vec_work[ell]; vec_work[ell] = swap; } - ctr = ctri; } return RNGState {ctr, key}; } + +template +static RNGState compute_next_state(SD dist, RNGState state) { + int64_t minor_len; + if (dist.major_axis == MajorAxis::Short) { + minor_len = std::min(dist.n_rows, dist.n_cols); + } else { + minor_len = std::max(dist.n_rows, dist.n_cols); + } + int64_t full_incr = minor_len * dist.vec_nnz; + state.counter.incr(full_incr); + return state; +} + } namespace RandBLAS {