Skip to content

Commit

Permalink
add consistency checks for allow halving/doubling flags. closes #266 (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
slayoo authored Jun 18, 2024
1 parent ebaffda commit 186eced
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 6 deletions.
12 changes: 11 additions & 1 deletion src/aero_state.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ auto pointer_vec_magic(arr_t &data_vec, const arg_t &arg) {
struct AeroState {
PMCResource ptr;
std::shared_ptr<AeroData> aero_data;
int allow_halving = -1, allow_doubling = -1;

AeroState(
std::shared_ptr<AeroData> aero_data,
Expand Down Expand Up @@ -572,7 +573,7 @@ struct AeroState {
}

static int dist_sample(
const AeroState &self,
AeroState &self,
const AeroDist &aero_dist,
const double &sample_prop,
const double &create_time,
Expand All @@ -581,6 +582,15 @@ struct AeroState {
) {
int n_part_add = 0;

if (
(self.allow_doubling != -1 && self.allow_doubling != allow_doubling) ||
(self.allow_halving != -1 && self.allow_halving != allow_halving)
)
throw std::runtime_error("dist_sample() called with different halving/doubling settings then in last call");

self.allow_doubling = allow_doubling;
self.allow_halving = allow_halving;

f_aero_state_add_aero_dist_sample(
self.ptr.f_arg(),
self.aero_data->ptr.f_arg(),
Expand Down
14 changes: 14 additions & 0 deletions src/run_part.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,17 @@
#include "run_part.hpp"
#include "pybind11/stl.h"

void check_allow_flags(
const AeroState &aero_state,
const RunPartOpt &run_part_opt
) {
if (
(aero_state.allow_halving != -1 && run_part_opt.allow_halving != aero_state.allow_halving) ||
(aero_state.allow_doubling != -1 && run_part_opt.allow_doubling != aero_state.allow_doubling)
)
throw std::runtime_error("allow halving/doubling flags set differently then while sampling");
}

void run_part(
const Scenario &scenario,
EnvState &env_state,
Expand All @@ -18,6 +29,7 @@ void run_part(
const CampCore &camp_core,
const Photolysis &photolysis
) {
check_allow_flags(aero_state, run_part_opt);
f_run_part(
scenario.ptr.f_arg(),
env_state.ptr.f_arg_non_const(),
Expand Down Expand Up @@ -47,6 +59,7 @@ std::tuple<double, double, int> run_part_timestep(
double &last_progress_time,
int &i_output
) {
check_allow_flags(aero_state, run_part_opt);
f_run_part_timestep(
scenario.ptr.f_arg(),
env_state.ptr.f_arg_non_const(),
Expand Down Expand Up @@ -84,6 +97,7 @@ std::tuple<double, double, int> run_part_timeblock(
double &last_progress_time,
int &i_output
) {
check_allow_flags(aero_state, run_part_opt);
f_run_part_timeblock(
scenario.ptr.f_arg(),
env_state.ptr.f_arg_non_const(),
Expand Down
3 changes: 3 additions & 0 deletions src/run_part_opt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ extern "C" void f_run_part_opt_del_t(const void *ptr, double *del_t) noexcept;

struct RunPartOpt {
PMCResource ptr;
bool allow_halving, allow_doubling;

RunPartOpt(const nlohmann::json &json) :
ptr(f_run_part_opt_ctor, f_run_part_opt_dtor)
Expand All @@ -39,6 +40,8 @@ struct RunPartOpt {
}))
if (json_copy.find(key) == json_copy.end())
json_copy[key] = true;
allow_halving = json_copy["allow_halving"];
allow_doubling = json_copy["allow_doubling"];

for (auto key : std::set<std::string>({
"t_output", "t_progress", "rand_init"
Expand Down
34 changes: 34 additions & 0 deletions tests/test_aero_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,3 +556,37 @@ def test_dist_sample_mono():

# assert
assert np.isclose(np.array(sut.diameters()), diam).all()

@staticmethod
@pytest.mark.parametrize(
"args",
(
((True, True), (True, False)),
((True, True), (False, True)),
((True, True), (False, False)),
((False, False), (True, False)),
((False, False), (False, True)),
((False, False), (True, True)),
((True, False), (False, False)),
((True, False), (False, True)),
((False, True), (False, False)),
((False, True), (True, False)),
),
)
@pytest.mark.skipif(platform.machine() == "arm64", reason="TODO #348")
def test_dist_sample_different_halving(args):
# arrange
aero_data = ppmc.AeroData(AERO_DATA_CTOR_ARG_MINIMAL)
aero_dist = ppmc.AeroDist(aero_data, [AERO_MODE_CTOR_SAMPLED])
sut = ppmc.AeroState(aero_data, *AERO_STATE_CTOR_ARG_MINIMAL)

# act
with pytest.raises(RuntimeError) as excinfo:
_ = sut.dist_sample(aero_dist, 1.0, 0.0, *args[0])
_ = sut.dist_sample(aero_dist, 1.0, 0.0, *args[1])

# assert
assert (
str(excinfo.value)
== "dist_sample() called with different halving/doubling settings then in last call"
)
4 changes: 2 additions & 2 deletions tests/test_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ def test_input_netcdf(tmp_path):
aero_dist,
sample_prop=1.0,
create_time=0.0,
allow_doubling=True,
allow_halving=True,
allow_doubling=False,
allow_halving=False,
)

num_concs = aero_state.num_concs
Expand Down
69 changes: 66 additions & 3 deletions tests/test_run_part.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# Authors: https://github.com/open-atmos/PyPartMC/graphs/contributors #
####################################################################################################

import platform

import numpy as np
import pytest

Expand Down Expand Up @@ -52,7 +54,7 @@ def test_run_part(common_args):

@staticmethod
def test_run_part_timestep(common_args):
(last_output_time, last_progress_time, i_output) = ppmc.run_part_timestep(
last_output_time, last_progress_time, i_output = ppmc.run_part_timestep(
*common_args, 1, 0, 0, 0, 1
)

Expand All @@ -63,14 +65,18 @@ def test_run_part_timestep(common_args):

@staticmethod
def test_run_part_timeblock(common_args):
# arrange
num_times = int(
RUN_PART_OPT_CTOR_ARG_SIMULATION["t_output"]
/ RUN_PART_OPT_CTOR_ARG_SIMULATION["del_t"]
)
(last_output_time, last_progress_time, i_output) = ppmc.run_part_timeblock(

# act
last_output_time, last_progress_time, i_output = ppmc.run_part_timeblock(
*common_args, 1, num_times, 0, 0, 0, 1
)

# assert
assert last_output_time == RUN_PART_OPT_CTOR_ARG_SIMULATION["t_output"]
assert last_progress_time == 0.0
assert i_output == 2
Expand All @@ -94,8 +100,65 @@ def test_run_part_do_condensation(common_args, tmp_path):
"do_condensation": True,
}
)
aero_state.dist_sample(aero_dist, 1.0, 0.0, True, True)
aero_state.dist_sample(aero_dist, 1.0, 0.0, False, False)
ppmc.condense_equilib_particles(env_state, aero_data, aero_state)
ppmc.run_part(*args)

assert np.sum(aero_state.masses(include=["H2O"])) > 0.0

@staticmethod
@pytest.mark.parametrize(
"flags",
(
((True, True), (True, False)),
((True, True), (False, True)),
((True, True), (False, False)),
((False, False), (True, False)),
((False, False), (False, True)),
((False, False), (True, True)),
((True, False), (False, False)),
((True, False), (False, True)),
((False, True), (False, False)),
((False, True), (True, False)),
),
)
@pytest.mark.parametrize(
"fun_args",
(
("run_part", []),
("run_part_timestep", [0, 0, 0, 0, 0]),
("run_part_timeblock", [0, 0, 0, 0, 0, 0]),
),
)
@pytest.mark.skipif(platform.machine() == "arm64", reason="TODO #348")
def test_run_part_allow_flag_mismatch(common_args, tmp_path, fun_args, flags):
# arrange
filename = tmp_path / "test"
env_state = ppmc.EnvState(ENV_STATE_CTOR_ARG_HIGH_RH)
aero_data = ppmc.AeroData(AERO_DATA_CTOR_ARG_FULL)
aero_dist = ppmc.AeroDist(aero_data, AERO_DIST_CTOR_ARG_FULL)
aero_state = ppmc.AeroState(aero_data, *AERO_STATE_CTOR_ARG_MINIMAL)
args = list(common_args)
args[0].init_env_state(env_state, 0.0)
args[1] = env_state
args[2] = aero_data
args[3] = aero_state
args[6] = ppmc.RunPartOpt(
{
**RUN_PART_OPT_CTOR_ARG_SIMULATION,
"output_prefix": str(filename),
"allow_doubling": flags[0][0],
"allow_halving": flags[0][1],
}
)
aero_state.dist_sample(aero_dist, 1.0, 0.0, flags[1][0], flags[1][1])

# act
with pytest.raises(RuntimeError) as excinfo:
getattr(ppmc, fun_args[0])(*args, *fun_args[1])

# assert
assert (
str(excinfo.value)
== "allow halving/doubling flags set differently then while sampling"
)

0 comments on commit 186eced

Please sign in to comment.