diff --git a/src/TiledArray/dist_eval/array_eval.h b/src/TiledArray/dist_eval/array_eval.h index a4cbdc47b1..2eaad01a9b 100644 --- a/src/TiledArray/dist_eval/array_eval.h +++ b/src/TiledArray/dist_eval/array_eval.h @@ -228,13 +228,15 @@ class ArrayEvalImpl /// \param pmap The process map for the result tensor tiles /// \param perm The permutation that is applied to the tile coordinate index /// \param op The operation that will be used to evaluate the tiles of array - template >> + template >>> ArrayEvalImpl(const array_type& array, World& world, const trange_type& trange, const shape_type& shape, - const std::shared_ptr& pmap, - const Perm& perm, const op_type& op) - : DistEvalImpl_(world, trange, shape, pmap, outer(perm)), + const std::shared_ptr& pmap, Perm&& perm, + const op_type& op) + : DistEvalImpl_(world, trange, shape, pmap, + outer(std::forward(perm))), array_(array), op_(std::make_shared(op)), block_range_() @@ -273,17 +275,19 @@ class ArrayEvalImpl /// \param op The operation that will be used to evaluate the tiles of array /// \param lower_bound The sub-block lower bound /// \param upper_bound The sub-block upper bound - template && - TiledArray::detail::is_integral_range_v && - TiledArray::detail::is_permutation_v>> + template < + typename Index1, typename Index2, typename Perm, + typename = std::enable_if_t< + TiledArray::detail::is_integral_range_v && + TiledArray::detail::is_integral_range_v && + TiledArray::detail::is_permutation_v>>> ArrayEvalImpl(const array_type& array, World& world, const trange_type& trange, const shape_type& shape, - const std::shared_ptr& pmap, - const Perm& perm, const op_type& op, const Index1& lower_bound, + const std::shared_ptr& pmap, Perm&& perm, + const op_type& op, const Index1& lower_bound, const Index2& upper_bound) - : DistEvalImpl_(world, trange, shape, pmap, outer(perm)), + : DistEvalImpl_(world, trange, shape, pmap, + outer(std::forward(perm))), array_(array), op_(std::make_shared(op)), block_range_(array.trange().tiles_range(), lower_bound, upper_bound) diff --git a/src/TiledArray/einsum/index.h b/src/TiledArray/einsum/index.h index 58c378704b..67e9d6c1a0 100644 --- a/src/TiledArray/einsum/index.h +++ b/src/TiledArray/einsum/index.h @@ -3,10 +3,10 @@ #include "TiledArray/expressions/fwd.h" +#include #include #include #include -#include #include #include @@ -29,10 +29,11 @@ class Index { public: using container_type = small_vector; using value_type = typename container_type::value_type; + using iterator = typename container_type::iterator; Index() = default; Index(const container_type &s) : data_(s) {} - Index(const std::initializer_list &s) : data_(s) {} + explicit Index(const std::initializer_list &s) : data_(s) {} template Index(const S &s) { @@ -45,18 +46,14 @@ class Index { Index(const char (&s)[N]) : Index(std::string(s)) {} template - explicit Index(const char* &s) : Index(std::string(s)) {} + explicit Index(const char *&s) : Index(std::string(s)) {} template explicit Index(const std::string &s) { - static_assert( - std::is_same_v || - std::is_same_v - ); - if constexpr (std::is_same_v) { + static_assert(std::is_same_v || std::is_same_v); + if constexpr (std::is_same_v) { data_ = index::tokenize(s); - } - else { + } else { using std::begin; using std::end; data_.assign(begin(s), end(s)); @@ -78,8 +75,11 @@ class Index { size_t size() const { return data_.size(); } - auto begin() const { return data_.begin(); } - auto end() const { return data_.end(); } + auto begin() const { return data_.cbegin(); } + auto end() const { return data_.cend(); } + + auto begin() { return data_.begin(); } + auto end() { return data_.end(); } auto find(const T &v) const { return std::find(this->begin(), this->end(), v); @@ -209,11 +209,8 @@ auto permute(const Permutation &p, const Index &s, if (!p) return s; using R = typename Index::container_type; R r(p.size()); - TiledArray::detail::permute_n( - p.size(), - p.begin(), s.begin(), r.begin(), - std::bool_constant{} - ); + TiledArray::detail::permute_n(p.size(), p.begin(), s.begin(), r.begin(), + std::bool_constant{}); return Index{r}; } @@ -306,8 +303,8 @@ IndexMap operator|(const IndexMap &a, const IndexMap &b) { } // namespace Einsum::index namespace Einsum { - using index::Index; - using index::IndexMap; -} // namespace TiledArray::Einsum +using index::Index; +using index::IndexMap; +} // namespace Einsum #endif /* TILEDARRAY_EINSUM_INDEX_H__INCLUDED */ diff --git a/src/TiledArray/einsum/string.h b/src/TiledArray/einsum/string.h index 7647aed63b..d2dc6048ab 100644 --- a/src/TiledArray/einsum/string.h +++ b/src/TiledArray/einsum/string.h @@ -1,50 +1,50 @@ #ifndef TILEDARRAY_EINSUM_STRING_H #define TILEDARRAY_EINSUM_STRING_H +#include #include #include -#include +#include #include #include namespace Einsum::string { namespace { - // Split delimiter must match completely - template - std::pair split2(const std::string& s, const std::string &d) { - auto pos = s.find(d); - if (pos == s.npos) return { T(s), U("") }; - return { T(s.substr(0,pos)), U(s.substr(pos+d.size())) }; - } +// Split delimiter must match completely +template +std::pair split2(const std::string& s, const std::string& d) { + auto pos = s.find(d); + if (pos == s.npos) return {T(s), U("")}; + return {T(s.substr(0, pos)), U(s.substr(pos + d.size()))}; +} - // Split delimiter must match completely - std::vector split(const std::string& s, char d) { - std::vector res; - return boost::split(res, s, [&d](char c) { return c == d; } /*boost::is_any_of(d)*/); - } +// Split delimiter must match completely +std::vector split(const std::string& s, char d) { + std::vector res; + return boost::split(res, s, + [&d](char c) { return c == d; } /*boost::is_any_of(d)*/); +} - std::string trim(const std::string& s) { - return boost::trim_copy(s); - } +std::string trim(const std::string& s) { return boost::trim_copy(s); } - template - std::string str(const T& obj) { - std::stringstream ss; - ss << obj; - return ss.str(); - } +template +std::string str(const T& obj) { + std::stringstream ss; + ss << obj; + return ss.str(); +} - template - std::string join(const T &s, const U& j = U("")) { - std::vector strings; - for (auto e : s) { - strings.push_back(str(e)); - } - return boost::join(strings, j); +template +std::string join(const T& s, const U& j = U("")) { + std::vector strings; + for (auto e : s) { + strings.push_back(str(e)); } - -} + return boost::join(strings, j); } -#endif //TILEDARRAY_EINSUM_STRING_H +} // namespace +} // namespace Einsum::string + +#endif // TILEDARRAY_EINSUM_STRING_H diff --git a/src/TiledArray/einsum/tiledarray.h b/src/TiledArray/einsum/tiledarray.h index 1851973709..d2d8b2c9ba 100644 --- a/src/TiledArray/einsum/tiledarray.h +++ b/src/TiledArray/einsum/tiledarray.h @@ -9,6 +9,10 @@ #include "TiledArray/tiled_range.h" #include "TiledArray/tiled_range1.h" +namespace TiledArray { +enum struct DeNest { True, False }; +} + namespace TiledArray::Einsum { using ::Einsum::index::small_vector; @@ -82,17 +86,109 @@ template constexpr bool AreArraySame = AreArrayT || AreArrayToT; +template +using DeNestedArray = DistArray; + +template +using MaxNestedArray = std::conditional_t<(detail::nested_rank > + detail::nested_rank), + Array2, Array1>; + +} // namespace + +namespace { + +/// +/// \brief This function replicates a tensor B into a tensor A such that +/// A(a_1,...a_k,i_1,...,i_l) = B(i_1,...,i_l). Evidently, the +/// extents of i_n modes must match in both A and B. +/// +/// \tparam Tensor TiledArray::Tensor type. +/// \param to The target tensor. +/// \param from The source tensor that will be replicated into \c to. +/// +template >> +void replicate_tensor(Tensor &to, Tensor const &from) { + // assert that corresponding modes have the same extents + TA_ASSERT(std::equal(from.range().extent().rbegin(), + from.range().extent().rend(), + to.range().extent().rbegin())); + + // number of elements to be copied + // (same as the number of elements in @c from) + auto const N = from.range().volume(); + for (auto i = 0; i < to.range().volume(); i += N) + std::copy(from.begin(), from.end(), to.data() + i); +} + +/// +/// \brief This function is the @c DistArray counterpart of the function +/// @c replicate_tensor(TA::Tensor&, TA::Tensor const&). +/// +/// \tparam Array +/// \param from The DistArray to be by-rank replicated. +/// \parama prepend_trng TiledRange1's in this argument will be prepended to the +/// `TiledRange` of the argument array. +/// \return An array whose rank is increased by `prepend_trng.rank()`. +/// \see `replicate_tensor` +/// +template >> +auto replicate_array(Array from, TiledRange const &prepend_trng) { + auto const result_rank = prepend_trng.rank() + rank(from); + container::svector tr1s; + tr1s.reserve(result_rank); + for (auto const &r : prepend_trng) tr1s.emplace_back(r); + for (auto const &r : from.trange()) tr1s.emplace_back(r); + auto const result_trange = TiledRange(tr1s); + + from.make_replicated(); + auto &world = from.world(); + world.gop.fence(); + + auto result = make_array( + world, result_trange, + [from, res_tr = result_trange, delta_rank = prepend_trng.rank()]( + auto &tile, auto const &res_rng) { + using std::begin; + using std::end; + using std::next; + + typename Array::value_type repped(res_rng); + auto res_coord_ix = res_tr.element_to_tile(res_rng.lobound()); + auto from_coord_ix = decltype(res_coord_ix)( + next(begin(res_coord_ix), delta_rank), end(res_coord_ix)); + replicate_tensor(repped, from.find_local(from_coord_ix).get(false)); + tile = repped; + return tile.norm(); + }); + return result; +} + +template +TiledRange make_trange(RangeMap const &map, Ixs const &ixs) { + container::svector tr1s; + tr1s.reserve(ixs.size()); + for (auto &&i : ixs) tr1s.emplace_back(map[i]); + return TiledRange(tr1s); +} + } // namespace -template +template auto einsum(expressions::TsrExpr A, expressions::TsrExpr B, std::tuple, Indices...> cs, World &world) { using ArrayA = std::remove_cv_t; using ArrayB = std::remove_cv_t; - using ArrayC = std::conditional_t< - AreArraySame, ArrayA, - std::conditional_t, ArrayA, ArrayB>>; + + using ArrayC = + std::conditional_t, + MaxNestedArray>; + using ResultTensor = typename ArrayC::value_type; using ResultShape = typename ArrayC::shape_type; @@ -102,229 +198,346 @@ auto einsum(expressions::TsrExpr A, expressions::TsrExpr B, struct { std::string a, b, c; + // Hadamard, external, internal indices for inner tensor + Einsum::Index A, B, C, h, e, i; } inner; - if constexpr (std::tuple_size::value == 2) { - if constexpr (IsArrayToT) - inner.a = ";" + (std::string)std::get<1>(Einsum::idx(A)); - if constexpr (IsArrayToT) - inner.b = ";" + (std::string)std::get<1>(Einsum::idx(B)); + if constexpr (IsArrayToT) { + inner.a = ";" + (std::string)std::get<1>(Einsum::idx(A)); + inner.A = std::get<1>(Einsum::idx(A)); + } - static_assert(IsArrayToT || IsArrayToT); - inner.c = ";" + (std::string)std::get<1>(cs); + if constexpr (IsArrayToT) { + inner.b = ";" + (std::string)std::get<1>(Einsum::idx(B)); + inner.B = std::get<1>(Einsum::idx(B)); } - // these are "Hadamard" (fused) indices - auto h = a & b & c; + if constexpr (std::tuple_size::value == 2) { + static_assert(IsArrayToT); + inner.c = ";" + (std::string)std::get<1>(cs); + inner.C = std::get<1>(cs); + } - // no Hadamard indices => standard contraction (or even outer product) - // same a, b, and c => pure Hadamard - if (!h || (!(a ^ b) && !(b ^ c))) { - ArrayC C; - C(std::string(c) + inner.c) = A * B; - return C; + { + inner.h = inner.A & inner.B & inner.C; + inner.e = (inner.A ^ inner.B); + inner.i = (inner.A & inner.B) - inner.h; + if constexpr (IsArrayToT) + TA_ASSERT(!(inner.h && (inner.i || inner.e)) && + "General product between inner tensors not supported"); } - auto e = (a ^ b); - // contracted indices - auto i = (a & b) - h; + if constexpr (DeNestFlag == DeNest::True) { + static_assert(detail::nested_rank == detail::nested_rank && + detail::nested_rank == 2); + + TA_ASSERT(!inner.C && + "Denested result cannot have inner-tensor annotation"); + + TA_ASSERT(inner.i.size() == inner.A.size() && + inner.i.size() == inner.B.size() && + "Nested-rank-reduction only supported when the inner tensor " + "ranks match on the arguments"); + + // Step I: A * B -> C' + // Step II: C' -> C + // + // At "Step I", a general product (without reduction) in outer indices, + // and pure Hadamard product in inner indices is carried out. + // Then at "Step II", the inner tensors are reduced with a unary function. + // The reducing function is determined by looking at the contracting and + // non-contracting outer indices. + // + // eg. A(i,j,k;a,b) * B(k,j;a,b) -> C(i,j) involves following two steps: + // Step I: A(i,j,k;a,b) * B(k,j;a,b) -> C'(i,j;a,b) + // Step II: C'(i,j;a,b) -> C(i,j) + + auto Cp = einsum(A, B, std::string(c) + ";" + std::string(inner.i)); + + auto sum_tot_2_tos = [](auto const &tot) { + typename std::remove_reference_t::value_type result( + tot.range(), [tot](auto &&ix) { return tot(ix).sum(); }); + return result; + }; - TA_ASSERT(e || h); + auto result = TA::foreach( + Cp, [sum_tot_2_tos](auto &out_tile, auto const &in_tile) { + out_tile = sum_tot_2_tos(in_tile); + }); - auto range_map = - (RangeMap(a, A.array().trange()) | RangeMap(b, B.array().trange())); + return result; + } else { + // these are "Hadamard" (fused) indices + auto h = a & b & c; - using ::Einsum::index::permutation; - using TiledArray::Permutation; + // external indices + auto e = (a ^ b); - std::tuple, ArrayTerm> AB{{A.array(), a}, - {B.array(), b}}; + // contracted indices + auto i = (a & b) - h; - auto update_perm_and_indices = [&e = std::as_const(e), &i = std::as_const(i), - &h = std::as_const(h)](auto &term) { - auto ei = (e + i & term.idx); - if (term.idx != h + ei) { - term.permutation = permutation(term.idx, h + ei); + // no Hadamard indices => standard contraction (or even outer product) + // same a, b, and c => pure Hadamard + if (!h || (h && !(i || e))) { + ArrayC C; + C(std::string(c) + inner.c) = A * B; + return C; } - term.expr = ei; - }; - std::invoke(update_perm_and_indices, std::get<0>(AB)); - std::invoke(update_perm_and_indices, std::get<1>(AB)); + TA_ASSERT(e || h); - ArrayTerm C = {ArrayC(world, TiledRange(range_map[c])), c}; - for (auto idx : e) { - C.tiles *= Range(range_map[idx].tiles_range()); - } - if (C.idx != h + e) { - C.permutation = permutation(h + e, C.idx); - } - C.expr = e; + auto range_map = + (RangeMap(a, A.array().trange()) | RangeMap(b, B.array().trange())); - std::get<0>(AB).expr += inner.a; - std::get<1>(AB).expr += inner.b; + auto perm_and_rank_replicate = [delta_trng = make_trange(range_map, e)]( + auto pre, // + std::string const &pre_annot, // + std::string const &permed_annot) { + decltype(pre) permed; + permed(permed_annot) = pre(pre_annot); + return replicate_array(permed, delta_trng); + }; - C.expr += inner.c; + // special Hadamard + if (h.size() == a.size() || h.size() == b.size()) { + TA_ASSERT(!i && e); + bool small_a = h.size() == a.size(); + std::string const eh_annot = (e | h); + std::string const permed_annot = + std::string(h) + (small_a ? inner.a : inner.b); + std::string const C_annot = std::string(c) + inner.c; + std::string const temp_annot = std::string(e) + "," + permed_annot; + ArrayC C; + if (small_a) { + auto temp = + perm_and_rank_replicate(A.array(), A.annotation(), permed_annot); + C(C_annot) = temp(temp_annot) * B; + } else { + auto temp = + perm_and_rank_replicate(B.array(), B.annotation(), permed_annot); + C(C_annot) = A * temp(temp_annot); + } + return C; + } - struct { - RangeProduct tiles; - std::vector> batch; - } H; + using ::Einsum::index::permutation; + using TiledArray::Permutation; - for (auto idx : h) { - H.tiles *= Range(range_map[idx].tiles_range()); - H.batch.push_back({}); - for (auto r : range_map[idx]) { - H.batch.back().push_back(Range{r}.size()); + std::tuple, ArrayTerm> AB{{A.array(), a}, + {B.array(), b}}; + + auto update_perm_and_indices = [&e = std::as_const(e), + &i = std::as_const(i), + &h = std::as_const(h)](auto &term) { + auto ei = (e + i & term.idx); + if (term.idx != h + ei) { + term.permutation = permutation(term.idx, h + ei); + } + term.expr = ei; + }; + + std::invoke(update_perm_and_indices, std::get<0>(AB)); + std::invoke(update_perm_and_indices, std::get<1>(AB)); + + ArrayTerm C = {ArrayC(world, TiledRange(range_map[c])), c}; + for (auto idx : e) { + C.tiles *= Range(range_map[idx].tiles_range()); } - } + if (C.idx != h + e) { + C.permutation = permutation(h + e, C.idx); + } + C.expr = e; - using Index = Einsum::Index; + std::get<0>(AB).expr += inner.a; + std::get<1>(AB).expr += inner.b; - if constexpr (AreArraySame) { - if (!e) { // hadamard reduction - auto &[A, B] = AB; - TiledRange trange(range_map[i]); + C.expr += inner.c; + + struct { RangeProduct tiles; - for (auto idx : i) { - tiles *= Range(range_map[idx].tiles_range()); + std::vector> batch; + } H; + + for (auto idx : h) { + H.tiles *= Range(range_map[idx].tiles_range()); + H.batch.push_back({}); + for (auto r : range_map[idx]) { + H.batch.back().push_back(Range{r}.size()); } - auto pa = A.permutation; - auto pb = B.permutation; - for (Index h : H.tiles) { - if (!C.array.is_local(h)) continue; - size_t batch = 1; - for (size_t i = 0; i < h.size(); ++i) { - batch *= H.batch[i].at(h[i]); + } + + using Index = Einsum::Index; + + if constexpr (AreArraySame && + AreArraySame) { + if (!e) { // hadamard reduction + auto &[A, B] = AB; + TiledRange trange(range_map[i]); + RangeProduct tiles; + for (auto idx : i) { + tiles *= Range(range_map[idx].tiles_range()); } - ResultTensor tile(TiledArray::Range{batch}, - typename ResultTensor::value_type{}); - for (Index i : tiles) { - // skip this unless both input tiles exist - const auto pahi_inv = apply_inverse(pa, h + i); - const auto pbhi_inv = apply_inverse(pb, h + i); - if (A.array.is_zero(pahi_inv) || B.array.is_zero(pbhi_inv)) continue; - - auto ai = A.array.find(pahi_inv).get(); - auto bi = B.array.find(pbhi_inv).get(); - if (pa) ai = ai.permute(pa); - if (pb) bi = bi.permute(pb); - auto shape = trange.tile(i); - ai = ai.reshape(shape, batch); - bi = bi.reshape(shape, batch); - for (size_t k = 0; k < batch; ++k) { - auto hk = ai.batch(k).dot(bi.batch(k)); - tile({k}) += hk; + auto pa = A.permutation; + auto pb = B.permutation; + for (Index h : H.tiles) { + if (!C.array.is_local(h)) continue; + size_t batch = 1; + for (size_t i = 0; i < h.size(); ++i) { + batch *= H.batch[i].at(h[i]); + } + ResultTensor tile(TiledArray::Range{batch}, + typename ResultTensor::value_type{}); + for (Index i : tiles) { + // skip this unless both input tiles exist + const auto pahi_inv = apply_inverse(pa, h + i); + const auto pbhi_inv = apply_inverse(pb, h + i); + if (A.array.is_zero(pahi_inv) || B.array.is_zero(pbhi_inv)) + continue; + + auto ai = A.array.find(pahi_inv).get(); + auto bi = B.array.find(pbhi_inv).get(); + if (pa) ai = ai.permute(pa); + if (pb) bi = bi.permute(pb); + auto shape = trange.tile(i); + ai = ai.reshape(shape, batch); + bi = bi.reshape(shape, batch); + for (size_t k = 0; k < batch; ++k) { + using Ix = ::Einsum::Index; + if constexpr (AreArrayToT) { + auto aik = ai.batch(k); + auto bik = bi.batch(k); + auto vol = aik.total_size(); + TA_ASSERT(vol == bik.total_size()); + + auto &el = tile({k}); + using TensorT = std::remove_reference_t; + + auto mult_op = [&inner](auto const &l, + auto const &r) -> TensorT { + return inner.h ? TA::detail::tensor_hadamard(l, inner.A, r, + inner.B, inner.C) + : TA::detail::tensor_contract( + l, inner.A, r, inner.B, inner.C); + }; + + for (auto i = 0; i < vol; ++i) + el.add_to(mult_op(aik.data()[i], bik.data()[i])); + + } else { + auto hk = ai.batch(k).dot(bi.batch(k)); + tile({k}) += hk; + } + } } + auto pc = C.permutation; + auto shape = apply_inverse(pc, C.array.trange().tile(h)); + tile = tile.reshape(shape); + if (pc) tile = tile.permute(pc); + C.array.set(h, tile); } - auto pc = C.permutation; - auto shape = apply_inverse(pc, C.array.trange().tile(h)); - tile = tile.reshape(shape); - if (pc) tile = tile.permute(pc); - C.array.set(h, tile); + return C.array; } - return C.array; } - } - // generalized contraction + // generalized contraction - auto update_tr = [&e = std::as_const(e), &i = std::as_const(i), - &range_map = std::as_const(range_map)](auto &term) { - auto ei = (e + i & term.idx); - term.ei_tiled_range = TiledRange(range_map[ei]); - for (auto idx : ei) { - term.tiles *= Range(range_map[idx].tiles_range()); - } - }; + auto update_tr = [&e = std::as_const(e), &i = std::as_const(i), + &range_map = std::as_const(range_map)](auto &term) { + auto ei = (e + i & term.idx); + term.ei_tiled_range = TiledRange(range_map[ei]); + for (auto idx : ei) { + term.tiles *= Range(range_map[idx].tiles_range()); + } + }; - std::invoke(update_tr, std::get<0>(AB)); - std::invoke(update_tr, std::get<1>(AB)); + std::invoke(update_tr, std::get<0>(AB)); + std::invoke(update_tr, std::get<1>(AB)); - std::vector> worlds; - std::vector> local_tiles; + std::vector> worlds; + std::vector> local_tiles; - // iterates over tiles of hadamard indices - for (Index h : H.tiles) { - auto &[A, B] = AB; - auto own = A.own(h) || B.own(h); - auto comm = world.mpi.comm().Split(own, world.rank()); - worlds.push_back(std::make_unique(comm)); - auto &owners = worlds.back(); - if (!own) continue; - size_t batch = 1; - for (size_t i = 0; i < h.size(); ++i) { - batch *= H.batch[i].at(h[i]); - } - - auto retile = [&owners, &h = std::as_const(h), batch](auto &term) { - term.local_tiles.clear(); - const Permutation &P = term.permutation; + // iterates over tiles of hadamard indices + for (Index h : H.tiles) { + auto &[A, B] = AB; + auto own = A.own(h) || B.own(h); + auto comm = world.mpi.comm().Split(own, world.rank()); + worlds.push_back(std::make_unique(comm)); + auto &owners = worlds.back(); + if (!own) continue; + size_t batch = 1; + for (size_t i = 0; i < h.size(); ++i) { + batch *= H.batch[i].at(h[i]); + } - for (Index ei : term.tiles) { - auto idx = apply_inverse(P, h + ei); - if (!term.array.is_local(idx)) continue; - if (term.array.is_zero(idx)) continue; + auto retile = [&owners, &h = std::as_const(h), batch](auto &term) { + term.local_tiles.clear(); + const Permutation &P = term.permutation; + + for (Index ei : term.tiles) { + auto idx = apply_inverse(P, h + ei); + if (!term.array.is_local(idx)) continue; + if (term.array.is_zero(idx)) continue; + // TODO no need for immediate evaluation + auto tile = term.array.find_local(idx).get(); + if (P) tile = tile.permute(P); + auto shape = term.ei_tiled_range.tile(ei); + tile = tile.reshape(shape, batch); + term.local_tiles.push_back({ei, tile}); + } + bool replicated = term.array.pmap()->is_replicated(); + term.ei = TiledArray::make_array( + *owners, term.ei_tiled_range, term.local_tiles.begin(), + term.local_tiles.end(), replicated); + }; + std::invoke(retile, std::get<0>(AB)); + std::invoke(retile, std::get<1>(AB)); + + C.ei(C.expr) = (A.ei(A.expr) * B.ei(B.expr)).set_world(*owners); + A.ei.defer_deleter_to_next_fence(); + B.ei.defer_deleter_to_next_fence(); + A.ei = ArrayA(); + B.ei = ArrayB(); + // why omitting this fence leads to deadlock? + owners->gop.fence(); + for (Index e : C.tiles) { + if (!C.ei.is_local(e)) continue; + if (C.ei.is_zero(e)) continue; // TODO no need for immediate evaluation - auto tile = term.array.find_local(idx).get(); + auto tile = C.ei.find_local(e).get(); + assert(tile.nbatch() == batch); + const Permutation &P = C.permutation; + auto c = apply(P, h + e); + auto shape = C.array.trange().tile(c); + shape = apply_inverse(P, shape); + tile = tile.reshape(shape); if (P) tile = tile.permute(P); - auto shape = term.ei_tiled_range.tile(ei); - tile = tile.reshape(shape, batch); - term.local_tiles.push_back({ei, tile}); + local_tiles.push_back({c, tile}); } - bool replicated = term.array.pmap()->is_replicated(); - term.ei = TiledArray::make_array( - *owners, term.ei_tiled_range, term.local_tiles.begin(), - term.local_tiles.end(), replicated); - }; - std::invoke(retile, std::get<0>(AB)); - std::invoke(retile, std::get<1>(AB)); - - C.ei(C.expr) = (A.ei(A.expr) * B.ei(B.expr)).set_world(*owners); - A.ei.defer_deleter_to_next_fence(); - B.ei.defer_deleter_to_next_fence(); - A.ei = ArrayA(); - B.ei = ArrayB(); - // why omitting this fence leads to deadlock? - owners->gop.fence(); - for (Index e : C.tiles) { - if (!C.ei.is_local(e)) continue; - if (C.ei.is_zero(e)) continue; - // TODO no need for immediate evaluation - auto tile = C.ei.find_local(e).get(); - assert(tile.nbatch() == batch); - const Permutation &P = C.permutation; - auto c = apply(P, h + e); - auto shape = C.array.trange().tile(c); - shape = apply_inverse(P, shape); - tile = tile.reshape(shape); - if (P) tile = tile.permute(P); - local_tiles.push_back({c, tile}); + // mark for lazy deletion + C.ei = ArrayC(); + } + + if constexpr (!ResultShape::is_dense()) { + TiledRange tiled_range = TiledRange(range_map[c]); + std::vector> tile_norms; + for (auto &[index, tile] : local_tiles) { + tile_norms.push_back({index, tile.norm()}); + } + ResultShape shape(world, tile_norms, tiled_range); + C.array = ArrayC(world, TiledRange(range_map[c]), shape); } - // mark for lazy deletion - C.ei = ArrayC(); - } - if constexpr (!ResultShape::is_dense()) { - TiledRange tiled_range = TiledRange(range_map[c]); - std::vector> tile_norms; for (auto &[index, tile] : local_tiles) { - tile_norms.push_back({index, tile.norm()}); + if (C.array.is_zero(index)) continue; + C.array.set(index, tile); } - ResultShape shape(world, tile_norms, tiled_range); - C.array = ArrayC(world, TiledRange(range_map[c]), shape); - } - for (auto &[index, tile] : local_tiles) { - if (C.array.is_zero(index)) continue; - C.array.set(index, tile); - } + for (auto &w : worlds) { + w->gop.fence(); + } - for (auto &w : worlds) { - w->gop.fence(); + return C.array; } - - return C.array; } /// Computes ternary tensor product whose result @@ -463,13 +676,19 @@ auto einsum(expressions::TsrExpr A, expressions::TsrExpr B) { /// @param[in] r result indices /// @warning just as in the plain expression code, reductions are a special /// case; use Expr::reduce() -template +template auto einsum(expressions::TsrExpr A, expressions::TsrExpr B, const std::string &cs, World &world = get_default_world()) { using ECT = expressions::TsrExpr; using ECU = expressions::TsrExpr; - using ResultExprT = std::conditional_t, T, U>; - return Einsum::einsum(ECT(A), ECU(B), Einsum::idx(cs), world); + + using ResultExprT = + std::conditional_t, + Einsum::MaxNestedArray>; + + return Einsum::einsum(ECT(A), ECU(B), + Einsum::idx(cs), world); } template @@ -488,14 +707,44 @@ namespace TiledArray { using expressions::dot; using expressions::einsum; -template -auto einsum(const std::string &expr, const DistArray &A, - const DistArray &B, World &world = get_default_world()) { - namespace string = ::Einsum::string; - auto [lhs, rhs] = string::split2(expr, "->"); - auto [a, b] = string::split2(lhs, ","); - return einsum(A(string::join(a, ",")), B(string::join(b, ",")), - string::join(rhs, ","), world); +template +auto einsum(const std::string &expr, const DistArray &A, + const DistArray &B, World &world = get_default_world()) { + using ::Einsum::string::join; + using ::Einsum::string::split2; + + struct { + std::string A, B, C; + } annot; + + { + struct { + std::string A, B, C; + } outer; + + struct { + std::string A, B, C; + } inner; + + auto [ab, aC] = split2(expr, "->"); + std::tie(outer.C, inner.C) = split2(aC, ";"); + + auto [aA, aB] = split2(ab, ","); + std::tie(outer.A, inner.A) = split2(aA, ";"); + std::tie(outer.B, inner.B) = split2(aB, ";"); + + auto combine = [](auto const &outer, auto const &inner) { + return inner.empty() ? join(outer, ",") + : (join(outer, ",") + ";" + join(inner, ",")); + }; + + annot.A = combine(outer.A, inner.A); + annot.B = combine(outer.B, inner.B); + annot.C = combine(outer.C, inner.C); + } + + return einsum(A(annot.A), B(annot.B), annot.C, world); } /// Computes ternary tensor product whose result diff --git a/src/TiledArray/expressions/add_engine.h b/src/TiledArray/expressions/add_engine.h index 9421f6ffb2..f4a879365a 100644 --- a/src/TiledArray/expressions/add_engine.h +++ b/src/TiledArray/expressions/add_engine.h @@ -195,10 +195,11 @@ class AddEngine : public BinaryEngine> { /// \param perm The permutation to be applied to tiles /// \return The tile operation - template >> - static op_type make_tile_op(const Perm& perm) { - return op_type(op_base_type(), perm); + template >>> + static op_type make_tile_op(Perm&& perm) { + return op_type(op_base_type(), std::forward(perm)); } /// Expression identification tag @@ -296,10 +297,11 @@ class ScalAddEngine /// \param perm The permutation to be applied to tiles /// \return The tile operation - template >> - op_type make_tile_op(const Perm& perm) const { - return op_type(op_base_type(factor_), perm); + template >>> + op_type make_tile_op(Perm&& perm) const { + return op_type(op_base_type(factor_), std::forward(perm)); } /// Scaling factor accessor diff --git a/src/TiledArray/expressions/binary_engine.h b/src/TiledArray/expressions/binary_engine.h index 411a1c7c13..33318b57a6 100644 --- a/src/TiledArray/expressions/binary_engine.h +++ b/src/TiledArray/expressions/binary_engine.h @@ -75,9 +75,10 @@ class BinaryEngine : public ExprEngine { protected: // Import base class variables to this scope + using ExprEngine_::implicit_permute_inner_; + using ExprEngine_::implicit_permute_outer_; using ExprEngine_::indices_; using ExprEngine_::perm_; - using ExprEngine_::permute_tiles_; using ExprEngine_::pmap_; using ExprEngine_::shape_; using ExprEngine_::trange_; @@ -96,14 +97,14 @@ class BinaryEngine : public ExprEngine { PermutationType right_inner_permtype_ = PermutationType::general; ///< Right-hand permutation type - template + template void init_indices_(const BipartiteIndexList& target_indices = {}) { - static_assert(ProductType == TensorProduct::Contraction || - ProductType == TensorProduct::Hadamard); + static_assert(OuterProductType == TensorProduct::Contraction || + OuterProductType == TensorProduct::Hadamard); // prefer to permute the arg with fewest leaves to try to minimize the // number of possible permutations using permopt_type = - std::conditional_t; @@ -150,43 +151,25 @@ class BinaryEngine : public ExprEngine { !left_tile_is_tot && !right_tile_is_tot; constexpr bool args_are_mixed_tensors = left_tile_is_tot ^ right_tile_is_tot; - if (args_are_plain_tensors && - (left_outer_permtype_ == PermutationType::matrix_transpose || - left_outer_permtype_ == PermutationType::identity)) { - left_.permute_tiles(false); + // implicit_permute_{outer,inner}() denotes whether permutations will be + // fused into consuming operation + if (left_outer_permtype_ == PermutationType::matrix_transpose || + left_outer_permtype_ == PermutationType::identity) { + left_.implicit_permute_outer(true); } - if (!args_are_plain_tensors && - ((left_outer_permtype_ == PermutationType::matrix_transpose || - left_outer_permtype_ == PermutationType::identity) || - (left_inner_permtype_ == PermutationType::matrix_transpose || - left_inner_permtype_ == PermutationType::identity))) { - left_.permute_tiles(false); + if (left_tile_is_tot && + (left_inner_permtype_ == PermutationType::matrix_transpose || + left_inner_permtype_ == PermutationType::identity)) { + left_.implicit_permute_inner(true); } - if (args_are_plain_tensors && - (right_outer_permtype_ == PermutationType::matrix_transpose || - right_outer_permtype_ == PermutationType::identity)) { - right_.permute_tiles(false); + if (right_outer_permtype_ == PermutationType::matrix_transpose || + right_outer_permtype_ == PermutationType::identity) { + right_.implicit_permute_outer(true); } - if (!args_are_plain_tensors && - ((left_outer_permtype_ == PermutationType::matrix_transpose || - left_outer_permtype_ == PermutationType::identity) || - (right_inner_permtype_ == PermutationType::matrix_transpose || - right_inner_permtype_ == PermutationType::identity))) { - right_.permute_tiles(false); - } - if (args_are_mixed_tensors && - ((left_outer_permtype_ == PermutationType::matrix_transpose || - left_outer_permtype_ == PermutationType::identity) || - (left_inner_permtype_ == PermutationType::matrix_transpose || - left_inner_permtype_ == PermutationType::identity))) { - left_.permute_tiles(false); - } - if (args_are_mixed_tensors && - ((left_outer_permtype_ == PermutationType::matrix_transpose || - left_outer_permtype_ == PermutationType::identity) || - (right_inner_permtype_ == PermutationType::matrix_transpose || - right_inner_permtype_ == PermutationType::identity))) { - right_.permute_tiles(false); + if (right_tile_is_tot && + (right_inner_permtype_ == PermutationType::matrix_transpose || + right_inner_permtype_ == PermutationType::identity)) { + right_.implicit_permute_inner(true); } } @@ -203,16 +186,18 @@ class BinaryEngine : public ExprEngine { /// result of this expression will be permuted to match \c target_indices. /// \param target_indices The target index list for this expression void perm_indices(const BipartiteIndexList& target_indices) { - if (permute_tiles_) { - TA_ASSERT(left_.indices().size() == target_indices.size() || - (left_.indices().second().size() ^ target_indices.second().size())); - TA_ASSERT(right_.indices().size() == target_indices.size() || - (right_.indices().second().size() ^ target_indices.second().size())); + if (!this->implicit_permute()) { + TA_ASSERT( + left_.indices().size() == target_indices.size() || + (left_.indices().second().size() ^ target_indices.second().size())); + TA_ASSERT( + right_.indices().size() == target_indices.size() || + (right_.indices().second().size() ^ target_indices.second().size())); init_indices_(target_indices); - TA_ASSERT(right_outer_permtype_ == PermutationType::general || - right_inner_permtype_ == PermutationType::general); + TA_ASSERT(left_outer_permtype_ == PermutationType::general && + right_outer_permtype_ == PermutationType::general); if (left_.indices() != left_indices_) left_.perm_indices(left_indices_); if (right_.indices() != right_indices_) diff --git a/src/TiledArray/expressions/blk_tsr_engine.h b/src/TiledArray/expressions/blk_tsr_engine.h index 31ad29ee74..e85aac7925 100644 --- a/src/TiledArray/expressions/blk_tsr_engine.h +++ b/src/TiledArray/expressions/blk_tsr_engine.h @@ -147,9 +147,10 @@ class BlkTsrEngineBase : public LeafEngine { protected: // Import base class variables to this scope + using ExprEngine_::implicit_permute_inner_; + using ExprEngine_::implicit_permute_outer_; using ExprEngine_::indices_; using ExprEngine_::perm_; - using ExprEngine_::permute_tiles_; using ExprEngine_::pmap_; using ExprEngine_::shape_; using ExprEngine_::trange_; @@ -341,9 +342,10 @@ class BlkTsrEngine // Import base class variables to this scope using BlkTsrEngineBase_::lower_bound_; using BlkTsrEngineBase_::upper_bound_; + using ExprEngine_::implicit_permute_inner_; + using ExprEngine_::implicit_permute_outer_; using ExprEngine_::indices_; using ExprEngine_::perm_; - using ExprEngine_::permute_tiles_; using ExprEngine_::pmap_; using ExprEngine_::shape_; using ExprEngine_::trange_; @@ -403,9 +405,10 @@ class BlkTsrEngine /// \param perm The permutation to be applied to tiles /// \return The tile operation - template >> - op_type make_tile_op(const Perm& perm) const { + template >>> + op_type make_tile_op(Perm&& perm) const { const unsigned int rank = trange_.tiles_range().rank(); // Construct and allocate memory for the shift range @@ -429,7 +432,7 @@ class BlkTsrEngine } } - return op_type(op_base_type(range_shift), perm); + return op_type(op_base_type(range_shift), std::forward(perm)); } /// Expression identification tag @@ -494,9 +497,10 @@ class ScalBlkTsrEngine // Import base class variables to this scope using BlkTsrEngineBase_::lower_bound_; using BlkTsrEngineBase_::upper_bound_; + using ExprEngine_::implicit_permute_inner_; + using ExprEngine_::implicit_permute_outer_; using ExprEngine_::indices_; using ExprEngine_::perm_; - using ExprEngine_::permute_tiles_; using ExprEngine_::pmap_; using ExprEngine_::shape_; using ExprEngine_::trange_; @@ -558,9 +562,10 @@ class ScalBlkTsrEngine /// \param perm The permutation to be applied to tiles /// \return The tile operation - template >> - op_type make_tile_op(const Perm& perm) const { + template >>> + op_type make_tile_op(Perm&& perm) const { const unsigned int rank = trange_.tiles_range().rank(); // Construct and allocate memory for the shift range @@ -584,7 +589,8 @@ class ScalBlkTsrEngine } } - return op_type(op_base_type(range_shift, factor_), perm); + return op_type(op_base_type(range_shift, factor_), + std::forward(perm)); } /// Expression identification tag diff --git a/src/TiledArray/expressions/cont_engine.h b/src/TiledArray/expressions/cont_engine.h index 94562b5154..d40e9c88fc 100644 --- a/src/TiledArray/expressions/cont_engine.h +++ b/src/TiledArray/expressions/cont_engine.h @@ -94,9 +94,10 @@ class ContEngine : public BinaryEngine { using BinaryEngine_::right_indices_; using BinaryEngine_::right_inner_permtype_; using BinaryEngine_::right_outer_permtype_; + using ExprEngine_::implicit_permute_inner_; + using ExprEngine_::implicit_permute_outer_; using ExprEngine_::indices_; using ExprEngine_::perm_; - using ExprEngine_::permute_tiles_; using ExprEngine_::pmap_; using ExprEngine_::shape_; using ExprEngine_::trange_; @@ -125,7 +126,7 @@ class ContEngine : public BinaryEngine { ///< nested tensor expressions) std::function - element_return_op_; ///< Same as inner_tile_nonreturn_op_ but returns + element_return_op_; ///< Same as element_nonreturn_op_ but returns ///< the result TiledArray::detail::ProcGrid proc_grid_; ///< Process grid for the contraction @@ -202,7 +203,7 @@ class ContEngine : public BinaryEngine { void perm_indices(const BipartiteIndexList& target_indices) { // assert that init_indices has been called TA_ASSERT(left_.indices() && right_.indices()); - if (permute_tiles_) { + if (!this->implicit_permute()) { this->template init_indices_(target_indices); // propagate the indices down the tree, if needed @@ -262,31 +263,31 @@ class ContEngine : public BinaryEngine { // Initialize the tile operation in this function because it is used to // evaluate the tiled range and shape. - const math::blas::Op left_op = - (left_outer_permtype_ == PermutationType::matrix_transpose - ? math::blas::Transpose - : math::blas::NoTranspose); - const math::blas::Op right_op = - (right_outer_permtype_ == PermutationType::matrix_transpose - ? math::blas::Transpose - : math::blas::NoTranspose); + const auto left_op = to_cblas_op(left_outer_permtype_); + const auto right_op = to_cblas_op(right_outer_permtype_); + // initialize perm_ + this->init_perm(target_indices); + + // initialize op_, trange_, and shape_ which only refer to the outer modes if (outer(target_indices) != outer(indices_)) { + const auto outer_perm = outer(perm_); // Initialize permuted structure - perm_ = ExprEngine_::make_perm(target_indices); if constexpr (!TiledArray::detail::is_tensor_of_tensor_v) { - op_ = op_type(left_op, right_op, factor_, outer_size(indices_), - outer_size(left_indices_), outer_size(right_indices_), - (permute_tiles_ ? perm_ : BipartitePermutation{})); + op_ = op_type( + left_op, right_op, factor_, outer_size(indices_), + outer_size(left_indices_), outer_size(right_indices_), + (!implicit_permute_outer_ ? std::move(outer_perm) : Permutation{})); } else { // factor_ is absorbed into inner_tile_nonreturn_op_ - op_ = op_type(left_op, right_op, scalar_type(1), outer_size(indices_), - outer_size(left_indices_), outer_size(right_indices_), - (permute_tiles_ ? perm_ : BipartitePermutation{}), - this->element_nonreturn_op_); + op_ = op_type( + left_op, right_op, scalar_type(1), outer_size(indices_), + outer_size(left_indices_), outer_size(right_indices_), + (!implicit_permute_outer_ ? std::move(outer_perm) : Permutation{}), + this->element_nonreturn_op_); } - trange_ = ContEngine_::make_trange(outer(perm_)); - shape_ = ContEngine_::make_shape(outer(perm_)); + trange_ = ContEngine_::make_trange(outer_perm); + shape_ = ContEngine_::make_shape(outer_perm); } else { // Initialize non-permuted structure if constexpr (!TiledArray::detail::is_tensor_of_tensor_v) { @@ -500,8 +501,8 @@ class ContEngine : public BinaryEngine { this->factor_, inner_size(this->indices_), inner_size(this->left_indices_), inner_size(this->right_indices_), - (this->permute_tiles_ ? inner(this->perm_) - : Permutation{})) + (!this->implicit_permute_inner_ ? inner(this->perm_) + : Permutation{})) : op_type(to_cblas_op(this->left_inner_permtype_), to_cblas_op(this->right_inner_permtype_), this->factor_, inner_size(this->indices_), @@ -531,7 +532,7 @@ class ContEngine : public BinaryEngine { // multiple times, e.g. when outer op is gemm auto mult_op = (inner_target_indices != inner(this->indices_)) - ? op_type(base_op_type(), this->permute_tiles_ + ? op_type(base_op_type(), !this->implicit_permute_inner_ ? inner(this->perm_) : Permutation{}) : op_type(base_op_type()); @@ -562,12 +563,12 @@ class ContEngine : public BinaryEngine { using op_type = TiledArray::detail::BinaryWrapper< base_op_type>; // can't consume inputs if they are used // multiple times, e.g. when outer op is gemm - auto mult_op = - (inner_target_indices != inner(this->indices_)) - ? op_type(base_op_type(this->factor_), - this->permute_tiles_ ? inner(this->perm_) - : Permutation{}) - : op_type(base_op_type(this->factor_)); + auto mult_op = (inner_target_indices != inner(this->indices_)) + ? op_type(base_op_type(this->factor_), + !this->implicit_permute_inner_ + ? inner(this->perm_) + : Permutation{}) + : op_type(base_op_type(this->factor_)); this->element_nonreturn_op_ = [mult_op, outer_prod](result_tile_element_type& result, const left_tile_element_type& left, @@ -608,8 +609,9 @@ class ContEngine : public BinaryEngine { std::conditional_t; - auto scal_op = [perm = this->permute_tiles_ ? inner(this->perm_) - : Permutation{}]( + auto scal_op = [perm = !this->implicit_permute_inner_ + ? inner(this->perm_) + : Permutation{}]( const left_tile_element_type& left, const right_tile_element_type& right) -> result_tile_element_type { @@ -628,10 +630,21 @@ class ContEngine : public BinaryEngine { abort(); // unreachable }; this->element_nonreturn_op_ = - [scal_op](result_tile_element_type& result, - const left_tile_element_type& left, - const right_tile_element_type& right) { - result = scal_op(left, right); + [scal_op, outer_prod = (this->product_type())]( + result_tile_element_type& result, + const left_tile_element_type& left, + const right_tile_element_type& right) { + if (outer_prod == TensorProduct::Contraction) { + if (empty(result)) + result = scal_op(left, right); + else { + auto result_increment = scal_op(left, right); + add_to(result, result_increment); + } + // result += scal_op(left, right); + } else { + result = scal_op(left, right); + } }; } } else diff --git a/src/TiledArray/expressions/expr_engine.h b/src/TiledArray/expressions/expr_engine.h index c364a5c1ba..a502857af9 100644 --- a/src/TiledArray/expressions/expr_engine.h +++ b/src/TiledArray/expressions/expr_engine.h @@ -54,6 +54,8 @@ class ExprEngine : private NO_DEFAULTS { typename EngineTrait::op_type op_type; ///< Tile operation type typedef typename EngineTrait::policy policy; ///< The result policy type + typedef typename EngineTrait::eval_type + eval_type; ///< Evaluation tile type typedef typename EngineTrait::dist_eval_type dist_eval_type; ///< This expression's distributed evaluator type @@ -74,9 +76,12 @@ class ExprEngine : private NO_DEFAULTS { BipartiteIndexList indices_; ///< The index list of this expression; bipartite due to need ///< to support nested tensors (e.g. tensors of tensors) - bool permute_tiles_; ///< Result tile permutation flag (\c true == permute - ///< tile) - /// The permutation that will be applied to the outer tensor of tensors + bool implicit_permute_outer_ = false; ///< If false, result tiles' outer + ///< modes will not need to be permuted + bool implicit_permute_inner_ = false; ///< If false, result tiles' inner + ///< modes will not need to be permuted + /// The permutation that will be applied to the result tensor (or tensor of + /// tensors) BipartitePermutation perm_; trange_type trange_; ///< The tiled range of the result tensor shape_type shape_; ///< The shape of the result tensor @@ -93,7 +98,6 @@ class ExprEngine : private NO_DEFAULTS { ExprEngine(const Expr& expr) : world_(NULL), indices_(), - permute_tiles_(true), perm_(), trange_(), shape_(), @@ -141,7 +145,7 @@ class ExprEngine : private NO_DEFAULTS { /// This function will initialize the permutation, tiled range, and shape /// for the result tensor. These members are initialized with the - /// make_perm(), \c make_trange(), and make_shape() functions. + /// \c init_perm(), \c make_trange(), and make_shape() functions. /// Derived classes may customize the structure initialization by /// providing their own implementation of this function or any of the /// above initialization. @@ -149,7 +153,7 @@ class ExprEngine : private NO_DEFAULTS { /// \param target_indices The target index list for the result tensor void init_struct(const BipartiteIndexList& target_indices) { if (target_indices != indices_) { - perm_ = derived().make_perm(target_indices); + if (!perm_) perm_ = make_perm(target_indices); trange_ = derived().make_trange(outer(perm_)); shape_ = derived().make_shape(outer(perm_)); } else { @@ -187,20 +191,41 @@ class ExprEngine : private NO_DEFAULTS { /// providing their own implementation it. BipartitePermutation make_perm( const BipartiteIndexList& target_indices) const { + TA_ASSERT(target_indices != indices_); return target_indices.permutation(indices_); } + void init_perm(const BipartiteIndexList& target_indices) { + if (!perm_ && target_indices != indices_) perm_ = make_perm(target_indices); + } + /// Tile operation factory function /// This function will generate the tile operations by calling /// \c make_tile_op(). The permuting or non-permuting version of the tile - /// operation will be selected based on permute_tiles(). Derived classes - /// may customize this function by providing their own implementation it. + /// operation will be selected based on implicit_permute_outer(). Derived + /// classes may customize this function by providing their own implementation + /// it. op_type make_op() const { - if (perm_ && permute_tiles_) - // permutation can only be applied to the tile, not to its element (if - // tile = tensor-of-tensors) - return derived().make_tile_op(perm_); + // figure out which permutations (of outer or inner modes) must be enacted + // explicitly + BipartitePermutation explicit_perm; + if (implicit_permute_outer_) { + if (!implicit_permute_inner_) { + explicit_perm = BipartitePermutation(Permutation{}, inner(perm_)); + } + } else { + if (implicit_permute_inner_) { + explicit_perm = BipartitePermutation(outer(perm_), Permutation{}); + } else { + explicit_perm = perm_; + } + } + const bool explicit_perm_is_nontrivial = + !(explicit_perm.first().is_identity() && + explicit_perm.second().is_identity()); + if (explicit_perm && explicit_perm_is_nontrivial) + return derived().make_tile_op(explicit_perm); else return derived().make_tile_op(); } @@ -243,11 +268,47 @@ class ExprEngine : private NO_DEFAULTS { /// \return A const reference to the process map const std::shared_ptr& pmap() const { return pmap_; } - /// Set the permute tiles flag + /// Set the flag that controls whether tiles' permutation will be implicit + + /// some consuming operations (like GEMM) permutation can perform some + /// permutation types implicitly. setting this to true indicates that the + /// result tiles' outer modes do not need to be permuted and permutation will + /// be performed implicitly by the consuming operation + /// \param status The new value for the implicit permute flag + /// (true => will not permute outer modes of result tiles; + /// false => will permute outer modes of result tiles if needed) + /// \note for plain tensors, i.e., tensor-of-scalars, any mode is + /// outer + void implicit_permute_outer(const bool status) { + implicit_permute_outer_ = status; + } + + /// Set the flag that controls whether tiles' permutation will be implicit + + /// some consuming operations (like GEMM) permutation can perform some + /// permutation types implicitly. setting this to true indicates that the + /// result tiles' inner modes do not need to be permuted and permutation will + /// be performed implicitly by the consuming operation + /// \param status The new value for the implicit permute flag + /// (true => will not permute inner modes of result tiles; + /// false => will permute inner modes of result tiles if needed) + /// \note for plain tensors, i.e., tensor-of-scalars, there are no + /// inner modes and this should not be used + void implicit_permute_inner(const bool status) { + TA_ASSERT(TiledArray::detail::is_tensor_of_tensor_v); + implicit_permute_inner_ = status; + } - /// \param status The new status for permute tiles (true == permute result - /// tiles) - void permute_tiles(const bool status) { permute_tiles_ = status; } + /// Reports whether permutation of the result tiles will be implicit, i.e. + /// will be fused into the consuming operation + + /// \return true if will not permute of result tiles; false will indicate that + /// the result tiles will be permuted if needed + bool implicit_permute() const { + constexpr bool is_tot = + TiledArray::detail::is_tensor_of_tensor_v; + return (implicit_permute_outer_ || (is_tot && implicit_permute_inner_)); + } /// Expression print @@ -255,9 +316,23 @@ class ExprEngine : private NO_DEFAULTS { /// \param target_indices The target index list for this expression void print(ExprOStream& os, const BipartiteIndexList& target_indices) const { if (perm_) { - os << "[P " << target_indices << "]" - << (permute_tiles_ ? " " : " [no permute tiles] ") - << derived().make_tag() << indices_ << "\n"; + os << "[P " << target_indices << "]"; + if (implicit_permute_outer_ || implicit_permute_inner_) { + os << " [implicit "; + constexpr bool is_tot = + TiledArray::detail::is_tensor_of_tensor_v; + if constexpr (is_tot) { + if (implicit_permute_outer_ && implicit_permute_inner_) { + os << "outer&inner "; + } else if (implicit_permute_outer_) { + os << "outer "; + } else + os << "inner "; + } + os << "permute ] "; + } else + os << " "; + os << derived().make_tag() << indices_ << "\n"; } else { os << derived().make_tag() << indices_ << "\n"; } diff --git a/src/TiledArray/expressions/leaf_engine.h b/src/TiledArray/expressions/leaf_engine.h index 5e273fb5dc..8804989d6f 100644 --- a/src/TiledArray/expressions/leaf_engine.h +++ b/src/TiledArray/expressions/leaf_engine.h @@ -70,9 +70,10 @@ class LeafEngine : public ExprEngine { protected: // Import base class variables to this scope + using ExprEngine_::implicit_permute_inner_; + using ExprEngine_::implicit_permute_outer_; using ExprEngine_::indices_; using ExprEngine_::perm_; - using ExprEngine_::permute_tiles_; using ExprEngine_::pmap_; using ExprEngine_::shape_; using ExprEngine_::trange_; diff --git a/src/TiledArray/expressions/mult_engine.h b/src/TiledArray/expressions/mult_engine.h index 20093b2cec..84d11bd4c0 100644 --- a/src/TiledArray/expressions/mult_engine.h +++ b/src/TiledArray/expressions/mult_engine.h @@ -299,7 +299,6 @@ class MultEngine : public ContEngine> { // the tile op; the type of the tile op does not need to match the type of // the operation on the outer indices if (this->product_type() == TensorProduct::Hadamard) { - // assumes inner op is also Hadamard BinaryEngine_::perm_indices(target_indices); } else { auto children_initialized = true; @@ -335,6 +334,9 @@ class MultEngine : public ContEngine> { /// for the result tensor. /// \param target_indices The target index list for the result tensor void init_struct(const BipartiteIndexList& target_indices) { + this->init_perm(target_indices); + + // for ContEngine_::init_struct need to initialize element op first this->init_inner_tile_op(inner(target_indices)); if (this->product_type() == TensorProduct::Contraction) ContEngine_::init_struct(target_indices); @@ -421,9 +423,10 @@ class MultEngine : public ContEngine> { /// \param perm The permutation to be applied to the result /// \return The tile operation - template >> - op_type make_tile_op(const Perm& perm) const { + template >>> + op_type make_tile_op(Perm&& perm) const { if constexpr (TiledArray::detail::is_tensor_of_tensor_v< value_type>) { // nested tensors const auto inner_prod = this->inner_product_type(); @@ -431,15 +434,21 @@ class MultEngine : public ContEngine> { TA_ASSERT(this->product_type() == inner_prod); // Hadamard automatically works for inner // dimensions as well - return op_type(op_base_type(), perm); + return op_type(op_base_type(), std::forward(perm)); } else if (inner_prod == TensorProduct::Contraction) { - return op_type(op_base_type(this->element_return_op_), perm); + // inner permutation, if needed, was fused into inner op, do not apply + // inner part of the perm again + return op_type(op_base_type(this->element_return_op_), + outer(std::forward(perm))); } else if (inner_prod == TensorProduct::Scale) { - return op_type(op_base_type(this->element_return_op_), perm); + // inner permutation, if needed, was fused into inner op, do not apply + // inner part of the perm again + return op_type(op_base_type(this->element_return_op_), + outer(std::forward(perm))); } else abort(); } else { // plain tensor - return op_type(op_base_type(), perm); + return op_type(op_base_type(), std::forward(perm)); } abort(); // unreachable } @@ -593,6 +602,9 @@ class ScalMultEngine /// for the result tensor. /// \param target_indices The target index list for the result tensor void init_struct(const BipartiteIndexList& target_indices) { + this->init_perm(target_indices); + + // for ContEngine_::init_struct need to initialize element op first this->init_inner_tile_op(inner(target_indices)); if (this->product_type() == TensorProduct::Contraction) ContEngine_::init_struct(target_indices); @@ -673,10 +685,12 @@ class ScalMultEngine /// \param perm The permutation to be applied to tiles /// \return The tile operation - template >> - op_type make_tile_op(const Perm& perm) const { - return op_type(op_base_type(ContEngine_::factor_), perm); + template >>> + op_type make_tile_op(Perm&& perm) const { + return op_type(op_base_type(ContEngine_::factor_), + std::forward(perm)); } /// Expression identification tag diff --git a/src/TiledArray/expressions/permopt.h b/src/TiledArray/expressions/permopt.h index 998ea78efe..291604faa8 100644 --- a/src/TiledArray/expressions/permopt.h +++ b/src/TiledArray/expressions/permopt.h @@ -45,8 +45,11 @@ namespace expressions { enum class PermutationType { identity = 1, matrix_transpose = 2, general = 3 }; inline blas::Op to_cblas_op(PermutationType permtype) { - TA_ASSERT(permtype == PermutationType::matrix_transpose || - permtype == PermutationType::identity); + // N.B. 3 cases: + // - permtype == identity : no transpose needed + // - permtype == matrix_transpose : transpose needed + // - permtype == general : the argument will be explicitly permuted to be in a + // layout which does not require permutation hence no need for a switch ... return permtype == PermutationType::matrix_transpose ? math::blas::Transpose : math::blas::NoTranspose; diff --git a/src/TiledArray/expressions/product.h b/src/TiledArray/expressions/product.h index 7111b7831b..df2867a360 100644 --- a/src/TiledArray/expressions/product.h +++ b/src/TiledArray/expressions/product.h @@ -73,8 +73,10 @@ inline TensorProduct compute_product_type(const IndexList& left_indices, const IndexList& right_indices, const IndexList& target_indices) { auto result = compute_product_type(left_indices, right_indices); - if (result == TensorProduct::Hadamard) + if (result == TensorProduct::Hadamard) { TA_ASSERT(left_indices.is_permutation(target_indices)); + TA_ASSERT(right_indices.is_permutation(target_indices)); + } return result; } diff --git a/src/TiledArray/expressions/scal_engine.h b/src/TiledArray/expressions/scal_engine.h index a2312fccb7..2c0d33bf33 100644 --- a/src/TiledArray/expressions/scal_engine.h +++ b/src/TiledArray/expressions/scal_engine.h @@ -146,10 +146,11 @@ class ScalEngine : public UnaryEngine> { /// \param perm The permutation to be applied to tiles /// \return The tile operation - template >> - op_type make_tile_op(const Perm& perm) const { - return op_type(perm, factor_); + template >>> + op_type make_tile_op(Perm&& perm) const { + return op_type(std::forward(perm), factor_); } /// Expression identification tag diff --git a/src/TiledArray/expressions/scal_tsr_engine.h b/src/TiledArray/expressions/scal_tsr_engine.h index 8dfcc596d9..8b38362740 100644 --- a/src/TiledArray/expressions/scal_tsr_engine.h +++ b/src/TiledArray/expressions/scal_tsr_engine.h @@ -140,10 +140,11 @@ class ScalTsrEngine : public LeafEngine> { /// \param perm The permutation to be applied to tiles /// \return The tile operation - template >> - op_type make_tile_op(const Perm& perm) const { - return op_type(op_base_type(factor_), perm); + template >>> + op_type make_tile_op(Perm&& perm) const { + return op_type(op_base_type(factor_), std::forward(perm)); } /// Expression identification tag diff --git a/src/TiledArray/expressions/subt_engine.h b/src/TiledArray/expressions/subt_engine.h index ab93dde1ea..3750a199c5 100644 --- a/src/TiledArray/expressions/subt_engine.h +++ b/src/TiledArray/expressions/subt_engine.h @@ -195,10 +195,11 @@ class SubtEngine : public BinaryEngine> { /// \param perm The permutation to be applied to tiles /// \return The tile operation - template >> - static op_type make_tile_op(const Perm& perm) { - return op_type(op_base_type(), perm); + template >>> + static op_type make_tile_op(Perm&& perm) { + return op_type(op_base_type(), std::forward(perm)); } /// Expression identification tag @@ -296,10 +297,11 @@ class ScalSubtEngine /// \param perm The permutation to be applied to tiles /// \return The tile operation - template >> - op_type make_tile_op(const Perm& perm) const { - return op_type(op_base_type(factor_), perm); + template >>> + op_type make_tile_op(Perm&& perm) const { + return op_type(op_base_type(factor_), std::forward(perm)); } /// Expression identification tag diff --git a/src/TiledArray/expressions/tsr_engine.h b/src/TiledArray/expressions/tsr_engine.h index 5219af37ca..20b893ead3 100644 --- a/src/TiledArray/expressions/tsr_engine.h +++ b/src/TiledArray/expressions/tsr_engine.h @@ -126,10 +126,11 @@ class TsrEngine : public LeafEngine> { /// \param perm The permutation to be applied to tiles /// \return The tile operation - template >> - static op_type make_tile_op(const Perm& perm) { - return op_type(op_base_type(), perm); + template >>> + static op_type make_tile_op(Perm&& perm) { + return op_type(op_base_type(), std::forward(perm)); } }; // class TsrEngine diff --git a/src/TiledArray/expressions/unary_engine.h b/src/TiledArray/expressions/unary_engine.h index 621c4a71b3..631fca8fed 100644 --- a/src/TiledArray/expressions/unary_engine.h +++ b/src/TiledArray/expressions/unary_engine.h @@ -70,9 +70,10 @@ class UnaryEngine : ExprEngine { protected: // Import base class variables to this scope + using ExprEngine_::implicit_permute_inner_; + using ExprEngine_::implicit_permute_outer_; using ExprEngine_::indices_; using ExprEngine_::perm_; - using ExprEngine_::permute_tiles_; using ExprEngine_::pmap_; using ExprEngine_::shape_; using ExprEngine_::trange_; @@ -99,7 +100,7 @@ class UnaryEngine : ExprEngine { /// children such that the number of permutations is minimized. /// \param target_indices The target index list for this expression void perm_indices(const BipartiteIndexList& target_indices) { - TA_ASSERT(permute_tiles_); + TA_ASSERT(!this->implicit_permute()); indices_ = target_indices; if (arg_.indices() != target_indices) arg_.perm_indices(target_indices); diff --git a/src/TiledArray/permutation.h b/src/TiledArray/permutation.h index cd527dfeef..d70b283034 100644 --- a/src/TiledArray/permutation.h +++ b/src/TiledArray/permutation.h @@ -271,7 +271,10 @@ class Permutation { /// \param i The element index /// \return The i-th element - index_type operator[](unsigned int i) const { return p_[i]; } + index_type operator[](unsigned int i) const { + TA_ASSERT(i < p_.size()); + return p_[i]; + } /// Cycles decomposition @@ -329,6 +332,13 @@ class Permutation { return result; } + /// + /// Checks if this permutation is the identity permutation. + /// + [[nodiscard]] bool is_identity() const { + return std::is_sorted(p_.begin(), p_.end()); + } + /// Identity permutation factory function /// \return An identity permutation @@ -402,11 +412,13 @@ class Permutation { /// Bool conversion /// \return \c true if the permutation is not empty, otherwise \c false. + /// \note equivalent to `this->size() != 0` explicit operator bool() const { return !p_.empty(); } /// Not operator /// \return \c true if the permutation is empty, otherwise \c false. + /// \note equivalent to `this->size() == 0` bool operator!() const { return p_.empty(); } /// Permutation data accessor @@ -421,7 +433,7 @@ class Permutation { /// \param[in,out] ar The serialization archive template void serialize(Archive& ar) { - ar& p_; + ar & p_; } }; // class Permutation @@ -721,6 +733,11 @@ class BipartitePermutation { init(); } + BipartitePermutation(Permutation&& p, index_type second_partition_size = 0) + : base_(std::move(p)), second_size_(second_partition_size) { + init(); + } + BipartitePermutation(const Permutation& first, const Permutation& second) : second_size_(second.size()) { vector base; @@ -778,9 +795,14 @@ class BipartitePermutation { } /// \return reference to the first partition - const Permutation& first() const { return first_; } + const Permutation& first() const& { return first_; } /// \return reference to the second partition - const Permutation& second() const { return second_; } + const Permutation& second() const& { return second_; } + + /// \return rvalue-reference to the first partition + Permutation&& first() && { return std::move(first_); } + /// \return reference to the second partition + Permutation&& second() && { return std::move(second_); } /// \return the size of the first partition index_type first_size() const { return this->size() - second_size_; } @@ -795,7 +817,7 @@ class BipartitePermutation { /// \param[in,out] ar The serialization archive template void serialize(Archive& ar) { - ar& base_& second_size_; + ar & base_ & second_size_; if constexpr (madness::is_input_archive_v) { first_ = {}; second_ = {}; @@ -858,6 +880,8 @@ inline auto inner(const Permutation& p) { // temporary inline auto outer(const Permutation& p) { return p; } +inline Permutation&& outer(Permutation&& p) { return std::move(p); } + inline auto inner_size(const Permutation& p) { abort(); return 0; @@ -867,8 +891,16 @@ inline auto outer_size(const Permutation& p) { return p.size(); } inline auto inner(const BipartitePermutation& p) { return p.second(); } +inline Permutation&& inner(BipartitePermutation&& p) { + return std::move(p).second(); +} + inline auto outer(const BipartitePermutation& p) { return p.first(); } +inline Permutation&& outer(BipartitePermutation&& p) { + return std::move(p).first(); +} + inline auto inner_size(const BipartitePermutation& p) { return p.second_size(); } diff --git a/src/TiledArray/range.h b/src/TiledArray/range.h index b4f2a0d48f..25e4852118 100644 --- a/src/TiledArray/range.h +++ b/src/TiledArray/range.h @@ -613,10 +613,10 @@ class Range { /// Permuting copy constructor - /// \param perm The permutation applied to other - /// \param other The range to be permuted and copied + /// \param perm The permutation applied to other; if `!perm` then no + /// permutation is applied \param other The range to be permuted and copied Range(const Permutation& perm, const Range_& other) { - TA_ASSERT(perm.size() == other.rank_); + TA_ASSERT(perm.size() == other.rank_ || !perm); if (other.rank_ > 0ul) { rank_ = other.rank_; @@ -1139,7 +1139,7 @@ class Range { template void serialize(Archive& ar) { - ar& rank_; + ar & rank_; const auto four_x_rank = rank_ << 2; // read via madness::archive::wrap to be able to // - avoid having to serialize datavec_'s size @@ -1151,7 +1151,7 @@ class Range { ar << madness::archive::wrap(datavec_.data(), four_x_rank); } else abort(); // unreachable - ar& offset_& volume_; + ar & offset_ & volume_; } void swap(Range_& other) { diff --git a/src/TiledArray/tensor/complex.h b/src/TiledArray/tensor/complex.h index 69a8971bf6..676327427f 100644 --- a/src/TiledArray/tensor/complex.h +++ b/src/TiledArray/tensor/complex.h @@ -81,30 +81,30 @@ TILEDARRAY_FORCE_INLINE auto inner_product(const L l, const R r) { return TiledArray::detail::conj(l) * r; } -/// Wrapper function for `std::norm` +/// Squared norm of a real number /// This function disables the call to `std::conj` for real values to /// prevent the result from being converted into a complex value. /// \tparam R A real scalar type /// \param r The real scalar -/// \return `r` +/// \return squared norm of `z` `r*r` template && !is_complex::value>::type* = nullptr> -TILEDARRAY_FORCE_INLINE R norm(const R r) { +TILEDARRAY_FORCE_INLINE R squared_norm(const R r) { return r * r; } -/// Compute the norm of a complex number `z` +/// Compute the squared norm of a complex number `z` /// \f[ -/// {\rm norm}(z) = zz^* = {\rm Re}(z)^2 + {\rm Im}(z)^2 +/// {\rm norm}(z)^2 = zz^* = {\rm Re}(z)^2 + {\rm Im}(z)^2 /// \f] /// \tparam R The scalar type /// \param z The complex scalar -/// \return The complex conjugate of `z` +/// \return squared norm of `z` template -TILEDARRAY_FORCE_INLINE R norm(const std::complex z) { +TILEDARRAY_FORCE_INLINE R squared_norm(const std::complex z) { const R real = z.real(); const R imag = z.imag(); return real * real + imag * imag; diff --git a/src/TiledArray/tensor/kernels.h b/src/TiledArray/tensor/kernels.h index 682cb1b209..876ed00feb 100644 --- a/src/TiledArray/tensor/kernels.h +++ b/src/TiledArray/tensor/kernels.h @@ -26,6 +26,8 @@ #ifndef TILEDARRAY_TENSOR_KENERLS_H__INCLUDED #define TILEDARRAY_TENSOR_KENERLS_H__INCLUDED +#include +#include #include #include #include @@ -37,6 +39,196 @@ class Tensor; namespace detail { +// ------------------------------------------------------------------------- +// Tensor GEMM + +/// Contract two tensors + +/// GEMM is limited to matrix like contractions. For example, the following +/// contractions are supported: +/// \code +/// C[a,b] = A[a,i,j] * B[i,j,b] +/// C[a,b] = A[a,i,j] * B[b,i,j] +/// C[a,b] = A[i,j,a] * B[i,j,b] +/// C[a,b] = A[i,j,a] * B[b,i,j] +/// +/// C[a,b,c,d] = A[a,b,i,j] * B[i,j,c,d] +/// C[a,b,c,d] = A[a,b,i,j] * B[c,d,i,j] +/// C[a,b,c,d] = A[i,j,a,b] * B[i,j,c,d] +/// C[a,b,c,d] = A[i,j,a,b] * B[c,d,i,j] +/// \endcode +/// Notice that in the above contractions, the inner and outer indices of +/// the arguments for exactly two contiguous groups in each tensor and that +/// each group is in the same order in all tensors. That is, the indices of +/// the tensors must fit the one of the following patterns: +/// \code +/// C[M...,N...] = A[M...,K...] * B[K...,N...] +/// C[M...,N...] = A[M...,K...] * B[N...,K...] +/// C[M...,N...] = A[K...,M...] * B[K...,N...] +/// C[M...,N...] = A[K...,M...] * B[N...,K...] +/// \endcode +/// This allows use of optimized BLAS functions to evaluate tensor +/// contractions. Tensor contractions that do not fit this pattern require +/// one or more tensor permutation so that the tensors fit the required +/// pattern. +/// \tparam U The left-hand tensor element type +/// \tparam AU The left-hand tensor allocator type +/// \tparam V The right-hand tensor element type +/// \tparam AV The right-hand tensor allocator type +/// \tparam W The type of the scaling factor +/// \param left The left-hand tensor that will be contracted +/// \param right The right-hand tensor that will be contracted +/// \param factor The contraction result will be scaling by this value, then +/// accumulated into \c this \param gemm_helper The *GEMM operation meta data +/// \return A reference to \c this +/// \note if this is uninitialized, i.e., if \c this->empty()==true will +/// this is equivalent to +/// \code +/// return (*this = left.gemm(right, factor, gemm_helper)); +/// \endcode +template +void gemm(Alpha alpha, const Tensor& A, const Tensor& B, + Beta beta, Tensor& C, const math::GemmHelper& gemm_helper) { + static_assert(!detail::is_tensor_of_tensor_v, Tensor, + Tensor>, + "TA::Tensor::gemm without custom element op is " + "only applicable to " + "plain tensors"); + { + // Check that tensor C is not empty and has the correct rank + TA_ASSERT(!C.empty()); + TA_ASSERT(C.range().rank() == gemm_helper.result_rank()); + + // Check that the arguments are not empty and have the correct ranks + TA_ASSERT(!A.empty()); + TA_ASSERT(A.range().rank() == gemm_helper.left_rank()); + TA_ASSERT(!B.empty()); + TA_ASSERT(B.range().rank() == gemm_helper.right_rank()); + + TA_ASSERT(A.nbatch() == 1); + TA_ASSERT(B.nbatch() == 1); + TA_ASSERT(C.nbatch() == 1); + + // Check that the outer dimensions of left match the corresponding + // dimensions in result + TA_ASSERT(gemm_helper.left_result_congruent(A.range().extent_data(), + C.range().extent_data())); + TA_ASSERT(ignore_tile_position() || + gemm_helper.left_result_congruent(A.range().lobound_data(), + C.range().lobound_data())); + TA_ASSERT(ignore_tile_position() || + gemm_helper.left_result_congruent(A.range().upbound_data(), + C.range().upbound_data())); + + // Check that the outer dimensions of right match the corresponding + // dimensions in result + TA_ASSERT(gemm_helper.right_result_congruent(B.range().extent_data(), + C.range().extent_data())); + TA_ASSERT(ignore_tile_position() || + gemm_helper.right_result_congruent(B.range().lobound_data(), + C.range().lobound_data())); + TA_ASSERT(ignore_tile_position() || + gemm_helper.right_result_congruent(B.range().upbound_data(), + C.range().upbound_data())); + + // Check that the inner dimensions of left and right match + TA_ASSERT(gemm_helper.left_right_congruent(A.range().extent_data(), + B.range().extent_data())); + TA_ASSERT(ignore_tile_position() || + gemm_helper.left_right_congruent(A.range().lobound_data(), + B.range().lobound_data())); + TA_ASSERT(ignore_tile_position() || + gemm_helper.left_right_congruent(A.range().upbound_data(), + B.range().upbound_data())); + + // Compute gemm dimensions + using integer = TiledArray::math::blas::integer; + integer m, n, k; + gemm_helper.compute_matrix_sizes(m, n, k, A.range(), B.range()); + + // Get the leading dimension for left and right matrices. + const integer lda = std::max( + integer{1}, + (gemm_helper.left_op() == TiledArray::math::blas::NoTranspose ? k : m)); + const integer ldb = std::max( + integer{1}, + (gemm_helper.right_op() == TiledArray::math::blas::NoTranspose ? n + : k)); + + // may need to split gemm into multiply + accumulate for tracing purposes +#ifdef TA_ENABLE_TILE_OPS_LOGGING + { + using numeric_type = typename Tensor::numeric_type; + using T = numeric_type; + const bool twostep = + TiledArray::TileOpsLogger::get_instance().gemm && + TiledArray::TileOpsLogger::get_instance().gemm_print_contributions; + std::unique_ptr data_copy; + size_t tile_volume; + if (twostep) { + tile_volume = C.range().volume(); + data_copy = std::make_unique(tile_volume); + std::copy(C.data(), C.data() + tile_volume, data_copy.get()); + } + non_distributed::gemm(gemm_helper.left_op(), gemm_helper.right_op(), m, n, + k, alpha, A.data(), lda, B.data(), ldb, + twostep ? numeric_type(0) : beta, C.data(), n); + + if (TiledArray::TileOpsLogger::get_instance_ptr() != nullptr && + TiledArray::TileOpsLogger::get_instance().gemm) { + auto& logger = TiledArray::TileOpsLogger::get_instance(); + auto apply = [](auto& fnptr, const Range& arg) { + return fnptr ? fnptr(arg) : arg; + }; + auto tformed_left_range = + apply(logger.gemm_left_range_transform, A.range()); + auto tformed_right_range = + apply(logger.gemm_right_range_transform, B.range()); + auto tformed_result_range = + apply(logger.gemm_result_range_transform, C.range()); + if ((!logger.gemm_result_range_filter || + logger.gemm_result_range_filter(tformed_result_range)) && + (!logger.gemm_left_range_filter || + logger.gemm_left_range_filter(tformed_left_range)) && + (!logger.gemm_right_range_filter || + logger.gemm_right_range_filter(tformed_right_range))) { + logger << "TA::Tensor::gemm+: left=" << tformed_left_range + << " right=" << tformed_right_range + << " result=" << tformed_result_range << std::endl; + if (TiledArray::TileOpsLogger::get_instance() + .gemm_print_contributions) { + if (!TiledArray::TileOpsLogger::get_instance() + .gemm_printer) { // default printer + // must use custom printer if result's range transformed + if (!logger.gemm_result_range_transform) + logger << C << std::endl; + else + logger << make_map(C.data(), tformed_result_range) << std::endl; + } else { + TiledArray::TileOpsLogger::get_instance().gemm_printer( + *logger.log, tformed_left_range, A.data(), + tformed_right_range, B.data(), tformed_right_range, C.data(), + C.nbatch()); + } + } + } + } + + if (twostep) { + for (size_t v = 0; v != tile_volume; ++v) { + C.data()[v] += data_copy[v]; + } + } + } +#else // TA_ENABLE_TILE_OPS_LOGGING + const integer ldc = std::max(integer{1}, n); + math::blas::gemm(gemm_helper.left_op(), gemm_helper.right_op(), m, n, k, + alpha, A.data(), lda, B.data(), ldb, beta, C.data(), ldc); +#endif // TA_ENABLE_TILE_OPS_LOGGING + } +} + /// customization point transform functionality to tensor class T, useful for /// nonintrusive extension of T to be usable as element type T in Tensor template @@ -474,21 +666,20 @@ inline void tensor_init(Op&& op, TR& result, const Ts&... tensors) { /// \param[in] tensors The argument tensors template < typename Op, typename TR, typename... Ts, - typename std::enable_if<(is_nested_tensor::value && - !is_tensor::value) && - is_contiguous_tensor::value>::type* = nullptr> + typename std::enable_if< + (is_nested_tensor::value && !is_tensor::value) && + is_contiguous_tensor::value>::type* = nullptr> inline void tensor_init(Op&& op, TR& result, const Ts&... tensors) { TA_ASSERT(!empty(result, tensors...)); TA_ASSERT(is_range_set_congruent(result, tensors...)); - const auto volume = result.range().volume(); - if constexpr (std::is_invocable_r_v) { result = std::forward(op)(tensors...); } else { - for (decltype(result.range().volume()) ord = 0ul; ord < volume; ++ord) { + const auto volume = result.total_size(); + for (std::remove_cv_t ord = 0ul; ord < volume; ++ord) { new (result.data() + ord) typename TR::value_type( - tensor_op(op, tensors.at_ordinal(ord)...)); + tensor_op(op, (*(tensors.data() + ord))...)); } } } @@ -953,6 +1144,146 @@ Scalar tensor_reduce(ReduceOp&& reduce_op, JoinOp&& join_op, return result; } +/// +/// todo: constraint ResultTensorAllocator type so that non-sensical Allocators +/// are prohibited +/// +template && + is_annotation_v>> +auto tensor_contract(TensorA const& A, Annot const& aA, TensorB const& B, + Annot const& aB, Annot const& aC) { + using Result = result_tensor_t, TensorA, TensorB, + ResultTensorAllocator>; + + using Indices = ::Einsum::index::Index; + using Permutation = ::Einsum::index::Permutation; + using ::Einsum::index::permutation; + + // Check that the ranks of the tensors match that of the annotation. + TA_ASSERT(A.range().rank() == aA.size()); + TA_ASSERT(B.range().rank() == aB.size()); + + struct { + Indices // + A, // indices of A + B, // indices of B + C, // indices of C (target indices) + h, // Hadamard indices (aA intersection aB intersection aC) + e, // external indices (aA symmetric difference aB) + i; // internal indices ((aA intersection aB) set difference aC) + } const indices{aA, + aB, + aC, + (indices.A & indices.B & indices.C), + (indices.A ^ indices.B), + ((indices.A & indices.B) - indices.h)}; + + TA_ASSERT(!indices.h && "Hadamard indices not supported"); + TA_ASSERT(indices.e && "Dot product not supported"); + + struct { + Indices A, B, C; + } const blas_layout{(indices.A - indices.B) | indices.i, + indices.i | (indices.B - indices.A), indices.e}; + + struct { + Permutation A, B, C; + } const perm{permutation(indices.A, blas_layout.A), + permutation(indices.B, blas_layout.B), + permutation(indices.C, blas_layout.C)}; + + struct { + bool A, B, C; + } const do_perm{indices.A != blas_layout.A, indices.B != blas_layout.B, + indices.C != blas_layout.C}; + + math::GemmHelper gemm_helper{blas::Op::NoTrans, blas::Op::NoTrans, + static_cast(indices.e.size()), + static_cast(indices.A.size()), + static_cast(indices.B.size())}; + + // initialize result with the correct extents + Result result; + { + using Index = typename Indices::value_type; + using Extent = std::remove_cv_t< + typename decltype(std::declval().extent())::value_type>; + using ExtentMap = ::Einsum::index::IndexMap; + + // Map tensor indices to their extents. + // Note that whether the contracting indices have matching extents is + // implicitly checked here by the pipe(|) operator on ExtentMap. + + ExtentMap extent = (ExtentMap{indices.A, A.range().extent()} | + ExtentMap{indices.B, B.range().extent()}); + + container::vector rng; + rng.reserve(indices.e.size()); + for (auto&& ix : indices.e) { + // assuming ix _exists_ in extent + rng.emplace_back(extent[ix]); + } + result = Result{TA::Range(rng)}; + } + + using Numeric = typename Result::numeric_type; + + // call gemm + gemm(Numeric{1}, // + do_perm.A ? A.permute(perm.A) : A, // + do_perm.B ? B.permute(perm.B) : B, // + Numeric{0}, result, gemm_helper); + + return do_perm.C ? result.permute(perm.C.inv()) : result; +} + +template && + is_annotation_v>> +auto tensor_hadamard(TensorA const& A, Annot const& aA, TensorB const& B, + Annot const& aB, Annot const& aC) { + using ::Einsum::index::Permutation; + using ::Einsum::index::permutation; + using Indices = ::Einsum::index::Index; + + struct { + Permutation // + AB, // permutes A to B + AC, // permutes A to C + BC; // permutes B to C + } const perm{permutation(Indices(aA), Indices(aB)), + permutation(Indices(aA), Indices(aC)), + permutation(Indices(aB), Indices(aC))}; + + struct { + bool no_perm, perm_to_c, perm_a, perm_b; + } const do_this{ + perm.AB.is_identity() && perm.AC.is_identity() && perm.BC.is_identity(), + perm.AB.is_identity(), // + perm.BC.is_identity(), // + perm.AC.is_identity()}; + + if (do_this.no_perm) { + return A.mult(B); + } else if (do_this.perm_to_c) { + return A.mult(B, perm.AC); + } else if (do_this.perm_a) { + auto pA = A.permute(perm.AC); + pA.mult_to(B); + return pA; + } else if (do_this.perm_b) { + auto pB = B.permute(perm.BC); + pB.mult_to(A); + return pB; + } else { + auto pA = A.permute(perm.AC); + return pA.mult_to(B.permute(perm.BC)); + return pA; + } +} + } // namespace detail } // namespace TiledArray diff --git a/src/TiledArray/tensor/tensor.h b/src/TiledArray/tensor/tensor.h index c12c2c15d1..cd0e7e97f1 100644 --- a/src/TiledArray/tensor/tensor.h +++ b/src/TiledArray/tensor/tensor.h @@ -36,11 +36,6 @@ namespace TiledArray { -template -void gemm(Alpha alpha, const Tensor& A, const Tensor& B, - Beta beta, Tensor& C, const math::GemmHelper& gemm_helper); - namespace detail { /// Signals that we can take the trace of a Tensor (for numeric \c T) @@ -402,7 +397,7 @@ class Tensor { /// \param perm The permutation that will be applied to the copy /// \warning if `T1` is a tensor of tensors its elements are _cloned_ rather /// than copied to make the semantics of this to be consistent - /// between tensors of scalars and tensors of scalars; specifically, + /// between tensors of scalars and tensors of tensors; specifically, /// if `T1` is a tensor of scalars the constructed tensor is /// is independent of \p other, thus should apply clone to inner /// tensor nests to behave similarly for nested tensors @@ -411,22 +406,32 @@ class Tensor { typename std::enable_if && detail::is_permutation_v>::type* = nullptr> Tensor(const T1& other, const Perm& perm) - : Tensor(outer(perm) * other.range(), 1, default_construct{false}) { - detail::tensor_init(value_converter, outer(perm), - *this, other); + : Tensor(outer(perm) * other.range(), other.nbatch(), + default_construct{false}) { + const auto outer_perm = outer(perm); + if (outer_perm) { + detail::tensor_init(value_converter, outer_perm, + *this, other); + } else { + detail::tensor_init(value_converter, *this, + other); + } // If we actually have a ToT the inner permutation was not applied above so // we do that now constexpr bool is_tot = detail::is_tensor_of_tensor_v; constexpr bool is_bperm = detail::is_bipartite_permutation_v; // tile ops pass bipartite permutations here even if this is a plain tensor - // static_assert(is_tot || (!is_tot && !is_bperm), "Permutation type does - // not match Tensor"); if constexpr (is_tot && is_bperm) { if (inner_size(perm) != 0) { - auto inner_perm = inner(perm); + const auto inner_perm = inner(perm); Permute p; - for (auto& x : *this) x = p(x, inner_perm); + + auto volume = total_size(); + for (decltype(volume) i = 0; i < volume; ++i) { + auto& el = *(data() + i); + el = p(el, inner_perm); + } } } } @@ -463,10 +468,8 @@ class Tensor { // If we actually have a ToT the inner permutation was not applied above so // we do that now constexpr bool is_tot = detail::is_tensor_of_tensor_v; - constexpr bool is_bperm = detail::is_bipartite_permutation_v; // tile ops pass bipartite permutations here even if this is a plain tensor - // static_assert(is_tot || (!is_tot && !is_bperm), "Permutation type does - // not match Tensor"); + constexpr bool is_bperm = detail::is_bipartite_permutation_v; if constexpr (is_tot && is_bperm) { if (inner_size(perm) != 0) { auto inner_perm = inner(perm); @@ -511,10 +514,8 @@ class Tensor { // If we actually have a ToT the inner permutation was not applied above so // we do that now constexpr bool is_tot = detail::is_tensor_of_tensor_v; - constexpr bool is_bperm = detail::is_bipartite_permutation_v; // tile ops pass bipartite permutations here even if this is a plain tensor - // static_assert(is_tot || (!is_tot && !is_bperm), "Permutation type does - // not match Tensor"); + constexpr bool is_bperm = detail::is_bipartite_permutation_v; if constexpr (is_tot && is_bperm) { if (inner_size(perm) != 0) { auto inner_perm = inner(perm); @@ -1296,33 +1297,7 @@ class Tensor { template >> Tensor permute(const Perm& perm) const { - constexpr bool is_tot = detail::is_tensor_of_tensor_v; - [[maybe_unused]] constexpr bool is_bperm = - detail::is_bipartite_permutation_v; - // tile ops pass bipartite permutations here even if this is a plain tensor - // static_assert(is_tot || (!is_tot && !is_bperm), "Permutation type does - // not match Tensor"); - if constexpr (!is_tot) { - if constexpr (is_bperm) { - TA_ASSERT(inner_size(perm) == 0); // ensure this is a plain permutation - return Tensor(*this, outer(perm)); - } else - return Tensor(*this, perm); - } else { - // If we have a ToT we need to apply the permutation in two steps. The - // first step is identical to the non-ToT case (permute the outer modes) - // the second step does the inner modes - Tensor rv(*this, outer(perm)); - if constexpr (is_bperm) { - if (inner_size(perm) != 0) { - auto inner_perm = inner(perm); - Permute p; - for (auto& inner_t : rv) inner_t = p(inner_t, inner_perm); - } - } - return rv; - } - abort(); // unreachable + return Tensor(*this, perm); } /// Shift the lower and upper bound of this tensor @@ -1393,7 +1368,8 @@ class Tensor { std::declval(), std::declval&>())); using result_allocator_type = typename std::allocator_traits< Allocator>::template rebind_alloc; - return Tensor(*this, right, op); + using ResultTensor = Tensor; + return ResultTensor(*this, right, op); } /// Use a binary, element wise operation to construct a new, permuted tensor @@ -1406,34 +1382,36 @@ class Tensor { /// \param perm The permutation to be applied to this tensor /// \return A tensor where element \c i of the new tensor is equal to /// \c op(*this[i],other[i]) - template < - typename Right, typename Op, typename Perm, - typename std::enable_if::value && - detail::is_permutation_v>::type* = nullptr> - auto binary(const Right& right, Op&& op, const Perm& perm) const { - constexpr bool is_tot = detail::is_tensor_of_tensor_v; + template ::value && + detail::is_permutation_v< + std::remove_reference_t>>::type* = + nullptr> + auto binary(const Right& right, Op&& op, Perm&& perm) const { + using result_value_type = decltype(op( + std::declval(), std::declval&>())); + using result_allocator_type = typename std::allocator_traits< + Allocator>::template rebind_alloc; + using ResultTensor = Tensor; + // tile ops pass bipartite permutations here even if the result is a plain + // tensor [[maybe_unused]] constexpr bool is_bperm = detail::is_bipartite_permutation_v; - // tile ops pass bipartite permutations here even if this is a plain tensor - // static_assert(is_tot || (!is_tot && !is_bperm), "Permutation type does - // not match Tensor"); - if constexpr (!is_tot) { - using result_value_type = decltype(op( - std::declval(), std::declval&>())); - using result_allocator_type = typename std::allocator_traits< - Allocator>::template rebind_alloc; - using ResultTensor = Tensor; + constexpr bool result_is_tot = detail::is_tensor_of_tensor_v; + + if constexpr (!result_is_tot) { if constexpr (is_bperm) { - TA_ASSERT(inner_size(perm) == 0); // ensure this is a plain permutation - return ResultTensor(*this, right, op, outer(perm)); + TA_ASSERT(!inner(perm)); // ensure this is a plain permutation since + // ResultTensor is plain + return ResultTensor(*this, right, op, outer(std::forward(perm))); } else - return ResultTensor(*this, right, op, perm); + return ResultTensor(*this, right, op, std::forward(perm)); } else { // AFAIK the other branch fundamentally relies on raw pointer arithmetic, // which won't work for ToTs. auto temp = binary(right, std::forward(op)); Permute p; - return p(temp, perm); + return p(temp, std::forward(perm)); } abort(); // unreachable } @@ -1481,24 +1459,23 @@ class Tensor { /// \throw TiledArray::Exception The dimension of \c perm does not match /// that of this tensor. template >> - Tensor unary(Op&& op, const Perm& perm) const { + typename = std::enable_if_t< + detail::is_permutation_v>>> + Tensor unary(Op&& op, Perm&& perm) const { constexpr bool is_tot = detail::is_tensor_of_tensor_v; [[maybe_unused]] constexpr bool is_bperm = detail::is_bipartite_permutation_v; // tile ops pass bipartite permutations here even if this is a plain tensor - // static_assert(is_tot || (!is_tot && !is_bperm), "Permutation type does - // not match Tensor"); if constexpr (!is_tot) { if constexpr (is_bperm) { TA_ASSERT(inner_size(perm) == 0); // ensure this is a plain permutation - return Tensor(*this, op, outer(perm)); + return Tensor(*this, op, outer(std::forward(perm))); } else - return Tensor(*this, op, perm); + return Tensor(*this, op, std::forward(perm)); } else { auto temp = unary(std::forward(op)); Permute p; - return p(temp, perm); + return p(temp, std::forward(perm)); } abort(); // unreachable } @@ -1692,6 +1669,9 @@ class Tensor { template ::value>::type* = nullptr> Tensor& add_to(const Right& right) { + if (empty()) { + *this = Tensor{right.range(), value_type{}}; + } return inplace_binary(right, [](value_type& MADNESS_RESTRICT l, const value_t r) { l += r; }); } @@ -2205,7 +2185,8 @@ class Tensor { #else // TA_ENABLE_TILE_OPS_LOGGING for (size_t i = 0; i < this->nbatch(); ++i) { auto Ci = this->batch(i); - TiledArray::gemm(alpha, A.batch(i), B.batch(i), beta, Ci, gemm_helper); + TiledArray::detail::gemm(alpha, A.batch(i), B.batch(i), beta, Ci, + gemm_helper); } #endif // TA_ENABLE_TILE_OPS_LOGGING @@ -2394,7 +2375,7 @@ class Tensor { scalar_type squared_norm() const { auto square_op = [](scalar_type& MADNESS_RESTRICT res, const numeric_type arg) { - res += TiledArray::detail::norm(arg); + res += TiledArray::detail::squared_norm(arg); }; auto sum_op = [](scalar_type& MADNESS_RESTRICT res, const scalar_type arg) { res += arg; @@ -2542,193 +2523,6 @@ Tensor operator*(const Permutation& p, const Tensor& t) { return t.permute(p); } -/// Contract two tensors and accumulate the scaled result to this tensor - -/// GEMM is limited to matrix like contractions. For example, the following -/// contractions are supported: -/// \code -/// C[a,b] = A[a,i,j] * B[i,j,b] -/// C[a,b] = A[a,i,j] * B[b,i,j] -/// C[a,b] = A[i,j,a] * B[i,j,b] -/// C[a,b] = A[i,j,a] * B[b,i,j] -/// -/// C[a,b,c,d] = A[a,b,i,j] * B[i,j,c,d] -/// C[a,b,c,d] = A[a,b,i,j] * B[c,d,i,j] -/// C[a,b,c,d] = A[i,j,a,b] * B[i,j,c,d] -/// C[a,b,c,d] = A[i,j,a,b] * B[c,d,i,j] -/// \endcode -/// Notice that in the above contractions, the inner and outer indices of -/// the arguments for exactly two contiguous groups in each tensor and that -/// each group is in the same order in all tensors. That is, the indices of -/// the tensors must fit the one of the following patterns: -/// \code -/// C[M...,N...] = A[M...,K...] * B[K...,N...] -/// C[M...,N...] = A[M...,K...] * B[N...,K...] -/// C[M...,N...] = A[K...,M...] * B[K...,N...] -/// C[M...,N...] = A[K...,M...] * B[N...,K...] -/// \endcode -/// This allows use of optimized BLAS functions to evaluate tensor -/// contractions. Tensor contractions that do not fit this pattern require -/// one or more tensor permutation so that the tensors fit the required -/// pattern. -/// \tparam U The left-hand tensor element type -/// \tparam AU The left-hand tensor allocator type -/// \tparam V The right-hand tensor element type -/// \tparam AV The right-hand tensor allocator type -/// \tparam W The type of the scaling factor -/// \param left The left-hand tensor that will be contracted -/// \param right The right-hand tensor that will be contracted -/// \param factor The contraction result will be scaling by this value, then -/// accumulated into \c this \param gemm_helper The *GEMM operation meta data -/// \return A reference to \c this -/// \note if this is uninitialized, i.e., if \c this->empty()==true will -/// this is equivalent to -/// \code -/// return (*this = left.gemm(right, factor, gemm_helper)); -/// \endcode -template -void gemm(Alpha alpha, const Tensor& A, const Tensor& B, - Beta beta, Tensor& C, const math::GemmHelper& gemm_helper) { - static_assert(!detail::is_tensor_of_tensor_v, Tensor, - Tensor>, - "TA::Tensor::gemm without custom element op is " - "only applicable to " - "plain tensors"); - { - // Check that tensor C is not empty and has the correct rank - TA_ASSERT(!C.empty()); - TA_ASSERT(C.range().rank() == gemm_helper.result_rank()); - - // Check that the arguments are not empty and have the correct ranks - TA_ASSERT(!A.empty()); - TA_ASSERT(A.range().rank() == gemm_helper.left_rank()); - TA_ASSERT(!B.empty()); - TA_ASSERT(B.range().rank() == gemm_helper.right_rank()); - - TA_ASSERT(A.nbatch() == 1); - TA_ASSERT(B.nbatch() == 1); - TA_ASSERT(C.nbatch() == 1); - - // Check that the outer dimensions of left match the corresponding - // dimensions in result - TA_ASSERT(gemm_helper.left_result_congruent(A.range().extent_data(), - C.range().extent_data())); - TA_ASSERT(ignore_tile_position() || - gemm_helper.left_result_congruent(A.range().lobound_data(), - C.range().lobound_data())); - TA_ASSERT(ignore_tile_position() || - gemm_helper.left_result_congruent(A.range().upbound_data(), - C.range().upbound_data())); - - // Check that the outer dimensions of right match the corresponding - // dimensions in result - TA_ASSERT(gemm_helper.right_result_congruent(B.range().extent_data(), - C.range().extent_data())); - TA_ASSERT(ignore_tile_position() || - gemm_helper.right_result_congruent(B.range().lobound_data(), - C.range().lobound_data())); - TA_ASSERT(ignore_tile_position() || - gemm_helper.right_result_congruent(B.range().upbound_data(), - C.range().upbound_data())); - - // Check that the inner dimensions of left and right match - TA_ASSERT(gemm_helper.left_right_congruent(A.range().extent_data(), - B.range().extent_data())); - TA_ASSERT(ignore_tile_position() || - gemm_helper.left_right_congruent(A.range().lobound_data(), - B.range().lobound_data())); - TA_ASSERT(ignore_tile_position() || - gemm_helper.left_right_congruent(A.range().upbound_data(), - B.range().upbound_data())); - - // Compute gemm dimensions - using integer = TiledArray::math::blas::integer; - integer m, n, k; - gemm_helper.compute_matrix_sizes(m, n, k, A.range(), B.range()); - - // Get the leading dimension for left and right matrices. - const integer lda = std::max( - integer{1}, - (gemm_helper.left_op() == TiledArray::math::blas::NoTranspose ? k : m)); - const integer ldb = std::max( - integer{1}, - (gemm_helper.right_op() == TiledArray::math::blas::NoTranspose ? n - : k)); - - // may need to split gemm into multiply + accumulate for tracing purposes -#ifdef TA_ENABLE_TILE_OPS_LOGGING - { - using numeric_type = typename Tensor::numeric_type; - using T = numeric_type; - const bool twostep = - TiledArray::TileOpsLogger::get_instance().gemm && - TiledArray::TileOpsLogger::get_instance().gemm_print_contributions; - std::unique_ptr data_copy; - size_t tile_volume; - if (twostep) { - tile_volume = C.range().volume(); - data_copy = std::make_unique(tile_volume); - std::copy(C.data(), C.data() + tile_volume, data_copy.get()); - } - non_distributed::gemm(gemm_helper.left_op(), gemm_helper.right_op(), m, n, - k, alpha, A.data(), lda, B.data(), ldb, - twostep ? numeric_type(0) : beta, C.data(), n); - - if (TiledArray::TileOpsLogger::get_instance_ptr() != nullptr && - TiledArray::TileOpsLogger::get_instance().gemm) { - auto& logger = TiledArray::TileOpsLogger::get_instance(); - auto apply = [](auto& fnptr, const Range& arg) { - return fnptr ? fnptr(arg) : arg; - }; - auto tformed_left_range = - apply(logger.gemm_left_range_transform, A.range()); - auto tformed_right_range = - apply(logger.gemm_right_range_transform, B.range()); - auto tformed_result_range = - apply(logger.gemm_result_range_transform, C.range()); - if ((!logger.gemm_result_range_filter || - logger.gemm_result_range_filter(tformed_result_range)) && - (!logger.gemm_left_range_filter || - logger.gemm_left_range_filter(tformed_left_range)) && - (!logger.gemm_right_range_filter || - logger.gemm_right_range_filter(tformed_right_range))) { - logger << "TA::Tensor::gemm+: left=" << tformed_left_range - << " right=" << tformed_right_range - << " result=" << tformed_result_range << std::endl; - if (TiledArray::TileOpsLogger::get_instance() - .gemm_print_contributions) { - if (!TiledArray::TileOpsLogger::get_instance() - .gemm_printer) { // default printer - // must use custom printer if result's range transformed - if (!logger.gemm_result_range_transform) - logger << C << std::endl; - else - logger << make_map(C.data(), tformed_result_range) << std::endl; - } else { - TiledArray::TileOpsLogger::get_instance().gemm_printer( - *logger.log, tformed_left_range, A.data(), - tformed_right_range, B.data(), tformed_right_range, C.data(), - C.nbatch()); - } - } - } - } - - if (twostep) { - for (size_t v = 0; v != tile_volume; ++v) { - C.data()[v] += data_copy[v]; - } - } - } -#else // TA_ENABLE_TILE_OPS_LOGGING - const integer ldc = std::max(integer{1}, n); - math::blas::gemm(gemm_helper.left_op(), gemm_helper.right_op(), m, n, k, - alpha, A.data(), lda, B.data(), ldb, beta, C.data(), ldc); -#endif // TA_ENABLE_TILE_OPS_LOGGING - } -} - // template // const typename Tensor::range_type Tensor::empty_range_; diff --git a/src/TiledArray/tensor/tensor_interface.h b/src/TiledArray/tensor/tensor_interface.h index a514959cab..7a23307036 100644 --- a/src/TiledArray/tensor/tensor_interface.h +++ b/src/TiledArray/tensor/tensor_interface.h @@ -1066,7 +1066,7 @@ class TensorInterface { scalar_type squared_norm() const { auto square_op = [](scalar_type& MADNESS_RESTRICT res, const numeric_type arg) { - res += TiledArray::detail::norm(arg); + res += TiledArray::detail::squared_norm(arg); }; auto sum_op = [](scalar_type& MADNESS_RESTRICT res, const scalar_type arg) { res += arg; diff --git a/src/TiledArray/tensor/type_traits.h b/src/TiledArray/tensor/type_traits.h index 89f8da70a2..a32de32e4a 100644 --- a/src/TiledArray/tensor/type_traits.h +++ b/src/TiledArray/tensor/type_traits.h @@ -30,6 +30,7 @@ #include #include +#include #include namespace Eigen { @@ -209,6 +210,33 @@ template constexpr const bool tensors_have_equal_nested_rank_v = tensors_have_equal_nested_rank::value; +template +constexpr size_t nested_rank = 0; + +template +constexpr size_t nested_rank> = 1 + nested_rank; + +template +constexpr size_t nested_rank> = + nested_rank>; + +template +constexpr size_t nested_rank> = nested_rank; + +template +constexpr size_t nested_rank> = + nested_rank>; + +template +constexpr size_t max_nested_rank = 0; + +template +constexpr size_t max_nested_rank = nested_rank; + +template +constexpr size_t max_nested_rank = + std::max(nested_rank, std::max(nested_rank, max_nested_rank)); + //////////////////////////////////////////////////////////////////////////////// template @@ -350,6 +378,9 @@ using default_permutation_t = typename default_permutation::type; template struct is_permutation : public std::false_type {}; +template +struct is_permutation : public is_permutation {}; + template <> struct is_permutation : public std::true_type {}; @@ -366,13 +397,114 @@ static constexpr const auto is_permutation_v = is_permutation::value; template static constexpr const auto is_bipartite_permutation_v = - std::is_same_v; + std::is_same_v || + std::is_same_v; template static constexpr const auto is_bipartite_permutable_v = is_free_function_permute_anyreturn_v< const T&, const TiledArray::BipartitePermutation&>; +// +template +constexpr bool is_random_access_container_v{}; + +/// +/// - The container concept is weakly tested -- any type that has +/// @c iterator typedef gets picked up. +/// +/// - The iterator category must be std::random_access_iterator_tag -- +/// random-access-ness is strongly tested. +/// +/// Following lines compile, for example: +/// +/// @c static_assert(is_random_access_container>); +/// @c static_assert(!is_random_access_container>); +/// +template +constexpr bool is_random_access_container_v< + T, std::void_t, + std::enable_if_t::iterator_category, + std::random_access_iterator_tag>>>{true}; + +// +template +constexpr bool is_annotation_v{}; + +/// +/// An annotation type (T) is a type that satisfies the following constraints: +/// - is_random_access_container_v is true. +/// - The value type of the container T are strictly ordered. Note that T is a +/// container from the first constraint. +/// +template +constexpr bool is_annotation_v< + T, std::void_t, + std::enable_if_t && + is_strictly_ordered_v> + + >{true}; + +namespace { + +template +using binop_result_t = std::invoke_result_t; + +template +constexpr bool is_binop_v{}; + +template +constexpr bool + is_binop_v>>{true}; + +template >> +struct result_tensor_helper { + private: + using TensorA_ = std::remove_reference_t; + using TensorB_ = std::remove_reference_t; + using value_type_A = typename TensorA_::value_type; + using value_type_B = typename TensorB_::value_type; + using allocator_type_A = typename TensorA_::allocator_type; + using allocator_type_B = typename TensorB_::allocator_type; + + public: + using numeric_type = binop_result_t; + using allocator_type = + std::conditional_t && + std::is_same_v, + allocator_type_A, Allocator>; + using result_type = + std::conditional_t, + TA::Tensor, + TA::Tensor>; +}; + +} // namespace + +/// +/// The typedef is a complete TA::Tensor type where +/// - NumericT is determined by Op: +/// - effectively, it is: +/// std::invoke_result_t +/// +/// - AllocatorT is +/// - the default TA::Tensor allocator if @tparam Allocator is void +/// - TensorA::allocator_type if TensorA and TensorB have the same allocator +/// type +/// - the @tparam Allocator otherwise +/// todo: constraint what @tparam Allocator +/// +/// +template >> +using result_tensor_t = + typename result_tensor_helper::result_type; + } // namespace detail /// Specifies how coordinates are mapped to ordinal values diff --git a/src/TiledArray/tile_op/binary_wrapper.h b/src/TiledArray/tile_op/binary_wrapper.h index 4c02b84318..33d021f2b0 100644 --- a/src/TiledArray/tile_op/binary_wrapper.h +++ b/src/TiledArray/tile_op/binary_wrapper.h @@ -129,9 +129,11 @@ class BinaryWrapper { BinaryWrapper& operator=(const BinaryWrapper&) = default; BinaryWrapper& operator=(BinaryWrapper&&) = default; - template >> - BinaryWrapper(const Op& op, const Perm& perm) : op_(op), perm_(perm) {} + template >>> + BinaryWrapper(const Op& op, Perm&& perm) + : op_(op), perm_(std::forward(perm)) {} BinaryWrapper(const Op& op) : op_(op), perm_() {} diff --git a/src/TiledArray/tile_op/contract_reduce.h b/src/TiledArray/tile_op/contract_reduce.h index d9d87d59c8..94c7107343 100644 --- a/src/TiledArray/tile_op/contract_reduce.h +++ b/src/TiledArray/tile_op/contract_reduce.h @@ -85,17 +85,18 @@ class ContractReduceBase { typename Perm = BipartitePermutation, typename ElemMultAddOp = TiledArray::function_ref, typename = std::enable_if_t< - TiledArray::detail::is_permutation_v && + TiledArray::detail::is_permutation_v< + std::remove_reference_t> && std::is_invocable_r_v, result_value_type&, const left_value_type&, const right_value_type&>>> Impl(const math::blas::Op left_op, const math::blas::Op right_op, const scalar_type alpha, const unsigned int result_rank, const unsigned int left_rank, const unsigned int right_rank, - const Perm& perm = {}, ElemMultAddOp&& elem_muladd_op = {}) + Perm&& perm = {}, ElemMultAddOp&& elem_muladd_op = {}) : gemm_helper_(left_op, right_op, result_rank, left_rank, right_rank), alpha_(alpha), - perm_(perm), + perm_(std::forward(perm)), elem_muladd_op_(std::forward(elem_muladd_op)) { // non-unit alpha must be absorbed into elem_muladd_op if (elem_muladd_op_) TA_ASSERT(alpha == scalar_type(1)); @@ -141,7 +142,7 @@ class ContractReduceBase { typename Perm = BipartitePermutation, typename ElemMultAddOp = TiledArray::function_ref, typename = std::enable_if_t< - TiledArray::detail::is_permutation_v && + TiledArray::detail::is_permutation_v> && std::is_invocable_r_v, result_value_type&, const left_value_type&, const right_value_type&>>> @@ -149,10 +150,11 @@ class ContractReduceBase { const math::blas::Op right_op, const scalar_type alpha, const unsigned int result_rank, const unsigned int left_rank, - const unsigned int right_rank, const Perm& perm = {}, + const unsigned int right_rank, Perm&& perm = {}, ElemMultAddOp&& elem_muladd_op = {}) : pimpl_(std::make_shared( - left_op, right_op, alpha, result_rank, left_rank, right_rank, perm, + left_op, right_op, alpha, result_rank, left_rank, right_rank, + std::forward(perm), std::forward(elem_muladd_op))) {} /// Gemm meta data accessor @@ -276,16 +278,16 @@ class ContractReduce : public ContractReduceBase { typename Perm = BipartitePermutation, typename ElemMultAddOp = TiledArray::function_ref, typename = std::enable_if_t< - TiledArray::detail::is_permutation_v && + TiledArray::detail::is_permutation_v> && std::is_invocable_r_v, result_value_type&, const left_value_type&, const right_value_type&>>> ContractReduce(const math::blas::Op left_op, const math::blas::Op right_op, const scalar_type alpha, const unsigned int result_rank, const unsigned int left_rank, const unsigned int right_rank, - const Perm& perm = {}, ElemMultAddOp&& elem_muladd_op = {}) + Perm&& perm = {}, ElemMultAddOp&& elem_muladd_op = {}) : ContractReduceBase_(left_op, right_op, alpha, result_rank, left_rank, - right_rank, perm, + right_rank, std::forward(perm), std::forward(elem_muladd_op)) {} /// Create a result type object @@ -404,16 +406,16 @@ class ContractReduce, typename = std::enable_if_t< - TiledArray::detail::is_permutation_v && + TiledArray::detail::is_permutation_v> && std::is_invocable_r_v, result_value_type&, const left_value_type&, const right_value_type&>>> ContractReduce(const math::blas::Op left_op, const math::blas::Op right_op, const scalar_type alpha, const unsigned int result_rank, const unsigned int left_rank, const unsigned int right_rank, - const Perm& perm = {}, ElemMultAddOp&& elem_muladd_op = {}) + Perm&& perm = {}, ElemMultAddOp&& elem_muladd_op = {}) : ContractReduceBase_(left_op, right_op, alpha, result_rank, left_rank, - right_rank, perm, + right_rank, std::forward(perm), std::forward(elem_muladd_op)) {} /// Create a result type object @@ -530,16 +532,16 @@ class ContractReduce, typename = std::enable_if_t< - TiledArray::detail::is_permutation_v && + TiledArray::detail::is_permutation_v> && std::is_invocable_r_v, result_value_type&, const left_value_type&, const right_value_type&>>> ContractReduce(const math::blas::Op left_op, const math::blas::Op right_op, const scalar_type alpha, const unsigned int result_rank, const unsigned int left_rank, const unsigned int right_rank, - const Perm& perm = {}, ElemMultAddOp&& elem_muladd_op = {}) + Perm&& perm = {}, ElemMultAddOp&& elem_muladd_op = {}) : ContractReduceBase_(left_op, right_op, alpha, result_rank, left_rank, - right_rank, perm, + right_rank, std::forward(perm), std::forward(elem_muladd_op)) {} /// Create a result type object diff --git a/src/TiledArray/tile_op/mult.h b/src/TiledArray/tile_op/mult.h index b9da1d5e24..577ea94115 100644 --- a/src/TiledArray/tile_op/mult.h +++ b/src/TiledArray/tile_op/mult.h @@ -128,17 +128,30 @@ class Mult { template ::type* = nullptr> result_type eval(left_type& first, const right_type& second) const { - TA_ASSERT(!element_op_); - using TiledArray::mult_to; - return mult_to(first, second); + if (!element_op_) { + using TiledArray::mult_to; + return mult_to(first, second); + } else { + // TODO figure out why this does not compiles!!! + // using TiledArray::inplace_binary; + // return inplace_binary(first, second, element_op_); + using TiledArray::binary; + return binary(first, second, element_op_); + } } template ::type* = nullptr> result_type eval(const left_type& first, right_type& second) const { - TA_ASSERT(!element_op_); - using TiledArray::mult_to; - return mult_to(second, first); + if (!element_op_) { + using TiledArray::mult_to; + return mult_to(second, first); + } else { // WARNING: element_op_ might be noncommuting, so can't swap first + // and second! for GEMM could optimize, but can't introspect + // element_op_ + using TiledArray::binary; + return binary(first, second, element_op_); + } } template ::type* = nullptr> diff --git a/src/TiledArray/type_traits.h b/src/TiledArray/type_traits.h index 428ad63716..1bddff446d 100644 --- a/src/TiledArray/type_traits.h +++ b/src/TiledArray/type_traits.h @@ -1258,8 +1258,9 @@ struct is_array : public std::false_type {}; template struct is_array> : public std::true_type {}; -template -static constexpr bool is_array_v = is_array::value; +template +constexpr bool is_array_v = + (is_array>::value && ...); template using trange_t = typename T::trange_type; diff --git a/tests/blocked_pmap.cpp b/tests/blocked_pmap.cpp index 4ad055d885..80ab449570 100644 --- a/tests/blocked_pmap.cpp +++ b/tests/blocked_pmap.cpp @@ -25,7 +25,7 @@ using namespace TiledArray; struct BlockedPmapFixture { - BlockedPmapFixture() {} + constexpr static std::size_t max_ntiles = 10ul; }; // ============================================================================= @@ -34,7 +34,7 @@ struct BlockedPmapFixture { BOOST_FIXTURE_TEST_SUITE(blocked_pmap_suite, BlockedPmapFixture) BOOST_AUTO_TEST_CASE(constructor) { - for (std::size_t tiles = 1ul; tiles < 100ul; ++tiles) { + for (std::size_t tiles = 1ul; tiles < max_ntiles; ++tiles) { BOOST_REQUIRE_NO_THROW( TiledArray::detail::BlockedPmap pmap(*GlobalFixture::world, tiles)); TiledArray::detail::BlockedPmap pmap(*GlobalFixture::world, tiles); @@ -51,7 +51,7 @@ BOOST_AUTO_TEST_CASE(owner) { ProcessID* p_owner = new ProcessID[size]; // Check various pmap sizes - for (std::size_t tiles = 1ul; tiles < 100ul; ++tiles) { + for (std::size_t tiles = 1ul; tiles < max_ntiles; ++tiles) { TiledArray::detail::BlockedPmap pmap(*GlobalFixture::world, tiles); for (std::size_t tile = 0; tile < tiles; ++tile) { @@ -71,7 +71,7 @@ BOOST_AUTO_TEST_CASE(owner) { } BOOST_AUTO_TEST_CASE(local_size) { - for (std::size_t tiles = 1ul; tiles < 100ul; ++tiles) { + for (std::size_t tiles = 1ul; tiles < max_ntiles; ++tiles) { TiledArray::detail::BlockedPmap pmap(*GlobalFixture::world, tiles); std::size_t total_size = pmap.local_size(); @@ -87,7 +87,7 @@ BOOST_AUTO_TEST_CASE(local_size) { BOOST_AUTO_TEST_CASE(local_group) { ProcessID tile_owners[100]; - for (std::size_t tiles = 1ul; tiles < 100ul; ++tiles) { + for (std::size_t tiles = 1ul; tiles < max_ntiles; ++tiles) { TiledArray::detail::BlockedPmap pmap(*GlobalFixture::world, tiles); // Check that all local elements map to this rank diff --git a/tests/cyclic_pmap.cpp b/tests/cyclic_pmap.cpp index b8c2b9670c..a1a029c1cd 100644 --- a/tests/cyclic_pmap.cpp +++ b/tests/cyclic_pmap.cpp @@ -24,7 +24,7 @@ using namespace TiledArray; struct CyclicPmapFixture { - CyclicPmapFixture() {} + constexpr static std::size_t max_ntiles_per_dim = 4ul; }; // ============================================================================= @@ -86,8 +86,8 @@ BOOST_AUTO_TEST_CASE(owner) { ProcessID* p_owner = new ProcessID[size]; // Check various pmap sizes - for (std::size_t x = 1ul; x < 10ul; ++x) { - for (std::size_t y = 1ul; y < 10ul; ++y) { + for (std::size_t x = 1ul; x < max_ntiles_per_dim; ++x) { + for (std::size_t y = 1ul; y < max_ntiles_per_dim; ++y) { // Compute the limits for process rows const std::size_t min_proc_rows = std::max( ((GlobalFixture::world->size() + y - 1ul) / y), 1ul); @@ -123,8 +123,8 @@ BOOST_AUTO_TEST_CASE(owner) { } BOOST_AUTO_TEST_CASE(local_size) { - for (std::size_t x = 1ul; x < 10ul; ++x) { - for (std::size_t y = 1ul; y < 10ul; ++y) { + for (std::size_t x = 1ul; x < max_ntiles_per_dim; ++x) { + for (std::size_t y = 1ul; y < max_ntiles_per_dim; ++y) { // Compute the limits for process rows const std::size_t min_proc_rows = std::max( ((GlobalFixture::world->size() + y - 1ul) / y), 1ul); @@ -156,8 +156,8 @@ BOOST_AUTO_TEST_CASE(local_size) { BOOST_AUTO_TEST_CASE(local_group) { ProcessID tile_owners[100]; - for (std::size_t x = 1ul; x < 10ul; ++x) { - for (std::size_t y = 1ul; y < 10ul; ++y) { + for (std::size_t x = 1ul; x < max_ntiles_per_dim; ++x) { + for (std::size_t y = 1ul; y < max_ntiles_per_dim; ++y) { // Compute the limits for process rows const std::size_t min_proc_rows = std::max( ((GlobalFixture::world->size() + y - 1ul) / y), 1ul); diff --git a/tests/einsum.cpp b/tests/einsum.cpp index 6155b2cb98..cfae7b5925 100644 --- a/tests/einsum.cpp +++ b/tests/einsum.cpp @@ -25,6 +25,389 @@ #include "TiledArray/expressions/contraction_helpers.h" +BOOST_AUTO_TEST_SUITE(manual) + +namespace { +using il_trange = std::initializer_list>; +using il_extent = std::initializer_list; +} // namespace + +template >> +bool check_manual_eval(std::string const& annot, ArrayA A, ArrayB B) { + auto out = TA::einsum(annot, A, B); + auto ref = manual_eval(annot, A, B); + return ToTArrayFixture::are_equal(ref, out); +} + +template >> +bool check_manual_eval(std::string const& annot, ArrayA A, ArrayB B) { + return check_manual_eval(annot, A, B); +} + +template +bool check_manual_eval(std::string const& annot, il_trange trangeA, + il_trange trangeB) { + static_assert(detail::is_array_v && + detail::is_tensor_v); + auto A = random_array(TA::TiledRange(trangeA)); + auto B = random_array(TA::TiledRange(trangeB)); + return check_manual_eval(annot, A, B); +} + +template +bool check_manual_eval(std::string const& annot, il_trange trangeA, + il_trange trangeB) { + return check_manual_eval(annot, trangeA, + trangeB); +} + +template +bool check_manual_eval(std::string const& annot, il_trange trangeA, + il_trange trangeB, il_extent inner_extents) { + static_assert(detail::is_array_v); + + if constexpr (detail::is_tensor_of_tensor_v) { + static_assert(!detail::is_tensor_of_tensor_v); + return check_manual_eval( + annot, random_array(trangeA, inner_extents), + random_array(trangeB)); + } else { + static_assert(detail::is_tensor_of_tensor_v); + return check_manual_eval( + annot, random_array(trangeA), + random_array(trangeB, inner_extents)); + } +} + +template +bool check_manual_eval(std::string const& annot, il_trange trangeA, + il_trange trangeB, il_extent inner_extents) { + return check_manual_eval( + annot, trangeA, trangeB); +} + +template +bool check_manual_eval(std::string const& annot, il_trange trangeA, + il_trange trangeB, il_extent inner_extentsA, + il_extent inner_extentsB) { + static_assert(detail::is_array_v && + detail::is_tensor_of_tensor_v); + return check_manual_eval( + annot, random_array(trangeA, inner_extentsA), + random_array(trangeB, inner_extentsB)); +} + +template +bool check_manual_eval(std::string const& annot, il_trange trangeA, + il_trange trangeB, il_extent inner_extentsA, + il_extent inner_extentsB) { + return check_manual_eval( + annot, trangeA, trangeB, inner_extentsA, inner_extentsB); +} + +BOOST_AUTO_TEST_CASE(contract) { + using Array = TA::Array; + + BOOST_REQUIRE(check_manual_eval("ij,j->i", + {{0, 2, 4}, {0, 4, 8}}, // A's trange + {{0, 4, 8}} // B's trange + )); + BOOST_REQUIRE(check_manual_eval("ik,jk->ji", + {{0, 2, 4}, {0, 4, 8}}, // A's trange + {{0, 3}, {0, 4, 8}} // B's trange + )); + + BOOST_REQUIRE(check_manual_eval( + "ijkl,jm->lkmi", // + {{0, 2}, {0, 4, 8}, {0, 3}, {0, 7}}, // + {{0, 4, 8}, {0, 5}} // + )); +} + +BOOST_AUTO_TEST_CASE(hadamard) { + using Array = TA::Array; + BOOST_REQUIRE(check_manual_eval("i,i->i", // + {{0, 1}}, // + {{0, 1}} // + )); + BOOST_REQUIRE(check_manual_eval("i,i->i", // + {{0, 2, 4}}, // + {{0, 2, 4}} // + )); + + BOOST_REQUIRE(check_manual_eval("ijk,kij->ikj", // + {{0, 2, 4}, {0, 2, 3}, {0, 5}}, // + {{0, 5}, {0, 2, 4}, {0, 2, 3}} // + )); +} + +BOOST_AUTO_TEST_CASE(general) { + using Array = TA::Array; + BOOST_REQUIRE(check_manual_eval("ijk,kil->ijl", // + {{0, 2}, {0, 3, 5}, {0, 2, 4}}, // + {{0, 2, 4}, {0, 2}, {0, 1}} // + )); + + using Array = TA::Array; + using Tensor = typename Array::value_type; + using namespace std::string_literals; + + Tensor A(TA::Range{2, 3}, {1, 2, 3, 4, 5, 6}); + Tensor B(TA::Range{2}, {2, 10}); + Tensor C(TA::Range{2, 3}, {2, 4, 6, 40, 50, 60}); + BOOST_REQUIRE( + C == general_product(A, B, ProductSetup("ij"s, "i"s, "ij"s))); +} + +BOOST_AUTO_TEST_CASE(equal_nested_ranks) { + using ArrayToT = TA::DistArray>>; + + // H;H (Hadamard outer; Hadamard inner) + BOOST_REQUIRE(check_manual_eval("ij;mn,ji;nm->ij;mn", // + {{0, 2, 4}, {0, 3}}, // + {{0, 3}, {0, 2, 4}}, // + {5, 7}, // + {7, 5} // + )); + + // H;C (Hadamard outer; contraction inner) + BOOST_REQUIRE(check_manual_eval("ij;mo,ji;on->ij;mn", // + {{0, 2, 4}, {0, 3}}, // + {{0, 3}, {0, 2, 4}}, // + {3, 7}, // + {7, 4} // + )); + + // H;C + BOOST_REQUIRE(check_manual_eval("ij;mo,ji;o->ij;m", // + {{0, 2, 4}, {0, 3}}, // + {{0, 3}, {0, 2, 4}}, // + {3, 7}, // + {7} // + )); + + // C;C + BOOST_REQUIRE(check_manual_eval("ik;mo,kj;on->ij;mn", // + {{0, 3, 5}, {0, 2, 4}}, // + {{0, 2, 4}, {0, 2}}, // + {2, 2}, // + {2, 2})); + + // C;C + BOOST_REQUIRE(check_manual_eval("ijk;dcb,ik;bc->ij;d", // + {{0, 3}, {0, 4}, {0, 5}}, // + {{0, 3}, {0, 5}}, // + {2, 3, 4}, // + {4, 3})); + + // H+C;H + BOOST_REQUIRE(check_manual_eval("ijk;mn,ijk;nm->ij;mn", // + {{0, 2}, {0, 3}, {0, 2}}, // + {{0, 2}, {0, 3}, {0, 2}}, // + {2, 2}, // + {2, 2})); + + // H+C;C + BOOST_REQUIRE(check_manual_eval("ijk;mo,ijk;no->ij;nm", // + {{0, 2}, {0, 3}, {0, 2}}, // + {{0, 2}, {0, 3}, {0, 2}}, // + {3, 2}, // + {3, 2})); + + // H+C;C + BOOST_REQUIRE(check_manual_eval("ijk;m,ijk;n->ij;nm", // + {{0, 2}, {0, 3}, {0, 2}}, // + {{0, 2}, {0, 3}, {0, 2}}, // + {3}, // + {2})); + // H+C;H+C not supported +} + +BOOST_AUTO_TEST_CASE(different_nested_ranks) { + using ArrayT = TA::DistArray>; + using ArrayToT = TA::DistArray>>; + + { + // these tests do not involve permutation of inner tensors + // H + BOOST_REQUIRE( + (check_manual_eval("ij;mn,ji->ji;mn", // + {{0, 2, 5}, {0, 3, 5, 9}}, // + {{0, 3, 5, 9}, {0, 2, 5}}, // + {2, 1}))); + + // H (reversed arguments) + BOOST_REQUIRE( + (check_manual_eval("ji,ij;mn->ji;mn", // + {{0, 3, 5, 9}, {0, 2, 5}}, // + {{0, 2, 5}, {0, 3, 5, 9}}, // + {2, 4}))); + + // C (outer product) + BOOST_REQUIRE((check_manual_eval("i;mn,j->ij;mn", // + {{0, 5}}, // + {{0, 3, 8}}, // + {3, 2}))); + + // C (outer product) (reversed arguments) + BOOST_REQUIRE((check_manual_eval("j,i;mn->ij;mn", // + {{0, 3, 8}}, // + {{0, 5}}, // + {2, 2}))); + } + + // C (outer product) + BOOST_REQUIRE((check_manual_eval("ik;mn,j->ijk;nm", // + {{0, 2, 4}, {0, 4}}, // + {{0, 3, 5}}, // + {3, 2}))); + + // C (outer product) (reversed arguments) + BOOST_REQUIRE((check_manual_eval("jl,ik;mn->ijkl;nm", // + {{0, 3, 5}, {0, 3}}, // + {{0, 2, 4}, {0, 4}}, // + {3, 2}))); + + // H+C (outer product) + BOOST_REQUIRE((check_manual_eval("ij;mn,ik->ijk;nm", // + {{0, 2, 5}, {0, 3, 7}}, // + {{0, 2, 5}, {0, 4, 7}}, // + {2, 5}))); + + // H+C (outer product) (reversed arguments) + BOOST_REQUIRE((check_manual_eval("ik,ij;mn->ijk;nm", // + {{0, 2, 5}, {0, 4, 7}}, // + {{0, 2, 5}, {0, 3, 7}}, // + {2, 5}))); + + { + // these tests do not involve permutation of inner tensors + // H+C + BOOST_REQUIRE( + (check_manual_eval("ik;mn,ijk->ij;mn", // + {{0, 2}, {0, 3}}, // + {{0, 2}, {0, 2}, {0, 3}}, // + {2, 2}))); + + // H+C (reversed arguments) + BOOST_REQUIRE( + (check_manual_eval("ijk,ik;mn->ij;mn", // + {{0, 2}, {0, 2}, {0, 3}}, // + {{0, 2}, {0, 3}}, // + {2, 2}))); + } + + // H + BOOST_REQUIRE((check_manual_eval("ij;mn,ji->ji;nm", // + {{0, 2, 4, 6}, {0, 3}}, // + {{0, 3}, {0, 2, 4, 6}}, // + {4, 2}))); + + // H (reversed arguments) + BOOST_REQUIRE((check_manual_eval("ji,ij;mn->ji;nm", // + {{0, 3, 5}, {0, 2, 4}}, // + {{0, 2, 4}, {0, 3, 5}}, // + {1, 2}))); + + // C + BOOST_REQUIRE((check_manual_eval("ij;m,j->i;m", // + {{0, 5}, {0, 2, 3}}, // + {{0, 2, 3}}, // + {3}))); + + // C (reversed arguments) + BOOST_REQUIRE((check_manual_eval("j,ij;m->i;m", // + {{0, 2}}, // + {{0, 1}, {0, 2}}, // + {3}))); + + // H+C + BOOST_REQUIRE(( + check_manual_eval("ik;mn,ijk->ij;nm", // + {{0, 2}, {0, 3, 5}}, // + {{0, 2}, {0, 2, 4, 6}, {0, 3, 5}}, // + {2, 2}))); + + // H+C (reversed arguments) + BOOST_REQUIRE( + (check_manual_eval("ijk,ik;mn->ij;nm", // + {{0, 2}, {0, 4}, {0, 3}}, // + {{0, 2}, {0, 3}}, // + {2, 4}))); +} + +BOOST_AUTO_TEST_CASE(nested_rank_reduction) { + using T = TA::Tensor; + using ToT = TA::Tensor; + using Array = TA::DistArray; + using ArrayToT = TA::DistArray; + BOOST_REQUIRE( + (check_manual_eval("ij;ab,ij;ab->ij", // + {{0, 2, 4}, {0, 4}}, // + {{0, 2, 4}, {0, 4}}, // + {3, 2}, // + {3, 2}))); + BOOST_REQUIRE( + (check_manual_eval("ij;ab,ij;ab->i", // + {{0, 2, 4}, {0, 4}}, // + {{0, 2, 4}, {0, 4}}, // + {3, 2}, // + {3, 2}))); +} + +BOOST_AUTO_TEST_CASE(corner_cases) { + using T = TA::Tensor; + using ToT = TA::Tensor; + using ArrayT = TA::DistArray; + using ArrayToT = TA::DistArray; + + BOOST_REQUIRE(check_manual_eval("ia,i->ia", // + {{0, 2, 5}, {0, 7, 11, 16}}, // + {{0, 2, 5}})); + + BOOST_REQUIRE(check_manual_eval("i,ai->ia", // + {{0, 2, 5}}, // + {{0, 7, 11, 16}, {0, 2, 5}})); + + BOOST_REQUIRE(check_manual_eval("ijk,kj->kij", // + {{0, 2, 5}, {0, 3, 6}, {0, 2, 7}}, // + {{0, 2, 7}, {0, 3, 6}})); + + BOOST_REQUIRE(check_manual_eval("kj,ijk->kij", // + {{0, 2, 7}, {0, 3, 6}}, // + {{0, 2, 5}, {0, 3, 6}, {0, 2, 7}})); + + BOOST_REQUIRE(check_manual_eval("kij;ab,kj;bc->kji;ac", // + {{0, 2}, {0, 3, 5}, {0, 4, 7}}, // + {{0, 2}, {0, 4, 7}}, // + {3, 5}, {5, 2})); + + BOOST_REQUIRE( + (check_manual_eval("ijk;ab,kj->kij;ba", // + {{0, 2}, {0, 4, 6}, {0, 3, 5}}, // + {{0, 3, 5}, {0, 4, 6}}, // + {7, 5}))); + + BOOST_REQUIRE( + (check_manual_eval("ij,jik;ab->kji;ab", // + {{0, 3, 5}, {0, 3, 8}}, // + {{0, 3, 8}, {0, 3, 5}, {0, 2}}, // + {3, 9}))); +} + +BOOST_AUTO_TEST_SUITE_END() + using namespace TiledArray; using namespace TiledArray::expressions; @@ -630,8 +1013,6 @@ BOOST_AUTO_TEST_CASE(ijk_mn_eq_ij_mn_times_kj_mn) { auto i = res_ix[0]; auto j = res_ix[1]; auto k = res_ix[2]; - using Ix2 = std::array; - using Ix3 = std::array; auto lhs_tile_ix = lhs.trange().element_to_tile({i, j}); auto lhs_tile = lhs.find_local(lhs_tile_ix).get(/* dowork = */ false); @@ -793,6 +1174,105 @@ BOOST_AUTO_TEST_CASE(xxx) { BOOST_CHECK(are_equal); } +BOOST_AUTO_TEST_CASE(ij_mn_eq_ij_mo_times_ji_on) { + auto& world = TA::get_default_world(); + + using Array = TA::DistArray>, TA::DensePolicy>; + using Perm = TA::Permutation; + + TA::TiledRange lhs_trng{{0, 2, 3}, {0, 1}}; + TA::TiledRange rhs_trng{{0, 1}, {0, 2, 3}}; + TA::Range lhs_inner_rng{1, 1}; + TA::Range rhs_inner_rng{1, 1}; + + auto lhs = random_array(lhs_trng, lhs_inner_rng); + auto rhs = random_array(rhs_trng, rhs_inner_rng); + Array out; + BOOST_REQUIRE_NO_THROW(out("i,j;m,n") = lhs("i,j;m,o") * rhs("j,i;o,n")); +} + +BOOST_AUTO_TEST_CASE(ij_mn_eq_ijk_mo_times_ijk_no) { + using Array = TA::DistArray>, TA::DensePolicy>; + using Ix = typename TA::Range::index1_type; + using namespace std::string_literals; + auto& world = TA::get_default_world(); + + Ix const K = 2; // the extent of contracted outer mode + + TA::Range const inner_rng{3, 7}; + TA::TiledRange const lhs_trng{ + std::initializer_list>{ + {0, 2, 4}, {0, 2}, {0, 2}}}; + TA::TiledRange const rhs_trng(lhs_trng); + TA::TiledRange const ref_trng{lhs_trng.dim(0), lhs_trng.dim(1)}; + TA::Range const ref_inner_rng{3, 3}; // contract(3x7,3x7) -> (3,3) + auto lhs = random_array(lhs_trng, inner_rng); + auto rhs = random_array(rhs_trng, inner_rng); + + // + // manual evaluation: ij;mn = ijk;mo * ijk;no + // + Array ref{world, ref_trng}; + { + lhs.make_replicated(); + rhs.make_replicated(); + world.gop.fence(); + + auto make_tile = [lhs, rhs, ref_inner_rng](TA::Range const& rng) { + using InnerT = typename Array::value_type::value_type; + typename Array::value_type result_tile{rng}; + + for (auto&& res_ix : result_tile.range()) { + auto i = res_ix[0]; + auto j = res_ix[1]; + + InnerT mn; + for (Ix k = 0; k < K; ++k) { + auto lhs_tile = + lhs.find_local(lhs.trange().element_to_tile({i, j, k})) + .get(/*dowork = */ false); + auto rhs_tile = + rhs.find_local(rhs.trange().element_to_tile({i, j, k})) + .get(/*dowork = */ false); + mn.add_to(tensor_contract("mo,no->mn", lhs_tile({i, j, k}), + rhs_tile({i, j, k}))); + } + result_tile({i, j}) = std::move(mn); + } + return result_tile; + }; + using std::begin; + using std::end; + + for (auto it = begin(ref); it != end(ref); ++it) + if (ref.is_local(it.index())) { + auto tile = world.taskq.add(make_tile, it.make_range()); + *it = tile; + } + } + + auto out = einsum(lhs("i,j,k;m,o"), rhs("i,j,k;n,o"), "i,j;m,n"); + bool are_equal = ToTArrayFixture::are_equal(ref, out); + + BOOST_CHECK(are_equal); +} + +#ifdef TILEDARRAY_HAS_BTAS +BOOST_AUTO_TEST_CASE(tensor_contract) { + using TensorT = TA::Tensor; + + TA::Range const rng_A{2, 3, 4}; + TA::Range const rng_B{4, 3, 2}; + auto const A = random_tensor(rng_A); + auto const B = random_tensor(rng_B); + + BOOST_CHECK(tensor_contract_equal("ijk,klm->ijlm", A, B)); + BOOST_CHECK(tensor_contract_equal("ijk,klm->milj", A, B)); + BOOST_CHECK(tensor_contract_equal("ijk,kjm->im", A, B)); + BOOST_CHECK(tensor_contract_equal("ijk,kli->lj", A, B)); +} +#endif + BOOST_AUTO_TEST_SUITE_END() // einsum_tot BOOST_AUTO_TEST_SUITE(einsum_tot_t) @@ -850,20 +1330,15 @@ BOOST_AUTO_TEST_CASE(ilkj_nm_eq_ij_mn_times_kl) { auto k = res_ix[2]; auto j = res_ix[3]; - using Ix2 = std::array; - using Ix4 = std::array; - - auto lhs_tile_ix = lhs.trange().element_to_tile(Ix2{i, j}); + auto lhs_tile_ix = lhs.trange().element_to_tile({i, j}); auto lhs_tile = lhs.find_local(lhs_tile_ix).get(/* dowork = */ false); - auto rhs_tile_ix = rhs.trange().element_to_tile(Ix2{k, l}); + auto rhs_tile_ix = rhs.trange().element_to_tile({k, l}); auto rhs_tile = rhs.find_local(rhs_tile_ix).get(/* dowork = */ false); - auto& res_el = - result_tile.at_ordinal(result_tile.range().ordinal(Ix4{i, l, k, j})); - auto const& lhs_el = - lhs_tile.at_ordinal(lhs_tile.range().ordinal(Ix2{i, j})); - auto rhs_el = rhs_tile.at_ordinal(rhs_tile.range().ordinal(Ix2{k, l})); + auto& res_el = result_tile({i, l, k, j}); + auto const& lhs_el = lhs_tile({i, j}); + auto rhs_el = rhs_tile({k, l}); res_el = tot_type::element_type( lhs_el.scale(rhs_el), // scale @@ -949,20 +1424,15 @@ BOOST_AUTO_TEST_CASE(ijk_mn_eq_ij_mn_times_jk) { auto j = res_ix[1]; auto k = res_ix[2]; - using Ix2 = std::array; - using Ix3 = std::array; - - auto lhs_tile_ix = lhs.trange().element_to_tile(Ix2{i, j}); + auto lhs_tile_ix = lhs.trange().element_to_tile({i, j}); auto lhs_tile = lhs.find_local(lhs_tile_ix).get(/* dowork = */ false); - auto rhs_tile_ix = rhs.trange().element_to_tile(Ix2{j, k}); + auto rhs_tile_ix = rhs.trange().element_to_tile({j, k}); auto rhs_tile = rhs.find_local(rhs_tile_ix).get(/* dowork = */ false); - auto& res_el = - result_tile.at_ordinal(result_tile.range().ordinal(Ix3{i, j, k})); - auto const& lhs_el = - lhs_tile.at_ordinal(lhs_tile.range().ordinal(Ix2{i, j})); - auto rhs_el = rhs_tile.at_ordinal(rhs_tile.range().ordinal(Ix2{j, k})); + auto& res_el = result_tile({i, j, k}); + auto const& lhs_el = lhs_tile({i, j}); + auto rhs_el = rhs_tile({j, k}); res_el = lhs_el.scale(rhs_el); } @@ -1057,19 +1527,15 @@ BOOST_AUTO_TEST_CASE(ij_mn_eq_ji_mn_times_ij) { auto i = res_ix[0]; auto j = res_ix[1]; - using Ix2 = std::array; - - auto lhs_tile_ix = lhs.trange().element_to_tile(Ix2{j, i}); + auto lhs_tile_ix = lhs.trange().element_to_tile({j, i}); auto lhs_tile = lhs.find_local(lhs_tile_ix).get(/* dowork */ false); - auto rhs_tile_ix = rhs.trange().element_to_tile(Ix2({i, j})); + auto rhs_tile_ix = rhs.trange().element_to_tile({i, j}); auto rhs_tile = rhs.find_local(rhs_tile_ix).get(/* dowork */ false); - auto& res_el = - result_tile.at_ordinal(result_tile.range().ordinal(Ix2{i, j})); - auto const& lhs_el = - lhs_tile.at_ordinal(lhs_tile.range().ordinal(Ix2{j, i})); - auto rhs_el = rhs_tile.at_ordinal(rhs_tile.range().ordinal(Ix2{i, j})); + auto& res_el = result_tile({i, j}); + auto const& lhs_el = lhs_tile({j, i}); + auto rhs_el = rhs_tile({i, j}); res_el = tot_type::element_type(lhs_el.scale(rhs_el), // scale TiledArray::Permutation{0, 1} // permute ); diff --git a/tests/hash_pmap.cpp b/tests/hash_pmap.cpp index a9b573802c..06d721dceb 100644 --- a/tests/hash_pmap.cpp +++ b/tests/hash_pmap.cpp @@ -24,7 +24,7 @@ using namespace TiledArray; struct HashPmapFixture { - HashPmapFixture() {} + constexpr static std::size_t max_ntiles = 10ul; }; // ============================================================================= @@ -33,7 +33,7 @@ struct HashPmapFixture { BOOST_FIXTURE_TEST_SUITE(hash_pmap_suite, HashPmapFixture) BOOST_AUTO_TEST_CASE(constructor) { - for (std::size_t tiles = 1ul; tiles < 100ul; ++tiles) { + for (std::size_t tiles = 1ul; tiles < max_ntiles; ++tiles) { BOOST_REQUIRE_NO_THROW( TiledArray::detail::HashPmap pmap(*GlobalFixture::world, tiles)); TiledArray::detail::HashPmap pmap(*GlobalFixture::world, tiles); @@ -50,7 +50,7 @@ BOOST_AUTO_TEST_CASE(owner) { ProcessID* p_owner = new ProcessID[size]; // Check various pmap sizes - for (std::size_t tiles = 1ul; tiles < 100ul; ++tiles) { + for (std::size_t tiles = 1ul; tiles < max_ntiles; ++tiles) { TiledArray::detail::HashPmap pmap(*GlobalFixture::world, tiles); for (std::size_t tile = 0; tile < tiles; ++tile) { @@ -77,7 +77,7 @@ BOOST_AUTO_TEST_CASE(local_size) { BOOST_AUTO_TEST_CASE(local_group) { ProcessID tile_owners[100]; - for (std::size_t tiles = 1ul; tiles < 100ul; ++tiles) { + for (std::size_t tiles = 1ul; tiles < max_ntiles; ++tiles) { TiledArray::detail::HashPmap pmap(*GlobalFixture::world, tiles); // Check that all local elements map to this rank diff --git a/tests/range.cpp b/tests/range.cpp index a5ac8898f9..a20f185d44 100644 --- a/tests/range.cpp +++ b/tests/range.cpp @@ -517,6 +517,9 @@ BOOST_AUTO_TEST_CASE(permutation) { BOOST_CHECK_EQUAL_COLLECTIONS(r3.stride_data(), r3.stride_data() + r3.rank(), r2.stride_data(), r2.stride_data() + r2.rank()); BOOST_CHECK_EQUAL(r3, r2); + + // using null Permutation is allowed + BOOST_CHECK_EQUAL(Range(Permutation{}, r1), r1); } BOOST_AUTO_TEST_CASE(include) { @@ -700,13 +703,13 @@ BOOST_AUTO_TEST_CASE(serialization) { 2 * (sizeof(Range) + sizeof(std::size_t) * (4 * GlobalFixture::dim + 1)); unsigned char* buf = new unsigned char[buf_size]; madness::archive::BufferOutputArchive oar(buf, buf_size); - oar& r; + oar & r; std::size_t nbyte = oar.size(); oar.close(); Range rs; madness::archive::BufferInputArchive iar(buf, nbyte); - iar& rs; + iar & rs; iar.close(); delete[] buf; diff --git a/tests/replicated_pmap.cpp b/tests/replicated_pmap.cpp index 1a06b85ea4..f9c8b45618 100644 --- a/tests/replicated_pmap.cpp +++ b/tests/replicated_pmap.cpp @@ -27,16 +27,13 @@ #include "unit_test_config.h" struct ReplicatedPmapFixture { - ReplicatedPmapFixture() {} - - ~ReplicatedPmapFixture() {} - + constexpr static std::size_t max_ntiles = 10ul; }; // Fixture BOOST_FIXTURE_TEST_SUITE(replicated_pmap_suite, ReplicatedPmapFixture) BOOST_AUTO_TEST_CASE(constructor) { - for (std::size_t tiles = 1ul; tiles < 100ul; ++tiles) { + for (std::size_t tiles = 1ul; tiles < max_ntiles; ++tiles) { BOOST_REQUIRE_NO_THROW( TiledArray::detail::ReplicatedPmap pmap(*GlobalFixture::world, tiles)); TiledArray::detail::ReplicatedPmap pmap(*GlobalFixture::world, tiles); @@ -50,7 +47,7 @@ BOOST_AUTO_TEST_CASE(owner) { const std::size_t rank = GlobalFixture::world->rank(); // Check various pmap sizes - for (std::size_t tiles = 1ul; tiles < 100ul; ++tiles) { + for (std::size_t tiles = 1ul; tiles < max_ntiles; ++tiles) { TiledArray::detail::ReplicatedPmap pmap(*GlobalFixture::world, tiles); for (std::size_t tile = 0; tile < tiles; ++tile) { @@ -60,7 +57,7 @@ BOOST_AUTO_TEST_CASE(owner) { } BOOST_AUTO_TEST_CASE(local_size) { - for (std::size_t tiles = 1ul; tiles < 100ul; ++tiles) { + for (std::size_t tiles = 1ul; tiles < max_ntiles; ++tiles) { TiledArray::detail::ReplicatedPmap pmap(*GlobalFixture::world, tiles); // Check that the total number of elements in all local groups is equal to @@ -71,7 +68,7 @@ BOOST_AUTO_TEST_CASE(local_size) { } BOOST_AUTO_TEST_CASE(local_group) { - for (std::size_t tiles = 1ul; tiles < 100ul; ++tiles) { + for (std::size_t tiles = 1ul; tiles < max_ntiles; ++tiles) { TiledArray::detail::ReplicatedPmap pmap(*GlobalFixture::world, tiles); // Check that all local elements map to this rank diff --git a/tests/round_robin_pmap.cpp b/tests/round_robin_pmap.cpp index 4851c5b5b1..7c601d4bfd 100644 --- a/tests/round_robin_pmap.cpp +++ b/tests/round_robin_pmap.cpp @@ -25,7 +25,7 @@ using namespace TiledArray; struct RoundRobinPmapFixture { - RoundRobinPmapFixture() {} + constexpr static std::size_t max_ntiles = 10ul; }; // ============================================================================= @@ -51,7 +51,7 @@ BOOST_AUTO_TEST_CASE(owner) { ProcessID *p_owner = new ProcessID[size]; // Check various pmap sizes - for (std::size_t tiles = 1ul; tiles < 100ul; ++tiles) { + for (std::size_t tiles = 1ul; tiles < max_ntiles; ++tiles) { TiledArray::detail::RoundRobinPmap pmap(*GlobalFixture::world, tiles); for (std::size_t tile = 0; tile < tiles; ++tile) { @@ -71,7 +71,7 @@ BOOST_AUTO_TEST_CASE(owner) { } BOOST_AUTO_TEST_CASE(local_size) { - for (std::size_t tiles = 1ul; tiles < 100ul; ++tiles) { + for (std::size_t tiles = 1ul; tiles < max_ntiles; ++tiles) { TiledArray::detail::RoundRobinPmap pmap(*GlobalFixture::world, tiles); std::size_t total_size = pmap.local_size(); @@ -87,7 +87,7 @@ BOOST_AUTO_TEST_CASE(local_size) { BOOST_AUTO_TEST_CASE(local_group) { ProcessID tile_owners[100]; - for (std::size_t tiles = 1ul; tiles < 100ul; ++tiles) { + for (std::size_t tiles = 1ul; tiles < max_ntiles; ++tiles) { TiledArray::detail::RoundRobinPmap pmap(*GlobalFixture::world, tiles); // Check that all local elements map to this rank diff --git a/tests/tot_array_fixture.h b/tests/tot_array_fixture.h index c01399dbba..94f57b0930 100644 --- a/tests/tot_array_fixture.h +++ b/tests/tot_array_fixture.h @@ -19,10 +19,13 @@ #ifndef TILEDARRAY_TEST_TOT_ARRAY_FIXTURE_H__INCLUDED #define TILEDARRAY_TEST_TOT_ARRAY_FIXTURE_H__INCLUDED -#include "tiledarray.h" +#include +#include #include "unit_test_config.h" #ifdef TILEDARRAY_HAS_BTAS +#include #include +#include #endif /* Notes: @@ -88,11 +91,567 @@ using input_archive_type = madness::archive::BinaryFstreamInputArchive; // Type of an output archive using output_archive_type = madness::archive::BinaryFstreamOutputArchive; -enum class ShapeComp { - True, - False +enum class ShapeComp { True, False }; + +template , bool> = true> +auto random_tensor(TA::Range const& rng) { + using NumericT = typename TensorT::numeric_type; + TensorT result{rng}; + + std::generate(/*std::execution::par, */ + result.begin(), result.end(), + TA::detail::MakeRandom::generate_value); + return result; +} + +template +auto random_tensor(std::initializer_list const& extents) { + auto lobounds = TA::container::svector(extents.size(), 0); + return random_tensor(TA::Range{lobounds, extents}); +} + +// +// note: all the inner tensors (elements of the outer tensor) +// have the same @c inner_rng +// +template < + typename TensorT, + std::enable_if_t, bool> = true> +auto random_tensor(TA::Range const& outer_rng, TA::Range const& inner_rng) { + using InnerTensorT = typename TensorT::value_type; + TensorT result{outer_rng}; + + std::generate(/*std::execution::par,*/ + result.begin(), result.end(), [inner_rng]() { + return random_tensor(inner_rng); + }); + + return result; +} + +template +auto random_tensor(TA::Range const& outer_rng, + std::initializer_list const& inner_extents) { + TA::container::svector lobounds(inner_extents.size(), 0); + return random_tensor(outer_rng, TA::Range(lobounds, inner_extents)); +} + +/// +/// \tparam Array The type of DistArray to be generated. Cannot be cv-qualified +/// or reference type. +/// \tparam Args TA::Range type for inner tensor if the tile type of the result +/// is a tensor-of-tensor. +/// \param trange The TiledRange of the result DistArray. +/// \param args Either exactly one TA::Range type when the tile type of Array is +/// tensor-of-tensor or nothing. +/// \return Returns a DistArray of type Array whose elements are randomly +/// generated. +/// @note: +/// - Although DistArrays with Sparse policy can be generated all of their +/// tiles are initialized with random values -- technically the returned value +/// is dense. +/// - In case of arrays with tensor-of-tensor tiles, all the inner tensors have +/// the same rank and the same extent of corresponding modes. +/// +template < + typename Array, typename... Args, + typename = + std::void_t, + std::enable_if_t, + bool> = true> +auto random_array(TA::TiledRange const& trange, Args const&... args) { + static_assert( + (sizeof...(Args) == 0 && + TA::detail::is_tensor_v) || + (sizeof...(Args) == 1) && + (TA::detail::is_tensor_of_tensor_v)); + + using TensorT = typename Array::value_type; + using PolicyT = typename Array::policy_type; + + auto make_tile_meta = [](auto&&... args) { + return [=](TensorT& tile, TA::Range const& rng) { + tile = random_tensor(rng, args...); + if constexpr (std::is_same_v) + return tile.norm(); + }; + }; + + return TA::make_array(TA::get_default_world(), trange, + make_tile_meta(args...)); +} + +template +auto random_array(std::initializer_list> trange, + Args&&... args) { + return random_array(TA::TiledRange(trange), + std::forward(args)...); +} + +/// +/// Succinctly call TA::detail::tensor_contract +/// +/// \tparam T TA::Tensor type. +/// \param einsum_annot Example annot: 'ik,kj->ij', when @c A is annotated by +/// 'i' and 'k' for its two modes, and @c B is annotated by 'k' and 'j' for the +/// same. The result tensor is rank-2 as well and its modes are annotated by 'i' +/// and 'j'. +/// \return Tensor contraction result. +/// +template , bool> = true> +auto tensor_contract(std::string const& einsum_annot, T const& A, T const& B) { + using ::Einsum::string::split2; + auto [ab, aC] = split2(einsum_annot, "->"); + auto [aA, aB] = split2(ab, ","); + + return TA::detail::tensor_contract(A, aA, B, aB, aC); +} + +using PartialPerm = TA::container::svector>; + +template +PartialPerm partial_perm(::Einsum::index::Index const& from, + ::Einsum::index::Index const& to) { + PartialPerm result; + for (auto i = 0; i < from.size(); ++i) + if (auto found = to.find(from[i]); found != to.end()) + result.emplace_back(i, std::distance(to.begin(), found)); + return result; +} + +template >> +void apply_partial_perm(T& to, T const& from, PartialPerm const& p) { + for (auto [f, t] : p) { + TA_ASSERT(f < from.size() && t < to.size() && "Invalid permutation used"); + to[t] = from[f]; + } +} + +/// +/// Example: To represent A("ik;ac") * B("kj;cb") -> C("ij;ab"), +/// construct with std::string("ij;ac,kj;cb->ij;ab"); +/// outer_indices;inner_indices annotates a single object (DistArray, Tensor +/// etc.) A_indices,B_indices annotates first(A) and second(B) object +/// '->' separates argument objects' annotation from the result's annotation +/// +class OuterInnerIndices { + // array[0] annotes A + // array[1] annotes B + // array[2] annotes C + std::array outer_, inner_; + + public: + OuterInnerIndices(std::string const& annot) { + using ::Einsum::string::split2; + + constexpr size_t A = 0; + constexpr size_t B = 1; + constexpr size_t C = 2; + + auto [ab, aC] = split2(annot, "->"); + std::tie(outer_[C], inner_[C]) = split2(aC, ";"); + + auto [aA, aB] = split2(ab, ","); + + std::tie(outer_[A], inner_[A]) = split2(aA, ";"); + std::tie(outer_[B], inner_[B]) = split2(aB, ";"); + } + + template + OuterInnerIndices(const char (&s)[N]) : OuterInnerIndices{std::string(s)} {} + + [[nodiscard]] auto const& outer() const noexcept { return outer_; } + [[nodiscard]] auto const& inner() const noexcept { return inner_; } + + [[nodiscard]] auto const& outerA() const noexcept { return outer_[0]; } + [[nodiscard]] auto const& outerB() const noexcept { return outer_[1]; } + [[nodiscard]] auto const& outerC() const noexcept { return outer_[2]; } + [[nodiscard]] auto const& innerA() const noexcept { return inner_[0]; } + [[nodiscard]] auto const& innerB() const noexcept { return inner_[1]; } + [[nodiscard]] auto const& innerC() const noexcept { return inner_[2]; } }; +enum struct TensorProduct { General, Dot, Invalid }; + +struct ProductSetup { + TensorProduct product_type{TensorProduct::Invalid}; + + PartialPerm + // - {} index at kth position in C appears at vth position in A + // and so on... + // - {} is sorted by k + C_to_A, + C_to_B, + I_to_A, // 'I' implies for contracted indices + I_to_B; + size_t // + rank_A, // + rank_B, + rank_C, // + rank_H, + rank_E, // + rank_I; + + // ProductSetup() = default; + + template >> + ProductSetup(T const& aA, T const& aB, T const& aC) { + using Indices = ::Einsum::index::Index; + + struct { + // A, B, C tensor indices + // H, E, I Hadamard, external, and internal indices + Indices A, B, C, H, E, I; + } const ixs{Indices(aA), Indices(aB), + Indices(aC), (ixs.A & ixs.B & ixs.C), + (ixs.A ^ ixs.B), ((ixs.A & ixs.B) - ixs.H)}; + + rank_A = ixs.A.size(); + rank_B = ixs.B.size(); + rank_C = ixs.C.size(); + rank_H = ixs.H.size(); + rank_E = ixs.E.size(); + rank_I = ixs.I.size(); + + C_to_A = partial_perm(ixs.C, ixs.A); + C_to_B = partial_perm(ixs.C, ixs.B); + I_to_A = partial_perm(ixs.I, ixs.A); + I_to_B = partial_perm(ixs.I, ixs.B); + + using TP = decltype(product_type); + + if (rank_A + rank_B != 0 && rank_C != 0) + product_type = TP::General; + else if (rank_A == rank_B && rank_B != 0 && rank_C == 0) + product_type = TP::Dot; + else + product_type = TP::Invalid; + } + + template >> + ProductSetup(ArrayLike const& arr) + : ProductSetup(std::get<0>(arr), std::get<1>(arr), std::get<2>(arr)) {} + + [[nodiscard]] bool valid() const noexcept { + return product_type != TensorProduct::Invalid; + } +}; + +namespace { + +auto make_perm(PartialPerm const& pp) { + TA::container::svector p(pp.size()); + for (auto [k, v] : pp) p[k] = v; + return TA::Permutation(p); +} + +template >> +inline Result general_product(Tensor const& t, typename Tensor::numeric_type s, + ProductSetup const& setup, + Setups const&... args) { + static_assert(std::is_same_v); + static_assert(sizeof...(args) == 0, + "To-Do: Only scalar times once-nested tensor supported now"); + return t.scale(s, make_perm(setup.C_to_A).inv()); +} + +template >> +inline Result general_product(typename Tensor::numeric_type s, Tensor const& t, + ProductSetup const& setup, + Setups const&... args) { + static_assert(std::is_same_v); + static_assert(sizeof...(args) == 0, + "To-Do: Only scalar times once-nested tensor supported now"); + return t.scale(s, make_perm(setup.C_to_B).inv()); +} + +} // namespace + +template < + typename Result, typename TensorA, typename TensorB, typename... Setups, + typename = + std::enable_if_t>> +Result general_product(TensorA const& A, TensorB const& B, + ProductSetup const& setup, Setups const&... args) { + using TA::detail::max_nested_rank; + using TA::detail::nested_rank; + + static_assert(std::is_same_v); + + static_assert(max_nested_rank == sizeof...(args) + 1); + + TA_ASSERT(setup.valid()); + + constexpr bool is_tot = max_nested_rank > 1; + + if constexpr (std::is_same_v) { + // + // tensor dot product evaluation + // T * T -> scalar + // ToT * ToT -> scalar + // + static_assert(nested_rank == nested_rank); + + TA_ASSERT(setup.rank_C == 0 && + "Attempted to evaluate dot product when the product setup does " + "not allow"); + + Result result{}; + + for (auto&& ix_A : A.range()) { + TA::Range::index_type ix_B(setup.rank_B, 0); + apply_partial_perm(ix_B, ix_A, setup.I_to_B); + + if constexpr (is_tot) { + auto const& lhs = A(ix_A); + auto const& rhs = B(ix_B); + result += general_product(lhs, rhs, args...); + } else + result += A(ix_A) * B(ix_B); + } + + return result; + } else { + // + // general product: + // T * T -> T + // ToT * T -> ToT + // ToT * ToT -> ToT + // ToT * ToT -> T + // + + static_assert(nested_rank <= max_nested_rank, + "Tensor product not supported with increased nested rank in " + "the result"); + + // creating the contracted TA::Range + TA::Range const rng_I = [&setup, &A, &B]() { + TA::container::svector rng1_I(setup.rank_I, TA::Range1{}); + for (auto [f, t] : setup.I_to_A) + // I_to_A implies I[f] == A[t] + rng1_I[f] = A.range().dim(t); + + return TA::Range(rng1_I); + }(); + + // creating the target TA::Range. + TA::Range const rng_C = [&setup, &A, &B]() { + TA::container::svector rng1_C(setup.rank_C, TA::Range1{0, 0}); + for (auto [f, t] : setup.C_to_A) + // C_to_A implies C[f] = A[t] + rng1_C[f] = A.range().dim(t); + + for (auto [f, t] : setup.C_to_B) + // C_to_B implies C[f] = B[t] + rng1_C[f] = B.range().dim(t); + + auto zero_r1 = [](TA::Range1 const& r) { return r == TA::Range1{0, 0}; }; + + TA_ASSERT(std::none_of(rng1_C.begin(), rng1_C.end(), zero_r1)); + + return TA::Range(rng1_C); + }(); + + Result C{rng_C}; + + // do the computation + for (auto ix_C : rng_C) { + // finding corresponding indices of A, and B. + TA::Range::index_type ix_A(setup.rank_A, 0), ix_B(setup.rank_B, 0); + apply_partial_perm(ix_A, ix_C, setup.C_to_A); + apply_partial_perm(ix_B, ix_C, setup.C_to_B); + + if (setup.rank_I == 0) { + if constexpr (is_tot) { + C(ix_C) = general_product( + A(ix_A), B(ix_B), args...); + } else { + TA_ASSERT(!(ix_A.empty() && ix_B.empty())); + C(ix_C) = ix_A.empty() ? B(ix_B) + : ix_B.empty() ? A(ix_B) + : A(ix_A) * B(ix_B); + } + } else { + typename Result::value_type temp{}; + for (auto ix_I : rng_I) { + apply_partial_perm(ix_A, ix_I, setup.I_to_A); + apply_partial_perm(ix_B, ix_I, setup.I_to_B); + if constexpr (is_tot) + temp += general_product( + A(ix_A), B(ix_B), args...); + else { + TA_ASSERT(!(ix_A.empty() || ix_B.empty())); + temp += A(ix_A) * B(ix_B); + } + } + C(ix_C) = temp; + } + } + + return C; + } +} + +template +auto general_product(TA::DistArray A, + TA::DistArray B, + ProductSetup const& setup, Setups const&... args) { + using TA::detail::max_nested_rank; + using TA::detail::nested_rank; + static_assert(nested_rank <= max_nested_rank); + static_assert(nested_rank != 0); + TA_ASSERT(setup.product_type == TensorProduct::General); + + auto& world = TA::get_default_world(); + + A.make_replicated(); + B.make_replicated(); + world.gop.fence(); + + TA::Tensor tensorA{A.trange().tiles_range()}; + for (auto&& ix : tensorA.range()) tensorA(ix) = A.find_local(ix).get(false); + + TA::Tensor tensorB{B.trange().tiles_range()}; + for (auto&& ix : tensorB.range()) tensorB(ix) = B.find_local(ix).get(false); + + auto result_tensor = general_product>( + tensorA, tensorB, setup, setup, args...); + + TA::TiledRange result_trange; + { + auto const rank = result_tensor.range().rank(); + auto const result_range = result_tensor.range(); + + TA::container::svector> tr1s(rank, {0}); + + TA::container::svector const ix_hi(result_range.upbound()); + for (auto d = 0; d < rank; ++d) { + TA::container::svector ix(result_range.lobound()); + for (auto& i = ix[d]; i < ix_hi[d]; ++i) { + auto const& elem_tensor = result_tensor(ix); + auto& tr1 = tr1s[d]; + tr1.emplace_back(tr1.back() + elem_tensor.range().extent(d)); + } + } + + TA::container::svector tr1s_explicit; + tr1s_explicit.reserve(tr1s.size()); + for (auto const& v : tr1s) tr1s_explicit.emplace_back(v.begin(), v.end()); + + result_trange = TA::TiledRange(tr1s_explicit); + } + + TA::DistArray C(world, result_trange); + + for (auto it : C) { + if (C.is_local(it.index())) it = result_tensor(it.index()); + } + return C; +} + +template +auto general_product(TA::DistArray A, + TA::DistArray B, + Setups const&... args) { + using TA::detail::nested_rank; + using TileC = std::conditional_t<(nested_rank > nested_rank), + TileB, TileA>; + return general_product(A, B, args...); +} + +template >> +auto manual_eval(OuterInnerIndices const& oixs, ArrayA A, ArrayB B) { + constexpr auto mnr = TA::detail::max_nested_rank; + static_assert(mnr == 1 || mnr == 2); + + auto const outer_setup = ProductSetup(oixs.outer()); + + TA_ASSERT(outer_setup.valid()); + + if constexpr (mnr == 2) { + auto const inner_setup = ProductSetup(oixs.inner()); + TA_ASSERT(inner_setup.valid()); + if constexpr (DeNestFlag == DeNest::True) { + // reduced nested rank in result + using TA::detail::nested_rank; + static_assert(nested_rank == nested_rank); + TA_ASSERT(inner_setup.rank_C == 0); + using TileC = typename ArrayA::value_type::value_type; + return general_product(A, B, outer_setup, inner_setup); + } else + return general_product(A, B, outer_setup, inner_setup); + } else { + return general_product(A, B, outer_setup); + } +} + +#ifdef TILEDARRAY_HAS_BTAS + +template >> +auto tensor_to_btas_tensor(T const& ta_tensor) { + using value_type = typename T::value_type; + using range_type = typename T::range_type; + + btas::Tensor result{ta_tensor.range()}; + TA::tensor_to_btas_subtensor(ta_tensor, result); + return result; +} + +template >> +auto btas_tensor_to_tensor( + btas::Tensor const& btas_tensor) { + TA::Tensor result{TA::Range(btas_tensor.range())}; + TA::btas_subtensor_to_tensor(btas_tensor, result); + return result; +} + +/// +/// @c einsum_annot pattern example: 'ik,kj->ij'. See tensor_contract function. +/// +template , bool> = true> +auto tensor_contract_btas(std::string const& einsum_annot, T const& A, + T const& B) { + using ::Einsum::string::split2; + auto [ab, aC] = split2(einsum_annot, "->"); + auto [aA, aB] = split2(ab, ","); + + using NumericT = typename T::numeric_type; + + struct { + btas::Tensor A, B, C; + } btas_tensor{tensor_to_btas_tensor(A), tensor_to_btas_tensor(B), {}}; + + btas::contract(NumericT{1}, btas_tensor.A, aA, btas_tensor.B, aB, NumericT{0}, + btas_tensor.C, aC); + + return btas_tensor_to_tensor(btas_tensor.C); +} + +/// +/// \tparam T TA::Tensor type +/// \param einsum_annot see tensor_contract_mult +/// \return True when TA::detail::tensor_contract and btas::contract result the +/// result. Performs bitwise comparison. +/// +template >> +auto tensor_contract_equal(std::string const& einsum_annot, T const& A, + T const& B) { + T result_ta = tensor_contract(einsum_annot, A, B); + T result_btas = tensor_contract_btas(einsum_annot, A, B); + return result_ta == result_btas; +} + +#endif /* * @@ -244,8 +803,8 @@ struct ToTArrayFixture { * * TODO: pmap comparisons */ - template + template static bool are_equal(const DistArray& lhs, const DistArray& rhs) { // Same type