-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature request]: new flepimop post-processing module #413
Comments
@TimothyWillard I've written a new function to also allow for scenario comparisons def plot_scenario_comp(results_dict, strain, fip, save=True, save_path=None, display_plot=True, data_hosp=None):
# Load hospitalization data
if data_hosp is None:
# wherever your calibration data lives
data_hosp = pd.read_csv(r'./model_input/SMH_Flu_2024_R1_allflu_medVax_H1_training_multiseason_emcee_difflocvarseas/us_data_Flu_2024_R1_allflu_training_multiseason_emcee_difflocvarseas.csv')
# Prepare simulation data
def get_sim_df(results_list, scenario):
dates = results_list[0].index
sim_data = {
'Date': np.concatenate([dates] * len(results_list)),
'sim_id': np.concatenate([np.full(len(dates), str(i + 1)) for i in range(len(results_list))]),
'scenario': np.concatenate([np.full(len(dates), scenario) for i in range(len(results_list))]),
'value': np.concatenate([result['incidH_AllFlu'].values for result in results_list])
}
return pd.DataFrame(sim_data)
# Get seasons to keep and state name
seasons_keep, state_name, fips = get_seasons_keep(strain, fip)
sim_dfs = {}
for key in results_dict.keys():
results_list = results_dict[key]
sim_dfs[key] = get_sim_df(results_list, key)
sim_df = pd.concat(sim_dfs.values(), ignore_index=True)
# Create subplots
fig, axs = plt.subplots(2, 2, figsize=(15, 10))
# Plot 1: Simulation results
sns.lineplot(data=sim_df, x='Date', y='value', style='sim_id', hue='scenario', ax=axs[0, 0], palette=sns.color_palette("Set3",3))
axs[0, 0].set_ylabel('Hospitalization incidence')
axs[0, 0].tick_params(axis='x', rotation=45)
axs[0, 0].legend_.remove()
# Prepare seasons_keep_2
seasons_keep_2 = ['20' + season.split("to")[0] + '-' + season.split("to")[1] for season in seasons_keep]
# Filter data_hosp for the state
temp = data_hosp[data_hosp['source'] == state_name]
seasons = temp['season'].unique()
# Plot historical data
for season in seasons:
temp_season = temp[temp['season'] == season]
if season in seasons_keep_2:
sns.lineplot(x=temp_season['date'].astype('datetime64[ns]'), y=temp_season['incidH'], color='red', marker='o', zorder=2, alpha=0.5, ax=axs[0, 0])
axs[0, 0].legend_.remove()
# Prepare outcomes
outcomes = []
for key in results_dict.keys():
for sim_id in sim_df['sim_id'].unique():
temp = sim_df[(sim_df['sim_id'] == sim_id) & (sim_df['scenario'] == key)]
outcomes.append([key, sim_id, temp['value'].max(), temp['value'].sum(), temp['Date'].values[temp['value'].argmax()]])
outcomes = pd.DataFrame(outcomes, columns=['scenario', 'sim_id', 'max', 'cuml', 'max_date'])
# Plot 2: Max hospitalization incidence
sns.histplot(data=outcomes, x='max', ax=axs[0, 1], stat='probability', hue='scenario', palette=sns.color_palette("Set3",3),multiple="dodge")
axs[0, 1].set_xlabel('Max hospitalization incidence')
# Plot 3: Cumulative hospitalizations
sns.histplot(data=outcomes, x='cuml', ax=axs[1, 0], stat='probability', hue='scenario', palette=sns.color_palette("Set3",3),multiple="dodge")
axs[1, 0].set_xlabel('Cumulative hospitalizations')
# Plot 4: Date of max hospitalization incidence
metrics = ['max', 'cuml']
combinations = ['High vs. Med', 'Med vs. Low']
difs = []
labels = []
combos = []
for combo in combinations:
if combo == 'High vs. Med':
hi_vals = outcomes[outcomes['scenario'] == 'HiVax'][metrics].values
med_vals = outcomes[outcomes['scenario'] == 'MedVax'][metrics].values
vals = -(hi_vals - med_vals) / med_vals
elif combo == 'Med vs. Low':
med_vals = outcomes[outcomes['scenario'] == 'MedVax'][metrics].values
low_vals = outcomes[outcomes['scenario'] == 'LowVax'][metrics].values
vals = -(med_vals - low_vals) / low_vals
for i, metric in enumerate(metrics):
difs.extend(vals[:, i])
labels.extend([metric] * len(vals))
combos.extend([combo] * len(vals))
result_df = pd.DataFrame({'combination': combos, 'label': labels, 'difference': difs})
sns.violinplot(data=result_df,x='label',y='difference', hue='combination',ax=axs[1,1], palette=sns.color_palette("Set2",2))
axs[1, 1].set_ylabel(r'Hospitalizations averted (%)')
axs[1, 1].set_xlabel('Target')
# Set the title of the figure
fig.suptitle(f'{state_name} {strain}', ha='center', va='bottom')
plt.tight_layout()
if save:
if save_path is None:
fig.savefig(fname=f'plot_{strain}_scenario_comp_{state_name}.pdf', bbox_inches='tight')
else:
fig.savefig(fname=f'{save_path}/plot_{strain}_scenario_comp_{state_name}.pdf', bbox_inches='tight')
# Display the plot if display_plot is True
if display_plot:
plt.show()` |
the above makes plots that look like this |
also useful is the ability to sample posterior predictive Here's the start of a function like that, also developed for flu scenarios, note that to really generalize this we would want to index by the parameter labels rather than just the ordering . Note that chains can be gotten from a h5 file like so, arviz is another (python) package that can read h5 files:
these chains can then be fed to gempyor to simulate the model given a config def shuffle_params(chains, idx_array, intersect, keep_list, Num_samples = None, Num_seasons = None, Num_params = None):
if Num_samples == None:
Num_samples = 100
if Num_seasons == None:
Num_seasons = 3
if Num_params == None:
Num_params = 9
samples = chains[-1,:,:]
shuffled_samples = np.zeros([Num_samples, Num_params])
shuffled_chains = np.zeros([chains.shape[0], Num_samples, Num_params])
r0_seasons = []
indices = []
for k in range(Num_samples):
r_season_idx = np.random.randint(0,len(intersect),Num_params)
r_chain_idx = np.random.randint(0,chains.shape[1],Num_params)
r0_seasons.append(keep_list[r_season_idx[0]])
for j in range(Num_params):
shuffled_samples[k,j] = samples[r_chain_idx[j],idx_array[r_season_idx[j]][j]]
shuffled_chains[:,k,j] = chains[:,r_chain_idx[j],idx_array[r_season_idx[j]][j]]
indices.append([r_chain_idx[j],idx_array[r_season_idx[j]][j]])
return shuffled_chains, shuffled_samples, indices, np.array(r0_seasons)
######################################################################
# usage
gempyor_inference = GempyorInference(
config_filepath=state_dst_config,
run_id=run_id,
prefix=None,
first_sim_index=1,
stoch_traj_flag=False,
rng_seed=None,
nslots=1,
inference_filename_prefix="global/final/", # usually for {global or chimeric}/{intermediate or final}
inference_filepath_suffix="", # usually for the slot_id
out_run_id=None, # if out_run_id is different from in_run_id, fill this
out_prefix=None, # if out_prefix is different from in_prefix, fill this
# in case the data folder is on another directory
autowrite_seir=False,
)
# generate a list of data frames from gempyor
result = gempyor_inference.simulate_proposal(shuffled_samples[0]) |
Label
post-processing
Priority Label
medium priority
Is your feature request related to a problem? Please describe.
When trying to assess scenario plots and/or model fit to empirical data there are a number of common targets across pathogens that it would be useful to have plotted together with sample time trajectories. This can be done fairly easily with seaborn and subfigures in matplotlib. Here's an implementation of the basic idea for a given set of results lists like is returned by the gempyor package, which is a list of data frames, this should probably be generalized to be able to read the .parquet files from the model_output folder if using other gempyor functions that populatte the model folder if the inference object is set to save
example output of the above function for the state of Alabama is here:
plot_H1N1_High_Vax_Alabama.pdf
Is your feature request related to a new application, scenario round, pathogen? Please describe.
this is useful for both scenario rounds and assessment of fits of calibration runs to empirical data
Describe the solution you'd like
incorporate something like the above function into default post-processing plots with automated post processing like currently exists for R-inference runs and will hopefully (soon) exist for emcee runs
The text was updated successfully, but these errors were encountered: