Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added the ability to use safe priors for hierarchical models #331

Merged
merged 34 commits into from
Dec 16, 2023
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
130ecc6
modified hssm.py to add special case for HDDM
digicosmos86 Nov 28, 2023
1401861
plugged in a few functions from bambi to handle default priors
digicosmos86 Nov 28, 2023
06795f3
added logic to modify the class to add default priors
digicosmos86 Nov 28, 2023
0dc986a
moved merge_dict to param.py to avoid circular import
digicosmos86 Nov 28, 2023
08a92e1
ensures that tests pass
digicosmos86 Nov 28, 2023
2a4587c
update software versions
digicosmos86 Nov 28, 2023
ce3e067
adjust how bounds are passed
digicosmos86 Nov 30, 2023
c3eeb14
add ruff ignore item, place warnings
digicosmos86 Nov 30, 2023
76a8411
update ci workflow
digicosmos86 Dec 5, 2023
2725e05
exclude ddm_sdv and ddm_full
digicosmos86 Dec 5, 2023
5136e04
fixed minor bugs in param.py
digicosmos86 Dec 5, 2023
24ab25e
use deepcopy to avoid errors
digicosmos86 Dec 5, 2023
c8d1981
added tests for safe prior strategy
digicosmos86 Dec 5, 2023
68cacd5
suppress jax warning
digicosmos86 Dec 5, 2023
01ba4dc
specify float type for each test file
digicosmos86 Dec 5, 2023
bcda159
update hssm version
digicosmos86 Dec 5, 2023
2567e32
Updated default parameter specifications
digicosmos86 Dec 7, 2023
84f1d58
suppress some warnings
digicosmos86 Dec 7, 2023
7a50254
bump ssm-simulators version
digicosmos86 Dec 7, 2023
6c935d0
update ssm-simulators
digicosmos86 Dec 7, 2023
d9b2827
update ssm-simulators
digicosmos86 Dec 7, 2023
f457631
fix a test
digicosmos86 Dec 7, 2023
8d006a0
Merge branch 'safe-prior-strategy' of https://github.com/lnccbrown/HS…
digicosmos86 Dec 7, 2023
509c0e2
set default init to
digicosmos86 Dec 7, 2023
2cb9608
bump ssm-simulators
digicosmos86 Dec 12, 2023
e7da626
Merge branch 'safe-prior-strategy' into update-documentation-020
digicosmos86 Dec 13, 2023
e0bac72
added string representation for generalized logit
digicosmos86 Dec 13, 2023
67575a0
fixed a bug where link_settings does not work in hssm
digicosmos86 Dec 13, 2023
6ae87ad
added documentation for GPU support
digicosmos86 Dec 13, 2023
820ff00
fix bugs in param.py
digicosmos86 Dec 13, 2023
e605503
added documentation for hierachical modeling
digicosmos86 Dec 15, 2023
0d4c507
added changelog
digicosmos86 Dec 15, 2023
5061f44
Merge branch 'main' into safe-prior-strategy
digicosmos86 Dec 15, 2023
32b6af1
changed version to 0.2.0b1
digicosmos86 Dec 15, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ name: Run tests

on:
pull_request:
push:

