Skip to content

Commit

Permalink
btas::Tensor: do not fill with zeroes when constructing (e.g., in zer…
Browse files Browse the repository at this point in the history
…o-copy serialization), unless necessary
  • Loading branch information
evaleev committed Jan 30, 2024
1 parent d42cb92 commit a85c2c4
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
13 changes: 8 additions & 5 deletions examples/spmm/spmm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ namespace ttg {
return boost::container::small_vector<iovec, 1>{};
}
static auto create_from_metadata(const std::pair<int, int> &meta) {
if (meta != std::pair{0, 0})
return blk_t(btas::Range(std::get<0>(meta), std::get<1>(meta)), 0.0);
if (meta != std::pair{0, 0}) // N.B. allocate only, do not fill with zeroes
return blk_t(btas::Range(std::get<0>(meta), std::get<1>(meta)));
else
return blk_t{};
}
Expand Down Expand Up @@ -121,10 +121,13 @@ namespace btas {
btas::Tensor<T_, Range_, Store_> gemm(btas::Tensor<T_, Range_, Store_> &&C, const btas::Tensor<T_, Range_, Store_> &A,
const btas::Tensor<T_, Range_, Store_> &B) {
using array = btas::DEFAULT::index<int>;
if (C.empty()) {
C = btas::Tensor<T_, Range_, Store_>(btas::Range(A.range().extent(0), B.range().extent(1)), 0.0);
if (C.empty()) { // first contribution to C = allocate it and gemm with beta=0
C = btas::Tensor<T_, Range_, Store_>(btas::Range(A.range().extent(0), B.range().extent(1)));
btas::contract_222(1.0, A, array{1, 2}, B, array{2, 3}, 0.0, C, array{1, 3}, false, false);
}
else { // subsequent contributions to C = gemm with beta=1
btas::contract_222(1.0, A, array{1, 2}, B, array{2, 3}, 1.0, C, array{1, 3}, false, false);
}
btas::contract_222(1.0, A, array{1, 2}, B, array{2, 3}, 1.0, C, array{1, 3}, false, false);
return std::move(C);
}
} // namespace btas
Expand Down
13 changes: 8 additions & 5 deletions examples/spmm/spmm_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -372,8 +372,8 @@ namespace ttg {
return boost::container::small_vector<iovec, 1>{};
}
static auto create_from_metadata(const std::pair<int, int> &meta) {
if (meta != std::pair{0, 0})
return blk_t(btas::Range(std::get<0>(meta), std::get<1>(meta)), 0.0);
if (meta != std::pair{0, 0}) // N.B. allocate only, do not fill with zeroes
return blk_t(btas::Range(std::get<0>(meta), std::get<1>(meta)));
else
return blk_t{};
}
Expand Down Expand Up @@ -422,10 +422,13 @@ namespace btas {
btas::Tensor<T_, Range_, Store_> gemm(btas::Tensor<T_, Range_, Store_> &&C, const btas::Tensor<T_, Range_, Store_> &A,
const btas::Tensor<T_, Range_, Store_> &B) {
using array = btas::DEFAULT::index<int>;
if (C.empty()) {
C = btas::Tensor<T_, Range_, Store_>(btas::Range(A.range().extent(0), B.range().extent(1)), 0.0);
if (C.empty()) { // first contribution to C = allocate it and gemm with beta=0
C = btas::Tensor<T_, Range_, Store_>(btas::Range(A.range().extent(0), B.range().extent(1)));
btas::contract_222(1.0, A, array{1, 2}, B, array{2, 3}, 0.0, C, array{1, 3}, false, false);
}
else { // subsequent contributions to C = gemm with beta=1
btas::contract_222(1.0, A, array{1, 2}, B, array{2, 3}, 1.0, C, array{1, 3}, false, false);
}
btas::contract_222(1.0, A, array{1, 2}, B, array{2, 3}, 1.0, C, array{1, 3}, false, false);
return std::move(C);
}
} // namespace btas
Expand Down

0 comments on commit a85c2c4

Please sign in to comment.