From 4416f3bf41fe5f4a01b34eb457001fe90e1103bb Mon Sep 17 00:00:00 2001 From: Joseph Schuchart Date: Fri, 2 Feb 2024 17:40:48 -0500 Subject: [PATCH] Enforce strict ordering of broadcasts in column/row For matrix A, we make sure that tiles are sent out in column-order (left to right). For matrix B, we want to make sure that tiles are sent out in row-order (top to bottom). Signed-off-by: Joseph Schuchart --- examples/spmm/spmm.cc | 246 ++++++++++++++++++++++++++++++++---------- 1 file changed, 188 insertions(+), 58 deletions(-) diff --git a/examples/spmm/spmm.cc b/examples/spmm/spmm.cc index f13d8aad2..5806f4548 100644 --- a/examples/spmm/spmm.cc +++ b/examples/spmm/spmm.cc @@ -43,6 +43,9 @@ using namespace ttg; using scalar_t = double; using blk_t = btas::Tensor, btas::Handle::shared_ptr>>; +//#include +//static std::atomic reduce_count = 0; + #if defined(TTG_USE_PARSEC) namespace ttg { template <> @@ -274,10 +277,11 @@ class SpMM25D { , ijk_keymap_(std::move(ijk_keymap)) , parallel_bcasts_(parallel_bcasts) { Edge, void> a_ctl, b_ctl; - bcast_a_ = std::make_unique(a, a_ctl, local_a_ijk_, a_cols_of_row_, b_cols_of_row_, + Edge, int> a_rowctl, b_colctl; // TODO: can we have multiple control inputs per TT? + bcast_a_ = std::make_unique(a, a_ctl, a_rowctl, local_a_ijk_, a_rows_of_col_, a_cols_of_row_, b_cols_of_row_, ij_keymap_, ijk_keymap_, parallel_bcasts_); local_bcast_a_ = std::make_unique(local_a_ijk_, a_ijk_, b_cols_of_row_, ijk_keymap_); - bcast_b_ = std::make_unique(b, b_ctl, local_b_ijk_, a_rows_of_col_, b_cols_of_row_, + bcast_b_ = std::make_unique(b, b_ctl, b_colctl, local_b_ijk_, a_rows_of_col_, b_cols_of_row_, b_rows_of_col_, ij_keymap_, ijk_keymap_, parallel_bcasts_); local_bcast_b_ = std::make_unique(local_b_ijk_, b_ijk_, a_rows_of_col_, ijk_keymap_); multiplyadd_ = std::make_unique(a_ijk_, b_ijk_, c_ijk_, c_ij_p_, a_cols_of_row_, @@ -285,7 +289,7 @@ class SpMM25D { reduce_c_ = std::make_unique(c_ij_p_, c, ij_keymap_); reduce_c_->template set_input_reducer<0>( [&](Blk &c_ij, const Blk &c_ij_p) { - reduce_count++; + //reduce_count++; c_ij = c_ij + c_ij_p; }); // compute how many contributions each C[i][j] should expect ... MultiplyAdd already does this, but need a way to @@ -293,7 +297,9 @@ class SpMM25D { // this logic ... // TODO: do this in MultiplyAdd (need to allreduce this info so that everyone has it) // N.B. only need to set stream size on the rank that will accumulate the C[i][j] contribution - const auto my_rank = ttg::default_execution_context().rank(); + auto world = ttg::default_execution_context(); + const auto my_rank = world.rank(); + std::vector c_ij_procmask(world.size(), false); for (auto i = 0ul; i != a_cols_of_row_.size(); ++i) { if (a_cols_of_row_[i].empty()) continue; for (auto j = 0ul; j != b_rows_of_col_.size(); ++j) { @@ -303,54 +309,76 @@ class SpMM25D { decltype(i) k; bool have_k; std::tie(k, have_k) = multiplyadd_->compute_first_k(i, j); - std::vector c_ij_procmask(R, false); - if (have_k) { - const auto pR = k % R; // k values are distributed round-robin among the layers of the 3-D grid + while (have_k) { + const auto pR = ijk_keymap_(Key<3>{i, j, k}); assert(pR < c_ij_procmask.size()); c_ij_procmask[pR] = true; - while (have_k) { - std::tie(k, have_k) = multiplyadd_->compute_next_k(i, j, k); - if (have_k) { - const auto pR = k % R; - assert(pR < c_ij_procmask.size()); - c_ij_procmask[pR] = true; - } - } + /* get next k */ + std::tie(k, have_k) = multiplyadd_->compute_next_k(i, j, k); } const auto c_ij_nprocs = std::count_if(c_ij_procmask.begin(), c_ij_procmask.end(), [](bool b) { return b; }); if (c_ij_nprocs > 0) reduce_c_->template set_argstream_size<0>(Key<2>{i, j}, c_ij_nprocs); + /* reset the map */ + std::fill(c_ij_procmask.begin(), c_ij_procmask.end(), false); } } } - auto world = ttg::default_execution_context(); - /* initial ctl input for bcasts for A */ - int to_start = parallel_bcasts; - for (int i = 0; - 0 < to_start && i < a_cols_of_row_.size(); - ++i) { - for (auto k : a_cols_of_row_[i]) { - auto key = Key<2>(i, k); - if (world.rank() == ij_keymap_(key)) { - bcast_a_->template in<1>()->sendk(key); - if (0 == --to_start) break; - } + /* kick off the first broadcast in each row of A + * this is used to enforce strict ordering within a row of A */ + for (int i = 0; i < a_cols_of_row_.size(); ++i) { + for (int k : a_cols_of_row_[i]) { + auto key = Key<2>(i, k); + if (world.rank() == ij_keymap_(key)) { + bcast_a_->template in<1>()->send(key, 0); + break; } } + } - /* initial ctl input for bcasts for B */ - to_start = parallel_bcasts; - for (int k = 0; - 0 < to_start && k < b_cols_of_row_.size(); - ++k) { - for (auto j : b_cols_of_row_[k]) { - auto key = Key<2>(k, j); - if (world.rank() == ij_keymap_(key)) { - bcast_b_->template in<1>()->sendk(key); - if (0 == --to_start) break; - } + /* initial ctl input for a number of bcasts for A + * this is used to limit the number of concurrent bcasts */ + int to_start = parallel_bcasts; + for (int k = 0; + 0 < to_start && k < a_rows_of_col_.size(); + ++k) { + for (auto i : a_rows_of_col_[k]) { + auto key = Key<2>(i, k); + if (world.rank() == ij_keymap_(key)) { + //std::cout << "SPMM kick off BcastA " << key << std::endl; + bcast_a_->template in<2>()->sendk(key); + if (0 == --to_start) break; + } + } + } + + /* kick off the first broadcast in each column of B + * this is used to enforce strict ordering within a column of B */ + for (int j = 0; j < b_rows_of_col_.size(); ++j) { + for (int k : b_rows_of_col_[j]) { + auto key = Key<2>(k, j); + if (world.rank() == ij_keymap_(key)) { + //std::cout << "BcastB kick off " << key << std::endl; + bcast_b_->template in<1>()->send(key, 0); + break; + } + } + } + + /* initial ctl input for bcasts for B */ + to_start = parallel_bcasts; + for (int k = 0; + 0 < to_start && k < b_cols_of_row_.size(); + ++k) { + for (auto j : b_cols_of_row_[k]) { + auto key = Key<2>(k, j); + if (world.rank() == ij_keymap_(key)) { + //std::cout << "SPMM kick off BcastB " << key << std::endl; + bcast_b_->template in<2>()->sendk(key); + if (0 == --to_start) break; } } + } TTGUNUSED(bcast_a_); TTGUNUSED(bcast_b_); @@ -397,26 +425,34 @@ class SpMM25D { }; // class LocalBcastA /// broadcast `A[i][k]` to all processors which will contain at least one `C[i][j]` such that `B[k][j]` exists - class BcastA : public TT, std::tuple, Blk>, Out, void>>, BcastA, ttg::typelist> { + class BcastA : public TT, std::tuple, Blk>, Out, int>, Out, void>>, BcastA, ttg::typelist> { public: using baseT = typename BcastA::ttT; - BcastA(Edge, Blk> &a_ik, Edge, void> &ctl, Edge, Blk> &a_ikp, + BcastA(Edge, Blk> &a_ik, Edge, void> &ctl, + Edge, int> &rowctl, Edge, Blk> &a_ikp, + const std::vector> &a_rows_of_col, const std::vector> &a_cols_of_row, const std::vector> &b_cols_of_row, const Keymap2 &ij_keymap, const Keymap3 &ijk_keymap, const int parallel_bcasts) - : baseT(edges(a_ik, ctl), edges(a_ikp, ctl), "SpMM25D::bcast_a", {"a_ik", "ctl"}, {"a_ikp", "ctl"}, ij_keymap) + : baseT(edges(a_ik, rowctl, ctl), edges(a_ikp, rowctl, ctl), "SpMM25D::bcast_a", {"a_ik", "rowctl", "ctl"}, {"a_ikp", "rowctl", "ctl"}, ij_keymap) + , a_rows_of_col_(a_rows_of_col) , a_cols_of_row_(a_cols_of_row) , b_cols_of_row_(b_cols_of_row) , ijk_keymap_(ijk_keymap) , ij_keymap_(ij_keymap) - , parallel_bcasts_(parallel_bcasts) {} + , parallel_bcasts_(parallel_bcasts) { + + this->set_priomap([](const Key<2>& key){ + return std::numeric_limits::max() - key[0]; + }); + } void op(const Key<2> &ik, typename baseT::input_values_tuple_type &&a_ik, - std::tuple, Blk>, Out, void>> &outs) { - auto i = ik[0]; // row - auto k = ik[1]; // col + std::tuple, Blk>, Out, int>, Out, void>> &outs) { + const auto i = ik[0]; // row + const auto k = ik[1]; // col ttg::trace("BcastA(", i, ", ", k, ")"); std::vector> ikp_keys; @@ -428,19 +464,58 @@ class SpMM25D { {i, j, k})); // N.B. in 2.5D SUMMA different k contributions to C[i][j] are computed on different nodes if (!procmap[p]) { ttg::trace("Broadcasting A[", i, "][", k, "] to proc ", p); + //std::cout << "[" << world.rank() << "] BcastA key " << ik << " op " << Key<3>({i, j, k}) << " to proc " << p << std::endl; ikp_keys.emplace_back(Key<3>({i, k, p})); procmap[p] = true; } + // TODO: debug + //if (p != world.rank() && ij_keymap_(Key<2>{k, j}) != p) { + // std::cout << "[" << world.rank() << "] BCAST A " << ik << " for C update " << Key<3>({i, k, p}) << " on " << p << " has B from " << ij_keymap_(Key<2>{k, j}) << std::endl; + //} } ::broadcast<0>(ikp_keys, std::move(baseT::template get<0>(a_ik)), outs); + /* enable the next broadcast on this row */ + int row = i; + int col = k; + auto rowit = std::find(a_cols_of_row_[row].begin(), a_cols_of_row_[row].end(), col); + for (++rowit; rowit != a_cols_of_row_[row].end(); ++rowit) { + Key<2> key = {row, *rowit}; + if (world.rank() == this->get_keymap()(key)) { + ::send<1>(key, std::move(baseT::template get<1>(a_ik)), outs); + break; + } + } + + /* enable next broadcast through a control message * we don't check whether this tile is in B here, this is - * done inside the next task (see above) */ - std::size_t num_rows = a_cols_of_row_.size(); - auto it = std::find(a_cols_of_row_[i].begin(), a_cols_of_row_[i].end(), k); - ++it; // skip the current tile + * done inside the next task (see above) + * we walk the matrix A column-major in an attempt to send from top to bottom, left to right */ long to_skip = parallel_bcasts_; + + auto colit = std::find(a_rows_of_col_[col].begin(), a_rows_of_col_[col].end(), row); + ++colit; // skip to next row + do { + for (; colit != a_rows_of_col_[col].end(); ++colit) { + Key<2> key = {*colit, col}; + if (world.rank() == this->get_keymap()(key)) { + if (0 == --to_skip) { + //std::cout << "BcastA sending to " << key << " from " << ik << std::endl; + ::sendk<2>(key, outs); + return; + } + } + } + /* nothing for us in this column, move on to the next column */ + if (++col < a_rows_of_col_.size()) { + colit = a_rows_of_col_[col].begin(); + } else { + break; + } + } while (1); + +#if 0 do { for (; it != a_cols_of_row_[i].end(); ++it) { Key<2> key = {i, *it}; @@ -457,9 +532,12 @@ class SpMM25D { break; } } while (1); +#endif // 0 } private: + //const std::vector> &a_cols_of_row_; + const std::vector> &a_rows_of_col_; const std::vector> &a_cols_of_row_; const std::vector> &b_cols_of_row_; const Keymap3 &ijk_keymap_; @@ -505,25 +583,32 @@ class SpMM25D { }; // class LocalBcastB /// broadcast `B[k][j]` to all processors which will contain at least one `C[i][j]` such that `A[i][k]` exists - class BcastB : public TT, std::tuple, Blk>, Out, void>>, BcastB, ttg::typelist> { + class BcastB : public TT, std::tuple, Blk>, Out, int>, Out, void>>, BcastB, ttg::typelist> { public: using baseT = typename BcastB::ttT; - BcastB(Edge, Blk> &b_kj, Edge, void> ctl, Edge, Blk> &b_kjp, + BcastB(Edge, Blk> &b_kj, Edge, void> ctl, Edge, int> colctl, Edge, Blk> &b_kjp, const std::vector> &a_rows_of_col, const std::vector> &b_cols_of_row, + const std::vector> &b_rows_of_col, const Keymap2 &ij_keymap, const Keymap3 &ijk_keymap, const int parallel_bcasts) - : baseT(edges(b_kj, ctl), edges(b_kjp, ctl), "SpMM25D::bcast_b", {"b_kj", "ctl"}, {"b_kjp", "ctl"}, ij_keymap) + : baseT(edges(b_kj, colctl, ctl), edges(b_kjp, colctl, ctl), "SpMM25D::bcast_b", {"b_kj", "colctl", "ctl"}, {"b_kjp", "colctl", "ctl"}, ij_keymap) , a_rows_of_col_(a_rows_of_col) , b_cols_of_row_(b_cols_of_row) + , b_rows_of_col_(b_rows_of_col) , ijk_keymap_(ijk_keymap) - , parallel_bcasts_(parallel_bcasts) {} + , parallel_bcasts_(parallel_bcasts) + { + this->set_priomap([](const Key<2>& key){ + return std::numeric_limits::max() - key[1]; + }); + } void op(const Key<2> &kj, typename baseT::input_values_tuple_type &&b_kj, - std::tuple, Blk>, Out, void>> &outs) { - auto k = kj[0]; // row - auto j = kj[1]; // col + std::tuple, Blk>, Out, int>, Out, void>> &outs) { + const auto k = kj[0]; // row + const auto j = kj[1]; // col // broadcast b_kj to all processors which will contain at least one c_ij such that a_ik exists std::vector> kjp_keys; ttg::trace("BcastB(", k, ", ", j, ")"); @@ -534,15 +619,58 @@ class SpMM25D { long p = ijk_keymap_(Key<3>({i, j, k})); if (!procmap[p]) { ttg::trace("Broadcasting B[", k, "][", j, "] to proc ", p); + //std::cout << "[" << world.rank() << "] BcastB key " << kj << " op " << Key<3>({i, j, k}) << " to proc " << p << std::endl; kjp_keys.emplace_back(Key<3>({k, j, p})); procmap[p] = true; } } ::broadcast<0>(kjp_keys, std::move(baseT::template get<0>(b_kj)), outs); + /* enable the next broadcast on this row */ + int row = k; + int col = j; + auto colit = std::find(b_rows_of_col_[col].begin(), b_rows_of_col_[col].end(), row); + for (++colit; colit != b_rows_of_col_[col].end(); ++colit) { + Key<2> key = {*colit, col}; + if (world.rank() == this->get_keymap()(key)) { + //std::cout << "BcastB kick off " << key << std::endl; + ::send<1>(key, std::move(baseT::template get<1>(b_kj)), outs); + break; + } + } + /* enable next broadcast through a control message * we don't check whether this tile is in A here, this is - * done inside the next task (see above) */ + * done inside the next task (see above) + * we run across a row to enable broadcasts */ + long to_skip = parallel_bcasts_; + + // iterator over the current row + auto rowit = std::find(b_cols_of_row_[row].begin(), b_cols_of_row_[row].end(), col); + ++rowit; // skip to next col + do { + for (; rowit != b_cols_of_row_[row].end(); ++rowit) { + Key<2> key = {row, *rowit}; + if (world.rank() == this->get_keymap()(key)) { + if (0 == --to_skip) { + //std::cout << "BcastB sending to " << key << " from " << kj << " pb " << parallel_bcasts_ << std::endl; + ::sendk<2>(key, outs); + return; + } else { + //std::cout << "BcastB skipping " << key << " from " << kj << " pb " << parallel_bcasts_ << std::endl; + } + } + } + /* nothing for us in this row, move on to the next row */ + if (++row != b_cols_of_row_.size()) { + rowit = b_cols_of_row_[row].begin(); + } else { + break; + } + } while (1); + + +#if 0 std::size_t num_rows = b_cols_of_row_.size(); auto it = std::find(b_cols_of_row_[k].begin(), b_cols_of_row_[k].end(), j); ++it; // skip the current tile @@ -563,11 +691,13 @@ class SpMM25D { break; } } while (1); +#endif // 0 } private: const std::vector> &a_rows_of_col_; const std::vector> &b_cols_of_row_; + const std::vector> &b_rows_of_col_; const Keymap3 &ijk_keymap_; const int parallel_bcasts_; }; // class BcastB @@ -1376,7 +1506,7 @@ static void timed_measurement(SpMatrix<> &A, SpMatrix<> &B, const std::function< SpMatrix<> C; C.resize(MT, NT); - reduce_count = 0; + //reduce_count = 0; // flow graph needs to exist on every node Edge> ctl("control");