Skip to content

Commit

Permalink
refactor device tests
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Jan 17, 2024
1 parent d28bcfe commit 14da16e
Showing 1 changed file with 56 additions and 62 deletions.
118 changes: 56 additions & 62 deletions tests/inference_on_device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from sbi.inference.potentials.base_potential import BasePotential
from sbi.simulators import diagonal_linear_gaussian, linear_gaussian
from sbi.utils.get_nn_models import classifier_nn, likelihood_nn, posterior_nn
from sbi.utils.torchutils import BoxUniform, process_device
from sbi.utils.torchutils import BoxUniform, gpu_available, process_device
from sbi.utils.user_input_checks import (
check_embedding_net_device,
prepare_for_sbi,
Expand All @@ -49,32 +49,25 @@
"method, model, mcmc_method",
[
(SNPE_C, "mdn", "rejection"),
(SNPE_C, "maf", "slice"),
(SNPE_C, "maf", "slice_np_vectorized"),
(SNPE_C, "maf", "direct"),
(SNPE_C, "maf_rqs", "direct"),
(SNLE, "maf", "slice"),
(SNLE, "nsf", "slice_np"),
(SNLE, "nsf", "slice_np_vectorized"),
(SNLE, "nsf", "rejection"),
(SNLE, "maf", "importance"),
(SNLE, "maf_rqs", "slice"),
(SNRE_A, "mlp", "slice_np_vectorized"),
(SNRE_B, "resnet", "nuts"),
(SNRE_B, "resnet", "rejection"),
(SNRE_B, "resnet", "importance"),
(SNRE_C, "resnet", "nuts"),
(SNRE_C, "resnet", "rejection"),
(SNRE_C, "resnet", "importance"),
],
)
@pytest.mark.parametrize("data_device", ("cpu", "cuda:0"))
@pytest.mark.parametrize("data_device", ("cpu", "gpu"))
@pytest.mark.parametrize(
"training_device, prior_device",
[
pytest.param("cpu", "cuda", marks=pytest.mark.xfail),
pytest.param("cuda:0", "cpu", marks=pytest.mark.xfail),
("cuda:0", "cuda:0"),
("cuda:0", "cuda:0"),
("cpu", "cpu"),
("gpu", "gpu"),
pytest.param("cpu", "gpu", marks=pytest.mark.xfail),
pytest.param("gpu", "cpu", marks=pytest.mark.xfail),
],
)
def test_training_and_mcmc_on_device(
Expand All @@ -92,6 +85,10 @@ def test_training_and_mcmc_on_device(
"""

training_device = process_device(training_device)
data_device = process_device(data_device)
prior_device = process_device(prior_device)

num_dim = 2
num_samples = 10
num_simulations = 100
Expand All @@ -115,8 +112,6 @@ def test_training_and_mcmc_on_device(
def simulator(theta):
return linear_gaussian(theta, likelihood_shift, likelihood_cov)

training_device = process_device(training_device)

if method in [SNPE_A, SNPE_C]:
kwargs = dict(
density_estimator=utils.posterior_nn(model=model, num_transforms=2)
Expand All @@ -142,7 +137,7 @@ def simulator(theta):
theta, x = simulate_for_sbi(simulator, proposals[-1], num_simulations)
theta, x = theta.to(data_device), x.to(data_device)

estimator = inferer.append_simulations(theta, x).train(
estimator = inferer.append_simulations(theta, x, data_device=data_device).train(
training_batch_size=100, max_num_epochs=max_num_epochs, **train_kwargs
)
if method == SNLE:
Expand Down Expand Up @@ -181,6 +176,9 @@ def simulator(theta):
proposal=prior,
method=mcmc_method,
device=training_device,
# for speed
num_chains=10,
thin=1,
)
proposals.append(posterior)

Expand All @@ -192,31 +190,15 @@ def simulator(theta):


@pytest.mark.gpu
@pytest.mark.parametrize(
"device_input, device_target",
[
("cpu", "cpu"),
("cuda", "cuda:0"),
("cuda:0", "cuda:0"),
pytest.param("cuda:42", None, marks=pytest.mark.xfail),
pytest.param("qwerty", None, marks=pytest.mark.xfail),
],
)
def test_process_device(device_input: str, device_target: Optional[str]) -> None:
device_output = process_device(device_input)
assert device_output == device_target, (
f"Failure when processing device '{device_input}': "
f"result should have been '{device_target}' and is "
f"instead '{device_output}'"
)


@pytest.mark.gpu
@pytest.mark.parametrize("device_datum", ["cpu", "cuda"])
@pytest.mark.parametrize("device_embedding_net", ["cpu", "cuda"])
@pytest.mark.parametrize("device_datum", ["cpu", "gpu"])
@pytest.mark.parametrize("device_embedding_net", ["cpu", "gpu"])
def test_check_embedding_net_device(
device_datum: str, device_embedding_net: str
) -> None:

device_datum = process_device(device_datum)
device_embedding_net = process_device(device_embedding_net)

datum = torch.zeros((1, 1)).to(device_datum)
embedding_net = nn.Linear(in_features=1, out_features=1).to(device_embedding_net)

Expand All @@ -234,12 +216,7 @@ def test_check_embedding_net_device(
)


@pytest.mark.parametrize(
"shape_x",
[
(3, 1),
],
)
@pytest.mark.parametrize("shape_x", [(3, 1)])
@pytest.mark.parametrize(
"shape_theta", [(3, 2), pytest.param((2, 1), marks=pytest.mark.xfail)]
)
Expand Down Expand Up @@ -269,9 +246,13 @@ def test_validate_theta_and_x_type() -> None:


@pytest.mark.gpu
@pytest.mark.parametrize("training_device", ["cpu", "cuda:0"])
@pytest.mark.parametrize("data_device", ["cpu", "cuda:0"])
@pytest.mark.parametrize("training_device", ["cpu", "gpu"])
@pytest.mark.parametrize("data_device", ["cpu", "gpu"])
def test_validate_theta_and_x_device(training_device: str, data_device: str) -> None:

training_device = process_device(training_device)
data_device = process_device(data_device)

theta = torch.empty((1, 1)).to(data_device)
x = torch.empty((1, 1)).to(data_device)

Expand All @@ -289,12 +270,15 @@ def test_validate_theta_and_x_device(training_device: str, data_device: str) ->
@pytest.mark.parametrize(
"inference_method", [SNPE_A, SNPE_C, SNRE_A, SNRE_B, SNRE_C, SNLE]
)
@pytest.mark.parametrize("data_device", ("cpu", "cuda:0"))
@pytest.mark.parametrize("training_device", ("cpu", "cuda:0"))
@pytest.mark.parametrize("data_device", ("cpu", "gpu"))
@pytest.mark.parametrize("training_device", ("cpu", "gpu"))
def test_train_with_different_data_and_training_device(
inference_method, data_device: str, training_device: str
) -> None:
assert torch.cuda.is_available(), "this test requires that cuda is available."
assert gpu_available(), "this test requires that gpu is available."

data_device = process_device(data_device)
training_device = process_device(training_device)

num_dim = 2
prior_ = BoxUniform(
Expand All @@ -320,7 +304,7 @@ def test_train_with_different_data_and_training_device(
theta, x = simulate_for_sbi(simulator, prior, 32)
theta, x = theta.to(data_device), x.to(data_device)
x_o = torch.zeros(x.shape[1])
inference = inference.append_simulations(theta, x)
inference = inference.append_simulations(theta, x, data_device=data_device)

posterior_estimator = inference.train(max_num_epochs=2)

Expand All @@ -337,10 +321,10 @@ def test_train_with_different_data_and_training_device(
@pytest.mark.parametrize(
"inference_method", [SNPE_A, SNPE_C, SNRE_A, SNRE_B, SNRE_C, SNLE]
)
@pytest.mark.parametrize("prior_device", ("cpu", "cuda"))
@pytest.mark.parametrize("embedding_net_device", ("cpu", "cuda"))
@pytest.mark.parametrize("data_device", ("cpu", "cuda"))
@pytest.mark.parametrize("training_device", ("cpu", "cuda"))
@pytest.mark.parametrize("prior_device", ("cpu", "gpu"))
@pytest.mark.parametrize("embedding_net_device", ("cpu", "gpu"))
@pytest.mark.parametrize("data_device", ("cpu", "gpu"))
@pytest.mark.parametrize("training_device", ("cpu", "gpu"))
def test_embedding_nets_integration_training_device(
inference_method,
prior_device: str,
Expand All @@ -350,6 +334,12 @@ def test_embedding_nets_integration_training_device(
) -> None:
# add other methods

# process all device strings
prior_device = process_device(prior_device)
embedding_net_device = process_device(embedding_net_device)
data_device = process_device(data_device)
training_device = process_device(training_device)

D_theta = 2
D_x = 3
samples_per_round = 32
Expand Down Expand Up @@ -416,7 +406,7 @@ def test_embedding_nets_integration_training_device(
theta = prior.sample((samples_per_round,)).to(data_device)

proposal = prior
for round_idx in range(num_rounds):
for _ in range(num_rounds):
X = (
MultivariateNormal(torch.zeros((D_x,)), torch.eye(D_x))
.sample((samples_per_round,))
Expand Down Expand Up @@ -487,7 +477,7 @@ def test_vi_on_gpu(num_dim: int, q: Distribution, vi_method: str, sampling_metho
sampling_method: Different sampling methods
"""

device = "cuda:0"
device = process_device("gpu")

if num_dim == 1 and q in ["mcf", "scf"]:
return
Expand Down Expand Up @@ -529,15 +519,19 @@ def allow_iid_x(self) -> bool:
"arg_device, device",
[
("cpu", None),
("cuda", None),
("gpu", None),
("cpu", "cpu"),
("cuda", "cuda"),
pytest.param("cuda", "cpu", marks=pytest.mark.xfail),
pytest.param("cpu", "cuda", marks=pytest.mark.xfail),
("gpu", "gpu"),
pytest.param("gpu", "cpu", marks=pytest.mark.xfail),
pytest.param("cpu", "gpu", marks=pytest.mark.xfail),
],
)
def test_BoxUniform_device_handling(arg_device, device):
def test_boxuniform_device_handling(arg_device, device):
"""Test mismatch between device passed via low / high and device kwarg."""

arg_device = process_device(arg_device)
device = process_device(device)

prior = BoxUniform(
low=zeros(1).to(arg_device), high=ones(1).to(arg_device), device=device
)
Expand Down

0 comments on commit 14da16e

Please sign in to comment.