Skip to content

Commit

Permalink
add option to pass dtype to nflows base, set default float32.
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Feb 9, 2024
1 parent 9ce7cfa commit ce04fa0
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 4 deletions.
20 changes: 17 additions & 3 deletions sbi/neural_nets/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Optional
from warnings import warn

import torch
from pyknos.nflows import distributions as distributions_
from pyknos.nflows import flows, transforms
from pyknos.nflows.nn import nets
Expand Down Expand Up @@ -174,7 +175,8 @@ def build_maf(
# Combine transforms.
transform = transforms.CompositeTransform(transform_list)

distribution = distributions_.StandardNormal((x_numel,))
# Pass on kwargs to allow for different dtype.
distribution = get_base_dist(x_numel, **kwargs)
neural_net = flows.Flow(transform, distribution, embedding_net)

return neural_net
Expand Down Expand Up @@ -284,7 +286,8 @@ def build_maf_rqs(
# Combine transforms.
transform = transforms.CompositeTransform(transform_list)

distribution = distributions_.StandardNormal((x_numel,))
# Pass on kwargs to allow for different dtype.
distribution = get_base_dist(x_numel, **kwargs)
neural_net = flows.Flow(transform, distribution, embedding_net)

return neural_net
Expand Down Expand Up @@ -402,7 +405,8 @@ def mask_in_layer(i):
standardizing_net(batch_y, structured_y), embedding_net
)

distribution = distributions_.StandardNormal((x_numel,))
# Pass on kwargs to allow for different dtype.
distribution = get_base_dist(x_numel, **kwargs)

# Combine transforms.
transform = transforms.CompositeTransform(transform_list)
Expand Down Expand Up @@ -471,3 +475,13 @@ def __call__(self, inputs: Tensor, context: Tensor, *args, **kwargs) -> Tensor:
Spline parameters.
"""
return self.spline_predictor(context)


def get_base_dist(
num_dims: int, dtype: torch.dtype = torch.float32, **kwargs
) -> distributions_.Distribution:
"""Returns the base distribution for the flows with given float type."""

base = distributions_.StandardNormal((num_dims,))
base._log_z = base._log_z.to(dtype)
return base
6 changes: 5 additions & 1 deletion tests/inference_on_device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,11 @@ def simulator(theta):
train_kwargs = dict(force_first_round_loss=True)
elif method == SNLE:
kwargs = dict(
density_estimator=utils.likelihood_nn(model=model, num_transforms=2)
density_estimator=utils.likelihood_nn(
model=model,
num_transforms=2,
dtype=torch.float32, # test passing dtype kwarg to base distribution.
)
)
train_kwargs = dict()
elif method in (SNRE_A, SNRE_B, SNRE_C):
Expand Down

0 comments on commit ce04fa0

Please sign in to comment.