Skip to content

Commit

Permalink
Merge pull request #269 from lnccbrown/fix-posterior-predictive
Browse files Browse the repository at this point in the history
Fixed a bug in sample_posterior_predictive()
  • Loading branch information
digicosmos86 authored Sep 1, 2023
2 parents d1b4f7b + d982f81 commit 9ee29a3
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 5 deletions.
7 changes: 5 additions & 2 deletions src/hssm/distribution_utils/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,9 @@ def rng_fn(
size = args[-1]
args = args[:-1]

if size is None:
size = 1

# Although we got around the ndims_supp issue, the size parameter passed
# here is still an array with one element. We need to take it out.
if not np.isscalar(size):
Expand Down Expand Up @@ -225,7 +228,7 @@ def rng_fn(
# All parameters are scalars

theta = np.stack(arg_arrays)
n_samples = 1 if not size else size
n_samples = size
else:
# Preprocess all parameters, reshape them into a matrix of dimension
# (size, n_params) where size is the number of elements in the largest
Expand All @@ -240,7 +243,7 @@ def rng_fn(
[np.broadcast_to(arg, max_shape).reshape(-1) for arg in arg_arrays]
)

if size is None:
if size is None or size == 1:
n_samples = 1
elif size % new_data_size != 0:
raise ValueError(
Expand Down
6 changes: 3 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ def data_angle():
@pytest.fixture(scope="module")
def data_ddm_reg():
# Generate some fake simulation data
intercept = 0.3
x = np.random.uniform(0.5, 0.7, size=1000)
y = np.random.uniform(0.4, 0.1, size=1000)
intercept = 1.5
x = np.random.uniform(-5.0, 5.0, size=1000)
y = np.random.uniform(-5.0, 5.0, size=1000)

v = intercept + 0.8 * x + 0.3 * y
true_values = np.column_stack(
Expand Down
8 changes: 8 additions & 0 deletions tests/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

def test_non_reg_models(data_ddm):
model1 = hssm.HSSM(data_ddm)
model1.sample_prior_predictive(draws=10)

model1.sample(cores=1, chains=1, tune=10, draws=10)
model1.sample(sampler="nuts_numpyro", cores=1, chains=1, tune=10, draws=10)

Expand All @@ -18,6 +20,8 @@ def test_non_reg_models(data_ddm):
model3.sample(cores=1, chains=1, tune=10, draws=10)
model3.sample(cores=1, chains=1, tune=10, draws=10)

model1.sample_posterior_predictive(data=data_ddm.iloc[:10, :])


def test_reg_models(data_ddm_reg):
param_reg = dict(
Expand All @@ -30,6 +34,8 @@ def test_reg_models(data_ddm_reg):
)

model1 = hssm.HSSM(data_ddm_reg, v=param_reg)
model1.sample_prior_predictive(draws=10)

model1.sample(cores=1, chains=1, tune=10, draws=10)
model1.sample(sampler="nuts_numpyro", cores=1, chains=1, tune=10, draws=10)

Expand All @@ -43,6 +49,8 @@ def test_reg_models(data_ddm_reg):
with pytest.raises(ValueError):
model3.sample(sampler="nuts_numpyro", cores=1, chains=1, tune=10, draws=10)

model1.sample_posterior_predictive(data=data_ddm_reg.iloc[:10, :])


def test_reg_models_a(data_ddm_reg):
param_reg = dict(
Expand Down

0 comments on commit 9ee29a3

Please sign in to comment.