Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Outer mode contraction in tot made more efficient #495

Merged
merged 3 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 95 additions & 88 deletions src/TiledArray/einsum/tiledarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,9 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
using ResultTensor = typename ArrayC::value_type;
using ResultShape = typename ArrayC::shape_type;

auto const& tnsrExprA = A;
auto const& tnsrExprB = B;

auto a = std::get<0>(Einsum::idx(A));
auto b = std::get<0>(Einsum::idx(B));
Einsum::Index<std::string> c = std::get<0>(cs);
Expand Down Expand Up @@ -536,16 +539,10 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
// the evaluation can be delegated to the expression layer
// for distarrays of both nested and non-nested tensor tiles.
// *) If no Hadamard indices are present (!h) the evaluation
// can be delegated to the expression _only_ for distarrays with
// non-nested tensor tiles.
// This is because even if Hadamard indices are not present, a contracted
// index might be present pertinent to the outer tensor in case of a
// nested-tile distarray, which is especially handled within this
// function because expression layer cannot handle that yet.
// can be delegated to the expression layer.
//
if ((h && !(i || e)) // pure Hadamard
|| (IsArrayToT<ArrayC> && !(i || h)) // ToT result from outer-product
|| (IsArrayT<ArrayC> && !h)) // T from general product without Hadamard
if ((h && !(i || e)) // pure Hadamard
|| !h) // no Hadamard
{
ArrayC C;
C(std::string(c) + inner.c) = A * B;
Expand Down Expand Up @@ -577,21 +574,6 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
return C;
}

//
// when contraction happens in the outer tensor
// need to evaluate specially..
//
if (IsArrayToT<ArrayC> && i.size() > 0) {
auto annot_c = std::string(h + e + i) + inner.c;
auto temp1 = einsum(A, B, idx<ArrayC>(annot_c), world);
auto temp2 = reduce_modes(temp1, i.size());

auto annot_c_ = std::string(h + e) + inner.c;
decltype(temp2) result;
result(std::string(c) + inner.c) = temp2(annot_c_);
return result;
}

using ::Einsum::index::permutation;
using TiledArray::Permutation;

Expand Down Expand Up @@ -640,79 +622,104 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,

using Index = Einsum::Index<size_t>;

