Skip to content

Latest commit

 

History

History
2535 lines (2321 loc) · 55.4 KB

005_015_brief-analysis-with-larger-data.md

File metadata and controls

2535 lines (2321 loc) · 55.4 KB

Brief analysis on model fit with a larger data set

Setup

%load_ext autoreload
%autoreload 2
import re
import warnings
from pathlib import Path
from time import time
from typing import Final

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotnine as gg
import pymc3 as pm
import scipy.stats as st
import seaborn as sns
from speclet.bayesian_models.hierarchical_nb import HierarchcalNegativeBinomialModel
from speclet.io import DataFile, models_dir
from speclet.managers.data_managers import CrisprScreenDataManager
from speclet.plot.plotnine_helpers import set_gg_theme
from speclet.project_configuration import read_project_configuration
# Notebook execution timer.
notebook_tic = time()

# Plotting setup.
set_gg_theme()
%config InlineBackend.figure_format = "retina"

# Constants
RANDOM_SEED = 847
np.random.seed(RANDOM_SEED)
HDI_PROB = read_project_configuration().modeling.highest_density_interval

Data

def read_posterior_summary(fpath: Path) -> pd.DataFrame:
    """Read in a posterior summary data frame."""
    post_summ = pd.read_csv(fpath).assign(
        parameter_name=lambda d: [x.split("[")[0] for x in d.parameter]
    )
    return post_summ
hnb_model_dir = models_dir() / "hierarchical-nb_PYMC3_MCMC"
posterior_summary_path = hnb_model_dir / "posterior-summary.csv"
posterior_summary = read_posterior_summary(posterior_summary_path)
posterior_summary.head()
<style scoped> .dataframe tbody tr th:only-of-type { vertical-align: middle; }
.dataframe tbody tr th {
    vertical-align: top;
}

.dataframe thead th {
    text-align: right;
}
</style>
parameter mean sd hdi_5.5% hdi_94.5% mcse_mean mcse_sd ess_bulk ess_tail r_hat parameter_name
0 z 0.042 0.005 0.035 0.050 0.000 0.000 142.0 404.0 1.02 z
1 a[0] 0.314 0.087 0.165 0.439 0.001 0.001 4659.0 2956.0 1.00 a
2 a[1] -0.004 0.060 -0.106 0.085 0.001 0.001 2523.0 3087.0 1.00 a
3 a[2] 0.179 0.060 0.087 0.279 0.001 0.001 2254.0 2605.0 1.00 a
4 a[3] 0.196 0.065 0.097 0.303 0.001 0.001 3222.0 2927.0 1.00 a
hnb_model_cls = HierarchcalNegativeBinomialModel()
dm = CrisprScreenDataManager(DataFile.DEPMAP_CRC_BONE_LARGE_SUBSAMPLE)
counts_data = (
    dm.get_data()
    .pipe(hnb_model_cls.data_processing_pipeline)
    .reset_index(drop=False)
    .rename(columns={"index": "data_idx"})
)
counts_data.head()
/var/folders/r4/qpcdgl_14hbd412snp1jnv300000gn/T/ipykernel_56942/1818648733.py:4: DtypeWarning: Columns (3,22) have mixed types.Specify dtype option on import or set low_memory=False.
<style scoped> .dataframe tbody tr th:only-of-type { vertical-align: middle; }
.dataframe tbody tr th {
    vertical-align: top;
}

.dataframe thead th {
    text-align: right;
}
</style>
data_idx sgrna replicate_id lfc p_dna_batch genome_alignment hugo_symbol screen multiple_hits_on_gene sgrna_target_chr ... is_mutated copy_number lineage primary_or_metastasis is_male age counts_final_total counts_initial_total counts_final_rpm counts_initial_adj
0 0 AAAGCCCAGGAGTATGGGAG LS513-311Cas9_RepA_p6_batch2 0.594321 2 chr2_130522105_- CFC1B broad True 2 ... False 0.951337 colorectal primary True 63.0 35176093 1.072163e+06 13.309497 257.442323
1 1 ACTTGTCTCATGAACGTGAT LS513-311Cas9_RepA_p6_batch2 0.475678 2 chr2_86917638_+ RGPD1 broad True 2 ... False 0.949234 colorectal primary True 63.0 35176093 1.072163e+06 37.928490 766.756365
2 2 AGAAACTTCACCCCTTTCAT LS513-311Cas9_RepA_p6_batch2 0.296108 2 chr16_18543661_+ NOMO2 broad True 16 ... False 0.944648 colorectal primary True 63.0 35176093 1.072163e+06 29.513684 685.044642
3 3 AGCTGAGCGCAGGGACCGGG LS513-311Cas9_RepA_p6_batch2 -0.020788 2 chr1_27012633_- TENT5B broad True 1 ... False 0.961139 colorectal primary True 63.0 35176093 1.072163e+06 4.837834 142.977169
4 4 ATACTCCTGGGCTTTCGGAG LS513-311Cas9_RepA_p6_batch2 -0.771298 2 chr2_130522124_+ CFC1B broad True 2 ... False 0.951337 colorectal primary True 63.0 35176093 1.072163e+06 14.588775 706.908890

5 rows × 29 columns

data_coords = hnb_model_cls._model_coords(counts_data)

cell_line_lineage_map = (
    counts_data[["depmap_id", "lineage"]].drop_duplicates().reset_index(drop=True)
)

sgrna_gene_map = (
    counts_data[["sgrna", "hugo_symbol"]].drop_duplicates().reset_index(drop=True)
)

