Skip to content

Commit

Permalink
Move the CallablePotentialWrapper to make posteriors pickleable
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Feb 15, 2024
1 parent cb53e67 commit 770d8d3
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 13 deletions.
19 changes: 8 additions & 11 deletions sbi/inference/posteriors/base_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@
import torch.distributions.transforms as torch_tf
from torch import Tensor

from sbi.inference.potentials.base_potential import (
BasePotential,
CallablePotentialWrapper,
)
from sbi.types import Array, Shape, TorchTransform
from sbi.utils import gradient_ascent
from sbi.utils.torchutils import ensure_theta_batched, process_device
from sbi.utils.user_input_checks import process_x
from sbi.inference.potentials.base_potential import BasePotential


class NeuralPosterior(ABC):
Expand Down Expand Up @@ -41,16 +44,10 @@ def __init__(
"""
if not isinstance(potential_fn, BasePotential):
callable_potential = potential_fn

class CallablePotentialWrapper(BasePotential):
"""If `potential_fn` is a callable it gets wrapped as this."""
allow_iid_x = True

def __call__(self, theta, **kwargs):
return callable_potential(theta=theta, x_o=self.x_o, **kwargs)

potential_fn = CallablePotentialWrapper(None, None)

potential_device = "cpu" if device is None else device
potential_fn = CallablePotentialWrapper(
callable_potential, prior=None, x_o=None, device=potential_device
)

# Ensure device string.
self._device = process_device(potential_fn.device if device is None else device)
Expand Down
1 change: 0 additions & 1 deletion sbi/inference/posteriors/mcmc_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from sbi.types import Shape, TorchTransform
from sbi.utils import pyro_potential_wrapper, tensor2numpy, transformed_potential
from sbi.utils.torchutils import ensure_theta_batched
from sbi.inference.potentials.base_potential import BasePotential


class MCMCPosterior(NeuralPosterior):
Expand Down
24 changes: 23 additions & 1 deletion sbi/inference/potentials/base_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@

class BasePotential(metaclass=ABCMeta):
def __init__(
self, prior: Optional[Distribution], x_o: Optional[Tensor] = None, device: str = "cpu"
self,
prior: Optional[Distribution],
x_o: Optional[Tensor] = None,
device: str = "cpu",
):
"""Initialize potential function.
Expand Down Expand Up @@ -61,3 +64,22 @@ def return_x_o(self) -> Optional[Tensor]:
`self._x_o` is `None`.
"""
return self._x_o


class CallablePotentialWrapper(BasePotential):
"""If `potential_fn` is a callable it gets wrapped as this."""

allow_iid_x = True

def __init__(
self,
callable_potential,
prior: Optional[Distribution],
x_o: Optional[Tensor] = None,
device: str = "cpu",
):
super().__init__(prior, x_o, device)
self.callable_potential = callable_potential

def __call__(self, theta, **kwargs):
return self.callable_potential(theta=theta, x_o=self.x_o, **kwargs)
1 change: 1 addition & 0 deletions tests/potential_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
[ImportanceSamplingPosterior, MCMCPosterior, RejectionPosterior, VIPosterior],
)
def test_callable_potential(sampling_method):
"""Test whether callable potentials can be used to sample from a Gaussian."""
dim = 2
mean = 2.5
cov = 2.0
Expand Down

0 comments on commit 770d8d3

Please sign in to comment.