diff --git a/examples/spmm/spmm.cc b/examples/spmm/spmm.cc index 5d745c067..a0c2d858c 100644 --- a/examples/spmm/spmm.cc +++ b/examples/spmm/spmm.cc @@ -39,6 +39,8 @@ using namespace ttg; #include "ttg/util/bug.h" +#define USE_AUTO_CONSTRAINT false + #if defined(BLOCK_SPARSE_GEMM) && defined(BTAS_IS_USABLE) using scalar_t = double; using blk_t = btas::Tensor, btas::Handle::shared_ptr>>; @@ -292,7 +294,7 @@ class SpMM25D { , parallel_bcasts_(parallel_bcasts) { Edge, void> a_ctl, b_ctl; Edge, int> a_rowctl, b_colctl; // TODO: can we have multiple control inputs per TT? - auto constraint = ttg::make_shared_constraint>>(); + auto constraint = ttg::make_shared_constraint>>(USE_AUTO_CONSTRAINT); bcast_a_ = std::make_unique(a, local_a_ijk_, b_cols_of_row_, ij_keymap_, ijk_keymap_); // add constraint with external mapper: key[1] represents `k` bcast_a_->add_constraint(constraint, [](const Key<2>& key){ return key[1]; }); @@ -302,7 +304,8 @@ class SpMM25D { bcast_b_->add_constraint(constraint, [](const Key<2>& key){ return key[0]; }); local_bcast_b_ = std::make_unique(local_b_ijk_, b_ijk_, a_rows_of_col_, ijk_keymap_); multiplyadd_ = std::make_unique(a_ijk_, b_ijk_, c_ijk_, c_ij_p_, a_cols_of_row_, - b_rows_of_col_, mTiles, nTiles, ijk_keymap_); + b_rows_of_col_, mTiles, nTiles, ijk_keymap_, constraint, k_cnt_); + reduce_c_ = std::make_unique(c_ij_p_, c, ij_keymap_); reduce_c_->template set_input_reducer<0>( [&](Blk &c_ij, const Blk &c_ij_p) { @@ -317,6 +320,9 @@ class SpMM25D { auto world = ttg::default_execution_context(); const auto my_rank = world.rank(); std::vector c_ij_procmask(world.size(), false); + std::vector first_k_map(world.size(), std::numeric_limits::max()); + std::size_t max_k = a_rows_of_col_.size(); + k_cnt_.resize(max_k+1, false); for (auto i = 0ul; i != a_cols_of_row_.size(); ++i) { if (a_cols_of_row_[i].empty()) continue; for (auto j = 0ul; j != b_rows_of_col_.size(); ++j) { @@ -326,10 +332,15 @@ class SpMM25D { decltype(i) k; bool have_k; std::tie(k, have_k) = multiplyadd_->compute_first_k(i, j); + if (have_k) { + k_cnt_[k] = true; + } while (have_k) { const auto pR = ijk_keymap_(Key<3>{i, j, k}); assert(pR < c_ij_procmask.size()); c_ij_procmask[pR] = true; + // find the first k that is needed from us by this rank + first_k_map[pR] = std::min(first_k_map[pR], k); /* get next k */ std::tie(k, have_k) = multiplyadd_->compute_next_k(i, j, k); } @@ -341,6 +352,17 @@ class SpMM25D { } } + k_cnt_.push_back(true); // we always want to release the last k + + // find the maximum k for which we need to release the broadcast constraint + unsigned long first_k = 0; + for (auto k : first_k_map) { + if (k != std::numeric_limits::max()) { + first_k = std::max(first_k, k); + } + } + constraint->release(first_k); + TTGUNUSED(bcast_a_); TTGUNUSED(bcast_b_); TTGUNUSED(multiplyadd_); @@ -521,11 +543,15 @@ class SpMM25D { MultiplyAdd(Edge, Blk> &a_ijk, Edge, Blk> &b_ijk, Edge, Blk> &c_ijk, Edge, Blk> &c, const std::vector> &a_cols_of_row, const std::vector> &b_rows_of_col, const std::vector &mTiles, - const std::vector &nTiles, const Keymap3 &ijk_keymap) + const std::vector &nTiles, const Keymap3 &ijk_keymap, + std::shared_ptr>> constraint, + std::vector& k_cnt) : baseT(edges(a_ijk, b_ijk, c_ijk), edges(c, c_ijk), "SpMM25D::MultiplyAdd", {"a_ijk", "b_ijk", "c_ijk"}, {"c_ij", "c_ijk"}, ijk_keymap) , a_cols_of_row_(a_cols_of_row) - , b_rows_of_col_(b_rows_of_col) { + , b_rows_of_col_(b_rows_of_col) + , k_cnt_(k_cnt) + , constraint(std::move(constraint)){ this->set_priomap([=,this](const Key<3> &ijk) { return this->prio(ijk); }); // map a key to an integral priority value // for each {i,j} determine first k that contributes AND belongs to this node, @@ -569,6 +595,17 @@ class SpMM25D { " C[", i, "][", j, "] += A[", i, "][", k, "] by B[", k, "][", j, "], next_k? ", (have_next_k ? std::to_string(next_k) : "does not exist")); + // release the constraint on the next round of broadcasts + { + std::size_t release_k = k; + while (release_k < k_cnt_.size()) { + ++release_k; + if (k_cnt_[release_k]) + break; + } + constraint->release(release_k); + } + // compute the contrib, pass the running total to the next flow, if needed // otherwise write to the result flow if (have_next_k) { @@ -588,6 +625,8 @@ class SpMM25D { private: const std::vector> &a_cols_of_row_; const std::vector> &b_rows_of_col_; + std::vector& k_cnt_; + std::shared_ptr>> constraint; /* Compute the length of the remaining sequence on that tile */ int32_t prio(const Key<3> &key) const { @@ -733,6 +772,7 @@ class SpMM25D { std::unique_ptr local_bcast_b_; std::unique_ptr multiplyadd_; std::unique_ptr reduce_c_; + std::vector k_cnt_; Keymap2 ij_keymap_; Keymap3 ijk_keymap_; long parallel_bcasts_; diff --git a/ttg/ttg/constraint.h b/ttg/ttg/constraint.h index f624b1ed8..199fb5ed6 100644 --- a/ttg/ttg/constraint.h +++ b/ttg/ttg/constraint.h @@ -63,7 +63,7 @@ namespace ttg { template, + typename Compare = std::less_equal, typename Mapper = ttg::Void> struct SequencedKeysConstraint : public ConstraintBase { @@ -73,7 +73,7 @@ namespace ttg { using compare_t = Compare; using base_t = ConstraintBase; - private: + protected: struct sequence_elem_t { std::map> m_keys; @@ -89,43 +89,22 @@ namespace ttg { } }; - void release_next() { - if (m_stopped) { - // don't release tasks if we're stopped - return; - } - // trigger the next sequence - sequence_elem_t elem; - { - // extract the next sequence - auto g = this->lock_guard(); - auto it = m_sequence.begin(); // sequence is ordered by ordinal - if (it == m_sequence.end()) { - return; // nothing to be done - } - m_current = it->first; - elem = std::move(it->second); - m_sequence.erase(it); - } - - for (auto& seq : elem.m_keys) { - // account for the newly active keys - m_active.fetch_add(seq.second.size(), std::memory_order_relaxed); - this->notify_listener(std::span(seq.second.data(), seq.second.size()), seq.first); - } - } - bool check_key_impl(const key_type& key, Ordinal ord, ttg::TTBase *tt) { if (!m_stopped) { if (m_order(ord, m_current)) { // key should be executed - m_active.fetch_add(1, std::memory_order_relaxed); + if (m_auto_release) { // only needed for auto-release + m_active.fetch_add(1, std::memory_order_relaxed); + } // reset current to the lower ordinal m_current = ord; return true; - } else if (m_sequence.empty() && 0 == m_active.load(std::memory_order_relaxed)) { + } else if (m_sequence.empty() && m_auto_release && 0 == m_active.load(std::memory_order_relaxed)) { // there are no keys (active or blocked) so we execute to avoid a deadlock // we don't change the current ordinal because there may be lower ordinals coming in later + // NOTE: there is a race condition between the check here and the increment above. + // This is mostly benign as it can lead to out-of-sequence released tasks. + // Avoiding this race would incur significant overheads. m_active.fetch_add(1, std::memory_order_relaxed); return true; } @@ -146,67 +125,99 @@ namespace ttg { return false; } + void complete_key_impl() { - auto active = m_active.fetch_sub(1, std::memory_order_relaxed) - 1; - if (0 == active) { - release_next(); + if (m_auto_release) { + auto active = m_active.fetch_sub(1, std::memory_order_relaxed) - 1; + if (0 == active) { + release_next(); + } + } + } + + // used in the auto case + void release_next() { + if (this->m_stopped) { + // don't release tasks if we're stopped + return; + } + // trigger the next sequence + sequence_elem_t elem; + { + // extract the next sequence + auto g = this->lock_guard(); + auto it = this->m_sequence.begin(); // sequence is ordered by ordinal + if (it == this->m_sequence.end()) { + return; // nothing to be done + } + this->m_current = it->first; + elem = std::move(it->second); + this->m_sequence.erase(it); + } + + for (auto& seq : elem.m_keys) { + // account for the newly active keys + this->m_active.fetch_add(seq.second.size(), std::memory_order_relaxed); + this->notify_listener(std::span(seq.second.data(), seq.second.size()), seq.first); } } + // used in the non-auto case + void release_next(ordinal_type ord, bool force_check = false) { + if (this->m_stopped) { + // don't release tasks if we're stopped + return; + } + if (!force_check && m_order(ord, this->m_current)) { + return; // already at the provided ordinal, nothing to be done + } + // set current ordinal + this->m_current = ord; + // trigger the next sequence(s) (m_sequence is ordered by ordinal) + std::vector seqs; + { + auto g = this->lock_guard(); + for (auto it = this->m_sequence.begin(); it != this->m_sequence.end(); it = this->m_sequence.begin()) { + if (!this->m_order(it->first, this->m_current)) break; + // extract the next sequence + this->m_current = it->first; + seqs.push_back(std::move(it->second)); + this->m_sequence.erase(it); + } + } + for (auto& elem : seqs) { + for (auto& e : elem.m_keys) { + // account for the newly active keys + this->notify_listener(std::span(e.second.data(), e.second.size()), e.first); + } + } + } + public: /** * Used for external key mapper. */ - SequencedKeysConstraint() + SequencedKeysConstraint(bool auto_release = false) : base_t() + , m_auto_release(auto_release) { } template, Mapper_>> - SequencedKeysConstraint(Mapper_&& map) + SequencedKeysConstraint(Mapper_&& map, bool auto_release) : base_t() , m_map(std::forward(map)) + , m_auto_release(auto_release) { } - SequencedKeysConstraint(SequencedKeysConstraint&& skc) - : base_t(std::move(skc)) - , m_sequence(std::move(skc.m_sequence)) - , m_active(skc.m_active.load(std::memory_order_relaxed)) - , m_current(std::move(skc.m_current)) - , m_map(std::move(skc.m_map)) - , m_order(std::move(skc.m_order)) - , m_stopped(skc.m_stopped) - { } + SequencedKeysConstraint(SequencedKeysConstraint&& skc) = default; - SequencedKeysConstraint(const SequencedKeysConstraint& skc) - : base_t(skc) - , m_sequence(skc.m_sequence) - , m_active(skc.m_active.load(std::memory_order_relaxed)) - , m_current(skc.m_current) - , m_map(skc.m_map) - , m_order(skc.m_order) - , m_stopped(skc.m_stopped) - { } + SequencedKeysConstraint(const SequencedKeysConstraint& skc) = default; - SequencedKeysConstraint& operator=(SequencedKeysConstraint&& skc) { - base_t::operator=(std::move(skc)); - m_sequence = std::move(skc.m_sequence); - m_active = skc.m_active.load(std::memory_order_relaxed); - m_current = std::move(skc.m_current); - m_map = std::move(skc.m_map); - m_order = std::move(skc.m_order); - m_stopped = skc.m_stopped; - } - SequencedKeysConstraint& operator=(const SequencedKeysConstraint& skc) { - base_t::operator=(skc); - m_sequence = skc.m_sequence; - m_active = skc.m_active.load(std::memory_order_relaxed); - m_current = skc.m_current; - m_map = skc.m_map; - m_order = skc.m_order; - m_stopped = skc.m_stopped; - } + SequencedKeysConstraint& operator=(SequencedKeysConstraint&& skc) = default; + + SequencedKeysConstraint& operator=(const SequencedKeysConstraint& skc) = default; virtual ~SequencedKeysConstraint() = default; @@ -217,49 +228,49 @@ namespace ttg { std::enable_if_t && !ttg::meta::is_void_v, bool> check(const key_type& key, ttg::TTBase *tt) { ordinal_type ord = m_map(key); - return check_key_impl(key, ord, tt); + return this->check_key_impl(key, ord, tt); } template std::enable_if_t && ttg::meta::is_void_v, bool> check(const key_type& key, Ordinal ord, ttg::TTBase *tt) { - return check_key_impl(key, ord, tt); + return this->check_key_impl(key, ord, tt); } template std::enable_if_t && !ttg::meta::is_void_v, bool> check(ttg::TTBase *tt) { - return check_key_impl(ttg::Void{}, m_map(), tt); + return this->check_key_impl(ttg::Void{}, m_map(), tt); } template std::enable_if_t && ttg::meta::is_void_v, bool> check(ordinal_type ord, ttg::TTBase *tt) { - return check_key_impl(ttg::Void{}, ord, tt); + return this->check_key_impl(ttg::Void{}, ord, tt); } template std::enable_if_t && !ttg::meta::is_void_v> complete(const key_type& key, ttg::TTBase *tt) { - complete_key_impl(); + this->complete_key_impl(); } template std::enable_if_t && ttg::meta::is_void_v> complete(const key_type& key, Ordinal ord, ttg::TTBase *tt) { - complete_key_impl(); + this->complete_key_impl(); } template std::enable_if_t && ttg::meta::is_void_v> complete(Ordinal ord, ttg::TTBase *tt) { - complete_key_impl(); + this->complete_key_impl(); } template std::enable_if_t && !ttg::meta::is_void_v> complete(ttg::TTBase *tt) { - complete_key_impl(); + this->complete_key_impl(); } /** @@ -279,23 +290,41 @@ namespace ttg { void start() { if (m_stopped) { m_stopped = false; + release_next(m_current, true); // force the check for a next release even if the current ordinal hasn't changed + } + } + + /** + * Release tasks up to the ordinal. The provided ordinal is ignored if `auto_release` is enabled. + */ + void release(ordinal_type ord = 0) { + if (m_auto_release) { + // last key for this ordinal, release the next + // the provided ordinal is ignored release_next(); + } else { + release_next(ord); } } + bool is_auto() const { + return m_auto_release; + } + - private: + protected: std::map m_sequence; - std::atomic m_active; - ordinal_type m_current; + ordinal_type m_current = std::numeric_limits::min(); [[no_unique_address]] Mapper m_map; [[no_unique_address]] compare_t m_order; + std::atomic m_active; bool m_stopped = false; + bool m_auto_release = false; }; - // deduction guide: take type of first argument to Mapper as the key type + // deduction guides: take type of first argument to Mapper as the key type // TODO: can we use the TTG callable_args instead? template>>>>> SequencedKeysConstraint(Mapper&&) @@ -314,11 +343,36 @@ namespace ttg { SequencedKeysConstraint(const SequencedKeysConstraint&) -> SequencedKeysConstraint; + /** + * Make a constraint that can be shared between multiple TT instances. + * Overload for incomplete templated constraint types. + * + * Example: + * // SequencedKeysConstraint is incomplete + * auto c = ttg::make_shared_constraint([](Key& k){ return k[0]; }); + * auto tt_a = ttg::make_tt(...); + * tt_a->add_constraint(c); + * auto tt_b = ttg::make_tt(...); + * tt_b->add_constraint(c); + * + * -> the constraint will handle keys from both tt_a and tt_b. Both TTs must have the same key type. + */ template typename Constraint, typename... Args> auto make_shared_constraint(Args&&... args) { - return std::make_shared(args)...))>(Constraint(std::forward(args)...)); + return std::make_shared(args)...))>(std::forward(args)...); } + /** + * Make a constraint that can be shared between multiple TT instances. + * Overload for complete constraint types. + */ + template + auto make_shared_constraint(Args&&... args) { + return std::make_shared(std::forward(args)...); + } + + + } // namespace ttg #endif // TTG_CONSTRAINT_H \ No newline at end of file