Skip to content

Commit

Permalink
More corner cases of ToT evaluations supported.
Browse files Browse the repository at this point in the history
  • Loading branch information
bimalgaudel committed Apr 17, 2024
1 parent 154d427 commit 2bfd5aa
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 0 deletions.
117 changes: 117 additions & 0 deletions src/TiledArray/einsum/tiledarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,101 @@ auto replicate_array(Array from, TiledRange const &prepend_trng) {
return result;
}

template <typename T, typename... Ts>
auto reduce_modes(Tensor<T, Ts...> const &orig, size_t drank) {
TA_ASSERT(orig.nbatch() == 1);
auto const orig_rng = orig.range();
TA_ASSERT(orig_rng.rank() > drank);

auto const result_rng = [orig_rng, drank]() {
container::vector<Range1> r1s;
for (auto i = 0; i < orig_rng.rank() - drank; ++i)
r1s.emplace_back(orig_rng.dim(i));
return TA::Range(r1s);
}();

auto const delta_rng = [orig_rng, drank]() {
container::vector<Range1> r1s;
for (auto i = orig_rng.rank() - drank; i < orig_rng.rank(); ++i)
r1s.emplace_back(orig_rng.dim(i));
return TA::Range(r1s);
}();

auto const delta_vol = delta_rng.volume();

auto reducer = [orig, delta_vol, delta_rng](auto const &ix) {
auto orig_ix = ix;
std::copy(delta_rng.lobound().begin(), //
delta_rng.lobound().end(), //
std::back_inserter(orig_ix));

auto beg = orig.data() + orig.range().ordinal(orig_ix);
auto end = beg + delta_vol;

// cannot get it done this way: return std::reduce(beg, end);

typename std::iterator_traits<decltype(beg)>::value_type sum{};
for (; beg != end; ++beg) sum += *beg;
return sum;
};

return Tensor<T, Ts...>(result_rng, reducer);
}

///
/// \param orig Input DistArray.
/// \param dmodes Reduce this many modes from the end as implied in the
/// tiled range of the input array.
/// \return Array with reduced rank.
///
template <typename T, typename... Ts>
auto reduce_modes(TA::DistArray<T, Ts...> orig, size_t drank) {
TA_ASSERT(orig.trange().rank() > drank);

auto const result_trange = [orig, drank]() {
container::svector<TiledRange1> tr1s;
for (auto i = 0; i < (orig.trange().rank() - drank); ++i)
tr1s.emplace_back(orig.trange().at(i));
return TiledRange(tr1s);
}();

auto const delta_trange = [orig, drank]() {
container::svector<TiledRange1> tr1s;
for (auto i = orig.trange().rank() - drank; i < orig.trange().rank(); ++i)
tr1s.emplace_back(orig.trange().at(i));
return TiledRange(tr1s);
}();

orig.make_replicated();
orig.world().gop.fence();

auto make_tile = [orig, delta_trange, drank](auto &tile, auto const &rng) {
using tile_type = std::remove_reference_t<decltype(tile)>;

tile_type res(rng, typename tile_type::value_type{});

for (auto &&r : delta_trange.tiles_range()) {
container::svector<TA::Range::index1_type> ix1s = rng.lobound();

{
auto dlo = delta_trange.make_tile_range(r).lobound();
std::copy(dlo.begin(), dlo.end(), std::back_inserter(ix1s));
}

auto tix = orig.trange().element_to_tile(ix1s);
auto got = orig.find_local(tix).get(false);

res += reduce_modes(got, drank);
}

tile = res;
return res.norm();
};

return make_array<DistArray<T, Ts...>>(orig.world(), result_trange,
make_tile);
}

template <typename Ixs>
TiledRange make_trange(RangeMap const &map, Ixs const &ixs) {
container::svector<TiledRange1> tr1s;
Expand Down Expand Up @@ -320,6 +415,28 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
return C;
}

//
// special Hadamard + contraction
// when ToT times T implied and T's indices are contraction AND Hadamard
// BUT not externals
//
if constexpr (!AreArraySame<ArrayA, ArrayB> &&
DeNestFlag == DeNest::False) {
auto hi_size = h.size() + i.size();
if (hi_size != h.size() && hi_size != i.size() &&
((hi_size == a.size() && IsArrayT<ArrayA>) ||
(hi_size == b.size() && IsArrayT<ArrayB>))) {
auto annot_c = std::string(h + e + i) + inner.c;
auto temp1 = einsum(A, B, idx<ArrayC>(annot_c), world);
auto temp2 = reduce_modes(temp1, i.size());

auto annot_c_ = std::string(h + e) + inner.c;
decltype(temp2) result;
result(std::string(c) + inner.c) = temp2(annot_c_);
return result;
}
}

using ::Einsum::index::permutation;
using TiledArray::Permutation;

Expand Down
20 changes: 20 additions & 0 deletions tests/einsum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,26 @@ BOOST_AUTO_TEST_CASE(corner_cases) {
{{0, 3, 5}, {0, 3, 8}}, //
{{0, 3, 8}, {0, 3, 5}, {0, 2}}, //
{3, 9})));

BOOST_REQUIRE(check_manual_eval<ArrayT>("bi,bi->i", //
{{0, 2}, {0, 4}}, //
{{0, 2}, {0, 4}}));

BOOST_REQUIRE(check_manual_eval<ArrayToT>("bi;a,bi;a->i;a", //
{{0, 2}, {0, 4}}, //
{{0, 2}, {0, 4}}, //
{3}, {3}));

BOOST_REQUIRE(
(check_manual_eval<ArrayToT, ArrayT>("jk;a,ijk->i;a", //
{{0, 2}, {0, 4}}, //
{{0, 3}, {0, 2}, {0, 4}}, //
{5})));

BOOST_REQUIRE((check_manual_eval<ArrayToT, ArrayT>("bi;a,bi->i;a", //
{{0, 4, 8}, {0, 4}}, //
{{0, 4, 8}, {0, 4}}, //
{8})));
}

BOOST_AUTO_TEST_SUITE_END()
Expand Down

0 comments on commit 2bfd5aa

Please sign in to comment.