Skip to content

Commit

Permalink
SPMM: implement hierarchical broadcast of A and B
Browse files Browse the repository at this point in the history
The hierarchical broadcast is used to avoid sending all the keys
and instead leave it to the recipient to distribute the tile
to all the relevant keys. This is done because the keys in the PaRSEC
backend are sent inline so that large numbers of keys may grow messages
past the eager limit. This is a bandage, not a fix. Instead, the PaRSEC
backend should learn to handle large key collections. Eventually, we
need to have edges containing key generators...

Signed-off-by: Joseph Schuchart <[email protected]>
  • Loading branch information
devreal committed Sep 13, 2021
1 parent 977d1b8 commit e95d219
Showing 1 changed file with 108 additions and 24 deletions.
132 changes: 108 additions & 24 deletions examples/spmm/spmm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -256,15 +256,16 @@ class Write_SpMatrix : public Op<Key<2>, std::tuple<>, Write_SpMatrix<Blk>, Blk>
};

// sparse mm
template <typename Blk = blk_t>
template <typename Keymap = std::function<int(const Key<2>&)>, typename Blk = blk_t>
class SpMM {
public:
template<typename Keymap>
SpMM(Edge<Key<2>, Blk> &a, Edge<Key<2>, Blk> &b, Edge<Key<2>, Blk> &c, const SpMatrix<Blk> &a_mat,
const SpMatrix<Blk> &b_mat, std::map<std::tuple<int, int>, bool> &Afilling,
std::map<std::tuple<int, int>, bool> &Bfilling, Keymap &&keymap)
std::map<std::tuple<int, int>, bool> &Bfilling, const Keymap& keymap)
: a_ijk_()
, local_a_ijk_()
, b_ijk_()
, local_b_ijk_()
, c_ijk_()
, a_rowidx_to_colidx_(make_rowidx_to_colidx(Afilling))
, b_colidx_to_rowidx_(make_colidx_to_rowidx(Bfilling))
Expand All @@ -277,8 +278,10 @@ class SpMM {
ttg_broadcast(ttg_default_execution_context(), a_colidx_to_rowidx_, root);
ttg_broadcast(ttg_default_execution_context(), b_colidx_to_rowidx_, root);

bcast_a_ = std::make_unique<BcastA>(a, a_ijk_, b_rowidx_to_colidx_, keymap);
bcast_b_ = std::make_unique<BcastB>(b, b_ijk_, a_colidx_to_rowidx_, keymap);
bcast_a_ = std::make_unique<BcastA>(a, local_a_ijk_, b_rowidx_to_colidx_, keymap);
local_bcast_a_ = std::make_unique<LocalBcastA>(local_a_ijk_, a_ijk_, b_rowidx_to_colidx_, keymap);
bcast_b_ = std::make_unique<BcastB>(b, local_b_ijk_, a_colidx_to_rowidx_, keymap);
local_bcast_b_ = std::make_unique<LocalBcastB>(local_b_ijk_, b_ijk_, a_colidx_to_rowidx_, keymap);
multiplyadd_ =
std::make_unique<MultiplyAdd>(a_ijk_, b_ijk_, c_ijk_, c, a_rowidx_to_colidx_, b_colidx_to_rowidx_, keymap);

Expand All @@ -287,56 +290,133 @@ class SpMM {
TTGUNUSED(multiplyadd_);
}

/// broadcast A[i][k] to all {i,j,k} such that B[j][k] exists

/// Locally broadcast A[i][k] to all {i,j,k} such that B[j][k] exists
class LocalBcastA : public Op<Key<3>, std::tuple<Out<Key<3>, Blk>>, LocalBcastA, Blk> {
public:
using baseT = Op<Key<3>, std::tuple<Out<Key<3>, Blk>>, LocalBcastA, Blk>;

LocalBcastA(Edge<Key<3>, Blk> &a, Edge<Key<3>, Blk> &a_ijk, const std::vector<std::vector<long>> &b_rowidx_to_colidx, Keymap keymap)
: baseT(edges(a), edges(a_ijk), "SpMM::local_bcast_a", {"a_ik"}, {"a_ijk"}, [](const Key<3>& key){ return key[2]; })
, b_rowidx_to_colidx_(b_rowidx_to_colidx), keymap_(keymap) {}

void op(const Key<3> &key, typename baseT::input_values_tuple_type &&a_ik, std::tuple<Out<Key<3>, Blk>> &a_ijk) {
const auto i = key[0];
const auto k = key[1];
auto world = get_default_world();
assert(key[2] == world.rank());
if (tracing()) ttg::print("LocalBcastA(", i, ", ", k, ")");
if (k >= b_rowidx_to_colidx_.size()) return;
// broadcast a_ik to all existing {i,j,k}
std::vector<Key<3>> ijk_keys;
for (auto &j : b_rowidx_to_colidx_[k]) {
if (tracing()) ttg::print("Broadcasting A[", i, "][", k, "] to j=", j);
if (keymap_(Key<2>({i, j})) == world.rank()) {
ijk_keys.emplace_back(Key<3>({i, j, k}));
}
}
::broadcast<0>(ijk_keys, baseT::template get<0>(a_ik), a_ijk);
}

private:
const std::vector<std::vector<long>> &b_rowidx_to_colidx_;
Keymap keymap_;
}; // class LocalBcastA


/// broadcast A[i][k] to all procs where B[j][k]
class BcastA : public Op<Key<2>, std::tuple<Out<Key<3>, Blk>>, BcastA, Blk> {
public:
using baseT = Op<Key<2>, std::tuple<Out<Key<3>, Blk>>, BcastA, Blk>;

template<typename Keymap>
BcastA(Edge<Key<2>, Blk> &a, Edge<Key<3>, Blk> &a_ijk, const std::vector<std::vector<long>> &b_rowidx_to_colidx, Keymap&& keymap)
: baseT(edges(a), edges(a_ijk), "SpMM::bcast_a", {"a_ik"}, {"a_ijk"}, keymap)
BcastA(Edge<Key<2>, Blk> &a, Edge<Key<3>, Blk> &a_ikp, const std::vector<std::vector<long>> &b_rowidx_to_colidx, Keymap keymap)
: baseT(edges(a), edges(a_ikp), "SpMM::bcast_a", {"a_ik"}, {"a_ikp"}, keymap)
, b_rowidx_to_colidx_(b_rowidx_to_colidx) {}

void op(const Key<2> &key, typename baseT::input_values_tuple_type &&a_ik, std::tuple<Out<Key<3>, Blk>> &a_ijk) {
void op(const Key<2> &key, typename baseT::input_values_tuple_type &&a_ik, std::tuple<Out<Key<3>, Blk>> &a_ikp) {
const auto i = key[0];
const auto k = key[1];
if (tracing()) ttg::print("BcastA(", i, ", ", k, ")");
// broadcast a_ik to all existing {i,j,k}
std::vector<Key<3>> ijk_keys;
std::vector<Key<3>> ikp_keys;
if (k >= b_rowidx_to_colidx_.size()) return;
auto world = get_default_world();
std::vector<bool> procmap(world.size());
auto keymap = baseT::get_keymap();
for (auto &j : b_rowidx_to_colidx_[k]) {
if (tracing()) ttg::print("Broadcasting A[", i, "][", k, "] to j=", j);
ijk_keys.emplace_back(Key<3>({i, j, k}));
long proc = keymap(Key<2>({i, j}));
if (!procmap[proc]) {
if (tracing()) ttg::print("Broadcasting A[", i, "][", k, "] to proc ", proc);
ikp_keys.emplace_back(Key<3>({i, k, proc}));
procmap[proc] = true;
}
}
::broadcast<0>(ijk_keys, baseT::template get<0>(a_ik), a_ijk);
::broadcast<0>(ikp_keys, baseT::template get<0>(a_ik), a_ikp);
}

private:
const std::vector<std::vector<long>> &b_rowidx_to_colidx_;
}; // class BcastA

/// broadcast B[k][j] to all {i,j,k} such that A[i][k] exists
class LocalBcastB : public Op<Key<3>, std::tuple<Out<Key<3>, Blk>>, LocalBcastB, Blk> {
public:
using baseT = Op<Key<3>, std::tuple<Out<Key<3>, Blk>>, LocalBcastB, Blk>;

LocalBcastB(Edge<Key<3>, Blk> &b, Edge<Key<3>, Blk> &b_ijk, const std::vector<std::vector<long>> &a_colidx_to_rowidx, Keymap keymap)
: baseT(edges(b), edges(b_ijk), "SpMM::local_bcast_b", {"b_kj"}, {"b_ijk"}, [](const Key<3> &key){ return key[2]; })
, a_colidx_to_rowidx_(a_colidx_to_rowidx), keymap_(keymap) {}

void op(const Key<3> &key, typename baseT::input_values_tuple_type &&b_kj, std::tuple<Out<Key<3>, Blk>> &b_ijk) {
const auto k = key[0];
const auto j = key[1];
auto world = get_default_world();
assert(key[2] == world.rank());
if (tracing()) ttg::print("BcastB(", k, ", ", j, ")");
if (k >= a_colidx_to_rowidx_.size()) return;
// broadcast b_kj to *jk
std::vector<Key<3>> ijk_keys;
for (auto &i : a_colidx_to_rowidx_[k]) {
if (tracing()) ttg::print("Broadcasting B[", k, "][", j, "] to i=", i);
if (keymap_(Key<2>({i, j})) == world.rank()) {
ijk_keys.emplace_back(Key<3>({i, j, k}));
}
}
::broadcast<0>(ijk_keys, baseT::template get<0>(b_kj), b_ijk);
}

private:
const std::vector<std::vector<long>> &a_colidx_to_rowidx_;
Keymap keymap_;
}; // class BcastA

/// broadcast B[k][j] to all {i,j,k} such that A[i][k] exists
class BcastB : public Op<Key<2>, std::tuple<Out<Key<3>, Blk>>, BcastB, Blk> {
public:
using baseT = Op<Key<2>, std::tuple<Out<Key<3>, Blk>>, BcastB, Blk>;

template<typename Keymap>
BcastB(Edge<Key<2>, Blk> &b, Edge<Key<3>, Blk> &b_ijk, const std::vector<std::vector<long>> &a_colidx_to_rowidx, Keymap&& keymap)
: baseT(edges(b), edges(b_ijk), "SpMM::bcast_b", {"b_kj"}, {"b_ijk"}, keymap)
BcastB(Edge<Key<2>, Blk> &b, Edge<Key<3>, Blk> &b_kjp, const std::vector<std::vector<long>> &a_colidx_to_rowidx, Keymap keymap)
: baseT(edges(b), edges(b_kjp), "SpMM::bcast_b", {"b_kjp"}, {"b_ijk"}, keymap)
, a_colidx_to_rowidx_(a_colidx_to_rowidx) {}

void op(const Key<2> &key, typename baseT::input_values_tuple_type &&b_kj, std::tuple<Out<Key<3>, Blk>> &b_ijk) {
void op(const Key<2> &key, typename baseT::input_values_tuple_type &&b_kj, std::tuple<Out<Key<3>, Blk>> &b_kjp) {
const auto k = key[0];
const auto j = key[1];
// broadcast b_kj to *jk
std::vector<Key<3>> ijk_keys;
std::vector<Key<3>> kjp_keys;
if (tracing()) ttg::print("BcastB(", k, ", ", j, ")");
if (k >= a_colidx_to_rowidx_.size()) return;
auto world = get_default_world();
std::vector<bool> procmap(world.size());
for (auto &i : a_colidx_to_rowidx_[k]) {
if (tracing()) ttg::print("Broadcasting B[", k, "][", j, "] to i=", i);
ijk_keys.emplace_back(Key<3>({i, j, k}));
long proc = baseT::get_keymap()(Key<2>({i, j}));
if (!procmap[proc]) {
if (tracing()) ttg::print("Broadcasting A[", k, "][", j, "] to proc ", proc);
kjp_keys.emplace_back(Key<3>({k, j, proc}));
procmap[proc] = true;
}
}
::broadcast<0>(ijk_keys, baseT::template get<0>(b_kj), b_ijk);
::broadcast<0>(kjp_keys, baseT::template get<0>(b_kj), b_kjp);
}

private:
Expand All @@ -349,10 +429,10 @@ class SpMM {
public:
using baseT = Op<Key<3>, std::tuple<Out<Key<2>, Blk>, Out<Key<3>, Blk>>, MultiplyAdd, const Blk, const Blk, Blk>;

template<typename Keymap>
MultiplyAdd(Edge<Key<3>, Blk> &a_ijk, Edge<Key<3>, Blk> &b_ijk, Edge<Key<3>, Blk> &c_ijk, Edge<Key<2>, Blk> &c,
const std::vector<std::vector<long>> &a_rowidx_to_colidx,
const std::vector<std::vector<long>> &b_colidx_to_rowidx, Keymap &&keymap)
const std::vector<std::vector<long>> &b_colidx_to_rowidx,
Keymap keymap)
: baseT(edges(a_ijk, b_ijk, c_ijk), edges(c, c_ijk), "SpMM::MultiplyAdd", {"a_ijk", "b_ijk", "c_ijk"},
{"c_ij", "c_ijk"},
[keymap](const Key<3> &key) {
Expand Down Expand Up @@ -500,14 +580,18 @@ class SpMM {

private:
Edge<Key<3>, Blk> a_ijk_;
Edge<Key<3>, Blk> local_a_ijk_;
Edge<Key<3>, Blk> b_ijk_;
Edge<Key<3>, Blk> local_b_ijk_;
Edge<Key<3>, Blk> c_ijk_;
std::vector<std::vector<long>> a_rowidx_to_colidx_;
std::vector<std::vector<long>> b_colidx_to_rowidx_;
std::vector<std::vector<long>> a_colidx_to_rowidx_;
std::vector<std::vector<long>> b_rowidx_to_colidx_;
std::unique_ptr<BcastA> bcast_a_;
std::unique_ptr<LocalBcastA> local_bcast_a_;
std::unique_ptr<BcastB> bcast_b_;
std::unique_ptr<LocalBcastB> local_bcast_b_;
std::unique_ptr<MultiplyAdd> multiplyadd_;

// result[i][j] gives the j-th nonzero row for column i in matrix mat
Expand Down

0 comments on commit e95d219

Please sign in to comment.