diff --git a/sbi/inference/posteriors/vi_posterior.py b/sbi/inference/posteriors/vi_posterior.py index 4bd29a4a3..95529b7dd 100644 --- a/sbi/inference/posteriors/vi_posterior.py +++ b/sbi/inference/posteriors/vi_posterior.py @@ -115,7 +115,7 @@ def __init__( self._prior = q._prior else: raise ValueError( - "We could not find a suitable prior distribution within `potential_fn`" + "We could not find a suitable prior distribution within `potential_fn` " "or `q` (if a VIPosterior is given). Please explicitly specify a prior." ) move_all_tensor_to_device(self._prior, device) @@ -461,9 +461,12 @@ def train( self.evaluate(quality_control_metric=quality_control_metric) except Exception as e: print( - f"Quality control did not work, we reset the variational \ - posterior,please check your setting. \ - \n Following error occured {e}" + f"Quality control showed a low quality of the variational " + f"posterior. We are automatically retraining the variational " + f"posterior from scratch with a smaller learning rate. " + f"Alternatively, if you want to skip quality control, please " + f"retrain with `VIPosterior.train(..., quality_control=False)`. " + f"\nThe error that occured is: {e}" ) self.train( learning_rate=learning_rate * 0.1, diff --git a/tests/potential_test.py b/tests/potential_test.py new file mode 100644 index 000000000..4784e4534 --- /dev/null +++ b/tests/potential_test.py @@ -0,0 +1,60 @@ +# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed +# under the Affero General Public License v3, see . + +from __future__ import annotations + +import pytest +import torch +from torch import eye, ones, zeros +from torch.distributions import MultivariateNormal + +from sbi.inference import ( + ImportanceSamplingPosterior, + MCMCPosterior, + RejectionPosterior, + VIPosterior, +) + + +@pytest.mark.parametrize( + "sampling_method", + [ImportanceSamplingPosterior, MCMCPosterior, RejectionPosterior, VIPosterior], +) +def test_callable_potential(sampling_method): + dim = 2 + mean = 2.5 + cov = 2.0 + x_o = 1 * ones((dim,)) + target_density = MultivariateNormal(mean * ones((dim,)), cov * eye(dim)) + + def potential(theta, x_o, **kwargs): + return target_density.log_prob(theta + x_o) + + proposal = MultivariateNormal(zeros((dim,)), 5 * eye(dim)) + + if sampling_method == ImportanceSamplingPosterior: + approx_density = sampling_method( + potential_fn=potential, proposal=proposal, method="sir" + ) + approx_samples = approx_density.sample((1024,), oversampling_factor=1024, x=x_o) + elif sampling_method == MCMCPosterior: + approx_density = sampling_method(potential_fn=potential, proposal=proposal) + approx_samples = approx_density.sample( + (1024,), x=x_o, num_chains=100, method="slice_np_vectorized" + ) + elif sampling_method == VIPosterior: + approx_density = sampling_method( + potential_fn=potential, prior=proposal + ).set_default_x(x_o) + approx_density = approx_density.train() + approx_samples = approx_density.sample((1024,)) + elif sampling_method == RejectionPosterior: + approx_density = sampling_method( + potential_fn=potential, proposal=proposal + ).set_default_x(x_o) + approx_samples = approx_density.sample((1024,)) + + sample_mean = torch.mean(approx_samples, dim=0) + sample_std = torch.std(approx_samples, dim=0) + assert torch.allclose(sample_mean, torch.as_tensor(mean) - x_o, atol=0.2) + assert torch.allclose(sample_std, torch.sqrt(torch.as_tensor(cov)), atol=0.1)