Skip to content

Commit

Permalink
Merge pull request #267 from lnccbrown/221-wfpt-as-blackbox
Browse files Browse the repository at this point in the history
Added blackbox likelihoods for `ddm` and `ddm_sdv` models
  • Loading branch information
digicosmos86 authored Sep 1, 2023
2 parents 2027eaa + c9e83af commit d1b4f7b
Show file tree
Hide file tree
Showing 15 changed files with 295 additions and 138 deletions.
3 changes: 2 additions & 1 deletion docs/api/defaults.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The module includes a dictionary, `default_model_config`, that provides default

- `ddm`
- `ddm_sdv`
- `full_ddm`
- `angle`
- `levy`
- `ornstein`
Expand Down Expand Up @@ -72,7 +73,7 @@ For each model, a dictionary is defined containing configurations for each `Logl
- v: Uniform (-10.0, 10.0)
- sv: HalfNormal with sigma 2.0
- a: HalfNormal with sigma 2.0
- t: Uniform (0.0, 5.0) with initial value 0.0
- t: Uniform (0.0, 5.0) with initial value 0.1

#### Approx Differentiable
- **Log-likelihood kind:** Approx Differentiable
Expand Down
1 change: 1 addition & 0 deletions docs/getting_started/getting_started.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2358,6 +2358,7 @@
"\n",
"- ddm\n",
"- ddm_sdv\n",
"- full_ddm\n",
"- angle\n",
"- levy\n",
"- ornstein\n",
Expand Down
3 changes: 2 additions & 1 deletion docs/tutorials/likelihoods.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@
"\n",
"- For `analytical` kind: `ddm` and `ddm_sdv` models.\n",
"- For `approx_differentiable` kind: `ddm`, `ddm_sdv`, `angle`, `levy`, `ornstein`, `weibull`, `race_no_bias_angle_4` and `ddm_seq2_no_bias`.\n",
"- For `blackbox` kind: `ddm`, `ddm_sdv` and `full_ddm` models.\n",
"\n",
"For a model that has default likelihood functions, only the `model` argument needs to be specified."
]
Expand Down Expand Up @@ -3058,7 +3059,7 @@
"\n",
"2. Specify a `model_config`. It typically contains the following fields:\n",
"\n",
" - `\"list_params\"`: Required if your `model` string is not one of `ddm`, `ddm_sdv`, `angle`, `levy`, `ornstein`, `weibull`, `race_no_bias_angle_4` and `ddm_seq2_no_bias`. A list of `str` indicating the parameters of the model.\n",
" - `\"list_params\"`: Required if your `model` string is not one of `ddm`, `ddm_sdv`, `full_ddm`, `angle`, `levy`, `ornstein`, `weibull`, `race_no_bias_angle_4` and `ddm_seq2_no_bias`. A list of `str` indicating the parameters of the model.\n",
" The order in which the parameters are specified in this list is important.\n",
" Values for each parameter will be passed to the likelihood function in this\n",
" order.\n",
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ huggingface-hub = "^0.15.1"
onnxruntime = "^1.15.0"
bambi = "^0.12.0"
numpyro = "^0.12.1"
hddm-wfpt = {git = "https://github.com/brown-ccv/hddm-wfpt.git"}

[tool.poetry.group.dev.dependencies]
pytest = "^7.3.1"
Expand All @@ -36,7 +37,6 @@ mypy = "^1.4.1"
pre-commit = "^2.20.0"
jupyterlab = "^4.0.2"
ipykernel = "^6.16.0"
hddm-wfpt = { git = "https://github.com/brown-ccv/hddm-wfpt.git" }
ipywidgets = "^8.0.3"
graphviz = "^0.20.1"
ruff = "^0.0.272"
Expand Down
46 changes: 46 additions & 0 deletions src/hssm/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,27 @@
from typing import Callable, Literal, Optional, TypedDict, Union

import bambi as bmb
import numpy as np
from pymc import Distribution
from pytensor.graph.op import Op

from .likelihoods.analytical import (
ddm_bounds,
ddm_params,
ddm_sdv_bounds,
ddm_sdv_params,
logp_ddm,
logp_ddm_sdv,
)
from .likelihoods.blackbox import logp_ddm_bbox, logp_ddm_sdv_bbox, logp_full_ddm
from .param import ParamSpec, _make_default_prior

LogLik = Union[str, PathLike, Callable, Op, type[Distribution]]