Analysis

Sampling diagnostics

R-hat

(
    gg.ggplot(posterior_summary, gg.aes(x="parameter_name", y="r_hat"))
    + gg.geom_boxplot(outlier_size=0.6, outlier_alpha=0.5)
    + gg.theme(axis_text_x=gg.element_text(angle=35, hjust=1))
    + gg.labs(x="parameter", y="$\widehat{R}$")
)

png

<ggplot: (355283119)>
max_r_hats = (
    posterior_summary.sort_values("r_hat", ascending=False)
    .groupby("parameter_name")
    .head(1)
)
(
    gg.ggplot(max_r_hats, gg.aes(x="parameter_name", y="r_hat"))
    + gg.geom_linerange(gg.aes(ymax="r_hat"), ymin=1)
    + gg.geom_point()
    + gg.scale_y_continuous(expand=(0, 0, 0.05, 0))
    + gg.theme(axis_text_x=gg.element_text(angle=35, hjust=1))
    + gg.labs(x="parameter", y="maximum $\widehat{R}$")
)

png

<ggplot: (355285619)>

ESS

(
    gg.ggplot(
        posterior_summary,
        gg.aes(x="parameter_name", y="ess_bulk"),
    )
    + gg.geom_jitter(alpha=0.1, size=0.2, width=0.3, height=0)
    + gg.geom_boxplot(outlier_alpha=0, alpha=0.5, color="red")
    + gg.theme(axis_text_x=gg.element_text(angle=35, hjust=1))
    + gg.labs(x="parameter", y="ESS (bulk)")
)

png

<ggplot: (357062199)>
min_ess = (
    posterior_summary.sort_values("ess_bulk", ascending=True)
    .groupby("parameter_name")
    .head(1)
)

_breaks = np.arange(0, 5000, 500)

(
    gg.ggplot(min_ess, gg.aes(x="parameter_name", y="ess_bulk"))
    + gg.geom_linerange(gg.aes(ymax="r_hat"), ymin=1)
    + gg.geom_point()
    + gg.scale_y_continuous(expand=(0.02, 0), breaks=_breaks)
    + gg.theme(axis_text_x=gg.element_text(angle=35, hjust=1))
    + gg.labs(x="parameter", y="minimum ESS (bulk)")
)

png

<ggplot: (356787182)>

Parameter posteriors

$b$: cell line

lineage_pal: Final[dict[str, str]] = {"bone": "darkviolet", "colorectal": "green"}
sigma_b_map = posterior_summary.query("parameter_name == 'sigma_b'")["mean"].values[0]

b_posterior = (
    posterior_summary.query("parameter_name == 'b'")
    .reset_index(drop=True)
    .assign(depmap_id=data_coords["cell_line"])
    .merge(cell_line_lineage_map, on="depmap_id")
    .sort_values(["mean"])
    .assign(
        depmap_id=lambda d: pd.Categorical(
            d.depmap_id.values, categories=d.depmap_id.values, ordered=True
        )
    )
)

(
    gg.ggplot(b_posterior, gg.aes(x="depmap_id", y="mean"))
    + gg.geom_hline(yintercept=0, color="gray")
    + gg.geom_hline(yintercept=[-sigma_b_map, sigma_b_map], linetype="--", color="gray")
    + gg.geom_linerange(
        gg.aes(ymin="hdi_5.5%", ymax="hdi_94.5%", color="lineage"), size=0.4
    )
    + gg.geom_point(gg.aes(color="lineage"), size=0.8)
    + gg.scale_color_manual(values=lineage_pal)
    + gg.theme(
        axis_text_x=gg.element_text(angle=90, size=5),
        panel_grid_major_x=gg.element_line(size=0.3, color="lightgray"),
        figure_size=(8, 3),
    )
    + gg.labs(x="cell line", y="posterior $b$\nmean ± 89% HDI")
)

png

<ggplot: (355138771)>

$a$: sgRNA

a_posterior = (
    posterior_summary.query("parameter_name == 'a'")
    .reset_index(drop=True)
    .assign(sgrna=data_coords["sgrna"])
    .merge(sgrna_gene_map, on="sgrna")
)
a_posterior.head()
<style scoped> .dataframe tbody tr th:only-of-type { vertical-align: middle; }
.dataframe tbody tr th {
    vertical-align: top;
}

.dataframe thead th {
    text-align: right;
}
</style>
parameter mean sd hdi_5.5% hdi_94.5% mcse_mean mcse_sd ess_bulk ess_tail r_hat parameter_name sgrna hugo_symbol
0 a[0] 0.314 0.087 0.165 0.439 0.001 0.001 4659.0 2956.0 1.0 a GAGCAAATACGAGCACCAAG LRP12
1 a[1] -0.004 0.060 -0.106 0.085 0.001 0.001 2523.0 3087.0 1.0 a CATTCTTTAGTGTAGCTAC CENPI
2 a[2] 0.179 0.060 0.087 0.279 0.001 0.001 2254.0 2605.0 1.0 a GTGTTCCGATTGGAGCCACA LPP
3 a[3] 0.196 0.065 0.097 0.303 0.001 0.001 3222.0 2927.0 1.0 a TTATTGACACCGAAACCGT BCAS3
4 a[4] 0.173 0.078 0.058 0.304 0.001 0.001 3788.0 3091.0 1.0 a AAGGTTTTCTGGTAGCAGA SLC35E3
(
    gg.ggplot(a_posterior, gg.aes(x="mean"))
    + gg.geom_density()
    + gg.scale_x_continuous(expand=(0, 0))
    + gg.scale_y_continuous(expand=(0, 0, 0.02, 0))
    + gg.theme(figure_size=(6, 4))
    + gg.labs(x="$a$ MAP estimates")
)

