Skip to content

Commit

Permalink
Extend sequence constraint to enable/disable auto-release
Browse files Browse the repository at this point in the history
Auto-release makes sure there are no deadlocks by enabling the next
set of keys once the current ordinal is done. Without auto-release
applications must release the next set explitly and ensure there
are no deadlocks.

Signed-off-by: Joseph Schuchart <[email protected]>
  • Loading branch information
devreal committed Mar 4, 2024
1 parent 5c8835b commit 52f2725
Show file tree
Hide file tree
Showing 2 changed files with 182 additions and 88 deletions.
48 changes: 44 additions & 4 deletions examples/spmm/spmm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ using namespace ttg;

#include "ttg/util/bug.h"

#define USE_AUTO_CONSTRAINT false

#if defined(BLOCK_SPARSE_GEMM) && defined(BTAS_IS_USABLE)
using scalar_t = double;
using blk_t = btas::Tensor<scalar_t, btas::DEFAULT::range, btas::mohndle<btas::varray<scalar_t>, btas::Handle::shared_ptr>>;
Expand Down Expand Up @@ -292,7 +294,7 @@ class SpMM25D {
, parallel_bcasts_(parallel_bcasts) {
Edge<Key<2>, void> a_ctl, b_ctl;
Edge<Key<2>, int> a_rowctl, b_colctl; // TODO: can we have multiple control inputs per TT?
auto constraint = ttg::make_shared_constraint<ttg::SequencedKeysConstraint<Key<2>>>();
auto constraint = ttg::make_shared_constraint<ttg::SequencedKeysConstraint<Key<2>>>(USE_AUTO_CONSTRAINT);
bcast_a_ = std::make_unique<BcastA>(a, local_a_ijk_, b_cols_of_row_, ij_keymap_, ijk_keymap_);
// add constraint with external mapper: key[1] represents `k`
bcast_a_->add_constraint(constraint, [](const Key<2>& key){ return key[1]; });
Expand All @@ -302,7 +304,8 @@ class SpMM25D {
bcast_b_->add_constraint(constraint, [](const Key<2>& key){ return key[0]; });
local_bcast_b_ = std::make_unique<LocalBcastB>(local_b_ijk_, b_ijk_, a_rows_of_col_, ijk_keymap_);
multiplyadd_ = std::make_unique<MultiplyAdd>(a_ijk_, b_ijk_, c_ijk_, c_ij_p_, a_cols_of_row_,
b_rows_of_col_, mTiles, nTiles, ijk_keymap_);
b_rows_of_col_, mTiles, nTiles, ijk_keymap_, constraint, k_cnt_);

reduce_c_ = std::make_unique<ReduceC>(c_ij_p_, c, ij_keymap_);
reduce_c_->template set_input_reducer<0>(
[&](Blk &c_ij, const Blk &c_ij_p) {
Expand All @@ -317,6 +320,9 @@ class SpMM25D {
auto world = ttg::default_execution_context();
const auto my_rank = world.rank();
std::vector<bool> c_ij_procmask(world.size(), false);
std::vector<unsigned long> first_k_map(world.size(), std::numeric_limits<unsigned long>::max());
std::size_t max_k = a_rows_of_col_.size();
k_cnt_.resize(max_k+1, false);
for (auto i = 0ul; i != a_cols_of_row_.size(); ++i) {
if (a_cols_of_row_[i].empty()) continue;
for (auto j = 0ul; j != b_rows_of_col_.size(); ++j) {
Expand All @@ -326,10 +332,15 @@ class SpMM25D {
decltype(i) k;
bool have_k;
std::tie(k, have_k) = multiplyadd_->compute_first_k(i, j);
if (have_k) {
k_cnt_[k] = true;
}
while (have_k) {
const auto pR = ijk_keymap_(Key<3>{i, j, k});
assert(pR < c_ij_procmask.size());
c_ij_procmask[pR] = true;
// find the first k that is needed from us by this rank
first_k_map[pR] = std::min(first_k_map[pR], k);
/* get next k */
std::tie(k, have_k) = multiplyadd_->compute_next_k(i, j, k);
}
Expand All @@ -341,6 +352,17 @@ class SpMM25D {
}
}

k_cnt_.push_back(true); // we always want to release the last k

// find the maximum k for which we need to release the broadcast constraint
unsigned long first_k = 0;
for (auto k : first_k_map) {
if (k != std::numeric_limits<unsigned long>::max()) {
first_k = std::max(first_k, k);
}
}
constraint->release(first_k);

TTGUNUSED(bcast_a_);
TTGUNUSED(bcast_b_);
TTGUNUSED(multiplyadd_);
Expand Down Expand Up @@ -521,11 +543,15 @@ class SpMM25D {
MultiplyAdd(Edge<Key<3>, Blk> &a_ijk, Edge<Key<3>, Blk> &b_ijk, Edge<Key<3>, Blk> &c_ijk, Edge<Key<2>, Blk> &c,
const std::vector<std::vector<long>> &a_cols_of_row,
const std::vector<std::vector<long>> &b_rows_of_col, const std::vector<int> &mTiles,
const std::vector<int> &nTiles, const Keymap3 &ijk_keymap)
const std::vector<int> &nTiles, const Keymap3 &ijk_keymap,
std::shared_ptr<ttg::SequencedKeysConstraint<Key<2>>> constraint,
std::vector<bool>& k_cnt)
: baseT(edges(a_ijk, b_ijk, c_ijk), edges(c, c_ijk), "SpMM25D::MultiplyAdd", {"a_ijk", "b_ijk", "c_ijk"},
{"c_ij", "c_ijk"}, ijk_keymap)
, a_cols_of_row_(a_cols_of_row)
, b_rows_of_col_(b_rows_of_col) {
, b_rows_of_col_(b_rows_of_col)
, k_cnt_(k_cnt)
, constraint(std::move(constraint)){
this->set_priomap([=,this](const Key<3> &ijk) { return this->prio(ijk); }); // map a key to an integral priority value

// for each {i,j} determine first k that contributes AND belongs to this node,
Expand Down Expand Up @@ -569,6 +595,17 @@ class SpMM25D {
" C[",
i, "][", j, "] += A[", i, "][", k, "] by B[", k, "][", j, "], next_k? ",
(have_next_k ? std::to_string(next_k) : "does not exist"));
// release the constraint on the next round of broadcasts
{
std::size_t release_k = k;
while (release_k < k_cnt_.size()) {
++release_k;
if (k_cnt_[release_k])
break;
}
constraint->release(release_k);
}

// compute the contrib, pass the running total to the next flow, if needed
// otherwise write to the result flow
if (have_next_k) {
Expand All @@ -588,6 +625,8 @@ class SpMM25D {
private:
const std::vector<std::vector<long>> &a_cols_of_row_;
const std::vector<std::vector<long>> &b_rows_of_col_;
std::vector<bool>& k_cnt_;
std::shared_ptr<ttg::SequencedKeysConstraint<Key<2>>> constraint;

/* Compute the length of the remaining sequence on that tile */
int32_t prio(const Key<3> &key) const {
Expand Down Expand Up @@ -733,6 +772,7 @@ class SpMM25D {
std::unique_ptr<LocalBcastB> local_bcast_b_;
std::unique_ptr<MultiplyAdd> multiplyadd_;
std::unique_ptr<ReduceC> reduce_c_;
std::vector<bool> k_cnt_;
Keymap2 ij_keymap_;
Keymap3 ijk_keymap_;
long parallel_bcasts_;
Expand Down
Loading

0 comments on commit 52f2725

Please sign in to comment.