Skip to content

Commit

Permalink
refactor abc classes
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Jan 17, 2024
1 parent ed15539 commit 06f0b86
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 42 deletions.
38 changes: 26 additions & 12 deletions sbi/inference/abc/abc_base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Base class for Approximate Bayesian Computation methods."""
import logging
from abc import ABC
from typing import Callable, Union
Expand All @@ -12,6 +13,8 @@


class ABCBASE(ABC):
"""Base class for Approximate Bayesian Computation methods."""

def __init__(
self,
simulator: Callable,
Expand Down Expand Up @@ -47,6 +50,9 @@ def __init__(
self._simulator = simulator
self._show_progress_bars = show_progress_bars

self.x_o = None
self.x_shape = None

# Select distance function.
self.distance = self.get_distance_function(distance)

Expand Down Expand Up @@ -77,19 +83,27 @@ def get_distance_function(distance_type: Union[str, Callable] = "l2") -> Callabl
if isinstance(distance_type, Callable):
return distance_type

distances = ["l1", "l2", "mse"]
# Select distance function.
implemented_distances = ["l1", "l2", "mse"]
assert (
distance_type in distances
), f"{distance_type} must be one of {distances}."

if distance_type == "mse":
distance = lambda xo, x: torch.mean((xo - x) ** 2, dim=-1)
elif distance_type == "l2":
distance = lambda xo, x: torch.norm((xo - x), dim=-1)
elif distance_type == "l1":
distance = lambda xo, x: torch.mean(abs(xo - x), dim=-1)
else:
raise ValueError(r"Distance {distance_type} not supported.")
distance_type in implemented_distances
), f"{distance_type} must be one of {implemented_distances}."

def mse_distance(xo, x):
return torch.mean((xo - x) ** 2, dim=-1)

def l2_distance(xo, x):
return torch.norm((xo - x), dim=-1)

def l1_distance(xo, x):
return torch.mean(abs(xo - x), dim=-1)

distance_functions = {"mse": mse_distance, "l2": l2_distance, "l1": l1_distance}

try:
distance = distance_functions[distance_type]
except KeyError as exc:
raise KeyError(f"Distance {distance_type} not supported.") from exc

def distance_fun(observed_data: Tensor, simulated_data: Tensor) -> Tensor:
"""Return distance over batch dimension.
Expand Down
19 changes: 14 additions & 5 deletions sbi/inference/abc/mcabc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Monte-Carlo Approximate Bayesian Computation (Rejection ABC)."""
from typing import Any, Callable, Dict, Optional, Tuple, Union

import torch
Expand All @@ -9,6 +10,8 @@


class MCABC(ABCBASE):
"""Monte-Carlo Approximate Bayesian Computation (Rejection ABC)."""

def __init__(
self,
simulator: Callable,
Expand Down Expand Up @@ -64,7 +67,7 @@ def __call__(
sass_fraction: float = 0.25,
sass_expansion_degree: int = 1,
kde: bool = False,
kde_kwargs: Dict[str, Any] = {},
kde_kwargs: Optional[Dict[str, Any]] = None,
return_summary: bool = False,
) -> Union[Tuple[Tensor, dict], Tuple[KDEWrapper, dict], Tensor, KDEWrapper]:
r"""Run MCABC and return accepted parameters or KDE object fitted on them.
Expand Down Expand Up @@ -107,12 +110,14 @@ def __call__(
assert (eps is not None) ^ (
quantile is not None
), "Eps or quantile must be passed, but not both."
if kde_kwargs is None:
kde_kwargs = {}

# Run SASS and change the simulator and x_o accordingly.
if sass:
num_pilot_simulations = int(sass_fraction * num_simulations)
self.logger.info(
f"Running SASS with {num_pilot_simulations} pilot samples."
"Running SASS with %s pilot samples.", num_pilot_simulations
)
num_simulations -= num_pilot_simulations

Expand All @@ -123,7 +128,10 @@ def __call__(
pilot_theta, pilot_x, sass_expansion_degree
)

simulator = lambda theta: sass_transform(self._batched_simulator(theta))
# Add sass transform to simulator and x_o.
def simulator(theta):
return sass_transform(self._batched_simulator(theta))

x_o = sass_transform(x_o)
else:
simulator = self._batched_simulator
Expand Down Expand Up @@ -168,10 +176,11 @@ def __call__(

if kde:
self.logger.info(
f"""KDE on {final_theta.shape[0]} samples with bandwidth option
"""KDE on %s samples with bandwidth option
{kde_kwargs["bandwidth"] if "bandwidth" in kde_kwargs else "cv"}.
Beware that KDE can give unreliable results when used with too few
samples and in high dimensions."""
samples and in high dimensions.""",
final_theta.shape[0],
)

kde_dist = get_kde(final_theta, **kde_kwargs)
Expand Down
69 changes: 44 additions & 25 deletions sbi/inference/abc/smcabc.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
"""Sequential Monte Carlo Approximate Bayesian Computation."""
from typing import Any, Callable, Dict, Optional, Tuple, Union

import numpy as np
import torch
from numpy import ndarray
from pyro.distributions import Uniform
from torch import Tensor, ones, tensor
from torch import Tensor
from torch.distributions import Distribution, Multinomial, MultivariateNormal

from sbi.inference.abc.abc_base import ABCBASE
from sbi.types import Array
from sbi.utils import KDEWrapper, get_kde, process_x, within_support
from sbi.utils import BoxUniform, KDEWrapper, get_kde, process_x, within_support


class SMCABC(ABCBASE):
"""Sequential Monte Carlo Approximate Bayesian Computation."""

def __init__(
self,
simulator: Callable,
Expand Down Expand Up @@ -85,6 +87,7 @@ def __init__(
self.distance_to_x0 = None
self.simulation_counter = 0
self.num_simulations = 0
self.kernel_variance = None

# Define simulator that keeps track of budget.
def simulate_with_budget(theta):
Expand All @@ -106,7 +109,7 @@ def __call__(
use_last_pop_samples: bool = True,
return_summary: bool = False,
kde: bool = False,
kde_kwargs: Dict[str, Any] = {},
kde_kwargs: Optional[Dict[str, Any]] = None,
kde_sample_weights: bool = False,
lra: bool = False,
lra_with_weights: bool = False,
Expand Down Expand Up @@ -135,7 +138,7 @@ def __call__(
lra: Whether to run linear regression adjustment as in Beaumont et al. 2002
lra_with_weights: Whether to run lra as weighted linear regression with SMC
weights
sass: Whether to determine semi-automatic summary statistics as in
sass: Whether to determine semi-automatic summary statistics (sass) as in
Fearnhead & Prangle 2012.
sass_fraction: Fraction of simulation budget used for the initial sass run.
sass_expansion_degree: Degree of the polynomial feature expansion for the
Expand Down Expand Up @@ -165,12 +168,15 @@ def __call__(

pop_idx = 0
self.num_simulations = num_simulations
if kde_kwargs is None:
kde_kwargs = {}
assert isinstance(epsilon_decay, float) and epsilon_decay > 0.0

# Pilot run for SASS.
if sass:
num_pilot_simulations = int(sass_fraction * num_simulations)
self.logger.info(
f"Running SASS with {num_pilot_simulations} pilot samples."
"Running SASS with %s pilot samples.", num_pilot_simulations
)
sass_transform = self.run_sass_set_xo(
num_particles, num_pilot_simulations, x_o, lra, sass_expansion_degree
Expand All @@ -188,12 +194,15 @@ def sass_simulator(theta):
particles, epsilon, distances, x = self._set_xo_and_sample_initial_population(
x_o, num_particles, num_initial_pop
)
log_weights = torch.log(1 / num_particles * ones(num_particles))
log_weights = torch.log(1 / num_particles * torch.ones(num_particles))

self.logger.info(
(
f"population={pop_idx}, eps={epsilon}, ess={1.0}, "
f"num_sims={num_initial_pop}"
"population=%s, eps=%s, ess=%s, num_sims=%s",
pop_idx,
epsilon,
1.0,
num_initial_pop,
)
)

Expand Down Expand Up @@ -238,8 +247,10 @@ def sass_simulator(theta):

self.logger.info(
(
f"population={pop_idx} done: eps={epsilon:.6f},"
f" num_sims={self.simulation_counter}."
"population=%s done: eps={epsilon:.6f}, num_sims=%s.",
pop_idx,
epsilon,
self.simulation_counter,
)
)

Expand All @@ -253,7 +264,7 @@ def sass_simulator(theta):
# Maybe run LRA and adjust weights.
if lra:
self.logger.info("Running Linear regression adjustment.")
adjusted_particles, adjusted_weights = self.run_lra_update_weights(
adjusted_particles, _ = self.run_lra_update_weights(
particles=all_particles[-1],
xs=all_x[-1],
observation=process_x(x_o),
Expand All @@ -266,10 +277,11 @@ def sass_simulator(theta):

if kde:
self.logger.info(
f"""KDE on {final_particles.shape[0]} samples with bandwidth option
{kde_kwargs["bandwidth"] if "bandwidth" in kde_kwargs else "cv"}.
Beware that KDE can give unreliable results when used with too few
samples and in high dimensions."""
"""KDE on %s samples with bandwidth option %s. Beware that KDE can give
unreliable results when used with too few samples and in high
dimensions.""",
final_particles.shape[0],
kde_kwargs["bandwidth"] if "bandwidth" in kde_kwargs else "cv",
)
# Maybe get particles weights from last population for weighted KDE.
if kde_sample_weights:
Expand Down Expand Up @@ -398,8 +410,9 @@ def _sample_next_population(
if use_last_pop_samples:
num_remaining = num_particles - num_accepted_particles
self.logger.info(
f"""Simulation Budget exceeded, filling up with {num_remaining}
samples from last population."""
"""Simulation Budget exceeded, filling up with %s
samples from last population.""",
num_remaining,
)
# Some new particles have been accepted already, therefore
# fill up the remaining once with old particles and weights.
Expand Down Expand Up @@ -467,8 +480,10 @@ def _get_next_epsilon(self, distances: Tensor, quantile: float) -> float:
except IndexError:
self.logger.warning(
(
f"Accepted unique distances={distances} don't match "
f"quantile={quantile:.2f}. Selecting last distance."
"""Accepted unique distances=%s don't match quantile=%s. Selecting
last distance.""",
distances,
quantile,
)
)
qidx = -1
Expand All @@ -494,7 +509,7 @@ def kernel_log_prob(new_particle):

# We still have to loop over particles here because
# the kernel log probs are already batched across old particles.
log_weighted_sum = tensor(
log_weighted_sum = torch.tensor(
[
torch.logsumexp(old_log_weights + kernel_log_prob(new_particle), dim=0)
for new_particle in new_particles
Expand Down Expand Up @@ -552,6 +567,7 @@ def get_kernel_variance(
samples_per_dim: int = 100,
kernel_variance_scale: float = 1.0,
) -> Tensor:
"""Return kernel variance for a given population of particles and weights."""
if self.kernel == "gaussian":
# For variant C, Beaumont et al. 2009, the kernel variance comes from the
# previous population.
Expand All @@ -563,7 +579,7 @@ def get_kernel_variance(
)
# Make sure variance is nonsingular.
try:
torch.cholesky(kernel_variance_scale * population_cov)
torch.linalg.cholesky(kernel_variance_scale * population_cov)
except RuntimeError:
self.logger.warning(
""""Singular particle covariance, using unit covariance."""
Expand Down Expand Up @@ -591,6 +607,7 @@ def get_new_kernel(self, thetas: Tensor) -> Distribution:
"""Return new kernel distribution for a given set of paramters."""

if self.kernel == "gaussian":
assert self.kernel_variance is not None, "get kernel variance first."
assert self.kernel_variance.ndim == 2
return MultivariateNormal(
loc=thetas, covariance_matrix=self.kernel_variance
Expand All @@ -601,7 +618,7 @@ def get_new_kernel(self, thetas: Tensor) -> Distribution:
high = thetas + self.kernel_variance
# Move batch shape to event shape to get Uniform that is multivariate in
# parameter dimension.
return Uniform(low=low, high=high).to_event(1)
return BoxUniform(low=low, high=high)
else:
raise ValueError(f"Kernel, '{self.kernel}' not supported.")

Expand All @@ -620,12 +637,12 @@ def resample_if_ess_too_small(
ess = (1 / torch.sum(torch.exp(2.0 * log_weights), dim=0)) / num_particles
# Resampling of weights for low ESS only for Sisson et al. 2007.
if ess < ess_min:
self.logger.info(f"ESS={ess:.2f} too low, resampling pop {pop_idx}...")
self.logger.info("ESS=%s too low, resampling pop %s...", ess, pop_idx)
# First resample, then set to uniform weights as in Sisson et al. 2007.
particles = self.sample_from_population_with_weights(
particles, torch.exp(log_weights), num_samples=num_particles
)
log_weights = torch.log(1 / num_particles * ones(num_particles))
log_weights = torch.log(1 / num_particles * torch.ones(num_particles))

return particles, log_weights

Expand Down Expand Up @@ -684,6 +701,8 @@ def run_sass_set_xo(
) = self._set_xo_and_sample_initial_population(
x_o, num_particles, num_pilot_simulations
)
assert self.x_o is not None, "x_o not set yet."

# Adjust with LRA.
if lra:
pilot_particles = self.run_lra(pilot_particles, pilot_xs, self.x_o)
Expand Down

0 comments on commit 06f0b86

Please sign in to comment.