Skip to content

Commit

Permalink
[WIP]: Make binary_egine less restrictive on left and right arg types.
Browse files Browse the repository at this point in the history
  • Loading branch information
bimalgaudel committed Nov 20, 2023
1 parent a60315d commit a9a6b58
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 12 deletions.
21 changes: 12 additions & 9 deletions src/TiledArray/einsum/tiledarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ auto einsum(expressions::TsrExpr<ArrayT_> A, expressions::TsrExpr<ArrayToT_> B,
Einsum::Index<std::string> c = std::get<0>(cs);

struct {
std::string a, b, c;
std::string b, c;
} inner;
if constexpr (std::tuple_size<decltype(cs)>::value == 2) {
inner.b = ";" + (std::string)std::get<1>(Einsum::idx(B));
Expand All @@ -319,16 +319,13 @@ auto einsum(expressions::TsrExpr<ArrayT_> A, expressions::TsrExpr<ArrayToT_> B,
// these are "Hadamard" (fused) indices
auto h = a & b & c;

auto e = (a ^ b);
// contracted indices
auto i = (a & b) - h;
// contraction not allowed in tensor x tensor-of-tensor
TA_ASSERT(!i);

// cannot be hadamard reduction type operation for this overload
TA_ASSERT(e);

// no Hadamard indices => standard contraction (or even outer product)
// same a, b, and c => pure Hadamard
TA_ASSERT(!h || (!(a ^ b) && !(b ^ c)));
// indices exclusively in 'a' or exclusively in 'b'
auto e = (a ^ b);

// maps Index to TiledRange1
// (asserts same index maps to the same TR1 in A, and B)
Expand Down Expand Up @@ -364,6 +361,9 @@ auto einsum(expressions::TsrExpr<ArrayT_> A, expressions::TsrExpr<ArrayToT_> B,
}
C.expr = e;

arrayTermB.expr += inner.b;
C.expr += inner.c;

struct {
RangeProduct tiles;
std::vector<std::vector<size_t>> batch;
Expand Down Expand Up @@ -453,7 +453,10 @@ auto einsum(expressions::TsrExpr<ArrayT_> A, expressions::TsrExpr<ArrayToT_> B,
}

// todo
// C.ei(C.expr) = (A.ei(A.expr) * B.ei(B.expr)).set_world(*owners);
C.ei(C.expr) = (A.ei(A.expr) * B.ei(B.expr)).set_world(*owners);

//

A.ei.defer_deleter_to_next_fence();
B.ei.defer_deleter_to_next_fence();
A.ei = ArrayT();
Expand Down
19 changes: 16 additions & 3 deletions src/TiledArray/expressions/binary_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,10 @@ class BinaryEngine : public ExprEngine<Derived> {
TiledArray::detail::is_tensor_of_tensor_v<left_tile_type>;
constexpr bool right_tile_is_tot =
TiledArray::detail::is_tensor_of_tensor_v<right_tile_type>;
static_assert(!(left_tile_is_tot ^ right_tile_is_tot),
"ContEngine can only handle tensors of same nested-ness "
"(both plain or both ToT)");
constexpr bool args_are_plain_tensors =
!left_tile_is_tot && !right_tile_is_tot;
constexpr bool args_are_mixed_tensors =
left_tile_is_tot ^ right_tile_is_tot;
if (args_are_plain_tensors &&
(left_outer_permtype_ == PermutationType::matrix_transpose ||
left_outer_permtype_ == PermutationType::identity)) {
Expand All @@ -175,6 +174,20 @@ class BinaryEngine : public ExprEngine<Derived> {
right_inner_permtype_ == PermutationType::identity))) {
right_.permute_tiles(false);
}
if (args_are_mixed_tensors &&
((left_outer_permtype_ == PermutationType::matrix_transpose ||
left_outer_permtype_ == PermutationType::identity) ||
(left_inner_permtype_ == PermutationType::matrix_transpose ||
left_inner_permtype_ == PermutationType::identity))) {
left_.permute_tiles(false);
}
if (args_are_mixed_tensors &&
((left_outer_permtype_ == PermutationType::matrix_transpose ||
left_outer_permtype_ == PermutationType::identity) ||
(right_inner_permtype_ == PermutationType::matrix_transpose ||
right_inner_permtype_ == PermutationType::identity))) {
right_.permute_tiles(false);
}
}

public:
Expand Down

0 comments on commit a9a6b58

Please sign in to comment.