Skip to content

Commit

Permalink
[unit] expression suite: can contract + reduce zero-volume DistArrays
Browse files Browse the repository at this point in the history
  • Loading branch information
evaleev committed Feb 16, 2024
1 parent df7e0c8 commit 9eaeafa
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 26 deletions.
8 changes: 5 additions & 3 deletions src/TiledArray/dist_eval/contraction_eval.h
Original file line number Diff line number Diff line change
Expand Up @@ -891,9 +891,9 @@ class Summa
ordinal_type initialize(const DenseShape&) {
// Construct static broadcast groups for dense arguments
const madness::DistributedID col_did(DistEvalImpl_::id(), 0ul);
col_group_ = proc_grid_.make_col_group(col_did);
if (k_ > 0) col_group_ = proc_grid_.make_col_group(col_did);
const madness::DistributedID row_did(DistEvalImpl_::id(), k_);
row_group_ = proc_grid_.make_row_group(row_did);
if (k_ > 0) row_group_ = proc_grid_.make_row_group(row_did);

#ifdef TILEDARRAY_ENABLE_SUMMA_TRACE_INITIALIZE
std::stringstream ss;
Expand Down Expand Up @@ -1347,7 +1347,6 @@ class Summa

template <typename Derived>
void make_next_step_tasks(Derived* task, ordinal_type depth) {
TA_ASSERT(depth > 0);
// Set the depth to be no greater than the maximum number steps
if (depth > owner_->k_) depth = owner_->k_;

Expand Down Expand Up @@ -1706,6 +1705,9 @@ class Summa
std::max(ProcGrid::size_type(2),
std::min(proc_grid_.proc_rows(), proc_grid_.proc_cols()));

// corner case: empty result
if (k_ == 0) return 0;

// Construct the first SUMMA iteration task
if (TensorImpl_::shape().is_dense()) {
// We cannot have more iterations than there are blocks in the k
Expand Down
4 changes: 0 additions & 4 deletions src/TiledArray/pmap/cyclic_pmap.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,6 @@ class CyclicPmap : public Pmap {
cols_(cols),
proc_cols_(proc_cols),
proc_rows_(proc_rows) {
// Check that the size is non-zero
TA_ASSERT(rows_ >= 1ul);
TA_ASSERT(cols_ >= 1ul);

// Check limits of process rows and columns
TA_ASSERT(proc_rows_ >= 1ul);
TA_ASSERT(proc_cols_ >= 1ul);
Expand Down
6 changes: 0 additions & 6 deletions src/TiledArray/proc_grid.h
Original file line number Diff line number Diff line change
Expand Up @@ -288,12 +288,6 @@ class ProcGrid {
local_rows_(0ul),
local_cols_(0ul),
local_size_(0ul) {
// Check for non-zero sizes
TA_ASSERT(rows_ >= 1u);
TA_ASSERT(cols_ >= 1u);
TA_ASSERT(row_size >= 1ul);
TA_ASSERT(col_size >= 1ul);

init(world_->rank(), world_->size(), row_size, col_size);
}

Expand Down
50 changes: 37 additions & 13 deletions tests/expressions_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -2947,24 +2947,48 @@ BOOST_FIXTURE_TEST_CASE_TEMPLATE(empty_trange1, F, Fixtures, F) {
auto& c = F::c;
auto& aC = F::aC;

BOOST_CHECK_NO_THROW(c("a,b,c") = aC("a,b,c"));
BOOST_CHECK_NO_THROW(c("a,b,c") += aC("a,b,c"));
BOOST_CHECK_NO_THROW(c("a,b,c") *= aC("a,b,c"));
BOOST_CHECK_NO_THROW(c("a,b,c") *= 2 * aC("a,b,c"));
BOOST_CHECK_NO_THROW(c("a,b,c") += 2 * aC("a,b,c").conj());
BOOST_CHECK_NO_THROW(c("a,b,c") = aC("a,c,b"));
BOOST_CHECK_NO_THROW(c("a,b,c") += 2 * aC("a,c,b").conj());
BOOST_CHECK_NO_THROW(c("a,b,c") *= 2 * aC("a,c,b").conj());
// unary/binary expressions
{
BOOST_CHECK_NO_THROW(c("a,b,c") = aC("a,b,c"));
BOOST_CHECK_NO_THROW(c("a,b,c") += aC("a,b,c"));
BOOST_CHECK_NO_THROW(c("a,b,c") *= aC("a,b,c"));
BOOST_CHECK_NO_THROW(c("a,b,c") *= 2 * aC("a,b,c"));
BOOST_CHECK_NO_THROW(c("a,b,c") += 2 * aC("a,b,c").conj());
BOOST_CHECK_NO_THROW(c("a,b,c") = aC("a,c,b"));
BOOST_CHECK_NO_THROW(c("a,b,c") += 2 * aC("a,c,b").conj());
BOOST_CHECK_NO_THROW(c("a,b,c") *= 2 * aC("a,c,b").conj());
}

using TiledArray::eigen::iv;
const std::array<int, 3> lobound{{0, 0, 1}};
const std::array<int, 3> upbound{{1, 0, 2}};

BOOST_CHECK_NO_THROW(c("a,b,c") = aC("a,b,c").block(lobound, upbound));
BOOST_CHECK_NO_THROW(c("a,b,c") +=
2 * aC("a,b,c").block(lobound, upbound).conj());
BOOST_CHECK_NO_THROW(c("a,b,c") =
2 * conj(aC("a,c,b").block(lobound, upbound)));
// unary/binary block expressions
{
BOOST_CHECK_NO_THROW(c("a,b,c") = aC("a,b,c").block(lobound, upbound));
BOOST_CHECK_NO_THROW(c("a,b,c") +=
2 * aC("a,b,c").block(lobound, upbound).conj());
BOOST_CHECK_NO_THROW(c("a,b,c") =
2 * conj(aC("a,c,b").block(lobound, upbound)));
}

// contraction expressions
{
std::decay_t<decltype(c)> t2, t4;
// contraction over empty dim
BOOST_CHECK_NO_THROW(t4("a,c,e,d") = aC("a,b,c") * aC("d,b,e"));
// contraction over empty and nonempty dims
BOOST_CHECK_NO_THROW(t2("a,d") = aC("a,b,c") * aC("d,b,c"));
// contraction over nonempty dims
BOOST_CHECK_NO_THROW(t4("b,a,e,d") = aC("a,b,c") * aC("d,e,c"));
}

// reduction expressions
{
// contraction over empty dim
BOOST_CHECK_NO_THROW(aC("a,b,c").dot(2 * aC("a,b,c").conj()).get());
BOOST_CHECK_EQUAL(aC("a,b,c").dot(2 * aC("a,b,c").conj()).get(), 0);
}
}

BOOST_AUTO_TEST_SUITE_END()
Expand Down

0 comments on commit 9eaeafa

Please sign in to comment.