Skip to content

Commit

Permalink
refactor device tests
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Feb 16, 2024
1 parent 1319529 commit afc7df2
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 30 deletions.
2 changes: 1 addition & 1 deletion sbi/utils/torchutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def check_device(device: str) -> None:
f"""Could not instantiate torch.randn(1, device={device}). Make sure
the device is set up properly and that you are passing the
corresponding device string. It should be something like 'cuda',
'cuda:0', or 'mps'."""
'cuda:0', or 'mps'. Error message: {exc}."""
) from exc


Expand Down
78 changes: 50 additions & 28 deletions tests/inference_on_device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from __future__ import annotations

from contextlib import nullcontext
from typing import Optional, Tuple
from typing import Tuple

import numpy as np
import pytest
Expand Down Expand Up @@ -32,7 +32,7 @@
from sbi.inference.potentials.base_potential import BasePotential
from sbi.simulators import diagonal_linear_gaussian, linear_gaussian
from sbi.utils.get_nn_models import classifier_nn, likelihood_nn, posterior_nn
from sbi.utils.torchutils import BoxUniform, process_device
from sbi.utils.torchutils import BoxUniform, gpu_available, process_device
from sbi.utils.user_input_checks import (
check_embedding_net_device,
prepare_for_sbi,
Expand Down Expand Up @@ -208,12 +208,14 @@ def test_process_device(device_input: str, device_target: Optional[str]) -> None


@pytest.mark.gpu
@pytest.mark.parametrize("device_datum", ["cpu", "cuda"])
@pytest.mark.parametrize("device_embedding_net", ["cpu", "cuda"])
@pytest.mark.parametrize("device_datum", ["cpu", "gpu"])
@pytest.mark.parametrize("device_embedding_net", ["cpu", "gpu"])
def test_check_embedding_net_device(
device_datum: str, device_embedding_net: str
) -> None:
"""Test check_embedding_net_device and data with different device combinations."""
device_datum = process_device(device_datum)
device_embedding_net = process_device(device_embedding_net)

datum = torch.zeros((1, 1)).to(device_datum)
embedding_net = nn.Linear(in_features=1, out_features=1).to(device_embedding_net)

Expand Down Expand Up @@ -264,10 +266,12 @@ def test_validate_theta_and_x_type() -> None:


@pytest.mark.gpu
@pytest.mark.parametrize("training_device", ["cpu", "cuda:0"])
@pytest.mark.parametrize("data_device", ["cpu", "cuda:0"])
@pytest.mark.parametrize("training_device", ["cpu", "gpu"])
@pytest.mark.parametrize("data_device", ["cpu", "gpu"])
def test_validate_theta_and_x_device(training_device: str, data_device: str) -> None:
"""Test validate_theta_and_x with different devices."""
training_device = process_device(training_device)
data_device = process_device(data_device)

theta = torch.empty((1, 1)).to(data_device)
x = torch.empty((1, 1)).to(data_device)

Expand All @@ -285,13 +289,15 @@ def test_validate_theta_and_x_device(training_device: str, data_device: str) ->
@pytest.mark.parametrize(
"inference_method", [SNPE_A, SNPE_C, SNRE_A, SNRE_B, SNRE_C, SNLE]
)
@pytest.mark.parametrize("data_device", ("cpu", "cuda:0"))
@pytest.mark.parametrize("training_device", ("cpu", "cuda:0"))
@pytest.mark.parametrize("data_device", ("cpu", "gpu"))
@pytest.mark.parametrize("training_device", ("cpu", "gpu"))
def test_train_with_different_data_and_training_device(
inference_method, data_device: str, training_device: str
) -> None:
"""Test training with different data and training device."""
assert torch.cuda.is_available(), "this test requires that cuda is available."
assert gpu_available(), "this test requires that gpu is available."

data_device = process_device(data_device)
training_device = process_device(training_device)

num_dim = 2
prior_ = BoxUniform(
Expand All @@ -317,7 +323,7 @@ def test_train_with_different_data_and_training_device(
theta, x = simulate_for_sbi(simulator, prior, 32)
theta, x = theta.to(data_device), x.to(data_device)
x_o = torch.zeros(x.shape[1])
inference = inference.append_simulations(theta, x)
inference = inference.append_simulations(theta, x, data_device=data_device)

posterior_estimator = inference.train(max_num_epochs=2)

Expand All @@ -331,11 +337,13 @@ def test_train_with_different_data_and_training_device(


@pytest.mark.gpu
@pytest.mark.parametrize("inference_method", [SNPE_C, SNRE_A, SNRE_B, SNRE_C, SNLE])
@pytest.mark.parametrize("prior_device", ("cpu", "cuda"))
@pytest.mark.parametrize("embedding_net_device", ("cpu", "cuda"))
@pytest.mark.parametrize("data_device", ("cpu", "cuda"))
@pytest.mark.parametrize("training_device", ("cpu", "cuda"))
@pytest.mark.parametrize(
"inference_method", [SNPE_A, SNPE_C, SNRE_A, SNRE_B, SNRE_C, SNLE]
)
@pytest.mark.parametrize("prior_device", ("cpu", "gpu"))
@pytest.mark.parametrize("embedding_net_device", ("cpu", "gpu"))
@pytest.mark.parametrize("data_device", ("cpu", "gpu"))
@pytest.mark.parametrize("training_device", ("cpu", "gpu"))
def test_embedding_nets_integration_training_device(
inference_method,
prior_device: str,
Expand All @@ -348,6 +356,12 @@ def test_embedding_nets_integration_training_device(

theta_dim = 2
x_dim = 3
# process all device strings
prior_device = process_device(prior_device)
embedding_net_device = process_device(embedding_net_device)
data_device = process_device(data_device)
training_device = process_device(training_device)

samples_per_round = 32
num_rounds = 2

Expand Down Expand Up @@ -480,11 +494,13 @@ def check_no_grad(model):

@pytest.mark.slow
@pytest.mark.gpu
@pytest.mark.parametrize("num_dim", (1, 2))
@pytest.mark.parametrize("num_dim", (1, 3))
# NOTE: macOS MPS fails for nsf with num_dim > 1
# might be related to https://github.com/pytorch/pytorch/issues/89127
@pytest.mark.parametrize("q", ("maf", "nsf", "gaussian_diag", "gaussian", "mcf", "scf"))
@pytest.mark.parametrize("vi_method", ("rKL", "fKL", "IW", "alpha"))
@pytest.mark.parametrize("sampling_method", ("naive", "sir"))
def test_vi_on_gpu(num_dim: int, q: Distribution, vi_method: str, sampling_method: str):
def test_vi_on_gpu(num_dim: int, q: str, vi_method: str, sampling_method: str):
"""Test VI on Gaussian, comparing to ground truth target via c2st.
Args:
Expand All @@ -493,11 +509,15 @@ def test_vi_on_gpu(num_dim: int, q: Distribution, vi_method: str, sampling_metho
sampling_method: Different sampling methods
"""

device = "cuda:0"
device = process_device("gpu")

if num_dim == 1 and q in ["mcf", "scf"]:
return

# Skip the test for nsf on mps:0 as it results in NaNs.
if device == "mps:0" and num_dim > 1 and q == "nsf":
return

# Good run where everythink is one the correct device.
class FakePotential(BasePotential):
def __call__(self, theta, **kwargs):
Expand All @@ -517,9 +537,7 @@ def allow_iid_x(self) -> bool:
posterior = VIPosterior(
potential_fn=potential_fn, theta_transform=theta_transform, q=q, device=device
)
posterior.set_default_x(
torch.tensor(np.zeros((num_dim,)).astype(np.float32)).to(device)
)
posterior.set_default_x(torch.zeros((num_dim,), dtype=torch.float32).to(device))
posterior.vi_method = vi_method

posterior.train(min_num_iters=9, max_num_iters=10, warm_up_rounds=10)
Expand All @@ -539,15 +557,19 @@ def allow_iid_x(self) -> bool:
"arg_device, device",
[
("cpu", None),
("cuda", None),
("gpu", None),
("cpu", "cpu"),
("cuda", "cuda"),
pytest.param("cuda", "cpu", marks=pytest.mark.xfail),
pytest.param("cpu", "cuda", marks=pytest.mark.xfail),
("gpu", "gpu"),
pytest.param("gpu", "cpu", marks=pytest.mark.xfail),
pytest.param("cpu", "gpu", marks=pytest.mark.xfail),
],
)
def test_boxuniform_device_handling(arg_device, device):
"""Test mismatch between device passed via low / high and device kwarg."""

arg_device = process_device(arg_device)
device = process_device(device)

prior = BoxUniform(
low=zeros(1).to(arg_device), high=ones(1).to(arg_device), device=device
)
Expand Down
3 changes: 2 additions & 1 deletion tests/torchutils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

"""Test PyTorch utility functions."""
from __future__ import annotations

from typing import Optional

import numpy as np
Expand Down Expand Up @@ -200,7 +201,7 @@ def test_dkl_gauss():
)


@pytest.mark.parametrize("device_input", ("cpu", "gpu", "cuda", "cuda:0", "mps", ))
@pytest.mark.parametrize("device_input", ("cpu", "gpu", "cuda", "cuda:0", "mps"))
def test_process_device(device_input: str) -> None:
"""Test whether the device is processed correctly."""

Expand Down

0 comments on commit afc7df2

Please sign in to comment.