Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow Apple MPS as GPU device #912

Merged
merged 2 commits into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 53 additions & 16 deletions sbi/utils/torchutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

"""Various PyTorch utility functions."""

import os
import warnings
from typing import Any, Optional, Union

Expand All @@ -16,9 +17,12 @@


def process_device(device: str) -> str:
"""Set and return the default device to cpu or cuda.
"""Set and return the default device to cpu or gpu (cuda, mps).

Throws an AssertionError if the prior is not matching the training device not.
Args:
device: target torch device
Returns:
device: processed string, e.g., "cuda" is mapped to "cuda:0".
"""

# NOTE: we might want to add support for other devices in the future, e.g., MPS
Expand All @@ -34,22 +38,55 @@
"only for large neural networks with operations that are fast on the "
"GPU, e.g., for a CNN or RNN `embedding_net`."
)
if device in gpu_devices:
assert torch.cuda.is_available(), "CUDA is not available."
current_gpu_index = torch.cuda.current_device()
return f"cuda:{current_gpu_index}"
else:
# Check if the device is a valid cuda device.
try:
torch.randn(1, device=device)
except RuntimeError:
# If user just passes 'gpu', search for CUDA or MPS.
if device == "gpu":
# check whether either pytorch cuda or mps is available
if torch.cuda.is_available():
current_gpu_index = torch.cuda.current_device()
device = f"cuda:{current_gpu_index}"
check_device(device)
torch.cuda.set_device(device)

Check warning on line 48 in sbi/utils/torchutils.py

View check run for this annotation

Codecov / codecov/patch

sbi/utils/torchutils.py#L45-L48

Added lines #L45 - L48 were not covered by tests
elif torch.backends.mps.is_available():
device = "mps:0"

Check warning on line 50 in sbi/utils/torchutils.py

View check run for this annotation

Codecov / codecov/patch

sbi/utils/torchutils.py#L50

Added line #L50 was not covered by tests
# MPS support is not implemented for a number of operations.
# use CPU as fallback.
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

Check warning on line 53 in sbi/utils/torchutils.py

View check run for this annotation

Codecov / codecov/patch

sbi/utils/torchutils.py#L53

Added line #L53 was not covered by tests
# MPS framework does not support double precision.
torch.set_default_dtype(torch.float32)
check_device(device)

Check warning on line 56 in sbi/utils/torchutils.py

View check run for this annotation

Codecov / codecov/patch

sbi/utils/torchutils.py#L55-L56

Added lines #L55 - L56 were not covered by tests
else:
raise RuntimeError(
f"""Could not instantiate torch.randn(1, device={device}). Please
use one in {gpu_devices}, or cuda:<index> with <index> <
{torch.cuda.device_count()}."""
"Neither CUDA nor MPS is available. "
"Please make sure to install a version of PyTorch that supports "
"CUDA or MPS."
)
torch.cuda.set_device(device)
return device
# Else, check whether the custom device is valid.
else:
check_device(device)

return device

Check warning on line 67 in sbi/utils/torchutils.py

View check run for this annotation

Codecov / codecov/patch

sbi/utils/torchutils.py#L67

Added line #L67 was not covered by tests


def gpu_available() -> bool:
"""Check whether GPU is available."""
return torch.cuda.is_available() or torch.backends.mps.is_available()


def check_device(device: str) -> None:
"""Check whether the device is valid.

