Skip to content

Commit

Permalink
prevent memory leaks in kfold and reloo
Browse files Browse the repository at this point in the history
  • Loading branch information
paul-buerkner committed Jan 26, 2024
1 parent 2aed221 commit 6ca2dfb
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 42 deletions.
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ Package: brms
Encoding: UTF-8
Type: Package
Title: Bayesian Regression Models using 'Stan'
Version: 2.20.11
Version: 2.20.12
Date: 2024-01-23
Authors@R:
c(person("Paul-Christian", "Bürkner", email = "[email protected]",
Expand Down Expand Up @@ -35,6 +35,7 @@ Imports:
glue (>= 1.3.0),
rlang (>= 1.0.0),
future (>= 1.19.0),
future.apply (>= 1.0.0),
matrixStats,
nleqslv,
nlme,
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ as backend. (#1544)
* Remove some remaining uses of Stan's old array syntax.
* Ensure compatibility with the latest `splines2` package version. (#1580)
* Fix output of `rmulti_normal` thanks to Ven Popov. (#1588)
* Prevent memory leaks when executing `kfold` or `reloo` in parallel.

# brms 2.20.3

Expand Down
49 changes: 24 additions & 25 deletions R/kfold.R
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,9 @@ kfold.brmsfit <- function(x, ..., K = 10, Ksub = NULL, folds = NULL,
Ksub <- sort(Ksub)
}

# ensure that the model can be run in the current R session
x <- recompile_model(x, recompile = recompile)

# split dots for use in log_lik and update
dots <- list(...)
ll_arg_names <- arg_names("log_lik")
Expand All @@ -250,73 +253,69 @@ kfold.brmsfit <- function(x, ..., K = 10, Ksub = NULL, folds = NULL,
ll_args$resp <- resp
ll_args$combine <- TRUE
up_args <- dots[setdiff(names(dots), ll_arg_names)]
up_args$object <- x
up_args$refresh <- 0

# function to be run inside future::future
.kfold_k <- function(k) {
message("Fitting model ", k, " out of ", K)
if (fold_type == "loo" && !is.null(group)) {
omitted <- which(folds == folds[k])
predicted <- k
} else {
omitted <- predicted <- which(folds == k)
}
newdata_omitted <- newdata[-omitted, , drop = FALSE]
up_args$object <- x
up_args$newdata <- newdata_omitted
up_args$data2 <- subset_data2(newdata2, -omitted)
fit <- SW(do_call(update, up_args))
# rm() trys to avoid memory leaks during parallel use
rm(up_args)

ll_args$object <- fit
ll_args$newdata <- newdata[predicted, , drop = FALSE]
ll_args$newdata2 <- subset_data2(newdata2, predicted)
lppds <- do_call(log_lik, ll_args)
rm(ll_args)

out <- nlist(lppds, omitted, predicted)
if (save_fits) {
out$fit <- fit
}
rm(fit)
return(out)
}

futures <- vector("list", length(Ksub))
lppds <- obs_order <- vector("list", length(Ksub))
# TODO: separate parallel and non-parallel code to enable better printing?
future_args$X <- Ksub
future_args$FUN <- .kfold_k
future_args$future.seed <- TRUE
out <- do_call("future_lapply", future_args, pkg = "future.apply")

lppds <- pred_obs <- vector("list", length(Ksub))
if (save_fits) {
fits <- array(list(), dim = c(length(Ksub), 3))
dimnames(fits) <- list(NULL, c("fit", "omitted", "predicted"))
}

x <- recompile_model(x, recompile = recompile)
future_args$FUN <- .kfold_k
future_args$seed <- TRUE
for (k in Ksub) {
ks <- match(k, Ksub)
message("Fitting model ", k, " out of ", K)
future_args$args <- list(k)
futures[[ks]] <- do_call("futureCall", future_args, pkg = "future")
}
for (k in Ksub) {
ks <- match(k, Ksub)
tmp <- future::value(futures[[ks]])
for (i in seq_along(Ksub)) {
if (save_fits) {
fits[ks, ] <- tmp[c("fit", "omitted", "predicted")]
fits[i, ] <- out[[i]][c("fit", "omitted", "predicted")]
}
obs_order[[ks]] <- tmp$predicted
lppds[[ks]] <- tmp$lppds
pred_obs[[i]] <- out[[i]]$predicted
lppds[[i]] <- out[[i]]$lppds
}

lppds <- do_call(cbind, lppds)
elpds <- apply(lppds, 2, log_mean_exp)
# make sure elpds are put back in the right order
obs_order <- unlist(obs_order)
elpds <- elpds[order(obs_order)]
pred_obs <- unlist(pred_obs)
elpds <- elpds[order(pred_obs)]
# compute effective number of parameters
ll_args$object <- x
ll_args$newdata <- newdata
ll_args$newdata2 <- newdata2
if (length(Ksub) < K) {
# select the correct subset of predicted observations
pred_obs_sorted <- sort(pred_obs)
ll_args$newdata <- ll_args$newdata[pred_obs_sorted, , drop = FALSE]
ll_args$newdata2 <- subset_data2(ll_args$newdata2, pred_obs_sorted)
}
ll_full <- do_call(log_lik, ll_args)
lpds <- apply(ll_full, 2, log_mean_exp)
ps <- lpds - elpds
Expand Down
29 changes: 13 additions & 16 deletions R/reloo.R
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ reloo.brmsfit <- function(x, loo, k_threshold = 0.7, newdata = NULL,
return(loo)
}

# ensure that the model can be run in the current R session
x <- recompile_model(x, recompile = recompile)

# split dots for use in log_lik and update
dots <- list(...)
ll_arg_names <- arg_names("log_lik")
Expand All @@ -105,13 +108,16 @@ reloo.brmsfit <- function(x, loo, k_threshold = 0.7, newdata = NULL,
# cores is used in both log_lik and update
up_arg_names <- setdiff(names(dots), setdiff(ll_arg_names, "cores"))
up_args <- dots[up_arg_names]
up_args$object <- x
up_args$refresh <- 0

.reloo <- function(j) {
message(
"\nFitting model ", j, " out of ", J,
" (leaving out observation ", obs[j], ")"
)
omitted <- obs[j]
mf_omitted <- mf[-omitted, , drop = FALSE]
fit_j <- x
up_args$object <- fit_j
up_args$newdata <- mf_omitted
up_args$data2 <- subset_data2(x$data2, -omitted)
fit_j <- SW(do_call(update, up_args))
Expand All @@ -121,25 +127,16 @@ reloo.brmsfit <- function(x, loo, k_threshold = 0.7, newdata = NULL,
return(do_call(log_lik, ll_args))
}

lls <- futures <- vector("list", J)
message(
J, " problematic observation(s) found.",
"\nThe model will be refit ", J, " times."
)
x <- recompile_model(x, recompile = recompile)
# TODO: separate parallel and non-parallel code to enable better printing?
future_args$X <- seq_len(J)
future_args$FUN <- .reloo
future_args$seed <- TRUE
for (j in seq_len(J)) {
message(
"\nFitting model ", j, " out of ", J,
" (leaving out observation ", obs[j], ")"
)
future_args$args <- list(j)
futures[[j]] <- do_call("futureCall", future_args, pkg = "future")
}
for (j in seq_len(J)) {
lls[[j]] <- future::value(futures[[j]])
}
future_args$future.seed <- TRUE
lls <- do_call("future_lapply", future_args, pkg = "future.apply")

# most of the following code is taken from rstanarm:::reloo
# compute elpd_{loo,j} for each of the held out observations
elpd_loo <- ulapply(lls, log_mean_exp)
Expand Down

0 comments on commit 6ca2dfb

Please sign in to comment.