Skip to content

Commit

Permalink
[ut] einsum_tot/ijk_mn_eq_ij_mn_times_kj_mn : how to compute ref_result
Browse files Browse the repository at this point in the history
  • Loading branch information
evaleev committed Dec 23, 2023
1 parent ba0be00 commit 987040b
Showing 1 changed file with 12 additions and 22 deletions.
34 changes: 12 additions & 22 deletions tests/einsum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,10 @@ BOOST_AUTO_TEST_CASE(ijk_mn_eq_ij_mn_times_kj_mn) {
rhs.trange().dim(0)};
tot_type ref_result(world, ref_result_trange);

// to be able to pull remote tiles make them local AND ready
lhs.make_replicated();
rhs.make_replicated();
world.gop.fence();
auto make_tile = [&lhs, &rhs](TA::Range const& rng) {
tot_type::value_type result_tile{rng};
for (auto&& res_ix : result_tile.range()) {
Expand All @@ -630,9 +634,9 @@ BOOST_AUTO_TEST_CASE(ijk_mn_eq_ij_mn_times_kj_mn) {
using Ix3 = std::array<decltype(i), 3>;

auto lhs_tile_ix = lhs.trange().element_to_tile(Ix2{i, j});
auto lhs_tile = lhs.find(lhs_tile_ix).get(/* dowork = */ false);
auto lhs_tile = lhs.find_local(lhs_tile_ix).get(/* dowork = */ false);
auto rhs_tile_ix = rhs.trange().element_to_tile(Ix2{k, j});
auto rhs_tile = rhs.find(rhs_tile_ix).get(/* dowork = */ false);
auto rhs_tile = rhs.find_local(rhs_tile_ix).get(/* dowork = */ false);

auto& res_el =
result_tile.at_ordinal(result_tile.range().ordinal(Ix3{i, j, k}));
Expand All @@ -647,28 +651,14 @@ BOOST_AUTO_TEST_CASE(ijk_mn_eq_ij_mn_times_kj_mn) {
using std::begin;
using std::end;

const auto have_spare_threads = madness::ThreadPool::size() > 0;
if (have_spare_threads) {
for (auto it = begin(ref_result); it != end(ref_result); ++it) {
if (ref_result.is_local(it.index())) {
// using tasks does not work because:
// - make_tile pulls possibly remote data
// - but it also blocks thread on a remote tile futures, whose
// fulfillment requires available threads in the pool
//
// *it = world.taskq.add(make_tile, it.make_range());

// this technically will only work if the number of free threads in the
// pool is > 0 (i.e. main is not part of the pool or pool has 2 threads)
//
// OK, fine, @bosilca, blocking in tasks is BAD
*it = make_tile(it.make_range());
}
for (auto it = begin(ref_result); it != end(ref_result); ++it) {
if (ref_result.is_local(it.index())) {
*it = world.taskq.add(make_tile, it.make_range());
}
bool are_equal =
ToTArrayFixture::are_equal<ShapeComp::False>(result, ref_result);
BOOST_REQUIRE(are_equal);
}
bool are_equal =
ToTArrayFixture::are_equal<ShapeComp::False>(result, ref_result);
BOOST_REQUIRE(are_equal);
}

BOOST_AUTO_TEST_CASE(xxx) {
Expand Down

0 comments on commit 987040b

Please sign in to comment.