diff --git a/src/hssm/distribution_utils/dist.py b/src/hssm/distribution_utils/dist.py index b14ce46a..24d9c350 100644 --- a/src/hssm/distribution_utils/dist.py +++ b/src/hssm/distribution_utils/dist.py @@ -191,6 +191,9 @@ def rng_fn( size = args[-1] args = args[:-1] + if size is None: + size = 1 + # Although we got around the ndims_supp issue, the size parameter passed # here is still an array with one element. We need to take it out. if not np.isscalar(size): @@ -225,7 +228,7 @@ def rng_fn( # All parameters are scalars theta = np.stack(arg_arrays) - n_samples = 1 if not size else size + n_samples = size else: # Preprocess all parameters, reshape them into a matrix of dimension # (size, n_params) where size is the number of elements in the largest @@ -240,7 +243,7 @@ def rng_fn( [np.broadcast_to(arg, max_shape).reshape(-1) for arg in arg_arrays] ) - if size is None: + if size is None or size == 1: n_samples = 1 elif size % new_data_size != 0: raise ValueError( diff --git a/tests/conftest.py b/tests/conftest.py index be019f7e..40238836 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -33,9 +33,9 @@ def data_angle(): @pytest.fixture(scope="module") def data_ddm_reg(): # Generate some fake simulation data - intercept = 0.3 - x = np.random.uniform(0.5, 0.7, size=1000) - y = np.random.uniform(0.4, 0.1, size=1000) + intercept = 1.5 + x = np.random.uniform(-5.0, 5.0, size=1000) + y = np.random.uniform(-5.0, 5.0, size=1000) v = intercept + 0.8 * x + 0.3 * y true_values = np.column_stack( diff --git a/tests/test_mcmc.py b/tests/test_mcmc.py index 4186d706..f3caa2d4 100644 --- a/tests/test_mcmc.py +++ b/tests/test_mcmc.py @@ -7,6 +7,8 @@ def test_non_reg_models(data_ddm): model1 = hssm.HSSM(data_ddm) + model1.sample_prior_predictive(draws=10) + model1.sample(cores=1, chains=1, tune=10, draws=10) model1.sample(sampler="nuts_numpyro", cores=1, chains=1, tune=10, draws=10) @@ -18,6 +20,8 @@ def test_non_reg_models(data_ddm): model3.sample(cores=1, chains=1, tune=10, draws=10) model3.sample(cores=1, chains=1, tune=10, draws=10) + model1.sample_posterior_predictive(data=data_ddm.iloc[:10, :]) + def test_reg_models(data_ddm_reg): param_reg = dict( @@ -30,6 +34,8 @@ def test_reg_models(data_ddm_reg): ) model1 = hssm.HSSM(data_ddm_reg, v=param_reg) + model1.sample_prior_predictive(draws=10) + model1.sample(cores=1, chains=1, tune=10, draws=10) model1.sample(sampler="nuts_numpyro", cores=1, chains=1, tune=10, draws=10) @@ -43,6 +49,8 @@ def test_reg_models(data_ddm_reg): with pytest.raises(ValueError): model3.sample(sampler="nuts_numpyro", cores=1, chains=1, tune=10, draws=10) + model1.sample_posterior_predictive(data=data_ddm_reg.iloc[:10, :]) + def test_reg_models_a(data_ddm_reg): param_reg = dict(