png

<ggplot: (354504689)>
def min_max(df: pd.DataFrame, n: int, drop_idx: bool = True) -> pd.DataFrame:
    """Get the top and botton `n` rows of a data frame."""
    return pd.concat([df.head(n), df.tail(n)]).reset_index(drop=drop_idx)


_n_top = 20
a_min_max = (
    a_posterior.sort_values("mean")
    .pipe(min_max, n=_n_top)
    .assign(
        sgrna=lambda d: pd.Categorical(
            d.sgrna.values, categories=d.sgrna.values, ordered=True
        )
    )
)
a_min_max.head()
<style scoped> .dataframe tbody tr th:only-of-type { vertical-align: middle; }
.dataframe tbody tr th {
    vertical-align: top;
}

.dataframe thead th {
    text-align: right;
}
</style>
parameter mean sd hdi_5.5% hdi_94.5% mcse_mean mcse_sd ess_bulk ess_tail r_hat parameter_name sgrna hugo_symbol
0 a[348] -3.567 0.145 -3.790 -3.337 0.002 0.001 5820.0 2632.0 1.0 a GTTGACGACAAGGGCGATG CRHR2
1 a[9682] -3.536 0.149 -3.758 -3.285 0.002 0.001 5308.0 3123.0 1.0 a GATGGAAGTGGAATCGCCC UBP1
2 a[16025] -3.531 0.142 -3.772 -3.311 0.002 0.001 5369.0 2937.0 1.0 a CGACACCACTACCACCCAT ASTL
3 a[349] -3.428 0.143 -3.645 -3.193 0.002 0.001 5352.0 3065.0 1.0 a ATATCGTTCACCCTAAACTT PSAT1
4 a[6869] -3.379 0.137 -3.585 -3.154 0.002 0.001 5291.0 2788.0 1.0 a CGCCCGCGACAGAAAAGAC ANKRD18A
_nudge_y = [-0.2 for _ in range(_n_top)] + [0.2 for _ in range(_n_top)]
_va = ["top" for _ in range(_n_top)] + ["bottom" for _ in range(_n_top)]
(
    gg.ggplot(a_min_max, gg.aes(x="sgrna", y="mean"))
    + gg.geom_hline(yintercept=0, alpha=0.2)
    + gg.geom_linerange(gg.aes(ymin=0, ymax="mean"), alpha=0.2)
    + gg.geom_point(size=1)
    + gg.geom_text(
        gg.aes(label="hugo_symbol"),
        size=7,
        angle=90,
        nudge_y=_nudge_y,
        va=_va,
        fontstyle="italic",
    )
    + gg.scale_y_continuous(expand=(0, 1.0, 0, 0.9))
    + gg.theme(figure_size=(8, 4), axis_text_x=gg.element_text(angle=90, size=6))
    + gg.labs(x="sgRNA", y="$a$ posterior\nmean ± 89% HDI")
)

png

<ggplot: (355553437)>
a_post_gene_avg = (
    a_posterior.groupby("hugo_symbol")["mean"]
    .median()
    .reset_index(drop=False)
    .sort_values("mean")
    .reset_index(drop=True)
)

_n_top = 20
a_post_gene_avg_minmax = a_post_gene_avg.pipe(min_max, n=_n_top)

_gene_order = a_post_gene_avg_minmax.hugo_symbol.values.astype(str)

a_post_gene_avg_minmax = (
    a_post_gene_avg_minmax.rename(columns={"mean": "gene_mean"})
    .merge(a_posterior, on="hugo_symbol", how="left")
    .assign(
        hugo_symbol=lambda d: pd.Categorical(
            d.hugo_symbol.values, categories=_gene_order, ordered=True
        )
    )
)

a_post_gene_avg_minmax.head()
<style scoped> .dataframe tbody tr th:only-of-type { vertical-align: middle; }
.dataframe tbody tr th {
    vertical-align: top;
}

.dataframe thead th {
    text-align: right;
}
</style>
hugo_symbol gene_mean parameter mean sd hdi_5.5% hdi_94.5% mcse_mean mcse_sd ess_bulk ess_tail r_hat parameter_name sgrna
0 CATSPERE -0.5835 a[10087] 0.121 0.057 0.029 0.213 0.001 0.001 2012.0 2893.0 1.0 a TCATCACTCAGAATGTCTGG
1 CATSPERE -0.5835 a[10660] 0.093 0.098 -0.052 0.258 0.001 0.001 4866.0 2890.0 1.0 a TTACCAATCTCCTCACCACG
2 CATSPERE -0.5835 a[13265] -1.646 0.164 -1.925 -1.401 0.002 0.001 7192.0 2734.0 1.0 a GCCATTAATTGACTACCACG
3 CATSPERE -0.5835 a[16733] -1.260 0.149 -1.501 -1.032 0.002 0.001 5925.0 2401.0 1.0 a AAAACACAGCAATCTCCAGA
4 PAGE2B -0.5125 a[7061] -0.604 0.062 -0.706 -0.510 0.001 0.001 2308.0 2849.0 1.0 a TCCCTTCACCTTGAACGGC
(
    gg.ggplot(a_post_gene_avg_minmax, gg.aes(x="hugo_symbol"))
    + gg.geom_boxplot(gg.aes(y="mean"), outlier_alpha=0, color="gray")
    + gg.geom_point(gg.aes(y="gene_mean"), shape="^", color="blue")
    + gg.geom_jitter(gg.aes(y="mean"), width=0.2, height=0, size=0.1)
    + gg.theme(
        figure_size=(8, 4), axis_text_x=gg.element_text(angle=90, size=7, hjust=1)
    )
    + gg.labs(
        x="gene (lowest and highest median $a$ posteriors)",
        y="$a$ posterior\nmean ± 89% HDI",
    )
)

