Skip to content

Commit

Permalink
refactor tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Feb 16, 2024
1 parent 14da16e commit 2591496
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 52 deletions.
2 changes: 1 addition & 1 deletion sbi/utils/torchutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def check_device(device: str) -> None:
f"""Could not instantiate torch.randn(1, device={device}). Make sure
the device is set up properly and that you are passing the
corresponding device string. It should be something like 'cuda',
'cuda:0', or 'mps'."""
'cuda:0', or 'mps'. Error message: {exc}."""
) from exc


Expand Down
98 changes: 48 additions & 50 deletions tests/inference_on_device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,22 +46,26 @@
@pytest.mark.slow
@pytest.mark.gpu
@pytest.mark.parametrize(
"method, model, mcmc_method",
"method, model, sampling_method",
[
(SNPE_C, "maf", "direct"),
(SNPE_C, "mdn", "rejection"),
(SNPE_C, "maf", "slice_np_vectorized"),
(SNPE_C, "maf", "direct"),
(SNPE_C, "mdn", "slice"),
(SNLE, "nsf", "slice_np_vectorized"),
(SNLE, "mdn", "slice"),
(SNLE, "nsf", "rejection"),
(SNLE, "maf", "importance"),
(SNRE_A, "mlp", "slice_np_vectorized"),
(SNRE_A, "mlp", "slice"),
(SNRE_B, "resnet", "rejection"),
(SNRE_B, "resnet", "importance"),
(SNRE_B, "resnet", "slice"),
(SNRE_C, "resnet", "rejection"),
(SNRE_C, "resnet", "importance"),
(SNRE_C, "resnet", "nuts"),
],
)
@pytest.mark.parametrize("data_device", ("cpu", "gpu"))
@pytest.mark.parametrize(
"training_device, prior_device",
[
Expand All @@ -70,14 +74,14 @@
pytest.param("gpu", "cpu", marks=pytest.mark.xfail),
],
)
@pytest.mark.parametrize("prior_type", ["gaussian", "uniform"])
def test_training_and_mcmc_on_device(
method,
model,
data_device,
mcmc_method,
sampling_method,
training_device,
prior_device,
prior_type="gaussian",
prior_type,
):
"""Test training on devices.
Expand All @@ -86,13 +90,14 @@ def test_training_and_mcmc_on_device(
"""

training_device = process_device(training_device)
data_device = process_device(data_device)
data_device = "cpu"
prior_device = process_device(prior_device)

num_dim = 2
num_samples = 10
num_simulations = 100
max_num_epochs = 5
num_rounds = 2 # test proposal sampling in round 2.

x_o = zeros(1, num_dim).to(data_device)
likelihood_shift = -1.0 * ones(num_dim).to(prior_device)
Expand Down Expand Up @@ -128,59 +133,54 @@ def simulator(theta):
else:
raise ValueError()

inferer = method(show_progress_bars=False, device=training_device, **kwargs)
inferer = method(
prior=prior, show_progress_bars=False, device=training_device, **kwargs
)

proposals = [prior]

# Test for two rounds.
for _ in range(2):
theta, x = simulate_for_sbi(simulator, proposals[-1], num_simulations)
theta, x = theta.to(data_device), x.to(data_device)
for _ in range(num_rounds):
theta = proposals[-1].sample((num_simulations,))
x = simulator(theta).to(data_device)
theta = theta.to(data_device)

estimator = inferer.append_simulations(theta, x, data_device=data_device).train(
training_batch_size=100, max_num_epochs=max_num_epochs, **train_kwargs
)
if method == SNLE:
potential_fn, theta_transform = likelihood_estimator_based_potential(
estimator, prior, x_o
)
elif method == SNPE_A or method == SNPE_C:
potential_fn, theta_transform = posterior_estimator_based_potential(
estimator, prior, x_o
)
elif method == SNRE_A or method == SNRE_B or method == SNRE_C:
potential_fn, theta_transform = ratio_estimator_based_potential(
estimator, prior, x_o
)
else:
raise ValueError

if mcmc_method == "rejection":
posterior = RejectionPosterior(
proposal=prior,
potential_fn=potential_fn,
device=training_device,
# mcmc cases
if sampling_method in ["slice", "slice_np", "slice_np_vectorized", "nuts"]:
posterior = inferer.build_posterior(
sample_with="mcmc",
mcmc_method=sampling_method,
mcmc_parameters=dict(thin=5, num_chains=10),
)
elif mcmc_method == "direct":
posterior = DirectPosterior(
posterior_estimator=estimator, prior=prior
).set_default_x(x_o)
elif mcmc_method == "importance":
posterior = ImportanceSamplingPosterior(
potential_fn=potential_fn, proposal=prior
elif sampling_method in ["rejection", "direct"]:
# all other cases: rejection, direct
posterior = inferer.build_posterior(
sample_with="rejection"
if sampling_method == "direct"
else sampling_method,
rejection_sampling_parameters={"proposal": prior}
if sampling_method == "rejection" and method == SNPE_C
else {},
)
else:
posterior = MCMCPosterior(
potential_fn=potential_fn,
theta_transform=theta_transform,
proposal=prior,
method=mcmc_method,
device=training_device,
# for speed
num_chains=10,
thin=1,
# build potential for SNLE or SNRE and construct ImportanceSamplingPosterior
if method == SNLE:
potential_fn, theta_transform = likelihood_estimator_based_potential(
estimator, prior, x_o
)
elif method in [SNRE_A, SNRE_B, SNRE_C]:
potential_fn, theta_transform = ratio_estimator_based_potential(
estimator, prior, x_o
)
else:
raise ValueError()
posterior = ImportanceSamplingPosterior(
potential_fn, prior, theta_transform
)
proposals.append(posterior)
proposals.append(posterior.set_default_x(x_o))

# Check for default device for inference object
weights_device = next(inferer._neural_net.parameters()).device
Expand All @@ -195,7 +195,6 @@ def simulator(theta):
def test_check_embedding_net_device(
device_datum: str, device_embedding_net: str
) -> None:

device_datum = process_device(device_datum)
device_embedding_net = process_device(device_embedding_net)

Expand Down Expand Up @@ -249,7 +248,6 @@ def test_validate_theta_and_x_type() -> None:
@pytest.mark.parametrize("training_device", ["cpu", "gpu"])
@pytest.mark.parametrize("data_device", ["cpu", "gpu"])
def test_validate_theta_and_x_device(training_device: str, data_device: str) -> None:

training_device = process_device(training_device)
data_device = process_device(data_device)

Expand Down
12 changes: 11 additions & 1 deletion tests/torchutils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

"""Test PyTorch utility functions."""
from __future__ import annotations

from typing import Optional

import numpy as np
Expand Down Expand Up @@ -200,7 +201,16 @@ def test_dkl_gauss():
)


@pytest.mark.parametrize("device_input", ("cpu", "gpu", "cuda", "cuda:0", "mps", ))
@pytest.mark.parametrize(
"device_input",
(
"cpu",
"gpu",
"cuda",
"cuda:0",
"mps",
),
)
def test_process_device(device_input: str) -> None:
"""Test whether the device is processed correctly."""

Expand Down

0 comments on commit 2591496

Please sign in to comment.