From 770d8d3bc638930864f425d5021feedc73424965 Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Thu, 15 Feb 2024 09:41:03 +0100 Subject: [PATCH] Move the CallablePotentialWrapper to make posteriors pickleable --- sbi/inference/posteriors/base_posterior.py | 19 ++++++++--------- sbi/inference/posteriors/mcmc_posterior.py | 1 - sbi/inference/potentials/base_potential.py | 24 +++++++++++++++++++++- tests/potential_test.py | 1 + 4 files changed, 32 insertions(+), 13 deletions(-) diff --git a/sbi/inference/posteriors/base_posterior.py b/sbi/inference/posteriors/base_posterior.py index 89ccab1ae..1c2b52f2b 100644 --- a/sbi/inference/posteriors/base_posterior.py +++ b/sbi/inference/posteriors/base_posterior.py @@ -8,11 +8,14 @@ import torch.distributions.transforms as torch_tf from torch import Tensor +from sbi.inference.potentials.base_potential import ( + BasePotential, + CallablePotentialWrapper, +) from sbi.types import Array, Shape, TorchTransform from sbi.utils import gradient_ascent from sbi.utils.torchutils import ensure_theta_batched, process_device from sbi.utils.user_input_checks import process_x -from sbi.inference.potentials.base_potential import BasePotential class NeuralPosterior(ABC): @@ -41,16 +44,10 @@ def __init__( """ if not isinstance(potential_fn, BasePotential): callable_potential = potential_fn - - class CallablePotentialWrapper(BasePotential): - """If `potential_fn` is a callable it gets wrapped as this.""" - allow_iid_x = True - - def __call__(self, theta, **kwargs): - return callable_potential(theta=theta, x_o=self.x_o, **kwargs) - - potential_fn = CallablePotentialWrapper(None, None) - + potential_device = "cpu" if device is None else device + potential_fn = CallablePotentialWrapper( + callable_potential, prior=None, x_o=None, device=potential_device + ) # Ensure device string. self._device = process_device(potential_fn.device if device is None else device) diff --git a/sbi/inference/posteriors/mcmc_posterior.py b/sbi/inference/posteriors/mcmc_posterior.py index c1586a370..784309702 100644 --- a/sbi/inference/posteriors/mcmc_posterior.py +++ b/sbi/inference/posteriors/mcmc_posterior.py @@ -31,7 +31,6 @@ from sbi.types import Shape, TorchTransform from sbi.utils import pyro_potential_wrapper, tensor2numpy, transformed_potential from sbi.utils.torchutils import ensure_theta_batched -from sbi.inference.potentials.base_potential import BasePotential class MCMCPosterior(NeuralPosterior): diff --git a/sbi/inference/potentials/base_potential.py b/sbi/inference/potentials/base_potential.py index b99645fda..d35124e1d 100644 --- a/sbi/inference/potentials/base_potential.py +++ b/sbi/inference/potentials/base_potential.py @@ -9,7 +9,10 @@ class BasePotential(metaclass=ABCMeta): def __init__( - self, prior: Optional[Distribution], x_o: Optional[Tensor] = None, device: str = "cpu" + self, + prior: Optional[Distribution], + x_o: Optional[Tensor] = None, + device: str = "cpu", ): """Initialize potential function. @@ -61,3 +64,22 @@ def return_x_o(self) -> Optional[Tensor]: `self._x_o` is `None`. """ return self._x_o + + +class CallablePotentialWrapper(BasePotential): + """If `potential_fn` is a callable it gets wrapped as this.""" + + allow_iid_x = True + + def __init__( + self, + callable_potential, + prior: Optional[Distribution], + x_o: Optional[Tensor] = None, + device: str = "cpu", + ): + super().__init__(prior, x_o, device) + self.callable_potential = callable_potential + + def __call__(self, theta, **kwargs): + return self.callable_potential(theta=theta, x_o=self.x_o, **kwargs) diff --git a/tests/potential_test.py b/tests/potential_test.py index 4784e4534..4c38ff6ce 100644 --- a/tests/potential_test.py +++ b/tests/potential_test.py @@ -21,6 +21,7 @@ [ImportanceSamplingPosterior, MCMCPosterior, RejectionPosterior, VIPosterior], ) def test_callable_potential(sampling_method): + """Test whether callable potentials can be used to sample from a Gaussian.""" dim = 2 mean = 2.5 cov = 2.0