png

<ggplot: (353525348)>

$d$: gene $\times$ lineage

d_posterior = posterior_summary.query("parameter_name == 'd'").reset_index(drop=True)

_idx = np.asarray(
    [x.replace("]", "").split("[")[1].split(",") for x in d_posterior.parameter],
    dtype=int,
)
d_posterior["hugo_symbol"] = [data_coords["gene"][i] for i in _idx[:, 0]]
d_posterior["lineage"] = [data_coords["lineage"][i] for i in _idx[:, 1]]


d_posterior.head()
<style scoped> .dataframe tbody tr th:only-of-type { vertical-align: middle; }
.dataframe tbody tr th {
    vertical-align: top;
}

.dataframe thead th {
    text-align: right;
}
</style>
parameter mean sd hdi_5.5% hdi_94.5% mcse_mean mcse_sd ess_bulk ess_tail r_hat parameter_name hugo_symbol lineage
0 d[0,0] -0.002 0.038 -0.059 0.060 0.001 0.001 973.0 2159.0 1.00 d ALAD bone
1 d[0,1] 0.025 0.037 -0.035 0.083 0.001 0.001 911.0 1550.0 1.00 d ALAD colorectal
2 d[1,0] -0.001 0.038 -0.059 0.062 0.001 0.001 1319.0 2430.0 1.01 d C14orf178 bone
3 d[1,1] 0.006 0.038 -0.058 0.065 0.001 0.001 1263.0 2169.0 1.01 d C14orf178 colorectal
4 d[2,0] -0.012 0.041 -0.073 0.057 0.001 0.001 2434.0 2330.0 1.00 d ERAP1 bone
(
    gg.ggplot(d_posterior, gg.aes(x="hugo_symbol", y="lineage"))
    + gg.geom_tile(gg.aes(fill="mean"), color=None)
    + gg.scale_y_discrete(expand=(0, 0.5))
    + gg.scale_fill_gradient2(low="#3A4CC0", high="#B30326")
    + gg.theme(
        figure_size=(8, 1),
        axis_text_x=gg.element_blank(),
        legend_position=(0.2, -0.4),
        legend_direction="horizontal",
        legend_key_width=10,
        legend_background=gg.element_blank(),
        legend_text=gg.element_text(angle=90, size=7, va="bottom"),
        panel_background=gg.element_blank(),
        panel_border=gg.element_blank(),
    )
    + gg.labs(x="gene", y="lineage", fill="$d$ MAP")
)

png

<ggplot: (355640381)>
top_d_genes = d_posterior.sort_values("mean").pipe(min_max, n=30).hugo_symbol.unique()
d_posterior_top = d_posterior.filter_column_isin("hugo_symbol", top_d_genes)

(
    gg.ggplot(d_posterior_top, gg.aes(x="hugo_symbol", y="lineage"))
    + gg.geom_tile(gg.aes(fill="mean"), color=None)
    + gg.scale_y_discrete(expand=(0, 0.5))
    + gg.scale_fill_gradient2(low="#3A4CC0", high="#B30326")
    + gg.theme(
        figure_size=(8, 1),
        axis_text_x=gg.element_text(angle=90, hjust=1, size=7),
        panel_background=gg.element_blank(),
        panel_border=gg.element_blank(),
    )
    + gg.labs(x="gene", y="lineage", fill="$d$ MAP")
)

png

<ggplot: (356083219)>
sigma_d_map = posterior_summary.query("parameter== 'sigma_d'")["mean"].values[0]

d_posterior_diff = (
    d_posterior[["mean", "hugo_symbol", "lineage"]]
    .pivot_wider(index="hugo_symbol", names_from="lineage", values_from="mean")
    .assign(diff=lambda d: d.bone - d.colorectal)
)

most_diff_d = (
    d_posterior_diff.sort_values("diff").pipe(min_max, n=20).hugo_symbol.values
)

plot_df = d_posterior.filter_column_isin("hugo_symbol", most_diff_d).assign(
    hugo_symbol=lambda d: pd.Categorical(
        d.hugo_symbol, categories=most_diff_d, ordered=True
    )
)

(
    gg.ggplot(plot_df, gg.aes(x="hugo_symbol", y="mean", color="lineage"))
    + gg.geom_hline(yintercept=0, color="gray")
    + gg.geom_hline(yintercept=[-sigma_d_map, sigma_d_map], linetype="--", color="gray")
    + gg.geom_linerange(gg.aes(ymin="hdi_5.5%", ymax="hdi_94.5%"))
    + gg.geom_point()
    + gg.scale_color_manual(values=lineage_pal)
    + gg.theme(axis_text_x=gg.element_text(angle=90, size=8), figure_size=(8, 4))
    + gg.labs(x="gene", y="$d$ posterior\nmean ± 89% HDI")
)

png

