Skip to content

Commit

Permalink
heavily revise fill_dense_submat_impl
Browse files Browse the repository at this point in the history
  • Loading branch information
rileyjmurray committed Jul 27, 2024
1 parent 4dae5a0 commit 421197c
Showing 1 changed file with 30 additions and 76 deletions.
106 changes: 30 additions & 76 deletions RandBLAS/dense_skops.hh
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,6 @@

namespace RandBLAS::dense {

template<typename RNG>
bool compare_ctr(typename RNG::ctr_type c1, typename RNG::ctr_type c2) {
int len = c1.size();

for (int ind = len - 1; ind >= 0; ind--) {
if (c1[ind] > c2[ind]) {
return true;
} else if (c1[ind] < c2[ind]) {
return false;
}
}
return false;
}

template <typename RNG, typename DD>
static inline RNGState<RNG> compute_next_state(DD dist, RNGState<RNG> state) {
Expand Down Expand Up @@ -132,101 +119,68 @@ static RNGState<RNG> fill_dense_submat_impl(
RNG rng;
using CTR_t = typename RNG::ctr_type;
using KEY_t = typename RNG::key_type;
CTR_t c = seed.counter;
KEY_t k = seed.key;
int64_t ctr_size = CTR_t::static_size;
const int64_t ctr_size = CTR_t::static_size;

int64_t pad = 0;
// ^ computed such that n_cols+pad is divisible by ctr_size
// ^ computed such that n_cols+pad is divisible by ctr_size
if (n_cols % ctr_size != 0) {
pad = ctr_size - n_cols % ctr_size;
}


int64_t n_cols_padded = n_cols + pad;
// ^ pad as necessary in order to be divisible by ctr_size
int64_t ptr_padded = ptr + ptr / n_cols * pad;
const int64_t ptr_padded = ptr + ptr / n_cols * pad;
// ^ ptr corresponding to the padded matrix
int64_t ctr_mat_start = ptr_padded / ctr_size;
int64_t first_block_start = ptr_padded % ctr_size;
const int64_t ctr_mat_start = ptr_padded / ctr_size;
const int64_t first_block_start = ptr_padded % ctr_size;
// ^ counter and [position within the counter's array] for index "ptr_padded".
int64_t ctr_mat_row_end = (ptr_padded + n_scols - 1) / ctr_size;
int64_t last_block_stop = ((ptr_padded + n_scols - 1) % ctr_size) + 1;
const int64_t ctr_mat_row_end = (ptr_padded + n_scols - 1) / ctr_size;
const int64_t last_block_stop = ((ptr_padded + n_scols - 1) % ctr_size) + 1;
// ^ counter and [1 + position within the counter's array] for index "(ptr_padded + n_scols - 1)".
int64_t ctr_inter_row_stride = n_cols_padded / ctr_size;
const int64_t ctr_inter_row_stride = (n_cols + pad) / ctr_size;
// ^ number of counters between the first counter of a given row to the first counter of the next row;
bool one_block_per_row = ctr_mat_start == ctr_mat_row_end;
int64_t first_block_len = ((one_block_per_row) ? last_block_stop : ctr_size) - first_block_start;
const bool one_block_per_row = ctr_mat_start == ctr_mat_row_end;
const int64_t first_block_len = ((one_block_per_row) ? last_block_stop : ctr_size) - first_block_start;

int64_t num_threads = 1;
#if defined(RandBLAS_HAS_OpenMP)
#pragma omp parallel
{
num_threads = omp_get_num_threads();
}
#endif
CTR_t *ctr_arr = new CTR_t[num_threads]{c};
CTR_t temp_c = seed.counter;
temp_c.incr(ctr_mat_start);
const CTR_t c = temp_c;
const KEY_t k = seed.key;

#pragma omp parallel firstprivate(c, k)
#pragma omp parallel
{

auto cc = c;
int64_t prev = 0;
int64_t row_ctr_start, row_ctr_end;
int64_t thread = 0;

#pragma omp for
#pragma omp for schedule(static)
for (int64_t row = 0; row < n_srows; row++) {

#if defined(RandBLAS_HAS_OpenMP)
thread = omp_get_thread_num();
#endif

int64_t __r01_offset = safe_signed_int_product(ctr_inter_row_stride, row);
row_ctr_start = ctr_mat_start + __r01_offset;
row_ctr_end = ctr_mat_row_end + __r01_offset;

int64_t incr_from_c = safe_signed_int_product(ctr_inter_row_stride, row);

cc.incr(row_ctr_start - prev);
prev = row_ctr_start;
auto rv = OP::generate(rng, cc, k);
auto c_row = c;
c_row.incr(incr_from_c);
auto rv = OP::generate(rng, c_row, k);

T* smat_row = smat + row*lda;
for (int i = 0; i < first_block_len; i++) {
smat_row[i] = rv[i+first_block_start];
}
if ( one_block_per_row )
if ( one_block_per_row ) {
continue;

}
// middle blocks
int64_t ind = first_block_len;
int64_t implicit_ctr = row_ctr_start;
while( implicit_ctr < row_ctr_end - 1) {
cc.incr();
prev++;
rv = OP::generate(rng, cc, k);
for (int i = 0; i < (ctr_mat_row_end - ctr_mat_start - 1); ++i) {
c_row.incr();
rv = OP::generate(rng, c_row, k);
copy_promote(ctr_size, rv, smat_row + ind);
ind = ind + ctr_size;
implicit_ctr++;
}
// last block
cc.incr();
prev++;
rv = OP::generate(rng, cc, k);
c_row.incr();
rv = OP::generate(rng, c_row, k);
copy_promote(last_block_stop, rv, smat_row + ind);
ctr_arr[thread] = cc;
}
}

// find the largest counter in the counter array
CTR_t max_c = ctr_arr[0];
for (int i = 1; i < num_threads; i++) {
if (compare_ctr<RNG>(ctr_arr[i], max_c)) {
max_c = ctr_arr[i];
}
}
delete [] ctr_arr;

max_c.incr();
CTR_t max_c = c;
max_c.incr(n_srows * ctr_inter_row_stride);
return RNGState<RNG> {max_c, k};
}

Expand Down

0 comments on commit 421197c

Please sign in to comment.