Skip to content

Commit

Permalink
Update docstring, ensure potential_fn signature
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Feb 16, 2024
1 parent bfda675 commit b7851cc
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 15 deletions.
19 changes: 15 additions & 4 deletions sbi/inference/posteriors/base_posterior.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Affero General Public License v3, see <https://www.gnu.org/licenses/>.

import inspect
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Optional, Union

Expand Down Expand Up @@ -28,25 +29,35 @@ class NeuralPosterior(ABC):

def __init__(
self,
potential_fn: Callable,
potential_fn: Union[Callable, BasePotential],
theta_transform: Optional[TorchTransform] = None,
device: Optional[str] = None,
x_shape: Optional[torch.Size] = None,
):
"""
Args:
potential_fn: The potential function from which to draw samples.
potential_fn: The potential function from which to draw samples. Must be a
`BasePotential` or a `Callable` which takes `theta` and `x_o` as inputs.
theta_transform: Transformation that will be applied during sampling.
Allows to perform, e.g. MCMC in unconstrained space.
device: Training device, e.g., "cpu", "cuda" or "cuda:0". If None,
`potential_fn.device` is used.
x_shape: Shape of the observed data.
"""
if not isinstance(potential_fn, BasePotential):
callable_potential = potential_fn
kwargs_of_callable = list(inspect.signature(potential_fn).parameters.keys())
for key in ["theta", "x_o"]:
assert key in kwargs_of_callable, (
"If you pass a `Callable` as `potential_fn` then it must have "
"`theta` and `x_o` as inputs, even if some of these keyword "
"arguments are unused."
)

# If the `potential_fn` is a Callable then we wrap it as a
# `CallablePotentialWrapper` which inherits from `BasePotential`.
potential_device = "cpu" if device is None else device
potential_fn = CallablePotentialWrapper(
callable_potential, prior=None, x_o=None, device=potential_device
potential_fn, prior=None, x_o=None, device=potential_device
)

# Ensure device string.
Expand Down
6 changes: 4 additions & 2 deletions sbi/inference/posteriors/importance_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from sbi import utils as utils
from sbi.inference.posteriors.base_posterior import NeuralPosterior
from sbi.inference.potentials.base_potential import BasePotential
from sbi.samplers.importance.importance_sampling import importance_sample
from sbi.samplers.importance.sir import sampling_importance_resampling
from sbi.types import Shape, TorchTransform
Expand All @@ -24,7 +25,7 @@ class ImportanceSamplingPosterior(NeuralPosterior):

def __init__(
self,
potential_fn: Callable,
potential_fn: Union[Callable, BasePotential],
proposal: Any,
theta_transform: Optional[TorchTransform] = None,
method: str = "sir",
Expand All @@ -35,7 +36,8 @@ def __init__(
):
"""
Args:
potential_fn: The potential function from which to draw samples.
potential_fn: The potential function from which to draw samples. Must be a
`BasePotential` or a `Callable` which takes `theta` and `x_o` as inputs.
proposal: The proposal distribution.
theta_transform: Transformation that is applied to parameters. Is not used
during but only when calling `.map()`.
Expand Down
6 changes: 4 additions & 2 deletions sbi/inference/posteriors/mcmc_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from tqdm.auto import tqdm

from sbi.inference.posteriors.base_posterior import NeuralPosterior
from sbi.inference.potentials.base_potential import BasePotential
from sbi.samplers.mcmc import (
IterateParameters,
Slice,
Expand All @@ -41,7 +42,7 @@ class MCMCPosterior(NeuralPosterior):

def __init__(
self,
potential_fn: Callable,
potential_fn: Union[Callable, BasePotential],
proposal: Any,
theta_transform: Optional[TorchTransform] = None,
method: str = "slice_np",
Expand All @@ -57,7 +58,8 @@ def __init__(
):
"""
Args:
potential_fn: The potential function from which to draw samples.
potential_fn: The potential function from which to draw samples. Must be a
`BasePotential` or a `Callable` which takes `theta` and `x_o` as inputs.
proposal: Proposal distribution that is used to initialize the MCMC chain.
theta_transform: Transformation that will be applied during sampling.
Allows to perform MCMC in unconstrained space.
Expand Down
6 changes: 4 additions & 2 deletions sbi/inference/posteriors/rejection_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from sbi import utils as utils
from sbi.inference.posteriors.base_posterior import NeuralPosterior
from sbi.inference.potentials.base_potential import BasePotential
from sbi.samplers.rejection.rejection import rejection_sample
from sbi.types import Shape, TorchTransform
from sbi.utils.torchutils import ensure_theta_batched
Expand All @@ -22,7 +23,7 @@ class RejectionPosterior(NeuralPosterior):

def __init__(
self,
potential_fn: Callable,
potential_fn: Union[Callable, BasePotential],
proposal: Any,
theta_transform: Optional[TorchTransform] = None,
max_sampling_batch_size: int = 10_000,
Expand All @@ -34,7 +35,8 @@ def __init__(
):
"""
Args:
potential_fn: The potential function from which to draw samples.
potential_fn: The potential function from which to draw samples. Must be a
`BasePotential` or a `Callable` which takes `theta` and `x_o` as inputs.
proposal: The proposal distribution.
theta_transform: Transformation that is applied to parameters. Is not used
during but only when calling `.map()`.
Expand Down
6 changes: 4 additions & 2 deletions sbi/inference/posteriors/vi_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from tqdm.auto import tqdm

from sbi.inference.posteriors.base_posterior import NeuralPosterior
from sbi.inference.potentials.base_potential import BasePotential
from sbi.samplers.vi import (
adapt_variational_distribution,
check_variational_distribution,
Expand Down Expand Up @@ -47,7 +48,7 @@ class VIPosterior(NeuralPosterior):

def __init__(
self,
potential_fn: Callable,
potential_fn: Union[Callable, BasePotential],
prior: Optional[TorchDistribution] = None,
q: Union[str, PyroTransformedDistribution, "VIPosterior", Callable] = "maf",
theta_transform: Optional[TorchTransform] = None,
Expand All @@ -59,7 +60,8 @@ def __init__(
):
"""
Args:
potential_fn: The potential function from which to draw samples.
potential_fn: The potential function from which to draw samples. Must be a
`BasePotential` or a `Callable` which takes `theta` and `x_o` as inputs.
prior: This is the prior distribution. Note that this is only
used to check/construct the variational distribution or within some
quality metrics. Please make sure that this matches with the prior
Expand Down
6 changes: 4 additions & 2 deletions sbi/inference/potentials/base_potential.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABCMeta, abstractmethod
from typing import Optional

import torch
from torch import Tensor
from torch.distributions import Distribution

Expand Down Expand Up @@ -81,5 +82,6 @@ def __init__(
super().__init__(prior, x_o, device)
self.callable_potential = callable_potential

def __call__(self, theta, **kwargs):
return self.callable_potential(theta=theta, x_o=self.x_o, **kwargs)
def __call__(self, theta, track_gradients: bool = True):
with torch.set_grad_enabled(track_gradients):
return self.callable_potential(theta=theta, x_o=self.x_o)
2 changes: 1 addition & 1 deletion tests/potential_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_callable_potential(sampling_method):
x_o = 1 * ones((dim,))
target_density = MultivariateNormal(mean * ones((dim,)), cov * eye(dim))

def potential(theta, x_o, **kwargs):
def potential(theta, x_o):
return target_density.log_prob(theta + x_o)

proposal = MultivariateNormal(zeros((dim,)), 5 * eye(dim))
Expand Down

0 comments on commit b7851cc

Please sign in to comment.