Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature request]: new flepimop post-processing module #413

Closed
MacdonaldJoshuaCaleb opened this issue Dec 5, 2024 · 4 comments
Closed
Labels
enhancement Request for improvement or addition of new feature(s). gempyor Concerns the Python core. medium priority Medium priority. plotting Relating to plotting and/or visualizations. post-processing Concern the post-processing.

Comments

@MacdonaldJoshuaCaleb
Copy link
Collaborator

MacdonaldJoshuaCaleb commented Dec 5, 2024

Label

post-processing

Priority Label

medium priority

Is your feature request related to a problem? Please describe.

When trying to assess scenario plots and/or model fit to empirical data there are a number of common targets across pathogens that it would be useful to have plotted together with sample time trajectories. This can be done fairly easily with seaborn and subfigures in matplotlib. Here's an implementation of the basic idea for a given set of results lists like is returned by the gempyor package, which is a list of data frames, this should probably be generalized to be able to read the .parquet files from the model_output folder if using other gempyor functions that populatte the model folder if the inference object is set to save

example output of the above function for the state of Alabama is here:
plot_H1N1_High_Vax_Alabama.pdf

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

def plot_state(results_list, strain, scenario, fip, save=True, save_path=None, display_plot=True, data_hosp=None):
    # Load hospitalization data
    if data_hosp is None:
        # wherever your calibration data lives
        data_hosp = pd.read_csv(r'./model_input/SMH_Flu_2024_R1_allflu_medVax_H1_training_multiseason_emcee_difflocvarseas/us_data_Flu_2024_R1_allflu_training_multiseason_emcee_difflocvarseas.csv')
    
    # Get seasons to keep and state name
    seasons_keep, state_name, fips = get_seasons_keep(strain, fip)
    dates = results_list[0].index

    # Prepare simulation data
    sim_data = {
        'Date': np.concatenate([dates] * len(results_list)),
        'sim_id': np.concatenate([np.full(len(dates), str(i + 1)) for i in range(len(results_list))]),
        'value': np.concatenate([result['incidH_AllFlu'].values for result in results_list])
    }
    sim_df = pd.DataFrame(sim_data)

    # Create subplots
    fig, axs = plt.subplots(2, 2, figsize=(15, 10))

    # Plot 1: Simulation results
    sns.lineplot(data=sim_df, x='Date', y='value', hue='sim_id', alpha=0.25, palette=["blue"], ax=axs[0, 0])
    axs[0, 0].set_ylabel('Hospitalization incidence')
    axs[0, 0].tick_params(axis='x', rotation=45)
    axs[0, 0].legend_.remove()

    # Prepare seasons_keep_2
    seasons_keep_2 = ['20' + season.split("to")[0] + '-' + season.split("to")[1] for season in seasons_keep]

    # Filter data_hosp for the state
    temp = data_hosp[data_hosp['source'] == state_name]
    seasons = temp['season'].unique()

    # Plot historical data
    for season in seasons:
        temp_season = temp[temp['season'] == season]
        if season in seasons_keep_2:
            sns.lineplot(x=temp_season['date'].astype('datetime64[ns]'), y=temp_season['incidH'], color='red', marker='o', zorder=2, alpha=0.5, ax=axs[0, 0])
    axs[0, 0].legend_.remove()

    # Prepare outcomes
    outcomes = []
    for sim_id in sim_df['sim_id'].unique():
        temp = sim_df[sim_df['sim_id'] == sim_id]
        outcomes.append(['H1 High Vax', sim_id, temp['value'].max(), temp['value'].sum(), temp['Date'].values[temp['value'].argmax()]])
    outcomes = pd.DataFrame(outcomes, columns=['scenario', 'sim_id', 'max', 'cuml', 'max_date'])

    # Plot 2: Max hospitalization incidence
    sns.histplot(data=outcomes, x='max', ax=axs[0, 1], stat='probability')
    axs[0, 1].set_xlabel('Max hospitalization incidence')

    # Plot 3: Cumulative hospitalizations
    sns.histplot(data=outcomes, x='cuml', ax=axs[1, 0], stat='probability')
    axs[1, 0].set_xlabel('Cumulative hospitalizations')

    # Plot 4: Date of max hospitalization incidence
    sns.boxplot(data=outcomes, x='max_date', ax=axs[1, 1])
    axs[1, 1].set_xlabel('Date of max hospitalization incidence')
    axs[1, 1].tick_params(axis='x', rotation=45)

    # Set the title of the figure
    fig.suptitle(f'{state_name} {strain} {scenario}', ha='center', va='bottom')
    plt.tight_layout()

    # Replace spaces in scenario with underscores for file naming
    if ' ' in scenario:
        scenario = scenario.replace(' ', '_')
    
    # Save the figure if save is True
    if save:
        if save_path is None:
            fig.savefig(fname=f'plot_{strain}_{scenario}_{state_name}.pdf', bbox_inches='tight')
        else:
            fig.savefig(fname=f'{save_path}/plot_{strain}_{scenario}_{state_name}.pdf', bbox_inches='tight')
    
    # Display the plot if display_plot is True
    if display_plot:
        plt.show()```

# usage 
# note the get_seasons_keep function is only relevant for current round of flu because we had the issue with "bad" seasons for simulations 
keep_list, state_name, fips = get_seasons_keep('H1N1','01000')
path = "./scenario_output/all_results_H1_HiVax_" + state_name + ".pkl"
    
import dill
# note this is because we are stitching together multiple seasons, for a general function should read output from model_oput folder or a user specified location 
with open(path, 'rb') as f:
    data = dill.load(f)

plot_state(data, 'H1N1', 'High Vax', fips[j], save = True,save_path = './model_plots', display_plot = False)

Is your feature request related to a new application, scenario round, pathogen? Please describe.

this is useful for both scenario rounds and assessment of fits of calibration runs to empirical data

Describe the solution you'd like

incorporate something like the above function into default post-processing plots with automated post processing like currently exists for R-inference runs and will hopefully (soon) exist for emcee runs

@TimothyWillard TimothyWillard added enhancement Request for improvement or addition of new feature(s). gempyor Concerns the Python core. post-processing Concern the post-processing. medium priority Medium priority. plotting Relating to plotting and/or visualizations. labels Dec 5, 2024
@MacdonaldJoshuaCaleb
Copy link
Collaborator Author

MacdonaldJoshuaCaleb commented Dec 7, 2024

@TimothyWillard I've written a new function to also allow for scenario comparisons

def plot_scenario_comp(results_dict, strain, fip, save=True, save_path=None, display_plot=True, data_hosp=None):
    # Load hospitalization data
    if data_hosp is None:
        # wherever your calibration data lives
        data_hosp = pd.read_csv(r'./model_input/SMH_Flu_2024_R1_allflu_medVax_H1_training_multiseason_emcee_difflocvarseas/us_data_Flu_2024_R1_allflu_training_multiseason_emcee_difflocvarseas.csv')
    
    # Prepare simulation data
    def get_sim_df(results_list, scenario):
        dates = results_list[0].index
        sim_data = {
            'Date': np.concatenate([dates] * len(results_list)),
            'sim_id': np.concatenate([np.full(len(dates), str(i + 1)) for i in range(len(results_list))]),
            'scenario': np.concatenate([np.full(len(dates), scenario) for i in range(len(results_list))]),
            'value': np.concatenate([result['incidH_AllFlu'].values for result in results_list])
        }
        return pd.DataFrame(sim_data)
    
    # Get seasons to keep and state name
    seasons_keep, state_name, fips = get_seasons_keep(strain, fip)
   

    sim_dfs = {}
    for key in results_dict.keys():
        results_list = results_dict[key]
        sim_dfs[key] = get_sim_df(results_list, key)
    sim_df = pd.concat(sim_dfs.values(), ignore_index=True)

    # Create subplots
    fig, axs = plt.subplots(2, 2, figsize=(15, 10))

    # Plot 1: Simulation results
    sns.lineplot(data=sim_df, x='Date', y='value', style='sim_id', hue='scenario', ax=axs[0, 0], palette=sns.color_palette("Set3",3))
    axs[0, 0].set_ylabel('Hospitalization incidence')
    axs[0, 0].tick_params(axis='x', rotation=45)
    axs[0, 0].legend_.remove()

    # Prepare seasons_keep_2
    seasons_keep_2 = ['20' + season.split("to")[0] + '-' + season.split("to")[1] for season in seasons_keep]

    # Filter data_hosp for the state
    temp = data_hosp[data_hosp['source'] == state_name]
    seasons = temp['season'].unique()

    # Plot historical data
    for season in seasons:
        temp_season = temp[temp['season'] == season]
        if season in seasons_keep_2:
            sns.lineplot(x=temp_season['date'].astype('datetime64[ns]'), y=temp_season['incidH'], color='red', marker='o', zorder=2, alpha=0.5, ax=axs[0, 0])
    axs[0, 0].legend_.remove()

    # Prepare outcomes
    outcomes = []
    for key in results_dict.keys():
        for sim_id in sim_df['sim_id'].unique():
            temp = sim_df[(sim_df['sim_id'] == sim_id) & (sim_df['scenario'] == key)]
            outcomes.append([key, sim_id, temp['value'].max(), temp['value'].sum(), temp['Date'].values[temp['value'].argmax()]])
    outcomes = pd.DataFrame(outcomes, columns=['scenario', 'sim_id', 'max', 'cuml', 'max_date'])

    # Plot 2: Max hospitalization incidence
    sns.histplot(data=outcomes, x='max', ax=axs[0, 1], stat='probability', hue='scenario', palette=sns.color_palette("Set3",3),multiple="dodge")
    axs[0, 1].set_xlabel('Max hospitalization incidence')

    # Plot 3: Cumulative hospitalizations
    sns.histplot(data=outcomes, x='cuml', ax=axs[1, 0], stat='probability', hue='scenario', palette=sns.color_palette("Set3",3),multiple="dodge")
    axs[1, 0].set_xlabel('Cumulative hospitalizations')

    # Plot 4: Date of max hospitalization incidence
    metrics = ['max', 'cuml']
    combinations = ['High vs. Med', 'Med vs. Low']
    difs = []
    labels = []
    combos = []

    for combo in combinations:
        if combo == 'High vs. Med':
            hi_vals = outcomes[outcomes['scenario'] == 'HiVax'][metrics].values
            med_vals = outcomes[outcomes['scenario'] == 'MedVax'][metrics].values
            vals = -(hi_vals - med_vals) / med_vals
        elif combo == 'Med vs. Low':
            med_vals = outcomes[outcomes['scenario'] == 'MedVax'][metrics].values
            low_vals = outcomes[outcomes['scenario'] == 'LowVax'][metrics].values
            vals = -(med_vals - low_vals) / low_vals
        
        for i, metric in enumerate(metrics):
            difs.extend(vals[:, i])
            labels.extend([metric] * len(vals))
            combos.extend([combo] * len(vals))

    result_df = pd.DataFrame({'combination': combos, 'label': labels, 'difference': difs})

    sns.violinplot(data=result_df,x='label',y='difference', hue='combination',ax=axs[1,1], palette=sns.color_palette("Set2",2))
    axs[1, 1].set_ylabel(r'Hospitalizations averted (%)')
    axs[1, 1].set_xlabel('Target')


    # Set the title of the figure
    fig.suptitle(f'{state_name} {strain}', ha='center', va='bottom')
    plt.tight_layout()

    if save:
        if save_path is None:
            fig.savefig(fname=f'plot_{strain}_scenario_comp_{state_name}.pdf', bbox_inches='tight')
        else:
            fig.savefig(fname=f'{save_path}/plot_{strain}_scenario_comp_{state_name}.pdf', bbox_inches='tight')
    
    # Display the plot if display_plot is True
    if display_plot:
        plt.show()`

