diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index 3708c2c5..264c5130 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -30,7 +30,7 @@ jobs: - name: install pymc and poetry run: | mamba info - mamba install -c conda-forge pymc=5.12 poetry + mamba install -c conda-forge pymc poetry - name: install hssm run: | diff --git a/src/hssm/defaults.py b/src/hssm/defaults.py index 40ed155f..7561c4ab 100644 --- a/src/hssm/defaults.py +++ b/src/hssm/defaults.py @@ -184,7 +184,7 @@ class DefaultConfig(TypedDict): }, "full_ddm": { "response": ["rt", "response"], - "list_params": ["v", "a", "z", "t", "sv", "sz", "st"], + "list_params": ["v", "a", "z", "t", "sz", "sv", "st"], "description": "The full Drift Diffusion Model (DDM)", "likelihoods": { "blackbox": { diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py index 4a0fff17..97c520cf 100644 --- a/src/hssm/hssm.py +++ b/src/hssm/hssm.py @@ -60,7 +60,7 @@ class HSSM: - """The Hierarchical Sequential Sampling Model (HSSM) class. + """The basic Hierarchical Sequential Sampling Model (HSSM) class. Parameters ---------- @@ -497,6 +497,17 @@ def sample( inference_method=sampler, init=init, **kwargs ) + # The parent was previously not part of deterministics --> compute it via + # posterior_predictive (works because it acts as the 'mu' parameter + # in the GLM as far as bambi is concerned) + if self._inference_obj is not None: + if self._parent not in self._inference_obj.posterior.data_vars.keys(): + self.model.predict(self._inference_obj, kind="mean", inplace=True) + # rename 'rt,response_mean' to 'v' so in the traces everything + # looks the way it should + self._inference_obj.rename_vars( + {"rt,response_mean": self._parent}, inplace=True + ) return self.traces def sample_posterior_predictive( @@ -526,13 +537,13 @@ def sample_posterior_predictive( If `True` will make predictions including the group specific effects. Otherwise, predictions are made with common effects only (i.e. group- specific are set to zero), by default True. - kind + kind: optional Indicates the type of prediction required. Can be `"mean"` or `"pps"`. The first returns draws from the posterior distribution of the mean, while the latter returns the draws from the posterior predictive distribution (i.e. the posterior probability distribution for a new observation). Defaults to `"pps"`. - n_samples + n_samples: optional The number of samples to draw from the posterior predictive distribution from each chain. When it's an integer >= 1, the number of samples to be extracted from the @@ -1308,11 +1319,18 @@ def _get_deterministic_var_names(self, idata) -> list[str]: var_names = [ f"~{param_name}" for param_name, param in self.params.items() - if param.is_regression and not param.is_parent + if param.is_regression ] + # Handle specific case where parent is not explictly in traces + if ("~" + self._parent in var_names) and ( + self._parent not in idata.posterior.data_vars + ): + var_names.remove("~" + self._parent) + if f"{self.response_str}_mean" in idata["posterior"].data_vars: var_names.append(f"~{self.response_str}_mean") + return var_names def _handle_missing_data_and_deadline(self): diff --git a/src/hssm/likelihoods/blackbox.py b/src/hssm/likelihoods/blackbox.py index c094c056..4aa91ca4 100644 --- a/src/hssm/likelihoods/blackbox.py +++ b/src/hssm/likelihoods/blackbox.py @@ -61,7 +61,7 @@ def logp_ddm_sdv_bbox(data: np.ndarray, v, a, z, t, sv) -> np.ndarray: @hddm_to_hssm -def logp_full_ddm(data: np.ndarray, v, a, z, t, sv, sz, st): +def logp_full_ddm(data: np.ndarray, v, a, z, t, sz, sv, st): """Compute blackbox log-likelihoods for full_ddm models.""" return wfpt.wiener_logp_array( x=data, diff --git a/src/hssm/prior.py b/src/hssm/prior.py index 70d3bfab..fcab0d80 100644 --- a/src/hssm/prior.py +++ b/src/hssm/prior.py @@ -351,11 +351,3 @@ def get_hddm_default_prior( "sz": {"dist": "Gamma", "mu": HDDM_MU["sz"], "sigma": HDDM_SIGMA["sz"]}, "st": {"dist": "Gamma", "mu": HDDM_MU["st"], "sigma": HDDM_SIGMA["st"]}, } - -# INITVAL_SETTINGS_LOGIT: dict[Any, Any] = { -# "t" : {"initval": -4.0}, -# } - -# INITVAL_SETTINGS_NOLINK: dict[Any, Any] = { -# "t" : {"initval": 0.05}, -# } diff --git a/tests/slow/test_mcmc.py b/tests/slow/test_mcmc.py index f0f8728a..83b24347 100644 --- a/tests/slow/test_mcmc.py +++ b/tests/slow/test_mcmc.py @@ -10,7 +10,7 @@ from hssm.utils import _rearrange_data -hssm.set_floatX("float32") +hssm.set_floatX("float32", jax=True) # AF-TODO: Include more tests that use different link functions! @@ -124,6 +124,8 @@ def run_sample(model, sampler, step, expected): @pytest.mark.parametrize(parameter_names, parameter_grid) def test_simple_models(data_ddm, loglik_kind, backend, sampler, step, expected): + print("PYMC VERSION: ") + print(pm.__version__) print("TEST INPUTS WERE: ") print("REPORTING FROM SIMPLE MODELS TEST") print(loglik_kind, backend, sampler, step, expected) @@ -147,6 +149,8 @@ def test_simple_models(data_ddm, loglik_kind, backend, sampler, step, expected): @pytest.mark.parametrize(parameter_names, parameter_grid) def test_reg_models(data_ddm_reg, loglik_kind, backend, sampler, step, expected): + print("PYMC VERSION: ") + print(pm.__version__) print("TEST INPUTS WERE: ") print("REPORTING FROM REG MODELS TEST") print(loglik_kind, backend, sampler, step, expected) @@ -169,7 +173,7 @@ def test_reg_models(data_ddm_reg, loglik_kind, backend, sampler, step, expected) # Only runs once if loglik_kind == "analytical" and sampler is None: - assert not model._get_deterministic_var_names(model.traces) + assert model._get_deterministic_var_names(model.traces) == ["~v"] # test summary: summary = model.summary() assert summary.shape[0] == 6 @@ -181,6 +185,8 @@ def test_reg_models(data_ddm_reg, loglik_kind, backend, sampler, step, expected) @pytest.mark.parametrize(parameter_names, parameter_grid) def test_reg_models_v_a(data_ddm_reg, loglik_kind, backend, sampler, step, expected): + print("PYMC VERSION: ") + print(pm.__version__) print("TEST INPUTS WERE: ") print("REPORTING FROM REG MODELS V_A TEST") print(loglik_kind, backend, sampler, step, expected) @@ -218,7 +224,12 @@ def test_reg_models_v_a(data_ddm_reg, loglik_kind, backend, sampler, step, expec # Only runs once if loglik_kind == "analytical" and sampler is None: - assert model._get_deterministic_var_names(model.traces) == ["~a"] + assert len(model._get_deterministic_var_names(model.traces)) == len( + ["~a", "~v"] + ) + assert set(model._get_deterministic_var_names(model.traces)) == set( + ["~a", "~v"] + ) # test summary: summary = model.summary() assert summary.shape[0] == 8 @@ -253,6 +264,8 @@ def test_reg_models_v_a(data_ddm_reg, loglik_kind, backend, sampler, step, expec def test_simple_models_missing_data( data_ddm_missing, loglik_kind, backend, sampler, step, expected, cpn ): + print("PYMC VERSION: ") + print(pm.__version__) print("TEST INPUTS WERE: ") print("REPORTING FROM SIMPLE MODELS MISSING DATA TEST") print(loglik_kind, backend, sampler, step, expected) @@ -271,6 +284,8 @@ def test_simple_models_missing_data( def test_reg_models_missing_data( data_ddm_reg_missing, loglik_kind, backend, sampler, step, expected, cpn ): + print("PYMC VERSION: ") + print(pm.__version__) print("TEST INPUTS WERE: ") print("REPORTING FROM REG MODELS MISSING DATA TEST") print(loglik_kind, backend, sampler, step, expected) @@ -298,6 +313,8 @@ def test_reg_models_missing_data( def test_simple_models_deadline( data_ddm_deadline, loglik_kind, backend, sampler, step, expected, opn ): + print("PYMC VERSION: ") + print(pm.__version__) print("TEST INPUTS WERE: ") print("REPORTING FROM SIMPLE MODELS DEADLINE TEST") print(loglik_kind, backend, sampler, step, expected) @@ -315,6 +332,8 @@ def test_simple_models_deadline( def test_reg_models_deadline( data_ddm_reg_deadline, loglik_kind, backend, sampler, step, expected, opn ): + print("PYMC VERSION: ") + print(pm.__version__) print("TEST INPUTS WERE: ") print("REPORTING FROM REG MODELS DEADLINE TEST") print(loglik_kind, backend, sampler, step, expected) diff --git a/tests/test_hssm.py b/tests/test_hssm.py index 0a660303..76c0e3d7 100644 --- a/tests/test_hssm.py +++ b/tests/test_hssm.py @@ -7,7 +7,7 @@ from hssm.utils import download_hf from hssm.likelihoods import DDM, logp_ddm -hssm.set_floatX("float32") +hssm.set_floatX("float32", jax=True) param_v = { "name": "v",