diff --git a/sbi/utils/torchutils.py b/sbi/utils/torchutils.py index 4ac0d6c0f..53b248b28 100644 --- a/sbi/utils/torchutils.py +++ b/sbi/utils/torchutils.py @@ -3,6 +3,7 @@ """Various PyTorch utility functions.""" +import os import warnings from typing import Any, Optional, Union @@ -16,9 +17,12 @@ def process_device(device: str) -> str: - """Set and return the default device to cpu or cuda. + """Set and return the default device to cpu or gpu (cuda, mps). - Throws an AssertionError if the prior is not matching the training device not. + Args: + device: target torch device + Returns: + device: processed string, e.g., "cuda" is mapped to "cuda:0". """ # NOTE: we might want to add support for other devices in the future, e.g., MPS @@ -34,22 +38,55 @@ def process_device(device: str) -> str: "only for large neural networks with operations that are fast on the " "GPU, e.g., for a CNN or RNN `embedding_net`." ) - if device in gpu_devices: - assert torch.cuda.is_available(), "CUDA is not available." - current_gpu_index = torch.cuda.current_device() - return f"cuda:{current_gpu_index}" - else: - # Check if the device is a valid cuda device. - try: - torch.randn(1, device=device) - except RuntimeError: + # If user just passes 'gpu', search for CUDA or MPS. + if device == "gpu": + # check whether either pytorch cuda or mps is available + if torch.cuda.is_available(): + current_gpu_index = torch.cuda.current_device() + device = f"cuda:{current_gpu_index}" + check_device(device) + torch.cuda.set_device(device) + elif torch.backends.mps.is_available(): + device = "mps:0" + # MPS support is not implemented for a number of operations. + # use CPU as fallback. + os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" + # MPS framework does not support double precision. + torch.set_default_dtype(torch.float32) + check_device(device) + else: raise RuntimeError( - f"""Could not instantiate torch.randn(1, device={device}). Please - use one in {gpu_devices}, or cuda: with < - {torch.cuda.device_count()}.""" + "Neither CUDA nor MPS is available. " + "Please make sure to install a version of PyTorch that supports " + "CUDA or MPS." ) - torch.cuda.set_device(device) - return device + # Else, check whether the custom device is valid. + else: + check_device(device) + + return device + + +def gpu_available() -> bool: + """Check whether GPU is available.""" + return torch.cuda.is_available() or torch.backends.mps.is_available() + + +def check_device(device: str) -> None: + """Check whether the device is valid. + + Args: + device: target torch device + """ + try: + torch.randn(1, device=device) + except (RuntimeError, AssertionError) as exc: + raise RuntimeError( + f"""Could not instantiate torch.randn(1, device={device}). Make sure + the device is set up properly and that you are passing the + corresponding device string. It should be something like 'cuda', + 'cuda:0', or 'mps'. Error message: {exc}.""" + ) from exc def check_if_prior_on_device( diff --git a/tests/inference_on_device_test.py b/tests/inference_on_device_test.py index 97fca4098..f7dcf3aee 100644 --- a/tests/inference_on_device_test.py +++ b/tests/inference_on_device_test.py @@ -4,7 +4,7 @@ from __future__ import annotations from contextlib import nullcontext -from typing import Optional, Tuple +from typing import Tuple import numpy as np import pytest @@ -32,7 +32,7 @@ from sbi.inference.potentials.base_potential import BasePotential from sbi.simulators import diagonal_linear_gaussian, linear_gaussian from sbi.utils.get_nn_models import classifier_nn, likelihood_nn, posterior_nn -from sbi.utils.torchutils import BoxUniform, process_device +from sbi.utils.torchutils import BoxUniform, gpu_available, process_device from sbi.utils.user_input_checks import ( check_embedding_net_device, prepare_for_sbi, @@ -208,12 +208,14 @@ def test_process_device(device_input: str, device_target: Optional[str]) -> None @pytest.mark.gpu -@pytest.mark.parametrize("device_datum", ["cpu", "cuda"]) -@pytest.mark.parametrize("device_embedding_net", ["cpu", "cuda"]) +@pytest.mark.parametrize("device_datum", ["cpu", "gpu"]) +@pytest.mark.parametrize("device_embedding_net", ["cpu", "gpu"]) def test_check_embedding_net_device( device_datum: str, device_embedding_net: str ) -> None: - """Test check_embedding_net_device and data with different device combinations.""" + device_datum = process_device(device_datum) + device_embedding_net = process_device(device_embedding_net) + datum = torch.zeros((1, 1)).to(device_datum) embedding_net = nn.Linear(in_features=1, out_features=1).to(device_embedding_net) @@ -264,10 +266,12 @@ def test_validate_theta_and_x_type() -> None: @pytest.mark.gpu -@pytest.mark.parametrize("training_device", ["cpu", "cuda:0"]) -@pytest.mark.parametrize("data_device", ["cpu", "cuda:0"]) +@pytest.mark.parametrize("training_device", ["cpu", "gpu"]) +@pytest.mark.parametrize("data_device", ["cpu", "gpu"]) def test_validate_theta_and_x_device(training_device: str, data_device: str) -> None: - """Test validate_theta_and_x with different devices.""" + training_device = process_device(training_device) + data_device = process_device(data_device) + theta = torch.empty((1, 1)).to(data_device) x = torch.empty((1, 1)).to(data_device) @@ -285,13 +289,15 @@ def test_validate_theta_and_x_device(training_device: str, data_device: str) -> @pytest.mark.parametrize( "inference_method", [SNPE_A, SNPE_C, SNRE_A, SNRE_B, SNRE_C, SNLE] ) -@pytest.mark.parametrize("data_device", ("cpu", "cuda:0")) -@pytest.mark.parametrize("training_device", ("cpu", "cuda:0")) +@pytest.mark.parametrize("data_device", ("cpu", "gpu")) +@pytest.mark.parametrize("training_device", ("cpu", "gpu")) def test_train_with_different_data_and_training_device( inference_method, data_device: str, training_device: str ) -> None: - """Test training with different data and training device.""" - assert torch.cuda.is_available(), "this test requires that cuda is available." + assert gpu_available(), "this test requires that gpu is available." + + data_device = process_device(data_device) + training_device = process_device(training_device) num_dim = 2 prior_ = BoxUniform( @@ -317,7 +323,7 @@ def test_train_with_different_data_and_training_device( theta, x = simulate_for_sbi(simulator, prior, 32) theta, x = theta.to(data_device), x.to(data_device) x_o = torch.zeros(x.shape[1]) - inference = inference.append_simulations(theta, x) + inference = inference.append_simulations(theta, x, data_device=data_device) posterior_estimator = inference.train(max_num_epochs=2) @@ -331,11 +337,13 @@ def test_train_with_different_data_and_training_device( @pytest.mark.gpu -@pytest.mark.parametrize("inference_method", [SNPE_C, SNRE_A, SNRE_B, SNRE_C, SNLE]) -@pytest.mark.parametrize("prior_device", ("cpu", "cuda")) -@pytest.mark.parametrize("embedding_net_device", ("cpu", "cuda")) -@pytest.mark.parametrize("data_device", ("cpu", "cuda")) -@pytest.mark.parametrize("training_device", ("cpu", "cuda")) +@pytest.mark.parametrize( + "inference_method", [SNPE_A, SNPE_C, SNRE_A, SNRE_B, SNRE_C, SNLE] +) +@pytest.mark.parametrize("prior_device", ("cpu", "gpu")) +@pytest.mark.parametrize("embedding_net_device", ("cpu", "gpu")) +@pytest.mark.parametrize("data_device", ("cpu", "gpu")) +@pytest.mark.parametrize("training_device", ("cpu", "gpu")) def test_embedding_nets_integration_training_device( inference_method, prior_device: str, @@ -348,6 +356,12 @@ def test_embedding_nets_integration_training_device( theta_dim = 2 x_dim = 3 + # process all device strings + prior_device = process_device(prior_device) + embedding_net_device = process_device(embedding_net_device) + data_device = process_device(data_device) + training_device = process_device(training_device) + samples_per_round = 32 num_rounds = 2 @@ -480,11 +494,13 @@ def check_no_grad(model): @pytest.mark.slow @pytest.mark.gpu -@pytest.mark.parametrize("num_dim", (1, 2)) +@pytest.mark.parametrize("num_dim", (1, 3)) +# NOTE: macOS MPS fails for nsf with num_dim > 1 +# might be related to https://github.com/pytorch/pytorch/issues/89127 @pytest.mark.parametrize("q", ("maf", "nsf", "gaussian_diag", "gaussian", "mcf", "scf")) @pytest.mark.parametrize("vi_method", ("rKL", "fKL", "IW", "alpha")) @pytest.mark.parametrize("sampling_method", ("naive", "sir")) -def test_vi_on_gpu(num_dim: int, q: Distribution, vi_method: str, sampling_method: str): +def test_vi_on_gpu(num_dim: int, q: str, vi_method: str, sampling_method: str): """Test VI on Gaussian, comparing to ground truth target via c2st. Args: @@ -493,11 +509,15 @@ def test_vi_on_gpu(num_dim: int, q: Distribution, vi_method: str, sampling_metho sampling_method: Different sampling methods """ - device = "cuda:0" + device = process_device("gpu") if num_dim == 1 and q in ["mcf", "scf"]: return + # Skip the test for nsf on mps:0 as it results in NaNs. + if device == "mps:0" and num_dim > 1 and q == "nsf": + return + # Good run where everythink is one the correct device. class FakePotential(BasePotential): def __call__(self, theta, **kwargs): @@ -517,9 +537,7 @@ def allow_iid_x(self) -> bool: posterior = VIPosterior( potential_fn=potential_fn, theta_transform=theta_transform, q=q, device=device ) - posterior.set_default_x( - torch.tensor(np.zeros((num_dim,)).astype(np.float32)).to(device) - ) + posterior.set_default_x(torch.zeros((num_dim,), dtype=torch.float32).to(device)) posterior.vi_method = vi_method posterior.train(min_num_iters=9, max_num_iters=10, warm_up_rounds=10) @@ -539,15 +557,19 @@ def allow_iid_x(self) -> bool: "arg_device, device", [ ("cpu", None), - ("cuda", None), + ("gpu", None), ("cpu", "cpu"), - ("cuda", "cuda"), - pytest.param("cuda", "cpu", marks=pytest.mark.xfail), - pytest.param("cpu", "cuda", marks=pytest.mark.xfail), + ("gpu", "gpu"), + pytest.param("gpu", "cpu", marks=pytest.mark.xfail), + pytest.param("cpu", "gpu", marks=pytest.mark.xfail), ], ) def test_boxuniform_device_handling(arg_device, device): """Test mismatch between device passed via low / high and device kwarg.""" + + arg_device = process_device(arg_device) + device = process_device(device) + prior = BoxUniform( low=zeros(1).to(arg_device), high=ones(1).to(arg_device), device=device ) diff --git a/tests/torchutils_test.py b/tests/torchutils_test.py index fc9db179e..c3c55b2bd 100644 --- a/tests/torchutils_test.py +++ b/tests/torchutils_test.py @@ -4,7 +4,10 @@ """Test PyTorch utility functions.""" from __future__ import annotations +from typing import Optional + import numpy as np +import pytest import torch import torchtestcase from torch import distributions as distributions @@ -196,3 +199,34 @@ def test_dkl_gauss(): f"Monte-Carlo-based KLd={monte_carlo_dkl} is too far from the torch" f" implementation, {torch_dkl}." ) + + +@pytest.mark.parametrize("device_input", ("cpu", "gpu", "cuda", "cuda:0", "mps")) +def test_process_device(device_input: str) -> None: + """Test whether the device is processed correctly.""" + + try: + device_output = torchutils.process_device(device_input) + if device_input == "cpu": + assert device_output == "cpu" + elif device_input == "gpu": + if torch.cuda.is_available(): + current_gpu_index = torch.cuda.current_device() + assert device_output == f"cuda:{current_gpu_index}" + elif torch.backends.mps.is_available(): + assert device_output == "mps" + + if device_input == "cuda" and torch.cuda.is_available(): + assert device_output == "cuda:0" + if device_input == "cuda:0" and torch.cuda.is_available(): + assert device_output == "cuda:0" + if device_input == "mps" and torch.backends.mps.is_available(): + assert device_output == "mps" + + except RuntimeError: + # this should not happen for cpu + assert not device_input == "cpu" + + # should only happen if no gpu is available + if device_input == "gpu": + assert not torchutils.gpu_available()