Skip to content

Commit

Permalink
more conflict tracking and reporting
Browse files Browse the repository at this point in the history
  • Loading branch information
jtwhite79 committed Oct 28, 2024
1 parent bef64e6 commit 2501118
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 117 deletions.
245 changes: 130 additions & 115 deletions src/libs/pestpp_common/EnsembleMethodUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2313,10 +2313,23 @@ map<string, double> L2PhiHandler::get_par_group_contrib(Eigen::VectorXd &phi_vec

void L2PhiHandler::update(ObservationEnsemble & oe, ParameterEnsemble & pe)
{
ObservationInfo oinfo = pest_scenario->get_ctl_observation_info();
num_conflict_group.clear();
for (auto& group : pest_scenario->get_ctl_ordered_obs_group_names())
{
num_conflict_group[group] = 0;
}
vector<string> in_conflict = detect_simulation_data_conflict(oe,"");
string group;
for (auto& ic : in_conflict)
{
group = oinfo.get_group(ic);
num_conflict_group[group]++;
}
//build up obs group and par group idx maps for group reporting
obs_group_idx_map.clear();
vector<string> nnz_obs = oe_base->get_var_names();
ObservationInfo oinfo = pest_scenario->get_ctl_observation_info();

vector<int> idx;
for (auto& og : pest_scenario->get_ctl_ordered_obs_group_names())
obs_group_idx_map[og] = vector<int>();
Expand Down Expand Up @@ -2386,10 +2399,22 @@ void L2PhiHandler::update(ObservationEnsemble & oe, ParameterEnsemble & pe)

void L2PhiHandler::update(ObservationEnsemble & oe, ParameterEnsemble & pe, ObservationEnsemble& weights)
{
ObservationInfo oinfo = pest_scenario->get_ctl_observation_info();
num_conflict_group.clear();
for (auto& group : pest_scenario->get_ctl_ordered_obs_group_names())
{
num_conflict_group[group] = 0;
}
vector<string> in_conflict = detect_simulation_data_conflict(oe,"");
string group;
for (auto& ic : in_conflict)
{
group = oinfo.get_group(ic);
num_conflict_group[group]++;
}
//build up obs group and par group idx maps for group reporting
obs_group_idx_map.clear();
vector<string> nnz_obs = oe_base->get_var_names();
ObservationInfo oinfo = pest_scenario->get_ctl_observation_info();
vector<int> idx;
for (auto& og : pest_scenario->get_ctl_ordered_obs_group_names())
obs_group_idx_map[og] = vector<int>();
Expand Down Expand Up @@ -2645,6 +2670,104 @@ bool cmp_pair(pair<string,double>& first, pair<string,double>& second)
return first.second > second.second;
}

vector<string> L2PhiHandler::detect_simulation_data_conflict(ObservationEnsemble& _oe, string csv_tag) {
vector<string> in_conflict;

ofstream pdccsv;
if (csv_tag.size() > 0)
pdccsv.open(file_manager->get_base_filename() + csv_tag);

double smin, smax, omin, omax, smin_stat, smax_stat, omin_stat, omax_stat;
map<string, int> smap, omap;
vector<string> snames = _oe.get_var_names();
vector<string> onames = oe_base->get_var_names();
vector<string> temp = get_lt_obs_names();
set<string> ineq_lt(temp.begin(), temp.end());
//set<string>::iterator end = ineq.end();
temp = get_gt_obs_names();
set<string> ineq_gt(temp.begin(), temp.end());
temp.resize(0);

for (int i = 0; i < snames.size(); i++) {
smap[snames[i]] = i;
}
for (int i = 0; i < onames.size(); i++) {
omap[onames[i]] = i;
}
int sidx, oidx;
bool use_stat_dist = true;
if (pest_scenario->get_pestpp_options().get_ies_pdc_sigma_distance() <= 0.0)
use_stat_dist = false;

double smn, sstd, omn, ostd, dist;
double sd = abs(pest_scenario->get_pestpp_options().get_ies_pdc_sigma_distance());
int oe_nr = _oe.shape().first;
int oe_base_nr = oe_base->shape().first;
Eigen::VectorXd t;

if (csv_tag.size() > 0)
{
pdccsv << "name,obs_mean,obs_std,obs_min,obs_max,obs_stat_min,obs_stat_max,sim_mean,sim_std,sim_min,sim_max,sim_stat_min,sim_stat_max,distance";
pdccsv << endl;
}

for (auto oname: pest_scenario->get_ctl_ordered_nz_obs_names()) {
//if (ineq.find(oname) != end)
// continue;
sidx = smap[oname];
oidx = omap[oname];
smin = _oe.get_eigen_ptr()->col(sidx).minCoeff();
omin = oe_base->get_eigen_ptr()->col(oidx).minCoeff();
smax = _oe.get_eigen_ptr()->col(sidx).maxCoeff();
omax = oe_base->get_eigen_ptr()->col(oidx).maxCoeff();
t = _oe.get_eigen_ptr()->col(sidx);
smn = t.mean();
sstd = std::sqrt((t.array() - smn).square().sum() / (oe_nr - 1));
smin_stat = smn - (sd * sstd);
smax_stat = smn + (sd * sstd);
t = oe_base->get_eigen_ptr()->col(oidx);
omn = t.mean();
ostd = std::sqrt((t.array() - omn).square().sum() / (oe_base_nr - 1));
omin_stat = omn - (sd * ostd);
omax_stat = omn + (sd * ostd);
bool conflicted = false;
if (use_stat_dist) {
if (ineq_lt.find(oname) != ineq_lt.end()) {
if (smin_stat > omax_stat)
conflicted = true;
} else if (ineq_gt.find(oname) != ineq_gt.end()) {
if (smax_stat < omin_stat)
conflicted = true;
} else if ((smin_stat > omax_stat) || (smax_stat < omin_stat)) {
conflicted = true;
}
} else {
if (ineq_lt.find(oname) != ineq_lt.end()) {
if (smin > omax)
conflicted = true;
} else if (ineq_gt.find(oname) != ineq_gt.end()) {
if (smax < omin)
conflicted = true;
} else if ((smin > omax) || (smax < omin)) {
conflicted = true;
}
}
if (conflicted) {
in_conflict.push_back(oname);
if (csv_tag.size() > 0) {
dist = max((smin - omax), (omin - smax));

pdccsv << pest_utils::lower_cp(oname) << "," << omn << "," << ostd << "," << omin << "," << omax << ","
<< omin_stat << ","
<< omax_stat;
pdccsv << "," << smn << "," << sstd << "," << smin << "," << smax << "," << smin_stat << ","
<< smax_stat << "," << dist << endl;
}
}
}
pdccsv.close();
return in_conflict;
}

void L2PhiHandler::report_group(bool echo) {

Expand All @@ -2665,7 +2788,6 @@ void L2PhiHandler::report_group(bool echo) {
snzgroups.emplace(oi_ptr->get_group(o));
}


double tot = 0, ptot = 0;
double v = 0,pv = 0;
int c = 0;
Expand Down Expand Up @@ -2738,7 +2860,7 @@ void L2PhiHandler::report_group(bool echo) {
ss << " --- observation group phi summary --- " << endl;
ss << " (computed using 'actual' phi)" << endl;
ss << " (sorted by mean phi)" << endl;
ss << left << setw(len) << "group" << right << setw(6) << "count" << setw(10) << "mean" << setw(10) << "std";
ss << left << setw(len) << "group" << right << setw(6) << "count" << setw(11) << "nconflict" << setw(10) << "mean" << setw(10) << "std";
ss << setw(10) << "min" << setw(10) << "max";
ss << setw(10) << "percent" << setw(10) << "std" << endl; //<< setw(10) << "min " << setw(10) << "max " << endl;
f << ss.str();
Expand Down Expand Up @@ -2771,6 +2893,8 @@ void L2PhiHandler::report_group(bool echo) {
ss.str("");
ss << left << setw(len) << pest_utils::lower_cp(g) << " ";
ss << right << setw(5) << nzc << " ";
ss << right << setw(10) << num_conflict_group[g] << " ";

ss << right << setw(9) << setprecision(3) << mn_map[g] << " ";
ss << setw(9) << setprecision(3) << std_map[g] << " ";
ss << setw(9) << setprecision(3) << mmn_map[g] << " ";
Expand Down Expand Up @@ -5620,7 +5744,8 @@ void EnsembleMethod::initialize(int cycle, bool run, bool use_existing)
ph.report(true);

pcs = ParChangeSummarizer(&pe_base, &file_manager, &output_file_writer);
vector<string> in_conflict = detect_simulation_data_conflict(oe,"pdc.csv");
message(1,"checking for prior-data conflict");
vector<string> in_conflict = ph.detect_simulation_data_conflict(oe,"pdc.csv");
if (in_conflict.size() > 0)
{
ss.str("");
Expand Down Expand Up @@ -8739,116 +8864,6 @@ void EnsembleMethod::zero_weight_obs(vector<string>& obs_to_zero_weight, bool up
message(1, ss.str());
}

vector<string> EnsembleMethod::detect_simulation_data_conflict(ObservationEnsemble& _oe, string csv_tag)
{
message(1, "checking for simulation-data conflict...");
//for now, just really simple metric - checking for overlap
// write out conflicted obs and some related info
// to a csv file
ofstream pdccsv(file_manager.get_base_filename() + csv_tag);

vector<string> in_conflict;
double smin, smax, omin, omax, smin_stat, smax_stat, omin_stat, omax_stat;
map<string, int> smap, omap;
vector<string> snames = _oe.get_var_names();
vector<string> onames = oe_base.get_var_names();
vector<string> temp = ph.get_lt_obs_names();
set<string> ineq_lt(temp.begin(), temp.end());
//set<string>::iterator end = ineq.end();
temp = ph.get_gt_obs_names();
set<string> ineq_gt(temp.begin(), temp.end());
temp.resize(0);

for (int i = 0; i < snames.size(); i++)
{
smap[snames[i]] = i;
}
for (int i = 0; i < onames.size(); i++)
{
omap[onames[i]] = i;
}
int sidx, oidx;
bool use_stat_dist = true;
if (pest_scenario.get_pestpp_options().get_ies_pdc_sigma_distance() <= 0.0)
use_stat_dist = false;

double smn, sstd, omn, ostd, dist;
double sd = abs(pest_scenario.get_pestpp_options().get_ies_pdc_sigma_distance());
int oe_nr = _oe.shape().first;
int oe_base_nr = oe_base.shape().first;
Eigen::VectorXd t;

pdccsv << "name,obs_mean,obs_std,obs_min,obs_max,obs_stat_min,obs_stat_max,sim_mean,sim_std,sim_min,sim_max,sim_stat_min,sim_stat_max,distance" << endl;
for (auto oname : pest_scenario.get_ctl_ordered_nz_obs_names())
{
//if (ineq.find(oname) != end)
// continue;
sidx = smap[oname];
oidx = omap[oname];
smin = _oe.get_eigen_ptr()->col(sidx).minCoeff();
omin = oe_base.get_eigen_ptr()->col(oidx).minCoeff();
smax = _oe.get_eigen_ptr()->col(sidx).maxCoeff();
omax = oe_base.get_eigen_ptr()->col(oidx).maxCoeff();
t = _oe.get_eigen_ptr()->col(sidx);
smn = t.mean();
sstd = std::sqrt((t.array() - smn).square().sum() / (oe_nr - 1));
smin_stat = smn - (sd * sstd);
smax_stat = smn + (sd * sstd);
t = oe_base.get_eigen_ptr()->col(oidx);
omn = t.mean();
ostd = std::sqrt((t.array() - omn).square().sum() / (oe_base_nr - 1));
omin_stat = omn - (sd * ostd);
omax_stat = omn + (sd * ostd);
bool conflicted = false;
if (use_stat_dist)
{
if (ineq_lt.find(oname) != ineq_lt.end())
{
if (smin_stat > omax_stat)
conflicted = true;
}
else if (ineq_gt.find(oname) != ineq_gt.end())
{
if (smax_stat < omin_stat)
conflicted = true;
}
else if ((smin_stat > omax_stat) || (smax_stat < omin_stat))
{
conflicted = true;
}
}
else
{
if (ineq_lt.find(oname) != ineq_lt.end())
{
if (smin > omax)
conflicted = true;
}
else if (ineq_gt.find(oname) != ineq_gt.end())
{
if (smax < omin)
conflicted = true;
}
else if ((smin > omax) || (smax < omin))
{
conflicted = true;
}
}
if (conflicted)
{
in_conflict.push_back(oname);
dist = max((smin - omax), (omin - smax));

pdccsv << pest_utils::lower_cp(oname) << "," << omn << "," << ostd << "," << omin << "," << omax << "," << omin_stat << ","
<< omax_stat;
pdccsv << "," << smn << "," << sstd << "," << smin << "," << smax << "," << smin_stat << ","
<< smax_stat << "," << dist << endl;

}
}

return in_conflict;
}

Eigen::MatrixXd EnsembleMethod::get_Am(const vector<string>& real_names, const vector<string>& par_names)
{
Expand Down
4 changes: 2 additions & 2 deletions src/libs/pestpp_common/EnsembleMethodUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ class L2PhiHandler
map<string,map<string,double>> get_meas_phi_weight_ensemble(ObservationEnsemble& oe, ObservationEnsemble& weights);

vector<string> get_violating_realizations(ObservationEnsemble& oe, const vector<string>& viol_obs_names);

vector<string> detect_simulation_data_conflict(ObservationEnsemble& _oe, string csv_tag);

private:
string tag;
Expand Down Expand Up @@ -158,6 +158,7 @@ class L2PhiHandler
map<string, double> composite;
map<string, double> actual;
map<string, double> noise;
map<string,int> num_conflict_group;

vector<string> lt_obs_names;
vector<string> gt_obs_names;
Expand Down Expand Up @@ -496,7 +497,6 @@ class EnsembleMethod

Eigen::MatrixXd get_Am(const vector<string>& real_names, const vector<string>& par_names);

vector<string> detect_simulation_data_conflict(ObservationEnsemble& _oe, string csv_tag);

void zero_weight_obs(vector<string>& obs_to_zero_weight, bool update_obscov = true, bool update_oe_base = true);

Expand Down
3 changes: 3 additions & 0 deletions src/libs/pestpp_common/EnsembleSmoother.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ void IterEnsembleSmoother::iterate_2_solution()
last_best_mean = ph.get_mean(L2PhiHandler::phiType::COMPOSITE);
last_best_std = ph.get_std(L2PhiHandler::phiType::COMPOSITE);
ph.report(true);
ss.str("");
ss << "." << iter << ".pdc.csv";
ph.detect_simulation_data_conflict(oe,ss.str());
ph.write(iter, run_mgr_ptr->get_total_runs());
if (pest_scenario.get_pestpp_options().get_ies_save_rescov())
ph.save_residual_cov(oe,iter);
Expand Down

0 comments on commit 2501118

Please sign in to comment.