Skip to content

Commit

Permalink
Add none target
Browse files Browse the repository at this point in the history
  • Loading branch information
sethaxen committed Jul 1, 2024
1 parent 174d8c9 commit f46a1d9
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 37 deletions.
3 changes: 3 additions & 0 deletions simplex_transforms/stan/targets/none/none_data.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
data {
int<lower=1> N;
}
8 changes: 8 additions & 0 deletions simplex_transforms/stan/targets/none/none_functions.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
real none_lpdf(vector theta) {
return 0;
}

real log_none_lpdf(vector log_theta) {
int N = rows(log_theta);
return sum(log_theta[1 : N - 1]);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
model {
target += log_none_lpdf(log_x);
}
3 changes: 3 additions & 0 deletions simplex_transforms/stan/targets/none/none_model_simplex.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
model {
target += none_lpdf(x);
}
105 changes: 68 additions & 37 deletions tests/stan/test_stan_transforms.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os

import arviz as az
Expand Down Expand Up @@ -34,6 +35,8 @@
targets_dir = os.path.join(project_dir, "targets")
transforms_dir = os.path.join(project_dir, "transforms")
stan_models = {}
bridgestan_models = {}
bridgestan_make_args = ["STAN_THREADS=true", "BRIDGESTAN_AD_HESSIAN=true"]


def make_dirichlet_data(N: int, seed: int = 638):
Expand Down Expand Up @@ -77,19 +80,6 @@ def make_jax_distribution(target: str, params: dict):
raise ValueError(f"Unknown target {target}")


def make_stan_model(
model_file: str, target_name: str, transform_name: str, log_scale: bool
) -> tuple[cmdstanpy.CmdStanModel, list[str]]:
model_code, include_paths = simplex_transforms.stan.make_stan_code(
target_name, transform_name, log_scale
)
with open(model_file, "w") as f:
f.write(model_code)
stanc_options = {"include-paths": ",".join(include_paths)}
model = cmdstanpy.CmdStanModel(stan_file=model_file, stanc_options=stanc_options)
return model, include_paths


@pytest.mark.parametrize("N", [3, 5])
@pytest.mark.parametrize("log_scale", [False, True])
@pytest.mark.parametrize("target_name", ["dirichlet", "multi-logit-normal"])
Expand All @@ -108,45 +98,36 @@ def test_stan_and_jax_transforms_consistent(
if target_name != "dirichlet" and transform_name not in ["ALR", "ILR"]:
pytest.skip(f"No need to test {transform_name} with {target_name}. Skipping.")

constrain_with_logdetjac_vec = jax.vmap(
jax.vmap(trans.constrain_with_logdetjac, 0), 0
)

data = make_model_data(target_name, N, seed=seed)
dist = make_jax_distribution(target_name, data)
log_prob = dist.log_prob

# get compiled model or compile and add to cache
model_key = (target_name, transform_name, log_scale)
if model_key not in stan_models:
stan_code, include_paths = simplex_transforms.stan.make_stan_code(
target_name, transform_name, log_scale
)
# save Stan code to file
stan_file = os.path.join(
tmpdir,
f"{target_name}_{transform_name}_{'log_simplex' if log_scale else 'simplex'}.stan",
)
model, include_paths = make_stan_model(
stan_file,
target_name,
transform_name,
log_scale,
)
with open(stan_file, "w") as f:
f.write(stan_code)

# compile cmdstanpy model
stanc_options = {"include-paths": ",".join(include_paths)}
model = cmdstanpy.CmdStanModel(stan_file=stan_file, stanc_options=stanc_options)
stan_models[model_key] = model

# check that we can compile the bridgestan model
stanc_args = ["--include-paths=" + ",".join(include_paths)]
stan_version = cmdstanpy.cmdstan_version()
if stan_version is None:
raise ValueError(
"Could not determine cmdstan version. It must be installed."
)
stan_version = ".".join([str(i) for i in stan_version])
make_args = [
"STAN_THREADS=true",
"BRIDGESTAN_AD_HESSIAN=true",
f"STANC3_VERSION={stan_version}",
]
bridgestan.compile_model(stan_file, stanc_args=stanc_args, make_args=make_args)
bridgestan_models[model_key] = bridgestan.compile_model(
stan_file, stanc_args=stanc_args, make_args=bridgestan_make_args
)
else:
model = stan_models[(target_name, transform_name, log_scale)]
model = stan_models[model_key]

result = model.sample(data=data, iter_sampling=100, sig_figs=18, seed=stan_seed)
idata = az.convert_to_inference_data(result)
Expand All @@ -155,10 +136,60 @@ def test_stan_and_jax_transforms_consistent(
y = trans.unconstrain(idata.posterior.x.data)
else:
y = idata.posterior.y.data
x_expected, lp_expected = constrain_with_logdetjac_vec(y)
x_expected, lp_expected = trans.constrain_with_logdetjac(y)
if transform_name in expanded_transforms:
r_expected, x_expected = x_expected
lp_expected += trans.default_prior(x_expected).log_prob(r_expected)
lp_expected += log_prob(x_expected)
assert jnp.allclose(x_expected, idata.posterior.x.data, rtol=1e-4)
assert jnp.allclose(lp_expected, idata.sample_stats.lp.data, rtol=1e-4)


@pytest.mark.parametrize("N", [3, 5])
@pytest.mark.parametrize("log_scale", [False, True])
@pytest.mark.parametrize("transform_name", ["ALR", "ExpandedSoftmax"])
def test_none_target(tmpdir, transform_name, N, log_scale, seed=638):
target_name = "none"
trans = getattr(jax_transforms, transform_name)()

data = {"N": N}
data_str = json.dumps(data)

# get compiled model or compile and add to cache
model_key = (target_name, transform_name, log_scale)
if model_key not in bridgestan_models:
stan_code, include_paths = simplex_transforms.stan.make_stan_code(
target_name,
transform_name,
log_scale,
)
# save Stan code to file
stan_file = os.path.join(
tmpdir,
f"{target_name}_{transform_name}_{'log_simplex' if log_scale else 'simplex'}.stan",
)
with open(stan_file, "w") as f:
f.write(stan_code)

# check that we can compile the bridgestan model
stanc_args = ["--include-paths=" + ",".join(include_paths)]
model_file = bridgestan.compile_model(
stan_file, stanc_args=stanc_args, make_args=bridgestan_make_args
)
bridgestan_models[model_key] = model_file
else:
model_file = bridgestan_models[model_key]
model = bridgestan.StanModel(model_file, data=data_str)

M = N - (transform_name in basic_transforms)

y = np.random.default_rng(seed).normal(size=(100, M))
x_expected, lp_expected = trans.constrain_with_logdetjac(y)
if transform_name in expanded_transforms:
r_expected, x_expected = x_expected
lp_expected += trans.default_prior(x_expected).log_prob(r_expected)

lp = np.apply_along_axis(
lambda y: model.log_density(y, propto=False, jacobian=True), -1, y
)
assert np.allclose(lp, lp_expected, rtol=1e-4)

0 comments on commit f46a1d9

Please sign in to comment.