Skip to content

Commit

Permalink
add refereces values to predictive explorer (#302)
Browse files Browse the repository at this point in the history
* add refereces values

* pass none predictive finder
  • Loading branch information
aloctavodia authored Jan 19, 2024
1 parent 0f91769 commit 25301f1
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 7 deletions.
18 changes: 15 additions & 3 deletions preliz/internal/plot_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def get_textboxes(signature, model):
return textboxes


def plot_decorator(func, iterations, kind_plot):
def plot_decorator(func, iterations, kind_plot, references):
def looper(*args, **kwargs):
results = []
kwargs.pop("__resample__")
Expand All @@ -406,12 +406,12 @@ def looper(*args, **kwargs):
_, ax = plt.subplots()
ax.set_xlim(x_min, x_max, auto=auto)

plot_repr(results, kind_plot, iterations, ax)
plot_repr(results, kind_plot, references, iterations, ax)

return looper


def plot_repr(results, kind_plot, iterations, ax):
def plot_repr(results, kind_plot, references, iterations, ax):
alpha = max(0.01, 1 - iterations * 0.009)

if kind_plot == "hist":
Expand Down Expand Up @@ -450,6 +450,18 @@ def plot_repr(results, kind_plot, iterations, ax):
a = np.concatenate(results)
ax.plot(np.sort(a), np.linspace(0, 1, len(a), endpoint=False), "k--")

if references is not None:
if isinstance(references, dict):
max_value = ax.get_ylim()[1]
for label, ref in references.items():
ax.text(ref, max_value * 0.2, label, rotation=90, bbox={"color": "w", "alpha": 0.5})
ax.axvline(ref, ls="--", color="0.5")
else:
if isinstance(references, (float, int)):
references = [references]
for ref in references:
ax.axvline(ref, ls="--", color="0.5")


def plot_pp_samples(pp_samples, pp_samples_idxs, references, kind="pdf", sharex=True, fig=None):
row_colum = int(np.ceil(len(pp_samples_idxs) ** 0.5))
Expand Down
7 changes: 5 additions & 2 deletions preliz/predictive/predictive_explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from preliz.internal.plot_helper import get_textboxes, plot_decorator


def predictive_explorer(fmodel, samples=50, kind_plot="ecdf"):
def predictive_explorer(fmodel, samples=50, kind_plot="ecdf", references=None):
"""
Create textboxes and plot a set of samples returned by a function relating one or more
PreliZ distributions.
Expand All @@ -26,6 +26,9 @@ def predictive_explorer(fmodel, samples=50, kind_plot="ecdf"):
The type of plot to display. Defaults to "kde". Options are "hist" (histogram),
"kde" (kernel density estimate), "ecdf" (empirical cumulative distribution function),
or None (no plot).
references : int, float, list, tuple or dictionary
Value(s) used as reference points representing prior knowledge. For example expected
values or values that are considered extreme. Use a dictionary for labeled references.
"""
source, signature = inspect_source(fmodel)

Expand All @@ -35,7 +38,7 @@ def predictive_explorer(fmodel, samples=50, kind_plot="ecdf"):
if kind_plot is None:
new_fmodel = fmodel
else:
new_fmodel = plot_decorator(fmodel, samples, kind_plot)
new_fmodel = plot_decorator(fmodel, samples, kind_plot, references)

out = interactive_output(new_fmodel, textboxes)
default_names = ["__set_xlim__", "__x_min__", "__x_max__", "__resample__"]
Expand Down
4 changes: 2 additions & 2 deletions preliz/predictive/predictive_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def __call__(self, kind_plot):
self.pp_samples = self.fmodel(*self.values)[-1]

reset_dist_panel(self.ax, True)
plot_repr(self.pp_samples, kind_plot, self.draws, self.ax)
plot_repr(self.pp_samples, kind_plot, None, self.draws, self.ax)

if kind_plot == "ecdf":
self.target.plot_cdf(color="C0", legend=False, ax=self.ax)
Expand Down Expand Up @@ -183,7 +183,7 @@ def select(prior_sample, pp_sample, draws, target_octiles, model):
def plot_pp_samples(pp_samples, draws, target, kind_plot, fig, ax):

reset_dist_panel(ax, True)
plot_repr(pp_samples, kind_plot, draws, ax)
plot_repr(pp_samples, kind_plot, None, draws, ax)

if kind_plot == "ecdf":
target.plot_cdf(color="C0", legend=False, ax=ax)
Expand Down

0 comments on commit 25301f1

Please sign in to comment.