Skip to content

Commit

Permalink
Add to_bambi method (#578)
Browse files Browse the repository at this point in the history
* add to_bambi method

* add to_bambi method
  • Loading branch information
aloctavodia authored Nov 1, 2024
1 parent f29cf14 commit 84d301f
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 4 deletions.
25 changes: 21 additions & 4 deletions preliz/distributions/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,6 @@ def to_pymc(self, name=None, **kwargs):
-------
PyMC distribution
"""
pymc_dist = None

try:
import pymc.distributions as pm_dists
from pymc.model import Model
Expand Down Expand Up @@ -369,10 +367,29 @@ def to_pymc(self, name=None, **kwargs):
else:
pymc_dist = pymc_class(name, **self.params_dict, **kwargs)

return pymc_dist

except ImportError:
pass
raise ImportError("This function requires PyMC") from None

def to_bambi(self, **kwargs):
"""
Convert the PreliZ distribution to a Bambi Prior.
kwargs : PyMC distributions properties
kwargs are used to specify properties such as shape or dims
return pymc_dist
Returns
-------
Bambi Prior
"""
try:
from bambi import Prior

return Prior(self.__class__.__name__, **self.params_dict, **kwargs)

except ImportError:
raise ImportError("This function requires Bambi") from None

def _check_endpoints(self, lower, upper, raise_error=True):
"""
Expand Down
7 changes: 7 additions & 0 deletions preliz/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,3 +313,10 @@ def test_to_pymc():
assert model.basic_RVs[2].ndim == 0
assert Normal(0, 1).to_pymc(shape=2).ndim == 1
assert Censored(Normal(0, 1), lower=0).to_pymc().ndim == 0


def test_to_bambi():
bambi_prior = Gamma(mu=2, sigma=1).to_bambi()
assert bambi_prior.name == "Gamma"
assert bambi_prior.args["mu"] == 2
assert bambi_prior.args["sigma"] == 1

0 comments on commit 84d301f

Please sign in to comment.