From 78c69939073f9702ef0b9bca1a1ced30cc980d20 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Wed, 23 Aug 2023 10:47:21 -0400 Subject: [PATCH 01/16] Change pyproject.toml to add hddm_wfpt as required dependency. --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index bc6d764c..54d20618 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ huggingface-hub = "^0.15.1" onnxruntime = "^1.15.0" bambi = "^0.12.0" numpyro = "^0.12.1" +hddm-wfpt = {git = "https://github.com/brown-ccv/hddm-wfpt.git"} [tool.poetry.group.dev.dependencies] pytest = "^7.3.1" @@ -36,7 +37,6 @@ mypy = "^1.4.1" pre-commit = "^2.20.0" jupyterlab = "^4.0.2" ipykernel = "^6.16.0" -hddm-wfpt = { git = "https://github.com/brown-ccv/hddm-wfpt.git" } ipywidgets = "^8.0.3" graphviz = "^0.20.1" ruff = "^0.0.272" From 0d6eca0902f041e0c7d1981611010d6bbbb248d6 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Wed, 23 Aug 2023 10:47:45 -0400 Subject: [PATCH 02/16] add blackbox likelihood functions --- src/hssm/likelihoods/blackbox.py | 53 ++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 src/hssm/likelihoods/blackbox.py diff --git a/src/hssm/likelihoods/blackbox.py b/src/hssm/likelihoods/blackbox.py new file mode 100644 index 00000000..78bddc9b --- /dev/null +++ b/src/hssm/likelihoods/blackbox.py @@ -0,0 +1,53 @@ +"""Black box likelihoods written in Cython for "ddm" and "ddm_sdv" models.""" + +from __future__ import annotations + +import hddm_wfpt +import numpy as np + + +def logp_ddm_bbox(data: np.ndarray, v, a, z, t) -> np.ndarray: + """Compute blackbox log-likelihoods for ddm models.""" + x = (data[:, 0] * data[:, 1]).astype(np.float64) + size = len(data) + + v, a, z, t = [_broadcast(param, size) for param in [v, a, z, t]] + zeros = np.zeros(size, dtype=np.float64) + + return hddm_wfpt.wfpt.wiener_logp_array( + x=x, + v=v, + sv=zeros, + a=a * 2, # Ensure compatibility with HSSM. + z=z, + sz=zeros, + t=t, + st=zeros, + err=1e-8, + ).astype(data.dtype) + + +def logp_ddm_sdv_bbox(data: np.ndarray, v, a, z, t, sv) -> np.ndarray: + """Compute blackbox log-likelihoods for ddm models.""" + x = (data[:, 0] * data[:, 1]).astype(np.float64) + size = len(x) + + v, a, z, t, sv = [_broadcast(param, size) for param in [v, a, z, t, sv]] + zeros = np.zeros(size, dtype=np.float64) + + return hddm_wfpt.wfpt.wiener_logp_array( + x=x, + v=v, + sv=zeros, + a=a * 2, # Ensure compatibility with HSSM. + z=z, + sz=zeros, + t=t, + st=zeros, + err=1e-8, + ).astype(data.dtype) + + +def _broadcast(x: float | np.ndarray, size: int): + """Broadcast a scalar or an array to size of `size`.""" + return np.broadcast_to(np.array(x, dtype=np.float64), size) From 27f78751fae2a4da58c8e4e96b0354513e8879ca Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Wed, 23 Aug 2023 10:48:10 -0400 Subject: [PATCH 03/16] Added tests for blackbox likelihood functions --- tests/test_likelihoods.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/tests/test_likelihoods.py b/tests/test_likelihoods.py index 1ace7a8b..759cb3b2 100644 --- a/tests/test_likelihoods.py +++ b/tests/test_likelihoods.py @@ -10,7 +10,8 @@ from numpy.random import rand # pylint: disable=C0413 -from hssm.likelihoods.analytical import compare_k, logp_ddm_sdv +from hssm.likelihoods.analytical import compare_k, logp_ddm, logp_ddm_sdv +from hssm.likelihoods.blackbox import logp_ddm_bbox, logp_ddm_sdv_bbox def test_kterm(data_ddm): @@ -97,3 +98,18 @@ def test_no_inf_values_v(data_ddm, shared_params): assert np.all( np.isfinite(logp.eval()) ), f"log_pdf_sv() returned non-finite values for v = {v}." + + +def test_bbox(data_ddm): + true_values = (0.5, 1.5, 0.5, 0.5) + true_values_sdv = (0.5, 1.5, 0.5, 0.5, 0) + data = data_ddm.values + + np.testing.assert_almost_equal( + logp_ddm(data, *true_values).eval(), logp_ddm_bbox(data, *true_values) + ) + + np.testing.assert_almost_equal( + logp_ddm_sdv(data, *true_values_sdv).eval(), + logp_ddm_sdv_bbox(data, *true_values_sdv), + ) From cb5d7fb0b494c3de54d740eab89475a3b456dcb0 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Wed, 23 Aug 2023 10:50:08 -0400 Subject: [PATCH 04/16] Add blackbox likelihood functions to __init__.py --- src/hssm/likelihoods/__init__.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/hssm/likelihoods/__init__.py b/src/hssm/likelihoods/__init__.py index bbd689d9..13669129 100644 --- a/src/hssm/likelihoods/__init__.py +++ b/src/hssm/likelihoods/__init__.py @@ -1,5 +1,13 @@ """Likelihood functions and distributions that use them.""" from .analytical import DDM, DDM_SDV, logp_ddm, logp_ddm_sdv +from .blackbox import logp_ddm_bbox, logp_ddm_sdv_bbox -__all__ = ["logp_ddm", "logp_ddm_sdv", "DDM", "DDM_SDV"] +__all__ = [ + "logp_ddm", + "logp_ddm_sdv", + "DDM", + "DDM_SDV", + "logp_ddm_bbox", + "logp_ddm_sdv_bbox", +] From d7809e29023f1ccb5eb8119330c86d430c941381 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Wed, 23 Aug 2023 13:56:21 -0400 Subject: [PATCH 05/16] Added blackbox defaults to ddm and ddm_sdv models --- src/hssm/defaults.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/src/hssm/defaults.py b/src/hssm/defaults.py index a273de7a..a5e31e63 100644 --- a/src/hssm/defaults.py +++ b/src/hssm/defaults.py @@ -13,6 +13,7 @@ logp_ddm, logp_ddm_sdv, ) +from .likelihoods.blackbox import logp_ddm_bbox, logp_ddm_sdv_bbox from .param import ParamSpec, _make_default_prior LogLik = Union[str, PathLike, Callable, Op, type[Distribution]] @@ -81,6 +82,18 @@ class DefaultConfig(TypedDict): "t": (0.0, 2.0), }, }, + "blackbox": { + "loglik": logp_ddm_bbox, + "backend": None, + "bounds": ddm_bounds, + "default_priors": { + "t": { + "name": "HalfNormal", + "sigma": 2.0, + "initval": 0.1, + }, + }, + }, }, }, "ddm_sdv": { @@ -111,6 +124,18 @@ class DefaultConfig(TypedDict): "sv": (0.0, 1.0), }, }, + "blackbox": { + "loglik": logp_ddm_sdv_bbox, + "backend": None, + "bounds": ddm_bounds, + "default_priors": { + "t": { + "name": "HalfNormal", + "sigma": 2.0, + "initval": 0.1, + }, + }, + }, }, }, "angle": { From 87bfeb397ef8128d0d742bad54b3dfa25df87222 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Wed, 23 Aug 2023 13:57:43 -0400 Subject: [PATCH 06/16] set parent to be first regression parameter --- src/hssm/hssm.py | 101 ++++++++++++++++++++++++++++++++++++++-------- src/hssm/param.py | 43 -------------------- 2 files changed, 84 insertions(+), 60 deletions(-) diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py index 60cafa3e..63cd6c28 100644 --- a/src/hssm/hssm.py +++ b/src/hssm/hssm.py @@ -27,7 +27,6 @@ from hssm.param import ( Param, _make_default_prior, - _parse_bambi, ) from hssm.utils import ( HSSMModelGraph, @@ -211,7 +210,6 @@ def __init__( self.model_name = self.model_config.model_name self.loglik = self.model_config.loglik self.loglik_kind = self.model_config.loglik_kind - self._parent = self.list_params[0] # Process lapse distribution self.has_lapse = p_outlier is not None and p_outlier != 0 @@ -225,15 +223,17 @@ def __init__( ) # Process parameter specifications include - processed = self._process_include(include) + processed = self._preprocess_include(include) # Process parameter specifications not in include - self.params = self._process_rest(processed) + self.params = self._preprocess_rest(processed) + # Find the parent parameter + self._parent, self._parent_param = self._find_parent() + assert self._parent_param is not None - # Get the bambi formula, priors, and link - self.formula, self.priors, self.link = _parse_bambi(self.params) + self._process_all() - self._parent_param = self.params[self.list_params[0]] - assert self._parent_param is not None + # Get the bambi formula, priors, and link + self.formula, self.priors, self.link = self._parse_bambi() # For parameters that are regression, apply bounds at the likelihood level to # ensure that the samples that are out of bounds are discarded (replaced with @@ -250,6 +250,8 @@ def __init__( self.model_distribution = self._make_model_distribution() + print(self.formula, self._parent) + self.family = make_family( self.model_distribution, self.list_params, @@ -674,7 +676,7 @@ def _add_kwargs_and_p_outlier_to_include( return include, other_kwargs - def _process_include(self, include: list[dict | Param]) -> dict[str, Param]: + def _preprocess_include(self, include: list[dict | Param]) -> dict[str, Param]: """Turn parameter specs in include into Params.""" result: dict[str, Param] = {} @@ -693,12 +695,10 @@ def _process_include(self, include: list[dict | Param]) -> dict[str, Param]: if isinstance(param_with_default, dict) else param_with_default ) - if name == self._parent: - result[name].set_parent() return result - def _process_rest(self, processed: dict[str, Param]) -> dict[str, Param]: + def _preprocess_rest(self, processed: dict[str, Param]) -> dict[str, Param]: """Turn parameter specs not in include into Params.""" not_in_include = {} @@ -716,20 +716,87 @@ def _process_rest(self, processed: dict[str, Param]) -> dict[str, Param]: prior, bounds = self.model_config.get_defaults(param_str) param = Param(param_str, prior=prior, bounds=bounds) param.do_not_truncate() - if param_str == self._parent: - param.set_parent() not_in_include[param_str] = param processed |= not_in_include sorted_params = {} for param_name in self.list_params: - processed_param = processed[param_name] - processed_param.convert() - sorted_params[param_name] = processed_param + sorted_params[param_name] = processed[param_name] return sorted_params + def _find_parent(self) -> tuple[str, Param]: + """Find the parent param for the model. + + The first param that has a regression will be set as parent. If none of the + params is a regression, then the first param will be set as parent. + + Returns + ------- + str + The name of the param as string + Param + The parent Param object + """ + for param_str in self.list_params: + param = self.params[param_str] + if param.is_regression: + param.set_parent() + return param_str, param + + param_str = self.list_params[0] + param = self.params[param_str] + param.set_parent() + return param_str, param + + def _process_all(self): + """Process all params.""" + for param in self.list_params: + self.params[param].convert() + + def _parse_bambi( + self, + ) -> tuple[bmb.Formula, dict | None, dict[str, str | bmb.Link] | str]: + """Retrieve three items that helps with bambi model building. + + Returns + ------- + tuple + A tuple containing: + 1. A bmb.Formula object. + 2. A dictionary of priors, if any is specified. + 3. A dictionary of link functions, if any is specified. + """ + # Handle the edge case where list_params is empty: + if not self.params: + return bmb.Formula("c(rt, response) ~ 1"), None, "identity" + + parent_formula = None + other_formulas = [] + priors: dict[str, Any] = {} + links: dict[str, str | bmb.Link] = {} + + for _, param in self.params.items(): + formula, prior, link = param.parse_bambi() + + if param.is_parent: + parent_formula = formula + else: + if formula is not None: + other_formulas.append(formula) + if prior is not None: + priors |= prior + if link is not None: + links |= link + + assert parent_formula is not None + result_formula: bmb.Formula = bmb.Formula(parent_formula, *other_formulas) + result_priors = None if not priors else priors + result_links: dict | str = "identity" if not links else links + + return result_formula, result_priors, result_links + def _make_model_distribution(self) -> type[pm.Distribution]: """Make a pm.Distribution for the model.""" ### Logic for different types of likelihoods: diff --git a/src/hssm/param.py b/src/hssm/param.py index 7ecc7036..d562525f 100644 --- a/src/hssm/param.py +++ b/src/hssm/param.py @@ -377,49 +377,6 @@ def _make_priors_recursive(prior: dict[str, Any]) -> Prior: return bmb.Prior(**prior) -def _parse_bambi( - params: dict[str, Param], -) -> tuple[bmb.Formula, dict | None, dict[str, str | bmb.Link] | str]: - """From a dict of Params, retrieve three items that helps with bambi model building. - - Parameters - ---------- - params - A dict of Param objects. - - Returns - ------- - tuple - A tuple containing: - 1. A bmb.Formula object. - 2. A dictionary of priors, if any is specified. - 3. A dictionary of link functions, if any is specified. - """ - # Handle the edge case where list_params is empty: - if not params: - return bmb.Formula("c(rt, response) ~ 1"), None, "identity" - - formulas = [] - priors: dict[str, Any] = {} - links: dict[str, str | bmb.Link] = {} - - for _, param in params.items(): - formula, prior, link = param.parse_bambi() - - if formula is not None: - formulas.append(formula) - if prior is not None: - priors |= prior - if link is not None: - links |= link - - result_formula: bmb.Formula = bmb.Formula(formulas[0], *formulas[1:]) - result_priors = None if not priors else priors - result_links: dict | str = "identity" if not links else links - - return result_formula, result_priors, result_links - - def _make_bounded_prior( param_name: str, prior: ParamSpec, bounds: tuple[float, float] ) -> float | Prior: From 031e4048f0dd24727ab932b47d142be90c93e243 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Wed, 23 Aug 2023 13:57:57 -0400 Subject: [PATCH 07/16] modified tests --- tests/test_mcmc.py | 42 ++++++++++++++++++++++++++++++++++------ tests/test_param.py | 47 --------------------------------------------- tests/test_utils.py | 8 +++++--- 3 files changed, 41 insertions(+), 56 deletions(-) diff --git a/tests/test_mcmc.py b/tests/test_mcmc.py index f1c69f86..4186d706 100644 --- a/tests/test_mcmc.py +++ b/tests/test_mcmc.py @@ -1,3 +1,5 @@ +import pytest + import hssm hssm.set_floatX("float32") @@ -12,6 +14,10 @@ def test_non_reg_models(data_ddm): model2.sample(cores=1, chains=1, tune=10, draws=10) model2.sample(sampler="nuts_numpyro", cores=1, chains=1, tune=10, draws=10) + model3 = hssm.HSSM(data_ddm, loglik_kind="blackbox") + model3.sample(cores=1, chains=1, tune=10, draws=10) + model3.sample(cores=1, chains=1, tune=10, draws=10) + def test_reg_models(data_ddm_reg): param_reg = dict( @@ -26,14 +32,38 @@ def test_reg_models(data_ddm_reg): model1 = hssm.HSSM(data_ddm_reg, v=param_reg) model1.sample(cores=1, chains=1, tune=10, draws=10) model1.sample(sampler="nuts_numpyro", cores=1, chains=1, tune=10, draws=10) + model2 = hssm.HSSM(data_ddm_reg, loglik_kind="approx_differentiable", v=param_reg) model2.sample(cores=1, chains=1, tune=10, draws=10) - model2.sample(sampler="nuts_numpyro", cores=1, chains=1, tune=10, draws=10) + model2.sample(sampler="mcmc", cores=1, chains=1, tune=10, draws=10) + + model3 = hssm.HSSM(data_ddm_reg, loglik_kind="blackbox", v=param_reg) + model3.sample(cores=1, chains=1, tune=10, draws=10) + + with pytest.raises(ValueError): + model3.sample(sampler="nuts_numpyro", cores=1, chains=1, tune=10, draws=10) + + +def test_reg_models_a(data_ddm_reg): + param_reg = dict( + formula="a ~ 1 + x + y", + prior={ + "Intercept": {"name": "Uniform", "lower": -3.0, "upper": 3.0}, + "x": {"name": "Uniform", "lower": -0.50, "upper": 0.50}, + "y": {"name": "Uniform", "lower": -0.50, "upper": 0.50}, + }, + ) + + model1 = hssm.HSSM(data_ddm_reg, a=param_reg) + model1.sample(cores=1, chains=1, tune=10, draws=10) + model1.sample(sampler="nuts_numpyro", cores=1, chains=1, tune=10, draws=10) + + model2 = hssm.HSSM(data_ddm_reg, loglik_kind="approx_differentiable", a=param_reg) + model2.sample(cores=1, chains=1, tune=10, draws=10) + model2.sample(sampler="mcmc", cores=1, chains=1, tune=10, draws=10) - model3 = hssm.HSSM(data_ddm_reg, a=param_reg) + model3 = hssm.HSSM(data_ddm_reg, loglik_kind="blackbox", a=param_reg) model3.sample(cores=1, chains=1, tune=10, draws=10) - model3.sample(sampler="nuts_numpyro", cores=1, chains=1, tune=10, draws=10) - model4 = hssm.HSSM(data_ddm_reg, loglik_kind="approx_differentiable", a=param_reg) - model4.sample(cores=1, chains=1, tune=10, draws=10) - model4.sample(sampler="nuts_numpyro", cores=1, chains=1, tune=10, draws=10) + with pytest.raises(ValueError): + model3.sample(sampler="nuts_numpyro", cores=1, chains=1, tune=10, draws=10) diff --git a/tests/test_param.py b/tests/test_param.py index 20f8faf9..cbedf11d 100644 --- a/tests/test_param.py +++ b/tests/test_param.py @@ -6,58 +6,11 @@ from hssm.param import ( Param, _make_default_prior, - _parse_bambi, _make_priors_recursive, _make_bounded_prior, ) -def test__parse_bambi(): - prior_dict = {"name": "Uniform", "lower": 0.3, "upper": 1.0} - prior_obj = bmb.Prior("Uniform", lower=0.3, upper=1.0) - - param_parent_non_regression = Param("v", prior=prior_dict) - param_parent_regression = Param( - "v", - formula="1 + x1", - prior={ - "Intercept": prior_dict, - "x1": prior_dict, - }, - ) - - param_parent_non_regression.set_parent() - param_parent_non_regression.convert() - - param_parent_regression.set_parent() - param_parent_regression.convert() - - empty_dict = {} - dict_one_parent_non_regression = {"v": param_parent_non_regression} - dict_one_parent_regression = {"v": param_parent_regression} - - f0, p0, l0 = _parse_bambi(empty_dict) - - assert f0.main == "c(rt, response) ~ 1" - assert p0 is None - assert l0 == "identity" - - f3, p3, l3 = _parse_bambi(dict_one_parent_non_regression) - - assert f3.main == "c(rt, response) ~ 1" - assert p3 is not None - assert p3["c(rt, response)"]["Intercept"] == prior_obj - assert l3 == {"v": "identity"} - - f4, p4, l4 = _parse_bambi(dict_one_parent_regression) - - assert f4.main == "c(rt, response) ~ 1 + x1" - assert p4 is not None - assert p4["c(rt, response)"]["Intercept"] == prior_obj - assert p4["c(rt, response)"]["x1"] == prior_obj - assert l4 == {"v": "identity"} - - def test_param_creation_non_regression(): # Test different param creation strategies v = { diff --git a/tests/test_utils.py b/tests/test_utils.py index 0d3f3d7c..c8993c77 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -64,11 +64,13 @@ def test_get_alias_dict(): assert alias_regression["c(rt, response)"] == "rt,response" assert alias_regression["Intercept"] == "v_Intercept" + assert alias_regression["x"] == "v_x" assert alias_regression["1|group"] == "v_1|group" - assert alias_regression_a["c(rt, response)"]["c(rt, response)"] == "rt,response" - assert alias_regression_a["c(rt, response)"]["Intercept"] == "v" - assert alias_regression_a["a"]["a"] == "a" + assert alias_regression_a["c(rt, response)"] == "rt,response" + assert alias_regression_a["Intercept"] == "a_Intercept" + assert alias_regression_a["x"] == "a_x" + assert alias_regression_a["1|group"] == "a_1|group" def test_set_floatX(): From f58489c1516da4756e6d20a5e7baa8b765de1352 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Mon, 28 Aug 2023 16:49:12 -0400 Subject: [PATCH 08/16] Requires results to match by 4 digits instead of 7 --- tests/test_likelihoods.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_likelihoods.py b/tests/test_likelihoods.py index 759cb3b2..c3c457d2 100644 --- a/tests/test_likelihoods.py +++ b/tests/test_likelihoods.py @@ -106,10 +106,13 @@ def test_bbox(data_ddm): data = data_ddm.values np.testing.assert_almost_equal( - logp_ddm(data, *true_values).eval(), logp_ddm_bbox(data, *true_values) + logp_ddm(data, *true_values).eval(), + logp_ddm_bbox(data, *true_values), + decimal=4, ) np.testing.assert_almost_equal( logp_ddm_sdv(data, *true_values_sdv).eval(), logp_ddm_sdv_bbox(data, *true_values_sdv), + decimal=4, ) From f31279671fbfb75a74a7a13049cff0cacf0adb01 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Mon, 28 Aug 2023 16:50:42 -0400 Subject: [PATCH 09/16] Fixed bb likelihoods, added full_ddm likelihood --- src/hssm/likelihoods/__init__.py | 3 ++- src/hssm/likelihoods/blackbox.py | 44 ++++++++++++++++++++++++++++---- 2 files changed, 41 insertions(+), 6 deletions(-) diff --git a/src/hssm/likelihoods/__init__.py b/src/hssm/likelihoods/__init__.py index 13669129..94d9a1c0 100644 --- a/src/hssm/likelihoods/__init__.py +++ b/src/hssm/likelihoods/__init__.py @@ -1,7 +1,7 @@ """Likelihood functions and distributions that use them.""" from .analytical import DDM, DDM_SDV, logp_ddm, logp_ddm_sdv -from .blackbox import logp_ddm_bbox, logp_ddm_sdv_bbox +from .blackbox import logp_ddm_bbox, logp_ddm_sdv_bbox, logp_full_ddm __all__ = [ "logp_ddm", @@ -10,4 +10,5 @@ "DDM_SDV", "logp_ddm_bbox", "logp_ddm_sdv_bbox", + "logp_full_ddm", ] diff --git a/src/hssm/likelihoods/blackbox.py b/src/hssm/likelihoods/blackbox.py index 78bddc9b..70954c8e 100644 --- a/src/hssm/likelihoods/blackbox.py +++ b/src/hssm/likelihoods/blackbox.py @@ -14,7 +14,7 @@ def logp_ddm_bbox(data: np.ndarray, v, a, z, t) -> np.ndarray: v, a, z, t = [_broadcast(param, size) for param in [v, a, z, t]] zeros = np.zeros(size, dtype=np.float64) - return hddm_wfpt.wfpt.wiener_logp_array( + logp = hddm_wfpt.wfpt.wiener_logp_array( x=x, v=v, sv=zeros, @@ -24,18 +24,22 @@ def logp_ddm_bbox(data: np.ndarray, v, a, z, t) -> np.ndarray: t=t, st=zeros, err=1e-8, - ).astype(data.dtype) + ) + + logp = np.where(np.isfinite(logp), logp, -66.1) + + return logp.astype(data.dtype) def logp_ddm_sdv_bbox(data: np.ndarray, v, a, z, t, sv) -> np.ndarray: - """Compute blackbox log-likelihoods for ddm models.""" + """Compute blackbox log-likelihoods for ddm_sdv models.""" x = (data[:, 0] * data[:, 1]).astype(np.float64) size = len(x) v, a, z, t, sv = [_broadcast(param, size) for param in [v, a, z, t, sv]] zeros = np.zeros(size, dtype=np.float64) - return hddm_wfpt.wfpt.wiener_logp_array( + logp = hddm_wfpt.wfpt.wiener_logp_array( x=x, v=v, sv=zeros, @@ -45,7 +49,37 @@ def logp_ddm_sdv_bbox(data: np.ndarray, v, a, z, t, sv) -> np.ndarray: t=t, st=zeros, err=1e-8, - ).astype(data.dtype) + ) + + logp = np.where(np.isfinite(logp), logp, -66.1) + + return logp.astype(data.dtype) + + +def logp_full_ddm(data: np.ndarray, v, a, z, t, sv, sz, st): + """Compute blackbox log-likelihoods for full_ddm models.""" + x = (data[:, 0] * data[:, 1]).astype(np.float64) + size = len(x) + + v, a, z, t, sv, sz, st = [ + _broadcast(param, size) for param in [v, a, z, t, sv, sz, st] + ] + + logp = hddm_wfpt.wfpt.wiener_logp_array( + x=x, + v=v, + sv=sv, + a=a * 2, # Ensure compatibility with HSSM. + z=z, + sz=sz, + t=t, + st=st, + err=1e-8, + ) + + logp = np.where(np.isfinite(logp), logp, -66.1) + + return logp.astype(data.dtype) def _broadcast(x: float | np.ndarray, size: int): From 2a4aa8ed54daa11b93b8bfd9c95e0f9d3f3e12b4 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Mon, 28 Aug 2023 16:51:12 -0400 Subject: [PATCH 10/16] Fixed potential bug in dist.py --- src/hssm/distribution_utils/dist.py | 20 +++++--------------- 1 file changed, 5 insertions(+), 15 deletions(-) diff --git a/src/hssm/distribution_utils/dist.py b/src/hssm/distribution_utils/dist.py index a2d7e9f1..b14ce46a 100644 --- a/src/hssm/distribution_utils/dist.py +++ b/src/hssm/distribution_utils/dist.py @@ -59,13 +59,8 @@ def apply_param_bounds_to_loglik( """ dist_params_dict = dict(zip(list_params, dist_params)) - bounds = { - k: ( - pm.floatX(v[0]), - pm.floatX(v[1]), - ) - for k, v in bounds.items() - } + bounds = {k: (pm.floatX(v[0]), pm.floatX(v[1])) for k, v in bounds.items()} + out_of_bounds_mask = pt.zeros_like(logp, dtype=bool) for param_name, param in dist_params_dict.items(): # It cannot be assumed that each parameter will have bounds. @@ -75,15 +70,10 @@ def apply_param_bounds_to_loglik( lower_bound, upper_bound = bounds[param_name] - out_of_bounds_mask = pt.bitwise_or( - pt.lt(param, lower_bound), pt.gt(param, upper_bound) - ) - - broadcasted_mask = pt.broadcast_to( - out_of_bounds_mask, logp.shape - ) # typing: ignore + param_mask = pt.bitwise_or(pt.lt(param, lower_bound), pt.gt(param, upper_bound)) + out_of_bounds_mask = pt.bitwise_or(out_of_bounds_mask, param_mask) - logp = pt.where(broadcasted_mask, OUT_OF_BOUNDS_VAL, logp) + logp = pt.where(out_of_bounds_mask, OUT_OF_BOUNDS_VAL, logp) return logp From 25bc711d552e71786044789965bc9c1d5a164055 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Mon, 28 Aug 2023 16:51:42 -0400 Subject: [PATCH 11/16] Added defaults for full ddm model --- src/hssm/defaults.py | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/src/hssm/defaults.py b/src/hssm/defaults.py index a5e31e63..60fc134e 100644 --- a/src/hssm/defaults.py +++ b/src/hssm/defaults.py @@ -3,17 +3,19 @@ from typing import Callable, Literal, Optional, TypedDict, Union import bambi as bmb +import numpy as np from pymc import Distribution from pytensor.graph.op import Op from .likelihoods.analytical import ( ddm_bounds, ddm_params, + ddm_sdv_bounds, ddm_sdv_params, logp_ddm, logp_ddm_sdv, ) -from .likelihoods.blackbox import logp_ddm_bbox, logp_ddm_sdv_bbox +from .likelihoods.blackbox import logp_ddm_bbox, logp_ddm_sdv_bbox, logp_full_ddm from .param import ParamSpec, _make_default_prior LogLik = Union[str, PathLike, Callable, Op, type[Distribution]] @@ -21,6 +23,7 @@ SupportedModels = Literal[ "ddm", "ddm_sdv", + "full_ddm", "angle", "levy", "ornstein", @@ -127,7 +130,7 @@ class DefaultConfig(TypedDict): "blackbox": { "loglik": logp_ddm_sdv_bbox, "backend": None, - "bounds": ddm_bounds, + "bounds": ddm_sdv_bounds, "default_priors": { "t": { "name": "HalfNormal", @@ -138,6 +141,24 @@ class DefaultConfig(TypedDict): }, }, }, + "full_ddm": { + "list_params": ["v", "a", "z", "t", "sv", "sz", "st"], + "description": "The full Drift Diffusion Model (DDM)", + "likelihoods": { + "blackbox": { + "loglik": logp_full_ddm, + "backend": None, + "bounds": ddm_sdv_bounds | {"sz": (0, np.inf), "st": (0, np.inf)}, + "default_priors": { + "t": { + "name": "HalfNormal", + "sigma": 2.0, + "initval": 0.1, + }, + }, + } + }, + }, "angle": { "list_params": ["v", "a", "z", "t", "theta"], "description": None, From d4642f2486a61af748481c79208f96b5e7ef80c6 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Mon, 28 Aug 2023 16:52:06 -0400 Subject: [PATCH 12/16] Update docstring to support full_ddm --- src/hssm/hssm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py index 63cd6c28..6079b4d4 100644 --- a/src/hssm/hssm.py +++ b/src/hssm/hssm.py @@ -63,7 +63,7 @@ class HSSM: columns "rt" and "response". model The name of the model to use. Currently supported models are "ddm", "ddm_sdv", - "angle", "levy", "ornstein", "weibull", "race_no_bias_angle_4", + "full_ddm", "angle", "levy", "ornstein", "weibull", "race_no_bias_angle_4", "ddm_seq2_no_bias". If any other string is passed, the model will be considered custom, in which case all `model_config`, `loglik`, and `loglik_kind` have to be provided by the user. From 66bb58be15ebc777b15dca422d4c276feb873ebe Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Mon, 28 Aug 2023 16:52:28 -0400 Subject: [PATCH 13/16] Updated data fixture --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 1ca04e9d..be019f7e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -34,7 +34,7 @@ def data_angle(): def data_ddm_reg(): # Generate some fake simulation data intercept = 0.3 - x = np.random.uniform(0.5, 0.2, size=1000) + x = np.random.uniform(0.5, 0.7, size=1000) y = np.random.uniform(0.4, 0.1, size=1000) v = intercept + 0.8 * x + 0.3 * y From 59831130cf241d0d1ce7b14a18225e858d734f0c Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Mon, 28 Aug 2023 16:55:46 -0400 Subject: [PATCH 14/16] Updated documentation to include full_ddm --- docs/api/defaults.md | 3 ++- docs/getting_started/getting_started.ipynb | 1 + docs/tutorials/likelihoods.ipynb | 3 ++- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/docs/api/defaults.md b/docs/api/defaults.md index 870b7560..b159fdec 100644 --- a/docs/api/defaults.md +++ b/docs/api/defaults.md @@ -8,6 +8,7 @@ The module includes a dictionary, `default_model_config`, that provides default - `ddm` - `ddm_sdv` +- `full_ddm` - `angle` - `levy` - `ornstein` @@ -72,7 +73,7 @@ For each model, a dictionary is defined containing configurations for each `Logl - v: Uniform (-10.0, 10.0) - sv: HalfNormal with sigma 2.0 - a: HalfNormal with sigma 2.0 - - t: Uniform (0.0, 5.0) with initial value 0.0 + - t: Uniform (0.0, 5.0) with initial value 0.1 #### Approx Differentiable - **Log-likelihood kind:** Approx Differentiable diff --git a/docs/getting_started/getting_started.ipynb b/docs/getting_started/getting_started.ipynb index 524e07ee..7dbe2a02 100644 --- a/docs/getting_started/getting_started.ipynb +++ b/docs/getting_started/getting_started.ipynb @@ -2358,6 +2358,7 @@ "\n", "- ddm\n", "- ddm_sdv\n", + "- full_ddm\n", "- angle\n", "- levy\n", "- ornstein\n", diff --git a/docs/tutorials/likelihoods.ipynb b/docs/tutorials/likelihoods.ipynb index 64205360..b3ac26dc 100644 --- a/docs/tutorials/likelihoods.ipynb +++ b/docs/tutorials/likelihoods.ipynb @@ -223,6 +223,7 @@ "\n", "- For `analytical` kind: `ddm` and `ddm_sdv` models.\n", "- For `approx_differentiable` kind: `ddm`, `ddm_sdv`, `angle`, `levy`, `ornstein`, `weibull`, `race_no_bias_angle_4` and `ddm_seq2_no_bias`.\n", + "- For `blackbox` kind: `ddm`, `ddm_sdv` and `full_ddm` models.\n", "\n", "For a model that has default likelihood functions, only the `model` argument needs to be specified." ] @@ -3058,7 +3059,7 @@ "\n", "2. Specify a `model_config`. It typically contains the following fields:\n", "\n", - " - `\"list_params\"`: Required if your `model` string is not one of `ddm`, `ddm_sdv`, `angle`, `levy`, `ornstein`, `weibull`, `race_no_bias_angle_4` and `ddm_seq2_no_bias`. A list of `str` indicating the parameters of the model.\n", + " - `\"list_params\"`: Required if your `model` string is not one of `ddm`, `ddm_sdv`, `full_ddm`, `angle`, `levy`, `ornstein`, `weibull`, `race_no_bias_angle_4` and `ddm_seq2_no_bias`. A list of `str` indicating the parameters of the model.\n", " The order in which the parameters are specified in this list is important.\n", " Values for each parameter will be passed to the likelihood function in this\n", " order.\n", From 91024a4a6bc1c3df1107b98eafd731c8b37205fa Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Tue, 29 Aug 2023 11:46:04 -0400 Subject: [PATCH 15/16] remove a print statement --- src/hssm/hssm.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py index 6079b4d4..8d41146f 100644 --- a/src/hssm/hssm.py +++ b/src/hssm/hssm.py @@ -250,8 +250,6 @@ def __init__( self.model_distribution = self._make_model_distribution() - print(self.formula, self._parent) - self.family = make_family( self.model_distribution, self.list_params, From c9e83af8f74d986f115c72a6e300f528ee86018a Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Thu, 31 Aug 2023 13:20:57 -0400 Subject: [PATCH 16/16] Enforce response to be 0, 1 --- src/hssm/likelihoods/blackbox.py | 62 +++++++++++++++----------------- 1 file changed, 29 insertions(+), 33 deletions(-) diff --git a/src/hssm/likelihoods/blackbox.py b/src/hssm/likelihoods/blackbox.py index 70954c8e..1712c1d7 100644 --- a/src/hssm/likelihoods/blackbox.py +++ b/src/hssm/likelihoods/blackbox.py @@ -6,16 +6,32 @@ import numpy as np +def hddm_to_hssm(func): + """Make HDDM likelihood function compatible with HSSM.""" + + def outer(data: np.ndarray, *args, **kwargs): + x = data[:, 0] * np.where(data[:, 1] == 1, 1.0, -1.0).astype(np.float64) + size = len(data) + + args_list = [_broadcast(param, size) for param in args] + kwargs = {k: _broadcast(param, size) for k, param in kwargs.items()} + + logp = func(x, *args_list, *kwargs) + logp = np.where(np.isfinite(logp), logp, -66.1) + + return logp.astype(data.dtype) + + return outer + + +@hddm_to_hssm def logp_ddm_bbox(data: np.ndarray, v, a, z, t) -> np.ndarray: """Compute blackbox log-likelihoods for ddm models.""" - x = (data[:, 0] * data[:, 1]).astype(np.float64) size = len(data) - - v, a, z, t = [_broadcast(param, size) for param in [v, a, z, t]] zeros = np.zeros(size, dtype=np.float64) - logp = hddm_wfpt.wfpt.wiener_logp_array( - x=x, + return hddm_wfpt.wfpt.wiener_logp_array( + x=data, v=v, sv=zeros, a=a * 2, # Ensure compatibility with HSSM. @@ -26,23 +42,17 @@ def logp_ddm_bbox(data: np.ndarray, v, a, z, t) -> np.ndarray: err=1e-8, ) - logp = np.where(np.isfinite(logp), logp, -66.1) - - return logp.astype(data.dtype) - +@hddm_to_hssm def logp_ddm_sdv_bbox(data: np.ndarray, v, a, z, t, sv) -> np.ndarray: """Compute blackbox log-likelihoods for ddm_sdv models.""" - x = (data[:, 0] * data[:, 1]).astype(np.float64) - size = len(x) - - v, a, z, t, sv = [_broadcast(param, size) for param in [v, a, z, t, sv]] + size = len(data) zeros = np.zeros(size, dtype=np.float64) - logp = hddm_wfpt.wfpt.wiener_logp_array( - x=x, + return hddm_wfpt.wfpt.wiener_logp_array( + x=data, v=v, - sv=zeros, + sv=sv, a=a * 2, # Ensure compatibility with HSSM. z=z, sz=zeros, @@ -51,22 +61,12 @@ def logp_ddm_sdv_bbox(data: np.ndarray, v, a, z, t, sv) -> np.ndarray: err=1e-8, ) - logp = np.where(np.isfinite(logp), logp, -66.1) - - return logp.astype(data.dtype) - +@hddm_to_hssm def logp_full_ddm(data: np.ndarray, v, a, z, t, sv, sz, st): """Compute blackbox log-likelihoods for full_ddm models.""" - x = (data[:, 0] * data[:, 1]).astype(np.float64) - size = len(x) - - v, a, z, t, sv, sz, st = [ - _broadcast(param, size) for param in [v, a, z, t, sv, sz, st] - ] - - logp = hddm_wfpt.wfpt.wiener_logp_array( - x=x, + return hddm_wfpt.wfpt.wiener_logp_array( + x=data, v=v, sv=sv, a=a * 2, # Ensure compatibility with HSSM. @@ -77,10 +77,6 @@ def logp_full_ddm(data: np.ndarray, v, a, z, t, sv, sz, st): err=1e-8, ) - logp = np.where(np.isfinite(logp), logp, -66.1) - - return logp.astype(data.dtype) - def _broadcast(x: float | np.ndarray, size: int): """Broadcast a scalar or an array to size of `size`."""