Skip to content

Commit

Permalink
reduce_modes function impl. amended to handle sparse dist-arrays.
Browse files Browse the repository at this point in the history
  • Loading branch information
bimalgaudel committed May 16, 2024
1 parent 40705d7 commit 6e864eb
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions src/TiledArray/einsum/tiledarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ auto reduce_modes(TA::DistArray<T, Ts...> orig, size_t drank) {

tile_type res(rng, typename tile_type::value_type{});

bool all_summed_tiles_zeros{true};
for (auto &&r : delta_trange.tiles_range()) {
container::svector<TA::Range::index1_type> ix1s = rng.lobound();

Expand All @@ -347,11 +348,18 @@ auto reduce_modes(TA::DistArray<T, Ts...> orig, size_t drank) {
}

auto tix = orig.trange().element_to_tile(ix1s);
if constexpr (std::is_same_v<typename DistArray<T, Ts...>::policy_type,
SparsePolicy>)
if (orig.is_zero(tix)) continue;
auto got = orig.find_local(tix).get(false);

res += reduce_modes(got, drank);
all_summed_tiles_zeros = false;
}

if (all_summed_tiles_zeros)
return typename std::remove_reference_t<decltype(tile)>::scalar_type{0};

tile = res;
return res.norm();
};
Expand Down

0 comments on commit 6e864eb

Please sign in to comment.