From 2591496f322d9117ab2eeb26c199b229c8141a24 Mon Sep 17 00:00:00 2001 From: janfb Date: Thu, 25 Jan 2024 10:18:36 +0100 Subject: [PATCH] refactor tests. --- sbi/utils/torchutils.py | 2 +- tests/inference_on_device_test.py | 98 +++++++++++++++---------------- tests/torchutils_test.py | 12 +++- 3 files changed, 60 insertions(+), 52 deletions(-) diff --git a/sbi/utils/torchutils.py b/sbi/utils/torchutils.py index 1b30e244e..a0184491d 100644 --- a/sbi/utils/torchutils.py +++ b/sbi/utils/torchutils.py @@ -82,7 +82,7 @@ def check_device(device: str) -> None: f"""Could not instantiate torch.randn(1, device={device}). Make sure the device is set up properly and that you are passing the corresponding device string. It should be something like 'cuda', - 'cuda:0', or 'mps'.""" + 'cuda:0', or 'mps'. Error message: {exc}.""" ) from exc diff --git a/tests/inference_on_device_test.py b/tests/inference_on_device_test.py index dddecef21..55d7bda79 100644 --- a/tests/inference_on_device_test.py +++ b/tests/inference_on_device_test.py @@ -46,22 +46,26 @@ @pytest.mark.slow @pytest.mark.gpu @pytest.mark.parametrize( - "method, model, mcmc_method", + "method, model, sampling_method", [ + (SNPE_C, "maf", "direct"), (SNPE_C, "mdn", "rejection"), (SNPE_C, "maf", "slice_np_vectorized"), - (SNPE_C, "maf", "direct"), + (SNPE_C, "mdn", "slice"), (SNLE, "nsf", "slice_np_vectorized"), + (SNLE, "mdn", "slice"), (SNLE, "nsf", "rejection"), (SNLE, "maf", "importance"), (SNRE_A, "mlp", "slice_np_vectorized"), + (SNRE_A, "mlp", "slice"), (SNRE_B, "resnet", "rejection"), (SNRE_B, "resnet", "importance"), + (SNRE_B, "resnet", "slice"), (SNRE_C, "resnet", "rejection"), (SNRE_C, "resnet", "importance"), + (SNRE_C, "resnet", "nuts"), ], ) -@pytest.mark.parametrize("data_device", ("cpu", "gpu")) @pytest.mark.parametrize( "training_device, prior_device", [ @@ -70,14 +74,14 @@ pytest.param("gpu", "cpu", marks=pytest.mark.xfail), ], ) +@pytest.mark.parametrize("prior_type", ["gaussian", "uniform"]) def test_training_and_mcmc_on_device( method, model, - data_device, - mcmc_method, + sampling_method, training_device, prior_device, - prior_type="gaussian", + prior_type, ): """Test training on devices. @@ -86,13 +90,14 @@ def test_training_and_mcmc_on_device( """ training_device = process_device(training_device) - data_device = process_device(data_device) + data_device = "cpu" prior_device = process_device(prior_device) num_dim = 2 num_samples = 10 num_simulations = 100 max_num_epochs = 5 + num_rounds = 2 # test proposal sampling in round 2. x_o = zeros(1, num_dim).to(data_device) likelihood_shift = -1.0 * ones(num_dim).to(prior_device) @@ -128,59 +133,54 @@ def simulator(theta): else: raise ValueError() - inferer = method(show_progress_bars=False, device=training_device, **kwargs) + inferer = method( + prior=prior, show_progress_bars=False, device=training_device, **kwargs + ) proposals = [prior] - # Test for two rounds. - for _ in range(2): - theta, x = simulate_for_sbi(simulator, proposals[-1], num_simulations) - theta, x = theta.to(data_device), x.to(data_device) + for _ in range(num_rounds): + theta = proposals[-1].sample((num_simulations,)) + x = simulator(theta).to(data_device) + theta = theta.to(data_device) estimator = inferer.append_simulations(theta, x, data_device=data_device).train( training_batch_size=100, max_num_epochs=max_num_epochs, **train_kwargs ) - if method == SNLE: - potential_fn, theta_transform = likelihood_estimator_based_potential( - estimator, prior, x_o - ) - elif method == SNPE_A or method == SNPE_C: - potential_fn, theta_transform = posterior_estimator_based_potential( - estimator, prior, x_o - ) - elif method == SNRE_A or method == SNRE_B or method == SNRE_C: - potential_fn, theta_transform = ratio_estimator_based_potential( - estimator, prior, x_o - ) - else: - raise ValueError - if mcmc_method == "rejection": - posterior = RejectionPosterior( - proposal=prior, - potential_fn=potential_fn, - device=training_device, + # mcmc cases + if sampling_method in ["slice", "slice_np", "slice_np_vectorized", "nuts"]: + posterior = inferer.build_posterior( + sample_with="mcmc", + mcmc_method=sampling_method, + mcmc_parameters=dict(thin=5, num_chains=10), ) - elif mcmc_method == "direct": - posterior = DirectPosterior( - posterior_estimator=estimator, prior=prior - ).set_default_x(x_o) - elif mcmc_method == "importance": - posterior = ImportanceSamplingPosterior( - potential_fn=potential_fn, proposal=prior + elif sampling_method in ["rejection", "direct"]: + # all other cases: rejection, direct + posterior = inferer.build_posterior( + sample_with="rejection" + if sampling_method == "direct" + else sampling_method, + rejection_sampling_parameters={"proposal": prior} + if sampling_method == "rejection" and method == SNPE_C + else {}, ) else: - posterior = MCMCPosterior( - potential_fn=potential_fn, - theta_transform=theta_transform, - proposal=prior, - method=mcmc_method, - device=training_device, - # for speed - num_chains=10, - thin=1, + # build potential for SNLE or SNRE and construct ImportanceSamplingPosterior + if method == SNLE: + potential_fn, theta_transform = likelihood_estimator_based_potential( + estimator, prior, x_o + ) + elif method in [SNRE_A, SNRE_B, SNRE_C]: + potential_fn, theta_transform = ratio_estimator_based_potential( + estimator, prior, x_o + ) + else: + raise ValueError() + posterior = ImportanceSamplingPosterior( + potential_fn, prior, theta_transform ) - proposals.append(posterior) + proposals.append(posterior.set_default_x(x_o)) # Check for default device for inference object weights_device = next(inferer._neural_net.parameters()).device @@ -195,7 +195,6 @@ def simulator(theta): def test_check_embedding_net_device( device_datum: str, device_embedding_net: str ) -> None: - device_datum = process_device(device_datum) device_embedding_net = process_device(device_embedding_net) @@ -249,7 +248,6 @@ def test_validate_theta_and_x_type() -> None: @pytest.mark.parametrize("training_device", ["cpu", "gpu"]) @pytest.mark.parametrize("data_device", ["cpu", "gpu"]) def test_validate_theta_and_x_device(training_device: str, data_device: str) -> None: - training_device = process_device(training_device) data_device = process_device(data_device) diff --git a/tests/torchutils_test.py b/tests/torchutils_test.py index cb6983208..0c0687d5b 100644 --- a/tests/torchutils_test.py +++ b/tests/torchutils_test.py @@ -3,6 +3,7 @@ """Test PyTorch utility functions.""" from __future__ import annotations + from typing import Optional import numpy as np @@ -200,7 +201,16 @@ def test_dkl_gauss(): ) -@pytest.mark.parametrize("device_input", ("cpu", "gpu", "cuda", "cuda:0", "mps", )) +@pytest.mark.parametrize( + "device_input", + ( + "cpu", + "gpu", + "cuda", + "cuda:0", + "mps", + ), +) def test_process_device(device_input: str) -> None: """Test whether the device is processed correctly."""