Skip to content

Commit

Permalink
refactor device tests
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Feb 16, 2024
1 parent 1319529 commit 64af31c
Showing 1 changed file with 36 additions and 19 deletions.
55 changes: 36 additions & 19 deletions tests/inference_on_device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from sbi.inference.potentials.base_potential import BasePotential
from sbi.simulators import diagonal_linear_gaussian, linear_gaussian
from sbi.utils.get_nn_models import classifier_nn, likelihood_nn, posterior_nn
from sbi.utils.torchutils import BoxUniform, process_device
from sbi.utils.torchutils import BoxUniform, gpu_available, process_device
from sbi.utils.user_input_checks import (
check_embedding_net_device,
prepare_for_sbi,
Expand Down Expand Up @@ -264,10 +264,13 @@ def test_validate_theta_and_x_type() -> None:


@pytest.mark.gpu
@pytest.mark.parametrize("training_device", ["cpu", "cuda:0"])
@pytest.mark.parametrize("data_device", ["cpu", "cuda:0"])
@pytest.mark.parametrize("training_device", ["cpu", "gpu"])
@pytest.mark.parametrize("data_device", ["cpu", "gpu"])
def test_validate_theta_and_x_device(training_device: str, data_device: str) -> None:
"""Test validate_theta_and_x with different devices."""

training_device = process_device(training_device)
data_device = process_device(data_device)

theta = torch.empty((1, 1)).to(data_device)
x = torch.empty((1, 1)).to(data_device)

Expand All @@ -285,13 +288,15 @@ def test_validate_theta_and_x_device(training_device: str, data_device: str) ->
@pytest.mark.parametrize(
"inference_method", [SNPE_A, SNPE_C, SNRE_A, SNRE_B, SNRE_C, SNLE]
)
@pytest.mark.parametrize("data_device", ("cpu", "cuda:0"))
@pytest.mark.parametrize("training_device", ("cpu", "cuda:0"))
@pytest.mark.parametrize("data_device", ("cpu", "gpu"))
@pytest.mark.parametrize("training_device", ("cpu", "gpu"))
def test_train_with_different_data_and_training_device(
inference_method, data_device: str, training_device: str
) -> None:
"""Test training with different data and training device."""
assert torch.cuda.is_available(), "this test requires that cuda is available."
assert gpu_available(), "this test requires that gpu is available."

data_device = process_device(data_device)
training_device = process_device(training_device)

num_dim = 2
prior_ = BoxUniform(
Expand All @@ -317,7 +322,7 @@ def test_train_with_different_data_and_training_device(
theta, x = simulate_for_sbi(simulator, prior, 32)
theta, x = theta.to(data_device), x.to(data_device)
x_o = torch.zeros(x.shape[1])
inference = inference.append_simulations(theta, x)
inference = inference.append_simulations(theta, x, data_device=data_device)

posterior_estimator = inference.train(max_num_epochs=2)

Expand All @@ -331,11 +336,13 @@ def test_train_with_different_data_and_training_device(


@pytest.mark.gpu
@pytest.mark.parametrize("inference_method", [SNPE_C, SNRE_A, SNRE_B, SNRE_C, SNLE])
@pytest.mark.parametrize("prior_device", ("cpu", "cuda"))
@pytest.mark.parametrize("embedding_net_device", ("cpu", "cuda"))
@pytest.mark.parametrize("data_device", ("cpu", "cuda"))
@pytest.mark.parametrize("training_device", ("cpu", "cuda"))
@pytest.mark.parametrize(
"inference_method", [SNPE_A, SNPE_C, SNRE_A, SNRE_B, SNRE_C, SNLE]
)
@pytest.mark.parametrize("prior_device", ("cpu", "gpu"))
@pytest.mark.parametrize("embedding_net_device", ("cpu", "gpu"))
@pytest.mark.parametrize("data_device", ("cpu", "gpu"))
@pytest.mark.parametrize("training_device", ("cpu", "gpu"))
def test_embedding_nets_integration_training_device(
inference_method,
prior_device: str,
Expand All @@ -348,6 +355,12 @@ def test_embedding_nets_integration_training_device(

theta_dim = 2
x_dim = 3
# process all device strings
prior_device = process_device(prior_device)
embedding_net_device = process_device(embedding_net_device)
data_device = process_device(data_device)
training_device = process_device(training_device)

samples_per_round = 32
num_rounds = 2

Expand Down Expand Up @@ -493,7 +506,7 @@ def test_vi_on_gpu(num_dim: int, q: Distribution, vi_method: str, sampling_metho
sampling_method: Different sampling methods
"""

device = "cuda:0"
device = process_device("gpu")

if num_dim == 1 and q in ["mcf", "scf"]:
return
Expand Down Expand Up @@ -539,15 +552,19 @@ def allow_iid_x(self) -> bool:
"arg_device, device",
[
("cpu", None),
("cuda", None),
("gpu", None),
("cpu", "cpu"),
("cuda", "cuda"),
pytest.param("cuda", "cpu", marks=pytest.mark.xfail),
pytest.param("cpu", "cuda", marks=pytest.mark.xfail),
("gpu", "gpu"),
pytest.param("gpu", "cpu", marks=pytest.mark.xfail),
pytest.param("cpu", "gpu", marks=pytest.mark.xfail),
],
)
def test_boxuniform_device_handling(arg_device, device):
"""Test mismatch between device passed via low / high and device kwarg."""

arg_device = process_device(arg_device)
device = process_device(device)

prior = BoxUniform(
low=zeros(1).to(arg_device), high=ones(1).to(arg_device), device=device
)
Expand Down

0 comments on commit 64af31c

Please sign in to comment.