Skip to content

Commit

Permalink
SPMM: Remove control flow from BcastA/B
Browse files Browse the repository at this point in the history
Replaced by constraints.

Signed-off-by: Joseph Schuchart <[email protected]>
  • Loading branch information
devreal committed Mar 1, 2024
1 parent 739762f commit 5c8835b
Showing 1 changed file with 13 additions and 223 deletions.
236 changes: 13 additions & 223 deletions examples/spmm/spmm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -293,13 +293,11 @@ class SpMM25D {
Edge<Key<2>, void> a_ctl, b_ctl;
Edge<Key<2>, int> a_rowctl, b_colctl; // TODO: can we have multiple control inputs per TT?
auto constraint = ttg::make_shared_constraint<ttg::SequencedKeysConstraint<Key<2>>>();
bcast_a_ = std::make_unique<BcastA>(a, a_ctl, a_rowctl, local_a_ijk_, a_rows_of_col_, a_cols_of_row_, b_cols_of_row_,
ij_keymap_, ijk_keymap_, parallel_bcasts_);
bcast_a_ = std::make_unique<BcastA>(a, local_a_ijk_, b_cols_of_row_, ij_keymap_, ijk_keymap_);
// add constraint with external mapper: key[1] represents `k`
bcast_a_->add_constraint(constraint, [](const Key<2>& key){ return key[1]; });
local_bcast_a_ = std::make_unique<LocalBcastA>(local_a_ijk_, a_ijk_, b_cols_of_row_, ijk_keymap_);
bcast_b_ = std::make_unique<BcastB>(b, b_ctl, b_colctl, local_b_ijk_, a_rows_of_col_, b_cols_of_row_, b_rows_of_col_,
ij_keymap_, ijk_keymap_, parallel_bcasts_);
bcast_b_ = std::make_unique<BcastB>(b, local_b_ijk_, a_rows_of_col_, ij_keymap_, ijk_keymap_);
// add constraint with external mapper: key[0] represents `k`
bcast_b_->add_constraint(constraint, [](const Key<2>& key){ return key[0]; });
local_bcast_b_ = std::make_unique<LocalBcastB>(local_b_ijk_, b_ijk_, a_rows_of_col_, ijk_keymap_);
Expand Down Expand Up @@ -343,62 +341,6 @@ class SpMM25D {
}
}

/* kick off the first broadcast in each row of A
* this is used to enforce strict ordering within a row of A */
for (int i = 0; i < a_cols_of_row_.size(); ++i) {
for (int k : a_cols_of_row_[i]) {
auto key = Key<2>(i, k);
if (world.rank() == ij_keymap_(key)) {
bcast_a_->template in<1>()->send(key, 0);
break;
}
}
}

/* initial ctl input for a number of bcasts for A
* this is used to limit the number of concurrent bcasts */
int to_start = parallel_bcasts;
for (int k = 0;
0 < to_start && k < a_rows_of_col_.size();
++k) {
for (auto i : a_rows_of_col_[k]) {
auto key = Key<2>(i, k);
if (world.rank() == ij_keymap_(key)) {
//std::cout << "SPMM kick off BcastA " << key << std::endl;
bcast_a_->template in<2>()->sendk(key);
if (0 == --to_start) break;
}
}
}

/* kick off the first broadcast in each column of B
* this is used to enforce strict ordering within a column of B */
for (int j = 0; j < b_rows_of_col_.size(); ++j) {
for (int k : b_rows_of_col_[j]) {
auto key = Key<2>(k, j);
if (world.rank() == ij_keymap_(key)) {
//std::cout << "BcastB kick off " << key << std::endl;
bcast_b_->template in<1>()->send(key, 0);
break;
}
}
}

/* initial ctl input for bcasts for B */
to_start = parallel_bcasts;
for (int k = 0;
0 < to_start && k < b_cols_of_row_.size();
++k) {
for (auto j : b_cols_of_row_[k]) {
auto key = Key<2>(k, j);
if (world.rank() == ij_keymap_(key)) {
//std::cout << "SPMM kick off BcastB " << key << std::endl;
bcast_b_->template in<2>()->sendk(key);
if (0 == --to_start) break;
}
}
}

TTGUNUSED(bcast_a_);
TTGUNUSED(bcast_b_);
TTGUNUSED(multiplyadd_);
Expand Down Expand Up @@ -444,32 +386,24 @@ class SpMM25D {
}; // class LocalBcastA