<ggplot: (356746999)>
plot_df = counts_data.filter_column_isin("hugo_symbol", most_diff_d).assign(
    hugo_symbol=lambda d: pd.Categorical(
        d.hugo_symbol, categories=most_diff_d, ordered=True
    )
)
(
    gg.ggplot(plot_df, gg.aes(x="hugo_symbol", y="lfc"))
    + gg.geom_boxplot(gg.aes(color="lineage"), outlier_alpha=0)
    + gg.scale_color_manual(values=lineage_pal)
    + gg.scale_y_continuous(limits=(-3, 1.5))
    + gg.theme(axis_text_x=gg.element_text(angle=90, size=8), figure_size=(8, 4))
    + gg.labs(x="gene", y="log-fold change")
)
/usr/local/Caskroom/miniconda/base/envs/speclet/lib/python3.9/site-packages/plotnine/layer.py:324: PlotnineWarning: stat_boxplot : Removed 78 rows containing non-finite values.

png

<ggplot: (357135613)>

$\alpha$: gene dispersion

posterior_summary.filter_string("parameter_name", "_alpha")
<style scoped> .dataframe tbody tr th:only-of-type { vertical-align: middle; }
.dataframe tbody tr th {
    vertical-align: top;
}

.dataframe thead th {
    text-align: right;
}
</style>
parameter mean sd hdi_5.5% hdi_94.5% mcse_mean mcse_sd ess_bulk ess_tail r_hat parameter_name
21045 alpha_alpha 3.680 0.110 3.512 3.866 0.002 0.001 3324.0 3128.0 1.0 alpha_alpha
21046 beta_alpha 0.529 0.017 0.501 0.555 0.000 0.000 3341.0 2910.0 1.0 beta_alpha
x = np.linspace(0, 20, 500)
alpha = posterior_summary.query("parameter_name == 'alpha_alpha'")["mean"].values[0]
beta = posterior_summary.query("parameter_name == 'beta_alpha'")["mean"].values[0]
pdf = st.gamma.pdf(x, alpha, scale=1.0 / beta)

alpha_parent_dist = pd.DataFrame({"x": x, "pdf": pdf})
alpha_posterior = (
    posterior_summary.query("parameter_name == 'alpha'")
    .reset_index(drop=True)
    .assign(hugo_symbol=data_coords["gene"])
)
alpha_posterior.head()
<style scoped> .dataframe tbody tr th:only-of-type { vertical-align: middle; }
.dataframe tbody tr th {
    vertical-align: top;
}

.dataframe thead th {
    text-align: right;
}
</style>
parameter mean sd hdi_5.5% hdi_94.5% mcse_mean mcse_sd ess_bulk ess_tail r_hat parameter_name hugo_symbol
0 alpha[0] 11.514 0.761 10.304 12.702 0.009 0.006 8002.0 2752.0 1.0 alpha ALAD
1 alpha[1] 9.574 0.647 8.579 10.636 0.008 0.006 6600.0 2627.0 1.0 alpha C14orf178
2 alpha[2] 4.555 0.312 4.098 5.082 0.004 0.003 6702.0 2841.0 1.0 alpha ERAP1
3 alpha[3] 8.987 0.592 8.066 9.936 0.007 0.005 6803.0 2756.0 1.0 alpha PARD3B
4 alpha[4] 9.101 0.582 8.174 10.002 0.007 0.005 6317.0 2917.0 1.0 alpha BHLHB9
(
    gg.ggplot(alpha_posterior, gg.aes(x="mean"))
    + gg.geom_density()
    + gg.geom_line(gg.aes(x="x", y="pdf"), data=alpha_parent_dist, color="blue")
    + gg.scale_x_continuous(expand=(0, 0))
    + gg.scale_y_continuous(expand=(0, 0, 0.02, 0))
    + gg.labs(x="$\\alpha$ MAP")
)

png

<ggplot: (357467862)>

Posterior predictions

def read_posterior_pred(fpath: Path) -> pd.DataFrame:
    return pd.read_csv(fpath).rename(columns={"ct_final_dim_0": "data_idx"})
posterior_pred = read_posterior_pred(hnb_model_dir / "posterior-predictions.csv")
posterior_pred.head()
<style scoped> .dataframe tbody tr th:only-of-type { vertical-align: middle; }
.dataframe tbody tr th {
    vertical-align: top;
}

.dataframe thead th {
    text-align: right;
}
</style>
chain draw data_idx ct_final
0 0 0 0 275
1 0 0 1 943
2 0 0 2 1474
3 0 0 3 65
4 0 0 4 615
ppc_df = (
    pd.concat(
        [
            posterior_pred[["ct_final", "data_idx"]].assign(data="post. pred."),
            counts_data[["counts_final"]]
            .rename(columns={"counts_final": "ct_final"})
            .assign(data="observed"),
        ]
    )
    .astype({"ct_final": float})
    .assign(
        data=lambda d: pd.Categorical(
            d["data"], categories=["post. pred.", "observed"], ordered=True
        )
    )
)
ppc_df.head()
<style scoped> .dataframe tbody tr th:only-of-type { vertical-align: middle; }
.dataframe tbody tr th {
    vertical-align: top;
}

