Skip to content

Commit

Permalink
Merge pull request #433 from ValeevGroup/gaudel/feature/t_x_tot_expr_…
Browse files Browse the repository at this point in the history
…support

tiny step towards supporting T*ToT in expr
  • Loading branch information
bimalgaudel authored Nov 17, 2023
2 parents 65f4374 + ab0698d commit a60315d
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 13 deletions.
7 changes: 4 additions & 3 deletions src/TiledArray/tensor/type_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ struct is_nested_tensor<T1, T2, Ts...> {
/// @c is_nested_tensor_v<Ts...> is an alias for @c
/// is_nested_tensor<Ts...>::value
template <typename... Ts>
constexpr const bool is_nested_tensor_v = is_nested_tensor<Ts...>::value;
inline constexpr const bool is_nested_tensor_v = is_nested_tensor<Ts...>::value;

////////////////////////////////////////////////////////////////////////////////

Expand Down Expand Up @@ -150,7 +150,7 @@ struct is_tensor<T1, T2, Ts...> {
/// @tparam Ts a parameter pack
/// @c is_tensor_v<Ts...> is an alias for @c is_tensor<Ts...>::value
template <typename... Ts>
constexpr const bool is_tensor_v = is_tensor<Ts...>::value;
inline constexpr const bool is_tensor_v = is_tensor<Ts...>::value;

////////////////////////////////////////////////////////////////////////////////

Expand All @@ -172,7 +172,8 @@ struct is_tensor_of_tensor<T1, T2, Ts...> {
/// @c is_tensor_of_tensor_v<Ts...> is an alias for @c
/// is_tensor_of_tensor<Ts...>::value
template <typename... Ts>
constexpr const bool is_tensor_of_tensor_v = is_tensor_of_tensor<Ts...>::value;
inline constexpr const bool is_tensor_of_tensor_v =
is_tensor_of_tensor<Ts...>::value;

////////////////////////////////////////////////////////////////////////////////

Expand Down
23 changes: 13 additions & 10 deletions src/TiledArray/tile_op/contract_reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,20 @@ class ContractReduceBase {
using elem_muladd_op_type = void(result_value_type&, const left_value_type&,
const right_value_type&);

static_assert(
TiledArray::detail::is_tensor_v<left_value_type> ==
TiledArray::detail::is_tensor_v<right_value_type> &&
TiledArray::detail::is_tensor_v<left_value_type> ==
TiledArray::detail::is_tensor_v<result_value_type>,
"ContractReduce can only handle plain tensors or nested tensors "
"(tensors-of-tensors); mixed contractions are not supported");
static constexpr bool plain_tensors =
!(TiledArray::detail::is_tensor_v<left_value_type> &&
TiledArray::detail::is_tensor_v<right_value_type> &&
TiledArray::detail::is_tensor_v<result_value_type>);
!TiledArray::detail::is_nested_tensor_v<left_value_type> &&
!TiledArray::detail::is_nested_tensor_v<right_value_type> &&
!TiledArray::detail::is_nested_tensor_v<result_value_type>;
static constexpr bool nested_tensors =
TiledArray::detail::is_nested_tensor_v<left_value_type, right_value_type,
result_value_type>;
static constexpr bool mixed_tensors = !plain_tensors && !nested_tensors;
static_assert(!mixed_tensors ||
(mixed_tensors &&
TiledArray::detail::is_nested_tensor_v<result_value_type>),
"ContractReduce applied to 1 plain tensor and 1 nested tensor "
"must produce a nested tensor "
"(tensors-of-tensors)");

private:
struct Impl {
Expand Down

0 comments on commit a60315d

Please sign in to comment.