Skip to content

Commit

Permalink
fixes for rebase on main; remove pbars in tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Feb 9, 2024
1 parent 815f07e commit 9ce7cfa
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 15 deletions.
12 changes: 7 additions & 5 deletions sbi/utils/torchutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

"""Various PyTorch utility functions."""

import logging
import os
import warnings
from typing import Any, Optional, Union
Expand All @@ -25,9 +26,6 @@ def process_device(device: str) -> str:
device: processed string, e.g., "cuda" is mapped to "cuda:0".
"""

# NOTE: we might want to add support for other devices in the future, e.g., MPS
gpu_devices = ["gpu", "cuda", "mps"]

if device == "cpu":
return "cpu"
else:
Expand All @@ -38,8 +36,8 @@ def process_device(device: str) -> str:
"only for large neural networks with operations that are fast on the "
"GPU, e.g., for a CNN or RNN `embedding_net`."
)
if device in gpu_devices:
# check whether either pytorch cuda or mps is available
# If device is "gpu", check whether either pytorch cuda or mps is available.
if device == "gpu":
if torch.cuda.is_available():
current_gpu_index = torch.cuda.current_device()
device = f"cuda:{current_gpu_index}"
Expand All @@ -53,6 +51,9 @@ def process_device(device: str) -> str:
# MPS framework does not support double precision.
torch.set_default_dtype(torch.float32)
check_device(device)
warnings.warn(

Check warning on line 54 in sbi/utils/torchutils.py

View check run for this annotation

Codecov / codecov/patch

sbi/utils/torchutils.py#L52-L54

Added lines #L52 - L54 were not covered by tests
"Using MPS as a device for training the neural network. Note MPS support is not implemented for a number of operations and requires setting `PYTORCH_ENABLE_MPS_FALLBACK=1` and `torch.set_default_dtype(torch.float32)`. MPS is not supported for double precision."
)
else:
raise RuntimeError(
"Neither CUDA nor MPS is available. "
Expand All @@ -63,6 +64,7 @@ def process_device(device: str) -> str:
else:
check_device(device)

logging.info(f"Using device: {device}")
return device

Check warning on line 68 in sbi/utils/torchutils.py

View check run for this annotation

Codecov / codecov/patch

sbi/utils/torchutils.py#L67-L68

Added lines #L67 - L68 were not covered by tests


Expand Down
33 changes: 23 additions & 10 deletions tests/inference_on_device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,13 @@ def simulator(theta):
)

proposals = [prior]
theta = prior.sample((num_simulations,))
# ImportanceSamplingPosterior does not support progress bars
sampling_kwargs = (
{} if sampling_method == "importance" else dict(show_progress_bars=False)
)

for _ in range(num_rounds):
theta = proposals[-1].sample((num_simulations,))
x = simulator(theta).to(data_device)
theta = theta.to(data_device)

Expand Down Expand Up @@ -178,11 +182,12 @@ def simulator(theta):
potential_fn, prior, theta_transform
)
proposals.append(posterior.set_default_x(x_o))
theta = proposals[-1].sample((num_samples,), **sampling_kwargs)

# Check for default device for inference object
weights_device = next(inferer._neural_net.parameters()).device
assert torch.device(training_device) == weights_device
samples = proposals[-1].sample(sample_shape=(num_samples,))
samples = proposals[-1].sample(sample_shape=(num_samples,), **sampling_kwargs)
proposals[-1].potential(samples)


Expand Down Expand Up @@ -299,7 +304,7 @@ def test_train_with_different_data_and_training_device(
device=training_device,
)

theta, x = simulate_for_sbi(simulator, prior, 32)
theta, x = simulate_for_sbi(simulator, prior, 32, show_progress_bar=False)
theta, x = theta.to(data_device), x.to(data_device)
x_o = torch.zeros(x.shape[1])
inference = inference.append_simulations(theta, x, data_device=data_device)
Expand Down Expand Up @@ -399,17 +404,17 @@ def test_embedding_nets_integration_training_device(
train_kwargs = dict(force_first_round_loss=True)

with pytest.raises(Exception) if prior_device != training_device else nullcontext():
inference = inference_method(prior=prior, **nn_kwargs, device=training_device)
inference = inference_method(
prior=prior, **nn_kwargs, device=training_device, show_progress_bars=False
)

if prior_device != training_device:
pytest.xfail("We do not correct the case of invalid prior device")

theta = prior.sample((samples_per_round,)).to(data_device)

proposal = prior
for _ in range(num_rounds):
# sample theta and x independently - quick way to get 3D simulation data.
theta = proposal.sample((samples_per_round,))
x = (
MultivariateNormal(torch.zeros((x_dim,)), torch.eye(x_dim))
.sample((samples_per_round,))
Expand All @@ -429,10 +434,18 @@ def test_embedding_nets_integration_training_device(

posterior = inference.build_posterior(
density_estimator_train,
mcmc_method="slice_np_vectorized",
mcmc_parameters=dict(thin=5, num_chains=10, warmup_steps=10),
# NOTE: SNPE_A only support DirectPosterior sampling.
**(
{}
if inference_method == SNPE_A
else dict(
mcmc_method="slice_np_vectorized",
mcmc_parameters=dict(thin=5, num_chains=10, warmup_steps=20),
)
),
)
proposal = posterior.set_default_x(x_o)
theta = proposal.sample((samples_per_round,), show_progress_bars=False)


@pytest.mark.parametrize(
Expand All @@ -458,7 +471,7 @@ def test_nograd_after_inference_train(inference_method) -> None:
show_progress_bars=False,
)

theta, x = simulate_for_sbi(simulator, prior, 32)
theta, x = simulate_for_sbi(simulator, prior, 32, show_progress_bar=False)
inference = inference.append_simulations(theta, x)

posterior_estimator = inference.train(max_num_epochs=2)
Expand Down Expand Up @@ -516,7 +529,7 @@ def allow_iid_x(self) -> bool:
posterior.vi_method = vi_method

posterior.train(min_num_iters=9, max_num_iters=10, warm_up_rounds=10)
samples = posterior.sample((1,), method=sampling_method)
samples = posterior.sample((1,), method=sampling_method, show_progress_bars=False)
logprobs = posterior.log_prob(samples)

assert (
Expand Down

0 comments on commit 9ce7cfa

Please sign in to comment.