Skip to content

Commit

Permalink
Update docstring, ensure potential_fn signature
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Feb 15, 2024
1 parent 2d556b0 commit feaef83
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 10 deletions.
17 changes: 14 additions & 3 deletions sbi/inference/posteriors/base_posterior.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Affero General Public License v3, see <https://www.gnu.org/licenses/>.

import inspect
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Optional, Union

Expand Down Expand Up @@ -35,18 +36,28 @@ def __init__(
):
"""
Args:
potential_fn: The potential function from which to draw samples.
potential_fn: The potential function from which to draw samples. Must be a
`BasePotential` or a `Callable` which takes `theta` and `x` as inputs.
theta_transform: Transformation that will be applied during sampling.
Allows to perform, e.g. MCMC in unconstrained space.
device: Training device, e.g., "cpu", "cuda" or "cuda:0". If None,
`potential_fn.device` is used.
x_shape: Shape of the observed data.
"""
if not isinstance(potential_fn, BasePotential):
callable_potential = potential_fn
kwargs_of_callable = list(inspect.signature(potential_fn).parameters.keys())
for key in ["theta", "x_o"]:
assert key in kwargs_of_callable, (
"If you pass a `Callable` as `potential_fn` then it must have "
"`theta` and `x_o` as inputs, even if some of these keyword "
"arguments are unused."
)

# If the `potential_fn` is a Callable then we wrap it as a
# `CallablePotentialWrapper` which inherits from `BasePotential`.
potential_device = "cpu" if device is None else device
potential_fn = CallablePotentialWrapper(
callable_potential, prior=None, x_o=None, device=potential_device
potential_fn, prior=None, x_o=None, device=potential_device
)

# Ensure device string.
Expand Down
3 changes: 2 additions & 1 deletion sbi/inference/posteriors/importance_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def __init__(
):
"""
Args:
potential_fn: The potential function from which to draw samples.
potential_fn: The potential function from which to draw samples. Must be a
`BasePotential` or a `Callable` which takes `theta` and `x` as inputs.
proposal: The proposal distribution.
theta_transform: Transformation that is applied to parameters. Is not used
during but only when calling `.map()`.
Expand Down
3 changes: 2 additions & 1 deletion sbi/inference/posteriors/mcmc_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def __init__(
):
"""
Args:
potential_fn: The potential function from which to draw samples.
potential_fn: The potential function from which to draw samples. Must be a
`BasePotential` or a `Callable` which takes `theta` and `x` as inputs.
proposal: Proposal distribution that is used to initialize the MCMC chain.
theta_transform: Transformation that will be applied during sampling.
Allows to perform MCMC in unconstrained space.
Expand Down
3 changes: 2 additions & 1 deletion sbi/inference/posteriors/rejection_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def __init__(
):
"""
Args:
potential_fn: The potential function from which to draw samples.
potential_fn: The potential function from which to draw samples. Must be a
`BasePotential` or a `Callable` which takes `theta` and `x` as inputs.
proposal: The proposal distribution.
theta_transform: Transformation that is applied to parameters. Is not used
during but only when calling `.map()`.
Expand Down
3 changes: 2 additions & 1 deletion sbi/inference/posteriors/vi_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ def __init__(
):
"""
Args:
potential_fn: The potential function from which to draw samples.
potential_fn: The potential function from which to draw samples. Must be a
`BasePotential` or a `Callable` which takes `theta` and `x` as inputs.
prior: This is the prior distribution. Note that this is only
used to check/construct the variational distribution or within some
quality metrics. Please make sure that this matches with the prior
Expand Down
6 changes: 4 additions & 2 deletions sbi/inference/potentials/base_potential.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABCMeta, abstractmethod
from typing import Optional

import torch
from torch import Tensor
from torch.distributions import Distribution

Expand Down Expand Up @@ -81,5 +82,6 @@ def __init__(
super().__init__(prior, x_o, device)
self.callable_potential = callable_potential

def __call__(self, theta, **kwargs):
return self.callable_potential(theta=theta, x_o=self.x_o, **kwargs)
def __call__(self, theta, track_gradients: bool = True):
with torch.set_grad_enabled(track_gradients):
return self.callable_potential(theta=theta, x_o=self.x_o)
2 changes: 1 addition & 1 deletion tests/potential_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_callable_potential(sampling_method):
x_o = 1 * ones((dim,))
target_density = MultivariateNormal(mean * ones((dim,)), cov * eye(dim))

def potential(theta, x_o, **kwargs):
def potential(theta, x_o):
return target_density.log_prob(theta + x_o)

proposal = MultivariateNormal(zeros((dim,)), 5 * eye(dim))
Expand Down

0 comments on commit feaef83

Please sign in to comment.