Skip to content

Commit

Permalink
Modify constraint to make auto-progress optional
Browse files Browse the repository at this point in the history
The constructor of the SequencedKeysConstraint takes a Boolean
argument that determines whether tasks are automatically released
or whether we depend on the application to release the next wave of tasks.

Signed-off-by: Joseph Schuchart <[email protected]>
  • Loading branch information
devreal committed Mar 6, 2024
1 parent 52f2725 commit a759f6b
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 19 deletions.
13 changes: 9 additions & 4 deletions examples/spmm/spmm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,8 @@ class SpMM25D {
bcast_b_->add_constraint(constraint, [](const Key<2>& key){ return key[0]; });
local_bcast_b_ = std::make_unique<LocalBcastB>(local_b_ijk_, b_ijk_, a_rows_of_col_, ijk_keymap_);
multiplyadd_ = std::make_unique<MultiplyAdd>(a_ijk_, b_ijk_, c_ijk_, c_ij_p_, a_cols_of_row_,
b_rows_of_col_, mTiles, nTiles, ijk_keymap_, constraint, k_cnt_);
b_rows_of_col_, mTiles, nTiles, ijk_keymap_, constraint,
k_cnt_, parallel_bcasts);

reduce_c_ = std::make_unique<ReduceC>(c_ij_p_, c, ij_keymap_);
reduce_c_->template set_input_reducer<0>(
Expand Down Expand Up @@ -545,13 +546,15 @@ class SpMM25D {
const std::vector<std::vector<long>> &b_rows_of_col, const std::vector<int> &mTiles,
const std::vector<int> &nTiles, const Keymap3 &ijk_keymap,
std::shared_ptr<ttg::SequencedKeysConstraint<Key<2>>> constraint,
std::vector<bool>& k_cnt)
std::vector<bool>& k_cnt,
std::size_t parallel_bcasts)
: baseT(edges(a_ijk, b_ijk, c_ijk), edges(c, c_ijk), "SpMM25D::MultiplyAdd", {"a_ijk", "b_ijk", "c_ijk"},
{"c_ij", "c_ijk"}, ijk_keymap)
, a_cols_of_row_(a_cols_of_row)
, b_rows_of_col_(b_rows_of_col)
, k_cnt_(k_cnt)
, constraint(std::move(constraint)){
, constraint(std::move(constraint))
, parallel_bcasts_(parallel_bcasts) {
this->set_priomap([=,this](const Key<3> &ijk) { return this->prio(ijk); }); // map a key to an integral priority value

// for each {i,j} determine first k that contributes AND belongs to this node,
Expand Down Expand Up @@ -598,9 +601,10 @@ class SpMM25D {
// release the constraint on the next round of broadcasts
{
std::size_t release_k = k;
std::size_t bcasts_ahead = parallel_bcasts_;
while (release_k < k_cnt_.size()) {
++release_k;
if (k_cnt_[release_k])
if (k_cnt_[release_k] && --bcasts_ahead)
break;
}
constraint->release(release_k);
Expand All @@ -627,6 +631,7 @@ class SpMM25D {
const std::vector<std::vector<long>> &b_rows_of_col_;
std::vector<bool>& k_cnt_;
std::shared_ptr<ttg::SequencedKeysConstraint<Key<2>>> constraint;
std::size_t parallel_bcasts_;

/* Compute the length of the remaining sequence on that tile */
int32_t prio(const Key<3> &key) const {
Expand Down
26 changes: 16 additions & 10 deletions ttg/ttg/constraint.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ namespace ttg {
std::map<ttg::TTBase*, std::vector<key_type>> m_keys;

sequence_elem_t() = default;
sequence_elem_t(sequence_elem_t&&) = default;
sequence_elem_t(const sequence_elem_t&) = default;
sequence_elem_t& operator=(sequence_elem_t&&) = default;
sequence_elem_t& operator=(const sequence_elem_t&) = default;

void add_key(const key_type& key, ttg::TTBase* tt) {
auto it = m_keys.find(tt);
Expand All @@ -95,9 +99,9 @@ namespace ttg {
// key should be executed
if (m_auto_release) { // only needed for auto-release
m_active.fetch_add(1, std::memory_order_relaxed);
// revert the current ordinal to the lower ordinal
m_current = ord;
}
// reset current to the lower ordinal
m_current = ord;
return true;
} else if (m_sequence.empty() && m_auto_release && 0 == m_active.load(std::memory_order_relaxed)) {
// there are no keys (active or blocked) so we execute to avoid a deadlock
Expand Down Expand Up @@ -172,18 +176,20 @@ namespace ttg {
if (!force_check && m_order(ord, this->m_current)) {
return; // already at the provided ordinal, nothing to be done
}
// set current ordinal
this->m_current = ord;
// trigger the next sequence(s) (m_sequence is ordered by ordinal)
std::vector<sequence_elem_t> seqs;
{
auto g = this->lock_guard();
for (auto it = this->m_sequence.begin(); it != this->m_sequence.end(); it = this->m_sequence.begin()) {
if (!this->m_order(it->first, this->m_current)) break;
// extract the next sequence
this->m_current = it->first;
seqs.push_back(std::move(it->second));
this->m_sequence.erase(it);
// set current ordinal
this->m_current = ord;
{
for (auto it = this->m_sequence.begin(); it != this->m_sequence.end();) {
if (!this->m_order(it->first, this->m_current)) break;
// extract the next sequence
this->m_current = it->first;
seqs.push_back(std::move(it->second));
it = this->m_sequence.erase(it);
}
}
}
for (auto& elem : seqs) {
Expand Down
10 changes: 6 additions & 4 deletions ttg/ttg/parsec/task.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,9 @@ namespace ttg_parsec {
device_state_t<TT::derived_has_device_op()> dev_state;
ttg_data_copy_t *copies[num_copies] = { nullptr }; // the data copies tracked by this task

parsec_ttg_task_t(parsec_thread_mempool_t *mempool, parsec_task_class_t *task_class)
: parsec_ttg_task_base_t(mempool, task_class, num_streams, copies) {
parsec_ttg_task_t(parsec_thread_mempool_t *mempool, parsec_task_class_t *task_class, TT *tt_ptr)
: parsec_ttg_task_base_t(mempool, task_class, num_streams, copies)
, tt(tt_ptr) {
tt_ht_item.key = pkey();
this->dev_ptr = this->dev_state.dev_ptr();
// We store the hash of the key and the address where it can be found in locals considered as a scratchpad
Expand Down Expand Up @@ -250,8 +251,9 @@ namespace ttg_parsec {
ttg_data_copy_t *copies[num_streams+1] = { nullptr }; // the data copies tracked by this task
// +1 for the copy needed during send/bcast

parsec_ttg_task_t(parsec_thread_mempool_t *mempool, parsec_task_class_t *task_class)
: parsec_ttg_task_base_t(mempool, task_class, num_streams, copies) {
parsec_ttg_task_t(parsec_thread_mempool_t *mempool, parsec_task_class_t *task_class, TT *tt_ptr)
: parsec_ttg_task_base_t(mempool, task_class, num_streams, copies)
, tt(tt_ptr) {
tt_ht_item.key = pkey();
this->dev_ptr = this->dev_state.dev_ptr();
}
Expand Down
2 changes: 1 addition & 1 deletion ttg/ttg/parsec/ttg.h
Original file line number Diff line number Diff line change
Expand Up @@ -1925,7 +1925,7 @@ ttg::abort(); // should not happen
task_t *dummy;
parsec_execution_stream_s *es = world.impl().execution_stream();
parsec_thread_mempool_t *mempool = get_task_mempool();
dummy = new (parsec_thread_mempool_allocate(mempool)) task_t(mempool, &this->self);
dummy = new (parsec_thread_mempool_allocate(mempool)) task_t(mempool, &this->self, this);
dummy->set_dummy(true);
// TODO: do we need to copy static_stream_goal in dummy?

Expand Down

0 comments on commit a759f6b

Please sign in to comment.