Skip to content

Commit

Permalink
refactor mnle and snle tests
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Jan 30, 2024
1 parent 9e5385f commit dfc375a
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 17 deletions.
4 changes: 1 addition & 3 deletions tests/linearGaussian_snle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def test_c2st_and_map_snl_on_linearGaussian_different(num_dim: int, prior_str: s
"""
num_samples = 500
num_simulations = 4500
num_simulations = 5000
trials_to_test = [1]

# likelihood_mean will be likelihood_shift+theta
Expand Down Expand Up @@ -219,8 +219,6 @@ def test_c2st_and_map_snl_on_linearGaussian_different(num_dim: int, prior_str: s
show_progress_bars=False,
)

# TODO: we do not have a test for SNL log_prob(). This is because the output
# TODO: density is not normalized, so KLd does not make sense.
if prior_str == "uniform":
# Check whether the returned probability outside of the support is zero.
posterior_prob = get_prob_outside_uniform_prior(posterior, prior, num_dim)
Expand Down
21 changes: 7 additions & 14 deletions tests/mnle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@

import pytest
import torch
from numpy import isin
from pyro.distributions import InverseGamma
from torch.distributions import Beta, Binomial, Categorical, Gamma

from sbi.inference import MNLE, MCMCPosterior, likelihood_estimator_based_potential
from sbi.inference import MNLE, MCMCPosterior
from sbi.inference.posteriors.rejection_posterior import RejectionPosterior
from sbi.inference.posteriors.vi_posterior import VIPosterior
from sbi.inference.potentials.base_potential import BasePotential
Expand All @@ -23,6 +22,7 @@

# toy simulator for mixed data
def mixed_simulator(theta, stimulus_condition=2.0):
"""Simulator for mixed data."""
# Extract parameters
beta, ps = theta[:, :1], theta[:, 1:]

Expand All @@ -35,6 +35,7 @@ def mixed_simulator(theta, stimulus_condition=2.0):
return torch.cat((rts, choices), dim=1)


# MCMC kwargs for faster testing
mcmc_kwargs = dict(
num_chains=20,
warmup_steps=50,
Expand All @@ -47,6 +48,7 @@ def mixed_simulator(theta, stimulus_condition=2.0):
@pytest.mark.gpu
@pytest.mark.parametrize("device", ("cpu", "cuda"))
def test_mnle_on_device(device):
"""Test MNLE API on device."""
# Generate mixed data.
num_simulations = 100
mcmc_method = "slice"
Expand Down Expand Up @@ -78,6 +80,7 @@ def test_mnle_on_device(device):

@pytest.mark.parametrize("sampler", ("mcmc", "rejection", "vi"))
def test_mnle_api(sampler):
"""Test MNLE API."""
# Generate mixed data.
num_simulations = 100
theta = torch.rand(num_simulations, 2)
Expand Down Expand Up @@ -124,16 +127,6 @@ def test_mnle_accuracy_with_different_samplers_and_trials(sampler, num_trials: i
num_simulations = 2000
num_samples = 500

def mixed_simulator(theta):
# Extract parameters
beta, ps = theta[:, :1], theta[:, 1:]

# Sample choices and rts independently.
choices = Binomial(probs=ps).sample()
rts = InverseGamma(concentration=1 * torch.ones_like(beta), rate=beta).sample()

return torch.cat((rts, choices), dim=1)

prior = MultipleIndependent(
[
Gamma(torch.tensor([1.0]), torch.tensor([0.5])),
Expand All @@ -143,7 +136,7 @@ def mixed_simulator(theta):
)

theta = prior.sample((num_simulations,))
x = mixed_simulator(theta)
x = mixed_simulator(theta, stimulus_condition=1.0)

# MNLE
trainer = MNLE(prior)
Expand Down Expand Up @@ -236,7 +229,7 @@ def test_mnle_with_experimental_conditions():
categorical parameter is set to a fixed value (conditioned posterior), and the
accuracy of the conditioned posterior is tested against the true posterior.
"""
num_simulations = 5000
num_simulations = 6000
num_samples = 500

def sim_wrapper(theta):
Expand Down

0 comments on commit dfc375a

Please sign in to comment.