diff --git a/cpp/examples/abm_history_object.cpp b/cpp/examples/abm_history_object.cpp index 91a05759f3..e98528a53d 100644 --- a/cpp/examples/abm_history_object.cpp +++ b/cpp/examples/abm_history_object.cpp @@ -22,7 +22,9 @@ #include "abm/simulation.h" #include "abm/model.h" #include "abm/location_type.h" +#include "memilio/utils/parameter_distribution_wrapper.h" #include "memilio/io/history.h" +#include "memilio/utils/parameter_distributions.h" #include #include @@ -68,9 +70,9 @@ int main() // Create the model with 4 age groups. auto model = mio::abm::Model(num_age_groups); - - // Set same infection parameter for all age groups. For example, the incubation period is 4 days. - model.parameters.get() = 4.; + mio::ParameterDistributionLogNormal log_norm(4., 1.); + // Set same infection parameter for all age groups. For example, the incubation period is log normally distributed with parameters 4 and 1. + model.parameters.get() = mio::ParameterDistributionWrapper(log_norm); // Set the age group the can go to school is AgeGroup(1) (i.e. 5-14) model.parameters.get()[age_group_5_to_14] = true; diff --git a/cpp/examples/abm_minimal.cpp b/cpp/examples/abm_minimal.cpp index 9abb173dcf..bbc5a6d501 100644 --- a/cpp/examples/abm_minimal.cpp +++ b/cpp/examples/abm_minimal.cpp @@ -21,6 +21,7 @@ #include "abm/lockdown_rules.h" #include "abm/model.h" #include "abm/common_abm_loggers.h" +#include "memilio/utils/parameter_distribution_wrapper.h" #include @@ -36,9 +37,9 @@ int main() // Create the model with 4 age groups. auto model = mio::abm::Model(num_age_groups); - - // Set same infection parameter for all age groups. For example, the incubation period is 4 days. - model.parameters.get() = 4.; + mio::ParameterDistributionLogNormal log_norm(4., 1.); + // Set same infection parameter for all age groups. For example, the incubation period is log normally distributed with parameters 4 and 1. + model.parameters.get() = mio::ParameterDistributionWrapper(log_norm); // Set the age group the can go to school is AgeGroup(1) (i.e. 5-14) model.parameters.get() = false; diff --git a/cpp/examples/ode_secir_parameter_sampling.cpp b/cpp/examples/ode_secir_parameter_sampling.cpp index 2ee75b74bd..079b59ea21 100644 --- a/cpp/examples/ode_secir_parameter_sampling.cpp +++ b/cpp/examples/ode_secir_parameter_sampling.cpp @@ -18,6 +18,7 @@ * limitations under the License. */ #include "memilio/utils/parameter_distributions.h" +#include "memilio/utils/random_number_generator.h" #include "ode_secir/parameter_space.h" #include "ode_secir/model.h" @@ -41,7 +42,7 @@ int main() printf("\n N(%.0f,%.0f)-distribution with sampling only in [%.0f,%.0f]", mean, stddev, min, max); int counter[10] = {0}; for (int i = 0; i < 1000; i++) { - int rounded = (int)(some_parameter.get_sample() - 1); + int rounded = (int)(some_parameter.get_sample(mio::thread_local_rng()) - 1); if (rounded >= 0 && rounded < 10) { counter[rounded]++; } @@ -59,7 +60,7 @@ int main() double counter_unif[10] = {0}; for (int i = 0; i < 1000; i++) { - int rounded = (int)(some_other_parameter.get_sample() - 1); + int rounded = (int)(some_other_parameter.get_sample(mio::thread_local_rng()) - 1); if (rounded >= 0 && rounded < 10) { counter_unif[rounded]++; } diff --git a/cpp/memilio/CMakeLists.txt b/cpp/memilio/CMakeLists.txt index e8b58aded5..5af06c920e 100644 --- a/cpp/memilio/CMakeLists.txt +++ b/cpp/memilio/CMakeLists.txt @@ -78,6 +78,7 @@ add_library(memilio utils/custom_index_array.h utils/memory.h utils/parameter_distributions.h + utils/parameter_distribution_wrapper.h utils/time_series.h utils/time_series.cpp utils/span.h diff --git a/cpp/memilio/epidemiology/damping_sampling.h b/cpp/memilio/epidemiology/damping_sampling.h index c17826883d..5b5a8103c1 100644 --- a/cpp/memilio/epidemiology/damping_sampling.h +++ b/cpp/memilio/epidemiology/damping_sampling.h @@ -21,6 +21,7 @@ #define EPI_SECIR_DAMPING_SAMPLING_H #include "memilio/epidemiology/damping.h" +#include "memilio/utils/random_number_generator.h" #include "memilio/utils/uncertain_value.h" #include diff --git a/cpp/memilio/utils/parameter_distribution_wrapper.h b/cpp/memilio/utils/parameter_distribution_wrapper.h new file mode 100644 index 0000000000..cbc6906ce3 --- /dev/null +++ b/cpp/memilio/utils/parameter_distribution_wrapper.h @@ -0,0 +1,139 @@ +/* +* Copyright (C) 2020-2024 MEmilio +* +* Authors: Julia Bicker +* +* Contact: Martin J. Kuehn +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +#ifndef PARAMETER_DISTRIBUTION_WRAPPER_H +#define PARAMETER_DISTRIBUTION_WRAPPER_H + +#include "memilio/io/io.h" +#include "memilio/utils/compiler_diagnostics.h" +#include "parameter_distributions.h" + +namespace mio +{ + +class ParameterDistributionWrapper +{ +public: + ParameterDistributionWrapper() + : m_dist(nullptr) + { + } + + ParameterDistributionWrapper(ParameterDistribution& dist) + : m_dist(std::unique_ptr(dist.clone())) + { + } + + ParameterDistributionWrapper(const ParameterDistributionWrapper& other) + { + m_dist = (other.m_dist == nullptr) ? nullptr : std::unique_ptr(other.m_dist->clone()); + } + + ParameterDistributionWrapper(ParameterDistributionWrapper&& other) + { + m_dist = (other.m_dist == nullptr) ? nullptr : std::unique_ptr(other.m_dist->clone()); + } + + ParameterDistributionWrapper& operator=(ParameterDistributionWrapper const& other) + { + m_dist = (other.m_dist == nullptr) ? nullptr : std::unique_ptr(other.m_dist->clone()); + return *this; + } + + ParameterDistributionWrapper& operator=(ParameterDistributionWrapper&& other) + { + m_dist = (other.m_dist == nullptr) ? nullptr : std::unique_ptr(other.m_dist->clone()); + return *this; + }; + + ~ParameterDistributionWrapper() = default; + + std::vector params() const + { + if (m_dist == nullptr) { + log_error("Distribution is not defined. Parameters cannot be deduced."); + } + return m_dist->params(); + } + + template + double get(RNG& rng) + { + if (m_dist == nullptr) { + log_error("Distribution is not defined. Value cannot be sampled."); + } + return m_dist->get_sample(rng); + } + + /** + * serialize this. + * @see mio::serialize + */ + template + void serialize(IOContext& io) const + { + m_dist->serialize(io); + } + +private: + std::unique_ptr m_dist; +}; + +/** + * deserialize a ParameterDistributionWrapper. + * @see mio::deserialize + */ +template +IOResult deserialize_internal(IOContext& io, Tag) +{ + + auto obj = io.expect_object("ParameterDistribution"); + auto type = obj.expect_element("Type", Tag{}); + if (type) { + if (type.value() == "Uniform") { + BOOST_OUTCOME_TRY(auto&& r, ParameterDistributionUniform::deserialize_elements(io, obj)); + return ParameterDistributionWrapper(r); + } + else if (type.value() == "Normal") { + BOOST_OUTCOME_TRY(auto&& r, ParameterDistributionNormal::deserialize_elements(io, obj)); + return ParameterDistributionWrapper(r); + } + else if (type.value() == "LogNormal") { + BOOST_OUTCOME_TRY(auto&& r, ParameterDistributionLogNormal::deserialize_elements(io, obj)); + return ParameterDistributionWrapper(r); + } + else if (type.value() == "Exponential") { + BOOST_OUTCOME_TRY(auto&& r, ParameterDistributionExponential::deserialize_elements(io, obj)); + return ParameterDistributionWrapper(r); + } + else if (type.value() == "Constant") { + BOOST_OUTCOME_TRY(auto&& r, ParameterDistributionConstant::deserialize_elements(io, obj)); + return ParameterDistributionWrapper(r); + } + else { + return failure(StatusCode::InvalidValue, "Type of ParameterDistribution in ParameterDistributionWrapper" + + type.value() + " not valid."); + } + } + return failure(type.error()); +} + +} // namespace mio + +#endif //PARAMETER_DISTRIBUTION_WRAPPER_H diff --git a/cpp/memilio/utils/parameter_distributions.h b/cpp/memilio/utils/parameter_distributions.h index 36d5cbea1f..05a3eed499 100644 --- a/cpp/memilio/utils/parameter_distributions.h +++ b/cpp/memilio/utils/parameter_distributions.h @@ -1,7 +1,7 @@ /* * Copyright (C) 2020-2024 MEmilio * -* Authors: Martin J. Kuehn, Daniel Abele +* Authors: Martin J. Kuehn, Daniel Abele, Julia Bicker * * Contact: Martin J. Kuehn * @@ -20,11 +20,14 @@ #ifndef PARAMETER_DISTRIBUTIONS_H #define PARAMETER_DISTRIBUTIONS_H +#include "memilio/utils/compiler_diagnostics.h" #include "memilio/utils/logging.h" #include "memilio/utils/visitor.h" #include "memilio/utils/random_number_generator.h" +#include "models/abm/personal_rng.h" #include "memilio/io/io.h" +#include #include #include #include @@ -37,9 +40,13 @@ namespace mio * * More information to the visitor pattern is here: https://en.wikipedia.org/wiki/Visitor_pattern */ -using ParameterDistributionVisitor = Visitor; +using ParameterDistributionVisitor = + Visitor; using ConstParameterDistributionVisitor = - ConstVisitor; + ConstVisitor; template using VisitableParameterDistribution = @@ -55,41 +62,25 @@ struct SerializationVisitor : ConstParameterDistributionVisitor { } virtual void visit(const ParameterDistributionNormal& normal_dist) final; virtual void visit(const ParameterDistributionUniform& uniform_dist) final; + virtual void visit(const ParameterDistributionLogNormal& lognormal_dist) final; + virtual void visit(const ParameterDistributionExponential& lognormal_dist) final; + virtual void visit(const ParameterDistributionConstant& lognormal_dist) final; IOObj& obj; }; } // namespace details /* - * Parameter Distribution class which contains the name of a variable as string - * the lower bound and the upper bound as maximum admissible values and an enum - * item with the name of the distribution + * Parameter Distribution class representing a generic distribution and contains predefined samples */ class ParameterDistribution { public: - ParameterDistribution(double lower_bound, double upper_bound) - : m_lower_bound(lower_bound) - , m_upper_bound(upper_bound) - { - } - ParameterDistribution() - : ParameterDistribution(0, 0) { } virtual ~ParameterDistribution() = default; - void set_lower_bound(double lower_bound) - { - m_lower_bound = lower_bound; - } - - void set_upper_bound(double upper_bound) - { - m_upper_bound = upper_bound; - } - void add_predefined_sample(double sample) { m_predefined_samples.push_back(sample); @@ -105,23 +96,14 @@ class ParameterDistribution return m_predefined_samples; } - double get_lower_bound() const - { - return m_lower_bound; - } - - double get_upper_bound() const - { - return m_upper_bound; - } - /* * @brief returns a value for the given parameter distribution * in case some predefined samples are set, these values are taken * first, in case the vector of predefined values is empty, a 'real' * random sample is taken */ - double get_sample() + template + double get_sample(RNG& rng) { if (m_predefined_samples.size() > 0) { double rnumb = m_predefined_samples[0]; @@ -129,7 +111,7 @@ class ParameterDistribution return rnumb; } else { - return get_rand_sample(); + return get_rand_sample(rng); } } @@ -145,7 +127,13 @@ class ParameterDistribution this->accept(visitor); } - virtual double get_rand_sample() = 0; + /** + * @brief Returns the distribution parameters as vector. + */ + virtual std::vector params() const = 0; + + virtual double get_rand_sample(RandomNumberGenerator& rng) = 0; + virtual double get_rand_sample(abm::PersonalRandomNumberGenerator&) = 0; virtual ParameterDistribution* clone() const = 0; @@ -159,8 +147,6 @@ class ParameterDistribution virtual void accept(ConstParameterDistributionVisitor& visitor) const = 0; protected: - double m_lower_bound; /*< A realistic lower bound on the given parameter */ - double m_upper_bound; /*< A realistic upper bound on the given parameter */ std::vector m_predefined_samples; // if these values are set; no real sample will occur but these values will be taken }; @@ -174,33 +160,43 @@ class ParameterDistributionNormal : public VisitableParameterDistribution() + , m_mean(0) + , m_standard_dev(1) + , m_distribution(0, 1) { - m_mean = 0; - m_standard_dev = 1; } ParameterDistributionNormal(double mean, double standard_dev) : VisitableParameterDistribution() + , m_mean(mean) + , m_standard_dev(standard_dev) + , m_distribution(mean, standard_dev) { m_mean = mean; m_standard_dev = standard_dev; - check_quantiles(m_mean, m_standard_dev); } ParameterDistributionNormal(double lower_bound, double upper_bound, double mean) - : VisitableParameterDistribution(lower_bound, upper_bound) + : VisitableParameterDistribution() + , m_mean(mean) + , m_upper_bound(upper_bound) + , m_lower_bound(lower_bound) { - m_mean = mean; + // if upper and lower bound are given, the standard deviation is calculated such that [lower_bound, upper_bound] represent the 0.995 quartile] m_standard_dev = upper_bound; // set as to high and adapt then - adapt_standard_dev(m_standard_dev); + adapt_standard_dev(m_standard_dev, upper_bound, lower_bound); + m_distribution = mio::NormalDistribution::ParamType(m_mean, m_standard_dev); } ParameterDistributionNormal(double lower_bound, double upper_bound, double mean, double standard_dev) - : VisitableParameterDistribution(lower_bound, upper_bound) + : VisitableParameterDistribution() + , m_mean(mean) + , m_standard_dev(standard_dev) + , m_upper_bound(upper_bound) + , m_lower_bound(lower_bound) { - m_mean = mean; - m_standard_dev = standard_dev; check_quantiles(m_mean, m_standard_dev); + m_distribution = mio::NormalDistribution::ParamType(m_mean, m_standard_dev); } void set_mean(double mean) @@ -220,12 +216,13 @@ class ParameterDistributionNormal : public VisitableParameterDistribution m_upper_bound) { - standard_dev = (m_upper_bound - m_mean) / m_quantile; + if (m_mean + standard_dev * m_quantile > upper_bound) { + standard_dev = (upper_bound - m_mean) / m_quantile; changed = true; } - if (m_mean - standard_dev * m_quantile < m_lower_bound) { - standard_dev = (m_mean - m_lower_bound) / m_quantile; + if (m_mean - standard_dev * m_quantile < lower_bound) { + standard_dev = (m_mean - lower_bound) / m_quantile; changed = true; } @@ -278,13 +275,39 @@ class ParameterDistributionNormal : public VisitableParameterDistribution params() const override + { + return {m_mean, m_standard_dev}; + } + /* * @brief gets a sample of a normally distributed variable * before sampling, it is verified that at least 99% of the * density function lie in the interval defined by the boundaries * otherwise the normal distribution is adapted */ - double get_rand_sample() override + template + double sample(RNG& rng) { //If ub = lb, sampling can only be succesful if mean = lb and dev = 0. //But this degenerate normal distribution is not allowed by the c++ standard. @@ -292,16 +315,16 @@ class ParameterDistributionNormal : public VisitableParameterDistribution{m_mean, m_standard_dev}; + if (check_quantiles(m_mean, m_standard_dev) || m_distribution.params.mean() != m_mean || + m_distribution.params.stddev() != m_standard_dev) { + m_distribution = NormalDistribution::ParamType{m_mean, m_standard_dev}; } int i = 0; int retries = 10; - double rnumb = m_distribution(thread_local_rng()); + double rnumb = m_distribution.get_distribution_instance()(thread_local_rng(), m_distribution.params); while ((rnumb > m_upper_bound || rnumb < m_lower_bound) && i < retries) { - rnumb = m_distribution(thread_local_rng()); + rnumb = m_distribution.get_distribution_instance()(rng, m_distribution.params); i++; if (i == retries) { log_warning("Not successfully sampled within [min,max]."); @@ -316,6 +339,16 @@ class ParameterDistributionNormal : public VisitableParameterDistribution void serialize_elements(IOObject& obj) const { @@ -342,15 +375,15 @@ class ParameterDistributionNormal : public VisitableParameterDistribution{}); auto predef = obj.expect_list("PredefinedSamples", Tag{}); auto p = apply( - io, - [](auto&& lb_, auto&& ub_, auto&& m_, auto&& s_, auto&& predef_) { + io, + [](auto&& lb_, auto&& ub_, auto&& m_, auto&& s_, auto&& predef_) { auto distr = ParameterDistributionNormal(lb_, ub_, m_, s_); for (auto&& e : predef_) { distr.add_predefined_sample(e); } return distr; - }, - lb, ub, m, s, predef); + }, + lb, ub, m, s, predef); if (p) { return success(p.value()); } @@ -374,8 +407,11 @@ class ParameterDistributionNormal : public VisitableParameterDistribution::max(); // upper bound and lower bound can be given to the constructor instead of stddev + double m_lower_bound = std::numeric_limits::min(); constexpr static double m_quantile = 2.5758; // 0.995 quartile - std::normal_distribution m_distribution; + NormalDistribution::ParamType m_distribution; bool m_log_stddev_change = true; }; @@ -392,26 +428,60 @@ void details::SerializationVisitor::visit(const ParameterDistributionNorm class ParameterDistributionUniform : public VisitableParameterDistribution { public: - ParameterDistributionUniform() + ParameterDistributionUniform(double lower_bound, double upper_bound) : VisitableParameterDistribution() + , m_upper_bound(upper_bound) + , m_lower_bound(lower_bound) + , m_distribution(lower_bound, upper_bound) { } - ParameterDistributionUniform(double lower_bound, double upper_bound) - : VisitableParameterDistribution(lower_bound, upper_bound) + std::vector params() const override + { + return {m_lower_bound, m_upper_bound}; + } + + void set_lower_bound(double lower_bound) + { + m_lower_bound = lower_bound; + } + + void set_upper_bound(double upper_bound) + { + m_upper_bound = upper_bound; + } + + double get_lower_bound() const { + return m_lower_bound; + } + + double get_upper_bound() const + { + return m_upper_bound; } /* * @brief gets a sample of a uniformly distributed variable */ - double get_rand_sample() override + template + double sample(RNG& rng) { - if (m_distribution.max() != m_upper_bound || m_distribution.min() != m_lower_bound) { - m_distribution = std::uniform_real_distribution{m_lower_bound, m_upper_bound}; + if (m_distribution.params.b() != m_upper_bound || m_distribution.params.a() != m_lower_bound) { + m_distribution = UniformDistribution::ParamType{m_lower_bound, m_upper_bound}; } - return m_distribution(thread_local_rng()); + return m_distribution.get_distribution_instance()(rng, m_distribution.params); + } + + double get_rand_sample(RandomNumberGenerator& rng) override + { + return sample(rng); + } + + double get_rand_sample(abm::PersonalRandomNumberGenerator& rng) override + { + return sample(rng); } ParameterDistribution* clone() const override @@ -441,15 +511,15 @@ class ParameterDistributionUniform : public VisitableParameterDistribution{}); auto predef = obj.expect_list("PredefinedSamples", Tag{}); auto p = apply( - io, - [](auto&& lb_, auto&& ub_, auto&& predef_) { + io, + [](auto&& lb_, auto&& ub_, auto&& predef_) { auto distr = ParameterDistributionUniform(lb_, ub_); for (auto&& e : predef_) { distr.add_predefined_sample(e); } return distr; - }, - lb, ub, predef); + }, + lb, ub, predef); if (p) { return success(p.value()); } @@ -466,7 +536,9 @@ class ParameterDistributionUniform : public VisitableParameterDistribution m_distribution; + double m_upper_bound; + double m_lower_bound; + UniformDistribution::ParamType m_distribution; }; template @@ -476,6 +548,350 @@ void details::SerializationVisitor::visit(const ParameterDistributionUnif uniform_dist.serialize_elements(obj); } +/* + * Child class of Parameter Distribution class which represents an lognormal distribution + */ +class ParameterDistributionLogNormal : public VisitableParameterDistribution +{ +public: + ParameterDistributionLogNormal(double log_mean, double log_stddev) + : VisitableParameterDistribution() + , m_log_mean(log_mean) + , m_log_stddev(log_stddev) + , m_distribution(log_mean, log_stddev) + { + } + + std::vector params() const override + { + return {m_log_mean, m_log_stddev}; + } + + void set_log_mean(double log_mean) + { + m_log_mean = log_mean; + } + + void set_log_stddev(double log_stddev) + { + m_log_stddev = log_stddev; + } + + double get_log_mean() const + { + return m_log_mean; + } + + double get_log_stddev() const + { + return m_log_stddev; + } + + /* + * @brief gets a sample of a lognormally distributed variable + */ + template + double sample(RNG& rng) + { + if (m_distribution.params.m() != m_log_mean || m_distribution.params.s() != m_log_stddev) { + m_distribution = LogNormalDistribution::ParamType{m_log_mean, m_log_stddev}; + } + + return m_distribution.get_distribution_instance()(rng, m_distribution.params); + } + + double get_rand_sample(RandomNumberGenerator& rng) override + { + return sample(rng); + } + + double get_rand_sample(abm::PersonalRandomNumberGenerator& rng) override + { + return sample(rng); + } + + ParameterDistribution* clone() const override + { + return new ParameterDistributionLogNormal(*this); + } + + template + void serialize_elements(IOObject& obj) const + { + obj.add_element("LogMean", m_log_mean); + obj.add_element("LogStddev", m_log_stddev); + obj.add_list("PredefinedSamples", m_predefined_samples.begin(), m_predefined_samples.end()); + } + + template + void serialize(IOContext& io) const + { + auto obj = io.create_object("ParameterDistributionLogNormal"); + serialize_elements(obj); + } + + template + static IOResult deserialize_elements(IOContext& io, IOObject& obj) + { + auto lm = obj.expect_element("LogMean", Tag{}); + auto ls = obj.expect_element("LogStddev", Tag{}); + auto predef = obj.expect_list("PredefinedSamples", Tag{}); + auto p = apply( + io, + [](auto&& lm_, auto&& ls_, auto&& predef_) { + auto distr = ParameterDistributionLogNormal(lm_, ls_); + for (auto&& e : predef_) { + distr.add_predefined_sample(e); + } + return distr; + }, + lm, ls, predef); + if (p) { + return success(p.value()); + } + else { + return p.as_failure(); + } + } + + template + static IOResult deserialize(IOContext& io) + { + auto obj = io.expect_object("ParameterDistributionLogNormal"); + return deserialize_elements(io, obj); + } + +private: + double m_log_mean; + double m_log_stddev; + LogNormalDistribution::ParamType m_distribution; +}; + +template +void details::SerializationVisitor::visit(const ParameterDistributionLogNormal& uniform_dist) +{ + obj.add_element("Type", std::string("LogNormal")); + uniform_dist.serialize_elements(obj); +} + +/* + * Child class of Parameter Distribution class which represents an exponential distribution + */ +class ParameterDistributionExponential : public VisitableParameterDistribution +{ +public: + ParameterDistributionExponential(double rate) + : VisitableParameterDistribution() + , m_rate(rate) + , m_distribution(rate) + { + } + + std::vector params() const override + { + return {m_rate}; + } + + void set_rate(double rate) + { + m_rate = rate; + } + + double get_rate() const + { + return m_rate; + } + + /* + * @brief gets a sample of a exponentially distributed variable + */ + template + double sample(RNG& rng) + { + if (m_distribution.params.lambda() != m_rate) { + m_distribution = ExponentialDistribution::ParamType{m_rate}; + } + + return m_distribution.get_distribution_instance()(rng, m_distribution.params); + } + + double get_rand_sample(RandomNumberGenerator& rng) override + { + return sample(rng); + } + + double get_rand_sample(abm::PersonalRandomNumberGenerator& rng) override + { + return sample(rng); + } + + ParameterDistribution* clone() const override + { + return new ParameterDistributionExponential(*this); + } + + template + void serialize_elements(IOObject& obj) const + { + obj.add_element("Rate", m_rate); + obj.add_list("PredefinedSamples", m_predefined_samples.begin(), m_predefined_samples.end()); + } + + template + void serialize(IOContext& io) const + { + auto obj = io.create_object("ParameterDistributionExponential"); + serialize_elements(obj); + } + + template + static IOResult deserialize_elements(IOContext& io, IOObject& obj) + { + auto r = obj.expect_element("Rate", Tag{}); + auto predef = obj.expect_list("PredefinedSamples", Tag{}); + auto p = apply( + io, + [](auto&& r_, auto&& predef_) { + auto distr = ParameterDistributionExponential(r_); + for (auto&& e : predef_) { + distr.add_predefined_sample(e); + } + return distr; + }, + r, predef); + if (p) { + return success(p.value()); + } + else { + return p.as_failure(); + } + } + + template + static IOResult deserialize(IOContext& io) + { + auto obj = io.expect_object("ParameterDistributionExponential"); + return deserialize_elements(io, obj); + } + +private: + double m_rate; + ExponentialDistribution::ParamType m_distribution; +}; + +template +void details::SerializationVisitor::visit(const ParameterDistributionExponential& uniform_dist) +{ + obj.add_element("Type", std::string("Exponential")); + uniform_dist.serialize_elements(obj); +} + +/* + * Child class of Parameter Distribution class which represents a constant distribution/value + */ +class ParameterDistributionConstant : public VisitableParameterDistribution +{ +public: + ParameterDistributionConstant(double constant) + : VisitableParameterDistribution() + , m_constant(constant) + { + } + + std::vector params() const override + { + return {m_constant}; + } + + void set_constant(double constant) + { + m_constant = constant; + } + + double get_constant() const + { + return m_constant; + } + + /* + * @brief gets a constant + */ + template + double sample(RNG& /*rng*/) + { + return m_constant; + } + + double get_rand_sample(RandomNumberGenerator& rng) override + { + return sample(rng); + } + + double get_rand_sample(abm::PersonalRandomNumberGenerator& rng) override + { + return sample(rng); + } + + ParameterDistribution* clone() const override + { + return new ParameterDistributionConstant(*this); + } + + template + void serialize_elements(IOObject& obj) const + { + obj.add_element("Constant", m_constant); + obj.add_list("PredefinedSamples", m_predefined_samples.begin(), m_predefined_samples.end()); + } + + template + void serialize(IOContext& io) const + { + auto obj = io.create_object("ParameterDistributionConstant"); + serialize_elements(obj); + } + + template + static IOResult deserialize_elements(IOContext& io, IOObject& obj) + { + auto c = obj.expect_element("Constant", Tag{}); + auto predef = obj.expect_list("PredefinedSamples", Tag{}); + auto p = apply( + io, + [](auto&& c_, auto&& predef_) { + auto distr = ParameterDistributionConstant(c_); + for (auto&& e : predef_) { + distr.add_predefined_sample(e); + } + return distr; + }, + c, predef); + if (p) { + return success(p.value()); + } + else { + return p.as_failure(); + } + } + + template + static IOResult deserialize(IOContext& io) + { + auto obj = io.expect_object("ParameterDistributionConstant"); + return deserialize_elements(io, obj); + } + +private: + double m_constant; +}; + +template +void details::SerializationVisitor::visit(const ParameterDistributionConstant& uniform_dist) +{ + obj.add_element("Type", std::string("Constant")); + uniform_dist.serialize_elements(obj); +} + /** * deserialize a parameter distribution as a shared_ptr. * @see mio::deserialize @@ -495,6 +911,18 @@ IOResult> deserialize_internal(IOContext& BOOST_OUTCOME_TRY(auto&& r, ParameterDistributionNormal::deserialize_elements(io, obj)); return std::make_shared(r); } + else if (type.value() == "LogNormal") { + BOOST_OUTCOME_TRY(auto&& r, ParameterDistributionLogNormal::deserialize_elements(io, obj)); + return std::make_shared(r); + } + else if (type.value() == "Exponential") { + BOOST_OUTCOME_TRY(auto&& r, ParameterDistributionExponential::deserialize_elements(io, obj)); + return std::make_shared(r); + } + else if (type.value() == "Constant") { + BOOST_OUTCOME_TRY(auto&& r, ParameterDistributionConstant::deserialize_elements(io, obj)); + return std::make_shared(r); + } else { return failure(StatusCode::InvalidValue, "Type of ParameterDistribution " + type.value() + " not valid."); } diff --git a/cpp/memilio/utils/random_number_generator.h b/cpp/memilio/utils/random_number_generator.h index 96456dade4..c0e8fa968b 100644 --- a/cpp/memilio/utils/random_number_generator.h +++ b/cpp/memilio/utils/random_number_generator.h @@ -711,6 +711,27 @@ IOResult deserialize_internal(IOContext& io, Tag using PoissonDistribution = DistributionAdapter>; +/** + * adapted lognormal_distribution. + * @see DistributionAdapter + */ +template +using LogNormalDistribution = DistributionAdapter>; + +/** + * adapted gamma_distribution. + * @see DistributionAdapter + */ +template +using GammaDistribution = DistributionAdapter>; + +/** + * adapted normal_distribution. + * @see DistributionAdapter + */ +template +using NormalDistribution = DistributionAdapter>; + } // namespace mio #endif diff --git a/cpp/memilio/utils/uncertain_value.h b/cpp/memilio/utils/uncertain_value.h index 90336d29ab..14f52cbd46 100644 --- a/cpp/memilio/utils/uncertain_value.h +++ b/cpp/memilio/utils/uncertain_value.h @@ -23,6 +23,7 @@ #include "memilio/config.h" #include "memilio/utils/memory.h" #include "memilio/utils/parameter_distributions.h" +#include "memilio/utils/random_number_generator.h" #include #include @@ -150,7 +151,7 @@ class UncertainValue FP draw_sample() { if (m_dist) { - m_value = m_dist->get_sample(); + m_value = m_dist->get_sample(mio::thread_local_rng()); } return m_value; diff --git a/cpp/models/abm/analyze_result.h b/cpp/models/abm/analyze_result.h index 05107af313..acfd8fde53 100644 --- a/cpp/models/abm/analyze_result.h +++ b/cpp/models/abm/analyze_result.h @@ -21,6 +21,7 @@ #define ABM_ANALYZE_RESULT_H #include "abm/parameters.h" +#include "memilio/utils/compiler_diagnostics.h" #include @@ -46,7 +47,7 @@ std::vector ensemble_params_percentile(const std::vector percentile(num_nodes, Model((int)num_groups)); - auto param_percentil = [&ensemble_params, p, num_runs, &percentile](auto n, auto get_param) mutable { + auto param_percentile = [&ensemble_params, p, num_runs, &percentile](auto n, auto get_param) mutable { std::vector single_element(num_runs); for (size_t run = 0; run < num_runs; run++) { auto const& params = ensemble_params[run][n]; @@ -56,139 +57,184 @@ std::vector ensemble_params_percentile(const std::vector(num_runs * p)]; }; + mio::unused(param_percentile); + + auto param_percentile_dist = [&ensemble_params, p, num_runs, &percentile](auto n, auto single_element, + auto get_param, auto sort_fct) mutable { + for (size_t run = 0; run < num_runs; run++) { + auto const& params = ensemble_params[run][n]; + single_element[run] = get_param(params); + } + std::sort(single_element.begin(), single_element.end(), sort_fct); + auto& new_params = get_param(percentile[n]); + new_params = single_element[static_cast(num_runs * p)]; + }; for (size_t node = 0; node < num_nodes; node++) { for (auto age_group = AgeGroup(0); age_group < AgeGroup(num_groups); age_group++) { - for (auto virus_variant = VirusVariant(0); virus_variant < VirusVariant::Count; - virus_variant = static_cast((uint32_t)virus_variant + 1)) { + for (auto virus_variant : enum_members()) { // Global infection parameters - param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { - return model.parameters.template get()[{virus_variant, age_group}]; - }); - param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { - return model.parameters.template get()[{virus_variant, age_group}]; - }); - param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { - return model.parameters.template get()[{virus_variant, age_group}]; - }); - param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { - return model.parameters.template get()[{virus_variant, age_group}]; - }); - param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { - return model.parameters.template get()[{virus_variant, age_group}]; - }); - param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { - return model.parameters.template get()[{virus_variant, age_group}]; - }); - param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { - return model.parameters.template get()[{virus_variant, age_group}]; - }); - param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { - return model.parameters.template get()[{virus_variant, age_group}]; - }); - param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { - return model.parameters.template get()[{virus_variant, age_group}]; - }); - param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { - return model.parameters.template get()[{virus_variant, age_group}]; + + //Stay time distributions + param_percentile_dist( + node, std::vector(num_runs), + [age_group, virus_variant](auto&& model) -> auto& { + return model.parameters.template get()[{virus_variant, age_group}]; + }, + [](auto& dist1, auto& dist2) { + return dist1.params()[0] < dist2.params()[0]; + }); + param_percentile_dist( + node, std::vector(num_runs), + [age_group, virus_variant](auto&& model) -> auto& { + return model.parameters + .template get()[{virus_variant, age_group}]; + }, + [](auto& dist1, auto& dist2) { + return dist1.params()[0] < dist2.params()[0]; + }); + param_percentile_dist( + node, std::vector(num_runs), + [age_group, virus_variant](auto&& model) -> auto& { + return model.parameters + .template get()[{virus_variant, age_group}]; + }, + [](auto& dist1, auto& dist2) { + return dist1.params()[0] < dist2.params()[0]; + }); + param_percentile_dist( + node, std::vector(num_runs), + [age_group, virus_variant](auto&& model) -> auto& { + return model.parameters + .template get()[{virus_variant, age_group}]; + }, + [](auto& dist1, auto& dist2) { + return dist1.params()[0] < dist2.params()[0]; + }); + param_percentile_dist( + node, std::vector(num_runs), + [age_group, virus_variant](auto&& model) -> auto& { + return model.parameters + .template get()[{virus_variant, age_group}]; + }, + [](auto& dist1, auto& dist2) { + return dist1.params()[0] < dist2.params()[0]; + }); + param_percentile_dist( + node, std::vector(num_runs), + [age_group, virus_variant](auto&& model) -> auto& { + return model.parameters + .template get()[{virus_variant, age_group}]; + }, + [](auto& dist1, auto& dist2) { + return dist1.params()[0] < dist2.params()[0]; + }); + param_percentile_dist( + node, std::vector(num_runs), + [age_group, virus_variant](auto&& model) -> auto& { + return model.parameters + .template get()[{virus_variant, age_group}]; + }, + [](auto& dist1, auto& dist2) { + return dist1.params()[0] < dist2.params()[0]; + }); + param_percentile_dist( + node, std::vector(num_runs), + [age_group, virus_variant](auto&& model) -> auto& { + return model.parameters.template get()[{virus_variant, age_group}]; + }, + [](auto& dist1, auto& dist2) { + return dist1.params()[0] < dist2.params()[0]; + }); + param_percentile_dist( + node, std::vector(num_runs), + [age_group, virus_variant](auto&& model) -> auto& { + return model.parameters + .template get()[{virus_variant, age_group}]; + }, + [](auto& dist1, auto& dist2) { + return dist1.params()[0] < dist2.params()[0]; + }); + // Uncertain values + param_percentile(node, [age_group, virus_variant](auto&& model) -> auto& { + return model.parameters.template get()[{virus_variant, age_group}]; + }); + param_percentile(node, [age_group, virus_variant](auto&& model) -> auto& { + return model.parameters.template get()[{virus_variant, age_group}]; + }); + param_percentile(node, [age_group, virus_variant](auto&& model) -> auto& { + return model.parameters.template get()[{virus_variant, age_group}]; + }); + param_percentile(node, [age_group, virus_variant](auto&& model) -> auto& { + return model.parameters.template get()[{virus_variant, age_group}]; }); - param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { + + param_percentile(node, [age_group, virus_variant](auto&& model) -> auto& { return model.parameters.template get()[{virus_variant, age_group}]; }); - param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { - static auto result = - model.parameters.template get()[{virus_variant, age_group}] - .viral_load_incline.params.a(); - return result; - }); - param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { - static auto result = - model.parameters.template get()[{virus_variant, age_group}] - .viral_load_incline.params.b(); - return result; - }); - param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { - static auto result = - model.parameters.template get()[{virus_variant, age_group}] - .viral_load_decline.params.a(); - return result; - }); - param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { - static auto result = - model.parameters.template get()[{virus_variant, age_group}] - .viral_load_decline.params.b(); - return result; - }); - param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { - static auto result = - model.parameters.template get()[{virus_variant, age_group}] - .viral_load_peak.params.a(); - return result; - }); - param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { - static auto result = - model.parameters.template get()[{virus_variant, age_group}] - .viral_load_peak.params.b(); - return result; - }); - param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { - static auto result = - model.parameters.template get()[{virus_variant, age_group}] - .infectivity_alpha.params.a(); - return result; - }); - param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { - static auto result = - model.parameters.template get()[{virus_variant, age_group}] - .infectivity_alpha.params.b(); - return result; - }); - param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { - static auto result = - model.parameters.template get()[{virus_variant, age_group}] - .infectivity_beta.params.a(); - return result; - }); - param_percentil(node, [age_group, virus_variant](auto&& model) -> auto& { - static auto result = - model.parameters.template get()[{virus_variant, age_group}] - .infectivity_beta.params.b(); - return result; - }); - param_percentil(node, [virus_variant](auto&& model) -> auto& { + + param_percentile(node, [virus_variant](auto&& model) -> auto& { return model.parameters.template get()[{virus_variant}]; }); + + //Other distributions + param_percentile_dist( + node, std::vector(num_runs), + [age_group, virus_variant](auto&& model) -> auto& { + return model.parameters.template get()[{virus_variant, age_group}]; + }, + [](auto& dist1, auto& dist2) { + return (dist1.viral_load_peak.params.a() + dist1.viral_load_peak.params.a()) / 2. < + (dist2.viral_load_peak.params.a() + dist2.viral_load_peak.params.a()) / 2.; + }); + param_percentile_dist( + node, std::vector(num_runs), + [age_group, virus_variant](auto&& model) -> auto& { + return model.parameters.template get()[{virus_variant, age_group}]; + }, + [](auto& dist1, auto& dist2) { + return (dist1.infectivity_alpha.params.a() + dist1.infectivity_alpha.params.b()) / 2. < + (dist2.infectivity_alpha.params.a() + dist2.infectivity_alpha.params.b()) / 2.; + }); + param_percentile_dist( + node, std::vector::ParamType>(num_runs), + [age_group, virus_variant](auto&& model) -> auto& { + return model.parameters.template get()[{virus_variant, age_group}]; + }, + [](auto& dist1, auto& dist2) { + return (dist1.params.a() + dist1.params.b()) / 2. < (dist2.params.a() + dist2.params.b()) / 2.; + }); } - param_percentil(node, [age_group](auto&& model) -> auto& { + param_percentile(node, [age_group](auto&& model) -> auto& { return model.parameters.template get()[{age_group}]; }); - param_percentil(node, [age_group](auto&& model) -> auto& { + param_percentile(node, [age_group](auto&& model) -> auto& { static auto result = model.parameters.template get()[{age_group}].hours(); return result; }); - param_percentil(node, [age_group](auto&& model) -> auto& { + param_percentile(node, [age_group](auto&& model) -> auto& { static auto result = model.parameters.template get()[{age_group}].hours(); return result; }); - param_percentil(node, [age_group](auto&& model) -> auto& { + param_percentile(node, [age_group](auto&& model) -> auto& { static auto result = model.parameters.template get()[{age_group}].hours(); return result; }); - param_percentil(node, [age_group](auto&& model) -> auto& { + param_percentile(node, [age_group](auto&& model) -> auto& { static auto result = model.parameters.template get()[{age_group}].hours(); return result; }); } - param_percentil(node, [](auto&& model) -> auto& { + param_percentile(node, [](auto&& model) -> auto& { return model.parameters.template get()[MaskType::Community]; }); - param_percentil(node, [](auto&& model) -> auto& { + param_percentile(node, [](auto&& model) -> auto& { return model.parameters.template get()[MaskType::FFP2]; }); - param_percentil(node, [](auto&& model) -> auto& { + param_percentile(node, [](auto&& model) -> auto& { return model.parameters.template get()[MaskType::Surgical]; }); - param_percentil(node, [](auto&& model) -> auto& { + param_percentile(node, [](auto&& model) -> auto& { static auto result = model.parameters.template get().days(); return result; }); diff --git a/cpp/models/abm/infection.cpp b/cpp/models/abm/infection.cpp index e7ca94ab3b..905fb7be71 100644 --- a/cpp/models/abm/infection.cpp +++ b/cpp/models/abm/infection.cpp @@ -19,6 +19,7 @@ */ #include "abm/infection.h" +#include "memilio/utils/compiler_diagnostics.h" #include namespace mio @@ -37,9 +38,8 @@ Infection::Infection(PersonalRandomNumberGenerator& rng, VirusVariant virus, Age auto vl_params = params.get()[{virus, age}]; ScalarType high_viral_load_factor = 1; if (latest_protection.type != ProtectionType::NoProtection) { - high_viral_load_factor -= - params.get()[{latest_protection.type, age, virus}]( - init_date.days() - latest_protection.time.days()); + high_viral_load_factor -= params.get()[{latest_protection.type, age, virus}]( + init_date.days() - latest_protection.time.days()); } m_viral_load.peak = vl_params.viral_load_peak.get_distribution_instance()(rng, vl_params.viral_load_peak.params) * high_viral_load_factor; @@ -55,6 +55,9 @@ Infection::Infection(PersonalRandomNumberGenerator& rng, VirusVariant virus, Age m_log_norm_alpha = inf_params.infectivity_alpha.get_distribution_instance()(rng, inf_params.infectivity_alpha.params); m_log_norm_beta = inf_params.infectivity_beta.get_distribution_instance()(rng, inf_params.infectivity_beta.params); + + auto shedfactor_param = params.get()[{virus, age}]; + m_individual_virus_shed_factor = shedfactor_param.get_distribution_instance()(rng, shedfactor_param.params); } ScalarType Infection::get_viral_load(TimePoint t) const @@ -77,7 +80,7 @@ ScalarType Infection::get_infectivity(TimePoint t) const { if (m_viral_load.start_date >= t || get_infection_state(t) == InfectionState::Exposed) return 0; - return 1 / (1 + exp(-(m_log_norm_alpha + m_log_norm_beta * get_viral_load(t)))); + return m_individual_virus_shed_factor / (1 + exp(-(m_log_norm_alpha + m_log_norm_beta * get_viral_load(t)))); } VirusVariant Infection::get_virus_variant() const @@ -129,79 +132,90 @@ void Infection::draw_infection_course_forward(PersonalRandomNumberGenerator& rng assert(age.get() < params.get_num_groups()); auto t = init_date; TimeSpan time_period{}; // time period for current infection state + auto time_in_state = params.get()[{ + m_virus_variant, age}]; // time distribution parameters for current infection state InfectionState next_state{start_state}; // next state to enter m_infection_course.push_back(std::pair(t, next_state)); auto& uniform_dist = UniformDistribution::get_instance(); - ScalarType v; // random draws + ScalarType p; // uniform random draws from [0, 1] while ((next_state != InfectionState::Recovered && next_state != InfectionState::Dead)) { switch (next_state) { - case InfectionState::Exposed: + case InfectionState::Exposed: { // roll out how long until infected without symptoms - time_period = days(params.get()[{m_virus_variant, age}]); // subject to change - next_state = InfectionState::InfectedNoSymptoms; - break; - case InfectionState::InfectedNoSymptoms: + time_in_state = params.get()[{m_virus_variant, age}]; + time_period = days(time_in_state.get(rng)); + next_state = InfectionState::InfectedNoSymptoms; + } break; + case InfectionState::InfectedNoSymptoms: { // roll out next infection step - v = uniform_dist(rng); - if (v < 0.5) { // TODO: subject to change - time_period = - days(params.get()[{m_virus_variant, age}]); // TODO: subject to change - next_state = InfectionState::InfectedSymptoms; + + p = uniform_dist(rng); + if (p < params.get()[{m_virus_variant, age}]) { + next_state = InfectionState::InfectedSymptoms; + time_in_state = params.get()[{m_virus_variant, age}]; + time_period = days(time_in_state.get(rng)); } else { - time_period = days( - params.get()[{m_virus_variant, age}]); // TODO: subject to change - next_state = InfectionState::Recovered; + next_state = InfectionState::Recovered; + time_in_state = params.get()[{m_virus_variant, age}]; + time_period = days(time_in_state.get(rng)); } - - break; - case InfectionState::InfectedSymptoms: + } break; + case InfectionState::InfectedSymptoms: { // roll out next infection step - { - ScalarType severity_protection_factor = 0.5; - v = uniform_dist(rng); - if (latest_protection.type != ProtectionType::NoProtection) { - severity_protection_factor = - params.get()[{latest_protection.type, age, m_virus_variant}]( - t.days() - latest_protection.time.days()); - } - if (v < (1 - severity_protection_factor) * 0.5) { - time_period = - days(params.get()[{m_virus_variant, age}]); // TODO: subject to change - next_state = InfectionState::InfectedSevere; - } - else { - time_period = days( - params.get()[{m_virus_variant, age}]); // TODO: subject to change - next_state = InfectionState::Recovered; - } - break; + + ScalarType severity_protection_factor = 1.; + p = uniform_dist(rng); + if (latest_protection.type != ProtectionType::NoProtection) { + severity_protection_factor = + params.get()[{latest_protection.type, age, m_virus_variant}]( + t.days() - latest_protection.time.days()); } - case InfectionState::InfectedSevere: + if (p < + (1 - severity_protection_factor) * params.get()[{m_virus_variant, age}]) { + next_state = InfectionState::InfectedSevere; + time_in_state = params.get()[{m_virus_variant, age}]; + time_period = days(time_in_state.get(rng)); + } + else { + next_state = InfectionState::Recovered; + time_in_state = params.get()[{m_virus_variant, age}]; + time_period = days(time_in_state.get(rng)); + } + } break; + + case InfectionState::InfectedSevere: { // roll out next infection step - v = uniform_dist(rng); - if (v < 0.5) { // TODO: subject to change - time_period = days(params.get()[{m_virus_variant, age}]); // TODO: subject to change - next_state = InfectionState::InfectedCritical; + + p = uniform_dist(rng); + if (p < params.get()[{m_virus_variant, age}]) { + next_state = InfectionState::InfectedCritical; + time_in_state = params.get()[{m_virus_variant, age}]; + time_period = days(time_in_state.get(rng)); } else { - time_period = days(params.get()[{m_virus_variant, age}]); // TODO: subject to change - next_state = InfectionState::Recovered; + next_state = InfectionState::Recovered; + time_in_state = params.get()[{m_virus_variant, age}]; + time_period = days(time_in_state.get(rng)); } - break; - case InfectionState::InfectedCritical: + } break; + + case InfectionState::InfectedCritical: { // roll out next infection step - v = uniform_dist(rng); - if (v < 0.5) { // TODO: subject to change - time_period = days(params.get()[{m_virus_variant, age}]); // TODO: subject to change - next_state = InfectionState::Dead; + + p = uniform_dist(rng); + if (p < params.get()[{m_virus_variant, age}]) { + next_state = InfectionState::Dead; + time_in_state = params.get()[{m_virus_variant, age}]; + time_period = days(time_in_state.get(rng)); } else { - time_period = - days(params.get()[{m_virus_variant, age}]); // TODO: subject to change - next_state = InfectionState::Recovered; + next_state = InfectionState::Recovered; + time_in_state = params.get()[{m_virus_variant, age}]; + time_period = days(time_in_state.get(rng)); } - break; + } break; + default: break; } @@ -217,63 +231,79 @@ TimePoint Infection::draw_infection_course_backward(PersonalRandomNumberGenerato assert(age.get() < params.get_num_groups()); auto start_date = init_date; TimeSpan time_period{}; // time period for current infection state - InfectionState previous_state{init_state}; // next state to enter + auto time_in_state = params.get()[{ + m_virus_variant, age}]; // time distribution parameters for current infection state + InfectionState previous_state{init_state}; // previous state to enter auto& uniform_dist = UniformDistribution::get_instance(); - ScalarType v; // random draws + ScalarType p; // uniform random draws from [0, 1] while ((previous_state != InfectionState::Exposed)) { switch (previous_state) { - case InfectionState::InfectedNoSymptoms: - time_period = days(params.get()[{m_virus_variant, age}]); // TODO: subject to change + case InfectionState::InfectedNoSymptoms: { + time_in_state = params.get()[{m_virus_variant, age}]; + time_period = days(time_in_state.get(rng)); previous_state = InfectionState::Exposed; - break; + } break; - case InfectionState::InfectedSymptoms: - time_period = - days(params.get()[{m_virus_variant, age}]); // TODO: subject to change + case InfectionState::InfectedSymptoms: { + time_in_state = params.get()[{m_virus_variant, age}]; + time_period = days(time_in_state.get(rng)); previous_state = InfectionState::InfectedNoSymptoms; - break; + } break; - case InfectionState::InfectedSevere: - time_period = - days(params.get()[{m_virus_variant, age}]); // TODO: subject to change + case InfectionState::InfectedSevere: { + time_in_state = params.get()[{m_virus_variant, age}]; + time_period = days(time_in_state.get(rng)); previous_state = InfectionState::InfectedSymptoms; - break; + } break; - case InfectionState::InfectedCritical: - time_period = days(params.get()[{m_virus_variant, age}]); // TODO: subject to change + case InfectionState::InfectedCritical: { + time_in_state = params.get()[{m_virus_variant, age}]; + time_period = days(time_in_state.get(rng)); previous_state = InfectionState::InfectedSevere; - break; + } break; - case InfectionState::Recovered: + case InfectionState::Recovered: { // roll out next infection step - v = uniform_dist(rng); - if (v < 0.25) { - time_period = days( - params.get()[{m_virus_variant, age}]); // TODO: subject to change + p = uniform_dist(rng); + // compute correct probabilities while factoring out the chance to die + auto p_death = params.get()[{m_virus_variant, age}] * + params.get()[{m_virus_variant, age}] * + params.get()[{m_virus_variant, age}] * + params.get()[{m_virus_variant, age}]; + if (p < (1 - params.get()[{m_virus_variant, age}]) / (1 - p_death)) { + time_in_state = params.get()[{m_virus_variant, age}]; + time_period = days(time_in_state.get(rng)); previous_state = InfectionState::InfectedNoSymptoms; } - else if (v < 0.5) { // TODO: subject to change - time_period = - days(params.get()[{m_virus_variant, age}]); // TODO: subject to change + else if (p < (1 - params.get()[{m_virus_variant, age}] * + (1 - params.get()[{m_virus_variant, age}])) / + (1 - p_death)) { + time_in_state = params.get()[{m_virus_variant, age}]; + time_period = days(time_in_state.get(rng)); previous_state = InfectionState::InfectedSymptoms; } - else if (v < 0.75) { - time_period = days(params.get()[{m_virus_variant, age}]); // TODO: subject to change + else if (p < (1 - params.get()[{m_virus_variant, age}] * + params.get()[{m_virus_variant, age}] * + (1 - params.get()[{m_virus_variant, age}])) / + (1 - p_death)) { + time_in_state = params.get()[{m_virus_variant, age}]; + time_period = days(time_in_state.get(rng)); previous_state = InfectionState::InfectedSevere; } else { - time_period = - days(params.get()[{m_virus_variant, age}]); // TODO: subject to change + time_in_state = params.get()[{m_virus_variant, age}]; + time_period = days(time_in_state.get(rng)); previous_state = InfectionState::InfectedCritical; } - break; + } break; - case InfectionState::Dead: - time_period = days(params.get()[{m_virus_variant, age}]); // TODO: subject to change + case InfectionState::Dead: { + time_in_state = params.get()[{m_virus_variant, age}]; + time_period = days(time_in_state.get(rng)); previous_state = InfectionState::InfectedCritical; - break; + } break; default: break; diff --git a/cpp/models/abm/infection.h b/cpp/models/abm/infection.h index 95cf0c4e48..fd0777ee7f 100644 --- a/cpp/models/abm/infection.h +++ b/cpp/models/abm/infection.h @@ -183,6 +183,7 @@ class Infection ViralLoad m_viral_load; ///< ViralLoad of the Infection. ScalarType m_log_norm_alpha, m_log_norm_beta; ///< Parameters for the infectivity mapping, which is modelled through an invlogit function. + ScalarType m_individual_virus_shed_factor; ///< Individual virus shed factor. bool m_detected; ///< Whether an Infection is detected or not. }; diff --git a/cpp/models/abm/mobility_rules.cpp b/cpp/models/abm/mobility_rules.cpp index 32858037f7..9949b92384 100644 --- a/cpp/models/abm/mobility_rules.cpp +++ b/cpp/models/abm/mobility_rules.cpp @@ -21,6 +21,7 @@ #include "abm/person.h" #include "abm/random_events.h" #include "abm/location_type.h" +#include "abm/parameters.h" namespace mio { @@ -34,7 +35,7 @@ LocationType random_mobility(PersonalRandomNumberGenerator& rng, const Person& p auto make_transition = [current_loc](auto l) { return std::make_pair(l, l == current_loc ? 0. : 1.); }; - if (t < params.get()) { + if (t < params.get()) { return random_transition(rng, current_loc, dt, {make_transition(LocationType::Work), make_transition(LocationType::Home), make_transition(LocationType::School), make_transition(LocationType::SocialEvent), @@ -48,7 +49,7 @@ LocationType go_to_school(PersonalRandomNumberGenerator& /*rng*/, const Person& { auto current_loc = person.get_location_type(); - if (current_loc == LocationType::Home && t < params.get() && t.day_of_week() < 5 && + if (current_loc == LocationType::Home && t < params.get() && t.day_of_week() < 5 && person.get_go_to_school_time(params) >= t.time_since_midnight() && person.get_go_to_school_time(params) < t.time_since_midnight() + dt && params.get()[person.get_age()] && person.goes_to_school(t, params) && @@ -67,7 +68,7 @@ LocationType go_to_work(PersonalRandomNumberGenerator& /*rng*/, const Person& pe { auto current_loc = person.get_location_type(); - if (current_loc == LocationType::Home && t < params.get() && + if (current_loc == LocationType::Home && t < params.get() && params.get()[person.get_age()] && t.day_of_week() < 5 && t.time_since_midnight() + dt > person.get_go_to_work_time(params) && t.time_since_midnight() <= person.get_go_to_work_time(params) && person.goes_to_work(t, params) && @@ -105,7 +106,7 @@ LocationType go_to_event(PersonalRandomNumberGenerator& rng, const Person& perso { auto current_loc = person.get_location_type(); //leave - if (current_loc == LocationType::Home && t < params.get() && + if (current_loc == LocationType::Home && t < params.get() && ((t.day_of_week() <= 4 && t.hour_of_day() >= 19) || (t.day_of_week() >= 5 && t.hour_of_day() >= 10)) && !person.is_in_quarantine(t, params)) { return random_transition(rng, current_loc, dt, diff --git a/cpp/models/abm/model.cpp b/cpp/models/abm/model.cpp index 4c8ba1f7f5..b143b91976 100755 --- a/cpp/models/abm/model.cpp +++ b/cpp/models/abm/model.cpp @@ -21,6 +21,7 @@ #include "abm/location_id.h" #include "abm/location_type.h" #include "abm/intervention_type.h" +#include "abm/model_functions.h" #include "abm/person.h" #include "abm/location.h" #include "abm/mobility_rules.h" @@ -28,6 +29,8 @@ #include "memilio/utils/logging.h" #include "memilio/utils/mioomp.h" #include "memilio/utils/stl_util.h" +#include +#include #include namespace mio @@ -247,9 +250,10 @@ void Model::compute_exposure_caches(TimePoint t, TimeSpan dt) const auto num_persons = m_persons.size(); // 1) reset all cached values - // Note: we cannot easily reuse values, as they are time dependant (get_infection_state) + // Note: we cannot easily reuse values, as they are time dependent (get_infection_state) PRAGMA_OMP(taskloop) for (size_t i = 0; i < num_locations; ++i) { + mio::abm::adjust_contact_rates(m_locations[i], parameters.get_num_groups()); const auto index = i; auto& local_air_exposure = m_air_exposure_rates_cache[index]; std::for_each(local_air_exposure.begin(), local_air_exposure.end(), [](auto& r) { @@ -271,6 +275,20 @@ void Model::compute_exposure_caches(TimePoint t, TimeSpan dt) m_contact_exposure_rates_cache[location], person, get_location(person.get_id()), t, dt); } // implicit taskloop barrier + //normalize cached exposure rates + for (size_t i = 0; i < num_locations; ++i) { + for (auto age_group = AgeGroup(0); age_group < AgeGroup(parameters.get_num_groups()); ++age_group) { + auto num_persons_in_location = + std::count_if(m_persons.begin(), m_persons.end(), [age_group](Person& p) { + return p.get_age() == age_group; + }); + if (num_persons_in_location > 0) { + for (auto& v : m_contact_exposure_rates_cache[i].slice(AgeGroup(age_group))) { + v = v / num_persons_in_location; + } + } + } + } } // implicit single barrier } diff --git a/cpp/models/abm/model_functions.cpp b/cpp/models/abm/model_functions.cpp index 9e8a0ae78e..02e4be4a14 100644 --- a/cpp/models/abm/model_functions.cpp +++ b/cpp/models/abm/model_functions.cpp @@ -77,11 +77,10 @@ void interact(PersonalRandomNumberGenerator& personal_rng, Person& person, const for (uint32_t v = 0; v != static_cast(VirusVariant::Count); ++v) { VirusVariant virus = static_cast(v); ScalarType local_indiv_trans_prob_v = - (std::min(local_parameters.get(), - daily_transmissions_by_contacts(local_contact_exposure, cell_index, virus, age_receiver, - local_parameters)) + + (daily_transmissions_by_contacts(local_contact_exposure, cell_index, virus, age_receiver, + local_parameters) + daily_transmissions_by_air(local_air_exposure, cell_index, virus, global_parameters)) * - dt.days() * (1 - mask_protection) * (1 - person.get_protection_factor(t, virus, global_parameters)); + (1 - mask_protection) * (1 - person.get_protection_factor(t, virus, global_parameters)); local_indiv_trans_prob[v] = std::make_pair(virus, local_indiv_trans_prob_v); } @@ -146,5 +145,26 @@ bool change_location(Person& person, const Location& destination, const Transpor } } +void adjust_contact_rates(Location& location, size_t num_agegroups) +{ + if (location.get_infection_parameters().get() == std::numeric_limits::max()) { + return; + } + for (auto contact_from = AgeGroup(0); contact_from < AgeGroup(num_agegroups); contact_from++) { + ScalarType total_contacts = 0.; + // slizing would be preferred but is problematic since both Tags of ContactRates are AgeGroup + for (auto contact_to = AgeGroup(0); contact_to < AgeGroup(num_agegroups); contact_to++) { + total_contacts += location.get_infection_parameters().get()[{contact_from, contact_to}]; + } + if (total_contacts > location.get_infection_parameters().get()) { + for (auto contact_to = AgeGroup(0); contact_to < AgeGroup(num_agegroups); contact_to++) { + location.get_infection_parameters().get()[{contact_from, contact_to}] = + location.get_infection_parameters().get()[{contact_from, contact_to}] * + location.get_infection_parameters().get() / total_contacts; + } + } + } +} + } // namespace abm } // namespace mio diff --git a/cpp/models/abm/model_functions.h b/cpp/models/abm/model_functions.h index 757d7e77d2..e7b0df3d27 100644 --- a/cpp/models/abm/model_functions.h +++ b/cpp/models/abm/model_functions.h @@ -93,6 +93,15 @@ void interact(PersonalRandomNumberGenerator& personal_rng, Person& person, const bool change_location(Person& person, const Location& destination, const TransportMode mode = TransportMode::Unknown, const std::vector& cells = {0}); +/** + * @brief Adjust ContactRates of location by MaximumContacts. + * Every ContactRate is adjusted by the proportion MaximumContacts of the location has on the total + * number of contacts according to the ContactRates. + * @param[in, out] location The location whose ContactRates are adjusted. + * @param[in] num_agegroup The number of AgeGroups in the model. + */ +void adjust_contact_rates(Location& location, size_t num_agegroups); + } // namespace abm } // namespace mio diff --git a/cpp/models/abm/parameters.h b/cpp/models/abm/parameters.h index 204e2eb77d..d3472aa7d7 100644 --- a/cpp/models/abm/parameters.h +++ b/cpp/models/abm/parameters.h @@ -1,7 +1,7 @@ /* * Copyright (C) 2020-2024 MEmilio * -* Authors: Daniel Abele, Elisabeth Kluth, Khoa Nguyen +* Authors: Daniel Abele, Elisabeth Kluth, Khoa Nguyen, David Kerkmann, Julia Bicker * * Contact: Martin J. Kuehn * @@ -24,12 +24,16 @@ #include "abm/time.h" #include "abm/virus_variant.h" #include "abm/protection_event.h" +#include "abm/protection_event.h" #include "abm/test_type.h" #include "memilio/config.h" #include "memilio/io/default_serialize.h" #include "memilio/io/io.h" #include "memilio/math/time_series_functor.h" +#include "memilio/utils/parameter_distribution_wrapper.h" #include "memilio/utils/custom_index_array.h" +#include "memilio/utils/logging.h" +#include "memilio/utils/parameter_distributions.h" #include "memilio/utils/uncertain_value.h" #include "memilio/utils/parameter_set.h" #include "memilio/utils/index_range.h" @@ -46,14 +50,18 @@ namespace mio namespace abm { +// Distribution that can be used for the time spend in InfectionStates +using InfectionStateTimesDistributionsParameters = LogNormalDistribution::ParamType; + /** - * @brief Time that a Person is infected but not yet infectious. + * @brief Time that a Person is infected but not yet infectious in day unit */ struct IncubationPeriod { - using Type = CustomIndexArray, VirusVariant, AgeGroup>; + using Type = CustomIndexArray; static Type get_default(AgeGroup size) { - return Type({VirusVariant::Count, size}, 1.); + ParameterDistributionLogNormal log_norm(1., 1.); + return Type({VirusVariant::Count, size}, ParameterDistributionWrapper(log_norm)); } static std::string name() { @@ -61,113 +69,196 @@ struct IncubationPeriod { } }; -struct InfectedNoSymptomsToSymptoms { - using Type = CustomIndexArray, VirusVariant, AgeGroup>; +/** +* @brief Time that a Person is infected but presymptomatic in day unit +*/ +struct TimeInfectedNoSymptomsToSymptoms { + using Type = CustomIndexArray; static Type get_default(AgeGroup size) { - return Type({VirusVariant::Count, size}, 1.); + ParameterDistributionLogNormal log_norm(1., 1.); + return Type({VirusVariant::Count, size}, ParameterDistributionWrapper(log_norm)); } static std::string name() { - return "InfectedNoSymptomsToSymptoms"; + return "TimeInfectedNoSymptomsToSymptoms"; } }; -struct InfectedNoSymptomsToRecovered { - using Type = CustomIndexArray, VirusVariant, AgeGroup>; +/** +* @brief Time that a Person is infected when staying asymptomatic in day unit +*/ +struct TimeInfectedNoSymptomsToRecovered { + using Type = CustomIndexArray; static Type get_default(AgeGroup size) { - return Type({VirusVariant::Count, size}, 1.); + ParameterDistributionLogNormal log_norm(1., 1.); + return Type({VirusVariant::Count, size}, ParameterDistributionWrapper(log_norm)); } static std::string name() { - return "InfectedNoSymptomsToRecovered"; + return "TimeInfectedNoSymptomsToRecovered"; } }; -struct InfectedSymptomsToRecovered { - using Type = CustomIndexArray, VirusVariant, AgeGroup>; +/** +* @brief Time that a Person is infected and symptomatic but +* who do not need to be hospitalized yet in day unit +*/ +struct TimeInfectedSymptomsToSevere { + using Type = CustomIndexArray; static Type get_default(AgeGroup size) { - return Type({VirusVariant::Count, size}, 1.); + ParameterDistributionLogNormal log_norm(1., 1.); + + return Type({VirusVariant::Count, size}, ParameterDistributionWrapper(log_norm)); } static std::string name() { - return "InfectedSymptomsToRecovered"; + return "TimeInfectedSymptomsToSevere"; } }; -struct InfectedSymptomsToSevere { - using Type = CustomIndexArray, VirusVariant, AgeGroup>; +/** +* @brief Time that a Person is infected and symptomatic who will recover in day unit +*/ +struct TimeInfectedSymptomsToRecovered { + using Type = CustomIndexArray; static Type get_default(AgeGroup size) { - return Type({VirusVariant::Count, size}, 1.); + ParameterDistributionLogNormal log_norm(1., 1.); + return Type({VirusVariant::Count, size}, ParameterDistributionWrapper(log_norm)); } static std::string name() { - return "InfectedSymptomsToSevere"; + return "TimeInfectedSymptomsToRecovered"; } }; -struct SevereToCritical { - using Type = CustomIndexArray, VirusVariant, AgeGroup>; +/** + * @brief Time that a Person is infected and 'simply' hospitalized before becoming critical in day unit + */ +struct TimeInfectedSevereToCritical { + using Type = CustomIndexArray; static Type get_default(AgeGroup size) { - return Type({VirusVariant::Count, size}, 1.); + ParameterDistributionLogNormal log_norm(1., 1.); + return Type({VirusVariant::Count, size}, ParameterDistributionWrapper(log_norm)); + } + static std::string name() + { + return "TimeInfectedSevereToCritical"; + } +}; + +/** + * @brief Time that a Person is infected and 'simply' hospitalized before recovering in day unit + */ +struct TimeInfectedSevereToRecovered { + using Type = CustomIndexArray; + static Type get_default(AgeGroup size) + { + ParameterDistributionLogNormal log_norm(1., 1.); + return Type({VirusVariant::Count, size}, ParameterDistributionWrapper(log_norm)); + } + static std::string name() + { + return "TimeInfectedSevereToRecovered"; + } +}; + +/** + * @brief Time that a Person is treated by ICU before dying in day unit + */ +struct TimeInfectedCriticalToDead { + using Type = CustomIndexArray; + static Type get_default(AgeGroup size) + { + ParameterDistributionLogNormal log_norm(1., 1.); + return Type({VirusVariant::Count, size}, ParameterDistributionWrapper(log_norm)); } static std::string name() { - return "SevereToCritical"; + return "TimeInfectedCriticalToDead"; } }; -struct SevereToRecovered { +/** + * @brief Time that a Person is treated by ICU before recovering in day unit + */ +struct TimeInfectedCriticalToRecovered { + using Type = CustomIndexArray; + static Type get_default(AgeGroup size) + { + ParameterDistributionLogNormal log_norm(1., 1.); + return Type({VirusVariant::Count, size}, ParameterDistributionWrapper(log_norm)); + } + static std::string name() + { + return "TimeInfectedCriticalToRecovered"; + } +}; + +/** +* @brief the percentage of symptomatic cases +*/ +struct SymptomsPerInfectedNoSymptoms { using Type = CustomIndexArray, VirusVariant, AgeGroup>; static Type get_default(AgeGroup size) { - return Type({VirusVariant::Count, size}, 1.); + return Type({VirusVariant::Count, size}, .5); } static std::string name() { - return "SevereToRecovered"; + return "SymptomaticPerInfectedNoSymptoms"; } }; -struct CriticalToRecovered { +/** +* @brief the percentage of hospitalized cases per infected cases +*/ +struct SeverePerInfectedSymptoms { using Type = CustomIndexArray, VirusVariant, AgeGroup>; static Type get_default(AgeGroup size) { - return Type({VirusVariant::Count, size}, 1.); + return Type({VirusVariant::Count, size}, .5); } static std::string name() { - return "CriticalToRecovered"; + return "SeverePerInfectedSymptoms"; } }; -struct CriticalToDead { +/** +* @brief the percentage of ICU cases per hospitalized cases +*/ +struct CriticalPerInfectedSevere { using Type = CustomIndexArray, VirusVariant, AgeGroup>; static Type get_default(AgeGroup size) { - return Type({VirusVariant::Count, size}, 1.); + return Type({VirusVariant::Count, size}, .5); } static std::string name() { - return "CriticalToDead"; + return "CriticalPerInfectedSevere"; } }; -struct RecoveredToSusceptible { +/** +* @brief the percentage of dead cases per ICU cases +*/ +struct DeathsPerInfectedCritical { using Type = CustomIndexArray, VirusVariant, AgeGroup>; static Type get_default(AgeGroup size) { - return Type({VirusVariant::Count, size}, 1.); + return Type({VirusVariant::Count, size}, .5); } static std::string name() { - return "RecoveredToSusceptible"; + return "DeathsPerInfectedCritical"; } }; + /** * @brief Parameters for the ViralLoad course. Default values taken as constant values from the average from * https://github.com/VirologyCharite/SARS-CoV-2-VL-paper/tree/main @@ -232,6 +323,22 @@ struct InfectivityDistributions { } }; +/** + * @brief Individual virus shed factor to account for variability in infectious viral load spread. +*/ +struct VirusShedFactor { + using Type = CustomIndexArray::ParamType, VirusVariant, AgeGroup>; + static Type get_default(AgeGroup size) + { + Type default_val({VirusVariant::Count, size}, UniformDistribution::ParamType{0., 0.28}); + return default_val; + } + static std::string name() + { + return "VirusShedFactor"; + } +}; + /** * @brief Probability that an Infection is detected. */ @@ -550,14 +657,15 @@ struct AgeGroupGotoWork { }; using ParametersBase = - ParameterSet; + ParameterSet; /** * @brief Maximum number of Person%s an infectious Person can infect at the respective Location. @@ -593,7 +701,7 @@ struct ContactRates { // If true, consider the capacity of the Cell%s of this Location for the computation of relative transmission risk. struct UseLocationCapacityForTransmissions { using Type = bool; - static Type get_default(AgeGroup) + static Type get_default(AgeGroup /*size*/) { return false; } @@ -643,116 +751,158 @@ class Parameters : public ParametersBase */ bool check_constraints() const { - for (auto age_group : make_index_range(AgeGroup{m_num_groups})) { - for (auto virus_variant : enum_members()) { + for (auto i = AgeGroup(0); i < AgeGroup(m_num_groups); ++i) { + for (auto&& v : enum_members()) { - if (this->get()[{virus_variant, age_group}] < 0) { - log_error("Constraint check: Parameter IncubationPeriod of age group {:.0f} smaller than {:.4f}", - (size_t)age_group, 0); + if (this->get()[{v, i}].params()[0] < 0) { + log_error("Constraint check: Mean of parameter IncubationPeriod of virus variant {} and " + "age group {:.0f} smaller " + "than {:.4f}", + (uint32_t)v, (size_t)i, 0); return true; } - if (this->get()[{virus_variant, age_group}] < 0.0) { - log_error("Constraint check: Parameter InfectedNoSymptomsToSymptoms of age group {:.0f} smaller " + if (this->get()[{v, i}].params()[0] < 0.0) { + log_error("Constraint check: Mean of parameter TimeInfectedNoSymptomsToSymptoms " + "of virus variant " + "{} and age group {:.0f} smaller " "than {:d}", - (size_t)age_group, 0); + (uint32_t)v, (size_t)i, 0); return true; } - if (this->get()[{virus_variant, age_group}] < 0.0) { - log_error("Constraint check: Parameter InfectedNoSymptomsToRecovered of age group {:.0f} smaller " + if (this->get()[{v, i}].params()[0] < 0.0) { + log_error("Constraint check: Mean of parameter TimeInfectedNoSymptomsToRecovered of " + "virus variant " + "{} and age group {:.0f} smaller " "than {:d}", - (size_t)age_group, 0); + (uint32_t)v, (size_t)i, 0); return true; } - if (this->get()[{virus_variant, age_group}] < 0.0) { - log_error( - "Constraint check: Parameter InfectedSymptomsToRecovered of age group {:.0f} smaller than {:d}", - (size_t)age_group, 0); + if (this->get()[{v, i}].params()[0] < 0.0) { + log_error("Constraint check: Mean of parameter TimeInfectedSymptomsToSevere of virus " + "variant {} " + "and age group {:.0f} smaller " + "than {:d}", + (uint32_t)v, (size_t)i, 0); return true; } - if (this->get()[{virus_variant, age_group}] < 0.0) { - log_error( - "Constraint check: Parameter InfectedSymptomsToSevere of age group {:.0f} smaller than {:d}", - (size_t)age_group, 0); + if (this->get()[{v, i}].params()[0] < 0.0) { + log_error("Constraint check: Mean of parameter TimeInfectedSymptomsToRecovered of virus " + "variant {} " + "and age group {:.0f} smaller " + "than {:d}", + (uint32_t)v, (size_t)i, 0); return true; } - if (this->get()[{virus_variant, age_group}] < 0.0) { - log_error("Constraint check: Parameter SevereToCritical of age group {:.0f} smaller than {:d}", - (size_t)age_group, 0); + if (this->get()[{v, i}].params()[0] < 0.0) { + log_error("Constraint check: Mean of parameter TimeInfectedSevereToCritical of virus " + "variant {} " + "and age group {:.0f} smaller " + "than {:d}", + (uint32_t)v, (size_t)i, 0); + return true; + } + + if (this->get()[{v, i}].params()[0] < 0.0) { + log_error("Constraint check: Mean of parameter TimeInfectedSevereToRecovered of virus " + "variant {} " + "and age group {:.0f} smaller " + "than {:d}", + (uint32_t)v, (size_t)i, 0); + return true; + } + + if (this->get()[{v, i}].params()[0] < 0.0) { + log_error("Constraint check: Mean of parameter TimeInfectedCriticalToDead of virus variant {} " + "and age group {:.0f} smaller " + "than {:d}", + (uint32_t)v, (size_t)i, 0); + return true; + } + + if (this->get()[{v, i}].params()[0] < 0.0) { + log_error("Constraint check: Mean of parameter TimeInfectedCriticalToRecovered of virus " + "variant {} " + "and age group {:.0f} smaller " + "than {:d}", + (uint32_t)v, (size_t)i, 0); return true; } - if (this->get()[{virus_variant, age_group}] < 0.0) { - log_error("Constraint check: Parameter SevereToRecovered of age group {:.0f} smaller than {:d}", - (size_t)age_group, 0); + if (this->get()[{v, i}] < 0.0 || + this->get()[{v, i}] > 1.0) { + log_error("Constraint check: Parameter SymptomsPerInfectedNoSymptoms of virus variant {} and age " + "group {:.0f} smaller than {:d} or larger than {:d}", + (uint32_t)v, (size_t)i, 0, 1); return true; } - if (this->get()[{virus_variant, age_group}] < 0.0) { - log_error("Constraint check: Parameter CriticalToDead of age group {:.0f} smaller than {:d}", - (size_t)age_group, 0); + if (this->get()[{v, i}] < 0.0 || + this->get()[{v, i}] > 1.0) { + log_error("Constraint check: Parameter SeverePerInfectedSymptoms of virus variant {} and age group " + "{:.0f} smaller than {:d} or larger than {:d}", + (uint32_t)v, (size_t)i, 0, 1); return true; } - if (this->get()[{virus_variant, age_group}] < 0.0) { - log_error("Constraint check: Parameter CriticalToRecovered of age group {:.0f} smaller than {:d}", - (size_t)age_group, 0); + if (this->get()[{v, i}] < 0.0 || + this->get()[{v, i}] > 1.0) { + log_error("Constraint check: Parameter CriticalPerInfectedSevere of virus variant {} and age group " + "{:.0f} smaller than {:d} or larger than {:d}", + (uint32_t)v, (size_t)i, 0, 1); return true; } - if (this->get()[{virus_variant, age_group}] < 0.0) { - log_error( - "Constraint check: Parameter RecoveredToSusceptible of age group {:.0f} smaller than {:d}", - (size_t)age_group, 0); + if (this->get()[{v, i}] < 0.0 || + this->get()[{v, i}] > 1.0) { + log_error("Constraint check: Parameter DeathsPerInfectedCritical of age group {:.0f} smaller than " + "{:d} or larger than {:d}", + (uint32_t)v, (size_t)i, 0, 1); return true; } - if (this->get()[{virus_variant, age_group}] < 0.0 || - this->get()[{virus_variant, age_group}] > 1.0) { - log_error("Constraint check: Parameter DetectInfection of age group {:.0f} smaller than {:d} or " + if (this->get()[{v, i}] < 0.0 || this->get()[{v, i}] > 1.0) { + log_error("Constraint check: Parameter DetectInfection of virus variant {} and age group {:.0f} " + "smaller than {:d} or " "larger than {:d}", - (size_t)age_group, 0, 1); + (uint32_t)v, (size_t)i, 0, 1); return true; } } - if (this->get()[age_group].seconds() < 0.0 || - this->get()[age_group].seconds() > - this->get()[age_group].seconds()) { + if (this->get()[i].seconds() < 0.0 || + this->get()[i].seconds() > this->get()[i].seconds()) { log_error("Constraint check: Parameter GotoWorkTimeMinimum of age group {:.0f} smaller {:d} or " "larger {:d}", - (size_t)age_group, 0, this->get()[age_group].seconds()); + (size_t)i, 0, this->get()[i].seconds()); return true; } - if (this->get()[age_group].seconds() < - this->get()[age_group].seconds() || - this->get()[age_group] > days(1)) { + if (this->get()[i].seconds() < this->get()[i].seconds() || + this->get()[i] > days(1)) { log_error("Constraint check: Parameter GotoWorkTimeMaximum of age group {:.0f} smaller {:d} or larger " "than one day time span", - (size_t)age_group, this->get()[age_group].seconds()); + (size_t)i, this->get()[i].seconds()); return true; } - if (this->get()[age_group].seconds() < 0.0 || - this->get()[age_group].seconds() > - this->get()[age_group].seconds()) { + if (this->get()[i].seconds() < 0.0 || + this->get()[i].seconds() > this->get()[i].seconds()) { log_error("Constraint check: Parameter GotoSchoolTimeMinimum of age group {:.0f} smaller {:d} or " "larger {:d}", - (size_t)age_group, 0, this->get()[age_group].seconds()); + (size_t)i, 0, this->get()[i].seconds()); return true; } - if (this->get()[age_group].seconds() < - this->get()[age_group].seconds() || - this->get()[age_group] > days(1)) { + if (this->get()[i].seconds() < this->get()[i].seconds() || + this->get()[i] > days(1)) { log_error("Constraint check: Parameter GotoWorkTimeMaximum of age group {:.0f} smaller {:d} or larger " "than one day time span", - (size_t)age_group, this->get()[age_group].seconds()); + (size_t)i, this->get()[i].seconds()); return true; } } diff --git a/cpp/models/abm/personal_rng.h b/cpp/models/abm/personal_rng.h index 654e8dea5a..9164091e78 100644 --- a/cpp/models/abm/personal_rng.h +++ b/cpp/models/abm/personal_rng.h @@ -22,7 +22,7 @@ #define MIO_ABM_PERSONAL_RNG_H #include "memilio/utils/random_number_generator.h" -#include "abm/person_id.h" +#include "models/abm/person_id.h" namespace mio { diff --git a/cpp/models/ode_secir/parameter_space.h b/cpp/models/ode_secir/parameter_space.h index 2c3ca1148d..26bfad910e 100644 --- a/cpp/models/ode_secir/parameter_space.h +++ b/cpp/models/ode_secir/parameter_space.h @@ -47,12 +47,26 @@ template void set_params_distributions_normal(Model& model, double t0, double tmax, double dev_rel) { auto set_distribution = [dev_rel](UncertainValue& v, double min_val = 0.001) { - v.set_distribution(ParameterDistributionNormal( - //add add limits for nonsense big values. Also mscv has a problem with a few doubles so this fixes it - std::min(std::max(min_val, (1 - dev_rel * 2.6) * v), 0.1 * std::numeric_limits::max()), - std::min(std::max(min_val, (1 + dev_rel * 2.6) * v), 0.5 * std::numeric_limits::max()), - std::min(std::max(min_val, double(v)), 0.3 * std::numeric_limits::max()), - std::min(std::max(min_val, dev_rel * v), std::numeric_limits::max()))); + auto lower_bound = + std::min(std::max(min_val, (1 - dev_rel * 2.6) * v), 0.1 * std::numeric_limits::max()); + auto upper_bound = + std::min(std::max(min_val, (1 + dev_rel * 2.6) * v), 0.5 * std::numeric_limits::max()); + + if (mio::floating_point_equal(lower_bound, upper_bound, mio::Limits::zero_tolerance())) { + //MSVC has problems if standard deviation for normal distribution is zero + mio::log_debug("Bounded ParameterDistribution has standard deviation close to zero. Therefore constant " + "distribution is used."); + v.set_distribution(ParameterDistributionConstant( + std::min(std::max(min_val, double(v)), 0.3 * std::numeric_limits::max()))); + } + else { + v.set_distribution(ParameterDistributionNormal( + //add add limits for nonsense big values. Also mscv has a problem with a few doubles so this fixes it + std::min(std::max(min_val, (1 - dev_rel * 2.6) * v), 0.1 * std::numeric_limits::max()), + std::min(std::max(min_val, (1 + dev_rel * 2.6) * v), 0.5 * std::numeric_limits::max()), + std::min(std::max(min_val, double(v)), 0.3 * std::numeric_limits::max()), + std::min(std::max(min_val, dev_rel * v), std::numeric_limits::max()))); + } }; set_distribution(model.parameters.template get>(), 0.0); diff --git a/cpp/models/ode_secirvvs/parameters.h b/cpp/models/ode_secirvvs/parameters.h index a69f96bb79..6a382ce7f8 100644 --- a/cpp/models/ode_secirvvs/parameters.h +++ b/cpp/models/ode_secirvvs/parameters.h @@ -234,7 +234,7 @@ struct TimeInfectedNoSymptoms { /** * @brief the infectious time for symptomatic cases that are infected but -* who do not need to be hsopitalized in the SECIR model in day unit +* who do not need to be hospitalized in the SECIR model in day unit */ template struct TimeInfectedSymptoms { diff --git a/cpp/simulations/abm.cpp b/cpp/simulations/abm.cpp index ba25076ac1..b9c7ad15d9 100644 --- a/cpp/simulations/abm.cpp +++ b/cpp/simulations/abm.cpp @@ -22,7 +22,10 @@ #include "abm/household.h" #include "abm/lockdown_rules.h" #include "memilio/config.h" +#include "memilio/utils/parameter_distribution_wrapper.h" +#include "memilio/config.h" #include "memilio/io/result_io.h" +#include "memilio/utils/parameter_distributions.h" #include "memilio/utils/random_number_generator.h" #include "memilio/utils/uncertain_value.h" @@ -466,156 +469,238 @@ void set_parameters(mio::abm::Parameters params) // Set the age group the can go to work is AgeGroup(2) and AgeGroup(3) (i.e. 15-34 and 35-59) params.get().set_multiple({age_group_15_to_34, age_group_35_to_59}, true); - params.set({{mio::abm::VirusVariant::Count, mio::AgeGroup(num_age_groups)}, 4.}); + mio::ParameterDistributionLogNormal log_norm(4., 1.); + params.set( + {{mio::abm::VirusVariant::Count, mio::AgeGroup(num_age_groups)}, mio::ParameterDistributionWrapper(log_norm)}); //0-4 - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 0.276; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 0.092; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 0.142; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 0.001; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 0.186; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 0.015; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 0.143; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 0.001; + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = + mio::ParameterDistributionWrapper(log_norm); //5-14 - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = 0.276; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = - 0.092; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = 0.142; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = 0.001; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = 0.186; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = 0.015; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = 0.143; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = 0.001; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = 0.; + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = + mio::ParameterDistributionWrapper(log_norm); //15-34 - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = - 0.315; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = - 0.079; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = 0.139; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = 0.003; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = 0.157; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = 0.013; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = 0.126; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = 0.021; + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = + mio::ParameterDistributionWrapper(log_norm); //35-59 - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = - 0.315; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = - 0.079; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = 0.136; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = 0.009; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = 0.113; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = 0.02; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = 0.05; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = 0.008; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = 0.; + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = + mio::ParameterDistributionWrapper(log_norm); //60-79 - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = - 0.315; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = - 0.079; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = 0.123; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = 0.024; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = 0.083; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = 0.035; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = 0.035; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = 0.023; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = 0.; + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = + mio::ParameterDistributionWrapper(log_norm); //80+ - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = 0.315; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = - 0.079; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = 0.115; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = 0.033; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = 0.055; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = 0.036; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = 0.035; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = 0.052; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = 0.; + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = + mio::ParameterDistributionWrapper(log_norm); // Set each parameter for vaccinated people including personal infection and vaccine protection levels. // Summary: https://doi.org/10.1038/s41577-021-00550-x, //0-4 - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 0.161; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 0.132; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 0.143; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 0.001; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 0.186; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 0.015; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 0.143; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 0.001; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 0.0; + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = + mio::ParameterDistributionWrapper(log_norm); //5-14 - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = 0.161; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = - 0.132; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = 0.143; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = 0.001; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = 0.186; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = 0.015; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = 0.143; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = 0.001; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = 0.0; + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = + mio::ParameterDistributionWrapper(log_norm); //15-34 - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = - 0.179; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = - 0.126; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = 0.142; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = 0.001; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = 0.157; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = 0.013; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = 0.126; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = 0.021; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = 0.0; + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = + mio::ParameterDistributionWrapper(log_norm); //35-59 - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = - 0.179; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = - 0.126; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = 0.141; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = 0.003; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = 0.113; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = 0.02; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = 0.05; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = 0.008; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = 0.0; + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = + mio::ParameterDistributionWrapper(log_norm); //60-79 - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = - 0.179; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = - 0.126; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = 0.136; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = 0.009; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = 0.083; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = 0.035; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = 0.035; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = 0.023; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = 0.0; + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = + mio::ParameterDistributionWrapper(log_norm); //80+ - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = 0.179; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = - 0.126; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = 0.133; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = 0.012; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = 0.055; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = 0.036; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = 0.035; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = 0.052; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = 0.0; + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = + mio::ParameterDistributionWrapper(log_norm); } /** diff --git a/cpp/simulations/abm_braunschweig.cpp b/cpp/simulations/abm_braunschweig.cpp index a8be3bb5f9..9434cc97a9 100644 --- a/cpp/simulations/abm_braunschweig.cpp +++ b/cpp/simulations/abm_braunschweig.cpp @@ -21,12 +21,15 @@ #include "abm/location_id.h" #include "abm/lockdown_rules.h" #include "abm/parameters.h" +#include "abm/parameters.h" #include "abm/person.h" #include "abm/simulation.h" #include "abm/model.h" +#include "memilio/utils/parameter_distribution_wrapper.h" #include "memilio/epidemiology/age_group.h" #include "memilio/io/io.h" #include "memilio/io/result_io.h" +#include "memilio/utils/parameter_distributions.h" #include "memilio/utils/uncertain_value.h" #include "boost/algorithm/string/split.hpp" #include "boost/algorithm/string/classification.hpp" @@ -396,94 +399,121 @@ void set_parameters(mio::abm::Parameters params) params.get()[age_group_5_to_14] = true; // Set the age group the can go to work is AgeGroup(2) and AgeGroup(3) (i.e. 15-34 and 35-59) params.get().set_multiple({age_group_15_to_34, age_group_35_to_59}, true); - - params.set({{mio::abm::VirusVariant::Count, mio::AgeGroup(num_age_groups)}, 4.}); + mio::ParameterDistributionLogNormal log_norm(4., 1.); + params.set( + {{mio::abm::VirusVariant::Count, mio::AgeGroup(num_age_groups)}, mio::ParameterDistributionWrapper(log_norm)}); //0-4 - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 0.276; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 0.092; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 0.142; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 0.001; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 0.186; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 0.015; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 0.143; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 0.001; + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = + mio::ParameterDistributionWrapper(log_norm); //5-14 - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = 0.276; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = - 0.092; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = 0.142; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = 0.001; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = 0.186; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = 0.015; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = 0.143; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = 0.001; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = 0.; + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = + mio::ParameterDistributionWrapper(log_norm); //15-34 - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = - 0.315; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = - 0.079; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = 0.139; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = 0.003; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = 0.157; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = 0.013; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = 0.126; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = 0.021; + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = + mio::ParameterDistributionWrapper(log_norm); //35-59 - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = - 0.315; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = - 0.079; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = 0.136; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = 0.009; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = 0.113; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = 0.02; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = 0.05; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = 0.008; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = 0.; + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = + mio::ParameterDistributionWrapper(log_norm); //60-79 - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = - 0.315; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = - 0.079; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = 0.123; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = 0.024; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = 0.083; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = 0.035; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = 0.035; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = 0.023; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = 0.; + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = + mio::ParameterDistributionWrapper(log_norm); //80+ - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = 0.315; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = - 0.079; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = 0.115; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = 0.033; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = 0.055; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = 0.036; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = 0.035; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = 0.052; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = 0.; + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = + mio::ParameterDistributionWrapper(log_norm); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = + mio::ParameterDistributionWrapper(log_norm); // Set each parameter for vaccinated people including personal infection and vaccine protection levels. // Summary: https://doi.org/10.1038/s41577-021-00550-x, - //0-4 - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 0.161; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 0.132; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 0.143; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 0.001; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 0.186; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 0.015; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 0.143; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 0.001; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 0.0; // Protection of reinfection is the same for all age-groups, based on: // https://doi.org/10.1016/S0140-6736(22)02465-5, https://doi.org/10.1038/s41591-021-01377-8 params.get()[{mio::abm::ProtectionType::NaturalInfection, age_group_0_to_4, @@ -533,17 +563,6 @@ void set_parameters(mio::abm::Parameters params) mio::TimeSeriesFunctorType::LinearInterpolation, {{0, 0.5}, {30, 0.88}, {60, 0.91}, {90, 0.98}, {120, 0.94}, {150, 0.88}, {450, 0.5}}}; - //5-14 - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = 0.161; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = - 0.132; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = 0.143; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = 0.001; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = 0.186; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = 0.015; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = 0.143; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = 0.001; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = 0.0; // Protection of reinfection is the same for all age-groups, based on: // https://doi.org/10.1016/S0140-6736(22)02465-5, https://doi.org/10.1038/s41591-021-01377-8 params.get()[{mio::abm::ProtectionType::NaturalInfection, age_group_5_to_14, @@ -592,18 +611,6 @@ void set_parameters(mio::abm::Parameters params) mio::TimeSeriesFunctorType::LinearInterpolation, {{0, 0.5}, {30, 0.88}, {60, 0.91}, {90, 0.98}, {120, 0.94}, {150, 0.88}, {450, 0.5}}}; - //15-34 - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = - 0.179; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = - 0.126; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = 0.142; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = 0.001; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = 0.157; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = 0.013; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = 0.126; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = 0.021; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = 0.0; // Set up personal infection and vaccine protection levels, based on: https://doi.org/10.1038/s41577-021-00550-x, https://doi.org/10.1038/s41591-021-01377-8 params.get()[{mio::abm::ProtectionType::NaturalInfection, age_group_15_to_34, mio::abm::VirusVariant::Wildtype}] = { @@ -650,18 +657,6 @@ void set_parameters(mio::abm::Parameters params) mio::TimeSeriesFunctorType::LinearInterpolation, {{0, 0.5}, {30, 0.88}, {60, 0.91}, {90, 0.98}, {120, 0.94}, {150, 0.88}, {180, 0.90}, {450, 0.5}}}; - //35-59 - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = - 0.179; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = - 0.126; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = 0.141; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = 0.003; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = 0.113; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = 0.02; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = 0.05; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = 0.008; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = 0.0; // Protection of reinfection is the same for all age-groups, based on: // https://doi.org/10.1016/S0140-6736(22)02465-5, https://doi.org/10.1038/s41591-021-01377-8 params.get()[{mio::abm::ProtectionType::NaturalInfection, age_group_35_to_59, @@ -709,17 +704,11 @@ void set_parameters(mio::abm::Parameters params) mio::TimeSeriesFunctorType::LinearInterpolation, {{0, 0.5}, {30, 0.88}, {60, 0.91}, {90, 0.98}, {120, 0.94}, {150, 0.88}, {180, 0.90}, {450, 0.5}}}; //60-79 - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = 0.179; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = - 0.126; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = 0.136; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = 0.009; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = 0.083; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = 0.035; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = 0.035; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = 0.023; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = 0.0; + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = 0.009; + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = 0.035; + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = 0.023; // Protection of reinfection is the same for all age-groups, based on: // https://doi.org/10.1016/S0140-6736(22)02465-5, https://doi.org/10.1038/s41591-021-01377-8 params.get()[{mio::abm::ProtectionType::NaturalInfection, age_group_60_to_79, @@ -767,17 +756,6 @@ void set_parameters(mio::abm::Parameters params) mio::TimeSeriesFunctorType::LinearInterpolation, {{0, 0.5}, {30, 0.91}, {60, 0.86}, {90, 0.91}, {120, 0.94}, {150, 0.95}, {180, 0.90}, {450, 0.5}}}; - //80+ - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = 0.179; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = - 0.126; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = 0.133; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = 0.012; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = 0.055; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = 0.036; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = 0.035; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = 0.052; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_80_plus}] = 0.0; // Protection of reinfection is the same for all age-groups, based on: // https://doi.org/10.1016/S0140-6736(22)02465-5, https://doi.org/10.1038/s41591-021-01377-8 params.get()[{mio::abm::ProtectionType::NaturalInfection, age_group_80_plus, diff --git a/cpp/tests/distributions_helpers.cpp b/cpp/tests/distributions_helpers.cpp index b81d1ffd66..888755605c 100644 --- a/cpp/tests/distributions_helpers.cpp +++ b/cpp/tests/distributions_helpers.cpp @@ -17,6 +17,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "memilio/utils/parameter_distributions.h" #include #include #include @@ -39,10 +40,6 @@ void check_distribution(const mio::ParameterDistribution& dist, const mio::Param EXPECT_THAT(self.get_mean(), FloatingPointEqual(p_other_normal_distribution->get_mean(), 1e-12, 1e-12)); EXPECT_THAT(self.get_standard_dev(), FloatingPointEqual(p_other_normal_distribution->get_standard_dev(), 1e-12, 1e-12)); - EXPECT_THAT(self.get_lower_bound(), - FloatingPointEqual(p_other_normal_distribution->get_lower_bound(), 1e-12, 1e-12)); - EXPECT_THAT(self.get_upper_bound(), - FloatingPointEqual(p_other_normal_distribution->get_upper_bound(), 1e-12, 1e-12)); EXPECT_EQ(self.get_predefined_samples().size(), p_other_normal_distribution->get_predefined_samples().size()); @@ -56,11 +53,6 @@ void check_distribution(const mio::ParameterDistribution& dist, const mio::Param auto p_other_uniform_distribution = dynamic_cast(&other); ASSERT_TRUE(p_other_uniform_distribution != nullptr); - EXPECT_THAT(self.get_lower_bound(), - FloatingPointEqual(p_other_uniform_distribution->get_lower_bound(), 1e-12, 1e-12)); - EXPECT_THAT(self.get_upper_bound(), - FloatingPointEqual(p_other_uniform_distribution->get_upper_bound(), 1e-12, 1e-12)); - EXPECT_EQ(self.get_predefined_samples().size(), p_other_uniform_distribution->get_predefined_samples().size()); for (size_t i = 0; i < self.get_predefined_samples().size(); i++) { @@ -69,6 +61,45 @@ void check_distribution(const mio::ParameterDistribution& dist, const mio::Param FloatingPointEqual(p_other_uniform_distribution->get_predefined_samples()[i], 1e-12, 1e-12)); } } + void visit(const mio::ParameterDistributionLogNormal& self) override + { + auto p_other_lognormal_distribution = dynamic_cast(&other); + ASSERT_TRUE(p_other_lognormal_distribution != nullptr); + + EXPECT_EQ(self.get_predefined_samples().size(), + p_other_lognormal_distribution->get_predefined_samples().size()); + for (size_t i = 0; i < self.get_predefined_samples().size(); i++) { + EXPECT_THAT( + self.get_predefined_samples()[i], + FloatingPointEqual(p_other_lognormal_distribution->get_predefined_samples()[i], 1e-12, 1e-12)); + } + } + void visit(const mio::ParameterDistributionExponential& self) override + { + auto p_other_exponential_distribution = dynamic_cast(&other); + ASSERT_TRUE(p_other_exponential_distribution != nullptr); + + EXPECT_EQ(self.get_predefined_samples().size(), + p_other_exponential_distribution->get_predefined_samples().size()); + for (size_t i = 0; i < self.get_predefined_samples().size(); i++) { + EXPECT_THAT( + self.get_predefined_samples()[i], + FloatingPointEqual(p_other_exponential_distribution->get_predefined_samples()[i], 1e-12, 1e-12)); + } + } + void visit(const mio::ParameterDistributionConstant& self) override + { + auto p_other_constant_distribution = dynamic_cast(&other); + ASSERT_TRUE(p_other_constant_distribution != nullptr); + + EXPECT_EQ(self.get_predefined_samples().size(), + p_other_constant_distribution->get_predefined_samples().size()); + for (size_t i = 0; i < self.get_predefined_samples().size(); i++) { + EXPECT_THAT( + self.get_predefined_samples()[i], + FloatingPointEqual(p_other_constant_distribution->get_predefined_samples()[i], 1e-12, 1e-12)); + } + } const mio::ParameterDistribution& other; }; diff --git a/cpp/tests/distributions_helpers.h b/cpp/tests/distributions_helpers.h index 01cddfae12..32dbffbd85 100644 --- a/cpp/tests/distributions_helpers.h +++ b/cpp/tests/distributions_helpers.h @@ -45,7 +45,7 @@ class MockParameterDistributionRef : public mio::ParameterDistributionNormal public: using mio::ParameterDistributionNormal::ParameterDistributionNormal; - double get_rand_sample() override + double get_rand_sample(mio::RandomNumberGenerator& /*rng*/) override { return mock->get_rand_sample(); } diff --git a/cpp/tests/test_abm_infection.cpp b/cpp/tests/test_abm_infection.cpp index 56e3f873fb..fd8c5258a1 100644 --- a/cpp/tests/test_abm_infection.cpp +++ b/cpp/tests/test_abm_infection.cpp @@ -2,6 +2,7 @@ * Copyright (C) 2020-2024 MEmilio * * Authors: David Kerkmann, Khoa Nguyen +* Authors: David Kerkmann, Khoa Nguyen * * Contact: Martin J. Kuehn * @@ -19,6 +20,7 @@ */ #include "abm/location_type.h" +#include "abm/parameters.h" #include "abm/person.h" #include "abm_helpers.h" #include "random_number_test.h" @@ -35,14 +37,21 @@ TEST_F(TestInfection, init) auto age_group_test = age_group_15_to_34; mio::abm::Location loc(mio::abm::LocationType::Hospital, 0); + params.get()[{virus_variant_test, age_group_test}] = {0.1, 0.2}; + //set up a personal RNG for infections //uses uniformdistribution but result doesn't matter, so init before the mock auto counter = mio::Counter(0); auto prng = mio::abm::PersonalRandomNumberGenerator(this->get_rng().get_key(), mio::abm::PersonId(0), counter); ScopedMockDistribution>>> mock_uniform_dist; + ScopedMockDistribution>>> + mock_logNormal_dist; + + //Distribution for state transitions EXPECT_CALL(mock_uniform_dist.get_mock(), invoke) - .Times(testing::AtLeast(7)) + .Times(testing::AtLeast(15)) + // 1st infection .WillOnce(testing::Return(0.4)) // Transition to Infected .WillOnce(testing::Return(0.6)) // Transition to Recovered .WillOnce(testing::Return(params.get()[{virus_variant_test, age_group_test}] @@ -55,8 +64,10 @@ TEST_F(TestInfection, init) .infectivity_alpha.params.a())) // Infectivity draws .WillOnce(testing::Return(params.get()[{virus_variant_test, age_group_test}] .infectivity_beta.params.a())) - .WillOnce(testing::Return(0.1)) // Transition to Infected - .WillOnce(testing::Return(0.1)) // Transition to Recovered + .WillOnce(testing::Return(params.get()[{virus_variant_test, age_group_test}] + .params.a())) // Virus Shed Factor + // 2nd infection + .WillOnce(testing::Return(1.0)) // Transition to Recovered .WillOnce(testing::Return(params.get()[{virus_variant_test, age_group_test}] .viral_load_peak.params.a())) // Viral load draws .WillOnce(testing::Return(params.get()[{virus_variant_test, age_group_test}] @@ -67,6 +78,20 @@ TEST_F(TestInfection, init) .infectivity_alpha.params.a())) // Infectivity draws .WillOnce(testing::Return(params.get()[{virus_variant_test, age_group_test}] .infectivity_beta.params.a())) + .WillOnce(testing::Return(params.get()[{virus_variant_test, age_group_test}] + .params.a())) // Virus Shed Factor + .WillRepeatedly(testing::Return(1.0)); + + //Distribution for stay times + EXPECT_CALL(mock_logNormal_dist.get_mock(), invoke) + // 1st infection + .WillOnce(testing::Return(1.)) // IncubationTime + .WillOnce(testing::Return(1.)) // TimeInfectedNoSymptomsToSymptoms + .WillOnce(testing::Return(1.)) // TimeInfectedSymptomsToRecovered + // 2nd infection + .WillOnce(testing::Return(1.0)) // TimeInfectedNoSymptomsToSymptoms + .WillOnce(testing::Return(1.0)) // IncubationTime + .WillOnce(testing::Return(1.0)) // TimeInfectedSymptomsToRecovered .WillRepeatedly(testing::Return(1.0)); auto infection = mio::abm::Infection(prng, mio::abm::VirusVariant::Wildtype, age_group_15_to_34, params, @@ -82,7 +107,7 @@ TEST_F(TestInfection, init) EXPECT_EQ(infection.get_infection_state(mio::abm::TimePoint(0) + mio::abm::days(1)), mio::abm::InfectionState::InfectedNoSymptoms); // Test infectivity at a specific time point - EXPECT_NEAR(infection.get_infectivity(mio::abm::TimePoint(0) + mio::abm::days(3)), 0.2689414213699951, 1e-14); + EXPECT_NEAR(infection.get_infectivity(mio::abm::TimePoint(0) + mio::abm::days(3)), 0.02689414213699951, 1e-14); // Test infection with previous exposure and recovery state transition. params.get()[{mio::abm::ProtectionType::GenericVaccine, age_group_test, @@ -103,7 +128,7 @@ TEST_F(TestInfection, init) mio::abm::InfectionState::Recovered); // Test infectivity at a specific time point EXPECT_NEAR(infection_w_previous_exp.get_infectivity(mio::abm::TimePoint(0) + mio::abm::days(3)), - 0.45760205922564895, 1e-14); + 9.1105119440064545e-05, 1e-14); } /** @@ -113,13 +138,13 @@ TEST_F(TestInfection, getInfectionState) { auto counter = mio::Counter(0); auto prng = mio::abm::PersonalRandomNumberGenerator(this->get_rng().get_key(), mio::abm::PersonId(0), counter); - auto params = mio::abm::Parameters(num_age_groups); - auto t = mio::abm::TimePoint(0); + auto params = mio::abm::Parameters(num_age_groups); + auto t = mio::abm::TimePoint(0); // Initialize infection in Exposed state auto infection = mio::abm::Infection(prng, mio::abm::VirusVariant::Wildtype, age_group_15_to_34, params, t, mio::abm::InfectionState::Exposed, - {mio::abm::ProtectionType::NoProtection, mio::abm::TimePoint(0)}, true); + {mio::abm::ProtectionType::NoProtection, mio::abm::TimePoint(0)}, true); // Test infection state at different time points EXPECT_EQ(infection.get_infection_state(t), mio::abm::InfectionState::Exposed); @@ -133,18 +158,19 @@ TEST_F(TestInfection, drawInfectionCourseForward) { auto counter = mio::Counter(0); auto prng = mio::abm::PersonalRandomNumberGenerator(this->get_rng().get_key(), mio::abm::PersonId(0), counter); - auto params = mio::abm::Parameters(num_age_groups); - auto t = mio::abm::TimePoint(0); + auto params = mio::abm::Parameters(num_age_groups); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = 0.; + auto t = mio::abm::TimePoint(0); // Mock recovery transition - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = 1; - ScopedMockDistribution>>> mock_uniform_dist; - EXPECT_CALL(mock_uniform_dist.get_mock(), invoke) + ScopedMockDistribution>>> + mock_logNormal_dist; + EXPECT_CALL(mock_logNormal_dist.get_mock(), invoke) .Times(testing::AtLeast(1)) .WillRepeatedly(testing::Return(0.8)); // Recovered auto infection = mio::abm::Infection(prng, mio::abm::VirusVariant::Wildtype, age_group_15_to_34, params, t, mio::abm::InfectionState::InfectedCritical, - {mio::abm::ProtectionType::NoProtection, mio::abm::TimePoint(0)}, true); + {mio::abm::ProtectionType::NoProtection, mio::abm::TimePoint(0)}, true); // Test state transitions from Critical to Recovered EXPECT_EQ(infection.get_infection_state(t), mio::abm::InfectionState::InfectedCritical); EXPECT_EQ(infection.get_infection_state(t + mio::abm::days(1)), mio::abm::InfectionState::Recovered); @@ -161,41 +187,89 @@ TEST_F(TestInfection, drawInfectionCourseBackward) auto dt = mio::abm::days(1); mio::abm::Parameters params = mio::abm::Parameters(num_age_groups); - // Time to go from all infected states to recover is 1 day (dt). - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = 1; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = 1; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = 1; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = 1; + auto virus_variant_test = mio::abm::VirusVariant::Wildtype; + auto age_group_test = age_group_60_to_79; ScopedMockDistribution>>> mock_uniform_dist; + ScopedMockDistribution>>> + mock_logNormal_dist; EXPECT_CALL(mock_uniform_dist.get_mock(), invoke) - .Times(testing::AtLeast(14)) - .WillOnce(testing::Return(0.1)) // Transition to InfectedNoSymptoms - .WillOnce(testing::Return(0.1)) - .WillOnce(testing::Return(0.1)) - .WillOnce(testing::Return(0.1)) - .WillOnce(testing::Return(0.1)) - .WillOnce(testing::Return(0.1)) - .WillOnce(testing::Return(0.3)) // Transition to InfectedSymptoms - .WillOnce(testing::Return(0.3)) - .WillOnce(testing::Return(0.3)) - .WillOnce(testing::Return(0.3)) - .WillOnce(testing::Return(0.3)) - .WillOnce(testing::Return(0.3)) - .WillOnce(testing::Return(0.6)) // Transition to InfectedSevere - .WillRepeatedly(testing::Return(0.9)); // Transition to InfectedCritical - - auto infection1 = mio::abm::Infection(prng, mio::abm::VirusVariant::Wildtype, age_group_60_to_79, params, - mio::abm::TimePoint(t + dt), mio::abm::InfectionState::Recovered, + .Times(testing::AtLeast(22)) + // 1st infection + .WillOnce(testing::Return(0.4)) // Transition to InfectedNoSymptoms + .WillOnce(testing::Return(params.get()[{virus_variant_test, age_group_test}] + .viral_load_peak.params.a())) // Viral load draws + .WillOnce(testing::Return(params.get()[{virus_variant_test, age_group_test}] + .viral_load_incline.params.a())) + .WillOnce(testing::Return(params.get()[{virus_variant_test, age_group_test}] + .viral_load_decline.params.a())) + .WillOnce(testing::Return(params.get()[{virus_variant_test, age_group_test}] + .infectivity_alpha.params.a())) // Infectivity draws + .WillOnce(testing::Return(params.get()[{virus_variant_test, age_group_test}] + .infectivity_beta.params.a())) + .WillOnce(testing::Return(params.get()[{virus_variant_test, age_group_test}] + .params.a())) // Virus Shed Factor + // 2nd infection + .WillOnce(testing::Return(0.6)) // Transition to InfectedSymptoms + .WillOnce(testing::Return(params.get()[{virus_variant_test, age_group_test}] + .viral_load_peak.params.a())) // Viral load draws + .WillOnce(testing::Return(params.get()[{virus_variant_test, age_group_test}] + .viral_load_incline.params.a())) + .WillOnce(testing::Return(params.get()[{virus_variant_test, age_group_test}] + .viral_load_decline.params.a())) + .WillOnce(testing::Return(params.get()[{virus_variant_test, age_group_test}] + .infectivity_alpha.params.a())) // Infectivity draws + .WillOnce(testing::Return(params.get()[{virus_variant_test, age_group_test}] + .infectivity_beta.params.a())) + .WillOnce(testing::Return(params.get()[{virus_variant_test, age_group_test}] + .params.a())) // Virus Shed Factor + // 3rd infection + .WillOnce(testing::Return(0.8)) // Transition to InfectedSevere + .WillOnce(testing::Return(params.get()[{virus_variant_test, age_group_test}] + .viral_load_peak.params.a())) // Viral load draws + .WillOnce(testing::Return(params.get()[{virus_variant_test, age_group_test}] + .viral_load_incline.params.a())) + .WillOnce(testing::Return(params.get()[{virus_variant_test, age_group_test}] + .viral_load_decline.params.a())) + .WillOnce(testing::Return(params.get()[{virus_variant_test, age_group_test}] + .infectivity_alpha.params.a())) // Infectivity draws + .WillOnce(testing::Return(params.get()[{virus_variant_test, age_group_test}] + .infectivity_beta.params.a())) + .WillOnce(testing::Return(params.get()[{virus_variant_test, age_group_test}] + .params.a())) // Virus Shed Factor + // 4th infection + .WillOnce(testing::Return(0.95)) // Transition to InfectedCritical + .WillRepeatedly(testing::Return(1.0)); + + EXPECT_CALL(mock_logNormal_dist.get_mock(), invoke) + .Times(testing::AtLeast(10)) + // 1st infection + .WillOnce(testing::Return(1.0)) // TimeInfectedNoSymptomsToRecovered + .WillOnce(testing::Return(1.0)) // IncubationPeriod + // 2nd infection + .WillOnce(testing::Return(1.0)) // TimeInfectedSymptomsToRecovered + .WillOnce(testing::Return(1.0)) // TimeInfectedNoSymptomsToSymptoms + .WillOnce(testing::Return(1.0)) // IncubationPeriod + // 3rd infection + .WillOnce(testing::Return(1.0)) // TimeInfectedSevereToRecovered + .WillOnce(testing::Return(1.0)) // TimeInfectedSymptomsToSevere + .WillOnce(testing::Return(1.0)) // TimeInfectedNoSymptomsToSymptoms + .WillOnce(testing::Return(1.0)) // IncubationPeriod + // 4th infection + .WillOnce(testing::Return(1.0)) //TimeInfectedCriticalToRecovered + .WillRepeatedly(testing::Return(1.0)); + + auto infection1 = mio::abm::Infection(prng, virus_variant_test, age_group_test, params, mio::abm::TimePoint(t + dt), + mio::abm::InfectionState::Recovered, {mio::abm::ProtectionType::NoProtection, mio::abm::TimePoint(0)}, false); - auto infection2 = mio::abm::Infection(prng, mio::abm::VirusVariant::Wildtype, age_group_60_to_79, params, - mio::abm::TimePoint(t + dt), mio::abm::InfectionState::Recovered, + auto infection2 = mio::abm::Infection(prng, virus_variant_test, age_group_test, params, mio::abm::TimePoint(t + dt), + mio::abm::InfectionState::Recovered, {mio::abm::ProtectionType::NoProtection, mio::abm::TimePoint(0)}, false); - auto infection3 = mio::abm::Infection(prng, mio::abm::VirusVariant::Wildtype, age_group_60_to_79, params, - mio::abm::TimePoint(t + dt), mio::abm::InfectionState::Recovered, + auto infection3 = mio::abm::Infection(prng, virus_variant_test, age_group_test, params, mio::abm::TimePoint(t + dt), + mio::abm::InfectionState::Recovered, {mio::abm::ProtectionType::NoProtection, mio::abm::TimePoint(0)}, false); - auto infection4 = mio::abm::Infection(prng, mio::abm::VirusVariant::Wildtype, age_group_60_to_79, params, - mio::abm::TimePoint(t + dt), mio::abm::InfectionState::Recovered, + auto infection4 = mio::abm::Infection(prng, virus_variant_test, age_group_test, params, mio::abm::TimePoint(t + dt), + mio::abm::InfectionState::Recovered, {mio::abm::ProtectionType::NoProtection, mio::abm::TimePoint(0)}, false); // Validate infection state progression backward. @@ -249,42 +323,42 @@ TEST_F(TestInfection, getPersonalProtectiveFactor) // Test Parameter InfectionProtectionFactor and get_protection_factor() t = mio::abm::TimePoint(0) + mio::abm::days(2); auto infection_protection_factor = params.get()[{ - latest_protection.type, age_group_15_to_34, mio::abm::VirusVariant::Wildtype}]( - t.days() - latest_protection.time.days()); + latest_protection.type, age_group_15_to_34, mio::abm::VirusVariant::Wildtype}](t.days() - + latest_protection.time.days()); EXPECT_NEAR(infection_protection_factor, 0.91, eps); EXPECT_NEAR(person.get_protection_factor(t, mio::abm::VirusVariant::Wildtype, params), 0.91, eps); t = mio::abm::TimePoint(0) + mio::abm::days(15); infection_protection_factor = params.get()[{ - latest_protection.type, age_group_15_to_34, mio::abm::VirusVariant::Wildtype}]( - t.days() - latest_protection.time.days()); + latest_protection.type, age_group_15_to_34, mio::abm::VirusVariant::Wildtype}](t.days() - + latest_protection.time.days()); EXPECT_NEAR(infection_protection_factor, 0.8635, eps); EXPECT_NEAR(person.get_protection_factor(t, mio::abm::VirusVariant::Wildtype, params), 0.8635, eps); t = mio::abm::TimePoint(0) + mio::abm::days(40); infection_protection_factor = params.get()[{ - latest_protection.type, age_group_15_to_34, mio::abm::VirusVariant::Wildtype}]( - t.days() - latest_protection.time.days()); + latest_protection.type, age_group_15_to_34, mio::abm::VirusVariant::Wildtype}](t.days() - + latest_protection.time.days()); EXPECT_NEAR(infection_protection_factor, 0.81, eps); EXPECT_NEAR(person.get_protection_factor(t, mio::abm::VirusVariant::Wildtype, params), 0.81, eps); // Test Parameter SeverityProtectionFactor t = mio::abm::TimePoint(0) + mio::abm::days(2); auto severity_protection_factor = params.get()[{ - latest_protection.type, age_group_15_to_34, mio::abm::VirusVariant::Wildtype}]( - t.days() - latest_protection.time.days()); + latest_protection.type, age_group_15_to_34, mio::abm::VirusVariant::Wildtype}](t.days() - + latest_protection.time.days()); EXPECT_NEAR(severity_protection_factor, 0.91, eps); t = mio::abm::TimePoint(0) + mio::abm::days(15); severity_protection_factor = params.get()[{ - latest_protection.type, age_group_15_to_34, mio::abm::VirusVariant::Wildtype}]( - t.days() - latest_protection.time.days()); + latest_protection.type, age_group_15_to_34, mio::abm::VirusVariant::Wildtype}](t.days() - + latest_protection.time.days()); EXPECT_NEAR(severity_protection_factor, 0.8635, eps); t = mio::abm::TimePoint(0) + mio::abm::days(40); severity_protection_factor = params.get()[{ - latest_protection.type, age_group_15_to_34, mio::abm::VirusVariant::Wildtype}]( - t.days() - latest_protection.time.days()); + latest_protection.type, age_group_15_to_34, mio::abm::VirusVariant::Wildtype}](t.days() - + latest_protection.time.days()); EXPECT_NEAR(severity_protection_factor, 0.81, eps); // Test Parameter HighViralLoadProtectionFactor diff --git a/cpp/tests/test_abm_location.cpp b/cpp/tests/test_abm_location.cpp index c92f199893..4ffe4209a5 100644 --- a/cpp/tests/test_abm_location.cpp +++ b/cpp/tests/test_abm_location.cpp @@ -22,6 +22,7 @@ #include "abm/parameters.h" #include "abm/person.h" #include "abm_helpers.h" +#include "memilio/utils/compiler_diagnostics.h" #include "random_number_test.h" using TestLocation = RandomNumberTest; @@ -84,6 +85,7 @@ TEST_F(TestLocation, interact) auto t = mio::abm::TimePoint(0); auto dt = mio::abm::seconds(8640); //0.1 days + // Setup model parameters for viral loads and infectivity distributions. // Setup model parameters for viral loads and infectivity distributions. mio::abm::Parameters params = mio::abm::Parameters(num_age_groups); params.set_default(num_age_groups); @@ -92,7 +94,8 @@ TEST_F(TestLocation, interact) params.get()[{variant, age}] = {{1., 1.}, {1., 1.}}; // Set incubtion period to two days so that the newly infected person is still exposed - params.get()[{variant, age}] = 2.; + ScopedMockDistribution>>> mock_logNorm_dist; + EXPECT_CALL(mock_logNorm_dist.get_mock(), invoke).WillRepeatedly(testing::Return(2)); // Setup location with some chance of exposure mio::abm::Location location(mio::abm::LocationType::Work, 0, num_age_groups); diff --git a/cpp/tests/test_abm_masks.cpp b/cpp/tests/test_abm_masks.cpp index 4121229071..0d3b4b2be6 100644 --- a/cpp/tests/test_abm_masks.cpp +++ b/cpp/tests/test_abm_masks.cpp @@ -64,7 +64,6 @@ TEST_F(TestMasks, changeMask) EXPECT_EQ(mask.get_type(), mio::abm::MaskType::Surgical); EXPECT_EQ(mask.get_time_used(t), mio::abm::hours(0)); } - /** * @brief Test mask protection during person interactions. */ @@ -72,8 +71,9 @@ TEST_F(TestMasks, maskProtection) { mio::abm::Parameters params(num_age_groups); - // Set incubation period to two days so that newly infected person is still exposed - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_5_to_14}] = 2.; + // set time for state transition to two days so that the newly infected person is still exposed + ScopedMockDistribution>>> mock_logNorm_dist; + EXPECT_CALL(mock_logNorm_dist.get_mock(), invoke).WillRepeatedly(testing::Return(2)); // Setup location and persons for the test auto t = mio::abm::TimePoint(0); diff --git a/cpp/tests/test_abm_model.cpp b/cpp/tests/test_abm_model.cpp index d0b38aaa3f..0cfbb4ddb2 100644 --- a/cpp/tests/test_abm_model.cpp +++ b/cpp/tests/test_abm_model.cpp @@ -17,9 +17,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "abm/parameters.h" #include "abm/person.h" #include "abm/model.h" +#include "abm/virus_variant.h" #include "abm_helpers.h" +#include "memilio/utils/parameter_distribution_wrapper.h" +#include "memilio/utils/parameter_distributions.h" #include "random_number_test.h" using TestModel = RandomNumberTest; @@ -91,9 +95,12 @@ TEST_F(TestModel, addPerson) EXPECT_EQ(model.get_persons().size(), 2); EXPECT_EQ(model.get_person(0).get_age(), age_group_15_to_34); EXPECT_EQ(model.get_person(1).get_age(), age_group_35_to_59); + // Verify the number of persons in the model and their respective age groups. + EXPECT_EQ(model.get_persons().size(), 2); + EXPECT_EQ(model.get_person(0).get_age(), age_group_15_to_34); + EXPECT_EQ(model.get_person(1).get_age(), age_group_35_to_59); } - /** * @brief Test combined subpopulation count by location type in the Model class. */ @@ -132,7 +139,7 @@ TEST_F(TestModel, getSubpopulationCombined) TEST_F(TestModel, findLocation) { // Create a model and add different location types. - auto model = mio::abm::Model(num_age_groups); + auto model = mio::abm::Model(num_age_groups); model.get_rng() = this->get_rng(); auto home_id = model.add_location(mio::abm::LocationType::Home); @@ -165,27 +172,16 @@ TEST_F(TestModel, evolveStateTransition) { using testing::Return; - auto t = mio::abm::TimePoint(0); - auto dt = mio::abm::hours(1); - auto model = mio::abm::Model(num_age_groups); + auto t = mio::abm::TimePoint(0); + auto dt = mio::abm::hours(1); + auto model = mio::abm::Model(num_age_groups); model.get_rng() = this->get_rng(); // Setup incubation and infection period parameters to prevent state transitions within one hour. p1 and p3 don't transition. - model.parameters.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = - 2 * dt.days(); - model.parameters - .get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = - 2 * dt.days(); - model.parameters - .get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = - 2 * dt.days(); - model.parameters.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = - 2 * dt.days(); - model.parameters - .get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = - 2 * dt.days(); + ScopedMockDistribution>>> mock_logNorm_dist; + EXPECT_CALL(mock_logNorm_dist.get_mock(), invoke).WillRepeatedly(testing::Return(2 * dt.days())); - // Add locations and persons to the model with different initial infection states. + // Add locations and persons to the model with different initial infection states. auto location1 = model.add_location(mio::abm::LocationType::School); auto location2 = model.add_location(mio::abm::LocationType::Work); add_test_person(model, location1, age_group_15_to_34, mio::abm::InfectionState::InfectedNoSymptoms); @@ -221,18 +217,14 @@ TEST_F(TestModel, evolveMobilityRules) { using testing::Return; - auto t = mio::abm::TimePoint(0) + mio::abm::hours(8); - auto dt = mio::abm::hours(1); - auto model = mio::abm::Model(num_age_groups); + auto t = mio::abm::TimePoint(0) + mio::abm::hours(8); + auto dt = mio::abm::hours(1); + auto model = mio::abm::Model(num_age_groups); model.get_rng() = this->get_rng(); // Setup infection period parameters to prevent state transitions within one hour. p1 doesn't transition. - model.parameters - .get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = - 2 * dt.days(); - model.parameters - .get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = - 2 * dt.days(); + ScopedMockDistribution>>> mock_logNorm_dist; + EXPECT_CALL(mock_logNorm_dist.get_mock(), invoke).WillRepeatedly(testing::Return(2 * dt.days())); model.parameters.get().set_multiple({age_group_5_to_14}, true); model.parameters.get().set_multiple({age_group_15_to_34, age_group_35_to_59}, true); @@ -287,23 +279,23 @@ TEST_F(TestModel, evolveMobilityTrips) { using testing::Return; - // Initialize model, time, and step size for simulation. auto t = mio::abm::TimePoint(0) + mio::abm::hours(8); auto dt = mio::abm::hours(2); auto model = mio::abm::Model(num_age_groups); - model.get_rng() = this->get_rng(); - - // Setup so p1-p5 don't do transition + mio::ParameterDistributionConstant constant(2 * dt.days()); + //setup so p1-p5 don't do transition + model.parameters + .get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = + mio::ParameterDistributionWrapper(constant); model.parameters - .get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = - 2 * dt.days(); + .get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = + mio::ParameterDistributionWrapper(constant); model.parameters - .get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = - 2 * dt.days(); - model.parameters.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = - 2 * dt.days(); - model.parameters.get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = - 2 * dt.days(); + .get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = + mio::ParameterDistributionWrapper(constant); + model.parameters + .get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = + mio::ParameterDistributionWrapper(constant); // Add different location types to the model. auto home_id = model.add_location(mio::abm::LocationType::Home); @@ -311,7 +303,6 @@ TEST_F(TestModel, evolveMobilityTrips) auto work_id = model.add_location(mio::abm::LocationType::Work); auto hospital_id = model.add_location(mio::abm::LocationType::Hospital); - // Mock the random distribution to control random behavior. ScopedMockDistribution>>> mock_uniform_dist; EXPECT_CALL(mock_uniform_dist.get_mock(), invoke) .Times(testing::AtLeast(8)) @@ -323,13 +314,13 @@ TEST_F(TestModel, evolveMobilityTrips) .WillOnce(testing::Return(0.8)) // draw random school group .WillOnce(testing::Return(0.8)) // draw random work hour .WillOnce(testing::Return(0.8)) // draw random school hour - .WillRepeatedly(testing::Return(1.0)); // this forces p1 and p3 to recover + .WillRepeatedly(testing::Return(0.8)); // this forces p1 and p3 to recover // Create persons with various infection states and assign them to multiple locations. auto pid1 = add_test_person(model, home_id, age_group_15_to_34, mio::abm::InfectionState::InfectedNoSymptoms, t); - auto pid2 = add_test_person(model, home_id, age_group_5_to_14, mio::abm::InfectionState::Susceptible, t); - auto pid3 = add_test_person(model, home_id, age_group_5_to_14, mio::abm::InfectionState::InfectedSevere, t); - auto pid4 = add_test_person(model, hospital_id, age_group_5_to_14, mio::abm::InfectionState::Recovered, t); + auto pid2 = add_test_person(model, home_id, age_group_15_to_34, mio::abm::InfectionState::Susceptible, t); + auto pid3 = add_test_person(model, home_id, age_group_15_to_34, mio::abm::InfectionState::InfectedSevere, t); + auto pid4 = add_test_person(model, hospital_id, age_group_15_to_34, mio::abm::InfectionState::Recovered, t); auto pid5 = add_test_person(model, home_id, age_group_15_to_34, mio::abm::InfectionState::Susceptible, t); // Assign persons to locations for trips. @@ -362,10 +353,12 @@ TEST_F(TestModel, evolveMobilityTrips) data.add_trip(trip2); data.add_trip(trip3); + // Set trips to use weekday trips on weekends. + data.use_weekday_trips_on_weekend(); // Set trips to use weekday trips on weekends. data.use_weekday_trips_on_weekend(); - // Mock the distribution to prevent infections or state transitions in the test. + // Mock the distribution to prevent infectionsin the test. ScopedMockDistribution>>> mock_exponential_dist; EXPECT_CALL(mock_exponential_dist.get_mock(), invoke).WillRepeatedly(Return(1.)); @@ -385,7 +378,7 @@ TEST_F(TestModel, evolveMobilityTrips) // Move all persons back to their home location to prepare for weekend trips. model.change_location(p1.get_id(), home_id); - model.change_location(p1.get_id(), home_id); + model.change_location(p3.get_id(), home_id); model.change_location(p2.get_id(), home_id); model.change_location(p5.get_id(), home_id); @@ -439,9 +432,9 @@ TEST_F(TestModel, reachCapacity) using testing::Return; // Initialize time and model. - auto t = mio::abm::TimePoint{mio::abm::hours(8).seconds()}; - auto dt = mio::abm::hours(1); - auto model = mio::abm::Model(num_age_groups); + auto t = mio::abm::TimePoint{mio::abm::hours(8).seconds()}; + auto dt = mio::abm::hours(1); + auto model = mio::abm::Model(num_age_groups); model.get_rng() = this->get_rng(); model.parameters.get()[age_group_5_to_14] = true; @@ -493,11 +486,20 @@ TEST_F(TestModel, checkMobilityOfDeadPerson) auto t = mio::abm::TimePoint(0); auto dt = mio::abm::days(1); auto model = mio::abm::Model(num_age_groups); - + model.parameters + .get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = 1.; + model.parameters + .get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = 1.; // Time to go from severe to critical infection is 1 day (dt). - model.parameters.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = 0.5; + mio::ParameterDistributionConstant constant1(dt.days()); + model.parameters + .get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = + mio::ParameterDistributionWrapper(constant1); // Time to go from critical infection to dead state is 1/2 day (0.5 * dt). - model.parameters.get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = 0.5; + mio::ParameterDistributionConstant constant2(0.5 * dt.days()); + model.parameters + .get()[{mio::abm::VirusVariant::Wildtype, age_group_60_to_79}] = + mio::ParameterDistributionWrapper(constant2); auto home_id = model.add_location(mio::abm::LocationType::Home); auto work_id = model.add_location(mio::abm::LocationType::Work); @@ -548,11 +550,12 @@ using TestModelTestingCriteria = RandomNumberTest; TEST_F(TestModelTestingCriteria, testAddingAndUpdatingAndRunningTestingSchemes) { auto model = mio::abm::Model(num_age_groups); - model.get_rng() = this->get_rng(); - // Make sure the infected person stay in Infected long enough - model.parameters.get()[{mio::abm::VirusVariant(0), age_group_15_to_34}] = - 100; - model.parameters.get()[{mio::abm::VirusVariant(0), age_group_15_to_34}] = 100; + // make sure the infected person stay in Infected long enough + mio::ParameterDistributionConstant constant(100.); + model.parameters.get()[{mio::abm::VirusVariant(0), age_group_15_to_34}] = + mio::ParameterDistributionWrapper(constant); + model.parameters.get()[{mio::abm::VirusVariant(0), age_group_15_to_34}] = + mio::ParameterDistributionWrapper(constant); auto home_id = model.add_location(mio::abm::LocationType::Home); auto work_id = model.add_location(mio::abm::LocationType::Work); @@ -597,7 +600,8 @@ TEST_F(TestModelTestingCriteria, testAddingAndUpdatingAndRunningTestingSchemes) .WillOnce(testing::Return(0.0)) // Draw for isolation compliance (doesn't matter in this test) .WillOnce( testing::Return(0.7)); // Person complies with testing (even though there is not testing strategy left) - EXPECT_EQ(model.get_testing_strategy().run_strategy(rng_person, person, work, current_time), false); // Testing scheme active and restricts entry + EXPECT_EQ(model.get_testing_strategy().run_strategy(rng_person, person, work, current_time), + false); // Testing scheme active and restricts entry // Try to re-add the same testing scheme and confirm it doesn't duplicate, then remove it. model.get_testing_strategy().add_testing_scheme(mio::abm::LocationType::Work, @@ -616,61 +620,101 @@ TEST_F(TestModel, checkParameterConstraints) auto model = mio::abm::Model(num_age_groups); auto params = model.parameters; - // Set valid values for various transition times, infection detection, and mask protection parameters. - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 1.; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 2.; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 3.; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 4.; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 5.; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 6.; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 7.; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 8.; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 9.; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 10.; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 0.3; - params.get()[age_group_35_to_59] = mio::abm::hours(4); - params.get()[age_group_35_to_59] = mio::abm::hours(8); - params.get()[age_group_0_to_4] = mio::abm::hours(3); - params.get()[age_group_0_to_4] = mio::abm::hours(6); - params.get()[mio::abm::MaskType::Community] = 0.5; - params.get()[mio::abm::MaskType::FFP2] = 0.6; - params.get()[mio::abm::MaskType::Surgical] = 0.7; - params.get() = mio::abm::TimePoint(0); - // Check that the parameter values are within their constraints (should pass). - EXPECT_FALSE(params.check_constraints()); - - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = -1.; - EXPECT_TRUE(params.check_constraints()); - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 1.; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = -2.; - EXPECT_TRUE(params.check_constraints()); - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 2.; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = -3.; - EXPECT_TRUE(params.check_constraints()); - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 3.; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = -4.; - EXPECT_TRUE(params.check_constraints()); - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 4.; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = -5.; - EXPECT_TRUE(params.check_constraints()); - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 5.; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = -6.; - EXPECT_TRUE(params.check_constraints()); - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 6.; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = -7.; - EXPECT_TRUE(params.check_constraints()); - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 7.; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = -8.; - EXPECT_TRUE(params.check_constraints()); - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 8.; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = -9.; - EXPECT_TRUE(params.check_constraints()); - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 9.; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = -10.; - EXPECT_TRUE(params.check_constraints()); - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 10.; - params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 1.1; - EXPECT_TRUE(params.check_constraints()); + mio::ParameterDistributionLogNormal log_norm1(1., 0.5); + mio::ParameterDistributionLogNormal log_norm2(2., 0.5); + mio::ParameterDistributionLogNormal log_norm3(3., 0.5); + mio::ParameterDistributionLogNormal log_norm4(4., 0.5); + mio::ParameterDistributionLogNormal log_norm5(5., 0.5); + mio::ParameterDistributionLogNormal log_norm6(6., 0.5); + mio::ParameterDistributionLogNormal log_norm7(7., 0.5); + mio::ParameterDistributionLogNormal log_norm8(8., 0.5); + mio::ParameterDistributionLogNormal log_norm9(9., 0.5); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = + mio::ParameterDistributionWrapper(log_norm1); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = + mio::ParameterDistributionWrapper(log_norm2); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = + mio::ParameterDistributionWrapper(log_norm3); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = + mio::ParameterDistributionWrapper(log_norm4); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = + mio::ParameterDistributionWrapper(log_norm5); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = + mio::ParameterDistributionWrapper(log_norm6); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = + mio::ParameterDistributionWrapper(log_norm7); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = + mio::ParameterDistributionWrapper(log_norm8); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = + mio::ParameterDistributionWrapper(log_norm9); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 0.3; + params.get()[age_group_35_to_59] = mio::abm::hours(4); + params.get()[age_group_35_to_59] = mio::abm::hours(8); + params.get()[age_group_0_to_4] = mio::abm::hours(3); + params.get()[age_group_0_to_4] = mio::abm::hours(6); + params.get()[mio::abm::MaskType::Community] = 0.5; + params.get()[mio::abm::MaskType::FFP2] = 0.6; + params.get()[mio::abm::MaskType::Surgical] = 0.7; + params.get() = mio::abm::TimePoint(0); + ASSERT_EQ(params.check_constraints(), false); + + mio::ParameterDistributionLogNormal log_normm1(-1., 0.5); + mio::ParameterDistributionLogNormal log_normm2(-2., 0.5); + mio::ParameterDistributionLogNormal log_normm3(-3., 0.5); + mio::ParameterDistributionLogNormal log_normm4(-4., 0.5); + mio::ParameterDistributionLogNormal log_normm5(-5., 0.5); + mio::ParameterDistributionLogNormal log_normm6(-6., 0.5); + mio::ParameterDistributionLogNormal log_normm7(-7., 0.5); + mio::ParameterDistributionLogNormal log_normm8(-8., 0.5); + mio::ParameterDistributionLogNormal log_normm9(-9., 0.5); + + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = + mio::ParameterDistributionWrapper(log_normm1); + ASSERT_EQ(params.check_constraints(), true); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = + mio::ParameterDistributionWrapper(log_norm1); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = + mio::ParameterDistributionWrapper(log_normm2); + ASSERT_EQ(params.check_constraints(), true); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = + mio::ParameterDistributionWrapper(log_norm2); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = + mio::ParameterDistributionWrapper(log_normm3); + ASSERT_EQ(params.check_constraints(), true); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = + mio::ParameterDistributionWrapper(log_norm3); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = + mio::ParameterDistributionWrapper(log_normm4); + ASSERT_EQ(params.check_constraints(), true); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = + mio::ParameterDistributionWrapper(log_norm4); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = + mio::ParameterDistributionWrapper(log_normm5); + ASSERT_EQ(params.check_constraints(), true); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = + mio::ParameterDistributionWrapper(log_norm5); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = + mio::ParameterDistributionWrapper(log_normm6); + ASSERT_EQ(params.check_constraints(), true); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = + mio::ParameterDistributionWrapper(log_norm6); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = + mio::ParameterDistributionWrapper(log_normm7); + ASSERT_EQ(params.check_constraints(), true); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = + mio::ParameterDistributionWrapper(log_norm7); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = + mio::ParameterDistributionWrapper(log_normm8); + ASSERT_EQ(params.check_constraints(), true); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = + mio::ParameterDistributionWrapper(log_norm8); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = + mio::ParameterDistributionWrapper(log_normm9); + ASSERT_EQ(params.check_constraints(), true); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = + mio::ParameterDistributionWrapper(log_norm9); + params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 1.1; + ASSERT_EQ(params.check_constraints(), true); params.get()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 0.3; params.get()[age_group_35_to_59] = mio::abm::hours(30); @@ -708,15 +752,16 @@ TEST_F(TestModel, mobilityRulesWithAppliedNPIs) { using testing::Return; // Test when the NPIs are applied, people can enter targeted location if they comply to the rules. - auto t = mio::abm::TimePoint(0) + mio::abm::hours(8); - auto dt = mio::abm::hours(1); - auto test_time = mio::abm::minutes(30); - auto model = mio::abm::Model(num_age_groups); + auto t = mio::abm::TimePoint(0) + mio::abm::hours(8); + auto dt = mio::abm::hours(1); + auto test_time = mio::abm::minutes(30); + auto model = mio::abm::Model(num_age_groups); model.get_rng() = this->get_rng(); + mio::ParameterDistributionConstant constant(2 * dt.days()); model.parameters - .get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = - 2 * dt.days(); + .get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = + mio::ParameterDistributionWrapper(constant); model.parameters.get().set_multiple({age_group_15_to_34, age_group_35_to_59}, true); model.parameters.get()[age_group_5_to_14] = true; @@ -825,15 +870,16 @@ TEST_F(TestModel, mobilityTripWithAppliedNPIs) { using testing::Return; // Test when the NPIs are applied, people can enter targeted location if they comply to the rules. - auto t = mio::abm::TimePoint(0) + mio::abm::hours(8); - auto dt = mio::abm::hours(1); - auto test_time = mio::abm::minutes(30); - auto model = mio::abm::Model(num_age_groups); + auto t = mio::abm::TimePoint(0) + mio::abm::hours(8); + auto dt = mio::abm::hours(1); + auto test_time = mio::abm::minutes(30); + auto model = mio::abm::Model(num_age_groups); model.get_rng() = this->get_rng(); + mio::ParameterDistributionConstant constant(2 * dt.days()); model.parameters - .get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = - 2 * dt.days(); + .get()[{mio::abm::VirusVariant::Wildtype, age_group_15_to_34}] = + mio::ParameterDistributionWrapper(constant); model.parameters.get().set_multiple({age_group_15_to_34, age_group_35_to_59}, true); model.parameters.get()[age_group_5_to_14] = true; diff --git a/cpp/tests/test_abm_person.cpp b/cpp/tests/test_abm_person.cpp index 120ded0714..ece4595d9e 100644 --- a/cpp/tests/test_abm_person.cpp +++ b/cpp/tests/test_abm_person.cpp @@ -136,11 +136,17 @@ TEST_F(TestPerson, quarantine) .WillOnce(testing::Return(0.6)) // goto_school_hour .WillRepeatedly(testing::Return(1.0)); // ViralLoad draws + ScopedMockDistribution>>> mock_logNorm_dist; + auto t_morning = mio::abm::TimePoint(0) + mio::abm::hours(7); auto dt = mio::abm::hours(1); - infection_parameters - .get()[{mio::abm::VirusVariant::Wildtype, age_group_35_to_59}] = - 0.5 * dt.days(); + EXPECT_CALL(mock_logNorm_dist.get_mock(), invoke) + .Times(testing::AtLeast(1)) + .WillOnce(testing::Return(1.0)) // TimeInfectedNoSymptomsToSymptoms + .WillOnce(testing::Return(1.0)) //IncubationTime + .WillOnce(testing::Return(0.5 * dt.days())) // TimeInfectedSymptomsToRecovered + .WillRepeatedly(testing::Return(1.0)); + infection_parameters.get().set_multiple({age_group_5_to_14}, true); infection_parameters.get().set_multiple({age_group_15_to_34, age_group_35_to_59}, true); diff --git a/cpp/tests/test_abm_serialization.cpp b/cpp/tests/test_abm_serialization.cpp index 834abc1059..6219f6b2c0 100644 --- a/cpp/tests/test_abm_serialization.cpp +++ b/cpp/tests/test_abm_serialization.cpp @@ -26,6 +26,7 @@ #include "abm/protection_event.h" #include "memilio/epidemiology/age_group.h" #include "memilio/io/json_serializer.h" +#include "memilio/utils/compiler_diagnostics.h" #include "memilio/utils/custom_index_array.h" #include "memilio/utils/uncertain_value.h" #include "models/abm/location.h" diff --git a/cpp/tests/test_analyze_result.cpp b/cpp/tests/test_analyze_result.cpp index c51d87c6af..d297b7fe8f 100644 --- a/cpp/tests/test_analyze_result.cpp +++ b/cpp/tests/test_analyze_result.cpp @@ -483,31 +483,45 @@ TEST(TestEnsembleParamsPercentile, graph_osecir_basic) TEST(TestEnsembleParamsPercentile, graph_abm_basic) { - size_t num_age_groups = 6; + size_t num_age_groups = 1; auto model1 = mio::abm::Model(num_age_groups); auto model2 = mio::abm::Model(num_age_groups); - model1.parameters.get()[{mio::abm::VirusVariant::Wildtype, mio::AgeGroup(0)}] = - 0.1; - model1.parameters.get()[{mio::abm::VirusVariant::Wildtype, mio::AgeGroup(0)}] = 0.2; - - model2.parameters.get()[{mio::abm::VirusVariant::Wildtype, mio::AgeGroup(0)}] = - 0.2; - model2.parameters.get()[{mio::abm::VirusVariant::Wildtype, mio::AgeGroup(0)}] = 0.3; + mio::ParameterDistributionLogNormal log_norm1(2., 1.2); + model1.parameters + .get()[{mio::abm::VirusVariant::Wildtype, mio::AgeGroup(0)}] = + mio::ParameterDistributionWrapper(log_norm1); + mio::ParameterDistributionLogNormal log_norm2(5., 1.2); + model1.parameters + .get()[{mio::abm::VirusVariant::Wildtype, mio::AgeGroup(0)}] = + mio::ParameterDistributionWrapper(log_norm2); + mio::ParameterDistributionLogNormal log_norm3(1.3, 2.); + model2.parameters + .get()[{mio::abm::VirusVariant::Wildtype, mio::AgeGroup(0)}] = + mio::ParameterDistributionWrapper(log_norm3); + mio::ParameterDistributionLogNormal log_norm4(4., 1.2); + model2.parameters + .get()[{mio::abm::VirusVariant::Wildtype, mio::AgeGroup(0)}] = + mio::ParameterDistributionWrapper(log_norm4); auto g1 = std::vector({model1, model2}); + mio::ParameterDistributionLogNormal log_norm5(1.5, 1.5); model1.parameters - .get()[{mio::abm::VirusVariant::Wildtype, mio::AgeGroup(0)}] = 0.2; - model1.parameters.get()[{mio::abm::VirusVariant::Wildtype, mio::AgeGroup(0)}] = - 0.3; - model1.parameters.get()[{mio::abm::VirusVariant::Wildtype, mio::AgeGroup(0)}] = 0.4; - + .get()[{mio::abm::VirusVariant::Wildtype, mio::AgeGroup(0)}] = + mio::ParameterDistributionWrapper(log_norm5); + mio::ParameterDistributionLogNormal log_norm(4., 1.5); + model1.parameters + .get()[{mio::abm::VirusVariant::Wildtype, mio::AgeGroup(0)}] = + mio::ParameterDistributionWrapper(log_norm); + mio::ParameterDistributionLogNormal log_norm6(1.1, 1.2); + model2.parameters + .get()[{mio::abm::VirusVariant::Wildtype, mio::AgeGroup(0)}] = + mio::ParameterDistributionWrapper(log_norm6); + mio::ParameterDistributionLogNormal log_norm7(6., 1.5); model2.parameters - .get()[{mio::abm::VirusVariant::Wildtype, mio::AgeGroup(0)}] = 0.7; - model2.parameters.get()[{mio::abm::VirusVariant::Wildtype, mio::AgeGroup(0)}] = - 0.4; - model2.parameters.get()[{mio::abm::VirusVariant::Wildtype, mio::AgeGroup(0)}] = 0.5; + .get()[{mio::abm::VirusVariant::Wildtype, mio::AgeGroup(0)}] = + mio::ParameterDistributionWrapper(log_norm7); auto g2 = std::vector({model1, model2}); @@ -518,51 +532,59 @@ TEST(TestEnsembleParamsPercentile, graph_abm_basic) auto check1 = ensemble_p49_params[0] - .parameters.get()[{mio::abm::VirusVariant::Wildtype, mio::AgeGroup(0)}] - .value(); + .parameters + .get()[{mio::abm::VirusVariant::Wildtype, mio::AgeGroup(0)}] + .params()[0]; auto check2 = ensemble_p49_params[1] - .parameters.get()[{mio::abm::VirusVariant::Wildtype, mio::AgeGroup(0)}] - .value(); + .parameters + .get()[{mio::abm::VirusVariant::Wildtype, mio::AgeGroup(0)}] + .params()[0]; - EXPECT_EQ(check1, 0.1); - EXPECT_EQ(check2, 0.2); + EXPECT_EQ(check1, 1.5); + EXPECT_EQ(check2, 1.1); auto check3 = ensemble_p51_params[0] - .parameters.get()[{mio::abm::VirusVariant::Wildtype, mio::AgeGroup(0)}] - .value(); + .parameters + .get()[{mio::abm::VirusVariant::Wildtype, mio::AgeGroup(0)}] + .params()[0]; auto check4 = ensemble_p51_params[1] - .parameters.get()[{mio::abm::VirusVariant::Wildtype, mio::AgeGroup(0)}] - .value(); + .parameters + .get()[{mio::abm::VirusVariant::Wildtype, mio::AgeGroup(0)}] + .params()[0]; - EXPECT_EQ(check3, 0.3); - EXPECT_EQ(check4, 0.4); + EXPECT_EQ(check3, 2.); + EXPECT_EQ(check4, 1.3); auto check5 = ensemble_p49_params[0] - .parameters.get()[{mio::abm::VirusVariant::Wildtype, mio::AgeGroup(0)}] - .value(); + .parameters + .get()[{mio::abm::VirusVariant::Wildtype, mio::AgeGroup(0)}] + .params()[0]; auto check6 = ensemble_p49_params[1] - .parameters.get()[{mio::abm::VirusVariant::Wildtype, mio::AgeGroup(0)}] - .value(); + .parameters + .get()[{mio::abm::VirusVariant::Wildtype, mio::AgeGroup(0)}] + .params()[0]; - EXPECT_EQ(check5, 0.2); - EXPECT_EQ(check6, 0.3); + EXPECT_EQ(check5, 4.); + EXPECT_EQ(check6, 4.); auto check7 = ensemble_p51_params[0] - .parameters.get()[{mio::abm::VirusVariant::Wildtype, mio::AgeGroup(0)}] - .value(); + .parameters + .get()[{mio::abm::VirusVariant::Wildtype, mio::AgeGroup(0)}] + .params()[0]; auto check8 = ensemble_p51_params[1] - .parameters.get()[{mio::abm::VirusVariant::Wildtype, mio::AgeGroup(0)}] - .value(); + .parameters + .get()[{mio::abm::VirusVariant::Wildtype, mio::AgeGroup(0)}] + .params()[0]; - EXPECT_EQ(check7, 0.4); - EXPECT_EQ(check8, 0.5); + EXPECT_EQ(check7, 5.); + EXPECT_EQ(check8, 6.); } TEST(TestDistance, same_result_zero_distance) diff --git a/cpp/tests/test_odesecir.cpp b/cpp/tests/test_odesecir.cpp index b1dce952d9..1266154ea6 100644 --- a/cpp/tests/test_odesecir.cpp +++ b/cpp/tests/test_odesecir.cpp @@ -448,7 +448,7 @@ TEST(TestOdeSecir, testSettersAndGetters) { std::vector> vec; - for (int i = 0; i < 22; i++) { + for (int i = 1; i < 23; i++) { mio::UncertainValue val = mio::UncertainValue(i); val.set_distribution(mio::ParameterDistributionNormal(i, 10 * i, 5 * i, i / 10.0)); vec.push_back(val); diff --git a/cpp/tests/test_parameter_studies.cpp b/cpp/tests/test_parameter_studies.cpp index 52cb6dec56..435f60e66b 100644 --- a/cpp/tests/test_parameter_studies.cpp +++ b/cpp/tests/test_parameter_studies.cpp @@ -186,7 +186,7 @@ TEST(ParameterStudies, test_normal_distribution) parameter_dist_normal_1.log_stddev_changes(false); // only avoid warning output in tests double std_dev_demanded = parameter_dist_normal_1.get_standard_dev(); - parameter_dist_normal_1.get_sample(); + parameter_dist_normal_1.get_sample(mio::thread_local_rng()); EXPECT_GE(std_dev_demanded, parameter_dist_normal_1.get_standard_dev()); @@ -211,14 +211,14 @@ TEST(ParameterStudies, test_normal_distribution) // check that sampling only occurs in boundaries for (int i = 0; i < 1000; i++) { - double val = parameter_dist_normal_2.get_sample(); + double val = parameter_dist_normal_2.get_sample(mio::thread_local_rng()); EXPECT_GE(parameter_dist_normal_2.get_upper_bound() + 1e-10, val); EXPECT_LE(parameter_dist_normal_2.get_lower_bound() - 1e-10, val); } - //degenerate case: ub == lb - mio::ParameterDistributionNormal dist3(3.0, 3.0, 3.0, 0.0); - EXPECT_EQ(dist3.get_sample(), 3.0); + //degenerate case: ub == lb //For MSVC the normal distribution cannot have a value of 0.0 for sigma + mio::ParameterDistributionNormal dist3(0.999999999 * 3.0, 1.000000001 * 3.0, 3.0, 0.00000001); + EXPECT_NEAR(dist3.get_sample(mio::thread_local_rng()), 3.0, 1e-07); } TEST(ParameterStudies, test_uniform_distribution) @@ -233,7 +233,7 @@ TEST(ParameterStudies, test_uniform_distribution) // check that sampling only occurs in boundaries for (int i = 0; i < 1000; i++) { - double val = parameter_dist_unif.get_sample(); + double val = parameter_dist_unif.get_sample(mio::thread_local_rng()); EXPECT_GE(parameter_dist_unif.get_upper_bound() + 1e-10, val); EXPECT_LE(parameter_dist_unif.get_lower_bound() - 1e-10, val); } @@ -247,20 +247,20 @@ TEST(ParameterStudies, test_predefined_samples) // set predefined sample (can be out of [min,max]) and get it parameter_dist_unif.add_predefined_sample(2); - double var = parameter_dist_unif.get_sample(); + double var = parameter_dist_unif.get_sample(mio::thread_local_rng()); EXPECT_EQ(var, 2); // predefined sample was deleted, get real sample which cannot be 2 due to [min,max] - var = parameter_dist_unif.get_sample(); + var = parameter_dist_unif.get_sample(mio::thread_local_rng()); EXPECT_NE(var, 2); // set predefined sample (can be out of [min,max]) and get it parameter_dist_normal.add_predefined_sample(2); - var = parameter_dist_normal.get_sample(); + var = parameter_dist_normal.get_sample(mio::thread_local_rng()); EXPECT_EQ(var, 2); // predefined sample was deleted, get real sample which cannot be 2 due to [min,max] - var = parameter_dist_normal.get_sample(); + var = parameter_dist_normal.get_sample(mio::thread_local_rng()); EXPECT_NE(var, 2); } diff --git a/pycode/memilio-simulation/memilio/simulation/bindings/utils/parameter_distributions.cpp b/pycode/memilio-simulation/memilio/simulation/bindings/utils/parameter_distributions.cpp index bcac1cc08b..9c6dee1b29 100644 --- a/pycode/memilio-simulation/memilio/simulation/bindings/utils/parameter_distributions.cpp +++ b/pycode/memilio-simulation/memilio/simulation/bindings/utils/parameter_distributions.cpp @@ -18,6 +18,7 @@ * limitations under the License. */ #include "utils/parameter_distributions.h" +#include "memilio/utils/random_number_generator.h" #include "pybind_util.h" #include "memilio/utils/parameter_distributions.h" @@ -29,21 +30,21 @@ namespace pymio void bind_parameter_distribution(py::module_& m, std::string const& name) { bind_class(m, name.c_str()) - .def_property("lower_bound", &mio::ParameterDistribution::get_lower_bound, - &mio::ParameterDistribution::set_lower_bound) - .def_property("upper_bound", &mio::ParameterDistribution::get_upper_bound, - &mio::ParameterDistribution::set_upper_bound) .def("add_predefined_sample", &mio::ParameterDistribution::add_predefined_sample) .def("remove_predefined_samples", &mio::ParameterDistribution::remove_predefined_samples) - .def("get_sample", &mio::ParameterDistribution::get_sample); + .def("get_sample", [](mio::ParameterDistribution& self) { + return self.get_sample(mio::thread_local_rng()); + }); } void bind_parameter_distribution_normal(py::module_& m, std::string const& name) { - bind_class(m, name.c_str()) + bind_class(m, + name.c_str()) .def(py::init(), py::arg("lb"), py::arg("ub"), py::arg("mean"), py::arg("std_dev")) .def(py::init(), py::arg("lb"), py::arg("ub"), py::arg("mean")) + .def(py::init(), py::arg("mean"), py::arg("std_dev")) .def_property("mean", &mio::ParameterDistributionNormal::get_mean, &mio::ParameterDistributionNormal::set_mean) .def_property("standard_dev", &mio::ParameterDistributionNormal::get_standard_dev, &mio::ParameterDistributionNormal::set_standard_dev); @@ -51,9 +52,13 @@ void bind_parameter_distribution_normal(py::module_& m, std::string const& name) void bind_parameter_distribution_uniform(py::module_& m, std::string const& name) { - bind_class(m, name.c_str()) - .def(py::init<>()) - .def(py::init(), py::arg("lb"), py::arg("ub")); + bind_class(m, + name.c_str()) + .def(py::init(), py::arg("lb"), py::arg("ub")) + .def_property("lower_bound", &mio::ParameterDistributionUniform::get_lower_bound, + &mio::ParameterDistributionUniform::set_lower_bound) + .def_property("upper_bound", &mio::ParameterDistributionUniform::get_upper_bound, + &mio::ParameterDistributionUniform::set_upper_bound); } } // namespace pymio diff --git a/pycode/memilio-simulation/memilio/simulation_test/test_abm.py b/pycode/memilio-simulation/memilio/simulation_test/test_abm.py index f14bbbc5e1..c861231e27 100644 --- a/pycode/memilio-simulation/memilio/simulation_test/test_abm.py +++ b/pycode/memilio-simulation/memilio/simulation_test/test_abm.py @@ -98,11 +98,6 @@ def test_simulation(self): model.assign_location(p1_id, loc_id) model.assign_location(p2_id, loc_id) - model.parameters.InfectedSymptomsToSevere[abm.VirusVariant.Wildtype, mio.AgeGroup( - 0)] = 0.0 - model.parameters.InfectedSymptomsToRecovered[abm.VirusVariant.Wildtype, mio.AgeGroup( - 0)] = 0.0 - # trips trip_list = abm.TripList() trip_list.add_trip(abm.Trip(0, abm.TimePoint( diff --git a/pycode/memilio-simulation/memilio/simulation_test/test_distributions.py b/pycode/memilio-simulation/memilio/simulation_test/test_distributions.py index 6cabf23f4b..66241f3ef2 100644 --- a/pycode/memilio-simulation/memilio/simulation_test/test_distributions.py +++ b/pycode/memilio-simulation/memilio/simulation_test/test_distributions.py @@ -37,8 +37,7 @@ def test_normal(self): N = mio.ParameterDistributionNormal(-1.0, 1.0, 0.0, 1.0) # properties self.assertEqual(N.mean, 0.0) # std_dev automatically adapted - self.assertEqual(N.lower_bound, -1.0) - self.assertEqual(N.upper_bound, 1.0) + # sample n = N.get_sample() self.assertGreaterEqual(n, -1) diff --git a/pycode/memilio-simulation/memilio/simulation_test/test_pickle.py b/pycode/memilio-simulation/memilio/simulation_test/test_pickle.py index b88950c80c..da916542db 100644 --- a/pycode/memilio-simulation/memilio/simulation_test/test_pickle.py +++ b/pycode/memilio-simulation/memilio/simulation_test/test_pickle.py @@ -54,9 +54,6 @@ def test_distribution(self): self.assertEqual(pickle_test.mean, 0.4) self.assertEqual(pickle_test.standard_dev, 0.1) - self.assertEqual(pickle_test.lower_bound, 0) - self.assertEqual(pickle_test.upper_bound, 1) - def test_damping_sampling(self): test = msim.UncertainValue(2.2) test.set_distribution(msim.ParameterDistributionNormal(0, 1, 0.4, 0.1)) @@ -70,8 +67,6 @@ def test_damping_sampling(self): self.assertEqual(pickle_test.value.get_distribution().mean, 0.4) self.assertEqual( pickle_test.value.get_distribution().standard_dev, 0.1) - self.assertEqual(pickle_test.value.get_distribution().lower_bound, 0) - self.assertEqual(pickle_test.value.get_distribution().upper_bound, 1) self.assertEqual(pickle_test.level, 1) self.assertEqual(pickle_test.type, 2)