Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH add h npe #1087

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
315 changes: 315 additions & 0 deletions examples/02_hNPE_with_extra_observation.ipynb

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion sbi/inference/snpe/snpe_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,8 +397,10 @@ def __init__(
prior: The prior distribution.
"""
# Call nn.Module's constructor.
super().__init__()

super().__init__(flow, flow._condition_shape)
self.net = flow
self._condition_shape = flow._condition_shape

self._neural_net = flow
self._prior = prior
Expand Down
52 changes: 7 additions & 45 deletions sbi/neural_nets/density_estimators/base.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from abc import ABC, abstractmethod
from typing import Optional, Tuple

import torch
from torch import Tensor, nn


class DensityEstimator(nn.Module):
class DensityEstimator(nn.Module, ABC):
r"""Base class for density estimators.

The density estimator class is a wrapper around neural networks that
allows to evaluate the `log_prob`, `sample`, and provide the `loss` of $\theta,x$
allows to evaluate the `log_prob`, `sample`, and provide the `loss` of $\theta, x$
pairs. Here $\theta$ would be the `input` and $x$ would be the `condition`.

Note:
Expand All @@ -19,23 +20,12 @@ class DensityEstimator(nn.Module):

"""

def __init__(self, net: nn.Module, condition_shape: torch.Size) -> None:
r"""Base class for density estimators.

Args:
net: Neural network.
condition_shape: Shape of the condition. If not provided, it will assume a
1D input.
"""
super().__init__()
self.net = net
self._condition_shape = condition_shape

@property
def embedding_net(self) -> Optional[nn.Module]:
r"""Return the embedding network if it exists."""
return None

@abstractmethod
def log_prob(self, input: Tensor, condition: Tensor, **kwargs) -> Tensor:
r"""Return the log probabilities of the inputs given a condition or multiple
i.e. batched conditions.
Expand Down Expand Up @@ -65,8 +55,7 @@ def log_prob(self, input: Tensor, condition: Tensor, **kwargs) -> Tensor:
- (batch_size1, input_size) + (batch_size2,1, *condition_shape)
-> (batch_size2,batch_size1)
"""

raise NotImplementedError
...

def loss(self, input: Tensor, condition: Tensor, **kwargs) -> Tensor:
r"""Return the loss for training the density estimator.
Expand All @@ -78,9 +67,9 @@ def loss(self, input: Tensor, condition: Tensor, **kwargs) -> Tensor:
Returns:
Loss of shape (batch_size,)
"""
return -self.log_prob(input, condition, **kwargs)

raise NotImplementedError

@abstractmethod
def sample(self, sample_shape: torch.Size, condition: Tensor, **kwargs) -> Tensor:
r"""Return samples from the density estimator.

Expand Down Expand Up @@ -123,30 +112,3 @@ def sample_and_log_prob(
samples = self.sample(sample_shape, condition, **kwargs)
log_probs = self.log_prob(samples, condition, **kwargs)
return samples, log_probs

def _check_condition_shape(self, condition: Tensor):
r"""This method checks whether the condition has the correct shape.

Args:
condition: Conditions of shape (*batch_shape, *condition_shape).

Raises:
ValueError: If the condition has a dimensionality that does not match
the expected input dimensionality.
ValueError: If the shape of the condition does not match the expected
input dimensionality.
"""
if len(condition.shape) < len(self._condition_shape):
raise ValueError(
f"Dimensionality of condition is to small and does not match the\
expected input dimensionality {len(self._condition_shape)}, as provided\
by condition_shape."
)
else:
condition_shape = condition.shape[-len(self._condition_shape) :]
if tuple(condition_shape) != tuple(self._condition_shape):
raise ValueError(
f"Shape of condition {tuple(condition_shape)} does not match the \
expected input dimensionality {tuple(self._condition_shape)}, as \
provided by condition_shape. Please reshape it accordingly."
)
145 changes: 145 additions & 0 deletions sbi/neural_nets/density_estimators/hierarchical_estimator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import functools

import torch

from sbi.neural_nets.density_estimators import DensityEstimator


def split_hierarchical(theta, dim_local):
return theta[..., :dim_local], theta[..., dim_local:]


def hierachical_simulator(n_extra, dim_local, p_local, simulator=None):
"""Return a hierachical simulator, which returns extra observations.
"""
if simulator is None:
return functools.partial(
hierachical_simulator, n_extra, dim_local, p_local)


def h_simulator(theta):
msg = (
"Hierarchical simulator only work with vector parameters, with "
f"of shape (n_batch, theta_dim). Got {theta.shape}."
)
assert theta.ndim == 2, msg
n_batch, theta_dim = theta.shape
local_theta, global_theta = split_hierarchical(theta, dim_local)
extra_local = p_local.sample((n_batch, n_extra))
all_theta_local = torch.concatenate(
(local_theta[:, None], extra_local), dim=1
)
all_theta = torch.concatenate(
(all_theta_local, global_theta.repeat([n_extra+1, 1])
.view(n_batch, n_extra+1, -1)), dim=2
)
observation = simulator(all_theta.view(n_batch * (n_extra+1), -1))
return observation.view((n_batch, n_extra + 1, *observation.shape[1:]))

return h_simulator


class HierarchicalDensityEstimator(DensityEstimator):

def __init__(
self, local_flow, global_flow, dim_local, condition_shape,
embedding_net: torch.nn.Module = torch.nn.Identity()
):

super().__init__()

self.dim_local = dim_local
self.local_flow = local_flow
self.global_flow = global_flow
self._embedding_net = embedding_net
self._condition_shape = condition_shape

@property
def embedding_net(self):
return self._embedding_net


@staticmethod
def embed_condition(embedding_net, condition, condition_shape):
'''Embed the condition for the hierarchical flow

Parameters
----------
condition: torch.Tensor, shape (n_batch, n_extra + 1, *condition_shape)
The hierarchical condition.

Returns
-------
global_condition: torch.Tensor, shape (n_batch, 2*n_embed)
local_condition: torch.Tensor, shape (n_batch, n_embed)
'''
if condition.ndim < len(condition_shape):
raise ValueError(
"condition should be at least with shape (n_extra, *condition_shape) "
f"but got {condition.shape}. This is likely because there is no "
"extra observations."
)
elif condition.ndim == len(condition_shape):
batch_condition_shape, n_extra = (), condition.shape[0]
else:
*batch_condition_shape, n_extra = condition.shape[:-len(condition_shape)+1]
condition_shape = condition_shape[1:] # remove n_extra
embedded_condition = embedding_net(
condition.view(-1, *condition_shape)
).reshape(*batch_condition_shape, n_extra, -1)

batch_slice = tuple(slice(None) for _ in range(len(batch_condition_shape)))
local_slice = (*batch_slice, slice(1))
agg_slice = (*batch_slice, slice(1, None))

local_condition = embedded_condition[local_slice]
agg_condition = torch.mean(
embedded_condition[agg_slice],
dim=len(batch_condition_shape), keepdim=True
)
global_condition = torch.concatenate(
(local_condition, agg_condition), dim=len(batch_condition_shape)
)
return (
local_condition.view(*batch_condition_shape, -1),
global_condition.view(*batch_condition_shape, -1),
)

def log_prob(self, theta, condition):
local_theta, global_theta = split_hierarchical(theta, self.dim_local)
local_condition, global_condition = self.embed_condition(
self.embedding_net, condition, self._condition_shape
)

log_p_global = self.global_flow.log_prob(
global_theta, global_condition
)

local_condition = torch.concatenate(
(local_condition, global_theta), dim=-1
)
log_p_local = self.local_flow.log_prob(
local_theta, local_condition
)
return log_p_global + log_p_local

def loss(self, inputs, condition):
return -self.log_prob(inputs, condition)

def sample(self, sample_shape, condition):
local_condition, global_condition = self.embed_condition(
self.embedding_net, condition, self._condition_shape
)

# shape (n_samples, 1)
global_samples = self.global_flow.sample(
sample_shape, global_condition
)
local_condition = torch.concatenate(
(local_condition.repeat((*sample_shape, 1)), global_samples), dim=-1
)
local_samples = self.local_flow.sample((1,), local_condition)[:, 0]

samples = torch.cat([local_samples, global_samples], dim=-1)
return samples

17 changes: 9 additions & 8 deletions sbi/neural_nets/density_estimators/nflows_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from sbi.neural_nets.density_estimators.base import DensityEstimator
from sbi.sbi_types import Shape
from sbi.utils.user_input_checks import check_condition_shape


class NFlowsFlow(DensityEstimator):
Expand All @@ -16,9 +17,9 @@ class NFlowsFlow(DensityEstimator):
"""

def __init__(self, net: Flow, condition_shape: torch.Size) -> None:
super().__init__(net, condition_shape)
# TODO: Remove as soon as DensityEstimator becomes abstract
self.net: Flow
super().__init__()
self.net = net
self._condition_shape = condition_shape

@property
def embedding_net(self) -> nn.Module:
Expand Down Expand Up @@ -54,7 +55,7 @@ def log_prob(self, input: Tensor, condition: Tensor) -> Tensor:
- (batch_size1, input_size) + (batch_size2,1, *condition_shape)
-> (batch_size2,batch_size1)
"""
self._check_condition_shape(condition)
check_condition_shape(condition, self._condition_shape)
condition_dims = len(self._condition_shape)

# PyTorch's automatic broadcasting
Expand Down Expand Up @@ -102,10 +103,10 @@ def sample(self, sample_shape: Shape, condition: Tensor) -> Tensor:
- (*batch_shape, *condition_shape)
-> (*batch_shape, *sample_shape, input_size)
"""
self._check_condition_shape(condition)
check_condition_shape(condition, self._condition_shape)
condition_dims = len(self._condition_shape)

num_samples = torch.Size(sample_shape).numel()
condition_dims = len(self._condition_shape)

if len(condition.shape) == condition_dims:
# nflows.sample() expects conditions to be batched.
Expand Down Expand Up @@ -138,10 +139,10 @@ def sample_and_log_prob(
Returns:
Samples and associated log probabilities.
"""
self._check_condition_shape(condition)
check_condition_shape(condition, self._condition_shape)
condition_dims = len(self._condition_shape)

num_samples = torch.Size(sample_shape).numel()
condition_dims = len(self._condition_shape)

if len(condition.shape) == condition_dims:
# nflows.sample() expects conditions to be batched.
Expand Down
13 changes: 7 additions & 6 deletions sbi/neural_nets/density_estimators/zuko_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from sbi.neural_nets.density_estimators.base import DensityEstimator
from sbi.sbi_types import Shape
from sbi.utils.user_input_checks import check_condition_shape


class ZukoFlow(DensityEstimator):
Expand All @@ -24,9 +25,9 @@ def __init__(
flow: Flow object.
condition_shape: Shape of the condition.
"""

# assert len(condition_shape) == 1, "Zuko Flows require 1D conditions."
super().__init__(net=net, condition_shape=condition_shape)
super().__init__()
self.net = net
self._condition_shape = condition_shape
self._embedding_net = embedding_net

@property
Expand Down Expand Up @@ -63,7 +64,7 @@ def log_prob(self, input: Tensor, condition: Tensor) -> Tensor:
- (batch_size1, input_size) + (batch_size2,1, *condition_shape)
-> (batch_size2,batch_size1)
"""
self._check_condition_shape(condition)
check_condition_shape(condition, self._condition_shape)
condition_dims = len(self._condition_shape)

# PyTorch's automatic broadcasting
Expand Down Expand Up @@ -110,7 +111,7 @@ def sample(self, sample_shape: Shape, condition: Tensor) -> Tensor:
- (*batch_shape, *condition_shape)
-> (*batch_shape, *sample_shape, input_size)
"""
self._check_condition_shape(condition)
check_condition_shape(condition, self._condition_shape)

condition_dims = len(self._condition_shape)
batch_shape = condition.shape[:-condition_dims] if condition_dims > 0 else ()
Expand All @@ -134,7 +135,7 @@ def sample_and_log_prob(
Returns:
Samples and associated log probabilities.
"""
self._check_condition_shape(condition)
check_condition_shape(condition, self._condition_shape)

condition_dims = len(self._condition_shape)
batch_shape = condition.shape[:-condition_dims] if condition_dims > 0 else ()
Expand Down
Loading
Loading