Skip to content

Commit

Permalink
Test for callable potential
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Feb 15, 2024
1 parent 9aff544 commit e9b07b1
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions tests/potential_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Affero General Public License v3, see <https://www.gnu.org/licenses/>.

from __future__ import annotations

import torch
from torch import eye, ones, zeros
from torch.distributions import MultivariateNormal

from sbi.inference import ImportanceSamplingPosterior


def test_callable_potential():
dim = 2
mean = 2.5
cov = 2.0
x_o = 1 * ones((dim,))
target_density = MultivariateNormal(mean * ones((dim,)), cov * eye(dim))

def potential(theta, x_o, **kwargs):
return target_density.log_prob(theta + x_o)

proposal = MultivariateNormal(zeros((dim,)), 5 * eye(dim))
approx_density = ImportanceSamplingPosterior(
potential_fn=potential,
proposal=proposal,
method="sir",
)
approx_samples = approx_density.sample((512,), oversampling_factor=1024, x=x_o)

sample_mean = torch.mean(approx_samples, dim=0)
sample_std = torch.std(approx_samples, dim=0)
assert torch.allclose(sample_mean, torch.as_tensor(mean) - x_o, atol=0.2)
assert torch.allclose(sample_std, torch.sqrt(torch.as_tensor(cov)), atol=0.05)

0 comments on commit e9b07b1

Please sign in to comment.