From 9d0fd20a58aae6e428e580e73f8919d4e5b80a97 Mon Sep 17 00:00:00 2001 From: Joseph Schuchart Date: Tue, 27 Feb 2024 18:11:11 -0500 Subject: [PATCH] Implement simple sequenced keys constraint Signed-off-by: Joseph Schuchart --- examples/spmm/spmm.cc | 5 + ttg/CMakeLists.txt | 1 + ttg/ttg/constraint.h | 230 ++++++++++++++++++++++++++++++++++++++++++ ttg/ttg/parsec/task.h | 14 +++ ttg/ttg/parsec/ttg.h | 140 +++++++++++++++++++++++-- ttg/ttg/util/meta.h | 15 +++ 6 files changed, 395 insertions(+), 10 deletions(-) create mode 100644 ttg/ttg/constraint.h diff --git a/examples/spmm/spmm.cc b/examples/spmm/spmm.cc index 0c85eb3d4..7f23703b5 100644 --- a/examples/spmm/spmm.cc +++ b/examples/spmm/spmm.cc @@ -292,11 +292,16 @@ class SpMM25D { , parallel_bcasts_(parallel_bcasts) { Edge, void> a_ctl, b_ctl; Edge, int> a_rowctl, b_colctl; // TODO: can we have multiple control inputs per TT? + auto constraint = ttg::make_shared_constraint>>(); bcast_a_ = std::make_unique(a, a_ctl, a_rowctl, local_a_ijk_, a_rows_of_col_, a_cols_of_row_, b_cols_of_row_, ij_keymap_, ijk_keymap_, parallel_bcasts_); + // add constraint with external mapper: key[1] represents `k` + bcast_a_->add_constraint(constraint, [](const Key<2>& key){ return key[1]; }); local_bcast_a_ = std::make_unique(local_a_ijk_, a_ijk_, b_cols_of_row_, ijk_keymap_); bcast_b_ = std::make_unique(b, b_ctl, b_colctl, local_b_ijk_, a_rows_of_col_, b_cols_of_row_, b_rows_of_col_, ij_keymap_, ijk_keymap_, parallel_bcasts_); + // add constraint with external mapper: key[0] represents `k` + bcast_b_->add_constraint(constraint, [](const Key<2>& key){ return key[0]; }); local_bcast_b_ = std::make_unique(local_b_ijk_, b_ijk_, a_rows_of_col_, ijk_keymap_); multiplyadd_ = std::make_unique(a_ijk_, b_ijk_, c_ijk_, c_ij_p_, a_cols_of_row_, b_rows_of_col_, mTiles, nTiles, ijk_keymap_); diff --git a/ttg/CMakeLists.txt b/ttg/CMakeLists.txt index 900df6519..14dba8d73 100644 --- a/ttg/CMakeLists.txt +++ b/ttg/CMakeLists.txt @@ -45,6 +45,7 @@ configure_file( set(ttg-impl-headers ${CMAKE_CURRENT_SOURCE_DIR}/ttg/broadcast.h ${CMAKE_CURRENT_SOURCE_DIR}/ttg/buffer.h + ${CMAKE_CURRENT_SOURCE_DIR}/ttg/constraint.h ${CMAKE_CURRENT_SOURCE_DIR}/ttg/devicescope.h ${CMAKE_CURRENT_SOURCE_DIR}/ttg/devicescratch.h ${CMAKE_CURRENT_SOURCE_DIR}/ttg/edge.h diff --git a/ttg/ttg/constraint.h b/ttg/ttg/constraint.h new file mode 100644 index 000000000..67b639ab4 --- /dev/null +++ b/ttg/ttg/constraint.h @@ -0,0 +1,230 @@ +#ifndef TTG_CONSTRAINT_H +#define TTG_CONSTRAINT_H + +#include +#include +#include +#include +#include + +namespace ttg { + + // TODO: do we need a (virtual) base class? + + template + struct ConstraintBase { + using key_type = Key; + using listener_t = std::function&)>; + ConstraintBase() + { } +#if 0 + virtual bool check_key(const key_type& key, ttg::TTBase *tt) = 0; + virtual void complete_key(const key_type& key) = 0; +#endif // 0 + virtual ~ConstraintBase() = default; + + void add_listener(listener_t l, ttg::TTBase *tt) { + auto g = this->lock_guard(); + m_listeners.insert_or_assign(tt, std::move(l)); + } + + void notify_listener(const std::span& keys, ttg::TTBase* tt) { + auto& release = m_listeners[tt]; + release(keys); + } + + protected: + + auto lock_guard() { + return std::lock_guard{m_mtx}; + } + + private: + std::map m_listeners; + std::mutex m_mtx; + }; + + template, + typename Mapper = ttg::Void> + struct SequencedKeysConstraint : public ConstraintBase { + + using key_type = std::conditional_t, ttg::Void, Key>; + using ordinal_type = Ordinal; + using keymap_t = std::function; + using compare_t = Compare; + using base_t = ConstraintBase; + + private: + struct sequence_elem_t { + std::map> m_keys; + + sequence_elem_t() = default; + //sequence_elem_t(const sequence_elem_t&) = default; + //sequence_elem_t(sequence_elem_t&&) = default; + + void add_key(const key_type& key, ttg::TTBase* tt) { + std::cout << "SEQ: add_key " << key << " tt " << tt << std::endl; + auto it = m_keys.find(tt); + if (it == m_keys.end()) { + m_keys.insert(std::make_pair(tt, std::vector{key})); + } else { + it->second.push_back(key); + } + } + }; + + void release_next() { + // trigger the next sequence + sequence_elem_t elem; + { + // extract the next sequence + auto g = this->lock_guard(); + auto it = m_sequence.begin(); // sequence is ordered by ordinal + if (it == m_sequence.end()) { + return; // nothing to be done + } + m_current = it->first; + elem = std::move(it->second); + m_sequence.erase(it); + } + std::cout << "SEQ: releasing ord " << m_current << std::endl; + for (auto& seq : elem.m_keys) { + // account for the newly active keys + m_active.fetch_add(seq.second.size(), std::memory_order_relaxed); + this->notify_listener(std::span(seq.second.data(), seq.second.size()), seq.first); + } + } + + bool check_key_impl(const key_type& key, Ordinal ord, ttg::TTBase *tt) { + if (m_order(ord, m_current)) { + // key should be executed + m_active.fetch_add(1, std::memory_order_relaxed); + // reset current to the lower ordinal + m_current = ord; + std::cout << "SEQ: reset ord to " << m_current << std::endl; + return true; + } else if (m_sequence.empty() && 0 == m_active.load(std::memory_order_relaxed)) { + // there are no keys (active or blocked) so we execute to avoid a deadlock + // we don't change the current ordinal because there may be lower ordinals coming in later + m_active.fetch_add(1, std::memory_order_relaxed); + std::cout << "SEQ: empty-releasing key " << key << std::endl; + return true; + } else { + // key should be deferred + auto g = this->lock_guard(); + if (m_order(ord, m_current)) { + // someone released this ordinal while we took the lock + return true; + } + auto it = m_sequence.find(ord); + if (it == m_sequence.end()) { + auto [iter, success] = m_sequence.insert(std::make_pair(ord, sequence_elem_t{})); + assert(success); + it = iter; + } + it->second.add_key(key, tt); + return false; + } + } + + void complete_key_impl() { + auto active = m_active.fetch_sub(1, std::memory_order_relaxed) - 1; + if (0 == active) { + release_next(); + } + } + + + public: + + /** + * Used for external key mapper. + */ + SequencedKeysConstraint() + : base_t() + { } + + template + SequencedKeysConstraint(Mapper_&& map) + : base_t() + , m_map(std::forward(map)) + { } + + ~SequencedKeysConstraint() = default; + + /* Check whether the key may be executed. + * Returns true if the key may be executed. + * Otherwise, returns false and */ + template + std::enable_if_t && !ttg::meta::is_void_v, bool> + check(const key_type& key, ttg::TTBase *tt) { + ordinal_type ord = m_map(key); + return check_key_impl(key, ord, tt); + } + + template + std::enable_if_t && ttg::meta::is_void_v, bool> + check(const key_type& key, Ordinal ord, ttg::TTBase *tt) { + return check_key_impl(key, ord, tt); + } + + template + std::enable_if_t && !ttg::meta::is_void_v, bool> + check(ttg::TTBase *tt) { + return check_key_impl(ttg::Void{}, m_map(), tt); + } + + template + std::enable_if_t && ttg::meta::is_void_v, bool> + check(ordinal_type ord, ttg::TTBase *tt) { + return check_key_impl(ttg::Void{}, ord, tt); + } + + template + std::enable_if_t && !ttg::meta::is_void_v> + complete(const key_type& key, ttg::TTBase *tt) { + std::cout << "SEQ complete_key " << key << std::endl; + complete_key_impl(); + } + + template + std::enable_if_t && ttg::meta::is_void_v> + complete(const key_type& key, Ordinal ord, ttg::TTBase *tt) { + std::cout << "SEQ complete_key " << key << std::endl; + complete_key_impl(); + } + + template + std::enable_if_t && ttg::meta::is_void_v> + complete(Ordinal ord, ttg::TTBase *tt) { + complete_key_impl(); + } + + template + std::enable_if_t && !ttg::meta::is_void_v> + complete(ttg::TTBase *tt) { + complete_key_impl(); + } + + private: + std::map m_sequence; + std::atomic m_active; + ordinal_type m_current; + [[no_unique_address]] + Mapper m_map; + [[no_unique_address]] + compare_t m_order; + }; + + + + template + std::shared_ptr make_shared_constraint(Args&&... args) { + return std::make_shared(new Constraint(std::forward(args)...)); + } + +} // namespace ttg + +#endif // TTG_CONSTRAINT_H \ No newline at end of file diff --git a/ttg/ttg/parsec/task.h b/ttg/ttg/parsec/task.h index 5b23d53af..2e2a327a6 100644 --- a/ttg/ttg/parsec/task.h +++ b/ttg/ttg/parsec/task.h @@ -83,6 +83,7 @@ namespace ttg_parsec { int32_t data_count = 0; //< number of data elements in the copies array ttg_data_copy_t **copies; //< pointer to the fixed copies array of the derived task parsec_hash_table_item_t tt_ht_item = {}; + std::atomic constraint_blocks; struct stream_info_t { std::size_t goal; @@ -134,6 +135,19 @@ namespace ttg_parsec { release_task_cb(this); } + /* add a constraint to the task + * returns true if this is the first constraint */ + bool add_constraint() { + return (0 == constraint_blocks.fetch_add(1, std::memory_order_relaxed)); + } + + /* remove a constraint from the task + * returns true if this is the last constraint */ + bool release_constaint() { + /* return true if this was the last constraint*/ + return 1 == constraint_blocks.fetch_sub(1, std::memory_order_relaxed); + } + protected: /** * Protected constructors: this class should not be instantiated directly diff --git a/ttg/ttg/parsec/ttg.h b/ttg/ttg/parsec/ttg.h index 65e39991a..bab389a59 100644 --- a/ttg/ttg/parsec/ttg.h +++ b/ttg/ttg/parsec/ttg.h @@ -20,6 +20,7 @@ #include "ttg/base/keymap.h" #include "ttg/base/tt.h" #include "ttg/base/world.h" +#include "ttg/constraint.h" #include "ttg/edge.h" #include "ttg/execution.h" #include "ttg/func.h" @@ -1104,6 +1105,7 @@ namespace ttg_parsec { protected: // static std::map function_id_to_instance; parsec_hash_table_t tasks_table; + parsec_hash_table_t task_constraint_table; parsec_task_class_t self; }; @@ -1272,6 +1274,9 @@ namespace ttg_parsec { bool m_defer_writer = TTG_PARSEC_DEFER_WRITER; + std::vector> constraints_check; + std::vector> constraints_complete; + public: ttg::World get_world() const override final { return world; } @@ -2486,6 +2491,73 @@ ttg::abort(); // should not happen } } + bool check_constraints(task_t *task) { + bool release = true; + for (auto& c : constraints_check) { + bool constrained = false; + if constexpr (ttg::meta::is_void_v) { + constrained = !c(); + } else { + constrained = !c(task->key); + } + if (constrained) { + if (task->add_constraint()) { + if constexpr (!ttg::meta::is_void_v) { + } + parsec_hash_table_insert(&task_constraint_table, &task->tt_ht_item); + } + release = false; + } + } + return release; + } + + template + std::enable_if_t, void> release_constraint(const std::span& keys) { + task_t *task; + parsec_key_t hk = 0; + parsec_hash_table_lock_bucket(&task_constraint_table, hk); + task = (task_t *)parsec_hash_table_nolock_find(&task_constraint_table, hk); + if (task->release_constaint()) { + parsec_hash_table_nolock_remove(&task_constraint_table, hk); + auto &world_impl = world.impl(); + parsec_execution_stream_t *es = world_impl.execution_stream(); + parsec_task_t *vp_task_rings[1] = { &task->parsec_task }; + __parsec_schedule_vp(es, vp_task_rings, 0); + } + parsec_hash_table_unlock_bucket(&task_constraint_table, hk); + } + + template + std::enable_if_t, void> release_constraint(const std::span& keys) { + parsec_task_t *task_ring = nullptr; + for (auto& key : keys) { + task_t *task; + auto hk = reinterpret_cast(&key); + parsec_hash_table_lock_bucket(&task_constraint_table, hk); + task = (task_t *)parsec_hash_table_nolock_find(&task_constraint_table, hk); + assert(task != nullptr); + if (task->release_constaint()) { + parsec_hash_table_nolock_remove(&task_constraint_table, hk); + if (task_ring == nullptr) { + /* the first task is set directly */ + task_ring = &task->parsec_task; + } else { + /* push into the ring */ + parsec_list_item_ring_push_sorted(&task_ring->super, &task->parsec_task.super, + offsetof(parsec_task_t, priority)); + } + } + parsec_hash_table_unlock_bucket(&task_constraint_table, hk); + } + if (nullptr != task_ring) { + auto &world_impl = world.impl(); + parsec_execution_stream_t *es = world_impl.execution_stream(); + parsec_task_t *vp_task_rings[1] = { task_ring }; + __parsec_schedule_vp(es, vp_task_rings, 0); + } + } + void release_task(task_t *task, parsec_task_t **task_ring = nullptr) { constexpr const bool keyT_is_Void = ttg::meta::is_void_v; @@ -2515,16 +2587,19 @@ ttg::abort(); // should not happen } } if (task->remove_from_hash) parsec_hash_table_remove(&tasks_table, hk); - if (nullptr == task_ring) { - parsec_task_t *vp_task_rings[1] = { &task->parsec_task }; - __parsec_schedule_vp(es, vp_task_rings, 0); - } else if (*task_ring == nullptr) { - /* the first task is set directly */ - *task_ring = &task->parsec_task; - } else { - /* push into the ring */ - parsec_list_item_ring_push_sorted(&(*task_ring)->super, &task->parsec_task.super, - offsetof(parsec_task_t, priority)); + + if (check_constraints(task)) { + if (nullptr == task_ring) { + parsec_task_t *vp_task_rings[1] = { &task->parsec_task }; + __parsec_schedule_vp(es, vp_task_rings, 0); + } else if (*task_ring == nullptr) { + /* the first task is set directly */ + *task_ring = &task->parsec_task; + } else { + /* push into the ring */ + parsec_list_item_ring_push_sorted(&(*task_ring)->super, &task->parsec_task.super, + offsetof(parsec_task_t, priority)); + } } } else if constexpr (!ttg::meta::is_void_v) { if ((baseobj->num_pullins + count == numins) && baseobj->is_lazy_pull()) { @@ -3670,6 +3745,14 @@ ttg::abort(); // should not happen detail::release_data_copy(copy); task->copies[i] = nullptr; } + + for (auto& c : task->tt->constraints_complete) { + if constexpr(std::is_void_v) { + c(); + } else { + c(task->key); + } + } return PARSEC_HOOK_RETURN_DONE; } @@ -3822,6 +3905,9 @@ ttg::abort(); // should not happen parsec_hash_table_init(&tasks_table, offsetof(detail::parsec_ttg_task_base_t, tt_ht_item), 8, tasks_hash_fcts, NULL); + + parsec_hash_table_init(&task_constraint_table, offsetof(detail::parsec_ttg_task_base_t, tt_ht_item), 8, tasks_hash_fcts, + NULL); } template , @@ -4160,6 +4246,40 @@ ttg::abort(); // should not happen priomap = std::forward(pm); } + /// add a constraint + /// the constraint must provide a valid override of `check_key(key)` + template + void add_constraint(std::shared_ptr c) { + c->add_listener(&release_constraint, this); + if constexpr(ttg::meta::is_void_v) { + c->add_listener([this](){ this->release_constraint(); }, this); + constraints_check.push_back([c, this](){ return c->check(this); }); + constraints_complete.push_back([c, this](const keyT& key){ c->complete(this); return true; }); + } else { + c->add_listener([this](const std::span& keys){ this->release_constraint(keys); }, this); + constraints_check.push_back([c, this](const keyT& key){ return c->check(key, this); }); + constraints_complete.push_back([c, this](const keyT& key){ c->complete(key, this); return true; }); + } + } + + /// add a constraint + /// the constraint must provide a valid override of `check_key(key, map(key))` + /// ths overload can be used to provide different key mapping functions for each TT + template + void add_constraint(std::shared_ptr c, Mapper&& map) { + static_assert(std::is_same_v); + if constexpr(ttg::meta::is_void_v) { + c->add_listener([this](){ this->release_constraint(); }, this); + constraints_check.push_back([map, c, this](){ return c->check(map(), this); }); + constraints_complete.push_back([map, c, this](){ c->complete(map(), this); return true; }); + } else { + c->add_listener([this](const std::span& keys){ this->release_constraint(keys); }, this); + constraints_check.push_back([map, c, this](const keyT& key){ return c->check(key, map(key), this); }); + constraints_complete.push_back([map, c, this](const keyT& key){ c->complete(key, map(key), this); return true; }); + } + } + + // Register the static_op function to associate it to instance_id void register_static_op_function(void) { int rank; diff --git a/ttg/ttg/util/meta.h b/ttg/ttg/util/meta.h index b7bb31690..8839ba7d4 100644 --- a/ttg/ttg/util/meta.h +++ b/ttg/ttg/util/meta.h @@ -912,6 +912,21 @@ namespace ttg { template using prepare_send_callback_t = typename prepare_send_callback::type; + template + struct constraint_callback; + + template + struct constraint_callback>> { + using type = std::function; + }; + + template + struct constraint_callback>> { + using type = std::function; + }; + + template + using constraint_callback_t = typename constraint_callback::type; } // namespace detail