Skip to content

Commit

Permalink
make lprior tags work also for non-coef priors
Browse files Browse the repository at this point in the history
  • Loading branch information
paul-buerkner committed Jan 24, 2025
1 parent df007fb commit da31932
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 15 deletions.
33 changes: 24 additions & 9 deletions R/stan-prior.R
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,13 @@ stan_prior <- function(prior, class, coef = NULL, group = NULL,
if (nrow(upx) > 1L) {
# TODO: find a better solution to handle this case
# can only happen for SD parameters of the same ID
base_prior <- lb <- ub <- rep(NA, nrow(upx))
base_prior <- base_lprior_tag <- lb <- ub <- rep(NA, nrow(upx))
base_bounds <- data.frame(lb = lb, ub = ub)
for (i in seq_rows(upx)) {
sub_upx <- lapply(upx[i, ], function(x) c(x, ""))
sub_prior <- subset2(prior, ls = sub_upx)
base_prior[i] <- stan_base_prior(sub_prior)
base_lprior_tag[i] <- stan_base_prior(sub_prior, col = "lprior")
base_bounds[i, ] <- stan_base_prior(sub_prior, col = c("lb", "ub"))
}
if (length(unique(base_prior)) > 1L) {
Expand All @@ -61,15 +62,19 @@ stan_prior <- function(prior, class, coef = NULL, group = NULL,
prior_of_coefs <- prior[take_coef_prior, vars_prefix()]
take_base_prior <- match_rows(prior_of_coefs, upx)
prior$prior[take_coef_prior] <- base_prior[take_base_prior]
prior$lprior[take_coef_prior] <- base_lprior_tag[take_base_prior]
}
base_prior <- base_prior[1]
base_lprior_tag <- base_lprior_tag[1]
if (nrow(unique(base_bounds)) > 1L) {
stop2("Conflicting boundary information for ",
"coefficients of class '", class, "'.")
}
base_bounds <- base_bounds[1, ]
} else {
# TODO: select base_prior together with tags and boundaries in one call?
base_prior <- stan_base_prior(prior)
base_lprior_tag <- stan_base_prior(prior, col = "lprior")
# select both bounds together so that they come from the same base prior
base_bounds <- stan_base_prior(prior, col = c("lb", "ub"))
}
Expand Down Expand Up @@ -132,13 +137,15 @@ stan_prior <- function(prior, class, coef = NULL, group = NULL,
coef_prior, par_ij, broadcast = broadcast,
bound = bound, resp = px$resp[1], normalize = normalize
)
# add to the lprior
str_add(out$tpar_prior) <- paste0(lpp(), coef_prior, ";\n")
# add to the lprior of the tag if specified
if (!is.null(lprior_tag) && lprior_tag != "") {
str_add(out$tpar_prior) <- paste0(lpp(tag = lprior_tag), coef_prior, ";\n")
if (isTRUE(nzchar(lprior_tag))) {
# add to a local lprior variable if specified
str_add(out$tpar_prior) <- paste0(
lpp(tag = lprior_tag), coef_prior, ";\n"
)
} else {
# add to the global lprior variable directly
str_add(out$tpar_prior) <- paste0(lpp(), coef_prior, ";\n")
}

}
}
}
Expand Down Expand Up @@ -180,7 +187,15 @@ stan_prior <- function(prior, class, coef = NULL, group = NULL,
base_prior, par = par, ncoef = ncoef, bound = bound,
broadcast = broadcast, resp = px$resp[1], normalize = normalize
)
str_add(out$tpar_prior) <- paste0(lpp(), target_base_prior, ";\n")
if (isTRUE(nzchar(base_lprior_tag))) {
# add to a local lprior variable if specified
str_add(out$tpar_prior) <- paste0(
lpp(tag = base_lprior_tag), target_base_prior, ";\n"
)
} else {
# add to the global lprior variable directly
str_add(out$tpar_prior) <- paste0(lpp(), target_base_prior, ";\n")
}
}
}

Expand Down Expand Up @@ -225,7 +240,7 @@ stan_prior <- function(prior, class, coef = NULL, group = NULL,
# finding the base prior
# @return the 'col' columns of the identified base prior
stan_base_prior <- function(prior, col = "prior", sel_prior = NULL, ...) {
stopifnot(all(col %in% c("prior", "lb", "ub")))
stopifnot(all(col %in% c("prior", "lb", "ub", "lprior")))
if (!is.null(sel_prior)) {
# find the base prior using sel_prior for subsetting
stopifnot(is.brmsprior(sel_prior))
Expand Down
19 changes: 13 additions & 6 deletions R/stancode.R
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,6 @@ stancode.default <- function(object, data, family = gaussian(),
backend = getOption("brms.backend", "rstan"),
silent = TRUE, save_model = NULL, ...) {

lprior_tags <- prior$lprior[prior$lprior != ""]

normalize <- as_one_logical(normalize)
parse <- as_one_logical(parse)
backend <- match.arg(backend, backend_choices())
Expand Down Expand Up @@ -278,17 +276,24 @@ stancode.default <- function(object, data, family = gaussian(),
"}\n"
)

# prepare lprior tags
lprior_tags <- unique(prior$lprior)
scode_lprior_def <- paste0(
" // prior contributions to the log posterior\n",
collapse(" real lprior", usc(lprior_tags), " = 0;\n")
)
lprior_tags <- lprior_tags[nzchar(lprior_tags)]
scode_lprior_assign <- str_if(length(lprior_tags),
collapse(" lprior += lprior", usc(lprior_tags), ";\n")
)

# generate transformed parameters block
scode_lprior_def <- " real lprior = 0; // prior contributions to the log posterior\n"
scode_lprior_tags_def <- paste0(
" real lprior_", unique(lprior_tags), " = 0;\n", collapse = "")
scode_transformed_parameters <- paste0(
"transformed parameters {\n",
scode_predictor[["tpar_def"]],
scode_re[["tpar_def"]],
scode_Xme[["tpar_def"]],
str_if(normalize, scode_lprior_def),
str_if(normalize, scode_lprior_tags_def),
collapse_stanvars(stanvars, "tparameters", "start"),
scode_predictor[["tpar_prior_const"]],
scode_re[["tpar_prior_const"]],
Expand All @@ -300,6 +305,7 @@ stancode.default <- function(object, data, family = gaussian(),
# lprior cannot contain _lupdf functions in transformed parameters
# as discussed on github.com/stan-dev/stan/issues/3094
str_if(normalize, scode_tpar_prior),
str_if(normalize, scode_lprior_assign),
collapse_stanvars(stanvars, "tparameters", "end"),
"}\n"
)
Expand All @@ -316,6 +322,7 @@ stancode.default <- function(object, data, family = gaussian(),
" }\n",
" // priors", not_const, " including constants\n",
str_if(!normalize, scode_tpar_prior),
str_if(!normalize, scode_lprior_assign),
" target += lprior;\n",
scode_predictor[["model_prior"]],
scode_re[["model_prior"]],
Expand Down
1 change: 1 addition & 0 deletions man/set_prior.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit da31932

Please sign in to comment.