jobs:
run_tests:
Expand Down
4 changes: 2 additions & 2 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -126,5 +126,5 @@ markdown_extensions:
- pymdownx.superfences
- attr_list
- pymdownx.emoji:
emoji_index: !!python/name:materialx.emoji.twemoji
emoji_generator: !!python/name:materialx.emoji.to_svg
emoji_index: !!python/name:material.extensions.emoji.twemoji
emoji_generator: !!python/name:material.extensions.emoji.to_svg
19 changes: 7 additions & 12 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "HSSM"
version = "0.1.5"
version = "0.2.0"
description = "Bayesian inference for hierarchical sequential sampling models."
authors = [
"Alexander Fengler <[email protected]>",
Expand All @@ -23,14 +23,14 @@ numpy = ">=1.23.4,<1.26"
onnx = "^1.12.0"
jax = "^0.4.0"
jaxlib = "^0.4.0"
ssm-simulators = "0.5.1"
ssm-simulators = "0.5.3"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can be bumped to 0.6.1 now

huggingface-hub = "^0.15.1"
onnxruntime = "^1.15.0"
bambi = "^0.12.0"
numpyro = "^0.12.1"
hddm-wfpt = "^0.1.1"
seaborn = "^0.13.0"
pytensor = "<=2.17.3"
pytensor = "<2.17.4"

[tool.poetry.group.dev.dependencies]
pytest = "^7.3.1"
Expand Down Expand Up @@ -69,7 +69,7 @@ profile = "black"

[tool.ruff]
line-length = 88
target-version = "py39"
target-version = "py310"
unfixable = ["E711"]

select = [
Expand Down Expand Up @@ -132,6 +132,8 @@ ignore = [
"B020",
# Function definition does not bind loop variable
"B023",
# zip()` without an explicit `strict=
"B905",
# Functions defined inside a loop must not use variables redefined in the loop
# "B301", # not yet implemented
# Too many arguments to function call
Expand Down Expand Up @@ -166,14 +168,7 @@ ignore = [
"TID252",
]

exclude = [
".github",
"docs",
"notebook",
"tests",
"src/hssm/likelihoods/hddm_wfpt/cdfdif_wrapper.c",
"src/hssm/likelihoods/hddm_wfpt/wfpt.cpp",
]
exclude = [".github", "docs", "notebook", "tests"]

[tool.ruff.pydocstyle]
convention = "numpy"
Expand Down
55 changes: 48 additions & 7 deletions src/hssm/hssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import seaborn as sns
import xarray as xr
from bambi.model_components import DistributionalComponent
from bambi.transformations import transformations_namespace
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is that one for actually?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is a default namespace that has to be included. This can be found in Bambi source code here

https://github.com/bambinos/bambi/blob/312afa24b25385f5fee9e0331e88052598c39b59/bambi/models.py#L149-L155


from hssm.defaults import (
LoglikKind,
Expand Down Expand Up @@ -164,6 +165,9 @@ class HSSM:
recommended when you are using hierarchical models.
The default value is `None` when `hierarchical` is `False` and `"safe"` when
`hierarchical` is `True`.
extra_namespace : optional
Additional user supplied variables with transformations or data to include in
the environment where the formula is evaluated. Defaults to `None`.
**kwargs
Additional arguments passed to the `bmb.Model` object.

Expand Down Expand Up @@ -214,6 +218,7 @@ def __init__(
hierarchical: bool = False,
link_settings: Literal["log_logit"] | None = None,
prior_settings: Literal["safe"] | None = None,
extra_namespace: dict[str, Any] | None = None,
**kwargs,
):
self.data = data
Expand All @@ -232,6 +237,11 @@ def __init__(
self.link_settings = link_settings
self.prior_settings = prior_settings

additional_namespace = transformations_namespace.copy()
if extra_namespace is not None:
additional_namespace.update(extra_namespace)
self.additional_namespace = additional_namespace

responses = self.data["response"].unique().astype(int)
self.n_responses = len(responses)
if self.n_responses == 2:
Expand Down Expand Up @@ -312,7 +322,12 @@ def __init__(
)

self.model = bmb.Model(
self.formula, data, family=self.family, priors=self.priors, **other_kwargs
self.formula,
data=data,
family=self.family,
priors=self.priors,
extra_namespace=extra_namespace,
**other_kwargs,
)

self._aliases = get_alias_dict(self.model, self._parent_param)
Expand All @@ -322,6 +337,7 @@ def sample(
self,
sampler: Literal["mcmc", "nuts_numpyro", "nuts_blackjax", "laplace", "vi"]
| None = None,
init: str | None = None,
**kwargs,
) -> az.InferenceData | pm.Approximation:
"""Perform sampling using the `fit` method via bambi.Model.
Expand All @@ -335,6 +351,9 @@ def sample(
sampler will automatically be chosen: when the model uses the
`approx_differentiable` likelihood, and `jax` backend, "nuts_numpyro" will
be used. Otherwise, "mcmc" (the default PyMC NUTS sampler) will be used.
init: optional
Initialization method to use for the sampler. If any of the NUTS samplers
is used, defaults to `"adapt_diag"`. Otherwise, defaults to `"auto"`.
kwargs
Other arguments passed to bmb.Model.fit(). Please see [here]
(https://bambinos.github.io/bambi/api_reference.html#bambi.models.Model.fit)
Expand Down Expand Up @@ -370,7 +389,7 @@ def sample(
)

if "step" not in kwargs:
kwargs["step"] = pm.Slice(model=self.pymc_model)
kwargs |= {"step": pm.Slice(model=self.pymc_model)}

if (
self.loglik_kind == "approx_differentiable"
Expand All @@ -387,7 +406,15 @@ def sample(
if self._check_extra_fields():
self._update_extra_fields()

self._inference_obj = self.model.fit(inference_method=sampler, **kwargs)
if init is None:
if sampler in ["mcmc", "nuts_numpyro", "nuts_blackjax"]:
init = "adapt_diag"
else:
init = "auto"

self._inference_obj = self.model.fit(
inference_method=sampler, init=init, **kwargs
)

return self.traces

Expand Down Expand Up @@ -629,11 +656,11 @@ def plot_trace(
data : optional
An ArviZ InferenceData object. If None, the traces stored in the model will
be used.
include deterministic : optional
include_deterministic : optional
Whether to include deterministic variables in the plot. Defaults to False.
Note that if include deterministic is set to False and and `var_names` is
provided, the `var_names` provided will be modified to also exclude the
deterministic values. If this is not desirable, set
deterministic values. If this is not desirable, set
`include deterministic` to True.
tight_layout : optional
Whether to call plt.tight_layout() after plotting. Defaults to True.
Expand Down Expand Up @@ -852,6 +879,8 @@ def _add_kwargs_and_p_outlier_to_include(
"""Process kwargs and p_outlier and add them to include."""
if include is None:
include = []
else:
include = include.copy()
params_in_include = [param["name"] for param in include]

# Process kwargs
Expand Down Expand Up @@ -913,7 +942,7 @@ def _preprocess_rest(self, processed: dict[str, Param]) -> dict[str, Param]:
bounds = self.model_config.bounds.get(param_str)
param = Param(
param_str,
formula="1 + (1|participant_id)",
formula=f"{param_str} ~ 1 + (1|participant_id)",
bounds=bounds,
)
else:
Expand Down Expand Up @@ -956,15 +985,27 @@ def _find_parent(self) -> tuple[str, Param]:

def _override_defaults(self):
"""Override the default priors or links."""
is_ddm = (
self.model_name in ["ddm", "ddm_sdv", "ddm_full"]
and self.loglik_kind != "approx_differentiable"
)
for param in self.list_params:
param_obj = self.params[param]
if self.prior_settings == "safe":
param_obj.override_default_priors(self.data)
if is_ddm:
param_obj.override_default_priors_ddm(
self.data, self.additional_namespace
)
else:
param_obj.override_default_priors(
self.data, self.additional_namespace
)
elif self.link_settings == "log_logit":
param_obj.override_default_link()

def _process_all(self):
"""Process all params."""
assert self.list_params is not None
for param in self.list_params:
self.params[param].convert()

Expand Down
Loading