diff --git a/tests/ensemble_test.py b/tests/ensemble_test.py index a9fb02235..d7f2a48c1 100644 --- a/tests/ensemble_test.py +++ b/tests/ensemble_test.py @@ -4,7 +4,7 @@ from __future__ import annotations import pytest -from torch import eye, ones, zeros +from torch import eye, ones, randn_like, tensor, zeros from torch.distributions import MultivariateNormal from sbi.inference import NLE_A, NPE_C, NRE_A @@ -133,9 +133,9 @@ def simulator(theta): num_samples=num_samples, ) max_dkl = 0.15 - assert ( - dkl < max_dkl - ), f"D-KL={dkl} is more than 2 stds above the average performance." + assert dkl < max_dkl, ( + f"D-KL={dkl} is more than 2 stds above the average performance." + ) # test individual log_prob and map posterior.log_prob(samples, individually=True) @@ -149,3 +149,42 @@ def simulator(theta): samples = posterior.sample_batched((10,), ones(x_o_batch_dim, num_dim)) assert samples.shape == (10, x_o_batch_dim, num_dim), "Sample shape wrong" + + +def test_ensemble_posterior_weights(): + """Test EnsemblePosterior weight handling for valid and invalid formats.""" + num_dim = 2 + ensemble_size = 2 + num_simulations = 50 + x_o = zeros(1, num_dim) + + prior = MultivariateNormal(loc=zeros(num_dim), covariance_matrix=eye(num_dim)) + + posteriors = [] + for _ in range(ensemble_size): + theta = prior.sample((num_simulations,)) + x = theta + 0.1 * randn_like(theta) + inferer = NPE_C(prior) + inferer.append_simulations(theta, x).train(max_num_epochs=1) + posteriors.append(inferer.build_posterior()) + + posterior = EnsemblePosterior(posteriors) + posterior.set_default_x(x_o) + _ = posterior.sample((2,)) + + posterior = EnsemblePosterior(posteriors, weights=[0.3, 0.7]) + posterior.set_default_x(x_o) + _ = posterior.sample((2,)) + + posterior = EnsemblePosterior(posteriors, weights=tensor([0.4, 0.6])) + posterior.set_default_x(x_o) + _ = posterior.sample((2,)) + + with pytest.raises(TypeError): + EnsemblePosterior(posteriors, weights={"w1": 0.5, "w2": 0.5}) + + with pytest.raises(TypeError): + EnsemblePosterior(posteriors, weights=0.5) + + with pytest.raises(TypeError): + EnsemblePosterior(posteriors, weights=(0.5, 0.5))