From ce04fa0f20ddc144331de982d6210a5382a1e3b9 Mon Sep 17 00:00:00 2001 From: Jan Boelts Date: Fri, 9 Feb 2024 16:11:36 +0100 Subject: [PATCH] add option to pass dtype to nflows base, set default float32. --- sbi/neural_nets/flow.py | 20 +++++++++++++++++--- tests/inference_on_device_test.py | 6 +++++- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/sbi/neural_nets/flow.py b/sbi/neural_nets/flow.py index 49da8af77..54c5c46fc 100644 --- a/sbi/neural_nets/flow.py +++ b/sbi/neural_nets/flow.py @@ -6,6 +6,7 @@ from typing import Optional from warnings import warn +import torch from pyknos.nflows import distributions as distributions_ from pyknos.nflows import flows, transforms from pyknos.nflows.nn import nets @@ -174,7 +175,8 @@ def build_maf( # Combine transforms. transform = transforms.CompositeTransform(transform_list) - distribution = distributions_.StandardNormal((x_numel,)) + # Pass on kwargs to allow for different dtype. + distribution = get_base_dist(x_numel, **kwargs) neural_net = flows.Flow(transform, distribution, embedding_net) return neural_net @@ -284,7 +286,8 @@ def build_maf_rqs( # Combine transforms. transform = transforms.CompositeTransform(transform_list) - distribution = distributions_.StandardNormal((x_numel,)) + # Pass on kwargs to allow for different dtype. + distribution = get_base_dist(x_numel, **kwargs) neural_net = flows.Flow(transform, distribution, embedding_net) return neural_net @@ -402,7 +405,8 @@ def mask_in_layer(i): standardizing_net(batch_y, structured_y), embedding_net ) - distribution = distributions_.StandardNormal((x_numel,)) + # Pass on kwargs to allow for different dtype. + distribution = get_base_dist(x_numel, **kwargs) # Combine transforms. transform = transforms.CompositeTransform(transform_list) @@ -471,3 +475,13 @@ def __call__(self, inputs: Tensor, context: Tensor, *args, **kwargs) -> Tensor: Spline parameters. """ return self.spline_predictor(context) + + +def get_base_dist( + num_dims: int, dtype: torch.dtype = torch.float32, **kwargs +) -> distributions_.Distribution: + """Returns the base distribution for the flows with given float type.""" + + base = distributions_.StandardNormal((num_dims,)) + base._log_z = base._log_z.to(dtype) + return base diff --git a/tests/inference_on_device_test.py b/tests/inference_on_device_test.py index 65458a3c2..f02e2db37 100644 --- a/tests/inference_on_device_test.py +++ b/tests/inference_on_device_test.py @@ -116,7 +116,11 @@ def simulator(theta): train_kwargs = dict(force_first_round_loss=True) elif method == SNLE: kwargs = dict( - density_estimator=utils.likelihood_nn(model=model, num_transforms=2) + density_estimator=utils.likelihood_nn( + model=model, + num_transforms=2, + dtype=torch.float32, # test passing dtype kwarg to base distribution. + ) ) train_kwargs = dict() elif method in (SNRE_A, SNRE_B, SNRE_C):