Skip to content

Commit

Permalink
test: add ensemble posterior weights validation
Browse files Browse the repository at this point in the history
Signed-off-by: samadpls <[email protected]>
  • Loading branch information
samadpls committed Nov 22, 2024
1 parent 3ff1f51 commit 43ca8bb
Showing 1 changed file with 44 additions and 4 deletions.
48 changes: 44 additions & 4 deletions tests/ensemble_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from __future__ import annotations

import pytest
from torch import eye, ones, zeros
from torch import eye, ones, randn_like, tensor, zeros
from torch.distributions import MultivariateNormal

from sbi.inference import NLE_A, NPE_C, NRE_A
Expand Down Expand Up @@ -133,9 +133,9 @@ def simulator(theta):
num_samples=num_samples,
)
max_dkl = 0.15
assert (
dkl < max_dkl
), f"D-KL={dkl} is more than 2 stds above the average performance."
assert dkl < max_dkl, (
f"D-KL={dkl} is more than 2 stds above the average performance."
)

# test individual log_prob and map
posterior.log_prob(samples, individually=True)
Expand All @@ -149,3 +149,43 @@ def simulator(theta):
samples = posterior.sample_batched((10,), ones(x_o_batch_dim, num_dim))

assert samples.shape == (10, x_o_batch_dim, num_dim), "Sample shape wrong"


def test_ensemble_posterior_weights():
"""Test EnsemblePosterior weight handling for valid and invalid formats."""
num_dim = 2
ensemble_size = 2
num_simulations = 50
x_o = zeros(1, num_dim)

prior = MultivariateNormal(loc=zeros(num_dim),
covariance_matrix=eye(num_dim))

posteriors = []
for _ in range(ensemble_size):
theta = prior.sample((num_simulations,))
x = theta + 0.1 * randn_like(theta)
inferer = NPE_C(prior)
inferer.append_simulations(theta, x).train(max_num_epochs=1)
posteriors.append(inferer.build_posterior())

posterior = EnsemblePosterior(posteriors)
posterior.set_default_x(x_o)
_ = posterior.sample((2,))

posterior = EnsemblePosterior(posteriors, weights=[0.3, 0.7])
posterior.set_default_x(x_o)
_ = posterior.sample((2,))

posterior = EnsemblePosterior(posteriors, weights=tensor([0.4, 0.6]))
posterior.set_default_x(x_o)
_ = posterior.sample((2,))

with pytest.raises(TypeError):
EnsemblePosterior(posteriors, weights={"w1": 0.5, "w2": 0.5})

with pytest.raises(TypeError):
EnsemblePosterior(posteriors, weights=0.5)

with pytest.raises(TypeError):
EnsemblePosterior(posteriors, weights=(0.5, 0.5))

0 comments on commit 43ca8bb

Please sign in to comment.