Args:
device: target torch device
"""
try:
torch.randn(1, device=device)
except (RuntimeError, AssertionError) as exc:
raise RuntimeError(
f"""Could not instantiate torch.randn(1, device={device}). Make sure
the device is set up properly and that you are passing the
corresponding device string. It should be something like 'cuda',
'cuda:0', or 'mps'. Error message: {exc}."""
) from exc


def check_if_prior_on_device(
Expand Down
78 changes: 50 additions & 28 deletions tests/inference_on_device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from __future__ import annotations

from contextlib import nullcontext
from typing import Optional, Tuple
from typing import Tuple

import numpy as np
import pytest
Expand Down Expand Up @@ -32,7 +32,7 @@
from sbi.inference.potentials.base_potential import BasePotential
from sbi.simulators import diagonal_linear_gaussian, linear_gaussian
from sbi.utils.get_nn_models import classifier_nn, likelihood_nn, posterior_nn
from sbi.utils.torchutils import BoxUniform, process_device
from sbi.utils.torchutils import BoxUniform, gpu_available, process_device
from sbi.utils.user_input_checks import (
check_embedding_net_device,
prepare_for_sbi,
Expand Down Expand Up @@ -208,12 +208,14 @@ def test_process_device(device_input: str, device_target: Optional[str]) -> None


@pytest.mark.gpu
@pytest.mark.parametrize("device_datum", ["cpu", "cuda"])
@pytest.mark.parametrize("device_embedding_net", ["cpu", "cuda"])
@pytest.mark.parametrize("device_datum", ["cpu", "gpu"])
@pytest.mark.parametrize("device_embedding_net", ["cpu", "gpu"])
def test_check_embedding_net_device(
device_datum: str, device_embedding_net: str
) -> None:
"""Test check_embedding_net_device and data with different device combinations."""
device_datum = process_device(device_datum)
device_embedding_net = process_device(device_embedding_net)

datum = torch.zeros((1, 1)).to(device_datum)
embedding_net = nn.Linear(in_features=1, out_features=1).to(device_embedding_net)

Expand Down Expand Up @@ -264,10 +266,12 @@ def test_validate_theta_and_x_type() -> None:


@pytest.mark.gpu
@pytest.mark.parametrize("training_device", ["cpu", "cuda:0"])
@pytest.mark.parametrize("data_device", ["cpu", "cuda:0"])
@pytest.mark.parametrize("training_device", ["cpu", "gpu"])
@pytest.mark.parametrize("data_device", ["cpu", "gpu"])
def test_validate_theta_and_x_device(training_device: str, data_device: str) -> None:
"""Test validate_theta_and_x with different devices."""
training_device = process_device(training_device)
data_device = process_device(data_device)

theta = torch.empty((1, 1)).to(data_device)
x = torch.empty((1, 1)).to(data_device)

Expand All @@ -285,13 +289,15 @@ def test_validate_theta_and_x_device(training_device: str, data_device: str) ->
@pytest.mark.parametrize(
"inference_method", [SNPE_A, SNPE_C, SNRE_A, SNRE_B, SNRE_C, SNLE]
)
@pytest.mark.parametrize("data_device", ("cpu", "cuda:0"))
@pytest.mark.parametrize("training_device", ("cpu", "cuda:0"))
@pytest.mark.parametrize("data_device", ("cpu", "gpu"))
@pytest.mark.parametrize("training_device", ("cpu", "gpu"))
def test_train_with_different_data_and_training_device(
inference_method, data_device: str, training_device: str
) -> None:
"""Test training with different data and training device."""
assert torch.cuda.is_available(), "this test requires that cuda is available."
assert gpu_available(), "this test requires that gpu is available."

data_device = process_device(data_device)
training_device = process_device(training_device)

num_dim = 2
prior_ = BoxUniform(
Expand All @@ -317,7 +323,7 @@ def test_train_with_different_data_and_training_device(
theta, x = simulate_for_sbi(simulator, prior, 32)
theta, x = theta.to(data_device), x.to(data_device)
x_o = torch.zeros(x.shape[1])
inference = inference.append_simulations(theta, x)
inference = inference.append_simulations(theta, x, data_device=data_device)

posterior_estimator = inference.train(max_num_epochs=2)

Expand All @@ -331,11 +337,13 @@ def test_train_with_different_data_and_training_device(


@pytest.mark.gpu
@pytest.mark.parametrize("inference_method", [SNPE_C, SNRE_A, SNRE_B, SNRE_C, SNLE])
@pytest.mark.parametrize("prior_device", ("cpu", "cuda"))
@pytest.mark.parametrize("embedding_net_device", ("cpu", "cuda"))
@pytest.mark.parametrize("data_device", ("cpu", "cuda"))
@pytest.mark.parametrize("training_device", ("cpu", "cuda"))
@pytest.mark.parametrize(
"inference_method", [SNPE_A, SNPE_C, SNRE_A, SNRE_B, SNRE_C, SNLE]
)
@pytest.mark.parametrize("prior_device", ("cpu", "gpu"))
@pytest.mark.parametrize("embedding_net_device", ("cpu", "gpu"))
@pytest.mark.parametrize("data_device", ("cpu", "gpu"))
@pytest.mark.parametrize("training_device", ("cpu", "gpu"))
def test_embedding_nets_integration_training_device(
inference_method,
prior_device: str,
Expand All @@ -348,6 +356,12 @@ def test_embedding_nets_integration_training_device(

theta_dim = 2
x_dim = 3
# process all device strings
prior_device = process_device(prior_device)
embedding_net_device = process_device(embedding_net_device)
data_device = process_device(data_device)
training_device = process_device(training_device)

samples_per_round = 32
num_rounds = 2

Expand Down Expand Up @@ -480,11 +494,13 @@ def check_no_grad(model):

@pytest.mark.slow
@pytest.mark.gpu
@pytest.mark.parametrize("num_dim", (1, 2))
@pytest.mark.parametrize("num_dim", (1, 3))
# NOTE: macOS MPS fails for nsf with num_dim > 1
# might be related to https://github.com/pytorch/pytorch/issues/89127
@pytest.mark.parametrize("q", ("maf", "nsf", "gaussian_diag", "gaussian", "mcf", "scf"))
@pytest.mark.parametrize("vi_method", ("rKL", "fKL", "IW", "alpha"))
@pytest.mark.parametrize("sampling_method", ("naive", "sir"))
def test_vi_on_gpu(num_dim: int, q: Distribution, vi_method: str, sampling_method: str):
def test_vi_on_gpu(num_dim: int, q: str, vi_method: str, sampling_method: str):
"""Test VI on Gaussian, comparing to ground truth target via c2st.

