Skip to content

Commit

Permalink
Enforce strict ordering of broadcasts in column/row
Browse files Browse the repository at this point in the history
For matrix A, we make sure that tiles are sent out in column-order (left to right).
For matrix B, we want to make sure that tiles are sent out in row-order (top to bottom).

Signed-off-by: Joseph Schuchart <[email protected]>
  • Loading branch information
devreal committed Feb 2, 2024
1 parent becdbf4 commit 4416f3b
Showing 1 changed file with 188 additions and 58 deletions.
246 changes: 188 additions & 58 deletions examples/spmm/spmm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ using namespace ttg;
using scalar_t = double;
using blk_t = btas::Tensor<scalar_t, btas::DEFAULT::range, btas::mohndle<btas::varray<scalar_t>, btas::Handle::shared_ptr>>;

//#include <atomic>
//static std::atomic<uint64_t> reduce_count = 0;

#if defined(TTG_USE_PARSEC)
namespace ttg {
template <>
Expand Down Expand Up @@ -274,26 +277,29 @@ class SpMM25D {
, ijk_keymap_(std::move(ijk_keymap))
, parallel_bcasts_(parallel_bcasts) {
Edge<Key<2>, void> a_ctl, b_ctl;
bcast_a_ = std::make_unique<BcastA>(a, a_ctl, local_a_ijk_, a_cols_of_row_, b_cols_of_row_,
Edge<Key<2>, int> a_rowctl, b_colctl; // TODO: can we have multiple control inputs per TT?
bcast_a_ = std::make_unique<BcastA>(a, a_ctl, a_rowctl, local_a_ijk_, a_rows_of_col_, a_cols_of_row_, b_cols_of_row_,
ij_keymap_, ijk_keymap_, parallel_bcasts_);
local_bcast_a_ = std::make_unique<LocalBcastA>(local_a_ijk_, a_ijk_, b_cols_of_row_, ijk_keymap_);
bcast_b_ = std::make_unique<BcastB>(b, b_ctl, local_b_ijk_, a_rows_of_col_, b_cols_of_row_,
bcast_b_ = std::make_unique<BcastB>(b, b_ctl, b_colctl, local_b_ijk_, a_rows_of_col_, b_cols_of_row_, b_rows_of_col_,
ij_keymap_, ijk_keymap_, parallel_bcasts_);
local_bcast_b_ = std::make_unique<LocalBcastB>(local_b_ijk_, b_ijk_, a_rows_of_col_, ijk_keymap_);
multiplyadd_ = std::make_unique<MultiplyAdd>(a_ijk_, b_ijk_, c_ijk_, c_ij_p_, a_cols_of_row_,
b_rows_of_col_, mTiles, nTiles, ijk_keymap_);
reduce_c_ = std::make_unique<ReduceC>(c_ij_p_, c, ij_keymap_);
reduce_c_->template set_input_reducer<0>(
[&](Blk &c_ij, const Blk &c_ij_p) {
reduce_count++;
//reduce_count++;
c_ij = c_ij + c_ij_p;
});
// compute how many contributions each C[i][j] should expect ... MultiplyAdd already does this, but need a way to
// send message from each process p to the process owning C[i][j] to expect a contribution from it for now replicate
// this logic ...
// TODO: do this in MultiplyAdd (need to allreduce this info so that everyone has it)
// N.B. only need to set stream size on the rank that will accumulate the C[i][j] contribution
const auto my_rank = ttg::default_execution_context().rank();
auto world = ttg::default_execution_context();
const auto my_rank = world.rank();
std::vector<bool> c_ij_procmask(world.size(), false);
for (auto i = 0ul; i != a_cols_of_row_.size(); ++i) {
if (a_cols_of_row_[i].empty()) continue;
for (auto j = 0ul; j != b_rows_of_col_.size(); ++j) {
Expand All @@ -303,54 +309,76 @@ class SpMM25D {
decltype(i) k;
bool have_k;
std::tie(k, have_k) = multiplyadd_->compute_first_k(i, j);
std::vector<bool> c_ij_procmask(R, false);
if (have_k) {
const auto pR = k % R; // k values are distributed round-robin among the layers of the 3-D grid
while (have_k) {
const auto pR = ijk_keymap_(Key<3>{i, j, k});
assert(pR < c_ij_procmask.size());
c_ij_procmask[pR] = true;
while (have_k) {
std::tie(k, have_k) = multiplyadd_->compute_next_k(i, j, k);
if (have_k) {
const auto pR = k % R;
assert(pR < c_ij_procmask.size());
c_ij_procmask[pR] = true;
}
}
/* get next k */
std::tie(k, have_k) = multiplyadd_->compute_next_k(i, j, k);
}
const auto c_ij_nprocs = std::count_if(c_ij_procmask.begin(), c_ij_procmask.end(), [](bool b) { return b; });
if (c_ij_nprocs > 0) reduce_c_->template set_argstream_size<0>(Key<2>{i, j}, c_ij_nprocs);
/* reset the map */
std::fill(c_ij_procmask.begin(), c_ij_procmask.end(), false);
}
}
}

auto world = ttg::default_execution_context();
/* initial ctl input for bcasts for A */
int to_start = parallel_bcasts;
for (int i = 0;
0 < to_start && i < a_cols_of_row_.size();
++i) {
for (auto k : a_cols_of_row_[i]) {
auto key = Key<2>(i, k);
if (world.rank() == ij_keymap_(key)) {
bcast_a_->template in<1>()->sendk(key);
if (0 == --to_start) break;
}
/* kick off the first broadcast in each row of A
* this is used to enforce strict ordering within a row of A */
for (int i = 0; i < a_cols_of_row_.size(); ++i) {
for (int k : a_cols_of_row_[i]) {
auto key = Key<2>(i, k);
if (world.rank() == ij_keymap_(key)) {
bcast_a_->template in<1>()->send(key, 0);
break;
}
}
}

/* initial ctl input for bcasts for B */
to_start = parallel_bcasts;
for (int k = 0;
0 < to_start && k < b_cols_of_row_.size();
++k) {
for (auto j : b_cols_of_row_[k]) {
auto key = Key<2>(k, j);
if (world.rank() == ij_keymap_(key)) {
bcast_b_->template in<1>()->sendk(key);
if (0 == --to_start) break;
}
/* initial ctl input for a number of bcasts for A
* this is used to limit the number of concurrent bcasts */
int to_start = parallel_bcasts;
for (int k = 0;
0 < to_start && k < a_rows_of_col_.size();
++k) {
for (auto i : a_rows_of_col_[k]) {
auto key = Key<2>(i, k);
if (world.rank() == ij_keymap_(key)) {
//std::cout << "SPMM kick off BcastA " << key << std::endl;
bcast_a_->template in<2>()->sendk(key);
if (0 == --to_start) break;
}
}
}

/* kick off the first broadcast in each column of B
* this is used to enforce strict ordering within a column of B */
for (int j = 0; j < b_rows_of_col_.size(); ++j) {
for (int k : b_rows_of_col_[j]) {
auto key = Key<2>(k, j);
if (world.rank() == ij_keymap_(key)) {
//std::cout << "BcastB kick off " << key << std::endl;
bcast_b_->template in<1>()->send(key, 0);
break;
}
}
}

/* initial ctl input for bcasts for B */
to_start = parallel_bcasts;
for (int k = 0;
0 < to_start && k < b_cols_of_row_.size();
++k) {
for (auto j : b_cols_of_row_[k]) {
auto key = Key<2>(k, j);
if (world.rank() == ij_keymap_(key)) {
//std::cout << "SPMM kick off BcastB " << key << std::endl;
bcast_b_->template in<2>()->sendk(key);
if (0 == --to_start) break;
}
}
}

TTGUNUSED(bcast_a_);
TTGUNUSED(bcast_b_);
Expand Down Expand Up @@ -397,26 +425,34 @@ class SpMM25D {
}; // class LocalBcastA

/// broadcast `A[i][k]` to all processors which will contain at least one `C[i][j]` such that `B[k][j]` exists
class BcastA : public TT<Key<2>, std::tuple<Out<Key<3>, Blk>, Out<Key<2>, void>>, BcastA, ttg::typelist<Blk, void>> {
class BcastA : public TT<Key<2>, std::tuple<Out<Key<3>, Blk>, Out<Key<2>, int>, Out<Key<2>, void>>, BcastA, ttg::typelist<Blk, int, void>> {
public:
using baseT = typename BcastA::ttT;

BcastA(Edge<Key<2>, Blk> &a_ik, Edge<Key<2>, void> &ctl, Edge<Key<3>, Blk> &a_ikp,
BcastA(Edge<Key<2>, Blk> &a_ik, Edge<Key<2>, void> &ctl,
Edge<Key<2>, int> &rowctl, Edge<Key<3>, Blk> &a_ikp,
const std::vector<std::vector<long>> &a_rows_of_col,
const std::vector<std::vector<long>> &a_cols_of_row,
const std::vector<std::vector<long>> &b_cols_of_row,
const Keymap2 &ij_keymap, const Keymap3 &ijk_keymap,
const int parallel_bcasts)
: baseT(edges(a_ik, ctl), edges(a_ikp, ctl), "SpMM25D::bcast_a", {"a_ik", "ctl"}, {"a_ikp", "ctl"}, ij_keymap)
: baseT(edges(a_ik, rowctl, ctl), edges(a_ikp, rowctl, ctl), "SpMM25D::bcast_a", {"a_ik", "rowctl", "ctl"}, {"a_ikp", "rowctl", "ctl"}, ij_keymap)
, a_rows_of_col_(a_rows_of_col)
, a_cols_of_row_(a_cols_of_row)
, b_cols_of_row_(b_cols_of_row)
, ijk_keymap_(ijk_keymap)
, ij_keymap_(ij_keymap)
, parallel_bcasts_(parallel_bcasts) {}
, parallel_bcasts_(parallel_bcasts) {

this->set_priomap([](const Key<2>& key){
return std::numeric_limits<int>::max() - key[0];
});
}

void op(const Key<2> &ik, typename baseT::input_values_tuple_type &&a_ik,
std::tuple<Out<Key<3>, Blk>, Out<Key<2>, void>> &outs) {
auto i = ik[0]; // row
auto k = ik[1]; // col
std::tuple<Out<Key<3>, Blk>, Out<Key<2>, int>, Out<Key<2>, void>> &outs) {
const auto i = ik[0]; // row
const auto k = ik[1]; // col
ttg::trace("BcastA(", i, ", ", k, ")");
std::vector<Key<3>> ikp_keys;

Expand All @@ -428,19 +464,58 @@ class SpMM25D {
{i, j, k})); // N.B. in 2.5D SUMMA different k contributions to C[i][j] are computed on different nodes
if (!procmap[p]) {
ttg::trace("Broadcasting A[", i, "][", k, "] to proc ", p);
//std::cout << "[" << world.rank() << "] BcastA key " << ik << " op " << Key<3>({i, j, k}) << " to proc " << p << std::endl;
ikp_keys.emplace_back(Key<3>({i, k, p}));
procmap[p] = true;
}
// TODO: debug
//if (p != world.rank() && ij_keymap_(Key<2>{k, j}) != p) {
// std::cout << "[" << world.rank() << "] BCAST A " << ik << " for C update " << Key<3>({i, k, p}) << " on " << p << " has B from " << ij_keymap_(Key<2>{k, j}) << std::endl;
//}
}
::broadcast<0>(ikp_keys, std::move(baseT::template get<0>(a_ik)), outs);

/* enable the next broadcast on this row */
int row = i;
int col = k;
auto rowit = std::find(a_cols_of_row_[row].begin(), a_cols_of_row_[row].end(), col);
for (++rowit; rowit != a_cols_of_row_[row].end(); ++rowit) {
Key<2> key = {row, *rowit};
if (world.rank() == this->get_keymap()(key)) {
::send<1>(key, std::move(baseT::template get<1>(a_ik)), outs);
break;
}
}


/* enable next broadcast through a control message
* we don't check whether this tile is in B here, this is
* done inside the next task (see above) */
std::size_t num_rows = a_cols_of_row_.size();
auto it = std::find(a_cols_of_row_[i].begin(), a_cols_of_row_[i].end(), k);
++it; // skip the current tile
* done inside the next task (see above)
* we walk the matrix A column-major in an attempt to send from top to bottom, left to right */
long to_skip = parallel_bcasts_;

auto colit = std::find(a_rows_of_col_[col].begin(), a_rows_of_col_[col].end(), row);
++colit; // skip to next row
do {
for (; colit != a_rows_of_col_[col].end(); ++colit) {
Key<2> key = {*colit, col};
if (world.rank() == this->get_keymap()(key)) {
if (0 == --to_skip) {
//std::cout << "BcastA sending to " << key << " from " << ik << std::endl;
::sendk<2>(key, outs);
return;
}
}
}
/* nothing for us in this column, move on to the next column */
if (++col < a_rows_of_col_.size()) {
colit = a_rows_of_col_[col].begin();
} else {
break;
}
} while (1);

#if 0
do {
for (; it != a_cols_of_row_[i].end(); ++it) {
Key<2> key = {i, *it};
Expand All @@ -457,9 +532,12 @@ class SpMM25D {
break;
}
} while (1);
#endif // 0
}

private:
//const std::vector<std::vector<long>> &a_cols_of_row_;
const std::vector<std::vector<long>> &a_rows_of_col_;
const std::vector<std::vector<long>> &a_cols_of_row_;
const std::vector<std::vector<long>> &b_cols_of_row_;
const Keymap3 &ijk_keymap_;
Expand Down Expand Up @@ -505,25 +583,32 @@ class SpMM25D {
}; // class LocalBcastB

/// broadcast `B[k][j]` to all processors which will contain at least one `C[i][j]` such that `A[i][k]` exists
class BcastB : public TT<Key<2>, std::tuple<Out<Key<3>, Blk>, Out<Key<2>, void>>, BcastB, ttg::typelist<Blk, void>> {
class BcastB : public TT<Key<2>, std::tuple<Out<Key<3>, Blk>, Out<Key<2>, int>, Out<Key<2>, void>>, BcastB, ttg::typelist<Blk, int, void>> {
public:
using baseT = typename BcastB::ttT;

BcastB(Edge<Key<2>, Blk> &b_kj, Edge<Key<2>, void> ctl, Edge<Key<3>, Blk> &b_kjp,
BcastB(Edge<Key<2>, Blk> &b_kj, Edge<Key<2>, void> ctl, Edge<Key<2>, int> colctl, Edge<Key<3>, Blk> &b_kjp,
const std::vector<std::vector<long>> &a_rows_of_col,
const std::vector<std::vector<long>> &b_cols_of_row,
const std::vector<std::vector<long>> &b_rows_of_col,
const Keymap2 &ij_keymap, const Keymap3 &ijk_keymap,
const int parallel_bcasts)
: baseT(edges(b_kj, ctl), edges(b_kjp, ctl), "SpMM25D::bcast_b", {"b_kj", "ctl"}, {"b_kjp", "ctl"}, ij_keymap)
: baseT(edges(b_kj, colctl, ctl), edges(b_kjp, colctl, ctl), "SpMM25D::bcast_b", {"b_kj", "colctl", "ctl"}, {"b_kjp", "colctl", "ctl"}, ij_keymap)
, a_rows_of_col_(a_rows_of_col)
, b_cols_of_row_(b_cols_of_row)
, b_rows_of_col_(b_rows_of_col)
, ijk_keymap_(ijk_keymap)
, parallel_bcasts_(parallel_bcasts) {}
, parallel_bcasts_(parallel_bcasts)
{
this->set_priomap([](const Key<2>& key){
return std::numeric_limits<int>::max() - key[1];
});
}

void op(const Key<2> &kj, typename baseT::input_values_tuple_type &&b_kj,
std::tuple<Out<Key<3>, Blk>, Out<Key<2>, void>> &outs) {
auto k = kj[0]; // row
auto j = kj[1]; // col
std::tuple<Out<Key<3>, Blk>, Out<Key<2>, int>, Out<Key<2>, void>> &outs) {
const auto k = kj[0]; // row
const auto j = kj[1]; // col
// broadcast b_kj to all processors which will contain at least one c_ij such that a_ik exists
std::vector<Key<3>> kjp_keys;
ttg::trace("BcastB(", k, ", ", j, ")");
Expand All @@ -534,15 +619,58 @@ class SpMM25D {
long p = ijk_keymap_(Key<3>({i, j, k}));
if (!procmap[p]) {
ttg::trace("Broadcasting B[", k, "][", j, "] to proc ", p);
//std::cout << "[" << world.rank() << "] BcastB key " << kj << " op " << Key<3>({i, j, k}) << " to proc " << p << std::endl;
kjp_keys.emplace_back(Key<3>({k, j, p}));
procmap[p] = true;
}
}
::broadcast<0>(kjp_keys, std::move(baseT::template get<0>(b_kj)), outs);

/* enable the next broadcast on this row */
int row = k;
int col = j;
auto colit = std::find(b_rows_of_col_[col].begin(), b_rows_of_col_[col].end(), row);
for (++colit; colit != b_rows_of_col_[col].end(); ++colit) {
Key<2> key = {*colit, col};
if (world.rank() == this->get_keymap()(key)) {
//std::cout << "BcastB kick off " << key << std::endl;
::send<1>(key, std::move(baseT::template get<1>(b_kj)), outs);
break;
}
}

/* enable next broadcast through a control message
* we don't check whether this tile is in A here, this is
* done inside the next task (see above) */
* done inside the next task (see above)
* we run across a row to enable broadcasts */
long to_skip = parallel_bcasts_;

// iterator over the current row
auto rowit = std::find(b_cols_of_row_[row].begin(), b_cols_of_row_[row].end(), col);
++rowit; // skip to next col
do {
for (; rowit != b_cols_of_row_[row].end(); ++rowit) {
Key<2> key = {row, *rowit};
if (world.rank() == this->get_keymap()(key)) {
if (0 == --to_skip) {
//std::cout << "BcastB sending to " << key << " from " << kj << " pb " << parallel_bcasts_ << std::endl;
::sendk<2>(key, outs);
return;
} else {
//std::cout << "BcastB skipping " << key << " from " << kj << " pb " << parallel_bcasts_ << std::endl;
}
}
}
/* nothing for us in this row, move on to the next row */
if (++row != b_cols_of_row_.size()) {
rowit = b_cols_of_row_[row].begin();
} else {
break;
}
} while (1);


#if 0
std::size_t num_rows = b_cols_of_row_.size();
auto it = std::find(b_cols_of_row_[k].begin(), b_cols_of_row_[k].end(), j);
++it; // skip the current tile
Expand All @@ -563,11 +691,13 @@ class SpMM25D {
break;
}
} while (1);
#endif // 0
}

private:
const std::vector<std::vector<long>> &a_rows_of_col_;
const std::vector<std::vector<long>> &b_cols_of_row_;
const std::vector<std::vector<long>> &b_rows_of_col_;
const Keymap3 &ijk_keymap_;
const int parallel_bcasts_;
}; // class BcastB
Expand Down Expand Up @@ -1376,7 +1506,7 @@ static void timed_measurement(SpMatrix<> &A, SpMatrix<> &B, const std::function<
SpMatrix<> C;
C.resize(MT, NT);

reduce_count = 0;
//reduce_count = 0;

// flow graph needs to exist on every node
Edge<Key<3>> ctl("control");
Expand Down

0 comments on commit 4416f3b

Please sign in to comment.