Skip to content

Commit

Permalink
Move the CallablePotentialWrapper to make posteriors pickleable
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Feb 16, 2024
1 parent 4ed8cc9 commit bfda675
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 16 deletions.
19 changes: 8 additions & 11 deletions sbi/inference/posteriors/base_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@
import torch.distributions.transforms as torch_tf
from torch import Tensor

from sbi.inference.potentials.base_potential import (
BasePotential,
CallablePotentialWrapper,
)
from sbi.types import Array, Shape, TorchTransform
from sbi.utils import gradient_ascent
from sbi.utils.torchutils import ensure_theta_batched, process_device
from sbi.utils.user_input_checks import process_x
from sbi.inference.potentials.base_potential import BasePotential


class NeuralPosterior(ABC):
Expand Down Expand Up @@ -41,16 +44,10 @@ def __init__(
"""
if not isinstance(potential_fn, BasePotential):
callable_potential = potential_fn

class CallablePotentialWrapper(BasePotential):
"""If `potential_fn` is a callable it gets wrapped as this."""
allow_iid_x = True

def __call__(self, theta, **kwargs):
return callable_potential(theta=theta, x_o=self.x_o, **kwargs)

potential_fn = CallablePotentialWrapper(None, None)

potential_device = "cpu" if device is None else device
potential_fn = CallablePotentialWrapper(
callable_potential, prior=None, x_o=None, device=potential_device
)

# Ensure device string.
self._device = process_device(potential_fn.device if device is None else device)
Expand Down
1 change: 0 additions & 1 deletion sbi/inference/posteriors/mcmc_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from sbi.types import Shape, TorchTransform
from sbi.utils import pyro_potential_wrapper, tensor2numpy, transformed_potential
from sbi.utils.torchutils import ensure_theta_batched
from sbi.inference.potentials.base_potential import BasePotential


class MCMCPosterior(NeuralPosterior):
Expand Down
24 changes: 23 additions & 1 deletion sbi/inference/potentials/base_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@

class BasePotential(metaclass=ABCMeta):
def __init__(
self, prior: Optional[Distribution], x_o: Optional[Tensor] = None, device: str = "cpu"
self,
prior: Optional[Distribution],
x_o: Optional[Tensor] = None,
device: str = "cpu",
):
"""Initialize potential function.
Expand Down Expand Up @@ -61,3 +64,22 @@ def return_x_o(self) -> Optional[Tensor]:
`self._x_o` is `None`.
"""
return self._x_o


class CallablePotentialWrapper(BasePotential):
"""If `potential_fn` is a callable it gets wrapped as this."""

allow_iid_x = True # type: ignore

def __init__(
self,
callable_potential,
prior: Optional[Distribution],
x_o: Optional[Tensor] = None,
device: str = "cpu",
):
super().__init__(prior, x_o, device)
self.callable_potential = callable_potential

def __call__(self, theta, **kwargs):
return self.callable_potential(theta=theta, x_o=self.x_o, **kwargs)
4 changes: 2 additions & 2 deletions sbi/inference/potentials/likelihood_based_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
track_gradients=track_gradients,
)

return log_likelihood_trial_sum + self.prior.log_prob(theta)
return log_likelihood_trial_sum + self.prior.log_prob(theta) # type: ignore


def _log_likelihoods_over_trials(
Expand Down Expand Up @@ -201,4 +201,4 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
self.x_o.shape[0], -1
).sum(0)

return log_likelihood_trial_sum + self.prior.log_prob(theta)
return log_likelihood_trial_sum + self.prior.log_prob(theta) # type: ignore
2 changes: 1 addition & 1 deletion sbi/inference/potentials/ratio_based_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
)

# Move to cpu for comparison with prior.
return log_likelihood_trial_sum + self.prior.log_prob(theta)
return log_likelihood_trial_sum + self.prior.log_prob(theta) # type: ignore


def _log_ratios_over_trials(
Expand Down
1 change: 1 addition & 0 deletions tests/potential_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
[ImportanceSamplingPosterior, MCMCPosterior, RejectionPosterior, VIPosterior],
)
def test_callable_potential(sampling_method):
"""Test whether callable potentials can be used to sample from a Gaussian."""
dim = 2
mean = 2.5
cov = 2.0
Expand Down

0 comments on commit bfda675

Please sign in to comment.