SupportedModels = Literal[
"ddm",
"ddm_sdv",
"full_ddm",
"angle",
"levy",
"ornstein",
Expand Down Expand Up @@ -81,6 +85,18 @@ class DefaultConfig(TypedDict):
"t": (0.0, 2.0),
},
},
"blackbox": {
"loglik": logp_ddm_bbox,
"backend": None,
"bounds": ddm_bounds,
"default_priors": {
"t": {
"name": "HalfNormal",
"sigma": 2.0,
"initval": 0.1,
},
},
},
},
},
"ddm_sdv": {
Expand Down Expand Up @@ -111,6 +127,36 @@ class DefaultConfig(TypedDict):
"sv": (0.0, 1.0),
},
},
"blackbox": {
"loglik": logp_ddm_sdv_bbox,
"backend": None,
"bounds": ddm_sdv_bounds,
"default_priors": {
"t": {
"name": "HalfNormal",
"sigma": 2.0,
"initval": 0.1,
},
},
},
},
},
"full_ddm": {
"list_params": ["v", "a", "z", "t", "sv", "sz", "st"],
"description": "The full Drift Diffusion Model (DDM)",
"likelihoods": {
"blackbox": {
"loglik": logp_full_ddm,
"backend": None,
"bounds": ddm_sdv_bounds | {"sz": (0, np.inf), "st": (0, np.inf)},
"default_priors": {
"t": {
"name": "HalfNormal",
"sigma": 2.0,
"initval": 0.1,
},
},
}
},
},
"angle": {
Expand Down
20 changes: 5 additions & 15 deletions src/hssm/distribution_utils/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,8 @@ def apply_param_bounds_to_loglik(
"""
dist_params_dict = dict(zip(list_params, dist_params))

bounds = {
k: (
pm.floatX(v[0]),
pm.floatX(v[1]),
)
for k, v in bounds.items()
}
bounds = {k: (pm.floatX(v[0]), pm.floatX(v[1])) for k, v in bounds.items()}
out_of_bounds_mask = pt.zeros_like(logp, dtype=bool)

for param_name, param in dist_params_dict.items():
# It cannot be assumed that each parameter will have bounds.
Expand All @@ -75,15 +70,10 @@ def apply_param_bounds_to_loglik(

lower_bound, upper_bound = bounds[param_name]

out_of_bounds_mask = pt.bitwise_or(
pt.lt(param, lower_bound), pt.gt(param, upper_bound)
)

broadcasted_mask = pt.broadcast_to(
out_of_bounds_mask, logp.shape
) # typing: ignore
param_mask = pt.bitwise_or(pt.lt(param, lower_bound), pt.gt(param, upper_bound))
out_of_bounds_mask = pt.bitwise_or(out_of_bounds_mask, param_mask)

logp = pt.where(broadcasted_mask, OUT_OF_BOUNDS_VAL, logp)
logp = pt.where(out_of_bounds_mask, OUT_OF_BOUNDS_VAL, logp)

return logp

Expand Down
101 changes: 83 additions & 18 deletions src/hssm/hssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from hssm.param import (
Param,
_make_default_prior,
_parse_bambi,
)
from hssm.utils import (
HSSMModelGraph,
Expand Down Expand Up @@ -64,7 +63,7 @@ class HSSM:
columns "rt" and "response".
model
The name of the model to use. Currently supported models are "ddm", "ddm_sdv",
"angle", "levy", "ornstein", "weibull", "race_no_bias_angle_4",
"full_ddm", "angle", "levy", "ornstein", "weibull", "race_no_bias_angle_4",
"ddm_seq2_no_bias". If any other string is passed, the model will be considered
custom, in which case all `model_config`, `loglik`, and `loglik_kind` have to be
provided by the user.
Expand Down Expand Up @@ -211,7 +210,6 @@ def __init__(
self.model_name = self.model_config.model_name
self.loglik = self.model_config.loglik
self.loglik_kind = self.model_config.loglik_kind
self._parent = self.list_params[0]

# Process lapse distribution
self.has_lapse = p_outlier is not None and p_outlier != 0
Expand All @@ -225,15 +223,17 @@ def __init__(
)

# Process parameter specifications include
processed = self._process_include(include)
processed = self._preprocess_include(include)
# Process parameter specifications not in include
self.params = self._process_rest(processed)
self.params = self._preprocess_rest(processed)
# Find the parent parameter
self._parent, self._parent_param = self._find_parent()
assert self._parent_param is not None

# Get the bambi formula, priors, and link
self.formula, self.priors, self.link = _parse_bambi(self.params)
self._process_all()

self._parent_param = self.params[self.list_params[0]]
assert self._parent_param is not None
# Get the bambi formula, priors, and link
self.formula, self.priors, self.link = self._parse_bambi()

# For parameters that are regression, apply bounds at the likelihood level to
# ensure that the samples that are out of bounds are discarded (replaced with
Expand Down Expand Up @@ -674,7 +674,7 @@ def _add_kwargs_and_p_outlier_to_include(

return include, other_kwargs

def _process_include(self, include: list[dict | Param]) -> dict[str, Param]:
def _preprocess_include(self, include: list[dict | Param]) -> dict[str, Param]:
"""Turn parameter specs in include into Params."""
result: dict[str, Param] = {}

Expand All @@ -693,12 +693,10 @@ def _process_include(self, include: list[dict | Param]) -> dict[str, Param]:
if isinstance(param_with_default, dict)
else param_with_default
)
if name == self._parent:
result[name].set_parent()

return result

def _process_rest(self, processed: dict[str, Param]) -> dict[str, Param]:
def _preprocess_rest(self, processed: dict[str, Param]) -> dict[str, Param]:
"""Turn parameter specs not in include into Params."""
not_in_include = {}

Expand All @@ -716,20 +714,87 @@ def _process_rest(self, processed: dict[str, Param]) -> dict[str, Param]:
prior, bounds = self.model_config.get_defaults(param_str)
param = Param(param_str, prior=prior, bounds=bounds)
param.do_not_truncate()
if param_str == self._parent:
param.set_parent()
not_in_include[param_str] = param

processed |= not_in_include
sorted_params = {}

for param_name in self.list_params:
processed_param = processed[param_name]
processed_param.convert()
sorted_params[param_name] = processed_param
sorted_params[param_name] = processed[param_name]

return sorted_params

def _find_parent(self) -> tuple[str, Param]:
"""Find the parent param for the model.
The first param that has a regression will be set as parent. If none of the
params is a regression, then the first param will be set as parent.
Returns
-------
str
The name of the param as string
Param
The parent Param object
"""
for param_str in self.list_params:
param = self.params[param_str]
if param.is_regression:
param.set_parent()
return param_str, param

param_str = self.list_params[0]
param = self.params[param_str]
param.set_parent()
return param_str, param

def _process_all(self):
"""Process all params."""
for param in self.list_params:
self.params[param].convert()

def _parse_bambi(
self,
) -> tuple[bmb.Formula, dict | None, dict[str, str | bmb.Link] | str]:
"""Retrieve three items that helps with bambi model building.
Returns
-------
tuple
A tuple containing:
1. A bmb.Formula object.
2. A dictionary of priors, if any is specified.
3. A dictionary of link functions, if any is specified.
"""
# Handle the edge case where list_params is empty:
if not self.params:
return bmb.Formula("c(rt, response) ~ 1"), None, "identity"

parent_formula = None
other_formulas = []
priors: dict[str, Any] = {}
links: dict[str, str | bmb.Link] = {}

for _, param in self.params.items():
formula, prior, link = param.parse_bambi()

if param.is_parent:
parent_formula = formula
else:
if formula is not None:
other_formulas.append(formula)
if prior is not None:
priors |= prior
if link is not None:
links |= link

assert parent_formula is not None
result_formula: bmb.Formula = bmb.Formula(parent_formula, *other_formulas)
result_priors = None if not priors else priors
result_links: dict | str = "identity" if not links else links

return result_formula, result_priors, result_links

def _make_model_distribution(self) -> type[pm.Distribution]:
"""Make a pm.Distribution for the model."""
### Logic for different types of likelihoods:
Expand Down
11 changes: 10 additions & 1 deletion src/hssm/likelihoods/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
"""Likelihood functions and distributions that use them."""

from .analytical import DDM, DDM_SDV, logp_ddm, logp_ddm_sdv
from .blackbox import logp_ddm_bbox, logp_ddm_sdv_bbox, logp_full_ddm

__all__ = ["logp_ddm", "logp_ddm_sdv", "DDM", "DDM_SDV"]
__all__ = [
"logp_ddm",
"logp_ddm_sdv",
"DDM",
"DDM_SDV",
"logp_ddm_bbox",
"logp_ddm_sdv_bbox",
"logp_full_ddm",
]
Loading

0 comments on commit d1b4f7b

Please sign in to comment.