Skip to content

Commit

Permalink
Merge branch 'main' into safe-prior-strategy
Browse files Browse the repository at this point in the history
digicosmos86 committed Dec 15, 2023
2 parents 0d4c507 + 186b5f6 commit 5061f44
Showing 5 changed files with 234 additions and 37 deletions.
141 changes: 112 additions & 29 deletions docs/tutorials/plotting.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -40,12 +40,12 @@ pre-commit = "^2.20.0"
jupyterlab = "^4.0.2"
ipykernel = "^6.16.0"
ipywidgets = "^8.0.3"
graphviz = "^0.20.1"
ruff = "^0.1.3"
mkdocs = "^1.4.3"
mkdocs-material = "^9.1.17"
mkdocstrings-python = "^1.1.2"
mkdocs-jupyter = "^0.24.1"
graphviz = "^0.20.1"

[tool.black]
line-length = 88
14 changes: 14 additions & 0 deletions src/hssm/hssm.py
Original file line number Diff line number Diff line change
@@ -523,6 +523,20 @@ def plot_posterior_predictive(self, **kwargs) -> mpl.axes.Axes | sns.FacetGrid:
"""
return plotting.plot_posterior_predictive(self, **kwargs)

def plot_quantile_probability(self, **kwargs) -> mpl.axes.Axes | sns.FacetGrid:
"""Produce a quantile probability plot.
Equivalent to calling `hssm.plotting.plot_quantile_probability()` with the
model. Please see that function for
[full documentation][hssm.plotting.plot_quantile_probability].
Returns
-------
mpl.axes.Axes | sns.FacetGrid
The matplotlib axis or seaborn FacetGrid object containing the plot.
"""
return plotting.plot_quantile_probability(self, **kwargs)

def sample_prior_predictive(
self,
draws: int = 500,
8 changes: 5 additions & 3 deletions src/hssm/plotting/quantile_probability.py
Original file line number Diff line number Diff line change
@@ -94,7 +94,6 @@ def _plot_quantile_probability_1D(
)
xticks = ticks_and_labels[x]
xticklabels = xticklabels or ticks_and_labels[cond]
print(xticks, xticklabels)
secax = ax.twiny()
secax.set_xticks(xticks)
secax.set_xticklabels(xticklabels)
@@ -189,7 +188,7 @@ def _plot_quantile_probability_2D(
def plot_quantile_probability(
model,
cond: str,
data: pd.DataFrame,
data: pd.DataFrame | None = None,
idata: az.InferenceData | None = None,
n_samples: int = 20,
x: str = "proportion",
@@ -220,7 +219,7 @@ def plot_quantile_probability(
A model object that has a `plot_quantile_probability` method.
cond
The column in `data` that indicates the conditions.
data
data : optional
A pandas DataFrame containing the observed data. If None, the data from
`idata.observed_data` will be used.
idata : optional
@@ -302,6 +301,9 @@ def plot_quantile_probability(
mpl.axes.Axes | sns.FacetGrid | list[sns.FacetGrid]
A seaborn FacetGrid object containing the plot.
"""
if data is None:
data = model.data

groups, groups_order = _check_groups_and_groups_order(
groups, groups_order, row, col
)
106 changes: 102 additions & 4 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
@@ -22,6 +22,11 @@
_plot_posterior_predictive_2D,
plot_posterior_predictive,
)
from hssm.plotting.quantile_probability import (
_plot_quantile_probability_1D,
_plot_quantile_probability_2D,
plot_quantile_probability,
)

hssm.set_floatX("float32")

@@ -228,8 +233,101 @@ def test_plot_posterior_predictive(cav_idata, cavanagh_test):


def test__process_df_for_qp_plot(cav_idata, cavanagh_test):
df = _get_plotting_df(cav_idata, cavanagh_test, extra_dims=["participant_id"])
n_chain = cav_idata.posterior_predictive["rt,response"].chain.size
n_draw = cav_idata.posterior_predictive["rt,response"].draw.size
df = _get_plotting_df(
cav_idata, cavanagh_test, extra_dims=["participant_id", "conf"]
)

processed_df = _process_df_for_qp_plot(df, 6, "conf", None)

assert "conf" in processed_df.columns
assert "is_correct" in processed_df.columns
assert processed_df["quantile"].nunique() == 4
assert np.all(
processed_df.groupby(["observed", "chain", "draw", "conf", "quantile"])[
"proportion"
].sum()
== 1
)


def has_twin(ax):
"""Checks if an axes has a twin axes with the same bounds.
Credit: https://stackoverflow.com/questions/36209575/how-to-detect-if-a-twin-axis-has-been-generated-for-a-matplotlib-axis
"""
for other_ax in ax.figure.axes:
if other_ax is ax:
continue
if other_ax.bbox.bounds == ax.bbox.bounds:
return True
return False


def test__plot_quantile_probability_1D(cav_idata, cavanagh_test):
df = _get_plotting_df(cav_idata, cavanagh_test, extra_dims=["stim"])
ax = _plot_quantile_probability_1D(df, cond="stim")

assert has_twin(ax)
assert ax.get_xlabel() == "Proportion"
assert ax.get_ylabel() == "rt"
assert ax.get_title() == "Quantile Probability Plot"


def test__plot_quantile_probability_2D(cav_idata, cavanagh_test):
df = _get_plotting_df(
cav_idata, cavanagh_test, extra_dims=["participant_id", "stim"]
)
g = _plot_quantile_probability_2D(df, cond="stim", col="participant_id", col_wrap=3)

assert len(g.fig.axes) == 10

df = _get_plotting_df(
cav_idata, cavanagh_test, extra_dims=["participant_id", "stim", "conf"]
)
g = _plot_quantile_probability_2D(df, cond="stim", col="participant_id", row="conf")

assert len(g.fig.axes) == 5 * 4


def test_plot_quantile_probability(cav_idata, cavanagh_test):
# Mock model object
model = hssm.HSSM(
data=cavanagh_test,
include=[
{
"name": "v",
"prior": {
"Intercept": {"name": "Normal", "mu": 0.0, "sigma": 1.0},
"theta": {"name": "Normal", "mu": 0.0, "sigma": 1.0},
},
"formula": "v ~ (1|participant_id) + theta",
"link": "identity",
},
],
) # Doesn't matter what model or data we use here
with pytest.raises(ValueError):
plot_quantile_probability(model, cond="stim")

model._inference_obj = cav_idata.copy()
ax1 = plot_quantile_probability(
model, cond="stim", data=cavanagh_test
) # Should work directly
assert len(ax1.get_lines()) == 9

processed_df = _process_df_for_qp_plot(df, 6, "response", "participant_id")
delattr(model.traces, "posterior_predictive")
ax2 = plot_quantile_probability(
model, cond="stim", data=cavanagh_test, n_samples=2
) # Should sample posterior predictive
assert len(ax2.get_lines()) == 9
assert "posterior_predictive" in model.traces
assert model.traces.posterior_predictive.draw.size == 2

with pytest.raises(ValueError):
plot_quantile_probability(model, groups="participant_id", cond="stim")
with pytest.raises(ValueError):
plot_quantile_probability(model, groups_order=["5", "4"], cond="stim")

plots = plot_quantile_probability(
model, row="dbs", col="participant_id", cond="stim", groups="conf"
)
assert len(plots) == 2

0 comments on commit 5061f44

Please sign in to comment.