@MacdonaldJoshuaCaleb MacdonaldJoshuaCaleb changed the title [Feature request]: edit default output plotting to be a panel figure with main scenario hub targets [Feature request]: new flepimop plotting options Dec 7, 2024
@MacdonaldJoshuaCaleb
Copy link
Collaborator Author

the above makes plots that look like this
plot_H1N1_scenario_comp_Alabama.pdf

@MacdonaldJoshuaCaleb MacdonaldJoshuaCaleb changed the title [Feature request]: new flepimop plotting options [Feature request]: new flepimop post-processing module Dec 7, 2024
@MacdonaldJoshuaCaleb
Copy link
Collaborator Author

MacdonaldJoshuaCaleb commented Dec 7, 2024

also useful is the ability to sample posterior predictive

Here's the start of a function like that, also developed for flu scenarios, note that to really generalize this we would want to index by the parameter labels rather than just the ordering . Note that chains can be gotten from a h5 file like so, arviz is another (python) package that can read h5 files:

sampler = emcee.backends.HDFBackend(filename, read_only=True)
chains = sampler.get_chain()

these chains can then be fed to gempyor to simulate the model given a config

 def shuffle_params(chains, idx_array, intersect, keep_list, Num_samples = None, Num_seasons = None, Num_params = None):

    if Num_samples == None:
        Num_samples = 100
    if Num_seasons == None:
        Num_seasons = 3
    if Num_params == None:
        Num_params = 9
        
    samples = chains[-1,:,:]
    shuffled_samples = np.zeros([Num_samples, Num_params])
    shuffled_chains = np.zeros([chains.shape[0], Num_samples, Num_params])
    r0_seasons = []
    indices = []
    for k in range(Num_samples):
        r_season_idx = np.random.randint(0,len(intersect),Num_params)
        r_chain_idx = np.random.randint(0,chains.shape[1],Num_params)
        r0_seasons.append(keep_list[r_season_idx[0]])
        for j in range(Num_params):
            shuffled_samples[k,j] = samples[r_chain_idx[j],idx_array[r_season_idx[j]][j]]
            shuffled_chains[:,k,j] = chains[:,r_chain_idx[j],idx_array[r_season_idx[j]][j]]
            indices.append([r_chain_idx[j],idx_array[r_season_idx[j]][j]])
           


        
    return shuffled_chains, shuffled_samples, indices, np.array(r0_seasons)
######################################################################
# usage 
    gempyor_inference = GempyorInference(
                config_filepath=state_dst_config,
                run_id=run_id,
                prefix=None,
                first_sim_index=1,
                stoch_traj_flag=False,
                rng_seed=None,
                nslots=1,
                inference_filename_prefix="global/final/",  # usually for {global or chimeric}/{intermediate or final}
                inference_filepath_suffix="",  # usually for the slot_id
                out_run_id=None,  # if out_run_id is different from in_run_id, fill this
                out_prefix=None,  # if out_prefix is different from in_prefix, fill this
            # in case the data folder is on another directory
                autowrite_seir=False,
            )
        
           # generate a list of data frames from gempyor
            result = gempyor_inference.simulate_proposal(shuffled_samples[0])

@TimothyWillard
Copy link
Contributor

This issue has been split into two more manageable pieces, GH-415 & GH-416.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Request for improvement or addition of new feature(s). gempyor Concerns the Python core. medium priority Medium priority. plotting Relating to plotting and/or visualizations. post-processing Concern the post-processing.
Projects
None yet
Development

No branches or pull requests

2 participants