From 9aff544907509cb91c8833bf052bc8db1c26afba Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Wed, 14 Feb 2024 19:03:05 +0100 Subject: [PATCH 1/4] First working draft of potential_fn being a Callable --- sbi/inference/posteriors/base_posterior.py | 13 +++++++++++++ sbi/inference/posteriors/mcmc_posterior.py | 1 + sbi/inference/potentials/base_potential.py | 2 +- 3 files changed, 15 insertions(+), 1 deletion(-) diff --git a/sbi/inference/posteriors/base_posterior.py b/sbi/inference/posteriors/base_posterior.py index 888a30fab..89ccab1ae 100644 --- a/sbi/inference/posteriors/base_posterior.py +++ b/sbi/inference/posteriors/base_posterior.py @@ -12,6 +12,7 @@ from sbi.utils import gradient_ascent from sbi.utils.torchutils import ensure_theta_batched, process_device from sbi.utils.user_input_checks import process_x +from sbi.inference.potentials.base_potential import BasePotential class NeuralPosterior(ABC): @@ -38,6 +39,18 @@ def __init__( `potential_fn.device` is used. x_shape: Shape of the observed data. """ + if not isinstance(potential_fn, BasePotential): + callable_potential = potential_fn + + class CallablePotentialWrapper(BasePotential): + """If `potential_fn` is a callable it gets wrapped as this.""" + allow_iid_x = True + + def __call__(self, theta, **kwargs): + return callable_potential(theta=theta, x_o=self.x_o, **kwargs) + + potential_fn = CallablePotentialWrapper(None, None) + # Ensure device string. self._device = process_device(potential_fn.device if device is None else device) diff --git a/sbi/inference/posteriors/mcmc_posterior.py b/sbi/inference/posteriors/mcmc_posterior.py index 784309702..c1586a370 100644 --- a/sbi/inference/posteriors/mcmc_posterior.py +++ b/sbi/inference/posteriors/mcmc_posterior.py @@ -31,6 +31,7 @@ from sbi.types import Shape, TorchTransform from sbi.utils import pyro_potential_wrapper, tensor2numpy, transformed_potential from sbi.utils.torchutils import ensure_theta_batched +from sbi.inference.potentials.base_potential import BasePotential class MCMCPosterior(NeuralPosterior): diff --git a/sbi/inference/potentials/base_potential.py b/sbi/inference/potentials/base_potential.py index 5c1fe32d7..b99645fda 100644 --- a/sbi/inference/potentials/base_potential.py +++ b/sbi/inference/potentials/base_potential.py @@ -9,7 +9,7 @@ class BasePotential(metaclass=ABCMeta): def __init__( - self, prior: Distribution, x_o: Optional[Tensor] = None, device: str = "cpu" + self, prior: Optional[Distribution], x_o: Optional[Tensor] = None, device: str = "cpu" ): """Initialize potential function. From cb53e67f9098e62e875971b2e01807c59642c756 Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Thu, 15 Feb 2024 08:40:30 +0100 Subject: [PATCH 2/4] Test for callable potential --- sbi/inference/posteriors/vi_posterior.py | 11 +++-- tests/potential_test.py | 60 ++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 4 deletions(-) create mode 100644 tests/potential_test.py diff --git a/sbi/inference/posteriors/vi_posterior.py b/sbi/inference/posteriors/vi_posterior.py index 4bd29a4a3..95529b7dd 100644 --- a/sbi/inference/posteriors/vi_posterior.py +++ b/sbi/inference/posteriors/vi_posterior.py @@ -115,7 +115,7 @@ def __init__( self._prior = q._prior else: raise ValueError( - "We could not find a suitable prior distribution within `potential_fn`" + "We could not find a suitable prior distribution within `potential_fn` " "or `q` (if a VIPosterior is given). Please explicitly specify a prior." ) move_all_tensor_to_device(self._prior, device) @@ -461,9 +461,12 @@ def train( self.evaluate(quality_control_metric=quality_control_metric) except Exception as e: print( - f"Quality control did not work, we reset the variational \ - posterior,please check your setting. \ - \n Following error occured {e}" + f"Quality control showed a low quality of the variational " + f"posterior. We are automatically retraining the variational " + f"posterior from scratch with a smaller learning rate. " + f"Alternatively, if you want to skip quality control, please " + f"retrain with `VIPosterior.train(..., quality_control=False)`. " + f"\nThe error that occured is: {e}" ) self.train( learning_rate=learning_rate * 0.1, diff --git a/tests/potential_test.py b/tests/potential_test.py new file mode 100644 index 000000000..4784e4534 --- /dev/null +++ b/tests/potential_test.py @@ -0,0 +1,60 @@ +# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed +# under the Affero General Public License v3, see . + +from __future__ import annotations + +import pytest +import torch +from torch import eye, ones, zeros +from torch.distributions import MultivariateNormal + +from sbi.inference import ( + ImportanceSamplingPosterior, + MCMCPosterior, + RejectionPosterior, + VIPosterior, +) + + +@pytest.mark.parametrize( + "sampling_method", + [ImportanceSamplingPosterior, MCMCPosterior, RejectionPosterior, VIPosterior], +) +def test_callable_potential(sampling_method): + dim = 2 + mean = 2.5 + cov = 2.0 + x_o = 1 * ones((dim,)) + target_density = MultivariateNormal(mean * ones((dim,)), cov * eye(dim)) + + def potential(theta, x_o, **kwargs): + return target_density.log_prob(theta + x_o) + + proposal = MultivariateNormal(zeros((dim,)), 5 * eye(dim)) + + if sampling_method == ImportanceSamplingPosterior: + approx_density = sampling_method( + potential_fn=potential, proposal=proposal, method="sir" + ) + approx_samples = approx_density.sample((1024,), oversampling_factor=1024, x=x_o) + elif sampling_method == MCMCPosterior: + approx_density = sampling_method(potential_fn=potential, proposal=proposal) + approx_samples = approx_density.sample( + (1024,), x=x_o, num_chains=100, method="slice_np_vectorized" + ) + elif sampling_method == VIPosterior: + approx_density = sampling_method( + potential_fn=potential, prior=proposal + ).set_default_x(x_o) + approx_density = approx_density.train() + approx_samples = approx_density.sample((1024,)) + elif sampling_method == RejectionPosterior: + approx_density = sampling_method( + potential_fn=potential, proposal=proposal + ).set_default_x(x_o) + approx_samples = approx_density.sample((1024,)) + + sample_mean = torch.mean(approx_samples, dim=0) + sample_std = torch.std(approx_samples, dim=0) + assert torch.allclose(sample_mean, torch.as_tensor(mean) - x_o, atol=0.2) + assert torch.allclose(sample_std, torch.sqrt(torch.as_tensor(cov)), atol=0.1) From 2d556b0f17c02df94312591b223be0d26e60a1fb Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Thu, 15 Feb 2024 09:41:03 +0100 Subject: [PATCH 3/4] Move the CallablePotentialWrapper to make posteriors pickleable --- sbi/inference/posteriors/base_posterior.py | 19 +++++++-------- sbi/inference/posteriors/mcmc_posterior.py | 1 - sbi/inference/potentials/base_potential.py | 24 ++++++++++++++++++- .../potentials/likelihood_based_potential.py | 4 ++-- .../potentials/ratio_based_potential.py | 2 +- tests/potential_test.py | 1 + 6 files changed, 35 insertions(+), 16 deletions(-) diff --git a/sbi/inference/posteriors/base_posterior.py b/sbi/inference/posteriors/base_posterior.py index 89ccab1ae..1c2b52f2b 100644 --- a/sbi/inference/posteriors/base_posterior.py +++ b/sbi/inference/posteriors/base_posterior.py @@ -8,11 +8,14 @@ import torch.distributions.transforms as torch_tf from torch import Tensor +from sbi.inference.potentials.base_potential import ( + BasePotential, + CallablePotentialWrapper, +) from sbi.types import Array, Shape, TorchTransform from sbi.utils import gradient_ascent from sbi.utils.torchutils import ensure_theta_batched, process_device from sbi.utils.user_input_checks import process_x -from sbi.inference.potentials.base_potential import BasePotential class NeuralPosterior(ABC): @@ -41,16 +44,10 @@ def __init__( """ if not isinstance(potential_fn, BasePotential): callable_potential = potential_fn - - class CallablePotentialWrapper(BasePotential): - """If `potential_fn` is a callable it gets wrapped as this.""" - allow_iid_x = True - - def __call__(self, theta, **kwargs): - return callable_potential(theta=theta, x_o=self.x_o, **kwargs) - - potential_fn = CallablePotentialWrapper(None, None) - + potential_device = "cpu" if device is None else device + potential_fn = CallablePotentialWrapper( + callable_potential, prior=None, x_o=None, device=potential_device + ) # Ensure device string. self._device = process_device(potential_fn.device if device is None else device) diff --git a/sbi/inference/posteriors/mcmc_posterior.py b/sbi/inference/posteriors/mcmc_posterior.py index c1586a370..784309702 100644 --- a/sbi/inference/posteriors/mcmc_posterior.py +++ b/sbi/inference/posteriors/mcmc_posterior.py @@ -31,7 +31,6 @@ from sbi.types import Shape, TorchTransform from sbi.utils import pyro_potential_wrapper, tensor2numpy, transformed_potential from sbi.utils.torchutils import ensure_theta_batched -from sbi.inference.potentials.base_potential import BasePotential class MCMCPosterior(NeuralPosterior): diff --git a/sbi/inference/potentials/base_potential.py b/sbi/inference/potentials/base_potential.py index b99645fda..4a6bbdb3a 100644 --- a/sbi/inference/potentials/base_potential.py +++ b/sbi/inference/potentials/base_potential.py @@ -9,7 +9,10 @@ class BasePotential(metaclass=ABCMeta): def __init__( - self, prior: Optional[Distribution], x_o: Optional[Tensor] = None, device: str = "cpu" + self, + prior: Optional[Distribution], + x_o: Optional[Tensor] = None, + device: str = "cpu", ): """Initialize potential function. @@ -61,3 +64,22 @@ def return_x_o(self) -> Optional[Tensor]: `self._x_o` is `None`. """ return self._x_o + + +class CallablePotentialWrapper(BasePotential): + """If `potential_fn` is a callable it gets wrapped as this.""" + + allow_iid_x = True # type: ignore + + def __init__( + self, + callable_potential, + prior: Optional[Distribution], + x_o: Optional[Tensor] = None, + device: str = "cpu", + ): + super().__init__(prior, x_o, device) + self.callable_potential = callable_potential + + def __call__(self, theta, **kwargs): + return self.callable_potential(theta=theta, x_o=self.x_o, **kwargs) diff --git a/sbi/inference/potentials/likelihood_based_potential.py b/sbi/inference/potentials/likelihood_based_potential.py index 5967556c0..d3aa5e51d 100644 --- a/sbi/inference/potentials/likelihood_based_potential.py +++ b/sbi/inference/potentials/likelihood_based_potential.py @@ -96,7 +96,7 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor: track_gradients=track_gradients, ) - return log_likelihood_trial_sum + self.prior.log_prob(theta) + return log_likelihood_trial_sum + self.prior.log_prob(theta) # type: ignore def _log_likelihoods_over_trials( @@ -201,4 +201,4 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor: self.x_o.shape[0], -1 ).sum(0) - return log_likelihood_trial_sum + self.prior.log_prob(theta) + return log_likelihood_trial_sum + self.prior.log_prob(theta) # type: ignore diff --git a/sbi/inference/potentials/ratio_based_potential.py b/sbi/inference/potentials/ratio_based_potential.py index bd96df247..86a77e0a9 100644 --- a/sbi/inference/potentials/ratio_based_potential.py +++ b/sbi/inference/potentials/ratio_based_potential.py @@ -92,7 +92,7 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor: ) # Move to cpu for comparison with prior. - return log_likelihood_trial_sum + self.prior.log_prob(theta) + return log_likelihood_trial_sum + self.prior.log_prob(theta) # type: ignore def _log_ratios_over_trials( diff --git a/tests/potential_test.py b/tests/potential_test.py index 4784e4534..4c38ff6ce 100644 --- a/tests/potential_test.py +++ b/tests/potential_test.py @@ -21,6 +21,7 @@ [ImportanceSamplingPosterior, MCMCPosterior, RejectionPosterior, VIPosterior], ) def test_callable_potential(sampling_method): + """Test whether callable potentials can be used to sample from a Gaussian.""" dim = 2 mean = 2.5 cov = 2.0 From db89fc7e1ff47f337e2b00ea2fdf927b3ef634f9 Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Thu, 15 Feb 2024 11:15:08 +0100 Subject: [PATCH 4/4] Update docstring, ensure potential_fn signature --- sbi/inference/posteriors/base_posterior.py | 19 +++++++++++++++---- .../posteriors/importance_posterior.py | 6 ++++-- sbi/inference/posteriors/mcmc_posterior.py | 6 ++++-- .../posteriors/rejection_posterior.py | 6 ++++-- sbi/inference/posteriors/vi_posterior.py | 6 ++++-- sbi/inference/potentials/base_potential.py | 6 ++++-- tests/potential_test.py | 2 +- 7 files changed, 36 insertions(+), 15 deletions(-) diff --git a/sbi/inference/posteriors/base_posterior.py b/sbi/inference/posteriors/base_posterior.py index 1c2b52f2b..eafdddc41 100644 --- a/sbi/inference/posteriors/base_posterior.py +++ b/sbi/inference/posteriors/base_posterior.py @@ -1,6 +1,7 @@ # This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed # under the Affero General Public License v3, see . +import inspect from abc import ABC, abstractmethod from typing import Any, Callable, Dict, Optional, Union @@ -28,14 +29,15 @@ class NeuralPosterior(ABC): def __init__( self, - potential_fn: Callable, + potential_fn: Union[Callable, BasePotential], theta_transform: Optional[TorchTransform] = None, device: Optional[str] = None, x_shape: Optional[torch.Size] = None, ): """ Args: - potential_fn: The potential function from which to draw samples. + potential_fn: The potential function from which to draw samples. Must be a + `BasePotential` or a `Callable` which takes `theta` and `x_o` as inputs. theta_transform: Transformation that will be applied during sampling. Allows to perform, e.g. MCMC in unconstrained space. device: Training device, e.g., "cpu", "cuda" or "cuda:0". If None, @@ -43,10 +45,19 @@ def __init__( x_shape: Shape of the observed data. """ if not isinstance(potential_fn, BasePotential): - callable_potential = potential_fn + kwargs_of_callable = list(inspect.signature(potential_fn).parameters.keys()) + for key in ["theta", "x_o"]: + assert key in kwargs_of_callable, ( + "If you pass a `Callable` as `potential_fn` then it must have " + "`theta` and `x_o` as inputs, even if some of these keyword " + "arguments are unused." + ) + + # If the `potential_fn` is a Callable then we wrap it as a + # `CallablePotentialWrapper` which inherits from `BasePotential`. potential_device = "cpu" if device is None else device potential_fn = CallablePotentialWrapper( - callable_potential, prior=None, x_o=None, device=potential_device + potential_fn, prior=None, x_o=None, device=potential_device ) # Ensure device string. diff --git a/sbi/inference/posteriors/importance_posterior.py b/sbi/inference/posteriors/importance_posterior.py index 9074267ad..ef71b100a 100644 --- a/sbi/inference/posteriors/importance_posterior.py +++ b/sbi/inference/posteriors/importance_posterior.py @@ -7,6 +7,7 @@ from sbi import utils as utils from sbi.inference.posteriors.base_posterior import NeuralPosterior +from sbi.inference.potentials.base_potential import BasePotential from sbi.samplers.importance.importance_sampling import importance_sample from sbi.samplers.importance.sir import sampling_importance_resampling from sbi.types import Shape, TorchTransform @@ -24,7 +25,7 @@ class ImportanceSamplingPosterior(NeuralPosterior): def __init__( self, - potential_fn: Callable, + potential_fn: Union[Callable, BasePotential], proposal: Any, theta_transform: Optional[TorchTransform] = None, method: str = "sir", @@ -35,7 +36,8 @@ def __init__( ): """ Args: - potential_fn: The potential function from which to draw samples. + potential_fn: The potential function from which to draw samples. Must be a + `BasePotential` or a `Callable` which takes `theta` and `x_o` as inputs. proposal: The proposal distribution. theta_transform: Transformation that is applied to parameters. Is not used during but only when calling `.map()`. diff --git a/sbi/inference/posteriors/mcmc_posterior.py b/sbi/inference/posteriors/mcmc_posterior.py index 784309702..515ee6aa1 100644 --- a/sbi/inference/posteriors/mcmc_posterior.py +++ b/sbi/inference/posteriors/mcmc_posterior.py @@ -18,6 +18,7 @@ from tqdm.auto import tqdm from sbi.inference.posteriors.base_posterior import NeuralPosterior +from sbi.inference.potentials.base_potential import BasePotential from sbi.samplers.mcmc import ( IterateParameters, Slice, @@ -41,7 +42,7 @@ class MCMCPosterior(NeuralPosterior): def __init__( self, - potential_fn: Callable, + potential_fn: Union[Callable, BasePotential], proposal: Any, theta_transform: Optional[TorchTransform] = None, method: str = "slice_np", @@ -57,7 +58,8 @@ def __init__( ): """ Args: - potential_fn: The potential function from which to draw samples. + potential_fn: The potential function from which to draw samples. Must be a + `BasePotential` or a `Callable` which takes `theta` and `x_o` as inputs. proposal: Proposal distribution that is used to initialize the MCMC chain. theta_transform: Transformation that will be applied during sampling. Allows to perform MCMC in unconstrained space. diff --git a/sbi/inference/posteriors/rejection_posterior.py b/sbi/inference/posteriors/rejection_posterior.py index 92d55bcc1..a1382a6be 100644 --- a/sbi/inference/posteriors/rejection_posterior.py +++ b/sbi/inference/posteriors/rejection_posterior.py @@ -9,6 +9,7 @@ from sbi import utils as utils from sbi.inference.posteriors.base_posterior import NeuralPosterior +from sbi.inference.potentials.base_potential import BasePotential from sbi.samplers.rejection.rejection import rejection_sample from sbi.types import Shape, TorchTransform from sbi.utils.torchutils import ensure_theta_batched @@ -22,7 +23,7 @@ class RejectionPosterior(NeuralPosterior): def __init__( self, - potential_fn: Callable, + potential_fn: Union[Callable, BasePotential], proposal: Any, theta_transform: Optional[TorchTransform] = None, max_sampling_batch_size: int = 10_000, @@ -34,7 +35,8 @@ def __init__( ): """ Args: - potential_fn: The potential function from which to draw samples. + potential_fn: The potential function from which to draw samples. Must be a + `BasePotential` or a `Callable` which takes `theta` and `x_o` as inputs. proposal: The proposal distribution. theta_transform: Transformation that is applied to parameters. Is not used during but only when calling `.map()`. diff --git a/sbi/inference/posteriors/vi_posterior.py b/sbi/inference/posteriors/vi_posterior.py index 95529b7dd..315a69950 100644 --- a/sbi/inference/posteriors/vi_posterior.py +++ b/sbi/inference/posteriors/vi_posterior.py @@ -11,6 +11,7 @@ from tqdm.auto import tqdm from sbi.inference.posteriors.base_posterior import NeuralPosterior +from sbi.inference.potentials.base_potential import BasePotential from sbi.samplers.vi import ( adapt_variational_distribution, check_variational_distribution, @@ -47,7 +48,7 @@ class VIPosterior(NeuralPosterior): def __init__( self, - potential_fn: Callable, + potential_fn: Union[Callable, BasePotential], prior: Optional[TorchDistribution] = None, q: Union[str, PyroTransformedDistribution, "VIPosterior", Callable] = "maf", theta_transform: Optional[TorchTransform] = None, @@ -59,7 +60,8 @@ def __init__( ): """ Args: - potential_fn: The potential function from which to draw samples. + potential_fn: The potential function from which to draw samples. Must be a + `BasePotential` or a `Callable` which takes `theta` and `x_o` as inputs. prior: This is the prior distribution. Note that this is only used to check/construct the variational distribution or within some quality metrics. Please make sure that this matches with the prior diff --git a/sbi/inference/potentials/base_potential.py b/sbi/inference/potentials/base_potential.py index 4a6bbdb3a..f0f0a6272 100644 --- a/sbi/inference/potentials/base_potential.py +++ b/sbi/inference/potentials/base_potential.py @@ -1,6 +1,7 @@ from abc import ABCMeta, abstractmethod from typing import Optional +import torch from torch import Tensor from torch.distributions import Distribution @@ -81,5 +82,6 @@ def __init__( super().__init__(prior, x_o, device) self.callable_potential = callable_potential - def __call__(self, theta, **kwargs): - return self.callable_potential(theta=theta, x_o=self.x_o, **kwargs) + def __call__(self, theta, track_gradients: bool = True): + with torch.set_grad_enabled(track_gradients): + return self.callable_potential(theta=theta, x_o=self.x_o) diff --git a/tests/potential_test.py b/tests/potential_test.py index 4c38ff6ce..9b5c9ae6f 100644 --- a/tests/potential_test.py +++ b/tests/potential_test.py @@ -28,7 +28,7 @@ def test_callable_potential(sampling_method): x_o = 1 * ones((dim,)) target_density = MultivariateNormal(mean * ones((dim,)), cov * eye(dim)) - def potential(theta, x_o, **kwargs): + def potential(theta, x_o): return target_density.log_prob(theta + x_o) proposal = MultivariateNormal(zeros((dim,)), 5 * eye(dim))