Skip to content

Commit

Permalink
Merge pull request #4 from devreal/spmm-with-seed-hierarchical-bcast
Browse files Browse the repository at this point in the history
Spmm with seed hierarchical bcast
  • Loading branch information
therault authored Sep 16, 2021
2 parents 73f137d + e95d219 commit 48156aa
Showing 1 changed file with 111 additions and 26 deletions.
137 changes: 111 additions & 26 deletions examples/spmm/spmm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -256,15 +256,16 @@ class Write_SpMatrix : public Op<Key<2>, std::tuple<>, Write_SpMatrix<Blk>, Blk>
};

// sparse mm
template <typename Blk = blk_t>
template <typename Keymap = std::function<int(const Key<2>&)>, typename Blk = blk_t>
class SpMM {
public:
template<typename Keymap>
SpMM(Edge<Key<2>, Blk> &a, Edge<Key<2>, Blk> &b, Edge<Key<2>, Blk> &c, const SpMatrix<Blk> &a_mat,
const SpMatrix<Blk> &b_mat, std::map<std::tuple<int, int>, bool> &Afilling,
std::map<std::tuple<int, int>, bool> &Bfilling, Keymap &&keymap)
std::map<std::tuple<int, int>, bool> &Bfilling, const Keymap& keymap)
: a_ijk_()
, local_a_ijk_()
, b_ijk_()
, local_b_ijk_()
, c_ijk_()
, a_rowidx_to_colidx_(make_rowidx_to_colidx(Afilling))
, b_colidx_to_rowidx_(make_colidx_to_rowidx(Bfilling))
Expand All @@ -277,8 +278,10 @@ class SpMM {
ttg_broadcast(ttg_default_execution_context(), a_colidx_to_rowidx_, root);
ttg_broadcast(ttg_default_execution_context(), b_colidx_to_rowidx_, root);

bcast_a_ = std::make_unique<BcastA>(a, a_ijk_, b_rowidx_to_colidx_, keymap);
bcast_b_ = std::make_unique<BcastB>(b, b_ijk_, a_colidx_to_rowidx_, keymap);
bcast_a_ = std::make_unique<BcastA>(a, local_a_ijk_, b_rowidx_to_colidx_, keymap);
local_bcast_a_ = std::make_unique<LocalBcastA>(local_a_ijk_, a_ijk_, b_rowidx_to_colidx_, keymap);
bcast_b_ = std::make_unique<BcastB>(b, local_b_ijk_, a_colidx_to_rowidx_, keymap);
local_bcast_b_ = std::make_unique<LocalBcastB>(local_b_ijk_, b_ijk_, a_colidx_to_rowidx_, keymap);
multiplyadd_ =
std::make_unique<MultiplyAdd>(a_ijk_, b_ijk_, c_ijk_, c, a_rowidx_to_colidx_, b_colidx_to_rowidx_, keymap);

Expand All @@ -287,56 +290,133 @@ class SpMM {
TTGUNUSED(multiplyadd_);
}

/// broadcast A[i][k] to all {i,j,k} such that B[j][k] exists

/// Locally broadcast A[i][k] to all {i,j,k} such that B[j][k] exists
class LocalBcastA : public Op<Key<3>, std::tuple<Out<Key<3>, Blk>>, LocalBcastA, Blk> {
public:
using baseT = Op<Key<3>, std::tuple<Out<Key<3>, Blk>>, LocalBcastA, Blk>;

LocalBcastA(Edge<Key<3>, Blk> &a, Edge<Key<3>, Blk> &a_ijk, const std::vector<std::vector<long>> &b_rowidx_to_colidx, Keymap keymap)
: baseT(edges(a), edges(a_ijk), "SpMM::local_bcast_a", {"a_ik"}, {"a_ijk"}, [](const Key<3>& key){ return key[2]; })
, b_rowidx_to_colidx_(b_rowidx_to_colidx), keymap_(keymap) {}

void op(const Key<3> &key, typename baseT::input_values_tuple_type &&a_ik, std::tuple<Out<Key<3>, Blk>> &a_ijk) {
const auto i = key[0];
const auto k = key[1];
auto world = get_default_world();
assert(key[2] == world.rank());
if (tracing()) ttg::print("LocalBcastA(", i, ", ", k, ")");
if (k >= b_rowidx_to_colidx_.size()) return;
// broadcast a_ik to all existing {i,j,k}
std::vector<Key<3>> ijk_keys;
for (auto &j : b_rowidx_to_colidx_[k]) {
if (tracing()) ttg::print("Broadcasting A[", i, "][", k, "] to j=", j);
if (keymap_(Key<2>({i, j})) == world.rank()) {
ijk_keys.emplace_back(Key<3>({i, j, k}));
}
}
::broadcast<0>(ijk_keys, baseT::template get<0>(a_ik), a_ijk);
}

private:
const std::vector<std::vector<long>> &b_rowidx_to_colidx_;
Keymap keymap_;
}; // class LocalBcastA


/// broadcast A[i][k] to all procs where B[j][k]
class BcastA : public Op<Key<2>, std::tuple<Out<Key<3>, Blk>>, BcastA, Blk> {
public:
using baseT = Op<Key<2>, std::tuple<Out<Key<3>, Blk>>, BcastA, Blk>;

template<typename Keymap>
BcastA(Edge<Key<2>, Blk> &a, Edge<Key<3>, Blk> &a_ijk, const std::vector<std::vector<long>> &b_rowidx_to_colidx, Keymap&& keymap)
: baseT(edges(a), edges(a_ijk), "SpMM::bcast_a", {"a_ik"}, {"a_ijk"}, keymap)
BcastA(Edge<Key<2>, Blk> &a, Edge<Key<3>, Blk> &a_ikp, const std::vector<std::vector<long>> &b_rowidx_to_colidx, Keymap keymap)
: baseT(edges(a), edges(a_ikp), "SpMM::bcast_a", {"a_ik"}, {"a_ikp"}, keymap)
, b_rowidx_to_colidx_(b_rowidx_to_colidx) {}

void op(const Key<2> &key, typename baseT::input_values_tuple_type &&a_ik, std::tuple<Out<Key<3>, Blk>> &a_ijk) {
void op(const Key<2> &key, typename baseT::input_values_tuple_type &&a_ik, std::tuple<Out<Key<3>, Blk>> &a_ikp) {
const auto i = key[0];
const auto k = key[1];
if (tracing()) ttg::print("BcastA(", i, ", ", k, ")");
// broadcast a_ik to all existing {i,j,k}
std::vector<Key<3>> ijk_keys;
std::vector<Key<3>> ikp_keys;
if (k >= b_rowidx_to_colidx_.size()) return;
auto world = get_default_world();
std::vector<bool> procmap(world.size());
auto keymap = baseT::get_keymap();
for (auto &j : b_rowidx_to_colidx_[k]) {
if (tracing()) ttg::print("Broadcasting A[", i, "][", k, "] to j=", j);
ijk_keys.emplace_back(Key<3>({i, j, k}));
long proc = keymap(Key<2>({i, j}));
if (!procmap[proc]) {
if (tracing()) ttg::print("Broadcasting A[", i, "][", k, "] to proc ", proc);
ikp_keys.emplace_back(Key<3>({i, k, proc}));
procmap[proc] = true;
}
}
::broadcast<0>(ijk_keys, baseT::template get<0>(a_ik), a_ijk);
::broadcast<0>(ikp_keys, baseT::template get<0>(a_ik), a_ikp);
}

private:
const std::vector<std::vector<long>> &b_rowidx_to_colidx_;
}; // class BcastA

/// broadcast B[k][j] to all {i,j,k} such that A[i][k] exists
class LocalBcastB : public Op<Key<3>, std::tuple<Out<Key<3>, Blk>>, LocalBcastB, Blk> {
public:
using baseT = Op<Key<3>, std::tuple<Out<Key<3>, Blk>>, LocalBcastB, Blk>;

LocalBcastB(Edge<Key<3>, Blk> &b, Edge<Key<3>, Blk> &b_ijk, const std::vector<std::vector<long>> &a_colidx_to_rowidx, Keymap keymap)
: baseT(edges(b), edges(b_ijk), "SpMM::local_bcast_b", {"b_kj"}, {"b_ijk"}, [](const Key<3> &key){ return key[2]; })
, a_colidx_to_rowidx_(a_colidx_to_rowidx), keymap_(keymap) {}

void op(const Key<3> &key, typename baseT::input_values_tuple_type &&b_kj, std::tuple<Out<Key<3>, Blk>> &b_ijk) {
const auto k = key[0];
const auto j = key[1];
auto world = get_default_world();
assert(key[2] == world.rank());
if (tracing()) ttg::print("BcastB(", k, ", ", j, ")");
if (k >= a_colidx_to_rowidx_.size()) return;
// broadcast b_kj to *jk
std::vector<Key<3>> ijk_keys;
for (auto &i : a_colidx_to_rowidx_[k]) {
if (tracing()) ttg::print("Broadcasting B[", k, "][", j, "] to i=", i);
if (keymap_(Key<2>({i, j})) == world.rank()) {
ijk_keys.emplace_back(Key<3>({i, j, k}));
}
}
::broadcast<0>(ijk_keys, baseT::template get<0>(b_kj), b_ijk);
}

private:
const std::vector<std::vector<long>> &a_colidx_to_rowidx_;
Keymap keymap_;
}; // class BcastA

/// broadcast B[k][j] to all {i,j,k} such that A[i][k] exists
class BcastB : public Op<Key<2>, std::tuple<Out<Key<3>, Blk>>, BcastB, Blk> {
public:
using baseT = Op<Key<2>, std::tuple<Out<Key<3>, Blk>>, BcastB, Blk>;

template<typename Keymap>
BcastB(Edge<Key<2>, Blk> &b, Edge<Key<3>, Blk> &b_ijk, const std::vector<std::vector<long>> &a_colidx_to_rowidx, Keymap&& keymap)
: baseT(edges(b), edges(b_ijk), "SpMM::bcast_b", {"b_kj"}, {"b_ijk"}, keymap)
BcastB(Edge<Key<2>, Blk> &b, Edge<Key<3>, Blk> &b_kjp, const std::vector<std::vector<long>> &a_colidx_to_rowidx, Keymap keymap)
: baseT(edges(b), edges(b_kjp), "SpMM::bcast_b", {"b_kjp"}, {"b_ijk"}, keymap)
, a_colidx_to_rowidx_(a_colidx_to_rowidx) {}

void op(const Key<2> &key, typename baseT::input_values_tuple_type &&b_kj, std::tuple<Out<Key<3>, Blk>> &b_ijk) {
void op(const Key<2> &key, typename baseT::input_values_tuple_type &&b_kj, std::tuple<Out<Key<3>, Blk>> &b_kjp) {
const auto k = key[0];
const auto j = key[1];
// broadcast b_kj to *jk
std::vector<Key<3>> ijk_keys;
std::vector<Key<3>> kjp_keys;
if (tracing()) ttg::print("BcastB(", k, ", ", j, ")");
if (k >= a_colidx_to_rowidx_.size()) return;
auto world = get_default_world();
std::vector<bool> procmap(world.size());
for (auto &i : a_colidx_to_rowidx_[k]) {
if (tracing()) ttg::print("Broadcasting B[", k, "][", j, "] to i=", i);
ijk_keys.emplace_back(Key<3>({i, j, k}));
long proc = baseT::get_keymap()(Key<2>({i, j}));
if (!procmap[proc]) {
if (tracing()) ttg::print("Broadcasting A[", k, "][", j, "] to proc ", proc);
kjp_keys.emplace_back(Key<3>({k, j, proc}));
procmap[proc] = true;
}
}
::broadcast<0>(ijk_keys, baseT::template get<0>(b_kj), b_ijk);
::broadcast<0>(kjp_keys, baseT::template get<0>(b_kj), b_kjp);
}

private:
Expand All @@ -349,10 +429,10 @@ class SpMM {
public:
using baseT = Op<Key<3>, std::tuple<Out<Key<2>, Blk>, Out<Key<3>, Blk>>, MultiplyAdd, const Blk, const Blk, Blk>;

template<typename Keymap>
MultiplyAdd(Edge<Key<3>, Blk> &a_ijk, Edge<Key<3>, Blk> &b_ijk, Edge<Key<3>, Blk> &c_ijk, Edge<Key<2>, Blk> &c,
const std::vector<std::vector<long>> &a_rowidx_to_colidx,
const std::vector<std::vector<long>> &b_colidx_to_rowidx, Keymap &&keymap)
const std::vector<std::vector<long>> &b_colidx_to_rowidx,
Keymap keymap)
: baseT(edges(a_ijk, b_ijk, c_ijk), edges(c, c_ijk), "SpMM::MultiplyAdd", {"a_ijk", "b_ijk", "c_ijk"},
{"c_ij", "c_ijk"},
[keymap](const Key<3> &key) {
Expand Down Expand Up @@ -500,14 +580,18 @@ class SpMM {

private:
Edge<Key<3>, Blk> a_ijk_;
Edge<Key<3>, Blk> local_a_ijk_;
Edge<Key<3>, Blk> b_ijk_;
Edge<Key<3>, Blk> local_b_ijk_;
Edge<Key<3>, Blk> c_ijk_;
std::vector<std::vector<long>> a_rowidx_to_colidx_;
std::vector<std::vector<long>> b_colidx_to_rowidx_;
std::vector<std::vector<long>> a_colidx_to_rowidx_;
std::vector<std::vector<long>> b_rowidx_to_colidx_;
std::unique_ptr<BcastA> bcast_a_;
std::unique_ptr<LocalBcastA> local_bcast_a_;
std::unique_ptr<BcastB> bcast_b_;
std::unique_ptr<LocalBcastB> local_bcast_b_;
std::unique_ptr<MultiplyAdd> multiplyadd_;

// result[i][j] gives the j-th nonzero row for column i in matrix mat
Expand Down Expand Up @@ -1209,6 +1293,7 @@ int main(int argc, char **argv) {
}
} else {
// flow graph needs to exist on every node
auto keymap_write = [](const Key<2> &key) { return 0; };
Edge<Key<2>> ctl("control");
Control control(ctl);
Edge<Key<2>, blk_t> eA, eB, eC;
Expand All @@ -1220,14 +1305,14 @@ int main(int argc, char **argv) {
};
Read_SpMatrix<> a("A", A, ctl, eA, keymap);
Read_SpMatrix<> b("B", B, ctl, eB, keymap);
Write_SpMatrix<> c(C, eC, keymap);
Write_SpMatrix<> c(C, eC, keymap_write);
auto &c_status = c.status();
assert(!has_value(c_status));
// SpMM a_times_b(world, eA, eB, eC, A, B);
SpMM<> a_times_b(eA, eB, eC, A, B, Afilling, Bfilling, keymap);
TTGUNUSED(a_times_b);

std::cout << Dot{}(&a, &b) << std::endl;
if (get_default_world().rank() == 0) std::cout << Dot{}(&a, &b) << std::endl;

// ready to run!
auto connected = make_graph_executable(&control);
Expand Down

0 comments on commit 48156aa

Please sign in to comment.