From 7ca72e20031b170127811d5633306417f6bbc50b Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Tue, 14 Nov 2023 13:27:57 -0500 Subject: [PATCH 01/10] extend bmb.link --- src/hssm/link.py | 78 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 src/hssm/link.py diff --git a/src/hssm/link.py b/src/hssm/link.py new file mode 100644 index 00000000..9eea10d5 --- /dev/null +++ b/src/hssm/link.py @@ -0,0 +1,78 @@ +"""A class that extends bmb.Link to allow for more generalized links with bounds.""" + +import bambi as bmb +import numpy as np + +HSSM_LINKS = {"gen_logit"} + + +class Link(bmb.Link): + """Representation of a generalized link function. + + This object contains two main functions. One is the link function itself, the + function that maps values in the response scale to the linear predictor, and the + other is the inverse of the link function, that maps values of the linear predictor + to the response scale. + + The great majority of users will never interact with this class unless they want to + create a custom ``Family`` with a custom ``Link``. This is automatically handled for + all the built-in families. + + Parameters + ---------- + name : str + The name of the link function. If it is a known name, it's not necessary to pass + any other arguments because functions are already defined internally. If not + known, all of `link``, ``linkinv`` and ``linkinv_backend`` must be specified. + link : function + A function that maps the response to the linear predictor. Known as the + :math:`g` function in GLM jargon. Does not need to be specified when ``name`` + is a known name. + linkinv : function + A function that maps the linear predictor to the response. Known as the + :math:`g^{-1}` function in GLM jargon. Does not need to be specified when + ``name`` is a known name. + linkinv_backend : function + Same than ``linkinv`` but must be something that works with PyMC backend + (i.e. it must work with PyTensor tensors). Does not need to be specified when + ``name`` is a known name. + """ + + def __init__( + self, + name, + link=None, + linkinv=None, + linkinv_backend=None, + bounds: tuple[float, float] | None = None, + ): + if name in HSSM_LINKS: + self.name = name + if name == "gen_logit": + if bounds is None: + raise ValueError( + "Bounds must be specified for generalized log link function." + ) + self.link = self._make_generalized_logit_simple(*bounds) + self.linkinv = self._make_generalized_sigmoid_simple(*bounds) + self.linkinv_backend = self._make_generalized_sigmoid_simple(*bounds) + else: + bmb.Link.__init__(name, link, linkinv, linkinv_backend) + + self.bounds = bounds + + def _make_generalized_sigmoid_simple(self, a, b): + """Make a generalized sigmoid link function with bounds a and b.""" + + def invlink_(x): + return a + ((b - a) / (1 + np.exp(-x))) + + return invlink_ + + def _make_generalized_logit_simple(self, a, b): + """Make a generalized logit link function with bounds a and b.""" + + def link_(x): + return np.log((x - a) / (b - x)) + + return link_ From 5c749065e22143a2b18a9b10d59026e3a58ae937 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Tue, 14 Nov 2023 13:28:22 -0500 Subject: [PATCH 02/10] add override_default_link to hssm.Param --- src/hssm/param.py | 67 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/src/hssm/param.py b/src/hssm/param.py index 05a161d1..0dfde27c 100644 --- a/src/hssm/param.py +++ b/src/hssm/param.py @@ -5,7 +5,9 @@ import bambi as bmb import numpy as np +import pandas as pd +from .link import Link from .prior import Prior # PEP604 union operator "|" not supported by pylint @@ -69,6 +71,15 @@ def __init__( self._is_parent = False self._is_converted = False self._do_not_truncate = False + self._link_specified = link is not None + + # Provides a convenient way to specify the link function + if self.link == "gen_logit": + if self.bounds is None: + raise ValueError( + "Bounds must be specified for generalized log link function." + ) + self.link = Link("gen_logit", bounds=self.bounds) # The initializer does not do anything immediately after the object is initiated # because things could still change. @@ -82,6 +93,62 @@ def update(self, **kwargs): raise ValueError(f"{attr} does not exist.") setattr(self, attr, value) + def override_default_link(self): + """Override the default link function. + + This is most likely because both default prior and default bounds are supplied. + """ + if self._is_converted: + raise ValueError( + ( + "Cannot override the default link function for parameter %s." + + " The object has already been processed." + ) + % self.name, + ) + + if not self.is_regression or self._link_specified: + return # do nothing + + if self.bounds is None: + raise ValueError( + ( + "Cannot override the default link function. Bounds are not" + + " specified for parameter %s." + ) + % self.name, + ) + + lower, upper = self.bounds + + if np.isneginf(lower) and np.isposinf(upper): + return + elif lower == 0.0 and np.isposinf(upper): + self.link = "log" + if not np.isneginf(lower) and not np.isposinf(upper): + self.link = Link("gen_logit", bounds=self.bounds) + else: + _logger.warning( + "The bounds for parameter %s (%f, %f) seems strange. Nothing is done to" + + " the link function. Please check if they are correct.", + self.name, + lower, + upper, + ) + + def override_default_priors(self, data: pd.DataFrame): + """Override the default priors. + + By supplying priors for all parameters in the regression, we can override the + defaults that Bambi uses. + + Parameters + ---------- + data + The data used to fit the model. + """ + return # Will implement in the next PR + def set_parent(self): """Set the Param as parent.""" self._is_parent = True From b688d68acb101e3ca0f1f25d47d0cca632c654be Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Tue, 14 Nov 2023 13:30:50 -0500 Subject: [PATCH 03/10] Add a default parameter to HSSM class --- src/hssm/hssm.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py index 93370fa2..64cb3111 100644 --- a/src/hssm/hssm.py +++ b/src/hssm/hssm.py @@ -142,6 +142,12 @@ class HSSM: If True, and if there is a `participant_id` field in `data`, will by default turn any unspecified parameter theta into a regression with "theta ~ 1 + (1|participant_id)" and default priors set by `bambi`. + default : optional + If "link", will override the default links for all parameters with a regression + with a generalized log link function defined by the bounds of the parameter. + If "prior", will override the bambi default priors for all parameters with a + regression with a set of priors defined by HSSM. If None, no default will be + overridden. Defaults to None if `hierarchical` is False, otherwise "prior". **kwargs Additional arguments passed to the `bmb.Model` object. @@ -190,6 +196,7 @@ def __init__( p_outlier: float | dict | bmb.Prior | None = 0.05, lapse: dict | bmb.Prior | None = bmb.Prior("Uniform", lower=0.0, upper=10.0), hierarchical: bool = False, + default: Literal["prior", "link"] | None = None, **kwargs, ): self.data = data @@ -202,6 +209,14 @@ def __init__( + "`participant_id` field in the DataFrame that you have passed." ) + if self.hierarchical: + if default is None: + self.override_strategy: Literal["prior", "link"] | None = "prior" + else: + self.override_strategy = default + else: + self.override_strategy = default + self.n_responses = len(self.data["response"].unique()) # Construct a model_config from defaults @@ -246,6 +261,9 @@ def __init__( self._parent, self._parent_param = self._find_parent() assert self._parent_param is not None + if self.override_strategy is not None: + self._override_defaults(self.override_strategy) + self._process_all() # Get the bambi formula, priors, and link @@ -876,7 +894,6 @@ def _preprocess_rest(self, processed: dict[str, Param]) -> dict[str, Param]: param = Param( param_str, formula="1 + (1|participant_id)", - link="identity", bounds=bounds, ) else: @@ -917,6 +934,14 @@ def _find_parent(self) -> tuple[str, Param]: param.set_parent() return param_str, param + def _override_defaults(self, default: Literal["prior", "link"]): + """Override the default priors or links.""" + for param in self.list_params: + if default == "prior": + param.override_default_priors(self.data) + elif default == "link": + param.override_default_link() + def _process_all(self): """Process all params.""" for param in self.list_params: From 1f601b9514b2e9d6c1cdeee5a92184c9b1b52abe Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Tue, 14 Nov 2023 13:31:37 -0500 Subject: [PATCH 04/10] exports Link in __init__.py --- src/hssm/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/hssm/__init__.py b/src/hssm/__init__.py index d7ba5d8f..68279901 100644 --- a/src/hssm/__init__.py +++ b/src/hssm/__init__.py @@ -15,6 +15,7 @@ from .datasets import load_data from .defaults import show_defaults from .hssm import HSSM +from .link import Link from .param import Param from .prior import Prior from .simulator import simulate_data @@ -29,6 +30,7 @@ __all__ = [ "HSSM", + "Link", "load_data", "ModelConfig", "Param", From 5bc47f669a22e6c4dbb4e5d63b633de8a3b3f8c4 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Tue, 14 Nov 2023 13:31:45 -0500 Subject: [PATCH 05/10] Add tests --- tests/test_hssm.py | 75 +++++++++++++++++++++++----------------- tests/test_param.py | 83 ++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 126 insertions(+), 32 deletions(-) diff --git a/tests/test_hssm.py b/tests/test_hssm.py index 65261f55..076aaaa7 100644 --- a/tests/test_hssm.py +++ b/tests/test_hssm.py @@ -5,7 +5,6 @@ import pandas as pd import pytensor import pytest -import ssms from hssm import HSSM from hssm.utils import download_hf @@ -13,45 +12,30 @@ pytensor.config.floatX = "float32" +param_v = { + "name": "v", + "prior": { + "Intercept": {"name": "Uniform", "lower": -3.0, "upper": 3.0}, + "x": {"name": "Uniform", "lower": -0.50, "upper": 0.50}, + "y": {"name": "Uniform", "lower": -0.50, "upper": 0.50}, + }, + "formula": "v ~ 1 + x + y", +} + +param_a = param_v | dict(name="a", formula="a ~ 1 + x + y") + @pytest.mark.parametrize( "include, should_raise_exception", [ ( - [ - { - "name": "v", - "prior": { - "Intercept": {"name": "Uniform", "lower": -3.0, "upper": 3.0}, - "x": {"name": "Uniform", "lower": -0.50, "upper": 0.50}, - "y": {"name": "Uniform", "lower": -0.50, "upper": 0.50}, - }, - "formula": "v ~ 1 + x + y", - "link": "identity", - } - ], + [param_v], False, ), ( [ - { - "name": "v", - "prior": { - "Intercept": {"name": "Uniform", "lower": -2.0, "upper": 3.0}, - "x": {"name": "Uniform", "lower": -0.50, "upper": 0.50}, - "y": {"name": "Uniform", "lower": -0.50, "upper": 0.50}, - }, - "formula": "v ~ 1 + x + y", - }, - { - "name": "a", - "prior": { - "Intercept": {"name": "Uniform", "lower": -2.0, "upper": 3.0}, - "x": {"name": "Uniform", "lower": -0.50, "upper": 0.50}, - "y": {"name": "Uniform", "lower": -0.50, "upper": 0.50}, - }, - "formula": "a ~ 1 + x + y", - }, + param_v, + param_a, ], False, ), @@ -230,6 +214,7 @@ def test_hierarchical(data_ddm): for name, param in model.params.items() if name != "p_outlier" ) + assert model.override_strategy == "prior" model = HSSM( data=data_ddm, @@ -262,3 +247,31 @@ def test_hierarchical(data_ddm): for name, param in model.params.items() if name not in ["v", "a", "p_outlier"] ) + + +def test_override_default_link(caplog, data_ddm_reg): + param_v = { + "name": "v", + "prior": { + "Intercept": {"name": "Uniform", "lower": -3.0, "upper": 3.0}, + "x": {"name": "Uniform", "lower": -0.50, "upper": 0.50}, + "y": {"name": "Uniform", "lower": -0.50, "upper": 0.50}, + }, + "formula": "v ~ 1 + x + y", + } + param_v = param_v | dict(bounds=(-np.inf, np.inf)) + param_a = param_v | dict(name="a", formula="a ~ 1 + x + y", bounds=(0, np.inf)) + param_z = param_v | dict(name="z", formula="z ~ 1 + x + y", bounds=(0, 1)) + param_t = param_v | dict(name="t", formula="t ~ 1 + x + y", bounds=(0.1, np.inf)) + + model = HSSM( + data=data_ddm_reg, include=[param_v, param_a, param_z, param_t], default="link" + ) + + assert model.params["v"].link == "identity" + assert model.params["a"].link == "log" + assert model.params["z"].link.name == "gen_logit" + assert model.params["t"].link == "identity" + + assert "t" in caplog.records[0].message + assert "strange" in caplog.records[0].message diff --git a/tests/test_param.py b/tests/test_param.py index cbedf11d..6f4664a9 100644 --- a/tests/test_param.py +++ b/tests/test_param.py @@ -23,8 +23,16 @@ def test_param_creation_non_regression(): } param_v = Param(**v) + param_v.override_default_link() param_v.convert() + with pytest.raises( + ValueError, + match="Cannot override the default link function for parameter v." + + " The object has already been processed.", + ): + param_v.override_default_link() + assert param_v.name == "v" assert isinstance(param_v.prior, bmb.Prior) assert param_v.prior.args["mu"] == 0.0 @@ -43,9 +51,11 @@ def test_param_creation_non_regression(): } param_a = Param(**a) + param_a.override_default_link() param_a.convert() assert param_a.is_truncated + assert param_a.link is None assert not param_a.is_fixed assert param_a.prior.is_truncated param_a_output = param_a.__str__().split("\r\n")[1].split("Prior: ")[1] @@ -100,6 +110,18 @@ def test_param_creation_non_regression(): assert pt.is_fixed assert not ptheta.is_truncated + model_1 = hssm.HSSM( + model="angle", + data=hssm.simulate_data( + model="angle", theta=[0.5, 1.5, 0.5, 0.5, 0.3], size=10 + ), + include=[v, a, z, t], + default="link", + ) + + for param in model_1.params.values(): + assert param.link is None + def test_param_creation_regression(): v_reg = { @@ -111,16 +133,27 @@ def test_param_creation_regression(): "y": bmb.Prior("Uniform", lower=0.0, upper=1.0), "z": 0.1, }, - "link": "identity", } v_reg_param = Param(**v_reg) + with pytest.raises( + ValueError, + match="Cannot override the default link function. Bounds are not" + + " specified for parameter v.", + ): + v_reg_param.override_default_link() v_reg_param.convert() assert v_reg_param.is_regression assert not v_reg_param.is_fixed assert not v_reg_param.is_truncated assert v_reg_param.formula == v_reg["formula"] + with pytest.raises( + ValueError, + match="Cannot override the default link function for parameter v." + + " The object has already been processed.", + ): + v_reg_param.override_default_link() # Generate some fake simulation data intercept = 0.3 @@ -154,6 +187,15 @@ def test_param_creation_regression(): assert not v_reg_param.is_truncated assert v_reg_param.formula == v_reg["formula"] + model_reg_v = hssm.HSSM( + data=dataset_reg_v, + model="ddm", + include=[v_reg], + default="link", + ) + + assert model_reg_v.params["v"].link == "identity" + def test__make_default_prior(): prior1 = _make_default_prior((-10.0, 10.0)) @@ -336,3 +378,42 @@ def test__make_bounded_prior(caplog): _make_bounded_prior(name, prior7, bounds) assert caplog.records[0].msg == caplog.records[1].msg + + +some_forumla = "1 + x + y" + + +@pytest.mark.parametrize( + ("formula", "link", "bounds", "result"), + [ + (None, None, (0, 1), None), + (some_forumla, None, None, "Error"), + (some_forumla, None, (0, 1), "gen_logit"), + (some_forumla, None, (0, np.inf), "log"), + (some_forumla, None, (-np.inf, 1), "warning"), + (some_forumla, None, (-np.inf, np.inf), "identity"), + (some_forumla, "logit", None, "logit"), + (some_forumla, "gen_logit", None, "Error"), + ], +) +def test_param_override_default_link(caplog, formula, link, bounds, result): + if result == "Error": + with pytest.raises(ValueError): + param = Param("a", formula=formula, link=link, bounds=bounds) + param.override_default_link() + else: + param = Param("a", formula=formula, link=link, bounds=bounds) + param.override_default_link() + param.convert() + if result == "warning": + assert "strange" in caplog.records[0].msg + else: + if result == "gen_logit": + assert isinstance(param.link, hssm.Link) + elif result is None: + assert param.link is None + else: + assert param.link == result + + with pytest.raises(ValueError): + param.override_default_link() From bc3a3fc77b8c16539b64e3f59191707a0e3f034c Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Tue, 14 Nov 2023 14:53:52 -0500 Subject: [PATCH 06/10] fix copilot error that causes tests to fail --- src/hssm/hssm.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py index 64cb3111..adcdc475 100644 --- a/src/hssm/hssm.py +++ b/src/hssm/hssm.py @@ -937,10 +937,11 @@ def _find_parent(self) -> tuple[str, Param]: def _override_defaults(self, default: Literal["prior", "link"]): """Override the default priors or links.""" for param in self.list_params: + param_obj = self.params[param] if default == "prior": - param.override_default_priors(self.data) + param_obj.override_default_priors(self.data) elif default == "link": - param.override_default_link() + param_obj.override_default_link() def _process_all(self): """Process all params.""" From c7c6eebe0acfe0f3a5ddf48af54a4abfb7c456f7 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Tue, 14 Nov 2023 15:02:10 -0500 Subject: [PATCH 07/10] remove xhistogram --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index f6e79613..2b2e670e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,6 @@ bambi = "^0.12.0" numpyro = "^0.12.1" hddm-wfpt = "^0.1.1" seaborn = "^0.13.0" -xhistogram = "^0.3.2" [tool.poetry.group.dev.dependencies] pytest = "^7.3.1" From 84ee5089cc88d71918b6cd11efdb1cee5e1349b6 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Wed, 15 Nov 2023 15:02:54 -0500 Subject: [PATCH 08/10] rename default to link_settings and prior_settings --- src/hssm/hssm.py | 64 +++++++++++++++++++++++++++++---------------- tests/test_hssm.py | 6 +++-- tests/test_param.py | 4 +-- 3 files changed, 48 insertions(+), 26 deletions(-) diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py index adcdc475..d17b2782 100644 --- a/src/hssm/hssm.py +++ b/src/hssm/hssm.py @@ -141,13 +141,29 @@ class HSSM: hierarchical : optional If True, and if there is a `participant_id` field in `data`, will by default turn any unspecified parameter theta into a regression with - "theta ~ 1 + (1|participant_id)" and default priors set by `bambi`. - default : optional - If "link", will override the default links for all parameters with a regression - with a generalized log link function defined by the bounds of the parameter. - If "prior", will override the bambi default priors for all parameters with a - regression with a set of priors defined by HSSM. If None, no default will be - overridden. Defaults to None if `hierarchical` is False, otherwise "prior". + "theta ~ 1 + (1|participant_id)" and default priors set by `bambi`. Also changes + default values of `link_settings` and `prior_settings`. Defaults to False. + link_settings : optional + An optional string literal that indicates the link functions to use for each + parameter. Helpful for hierarchical models where sampling might get stuck/ + very slow. Can be one of the following: + + - `"log_logit"`: applies log link functions to positive parameters and + generalized logit link functions to parameters that have explicit bounds. + - `None`: unless otherwise specified, the `"identity"` link functions will be + used. + The default value is `None`. + prior_settings : optional + An optional string literal that indicates the prior distributions to use for + each parameter. Helpful for hierarchical models where sampling might get stuck/ + very slow. Can be one of the following: + + - `"safe"`: HSSM will scan all parameters in the model and apply safe priors to + all parameters that do not have explicit bounds. + - `None`: HSSM will use bambi to provide default priors for all parameters. Not + recommended when you are using hierarchical models. + The default value is `None` when `hierarchical` is `False` and `"safe"` when + `hierarchical` is `True`. **kwargs Additional arguments passed to the `bmb.Model` object. @@ -196,7 +212,8 @@ def __init__( p_outlier: float | dict | bmb.Prior | None = 0.05, lapse: dict | bmb.Prior | None = bmb.Prior("Uniform", lower=0.0, upper=10.0), hierarchical: bool = False, - default: Literal["prior", "link"] | None = None, + link_settings: Literal["log_logit"] | None = None, + prior_settings: Literal["safe"] | None = None, **kwargs, ): self.data = data @@ -209,15 +226,20 @@ def __init__( + "`participant_id` field in the DataFrame that you have passed." ) - if self.hierarchical: - if default is None: - self.override_strategy: Literal["prior", "link"] | None = "prior" - else: - self.override_strategy = default - else: - self.override_strategy = default + if self.hierarchical and prior_settings is None: + prior_settings = "safe" + + self.link_settings = link_settings + self.prior_settings = prior_settings - self.n_responses = len(self.data["response"].unique()) + responses = self.data["response"].unique().astype(int) + self.n_responses = len(responses) + if self.n_responses == 2: + if -1 not in responses or 1 not in responses: + raise ValueError( + "The response column must contain only -1 and 1 when there are " + + "two responses." + ) # Construct a model_config from defaults self.model_config = Config.from_defaults(model, loglik_kind) @@ -261,9 +283,7 @@ def __init__( self._parent, self._parent_param = self._find_parent() assert self._parent_param is not None - if self.override_strategy is not None: - self._override_defaults(self.override_strategy) - + self._override_defaults() self._process_all() # Get the bambi formula, priors, and link @@ -934,13 +954,13 @@ def _find_parent(self) -> tuple[str, Param]: param.set_parent() return param_str, param - def _override_defaults(self, default: Literal["prior", "link"]): + def _override_defaults(self): """Override the default priors or links.""" for param in self.list_params: param_obj = self.params[param] - if default == "prior": + if self.prior_settings == "safe": param_obj.override_default_priors(self.data) - elif default == "link": + elif self.link_settings == "log_logit": param_obj.override_default_link() def _process_all(self): diff --git a/tests/test_hssm.py b/tests/test_hssm.py index 076aaaa7..3d890fe9 100644 --- a/tests/test_hssm.py +++ b/tests/test_hssm.py @@ -214,7 +214,7 @@ def test_hierarchical(data_ddm): for name, param in model.params.items() if name != "p_outlier" ) - assert model.override_strategy == "prior" + assert model.prior_settings == "safe" model = HSSM( data=data_ddm, @@ -265,7 +265,9 @@ def test_override_default_link(caplog, data_ddm_reg): param_t = param_v | dict(name="t", formula="t ~ 1 + x + y", bounds=(0.1, np.inf)) model = HSSM( - data=data_ddm_reg, include=[param_v, param_a, param_z, param_t], default="link" + data=data_ddm_reg, + include=[param_v, param_a, param_z, param_t], + link_settings="log_logit", ) assert model.params["v"].link == "identity" diff --git a/tests/test_param.py b/tests/test_param.py index 6f4664a9..8b43588c 100644 --- a/tests/test_param.py +++ b/tests/test_param.py @@ -116,7 +116,7 @@ def test_param_creation_non_regression(): model="angle", theta=[0.5, 1.5, 0.5, 0.5, 0.3], size=10 ), include=[v, a, z, t], - default="link", + link_settings="log_logit", ) for param in model_1.params.values(): @@ -191,7 +191,7 @@ def test_param_creation_regression(): data=dataset_reg_v, model="ddm", include=[v_reg], - default="link", + link_settings="log_logit", ) assert model_reg_v.params["v"].link == "identity" From 9840c250bef94f2cff730ca49e7ee6c00dce3e40 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Thu, 16 Nov 2023 11:05:23 -0500 Subject: [PATCH 09/10] limit pytensor version --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 2b2e670e..631d5826 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ keywords = ["HSSM", "sequential sampling models", "bayesian", "bayes", "mcmc"] [tool.poetry.dependencies] python = ">=3.10,<3.12" -pymc = ">=5.9.0" +pymc = "^5.9.0" scipy = "1.10.1" arviz = "^0.14.0" numpy = ">=1.23.4,<1.26" @@ -30,6 +30,7 @@ bambi = "^0.12.0" numpyro = "^0.12.1" hddm-wfpt = "^0.1.1" seaborn = "^0.13.0" +pytensor = "<=2.17.3" [tool.poetry.group.dev.dependencies] pytest = "^7.3.1" From dc926d9576421f935ebed33060738c5f08143d4e Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Thu, 16 Nov 2023 11:05:54 -0500 Subject: [PATCH 10/10] Fix docstring for hssm.Link --- src/hssm/link.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/hssm/link.py b/src/hssm/link.py index 9eea10d5..1ad00c10 100644 --- a/src/hssm/link.py +++ b/src/hssm/link.py @@ -20,22 +20,24 @@ class Link(bmb.Link): Parameters ---------- - name : str + name The name of the link function. If it is a known name, it's not necessary to pass any other arguments because functions are already defined internally. If not known, all of `link``, ``linkinv`` and ``linkinv_backend`` must be specified. - link : function + link : optional A function that maps the response to the linear predictor. Known as the :math:`g` function in GLM jargon. Does not need to be specified when ``name`` is a known name. - linkinv : function + linkinv : optional A function that maps the linear predictor to the response. Known as the :math:`g^{-1}` function in GLM jargon. Does not need to be specified when ``name`` is a known name. - linkinv_backend : function + linkinv_backend : optional Same than ``linkinv`` but must be something that works with PyMC backend (i.e. it must work with PyTensor tensors). Does not need to be specified when ``name`` is a known name. + bounds : optional + Bounds of the response scale. Only needed when ``name`` is ``gen_logit``. """ def __init__(