diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml
index dc3e33ab..71b70fe5 100644
--- a/.github/workflows/run_tests.yml
+++ b/.github/workflows/run_tests.yml
@@ -2,7 +2,6 @@ name: Run tests
on:
pull_request:
- push:
jobs:
run_tests:
diff --git a/docs/changelog.md b/docs/changelog.md
index d641dcc2..18621e2c 100644
--- a/docs/changelog.md
+++ b/docs/changelog.md
@@ -1,5 +1,31 @@
# Changelog
+## 0.2.x
+
+### 0.2.0
+
+This is a major version update! Many changes have taken place in this version:
+
+#### Breaking changes
+
+When `hierarchical` argument of `hssm.HSSM` is set to `True`, HSSM will look into the
+`data` provided for the `participant_id` field. If it does not exist, an error will
+be thrown.
+
+### New features
+
+- Added `link_settings` and `prior_settings` arguments to `hssm.HSSM`, which allows HSSM
+ to use intelligent default priors and link functions for complex hierarchical models.
+
+- Added an `hssm.plotting` submodule with `plot_posterior_predictive()` and
+ `plot_quantile_probability` for creating posterior predictive plots and quantile
+ probability plots.
+
+- Added an `extra_fields` argument to `hssm.HSSM` to pass additional data to the
+ likelihood function computation.
+
+- Limited `PyMC`, `pytensor`, `numpy`, and `jax` dependency versions for compatibility.
+
## 0.1.x
### 0.1.5
@@ -8,8 +34,8 @@ We fixed the errors in v0.1.4. Sorry for the convenience! If you have accidental
downloaded v0.1.4, please make sure that you update hssm to the current version.
- We made Cython dependencies of this package available via pypi. We have also built
-wheels for (almost) all platforms so there is no need to build these Cython
-dependencies.
+ wheels for (almost) all platforms so there is no need to build these Cython
+ dependencies.
### 0.1.4
diff --git a/docs/getting_started/hierarchical_modeling.ipynb b/docs/getting_started/hierarchical_modeling.ipynb
new file mode 100644
index 00000000..04c81337
--- /dev/null
+++ b/docs/getting_started/hierarchical_modeling.ipynb
@@ -0,0 +1,640 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "0aec427c-56d5-48bb-83c6-6bdc0c445ba2",
+ "metadata": {},
+ "source": [
+ "# Hierarchical Modeling\n",
+ "\n",
+ "This tutorial demonstrates how to take advantage of HSSM's hierarchical modeling capabilities. We will cover the following:\n",
+ "\n",
+ "- How to define a mixed-effect regression\n",
+ "- How to define a hierarchial HSSM model\n",
+ "- How to apply prior and link function settings to ensure successful sampling"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "36aafcf1-e703-40e3-b11d-4ab5ad74655f",
+ "metadata": {},
+ "source": [
+ "## Colab Instructions\n",
+ "\n",
+ "If you would like to run this tutorial on Google colab, please click this [link](https://github.com/lnccbrown/HSSM/blob/main/docs/tutorial_notebooks/no_execute/getting_started.ipynb). \n",
+ "\n",
+ "Once you are *in the colab*, follow the *installation instructions below* and then **restart your runtime**. \n",
+ "\n",
+ "Just **uncomment the code in the next code cell** and run it!\n",
+ "\n",
+ "**NOTE**:\n",
+ "\n",
+ "You may want to *switch your runtime* to have a GPU or TPU. To do so, go to *Runtime* > *Change runtime type* and select the desired hardware accelerator.\n",
+ "\n",
+ "Note that if you switch your runtime you have to follow the installation instructions again."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "61937b47-810d-41b6-a6b8-e461c5e5ae71",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# !pip install hssm"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "650b011a-62b3-4243-9ed7-2087b2f232cd",
+ "metadata": {},
+ "source": [
+ "## Import Modules"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "11fc424e-2aff-49b0-b1a9-d54c7d7f67be",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "import matplotlib as plt\n",
+ "\n",
+ "import hssm\n",
+ "\n",
+ "%matplotlib inline\n",
+ "%config InlineBackend.figure_format='retina'"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d671ab29-710a-47ab-af79-32fac7891318",
+ "metadata": {},
+ "source": [
+ "### Setting the global float type\n",
+ "\n",
+ "**Note**: Using the analytical DDM (Drift Diffusion Model) likelihood in PyMC without setting the float type in `PyTensor` may result in warning messages during sampling, which is a known bug in PyMC v5.6.0 and earlier versions. To avoid these warnings, we provide a convenience function:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "01314abb-6ee5-4fc5-975e-002768fde007",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Setting PyTensor floatX type to float32.\n",
+ "Setting \"jax_enable_x64\" to False. If this is not intended, please set `jax` to False.\n"
+ ]
+ }
+ ],
+ "source": [
+ "hssm.set_floatX(\"float32\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "c2fcd33e-46e0-41fe-8fd6-26eedd39aec0",
+ "metadata": {},
+ "source": [
+ "## 1. Defining Regressions\n",
+ "\n",
+ "Under the hood, HSSM uses [`bambi`](https://bambinos.github.io/bambi/) for model creation. `bambi` takes inspiration from the [`lme4` package in R](https://www.rdocumentation.org/packages/lme4/versions/1.1-35.1/topics/lmer) and supports the definition of generalized linear mixed-effect models through\n",
+ "R-like formulas and concepts such as link functions. This makes it possible to create arbitrary mixed-effect regressions in HSSM, which is one advantage of HSSM over HDDM. Now let's walk through the ways to define a parameter with a regression in HSSM.\n",
+ "\n",
+ "### Specifying fixed- and random-effect terms\n",
+ "\n",
+ "Suppose that we want to define a parameter `v` that has a regression defined. There are two ways to define such a parameter - either through a dictionary\n",
+ "or through a `hssm.Param` object:\n",
+ "\n",
+ "```\n",
+ "# The following code are equivalent,\n",
+ "# including the definition of the formula.\n",
+ "\n",
+ "# The dictionary way:\n",
+ "param_v = {\n",
+ " \"name\": \"v\",\n",
+ " \"formula\": \"v ~ (1|participant_id) + x + y + x:y\",\n",
+ " \"link\": \"identity\",\n",
+ " \"prior\": {\n",
+ " \"Intercept\": {\"name\": \"Normal\", \"mu\": 0.0, \"sigma\": 0.25},\n",
+ " \"1|participant_id\": {\n",
+ " \"name\": \"Normal\",\n",
+ " \"mu\": 0.0,\n",
+ " \"sigma\": {\"name\": \"HalfNormal\", \"sigma\": 0.2}, # this is a hyperprior\n",
+ " },\n",
+ " \"x\": {\"name\": \"Normal\", \"mu\": 0.0, \"sigma\": 0.25},\n",
+ " },\n",
+ "}\n",
+ "\n",
+ "# The object-oriented way\n",
+ "param_v = hssm.Param(\n",
+ " \"v\",\n",
+ " formula=\"v ~ 1 + (1|participant_id) + x*y\",\n",
+ " link=\"identity\",\n",
+ " prior={\n",
+ " \"Intercept\": hssm.Prior(\"Normal\", mu=0.0, sigma=0.25),\n",
+ " \"1|participant_id\": hssm.Prior(\n",
+ " \"Normal\",\n",
+ " mu=0.0,\n",
+ " sigma=hssm.Prior(\"HalfNormal\", sigma=0.2), # this is a hyperprior\n",
+ " ),\n",
+ " \"x\": hssm.Prior(\"Normal\", mu=0.0, sigma=0.25),\n",
+ " },\n",
+ ")\n",
+ "```\n",
+ "\n",
+ "The formula `\"v ~ (1|participant_id) + x + y + x:y\"` defines a random-intercept model. Like R, unless otherwise specified, a fixed-effect intercept term is added to the formula by default. You can make this explicit by adding a `1` to the formula. Or, if your regression does not have an intercept. you can explicitly remove the intercept term by using a `0` in the place of `1`: `\"v ~ 0 + (1|participant_id) + x * y\"`.\n",
+ "\n",
+ "Other fixed effect covariates are `x`, `y`, and the interaction term `x:y`. When all three terms are present, you can use the shortcut `x * y` in place of the three terms.\n",
+ "\n",
+ "The only random effect term in this model is `1|participant_id`. It is a random-intercept term with `participant_id` indicating the grouping variable. You can add another random-effect term in a similar way: `\"v ~ (1|participant_id) + (x|participant_id) + x + y + x:y\"`, or more briefly, `\"v ~ (1 + x|participant_id) + x + y + x:y\"`.\n",
+ "\n",
+ "### Specifying priors for fixed- and random-effect terms:\n",
+ "\n",
+ "As demonstrated in the above code, you can specify priors of each term through a dictionary, with the key being the name of each term, and the corresponding value being the prior specification, etiher through a dictionary, or a `hssm.Prior` object. There are a few things to note:\n",
+ "\n",
+ "* The prior of fixed-effect intercept is specified with `\"Intercept\"`, capitalized.\n",
+ "* For random effects, you can specify hyperpriors for the parameters of of their priors.\n",
+ "\n",
+ "### Specifying the link functions:\n",
+ "\n",
+ "Link functions is another concept in frequentist generalized linear models, which defines a transformation between the linear combination of the covariates and the response variable. This is helpful especially when the response variable is not normally distributed, e.g. in a logistic regression. In HSSM, the link function is identity by default. However, since some parameters of SSMs are defined on `(0, inf)` or `(0, 1)`, link function can be helpful in ensuring the result of the regression is defined for these parameters. We will come back to this later."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "3cb603f1-dc4f-44d3-b12f-e4a460d2d9f9",
+ "metadata": {},
+ "source": [
+ "## 2. Defining a hierarchical HSSM model\n",
+ "\n",
+ "In fact, HSSM does not differentiate between a hierarchical or non-hierarchical model. A hierarchical model in HSSM is simply a model with one or more parameters defined as regressions. However, HSSM does provide some useful functionalities in creating hierarchical models.\n",
+ "\n",
+ "### Clarifying the use of `hierarchical` argument during model creation\n",
+ "\n",
+ "First, HSSM has a `hierarchical` argument which is a `bool`. It serves as a convenient switch to add a random-intercept regression to any parameter that is not explicitly defined by the user, using `participant_id` as a grouping variable. If there is not a `participant_id` column in the data, setting `hierarchical` to `True` will raise an error. Setting `hierarchical` to True will also change some default behavior in HSSM. Here's an example:\n",
+ "\n",
+ "
\n",
+ "
Note
\n",
+ "
\n",
+ " In HSSM, the default grouping variable is now `participant_id`, which is different from `subj_idx` in HDDM.\n",
+ "
\n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "cab2a960-e9e6-4043-996c-57742832de0d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Load a package-supplied dataset\n",
+ "cav_data = hssm.load_data(\"cavanagh_theta\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "28c75438-c6c4-4589-8244-dac7ffc18eb5",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Hierarchical Sequential Sampling Model\n",
+ "Model: ddm\n",
+ "\n",
+ "Response variable: rt,response\n",
+ "Likelihood: analytical\n",
+ "Observations: 3988\n",
+ "\n",
+ "Parameters:\n",
+ "\n",
+ "v:\n",
+ " Prior: Normal(mu: 0.0, sigma: 2.0)\n",
+ " Explicit bounds: (-inf, inf)\n",
+ "a:\n",
+ " Prior: HalfNormal(sigma: 2.0)\n",
+ " Explicit bounds: (0.0, inf)\n",
+ "z:\n",
+ " Prior: Uniform(lower: 0.0, upper: 1.0)\n",
+ " Explicit bounds: (0.0, 1.0)\n",
+ "t:\n",
+ " Prior: HalfNormal(sigma: 2.0, initval: 0.10000000149011612)\n",
+ " Explicit bounds: (0.0, inf)\n",
+ "\n",
+ "Lapse probability: 0.05\n",
+ "Lapse distribution: Uniform(lower: 0.0, upper: 10.0)"
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Define a basic non-hierarchical model\n",
+ "model_non_hierarchical = hssm.HSSM(data=cav_data)\n",
+ "model_non_hierarchical"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "7cde1989-8d6b-437a-a972-940f0fd84904",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Hierarchical Sequential Sampling Model\n",
+ "Model: ddm\n",
+ "\n",
+ "Response variable: rt,response\n",
+ "Likelihood: analytical\n",
+ "Observations: 3988\n",
+ "\n",
+ "Parameters:\n",
+ "\n",
+ "v:\n",
+ " Formula: v ~ 1 + (1|participant_id)\n",
+ " Priors:\n",
+ " v_Intercept ~ Normal(mu: 2.0, sigma: 3.0)\n",
+ " v_1|participant_id ~ Normal(mu: 0.0, sigma: Weibull(alpha: 1.5, beta: 0.30000001192092896))\n",
+ " Link: identity\n",
+ " Explicit bounds: (-inf, inf)\n",
+ "a:\n",
+ " Formula: a ~ 1 + (1|participant_id)\n",
+ " Priors:\n",
+ " a_Intercept ~ Gamma(mu: 1.5, sigma: 0.75)\n",
+ " a_1|participant_id ~ Normal(mu: 0.0, sigma: Weibull(alpha: 1.5, beta: 0.30000001192092896))\n",
+ " Link: identity\n",
+ " Explicit bounds: (0.0, inf)\n",
+ "z:\n",
+ " Formula: z ~ 1 + (1|participant_id)\n",
+ " Priors:\n",
+ " z_Intercept ~ Gamma(mu: 10.0, sigma: 10.0)\n",
+ " z_1|participant_id ~ Normal(mu: 0.0, sigma: Weibull(alpha: 1.5, beta: 0.30000001192092896))\n",
+ " Link: identity\n",
+ " Explicit bounds: (0.0, 1.0)\n",
+ "t:\n",
+ " Formula: t ~ 1 + (1|participant_id)\n",
+ " Priors:\n",
+ " t_Intercept ~ Gamma(mu: 0.4000000059604645, sigma: 0.20000000298023224)\n",
+ " t_1|participant_id ~ Normal(mu: 0.0, sigma: Weibull(alpha: 1.5, beta: 0.30000001192092896))\n",
+ " Link: identity\n",
+ " Explicit bounds: (0.0, inf)\n",
+ "\n",
+ "Lapse probability: 0.05\n",
+ "Lapse distribution: Uniform(lower: 0.0, upper: 10.0)"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Now let's set `hierarchical` to True\n",
+ "model_hierarchical = hssm.HSSM(data=cav_data, hierarchical=True, prior_settings=\"safe\")\n",
+ "model_hierarchical"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e4baeb2f-c49e-4d63-a962-4c6f47f3f848",
+ "metadata": {},
+ "source": [
+ "## 3. Intelligent defaults for complex hierarchical models\n",
+ "\n",
+ "`bambi` is not designed with HSSM in mind. Therefore, in cases where priors for certain parameters are not defined, the default priors supplied by `bambi` sometimes are not optimal. The same goes for link functions. `\"identity\"` link functions tend not to work well for certain parameters that are not defined on `(inf, inf)`. Therefore, we provide some default settings that the users can experiment to ensure that sampling is successful."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "11b405fc-d508-468c-8fd8-12283ed03945",
+ "metadata": {},
+ "source": [
+ "### `prior_settings`\n",
+ "\n",
+ "Currently we provide a `\"safe\"` strategy that uses HSSM default priors. This is turned on by default when `hierarchical` is set to `True`. One can compare the two models below, with `safe` strategy turned on and off:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "145545e3-4712-4929-84c5-53ab3f8ae051",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Hierarchical Sequential Sampling Model\n",
+ "Model: ddm\n",
+ "\n",
+ "Response variable: rt,response\n",
+ "Likelihood: approx_differentiable\n",
+ "Observations: 3988\n",
+ "\n",
+ "Parameters:\n",
+ "\n",
+ "v:\n",
+ " Formula: v ~ 1 + (1|participant_id)\n",
+ " Priors:\n",
+ " v_Intercept ~ Normal(mu: 0.0, sigma: 0.25)\n",
+ " v_1|participant_id ~ Normal(mu: 0.0, sigma: Weibull(alpha: 1.5, beta: 0.30000001192092896))\n",
+ " Link: identity\n",
+ " Explicit bounds: (-3.0, 3.0)\n",
+ "a:\n",
+ " Formula: a ~ 1 + (1|participant_id)\n",
+ " Priors:\n",
+ " a_Intercept ~ Normal(mu: 1.399999976158142, sigma: 0.25)\n",
+ " a_1|participant_id ~ Normal(mu: 0.0, sigma: Weibull(alpha: 1.5, beta: 0.30000001192092896))\n",
+ " Link: identity\n",
+ " Explicit bounds: (0.3, 2.5)\n",
+ "z:\n",
+ " Formula: z ~ 1 + (1|participant_id)\n",
+ " Priors:\n",
+ " z_Intercept ~ Normal(mu: 0.5, sigma: 0.25)\n",
+ " z_1|participant_id ~ Normal(mu: 0.0, sigma: Weibull(alpha: 1.5, beta: 0.30000001192092896))\n",
+ " Link: identity\n",
+ " Explicit bounds: (0.0, 1.0)\n",
+ "t:\n",
+ " Formula: t ~ 1 + (1|participant_id)\n",
+ " Priors:\n",
+ " t_Intercept ~ Normal(mu: 1.0, sigma: 0.25)\n",
+ " t_1|participant_id ~ Normal(mu: 0.0, sigma: Weibull(alpha: 1.5, beta: 0.30000001192092896))\n",
+ " Link: identity\n",
+ " Explicit bounds: (0.0, 2.0)\n",
+ "\n",
+ "Lapse probability: 0.05\n",
+ "Lapse distribution: Uniform(lower: 0.0, upper: 10.0)"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "model_safe = hssm.HSSM(\n",
+ " data=cav_data,\n",
+ " hierarchical=True,\n",
+ " prior_settings=\"safe\",\n",
+ " loglik_kind=\"approx_differentiable\",\n",
+ ")\n",
+ "model_safe"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "8ea21e3d-e5e4-4bf3-a5a1-7aea1e9d9d00",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Hierarchical Sequential Sampling Model\n",
+ "Model: ddm\n",
+ "\n",
+ "Response variable: rt,response\n",
+ "Likelihood: approx_differentiable\n",
+ "Observations: 3988\n",
+ "\n",
+ "Parameters:\n",
+ "\n",
+ "v:\n",
+ " Formula: v ~ 1 + (1|participant_id)\n",
+ " Priors:\n",
+ " v_Intercept ~ Normal(mu: 0.0, sigma: 0.25)\n",
+ " v_1|participant_id ~ Normal(mu: 0.0, sigma: Weibull(alpha: 1.5, beta: 0.30000001192092896))\n",
+ " Link: identity\n",
+ " Explicit bounds: (-3.0, 3.0)\n",
+ "a:\n",
+ " Formula: a ~ 1 + (1|participant_id)\n",
+ " Priors:\n",
+ " a_Intercept ~ Normal(mu: 1.399999976158142, sigma: 0.25)\n",
+ " a_1|participant_id ~ Normal(mu: 0.0, sigma: Weibull(alpha: 1.5, beta: 0.30000001192092896))\n",
+ " Link: identity\n",
+ " Explicit bounds: (0.3, 2.5)\n",
+ "z:\n",
+ " Formula: z ~ 1 + (1|participant_id)\n",
+ " Priors:\n",
+ " z_Intercept ~ Normal(mu: 0.5, sigma: 0.25)\n",
+ " z_1|participant_id ~ Normal(mu: 0.0, sigma: Weibull(alpha: 1.5, beta: 0.30000001192092896))\n",
+ " Link: identity\n",
+ " Explicit bounds: (0.0, 1.0)\n",
+ "t:\n",
+ " Formula: t ~ 1 + (1|participant_id)\n",
+ " Priors:\n",
+ " t_Intercept ~ Normal(mu: 1.0, sigma: 0.25)\n",
+ " t_1|participant_id ~ Normal(mu: 0.0, sigma: Weibull(alpha: 1.5, beta: 0.30000001192092896))\n",
+ " Link: identity\n",
+ " Explicit bounds: (0.0, 2.0)\n",
+ "\n",
+ "Lapse probability: 0.05\n",
+ "Lapse distribution: Uniform(lower: 0.0, upper: 10.0)"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "model_safe_off = hssm.HSSM(\n",
+ " data=cav_data,\n",
+ " hierarchical=True,\n",
+ " prior_settings=None,\n",
+ " loglik_kind=\"approx_differentiable\",\n",
+ ")\n",
+ "model_safe_off"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "c99b6d77-b4a2-4813-8927-433b65d646a3",
+ "metadata": {},
+ "source": [
+ "### `link_settings`\n",
+ "\n",
+ "We also provide a `link_settings` switch, which changes default link functions for parameters according to their explicit bounds. See the model below with `link_settings` set to `\"log_logit\"`:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "67d0f895-8188-4e44-8946-ef80af2c4b67",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Hierarchical Sequential Sampling Model\n",
+ "Model: ddm\n",
+ "\n",
+ "Response variable: rt,response\n",
+ "Likelihood: analytical\n",
+ "Observations: 3988\n",
+ "\n",
+ "Parameters:\n",
+ "\n",
+ "v:\n",
+ " Formula: v ~ 1 + (1|participant_id)\n",
+ " Priors:\n",
+ " v_Intercept ~ Normal(mu: 2.0, sigma: 3.0)\n",
+ " v_1|participant_id ~ Normal(mu: 0.0, sigma: Weibull(alpha: 1.5, beta: 0.30000001192092896))\n",
+ " Link: identity\n",
+ " Explicit bounds: (-inf, inf)\n",
+ "a:\n",
+ " Formula: a ~ 1 + (1|participant_id)\n",
+ " Priors:\n",
+ " a_Intercept ~ Gamma(mu: 1.5, sigma: 0.75)\n",
+ " a_1|participant_id ~ Normal(mu: 0.0, sigma: Weibull(alpha: 1.5, beta: 0.30000001192092896))\n",
+ " Link: log\n",
+ " Explicit bounds: (0.0, inf)\n",
+ "z:\n",
+ " Formula: z ~ 1 + (1|participant_id)\n",
+ " Priors:\n",
+ " z_Intercept ~ Gamma(mu: 10.0, sigma: 10.0)\n",
+ " z_1|participant_id ~ Normal(mu: 0.0, sigma: Weibull(alpha: 1.5, beta: 0.30000001192092896))\n",
+ " Link: Generalized logit link function with bounds (0.0, 1.0)\n",
+ " Explicit bounds: (0.0, 1.0)\n",
+ "t:\n",
+ " Formula: t ~ 1 + (1|participant_id)\n",
+ " Priors:\n",
+ " t_Intercept ~ Gamma(mu: 0.4000000059604645, sigma: 0.20000000298023224)\n",
+ " t_1|participant_id ~ Normal(mu: 0.0, sigma: Weibull(alpha: 1.5, beta: 0.30000001192092896))\n",
+ " Link: log\n",
+ " Explicit bounds: (0.0, inf)\n",
+ "\n",
+ "Lapse probability: 0.05\n",
+ "Lapse distribution: Uniform(lower: 0.0, upper: 10.0)"
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "model_log_logit = hssm.HSSM(\n",
+ " data=cav_data, hierarchical=True, prior_settings=None, link_settings=\"log_logit\"\n",
+ ")\n",
+ "model_log_logit"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "bc82d284-7164-4072-9a25-67fa8cc77b17",
+ "metadata": {},
+ "source": [
+ "### Mixing strategies:\n",
+ "\n",
+ "It is possible to turn on both `prior_settings` and `link_settings`:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "a6099bb5-2d55-4ef8-b08b-cee2edfa4bc7",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Hierarchical Sequential Sampling Model\n",
+ "Model: ddm\n",
+ "\n",
+ "Response variable: rt,response\n",
+ "Likelihood: analytical\n",
+ "Observations: 3988\n",
+ "\n",
+ "Parameters:\n",
+ "\n",
+ "v:\n",
+ " Formula: v ~ 1 + (1|participant_id)\n",
+ " Priors:\n",
+ " v_Intercept ~ Normal(mu: 2.0, sigma: 3.0)\n",
+ " v_1|participant_id ~ Normal(mu: 0.0, sigma: Weibull(alpha: 1.5, beta: 0.30000001192092896))\n",
+ " Link: identity\n",
+ " Explicit bounds: (-inf, inf)\n",
+ "a:\n",
+ " Formula: a ~ 1 + (1|participant_id)\n",
+ " Priors:\n",
+ " a_Intercept ~ Gamma(mu: 1.5, sigma: 0.75)\n",
+ " a_1|participant_id ~ Normal(mu: 0.0, sigma: Weibull(alpha: 1.5, beta: 0.30000001192092896))\n",
+ " Link: log\n",
+ " Explicit bounds: (0.0, inf)\n",
+ "z:\n",
+ " Formula: z ~ 1 + (1|participant_id)\n",
+ " Priors:\n",
+ " z_Intercept ~ Gamma(mu: 10.0, sigma: 10.0)\n",
+ " z_1|participant_id ~ Normal(mu: 0.0, sigma: Weibull(alpha: 1.5, beta: 0.30000001192092896))\n",
+ " Link: Generalized logit link function with bounds (0.0, 1.0)\n",
+ " Explicit bounds: (0.0, 1.0)\n",
+ "t:\n",
+ " Formula: t ~ 1 + (1|participant_id)\n",
+ " Priors:\n",
+ " t_Intercept ~ Gamma(mu: 0.4000000059604645, sigma: 0.20000000298023224)\n",
+ " t_1|participant_id ~ Normal(mu: 0.0, sigma: Weibull(alpha: 1.5, beta: 0.30000001192092896))\n",
+ " Link: log\n",
+ " Explicit bounds: (0.0, inf)\n",
+ "\n",
+ "Lapse probability: 0.05\n",
+ "Lapse distribution: Uniform(lower: 0.0, upper: 10.0)"
+ ]
+ },
+ "execution_count": 10,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "model_safe_loglogit = hssm.HSSM(\n",
+ " data=cav_data, hierarchical=True, prior_settings=\"safe\", link_settings=\"log_logit\"\n",
+ ")\n",
+ "model_safe_loglogit"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "762e94e1-66be-47c6-9024-59047790953a",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.5"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/docs/getting_started/installation.md b/docs/getting_started/installation.md
index 2b495e41..7c5264b4 100644
--- a/docs/getting_started/installation.md
+++ b/docs/getting_started/installation.md
@@ -37,6 +37,31 @@ a dependency by default. You need to have `blackjax` installed if you want to us
pip install blackjax
```
+### Sampling with JAX support for GPU
+
+The `nuts_numpyro` sampler uses JAX as the backend and thus can support sampling on nvidia
+GPU. The only thing you need to do to take advantage of this is to install JAX with CUDA
+support before installing HSSM. Here's one example:
+
+```bash
+python -m venv .venv # Create a virtual environment
+source .venv/bin/activate # Activate the virtual environment
+
+pip install --upgrade pip
+
+# We need to limit the version of JAX for now due to some breaking
+# changes introduced in JAX 0.4.16.
+pip install --upgrade "jax[cuda11_pip]<0.4.16" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
+pip install hssm
+```
+
+The example above shows how to install JAX with CUDA 11 support. Please refer to the
+[JAX Installation](https://jax.readthedocs.io/en/latest/installation.html) page for more
+details on installing JAX on different platforms with GPU or TPU support.
+
+Note that on Google Colab, JAX support for GPU is enabled by default if the Colab backend
+has GPU enabled. You simply need only install HSSM.
+
### Visualizing the model
Model graphs are created with `model.graph()` through `graphviz`. In order to use it,
diff --git a/docs/overrides/main.html b/docs/overrides/main.html
index 38eda562..81fe76d3 100644
--- a/docs/overrides/main.html
+++ b/docs/overrides/main.html
@@ -1,22 +1,18 @@
-{% extends "base.html" %}
-
-{% block announce %}
+{% extends "base.html" %} {% block announce %}
-
- {% include ".icons/fontawesome/solid/angles-down.svg" %}
-
- Navigate the site here!
-
-
- v0.1.5 is released!
+
+ {% include ".icons/fontawesome/solid/angles-down.svg" %}
+
+ Navigate the site here!
+ v0.2.0b1 is released!
-
- {% include ".icons/material/head-question.svg" %}
-
- Questions?
-
- Open a discussion here!
-
+
+ {% include ".icons/material/head-question.svg" %}
+
+ Questions?
+
+ Open a discussion here!
+
{% endblock %}
diff --git a/mkdocs.yml b/mkdocs.yml
index c8414a10..11e5f9f7 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -12,6 +12,7 @@ nav:
- Getting Started:
- Installation: getting_started/installation.md
- Getting started: getting_started/getting_started.ipynb
+ - Hierarchical modeling: getting_started/hierarchical_modeling.ipynb
- API References:
- hssm: api/hssm.md
- hssm.plotting: api/plotting.md
@@ -33,6 +34,7 @@ plugins:
execute: true
execute_ignore:
- getting_started/getting_started.ipynb
+ - getting_started/hierarchical_modeling.ipynb
- tutorials/main_tutorial.ipynb
- tutorials/likelihoods.ipynb
- .ipynb_checkpoints/*.ipynb
@@ -126,5 +128,5 @@ markdown_extensions:
- pymdownx.superfences
- attr_list
- pymdownx.emoji:
- emoji_index: !!python/name:materialx.emoji.twemoji
- emoji_generator: !!python/name:materialx.emoji.to_svg
+ emoji_index: !!python/name:material.extensions.emoji.twemoji
+ emoji_generator: !!python/name:material.extensions.emoji.to_svg
diff --git a/pyproject.toml b/pyproject.toml
index f0b65346..c963895a 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "HSSM"
-version = "0.1.5"
+version = "0.2.0b1"
description = "Bayesian inference for hierarchical sequential sampling models."
authors = [
"Alexander Fengler ",
@@ -23,14 +23,14 @@ numpy = ">=1.23.4,<1.26"
onnx = "^1.12.0"
jax = "^0.4.0"
jaxlib = "^0.4.0"
-ssm-simulators = "0.5.1"
+ssm-simulators = "0.6.1"
huggingface-hub = "^0.15.1"
onnxruntime = "^1.15.0"
bambi = "^0.12.0"
numpyro = "^0.12.1"
hddm-wfpt = "^0.1.1"
seaborn = "^0.13.0"
-pytensor = "<=2.17.3"
+pytensor = "<2.17.4"
[tool.poetry.group.dev.dependencies]
pytest = "^7.3.1"
@@ -69,7 +69,7 @@ profile = "black"
[tool.ruff]
line-length = 88
-target-version = "py39"
+target-version = "py310"
unfixable = ["E711"]
select = [
@@ -132,6 +132,8 @@ ignore = [
"B020",
# Function definition does not bind loop variable
"B023",
+ # zip()` without an explicit `strict=
+ "B905",
# Functions defined inside a loop must not use variables redefined in the loop
# "B301", # not yet implemented
# Too many arguments to function call
@@ -166,14 +168,7 @@ ignore = [
"TID252",
]
-exclude = [
- ".github",
- "docs",
- "notebook",
- "tests",
- "src/hssm/likelihoods/hddm_wfpt/cdfdif_wrapper.c",
- "src/hssm/likelihoods/hddm_wfpt/wfpt.cpp",
-]
+exclude = [".github", "docs", "notebook", "tests"]
[tool.ruff.pydocstyle]
convention = "numpy"
diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py
index c716f136..41830085 100644
--- a/src/hssm/hssm.py
+++ b/src/hssm/hssm.py
@@ -24,6 +24,7 @@
import seaborn as sns
import xarray as xr
from bambi.model_components import DistributionalComponent
+from bambi.transformations import transformations_namespace
from hssm.defaults import (
LoglikKind,
@@ -164,6 +165,9 @@ class HSSM:
recommended when you are using hierarchical models.
The default value is `None` when `hierarchical` is `False` and `"safe"` when
`hierarchical` is `True`.
+ extra_namespace : optional
+ Additional user supplied variables with transformations or data to include in
+ the environment where the formula is evaluated. Defaults to `None`.
**kwargs
Additional arguments passed to the `bmb.Model` object.
@@ -214,6 +218,7 @@ def __init__(
hierarchical: bool = False,
link_settings: Literal["log_logit"] | None = None,
prior_settings: Literal["safe"] | None = None,
+ extra_namespace: dict[str, Any] | None = None,
**kwargs,
):
self.data = data
@@ -232,6 +237,11 @@ def __init__(
self.link_settings = link_settings
self.prior_settings = prior_settings
+ additional_namespace = transformations_namespace.copy()
+ if extra_namespace is not None:
+ additional_namespace.update(extra_namespace)
+ self.additional_namespace = additional_namespace
+
responses = self.data["response"].unique().astype(int)
self.n_responses = len(responses)
if self.n_responses == 2:
@@ -312,7 +322,12 @@ def __init__(
)
self.model = bmb.Model(
- self.formula, data, family=self.family, priors=self.priors, **other_kwargs
+ self.formula,
+ data=data,
+ family=self.family,
+ priors=self.priors,
+ extra_namespace=extra_namespace,
+ **other_kwargs,
)
self._aliases = get_alias_dict(self.model, self._parent_param)
@@ -322,6 +337,7 @@ def sample(
self,
sampler: Literal["mcmc", "nuts_numpyro", "nuts_blackjax", "laplace", "vi"]
| None = None,
+ init: str | None = None,
**kwargs,
) -> az.InferenceData | pm.Approximation:
"""Perform sampling using the `fit` method via bambi.Model.
@@ -335,6 +351,9 @@ def sample(
sampler will automatically be chosen: when the model uses the
`approx_differentiable` likelihood, and `jax` backend, "nuts_numpyro" will
be used. Otherwise, "mcmc" (the default PyMC NUTS sampler) will be used.
+ init: optional
+ Initialization method to use for the sampler. If any of the NUTS samplers
+ is used, defaults to `"adapt_diag"`. Otherwise, defaults to `"auto"`.
kwargs
Other arguments passed to bmb.Model.fit(). Please see [here]
(https://bambinos.github.io/bambi/api_reference.html#bambi.models.Model.fit)
@@ -370,7 +389,7 @@ def sample(
)
if "step" not in kwargs:
- kwargs["step"] = pm.Slice(model=self.pymc_model)
+ kwargs |= {"step": pm.Slice(model=self.pymc_model)}
if (
self.loglik_kind == "approx_differentiable"
@@ -387,7 +406,15 @@ def sample(
if self._check_extra_fields():
self._update_extra_fields()
- self._inference_obj = self.model.fit(inference_method=sampler, **kwargs)
+ if init is None:
+ if sampler in ["mcmc", "nuts_numpyro", "nuts_blackjax"]:
+ init = "adapt_diag"
+ else:
+ init = "auto"
+
+ self._inference_obj = self.model.fit(
+ inference_method=sampler, init=init, **kwargs
+ )
return self.traces
@@ -643,11 +670,11 @@ def plot_trace(
data : optional
An ArviZ InferenceData object. If None, the traces stored in the model will
be used.
- include deterministic : optional
+ include_deterministic : optional
Whether to include deterministic variables in the plot. Defaults to False.
Note that if include deterministic is set to False and and `var_names` is
provided, the `var_names` provided will be modified to also exclude the
- deterministic values. If this is not desirable, set
+ deterministic values. If this is not desirable, set
`include deterministic` to True.
tight_layout : optional
Whether to call plt.tight_layout() after plotting. Defaults to True.
@@ -866,6 +893,8 @@ def _add_kwargs_and_p_outlier_to_include(
"""Process kwargs and p_outlier and add them to include."""
if include is None:
include = []
+ else:
+ include = include.copy()
params_in_include = [param["name"] for param in include]
# Process kwargs
@@ -927,7 +956,7 @@ def _preprocess_rest(self, processed: dict[str, Param]) -> dict[str, Param]:
bounds = self.model_config.bounds.get(param_str)
param = Param(
param_str,
- formula="1 + (1|participant_id)",
+ formula=f"{param_str} ~ 1 + (1|participant_id)",
bounds=bounds,
)
else:
@@ -970,15 +999,27 @@ def _find_parent(self) -> tuple[str, Param]:
def _override_defaults(self):
"""Override the default priors or links."""
+ is_ddm = (
+ self.model_name in ["ddm", "ddm_sdv", "ddm_full"]
+ and self.loglik_kind != "approx_differentiable"
+ )
for param in self.list_params:
param_obj = self.params[param]
if self.prior_settings == "safe":
- param_obj.override_default_priors(self.data)
- elif self.link_settings == "log_logit":
+ if is_ddm:
+ param_obj.override_default_priors_ddm(
+ self.data, self.additional_namespace
+ )
+ else:
+ param_obj.override_default_priors(
+ self.data, self.additional_namespace
+ )
+ if self.link_settings == "log_logit":
param_obj.override_default_link()
def _process_all(self):
"""Process all params."""
+ assert self.list_params is not None
for param in self.list_params:
self.params[param].convert()
diff --git a/src/hssm/link.py b/src/hssm/link.py
index 1ad00c10..68e6164e 100644
--- a/src/hssm/link.py
+++ b/src/hssm/link.py
@@ -78,3 +78,10 @@ def link_(x):
return np.log((x - a) / (b - x))
return link_
+
+ def __str__(self):
+ """Return a string representation of the link function."""
+ if self.name == "gen_logit":
+ lower, upper = self.bounds
+ return f"Generalized logit link function with bounds ({lower}, {upper})"
+ return super().__str__()
diff --git a/src/hssm/param.py b/src/hssm/param.py
index 0dfde27c..175920bd 100644
--- a/src/hssm/param.py
+++ b/src/hssm/param.py
@@ -1,14 +1,16 @@
"""The Param utility class."""
import logging
-from typing import Any, Union, cast
+from copy import deepcopy
+from typing import Any, Literal, Union, cast
import bambi as bmb
import numpy as np
import pandas as pd
+from formulae import design_matrices
from .link import Link
-from .prior import Prior
+from .prior import Prior, get_default_prior, get_hddm_default_prior
# PEP604 union operator "|" not supported by pylint
# Fall back to old syntax
@@ -98,14 +100,7 @@ def override_default_link(self):
This is most likely because both default prior and default bounds are supplied.
"""
- if self._is_converted:
- raise ValueError(
- (
- "Cannot override the default link function for parameter %s."
- + " The object has already been processed."
- )
- % self.name,
- )
+ self._ensure_not_converted(context="link")
if not self.is_regression or self._link_specified:
return # do nothing
@@ -125,7 +120,7 @@ def override_default_link(self):
return
elif lower == 0.0 and np.isposinf(upper):
self.link = "log"
- if not np.isneginf(lower) and not np.isposinf(upper):
+ elif not np.isneginf(lower) and not np.isposinf(upper):
self.link = Link("gen_logit", bounds=self.bounds)
else:
_logger.warning(
@@ -136,8 +131,8 @@ def override_default_link(self):
upper,
)
- def override_default_priors(self, data: pd.DataFrame):
- """Override the default priors.
+ def override_default_priors(self, data: pd.DataFrame, eval_env: dict[str, Any]):
+ """Override the default priors - the general case.
By supplying priors for all parameters in the regression, we can override the
defaults that Bambi uses.
@@ -146,8 +141,144 @@ def override_default_priors(self, data: pd.DataFrame):
----------
data
The data used to fit the model.
+ eval_env
+ The environment used to evaluate the formula.
+ """
+ self._ensure_not_converted(context="prior")
+
+ if not self.is_regression:
+ return
+
+ override_priors = {}
+ dm = self._get_design_matrices(data, eval_env)
+
+ has_common_intercept = False
+ if dm.common is not None:
+ for name, term in dm.common.terms.items():
+ if term.kind == "intercept":
+ has_common_intercept = True
+ override_priors[name] = get_default_prior(
+ "common_intercept", self.bounds
+ )
+ else:
+ override_priors[name] = get_default_prior("common", bounds=None)
+
+ if dm.group is not None:
+ for name, term in dm.group.terms.items():
+ if term.kind == "intercept":
+ if has_common_intercept:
+ override_priors[name] = get_default_prior(
+ "group_intercept_with_common", bounds=None
+ )
+ else:
+ # treat the term as any other group-specific term
+ _logger.warning(
+ f"No common intercept. Bounds for parameter {self.name} is"
+ + " not applied due to a current limitation of Bambi."
+ + " This will change in the future."
+ )
+ override_priors[name] = get_default_prior(
+ "group_intercept", bounds=None
+ )
+ else:
+ override_priors[name] = get_default_prior(
+ "group_specific", bounds=None
+ )
+
+ if not self.prior:
+ self.prior = override_priors
+ else:
+ prior = cast(dict[str, ParamSpec], self.prior)
+ self.prior = override_priors | prior
+
+ def override_default_priors_ddm(self, data: pd.DataFrame, eval_env: dict[str, Any]):
+ """Override the default priors - the ddm case.
+
+ By supplying priors for all parameters in the regression, we can override the
+ defaults that Bambi uses.
+
+ Parameters
+ ----------
+ data
+ The data used to fit the model.
+ eval_env
+ The environment used to evaluate the formula.
+ """
+ self._ensure_not_converted(context="prior")
+ assert self.name is not None
+
+ if not self.is_regression:
+ return
+
+ override_priors = {}
+ dm = self._get_design_matrices(data, eval_env)
+
+ has_common_intercept = False
+ if dm.common is not None:
+ for name, term in dm.common.terms.items():
+ if term.kind == "intercept":
+ has_common_intercept = True
+ override_priors[name] = get_hddm_default_prior(
+ "common_intercept", self.name, self.bounds
+ )
+ else:
+ override_priors[name] = get_hddm_default_prior(
+ "common", self.name, bounds=None
+ )
+
+ if dm.group is not None:
+ for name, term in dm.group.terms.items():
+ if term.kind == "intercept":
+ if has_common_intercept:
+ override_priors[name] = get_default_prior(
+ "group_intercept_with_common", bounds=None
+ )
+ else:
+ # treat the term as any other group-specific term
+ _logger.warning(
+ f"No common intercept. Bounds for parameter {self.name} is"
+ + " not applied due to a current limitation of Bambi."
+ + " This will change in the future."
+ )
+ override_priors[name] = get_hddm_default_prior(
+ "group_intercept", self.name, bounds=None
+ )
+ else:
+ override_priors[name] = get_hddm_default_prior(
+ "group_specific", self.name, bounds=None
+ )
+
+ if not self.prior:
+ self.prior = override_priors
+ else:
+ prior = cast(dict[str, ParamSpec], self.prior)
+ self.prior = override_priors | prior
+
+ def _get_design_matrices(self, data: pd.DataFrame, extra_namespace: dict[str, Any]):
+ """Get the design matrices for the regression.
+
+ Parameters
+ ----------
+ data
+ A pandas DataFrame
+ eval_env
+ The evaluation environment
"""
- return # Will implement in the next PR
+ formula = cast(str, self.formula)
+ rhs = formula.split("~")[1]
+ formula = "rt ~ " + rhs
+ dm = design_matrices(formula, data=data, extra_namespace=extra_namespace)
+
+ return dm
+
+ def _ensure_not_converted(self, context=Literal["link", "prior"]):
+ """Ensure that the object has not been converted."""
+ if self._is_converted:
+ context = "link function" if context == "link" else "priors"
+ raise ValueError(
+ f"Cannot override the default {context} for parameter {self.name}."
+ + " The object has already been processed."
+ )
def set_parent(self):
"""Set the Param as parent."""
@@ -187,11 +318,6 @@ def convert(self):
if self.formula is not None:
# The regression case
-
- self.formula = (
- self.formula if "~" in self.formula else f"{self.name} ~ {self.formula}"
- )
-
if isinstance(self.prior, (float, bmb.Prior)):
raise ValueError(
"Please specify priors for each individual parameter in the "
@@ -531,3 +657,14 @@ def _make_default_prior(bounds: tuple[float, float]) -> bmb.Prior:
return bmb.Prior("TruncatedNormal", mu=lower, lower=lower, sigma=2.0)
else:
return bmb.Prior(name="Uniform", lower=lower, upper=upper)
+
+
+def merge_dicts(dict1: dict, dict2: dict) -> dict:
+ """Recursively merge two dictionaries."""
+ merged = deepcopy(dict1)
+ for key, value in dict2.items():
+ if key in merged and isinstance(merged[key], dict) and isinstance(value, dict):
+ merged[key] = merge_dicts(merged[key], value)
+ else:
+ merged[key] = value
+ return merged
diff --git a/src/hssm/prior.py b/src/hssm/prior.py
index 3cc21244..4fa55d77 100644
--- a/src/hssm/prior.py
+++ b/src/hssm/prior.py
@@ -8,8 +8,9 @@
2. The ability to still print out the prior before the truncation.
3. The ability to shorten the output of bmb.Prior.
"""
-
-from typing import Callable
+from copy import deepcopy
+from statistics import mean
+from typing import Any, Callable
import bambi as bmb
import numpy as np
@@ -142,3 +143,176 @@ def TruncatedDist(name):
)
return TruncatedDist
+
+
+def generate_prior(
+ dist: str | dict | int | float | Prior,
+ bounds: tuple[float, float] | None = None,
+ **kwargs,
+):
+ """Generate a Prior distribution.
+
+ The parameter ``kwargs`` is used to pass hyperpriors that are assigned to the
+ parameters of the prior to be built.
+
+ This function is taken from bambi.priors.prior.py and modified to handle bounds.
+
+ Parameters
+ ----------
+ dist:
+ If a string, it is the name of the prior distribution with default values taken
+ from ``SETTINGS_DISTRIBUTIONS``. If a number, it is a factor used to scale the
+ standard deviation of the priors generated automatically by Bambi. If a `dict`,
+ it must contain a ``"dist"`` key with the name of the distribution and other
+ keys.
+ bounds: optional
+ A tuple of two floats indicating the lower and upper bounds of the prior.
+
+ Raises
+ ------
+ ValueError
+ If ``dist`` is not a string, number, or dict.
+
+ Returns
+ -------
+ Prior
+ The Prior instance.
+ """
+ if isinstance(dist, str):
+ default_settings = deepcopy(HSSM_SETTINGS_DISTRIBUTIONS[dist])
+ if kwargs:
+ for k, v in kwargs.items():
+ default_settings[k] = generate_prior(v)
+ prior: Prior | int | float = Prior(dist, bounds=bounds, **default_settings)
+ elif isinstance(dist, dict):
+ prior_settings = deepcopy(dist)
+ dist_name: str = prior_settings.pop("dist")
+ for k, v in prior_settings.items():
+ prior_settings[k] = generate_prior(v)
+ prior = Prior(dist_name, bounds=bounds, **prior_settings)
+ elif isinstance(dist, Prior):
+ prior = dist
+ elif isinstance(dist, (int, float)):
+ if bounds is not None:
+ lower, upper = bounds
+ if dist < lower or dist > upper:
+ raise ValueError(
+ f"The prior value {dist} is outside the bounds {bounds}."
+ )
+ prior = dist
+ else:
+ raise ValueError(
+ "'dist' must be the name of a distribution or a numeric value."
+ )
+ return prior
+
+
+def get_default_prior(term_type: str, bounds: tuple[float, float] | None):
+ """Generate a Prior based on the default settings.
+
+ The following summarizes default priors for each type of term:
+
+ * common_intercept: Bounded Normal prior (N(mean(bounds), 0.25)).
+ * common: Normal prior (N(0, 0.25)).
+ * group_intercept: Normal prior N(N(0, 0.25), Weibull(1.5, 0.3). It's supposed to
+ be bounded but Bambi does not fully support it yet.
+ * group_specific: Normal prior N(N(0, 0.25), Weibull(1.5, 0.3).
+
+ This function is taken from bambi.priors.prior.py and modified to handle hssm-
+ specific situations.
+
+ Parameters
+ ----------
+ term_type : str
+ The type of the term for which the default prior is wanted.
+ bounds : tuple[float, float] | None
+ A tuple of two floats indicating the lower and upper bounds of the prior.
+
+ Raises
+ ------
+ ValueError
+ If ``term_type`` is not within the values listed above.
+
+ Returns
+ -------
+ prior: Prior
+ The instance of Prior according to the ``term_type``.
+ """
+ if term_type == "common":
+ prior = generate_prior("Normal", bounds=None)
+ elif term_type == "common_intercept":
+ if bounds is not None:
+ if any(np.isinf(b) for b in bounds):
+ # TODO: Make it more specific.
+ prior = generate_prior("Normal", bounds=bounds)
+ else:
+ prior = generate_prior(
+ "Normal", mu=mean(bounds), sigma=0.25, bounds=bounds
+ )
+ else:
+ prior = generate_prior("Normal")
+ elif term_type == "group_intercept":
+ prior = generate_prior("Normal", mu="Normal", sigma="Weibull")
+ elif term_type == "group_specific":
+ prior = generate_prior("Normal", mu="Normal", sigma="Weibull")
+ elif term_type == "group_intercept_with_common":
+ prior = generate_prior("Normal", mu=0.0, sigma="Weibull")
+ else:
+ raise ValueError("Unrecognized term type.")
+ return prior
+
+
+def get_hddm_default_prior(
+ term_type: str, param: str, bounds: tuple[float, float] | None
+):
+ """Generate a Prior based on the default settings - the HDDM case."""
+ if term_type == "common":
+ prior = generate_prior("Normal", bounds=None)
+ elif term_type == "common_intercept":
+ prior = generate_prior(HDDM_MU[param], bounds=bounds)
+ elif term_type == "group_intercept":
+ prior = generate_prior(HDDM_SETTINGS_GROUP[param], bounds=None)
+ elif term_type == "group_specific":
+ prior = generate_prior("Normal", mu="Normal", sigma="Weibull", bounds=None)
+ else:
+ raise ValueError("Unrecognized term type.")
+ return prior
+
+
+HSSM_SETTINGS_DISTRIBUTIONS: dict[Any, Any] = {
+ "Normal": {"mu": 0.0, "sigma": 0.25},
+ "Weibull": {"alpha": 1.5, "beta": 0.3},
+ "HalfNormal": {"sigma": 0.25},
+ "Beta": {"alpha": 1.0, "beta": 1.0},
+ "Gamma": {"mu": 1.0, "sigma": 1.0},
+}
+
+HDDM_MU: dict[Any, Any] = {
+ "v": {"dist": "Normal", "mu": 2.0, "sigma": 3.0},
+ "a": {"dist": "Gamma", "mu": 1.5, "sigma": 0.75},
+ "z": {"dist": "Gamma", "mu": 10, "sigma": 10},
+ "t": {"dist": "Gamma", "mu": 0.4, "sigma": 0.2},
+ "sv": {"dist": "HalfNormal", "sigma": 2.0},
+ "st": {"dist": "HalfNormal", "sigma": 0.3},
+ "sz": {"dist": "HalfNormal", "sigma": 0.5},
+}
+
+HDDM_SIGMA: dict[Any, Any] = {
+ "v": {"dist": "HalfNormal", "sigma": 2.0},
+ "a": {"dist": "HalfNormal", "sigma": 0.1},
+ "z": {"dist": "Gamma", "mu": 10, "sigma": 10},
+ "t": {"dist": "HalfNormal", "sigma": 1.0},
+ "sv": {"dist": "Weibull", "alpha": 1.5, "beta": "0.3"},
+ "sz": {"dist": "Weibull", "alpha": 1.5, "beta": "0.3"},
+ "st": {"dist": "Weibull", "alpha": 1.5, "beta": "0.3"},
+}
+
+HDDM_SETTINGS_GROUP: dict[Any, Any] = {
+ "v": {"dist": "Normal", "mu": HDDM_MU["v"], "sigma": HDDM_SIGMA["v"]},
+ "a": {"dist": "Gamma", "mu": HDDM_MU["a"], "sigma": HDDM_SIGMA["a"]},
+ "z": {"dist": "Beta", "alpha": HDDM_MU["z"], "beta": HDDM_SIGMA["z"]},
+ "t": {"dist": "Normal", "mu": HDDM_MU["t"], "sigma": HDDM_SIGMA["t"]},
+ "sv": {"dist": "Gamma", "mu": HDDM_MU["sv"], "sigma": HDDM_SIGMA["sv"]},
+ "sz": {"dist": "Gamma", "mu": HDDM_MU["sz"], "sigma": HDDM_SIGMA["sz"]},
+ "st": {"dist": "Gamma", "mu": HDDM_MU["st"], "sigma": HDDM_SIGMA["st"]},
+}
diff --git a/src/hssm/utils.py b/src/hssm/utils.py
index 2adcc1e6..f06c3fdc 100644
--- a/src/hssm/utils.py
+++ b/src/hssm/utils.py
@@ -10,7 +10,6 @@
"""
import logging
-from copy import deepcopy
from typing import Any, Iterable, Literal, NewType
import bambi as bmb
@@ -19,7 +18,7 @@
import xarray as xr
from bambi.terms import CommonTerm, GroupSpecificTerm, HSGPTerm, OffsetTerm
from huggingface_hub import hf_hub_download
-from jax.config import config
+from jax import config
from pymc.model_graph import ModelGraph
from pytensor import function
@@ -54,17 +53,6 @@ def download_hf(path: str):
return hf_hub_download(repo_id=REPO_ID, filename=path)
-def merge_dicts(dict1: dict, dict2: dict) -> dict:
- """Recursively merge two dictionaries."""
- merged = deepcopy(dict1)
- for key, value in dict2.items():
- if key in merged and isinstance(merged[key], dict) and isinstance(value, dict):
- merged[key] = merge_dicts(merged[key], value)
- else:
- merged[key] = value
- return merged
-
-
def make_alias_dict_from_parent(parent: Param) -> dict[str, str]:
"""Make aliases from the parent parameter.
@@ -336,7 +324,7 @@ def _process_param_in_kwargs(
Raises
------
ValueError
- When `prior` is not a `float`, a `dict`, or a `b`mb.Prior` object.
+ When `prior` is not a `float`, a `dict`, or a `bmb.Prior` object.
"""
if isinstance(prior, (int, float, bmb.Prior)):
return {"name": name, "prior": prior}
diff --git a/tests/test_config.py b/tests/test_config.py
index 3cc77e44..b0261bde 100644
--- a/tests/test_config.py
+++ b/tests/test_config.py
@@ -4,6 +4,8 @@
import hssm
from hssm.config import Config, ModelConfig
+hssm.set_floatX("float32")
+
def test_from_defaults():
# Case 1: Has default prior
diff --git a/tests/test_distribution_utils.py b/tests/test_distribution_utils.py
index 8d12970d..f1efc5b1 100644
--- a/tests/test_distribution_utils.py
+++ b/tests/test_distribution_utils.py
@@ -9,6 +9,8 @@
from hssm.distribution_utils.dist import apply_param_bounds_to_loglik, make_distribution
from hssm.likelihoods.analytical import logp_ddm, DDM
+hssm.set_floatX("float32")
+
def test_make_ssm_rv():
params = ["v", "a", "z", "t"]
diff --git a/tests/test_hssm.py b/tests/test_hssm.py
index 3d890fe9..6169b44f 100644
--- a/tests/test_hssm.py
+++ b/tests/test_hssm.py
@@ -2,15 +2,14 @@
import bambi as bmb
import numpy as np
-import pandas as pd
-import pytensor
import pytest
+import hssm
from hssm import HSSM
from hssm.utils import download_hf
from hssm.likelihoods import DDM, logp_ddm
-pytensor.config.floatX = "float32"
+hssm.set_floatX("float32")
param_v = {
"name": "v",
@@ -190,7 +189,7 @@ def test_sample_prior_predictive(data_ddm_reg):
)
prior_predictive_5 = model_regression_multi.sample_prior_predictive(draws=10)
- data_ddm_reg["subject_id"] = np.arange(10)
+ data_ddm_reg.loc[:, "subject_id"] = np.arange(10)
model_regression_random_effect = HSSM(
data=data_ddm_reg,
diff --git a/tests/test_likelihoods.py b/tests/test_likelihoods.py
index c3c457d2..18b7f441 100644
--- a/tests/test_likelihoods.py
+++ b/tests/test_likelihoods.py
@@ -9,10 +9,14 @@
import pytest
from numpy.random import rand
+import hssm
+
# pylint: disable=C0413
from hssm.likelihoods.analytical import compare_k, logp_ddm, logp_ddm_sdv
from hssm.likelihoods.blackbox import logp_ddm_bbox, logp_ddm_sdv_bbox
+hssm.set_floatX("float32")
+
def test_kterm(data_ddm):
"""This function defines a range of kterms and tests results to
diff --git a/tests/test_onnx.py b/tests/test_onnx.py
index 31e464b6..5474e7ee 100644
--- a/tests/test_onnx.py
+++ b/tests/test_onnx.py
@@ -7,9 +7,10 @@
import pytensor.tensor as pt
import pytest
+import hssm
from hssm.distribution_utils.onnx import *
-pytensor.config.floatX = "float32"
+hssm.set_floatX("float32")
DECIMAL = 4
diff --git a/tests/test_param.py b/tests/test_param.py
index 8b43588c..306172d2 100644
--- a/tests/test_param.py
+++ b/tests/test_param.py
@@ -9,6 +9,7 @@
_make_priors_recursive,
_make_bounded_prior,
)
+from hssm.defaults import default_model_config
def test_param_creation_non_regression():
@@ -285,7 +286,7 @@ def fake_func(x):
"x1": bmb.Prior("Normal", mu=0, sigma=0.5),
}
- param_reg_formula1 = Param("a", formula="1 + x1", prior=priors_dict)
+ param_reg_formula1 = Param("a", formula="a ~ 1 + x1", prior=priors_dict)
param_reg_formula2 = Param(
"a", formula="a ~ 1 + x1", prior=priors_dict, link=fake_link
)
@@ -417,3 +418,328 @@ def test_param_override_default_link(caplog, formula, link, bounds, result):
with pytest.raises(ValueError):
param.override_default_link()
+
+
+def _check_group_prior(group_prior):
+ assert isinstance(group_prior, bmb.Prior)
+ assert group_prior.dist is None
+ assert group_prior.name == "Normal"
+
+ mu = group_prior.args["mu"]
+ sigma = group_prior.args["sigma"]
+
+ assert isinstance(group_prior, bmb.Prior)
+ assert mu.name == "Normal"
+ assert mu.args["mu"] == 0.0
+ assert mu.args["sigma"] == 0.25
+
+ assert isinstance(group_prior, bmb.Prior)
+ assert sigma.name == "Weibull"
+ assert sigma.args["alpha"] == 1.5
+ assert sigma.args["beta"] == 0.3
+
+
+def _check_group_prior_with_common(group_prior):
+ assert isinstance(group_prior, bmb.Prior)
+ assert group_prior.dist is None
+ assert group_prior.name == "Normal"
+
+ mu = group_prior.args["mu"]
+ sigma = group_prior.args["sigma"]
+
+ assert mu == 0.0
+
+ assert isinstance(group_prior, bmb.Prior)
+ assert sigma.name == "Weibull"
+ assert sigma.args["alpha"] == 1.5
+ assert sigma.args["beta"] == 0.3
+
+
+angle_params = default_model_config["angle"]["list_params"]
+angle_bounds = default_model_config["angle"]["likelihoods"]["approx_differentiable"][
+ "bounds"
+].values()
+param_and_bounds = zip(angle_params, angle_bounds)
+
+
+@pytest.mark.parametrize(
+ ("param_name", "bounds"),
+ param_and_bounds,
+)
+def test_param_override_default_priors(cavanagh_test, caplog, param_name, bounds):
+ # Necessary for verifying the values of certain parameters of the priors
+ hssm.set_floatX("float64")
+ # Shouldn't do anything if the param is not a regression
+ param_non_reg = Param(
+ name=param_name,
+ prior={},
+ )
+ param_non_reg.override_default_priors(cavanagh_test, {})
+ assert not param_non_reg.prior
+
+ # The basic regression case, no group-specific terms
+ param = Param(
+ name=param_name,
+ formula=f"{param_name} ~ 1 + theta",
+ bounds=bounds,
+ )
+
+ param.override_default_priors(cavanagh_test, {})
+
+ assert param.prior is not None
+
+ intercept_prior = param.prior["Intercept"]
+ slope_prior = param.prior["theta"]
+
+ assert isinstance(intercept_prior, hssm.Prior)
+ assert intercept_prior.is_truncated
+ assert intercept_prior.bounds == bounds
+ assert intercept_prior.dist is not None
+ lower, upper = intercept_prior.bounds
+ _mu = intercept_prior._args["mu"]
+ if isinstance(_mu, np.ndarray):
+ assert _mu.item() == (lower + upper) / 2
+ else:
+ assert _mu == (lower + upper) / 2
+ assert intercept_prior._args["sigma"] == 0.25
+
+ assert isinstance(slope_prior, bmb.Prior)
+ assert slope_prior.dist is None
+ assert slope_prior.args["mu"] == 0.0
+ assert slope_prior.args["sigma"] == 0.25
+
+ unif_prior = {"name": "Uniform", "lower": 0.0, "upper": 1.0}
+ set_prior = {
+ "Intercept": unif_prior,
+ "theta": unif_prior,
+ }
+
+ param_with_prior = Param(
+ name=param_name,
+ formula=f"{param_name} ~ 1 + theta",
+ bounds=bounds,
+ prior=set_prior,
+ )
+
+ param_with_prior.override_default_priors(cavanagh_test, {})
+ assert param_with_prior.prior == set_prior
+
+ # The regression case, with group-specific terms
+ param_group = Param(
+ name=param_name,
+ formula=f"{param_name} ~ 1 + (1 + theta | participant_id)",
+ bounds=bounds,
+ )
+
+ param_group.override_default_priors(cavanagh_test, {})
+
+ assert all(
+ param in param_group.prior
+ for param in ["Intercept", "1|participant_id", "theta|participant_id"]
+ )
+
+ assert param_group.prior["Intercept"].is_truncated
+
+ group_intercept_prior = param_group.prior["1|participant_id"]
+ group_slope_prior = param_group.prior["theta|participant_id"]
+
+ _check_group_prior_with_common(group_intercept_prior)
+ _check_group_prior(group_slope_prior)
+
+ param_no_common_intercept = Param(
+ name=param_name,
+ formula=f"{param_name} ~ 0 + (1 + theta | participant_id)",
+ bounds=bounds,
+ )
+
+ param_no_common_intercept.override_default_priors(cavanagh_test, {})
+ print(caplog.records)
+ assert "limitation" in caplog.records[-1].msg
+
+ assert "Intercept" not in param_no_common_intercept.prior
+ group_intercept_prior = param_no_common_intercept.prior["1|participant_id"]
+ group_slope_prior = param_no_common_intercept.prior["theta|participant_id"]
+
+ _check_group_prior(group_intercept_prior)
+ _check_group_prior(group_slope_prior)
+
+ # Change back after testing
+ hssm.set_floatX("float32")
+
+
+v_mu = {"name": "Normal", "mu": 2.0, "sigma": 3.0}
+v_sigma = {"name": "HalfNormal", "sigma": 2.0}
+v_prior = {"name": "Normal", "mu": v_mu, "sigma": v_sigma}
+
+a_mu = {"name": "Gamma", "mu": 1.5, "sigma": 0.75}
+a_sigma = {"name": "HalfNormal", "sigma": 0.1}
+a_prior = {"name": "Gamma", "mu": a_mu, "sigma": a_sigma}
+
+z_mu = {"name": "Gamma", "mu": 10.0, "sigma": 10.0}
+z_sigma = {"name": "Gamma", "mu": 10.0, "sigma": 10.0}
+z_prior = {"name": "Beta", "alpha": z_mu, "beta": z_sigma}
+
+t_mu = {"name": "Gamma", "mu": 0.4, "sigma": 0.2}
+t_sigma = {"name": "HalfNormal", "sigma": 1}
+t_prior = {"name": "Normal", "mu": t_mu, "sigma": t_sigma}
+
+
+@pytest.mark.parametrize(
+ ("param_name", "mu", "prior"),
+ [
+ ("v", v_mu, v_prior),
+ ("a", a_mu, a_prior),
+ ("z", z_mu, z_prior),
+ ("t", t_mu, t_prior),
+ ],
+)
+def test_param_override_default_priors_ddm(
+ cavanagh_test, caplog, param_name, mu, prior
+):
+ # Necessary for verifying the values of certain parameters of the priors
+ hssm.set_floatX("float64")
+ # Shouldn't do anything if the param is not a regression
+ param_non_reg = Param(
+ name=param_name,
+ prior={},
+ )
+ param_non_reg.override_default_priors_ddm(cavanagh_test, {})
+ assert not param_non_reg.prior
+
+ bounds = (-10, 10)
+
+ # The basic regression case, no group-specific terms
+ param = Param(
+ name=param_name,
+ formula=f"{param_name} ~ 1 + theta",
+ bounds=bounds, # invalid, just for testing
+ )
+
+ param.override_default_priors_ddm(cavanagh_test, {})
+
+ intercept_prior = param.prior["Intercept"]
+ slope_prior = param.prior["theta"]
+
+ assert isinstance(intercept_prior, hssm.Prior)
+ assert intercept_prior.bounds == bounds
+ assert intercept_prior.dist is not None
+ mu1 = mu.copy()
+ assert intercept_prior.name == mu1.pop("name")
+ for key, val in mu1.items():
+ val1 = intercept_prior._args[key]
+ np.testing.assert_almost_equal(val1, val)
+
+ assert isinstance(slope_prior, bmb.Prior)
+ assert slope_prior.dist is None
+ assert slope_prior.args["mu"] == 0.0
+ assert slope_prior.args["sigma"] == 0.25
+
+ # If prior is set, do not override
+ unif_prior = {"name": "Uniform", "lower": 0.0, "upper": 1.0}
+ set_prior = {
+ "Intercept": unif_prior,
+ "theta": unif_prior,
+ }
+
+ param_with_prior = Param(
+ name=param_name,
+ formula=f"{param_name} ~ 1 + theta",
+ bounds=bounds,
+ prior=set_prior,
+ )
+
+ param_with_prior.override_default_priors_ddm(cavanagh_test, {})
+ assert param_with_prior.prior == set_prior
+
+ # The regression case, with group-specific terms
+ param_group = Param(
+ name=param_name,
+ formula=f"{param_name} ~ 1 + (1 + theta | participant_id)",
+ bounds=bounds,
+ )
+
+ param_group.override_default_priors_ddm(cavanagh_test, {})
+
+ assert all(
+ param in param_group.prior
+ for param in ["Intercept", "1|participant_id", "theta|participant_id"]
+ )
+
+ assert param_group.prior["Intercept"].is_truncated
+
+ group_intercept_prior = param_group.prior["1|participant_id"]
+ group_slope_prior = param_group.prior["theta|participant_id"]
+
+ def _check_group_prior_intercept_ddm(group_prior, prior):
+ assert isinstance(group_prior, bmb.Prior)
+ assert group_prior.dist is None
+ prior1 = prior.copy()
+ assert group_prior.name == prior1.pop("name")
+ for key, val in prior1.items():
+ hyperprior = group_prior.args[key]
+ val1 = val.copy()
+ assert hyperprior.name == val1.pop("name")
+ for key2, val2 in val1.items():
+ assert hyperprior.args[key2] == val2
+
+ _check_group_prior_with_common(group_intercept_prior)
+ _check_group_prior(group_slope_prior)
+
+ param_no_common_intercept = Param(
+ name=param_name,
+ formula=f"{param_name} ~ 0 + (1 + theta | participant_id)",
+ bounds=bounds,
+ )
+
+ param_no_common_intercept.override_default_priors_ddm(cavanagh_test, {})
+ assert "limitation" in caplog.records[-1].msg
+
+ assert "Intercept" not in param_no_common_intercept.prior
+ group_intercept_prior = param_no_common_intercept.prior["1|participant_id"]
+ group_slope_prior = param_no_common_intercept.prior["theta|participant_id"]
+
+ _check_group_prior_intercept_ddm(group_intercept_prior, prior)
+ _check_group_prior(group_slope_prior)
+
+ # Change back after testing
+ hssm.set_floatX("float32")
+
+
+def test_hssm_override_default_prior(cavanagh_test):
+ model1 = hssm.HSSM(
+ model="angle",
+ data=cavanagh_test,
+ hierarchical=False,
+ include=[
+ {
+ "name": "v",
+ "formula": "v ~ 1 + C(conf)",
+ }
+ ],
+ prior_settings="safe",
+ )
+
+ param_v = model1.params["v"]
+ assert param_v.prior["Intercept"].name == "Normal"
+ assert param_v.prior["Intercept"].is_truncated
+
+ model2 = hssm.HSSM(
+ model="ddm",
+ data=cavanagh_test,
+ hierarchical=True,
+ include=[
+ {
+ "name": "v",
+ "formula": "v ~ 1 + theta",
+ "prior": {"Intercept": {"name": "Uniform", "lower": -10, "upper": 10}},
+ },
+ ],
+ prior_settings="safe",
+ )
+ param_v = model2.params["v"]
+ assert param_v.prior["Intercept"].name == "Uniform"
+ assert param_v.prior["theta"].name == "Normal"
+
+ param_a = model2.params["a"]
+ assert param_a.prior["Intercept"].name == a_mu["name"]
+ assert "1|participant_id" in param_a.prior
diff --git a/tests/test_plotting.py b/tests/test_plotting.py
index 234207dc..06f608ee 100644
--- a/tests/test_plotting.py
+++ b/tests/test_plotting.py
@@ -28,6 +28,8 @@
plot_quantile_probability,
)
+hssm.set_floatX("float32")
+
def test__get_title():
assert _get_title(("a"), ("b")) == "a = b"
diff --git a/tests/test_prior.py b/tests/test_prior.py
index 42d07611..cfdcd1f8 100644
--- a/tests/test_prior.py
+++ b/tests/test_prior.py
@@ -3,8 +3,11 @@
import bambi as bmb
import numpy as np
+import hssm
from hssm import Prior
+hssm.set_floatX("float32")
+
def test_truncation():
hssm_prior = Prior("Uniform", lower=0.0, upper=1.0)
diff --git a/tests/test_sample_posterior_predictive.py b/tests/test_sample_posterior_predictive.py
index d10a8cc9..08ebdb91 100644
--- a/tests/test_sample_posterior_predictive.py
+++ b/tests/test_sample_posterior_predictive.py
@@ -1,5 +1,7 @@
import hssm
+hssm.set_floatX("float32")
+
def test_sample_posterior_predictive(cav_idata, cavanagh_test):
model = hssm.HSSM(
diff --git a/tests/test_simulator.py b/tests/test_simulator.py
index af7d9605..0eaf1949 100644
--- a/tests/test_simulator.py
+++ b/tests/test_simulator.py
@@ -2,8 +2,11 @@
import pandas as pd
import pytest
+import hssm
from hssm.simulator import simulate_data
+hssm.set_floatX("float32")
+
def test_simulator():
theta = [0.5, 1.5, 0.5, 0.5]
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 5a004664..6e629aa8 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -3,7 +3,7 @@
import pytensor
import pytest
from ssms.basic_simulators.simulator import simulator
-from jax.config import config
+from jax import config
import hssm
from hssm.utils import (
@@ -12,6 +12,8 @@
_random_sample,
)
+hssm.set_floatX("float32")
+
def test_get_alias_dict():
# Simulate some data: