diff --git a/src/hssm/config.py b/src/hssm/config.py index 788855b1..57be85e3 100644 --- a/src/hssm/config.py +++ b/src/hssm/config.py @@ -25,6 +25,7 @@ class Config: loglik: LogLik | None = None backend: Literal["jax", "pytensor"] | None = None rv: RandomVariable | None = None + extra_fields: list[str] | None = None # Fields with dictionaries are automatically deepcopied default_priors: dict[str, ParamSpec] = field(default_factory=dict) bounds: dict[str, tuple[float, float]] = field(default_factory=dict) @@ -162,3 +163,4 @@ class ModelConfig: bounds: dict[str, tuple[float, float]] = field(default_factory=dict) backend: Literal["jax", "pytensor"] | None = None rv: RandomVariable | None = None + extra_fields: list[str] | None = None diff --git a/src/hssm/defaults.py b/src/hssm/defaults.py index 60fc134e..9c5e411d 100644 --- a/src/hssm/defaults.py +++ b/src/hssm/defaults.py @@ -42,6 +42,7 @@ class LoglikConfig(TypedDict): backend: Optional[Literal["jax", "pytensor"]] default_priors: dict[str, ParamSpec] bounds: dict[str, tuple[float, float]] + extra_fields: Optional[list[str]] LoglikConfigs = dict[LoglikKind, LoglikConfig] @@ -73,6 +74,7 @@ class DefaultConfig(TypedDict): "initval": 0.1, }, }, + "extra_fields": None, }, "approx_differentiable": { "loglik": "ddm.onnx", @@ -84,6 +86,7 @@ class DefaultConfig(TypedDict): "z": (0.0, 1.0), "t": (0.0, 2.0), }, + "extra_fields": None, }, "blackbox": { "loglik": logp_ddm_bbox, @@ -96,6 +99,7 @@ class DefaultConfig(TypedDict): "initval": 0.1, }, }, + "extra_fields": None, }, }, }, @@ -114,6 +118,7 @@ class DefaultConfig(TypedDict): "initval": 0.1, }, }, + "extra_fields": None, }, "approx_differentiable": { "loglik": "ddm_sdv.onnx", @@ -126,6 +131,7 @@ class DefaultConfig(TypedDict): "t": (0.0, 2.0), "sv": (0.0, 1.0), }, + "extra_fields": None, }, "blackbox": { "loglik": logp_ddm_sdv_bbox, @@ -138,6 +144,7 @@ class DefaultConfig(TypedDict): "initval": 0.1, }, }, + "extra_fields": None, }, }, }, @@ -156,6 +163,7 @@ class DefaultConfig(TypedDict): "initval": 0.1, }, }, + "extra_fields": None, } }, }, @@ -174,6 +182,7 @@ class DefaultConfig(TypedDict): "t": (0.001, 2.0), "theta": (-0.1, 1.3), }, + "extra_fields": None, }, }, }, @@ -192,6 +201,7 @@ class DefaultConfig(TypedDict): "alpha": (1.0, 2.0), "t": (1e-3, 2.0), }, + "extra_fields": None, }, }, }, @@ -210,6 +220,7 @@ class DefaultConfig(TypedDict): "g": (-1.0, 1.0), "t": (1e-3, 2.0), }, + "extra_fields": None, }, }, }, @@ -229,6 +240,7 @@ class DefaultConfig(TypedDict): "alpha": (0.31, 4.99), "beta": (0.31, 6.99), }, + "extra_fields": None, }, }, }, @@ -250,6 +262,7 @@ class DefaultConfig(TypedDict): "ndt": (0.0, 2.0), "theta": (-0.1, 1.45), }, + "extra_fields": None, }, }, }, @@ -268,6 +281,7 @@ class DefaultConfig(TypedDict): "a": (0.3, 2.5), "t": (0.0, 2.0), }, + "extra_fields": None, }, }, }, diff --git a/src/hssm/distribution_utils/dist.py b/src/hssm/distribution_utils/dist.py index 24d9c350..389c34a2 100644 --- a/src/hssm/distribution_utils/dist.py +++ b/src/hssm/distribution_utils/dist.py @@ -9,7 +9,7 @@ import logging from os import PathLike -from typing import Any, Callable, Type +from typing import Any, Callable, Iterable, Type import bambi as bmb import numpy as np @@ -199,7 +199,14 @@ def rng_fn( if not np.isscalar(size): size = np.squeeze(size) - arg_arrays = [np.asarray(arg) for arg in args] + num_params = len(cls._list_params) + + # TODO: We need to figure out what to do with extra_fields when + # doing posterior predictive sampling. Right now nothing is done. + if num_params < len(args): + arg_arrays = [np.asarray(arg) for arg in args[:num_params]] + else: + arg_arrays = [np.asarray(arg) for arg in args] p_outlier = None @@ -299,6 +306,7 @@ def make_distribution( list_params: list[str], bounds: dict | None = None, lapse: bmb.Prior | None = None, + extra_fields: list[np.ndarray] | None = None, ) -> Type[pm.Distribution]: """Make a `pymc.Distribution`. @@ -325,6 +333,9 @@ def make_distribution( Example: {"parameter": (lower_boundary, upper_boundary)}. lapse : optional A bmb.Prior object representing the lapse distribution. + extra_fields : optional + An optional list of arrays that are stored in the class created and will be + used in likelihood calculation. Defaults to None. Returns ------- @@ -337,13 +348,13 @@ def make_distribution( if list_params[-1] != "p_outlier": list_params.append("p_outlier") - data = pt.dvector() + data_vector = pt.dvector() lapse_logp = pm.logp( get_distribution_from_prior(lapse).dist(**lapse.args), - data, + data_vector, ) lapse_func = pytensor.function( - [data], + [data_vector], lapse_logp, ) @@ -356,29 +367,39 @@ class SSMDistribution(pm.Distribution): # NOTE: rv_op is an INSTANCE of RandomVariable rv_op = random_variable() params = list_params + _extra_fields = extra_fields @classmethod def dist(cls, **kwargs): # pylint: disable=arguments-renamed dist_params = [ pt.as_tensor_variable(pm.floatX(kwargs[param])) for param in cls.params ] + if cls._extra_fields: + dist_params += [pm.floatX(field) for field in cls._extra_fields] other_kwargs = {k: v for k, v in kwargs.items() if k not in cls.params} return super().dist(dist_params, **other_kwargs) def logp(data, *dist_params): # pylint: disable=E0213 + num_params = len(list_params) + extra_fields: Iterable[np.ndarray] = [] + + if num_params < len(dist_params): + extra_fields = dist_params[num_params:] + dist_params = dist_params[:num_params] + if list_params[-1] == "p_outlier": p_outlier = dist_params[-1] dist_params = dist_params[:-1] lapse_logp = lapse_func(data[:, 0].eval()) - logp = loglik(data, *dist_params) + logp = loglik(data, *dist_params, *extra_fields) logp = pt.log( (1.0 - p_outlier) * pt.exp(logp) + p_outlier * pt.exp(lapse_logp) + 1e-29 ) else: - logp = loglik(data, *dist_params) + logp = loglik(data, *dist_params, *extra_fields) if bounds is None: return logp @@ -398,6 +419,7 @@ def make_distribution_from_onnx( bounds: dict | None = None, params_is_reg: list[bool] | None = None, lapse: bmb.Prior | None = None, + extra_fields: list[np.ndarray] | None = None, ) -> Type[pm.Distribution]: """Make a PyMC distribution from an ONNX model. @@ -429,6 +451,9 @@ def make_distribution_from_onnx( corresponding position in `list_params` is a regression. lapse : optional A bmb.Prior object representing the lapse distribution. + extra_fields : optional + An optional list of arrays that are stored in the class created and will be + used in likelihood calculation. Defaults to None. Returns ------- @@ -446,21 +471,30 @@ def make_distribution_from_onnx( list_params, bounds=bounds, lapse=lapse, + extra_fields=extra_fields, ) if backend == "jax": if params_is_reg is None: params_is_reg = [False for param in list_params if param != "p_outlier"] + + # Extra fields are passed to the likelihood functions as vectors + # They do not need to be broadcast, so param_is_reg is padded with True + if extra_fields: + params_is_reg += [True for _ in extra_fields] + logp, logp_grad, logp_nojit = make_jax_logp_funcs_from_onnx( onnx_model, params_is_reg, ) lan_logp_jax = make_jax_logp_ops(logp, logp_grad, logp_nojit) + return make_distribution( rv, lan_logp_jax, list_params, bounds=bounds, lapse=lapse, + extra_fields=extra_fields, ) raise ValueError("Currently only 'pytensor' and 'jax' backends are supported.") @@ -581,6 +615,7 @@ def make_distribution_from_blackbox( loglik: Callable, list_params: list[str], bounds: dict | None = None, + extra_fields: list[np.ndarray] | None = None, ) -> Type[pm.Distribution]: """Make a `pymc.Distribution`. @@ -604,6 +639,9 @@ def make_distribution_from_blackbox( bounds : optional A dictionary with parameters as keys (a string) and its boundaries as values. Example: {"parameter": (lower_boundary, upper_boundary)}. + extra_fields : optional + An optional list of arrays that are stored in the class created and will be + used in likelihood calculation. Defaults to None. Returns ------- @@ -612,4 +650,6 @@ def make_distribution_from_blackbox( """ blackbox_op = make_blackbox_op(loglik) - return make_distribution(rv, blackbox_op, list_params, bounds) + return make_distribution( + rv, blackbox_op, list_params, bounds, extra_fields=extra_fields + ) diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py index 8d41146f..377864e6 100644 --- a/src/hssm/hssm.py +++ b/src/hssm/hssm.py @@ -95,6 +95,10 @@ class HSSM: `ssm_simulators` package. If `model` is not supported in `ssm_simulators`, a warning will be raised letting the user know that sampling from the `RandomVariable` will result in errors. + - `"extra_fields"`: Optional. A list of strings indicating the additional + columns in `data` that will be passed to the likelihood function for + calculation. This is helpful if the likelihood function depends on data + other than the observed data and the parameter values. loglik : optional A likelihood function. Defaults to None. Requirements are: @@ -210,6 +214,9 @@ def __init__( self.model_name = self.model_config.model_name self.loglik = self.model_config.loglik self.loglik_kind = self.model_config.loglik_kind + self.extra_fields = self.model_config.extra_fields + + self._check_extra_fields() # Process lapse distribution self.has_lapse = p_outlier is not None and p_outlier != 0 @@ -330,6 +337,9 @@ def sample( + "`nuts_blackjax` sampler if that is a problem." ) + if self._check_extra_fields(): + self._update_extra_fields() + self._inference_obj = self.model.fit(inference_method=sampler, **kwargs) return self.traces @@ -384,6 +394,10 @@ def sample_posterior_predictive( + "Please either provide an idata object or sample the model first." ) idata = self._inference_obj + + if self._check_extra_fields(data): + self._update_extra_fields(data) + return self.model.predict(idata, kind, data, inplace, include_group_specific) def sample_prior_predictive( @@ -819,6 +833,9 @@ def _make_model_distribution(self) -> type[pm.Distribution]: list_params=self.list_params, bounds=self.bounds, lapse=self.lapse, + extra_fields=None + if not self.extra_fields + else [self.data[field].values for field in self.extra_fields], ) # type: ignore # If the user has provided a callable (an arbitrary likelihood function) # If `loglik_kind` is `blackbox`, wrap it in an op and then a distribution @@ -833,6 +850,9 @@ def _make_model_distribution(self) -> type[pm.Distribution]: list_params=self.list_params, bounds=self.bounds, lapse=self.lapse, + extra_fields=None + if not self.extra_fields + else [self.data[field].values for field in self.extra_fields], ) # type: ignore # All other situations if self.loglik_kind != "approx_differentiable": @@ -859,4 +879,36 @@ def _make_model_distribution(self) -> type[pm.Distribution]: params_is_reg=params_is_reg, bounds=self.bounds, lapse=self.lapse, + extra_fields=None + if not self.extra_fields + else [self.data[field].values for field in self.extra_fields], ) + + def _check_extra_fields(self, data: pd.DataFrame | None = None) -> bool: + """Check if every field in self.extra_fields exists in data.""" + if not self.extra_fields: + return False + + if not data: + data = self.data + + for field in self.extra_fields: + if field not in data.columns: + raise ValueError(f"Field {field} not found in data.") + + return True + + def _update_extra_fields(self, new_data: pd.DataFrame | None = None): + """Update the extra fields data in self.model_distribution. + + Parameters + ---------- + new_data + A DataFrame containing new data for update. + """ + if not new_data: + new_data = self.data + + self.model_distribution.extra_fields = [ + new_data[field].values for field in self.extra_fields + ] diff --git a/tests/test_distribution_utils.py b/tests/test_distribution_utils.py index 60a142ce..bef9e551 100644 --- a/tests/test_distribution_utils.py +++ b/tests/test_distribution_utils.py @@ -1,9 +1,12 @@ import bambi as bmb import numpy as np +import pymc as pm import pytest +import hssm from hssm import distribution_utils from hssm.distribution_utils.dist import apply_param_bounds_to_loglik, make_distribution +from hssm.likelihoods.analytical import logp_ddm, DDM def test_make_ssm_rv(): @@ -152,3 +155,53 @@ def fake_logp_function(data, param1, param2): * scalar_in_bound * random_vector[~out_of_bound_indices], ) + + +def test_extra_fields(data_ddm): + ones = np.ones(len(data_ddm)) + x = ones * 0.5 + y = ones * 2 + + def logp_ddm_extra_fields(data, v, a, z, t, x, y): + return logp_ddm(data, v, a, z, t) * x * y + + DDM_WITH_XY = make_distribution( + rv="ddm", + loglik=logp_ddm_extra_fields, + list_params=["v", "a", "z", "t"], + extra_fields=[x, y], + ) + + true_values = dict(v=0.5, a=1.5, z=0.5, t=0.5) + + np.testing.assert_almost_equal( + pm.logp(DDM.dist(**true_values), data_ddm).eval(), + pm.logp(DDM_WITH_XY.dist(**true_values), data_ddm).eval(), + ) + + data_ddm_copy = data_ddm.copy() + data_ddm_copy["x"] = x + data_ddm_copy["y"] = y + + ddm_model_xy = hssm.HSSM( + data=data_ddm_copy, model_config=dict(extra_fields=["x", "y"]), p_outlier=None + ) + + np.testing.assert_almost_equal( + pm.logp(DDM.dist(**true_values), data_ddm).eval(), + pm.logp(ddm_model_xy.model_distribution.dist(**true_values), data_ddm).eval(), + ) + + ddm_model = hssm.HSSM(data=data_ddm) + ddm_model_p = hssm.HSSM( + data=data_ddm_copy, model_config=dict(extra_fields=["x", "y"]) + ) + np.testing.assert_almost_equal( + pm.logp( + ddm_model.model_distribution.dist(**true_values, p_outlier=0.05), data_ddm + ).eval(), + pm.logp( + ddm_model_p.model_distribution.dist(**true_values, p_outlier=0.05), + data_ddm, + ).eval(), + )