Skip to content

Commit

Permalink
Implement simple sequenced keys constraint
Browse files Browse the repository at this point in the history
Signed-off-by: Joseph Schuchart <[email protected]>
  • Loading branch information
devreal committed Feb 27, 2024
1 parent 1cf37d5 commit 9d0fd20
Show file tree
Hide file tree
Showing 6 changed files with 395 additions and 10 deletions.
5 changes: 5 additions & 0 deletions examples/spmm/spmm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -292,11 +292,16 @@ class SpMM25D {
, parallel_bcasts_(parallel_bcasts) {
Edge<Key<2>, void> a_ctl, b_ctl;
Edge<Key<2>, int> a_rowctl, b_colctl; // TODO: can we have multiple control inputs per TT?
auto constraint = ttg::make_shared_constraint<ttg::SequencedKeysConstraint<Key<2>>>();
bcast_a_ = std::make_unique<BcastA>(a, a_ctl, a_rowctl, local_a_ijk_, a_rows_of_col_, a_cols_of_row_, b_cols_of_row_,
ij_keymap_, ijk_keymap_, parallel_bcasts_);
// add constraint with external mapper: key[1] represents `k`
bcast_a_->add_constraint(constraint, [](const Key<2>& key){ return key[1]; });
local_bcast_a_ = std::make_unique<LocalBcastA>(local_a_ijk_, a_ijk_, b_cols_of_row_, ijk_keymap_);
bcast_b_ = std::make_unique<BcastB>(b, b_ctl, b_colctl, local_b_ijk_, a_rows_of_col_, b_cols_of_row_, b_rows_of_col_,
ij_keymap_, ijk_keymap_, parallel_bcasts_);
// add constraint with external mapper: key[0] represents `k`
bcast_b_->add_constraint(constraint, [](const Key<2>& key){ return key[0]; });
local_bcast_b_ = std::make_unique<LocalBcastB>(local_b_ijk_, b_ijk_, a_rows_of_col_, ijk_keymap_);
multiplyadd_ = std::make_unique<MultiplyAdd>(a_ijk_, b_ijk_, c_ijk_, c_ij_p_, a_cols_of_row_,
b_rows_of_col_, mTiles, nTiles, ijk_keymap_);
Expand Down
1 change: 1 addition & 0 deletions ttg/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ configure_file(
set(ttg-impl-headers
${CMAKE_CURRENT_SOURCE_DIR}/ttg/broadcast.h
${CMAKE_CURRENT_SOURCE_DIR}/ttg/buffer.h
${CMAKE_CURRENT_SOURCE_DIR}/ttg/constraint.h
${CMAKE_CURRENT_SOURCE_DIR}/ttg/devicescope.h
${CMAKE_CURRENT_SOURCE_DIR}/ttg/devicescratch.h
${CMAKE_CURRENT_SOURCE_DIR}/ttg/edge.h
Expand Down
230 changes: 230 additions & 0 deletions ttg/ttg/constraint.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
#ifndef TTG_CONSTRAINT_H
#define TTG_CONSTRAINT_H

#include <functional>
#include <atomic>
#include <cstdint>
#include <mutex>
#include <map>

namespace ttg {

// TODO: do we need a (virtual) base class?

template<typename Key>
struct ConstraintBase {
using key_type = Key;
using listener_t = std::function<void(const std::span<key_type>&)>;
ConstraintBase()
{ }
#if 0
virtual bool check_key(const key_type& key, ttg::TTBase *tt) = 0;
virtual void complete_key(const key_type& key) = 0;
#endif // 0
virtual ~ConstraintBase() = default;

void add_listener(listener_t l, ttg::TTBase *tt) {
auto g = this->lock_guard();
m_listeners.insert_or_assign(tt, std::move(l));
}

void notify_listener(const std::span<key_type>& keys, ttg::TTBase* tt) {
auto& release = m_listeners[tt];
release(keys);
}

protected:

auto lock_guard() {
return std::lock_guard{m_mtx};
}

private:
std::map<ttg::TTBase*, listener_t> m_listeners;
std::mutex m_mtx;
};

template<typename Key,
typename Ordinal = std::size_t,
typename Compare = std::less<Ordinal>,
typename Mapper = ttg::Void>
struct SequencedKeysConstraint : public ConstraintBase<Key> {

using key_type = std::conditional_t<ttg::meta::is_void_v<Key>, ttg::Void, Key>;
using ordinal_type = Ordinal;
using keymap_t = std::function<Ordinal(const key_type&)>;
using compare_t = Compare;
using base_t = ConstraintBase<Key>;

private:
struct sequence_elem_t {
std::map<ttg::TTBase*, std::vector<key_type>> m_keys;

sequence_elem_t() = default;
//sequence_elem_t(const sequence_elem_t&) = default;
//sequence_elem_t(sequence_elem_t&&) = default;

void add_key(const key_type& key, ttg::TTBase* tt) {
std::cout << "SEQ: add_key " << key << " tt " << tt << std::endl;
auto it = m_keys.find(tt);
if (it == m_keys.end()) {
m_keys.insert(std::make_pair(tt, std::vector<key_type>{key}));
} else {
it->second.push_back(key);
}
}
};

void release_next() {
// trigger the next sequence
sequence_elem_t elem;
{
// extract the next sequence
auto g = this->lock_guard();
auto it = m_sequence.begin(); // sequence is ordered by ordinal
if (it == m_sequence.end()) {
return; // nothing to be done
}
m_current = it->first;
elem = std::move(it->second);
m_sequence.erase(it);
}
std::cout << "SEQ: releasing ord " << m_current << std::endl;
for (auto& seq : elem.m_keys) {
// account for the newly active keys
m_active.fetch_add(seq.second.size(), std::memory_order_relaxed);
this->notify_listener(std::span<key_type>(seq.second.data(), seq.second.size()), seq.first);
}
}

bool check_key_impl(const key_type& key, Ordinal ord, ttg::TTBase *tt) {
if (m_order(ord, m_current)) {
// key should be executed
m_active.fetch_add(1, std::memory_order_relaxed);
// reset current to the lower ordinal
m_current = ord;
std::cout << "SEQ: reset ord to " << m_current << std::endl;
return true;
} else if (m_sequence.empty() && 0 == m_active.load(std::memory_order_relaxed)) {
// there are no keys (active or blocked) so we execute to avoid a deadlock
// we don't change the current ordinal because there may be lower ordinals coming in later
m_active.fetch_add(1, std::memory_order_relaxed);
std::cout << "SEQ: empty-releasing key " << key << std::endl;
return true;
} else {
// key should be deferred
auto g = this->lock_guard();
if (m_order(ord, m_current)) {
// someone released this ordinal while we took the lock
return true;
}
auto it = m_sequence.find(ord);
if (it == m_sequence.end()) {
auto [iter, success] = m_sequence.insert(std::make_pair(ord, sequence_elem_t{}));
assert(success);
it = iter;
}
it->second.add_key(key, tt);
return false;
}
}

void complete_key_impl() {
auto active = m_active.fetch_sub(1, std::memory_order_relaxed) - 1;
if (0 == active) {
release_next();
}
}


public:

/**
* Used for external key mapper.
*/
SequencedKeysConstraint()
: base_t()
{ }

template<typename Mapper_>
SequencedKeysConstraint(Mapper_&& map)
: base_t()
, m_map(std::forward<Mapper_>(map))
{ }

~SequencedKeysConstraint() = default;

/* Check whether the key may be executed.
* Returns true if the key may be executed.
* Otherwise, returns false and */
template<typename Key_ = key_type, typename Mapper_ = Mapper>
std::enable_if_t<!ttg::meta::is_void_v<Key_> && !ttg::meta::is_void_v<Mapper_>, bool>
check(const key_type& key, ttg::TTBase *tt) {
ordinal_type ord = m_map(key);
return check_key_impl(key, ord, tt);
}

template<typename Key_ = key_type, typename Mapper_ = Mapper>
std::enable_if_t<!ttg::meta::is_void_v<Key_> && ttg::meta::is_void_v<Mapper_>, bool>
check(const key_type& key, Ordinal ord, ttg::TTBase *tt) {
return check_key_impl(key, ord, tt);
}

template<typename Key_ = key_type, typename Mapper_ = Mapper>
std::enable_if_t<ttg::meta::is_void_v<Key_> && !ttg::meta::is_void_v<Mapper_>, bool>
check(ttg::TTBase *tt) {
return check_key_impl(ttg::Void{}, m_map(), tt);
}

template<typename Key_ = key_type, typename Mapper_ = Mapper>
std::enable_if_t<ttg::meta::is_void_v<Key_> && ttg::meta::is_void_v<Mapper_>, bool>
check(ordinal_type ord, ttg::TTBase *tt) {
return check_key_impl(ttg::Void{}, ord, tt);
}

template<typename Key_ = key_type, typename Mapper_ = Mapper>
std::enable_if_t<!ttg::meta::is_void_v<Key_> && !ttg::meta::is_void_v<Mapper_>>
complete(const key_type& key, ttg::TTBase *tt) {
std::cout << "SEQ complete_key " << key << std::endl;
complete_key_impl();
}

template<typename Key_ = key_type, typename Mapper_ = Mapper>
std::enable_if_t<!ttg::meta::is_void_v<Key_> && ttg::meta::is_void_v<Mapper_>>
complete(const key_type& key, Ordinal ord, ttg::TTBase *tt) {
std::cout << "SEQ complete_key " << key << std::endl;
complete_key_impl();
}

template<typename Key_ = key_type, typename Mapper_ = Mapper>
std::enable_if_t<!ttg::meta::is_void_v<Key_> && ttg::meta::is_void_v<Mapper_>>
complete(Ordinal ord, ttg::TTBase *tt) {
complete_key_impl();
}

template<typename Key_ = key_type, typename Mapper_ = Mapper>
std::enable_if_t<!ttg::meta::is_void_v<Key_> && !ttg::meta::is_void_v<Mapper_>>
complete(ttg::TTBase *tt) {
complete_key_impl();
}

private:
std::map<ordinal_type, sequence_elem_t, compare_t> m_sequence;
std::atomic<std::size_t> m_active;
ordinal_type m_current;
[[no_unique_address]]
Mapper m_map;
[[no_unique_address]]
compare_t m_order;
};



template<typename Constraint, typename... Args>
std::shared_ptr<Constraint> make_shared_constraint(Args&&... args) {
return std::make_shared<Constraint>(new Constraint(std::forward<Args>(args)...));
}

} // namespace ttg

#endif // TTG_CONSTRAINT_H
14 changes: 14 additions & 0 deletions ttg/ttg/parsec/task.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ namespace ttg_parsec {
int32_t data_count = 0; //< number of data elements in the copies array
ttg_data_copy_t **copies; //< pointer to the fixed copies array of the derived task
parsec_hash_table_item_t tt_ht_item = {};
std::atomic<uint32_t> constraint_blocks;

struct stream_info_t {
std::size_t goal;
Expand Down Expand Up @@ -134,6 +135,19 @@ namespace ttg_parsec {
release_task_cb(this);
}

/* add a constraint to the task
* returns true if this is the first constraint */
bool add_constraint() {
return (0 == constraint_blocks.fetch_add(1, std::memory_order_relaxed));
}

/* remove a constraint from the task
* returns true if this is the last constraint */
bool release_constaint() {
/* return true if this was the last constraint*/
return 1 == constraint_blocks.fetch_sub(1, std::memory_order_relaxed);
}

protected:
/**
* Protected constructors: this class should not be instantiated directly
Expand Down
Loading

0 comments on commit 9d0fd20

Please sign in to comment.