Skip to content

Commit

Permalink
Outer tensor contraction logic update.
Browse files Browse the repository at this point in the history
When the contraction occurs in the outer tensor (between two ToTs or between T or ToT), the contraction step cannot be passed to the expression layer. Example: `J = TA::einsum(I("i,j,k,l;eik,fjl"), S_1("i,j,k;aij,eik"), "i,j,l;aij,fjl");`
  • Loading branch information
bimalgaudel committed May 10, 2024
1 parent 3a68f0c commit 6942120
Showing 1 changed file with 11 additions and 18 deletions.
29 changes: 11 additions & 18 deletions src/TiledArray/einsum/tiledarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -525,25 +525,18 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
}

//
// special Hadamard + contraction
// when ToT times T implied and T's indices are contraction AND Hadamard
// BUT not externals
// when contraction happens in the outer tensor
// need to evaluate specially..
//
if constexpr (!AreArraySame<ArrayA, ArrayB> &&
DeNestFlag == DeNest::False) {
auto hi_size = h.size() + i.size();
if (hi_size != h.size() && hi_size != i.size() &&
((hi_size == a.size() && IsArrayT<ArrayA>) ||
(hi_size == b.size() && IsArrayT<ArrayB>))) {
auto annot_c = std::string(h + e + i) + inner.c;
auto temp1 = einsum(A, B, idx<ArrayC>(annot_c), world);
auto temp2 = reduce_modes(temp1, i.size());

auto annot_c_ = std::string(h + e) + inner.c;
decltype(temp2) result;
result(std::string(c) + inner.c) = temp2(annot_c_);
return result;
}
if (IsArrayToT<ArrayC> && i.size() > 0) {
auto annot_c = std::string(h + e + i) + inner.c;
auto temp1 = einsum(A, B, idx<ArrayC>(annot_c), world);
auto temp2 = reduce_modes(temp1, i.size());

auto annot_c_ = std::string(h + e) + inner.c;
decltype(temp2) result;
result(std::string(c) + inner.c) = temp2(annot_c_);
return result;
}

using ::Einsum::index::permutation;
Expand Down

0 comments on commit 6942120

Please sign in to comment.