/// broadcast `A[i][k]` to all processors which will contain at least one `C[i][j]` such that `B[k][j]` exists
class BcastA : public TT<Key<2>, std::tuple<Out<Key<3>, Blk>, Out<Key<2>, int>, Out<Key<2>, void>>, BcastA, ttg::typelist<Blk, int, void>> {
class BcastA : public TT<Key<2>, std::tuple<Out<Key<3>, Blk>>, BcastA, ttg::typelist<Blk>> {
public:
using baseT = typename BcastA::ttT;

BcastA(Edge<Key<2>, Blk> &a_ik, Edge<Key<2>, void> &ctl,
Edge<Key<2>, int> &rowctl, Edge<Key<3>, Blk> &a_ikp,
const std::vector<std::vector<long>> &a_rows_of_col,
const std::vector<std::vector<long>> &a_cols_of_row,
BcastA(Edge<Key<2>, Blk> &a_ik, Edge<Key<3>, Blk> &a_ikp,
const std::vector<std::vector<long>> &b_cols_of_row,
const Keymap2 &ij_keymap, const Keymap3 &ijk_keymap,
const int parallel_bcasts)
: baseT(edges(a_ik, rowctl, ctl), edges(a_ikp, rowctl, ctl), "SpMM25D::bcast_a", {"a_ik", "rowctl", "ctl"}, {"a_ikp", "rowctl", "ctl"}, ij_keymap)
, a_rows_of_col_(a_rows_of_col)
, a_cols_of_row_(a_cols_of_row)
const Keymap2 &ij_keymap, const Keymap3 &ijk_keymap)
: baseT(edges(a_ik), edges(a_ikp), "SpMM25D::bcast_a", {"a_ik"}, {"a_ikp"}, ij_keymap)
, b_cols_of_row_(b_cols_of_row)
, ijk_keymap_(ijk_keymap)
, ij_keymap_(ij_keymap)
, parallel_bcasts_(parallel_bcasts) {
, ijk_keymap_(ijk_keymap) {

this->set_priomap([](const Key<2>& key){
return std::numeric_limits<int>::max() - key[0];
});
}

void op(const Key<2> &ik, typename baseT::input_values_tuple_type &&a_ik,
std::tuple<Out<Key<3>, Blk>, Out<Key<2>, int>, Out<Key<2>, void>> &outs) {
std::tuple<Out<Key<3>, Blk>> &outs) {
const auto i = ik[0]; // row
const auto k = ik[1]; // col
ttg::trace("BcastA(", i, ", ", k, ")");
Expand All @@ -487,81 +421,13 @@ class SpMM25D {
ikp_keys.emplace_back(Key<3>({i, k, p}));
procmap[p] = true;
}
// TODO: debug
//if (p != world.rank() && ij_keymap_(Key<2>{k, j}) != p) {
// std::cout << "[" << world.rank() << "] BCAST A " << ik << " for C update " << Key<3>({i, k, p}) << " on " << p << " has B from " << ij_keymap_(Key<2>{k, j}) << std::endl;
//}
}
::broadcast<0>(ikp_keys, std::move(baseT::template get<0>(a_ik)), outs);

/* enable the next broadcast on this row */
int row = i;
int col = k;
auto rowit = std::find(a_cols_of_row_[row].begin(), a_cols_of_row_[row].end(), col);
for (++rowit; rowit != a_cols_of_row_[row].end(); ++rowit) {
Key<2> key = {row, *rowit};
if (world.rank() == this->get_keymap()(key)) {
::send<1>(key, std::move(baseT::template get<1>(a_ik)), outs);
break;
}
}


/* enable next broadcast through a control message
* we don't check whether this tile is in B here, this is
* done inside the next task (see above)
* we walk the matrix A column-major in an attempt to send from top to bottom, left to right */
long to_skip = parallel_bcasts_;

auto colit = std::find(a_rows_of_col_[col].begin(), a_rows_of_col_[col].end(), row);
++colit; // skip to next row
do {
for (; colit != a_rows_of_col_[col].end(); ++colit) {
Key<2> key = {*colit, col};
if (world.rank() == this->get_keymap()(key)) {
if (0 == --to_skip) {
//std::cout << "BcastA sending to " << key << " from " << ik << std::endl;
::sendk<2>(key, outs);
return;
}
}
}
/* nothing for us in this column, move on to the next column */
if (++col < a_rows_of_col_.size()) {
colit = a_rows_of_col_[col].begin();
} else {
break;
}
} while (1);

#if 0
do {
for (; it != a_cols_of_row_[i].end(); ++it) {
Key<2> key = {i, *it};
if (world.rank() == this->get_keymap()(key)) {
if (0 == --to_skip) {
::sendk<1>(key, outs);
return;
}
}
}
if ((i+1) < num_rows) {
it = a_cols_of_row_[++i].begin();
} else {
break;
}
} while (1);
#endif // 0
}

private:
//const std::vector<std::vector<long>> &a_cols_of_row_;
const std::vector<std::vector<long>> &a_rows_of_col_;
const std::vector<std::vector<long>> &a_cols_of_row_;
const std::vector<std::vector<long>> &b_cols_of_row_;
const Keymap3 &ijk_keymap_;
const Keymap2 &ij_keymap_;
const int parallel_bcasts_;
}; // class BcastA

/// Locally broadcast `B[k][j]` assigned to this processor `p` to matmul tasks `{i,j,k}` for all `k` such that
Expand Down Expand Up @@ -602,30 +468,24 @@ class SpMM25D {
}; // class LocalBcastB

/// broadcast `B[k][j]` to all processors which will contain at least one `C[i][j]` such that `A[i][k]` exists
class BcastB : public TT<Key<2>, std::tuple<Out<Key<3>, Blk>, Out<Key<2>, int>, Out<Key<2>, void>>, BcastB, ttg::typelist<Blk, int, void>> {
class BcastB : public TT<Key<2>, std::tuple<Out<Key<3>, Blk>>, BcastB, ttg::typelist<Blk>> {
public:
using baseT = typename BcastB::ttT;

BcastB(Edge<Key<2>, Blk> &b_kj, Edge<Key<2>, void> ctl, Edge<Key<2>, int> colctl, Edge<Key<3>, Blk> &b_kjp,
BcastB(Edge<Key<2>, Blk> &b_kj, Edge<Key<3>, Blk> &b_kjp,
const std::vector<std::vector<long>> &a_rows_of_col,
const std::vector<std::vector<long>> &b_cols_of_row,
const std::vector<std::vector<long>> &b_rows_of_col,
const Keymap2 &ij_keymap, const Keymap3 &ijk_keymap,
const int parallel_bcasts)
: baseT(edges(b_kj, colctl, ctl), edges(b_kjp, colctl, ctl), "SpMM25D::bcast_b", {"b_kj", "colctl", "ctl"}, {"b_kjp", "colctl", "ctl"}, ij_keymap)
const Keymap2 &ij_keymap, const Keymap3 &ijk_keymap)
: baseT(edges(b_kj), edges(b_kjp), "SpMM25D::bcast_b", {"b_kj"}, {"b_kjp"}, ij_keymap)
, a_rows_of_col_(a_rows_of_col)
, b_cols_of_row_(b_cols_of_row)
, b_rows_of_col_(b_rows_of_col)
, ijk_keymap_(ijk_keymap)
, parallel_bcasts_(parallel_bcasts)
{
this->set_priomap([](const Key<2>& key){
return std::numeric_limits<int>::max() - key[1];
});
}

void op(const Key<2> &kj, typename baseT::input_values_tuple_type &&b_kj,
std::tuple<Out<Key<3>, Blk>, Out<Key<2>, int>, Out<Key<2>, void>> &outs) {
std::tuple<Out<Key<3>, Blk>> &outs) {
const auto k = kj[0]; // row
const auto j = kj[1]; // col
// broadcast b_kj to all processors which will contain at least one c_ij such that a_ik exists
Expand All @@ -644,81 +504,11 @@ class SpMM25D {
}
}
::broadcast<0>(kjp_keys, std::move(baseT::template get<0>(b_kj)), outs);

/* enable the next broadcast on this row */
int row = k;
int col = j;
auto colit = std::find(b_rows_of_col_[col].begin(), b_rows_of_col_[col].end(), row);
for (++colit; colit != b_rows_of_col_[col].end(); ++colit) {
Key<2> key = {*colit, col};
if (world.rank() == this->get_keymap()(key)) {
//std::cout << "BcastB kick off " << key << std::endl;
::send<1>(key, std::move(baseT::template get<1>(b_kj)), outs);
break;
}
}

/* enable next broadcast through a control message
* we don't check whether this tile is in A here, this is
* done inside the next task (see above)
* we run across a row to enable broadcasts */
long to_skip = parallel_bcasts_;

// iterator over the current row
auto rowit = std::find(b_cols_of_row_[row].begin(), b_cols_of_row_[row].end(), col);
++rowit; // skip to next col
do {
for (; rowit != b_cols_of_row_[row].end(); ++rowit) {
Key<2> key = {row, *rowit};
if (world.rank() == this->get_keymap()(key)) {
if (0 == --to_skip) {
//std::cout << "BcastB sending to " << key << " from " << kj << " pb " << parallel_bcasts_ << std::endl;
::sendk<2>(key, outs);
return;
} else {
//std::cout << "BcastB skipping " << key << " from " << kj << " pb " << parallel_bcasts_ << std::endl;
}
}
}
/* nothing for us in this row, move on to the next row */
if (++row != b_cols_of_row_.size()) {
rowit = b_cols_of_row_[row].begin();
} else {
break;
}
} while (1);


#if 0
std::size_t num_rows = b_cols_of_row_.size();
auto it = std::find(b_cols_of_row_[k].begin(), b_cols_of_row_[k].end(), j);
++it; // skip the current tile
long to_skip = parallel_bcasts_;
do {
for (; it != b_cols_of_row_[k].end(); ++it) {
Key<2> key = {k, *it};
if (world.rank() == this->get_keymap()(key)) {
if (0 == --to_skip) {
::sendk<1>(key, outs);
return;
}
}
}
if ((k+1) < num_rows) {
it = b_cols_of_row_[++k].begin();
} else {
break;
}
} while (1);
#endif // 0
}

private:
const std::vector<std::vector<long>> &a_rows_of_col_;
const std::vector<std::vector<long>> &b_cols_of_row_;
const std::vector<std::vector<long>> &b_rows_of_col_;
const Keymap3 &ijk_keymap_;
const int parallel_bcasts_;
}; // class BcastB

/// multiply task has 3 input flows: a_ijk, b_ijk, and c_ijk, c_ijk contains the running total for this layer of the
Expand Down

0 comments on commit 5c8835b

Please sign in to comment.