Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for the potential_fn to be a Callable #943

Merged
merged 4 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions sbi/inference/posteriors/base_posterior.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Affero General Public License v3, see <https://www.gnu.org/licenses/>.

import inspect
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Optional, Union

import torch
import torch.distributions.transforms as torch_tf
from torch import Tensor

from sbi.inference.potentials.base_potential import (
BasePotential,
CallablePotentialWrapper,
)
from sbi.types import Array, Shape, TorchTransform
from sbi.utils import gradient_ascent
from sbi.utils.torchutils import ensure_theta_batched, process_device
Expand All @@ -24,20 +29,36 @@ class NeuralPosterior(ABC):

def __init__(
self,
potential_fn: Callable,
potential_fn: Union[Callable, BasePotential],
theta_transform: Optional[TorchTransform] = None,
device: Optional[str] = None,
x_shape: Optional[torch.Size] = None,
):
"""
Args:
potential_fn: The potential function from which to draw samples.
potential_fn: The potential function from which to draw samples. Must be a
janfb marked this conversation as resolved.
Show resolved Hide resolved
`BasePotential` or a `Callable` which takes `theta` and `x_o` as inputs.
theta_transform: Transformation that will be applied during sampling.
Allows to perform, e.g. MCMC in unconstrained space.
device: Training device, e.g., "cpu", "cuda" or "cuda:0". If None,
`potential_fn.device` is used.
x_shape: Shape of the observed data.
"""
if not isinstance(potential_fn, BasePotential):
kwargs_of_callable = list(inspect.signature(potential_fn).parameters.keys())
janfb marked this conversation as resolved.
Show resolved Hide resolved
for key in ["theta", "x_o"]:
assert key in kwargs_of_callable, (
"If you pass a `Callable` as `potential_fn` then it must have "
"`theta` and `x_o` as inputs, even if some of these keyword "
"arguments are unused."
)

# If the `potential_fn` is a Callable then we wrap it as a
# `CallablePotentialWrapper` which inherits from `BasePotential`.
potential_device = "cpu" if device is None else device
potential_fn = CallablePotentialWrapper(
potential_fn, prior=None, x_o=None, device=potential_device
)

# Ensure device string.
self._device = process_device(potential_fn.device if device is None else device)
Expand Down
6 changes: 4 additions & 2 deletions sbi/inference/posteriors/importance_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from sbi import utils as utils
from sbi.inference.posteriors.base_posterior import NeuralPosterior
from sbi.inference.potentials.base_potential import BasePotential
from sbi.samplers.importance.importance_sampling import importance_sample
from sbi.samplers.importance.sir import sampling_importance_resampling
from sbi.types import Shape, TorchTransform
Expand All @@ -24,7 +25,7 @@ class ImportanceSamplingPosterior(NeuralPosterior):

