Skip to content

Commit

Permalink
vastly improve performance.
Browse files Browse the repository at this point in the history
avoid creating Location objects, do not rebuild exposure caches
  • Loading branch information
reneSchm committed Feb 2, 2024
1 parent e50bd75 commit 144252a
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 41 deletions.
2 changes: 1 addition & 1 deletion cpp/models/abm/common_abm_loggers.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ struct LogInfectionState : mio::LogAlways {
Eigen::VectorXd sum = Eigen::VectorXd::Zero(Eigen::Index(mio::abm::InfectionState::Count));
auto curr_time = sim.get_time();
PRAGMA_OMP(for)
for (auto&& location : sim.get_world().get_locations()) {
for (auto& location : sim.get_world().get_locations()) {
for (uint32_t inf_state = 0; inf_state < (int)mio::abm::InfectionState::Count; inf_state++) {
sum[inf_state] +=
sim.get_world().get_subpopulation(location, curr_time, mio::abm::InfectionState(inf_state));
Expand Down
2 changes: 1 addition & 1 deletion cpp/models/abm/location.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ Location::Location(LocationId loc_id, size_t num_agegroups, uint32_t num_cells)
assert(num_cells > 0 && "Number of cells has to be larger than 0.");
}

Location Location::copy_location_without_persons(size_t) const
Location Location::copy() const
{
return *this;
}
Expand Down
6 changes: 3 additions & 3 deletions cpp/models/abm/location.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class Location
* @param[in] num_agegroups [Default: 1] The number of age groups in the model.
* @param[in] num_cells [Default: 1] The number of Cell%s in which the Location is divided.
*/
Location(LocationId loc_id, size_t num_agegroups = 1, uint32_t num_cells = 1);
explicit Location(LocationId loc_id, size_t num_agegroups = 1, uint32_t num_cells = 1);

/**
* @brief Construct a Location with provided parameters.
Expand All @@ -99,7 +99,7 @@ class Location
* @param[in] num_agegroups [Default: 1] The number of age groups in the model.
* @param[in] num_cells [Default: 1] The number of Cell%s in which the Location is divided.
*/
Location(LocationType loc_type, uint32_t loc_index, size_t num_agegroups = 1, uint32_t num_cells = 1)
explicit Location(LocationType loc_type, uint32_t loc_index, size_t num_agegroups = 1, uint32_t num_cells = 1)
: Location(LocationId{loc_index, loc_type}, num_agegroups, num_cells)
{
}
Expand All @@ -119,7 +119,7 @@ class Location
* @brief Return a copy of this #Location object with an empty m_persons.
* @param[in] num_agegroups The number of age groups in the model.
*/
Location copy_location_without_persons(size_t num_agegroups) const;
Location copy() const;

/**
* @brief Compare two Location%s.
Expand Down
19 changes: 13 additions & 6 deletions cpp/models/abm/world.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ LocationId World::add_location(LocationType type, uint32_t num_cells)
if (m_local_populations_cache.is_valid()) {
m_local_populations_cache.data[id.index];
}
m_air_exposure_rates_cache.data.emplace(
id.index, Location::AirExposureRates({CellIndex(num_cells), VirusVariant::Count}, 0.));
m_contact_exposure_rates_cache.data.emplace(
id.index, Location::ContactExposureRates(
{CellIndex(num_cells), VirusVariant::Count, AgeGroup(parameters.get_num_groups())}, 0.));
return id;
}

Expand Down Expand Up @@ -75,15 +80,15 @@ void World::migration(TimePoint t, TimeSpan dt)

auto try_migration_rule = [&](auto rule) -> bool {
//run migration rule and check if migration can actually happen
auto target_type = rule(personal_rng, person, t, dt, parameters);
auto target_location = find_location(target_type, person);
auto current_location = person.get_location();
auto target_type = rule(personal_rng, person, t, dt, parameters);
const Location& target_location = get_location(find_location(target_type, person));
const LocationId current_location = person.get_location();
if (m_testing_strategy.run_strategy(personal_rng, person, target_location, t)) {
if (target_location != current_location &&
get_number_persons(target_location) < get_location(target_location).get_capacity().persons) {
if (target_location.get_id() != current_location &&
get_number_persons(target_location) < target_location.get_capacity().persons) {
bool wears_mask = person.apply_mask_intervention(personal_rng, target_location);
if (wears_mask) {
migrate(person.get_person_id(), target_location); // TODO: i == PersonId, use?
migrate(person.get_person_id(), target_location.get_id()); // TODO: i == PersonId, use?
}
return true;
}
Expand Down Expand Up @@ -148,6 +153,8 @@ void World::begin_step(TimePoint t, TimeSpan dt)
m_local_populations_cache.validate();
}
recompute_exposure_rates(t, dt);
m_air_exposure_rates_cache.validate();
m_contact_exposure_rates_cache.validate();
}

auto World::get_locations() const -> Range<std::pair<ConstLocationIterator, ConstLocationIterator>>
Expand Down
41 changes: 12 additions & 29 deletions cpp/models/abm/world.h
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ class World
});
}

inline size_t get_subpopulation(Location location, TimePoint t, InfectionState state) const
inline size_t get_subpopulation(const Location& location, TimePoint t, InfectionState state) const
{
return get_subpopulation(location.get_id(), t, state);
}
Expand All @@ -364,7 +364,7 @@ class World
});
}

inline size_t get_number_persons(Location location) const
inline size_t get_number_persons(const Location& location) const
{
return get_number_persons(location.get_id());
}
Expand Down Expand Up @@ -452,18 +452,9 @@ class World
template <class T>
struct Cache {
T data;
bool m_is_valid = false;
mutable size_t m_hits = 0;
mutable size_t m_misses = 0;

bool is_valid() const
{
if (m_is_valid) {
m_hits++;
}
else {
m_misses++;
}
return m_is_valid;
}

Expand All @@ -477,10 +468,8 @@ class World
m_is_valid = true;
}

// ~Cache()
// {
// std::cout << "hits: " << m_hits << " misses: " << m_misses << "\n";
// }
private:
bool m_is_valid = false;
};

void rebuild()
Expand All @@ -496,22 +485,16 @@ class World

void recompute_exposure_rates(TimePoint t, TimeSpan dt)
{
m_air_exposure_rates_cache.data.clear();
m_contact_exposure_rates_cache.data.clear();
for (Location& location : m_locations) {
m_air_exposure_rates_cache.data.emplace(
location.get_index(),
Location::AirExposureRates({CellIndex(location.get_cells().size()), VirusVariant::Count}, 0.));
m_contact_exposure_rates_cache.data.emplace(
location.get_index(),
Location::ContactExposureRates({CellIndex(location.get_cells().size()), VirusVariant::Count,
AgeGroup(parameters.get_num_groups())},
0.));
for (Location& location : get_locations()) {
auto index = location.get_index();
m_air_exposure_rates_cache.data.at(index).array().setZero();
m_contact_exposure_rates_cache.data.at(index).array().setZero();
}
for (Person& person : get_persons()) {
mio::abm::add_exposure_contribution(m_air_exposure_rates_cache.data.at(person.get_location().index),
m_contact_exposure_rates_cache.data.at(person.get_location().index),
person, get_location(person.get_person_id()), t, dt);
auto location = person.get_location().index;
mio::abm::add_exposure_contribution(m_air_exposure_rates_cache.data.at(location),
m_contact_exposure_rates_cache.data.at(location), person,
get_location(person.get_person_id()), t, dt);
}
}

Expand Down
2 changes: 1 addition & 1 deletion cpp/tests/test_abm_person.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ TEST(TestPerson, copyPerson)
auto location = mio::abm::Location(mio::abm::LocationType::Work, 0, num_age_groups);
auto t = mio::abm::TimePoint(0);
auto person = mio::abm::Person(rng, location, age_group_60_to_79);
auto copied_location = location.copy_location_without_persons(num_age_groups);
auto copied_location = location.copy();
auto copied_person = person.copy_person(copied_location);

EXPECT_EQ(copied_person.get_infection_state(t), mio::abm::InfectionState::Susceptible);
Expand Down

0 comments on commit 144252a

Please sign in to comment.