From f46a1d982461a8963a2a52952264260707534764 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 1 Jul 2024 19:13:01 +0200 Subject: [PATCH] Add none target --- .../stan/targets/none/none_data.stan | 3 + .../stan/targets/none/none_functions.stan | 8 ++ .../targets/none/none_model_log_simplex.stan | 3 + .../stan/targets/none/none_model_simplex.stan | 3 + tests/stan/test_stan_transforms.py | 105 ++++++++++++------ 5 files changed, 85 insertions(+), 37 deletions(-) create mode 100644 simplex_transforms/stan/targets/none/none_data.stan create mode 100644 simplex_transforms/stan/targets/none/none_functions.stan create mode 100644 simplex_transforms/stan/targets/none/none_model_log_simplex.stan create mode 100644 simplex_transforms/stan/targets/none/none_model_simplex.stan diff --git a/simplex_transforms/stan/targets/none/none_data.stan b/simplex_transforms/stan/targets/none/none_data.stan new file mode 100644 index 0000000..6dd30ee --- /dev/null +++ b/simplex_transforms/stan/targets/none/none_data.stan @@ -0,0 +1,3 @@ +data { + int N; +} diff --git a/simplex_transforms/stan/targets/none/none_functions.stan b/simplex_transforms/stan/targets/none/none_functions.stan new file mode 100644 index 0000000..2f21e51 --- /dev/null +++ b/simplex_transforms/stan/targets/none/none_functions.stan @@ -0,0 +1,8 @@ +real none_lpdf(vector theta) { + return 0; +} + +real log_none_lpdf(vector log_theta) { + int N = rows(log_theta); + return sum(log_theta[1 : N - 1]); +} diff --git a/simplex_transforms/stan/targets/none/none_model_log_simplex.stan b/simplex_transforms/stan/targets/none/none_model_log_simplex.stan new file mode 100644 index 0000000..ae28a24 --- /dev/null +++ b/simplex_transforms/stan/targets/none/none_model_log_simplex.stan @@ -0,0 +1,3 @@ +model { + target += log_none_lpdf(log_x); +} diff --git a/simplex_transforms/stan/targets/none/none_model_simplex.stan b/simplex_transforms/stan/targets/none/none_model_simplex.stan new file mode 100644 index 0000000..6153fcb --- /dev/null +++ b/simplex_transforms/stan/targets/none/none_model_simplex.stan @@ -0,0 +1,3 @@ +model { + target += none_lpdf(x); +} diff --git a/tests/stan/test_stan_transforms.py b/tests/stan/test_stan_transforms.py index 245b142..fe33c35 100644 --- a/tests/stan/test_stan_transforms.py +++ b/tests/stan/test_stan_transforms.py @@ -1,3 +1,4 @@ +import json import os import arviz as az @@ -34,6 +35,8 @@ targets_dir = os.path.join(project_dir, "targets") transforms_dir = os.path.join(project_dir, "transforms") stan_models = {} +bridgestan_models = {} +bridgestan_make_args = ["STAN_THREADS=true", "BRIDGESTAN_AD_HESSIAN=true"] def make_dirichlet_data(N: int, seed: int = 638): @@ -77,19 +80,6 @@ def make_jax_distribution(target: str, params: dict): raise ValueError(f"Unknown target {target}") -def make_stan_model( - model_file: str, target_name: str, transform_name: str, log_scale: bool -) -> tuple[cmdstanpy.CmdStanModel, list[str]]: - model_code, include_paths = simplex_transforms.stan.make_stan_code( - target_name, transform_name, log_scale - ) - with open(model_file, "w") as f: - f.write(model_code) - stanc_options = {"include-paths": ",".join(include_paths)} - model = cmdstanpy.CmdStanModel(stan_file=model_file, stanc_options=stanc_options) - return model, include_paths - - @pytest.mark.parametrize("N", [3, 5]) @pytest.mark.parametrize("log_scale", [False, True]) @pytest.mark.parametrize("target_name", ["dirichlet", "multi-logit-normal"]) @@ -108,10 +98,6 @@ def test_stan_and_jax_transforms_consistent( if target_name != "dirichlet" and transform_name not in ["ALR", "ILR"]: pytest.skip(f"No need to test {transform_name} with {target_name}. Skipping.") - constrain_with_logdetjac_vec = jax.vmap( - jax.vmap(trans.constrain_with_logdetjac, 0), 0 - ) - data = make_model_data(target_name, N, seed=seed) dist = make_jax_distribution(target_name, data) log_prob = dist.log_prob @@ -119,34 +105,29 @@ def test_stan_and_jax_transforms_consistent( # get compiled model or compile and add to cache model_key = (target_name, transform_name, log_scale) if model_key not in stan_models: + stan_code, include_paths = simplex_transforms.stan.make_stan_code( + target_name, transform_name, log_scale + ) + # save Stan code to file stan_file = os.path.join( tmpdir, f"{target_name}_{transform_name}_{'log_simplex' if log_scale else 'simplex'}.stan", ) - model, include_paths = make_stan_model( - stan_file, - target_name, - transform_name, - log_scale, - ) + with open(stan_file, "w") as f: + f.write(stan_code) + + # compile cmdstanpy model + stanc_options = {"include-paths": ",".join(include_paths)} + model = cmdstanpy.CmdStanModel(stan_file=stan_file, stanc_options=stanc_options) stan_models[model_key] = model # check that we can compile the bridgestan model stanc_args = ["--include-paths=" + ",".join(include_paths)] - stan_version = cmdstanpy.cmdstan_version() - if stan_version is None: - raise ValueError( - "Could not determine cmdstan version. It must be installed." - ) - stan_version = ".".join([str(i) for i in stan_version]) - make_args = [ - "STAN_THREADS=true", - "BRIDGESTAN_AD_HESSIAN=true", - f"STANC3_VERSION={stan_version}", - ] - bridgestan.compile_model(stan_file, stanc_args=stanc_args, make_args=make_args) + bridgestan_models[model_key] = bridgestan.compile_model( + stan_file, stanc_args=stanc_args, make_args=bridgestan_make_args + ) else: - model = stan_models[(target_name, transform_name, log_scale)] + model = stan_models[model_key] result = model.sample(data=data, iter_sampling=100, sig_figs=18, seed=stan_seed) idata = az.convert_to_inference_data(result) @@ -155,10 +136,60 @@ def test_stan_and_jax_transforms_consistent( y = trans.unconstrain(idata.posterior.x.data) else: y = idata.posterior.y.data - x_expected, lp_expected = constrain_with_logdetjac_vec(y) + x_expected, lp_expected = trans.constrain_with_logdetjac(y) if transform_name in expanded_transforms: r_expected, x_expected = x_expected lp_expected += trans.default_prior(x_expected).log_prob(r_expected) lp_expected += log_prob(x_expected) assert jnp.allclose(x_expected, idata.posterior.x.data, rtol=1e-4) assert jnp.allclose(lp_expected, idata.sample_stats.lp.data, rtol=1e-4) + + +@pytest.mark.parametrize("N", [3, 5]) +@pytest.mark.parametrize("log_scale", [False, True]) +@pytest.mark.parametrize("transform_name", ["ALR", "ExpandedSoftmax"]) +def test_none_target(tmpdir, transform_name, N, log_scale, seed=638): + target_name = "none" + trans = getattr(jax_transforms, transform_name)() + + data = {"N": N} + data_str = json.dumps(data) + + # get compiled model or compile and add to cache + model_key = (target_name, transform_name, log_scale) + if model_key not in bridgestan_models: + stan_code, include_paths = simplex_transforms.stan.make_stan_code( + target_name, + transform_name, + log_scale, + ) + # save Stan code to file + stan_file = os.path.join( + tmpdir, + f"{target_name}_{transform_name}_{'log_simplex' if log_scale else 'simplex'}.stan", + ) + with open(stan_file, "w") as f: + f.write(stan_code) + + # check that we can compile the bridgestan model + stanc_args = ["--include-paths=" + ",".join(include_paths)] + model_file = bridgestan.compile_model( + stan_file, stanc_args=stanc_args, make_args=bridgestan_make_args + ) + bridgestan_models[model_key] = model_file + else: + model_file = bridgestan_models[model_key] + model = bridgestan.StanModel(model_file, data=data_str) + + M = N - (transform_name in basic_transforms) + + y = np.random.default_rng(seed).normal(size=(100, M)) + x_expected, lp_expected = trans.constrain_with_logdetjac(y) + if transform_name in expanded_transforms: + r_expected, x_expected = x_expected + lp_expected += trans.default_prior(x_expected).log_prob(r_expected) + + lp = np.apply_along_axis( + lambda y: model.log_density(y, propto=False, jacobian=True), -1, y + ) + assert np.allclose(lp, lp_expected, rtol=1e-4)