Skip to content

Commit

Permalink
test: ensemble posterior weights validation (#1307)
Browse files Browse the repository at this point in the history
Signed-off-by: samadpls <[email protected]>
  • Loading branch information
samadpls authored Dec 2, 2024
1 parent 3ff1f51 commit 3bd8aa9
Showing 1 changed file with 38 additions and 1 deletion.
39 changes: 38 additions & 1 deletion tests/ensemble_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from __future__ import annotations

import pytest
from torch import eye, ones, zeros
from torch import eye, ones, randn_like, tensor, zeros
from torch.distributions import MultivariateNormal

from sbi.inference import NLE_A, NPE_C, NRE_A
Expand Down Expand Up @@ -149,3 +149,40 @@ def simulator(theta):
samples = posterior.sample_batched((10,), ones(x_o_batch_dim, num_dim))

assert samples.shape == (10, x_o_batch_dim, num_dim), "Sample shape wrong"


@pytest.mark.parametrize(
"weights, expected_exception",
[
(None, None),
([0.3, 0.7], None),
(tensor([0.4, 0.6]), None),
({"w1": 0.5, "w2": 0.5}, TypeError),
(0.5, TypeError),
((0.5, 0.5), TypeError),
],
)
def test_ensemble_posterior_weights(weights, expected_exception):
"""Test EnsemblePosterior weight handling for valid and invalid formats."""
num_dim = 2
ensemble_size = 2
num_simulations = 50
x_o = zeros(1, num_dim)

prior = MultivariateNormal(loc=zeros(num_dim), covariance_matrix=eye(num_dim))

posteriors = []
for _ in range(ensemble_size):
theta = prior.sample((num_simulations,))
x = theta + 0.1 * randn_like(theta)
inferer = NPE_C(prior)
inferer.append_simulations(theta, x).train(max_num_epochs=1)
posteriors.append(inferer.build_posterior())

if expected_exception:
with pytest.raises(expected_exception):
EnsemblePosterior(posteriors, weights=weights)
else:
posterior = EnsemblePosterior(posteriors, weights=weights)
posterior.set_default_x(x_o)
_ = posterior.sample((2,))

0 comments on commit 3bd8aa9

Please sign in to comment.