From bfda675111bb49957f85e4dad5008fe4dd85898d Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Thu, 15 Feb 2024 09:41:03 +0100 Subject: [PATCH] Move the CallablePotentialWrapper to make posteriors pickleable --- sbi/inference/posteriors/base_posterior.py | 19 +++++++-------- sbi/inference/posteriors/mcmc_posterior.py | 1 - sbi/inference/potentials/base_potential.py | 24 ++++++++++++++++++- .../potentials/likelihood_based_potential.py | 4 ++-- .../potentials/ratio_based_potential.py | 2 +- tests/potential_test.py | 1 + 6 files changed, 35 insertions(+), 16 deletions(-) diff --git a/sbi/inference/posteriors/base_posterior.py b/sbi/inference/posteriors/base_posterior.py index 89ccab1ae..1c2b52f2b 100644 --- a/sbi/inference/posteriors/base_posterior.py +++ b/sbi/inference/posteriors/base_posterior.py @@ -8,11 +8,14 @@ import torch.distributions.transforms as torch_tf from torch import Tensor +from sbi.inference.potentials.base_potential import ( + BasePotential, + CallablePotentialWrapper, +) from sbi.types import Array, Shape, TorchTransform from sbi.utils import gradient_ascent from sbi.utils.torchutils import ensure_theta_batched, process_device from sbi.utils.user_input_checks import process_x -from sbi.inference.potentials.base_potential import BasePotential class NeuralPosterior(ABC): @@ -41,16 +44,10 @@ def __init__( """ if not isinstance(potential_fn, BasePotential): callable_potential = potential_fn - - class CallablePotentialWrapper(BasePotential): - """If `potential_fn` is a callable it gets wrapped as this.""" - allow_iid_x = True - - def __call__(self, theta, **kwargs): - return callable_potential(theta=theta, x_o=self.x_o, **kwargs) - - potential_fn = CallablePotentialWrapper(None, None) - + potential_device = "cpu" if device is None else device + potential_fn = CallablePotentialWrapper( + callable_potential, prior=None, x_o=None, device=potential_device + ) # Ensure device string. self._device = process_device(potential_fn.device if device is None else device) diff --git a/sbi/inference/posteriors/mcmc_posterior.py b/sbi/inference/posteriors/mcmc_posterior.py index c1586a370..784309702 100644 --- a/sbi/inference/posteriors/mcmc_posterior.py +++ b/sbi/inference/posteriors/mcmc_posterior.py @@ -31,7 +31,6 @@ from sbi.types import Shape, TorchTransform from sbi.utils import pyro_potential_wrapper, tensor2numpy, transformed_potential from sbi.utils.torchutils import ensure_theta_batched -from sbi.inference.potentials.base_potential import BasePotential class MCMCPosterior(NeuralPosterior): diff --git a/sbi/inference/potentials/base_potential.py b/sbi/inference/potentials/base_potential.py index b99645fda..4a6bbdb3a 100644 --- a/sbi/inference/potentials/base_potential.py +++ b/sbi/inference/potentials/base_potential.py @@ -9,7 +9,10 @@ class BasePotential(metaclass=ABCMeta): def __init__( - self, prior: Optional[Distribution], x_o: Optional[Tensor] = None, device: str = "cpu" + self, + prior: Optional[Distribution], + x_o: Optional[Tensor] = None, + device: str = "cpu", ): """Initialize potential function. @@ -61,3 +64,22 @@ def return_x_o(self) -> Optional[Tensor]: `self._x_o` is `None`. """ return self._x_o + + +class CallablePotentialWrapper(BasePotential): + """If `potential_fn` is a callable it gets wrapped as this.""" + + allow_iid_x = True # type: ignore + + def __init__( + self, + callable_potential, + prior: Optional[Distribution], + x_o: Optional[Tensor] = None, + device: str = "cpu", + ): + super().__init__(prior, x_o, device) + self.callable_potential = callable_potential + + def __call__(self, theta, **kwargs): + return self.callable_potential(theta=theta, x_o=self.x_o, **kwargs) diff --git a/sbi/inference/potentials/likelihood_based_potential.py b/sbi/inference/potentials/likelihood_based_potential.py index 5967556c0..d3aa5e51d 100644 --- a/sbi/inference/potentials/likelihood_based_potential.py +++ b/sbi/inference/potentials/likelihood_based_potential.py @@ -96,7 +96,7 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor: track_gradients=track_gradients, ) - return log_likelihood_trial_sum + self.prior.log_prob(theta) + return log_likelihood_trial_sum + self.prior.log_prob(theta) # type: ignore def _log_likelihoods_over_trials( @@ -201,4 +201,4 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor: self.x_o.shape[0], -1 ).sum(0) - return log_likelihood_trial_sum + self.prior.log_prob(theta) + return log_likelihood_trial_sum + self.prior.log_prob(theta) # type: ignore diff --git a/sbi/inference/potentials/ratio_based_potential.py b/sbi/inference/potentials/ratio_based_potential.py index bd96df247..86a77e0a9 100644 --- a/sbi/inference/potentials/ratio_based_potential.py +++ b/sbi/inference/potentials/ratio_based_potential.py @@ -92,7 +92,7 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor: ) # Move to cpu for comparison with prior. - return log_likelihood_trial_sum + self.prior.log_prob(theta) + return log_likelihood_trial_sum + self.prior.log_prob(theta) # type: ignore def _log_ratios_over_trials( diff --git a/tests/potential_test.py b/tests/potential_test.py index 4784e4534..4c38ff6ce 100644 --- a/tests/potential_test.py +++ b/tests/potential_test.py @@ -21,6 +21,7 @@ [ImportanceSamplingPosterior, MCMCPosterior, RejectionPosterior, VIPosterior], ) def test_callable_potential(sampling_method): + """Test whether callable potentials can be used to sample from a Gaussian.""" dim = 2 mean = 2.5 cov = 2.0