if constexpr (AreArraySame<ArrayA, ArrayB> &&
AreArraySame<ArrayB, ArrayC>) {
if (!e) { // hadamard reduction
auto &[A, B] = AB;
TiledRange trange(range_map[i]);
RangeProduct tiles;
for (auto idx : i) {
tiles *= Range(range_map[idx].tiles_range());
if (!e) { // hadamard reduction
auto &[A, B] = AB;
TiledRange trange(range_map[i]);
RangeProduct tiles;
for (auto idx : i) {
tiles *= Range(range_map[idx].tiles_range());
}
auto pa = A.permutation;
auto pb = B.permutation;
for (Index h : H.tiles) {
if (!C.array.is_local(h)) continue;
size_t batch = 1;
for (size_t i = 0; i < h.size(); ++i) {
batch *= H.batch[i].at(h[i]);
}
auto pa = A.permutation;
auto pb = B.permutation;
for (Index h : H.tiles) {
if (!C.array.is_local(h)) continue;
size_t batch = 1;
for (size_t i = 0; i < h.size(); ++i) {
batch *= H.batch[i].at(h[i]);
}
ResultTensor tile(TiledArray::Range{batch},
typename ResultTensor::value_type{});
for (Index i : tiles) {
// skip this unless both input tiles exist
const auto pahi_inv = apply_inverse(pa, h + i);
const auto pbhi_inv = apply_inverse(pb, h + i);
if (A.array.is_zero(pahi_inv) || B.array.is_zero(pbhi_inv))
continue;

auto ai = A.array.find(pahi_inv).get();
auto bi = B.array.find(pbhi_inv).get();
if (pa) ai = ai.permute(pa);
if (pb) bi = bi.permute(pb);
auto shape = trange.tile(i);
ai = ai.reshape(shape, batch);
bi = bi.reshape(shape, batch);
for (size_t k = 0; k < batch; ++k) {
using Ix = ::Einsum::Index<std::string>;
if constexpr (AreArrayToT<ArrayA, ArrayB>) {
auto aik = ai.batch(k);
auto bik = bi.batch(k);
auto vol = aik.total_size();
TA_ASSERT(vol == bik.total_size());

auto &el = tile({k});
using TensorT = std::remove_reference_t<decltype(el)>;

auto mult_op = [&inner](auto const &l,
auto const &r) -> TensorT {
return inner.h ? TA::detail::tensor_hadamard(l, inner.A, r,
inner.B, inner.C)
: TA::detail::tensor_contract(
l, inner.A, r, inner.B, inner.C);
};

for (auto i = 0; i < vol; ++i)
el.add_to(mult_op(aik.data()[i], bik.data()[i]));

} else {
auto hk = ai.batch(k).dot(bi.batch(k));
tile({k}) += hk;
}
ResultTensor tile(TiledArray::Range{batch},
typename ResultTensor::value_type{});
for (Index i : tiles) {
// skip this unless both input tiles exist
const auto pahi_inv = apply_inverse(pa, h + i);
const auto pbhi_inv = apply_inverse(pb, h + i);
if (A.array.is_zero(pahi_inv) || B.array.is_zero(pbhi_inv)) continue;

auto ai = A.array.find(pahi_inv).get();
auto bi = B.array.find(pbhi_inv).get();
if (pa) ai = ai.permute(pa);
if (pb) bi = bi.permute(pb);
auto shape = trange.tile(i);
ai = ai.reshape(shape, batch);
bi = bi.reshape(shape, batch);
for (size_t k = 0; k < batch; ++k) {
using Ix = ::Einsum::Index<std::string>;
if constexpr (AreArrayToT<ArrayA, ArrayB>) {
auto aik = ai.batch(k);
auto bik = bi.batch(k);
auto vol = aik.total_size();
TA_ASSERT(vol == bik.total_size());

auto &el = tile({k});
using TensorT = std::remove_reference_t<decltype(el)>;

auto mult_op = [&inner](auto const &l, auto const &r) -> TensorT {
if (l.empty() || r.empty()) return TensorT{};
return inner.h ? TA::detail::tensor_hadamard(l, inner.A, r,
inner.B, inner.C)
: TA::detail::tensor_contract(l, inner.A, r,
inner.B, inner.C);
};

for (auto i = 0; i < vol; ++i)
el.add_to(mult_op(aik.data()[i], bik.data()[i]));

} else if constexpr (!AreArraySame<ArrayA, ArrayB>) {
auto aik = ai.batch(k);
auto bik = bi.batch(k);
auto vol = aik.total_size();
TA_ASSERT(vol == bik.total_size());

auto &el = tile({k});

for (auto i = 0; i < vol; ++i)
if constexpr (IsArrayToT<ArrayA>) {
el.add_to(aik.data()[i].scale(bik.data()[i]));
} else {
el.add_to(bik.data()[i].scale(aik.data()[i]));
}

} else {
auto hk = ai.batch(k).dot(bi.batch(k));
tile({k}) += hk;
}
}
auto pc = C.permutation;
auto shape = apply_inverse(pc, C.array.trange().tile(h));
tile = tile.reshape(shape);
if (pc) tile = tile.permute(pc);
C.array.set(h, tile);
}
return C.array;
auto pc = C.permutation;
auto shape = apply_inverse(pc, C.array.trange().tile(h));
tile = tile.reshape(shape);
if (pc) tile = tile.permute(pc);
C.array.set(h, tile);
}
return C.array;
}

// generalized contraction

if constexpr (IsArrayToT<ArrayC>) {
if (inner.C != inner.h + inner.e) {
// when inner tensor permutation is non-trivial (could be potentially
// elided by extending this function (@c einsum) to take into account
// of inner tensor's permutations)
auto temp_annot = std::string(c) + ";" + std::string(inner.h + inner.e);
ArrayC temp = einsum(tnsrExprA, tnsrExprB,
Einsum::idx<ArrayC>(temp_annot), world);
ArrayC result;
result(std::string(c) + inner.c) = temp(temp_annot);
return result;
}
}

auto update_tr = [&e = std::as_const(e), &i = std::as_const(i),
&range_map = std::as_const(range_map)](auto &term) {
auto ei = (e + i & term.idx);
Expand Down
54 changes: 47 additions & 7 deletions src/TiledArray/expressions/cont_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -279,25 +279,62 @@ class ContEngine : public BinaryEngine<Derived> {
outer_size(left_indices_), outer_size(right_indices_),
(!implicit_permute_outer_ ? std::move(outer_perm) : Permutation{}));
} else {

auto make_total_perm = [this]() -> BipartitePermutation {
if (this->product_type() != TensorProduct::Contraction
|| this->implicit_permute_inner_)
return this->implicit_permute_outer_
? BipartitePermutation()
: BipartitePermutation(outer(this->perm_));

// Here,
// this->product_type() is Tensor::Contraction, and,
// this->implicit_permute_inner_ is false

return this->inner_product_type() == TensorProduct::Scale
? BipartitePermutation(outer(this->perm_))
: this->perm_;
};

auto total_perm = make_total_perm();

// factor_ is absorbed into inner_tile_nonreturn_op_
op_ = op_type(
left_op, right_op, scalar_type(1), outer_size(indices_),
outer_size(left_indices_), outer_size(right_indices_),
(!implicit_permute_outer_ ? std::move(outer_perm) : Permutation{}),
total_perm,
this->element_nonreturn_op_);
}
trange_ = ContEngine_::make_trange(outer_perm);
shape_ = ContEngine_::make_shape(outer_perm);
} else {
// Initialize non-permuted structure

if constexpr (!TiledArray::detail::is_tensor_of_tensor_v<value_type>) {
op_ = op_type(left_op, right_op, factor_, outer_size(indices_),
outer_size(left_indices_), outer_size(right_indices_));
} else {

auto make_total_perm = [this]() -> BipartitePermutation {
if (this->product_type() != TensorProduct::Contraction
|| this->implicit_permute_inner_)
return {};

// Here,
// this->product_type() is Tensor::Contraction, and,
// this->implicit_permute_inner_ is false

return this->inner_product_type() == TensorProduct::Scale
? BipartitePermutation(outer(this->perm_))
: this->perm_;
};

auto total_perm = make_total_perm();

// factor_ is absorbed into inner_tile_nonreturn_op_
op_ = op_type(left_op, right_op, scalar_type(1), outer_size(indices_),
outer_size(left_indices_), outer_size(right_indices_),
BipartitePermutation{}, this->element_nonreturn_op_);
total_perm, this->element_nonreturn_op_);
}
trange_ = ContEngine_::make_trange();
shape_ = ContEngine_::make_shape();
Expand Down Expand Up @@ -509,12 +546,15 @@ class ContEngine : public BinaryEngine<Derived> {
inner_size(this->left_indices_),
inner_size(this->right_indices_));
this->element_nonreturn_op_ =
[contrreduce_op](result_tile_element_type& result,
const left_tile_element_type& left,
const right_tile_element_type& right) {
[contrreduce_op, permute_inner = this->product_type() !=
TensorProduct::Contraction](
result_tile_element_type& result,
const left_tile_element_type& left,
const right_tile_element_type& right) {
contrreduce_op(result, left, right);
if (!TA::empty(result))
result = contrreduce_op(result); // permutations of result are applied as "postprocessing"
// permutations of result are applied as "postprocessing"
if (permute_inner && !TA::empty(result))
result = contrreduce_op(result);
};
} // ToT x ToT
} else if (inner_prod == TensorProduct::Hadamard) {
Expand Down
11 changes: 6 additions & 5 deletions src/TiledArray/tensor/kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -417,14 +417,15 @@ inline void inplace_tensor_op(Op&& op, TR& result, const Ts&... tensors) {
TA_ASSERT(!empty(result, tensors...));
TA_ASSERT(is_range_set_congruent(result, tensors...));

const auto volume = result.range().volume();

for (decltype(result.range().volume()) ord = 0ul; ord < volume; ++ord) {
auto volume = result.total_size();
for (decltype(volume) ord = 0; ord < volume; ++ord) {
if constexpr (is_tensor_of_tensor_v<TR, Ts...>)
if (((tensors.data()[ord].range().volume() == 0) || ...)) continue;
evaleev marked this conversation as resolved.
Show resolved Hide resolved
if constexpr (std::is_invocable_r_v<void, Op, typename TR::value_type&,
typename Ts::value_type...>)
op(result.at_ordinal(ord), tensors.at_ordinal(ord)...);
evaleev marked this conversation as resolved.
Show resolved Hide resolved
op(result.data()[ord], tensors.data()[ord]...);
else
inplace_tensor_op(op, result.at_ordinal(ord), tensors.at_ordinal(ord)...);
inplace_tensor_op(op, result.data()[ord], tensors.data()[ord]...);
}
}

Expand Down
1 change: 1 addition & 0 deletions src/TiledArray/tensor/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1630,6 +1630,7 @@ class Tensor {
template <typename Right,
typename std::enable_if<is_tensor<Right>::value>::type* = nullptr>
Tensor add(const Right& right) const& {
if (right.empty()) return *this;
return binary(
right,
[](const value_type& l, const value_t<Right>& r) -> decltype(auto) {
Expand Down
1 change: 0 additions & 1 deletion src/TiledArray/tile_op/contract_reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,6 @@ class ContractReduce : public ContractReduceBase<Result, Left, Right, Scalar> {

if constexpr (!ContractReduceBase_::plain_tensors) {
TA_ASSERT(this->elem_muladd_op());
// not yet implemented
gemm(result, left, right, ContractReduceBase_::gemm_helper(),
this->elem_muladd_op());
} else { // plain tensors
Expand Down
Loading