diff --git a/src/TiledArray/einsum/tiledarray.h b/src/TiledArray/einsum/tiledarray.h index 48648407cb..2bd548df5c 100644 --- a/src/TiledArray/einsum/tiledarray.h +++ b/src/TiledArray/einsum/tiledarray.h @@ -181,50 +181,51 @@ auto einsum(expressions::TsrExpr A, expressions::TsrExpr B, using Index = Einsum::Index; - if constexpr (std::tuple_size::value > 1) { - TA_ASSERT(e); - } else if (!e) { // hadamard reduction - auto &[A, B] = AB; - TiledRange trange(range_map[i]); - RangeProduct tiles; - for (auto idx : i) { - tiles *= Range(range_map[idx].tiles_range()); - } - auto pa = A.permutation; - auto pb = B.permutation; - for (Index h : H.tiles) { - if (!C.array.is_local(h)) continue; - size_t batch = 1; - for (size_t i = 0; i < h.size(); ++i) { - batch *= H.batch[i].at(h[i]); + if constexpr (std::tuple_size::value > 1) TA_ASSERT(e); + if constexpr (AreArraySame) { + if (!e) { // hadamard reduction + auto &[A, B] = AB; + TiledRange trange(range_map[i]); + RangeProduct tiles; + for (auto idx : i) { + tiles *= Range(range_map[idx].tiles_range()); } - ResultTensor tile(TiledArray::Range{batch}, - typename ResultTensor::value_type{}); - for (Index i : tiles) { - // skip this unless both input tiles exist - const auto pahi_inv = apply_inverse(pa, h + i); - const auto pbhi_inv = apply_inverse(pb, h + i); - if (A.array.is_zero(pahi_inv) || B.array.is_zero(pbhi_inv)) continue; - - auto ai = A.array.find(pahi_inv).get(); - auto bi = B.array.find(pbhi_inv).get(); - if (pa) ai = ai.permute(pa); - if (pb) bi = bi.permute(pb); - auto shape = trange.tile(i); - ai = ai.reshape(shape, batch); - bi = bi.reshape(shape, batch); - for (size_t k = 0; k < batch; ++k) { - auto hk = ai.batch(k).dot(bi.batch(k)); - tile({k}) += hk; + auto pa = A.permutation; + auto pb = B.permutation; + for (Index h : H.tiles) { + if (!C.array.is_local(h)) continue; + size_t batch = 1; + for (size_t i = 0; i < h.size(); ++i) { + batch *= H.batch[i].at(h[i]); } + ResultTensor tile(TiledArray::Range{batch}, + typename ResultTensor::value_type{}); + for (Index i : tiles) { + // skip this unless both input tiles exist + const auto pahi_inv = apply_inverse(pa, h + i); + const auto pbhi_inv = apply_inverse(pb, h + i); + if (A.array.is_zero(pahi_inv) || B.array.is_zero(pbhi_inv)) continue; + + auto ai = A.array.find(pahi_inv).get(); + auto bi = B.array.find(pbhi_inv).get(); + if (pa) ai = ai.permute(pa); + if (pb) bi = bi.permute(pb); + auto shape = trange.tile(i); + ai = ai.reshape(shape, batch); + bi = bi.reshape(shape, batch); + for (size_t k = 0; k < batch; ++k) { + auto hk = ai.batch(k).dot(bi.batch(k)); + tile({k}) += hk; + } + } + auto pc = C.permutation; + auto shape = apply_inverse(pc, C.array.trange().tile(h)); + tile = tile.reshape(shape); + if (pc) tile = tile.permute(pc); + C.array.set(h, tile); } - auto pc = C.permutation; - auto shape = apply_inverse(pc, C.array.trange().tile(h)); - tile = tile.reshape(shape); - if (pc) tile = tile.permute(pc); - C.array.set(h, tile); + return C.array; } - return C.array; } // generalized contraction @@ -468,7 +469,8 @@ auto einsum(expressions::TsrExpr A, expressions::TsrExpr B, const std::string &cs, World &world = get_default_world()) { using ECT = expressions::TsrExpr; using ECU = expressions::TsrExpr; - return Einsum::einsum(ECT(A), ECU(B), Einsum::idx(cs), world); + using ResultExprT = std::conditional_t, T, U>; + return Einsum::einsum(ECT(A), ECU(B), Einsum::idx(cs), world); } template diff --git a/tests/einsum.cpp b/tests/einsum.cpp index eb976b31f5..3e7b502da9 100644 --- a/tests/einsum.cpp +++ b/tests/einsum.cpp @@ -845,7 +845,7 @@ BOOST_AUTO_TEST_CASE(ilkj_nm_eq_ij_mn_times_kl) { } } -BOOST_AUTO_TEST_CASE(ikj_mn_eq_ij_mn_times_jk) { +BOOST_AUTO_TEST_CASE(ijk_mn_eq_ij_mn_times_jk) { using t_type = DistArray, SparsePolicy>; using tot_type = DistArray>, SparsePolicy>; using matrix_il = TiledArray::detail::matrix_il>; @@ -877,7 +877,6 @@ BOOST_AUTO_TEST_CASE(ikj_mn_eq_ij_mn_times_jk) { t_type rhs(world, rhs_trange); rhs.fill_random(); - // TODO compute ref_result // i,j;m,n * j,k => i,j,k;m,n TiledRange ref_result_trange{lhs_trange.dim(0), rhs_trange.dim(0), rhs_trange.dim(1)}; @@ -928,10 +927,17 @@ BOOST_AUTO_TEST_CASE(ikj_mn_eq_ij_mn_times_jk) { // - general product w.r.t. outer indices // - involves ToT * T // tot_type result; - // BOOST_REQUIRE_NO_THROW(result("k,i,j;n,m") = lhs("i,j;m,n") * rhs("j,k")); + // BOOST_REQUIRE_NO_THROW(result("i,j,k;m,n") = lhs("i,j;m,n") * rhs("j,k")); // will try to make this work - // tot_type out = einsum(lhs("i,j;m,n"), rhs("j,k"), "k,i,j;n,m"); + tot_type result = einsum(lhs("i,j;m,n"), rhs("j,k"), "i,j,k;m,n"); + bool are_equal = ToTArrayFixture::are_equal(result, ref_result); + BOOST_REQUIRE(are_equal); + { + result = einsum(rhs("j,k"), lhs("i,j;m,n"), "i,j,k;m,n"); + are_equal = ToTArrayFixture::are_equal(result, ref_result); + BOOST_REQUIRE(are_equal); + } } BOOST_AUTO_TEST_CASE(ij_mn_eq_ji_mn_times_ij) {