Skip to content

Commit

Permalink
SPMM bcasts: use input_refs_tuple_type and and const
Browse files Browse the repository at this point in the history
Signed-off-by: Joseph Schuchart <[email protected]>
  • Loading branch information
devreal committed Aug 14, 2024
1 parent 1948ee7 commit 29c9771
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions examples/spmm/spmm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ class SpMM25D {
}; // class LocalBcastA

/// broadcast `A[i][k]` to all processors which will contain at least one `C[i][j]` such that `B[k][j]` exists
class BcastA : public TT<Key<2>, std::tuple<Out<Key<3>, Blk>>, BcastA, ttg::typelist<Blk>> {
class BcastA : public TT<Key<2>, std::tuple<Out<Key<3>, Blk>>, BcastA, ttg::typelist<const Blk>> {
public:
using baseT = typename BcastA::ttT;

Expand All @@ -462,7 +462,7 @@ class SpMM25D {
});
}

void op(const Key<2> &ik, typename baseT::input_values_tuple_type &&a_ik,
void op(const Key<2> &ik, typename baseT::input_refs_tuple_type &&a_ik,
std::tuple<Out<Key<3>, Blk>> &outs) {
const auto i = ik[0]; // row
const auto k = ik[1]; // col
Expand Down Expand Up @@ -492,7 +492,7 @@ class SpMM25D {

/// Locally broadcast `B[k][j]` assigned to this processor `p` to matmul tasks `{i,j,k}` for all `k` such that
/// `A[i][k]` exists AND `k` contribution to `C[i][j]` is assigned to this processor
class LocalBcastB : public TT<Key<3>, std::tuple<Out<Key<3>, Blk>>, LocalBcastB, ttg::typelist<Blk>> {
class LocalBcastB : public TT<Key<3>, std::tuple<Out<Key<3>, Blk>>, LocalBcastB, ttg::typelist<const Blk>> {
public:
using baseT = typename LocalBcastB::ttT;

Expand Down Expand Up @@ -528,7 +528,7 @@ class SpMM25D {
}; // class LocalBcastB

/// broadcast `B[k][j]` to all processors which will contain at least one `C[i][j]` such that `A[i][k]` exists
class BcastB : public TT<Key<2>, std::tuple<Out<Key<3>, Blk>>, BcastB, ttg::typelist<Blk>> {
class BcastB : public TT<Key<2>, std::tuple<Out<Key<3>, Blk>>, BcastB, ttg::typelist<const Blk>> {
public:
using baseT = typename BcastB::ttT;

Expand All @@ -544,7 +544,7 @@ class SpMM25D {
});
}

void op(const Key<2> &kj, typename baseT::input_values_tuple_type &&b_kj,
void op(const Key<2> &kj, typename baseT::input_refs_tuple_type &&b_kj,
std::tuple<Out<Key<3>, Blk>> &outs) {
const auto k = kj[0]; // row
const auto j = kj[1]; // col
Expand Down

0 comments on commit 29c9771

Please sign in to comment.