Skip to content

Commit

Permalink
Merge pull request #1624 from avehtari/update_loo_k_threshold
Browse files Browse the repository at this point in the history
update doc and loo recommendations
  • Loading branch information
paul-buerkner authored Mar 19, 2024
2 parents 9f94b1d + 5ed2b2e commit 332cd55
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 14 deletions.
26 changes: 19 additions & 7 deletions R/loo.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@
#' details.
#' @param reloo Logical; Indicate whether \code{\link{reloo}}
#' should be applied on problematic observations. Defaults to \code{FALSE}.
#' @param k_threshold The threshold at which pareto \eqn{k}
#' estimates are treated as problematic. Defaults to \code{0.7}.
#' Only used if argument \code{reloo} is \code{TRUE}.
#' @param k_threshold The Pareto \eqn{k} threshold for which observations
#' \code{\link{loo_moment_match}} or \code{\link{reloo}} is applied if
#' argument \code{moment_match} or \code{reloo} is \code{TRUE}.
#' Defaults to \code{0.7}.
#' See \code{\link[loo:pareto-k-diagnostic]{pareto_k_ids}} for more details.
#' @param save_psis Should the \code{"psis"} object created internally be saved
#' in the returned object? For more details see \code{\link[loo:loo]{loo}}.
Expand Down Expand Up @@ -677,13 +678,18 @@ recommend_loo_options <- function(loo, k_threshold = 0.7, moment_match = FALSE,
} else {
model_name <- ""
}
n <- length(loo::pareto_k_ids(loo, threshold = k_threshold))
ndraws <- dim(loo)[1] %||% Inf
if (n > 0 && ndraws < 2200) {
n <- n2 <- length(loo::pareto_k_ids(loo, threshold = k_threshold))
# for small number of draws the threshold may be smaller than 0.7
k_threshold2 <- ps_khat_threshold(ndraws)
if (k_threshold2 < k_threshold) {
n2 <- length(loo::pareto_k_ids(loo, threshold = k_threshold2))
}
if (n2 > n && k_threshold2 <= 0.7) {
warning2(
"Found ", n, " observations with a pareto_k > ", k_threshold,
"Found ", n2, " observations with a pareto_k > ", round(k_threshold2, 2),
model_name, ". We recommend to run more iterations to get at least ",
"about 2200 posterior draws for more reliable pareteo_k estimation."
"about 2200 posterior draws to improve LOO-CV approximation accuracy."
)
out <- "loo_more_draws"
} else if (n > 0 && !moment_match) {
Expand Down Expand Up @@ -991,3 +997,9 @@ print.iclist <- function(x, digits = 2, ...) {
print(round(mat, digits = digits), na.print = "")
invisible(x)
}

# Pareto-smoothing k-hat threshold
# not yet exported by loo so copied over here for now
ps_khat_threshold <- function(S, ...) {
1 - 1 / log10(S)
}
4 changes: 2 additions & 2 deletions R/loo_moment_match.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
#' @inheritParams predict.brmsfit
#' @param x An object of class \code{brmsfit}.
#' @param loo An object of class \code{loo} originally created from \code{x}.
#' @param k_threshold The threshold at which Pareto \eqn{k}
#' estimates are treated as problematic. Defaults to \code{0.7}.
#' @param k_threshold The Pareto \eqn{k} threshold for which observations
#' moment matching is applied. Defaults to \code{0.7}.
#' See \code{\link[loo:pareto-k-diagnostic]{pareto_k_ids}}
#' for more details.
#' @param check Logical; If \code{TRUE} (the default), some checks
Expand Down
7 changes: 4 additions & 3 deletions man/loo.brmsfit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/loo_moment_match.brmsfit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 332cd55

Please sign in to comment.