-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added the ability to use safe priors for hierarchical models #331
Merged
Merged
Changes from 24 commits
Commits
Show all changes
34 commits
Select commit
Hold shift + click to select a range
130ecc6
modified hssm.py to add special case for HDDM
digicosmos86 1401861
plugged in a few functions from bambi to handle default priors
digicosmos86 06795f3
added logic to modify the class to add default priors
digicosmos86 0dc986a
moved merge_dict to param.py to avoid circular import
digicosmos86 08a92e1
ensures that tests pass
digicosmos86 2a4587c
update software versions
digicosmos86 ce3e067
adjust how bounds are passed
digicosmos86 c3eeb14
add ruff ignore item, place warnings
digicosmos86 76a8411
update ci workflow
digicosmos86 2725e05
exclude ddm_sdv and ddm_full
digicosmos86 5136e04
fixed minor bugs in param.py
digicosmos86 24ab25e
use deepcopy to avoid errors
digicosmos86 c8d1981
added tests for safe prior strategy
digicosmos86 68cacd5
suppress jax warning
digicosmos86 01ba4dc
specify float type for each test file
digicosmos86 bcda159
update hssm version
digicosmos86 2567e32
Updated default parameter specifications
digicosmos86 84f1d58
suppress some warnings
digicosmos86 7a50254
bump ssm-simulators version
digicosmos86 6c935d0
update ssm-simulators
digicosmos86 d9b2827
update ssm-simulators
digicosmos86 f457631
fix a test
digicosmos86 8d006a0
Merge branch 'safe-prior-strategy' of https://github.com/lnccbrown/HS…
digicosmos86 509c0e2
set default init to
digicosmos86 2cb9608
bump ssm-simulators
digicosmos86 e7da626
Merge branch 'safe-prior-strategy' into update-documentation-020
digicosmos86 e0bac72
added string representation for generalized logit
digicosmos86 67575a0
fixed a bug where link_settings does not work in hssm
digicosmos86 6ae87ad
added documentation for GPU support
digicosmos86 820ff00
fix bugs in param.py
digicosmos86 e605503
added documentation for hierachical modeling
digicosmos86 0d4c507
added changelog
digicosmos86 5061f44
Merge branch 'main' into safe-prior-strategy
digicosmos86 32b6af1
changed version to 0.2.0b1
digicosmos86 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,6 @@ name: Run tests | |
|
||
on: | ||
pull_request: | ||
push: | ||
|
||
jobs: | ||
run_tests: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
[tool.poetry] | ||
name = "HSSM" | ||
version = "0.1.5" | ||
version = "0.2.0" | ||
description = "Bayesian inference for hierarchical sequential sampling models." | ||
authors = [ | ||
"Alexander Fengler <[email protected]>", | ||
|
@@ -23,14 +23,14 @@ numpy = ">=1.23.4,<1.26" | |
onnx = "^1.12.0" | ||
jax = "^0.4.0" | ||
jaxlib = "^0.4.0" | ||
ssm-simulators = "0.5.1" | ||
ssm-simulators = "0.5.3" | ||
huggingface-hub = "^0.15.1" | ||
onnxruntime = "^1.15.0" | ||
bambi = "^0.12.0" | ||
numpyro = "^0.12.1" | ||
hddm-wfpt = "^0.1.1" | ||
seaborn = "^0.13.0" | ||
pytensor = "<=2.17.3" | ||
pytensor = "<2.17.4" | ||
|
||
[tool.poetry.group.dev.dependencies] | ||
pytest = "^7.3.1" | ||
|
@@ -69,7 +69,7 @@ profile = "black" | |
|
||
[tool.ruff] | ||
line-length = 88 | ||
target-version = "py39" | ||
target-version = "py310" | ||
unfixable = ["E711"] | ||
|
||
select = [ | ||
|
@@ -132,6 +132,8 @@ ignore = [ | |
"B020", | ||
# Function definition does not bind loop variable | ||
"B023", | ||
# zip()` without an explicit `strict= | ||
"B905", | ||
# Functions defined inside a loop must not use variables redefined in the loop | ||
# "B301", # not yet implemented | ||
# Too many arguments to function call | ||
|
@@ -166,14 +168,7 @@ ignore = [ | |
"TID252", | ||
] | ||
|
||
exclude = [ | ||
".github", | ||
"docs", | ||
"notebook", | ||
"tests", | ||
"src/hssm/likelihoods/hddm_wfpt/cdfdif_wrapper.c", | ||
"src/hssm/likelihoods/hddm_wfpt/wfpt.cpp", | ||
] | ||
exclude = [".github", "docs", "notebook", "tests"] | ||
|
||
[tool.ruff.pydocstyle] | ||
convention = "numpy" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,7 @@ | |
import seaborn as sns | ||
import xarray as xr | ||
from bambi.model_components import DistributionalComponent | ||
from bambi.transformations import transformations_namespace | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is that one for actually? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think there is a default namespace that has to be included. This can be found in Bambi source code here |
||
|
||
from hssm.defaults import ( | ||
LoglikKind, | ||
|
@@ -164,6 +165,9 @@ class HSSM: | |
recommended when you are using hierarchical models. | ||
The default value is `None` when `hierarchical` is `False` and `"safe"` when | ||
`hierarchical` is `True`. | ||
extra_namespace : optional | ||
Additional user supplied variables with transformations or data to include in | ||
the environment where the formula is evaluated. Defaults to `None`. | ||
**kwargs | ||
Additional arguments passed to the `bmb.Model` object. | ||
|
||
|
@@ -214,6 +218,7 @@ def __init__( | |
hierarchical: bool = False, | ||
link_settings: Literal["log_logit"] | None = None, | ||
prior_settings: Literal["safe"] | None = None, | ||
extra_namespace: dict[str, Any] | None = None, | ||
**kwargs, | ||
): | ||
self.data = data | ||
|
@@ -232,6 +237,11 @@ def __init__( | |
self.link_settings = link_settings | ||
self.prior_settings = prior_settings | ||
|
||
additional_namespace = transformations_namespace.copy() | ||
if extra_namespace is not None: | ||
additional_namespace.update(extra_namespace) | ||
self.additional_namespace = additional_namespace | ||
|
||
responses = self.data["response"].unique().astype(int) | ||
self.n_responses = len(responses) | ||
if self.n_responses == 2: | ||
|
@@ -312,7 +322,12 @@ def __init__( | |
) | ||
|
||
self.model = bmb.Model( | ||
self.formula, data, family=self.family, priors=self.priors, **other_kwargs | ||
self.formula, | ||
data=data, | ||
family=self.family, | ||
priors=self.priors, | ||
extra_namespace=extra_namespace, | ||
**other_kwargs, | ||
) | ||
|
||
self._aliases = get_alias_dict(self.model, self._parent_param) | ||
|
@@ -322,6 +337,7 @@ def sample( | |
self, | ||
sampler: Literal["mcmc", "nuts_numpyro", "nuts_blackjax", "laplace", "vi"] | ||
| None = None, | ||
init: str | None = None, | ||
**kwargs, | ||
) -> az.InferenceData | pm.Approximation: | ||
"""Perform sampling using the `fit` method via bambi.Model. | ||
|
@@ -335,6 +351,9 @@ def sample( | |
sampler will automatically be chosen: when the model uses the | ||
`approx_differentiable` likelihood, and `jax` backend, "nuts_numpyro" will | ||
be used. Otherwise, "mcmc" (the default PyMC NUTS sampler) will be used. | ||
init: optional | ||
Initialization method to use for the sampler. If any of the NUTS samplers | ||
is used, defaults to `"adapt_diag"`. Otherwise, defaults to `"auto"`. | ||
kwargs | ||
Other arguments passed to bmb.Model.fit(). Please see [here] | ||
(https://bambinos.github.io/bambi/api_reference.html#bambi.models.Model.fit) | ||
|
@@ -370,7 +389,7 @@ def sample( | |
) | ||
|
||
if "step" not in kwargs: | ||
kwargs["step"] = pm.Slice(model=self.pymc_model) | ||
kwargs |= {"step": pm.Slice(model=self.pymc_model)} | ||
|
||
if ( | ||
self.loglik_kind == "approx_differentiable" | ||
|
@@ -387,7 +406,15 @@ def sample( | |
if self._check_extra_fields(): | ||
self._update_extra_fields() | ||
|
||
self._inference_obj = self.model.fit(inference_method=sampler, **kwargs) | ||
if init is None: | ||
if sampler in ["mcmc", "nuts_numpyro", "nuts_blackjax"]: | ||
init = "adapt_diag" | ||
else: | ||
init = "auto" | ||
|
||
self._inference_obj = self.model.fit( | ||
inference_method=sampler, init=init, **kwargs | ||
) | ||
|
||
return self.traces | ||
|
||
|
@@ -629,11 +656,11 @@ def plot_trace( | |
data : optional | ||
An ArviZ InferenceData object. If None, the traces stored in the model will | ||
be used. | ||
include deterministic : optional | ||
include_deterministic : optional | ||
Whether to include deterministic variables in the plot. Defaults to False. | ||
Note that if include deterministic is set to False and and `var_names` is | ||
provided, the `var_names` provided will be modified to also exclude the | ||
deterministic values. If this is not desirable, set | ||
deterministic values. If this is not desirable, set | ||
`include deterministic` to True. | ||
tight_layout : optional | ||
Whether to call plt.tight_layout() after plotting. Defaults to True. | ||
|
@@ -852,6 +879,8 @@ def _add_kwargs_and_p_outlier_to_include( | |
"""Process kwargs and p_outlier and add them to include.""" | ||
if include is None: | ||
include = [] | ||
else: | ||
include = include.copy() | ||
params_in_include = [param["name"] for param in include] | ||
|
||
# Process kwargs | ||
|
@@ -913,7 +942,7 @@ def _preprocess_rest(self, processed: dict[str, Param]) -> dict[str, Param]: | |
bounds = self.model_config.bounds.get(param_str) | ||
param = Param( | ||
param_str, | ||
formula="1 + (1|participant_id)", | ||
formula=f"{param_str} ~ 1 + (1|participant_id)", | ||
bounds=bounds, | ||
) | ||
else: | ||
|
@@ -956,15 +985,27 @@ def _find_parent(self) -> tuple[str, Param]: | |
|
||
def _override_defaults(self): | ||
"""Override the default priors or links.""" | ||
is_ddm = ( | ||
self.model_name in ["ddm", "ddm_sdv", "ddm_full"] | ||
and self.loglik_kind != "approx_differentiable" | ||
) | ||
for param in self.list_params: | ||
param_obj = self.params[param] | ||
if self.prior_settings == "safe": | ||
param_obj.override_default_priors(self.data) | ||
if is_ddm: | ||
param_obj.override_default_priors_ddm( | ||
self.data, self.additional_namespace | ||
) | ||
else: | ||
param_obj.override_default_priors( | ||
self.data, self.additional_namespace | ||
) | ||
elif self.link_settings == "log_logit": | ||
param_obj.override_default_link() | ||
|
||
def _process_all(self): | ||
"""Process all params.""" | ||
assert self.list_params is not None | ||
for param in self.list_params: | ||
self.params[param].convert() | ||
|
||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can be bumped to 0.6.1 now