diff --git a/NAMESPACE b/NAMESPACE index 98da81f54..177c4be3b 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -565,6 +565,7 @@ export(ranef) export(rasym_laplace) export(rbeta_binomial) export(rdirichlet) +export(re) export(read_csv_as_stanfit) export(recompile_model) export(reloo) diff --git a/NEWS.md b/NEWS.md index d83e4ffb8..f2e58c57e 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,11 @@ # brms 3.0 +### New Features + +* Use varying coefficients as predictors in other model parts via +`re` predictor terms. (#1687) + + # brms 2.23 ### New Features diff --git a/R/brmsframe.R b/R/brmsframe.R index 0c45521c1..ea7ee11c8 100644 --- a/R/brmsframe.R +++ b/R/brmsframe.R @@ -34,7 +34,6 @@ brmsframe.brmsterms <- function(x, data, frame = NULL, basis = NULL, ...) { # this must be a multivariate model stopifnot(is.list(frame)) x$frame <- frame - x$frame$re <- subset2(x$frame$re, resp = x$resp) } data <- subset_data(data, x) x$frame$resp <- frame_resp(x, data = data) @@ -51,6 +50,10 @@ brmsframe.brmsterms <- function(x, data, frame = NULL, basis = NULL, ...) { basis = basis$nlpars[[nlp]], ... ) } + # If this is a multivariate model, retain only the subset of random effects + # belonging to the current response variable. Subsetting is performed here + # rather than earlier to allow for correct validation of 're' terms in stan_sp. + x$frame$re <- subset2(x$frame$re, resp = x$resp) class(x) <- c("brmsframe", class(x)) x } @@ -79,8 +82,8 @@ brmsframe.btl <- function(x, data, frame = list(), basis = NULL, ...) { x$frame$sp <- frame_sp(x, data = data) x$frame$gp <- frame_gp(x, data = data) x$frame$ac <- frame_ac(x, data = data) - # only store the ranefs of this specific linear formula - x$frame$re <- subset2(frame$re, ls = check_prefix(x)) + # only keep the ranefs of this specific linear formula + x$frame$re <- subset2(x$frame$re, ls = check_prefix(x)) class(x) <- c("bframel", class(x)) # these data_ functions may require the outputs of the corresponding # frame_ functions (but not vice versa) and are thus evaluated last diff --git a/R/brmsterms.R b/R/brmsterms.R index 15090769d..71f52d131 100644 --- a/R/brmsterms.R +++ b/R/brmsterms.R @@ -434,23 +434,22 @@ terms_cs <- function(formula) { # extract special effects terms terms_sp <- function(formula) { - types <- c("mo", "me", "mi") - out <- find_terms(formula, types, complete = FALSE) + out <- find_terms(formula, all_sp_types(), complete = FALSE) if (!length(out)) { return(NULL) } uni_mo <- get_matches_expr(regex_sp("mo"), out) uni_me <- get_matches_expr(regex_sp("me"), out) uni_mi <- get_matches_expr(regex_sp("mi"), out) + uni_re <- get_matches_expr(regex_sp("re"), out) # remove the intercept as it is handled separately out <- str2formula(c("0", out)) attr(out, "int") <- FALSE attr(out, "uni_mo") <- uni_mo attr(out, "uni_me") <- uni_me attr(out, "uni_mi") <- uni_mi + attr(out, "uni_re") <- uni_re attr(out, "allvars") <- str2formula(all_vars(out)) - # TODO: do we need sp_fake_formula at all? - # attr(out, "allvars") <- sp_fake_formula(uni_mo, uni_me, uni_mi) out } diff --git a/R/formula-re.R b/R/formula-re.R index aaae9c4d1..487e2aeb3 100644 --- a/R/formula-re.R +++ b/R/formula-re.R @@ -523,6 +523,7 @@ get_re.btl <- function(x, ...) { # id: ID of the group-level effect # group: name of the grouping factor # gn: number of the grouping term within the respective formula +# gtype: type of the grouping term: 'gr' or 'mm' # coef: name of the group-level effect # cn: number of the effect within the ID # resp: name of the response variable @@ -729,7 +730,7 @@ frame_re_levels_only <- function(bterms, data) { empty_reframe <- function() { out <- data.frame( - id = numeric(0), group = character(0), gn = numeric(0), + id = numeric(0), group = character(0), gn = numeric(0), gtype = character(0), coef = character(0), cn = numeric(0), resp = character(0), dpar = character(0), nlpar = character(0), ggn = numeric(0), cor = logical(0), type = character(0), form = character(0), @@ -755,6 +756,17 @@ is.reframe <- function(x) { inherits(x, "reframe") } +# helper function to find matching rows in reframes +# @param x the reframe to be matched +# @param y the reference reframe to be matched against +# @return an integer vector of matching rows +which_rows_reframe <- function(x, y) { + stopifnot(is.reframe(x), is.reframe(y)) + # these columns define a row uniquely in reframes + cols <- c("group", "coef", "resp", "dpar", "nlpar") + which_rows(x, ls = y[cols]) +} + # extract names of all grouping variables get_group_vars <- function(x, ...) { UseMethod("get_group_vars") diff --git a/R/formula-sp.R b/R/formula-sp.R index da5d1a1ee..54ab3cd6b 100644 --- a/R/formula-sp.R +++ b/R/formula-sp.R @@ -211,6 +211,67 @@ mo <- function(x, id = NA) { out } +#' Group-level effects as predictors in \pkg{brms} Models +#' +#' Specify a group-level predictor term in \pkg{brms}. That is, +#' use group-level effects defined somewhere in the model as +#' predictors in another part of the model. The function does not +#' evaluate its arguments -- it exists purely to help set up a model. +#' +#' @param gr Name of the grouping factor of the group-level effect +#' to be used as predictor. +#' @param coef Optional name of the coefficient of the group-level effect. +#' Defaults to \code{"Intercept"}. +#' @param resp Optional name of the response variable of the group-level effect. +#' @param dpar Optional name of the distributional parameter of the group-level effect. +#' @param nlpar Optional name of the non-linear parameter of the group-level effect. +#' +#' @seealso \code{\link{brmsformula}} +#' +#' @examples +#' \dontrun{ +#' # use the group-level intercept of 'AY' for parameter 'ult' +#' # as predictor for the residual standard deviation 'sigma' +#' # multiplying by 1000 reduces the scale of 'ult' to roughly unity +#' bform <- bf( +#' cum ~ 1000 * ult * (1 - exp(-(dev/theta)^omega)), +#' ult ~ 1 + (1|AY), omega ~ 1, theta ~ 1, +#' sigma ~ re(AY, nlpar = "ult"), +#' nl = TRUE +#' ) +#' bprior <- c( +#' prior(normal(5, 1), nlpar = "ult"), +#' prior(normal(1, 2), nlpar = "omega"), +#' prior(normal(45, 10), nlpar = "theta"), +#' prior(normal(0, 0.5), dpar = "sigma") +#' ) +#' +#' fit <- brm( +#' bform, data = loss, +#' family = gaussian(), +#' prior = bprior, +#' control = list(adapt_delta = 0.9), +#' chains = 2 +#' ) +#' summary(fit) +#' +#' # shows how sigma varies as a function of the AY levels +#' conditional_effects(fit, "AY", dpar = "sigma", re_formula = NULL) +#' } +#' +#' @export +re <- function(gr, coef = "Intercept", resp = "", dpar = "", nlpar = "") { + term <- as_one_character(deparse_no_string(substitute(gr))) + coef <- as_one_character(coef) + resp <- as_one_character(resp) + dpar <- as_one_character(dpar) + nlpar <- as_one_character(nlpar) + label <- deparse0(match.call()) + out <- nlist(term, coef, resp, dpar, nlpar, label) + class(out) <- c("re_term", "sp_term") + out +} + # find variable names for which to keep NAs vars_keep_na <- function(x, ...) { UseMethod("vars_keep_na") @@ -351,8 +412,7 @@ get_sp_vars <- function(x, type) { } # gather information of special effects terms -# @param x either a formula or a list containing an element "sp" -# @param data data frame containing the monotonic variables +# @param x a formula, brmsterms, or brmsframe object # @return a data.frame with one row per special term # TODO: refactor to store in long format to avoid several list columns? frame_sp <- function(x, data) { @@ -367,7 +427,7 @@ frame_sp <- function(x, data) { out <- data.frame(term = colnames(mm), stringsAsFactors = FALSE) out$coef <- rename(out$term) calls_cols <- c(paste0("calls_", all_sp_types()), "joint_call") - list_cols <- c("vars_mi", "idx_mi", "idx2_mi", "ids_mo", "Imo") + list_cols <- c("vars_mi", "idx_mi", "idx2_mi", "ids_mo", "Imo", "reframe") for (col in c(calls_cols, list_cols)) { out[[col]] <- vector("list", nrow(out)) } @@ -376,7 +436,7 @@ frame_sp <- function(x, data) { for (i in seq_rows(out)) { # prepare mo terms take_mo <- grepl_expr(regex_sp("mo"), terms_split[[i]]) - if (sum(take_mo)) { + if (any(take_mo)) { out$calls_mo[[i]] <- terms_split[[i]][take_mo] nmo <- length(out$calls_mo[[i]]) out$Imo[[i]] <- (kmo + 1):(kmo + nmo) @@ -393,7 +453,7 @@ frame_sp <- function(x, data) { } # prepare me terms take_me <- grepl_expr(regex_sp("me"), terms_split[[i]]) - if (sum(take_me)) { + if (any(take_me)) { out$calls_me[[i]] <- terms_split[[i]][take_me] # remove 'I' (identity) function calls that # were used solely to separate formula terms @@ -401,7 +461,7 @@ frame_sp <- function(x, data) { } # prepare mi terms take_mi <- grepl_expr(regex_sp("mi"), terms_split[[i]]) - if (sum(take_mi)) { + if (any(take_mi)) { mi_parts <- terms_split[[i]][take_mi] out$calls_mi[[i]] <- get_matches_expr(regex_sp("mi"), mi_parts) out$vars_mi[[i]] <- out$idx_mi[[i]] <- rep(NA, length(out$calls_mi[[i]])) @@ -415,6 +475,40 @@ frame_sp <- function(x, data) { # do it like terms_resp to ensure correct matching out$vars_mi[[i]] <- gsub("\\.|_", "", make.names(out$vars_mi[[i]])) } + take_re <- grepl_expr(regex_sp("re"), terms_split[[i]]) + if (any(take_re)) { + re_parts <- terms_split[[i]][take_re] + out$calls_re[[i]] <- get_matches_expr(regex_sp("re"), re_parts) + out$reframe[[i]] <- vector("list", length(out$calls_re[[i]])) + for (j in seq_along(out$calls_re[[i]])) { + re_call <- out$calls_re[[i]][[j]] + re_term <- eval2(re_call) + if (!is.null(x$frame$re)) { + stopifnot(is.reframe(x$frame$re)) + cols <- c("coef", "resp", "dpar", "nlpar") + rf <- subset2(x$frame$re, group = re_term$term, ls = re_term[cols]) + # Ideally we should check here if the required re term can be found. + # However this will lead to errors in post-processing even if the + # re terms are not actually evaluated. See prepare_predictions_sp + # for more details. The necessary pre-processing validity check + # is instead done in stan_sp. + # if (!NROW(rf)) { + # stop2("Cannot find varying coefficients belonging to ", re_call, ".") + # } + # there should theoretically never be more than one matching row + stopifnot(NROW(rf) <= 1L) + if (isTRUE(rf$gtype == "mm")) { + stop2("Multimembership terms are not yet supported by 're'.") + } + out$reframe[[i]][[j]] <- rf + } + } + if (!isNULL(out$reframe[[i]])) { + out$reframe[[i]] <- Reduce(rbind, out$reframe[[i]]) + } else { + out$reframe[[i]] <- empty_reframe() + } + } has_sp_calls <- grepl_expr(regex_sp(all_sp_types()), terms_split[[i]]) sp_calls <- sub("^I\\(", "(", terms_split[[i]][has_sp_calls]) out$joint_call[[i]] <- paste0(sp_calls, collapse = " * ") @@ -539,17 +633,6 @@ sp_model_matrix <- function(formula, data, types = all_sp_types(), ...) { out } -# formula of variables used in special effects terms -sp_fake_formula <- function(...) { - dots <- c(...) - out <- vector("list", length(dots)) - for (i in seq_along(dots)) { - tmp <- eval2(dots[[i]]) - out[[i]] <- all_vars(c(tmp$term, tmp$sdx, tmp$gr)) - } - str2formula(unique(unlist(out))) -} - # extract an me variable get_me_values <- function(term, data) { term <- get_sp_term(term) @@ -624,7 +707,7 @@ get_sp_term <- function(term) { # all effects which fall under the 'sp' category of brms all_sp_types <- function() { - c("mo", "me", "mi") + c("mo", "me", "mi", "re") } # classes used to set up special effects terms @@ -643,3 +726,7 @@ is.me_term <- function(x) { is.mi_term <- function(x) { inherits(x, "mi_term") } + +is.re_term <- function(x) { + inherits(x, "re_term") +} diff --git a/R/misc.R b/R/misc.R index 292187628..713b98402 100644 --- a/R/misc.R +++ b/R/misc.R @@ -143,6 +143,7 @@ find_elements <- function(x, ..., ls = list(), fun = '%in%') { # find rows of 'x' matching columns passed via 'ls' and '...' # similar to 'find_elements' but for matrix like objects +# TODO: rename ls and fun to .ls and .fun to prevent name clashing find_rows <- function(x, ..., ls = list(), fun = '%in%') { x <- as.data.frame(x) if (!nrow(x)) { @@ -162,6 +163,11 @@ find_rows <- function(x, ..., ls = list(), fun = '%in%') { out } +# short form of which(find_rows()) +which_rows <- function(x, ..., ls = list(), fun = '%in%') { + which(find_rows(x, ..., ls = ls, fun = fun)) +} + # subset 'x' using arguments passed via 'ls' and '...' subset2 <- function(x, ..., ls = list(), fun = '%in%') { x[find_rows(x, ..., ls = ls, fun = fun), , drop = FALSE] diff --git a/R/predictor.R b/R/predictor.R index f8bf2192c..42783787f 100644 --- a/R/predictor.R +++ b/R/predictor.R @@ -181,6 +181,13 @@ predictor_sp <- function(prep, i) { for (j in seq_along(sp[["idxl"]])) { eval_list[[names(sp[["idxl"]])[j]]] <- p(sp[["idxl"]][[j]], i, row = FALSE) } + for (j in seq_along(sp[["r"]])) { + # r is not subsetted here since subsetting is handled via Jr + # the advantages of this approach is a reduced memory requirement + # as only the draws per level instead of per observation need to be stored + eval_list[[paste0("r_", j)]] <- sp[["r"]][[j]] + eval_list[[paste0("Jr_", j)]] <- p(sp[["Jr"]][[j]], i) + } for (j in seq_along(sp[["Csp"]])) { eval_list[[paste0("Csp_", j)]] <- p(sp[["Csp"]][[j]], i, row = FALSE) } diff --git a/R/prepare_predictions.R b/R/prepare_predictions.R index 1b0d9b415..78ecf9eb4 100644 --- a/R/prepare_predictions.R +++ b/R/prepare_predictions.R @@ -252,17 +252,19 @@ prepare_predictions_fe <- function(bframe, draws, sdata, ...) { } # prepare predictions of special effects terms -prepare_predictions_sp <- function(bframe, draws, sdata, new = FALSE, ...) { +prepare_predictions_sp <- function(bframe, draws, sdata, prep_re = list(), + new = FALSE, ...) { stopifnot(is.bframel(bframe)) out <- list() spframe <- bframe$frame$sp - meframe <- bframe$frame$me if (!has_rows(spframe)) { return(out) } p <- usc(combine_prefix(bframe)) resp <- usc(bframe$resp) # prepare calls evaluated in sp_predictor + meframe <- bframe$frame$me + reframe <- unique(Reduce(rbind, spframe$reframe)) out$calls <- vector("list", nrow(spframe)) for (i in seq_along(out$calls)) { call <- spframe$joint_call[[i]] @@ -281,6 +283,27 @@ prepare_predictions_sp <- function(bframe, draws, sdata, new = FALSE, ...) { new_mi <- paste0("Yl_", spframe$vars_mi[[i]], idx_mi) call <- rename(call, spframe$calls_mi[[i]], new_mi) } + if (!is.null(spframe$calls_re[[i]])) { + if (NROW(spframe$reframe[[i]]) < length(spframe$calls_re[[i]])) { + # this will lead to an error upon evaluation only which is important + # as parts of prepare_predictions may not actually be evaluated in the end + new_re <- paste0( + "stop2('Cannot find all varying coefficients required in ", + spframe$joint_call[[i]], ". ", + "Did you exclude them via argument re_formula?')" + ) + } else { + # all required varying coefficients are present + new_re <- rep(NA, length(spframe$calls_re[[i]])) + for (j in seq_along(spframe$calls_re[[i]])) { + # the ordering is in reference to the unique re terms in the formula + k <- which_rows_reframe(spframe$reframe[[i]][j, ], reframe) + stopifnot(length(k) == 1L) + new_re[j] <- paste0("r_", k, "[, Jr_", k, ", drop = FALSE]") + } + } + call <- rename(call, spframe$calls_re[[i]], new_re) + } if (spframe$Ic[i] > 0) { str_add(call) <- paste0(" * Csp_", spframe$Ic[i]) } @@ -397,6 +420,20 @@ prepare_predictions_sp <- function(bframe, draws, sdata, new = FALSE, ...) { "Treating original data as if it was new data as a workaround." ) } + # prepare predictions for 're' terms + if (NROW(reframe)) { + out$r <- out$Jr <- vector("list", nrow(reframe)) + for (i in seq_rows(reframe)) { + rf <- reframe[i, ] + pr <- prep_re[[rf$gr]] + select <- which_rows_reframe(rf, pr$reframe) + nlevels <- length(pr$levels) + out$r[[i]] <- subset_matrix_ranefs(pr$rdraws, select, nlevels) + # the order of levels in pr$draws follows that of pr$levels + # so the matching approach below is always valid + out$Jr[[i]] <- match(pr$gf[[1]], pr$levels) + } + } # prepare covariates ncovars <- max(spframe$Ic) out$Csp <- vector("list", ncovars) @@ -569,7 +606,7 @@ prepare_predictions_re_global <- function(bframe, draws, sdata, old_reframe, res # used (new) levels are currently not available within the bframe argument # since it has been computed with the old data (but new formula) the likely # reason for this choice was to avoid running validate_newdata twice (in - # prepare_predictions and standata). Perhaps this choice can can be + # prepare_predictions and standata). Perhaps this choice can be # reconsidered in the future while avoiding multiple validate_newdata runs out <- named_list(groups, list()) for (g in groups) { @@ -590,8 +627,7 @@ prepare_predictions_re_global <- function(bframe, draws, sdata, old_reframe, res ) } # only prepare predictions of effects specified in the new formula - cols_match <- c("coef", "resp", "dpar", "nlpar") - used_rpars <- which(find_rows(old_reframe_g, ls = reframe_g[cols_match])) + used_rpars <- which_rows_reframe(old_reframe_g, reframe_g) used_rpars <- outer(seq_len(nlevels), (used_rpars - 1) * nlevels, "+") used_rpars <- as.vector(used_rpars) rdraws <- rdraws[, used_rpars, drop = FALSE] @@ -621,7 +657,7 @@ prepare_predictions_re_global <- function(bframe, draws, sdata, old_reframe, res rdraws <- cbind(rdraws, new_rdraws) # keep only those levels actually used in the current data levels <- unique(unlist(gf)) - rdraws <- subset_levels(rdraws, levels, nranef) + rdraws <- subset_matrix_levels(rdraws, levels, nranef) # store all information required in 'prepare_predictions_re' out[[g]]$reframe <- reframe_g out[[g]]$rdraws <- rdraws @@ -661,12 +697,13 @@ prepare_predictions_re <- function(bframe, sdata, prep_re = list(), rdraws <- prep_re[[g]]$rdraws nranef <- prep_re[[g]]$nranef levels <- prep_re[[g]]$levels + nlevels <- length(levels) max_level <- prep_re[[g]]$max_level gf <- prep_re[[g]]$gf weights <- prep_re[[g]]$weights - # TODO: define 'select' according to parameter names not by position + # TODO: define 'select' according to names instead of position? # store draws and corresponding data in the output - # special group-level terms (mo, me, mi) + # special group-level terms (mo, mi, etc.) reframe_g_px_sp <- subset2(reframe_g_px, type = "sp") if (nrow(reframe_g_px_sp)) { Z <- matrix(1, length(gf[[1]])) @@ -675,9 +712,7 @@ prepare_predictions_re <- function(bframe, sdata, prep_re = list(), # select from all varying effects of that group select <- find_rows(reframe_g, ls = px) & reframe_g$coef == co & reframe_g$type == "sp" - select <- which(select) - select <- select + nranef * (seq_along(levels) - 1) - out[["rsp"]][[co]][[g]] <- rdraws[, select, drop = FALSE] + out[["rsp"]][[co]][[g]] <- subset_matrix_ranefs(rdraws, select, nlevels) } } # category specific group-level terms @@ -693,9 +728,7 @@ prepare_predictions_re <- function(bframe, sdata, prep_re = list(), # select from all varying effects of that group select <- find_rows(reframe_g, ls = px) & grepl(index, reframe_g$coef) & reframe_g$type == "cs" - select <- which(select) - select <- as.vector(outer(select, nranef * (seq_along(levels) - 1), "+")) - out[["rcs"]][[g]][[i]] <- rdraws[, select, drop = FALSE] + out[["rcs"]][[g]][[i]] <- subset_matrix_ranefs(rdraws, select, nlevels) } } # basic group-level terms @@ -714,9 +747,7 @@ prepare_predictions_re <- function(bframe, sdata, prep_re = list(), out[["Z"]][[g]] <- prepare_Z(Z, gf, max_level, weights) # select from all varying effects of that group select <- find_rows(reframe_g, ls = px) & reframe_g$type %in% c("", "mmc") - select <- which(select) - select <- as.vector(outer(select, nranef * (seq_along(levels) - 1), "+")) - out[["r"]][[g]] <- rdraws[, select, drop = FALSE] + out[["r"]][[g]] <- subset_matrix_ranefs(rdraws, select, nlevels) } } out @@ -933,17 +964,35 @@ pseudo_prep_for_mixture <- function(prep, comp, draw_ids = NULL) { structure(out, class = "brmsprep") } -# take relevant cols of a matrix of group-level terms +# subset the columns of a matrix of group-level terms # if only a subset of levels is provided (for newdata) # @param x a matrix typically draws of r or Z design matrices -# draws need to be stored in row major order +# draws need to be stored in row major order that is +# all effects of the same level in consequitive columns # @param levels grouping factor levels to keep -# @param nranef number of group-level effects -subset_levels <- function(x, levels, nranef) { - take_levels <- ulapply(levels, - function(l) ((l - 1) * nranef + 1):(l * nranef) - ) - x[, take_levels, drop = FALSE] +# @param nranef total number of group-level effects +subset_matrix_levels <- function(x, levels, nranef) { + if (is.logical(levels)) { + levels <- which(levels) + } + take <- ulapply(levels, function(l) ((l - 1) * nranef + 1):(l * nranef)) + x[, take, drop = FALSE] +} + +# subset the columns of a matrix of group-level terms +# if only a subset of ranefs is required +# @param x a matrix typically draws of r or Z design matrices +# draws need to be stored in row major order that is +# all effects of the same level in consequitive columns +# @param ranef group-level effects to keep +# @param nlevels total number of grouping factor levels +subset_matrix_ranefs <- function(x, ranefs, nlevels) { + if (is.logical(ranefs)) { + ranefs <- which(ranefs) + } + nranef <- ncol(x) / nlevels + take <- as.vector(outer(ranefs, nranef * (seq_len(nlevels) - 1), "+")) + x[, take, drop = FALSE] } # transform x from column to row major order @@ -963,7 +1012,7 @@ column_to_row_major_order <- function(x, nranef) { # @param gf (list of) vectors containing grouping factor values # @param weights optional (list of) weights of the same length as gf # @param max_level maximal level of 'gf' -# @return a sparse matrix representation of Z +# @return a sparse matrix representation of Z in row major order prepare_Z <- function(Z, gf, max_level = NULL, weights = NULL) { if (!is.list(Z)) { Z <- list(Z) @@ -987,7 +1036,7 @@ prepare_Z <- function(Z, gf, max_level = NULL, weights = NULL) { MoreArgs = nlist(max_level) ) Z <- Reduce("+", Z) - subset_levels(Z, levels, nranef) + subset_matrix_levels(Z, levels, nranef) } # expand a matrix into a sparse matrix of higher dimension diff --git a/R/priors.R b/R/priors.R index d3342aa66..0ff5e54e8 100644 --- a/R/priors.R +++ b/R/priors.R @@ -1495,7 +1495,7 @@ validate_special_prior.btl <- function(x, prior, allow_autoscale = TRUE, ...) { # it is still the same as the order in the Stan code special_classes <- c("b", "sds", "sdgp", "ar", "ma", "sderr", "sdcar", "sd") for (sc in special_classes) { - index <- which(find_rows(prior, class = sc, coef = "", group = "", ls = px)) + index <- which_rows(prior, class = sc, coef = "", group = "", ls = px) if (!length(index)) { next } diff --git a/R/stan-predictor.R b/R/stan-predictor.R index 2384ea1ca..fb2f7bb43 100644 --- a/R/stan-predictor.R +++ b/R/stan-predictor.R @@ -963,6 +963,17 @@ stan_sp <- function(bframe, prior, stanvars, threads, normalize, ...) { eta <- rename(eta, spframe$calls_mi[[i]], new_mi) str_add(out$pll_args) <- glue(", vector Yl_{spframe$vars_mi[[i]]}") } + if (!is.null(spframe$calls_re[[i]])) { + r <- spframe$reframe[[i]] + if (NROW(r) < length(spframe$calls_re[[i]])) { + stop2("Cannot find all varying coefficients required in ", + spframe$joint_call[[i]], ".") + } + idp <- paste0(r$id, usc(combine_prefix(r))) + idresp <- paste0(r$id, usc(r$resp)) + new_re <- glue("r_{idp}_{r$cn}[J_{idresp}{n}]") + eta <- rename(eta, spframe$calls_re[[i]], new_re) + } if (spframe$Ic[i] > 0) { str_add(eta) <- glue(" * Csp{p}_{spframe$Ic[i]}{n}") } diff --git a/man/re.Rd b/man/re.Rd new file mode 100644 index 000000000..7305f0bad --- /dev/null +++ b/man/re.Rd @@ -0,0 +1,62 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/formula-sp.R +\name{re} +\alias{re} +\title{Group-level effects as predictors in \pkg{brms} Models} +\usage{ +re(gr, coef = "Intercept", resp = "", dpar = "", nlpar = "") +} +\arguments{ +\item{gr}{Name of the grouping factor of the group-level effect +to be used as predictor.} + +\item{coef}{Optional name of the coefficient of the group-level effect. +Defaults to \code{"Intercept"}.} + +\item{resp}{Optional name of the response variable of the group-level effect.} + +\item{dpar}{Optional name of the distributional parameter of the group-level effect.} + +\item{nlpar}{Optional name of the non-linear parameter of the group-level effect.} +} +\description{ +Specify a group-level predictor term in \pkg{brms}. That is, +use group-level effects defined somewhere in the model as +predictors in another part of the model. The function does not +evaluate its arguments -- it exists purely to help set up a model. +} +\examples{ +\dontrun{ +# use the group-level intercept of 'AY' for parameter 'ult' +# as predictor for the residual standard deviation 'sigma' +# multiplying by 1000 reduces the scale of 'ult' to roughly unity +bform <- bf( + cum ~ 1000 * ult * (1 - exp(-(dev/theta)^omega)), + ult ~ 1 + (1|AY), omega ~ 1, theta ~ 1, + sigma ~ re(AY, nlpar = "ult"), + nl = TRUE +) +bprior <- c( + prior(normal(5, 1), nlpar = "ult"), + prior(normal(1, 2), nlpar = "omega"), + prior(normal(45, 10), nlpar = "theta"), + prior(normal(0, 0.5), dpar = "sigma") +) + +fit <- brm( + bform, data = loss, + family = gaussian(), + prior = bprior, + control = list(adapt_delta = 0.9), + chains = 2 +) +summary(fit) + +# shows how sigma varies as a function of the AY levels +conditional_effects(fit, "AY", dpar = "sigma", re_formula = NULL) +} + +} +\seealso{ +\code{\link{brmsformula}} +} diff --git a/tests/local/tests.models-5.R b/tests/local/tests.models-5.R index 7baadd3d2..ff65a96f6 100644 --- a/tests/local/tests.models-5.R +++ b/tests/local/tests.models-5.R @@ -198,6 +198,40 @@ test_that("alternative algorithms can be used", suppressWarnings({ expect_is(fit, "brmsfit") })) +test_that("Models with re-predictor terms yield sensible outputs", { + fit <- brm( + bf(cum ~ ult * (1 - exp(-(dev/theta)^omega)), + ult ~ 1 + (1|AY), omega ~ 1, theta ~ 1, + sigma ~ re(AY, nlpar = "ult"), + nl = TRUE + ), + data = loss, family = gaussian(), + prior = c( + prior(normal(5000, 1000), nlpar = "ult"), + prior(normal(1000, 300), class = "sd", nlpar = "ult"), + prior(normal(1, 2), nlpar = "omega"), + prior(normal(45, 10), nlpar = "theta"), + prior(normal(0, 0.05), dpar = "sigma") + ), + control = list(adapt_delta = 0.9), + chains = 1, seed = 125314 + ) + + summary(fit) + expect_range(loo(fit)$estimates[3, 1], 700, 730) + + ce <- conditional_effects(fit, dpar = "sigma", re_formula = NULL) + expect_ggplot(plot(ce, ask = FALSE)[[1]]) + expect_error( + conditional_effects(fit, dpar = "sigma"), + "Some of the varying coefficients required" + ) + # check if predictions without re terms can be performed + # while random effects are excluded + ce <- conditional_effects(fit, "dev", dpar = "mu", re_formula = NA) + expect_ggplot(plot(ce, ask = FALSE)[[1]]) +}) + test_that(paste( "Families sratio() and cratio() are equivalent for symmetric distribution", "functions (here only testing the logit link)" diff --git a/tests/testthat/tests.stancode.R b/tests/testthat/tests.stancode.R index 88b221c4a..49b5d2826 100644 --- a/tests/testthat/tests.stancode.R +++ b/tests/testthat/tests.stancode.R @@ -1134,6 +1134,42 @@ test_that("monotonic effects appear in the Stan code", { ) }) +test_that("Stan code for re predictor terms is correct", { + dat <- data.frame( + y = rnorm(100, mean = rep(1:10, each = 10)), + x = rnorm(100), gr = rep(1:10, each = 10) + ) + + bform <- bf(y ~ x + (1 + x|gr), sigma ~ re(gr, coef = "x")) + scode <- make_stancode(bform, dat) + expect_match2(scode, "sigma[n] += (bsp_sigma[1]) * r_1_2[J_1[n]];") + + bform <- bf(y ~ (1|gr)) + + bf(x ~ (1|gr) + re(gr, resp = "y")) + + set_rescor(FALSE) + scode <- make_stancode(bform, dat) + expect_match2(scode, + "mu_x[n] += (bsp_x[1]) * r_1_y_1[J_1_y[n]] + r_2_x_1[J_2_x[n]] * Z_2_x_1[n];" + ) + + bform <- bf( + y ~ a + b, + a ~ x + (1 + x |id| gr) + re(gr, coef = "Intercept", dpar = "sigma"), + b ~ (1 + x |id| gr) + re(gr, coef = "x", nlpar = "a"), + sigma ~ (1 |id| gr), + nl = TRUE + ) + scode <- make_stancode(bform, dat) + expect_match2(scode, "nlp_a[n] += (bsp_a[1]) * r_1_sigma_1[J_1[n]] +") + expect_match2(scode, "nlp_b[n] += (bsp_b[1]) * r_1_a_3[J_1[n]] +") + + bform <- bf(y ~ x + (1 + x|gr), sigma ~ re(gr, coef = "z")) + expect_error( + make_stancode(bform, dat), + "Cannot find all varying coefficients" + ) +}) + test_that("Stan code for non-linear models is correct", { flist <- list(a ~ x, b ~ z + (1|g)) data <- data.frame(