Skip to content

Commit

Permalink
relax type requirements on tensor_init to support mixed (ToT alongsid…
Browse files Browse the repository at this point in the history
…e T) invocations, this allows T * ToT expr to compile and unit test to succeed
  • Loading branch information
evaleev committed Nov 29, 2023
1 parent 8341bbb commit 56b49a0
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 9 deletions.
7 changes: 4 additions & 3 deletions src/TiledArray/tensor/kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -541,9 +541,10 @@ inline void tensor_init(Op&& op, const Permutation& perm, TR& result,
/// \param[out] result The result tensor
/// \param[in] tensor1 The first argument tensor
/// \param[in] tensors The argument tensors
template <typename Op, typename TR, typename T1, typename... Ts,
typename std::enable_if<
is_tensor_of_tensor<TR, T1, Ts...>::value>::type* = nullptr>
template <
typename Op, typename TR, typename T1, typename... Ts,
typename std::enable_if<is_nested_tensor<TR, T1, Ts...>::value &&
!is_tensor<TR, T1, Ts...>::value>::type* = nullptr>
inline void tensor_init(Op&& op, const Permutation& perm, TR& result,
const T1& tensor1, const Ts&... tensors) {
TA_ASSERT(!empty(result, tensor1, tensors...));
Expand Down
12 changes: 6 additions & 6 deletions tests/einsum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -803,12 +803,12 @@ BOOST_AUTO_TEST_CASE(ilkj_nm_eq_ij_mn_times_kl) {
const bool are_equal = ToTArrayFixture::are_equal(result, ref_result);
BOOST_CHECK(are_equal);

// { // reverse the order
// tot_type result;
// BOOST_REQUIRE_NO_THROW(result("i,l,k,j;n,m") = rhs("k,l") * lhs("i,j;m,n"));
// const bool are_equal = ToTArrayFixture::are_equal(result, ref_result);
// BOOST_CHECK(are_equal);
// }
{ // reverse the order
tot_type result;
BOOST_REQUIRE_NO_THROW(result("i,l,k,j;n,m") = rhs("k,l") * lhs("i,j;m,n"));
const bool are_equal = ToTArrayFixture::are_equal(result, ref_result);
BOOST_CHECK(are_equal);
}
}

BOOST_AUTO_TEST_CASE(ikj_mn_eq_ij_mn_times_jk) {
Expand Down

0 comments on commit 56b49a0

Please sign in to comment.