diff --git a/src/TiledArray/einsum/tiledarray.h b/src/TiledArray/einsum/tiledarray.h index 52dab7477e..09640d31f6 100644 --- a/src/TiledArray/einsum/tiledarray.h +++ b/src/TiledArray/einsum/tiledarray.h @@ -309,7 +309,7 @@ auto einsum(expressions::TsrExpr A, expressions::TsrExpr B, Einsum::Index c = std::get<0>(cs); struct { - std::string a, b, c; + std::string b, c; } inner; if constexpr (std::tuple_size::value == 2) { inner.b = ";" + (std::string)std::get<1>(Einsum::idx(B)); @@ -319,16 +319,13 @@ auto einsum(expressions::TsrExpr A, expressions::TsrExpr B, // these are "Hadamard" (fused) indices auto h = a & b & c; - auto e = (a ^ b); // contracted indices auto i = (a & b) - h; + // contraction not allowed in tensor x tensor-of-tensor + TA_ASSERT(!i); - // cannot be hadamard reduction type operation for this overload - TA_ASSERT(e); - - // no Hadamard indices => standard contraction (or even outer product) - // same a, b, and c => pure Hadamard - TA_ASSERT(!h || (!(a ^ b) && !(b ^ c))); + // indices exclusively in 'a' or exclusively in 'b' + auto e = (a ^ b); // maps Index to TiledRange1 // (asserts same index maps to the same TR1 in A, and B) @@ -364,6 +361,9 @@ auto einsum(expressions::TsrExpr A, expressions::TsrExpr B, } C.expr = e; + arrayTermB.expr += inner.b; + C.expr += inner.c; + struct { RangeProduct tiles; std::vector> batch; @@ -453,7 +453,10 @@ auto einsum(expressions::TsrExpr A, expressions::TsrExpr B, } // todo - // C.ei(C.expr) = (A.ei(A.expr) * B.ei(B.expr)).set_world(*owners); + C.ei(C.expr) = (A.ei(A.expr) * B.ei(B.expr)).set_world(*owners); + + // + A.ei.defer_deleter_to_next_fence(); B.ei.defer_deleter_to_next_fence(); A.ei = ArrayT(); diff --git a/src/TiledArray/expressions/binary_engine.h b/src/TiledArray/expressions/binary_engine.h index 4758ab0069..93192e2b5e 100644 --- a/src/TiledArray/expressions/binary_engine.h +++ b/src/TiledArray/expressions/binary_engine.h @@ -146,11 +146,10 @@ class BinaryEngine : public ExprEngine { TiledArray::detail::is_tensor_of_tensor_v; constexpr bool right_tile_is_tot = TiledArray::detail::is_tensor_of_tensor_v; - static_assert(!(left_tile_is_tot ^ right_tile_is_tot), - "ContEngine can only handle tensors of same nested-ness " - "(both plain or both ToT)"); constexpr bool args_are_plain_tensors = !left_tile_is_tot && !right_tile_is_tot; + constexpr bool args_are_mixed_tensors = + left_tile_is_tot ^ right_tile_is_tot; if (args_are_plain_tensors && (left_outer_permtype_ == PermutationType::matrix_transpose || left_outer_permtype_ == PermutationType::identity)) { @@ -175,6 +174,20 @@ class BinaryEngine : public ExprEngine { right_inner_permtype_ == PermutationType::identity))) { right_.permute_tiles(false); } + if (args_are_mixed_tensors && + ((left_outer_permtype_ == PermutationType::matrix_transpose || + left_outer_permtype_ == PermutationType::identity) || + (left_inner_permtype_ == PermutationType::matrix_transpose || + left_inner_permtype_ == PermutationType::identity))) { + left_.permute_tiles(false); + } + if (args_are_mixed_tensors && + ((left_outer_permtype_ == PermutationType::matrix_transpose || + left_outer_permtype_ == PermutationType::identity) || + (right_inner_permtype_ == PermutationType::matrix_transpose || + right_inner_permtype_ == PermutationType::identity))) { + right_.permute_tiles(false); + } } public: