From da319320241970760d0a3c8169918375f50db60f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul-Christian=20B=C3=BCrkner?= Date: Fri, 24 Jan 2025 14:12:45 +0100 Subject: [PATCH] make lprior tags work also for non-coef priors --- R/stan-prior.R | 33 ++++++++++++++++++++++++--------- R/stancode.R | 19 +++++++++++++------ man/set_prior.Rd | 1 + 3 files changed, 38 insertions(+), 15 deletions(-) diff --git a/R/stan-prior.R b/R/stan-prior.R index 92fd1585a..1b9241bb2 100644 --- a/R/stan-prior.R +++ b/R/stan-prior.R @@ -46,12 +46,13 @@ stan_prior <- function(prior, class, coef = NULL, group = NULL, if (nrow(upx) > 1L) { # TODO: find a better solution to handle this case # can only happen for SD parameters of the same ID - base_prior <- lb <- ub <- rep(NA, nrow(upx)) + base_prior <- base_lprior_tag <- lb <- ub <- rep(NA, nrow(upx)) base_bounds <- data.frame(lb = lb, ub = ub) for (i in seq_rows(upx)) { sub_upx <- lapply(upx[i, ], function(x) c(x, "")) sub_prior <- subset2(prior, ls = sub_upx) base_prior[i] <- stan_base_prior(sub_prior) + base_lprior_tag[i] <- stan_base_prior(sub_prior, col = "lprior") base_bounds[i, ] <- stan_base_prior(sub_prior, col = c("lb", "ub")) } if (length(unique(base_prior)) > 1L) { @@ -61,15 +62,19 @@ stan_prior <- function(prior, class, coef = NULL, group = NULL, prior_of_coefs <- prior[take_coef_prior, vars_prefix()] take_base_prior <- match_rows(prior_of_coefs, upx) prior$prior[take_coef_prior] <- base_prior[take_base_prior] + prior$lprior[take_coef_prior] <- base_lprior_tag[take_base_prior] } base_prior <- base_prior[1] + base_lprior_tag <- base_lprior_tag[1] if (nrow(unique(base_bounds)) > 1L) { stop2("Conflicting boundary information for ", "coefficients of class '", class, "'.") } base_bounds <- base_bounds[1, ] } else { + # TODO: select base_prior together with tags and boundaries in one call? base_prior <- stan_base_prior(prior) + base_lprior_tag <- stan_base_prior(prior, col = "lprior") # select both bounds together so that they come from the same base prior base_bounds <- stan_base_prior(prior, col = c("lb", "ub")) } @@ -132,13 +137,15 @@ stan_prior <- function(prior, class, coef = NULL, group = NULL, coef_prior, par_ij, broadcast = broadcast, bound = bound, resp = px$resp[1], normalize = normalize ) - # add to the lprior - str_add(out$tpar_prior) <- paste0(lpp(), coef_prior, ";\n") - # add to the lprior of the tag if specified - if (!is.null(lprior_tag) && lprior_tag != "") { - str_add(out$tpar_prior) <- paste0(lpp(tag = lprior_tag), coef_prior, ";\n") + if (isTRUE(nzchar(lprior_tag))) { + # add to a local lprior variable if specified + str_add(out$tpar_prior) <- paste0( + lpp(tag = lprior_tag), coef_prior, ";\n" + ) + } else { + # add to the global lprior variable directly + str_add(out$tpar_prior) <- paste0(lpp(), coef_prior, ";\n") } - } } } @@ -180,7 +187,15 @@ stan_prior <- function(prior, class, coef = NULL, group = NULL, base_prior, par = par, ncoef = ncoef, bound = bound, broadcast = broadcast, resp = px$resp[1], normalize = normalize ) - str_add(out$tpar_prior) <- paste0(lpp(), target_base_prior, ";\n") + if (isTRUE(nzchar(base_lprior_tag))) { + # add to a local lprior variable if specified + str_add(out$tpar_prior) <- paste0( + lpp(tag = base_lprior_tag), target_base_prior, ";\n" + ) + } else { + # add to the global lprior variable directly + str_add(out$tpar_prior) <- paste0(lpp(), target_base_prior, ";\n") + } } } @@ -225,7 +240,7 @@ stan_prior <- function(prior, class, coef = NULL, group = NULL, # finding the base prior # @return the 'col' columns of the identified base prior stan_base_prior <- function(prior, col = "prior", sel_prior = NULL, ...) { - stopifnot(all(col %in% c("prior", "lb", "ub"))) + stopifnot(all(col %in% c("prior", "lb", "ub", "lprior"))) if (!is.null(sel_prior)) { # find the base prior using sel_prior for subsetting stopifnot(is.brmsprior(sel_prior)) diff --git a/R/stancode.R b/R/stancode.R index 4b0ae17b5..897c1d031 100644 --- a/R/stancode.R +++ b/R/stancode.R @@ -114,8 +114,6 @@ stancode.default <- function(object, data, family = gaussian(), backend = getOption("brms.backend", "rstan"), silent = TRUE, save_model = NULL, ...) { - lprior_tags <- prior$lprior[prior$lprior != ""] - normalize <- as_one_logical(normalize) parse <- as_one_logical(parse) backend <- match.arg(backend, backend_choices()) @@ -278,17 +276,24 @@ stancode.default <- function(object, data, family = gaussian(), "}\n" ) + # prepare lprior tags + lprior_tags <- unique(prior$lprior) + scode_lprior_def <- paste0( + " // prior contributions to the log posterior\n", + collapse(" real lprior", usc(lprior_tags), " = 0;\n") + ) + lprior_tags <- lprior_tags[nzchar(lprior_tags)] + scode_lprior_assign <- str_if(length(lprior_tags), + collapse(" lprior += lprior", usc(lprior_tags), ";\n") + ) + # generate transformed parameters block - scode_lprior_def <- " real lprior = 0; // prior contributions to the log posterior\n" - scode_lprior_tags_def <- paste0( - " real lprior_", unique(lprior_tags), " = 0;\n", collapse = "") scode_transformed_parameters <- paste0( "transformed parameters {\n", scode_predictor[["tpar_def"]], scode_re[["tpar_def"]], scode_Xme[["tpar_def"]], str_if(normalize, scode_lprior_def), - str_if(normalize, scode_lprior_tags_def), collapse_stanvars(stanvars, "tparameters", "start"), scode_predictor[["tpar_prior_const"]], scode_re[["tpar_prior_const"]], @@ -300,6 +305,7 @@ stancode.default <- function(object, data, family = gaussian(), # lprior cannot contain _lupdf functions in transformed parameters # as discussed on github.com/stan-dev/stan/issues/3094 str_if(normalize, scode_tpar_prior), + str_if(normalize, scode_lprior_assign), collapse_stanvars(stanvars, "tparameters", "end"), "}\n" ) @@ -316,6 +322,7 @@ stancode.default <- function(object, data, family = gaussian(), " }\n", " // priors", not_const, " including constants\n", str_if(!normalize, scode_tpar_prior), + str_if(!normalize, scode_lprior_assign), " target += lprior;\n", scode_predictor[["model_prior"]], scode_re[["model_prior"]], diff --git a/man/set_prior.Rd b/man/set_prior.Rd index eaf678cdb..e0d3c0e7f 100644 --- a/man/set_prior.Rd +++ b/man/set_prior.Rd @@ -20,6 +20,7 @@ set_prior( nlpar = "", lb = NA, ub = NA, + lprior = "", check = TRUE )