Skip to content

Commit

Permalink
Read{A,B} can use MultipleK schedule
Browse files Browse the repository at this point in the history
  • Loading branch information
evaleev committed Feb 2, 2024
1 parent 1ffaf6a commit 3dad2c1
Showing 1 changed file with 31 additions and 14 deletions.
45 changes: 31 additions & 14 deletions examples/spmm/spmm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -238,22 +238,39 @@ class ReadA : public TT<Key<2>, std::tuple<Out<Key<2>, Blk>>, ReadA<KSchedule, B
// this assumes col-major layout of SpMatrix
static_assert(SpMatrix<Blk>::IsRowMajor == false, "SpMatrix must be col-major");

// MultipleK schedule is not yet correctly implemented
static_assert(KSchedule == ReadSchedule::SingleK, "MultipleK schedule not yet implemented");

// loop over blocks of k at a time, block size controlled by KSchedule
const int k_blk_size = (KSchedule == ReadSchedule::SingleK) ? 1 : R_;

// this keeps iterators over each k-column in this block
// there is only 1 task per process, so no need to synchronize access to this
static std::vector<typename SpMatrix<Blk>::InnerIterator> column_iterators;

for (std::pair<int, int> k_blk = {0, std::min(k_blk_size, K)}; k_blk.first < K;
k_blk = {k_blk.first + k_blk_size, std::min(k_blk.first + k_blk_size + k_blk_size, K)}) {

// for each k in the block send one A[i][k], then next i, etc. this means keep track of iterators over k-column
// for each k in the block
// N.B. : due to the CSC layout of A iterating over (blocks of) columns is efficient
column_iterators.resize(0);
for (int k = k_blk.first; k < k_blk.second; ++k) {
for (typename SpMatrix<Blk>::InnerIterator it(matrix_, k); it; ++it) {
assert(k == it.col());
const auto i = it.row();
// IF the receiver uses the same keymap, these sends are local
if (rank == this->get_keymap()(Key<2>(std::initializer_list<long>({i, k})))) {
::send<0>(Key<2>(std::initializer_list<long>({i, k})), it.value(), out);
column_iterators.emplace_back(matrix_, k);
}

int k_remaining = k_blk.second - k_blk.first;
while(k_remaining != 0) {
for (int k = k_blk.first, k_in_blk = 0; k < k_blk.second; ++k, ++k_in_blk) {
if (auto& it = column_iterators[k_in_blk]) {
assert(k == it.col());
const auto i = it.row();
// IF the receiver uses the same keymap, these sends are local
if (rank == this->get_keymap()(Key<2>(std::initializer_list<long>({i, k})))) {
::send<0>(Key<2>(std::initializer_list<long>({i, k})), it.value(), out);
}
++it;
}
else { // this k is done
assert(k_remaining != 0);
--k_remaining;
}
}
}
Expand Down Expand Up @@ -286,13 +303,10 @@ class ReadB : public TT<Key<2>, std::tuple<Out<Key<2>, Blk>>, ReadB<KSchedule, B
// this assumes col-major layout of SpMatrix
static_assert(SpMatrix<Blk>::IsRowMajor == false, "SpMatrix must be col-major");

// MultipleK schedule is not yet correctly implemented
static_assert(KSchedule == ReadSchedule::SingleK, "MultipleK schedule not yet implemented");

// loop over blocks of k at a time, block size controlled by KSchedule
const int k_blk_size = (KSchedule == ReadSchedule::SingleK) ? 1 : R_;
for (std::pair<int, int> k_blk = {0, std::min(k_blk_size, K)}; k_blk.first < K;
k_blk = {k_blk.first + k_blk_size, std::min(k_blk.first + k_blk_size + k_blk_size, K)}) {
k_blk = {k_blk.first + k_blk_size, std::min(k_blk.first + k_blk_size + k_blk_size, K)}) {

// WARNING : due to the CSC layout of B iterating over (blocks of) columns is inefficient
for (int j = 0; j < matrix_.outerSize(); ++j) {
Expand Down Expand Up @@ -580,7 +594,7 @@ class SpMM25D {
const Keymap3 &ijk_keymap_;
}; // class BcastB

/// multiply task has 3 input flows: a_ijk, b_ijk, and c_ijk, c_ijk contains the running total for this kayer of the
/// multiply task has 3 input flows: a_ijk, b_ijk, and c_ijk, c_ijk contains the running total for this layer of the
/// 3-D process grid only
class MultiplyAdd : public TT<Key<3>, std::tuple<Out<Key<2>, Blk>, Out<Key<3>, Blk>>, MultiplyAdd,
ttg::typelist<const Blk, const Blk, Blk>> {
Expand Down Expand Up @@ -1691,6 +1705,9 @@ int main(int argc, char **argv) {
Edge<Key<2>, blk_t> eA, eB, eC;
Read_SpMatrix a("A", A, ctl, eA, ij_keymap);
Read_SpMatrix b("B", B, ctl, eB, ij_keymap);
// uncomment this to use more intelligent schedule of reads
// ReadA<> a(A, ctl, eA, ij_keymap, R);
// ReadB<> b(B, ctl, eB, ij_keymap, R);
Write_SpMatrix<> c(C, eC, keymap_write);
auto &c_status = c.status();
assert(!has_value(c_status));
Expand Down

0 comments on commit 3dad2c1

Please sign in to comment.