diff --git a/preliz/distributions/vonmises.py b/preliz/distributions/vonmises.py index 419a7b4a..533fa7fc 100644 --- a/preliz/distributions/vonmises.py +++ b/preliz/distributions/vonmises.py @@ -119,6 +119,20 @@ def _fit_mle(self, sample): mu = np.mod(mu + np.pi, 2 * np.pi) - np.pi self._update(mu, kappa) + def eti(self, mass=0.94, fmt=".2f"): + mean = self.mu + self.mu = 0 + hdi_min, hdi_max = super().eti(mass=mass, fmt=fmt) + self.mu = mean + return _warp_interval(hdi_min, hdi_max, self.mu, fmt) + + def hdi(self, mass=0.94, fmt=".2f"): + mean = self.mu + self.mu = 0 + hdi_min, hdi_max = super().hdi(mass=mass, fmt=fmt) + self.mu = mean + return _warp_interval(hdi_min, hdi_max, self.mu, fmt) + def nb_cdf(x, pdf): if isinstance(x, (int, float)): @@ -170,3 +184,15 @@ def nb_logpdf(x, mu, kappa): def nb_neg_logpdf(x, mu, kappa): return -(nb_logpdf(x, mu, kappa)).sum() + + +def _warp_interval(hdi_min, hdi_max, mu, fmt): + hdi_min = hdi_min + mu + hdi_max = hdi_max + mu + + lower_tail = np.arctan2(np.sin(hdi_min), np.cos(hdi_min)) + upper_tail = np.arctan2(np.sin(hdi_max), np.cos(hdi_max)) + if fmt != "none": + lower_tail = float(f"{lower_tail:{fmt}}") + upper_tail = float(f"{upper_tail:{fmt}}") + return (lower_tail, upper_tail) diff --git a/preliz/internal/plot_helper.py b/preliz/internal/plot_helper.py index dad94eee..4be1ef48 100644 --- a/preliz/internal/plot_helper.py +++ b/preliz/internal/plot_helper.py @@ -34,13 +34,19 @@ def plot_pointinterval(distribution, interval="hdi", levels=None, rotated=False, Whether to do the plot along the x-axis (default) or on the y-axis ax : matplotlib axis """ + + if isinstance(distribution, (np.ndarray, list, tuple)): + dist_type = "sample" + else: + dist_type = "preliz" + if interval == "quantiles": if levels is None: levels = [0.05, 0.25, 0.5, 0.75, 0.95] elif len(levels) not in (5, 3, 1, 0): raise ValueError("levels should have 5, 3, 1 or 0 elements") - if isinstance(distribution, (np.ndarray, list, tuple)): + if dist_type == "sample": q_s = np.quantile(distribution, levels).tolist() else: q_s = distribution.ppf(levels).tolist() @@ -52,7 +58,7 @@ def plot_pointinterval(distribution, interval="hdi", levels=None, rotated=False, elif len(levels) not in (2, 1): raise ValueError("levels should have 2 or 1 elements") - if isinstance(distribution, (np.ndarray, list, tuple)): + if dist_type == "sample": if interval == "hdi": func = hdi if interval == "eti": @@ -77,21 +83,32 @@ def plot_pointinterval(distribution, interval="hdi", levels=None, rotated=False, q_s_size = len(q_s) - if rotated: - if q_s_size == 5: - ax.plot([0, 0], (q_s.pop(0), q_s.pop(-1)), "k", solid_capstyle="butt", lw=1.5) - if q_s_size > 2: - ax.plot([0, 0], (q_s.pop(0), q_s.pop(-1)), "k", solid_capstyle="butt", lw=4) - if q_s_size > 0: - ax.plot(0, q_s[0], "wo", mec="k") + if q_s_size == 5: + _plot_sub_iterval(q_s, lw=1.5, rotated=rotated, ax=ax) + if q_s_size > 2: + _plot_sub_iterval(q_s, lw=4, rotated=rotated, ax=ax) + if q_s_size > 0: + x, y = q_s[0], 0 + if rotated: + x, y = y, x + ax.plot(x, y, "wo", mec="k") + + +def _plot_sub_iterval(q_s, lw, rotated, ax): + lower, upper = q_s.pop(0), q_s.pop(-1) + if lower < upper: + x, y = (lower, upper), [0, 0] + if rotated: + x, y = y, x + ax.plot(x, y, "k", solid_capstyle="butt", lw=lw) else: - if q_s_size == 5: - ax.plot((q_s.pop(0), q_s.pop(-1)), [0, 0], "k", solid_capstyle="butt", lw=1.5) - if q_s_size > 2: - ax.plot((q_s.pop(0), q_s.pop(-1)), [0, 0], "k", solid_capstyle="butt", lw=4) - - if q_s_size > 0: - ax.plot(q_s[0], 0, "wo", mec="k") + x0, y0 = (lower, np.pi), [0, 0] + x1, y1 = (-np.pi, upper), [0, 0] + if rotated: + x0, y0 = y0, x0 + x1, y1 = y1, x1 + ax.plot(x0, y0, "k", solid_capstyle="butt", lw=lw) + ax.plot(x1, y1, "k", solid_capstyle="butt", lw=lw) def eti(distribution, mass):