diff --git a/examples/spmm/spmm.cc b/examples/spmm/spmm.cc index a0c2d858c..101bff686 100644 --- a/examples/spmm/spmm.cc +++ b/examples/spmm/spmm.cc @@ -304,7 +304,8 @@ class SpMM25D { bcast_b_->add_constraint(constraint, [](const Key<2>& key){ return key[0]; }); local_bcast_b_ = std::make_unique(local_b_ijk_, b_ijk_, a_rows_of_col_, ijk_keymap_); multiplyadd_ = std::make_unique(a_ijk_, b_ijk_, c_ijk_, c_ij_p_, a_cols_of_row_, - b_rows_of_col_, mTiles, nTiles, ijk_keymap_, constraint, k_cnt_); + b_rows_of_col_, mTiles, nTiles, ijk_keymap_, constraint, + k_cnt_, parallel_bcasts); reduce_c_ = std::make_unique(c_ij_p_, c, ij_keymap_); reduce_c_->template set_input_reducer<0>( @@ -545,13 +546,15 @@ class SpMM25D { const std::vector> &b_rows_of_col, const std::vector &mTiles, const std::vector &nTiles, const Keymap3 &ijk_keymap, std::shared_ptr>> constraint, - std::vector& k_cnt) + std::vector& k_cnt, + std::size_t parallel_bcasts) : baseT(edges(a_ijk, b_ijk, c_ijk), edges(c, c_ijk), "SpMM25D::MultiplyAdd", {"a_ijk", "b_ijk", "c_ijk"}, {"c_ij", "c_ijk"}, ijk_keymap) , a_cols_of_row_(a_cols_of_row) , b_rows_of_col_(b_rows_of_col) , k_cnt_(k_cnt) - , constraint(std::move(constraint)){ + , constraint(std::move(constraint)) + , parallel_bcasts_(parallel_bcasts) { this->set_priomap([=,this](const Key<3> &ijk) { return this->prio(ijk); }); // map a key to an integral priority value // for each {i,j} determine first k that contributes AND belongs to this node, @@ -598,9 +601,10 @@ class SpMM25D { // release the constraint on the next round of broadcasts { std::size_t release_k = k; + std::size_t bcasts_ahead = parallel_bcasts_; while (release_k < k_cnt_.size()) { ++release_k; - if (k_cnt_[release_k]) + if (k_cnt_[release_k] && --bcasts_ahead) break; } constraint->release(release_k); @@ -627,6 +631,7 @@ class SpMM25D { const std::vector> &b_rows_of_col_; std::vector& k_cnt_; std::shared_ptr>> constraint; + std::size_t parallel_bcasts_; /* Compute the length of the remaining sequence on that tile */ int32_t prio(const Key<3> &key) const { diff --git a/ttg/ttg/constraint.h b/ttg/ttg/constraint.h index 199fb5ed6..311bf3ead 100644 --- a/ttg/ttg/constraint.h +++ b/ttg/ttg/constraint.h @@ -78,6 +78,10 @@ namespace ttg { std::map> m_keys; sequence_elem_t() = default; + sequence_elem_t(sequence_elem_t&&) = default; + sequence_elem_t(const sequence_elem_t&) = default; + sequence_elem_t& operator=(sequence_elem_t&&) = default; + sequence_elem_t& operator=(const sequence_elem_t&) = default; void add_key(const key_type& key, ttg::TTBase* tt) { auto it = m_keys.find(tt); @@ -95,9 +99,9 @@ namespace ttg { // key should be executed if (m_auto_release) { // only needed for auto-release m_active.fetch_add(1, std::memory_order_relaxed); + // revert the current ordinal to the lower ordinal + m_current = ord; } - // reset current to the lower ordinal - m_current = ord; return true; } else if (m_sequence.empty() && m_auto_release && 0 == m_active.load(std::memory_order_relaxed)) { // there are no keys (active or blocked) so we execute to avoid a deadlock @@ -172,18 +176,20 @@ namespace ttg { if (!force_check && m_order(ord, this->m_current)) { return; // already at the provided ordinal, nothing to be done } - // set current ordinal - this->m_current = ord; // trigger the next sequence(s) (m_sequence is ordered by ordinal) std::vector seqs; { auto g = this->lock_guard(); - for (auto it = this->m_sequence.begin(); it != this->m_sequence.end(); it = this->m_sequence.begin()) { - if (!this->m_order(it->first, this->m_current)) break; - // extract the next sequence - this->m_current = it->first; - seqs.push_back(std::move(it->second)); - this->m_sequence.erase(it); + // set current ordinal + this->m_current = ord; + { + for (auto it = this->m_sequence.begin(); it != this->m_sequence.end();) { + if (!this->m_order(it->first, this->m_current)) break; + // extract the next sequence + this->m_current = it->first; + seqs.push_back(std::move(it->second)); + it = this->m_sequence.erase(it); + } } } for (auto& elem : seqs) { diff --git a/ttg/ttg/parsec/task.h b/ttg/ttg/parsec/task.h index 5b23d53af..af655dac3 100644 --- a/ttg/ttg/parsec/task.h +++ b/ttg/ttg/parsec/task.h @@ -192,8 +192,9 @@ namespace ttg_parsec { device_state_t dev_state; ttg_data_copy_t *copies[num_copies] = { nullptr }; // the data copies tracked by this task - parsec_ttg_task_t(parsec_thread_mempool_t *mempool, parsec_task_class_t *task_class) - : parsec_ttg_task_base_t(mempool, task_class, num_streams, copies) { + parsec_ttg_task_t(parsec_thread_mempool_t *mempool, parsec_task_class_t *task_class, TT *tt_ptr) + : parsec_ttg_task_base_t(mempool, task_class, num_streams, copies) + , tt(tt_ptr) { tt_ht_item.key = pkey(); this->dev_ptr = this->dev_state.dev_ptr(); // We store the hash of the key and the address where it can be found in locals considered as a scratchpad @@ -250,8 +251,9 @@ namespace ttg_parsec { ttg_data_copy_t *copies[num_streams+1] = { nullptr }; // the data copies tracked by this task // +1 for the copy needed during send/bcast - parsec_ttg_task_t(parsec_thread_mempool_t *mempool, parsec_task_class_t *task_class) - : parsec_ttg_task_base_t(mempool, task_class, num_streams, copies) { + parsec_ttg_task_t(parsec_thread_mempool_t *mempool, parsec_task_class_t *task_class, TT *tt_ptr) + : parsec_ttg_task_base_t(mempool, task_class, num_streams, copies) + , tt(tt_ptr) { tt_ht_item.key = pkey(); this->dev_ptr = this->dev_state.dev_ptr(); } diff --git a/ttg/ttg/parsec/ttg.h b/ttg/ttg/parsec/ttg.h index 7aa3473cb..91b2ca322 100644 --- a/ttg/ttg/parsec/ttg.h +++ b/ttg/ttg/parsec/ttg.h @@ -1925,7 +1925,7 @@ ttg::abort(); // should not happen task_t *dummy; parsec_execution_stream_s *es = world.impl().execution_stream(); parsec_thread_mempool_t *mempool = get_task_mempool(); - dummy = new (parsec_thread_mempool_allocate(mempool)) task_t(mempool, &this->self); + dummy = new (parsec_thread_mempool_allocate(mempool)) task_t(mempool, &this->self, this); dummy->set_dummy(true); // TODO: do we need to copy static_stream_goal in dummy?