.dataframe thead th {
    text-align: right;
}
</style>
ct_final data_idx data
0 275.0 0.0 post. pred.
1 943.0 1.0 post. pred.
2 1474.0 2.0 post. pred.
3 65.0 3.0 post. pred.
4 615.0 4.0 post. pred.
n_data_pts = ppc_df.data_idx.max()
_sample_size = min(10_000, round(n_data_pts * 0.05))
print(f"using {_sample_size:,d} randomly sampled data points")
_data_idx = np.random.choice(np.arange(0, n_data_pts), size=_sample_size, replace=False)

_pal = {"post. pred.": "k", "observed": "b"}

(
    gg.ggplot(
        ppc_df.filter_column_isin("data_idx", _data_idx).query("ct_final <= 2000"),
        gg.aes(x="ct_final", fill="data", color="data"),
    )
    + gg.geom_histogram(
        gg.aes(y=gg.after_stat("ncount")),
        position="identity",
        binwidth=50,
        size=0.5,
        alpha=0.1,
    )
    + gg.scale_x_continuous(expand=(0, 0))
    + gg.scale_y_continuous(expand=(0, 0, 0.02, 0))
    + gg.scale_fill_manual(values=_pal)
    + gg.scale_color_manual(values=_pal)
    + gg.theme(
        legend_position=(0.8, 0.45),
        legend_background=gg.element_rect(alpha=0.5),
        figure_size=(6, 4),
    )
)
using 10,000 randomly sampled data points

png

<ggplot: (357527540)>
posterior_pred.shape, counts_data.shape
((87419400, 4), (874194, 29))
def _summarize_ppc(df: pd.DataFrame) -> pd.DataFrame:
    vals = df["ct_final"]
    avg_mean = vals.mean()
    avg_mid = vals.median()
    hdi = az.hdi(vals.values.flatten()).flatten()
    return pd.DataFrame(
        {"mean": avg_mean, "median": avg_mid, "hdi_low": hdi[0], "hdi_high": hdi[1]},
        index=[0],
    )


ppc_summary = (
    posterior_pred.filter_column_isin("data_idx", _data_idx)
    .groupby("data_idx")
    .apply(_summarize_ppc)
    .reset_index(drop=False)
    .drop(columns="level_1")
    .merge(counts_data, on="data_idx", how="left")
    .assign(
        error=lambda d: d["counts_final"] - d["median"],
        pct_error=lambda d: 100
        * (1 + d["counts_final"] - d["median"])
        / (1 + d["counts_final"]),
    )
)
ppc_summary.head()
<style scoped> .dataframe tbody tr th:only-of-type { vertical-align: middle; }
.dataframe tbody tr th {
    vertical-align: top;
}

.dataframe thead th {
    text-align: right;
}
</style>
data_idx mean median hdi_low hdi_high sgrna replicate_id lfc p_dna_batch genome_alignment ... lineage primary_or_metastasis is_male age counts_final_total counts_initial_total counts_final_rpm counts_initial_adj error pct_error
0 144 470.17 431.5 109 915 AAGCATCTTGGGAGACAGCG LS513-311Cas9_RepA_p6_batch2 -0.218908 2 chr17_59697794_- ... colorectal primary True 63.0 35176093 1.072163e+06 12.996784 430.314884 -9.5 -2.009456
1 185 172.27 156.0 46 322 AAGTGGTGCTGGAAAAACAG LS513-311Cas9_RepA_p6_batch2 -0.750700 2 chr15_59236649_- ... colorectal primary True 63.0 35176093 1.072163e+06 2.904703 146.623744 -89.0 -129.411765
2 215 323.33 303.0 170 612 AATCCAGGCGATGTCAGCCA LS513-311Cas9_RepA_p6_batch2 0.018910 2 chr1_25826359_- ... colorectal primary True 63.0 35176093 1.072163e+06 9.727518 273.517832 4.0 1.623377
3 351 229.93 216.5 99 354 ACAGAAGTACATGACCGCCG LS513-311Cas9_RepA_p6_batch2 0.945948 2 chr16_88646818_- ... colorectal primary True 63.0 35176093 1.072163e+06 16.550334 245.670676 330.5 60.492701
4 527 1188.48 1154.0 523 1961 ACTATGTTCCAATTCTTCAG LS513-311Cas9_RepA_p6_batch2 0.511113 2 chr2_33588353_- ... colorectal primary True 63.0 35176093 1.072163e+06 47.025578 941.679979 465.0 28.765432

5 rows × 35 columns

(
    gg.ggplot(ppc_summary, gg.aes(x="counts_final", y="median"))
    + gg.geom_linerange(
        gg.aes(ymin="hdi_low", ymax="hdi_high"), alpha=0.5, size=0.5, color="gray"
    )
    + gg.geom_point(size=0.5, alpha=0.4, color="black")
    + gg.geom_abline(slope=1, intercept=0, color="blue", alpha=0.6, linetype="--")
    + gg.scale_x_continuous(expand=(0, 0, 0.02, 0))
    + gg.scale_y_continuous(expand=(0, 0, 0.02, 0))
)

png

<ggplot: (355831330)>
(
    gg.ggplot(ppc_summary, gg.aes(x="pct_error"))
    + gg.geom_density(alpha=0.2)
    + gg.scale_x_continuous(expand=(0, 0))
    + gg.scale_y_continuous(expand=(0, 0, 0.02, 0))
    + gg.theme(figure_size=(6, 4))
)

png

