diff --git a/sbi/inference/posteriors/base_posterior.py b/sbi/inference/posteriors/base_posterior.py index 888a30fab..eafdddc41 100644 --- a/sbi/inference/posteriors/base_posterior.py +++ b/sbi/inference/posteriors/base_posterior.py @@ -1,6 +1,7 @@ # This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed # under the Affero General Public License v3, see . +import inspect from abc import ABC, abstractmethod from typing import Any, Callable, Dict, Optional, Union @@ -8,6 +9,10 @@ import torch.distributions.transforms as torch_tf from torch import Tensor +from sbi.inference.potentials.base_potential import ( + BasePotential, + CallablePotentialWrapper, +) from sbi.types import Array, Shape, TorchTransform from sbi.utils import gradient_ascent from sbi.utils.torchutils import ensure_theta_batched, process_device @@ -24,20 +29,36 @@ class NeuralPosterior(ABC): def __init__( self, - potential_fn: Callable, + potential_fn: Union[Callable, BasePotential], theta_transform: Optional[TorchTransform] = None, device: Optional[str] = None, x_shape: Optional[torch.Size] = None, ): """ Args: - potential_fn: The potential function from which to draw samples. + potential_fn: The potential function from which to draw samples. Must be a + `BasePotential` or a `Callable` which takes `theta` and `x_o` as inputs. theta_transform: Transformation that will be applied during sampling. Allows to perform, e.g. MCMC in unconstrained space. device: Training device, e.g., "cpu", "cuda" or "cuda:0". If None, `potential_fn.device` is used. x_shape: Shape of the observed data. """ + if not isinstance(potential_fn, BasePotential): + kwargs_of_callable = list(inspect.signature(potential_fn).parameters.keys()) + for key in ["theta", "x_o"]: + assert key in kwargs_of_callable, ( + "If you pass a `Callable` as `potential_fn` then it must have " + "`theta` and `x_o` as inputs, even if some of these keyword " + "arguments are unused." + ) + + # If the `potential_fn` is a Callable then we wrap it as a + # `CallablePotentialWrapper` which inherits from `BasePotential`. + potential_device = "cpu" if device is None else device + potential_fn = CallablePotentialWrapper( + potential_fn, prior=None, x_o=None, device=potential_device + ) # Ensure device string. self._device = process_device(potential_fn.device if device is None else device) diff --git a/sbi/inference/posteriors/importance_posterior.py b/sbi/inference/posteriors/importance_posterior.py index 9074267ad..ef71b100a 100644 --- a/sbi/inference/posteriors/importance_posterior.py +++ b/sbi/inference/posteriors/importance_posterior.py @@ -7,6 +7,7 @@ from sbi import utils as utils from sbi.inference.posteriors.base_posterior import NeuralPosterior +from sbi.inference.potentials.base_potential import BasePotential from sbi.samplers.importance.importance_sampling import importance_sample from sbi.samplers.importance.sir import sampling_importance_resampling from sbi.types import Shape, TorchTransform @@ -24,7 +25,7 @@ class ImportanceSamplingPosterior(NeuralPosterior): def __init__( self, - potential_fn: Callable, + potential_fn: Union[Callable, BasePotential], proposal: Any, theta_transform: Optional[TorchTransform] = None, method: str = "sir", @@ -35,7 +36,8 @@ def __init__( ): """ Args: - potential_fn: The potential function from which to draw samples. + potential_fn: The potential function from which to draw samples. Must be a + `BasePotential` or a `Callable` which takes `theta` and `x_o` as inputs. proposal: The proposal distribution. theta_transform: Transformation that is applied to parameters. Is not used during but only when calling `.map()`. diff --git a/sbi/inference/posteriors/mcmc_posterior.py b/sbi/inference/posteriors/mcmc_posterior.py index 784309702..515ee6aa1 100644 --- a/sbi/inference/posteriors/mcmc_posterior.py +++ b/sbi/inference/posteriors/mcmc_posterior.py @@ -18,6 +18,7 @@ from tqdm.auto import tqdm from sbi.inference.posteriors.base_posterior import NeuralPosterior +from sbi.inference.potentials.base_potential import BasePotential from sbi.samplers.mcmc import ( IterateParameters, Slice, @@ -41,7 +42,7 @@ class MCMCPosterior(NeuralPosterior): def __init__( self, - potential_fn: Callable, + potential_fn: Union[Callable, BasePotential], proposal: Any, theta_transform: Optional[TorchTransform] = None, method: str = "slice_np", @@ -57,7 +58,8 @@ def __init__( ): """ Args: - potential_fn: The potential function from which to draw samples. + potential_fn: The potential function from which to draw samples. Must be a + `BasePotential` or a `Callable` which takes `theta` and `x_o` as inputs. proposal: Proposal distribution that is used to initialize the MCMC chain. theta_transform: Transformation that will be applied during sampling. Allows to perform MCMC in unconstrained space. diff --git a/sbi/inference/posteriors/rejection_posterior.py b/sbi/inference/posteriors/rejection_posterior.py index 92d55bcc1..a1382a6be 100644 --- a/sbi/inference/posteriors/rejection_posterior.py +++ b/sbi/inference/posteriors/rejection_posterior.py @@ -9,6 +9,7 @@ from sbi import utils as utils from sbi.inference.posteriors.base_posterior import NeuralPosterior +from sbi.inference.potentials.base_potential import BasePotential from sbi.samplers.rejection.rejection import rejection_sample from sbi.types import Shape, TorchTransform from sbi.utils.torchutils import ensure_theta_batched @@ -22,7 +23,7 @@ class RejectionPosterior(NeuralPosterior): def __init__( self, - potential_fn: Callable, + potential_fn: Union[Callable, BasePotential], proposal: Any, theta_transform: Optional[TorchTransform] = None, max_sampling_batch_size: int = 10_000, @@ -34,7 +35,8 @@ def __init__( ): """ Args: - potential_fn: The potential function from which to draw samples. + potential_fn: The potential function from which to draw samples. Must be a + `BasePotential` or a `Callable` which takes `theta` and `x_o` as inputs. proposal: The proposal distribution. theta_transform: Transformation that is applied to parameters. Is not used during but only when calling `.map()`. diff --git a/sbi/inference/posteriors/vi_posterior.py b/sbi/inference/posteriors/vi_posterior.py index 4bd29a4a3..315a69950 100644 --- a/sbi/inference/posteriors/vi_posterior.py +++ b/sbi/inference/posteriors/vi_posterior.py @@ -11,6 +11,7 @@ from tqdm.auto import tqdm from sbi.inference.posteriors.base_posterior import NeuralPosterior +from sbi.inference.potentials.base_potential import BasePotential from sbi.samplers.vi import ( adapt_variational_distribution, check_variational_distribution, @@ -47,7 +48,7 @@ class VIPosterior(NeuralPosterior): def __init__( self, - potential_fn: Callable, + potential_fn: Union[Callable, BasePotential], prior: Optional[TorchDistribution] = None, q: Union[str, PyroTransformedDistribution, "VIPosterior", Callable] = "maf", theta_transform: Optional[TorchTransform] = None, @@ -59,7 +60,8 @@ def __init__( ): """ Args: - potential_fn: The potential function from which to draw samples. + potential_fn: The potential function from which to draw samples. Must be a + `BasePotential` or a `Callable` which takes `theta` and `x_o` as inputs. prior: This is the prior distribution. Note that this is only used to check/construct the variational distribution or within some quality metrics. Please make sure that this matches with the prior @@ -115,7 +117,7 @@ def __init__( self._prior = q._prior else: raise ValueError( - "We could not find a suitable prior distribution within `potential_fn`" + "We could not find a suitable prior distribution within `potential_fn` " "or `q` (if a VIPosterior is given). Please explicitly specify a prior." ) move_all_tensor_to_device(self._prior, device) @@ -461,9 +463,12 @@ def train( self.evaluate(quality_control_metric=quality_control_metric) except Exception as e: print( - f"Quality control did not work, we reset the variational \ - posterior,please check your setting. \ - \n Following error occured {e}" + f"Quality control showed a low quality of the variational " + f"posterior. We are automatically retraining the variational " + f"posterior from scratch with a smaller learning rate. " + f"Alternatively, if you want to skip quality control, please " + f"retrain with `VIPosterior.train(..., quality_control=False)`. " + f"\nThe error that occured is: {e}" ) self.train( learning_rate=learning_rate * 0.1, diff --git a/sbi/inference/potentials/base_potential.py b/sbi/inference/potentials/base_potential.py index 5c1fe32d7..f0f0a6272 100644 --- a/sbi/inference/potentials/base_potential.py +++ b/sbi/inference/potentials/base_potential.py @@ -1,6 +1,7 @@ from abc import ABCMeta, abstractmethod from typing import Optional +import torch from torch import Tensor from torch.distributions import Distribution @@ -9,7 +10,10 @@ class BasePotential(metaclass=ABCMeta): def __init__( - self, prior: Distribution, x_o: Optional[Tensor] = None, device: str = "cpu" + self, + prior: Optional[Distribution], + x_o: Optional[Tensor] = None, + device: str = "cpu", ): """Initialize potential function. @@ -61,3 +65,23 @@ def return_x_o(self) -> Optional[Tensor]: `self._x_o` is `None`. """ return self._x_o + + +class CallablePotentialWrapper(BasePotential): + """If `potential_fn` is a callable it gets wrapped as this.""" + + allow_iid_x = True # type: ignore + + def __init__( + self, + callable_potential, + prior: Optional[Distribution], + x_o: Optional[Tensor] = None, + device: str = "cpu", + ): + super().__init__(prior, x_o, device) + self.callable_potential = callable_potential + + def __call__(self, theta, track_gradients: bool = True): + with torch.set_grad_enabled(track_gradients): + return self.callable_potential(theta=theta, x_o=self.x_o) diff --git a/sbi/inference/potentials/likelihood_based_potential.py b/sbi/inference/potentials/likelihood_based_potential.py index 5967556c0..d3aa5e51d 100644 --- a/sbi/inference/potentials/likelihood_based_potential.py +++ b/sbi/inference/potentials/likelihood_based_potential.py @@ -96,7 +96,7 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor: track_gradients=track_gradients, ) - return log_likelihood_trial_sum + self.prior.log_prob(theta) + return log_likelihood_trial_sum + self.prior.log_prob(theta) # type: ignore def _log_likelihoods_over_trials( @@ -201,4 +201,4 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor: self.x_o.shape[0], -1 ).sum(0) - return log_likelihood_trial_sum + self.prior.log_prob(theta) + return log_likelihood_trial_sum + self.prior.log_prob(theta) # type: ignore diff --git a/sbi/inference/potentials/ratio_based_potential.py b/sbi/inference/potentials/ratio_based_potential.py index bd96df247..86a77e0a9 100644 --- a/sbi/inference/potentials/ratio_based_potential.py +++ b/sbi/inference/potentials/ratio_based_potential.py @@ -92,7 +92,7 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor: ) # Move to cpu for comparison with prior. - return log_likelihood_trial_sum + self.prior.log_prob(theta) + return log_likelihood_trial_sum + self.prior.log_prob(theta) # type: ignore def _log_ratios_over_trials( diff --git a/tests/potential_test.py b/tests/potential_test.py new file mode 100644 index 000000000..9b5c9ae6f --- /dev/null +++ b/tests/potential_test.py @@ -0,0 +1,61 @@ +# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed +# under the Affero General Public License v3, see . + +from __future__ import annotations + +import pytest +import torch +from torch import eye, ones, zeros +from torch.distributions import MultivariateNormal + +from sbi.inference import ( + ImportanceSamplingPosterior, + MCMCPosterior, + RejectionPosterior, + VIPosterior, +) + + +@pytest.mark.parametrize( + "sampling_method", + [ImportanceSamplingPosterior, MCMCPosterior, RejectionPosterior, VIPosterior], +) +def test_callable_potential(sampling_method): + """Test whether callable potentials can be used to sample from a Gaussian.""" + dim = 2 + mean = 2.5 + cov = 2.0 + x_o = 1 * ones((dim,)) + target_density = MultivariateNormal(mean * ones((dim,)), cov * eye(dim)) + + def potential(theta, x_o): + return target_density.log_prob(theta + x_o) + + proposal = MultivariateNormal(zeros((dim,)), 5 * eye(dim)) + + if sampling_method == ImportanceSamplingPosterior: + approx_density = sampling_method( + potential_fn=potential, proposal=proposal, method="sir" + ) + approx_samples = approx_density.sample((1024,), oversampling_factor=1024, x=x_o) + elif sampling_method == MCMCPosterior: + approx_density = sampling_method(potential_fn=potential, proposal=proposal) + approx_samples = approx_density.sample( + (1024,), x=x_o, num_chains=100, method="slice_np_vectorized" + ) + elif sampling_method == VIPosterior: + approx_density = sampling_method( + potential_fn=potential, prior=proposal + ).set_default_x(x_o) + approx_density = approx_density.train() + approx_samples = approx_density.sample((1024,)) + elif sampling_method == RejectionPosterior: + approx_density = sampling_method( + potential_fn=potential, proposal=proposal + ).set_default_x(x_o) + approx_samples = approx_density.sample((1024,)) + + sample_mean = torch.mean(approx_samples, dim=0) + sample_std = torch.std(approx_samples, dim=0) + assert torch.allclose(sample_mean, torch.as_tensor(mean) - x_o, atol=0.2) + assert torch.allclose(sample_std, torch.sqrt(torch.as_tensor(cov)), atol=0.1)