From 30ee8dba33614e182e4d226e9aaaf37f49f31754 Mon Sep 17 00:00:00 2001 From: tommoral Date: Thu, 21 Mar 2024 08:45:39 +0100 Subject: [PATCH 01/13] API simplify public DensityEstimator base class --- sbi/neural_nets/density_estimators/base.py | 40 ++++++++++--------- .../density_estimators/nflows_flow.py | 4 +- .../density_estimators/zuko_flow.py | 4 +- sbi/utils/sbiutils.py | 2 +- 4 files changed, 27 insertions(+), 23 deletions(-) diff --git a/sbi/neural_nets/density_estimators/base.py b/sbi/neural_nets/density_estimators/base.py index a2af32bc1..f5e65db49 100644 --- a/sbi/neural_nets/density_estimators/base.py +++ b/sbi/neural_nets/density_estimators/base.py @@ -1,14 +1,15 @@ +from abc import ABC, abstractmethod from typing import Optional, Tuple import torch from torch import Tensor, nn -class DensityEstimator(nn.Module): +class DensityEstimator(nn.Module, ABC): r"""Base class for density estimators. The density estimator class is a wrapper around neural networks that - allows to evaluate the `log_prob`, `sample`, and provide the `loss` of $\theta,x$ + allows to evaluate the `log_prob`, `sample`, and provide the `loss` of $\theta, x$ pairs. Here $\theta$ would be the `input` and $x$ would be the `condition`. Note: @@ -19,23 +20,12 @@ class DensityEstimator(nn.Module): """ - def __init__(self, net: nn.Module, condition_shape: torch.Size) -> None: - r"""Base class for density estimators. - - Args: - net: Neural network. - condition_shape: Shape of the condition. If not provided, it will assume a - 1D input. - """ - super().__init__() - self.net = net - self._condition_shape = condition_shape - @property def embedding_net(self) -> Optional[nn.Module]: r"""Return the embedding network if it exists.""" return None + @abstractmethod def log_prob(self, input: Tensor, condition: Tensor, **kwargs) -> Tensor: r"""Return the log probabilities of the inputs given a condition or multiple i.e. batched conditions. @@ -65,9 +55,9 @@ def log_prob(self, input: Tensor, condition: Tensor, **kwargs) -> Tensor: - (batch_size1, input_size) + (batch_size2,1, *condition_shape) -> (batch_size2,batch_size1) """ + ... - raise NotImplementedError - + @abstractmethod def loss(self, input: Tensor, condition: Tensor, **kwargs) -> Tensor: r"""Return the loss for training the density estimator. @@ -78,9 +68,9 @@ def loss(self, input: Tensor, condition: Tensor, **kwargs) -> Tensor: Returns: Loss of shape (batch_size,) """ + ... - raise NotImplementedError - + @abstractmethod def sample(self, sample_shape: torch.Size, condition: Tensor, **kwargs) -> Tensor: r"""Return samples from the density estimator. @@ -124,6 +114,20 @@ def sample_and_log_prob( log_probs = self.log_prob(samples, condition, **kwargs) return samples, log_probs + +class SBIDensityEstimator(DensityEstimator): + def __init__(self, net: nn.Module, condition_shape: torch.Size) -> None: + r"""Base class for density estimators. + + Args: + net: Neural network. + condition_shape: Shape of the condition. If not provided, it will assume a + 1D input. + """ + super().__init__() + self.net = net + self._condition_shape = condition_shape + def _check_condition_shape(self, condition: Tensor): r"""This method checks whether the condition has the correct shape. diff --git a/sbi/neural_nets/density_estimators/nflows_flow.py b/sbi/neural_nets/density_estimators/nflows_flow.py index ad578d96c..9358f2671 100644 --- a/sbi/neural_nets/density_estimators/nflows_flow.py +++ b/sbi/neural_nets/density_estimators/nflows_flow.py @@ -4,11 +4,11 @@ from pyknos.nflows.flows import Flow from torch import Tensor, nn -from sbi.neural_nets.density_estimators.base import DensityEstimator +from sbi.neural_nets.density_estimators.base import SBIDensityEstimator from sbi.types import Shape -class NFlowsFlow(DensityEstimator): +class NFlowsFlow(SBIDensityEstimator): r"""`nflows`- based normalizing flow density estimator. Flow type objects already have a .log_prob() and .sample() method, so here we just diff --git a/sbi/neural_nets/density_estimators/zuko_flow.py b/sbi/neural_nets/density_estimators/zuko_flow.py index 7cbb240a9..bb00deb69 100644 --- a/sbi/neural_nets/density_estimators/zuko_flow.py +++ b/sbi/neural_nets/density_estimators/zuko_flow.py @@ -4,11 +4,11 @@ from torch import Tensor, nn from zuko.flows import Flow -from sbi.neural_nets.density_estimators.base import DensityEstimator +from sbi.neural_nets.density_estimators.base import SBIDensityEstimator from sbi.types import Shape -class ZukoFlow(DensityEstimator): +class ZukoFlow(SBIDensityEstimator): r"""`zuko`- based normalizing flow density estimator. Flow type objects already have a .log_prob() and .sample() method, so here we just diff --git a/sbi/utils/sbiutils.py b/sbi/utils/sbiutils.py index 4f9bcd980..407e52360 100644 --- a/sbi/utils/sbiutils.py +++ b/sbi/utils/sbiutils.py @@ -176,7 +176,7 @@ def standardizing_transform( t_std[t_std < min_std] = min_std if backend == "nflows": - return transforms.AffineTransform(shift=-t_mean / t_std, scale=1 / t_std) + return transforms.PointwiseAffineTransform(shift=-t_mean / t_std, scale=1 / t_std) elif backend == "zuko": return zuko.flows.Unconditional( zuko.transforms.MonotonicAffineTransform, From e27c148145ba213ce7f966a20f8d93706264d514 Mon Sep 17 00:00:00 2001 From: tommoral Date: Thu, 21 Mar 2024 09:13:33 +0100 Subject: [PATCH 02/13] CLN remove SBIDensityEstimator and make dedicated condition_shape check --- sbi/neural_nets/density_estimators/base.py | 41 ------------------- .../density_estimators/nflows_flow.py | 19 +++++---- .../density_estimators/zuko_flow.py | 16 ++++---- sbi/utils/user_input_checks.py | 31 ++++++++++++++ 4 files changed, 50 insertions(+), 57 deletions(-) diff --git a/sbi/neural_nets/density_estimators/base.py b/sbi/neural_nets/density_estimators/base.py index f5e65db49..a35edfd64 100644 --- a/sbi/neural_nets/density_estimators/base.py +++ b/sbi/neural_nets/density_estimators/base.py @@ -113,44 +113,3 @@ def sample_and_log_prob( samples = self.sample(sample_shape, condition, **kwargs) log_probs = self.log_prob(samples, condition, **kwargs) return samples, log_probs - - -class SBIDensityEstimator(DensityEstimator): - def __init__(self, net: nn.Module, condition_shape: torch.Size) -> None: - r"""Base class for density estimators. - - Args: - net: Neural network. - condition_shape: Shape of the condition. If not provided, it will assume a - 1D input. - """ - super().__init__() - self.net = net - self._condition_shape = condition_shape - - def _check_condition_shape(self, condition: Tensor): - r"""This method checks whether the condition has the correct shape. - - Args: - condition: Conditions of shape (*batch_shape, *condition_shape). - - Raises: - ValueError: If the condition has a dimensionality that does not match - the expected input dimensionality. - ValueError: If the shape of the condition does not match the expected - input dimensionality. - """ - if len(condition.shape) < len(self._condition_shape): - raise ValueError( - f"Dimensionality of condition is to small and does not match the\ - expected input dimensionality {len(self._condition_shape)}, as provided\ - by condition_shape." - ) - else: - condition_shape = condition.shape[-len(self._condition_shape) :] - if tuple(condition_shape) != tuple(self._condition_shape): - raise ValueError( - f"Shape of condition {tuple(condition_shape)} does not match the \ - expected input dimensionality {tuple(self._condition_shape)}, as \ - provided by condition_shape. Please reshape it accordingly." - ) diff --git a/sbi/neural_nets/density_estimators/nflows_flow.py b/sbi/neural_nets/density_estimators/nflows_flow.py index 9358f2671..fe14c9046 100644 --- a/sbi/neural_nets/density_estimators/nflows_flow.py +++ b/sbi/neural_nets/density_estimators/nflows_flow.py @@ -4,11 +4,12 @@ from pyknos.nflows.flows import Flow from torch import Tensor, nn -from sbi.neural_nets.density_estimators.base import SBIDensityEstimator +from sbi.neural_nets.density_estimators.base import DensityEstimator from sbi.types import Shape +from sbi.utils.user_input_checks import check_condition_shape -class NFlowsFlow(SBIDensityEstimator): +class NFlowsFlow(DensityEstimator): r"""`nflows`- based normalizing flow density estimator. Flow type objects already have a .log_prob() and .sample() method, so here we just @@ -16,7 +17,9 @@ class NFlowsFlow(SBIDensityEstimator): """ def __init__(self, net: Flow, condition_shape: torch.Size) -> None: - super().__init__(net, condition_shape) + super().__init__() + self._net = net + self._condition_shape = condition_shape @property def embedding_net(self) -> nn.Module: @@ -52,7 +55,7 @@ def log_prob(self, input: Tensor, condition: Tensor) -> Tensor: - (batch_size1, input_size) + (batch_size2,1, *condition_shape) -> (batch_size2,batch_size1) """ - self._check_condition_shape(condition) + check_condition_shape(condition, self._condition_shape) condition_dims = len(self._condition_shape) # PyTorch's automatic broadcasting @@ -100,10 +103,10 @@ def sample(self, sample_shape: Shape, condition: Tensor) -> Tensor: - (*batch_shape, *condition_shape) -> (*batch_shape, *sample_shape, input_size) """ - self._check_condition_shape(condition) + check_condition_shape(condition, self._condition_shape) + condition_dims = len(self._condition_shape) num_samples = torch.Size(sample_shape).numel() - condition_dims = len(self._condition_shape) if len(condition.shape) == condition_dims: # nflows.sample() expects conditions to be batched. @@ -136,10 +139,10 @@ def sample_and_log_prob( Returns: Samples and associated log probabilities. """ - self._check_condition_shape(condition) + check_condition_shape(condition, self._condition_shape) + condition_dims = len(self._condition_shape) num_samples = torch.Size(sample_shape).numel() - condition_dims = len(self._condition_shape) if len(condition.shape) == condition_dims: # nflows.sample() expects conditions to be batched. diff --git a/sbi/neural_nets/density_estimators/zuko_flow.py b/sbi/neural_nets/density_estimators/zuko_flow.py index bb00deb69..6713f0c7c 100644 --- a/sbi/neural_nets/density_estimators/zuko_flow.py +++ b/sbi/neural_nets/density_estimators/zuko_flow.py @@ -4,11 +4,11 @@ from torch import Tensor, nn from zuko.flows import Flow -from sbi.neural_nets.density_estimators.base import SBIDensityEstimator +from sbi.neural_nets.density_estimators.base import DensityEstimator from sbi.types import Shape -class ZukoFlow(SBIDensityEstimator): +class ZukoFlow(DensityEstimator): r"""`zuko`- based normalizing flow density estimator. Flow type objects already have a .log_prob() and .sample() method, so here we just @@ -24,9 +24,9 @@ def __init__( flow: Flow object. condition_shape: Shape of the condition. """ - - # assert len(condition_shape) == 1, "Zuko Flows require 1D conditions." - super().__init__(net=net, condition_shape=condition_shape) + super().__init__() + self.net = net + self._condition_shape = condition_shape self._embedding_net = embedding_net @property @@ -63,7 +63,7 @@ def log_prob(self, input: Tensor, condition: Tensor) -> Tensor: - (batch_size1, input_size) + (batch_size2,1, *condition_shape) -> (batch_size2,batch_size1) """ - self._check_condition_shape(condition) + check_condition_shape(condition, self._condition_shape) condition_dims = len(self._condition_shape) # PyTorch's automatic broadcasting @@ -110,7 +110,7 @@ def sample(self, sample_shape: Shape, condition: Tensor) -> Tensor: - (*batch_shape, *condition_shape) -> (*batch_shape, *sample_shape, input_size) """ - self._check_condition_shape(condition) + check_condition_shape(condition, self._condition_shape) condition_dims = len(self._condition_shape) batch_shape = condition.shape[:-condition_dims] if condition_dims > 0 else () @@ -134,7 +134,7 @@ def sample_and_log_prob( Returns: Samples and associated log probabilities. """ - self._check_condition_shape(condition) + check_condition_shape(condition, self._condition_shape) condition_dims = len(self._condition_shape) batch_shape = condition.shape[:-condition_dims] if condition_dims > 0 else () diff --git a/sbi/utils/user_input_checks.py b/sbi/utils/user_input_checks.py index dfd9ea5e3..799c3d2fd 100644 --- a/sbi/utils/user_input_checks.py +++ b/sbi/utils/user_input_checks.py @@ -741,6 +741,37 @@ def validate_theta_and_x( return theta, x +def check_condition_shape(condition, condition_shape=None): + r"""This method checks whether the condition has the correct shape. + + Args: + condition: Conditions of shape (*batch_shape, *condition_shape). + condition_shape: Shape of the condition. + If not provided, it will assume a 1D input. + + Raises: + ValueError: If the condition has a dimensionality that does not match + the expected input dimensionality. + ValueError: If the shape of the condition does not match the expected + input dimensionality. + """ + if condition_shape is None: + condition_shape = (1,) + if condition.ndim < len(condition_shape): + raise ValueError( + "Dimensionality of condition is to small and does not match the " + f"expected input dimensionality {len(condition_shape)}, as provided " + "by condition_shape." + ) + else: + condition_shape = condition.shape[-len(condition_shape) :] + if tuple(condition_shape) != tuple(condition_shape): + raise ValueError( + f"Shape of condition {tuple(condition_shape)} does not match the " + f"expected input dimensionality {tuple(condition_shape)}, as " + "provided by condition_shape. Please reshape it accordingly." + ) + def test_posterior_net_for_multi_d_x(net: flows.Flow, theta: Tensor, x: Tensor) -> None: """Test log prob method of the net. From 04988162cd88fbc5d6b4cc0d8dca321e206e2935 Mon Sep 17 00:00:00 2001 From: tommoral Date: Thu, 21 Mar 2024 09:21:23 +0100 Subject: [PATCH 03/13] FIX linter+missing import --- sbi/neural_nets/density_estimators/zuko_flow.py | 1 + sbi/utils/sbiutils.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/sbi/neural_nets/density_estimators/zuko_flow.py b/sbi/neural_nets/density_estimators/zuko_flow.py index f1ef60980..d903345ee 100644 --- a/sbi/neural_nets/density_estimators/zuko_flow.py +++ b/sbi/neural_nets/density_estimators/zuko_flow.py @@ -6,6 +6,7 @@ from sbi.neural_nets.density_estimators.base import DensityEstimator from sbi.sbi_types import Shape +from sbi.utils.user_input_checks import check_condition_shape class ZukoFlow(DensityEstimator): diff --git a/sbi/utils/sbiutils.py b/sbi/utils/sbiutils.py index c223a3c36..0ec9923ab 100644 --- a/sbi/utils/sbiutils.py +++ b/sbi/utils/sbiutils.py @@ -176,7 +176,9 @@ def standardizing_transform( t_std[t_std < min_std] = min_std if backend == "nflows": - return transforms.PointwiseAffineTransform(shift=-t_mean / t_std, scale=1 / t_std) + return transforms.PointwiseAffineTransform( + shift=-t_mean / t_std, scale=1 / t_std + ) elif backend == "zuko": return zuko.flows.Unconditional( zuko.transforms.MonotonicAffineTransform, From 767acf588f6986cdaa4fbc166268eb0d8d0e9069 Mon Sep 17 00:00:00 2001 From: tommoral Date: Thu, 21 Mar 2024 09:23:42 +0100 Subject: [PATCH 04/13] FIX revert PAT warning fix --- sbi/utils/sbiutils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sbi/utils/sbiutils.py b/sbi/utils/sbiutils.py index 0ec9923ab..9e8555cdc 100644 --- a/sbi/utils/sbiutils.py +++ b/sbi/utils/sbiutils.py @@ -176,9 +176,7 @@ def standardizing_transform( t_std[t_std < min_std] = min_std if backend == "nflows": - return transforms.PointwiseAffineTransform( - shift=-t_mean / t_std, scale=1 / t_std - ) + return transforms.AffineTransform(shift=-t_mean / t_std, scale=1 / t_std) elif backend == "zuko": return zuko.flows.Unconditional( zuko.transforms.MonotonicAffineTransform, From 86d3f22e53f705658d1cc337b6ecc3cd1fe1e9db Mon Sep 17 00:00:00 2001 From: tommoral Date: Thu, 21 Mar 2024 09:43:44 +0100 Subject: [PATCH 05/13] FIX linter --- sbi/utils/user_input_checks.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sbi/utils/user_input_checks.py b/sbi/utils/user_input_checks.py index 23e9bcd69..bc0fd47d9 100644 --- a/sbi/utils/user_input_checks.py +++ b/sbi/utils/user_input_checks.py @@ -738,6 +738,7 @@ def validate_theta_and_x( return theta, x + def check_condition_shape(condition, condition_shape=None): r"""This method checks whether the condition has the correct shape. From 403f855c10ef586f624128eb2a49294b748b6ba6 Mon Sep 17 00:00:00 2001 From: tommoral Date: Thu, 21 Mar 2024 09:52:45 +0100 Subject: [PATCH 06/13] API remove abstractmethod loss --- sbi/neural_nets/density_estimators/base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sbi/neural_nets/density_estimators/base.py b/sbi/neural_nets/density_estimators/base.py index a35edfd64..efffa2ab1 100644 --- a/sbi/neural_nets/density_estimators/base.py +++ b/sbi/neural_nets/density_estimators/base.py @@ -57,7 +57,6 @@ def log_prob(self, input: Tensor, condition: Tensor, **kwargs) -> Tensor: """ ... - @abstractmethod def loss(self, input: Tensor, condition: Tensor, **kwargs) -> Tensor: r"""Return the loss for training the density estimator. @@ -68,7 +67,7 @@ def loss(self, input: Tensor, condition: Tensor, **kwargs) -> Tensor: Returns: Loss of shape (batch_size,) """ - ... + return - self.log_prob(input, condition, **kwargs) @abstractmethod def sample(self, sample_shape: torch.Size, condition: Tensor, **kwargs) -> Tensor: From 11688631efea13917ee308502cf5866b0c1565e8 Mon Sep 17 00:00:00 2001 From: tommoral Date: Thu, 21 Mar 2024 09:58:51 +0100 Subject: [PATCH 07/13] FIX linter --- sbi/neural_nets/density_estimators/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sbi/neural_nets/density_estimators/base.py b/sbi/neural_nets/density_estimators/base.py index efffa2ab1..28bb3f817 100644 --- a/sbi/neural_nets/density_estimators/base.py +++ b/sbi/neural_nets/density_estimators/base.py @@ -67,7 +67,7 @@ def loss(self, input: Tensor, condition: Tensor, **kwargs) -> Tensor: Returns: Loss of shape (batch_size,) """ - return - self.log_prob(input, condition, **kwargs) + return -self.log_prob(input, condition, **kwargs) @abstractmethod def sample(self, sample_shape: torch.Size, condition: Tensor, **kwargs) -> Tensor: From 3b5af99eeed18f8f5223f8e8ff71860b85ce463f Mon Sep 17 00:00:00 2001 From: tommoral Date: Thu, 21 Mar 2024 11:02:31 +0100 Subject: [PATCH 08/13] FIX init SNPE_A_MDN --- sbi/inference/snpe/snpe_a.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sbi/inference/snpe/snpe_a.py b/sbi/inference/snpe/snpe_a.py index 6b7d4df00..0dfa45bcb 100644 --- a/sbi/inference/snpe/snpe_a.py +++ b/sbi/inference/snpe/snpe_a.py @@ -397,8 +397,10 @@ def __init__( prior: The prior distribution. """ # Call nn.Module's constructor. + super().__init__() - super().__init__(flow, flow._condition_shape) + self.net = flow + self._condition_shape = condition_shape self._neural_net = flow self._prior = prior From d47f4a76ae82abe08c1147d7f2c9caa873122a10 Mon Sep 17 00:00:00 2001 From: tommoral Date: Thu, 21 Mar 2024 11:29:43 +0100 Subject: [PATCH 09/13] FIX --- sbi/inference/snpe/snpe_a.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sbi/inference/snpe/snpe_a.py b/sbi/inference/snpe/snpe_a.py index 0dfa45bcb..22989d0c4 100644 --- a/sbi/inference/snpe/snpe_a.py +++ b/sbi/inference/snpe/snpe_a.py @@ -400,7 +400,7 @@ def __init__( super().__init__() self.net = flow - self._condition_shape = condition_shape + self._condition_shape = flow._condition_shape self._neural_net = flow self._prior = prior From e8bb62cf319d52647ca85e609e0d21abfa0c0aeb Mon Sep 17 00:00:00 2001 From: tommoral Date: Fri, 22 Mar 2024 02:26:02 +0100 Subject: [PATCH 10/13] ENH add new class for hierarchical density estimator+helpers --- .../hierarchical_estimator.py | 140 ++++++++++++++++++ sbi/neural_nets/factory.py | 50 +++++++ 2 files changed, 190 insertions(+) create mode 100644 sbi/neural_nets/density_estimators/hierarchical_estimator.py diff --git a/sbi/neural_nets/density_estimators/hierarchical_estimator.py b/sbi/neural_nets/density_estimators/hierarchical_estimator.py new file mode 100644 index 000000000..b863c9f25 --- /dev/null +++ b/sbi/neural_nets/density_estimators/hierarchical_estimator.py @@ -0,0 +1,140 @@ +import torch +import functools + + +from sbi.neural_nets.density_estimators import DensityEstimator + + +def split_hierarchical(theta, dim_local): + return theta[..., :dim_local], theta[..., dim_local:] + + +def hierachical_simulator(n_extra, dim_local, p_local, simulator=None): + """Return a hierachical simulator, which returns extra observations. + """ + if simulator is None: + return functools.partial( + hierachical_simulator, n_extra, dim_local, p_local) + + + def h_simulator(theta): + assert theta.ndim == 2, "Hierarchical simulator only work with vector parameters" + n_batch, theta_dim = theta.shape + local_theta, global_theta = split_hierarchical(theta, dim_local) + extra_local = p_local.sample((n_batch, n_extra)) + all_theta_local = torch.concatenate( + (local_theta[:, None], extra_local), dim=1 + ) + all_theta = torch.concatenate( + (all_theta_local, global_theta.repeat([n_extra+1, 1]).view(n_batch, n_extra+1, -1)), dim=2 + ) + observation = simulator(all_theta.view(n_batch * (n_extra+1), -1)) + return observation.view((n_batch, n_extra + 1, *observation.shape[1:])) + + return h_simulator + + +class HierarchicalDensityEstimator(DensityEstimator): + + def __init__( + self, local_flow, global_flow, dim_local, condition_shape, + embedding_net=torch.nn.Identity() + ): + + super().__init__() + + self.dim_local = dim_local + self.local_flow = local_flow + self.global_flow = global_flow + self._embedding_net = embedding_net + self._condition_shape = condition_shape + + @property + def embedding_net(self): + return self._embedding_net + + + @staticmethod + def embed_condition(embedding_net, condition, condition_shape): + '''Embed the condition for the hierarchical flow + + Parameters + ---------- + condition: torch.Tensor, shape (n_batch, n_extra + 1, *condition_shape) + The hierarchical condition. + + Returns + ------- + global_condition: torch.Tensor, shape (n_batch, 2*n_embed) + local_condition: torch.Tensor, shape (n_batch, n_embed) + ''' + if condition.ndim < len(condition_shape): + raise ValueError( + "condition should be at least with shape (n_extra, *condition_shape) " + f"but got {condition.shape}. This is likely because there is no " + "extra observations." + ) + elif condition.ndim == len(condition_shape): + batch_condition_shape, n_extra = (), condition.shape[0] + else: + *batch_condition_shape, n_extra = condition.shape[:-len(condition_shape)+1] + condition_shape = condition_shape[1:] # remove n_extra + embedded_condition = embedding_net( + condition.view(-1, *condition_shape) + ).reshape(*batch_condition_shape, n_extra, -1) + + batch_slice = tuple(slice(None) for _ in range(len(batch_condition_shape))) + local_slice = (*batch_slice, slice(1)) + agg_slice = (*batch_slice, slice(1, None)) + + local_condition = embedded_condition[local_slice] + agg_condition = torch.mean( + embedded_condition[agg_slice], + dim=len(batch_condition_shape), keepdim=True + ) + global_condition = torch.concatenate( + (local_condition, agg_condition), dim=len(batch_condition_shape) + ) + return ( + local_condition.view(*batch_condition_shape, -1), + global_condition.view(*batch_condition_shape, -1), + ) + + def log_prob(self, theta, condition): + local_theta, global_theta = split_hierarchical(theta, self.dim_local) + local_condition, global_condition = self.embed_condition( + self.embedding_net, condition, self._condition_shape + ) + + log_p_global = self.global_flow.log_prob( + global_theta, global_condition + ) + + local_condition = torch.concatenate( + (local_condition, global_theta), dim=-1 + ) + log_p_local = self.local_flow.log_prob( + local_theta, local_condition + ) + return log_p_global + log_p_local + + def loss(self, inputs, condition): + return -self.log_prob(inputs, condition) + + def sample(self, sample_shape, condition): + local_condition, global_condition = self.embed_condition( + self.embedding_net, condition, self._condition_shape + ) + + # shape (n_samples, 1) + global_samples = self.global_flow.sample( + sample_shape, global_condition + ) + local_condition = torch.concatenate( + (local_condition.repeat((*sample_shape, 1)), global_samples), dim=-1 + ) + local_samples = self.local_flow.sample((1,), local_condition)[:, 0] + + samples = torch.cat([local_samples, global_samples], dim=-1) + return samples + diff --git a/sbi/neural_nets/factory.py b/sbi/neural_nets/factory.py index 1273b4587..c57f2116e 100644 --- a/sbi/neural_nets/factory.py +++ b/sbi/neural_nets/factory.py @@ -4,8 +4,14 @@ from typing import Any, Callable, Optional +import torch from torch import nn +from sbi.neural_nets.density_estimators.hierarchical_estimator import ( + HierarchicalDensityEstimator, + split_hierarchical +) + from sbi.neural_nets.classifier import ( build_linear_classifier, build_mlp_classifier, @@ -290,3 +296,47 @@ def build_fn(batch_theta, batch_x): kwargs.pop("num_components") return build_fn_snpe_a if model == "mdn_snpe_a" else build_fn + +def hierarchical_nn( + build_global_flow: Callable, + build_local_flow: Callable, + dim_local: int = 1, + embedding_net: nn.Module = nn.Identity(), +) -> Callable: + r""" + Returns a function that builds a density estimator for learning the posterior. + + This function will usually be used for SNPE. The returned function is to be passed + to the inference class when using the flexible interface. + + Args: + build_global_flow, build_global_flow: Build function to create the global + and local flow. + dim_local: Number of dimension of the local parameters in theta. + embedding_net: Optional embedding network for simulation outputs $x$. This + embedding net allows to learn features from potentially high-dimensional + simulation outputs. + """ + def build_hierarchical(batch_theta, batch_condition): + assert batch_theta.ndim == 2, "Only working with 1D theta for now." + local_theta, global_theta = split_hierarchical(batch_theta, dim_local) + local_condition, global_condition = HierarchicalDensityEstimator.embed_condition( + embedding_net, batch_condition, batch_condition.shape[1:] + ) + + global_flow = build_global_flow( + batch_x=global_theta, batch_y=global_condition + ) + local_condition = torch.concatenate( + (local_condition, global_theta), dim=-1 + ) + local_flow = build_local_flow( + batch_x=local_theta, batch_y=local_condition + ) + + return HierarchicalDensityEstimator( + local_flow, global_flow, dim_local, batch_condition.shape[1:], + embedding_net=embedding_net + ) + + return build_hierarchical From 4df86ebba51283129b3799134e4351532b891e47 Mon Sep 17 00:00:00 2001 From: tommoral Date: Fri, 22 Mar 2024 02:26:23 +0100 Subject: [PATCH 11/13] DOC add example of use case for hNPE --- examples/02_hNPE_with_extra_observation.ipynb | 315 ++++++++++++++++++ 1 file changed, 315 insertions(+) create mode 100644 examples/02_hNPE_with_extra_observation.ipynb diff --git a/examples/02_hNPE_with_extra_observation.ipynb b/examples/02_hNPE_with_extra_observation.ipynb new file mode 100644 index 000000000..ffe1dfab1 --- /dev/null +++ b/examples/02_hNPE_with_extra_observation.ipynb @@ -0,0 +1,315 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Hierarchical Neural Posterior Estimator\n", + "\n", + "\n", + "Consider a ill posed problem $x = \\alpha \\dot \\beta$.\n", + "When using SBI on such problem, we obtain a posterior with infinitely many equivalent solutions: " + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-03-22 02:21:39.984406: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", + "2024-03-22 02:21:39.986655: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.\n", + "2024-03-22 02:21:40.034183: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.\n", + "2024-03-22 02:21:40.035206: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", + "To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", + "2024-03-22 02:21:40.928558: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "5d3ef23c301d46c88550bc800a6e084c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Running 100 simulations.: 0%| | 0/100 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import torch\n", + "\n", + "from sbi import analysis\n", + "\n", + "from sbi.inference import SNPE\n", + "from sbi.inference import prepare_for_sbi\n", + "from sbi.inference import simulate_for_sbi\n", + "from sbi.neural_nets.factory import build_maf\n", + "from sbi.utils import BoxUniform\n", + "\n", + "\n", + "prior = BoxUniform(\n", + " low=torch.tensor([0.0, 0.0]),\n", + " high=torch.tensor([1.0, 1.0])\n", + ")\n", + "\n", + "\n", + "def simulator(theta):\n", + " alpha, beta = theta[..., :1], theta[..., 1:]\n", + " return (alpha * beta).reshape(-1, 1)\n", + "\n", + "\n", + "simulator, prior = prepare_for_sbi(simulator, prior)\n", + "theta, x = simulate_for_sbi(simulator, proposal=prior, num_simulations=100)\n", + "\n", + "# setup the inference procedure\n", + "inference = SNPE(\n", + " prior=prior,\n", + " density_estimator=build_maf,\n", + " show_progress_bars=True,\n", + " device='cpu',\n", + ")\n", + "density_estimator = inference.append_simulations(theta, x).train(\n", + " num_atoms=10,\n", + " training_batch_size=10,\n", + ")\n", + "posterior = inference.build_posterior(density_estimator)\n", + "\n", + "# plot posterior samples\n", + "x0 = torch.tensor([0.25])\n", + "posterior_samples = posterior.sample((10000,), x=x0)\n", + "_ = analysis.pairplot(\n", + " posterior_samples, limits=[[0, 1], [0, 1], [0, 1]], figsize=(5, 5)\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This ambiguity can some time be overcome with extra observation which share a common parameter. This is the purpose of hNPE.\n", + "To do this, let's first build a simulator with extra observation.\n", + "For this, we can use the `hierachical_simulator` helper, which can decorate a classical simulator, with parameter `n_extra` for the number of extra observation and the parameter `p_local` for drawing extra value for the non-common parameters, call here `local` parameters.\n", + "We also need to specify how many parameters are local." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "f9f6842c9c93409881a44129b8c1c418", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Running 100 simulations.: 0%| | 0/100 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "\n", + "# plot posterior samples\n", + "x0 = simulator(torch.tensor([[0.5, 0.5]]))[0]\n", + "posterior_samples = posterior.sample((10000,), x=x0)\n", + "_ = analysis.pairplot(\n", + " posterior_samples, limits=[[0, 1], [0, 1], [0, 1]], figsize=(5, 5)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From fd0ec1f20246c237196ec00fb8d186b54640847a Mon Sep 17 00:00:00 2001 From: tommoral Date: Fri, 22 Mar 2024 14:58:40 +0100 Subject: [PATCH 12/13] FIX linter --- .../density_estimators/hierarchical_estimator.py | 11 ++++++++--- sbi/neural_nets/factory.py | 15 ++++++++------- sbi/neural_nets/flow.py | 3 ++- 3 files changed, 18 insertions(+), 11 deletions(-) diff --git a/sbi/neural_nets/density_estimators/hierarchical_estimator.py b/sbi/neural_nets/density_estimators/hierarchical_estimator.py index b863c9f25..cd8fe3b1d 100644 --- a/sbi/neural_nets/density_estimators/hierarchical_estimator.py +++ b/sbi/neural_nets/density_estimators/hierarchical_estimator.py @@ -1,6 +1,6 @@ -import torch import functools +import torch from sbi.neural_nets.density_estimators import DensityEstimator @@ -18,7 +18,11 @@ def hierachical_simulator(n_extra, dim_local, p_local, simulator=None): def h_simulator(theta): - assert theta.ndim == 2, "Hierarchical simulator only work with vector parameters" + msg = ( + "Hierarchical simulator only work with vector parameters, with " + f"of shape (n_batch, theta_dim). Got {theta.shape}." + ) + assert theta.ndim == 2, msg n_batch, theta_dim = theta.shape local_theta, global_theta = split_hierarchical(theta, dim_local) extra_local = p_local.sample((n_batch, n_extra)) @@ -26,7 +30,8 @@ def h_simulator(theta): (local_theta[:, None], extra_local), dim=1 ) all_theta = torch.concatenate( - (all_theta_local, global_theta.repeat([n_extra+1, 1]).view(n_batch, n_extra+1, -1)), dim=2 + (all_theta_local, global_theta.repeat([n_extra+1, 1]) + .view(n_batch, n_extra+1, -1)), dim=2 ) observation = simulator(all_theta.view(n_batch * (n_extra+1), -1)) return observation.view((n_batch, n_extra + 1, *observation.shape[1:])) diff --git a/sbi/neural_nets/factory.py b/sbi/neural_nets/factory.py index c57f2116e..27509c94f 100644 --- a/sbi/neural_nets/factory.py +++ b/sbi/neural_nets/factory.py @@ -7,16 +7,15 @@ import torch from torch import nn -from sbi.neural_nets.density_estimators.hierarchical_estimator import ( - HierarchicalDensityEstimator, - split_hierarchical -) - from sbi.neural_nets.classifier import ( build_linear_classifier, build_mlp_classifier, build_resnet_classifier, ) +from sbi.neural_nets.density_estimators.hierarchical_estimator import ( + HierarchicalDensityEstimator, + split_hierarchical, +) from sbi.neural_nets.flow import ( build_made, build_maf, @@ -320,8 +319,10 @@ def hierarchical_nn( def build_hierarchical(batch_theta, batch_condition): assert batch_theta.ndim == 2, "Only working with 1D theta for now." local_theta, global_theta = split_hierarchical(batch_theta, dim_local) - local_condition, global_condition = HierarchicalDensityEstimator.embed_condition( - embedding_net, batch_condition, batch_condition.shape[1:] + local_condition, global_condition = ( + HierarchicalDensityEstimator.embed_condition( + embedding_net, batch_condition, batch_condition.shape[1:] + ) ) global_flow = build_global_flow( diff --git a/sbi/neural_nets/flow.py b/sbi/neural_nets/flow.py index 23c531654..f05278df6 100644 --- a/sbi/neural_nets/flow.py +++ b/sbi/neural_nets/flow.py @@ -508,7 +508,8 @@ def build_zuko_maf( # transforms = transforms transforms = ( *transforms, - # Ideally `standardizing_transform` would return a `LazyTransform` instead of ` AffineTransform | Unconditional`, maybe all three are compatible + # Ideally `standardizing_transform` would return a `LazyTransform` instead + # of `AffineTransform | Unconditional`, maybe all three are compatible standardizing_transform(batch_x, structured_x, backend="zuko"), # pyright: ignore[reportAssignmentType] ) From e452ad32082311ae5be93ba261cab1cb9550e5ef Mon Sep 17 00:00:00 2001 From: tommoral Date: Fri, 22 Mar 2024 16:29:05 +0100 Subject: [PATCH 13/13] FIX pyright --- sbi/neural_nets/density_estimators/hierarchical_estimator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sbi/neural_nets/density_estimators/hierarchical_estimator.py b/sbi/neural_nets/density_estimators/hierarchical_estimator.py index cd8fe3b1d..7c5f019ee 100644 --- a/sbi/neural_nets/density_estimators/hierarchical_estimator.py +++ b/sbi/neural_nets/density_estimators/hierarchical_estimator.py @@ -43,7 +43,7 @@ class HierarchicalDensityEstimator(DensityEstimator): def __init__( self, local_flow, global_flow, dim_local, condition_shape, - embedding_net=torch.nn.Identity() + embedding_net: torch.nn.Module = torch.nn.Identity() ): super().__init__()