<ggplot: (354679251)>
(
    gg.ggplot(ppc_summary, gg.aes(x="counts_final", y="pct_error"))
    + gg.geom_point(size=0.3, alpha=0.3)
    + gg.scale_x_log10(expand=(0.02, 0))
    + gg.scale_y_continuous(expand=(0.02, 0))
    + gg.theme(figure_size=(6, 4))
    + gg.labs(x="final counts (observed)", y="percent error ($\\frac{T-P}{T}$)")
)
/usr/local/Caskroom/miniconda/base/envs/speclet/lib/python3.9/site-packages/pandas/core/arraylike.py:364: RuntimeWarning: divide by zero encountered in log10

png

<ggplot: (353797681)>
obs_zeros = np.mean(counts_data["counts_final"] == 0) * 100
pred_zeros = np.mean(ppc_df["ct_final"] == 0) * 100
print("percent of zeros:")
print(f"   observed: {obs_zeros:0.3f}%")
print(f"  predicted: {pred_zeros:0.3f}%")
percent of zeros:
   observed: 1.017%
  predicted: 0.072%

Comparing "jitter+adapt_diag" and "advi" chain initialization

advi_posterior_dir = models_dir() / "hierarchical-nb_advi-init_PYMC3_MCMC"
advi_post_summ = read_posterior_summary(advi_posterior_dir / "posterior-summary.csv")
advi_post_summ.head()
<style scoped> .dataframe tbody tr th:only-of-type { vertical-align: middle; }
.dataframe tbody tr th {
    vertical-align: top;
}

.dataframe thead th {
    text-align: right;
}
</style>
parameter mean sd hdi_5.5% hdi_94.5% mcse_mean mcse_sd ess_bulk ess_tail r_hat parameter_name
0 z 0.040 0.006 0.030 0.050 0.001 0.001 23.0 69.0 1.16 z
1 a[0] 0.314 0.085 0.179 0.453 0.001 0.001 7174.0 2973.0 1.00 a
2 a[1] -0.002 0.059 -0.101 0.088 0.001 0.001 3795.0 3032.0 1.00 a
3 a[2] 0.177 0.062 0.082 0.280 0.001 0.001 3750.0 2886.0 1.00 a
4 a[3] 0.199 0.064 0.098 0.298 0.001 0.001 4707.0 3159.0 1.00 a
combined_post = pd.concat(
    [
        posterior_summary.assign(init_method="jitter+adapt_diag"),
        advi_post_summ.assign(init_method="advi"),
    ]
)
combined_post.head()
<style scoped> .dataframe tbody tr th:only-of-type { vertical-align: middle; }
.dataframe tbody tr th {
    vertical-align: top;
}

.dataframe thead th {
    text-align: right;
}
</style>
parameter mean sd hdi_5.5% hdi_94.5% mcse_mean mcse_sd ess_bulk ess_tail r_hat parameter_name init_method
0 z 0.042 0.005 0.035 0.050 0.000 0.000 142.0 404.0 1.02 z jitter+adapt_diag
1 a[0] 0.314 0.087 0.165 0.439 0.001 0.001 4659.0 2956.0 1.00 a jitter+adapt_diag
2 a[1] -0.004 0.060 -0.106 0.085 0.001 0.001 2523.0 3087.0 1.00 a jitter+adapt_diag
3 a[2] 0.179 0.060 0.087 0.279 0.001 0.001 2254.0 2605.0 1.00 a jitter+adapt_diag
4 a[3] 0.196 0.065 0.097 0.303 0.001 0.001 3222.0 2927.0 1.00 a jitter+adapt_diag

Sampling diagnostics

R-hat

(
    gg.ggplot(combined_post, gg.aes(x="parameter_name", y="r_hat", color="init_method"))
    + gg.geom_boxplot(outlier_size=0.6, outlier_alpha=0.5)
    + gg.theme(axis_text_x=gg.element_text(angle=35, hjust=1))
    + gg.labs(x="parameter", y="$\widehat{R}$", color="init. method")
)

png

<ggplot: (353879342)>

ESS

(
    gg.ggplot(
        combined_post,
        gg.aes(x="parameter_name", y="ess_bulk", color="init_method"),
    )
    + gg.geom_point(
        position=gg.position_jitterdodge(
            jitter_width=0.3, jitter_height=0, dodge_width=0.5
        ),
        alpha=0.1,
        size=0.2,
    )
    + gg.geom_boxplot(outlier_alpha=0, alpha=0.5)
    + gg.theme(axis_text_x=gg.element_text(angle=35, hjust=1))
    + gg.labs(x="parameter", y="ESS (bulk)", color="init. method")
)

png

<ggplot: (343868444)>
(
    gg.ggplot(
        combined_post,
        gg.aes(x="parameter_name", y="ess_tail", color="init_method"),
    )
    + gg.geom_point(
        position=gg.position_jitterdodge(
            jitter_width=0.3, jitter_height=0, dodge_width=0.5
        ),
        alpha=0.1,
        size=0.2,
    )
    + gg.geom_boxplot(outlier_alpha=0, alpha=0.5)
    + gg.theme(axis_text_x=gg.element_text(angle=35, hjust=1))
    + gg.labs(x="parameter", y="ESS (bulk)", color="init. method")
)

png

<ggplot: (357493827)>

Posterior estimates

keep_cols = ["a", "b", "d", "alpha"]
plot_df = combined_post.filter_column_isin("parameter_name", keep_cols)

