Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mle: handle multidimensional samples #568

Merged
merged 2 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions preliz/internal/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,16 +337,35 @@ def fit_to_sample(selected_distributions, sample, x_min, x_max):
Maximize the likelihood given a sample
"""
fitted = Loss(len(selected_distributions))
sample_size = len(sample)
for dist in selected_distributions:
for dist in selected_distributions: # pylint: disable=too-many-nested-blocks
if dist.__class__.__name__ in ["BetaScaled", "TruncatedNormal"]:
update_bounds_beta_scaled(dist, x_min, x_max)

loss = np.inf
if dist._check_endpoints(x_min, x_max, raise_error=False):
dist._fit_mle(sample) # pylint:disable=protected-access
corr = get_penalization(sample_size, dist)
loss = dist._neg_logpdf(sample) + corr
if sample.ndim > 1:
dists = []
neg_logpdf = 0
for s in sample:
dist_i = copy(dist)
dist_i._fit_mle(s)
neg_logpdf += dist_i._neg_logpdf(s)
dists.append(dist_i)
new_dict = {}
for d in dists:
params = d.params_dict
for k, v in params.items():
if k in new_dict:
new_dict[k].append(v)
else:
new_dict[k] = [v]

dist._parametrization(**{k: np.asarray(v) for k, v in new_dict.items()})
else:
dist._fit_mle(sample) # pylint:disable=protected-access
neg_logpdf = dist._neg_logpdf(sample)
corr = get_penalization(sample.size, dist)
loss = neg_logpdf + corr
fitted.update(loss, dist)

return fitted
Expand Down
7 changes: 5 additions & 2 deletions preliz/ppls/agnostic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging


from preliz.internal.parser import get_engine
from preliz.distributions import Gamma, Normal, HalfNormal
from preliz.unidimensional.mle import mle
Expand Down Expand Up @@ -48,21 +49,23 @@ def posterior_to_prior(model, idata, alternative=None, engine="auto"):

if alternative is None:
for var, dist in model_info.items():
dist._fit_mle(posterior[var].values)
idx, _ = mle([dist], posterior[var].values, plot=False)
new_priors[var] = dist
else:
for var, dist in model_info.items():
dists = [dist]

if alternative == "auto":
dists += [Normal(), HalfNormal(), Gamma()]
alt = [Normal(), HalfNormal(), Gamma()]
dists += [a for a in alt if dist.__class__.__name__ != a.__class__.__name__]
elif isinstance(alternative, list):
dists += alternative
elif isinstance(alternative, dict):
dists += alternative.get(var, [])

idx, _ = mle(dists, posterior[var].values, plot=False)
new_priors[var] = dists[idx[0]]

if engine == "bambi":
new_model = write_bambi_string(new_priors, var_info2)
elif engine == "pymc":
Expand Down
3 changes: 2 additions & 1 deletion preliz/ppls/pymc_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

# pylint: disable=protected-access
from sys import modules
from copy import copy

import numpy as np

Expand Down Expand Up @@ -126,7 +127,7 @@ def get_model_information(model): # pylint: disable=too-many-locals
name = (
r_v.owner.op.name if r_v.owner.op.name else str(r_v.owner.op).split("RV", 1)[0].lower()
)
dist = pymc_to_preliz[name]
dist = copy(pymc_to_preliz[name])
p_model[r_v.name] = dist
if nc_parents:
idxs = [free_rvs.index(var_) for var_ in nc_parents]
Expand Down
2 changes: 1 addition & 1 deletion preliz/tests/test_posterior_to_prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_p2p_pymc():
"x1": np.random.normal(size=117),
}
)
bmb_prior = {"Intercept": bmb.Prior("HalfStudentT", nu=1)}
bmb_prior = {"Intercept": bmb.Prior("Normal", mu=0, sigma=1)}
bmb_model = bmb.Model("y ~ x + x1", bmb_data, priors=bmb_prior)
bmb_idata = bmb_model.fit(tune=200, draws=200, random_seed=2945)

Expand Down
38 changes: 0 additions & 38 deletions preliz/tests/test_ppe.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,41 +73,3 @@ def test_ppe(params):
assert_allclose(new_prior["x"].mu, params["new_mu_x"], 1)
assert_allclose(new_prior["x"].sigma, params["new_sigma_x"], 1)
assert_allclose(new_prior["z"].sigma, params["new_sigma_z"], 1)


@pytest.mark.parametrize(
"params",
[
{
"mu_a": 0,
"sigma_a": 10,
"sigma_b": 10,
"sigma_z": 10,
"target": pz.Normal(mu=40, sigma=7),
"new_mu_a": 40.00908,
"new_sigma_a": 0.167701,
"new_sigma_b": 7.003267,
"new_sigma_z": 7.003267,
},
],
)
def test_ppe_testdata(params):
csv_data = pd.read_csv(r"testdata/chemical_shifts_theo_exp.csv")
diff = csv_data.theo - csv_data.exp
cat_encode = pd.Categorical(csv_data["aa"])
idx = cat_encode.codes
coords = {"aa": cat_encode.categories}
with pm.Model(coords=coords) as model:
# hyper_priors
a = pm.Normal("a", mu=0, sigma=10)
b = pm.HalfNormal("b", 10)
z = pm.HalfNormal("z", sigma=10)

x = pm.Normal("x", mu=a, sigma=b, dims="aa") # No prior for this!

y = pm.Normal("y", mu=x[idx], sigma=z, observed=diff)
prior, new_prior, pymc_string = pz.ppe(model, params["target"])
assert_allclose(new_prior["a"].mu, params["new_mu_a"], 1e-6)
assert_allclose(new_prior["a"].sigma, params["new_sigma_a"], 1e-6)
assert_allclose(new_prior["b"].sigma, params["new_sigma_b"], 1e-6)
assert_allclose(new_prior["z"].sigma, params["new_sigma_z"], 1e-6)
2 changes: 2 additions & 0 deletions preliz/unidimensional/mle.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def mle(
updated inplace.
sample : list or 1D array-like
Data used to estimate the distribution parameters.
ignore_support : bool
plot : int
Number of distributions to plots. Defaults to ``1`` (i.e. plot the best match)
If larger than the number of passed distributions it will plot all of them.
Expand Down
Loading
Loading