Skip to content

Commit

Permalink
stupid linting
Browse files Browse the repository at this point in the history
  • Loading branch information
20DM committed Nov 22, 2024
1 parent d861da5 commit dd9fba6
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 16 deletions.
13 changes: 6 additions & 7 deletions cpp/purify/h5reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class H5Handler {
if (!_comm) throw std::runtime_error("No MPI-collective reading enabled!");

_loadDataSet(label);
if (shuffle) _shuffle();
if (shuffle) _shuffle();

std::vector<T> data;
data.reserve(batchsize);
Expand All @@ -103,15 +103,14 @@ class H5Handler {
size_t len = std::min(batchsize, _slicepos + _slicelen - pos);
_ds[label].select({pos}, {len}).read(tmp, _dtp);
data.insert(data.end(), std::make_move_iterator(std::begin(tmp)),
std::make_move_iterator(std::end(tmp)));
std::make_move_iterator(std::end(tmp)));
pos = _slicepos;
batchsize -= len;
}
return data;
}

private:

void _loadDataSet(const std::string& label) {
if (_ds.find(label) != _ds.end()) return;

Expand All @@ -128,14 +127,13 @@ class H5Handler {
if (_comm->rank() == _comm->size() - 1) {
_slicelen += _datalen % _comm->size();
}
}
else if (len != _datalen) {
} else if (len != _datalen) {
throw std::runtime_error("Inconsistent dataset length!");
}
}

void _shuffle() {
std::uniform_int_distribution<size_t> uni(_slicepos,_slicepos+_slicelen-1);
std::uniform_int_distribution<size_t> uni(_slicepos,_slicepos + _slicelen - 1);
_batchpos = uni(_rng);
}

Expand Down Expand Up @@ -200,7 +198,8 @@ utilities::vis_params read_visibility(const std::string& vis_name, const bool w_
utilities::vis_params stochread_visibility(H5Handler& file, const size_t N, const bool w_term) {
utilities::vis_params uv_vis;

std::vector<t_real> utemp = file.stochread<t_real>("u", N, true); //< shuffle batch starting position
std::vector<t_real> utemp =
file.stochread<t_real>("u", N, true); //< shuffle batch starting position
uv_vis.u = Eigen::Map<Vector<t_real>>(utemp.data(), utemp.size(), 1);

// found that a reflection is needed for the orientation
Expand Down
14 changes: 5 additions & 9 deletions cpp/tests/mpi_read_measurements.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,17 @@
#include "catch2/catch_all.hpp"
#include "purify/logging.h"

#include <sopt/gradient_utils.h>
#include <iostream>
#include "purify/directories.h"
#include "purify/read_measurements.h"
#include <sopt/gradient_utils.h>
#ifdef PURIFY_H5
#include "purify/h5reader.h"
#include "purify/measurement_operator_factory.h"
#endif

using namespace purify;


TEST_CASE("uvfits") {
auto const comm = sopt::mpi::Communicator::World();
const std::string filename = atca_filename("0332-391");
Expand Down Expand Up @@ -103,7 +102,7 @@ TEST_CASE("uvfits") {
// and constructs a uv_params object from it
const size_t N = 10000;
H5::H5Handler f(filename + ".h5", comm);
const auto uvfits = H5::stochread_visibility(f, N, true); //< true = include w-term
const auto uvfits = H5::stochread_visibility(f, N, true); //< true = include w-term
CAPTURE(uvfits.size());
CHECK(comm.all_sum_all(uvfits.size()) == N * comm.size());
}
Expand All @@ -119,21 +118,18 @@ TEST_CASE("uvfits") {
auto functor = [&f = h5file, &N]() {
utilities::vis_params uv_data = H5::stochread_visibility(f, N, true);
auto phi = factory::measurement_operator_factory<t_complexVec>(
factory::distributed_measurement_operator::mpi_distribute_image,
uv_data, 128, 128, 1, 1, 2,
kernels::kernel_from_string.at("kb"), 4, 4);
factory::distributed_measurement_operator::mpi_distribute_image, uv_data, 128, 128, 1,
1, 2, kernels::kernel_from_string.at("kb"), 4, 4);

return sopt::IterationState<t_complexVec>(uv_data.vis, phi);

};

// And it would be called in Sopt like this
sopt::IterationState<t_complexVec> item = functor();

// Make sure the return values are sensible
const bool pass = comm.all_sum_all(item.target().size()) == N * comm.size() &&
item.phi().sizes()[0] == 0 &&
item.phi().sizes()[1] == 1 &&
item.phi().sizes()[0] == 0 && item.phi().sizes()[1] == 1 &&
item.phi().sizes()[2] == N;
CHECK(pass);
}
Expand Down

0 comments on commit dd9fba6

Please sign in to comment.