(
    gg.ggplot(plot_df, gg.aes(x="mean", color="init_method", fill="init_method"))
    + gg.facet_wrap("~parameter_name", nrow=2, scales="free")
    + gg.geom_density(alpha=0.5)
    + gg.scale_x_continuous(expand=(0, 0))
    + gg.scale_y_continuous(expand=(0, 0, 0.02, 0))
    + gg.theme(
        figure_size=(8, 6),
        subplots_adjust={"hspace": 0.25, "wspace": 0.25},
        strip_text=gg.element_text(weight="bold"),
    )
    + gg.labs(x="MAP", y="density", color="init. method", fill="init. method")
)

png

<ggplot: (356429346)>
plot_df = combined_post.filter_column_isin("parameter_name", keep_cols)

(
    gg.ggplot(plot_df, gg.aes(x="sd", color="init_method", fill="init_method"))
    + gg.facet_wrap("~parameter_name", nrow=2, scales="free")
    + gg.geom_density(alpha=0.5)
    + gg.scale_x_continuous(expand=(0, 0))
    + gg.scale_y_continuous(expand=(0, 0, 0.02, 0))
    + gg.theme(
        figure_size=(8, 6),
        subplots_adjust={"hspace": 0.25, "wspace": 0.25},
        strip_text=gg.element_text(weight="bold"),
    )
    + gg.labs(
        x="posterior std. dev.", y="density", color="init. method", fill="init. method"
    )
)

png

<ggplot: (354963658)>
plot_df = (
    combined_post.query("parameter_name == 'b'")
    .reset_index(drop=True)
    .assign(idx=lambda d: [int(re.findall("[0-9]+", x)[0]) for x in d["parameter"]])
)

(
    gg.ggplot(plot_df, gg.aes(x="factor(idx)", y="mean", color="init_method"))
    + gg.geom_hline(yintercept=0, color="gray")
    + gg.geom_linerange(gg.aes(ymin="hdi_5.5%", ymax="hdi_94.5%"), alpha=0.5, size=1)
    + gg.geom_point(size=2, alpha=0.5)
    + gg.scale_y_continuous(expand=(0.02, 0))
    + gg.scale_color_brewer(type="qual", palette="Dark2")
    + gg.theme(
        figure_size=(8, 6),
        axis_text_x=gg.element_blank(),
        panel_grid_major_x=gg.element_blank(),
    )
)

png

<ggplot: (353776939)>

Posterior predictions

advi_post_pred = read_posterior_pred(advi_posterior_dir / "posterior-predictions.csv")
advi_post_pred.head()
<style scoped> .dataframe tbody tr th:only-of-type { vertical-align: middle; }
.dataframe tbody tr th {
    vertical-align: top;
}

.dataframe thead th {
    text-align: right;
}
</style>
chain draw data_idx ct_final
0 0 0 0 211
1 0 0 1 900
2 0 0 2 1034
3 0 0 3 97
4 0 0 4 1010
combined_ppc = pd.concat(
    [
        posterior_pred.assign(init_method="jitter+adapt_diag"),
        advi_post_pred.assign(init_method="advi"),
    ]
)
combined_ppc.head()
<style scoped> .dataframe tbody tr th:only-of-type { vertical-align: middle; }
.dataframe tbody tr th {
    vertical-align: top;
}

.dataframe thead th {
    text-align: right;
}
</style>
chain draw data_idx ct_final init_method
0 0 0 0 275 jitter+adapt_diag
1 0 0 1 943 jitter+adapt_diag
2 0 0 2 1474 jitter+adapt_diag
3 0 0 3 65 jitter+adapt_diag
4 0 0 4 615 jitter+adapt_diag
combined_ppc_avg = (
    combined_ppc.filter_column_isin("data_idx", _data_idx)
    .groupby(["data_idx", "init_method"])["ct_final"]
    .median()
    .reset_index(drop=False)
    .pivot_wider(index="data_idx", names_from="init_method", values_from="ct_final")
)
combined_ppc_avg.head()
<style scoped> .dataframe tbody tr th:only-of-type { vertical-align: middle; }
.dataframe tbody tr th {
    vertical-align: top;
}

.dataframe thead th {
    text-align: right;
}
</style>
data_idx advi jitter+adapt_diag
0 144 478.0 431.5
1 185 148.0 156.0
2 215 316.0 303.0
3 351 230.5 216.5
4 527 1027.0 1154.0
(
    gg.ggplot(combined_ppc_avg, gg.aes(x="advi", y="jitter+adapt_diag"))
    + gg.geom_point(alpha=0.5, size=0.5)
    + gg.geom_abline(slope=1, intercept=0, alpha=1, linetype="--", color="blue")
    + gg.scale_x_continuous(expand=(0, 0, 0.02, 0))
    + gg.scale_y_continuous(expand=(0, 0, 0.02, 0))
)

png

<ggplot: (353749520)>

Watermark

notebook_toc = time()
print(f"execution time: {(notebook_toc - notebook_tic) / 60:.2f} minutes")
execution time: 6.04 minutes
%load_ext watermark
%watermark -d -u -v -iv -b -h -m
Last updated: 2022-02-11

Python implementation: CPython
Python version       : 3.9.9
IPython version      : 8.0.0

Compiler    : Clang 11.1.0
OS          : Darwin
Release     : 21.2.0
Machine     : x86_64
Processor   : i386
CPU cores   : 4
Architecture: 64bit

Hostname: JHCookMac

Git branch: theano-blas-warning

matplotlib: 3.5.1
seaborn   : 0.11.2
pandas    : 1.3.5
plotnine  : 0.8.0
arviz     : 0.11.4
scipy     : 1.7.3
pymc3     : 3.11.4
numpy     : 1.22.0
re        : 2.2.1