diff --git a/sbi/inference/posteriors/base_posterior.py b/sbi/inference/posteriors/base_posterior.py
index 888a30fab..eafdddc41 100644
--- a/sbi/inference/posteriors/base_posterior.py
+++ b/sbi/inference/posteriors/base_posterior.py
@@ -1,6 +1,7 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Affero General Public License v3, see .
+import inspect
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Optional, Union
@@ -8,6 +9,10 @@
import torch.distributions.transforms as torch_tf
from torch import Tensor
+from sbi.inference.potentials.base_potential import (
+ BasePotential,
+ CallablePotentialWrapper,
+)
from sbi.types import Array, Shape, TorchTransform
from sbi.utils import gradient_ascent
from sbi.utils.torchutils import ensure_theta_batched, process_device
@@ -24,20 +29,36 @@ class NeuralPosterior(ABC):
def __init__(
self,
- potential_fn: Callable,
+ potential_fn: Union[Callable, BasePotential],
theta_transform: Optional[TorchTransform] = None,
device: Optional[str] = None,
x_shape: Optional[torch.Size] = None,
):
"""
Args:
- potential_fn: The potential function from which to draw samples.
+ potential_fn: The potential function from which to draw samples. Must be a
+ `BasePotential` or a `Callable` which takes `theta` and `x_o` as inputs.
theta_transform: Transformation that will be applied during sampling.
Allows to perform, e.g. MCMC in unconstrained space.
device: Training device, e.g., "cpu", "cuda" or "cuda:0". If None,
`potential_fn.device` is used.
x_shape: Shape of the observed data.
"""
+ if not isinstance(potential_fn, BasePotential):
+ kwargs_of_callable = list(inspect.signature(potential_fn).parameters.keys())
+ for key in ["theta", "x_o"]:
+ assert key in kwargs_of_callable, (
+ "If you pass a `Callable` as `potential_fn` then it must have "
+ "`theta` and `x_o` as inputs, even if some of these keyword "
+ "arguments are unused."
+ )
+
+ # If the `potential_fn` is a Callable then we wrap it as a
+ # `CallablePotentialWrapper` which inherits from `BasePotential`.
+ potential_device = "cpu" if device is None else device
+ potential_fn = CallablePotentialWrapper(
+ potential_fn, prior=None, x_o=None, device=potential_device
+ )
# Ensure device string.
self._device = process_device(potential_fn.device if device is None else device)
diff --git a/sbi/inference/posteriors/importance_posterior.py b/sbi/inference/posteriors/importance_posterior.py
index 9074267ad..ef71b100a 100644
--- a/sbi/inference/posteriors/importance_posterior.py
+++ b/sbi/inference/posteriors/importance_posterior.py
@@ -7,6 +7,7 @@
from sbi import utils as utils
from sbi.inference.posteriors.base_posterior import NeuralPosterior
+from sbi.inference.potentials.base_potential import BasePotential
from sbi.samplers.importance.importance_sampling import importance_sample
from sbi.samplers.importance.sir import sampling_importance_resampling
from sbi.types import Shape, TorchTransform
@@ -24,7 +25,7 @@ class ImportanceSamplingPosterior(NeuralPosterior):
def __init__(
self,
- potential_fn: Callable,
+ potential_fn: Union[Callable, BasePotential],
proposal: Any,
theta_transform: Optional[TorchTransform] = None,
method: str = "sir",
@@ -35,7 +36,8 @@ def __init__(
):
"""
Args:
- potential_fn: The potential function from which to draw samples.
+ potential_fn: The potential function from which to draw samples. Must be a
+ `BasePotential` or a `Callable` which takes `theta` and `x_o` as inputs.
proposal: The proposal distribution.
theta_transform: Transformation that is applied to parameters. Is not used
during but only when calling `.map()`.
diff --git a/sbi/inference/posteriors/mcmc_posterior.py b/sbi/inference/posteriors/mcmc_posterior.py
index 784309702..515ee6aa1 100644
--- a/sbi/inference/posteriors/mcmc_posterior.py
+++ b/sbi/inference/posteriors/mcmc_posterior.py
@@ -18,6 +18,7 @@
from tqdm.auto import tqdm
from sbi.inference.posteriors.base_posterior import NeuralPosterior
+from sbi.inference.potentials.base_potential import BasePotential
from sbi.samplers.mcmc import (
IterateParameters,
Slice,
@@ -41,7 +42,7 @@ class MCMCPosterior(NeuralPosterior):
def __init__(
self,
- potential_fn: Callable,
+ potential_fn: Union[Callable, BasePotential],
proposal: Any,
theta_transform: Optional[TorchTransform] = None,
method: str = "slice_np",
@@ -57,7 +58,8 @@ def __init__(
):
"""
Args:
- potential_fn: The potential function from which to draw samples.
+ potential_fn: The potential function from which to draw samples. Must be a
+ `BasePotential` or a `Callable` which takes `theta` and `x_o` as inputs.
proposal: Proposal distribution that is used to initialize the MCMC chain.
theta_transform: Transformation that will be applied during sampling.
Allows to perform MCMC in unconstrained space.
diff --git a/sbi/inference/posteriors/rejection_posterior.py b/sbi/inference/posteriors/rejection_posterior.py
index 92d55bcc1..a1382a6be 100644
--- a/sbi/inference/posteriors/rejection_posterior.py
+++ b/sbi/inference/posteriors/rejection_posterior.py
@@ -9,6 +9,7 @@
from sbi import utils as utils
from sbi.inference.posteriors.base_posterior import NeuralPosterior
+from sbi.inference.potentials.base_potential import BasePotential
from sbi.samplers.rejection.rejection import rejection_sample
from sbi.types import Shape, TorchTransform
from sbi.utils.torchutils import ensure_theta_batched
@@ -22,7 +23,7 @@ class RejectionPosterior(NeuralPosterior):
def __init__(
self,
- potential_fn: Callable,
+ potential_fn: Union[Callable, BasePotential],
proposal: Any,
theta_transform: Optional[TorchTransform] = None,
max_sampling_batch_size: int = 10_000,
@@ -34,7 +35,8 @@ def __init__(
):
"""
Args:
- potential_fn: The potential function from which to draw samples.
+ potential_fn: The potential function from which to draw samples. Must be a
+ `BasePotential` or a `Callable` which takes `theta` and `x_o` as inputs.
proposal: The proposal distribution.
theta_transform: Transformation that is applied to parameters. Is not used
during but only when calling `.map()`.
diff --git a/sbi/inference/posteriors/vi_posterior.py b/sbi/inference/posteriors/vi_posterior.py
index 4bd29a4a3..315a69950 100644
--- a/sbi/inference/posteriors/vi_posterior.py
+++ b/sbi/inference/posteriors/vi_posterior.py
@@ -11,6 +11,7 @@
from tqdm.auto import tqdm
from sbi.inference.posteriors.base_posterior import NeuralPosterior
+from sbi.inference.potentials.base_potential import BasePotential
from sbi.samplers.vi import (
adapt_variational_distribution,
check_variational_distribution,
@@ -47,7 +48,7 @@ class VIPosterior(NeuralPosterior):
def __init__(
self,
- potential_fn: Callable,
+ potential_fn: Union[Callable, BasePotential],
prior: Optional[TorchDistribution] = None,
q: Union[str, PyroTransformedDistribution, "VIPosterior", Callable] = "maf",
theta_transform: Optional[TorchTransform] = None,
@@ -59,7 +60,8 @@ def __init__(
):
"""
Args:
- potential_fn: The potential function from which to draw samples.
+ potential_fn: The potential function from which to draw samples. Must be a
+ `BasePotential` or a `Callable` which takes `theta` and `x_o` as inputs.
prior: This is the prior distribution. Note that this is only
used to check/construct the variational distribution or within some
quality metrics. Please make sure that this matches with the prior
@@ -115,7 +117,7 @@ def __init__(
self._prior = q._prior
else:
raise ValueError(
- "We could not find a suitable prior distribution within `potential_fn`"
+ "We could not find a suitable prior distribution within `potential_fn` "
"or `q` (if a VIPosterior is given). Please explicitly specify a prior."
)
move_all_tensor_to_device(self._prior, device)
@@ -461,9 +463,12 @@ def train(
self.evaluate(quality_control_metric=quality_control_metric)
except Exception as e:
print(
- f"Quality control did not work, we reset the variational \
- posterior,please check your setting. \
- \n Following error occured {e}"
+ f"Quality control showed a low quality of the variational "
+ f"posterior. We are automatically retraining the variational "
+ f"posterior from scratch with a smaller learning rate. "
+ f"Alternatively, if you want to skip quality control, please "
+ f"retrain with `VIPosterior.train(..., quality_control=False)`. "
+ f"\nThe error that occured is: {e}"
)
self.train(
learning_rate=learning_rate * 0.1,
diff --git a/sbi/inference/potentials/base_potential.py b/sbi/inference/potentials/base_potential.py
index 5c1fe32d7..f0f0a6272 100644
--- a/sbi/inference/potentials/base_potential.py
+++ b/sbi/inference/potentials/base_potential.py
@@ -1,6 +1,7 @@
from abc import ABCMeta, abstractmethod
from typing import Optional
+import torch
from torch import Tensor
from torch.distributions import Distribution
@@ -9,7 +10,10 @@
class BasePotential(metaclass=ABCMeta):
def __init__(
- self, prior: Distribution, x_o: Optional[Tensor] = None, device: str = "cpu"
+ self,
+ prior: Optional[Distribution],
+ x_o: Optional[Tensor] = None,
+ device: str = "cpu",
):
"""Initialize potential function.
@@ -61,3 +65,23 @@ def return_x_o(self) -> Optional[Tensor]:
`self._x_o` is `None`.
"""
return self._x_o
+
+
+class CallablePotentialWrapper(BasePotential):
+ """If `potential_fn` is a callable it gets wrapped as this."""
+
+ allow_iid_x = True # type: ignore
+
+ def __init__(
+ self,
+ callable_potential,
+ prior: Optional[Distribution],
+ x_o: Optional[Tensor] = None,
+ device: str = "cpu",
+ ):
+ super().__init__(prior, x_o, device)
+ self.callable_potential = callable_potential
+
+ def __call__(self, theta, track_gradients: bool = True):
+ with torch.set_grad_enabled(track_gradients):
+ return self.callable_potential(theta=theta, x_o=self.x_o)
diff --git a/sbi/inference/potentials/likelihood_based_potential.py b/sbi/inference/potentials/likelihood_based_potential.py
index 5967556c0..d3aa5e51d 100644
--- a/sbi/inference/potentials/likelihood_based_potential.py
+++ b/sbi/inference/potentials/likelihood_based_potential.py
@@ -96,7 +96,7 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
track_gradients=track_gradients,
)
- return log_likelihood_trial_sum + self.prior.log_prob(theta)
+ return log_likelihood_trial_sum + self.prior.log_prob(theta) # type: ignore
def _log_likelihoods_over_trials(
@@ -201,4 +201,4 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
self.x_o.shape[0], -1
).sum(0)
- return log_likelihood_trial_sum + self.prior.log_prob(theta)
+ return log_likelihood_trial_sum + self.prior.log_prob(theta) # type: ignore
diff --git a/sbi/inference/potentials/ratio_based_potential.py b/sbi/inference/potentials/ratio_based_potential.py
index bd96df247..86a77e0a9 100644
--- a/sbi/inference/potentials/ratio_based_potential.py
+++ b/sbi/inference/potentials/ratio_based_potential.py
@@ -92,7 +92,7 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
)
# Move to cpu for comparison with prior.
- return log_likelihood_trial_sum + self.prior.log_prob(theta)
+ return log_likelihood_trial_sum + self.prior.log_prob(theta) # type: ignore
def _log_ratios_over_trials(
diff --git a/tests/potential_test.py b/tests/potential_test.py
new file mode 100644
index 000000000..9b5c9ae6f
--- /dev/null
+++ b/tests/potential_test.py
@@ -0,0 +1,61 @@
+# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
+# under the Affero General Public License v3, see .
+
+from __future__ import annotations
+
+import pytest
+import torch
+from torch import eye, ones, zeros
+from torch.distributions import MultivariateNormal
+
+from sbi.inference import (
+ ImportanceSamplingPosterior,
+ MCMCPosterior,
+ RejectionPosterior,
+ VIPosterior,
+)
+
+
+@pytest.mark.parametrize(
+ "sampling_method",
+ [ImportanceSamplingPosterior, MCMCPosterior, RejectionPosterior, VIPosterior],
+)
+def test_callable_potential(sampling_method):
+ """Test whether callable potentials can be used to sample from a Gaussian."""
+ dim = 2
+ mean = 2.5
+ cov = 2.0
+ x_o = 1 * ones((dim,))
+ target_density = MultivariateNormal(mean * ones((dim,)), cov * eye(dim))
+
+ def potential(theta, x_o):
+ return target_density.log_prob(theta + x_o)
+
+ proposal = MultivariateNormal(zeros((dim,)), 5 * eye(dim))
+
+ if sampling_method == ImportanceSamplingPosterior:
+ approx_density = sampling_method(
+ potential_fn=potential, proposal=proposal, method="sir"
+ )
+ approx_samples = approx_density.sample((1024,), oversampling_factor=1024, x=x_o)
+ elif sampling_method == MCMCPosterior:
+ approx_density = sampling_method(potential_fn=potential, proposal=proposal)
+ approx_samples = approx_density.sample(
+ (1024,), x=x_o, num_chains=100, method="slice_np_vectorized"
+ )
+ elif sampling_method == VIPosterior:
+ approx_density = sampling_method(
+ potential_fn=potential, prior=proposal
+ ).set_default_x(x_o)
+ approx_density = approx_density.train()
+ approx_samples = approx_density.sample((1024,))
+ elif sampling_method == RejectionPosterior:
+ approx_density = sampling_method(
+ potential_fn=potential, proposal=proposal
+ ).set_default_x(x_o)
+ approx_samples = approx_density.sample((1024,))
+
+ sample_mean = torch.mean(approx_samples, dim=0)
+ sample_std = torch.std(approx_samples, dim=0)
+ assert torch.allclose(sample_mean, torch.as_tensor(mean) - x_o, atol=0.2)
+ assert torch.allclose(sample_std, torch.sqrt(torch.as_tensor(cov)), atol=0.1)