diff --git a/sbi/inference/posteriors/base_posterior.py b/sbi/inference/posteriors/base_posterior.py index 1c2b52f2b..d34d2c19d 100644 --- a/sbi/inference/posteriors/base_posterior.py +++ b/sbi/inference/posteriors/base_posterior.py @@ -1,6 +1,7 @@ # This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed # under the Affero General Public License v3, see . +import inspect from abc import ABC, abstractmethod from typing import Any, Callable, Dict, Optional, Union @@ -35,7 +36,8 @@ def __init__( ): """ Args: - potential_fn: The potential function from which to draw samples. + potential_fn: The potential function from which to draw samples. Must be a + `BasePotential` or a `Callable` which takes `theta` and `x` as inputs. theta_transform: Transformation that will be applied during sampling. Allows to perform, e.g. MCMC in unconstrained space. device: Training device, e.g., "cpu", "cuda" or "cuda:0". If None, @@ -43,10 +45,19 @@ def __init__( x_shape: Shape of the observed data. """ if not isinstance(potential_fn, BasePotential): - callable_potential = potential_fn + kwargs_of_callable = list(inspect.signature(potential_fn).parameters.keys()) + for key in ["theta", "x_o"]: + assert key in kwargs_of_callable, ( + "If you pass a `Callable` as `potential_fn` then it must have " + "`theta` and `x_o` as inputs, even if some of these keyword " + "arguments are unused." + ) + + # If the `potential_fn` is a Callable then we wrap it as a + # `CallablePotentialWrapper` which inherits from `BasePotential`. potential_device = "cpu" if device is None else device potential_fn = CallablePotentialWrapper( - callable_potential, prior=None, x_o=None, device=potential_device + potential_fn, prior=None, x_o=None, device=potential_device ) # Ensure device string. diff --git a/sbi/inference/posteriors/importance_posterior.py b/sbi/inference/posteriors/importance_posterior.py index 9074267ad..55a684bdb 100644 --- a/sbi/inference/posteriors/importance_posterior.py +++ b/sbi/inference/posteriors/importance_posterior.py @@ -35,7 +35,8 @@ def __init__( ): """ Args: - potential_fn: The potential function from which to draw samples. + potential_fn: The potential function from which to draw samples. Must be a + `BasePotential` or a `Callable` which takes `theta` and `x` as inputs. proposal: The proposal distribution. theta_transform: Transformation that is applied to parameters. Is not used during but only when calling `.map()`. diff --git a/sbi/inference/posteriors/mcmc_posterior.py b/sbi/inference/posteriors/mcmc_posterior.py index 784309702..5ca112377 100644 --- a/sbi/inference/posteriors/mcmc_posterior.py +++ b/sbi/inference/posteriors/mcmc_posterior.py @@ -57,7 +57,8 @@ def __init__( ): """ Args: - potential_fn: The potential function from which to draw samples. + potential_fn: The potential function from which to draw samples. Must be a + `BasePotential` or a `Callable` which takes `theta` and `x` as inputs. proposal: Proposal distribution that is used to initialize the MCMC chain. theta_transform: Transformation that will be applied during sampling. Allows to perform MCMC in unconstrained space. diff --git a/sbi/inference/posteriors/rejection_posterior.py b/sbi/inference/posteriors/rejection_posterior.py index 92d55bcc1..2820bf5eb 100644 --- a/sbi/inference/posteriors/rejection_posterior.py +++ b/sbi/inference/posteriors/rejection_posterior.py @@ -34,7 +34,8 @@ def __init__( ): """ Args: - potential_fn: The potential function from which to draw samples. + potential_fn: The potential function from which to draw samples. Must be a + `BasePotential` or a `Callable` which takes `theta` and `x` as inputs. proposal: The proposal distribution. theta_transform: Transformation that is applied to parameters. Is not used during but only when calling `.map()`. diff --git a/sbi/inference/posteriors/vi_posterior.py b/sbi/inference/posteriors/vi_posterior.py index 95529b7dd..7b31aae3f 100644 --- a/sbi/inference/posteriors/vi_posterior.py +++ b/sbi/inference/posteriors/vi_posterior.py @@ -59,7 +59,8 @@ def __init__( ): """ Args: - potential_fn: The potential function from which to draw samples. + potential_fn: The potential function from which to draw samples. Must be a + `BasePotential` or a `Callable` which takes `theta` and `x` as inputs. prior: This is the prior distribution. Note that this is only used to check/construct the variational distribution or within some quality metrics. Please make sure that this matches with the prior diff --git a/sbi/inference/potentials/base_potential.py b/sbi/inference/potentials/base_potential.py index 4a6bbdb3a..f0f0a6272 100644 --- a/sbi/inference/potentials/base_potential.py +++ b/sbi/inference/potentials/base_potential.py @@ -1,6 +1,7 @@ from abc import ABCMeta, abstractmethod from typing import Optional +import torch from torch import Tensor from torch.distributions import Distribution @@ -81,5 +82,6 @@ def __init__( super().__init__(prior, x_o, device) self.callable_potential = callable_potential - def __call__(self, theta, **kwargs): - return self.callable_potential(theta=theta, x_o=self.x_o, **kwargs) + def __call__(self, theta, track_gradients: bool = True): + with torch.set_grad_enabled(track_gradients): + return self.callable_potential(theta=theta, x_o=self.x_o) diff --git a/tests/potential_test.py b/tests/potential_test.py index 4c38ff6ce..9b5c9ae6f 100644 --- a/tests/potential_test.py +++ b/tests/potential_test.py @@ -28,7 +28,7 @@ def test_callable_potential(sampling_method): x_o = 1 * ones((dim,)) target_density = MultivariateNormal(mean * ones((dim,)), cov * eye(dim)) - def potential(theta, x_o, **kwargs): + def potential(theta, x_o): return target_density.log_prob(theta + x_o) proposal = MultivariateNormal(zeros((dim,)), 5 * eye(dim))