Skip to content

Commit

Permalink
Implement Tot x T (and reverse) generalized contraction.
Browse files Browse the repository at this point in the history
  • Loading branch information
bimalgaudel committed Dec 17, 2023
1 parent 2520fe5 commit eacc22b
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 45 deletions.
84 changes: 43 additions & 41 deletions src/TiledArray/einsum/tiledarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,50 +181,51 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,

using Index = Einsum::Index<size_t>;

if constexpr (std::tuple_size<decltype(cs)>::value > 1) {
TA_ASSERT(e);
} else if (!e) { // hadamard reduction
auto &[A, B] = AB;
TiledRange trange(range_map[i]);
RangeProduct tiles;
for (auto idx : i) {
tiles *= Range(range_map[idx].tiles_range());
}
auto pa = A.permutation;
auto pb = B.permutation;
for (Index h : H.tiles) {
if (!C.array.is_local(h)) continue;
size_t batch = 1;
for (size_t i = 0; i < h.size(); ++i) {
batch *= H.batch[i].at(h[i]);
if constexpr (std::tuple_size<decltype(cs)>::value > 1) TA_ASSERT(e);
if constexpr (AreArraySame<ArrayA, ArrayB>) {
if (!e) { // hadamard reduction
auto &[A, B] = AB;
TiledRange trange(range_map[i]);
RangeProduct tiles;
for (auto idx : i) {
tiles *= Range(range_map[idx].tiles_range());
}
ResultTensor tile(TiledArray::Range{batch},
typename ResultTensor::value_type{});
for (Index i : tiles) {
// skip this unless both input tiles exist
const auto pahi_inv = apply_inverse(pa, h + i);
const auto pbhi_inv = apply_inverse(pb, h + i);
if (A.array.is_zero(pahi_inv) || B.array.is_zero(pbhi_inv)) continue;

auto ai = A.array.find(pahi_inv).get();
auto bi = B.array.find(pbhi_inv).get();
if (pa) ai = ai.permute(pa);
if (pb) bi = bi.permute(pb);
auto shape = trange.tile(i);
ai = ai.reshape(shape, batch);
bi = bi.reshape(shape, batch);
for (size_t k = 0; k < batch; ++k) {
auto hk = ai.batch(k).dot(bi.batch(k));
tile({k}) += hk;
auto pa = A.permutation;
auto pb = B.permutation;
for (Index h : H.tiles) {
if (!C.array.is_local(h)) continue;
size_t batch = 1;
for (size_t i = 0; i < h.size(); ++i) {
batch *= H.batch[i].at(h[i]);
}
ResultTensor tile(TiledArray::Range{batch},
typename ResultTensor::value_type{});
for (Index i : tiles) {
// skip this unless both input tiles exist
const auto pahi_inv = apply_inverse(pa, h + i);
const auto pbhi_inv = apply_inverse(pb, h + i);
if (A.array.is_zero(pahi_inv) || B.array.is_zero(pbhi_inv)) continue;

auto ai = A.array.find(pahi_inv).get();
auto bi = B.array.find(pbhi_inv).get();
if (pa) ai = ai.permute(pa);
if (pb) bi = bi.permute(pb);
auto shape = trange.tile(i);
ai = ai.reshape(shape, batch);
bi = bi.reshape(shape, batch);
for (size_t k = 0; k < batch; ++k) {
auto hk = ai.batch(k).dot(bi.batch(k));
tile({k}) += hk;
}
}
auto pc = C.permutation;
auto shape = apply_inverse(pc, C.array.trange().tile(h));
tile = tile.reshape(shape);
if (pc) tile = tile.permute(pc);
C.array.set(h, tile);
}
auto pc = C.permutation;
auto shape = apply_inverse(pc, C.array.trange().tile(h));
tile = tile.reshape(shape);
if (pc) tile = tile.permute(pc);
C.array.set(h, tile);
return C.array;
}
return C.array;
}

// generalized contraction
Expand Down Expand Up @@ -468,7 +469,8 @@ auto einsum(expressions::TsrExpr<T> A, expressions::TsrExpr<U> B,
const std::string &cs, World &world = get_default_world()) {
using ECT = expressions::TsrExpr<const T>;
using ECU = expressions::TsrExpr<const U>;
return Einsum::einsum(ECT(A), ECU(B), Einsum::idx<T>(cs), world);
using ResultExprT = std::conditional_t<Einsum::IsArrayToT<T>, T, U>;
return Einsum::einsum(ECT(A), ECU(B), Einsum::idx<ResultExprT>(cs), world);
}

template <typename T, typename U, typename V>
Expand Down
14 changes: 10 additions & 4 deletions tests/einsum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -845,7 +845,7 @@ BOOST_AUTO_TEST_CASE(ilkj_nm_eq_ij_mn_times_kl) {
}
}

BOOST_AUTO_TEST_CASE(ikj_mn_eq_ij_mn_times_jk) {
BOOST_AUTO_TEST_CASE(ijk_mn_eq_ij_mn_times_jk) {
using t_type = DistArray<Tensor<double>, SparsePolicy>;
using tot_type = DistArray<Tensor<Tensor<double>>, SparsePolicy>;
using matrix_il = TiledArray::detail::matrix_il<Tensor<double>>;
Expand Down Expand Up @@ -877,7 +877,6 @@ BOOST_AUTO_TEST_CASE(ikj_mn_eq_ij_mn_times_jk) {
t_type rhs(world, rhs_trange);
rhs.fill_random();

// TODO compute ref_result
// i,j;m,n * j,k => i,j,k;m,n
TiledRange ref_result_trange{lhs_trange.dim(0), rhs_trange.dim(0),
rhs_trange.dim(1)};
Expand Down Expand Up @@ -928,10 +927,17 @@ BOOST_AUTO_TEST_CASE(ikj_mn_eq_ij_mn_times_jk) {
// - general product w.r.t. outer indices
// - involves ToT * T
// tot_type result;
// BOOST_REQUIRE_NO_THROW(result("k,i,j;n,m") = lhs("i,j;m,n") * rhs("j,k"));
// BOOST_REQUIRE_NO_THROW(result("i,j,k;m,n") = lhs("i,j;m,n") * rhs("j,k"));

// will try to make this work
// tot_type out = einsum(lhs("i,j;m,n"), rhs("j,k"), "k,i,j;n,m");
tot_type result = einsum(lhs("i,j;m,n"), rhs("j,k"), "i,j,k;m,n");
bool are_equal = ToTArrayFixture::are_equal<false>(result, ref_result);
BOOST_REQUIRE(are_equal);
{
result = einsum(rhs("j,k"), lhs("i,j;m,n"), "i,j,k;m,n");
are_equal = ToTArrayFixture::are_equal<false>(result, ref_result);
BOOST_REQUIRE(are_equal);
}
}

BOOST_AUTO_TEST_CASE(ij_mn_eq_ji_mn_times_ij) {
Expand Down

0 comments on commit eacc22b

Please sign in to comment.