Skip to content

Commit

Permalink
Merge pull request #320 from lnccbrown/add-hssm-link
Browse files Browse the repository at this point in the history
Override defaults for link functions
  • Loading branch information
digicosmos86 authored Nov 28, 2023
2 parents b0c124e + dc926d9 commit de48984
Show file tree
Hide file tree
Showing 7 changed files with 328 additions and 37 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ keywords = ["HSSM", "sequential sampling models", "bayesian", "bayes", "mcmc"]

[tool.poetry.dependencies]
python = ">=3.10,<3.12"
pymc = ">=5.9.0"
pymc = "^5.9.0"
scipy = "1.10.1"
arviz = "^0.14.0"
numpy = ">=1.23.4,<1.26"
Expand All @@ -30,7 +30,7 @@ bambi = "^0.12.0"
numpyro = "^0.12.1"
hddm-wfpt = "^0.1.1"
seaborn = "^0.13.0"
xhistogram = "^0.3.2"
pytensor = "<=2.17.3"

[tool.poetry.group.dev.dependencies]
pytest = "^7.3.1"
Expand Down
2 changes: 2 additions & 0 deletions src/hssm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .datasets import load_data
from .defaults import show_defaults
from .hssm import HSSM
from .link import Link
from .param import Param
from .prior import Prior
from .simulator import simulate_data
Expand All @@ -29,6 +30,7 @@

