diff --git a/src/TiledArray/einsum/tiledarray.h b/src/TiledArray/einsum/tiledarray.h index 40076ed0ce..ace7caa15a 100644 --- a/src/TiledArray/einsum/tiledarray.h +++ b/src/TiledArray/einsum/tiledarray.h @@ -420,6 +420,9 @@ auto einsum(expressions::TsrExpr A, expressions::TsrExpr B, using ResultTensor = typename ArrayC::value_type; using ResultShape = typename ArrayC::shape_type; + auto const& tnsrExprA = A; + auto const& tnsrExprB = B; + auto a = std::get<0>(Einsum::idx(A)); auto b = std::get<0>(Einsum::idx(B)); Einsum::Index c = std::get<0>(cs); @@ -536,16 +539,10 @@ auto einsum(expressions::TsrExpr A, expressions::TsrExpr B, // the evaluation can be delegated to the expression layer // for distarrays of both nested and non-nested tensor tiles. // *) If no Hadamard indices are present (!h) the evaluation - // can be delegated to the expression _only_ for distarrays with - // non-nested tensor tiles. - // This is because even if Hadamard indices are not present, a contracted - // index might be present pertinent to the outer tensor in case of a - // nested-tile distarray, which is especially handled within this - // function because expression layer cannot handle that yet. + // can be delegated to the expression layer. // - if ((h && !(i || e)) // pure Hadamard - || (IsArrayToT && !(i || h)) // ToT result from outer-product - || (IsArrayT && !h)) // T from general product without Hadamard + if ((h && !(i || e)) // pure Hadamard + || !h) // no Hadamard { ArrayC C; C(std::string(c) + inner.c) = A * B; @@ -577,21 +574,6 @@ auto einsum(expressions::TsrExpr A, expressions::TsrExpr B, return C; } - // - // when contraction happens in the outer tensor - // need to evaluate specially.. - // - if (IsArrayToT && i.size() > 0) { - auto annot_c = std::string(h + e + i) + inner.c; - auto temp1 = einsum(A, B, idx(annot_c), world); - auto temp2 = reduce_modes(temp1, i.size()); - - auto annot_c_ = std::string(h + e) + inner.c; - decltype(temp2) result; - result(std::string(c) + inner.c) = temp2(annot_c_); - return result; - } - using ::Einsum::index::permutation; using TiledArray::Permutation; @@ -640,79 +622,104 @@ auto einsum(expressions::TsrExpr A, expressions::TsrExpr B, using Index = Einsum::Index; - if constexpr (AreArraySame && - AreArraySame) { - if (!e) { // hadamard reduction - auto &[A, B] = AB; - TiledRange trange(range_map[i]); - RangeProduct tiles; - for (auto idx : i) { - tiles *= Range(range_map[idx].tiles_range()); + if (!e) { // hadamard reduction + auto &[A, B] = AB; + TiledRange trange(range_map[i]); + RangeProduct tiles; + for (auto idx : i) { + tiles *= Range(range_map[idx].tiles_range()); + } + auto pa = A.permutation; + auto pb = B.permutation; + for (Index h : H.tiles) { + if (!C.array.is_local(h)) continue; + size_t batch = 1; + for (size_t i = 0; i < h.size(); ++i) { + batch *= H.batch[i].at(h[i]); } - auto pa = A.permutation; - auto pb = B.permutation; - for (Index h : H.tiles) { - if (!C.array.is_local(h)) continue; - size_t batch = 1; - for (size_t i = 0; i < h.size(); ++i) { - batch *= H.batch[i].at(h[i]); - } - ResultTensor tile(TiledArray::Range{batch}, - typename ResultTensor::value_type{}); - for (Index i : tiles) { - // skip this unless both input tiles exist - const auto pahi_inv = apply_inverse(pa, h + i); - const auto pbhi_inv = apply_inverse(pb, h + i); - if (A.array.is_zero(pahi_inv) || B.array.is_zero(pbhi_inv)) - continue; - - auto ai = A.array.find(pahi_inv).get(); - auto bi = B.array.find(pbhi_inv).get(); - if (pa) ai = ai.permute(pa); - if (pb) bi = bi.permute(pb); - auto shape = trange.tile(i); - ai = ai.reshape(shape, batch); - bi = bi.reshape(shape, batch); - for (size_t k = 0; k < batch; ++k) { - using Ix = ::Einsum::Index; - if constexpr (AreArrayToT) { - auto aik = ai.batch(k); - auto bik = bi.batch(k); - auto vol = aik.total_size(); - TA_ASSERT(vol == bik.total_size()); - - auto &el = tile({k}); - using TensorT = std::remove_reference_t; - - auto mult_op = [&inner](auto const &l, - auto const &r) -> TensorT { - return inner.h ? TA::detail::tensor_hadamard(l, inner.A, r, - inner.B, inner.C) - : TA::detail::tensor_contract( - l, inner.A, r, inner.B, inner.C); - }; - - for (auto i = 0; i < vol; ++i) - el.add_to(mult_op(aik.data()[i], bik.data()[i])); - - } else { - auto hk = ai.batch(k).dot(bi.batch(k)); - tile({k}) += hk; - } + ResultTensor tile(TiledArray::Range{batch}, + typename ResultTensor::value_type{}); + for (Index i : tiles) { + // skip this unless both input tiles exist + const auto pahi_inv = apply_inverse(pa, h + i); + const auto pbhi_inv = apply_inverse(pb, h + i); + if (A.array.is_zero(pahi_inv) || B.array.is_zero(pbhi_inv)) continue; + + auto ai = A.array.find(pahi_inv).get(); + auto bi = B.array.find(pbhi_inv).get(); + if (pa) ai = ai.permute(pa); + if (pb) bi = bi.permute(pb); + auto shape = trange.tile(i); + ai = ai.reshape(shape, batch); + bi = bi.reshape(shape, batch); + for (size_t k = 0; k < batch; ++k) { + using Ix = ::Einsum::Index; + if constexpr (AreArrayToT) { + auto aik = ai.batch(k); + auto bik = bi.batch(k); + auto vol = aik.total_size(); + TA_ASSERT(vol == bik.total_size()); + + auto &el = tile({k}); + using TensorT = std::remove_reference_t; + + auto mult_op = [&inner](auto const &l, auto const &r) -> TensorT { + if (l.empty() || r.empty()) return TensorT{}; + return inner.h ? TA::detail::tensor_hadamard(l, inner.A, r, + inner.B, inner.C) + : TA::detail::tensor_contract(l, inner.A, r, + inner.B, inner.C); + }; + + for (auto i = 0; i < vol; ++i) + el.add_to(mult_op(aik.data()[i], bik.data()[i])); + + } else if constexpr (!AreArraySame) { + auto aik = ai.batch(k); + auto bik = bi.batch(k); + auto vol = aik.total_size(); + TA_ASSERT(vol == bik.total_size()); + + auto &el = tile({k}); + + for (auto i = 0; i < vol; ++i) + if constexpr (IsArrayToT) { + el.add_to(aik.data()[i].scale(bik.data()[i])); + } else { + el.add_to(bik.data()[i].scale(aik.data()[i])); + } + + } else { + auto hk = ai.batch(k).dot(bi.batch(k)); + tile({k}) += hk; } } - auto pc = C.permutation; - auto shape = apply_inverse(pc, C.array.trange().tile(h)); - tile = tile.reshape(shape); - if (pc) tile = tile.permute(pc); - C.array.set(h, tile); } - return C.array; + auto pc = C.permutation; + auto shape = apply_inverse(pc, C.array.trange().tile(h)); + tile = tile.reshape(shape); + if (pc) tile = tile.permute(pc); + C.array.set(h, tile); } + return C.array; } // generalized contraction + if constexpr (IsArrayToT) { + if (inner.C != inner.h + inner.e) { + // when inner tensor permutation is non-trivial (could be potentially + // elided by extending this function (@c einsum) to take into account + // of inner tensor's permutations) + auto temp_annot = std::string(c) + ";" + std::string(inner.h + inner.e); + ArrayC temp = einsum(tnsrExprA, tnsrExprB, + Einsum::idx(temp_annot), world); + ArrayC result; + result(std::string(c) + inner.c) = temp(temp_annot); + return result; + } + } + auto update_tr = [&e = std::as_const(e), &i = std::as_const(i), &range_map = std::as_const(range_map)](auto &term) { auto ei = (e + i & term.idx); diff --git a/src/TiledArray/expressions/cont_engine.h b/src/TiledArray/expressions/cont_engine.h index f0a94c7e05..3d0ef11c10 100644 --- a/src/TiledArray/expressions/cont_engine.h +++ b/src/TiledArray/expressions/cont_engine.h @@ -279,25 +279,62 @@ class ContEngine : public BinaryEngine { outer_size(left_indices_), outer_size(right_indices_), (!implicit_permute_outer_ ? std::move(outer_perm) : Permutation{})); } else { + + auto make_total_perm = [this]() -> BipartitePermutation { + if (this->product_type() != TensorProduct::Contraction + || this->implicit_permute_inner_) + return this->implicit_permute_outer_ + ? BipartitePermutation() + : BipartitePermutation(outer(this->perm_)); + + // Here, + // this->product_type() is Tensor::Contraction, and, + // this->implicit_permute_inner_ is false + + return this->inner_product_type() == TensorProduct::Scale + ? BipartitePermutation(outer(this->perm_)) + : this->perm_; + }; + + auto total_perm = make_total_perm(); + // factor_ is absorbed into inner_tile_nonreturn_op_ op_ = op_type( left_op, right_op, scalar_type(1), outer_size(indices_), outer_size(left_indices_), outer_size(right_indices_), - (!implicit_permute_outer_ ? std::move(outer_perm) : Permutation{}), + total_perm, this->element_nonreturn_op_); } trange_ = ContEngine_::make_trange(outer_perm); shape_ = ContEngine_::make_shape(outer_perm); } else { // Initialize non-permuted structure + if constexpr (!TiledArray::detail::is_tensor_of_tensor_v) { op_ = op_type(left_op, right_op, factor_, outer_size(indices_), outer_size(left_indices_), outer_size(right_indices_)); } else { + + auto make_total_perm = [this]() -> BipartitePermutation { + if (this->product_type() != TensorProduct::Contraction + || this->implicit_permute_inner_) + return {}; + + // Here, + // this->product_type() is Tensor::Contraction, and, + // this->implicit_permute_inner_ is false + + return this->inner_product_type() == TensorProduct::Scale + ? BipartitePermutation(outer(this->perm_)) + : this->perm_; + }; + + auto total_perm = make_total_perm(); + // factor_ is absorbed into inner_tile_nonreturn_op_ op_ = op_type(left_op, right_op, scalar_type(1), outer_size(indices_), outer_size(left_indices_), outer_size(right_indices_), - BipartitePermutation{}, this->element_nonreturn_op_); + total_perm, this->element_nonreturn_op_); } trange_ = ContEngine_::make_trange(); shape_ = ContEngine_::make_shape(); @@ -509,12 +546,15 @@ class ContEngine : public BinaryEngine { inner_size(this->left_indices_), inner_size(this->right_indices_)); this->element_nonreturn_op_ = - [contrreduce_op](result_tile_element_type& result, - const left_tile_element_type& left, - const right_tile_element_type& right) { + [contrreduce_op, permute_inner = this->product_type() != + TensorProduct::Contraction]( + result_tile_element_type& result, + const left_tile_element_type& left, + const right_tile_element_type& right) { contrreduce_op(result, left, right); - if (!TA::empty(result)) - result = contrreduce_op(result); // permutations of result are applied as "postprocessing" + // permutations of result are applied as "postprocessing" + if (permute_inner && !TA::empty(result)) + result = contrreduce_op(result); }; } // ToT x ToT } else if (inner_prod == TensorProduct::Hadamard) { diff --git a/src/TiledArray/tensor/kernels.h b/src/TiledArray/tensor/kernels.h index a2530f2f5d..0b0767ed81 100644 --- a/src/TiledArray/tensor/kernels.h +++ b/src/TiledArray/tensor/kernels.h @@ -417,14 +417,15 @@ inline void inplace_tensor_op(Op&& op, TR& result, const Ts&... tensors) { TA_ASSERT(!empty(result, tensors...)); TA_ASSERT(is_range_set_congruent(result, tensors...)); - const auto volume = result.range().volume(); - - for (decltype(result.range().volume()) ord = 0ul; ord < volume; ++ord) { + auto volume = result.total_size(); + for (decltype(volume) ord = 0; ord < volume; ++ord) { + if constexpr (is_tensor_of_tensor_v) + if (((tensors.data()[ord].range().volume() == 0) || ...)) continue; if constexpr (std::is_invocable_r_v) - op(result.at_ordinal(ord), tensors.at_ordinal(ord)...); + op(result.data()[ord], tensors.data()[ord]...); else - inplace_tensor_op(op, result.at_ordinal(ord), tensors.at_ordinal(ord)...); + inplace_tensor_op(op, result.data()[ord], tensors.data()[ord]...); } } diff --git a/src/TiledArray/tensor/tensor.h b/src/TiledArray/tensor/tensor.h index bd6fb8f3e5..a394594b8e 100644 --- a/src/TiledArray/tensor/tensor.h +++ b/src/TiledArray/tensor/tensor.h @@ -1630,6 +1630,7 @@ class Tensor { template ::value>::type* = nullptr> Tensor add(const Right& right) const& { + if (right.empty()) return *this; return binary( right, [](const value_type& l, const value_t& r) -> decltype(auto) { diff --git a/src/TiledArray/tile_op/contract_reduce.h b/src/TiledArray/tile_op/contract_reduce.h index f0654f1431..2a5e90ea5d 100644 --- a/src/TiledArray/tile_op/contract_reduce.h +++ b/src/TiledArray/tile_op/contract_reduce.h @@ -332,7 +332,6 @@ class ContractReduce : public ContractReduceBase { if constexpr (!ContractReduceBase_::plain_tensors) { TA_ASSERT(this->elem_muladd_op()); - // not yet implemented gemm(result, left, right, ContractReduceBase_::gemm_helper(), this->elem_muladd_op()); } else { // plain tensors