Skip to content

Commit

Permalink
Test for callable potential
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Feb 16, 2024
1 parent 5d0f67b commit 4ed8cc9
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 4 deletions.
11 changes: 7 additions & 4 deletions sbi/inference/posteriors/vi_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def __init__(
self._prior = q._prior
else:
raise ValueError(
"We could not find a suitable prior distribution within `potential_fn`"
"We could not find a suitable prior distribution within `potential_fn` "
"or `q` (if a VIPosterior is given). Please explicitly specify a prior."
)
move_all_tensor_to_device(self._prior, device)
Expand Down Expand Up @@ -461,9 +461,12 @@ def train(
self.evaluate(quality_control_metric=quality_control_metric)
except Exception as e:
print(
f"Quality control did not work, we reset the variational \
posterior,please check your setting. \
\n Following error occured {e}"
f"Quality control showed a low quality of the variational "
f"posterior. We are automatically retraining the variational "
f"posterior from scratch with a smaller learning rate. "
f"Alternatively, if you want to skip quality control, please "
f"retrain with `VIPosterior.train(..., quality_control=False)`. "
f"\nThe error that occured is: {e}"
)
self.train(
learning_rate=learning_rate * 0.1,
Expand Down
60 changes: 60 additions & 0 deletions tests/potential_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Affero General Public License v3, see <https://www.gnu.org/licenses/>.

from __future__ import annotations

import pytest
import torch
from torch import eye, ones, zeros
from torch.distributions import MultivariateNormal

from sbi.inference import (
ImportanceSamplingPosterior,
MCMCPosterior,
RejectionPosterior,
VIPosterior,
)


@pytest.mark.parametrize(
"sampling_method",
[ImportanceSamplingPosterior, MCMCPosterior, RejectionPosterior, VIPosterior],
)
def test_callable_potential(sampling_method):
dim = 2
mean = 2.5
cov = 2.0
x_o = 1 * ones((dim,))
target_density = MultivariateNormal(mean * ones((dim,)), cov * eye(dim))

def potential(theta, x_o, **kwargs):
return target_density.log_prob(theta + x_o)

proposal = MultivariateNormal(zeros((dim,)), 5 * eye(dim))

if sampling_method == ImportanceSamplingPosterior:
approx_density = sampling_method(
potential_fn=potential, proposal=proposal, method="sir"
)
approx_samples = approx_density.sample((1024,), oversampling_factor=1024, x=x_o)
elif sampling_method == MCMCPosterior:
approx_density = sampling_method(potential_fn=potential, proposal=proposal)
approx_samples = approx_density.sample(
(1024,), x=x_o, num_chains=100, method="slice_np_vectorized"
)
elif sampling_method == VIPosterior:
approx_density = sampling_method(
potential_fn=potential, prior=proposal
).set_default_x(x_o)
approx_density = approx_density.train()
approx_samples = approx_density.sample((1024,))
elif sampling_method == RejectionPosterior:
approx_density = sampling_method(
potential_fn=potential, proposal=proposal
).set_default_x(x_o)
approx_samples = approx_density.sample((1024,))

sample_mean = torch.mean(approx_samples, dim=0)
sample_std = torch.std(approx_samples, dim=0)
assert torch.allclose(sample_mean, torch.as_tensor(mean) - x_o, atol=0.2)
assert torch.allclose(sample_std, torch.sqrt(torch.as_tensor(cov)), atol=0.1)

0 comments on commit 4ed8cc9

Please sign in to comment.