def __init__(
self,
potential_fn: Callable,
potential_fn: Union[Callable, BasePotential],
proposal: Any,
theta_transform: Optional[TorchTransform] = None,
method: str = "sir",
Expand All @@ -35,7 +36,8 @@ def __init__(
):
"""
Args:
potential_fn: The potential function from which to draw samples.
potential_fn: The potential function from which to draw samples. Must be a
`BasePotential` or a `Callable` which takes `theta` and `x_o` as inputs.
proposal: The proposal distribution.
theta_transform: Transformation that is applied to parameters. Is not used
during but only when calling `.map()`.
Expand Down
6 changes: 4 additions & 2 deletions sbi/inference/posteriors/mcmc_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from tqdm.auto import tqdm

from sbi.inference.posteriors.base_posterior import NeuralPosterior
from sbi.inference.potentials.base_potential import BasePotential
from sbi.samplers.mcmc import (
IterateParameters,
Slice,
Expand All @@ -41,7 +42,7 @@ class MCMCPosterior(NeuralPosterior):

def __init__(
self,
potential_fn: Callable,
potential_fn: Union[Callable, BasePotential],
proposal: Any,
theta_transform: Optional[TorchTransform] = None,
method: str = "slice_np",
Expand All @@ -57,7 +58,8 @@ def __init__(
):
"""
Args:
potential_fn: The potential function from which to draw samples.
potential_fn: The potential function from which to draw samples. Must be a
`BasePotential` or a `Callable` which takes `theta` and `x_o` as inputs.
proposal: Proposal distribution that is used to initialize the MCMC chain.
theta_transform: Transformation that will be applied during sampling.
Allows to perform MCMC in unconstrained space.
Expand Down
6 changes: 4 additions & 2 deletions sbi/inference/posteriors/rejection_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from sbi import utils as utils
from sbi.inference.posteriors.base_posterior import NeuralPosterior
from sbi.inference.potentials.base_potential import BasePotential
from sbi.samplers.rejection.rejection import rejection_sample
from sbi.types import Shape, TorchTransform
from sbi.utils.torchutils import ensure_theta_batched
Expand All @@ -22,7 +23,7 @@ class RejectionPosterior(NeuralPosterior):

def __init__(
self,
potential_fn: Callable,
potential_fn: Union[Callable, BasePotential],
proposal: Any,
theta_transform: Optional[TorchTransform] = None,
max_sampling_batch_size: int = 10_000,
Expand All @@ -34,7 +35,8 @@ def __init__(
):
"""
Args:
potential_fn: The potential function from which to draw samples.
potential_fn: The potential function from which to draw samples. Must be a
`BasePotential` or a `Callable` which takes `theta` and `x_o` as inputs.
proposal: The proposal distribution.
theta_transform: Transformation that is applied to parameters. Is not used
during but only when calling `.map()`.
Expand Down
17 changes: 11 additions & 6 deletions sbi/inference/posteriors/vi_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from tqdm.auto import tqdm

from sbi.inference.posteriors.base_posterior import NeuralPosterior
from sbi.inference.potentials.base_potential import BasePotential
from sbi.samplers.vi import (
adapt_variational_distribution,
check_variational_distribution,
Expand Down Expand Up @@ -47,7 +48,7 @@ class VIPosterior(NeuralPosterior):

def __init__(
self,
potential_fn: Callable,
potential_fn: Union[Callable, BasePotential],
prior: Optional[TorchDistribution] = None,
q: Union[str, PyroTransformedDistribution, "VIPosterior", Callable] = "maf",
theta_transform: Optional[TorchTransform] = None,
Expand All @@ -59,7 +60,8 @@ def __init__(
):
"""
Args:
potential_fn: The potential function from which to draw samples.
potential_fn: The potential function from which to draw samples. Must be a
`BasePotential` or a `Callable` which takes `theta` and `x_o` as inputs.
prior: This is the prior distribution. Note that this is only
used to check/construct the variational distribution or within some
quality metrics. Please make sure that this matches with the prior
Expand Down Expand Up @@ -115,7 +117,7 @@ def __init__(
self._prior = q._prior
else:
raise ValueError(
"We could not find a suitable prior distribution within `potential_fn`"
"We could not find a suitable prior distribution within `potential_fn` "
"or `q` (if a VIPosterior is given). Please explicitly specify a prior."
)
move_all_tensor_to_device(self._prior, device)
Expand Down Expand Up @@ -461,9 +463,12 @@ def train(
self.evaluate(quality_control_metric=quality_control_metric)
except Exception as e:
print(
f"Quality control did not work, we reset the variational \
posterior,please check your setting. \
\n Following error occured {e}"
f"Quality control showed a low quality of the variational "
michaeldeistler marked this conversation as resolved.
Show resolved Hide resolved
f"posterior. We are automatically retraining the variational "
f"posterior from scratch with a smaller learning rate. "
f"Alternatively, if you want to skip quality control, please "
f"retrain with `VIPosterior.train(..., quality_control=False)`. "
f"\nThe error that occured is: {e}"
)
self.train(
learning_rate=learning_rate * 0.1,
Expand Down
26 changes: 25 additions & 1 deletion sbi/inference/potentials/base_potential.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABCMeta, abstractmethod
from typing import Optional

import torch
from torch import Tensor
from torch.distributions import Distribution

Expand All @@ -9,7 +10,10 @@

class BasePotential(metaclass=ABCMeta):
def __init__(
self, prior: Distribution, x_o: Optional[Tensor] = None, device: str = "cpu"
self,
prior: Optional[Distribution],
x_o: Optional[Tensor] = None,
device: str = "cpu",
):
"""Initialize potential function.

Expand Down Expand Up @@ -61,3 +65,23 @@ def return_x_o(self) -> Optional[Tensor]:
`self._x_o` is `None`.
"""
return self._x_o


class CallablePotentialWrapper(BasePotential):
"""If `potential_fn` is a callable it gets wrapped as this."""

allow_iid_x = True # type: ignore

def __init__(
self,
callable_potential,
prior: Optional[Distribution],
x_o: Optional[Tensor] = None,
device: str = "cpu",
):
super().__init__(prior, x_o, device)
self.callable_potential = callable_potential

def __call__(self, theta, track_gradients: bool = True):
with torch.set_grad_enabled(track_gradients):
return self.callable_potential(theta=theta, x_o=self.x_o)
4 changes: 2 additions & 2 deletions sbi/inference/potentials/likelihood_based_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
track_gradients=track_gradients,
)

return log_likelihood_trial_sum + self.prior.log_prob(theta)
return log_likelihood_trial_sum + self.prior.log_prob(theta) # type: ignore


def _log_likelihoods_over_trials(
Expand Down Expand Up @@ -201,4 +201,4 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
self.x_o.shape[0], -1
).sum(0)

return log_likelihood_trial_sum + self.prior.log_prob(theta)
return log_likelihood_trial_sum + self.prior.log_prob(theta) # type: ignore
2 changes: 1 addition & 1 deletion sbi/inference/potentials/ratio_based_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
)

# Move to cpu for comparison with prior.
return log_likelihood_trial_sum + self.prior.log_prob(theta)
return log_likelihood_trial_sum + self.prior.log_prob(theta) # type: ignore


def _log_ratios_over_trials(
Expand Down
61 changes: 61 additions & 0 deletions tests/potential_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Affero General Public License v3, see <https://www.gnu.org/licenses/>.

from __future__ import annotations

import pytest
import torch
from torch import eye, ones, zeros
from torch.distributions import MultivariateNormal

from sbi.inference import (
ImportanceSamplingPosterior,
MCMCPosterior,
RejectionPosterior,
VIPosterior,
)


@pytest.mark.parametrize(
"sampling_method",
[ImportanceSamplingPosterior, MCMCPosterior, RejectionPosterior, VIPosterior],
)
def test_callable_potential(sampling_method):
michaeldeistler marked this conversation as resolved.
Show resolved Hide resolved
"""Test whether callable potentials can be used to sample from a Gaussian."""
dim = 2
mean = 2.5
cov = 2.0
x_o = 1 * ones((dim,))
target_density = MultivariateNormal(mean * ones((dim,)), cov * eye(dim))

def potential(theta, x_o):
return target_density.log_prob(theta + x_o)

proposal = MultivariateNormal(zeros((dim,)), 5 * eye(dim))

if sampling_method == ImportanceSamplingPosterior:
approx_density = sampling_method(
potential_fn=potential, proposal=proposal, method="sir"
)
approx_samples = approx_density.sample((1024,), oversampling_factor=1024, x=x_o)
elif sampling_method == MCMCPosterior:
approx_density = sampling_method(potential_fn=potential, proposal=proposal)
approx_samples = approx_density.sample(
(1024,), x=x_o, num_chains=100, method="slice_np_vectorized"
)
elif sampling_method == VIPosterior:
approx_density = sampling_method(
potential_fn=potential, prior=proposal
).set_default_x(x_o)
approx_density = approx_density.train()
approx_samples = approx_density.sample((1024,))
elif sampling_method == RejectionPosterior:
approx_density = sampling_method(
potential_fn=potential, proposal=proposal
).set_default_x(x_o)
approx_samples = approx_density.sample((1024,))

sample_mean = torch.mean(approx_samples, dim=0)
sample_std = torch.std(approx_samples, dim=0)
assert torch.allclose(sample_mean, torch.as_tensor(mean) - x_o, atol=0.2)
assert torch.allclose(sample_std, torch.sqrt(torch.as_tensor(cov)), atol=0.1)
Loading