Skip to content

Commit

Permalink
refactor tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Feb 16, 2024
1 parent 64af31c commit 9d9b128
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 12 deletions.
2 changes: 1 addition & 1 deletion sbi/utils/torchutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def check_device(device: str) -> None:
f"""Could not instantiate torch.randn(1, device={device}). Make sure
the device is set up properly and that you are passing the
corresponding device string. It should be something like 'cuda',
'cuda:0', or 'mps'."""
'cuda:0', or 'mps'. Error message: {exc}."""
) from exc


Expand Down
23 changes: 13 additions & 10 deletions tests/inference_on_device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from __future__ import annotations

from contextlib import nullcontext
from typing import Optional, Tuple
from typing import Tuple

import numpy as np
import pytest
Expand Down Expand Up @@ -208,12 +208,14 @@ def test_process_device(device_input: str, device_target: Optional[str]) -> None


@pytest.mark.gpu
@pytest.mark.parametrize("device_datum", ["cpu", "cuda"])
@pytest.mark.parametrize("device_embedding_net", ["cpu", "cuda"])
@pytest.mark.parametrize("device_datum", ["cpu", "gpu"])
@pytest.mark.parametrize("device_embedding_net", ["cpu", "gpu"])
def test_check_embedding_net_device(
device_datum: str, device_embedding_net: str
) -> None:
"""Test check_embedding_net_device and data with different device combinations."""
device_datum = process_device(device_datum)
device_embedding_net = process_device(device_embedding_net)

datum = torch.zeros((1, 1)).to(device_datum)
embedding_net = nn.Linear(in_features=1, out_features=1).to(device_embedding_net)

Expand Down Expand Up @@ -267,7 +269,6 @@ def test_validate_theta_and_x_type() -> None:
@pytest.mark.parametrize("training_device", ["cpu", "gpu"])
@pytest.mark.parametrize("data_device", ["cpu", "gpu"])
def test_validate_theta_and_x_device(training_device: str, data_device: str) -> None:

training_device = process_device(training_device)
data_device = process_device(data_device)

Expand Down Expand Up @@ -493,11 +494,11 @@ def check_no_grad(model):

@pytest.mark.slow
@pytest.mark.gpu
@pytest.mark.parametrize("num_dim", (1, 2))
@pytest.mark.parametrize("num_dim", (1, 3))
@pytest.mark.parametrize("q", ("maf", "nsf", "gaussian_diag", "gaussian", "mcf", "scf"))
@pytest.mark.parametrize("vi_method", ("rKL", "fKL", "IW", "alpha"))
@pytest.mark.parametrize("sampling_method", ("naive", "sir"))
def test_vi_on_gpu(num_dim: int, q: Distribution, vi_method: str, sampling_method: str):
def test_vi_on_gpu(num_dim: int, q: str, vi_method: str, sampling_method: str):
"""Test VI on Gaussian, comparing to ground truth target via c2st.
Args:
Expand All @@ -511,6 +512,10 @@ def test_vi_on_gpu(num_dim: int, q: Distribution, vi_method: str, sampling_metho
if num_dim == 1 and q in ["mcf", "scf"]:
return

# Skip the test for nsf on mps:0 as it results in NaNs.
if device == "mps:0" and num_dim > 1 and q == "nsf":
return

# Good run where everythink is one the correct device.
class FakePotential(BasePotential):
def __call__(self, theta, **kwargs):
Expand All @@ -530,9 +535,7 @@ def allow_iid_x(self) -> bool:
posterior = VIPosterior(
potential_fn=potential_fn, theta_transform=theta_transform, q=q, device=device
)
posterior.set_default_x(
torch.tensor(np.zeros((num_dim,)).astype(np.float32)).to(device)
)
posterior.set_default_x(torch.zeros((num_dim,), dtype=torch.float32).to(device))
posterior.vi_method = vi_method

posterior.train(min_num_iters=9, max_num_iters=10, warm_up_rounds=10)
Expand Down
12 changes: 11 additions & 1 deletion tests/torchutils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

"""Test PyTorch utility functions."""
from __future__ import annotations

from typing import Optional

import numpy as np
Expand Down Expand Up @@ -200,7 +201,16 @@ def test_dkl_gauss():
)


@pytest.mark.parametrize("device_input", ("cpu", "gpu", "cuda", "cuda:0", "mps", ))
@pytest.mark.parametrize(
"device_input",
(
"cpu",
"gpu",
"cuda",
"cuda:0",
"mps",
),
)
def test_process_device(device_input: str) -> None:
"""Test whether the device is processed correctly."""

Expand Down

0 comments on commit 9d9b128

Please sign in to comment.