Args:
Expand All @@ -493,11 +509,15 @@ def test_vi_on_gpu(num_dim: int, q: Distribution, vi_method: str, sampling_metho
sampling_method: Different sampling methods
"""

device = "cuda:0"
device = process_device("gpu")

if num_dim == 1 and q in ["mcf", "scf"]:
return

# Skip the test for nsf on mps:0 as it results in NaNs.
if device == "mps:0" and num_dim > 1 and q == "nsf":
janfb marked this conversation as resolved.
Show resolved Hide resolved
return

# Good run where everythink is one the correct device.
class FakePotential(BasePotential):
def __call__(self, theta, **kwargs):
Expand All @@ -517,9 +537,7 @@ def allow_iid_x(self) -> bool:
posterior = VIPosterior(
potential_fn=potential_fn, theta_transform=theta_transform, q=q, device=device
)
posterior.set_default_x(
torch.tensor(np.zeros((num_dim,)).astype(np.float32)).to(device)
)
posterior.set_default_x(torch.zeros((num_dim,), dtype=torch.float32).to(device))
posterior.vi_method = vi_method

posterior.train(min_num_iters=9, max_num_iters=10, warm_up_rounds=10)
Expand All @@ -539,15 +557,19 @@ def allow_iid_x(self) -> bool:
"arg_device, device",
[
("cpu", None),
("cuda", None),
("gpu", None),
("cpu", "cpu"),
("cuda", "cuda"),
pytest.param("cuda", "cpu", marks=pytest.mark.xfail),
pytest.param("cpu", "cuda", marks=pytest.mark.xfail),
("gpu", "gpu"),
pytest.param("gpu", "cpu", marks=pytest.mark.xfail),
pytest.param("cpu", "gpu", marks=pytest.mark.xfail),
],
)
def test_boxuniform_device_handling(arg_device, device):
"""Test mismatch between device passed via low / high and device kwarg."""

arg_device = process_device(arg_device)
device = process_device(device)

prior = BoxUniform(
low=zeros(1).to(arg_device), high=ones(1).to(arg_device), device=device
)
Expand Down
34 changes: 34 additions & 0 deletions tests/torchutils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
"""Test PyTorch utility functions."""
from __future__ import annotations

from typing import Optional

import numpy as np
import pytest
import torch
import torchtestcase
from torch import distributions as distributions
Expand Down Expand Up @@ -196,3 +199,34 @@ def test_dkl_gauss():
f"Monte-Carlo-based KLd={monte_carlo_dkl} is too far from the torch"
f" implementation, {torch_dkl}."
)


@pytest.mark.parametrize("device_input", ("cpu", "gpu", "cuda", "cuda:0", "mps"))
def test_process_device(device_input: str) -> None:
"""Test whether the device is processed correctly."""

try:
device_output = torchutils.process_device(device_input)
if device_input == "cpu":
assert device_output == "cpu"
elif device_input == "gpu":
if torch.cuda.is_available():
current_gpu_index = torch.cuda.current_device()
assert device_output == f"cuda:{current_gpu_index}"
elif torch.backends.mps.is_available():
assert device_output == "mps"

if device_input == "cuda" and torch.cuda.is_available():
assert device_output == "cuda:0"
if device_input == "cuda:0" and torch.cuda.is_available():
assert device_output == "cuda:0"
if device_input == "mps" and torch.backends.mps.is_available():
assert device_output == "mps"

except RuntimeError:
# this should not happen for cpu
assert not device_input == "cpu"

# should only happen if no gpu is available
if device_input == "gpu":
assert not torchutils.gpu_available()
Loading