Skip to content

Commit

Permalink
expression-level support for ToT x T (and vice versa) implemented, ne…
Browse files Browse the repository at this point in the history
…ed to test
  • Loading branch information
evaleev committed Nov 21, 2023
1 parent c199457 commit bff7d28
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 15 deletions.
19 changes: 10 additions & 9 deletions src/TiledArray/expressions/cont_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,10 @@ class ContEngine : public BinaryEngine<Derived> {
TensorProduct inner_product_type() const {
TA_ASSERT(inner_product_type_ !=
TensorProduct::Invalid); // init_indices() must initialize this
/// only Hadamard and contraction are supported now
/// only Hadamard, contraction, and scale are supported now
TA_ASSERT(inner_product_type_ == TensorProduct::Hadamard ||
inner_product_type_ == TensorProduct::Contraction);
inner_product_type_ == TensorProduct::Contraction ||
inner_product_type_ == TensorProduct::Scale);
return inner_product_type_;
}

Expand Down Expand Up @@ -473,7 +474,8 @@ class ContEngine : public BinaryEngine<Derived> {
result_tile_type, left_tile_type, right_tile_type>;
const auto inner_prod = this->inner_product_type();
TA_ASSERT(inner_prod == TensorProduct::Contraction ||
inner_prod == TensorProduct::Hadamard);
inner_prod == TensorProduct::Hadamard ||
inner_prod == TensorProduct::Scale);
if (inner_prod == TensorProduct::Contraction) {
TA_ASSERT(tot_x_tot);
if constexpr (tot_x_tot) {
Expand Down Expand Up @@ -577,8 +579,8 @@ class ContEngine : public BinaryEngine<Derived> {
}
};
}
} // ToT x ToT
} else if (inner_prod == TensorProduct::General) {
} // ToT x T or T x ToT
} else if (inner_prod == TensorProduct::Scale) {
TA_ASSERT(!tot_x_tot);
constexpr bool tot_x_t =
TiledArray::detail::is_tensor_of_tensor_v<result_tile_type,
Expand All @@ -596,20 +598,19 @@ class ContEngine : public BinaryEngine<Derived> {
std::conditional_t<tot_x_t, right_tile_element_type,
left_tile_element_type>;

auto scal_op = [do_perm = this->permute_tiles_,
perm = this->permute_tiles_ ? inner(this->perm_)
auto scal_op = [perm = this->permute_tiles_ ? inner(this->perm_)
: Permutation{}](
const left_tile_element_type& left,
const right_tile_element_type& right)
-> result_tile_element_type {
using TiledArray::scale;
if constexpr (tot_x_t) {
if (do_perm)
if (perm)
return scale(left, right, perm);
else
return scale(left, right);
} else if constexpr (tot_x_t) {
if (do_perm)
if (perm)
return scale(right, left, perm);
else
return scale(right, left);
Expand Down
5 changes: 4 additions & 1 deletion src/TiledArray/expressions/product.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ enum class TensorProduct {
Contraction,
/// free, fused, and contracted indices
General,
/// no indices on one, free indices on the other; only used for inner index
/// products in mixed nested products (ToT x T)
Scale,
/// invalid
Invalid = -1
};
Expand All @@ -59,7 +62,7 @@ inline TensorProduct compute_product_type(const IndexList& left_indices,
result = TensorProduct::Contraction;
} else if ((left_indices && !right_indices) ||
(!left_indices && right_indices)) { // used for ToT*T or T*ToT
result = TensorProduct::General;
result = TensorProduct::Scale;
}
return result;
}
Expand Down
49 changes: 44 additions & 5 deletions tests/einsum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,49 @@ BOOST_AUTO_TEST_SUITE_END() // einsum_tot

BOOST_AUTO_TEST_SUITE(einsum_tot_t)

BOOST_AUTO_TEST_CASE(ilkj_nm_eq_ij_mn_times_kl) {
using t_type = DistArray<Tensor<double>, SparsePolicy>;
using tot_type = DistArray<Tensor<Tensor<double>>, SparsePolicy>;
using matrix_il = TiledArray::detail::matrix_il<Tensor<double>>;
auto& world = TiledArray::get_default_world();
Tensor<double> lhs_elem_0_0(
Range{7, 2}, {49, 73, 28, 46, 12, 83, 29, 61, 61, 98, 57, 28, 96, 57});
Tensor<double> lhs_elem_0_1(
Range{7, 2}, {78, 15, 69, 55, 87, 94, 28, 94, 79, 30, 26, 88, 48, 74});
Tensor<double> lhs_elem_1_0(
Range{7, 2}, {70, 32, 25, 71, 6, 56, 4, 13, 72, 50, 15, 95, 52, 89});
Tensor<double> lhs_elem_1_1(
Range{7, 2}, {12, 29, 17, 68, 37, 79, 5, 52, 13, 35, 53, 54, 78, 71});
Tensor<double> lhs_elem_2_0(
Range{7, 2}, {77, 39, 34, 94, 16, 82, 63, 27, 75, 12, 14, 59, 3, 14});
Tensor<double> lhs_elem_2_1(
Range{7, 2}, {65, 90, 37, 41, 65, 75, 59, 16, 44, 85, 86, 11, 40, 24});
Tensor<double> lhs_elem_3_0(
Range{7, 2}, {77, 53, 11, 6, 99, 63, 46, 68, 83, 56, 76, 86, 91, 79});
Tensor<double> lhs_elem_3_1(
Range{7, 2}, {56, 11, 33, 90, 36, 38, 33, 54, 60, 21, 16, 28, 6, 97});
matrix_il lhs_il{{lhs_elem_0_0, lhs_elem_0_1},
{lhs_elem_1_0, lhs_elem_1_1},
{lhs_elem_2_0, lhs_elem_2_1},
{lhs_elem_3_0, lhs_elem_3_1}};
TiledRange lhs_trange{{0, 2, 4}, {0, 2}};
tot_type lhs(world, lhs_trange, lhs_il);

TiledRange rhs_trange{{0, 2}, {0, 2, 4, 6}};
t_type rhs(world, rhs_trange);
rhs.fill_random();

TiledRange ref_result_trange{lhs_trange.dim(0), rhs_trange.dim(1),
rhs_trange.dim(0)};
tot_type ref_result(world, ref_result_trange);
// TODO compute ref_result

tot_type result;
BOOST_REQUIRE_NO_THROW(result("i,l,k,j;n,m") = lhs("i,j;m,n") * rhs("k,l"));

// TODO check result against ref_result
}

BOOST_AUTO_TEST_CASE(ikj_mn_eq_ij_mn_times_jk) {
using t_type = DistArray<Tensor<double>, SparsePolicy>;
using tot_type = DistArray<Tensor<Tensor<double>>, SparsePolicy>;
Expand Down Expand Up @@ -764,11 +807,7 @@ BOOST_AUTO_TEST_CASE(ikj_mn_eq_ij_mn_times_jk) {
// tot_type result;
// BOOST_REQUIRE_NO_THROW(result("i,k,j;m,n") = lhs("i,j;m,n") * rhs("j,k"));

// will try to make this work FIRST since this is used by the einsum code
// below
tot_type out;
out("i,l,k,j;n,m") = lhs("i,j;m,n") * rhs("k,l");
// will try to make this work NEXT
// will try to make this work
// tot_type out = einsum(lhs("i,j;m,n"), rhs("j,k"), "i,j,k;m,n");
}

Expand Down

0 comments on commit bff7d28

Please sign in to comment.