__all__ = [
"HSSM",
"Link",
"load_data",
"ModelConfig",
"Param",
Expand Down
52 changes: 49 additions & 3 deletions src/hssm/hssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,29 @@ class HSSM:
hierarchical : optional
If True, and if there is a `participant_id` field in `data`, will by default
turn any unspecified parameter theta into a regression with
"theta ~ 1 + (1|participant_id)" and default priors set by `bambi`.
"theta ~ 1 + (1|participant_id)" and default priors set by `bambi`. Also changes
default values of `link_settings` and `prior_settings`. Defaults to False.
link_settings : optional
An optional string literal that indicates the link functions to use for each
parameter. Helpful for hierarchical models where sampling might get stuck/
very slow. Can be one of the following:
- `"log_logit"`: applies log link functions to positive parameters and
generalized logit link functions to parameters that have explicit bounds.
- `None`: unless otherwise specified, the `"identity"` link functions will be
used.
The default value is `None`.
prior_settings : optional
An optional string literal that indicates the prior distributions to use for
each parameter. Helpful for hierarchical models where sampling might get stuck/
very slow. Can be one of the following:
- `"safe"`: HSSM will scan all parameters in the model and apply safe priors to
all parameters that do not have explicit bounds.
- `None`: HSSM will use bambi to provide default priors for all parameters. Not
recommended when you are using hierarchical models.
The default value is `None` when `hierarchical` is `False` and `"safe"` when
`hierarchical` is `True`.
**kwargs
Additional arguments passed to the `bmb.Model` object.
Expand Down Expand Up @@ -190,6 +212,8 @@ def __init__(
p_outlier: float | dict | bmb.Prior | None = 0.05,
lapse: dict | bmb.Prior | None = bmb.Prior("Uniform", lower=0.0, upper=10.0),
hierarchical: bool = False,
link_settings: Literal["log_logit"] | None = None,
prior_settings: Literal["safe"] | None = None,
**kwargs,
):
self.data = data
Expand All @@ -202,7 +226,20 @@ def __init__(
+ "`participant_id` field in the DataFrame that you have passed."
)

self.n_responses = len(self.data["response"].unique())
if self.hierarchical and prior_settings is None:
prior_settings = "safe"

self.link_settings = link_settings
self.prior_settings = prior_settings

responses = self.data["response"].unique().astype(int)
self.n_responses = len(responses)
if self.n_responses == 2:
if -1 not in responses or 1 not in responses:
raise ValueError(
"The response column must contain only -1 and 1 when there are "
+ "two responses."
)

# Construct a model_config from defaults
self.model_config = Config.from_defaults(model, loglik_kind)
Expand Down Expand Up @@ -246,6 +283,7 @@ def __init__(
self._parent, self._parent_param = self._find_parent()
assert self._parent_param is not None

self._override_defaults()
self._process_all()

# Get the bambi formula, priors, and link
Expand Down Expand Up @@ -876,7 +914,6 @@ def _preprocess_rest(self, processed: dict[str, Param]) -> dict[str, Param]:
param = Param(
param_str,
formula="1 + (1|participant_id)",
link="identity",
bounds=bounds,
)
else:
Expand Down Expand Up @@ -917,6 +954,15 @@ def _find_parent(self) -> tuple[str, Param]:
param.set_parent()
return param_str, param

def _override_defaults(self):
"""Override the default priors or links."""
for param in self.list_params:
param_obj = self.params[param]
if self.prior_settings == "safe":
param_obj.override_default_priors(self.data)
elif self.link_settings == "log_logit":
param_obj.override_default_link()

def _process_all(self):
"""Process all params."""
for param in self.list_params:
Expand Down
80 changes: 80 additions & 0 deletions src/hssm/link.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""A class that extends bmb.Link to allow for more generalized links with bounds."""

import bambi as bmb
import numpy as np

HSSM_LINKS = {"gen_logit"}


class Link(bmb.Link):
"""Representation of a generalized link function.
This object contains two main functions. One is the link function itself, the
function that maps values in the response scale to the linear predictor, and the
other is the inverse of the link function, that maps values of the linear predictor
to the response scale.
The great majority of users will never interact with this class unless they want to
create a custom ``Family`` with a custom ``Link``. This is automatically handled for
all the built-in families.
Parameters
----------
name
The name of the link function. If it is a known name, it's not necessary to pass
any other arguments because functions are already defined internally. If not
known, all of `link``, ``linkinv`` and ``linkinv_backend`` must be specified.
link : optional
A function that maps the response to the linear predictor. Known as the
:math:`g` function in GLM jargon. Does not need to be specified when ``name``
is a known name.
linkinv : optional
A function that maps the linear predictor to the response. Known as the
:math:`g^{-1}` function in GLM jargon. Does not need to be specified when
``name`` is a known name.
linkinv_backend : optional
Same than ``linkinv`` but must be something that works with PyMC backend
(i.e. it must work with PyTensor tensors). Does not need to be specified when
``name`` is a known name.
bounds : optional
Bounds of the response scale. Only needed when ``name`` is ``gen_logit``.
"""

def __init__(
self,
name,
link=None,
linkinv=None,
linkinv_backend=None,
bounds: tuple[float, float] | None = None,
):
if name in HSSM_LINKS:
self.name = name
if name == "gen_logit":
if bounds is None:
raise ValueError(
"Bounds must be specified for generalized log link function."
)
self.link = self._make_generalized_logit_simple(*bounds)
self.linkinv = self._make_generalized_sigmoid_simple(*bounds)
self.linkinv_backend = self._make_generalized_sigmoid_simple(*bounds)
else:
bmb.Link.__init__(name, link, linkinv, linkinv_backend)

self.bounds = bounds

def _make_generalized_sigmoid_simple(self, a, b):
"""Make a generalized sigmoid link function with bounds a and b."""

def invlink_(x):
return a + ((b - a) / (1 + np.exp(-x)))

return invlink_

def _make_generalized_logit_simple(self, a, b):
"""Make a generalized logit link function with bounds a and b."""

def link_(x):
return np.log((x - a) / (b - x))

return link_
67 changes: 67 additions & 0 deletions src/hssm/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

import bambi as bmb
import numpy as np
import pandas as pd

from .link import Link
from .prior import Prior

# PEP604 union operator "|" not supported by pylint
Expand Down Expand Up @@ -69,6 +71,15 @@ def __init__(
self._is_parent = False
self._is_converted = False
self._do_not_truncate = False
self._link_specified = link is not None

# Provides a convenient way to specify the link function
if self.link == "gen_logit":
if self.bounds is None:
raise ValueError(
"Bounds must be specified for generalized log link function."
)
self.link = Link("gen_logit", bounds=self.bounds)

# The initializer does not do anything immediately after the object is initiated
# because things could still change.
Expand All @@ -82,6 +93,62 @@ def update(self, **kwargs):
raise ValueError(f"{attr} does not exist.")
setattr(self, attr, value)

def override_default_link(self):
"""Override the default link function.
This is most likely because both default prior and default bounds are supplied.
"""
if self._is_converted:
raise ValueError(
(
"Cannot override the default link function for parameter %s."
+ " The object has already been processed."
)
% self.name,
)

if not self.is_regression or self._link_specified:
return # do nothing

if self.bounds is None:
raise ValueError(
(
"Cannot override the default link function. Bounds are not"
+ " specified for parameter %s."
)
% self.name,
)

lower, upper = self.bounds

if np.isneginf(lower) and np.isposinf(upper):
return
elif lower == 0.0 and np.isposinf(upper):
self.link = "log"
if not np.isneginf(lower) and not np.isposinf(upper):
self.link = Link("gen_logit", bounds=self.bounds)
else:
_logger.warning(
"The bounds for parameter %s (%f, %f) seems strange. Nothing is done to"
+ " the link function. Please check if they are correct.",
self.name,
lower,
upper,
)

def override_default_priors(self, data: pd.DataFrame):
"""Override the default priors.
By supplying priors for all parameters in the regression, we can override the
defaults that Bambi uses.
Parameters
----------
data
The data used to fit the model.
"""
return # Will implement in the next PR

def set_parent(self):
"""Set the Param as parent."""
self._is_parent = True
Expand Down
Loading

0 comments on commit de48984

Please sign in to comment.