From 06f0b864199962e54b91903406fc11d6a799671b Mon Sep 17 00:00:00 2001 From: Jan Boelts Date: Wed, 17 Jan 2024 09:51:30 +0100 Subject: [PATCH] refactor abc classes --- sbi/inference/abc/abc_base.py | 38 +++++++++++++------ sbi/inference/abc/mcabc.py | 19 +++++++--- sbi/inference/abc/smcabc.py | 69 ++++++++++++++++++++++------------- 3 files changed, 84 insertions(+), 42 deletions(-) diff --git a/sbi/inference/abc/abc_base.py b/sbi/inference/abc/abc_base.py index 3da6c7e47..3540ffac9 100644 --- a/sbi/inference/abc/abc_base.py +++ b/sbi/inference/abc/abc_base.py @@ -1,3 +1,4 @@ +"""Base class for Approximate Bayesian Computation methods.""" import logging from abc import ABC from typing import Callable, Union @@ -12,6 +13,8 @@ class ABCBASE(ABC): + """Base class for Approximate Bayesian Computation methods.""" + def __init__( self, simulator: Callable, @@ -47,6 +50,9 @@ def __init__( self._simulator = simulator self._show_progress_bars = show_progress_bars + self.x_o = None + self.x_shape = None + # Select distance function. self.distance = self.get_distance_function(distance) @@ -77,19 +83,27 @@ def get_distance_function(distance_type: Union[str, Callable] = "l2") -> Callabl if isinstance(distance_type, Callable): return distance_type - distances = ["l1", "l2", "mse"] + # Select distance function. + implemented_distances = ["l1", "l2", "mse"] assert ( - distance_type in distances - ), f"{distance_type} must be one of {distances}." - - if distance_type == "mse": - distance = lambda xo, x: torch.mean((xo - x) ** 2, dim=-1) - elif distance_type == "l2": - distance = lambda xo, x: torch.norm((xo - x), dim=-1) - elif distance_type == "l1": - distance = lambda xo, x: torch.mean(abs(xo - x), dim=-1) - else: - raise ValueError(r"Distance {distance_type} not supported.") + distance_type in implemented_distances + ), f"{distance_type} must be one of {implemented_distances}." + + def mse_distance(xo, x): + return torch.mean((xo - x) ** 2, dim=-1) + + def l2_distance(xo, x): + return torch.norm((xo - x), dim=-1) + + def l1_distance(xo, x): + return torch.mean(abs(xo - x), dim=-1) + + distance_functions = {"mse": mse_distance, "l2": l2_distance, "l1": l1_distance} + + try: + distance = distance_functions[distance_type] + except KeyError as exc: + raise KeyError(f"Distance {distance_type} not supported.") from exc def distance_fun(observed_data: Tensor, simulated_data: Tensor) -> Tensor: """Return distance over batch dimension. diff --git a/sbi/inference/abc/mcabc.py b/sbi/inference/abc/mcabc.py index 30c6e01aa..1b09d6a46 100644 --- a/sbi/inference/abc/mcabc.py +++ b/sbi/inference/abc/mcabc.py @@ -1,3 +1,4 @@ +"""Monte-Carlo Approximate Bayesian Computation (Rejection ABC).""" from typing import Any, Callable, Dict, Optional, Tuple, Union import torch @@ -9,6 +10,8 @@ class MCABC(ABCBASE): + """Monte-Carlo Approximate Bayesian Computation (Rejection ABC).""" + def __init__( self, simulator: Callable, @@ -64,7 +67,7 @@ def __call__( sass_fraction: float = 0.25, sass_expansion_degree: int = 1, kde: bool = False, - kde_kwargs: Dict[str, Any] = {}, + kde_kwargs: Optional[Dict[str, Any]] = None, return_summary: bool = False, ) -> Union[Tuple[Tensor, dict], Tuple[KDEWrapper, dict], Tensor, KDEWrapper]: r"""Run MCABC and return accepted parameters or KDE object fitted on them. @@ -107,12 +110,14 @@ def __call__( assert (eps is not None) ^ ( quantile is not None ), "Eps or quantile must be passed, but not both." + if kde_kwargs is None: + kde_kwargs = {} # Run SASS and change the simulator and x_o accordingly. if sass: num_pilot_simulations = int(sass_fraction * num_simulations) self.logger.info( - f"Running SASS with {num_pilot_simulations} pilot samples." + "Running SASS with %s pilot samples.", num_pilot_simulations ) num_simulations -= num_pilot_simulations @@ -123,7 +128,10 @@ def __call__( pilot_theta, pilot_x, sass_expansion_degree ) - simulator = lambda theta: sass_transform(self._batched_simulator(theta)) + # Add sass transform to simulator and x_o. + def simulator(theta): + return sass_transform(self._batched_simulator(theta)) + x_o = sass_transform(x_o) else: simulator = self._batched_simulator @@ -168,10 +176,11 @@ def __call__( if kde: self.logger.info( - f"""KDE on {final_theta.shape[0]} samples with bandwidth option + """KDE on %s samples with bandwidth option {kde_kwargs["bandwidth"] if "bandwidth" in kde_kwargs else "cv"}. Beware that KDE can give unreliable results when used with too few - samples and in high dimensions.""" + samples and in high dimensions.""", + final_theta.shape[0], ) kde_dist = get_kde(final_theta, **kde_kwargs) diff --git a/sbi/inference/abc/smcabc.py b/sbi/inference/abc/smcabc.py index 0f0eceb9f..3ece908f4 100644 --- a/sbi/inference/abc/smcabc.py +++ b/sbi/inference/abc/smcabc.py @@ -1,18 +1,20 @@ +"""Sequential Monte Carlo Approximate Bayesian Computation.""" from typing import Any, Callable, Dict, Optional, Tuple, Union import numpy as np import torch from numpy import ndarray -from pyro.distributions import Uniform -from torch import Tensor, ones, tensor +from torch import Tensor from torch.distributions import Distribution, Multinomial, MultivariateNormal from sbi.inference.abc.abc_base import ABCBASE from sbi.types import Array -from sbi.utils import KDEWrapper, get_kde, process_x, within_support +from sbi.utils import BoxUniform, KDEWrapper, get_kde, process_x, within_support class SMCABC(ABCBASE): + """Sequential Monte Carlo Approximate Bayesian Computation.""" + def __init__( self, simulator: Callable, @@ -85,6 +87,7 @@ def __init__( self.distance_to_x0 = None self.simulation_counter = 0 self.num_simulations = 0 + self.kernel_variance = None # Define simulator that keeps track of budget. def simulate_with_budget(theta): @@ -106,7 +109,7 @@ def __call__( use_last_pop_samples: bool = True, return_summary: bool = False, kde: bool = False, - kde_kwargs: Dict[str, Any] = {}, + kde_kwargs: Optional[Dict[str, Any]] = None, kde_sample_weights: bool = False, lra: bool = False, lra_with_weights: bool = False, @@ -135,7 +138,7 @@ def __call__( lra: Whether to run linear regression adjustment as in Beaumont et al. 2002 lra_with_weights: Whether to run lra as weighted linear regression with SMC weights - sass: Whether to determine semi-automatic summary statistics as in + sass: Whether to determine semi-automatic summary statistics (sass) as in Fearnhead & Prangle 2012. sass_fraction: Fraction of simulation budget used for the initial sass run. sass_expansion_degree: Degree of the polynomial feature expansion for the @@ -165,12 +168,15 @@ def __call__( pop_idx = 0 self.num_simulations = num_simulations + if kde_kwargs is None: + kde_kwargs = {} + assert isinstance(epsilon_decay, float) and epsilon_decay > 0.0 # Pilot run for SASS. if sass: num_pilot_simulations = int(sass_fraction * num_simulations) self.logger.info( - f"Running SASS with {num_pilot_simulations} pilot samples." + "Running SASS with %s pilot samples.", num_pilot_simulations ) sass_transform = self.run_sass_set_xo( num_particles, num_pilot_simulations, x_o, lra, sass_expansion_degree @@ -188,12 +194,15 @@ def sass_simulator(theta): particles, epsilon, distances, x = self._set_xo_and_sample_initial_population( x_o, num_particles, num_initial_pop ) - log_weights = torch.log(1 / num_particles * ones(num_particles)) + log_weights = torch.log(1 / num_particles * torch.ones(num_particles)) self.logger.info( ( - f"population={pop_idx}, eps={epsilon}, ess={1.0}, " - f"num_sims={num_initial_pop}" + "population=%s, eps=%s, ess=%s, num_sims=%s", + pop_idx, + epsilon, + 1.0, + num_initial_pop, ) ) @@ -238,8 +247,10 @@ def sass_simulator(theta): self.logger.info( ( - f"population={pop_idx} done: eps={epsilon:.6f}," - f" num_sims={self.simulation_counter}." + "population=%s done: eps={epsilon:.6f}, num_sims=%s.", + pop_idx, + epsilon, + self.simulation_counter, ) ) @@ -253,7 +264,7 @@ def sass_simulator(theta): # Maybe run LRA and adjust weights. if lra: self.logger.info("Running Linear regression adjustment.") - adjusted_particles, adjusted_weights = self.run_lra_update_weights( + adjusted_particles, _ = self.run_lra_update_weights( particles=all_particles[-1], xs=all_x[-1], observation=process_x(x_o), @@ -266,10 +277,11 @@ def sass_simulator(theta): if kde: self.logger.info( - f"""KDE on {final_particles.shape[0]} samples with bandwidth option - {kde_kwargs["bandwidth"] if "bandwidth" in kde_kwargs else "cv"}. - Beware that KDE can give unreliable results when used with too few - samples and in high dimensions.""" + """KDE on %s samples with bandwidth option %s. Beware that KDE can give + unreliable results when used with too few samples and in high + dimensions.""", + final_particles.shape[0], + kde_kwargs["bandwidth"] if "bandwidth" in kde_kwargs else "cv", ) # Maybe get particles weights from last population for weighted KDE. if kde_sample_weights: @@ -398,8 +410,9 @@ def _sample_next_population( if use_last_pop_samples: num_remaining = num_particles - num_accepted_particles self.logger.info( - f"""Simulation Budget exceeded, filling up with {num_remaining} - samples from last population.""" + """Simulation Budget exceeded, filling up with %s + samples from last population.""", + num_remaining, ) # Some new particles have been accepted already, therefore # fill up the remaining once with old particles and weights. @@ -467,8 +480,10 @@ def _get_next_epsilon(self, distances: Tensor, quantile: float) -> float: except IndexError: self.logger.warning( ( - f"Accepted unique distances={distances} don't match " - f"quantile={quantile:.2f}. Selecting last distance." + """Accepted unique distances=%s don't match quantile=%s. Selecting + last distance.""", + distances, + quantile, ) ) qidx = -1 @@ -494,7 +509,7 @@ def kernel_log_prob(new_particle): # We still have to loop over particles here because # the kernel log probs are already batched across old particles. - log_weighted_sum = tensor( + log_weighted_sum = torch.tensor( [ torch.logsumexp(old_log_weights + kernel_log_prob(new_particle), dim=0) for new_particle in new_particles @@ -552,6 +567,7 @@ def get_kernel_variance( samples_per_dim: int = 100, kernel_variance_scale: float = 1.0, ) -> Tensor: + """Return kernel variance for a given population of particles and weights.""" if self.kernel == "gaussian": # For variant C, Beaumont et al. 2009, the kernel variance comes from the # previous population. @@ -563,7 +579,7 @@ def get_kernel_variance( ) # Make sure variance is nonsingular. try: - torch.cholesky(kernel_variance_scale * population_cov) + torch.linalg.cholesky(kernel_variance_scale * population_cov) except RuntimeError: self.logger.warning( """"Singular particle covariance, using unit covariance.""" @@ -591,6 +607,7 @@ def get_new_kernel(self, thetas: Tensor) -> Distribution: """Return new kernel distribution for a given set of paramters.""" if self.kernel == "gaussian": + assert self.kernel_variance is not None, "get kernel variance first." assert self.kernel_variance.ndim == 2 return MultivariateNormal( loc=thetas, covariance_matrix=self.kernel_variance @@ -601,7 +618,7 @@ def get_new_kernel(self, thetas: Tensor) -> Distribution: high = thetas + self.kernel_variance # Move batch shape to event shape to get Uniform that is multivariate in # parameter dimension. - return Uniform(low=low, high=high).to_event(1) + return BoxUniform(low=low, high=high) else: raise ValueError(f"Kernel, '{self.kernel}' not supported.") @@ -620,12 +637,12 @@ def resample_if_ess_too_small( ess = (1 / torch.sum(torch.exp(2.0 * log_weights), dim=0)) / num_particles # Resampling of weights for low ESS only for Sisson et al. 2007. if ess < ess_min: - self.logger.info(f"ESS={ess:.2f} too low, resampling pop {pop_idx}...") + self.logger.info("ESS=%s too low, resampling pop %s...", ess, pop_idx) # First resample, then set to uniform weights as in Sisson et al. 2007. particles = self.sample_from_population_with_weights( particles, torch.exp(log_weights), num_samples=num_particles ) - log_weights = torch.log(1 / num_particles * ones(num_particles)) + log_weights = torch.log(1 / num_particles * torch.ones(num_particles)) return particles, log_weights @@ -684,6 +701,8 @@ def run_sass_set_xo( ) = self._set_xo_and_sample_initial_population( x_o, num_particles, num_pilot_simulations ) + assert self.x_o is not None, "x_o not set yet." + # Adjust with LRA. if lra: pilot_particles = self.run_lra(pilot_particles, pilot_xs, self.x_o)