Skip to content

Commit

Permalink
Merge pull request #3 from devreal/spmm-with-seed-keymap2
Browse files Browse the repository at this point in the history
SPMM: Add keymaps to place tasks more efficiently
  • Loading branch information
therault authored Sep 13, 2021
2 parents 90a3b94 + 8dc8c78 commit 73f137d
Showing 1 changed file with 52 additions and 34 deletions.
86 changes: 52 additions & 34 deletions examples/spmm/spmm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,20 +179,23 @@ std::ostream &operator<<(std::ostream &os, const Key<Rank> &key) {
return os;
}

inline int tile2rank(int i, int j, int P, int Q)
{
int p = (i % P);
int q = (j % Q);
int r = (q * P) + p;
return r;
}

// flow data from an existing SpMatrix on rank 0
template <typename Blk = blk_t>
class Read_SpMatrix : public Op<Key<2>, std::tuple<Out<Key<2>, Blk>>, Read_SpMatrix<Blk>, void> {
public:
using baseT = Op<Key<2>, std::tuple<Out<Key<2>, Blk>>, Read_SpMatrix<Blk>, void>;

Read_SpMatrix(const char *label, const SpMatrix<Blk> &matrix, Edge<Key<2>> &ctl, Edge<Key<2>, Blk> &out, const int P,
const int Q)
: baseT(edges(ctl), edges(out), std::string("read_spmatrix(") + label + ")", {"ctl"}, {std::string(label) + "ij"},
[Q](const Key<2> &key) {
int r = (int)key[0] * Q + (int)key[1];
assert(r >= 0 && r < ttg_default_execution_context().size());
return r;
})
template<typename Keymap>
Read_SpMatrix(const char *label, const SpMatrix<Blk> &matrix, Edge<Key<2>> &ctl, Edge<Key<2>, Blk> &out, Keymap &&keymap)
: baseT(edges(ctl), edges(out), std::string("read_spmatrix(") + label + ")", {"ctl"}, {std::string(label) + "ij"}, keymap)
, matrix_(matrix) {}

void op(const Key<2> &key, std::tuple<Out<Key<2>, Blk>> &out) {
Expand All @@ -213,8 +216,9 @@ class Write_SpMatrix : public Op<Key<2>, std::tuple<>, Write_SpMatrix<Blk>, Blk>
public:
using baseT = Op<Key<2>, std::tuple<>, Write_SpMatrix<Blk>, Blk>;

Write_SpMatrix(SpMatrix<Blk> &matrix, Edge<Key<2>, Blk> &in)
: baseT(edges(in), edges(), "write_spmatrix", {"Cij"}, {}, [](auto key) { return 0; }), matrix_(matrix) {}
template<typename Keymap>
Write_SpMatrix(SpMatrix<Blk> &matrix, Edge<Key<2>, Blk> &in, Keymap&& keymap)
: baseT(edges(in), edges(), "write_spmatrix", {"Cij"}, {}, keymap), matrix_(matrix) {}

void op(const Key<2> &key, typename baseT::input_values_tuple_type &&elem, std::tuple<> &) {
std::lock_guard<std::mutex> lock(mtx_);
Expand Down Expand Up @@ -255,9 +259,10 @@ class Write_SpMatrix : public Op<Key<2>, std::tuple<>, Write_SpMatrix<Blk>, Blk>
template <typename Blk = blk_t>
class SpMM {
public:
template<typename Keymap>
SpMM(Edge<Key<2>, Blk> &a, Edge<Key<2>, Blk> &b, Edge<Key<2>, Blk> &c, const SpMatrix<Blk> &a_mat,
const SpMatrix<Blk> &b_mat, std::map<std::tuple<int, int>, bool> &Afilling,
std::map<std::tuple<int, int>, bool> &Bfilling, const int P, const int Q)
std::map<std::tuple<int, int>, bool> &Bfilling, Keymap &&keymap)
: a_ijk_()
, b_ijk_()
, c_ijk_()
Expand All @@ -272,10 +277,10 @@ class SpMM {
ttg_broadcast(ttg_default_execution_context(), a_colidx_to_rowidx_, root);
ttg_broadcast(ttg_default_execution_context(), b_colidx_to_rowidx_, root);

bcast_a_ = std::make_unique<BcastA>(a, a_ijk_, b_rowidx_to_colidx_);
bcast_b_ = std::make_unique<BcastB>(b, b_ijk_, a_colidx_to_rowidx_);
bcast_a_ = std::make_unique<BcastA>(a, a_ijk_, b_rowidx_to_colidx_, keymap);
bcast_b_ = std::make_unique<BcastB>(b, b_ijk_, a_colidx_to_rowidx_, keymap);
multiplyadd_ =
std::make_unique<MultiplyAdd>(a_ijk_, b_ijk_, c_ijk_, c, a_rowidx_to_colidx_, b_colidx_to_rowidx_, P, Q);
std::make_unique<MultiplyAdd>(a_ijk_, b_ijk_, c_ijk_, c, a_rowidx_to_colidx_, b_colidx_to_rowidx_, keymap);

TTGUNUSED(bcast_a_);
TTGUNUSED(bcast_b_);
Expand All @@ -287,8 +292,9 @@ class SpMM {
public:
using baseT = Op<Key<2>, std::tuple<Out<Key<3>, Blk>>, BcastA, Blk>;

BcastA(Edge<Key<2>, Blk> &a, Edge<Key<3>, Blk> &a_ijk, const std::vector<std::vector<long>> &b_rowidx_to_colidx)
: baseT(edges(a), edges(a_ijk), "SpMM::bcast_a", {"a_ik"}, {"a_ijk"})
template<typename Keymap>
BcastA(Edge<Key<2>, Blk> &a, Edge<Key<3>, Blk> &a_ijk, const std::vector<std::vector<long>> &b_rowidx_to_colidx, Keymap&& keymap)
: baseT(edges(a), edges(a_ijk), "SpMM::bcast_a", {"a_ik"}, {"a_ijk"}, keymap)
, b_rowidx_to_colidx_(b_rowidx_to_colidx) {}

void op(const Key<2> &key, typename baseT::input_values_tuple_type &&a_ik, std::tuple<Out<Key<3>, Blk>> &a_ijk) {
Expand All @@ -314,8 +320,9 @@ class SpMM {
public:
using baseT = Op<Key<2>, std::tuple<Out<Key<3>, Blk>>, BcastB, Blk>;

BcastB(Edge<Key<2>, Blk> &b, Edge<Key<3>, Blk> &b_ijk, const std::vector<std::vector<long>> &a_colidx_to_rowidx)
: baseT(edges(b), edges(b_ijk), "SpMM::bcast_b", {"b_kj"}, {"b_ijk"})
template<typename Keymap>
BcastB(Edge<Key<2>, Blk> &b, Edge<Key<3>, Blk> &b_ijk, const std::vector<std::vector<long>> &a_colidx_to_rowidx, Keymap&& keymap)
: baseT(edges(b), edges(b_ijk), "SpMM::bcast_b", {"b_kj"}, {"b_ijk"}, keymap)
, a_colidx_to_rowidx_(a_colidx_to_rowidx) {}

void op(const Key<2> &key, typename baseT::input_values_tuple_type &&b_kj, std::tuple<Out<Key<3>, Blk>> &b_ijk) {
Expand All @@ -342,21 +349,19 @@ class SpMM {
public:
using baseT = Op<Key<3>, std::tuple<Out<Key<2>, Blk>, Out<Key<3>, Blk>>, MultiplyAdd, const Blk, const Blk, Blk>;

template<typename Keymap>
MultiplyAdd(Edge<Key<3>, Blk> &a_ijk, Edge<Key<3>, Blk> &b_ijk, Edge<Key<3>, Blk> &c_ijk, Edge<Key<2>, Blk> &c,
const std::vector<std::vector<long>> &a_rowidx_to_colidx,
const std::vector<std::vector<long>> &b_colidx_to_rowidx, const int P, const int Q)
const std::vector<std::vector<long>> &b_colidx_to_rowidx, Keymap &&keymap)
: baseT(edges(a_ijk, b_ijk, c_ijk), edges(c, c_ijk), "SpMM::MultiplyAdd", {"a_ijk", "b_ijk", "c_ijk"},
{"c_ij", "c_ijk"},
[P, Q](const Key<3> &key) {
int i = (int)key[0];
int j = (int)key[1];
int r = (i % P) * Q + (j % Q);
return r;
[keymap](const Key<3> &key) {
auto key2 = Key<2>({key[0], key[1]});
return keymap(key2);
})
, a_rowidx_to_colidx_(a_rowidx_to_colidx)
, b_colidx_to_rowidx_(b_colidx_to_rowidx) {
this->set_priomap([=](const Key<3> &key) { return this->prio(key); });
auto &keymap = this->get_keymap();

// for each i and j that belongs to this node
// determine first k that contributes, initialize input {i,j,first_k} flow to 0
Expand All @@ -366,7 +371,7 @@ class SpMM {
if (b_colidx_to_rowidx_[j].empty()) continue;

// assuming here {i,j,k} for all k map to same node
auto owner = keymap(Key<3>({i, j, 0ul}));
auto owner = keymap(Key<2>({i, j}));
if (owner == ttg_default_execution_context().rank()) {
if (true) {
decltype(i) k;
Expand Down Expand Up @@ -984,13 +989,20 @@ static void timed_measurement(SpMatrix<> &A, SpMatrix<> &B, const std::string &t
Edge<Key<2>> ctl("control");
Control control(ctl);
Edge<Key<2>, blk_t> eA, eB, eC;
Read_SpMatrix<> a("A", A, ctl, eA, P, Q);
Read_SpMatrix<> b("B", B, ctl, eB, P, Q);
Write_SpMatrix<> c(C, eC);

auto keymap = [P, Q](const Key<2> &key) {
int i = (int)key[0];
int j = (int)key[1];
int r = tile2rank(i, j, P, Q);
return r;
};
Read_SpMatrix<> a("A", A, ctl, eA, keymap);
Read_SpMatrix<> b("B", B, ctl, eB, keymap);
Write_SpMatrix<> c(C, eC, keymap);
auto &c_status = c.status();
assert(!has_value(c_status));
// SpMM a_times_b(world, eA, eB, eC, A, B);
SpMM<> a_times_b(eA, eB, eC, A, B, Afilling, Bfilling, P, Q);
SpMM<> a_times_b(eA, eB, eC, A, B, Afilling, Bfilling, keymap);
TTGUNUSED(a);
TTGUNUSED(b);
TTGUNUSED(a_times_b);
Expand Down Expand Up @@ -1200,13 +1212,19 @@ int main(int argc, char **argv) {
Edge<Key<2>> ctl("control");
Control control(ctl);
Edge<Key<2>, blk_t> eA, eB, eC;
Read_SpMatrix<> a("A", A, ctl, eA, P, Q);
Read_SpMatrix<> b("B", B, ctl, eB, P, Q);
Write_SpMatrix<> c(C, eC);
auto keymap = [P, Q](const Key<2> &key) {
int i = (int)key[0];
int j = (int)key[1];
int r = tile2rank(i, j, P, Q);
return r;
};
Read_SpMatrix<> a("A", A, ctl, eA, keymap);
Read_SpMatrix<> b("B", B, ctl, eB, keymap);
Write_SpMatrix<> c(C, eC, keymap);
auto &c_status = c.status();
assert(!has_value(c_status));
// SpMM a_times_b(world, eA, eB, eC, A, B);
SpMM<> a_times_b(eA, eB, eC, A, B, Afilling, Bfilling, P, Q);
SpMM<> a_times_b(eA, eB, eC, A, B, Afilling, Bfilling, keymap);
TTGUNUSED(a_times_b);

std::cout << Dot{}(&a, &b) << std::endl;
Expand Down

0 comments on commit 73f137d

Please sign in to comment.