diff --git a/NEWS.md b/NEWS.md index 51076d0..543c642 100644 --- a/NEWS.md +++ b/NEWS.md @@ -3,6 +3,7 @@ ## Documentation - More compact README. +- Updated function description. # kernelshap 0.7.0 diff --git a/R/additive_shap.R b/R/additive_shap.R index b8165f8..8c4fcec 100644 --- a/R/additive_shap.R +++ b/R/additive_shap.R @@ -9,19 +9,20 @@ #' - `gam::gam()`, #' - [survival::coxph()], and #' - [survival::survreg()]. -#' +#' #' The SHAP values are extracted via `predict(object, newdata = X, type = "terms")`, -#' a logic heavily inspired by `fastshap:::explain.lm(..., exact = TRUE)`. +#' a logic adopted from `fastshap:::explain.lm(..., exact = TRUE)`. #' Models with interactions (specified via `:` or `*`), or with terms of #' multiple features like `log(x1/x2)` are not supported. -#' +#' #' Note that the SHAP values obtained by [additive_shap()] are expected to #' match those of [permshap()] and [kernelshap()] as long as their background #' data equals the full training data (which is typically not feasible). #' -#' @inheritParams kernelshap -#' @param X Dataframe with rows to be explained. Will be used like +#' @param object Fitted additive model. +#' @param X Dataframe with rows to be explained. Passed to #' `predict(object, newdata = X, type = "terms")`. +#' @param verbose Set to `FALSE` to suppress messages. #' @param ... Currently unused. #' @returns #' An object of class "kernelshap" with the following components: @@ -38,7 +39,7 @@ #' fit <- lm(Sepal.Length ~ ., data = iris) #' s <- additive_shap(fit, head(iris)) #' s -#' +#' #' # MODEL TWO: More complicated (but not very clever) formula #' fit <- lm( #' Sepal.Length ~ poly(Sepal.Width, 2) + log(Petal.Length) + log(Sepal.Width), @@ -46,7 +47,7 @@ #' ) #' s_add <- additive_shap(fit, head(iris)) #' s_add -#' +#' #' # Equals kernelshap()/permshap() when background data is full training data #' s_kernel <- kernelshap( #' fit, head(iris[c("Sepal.Width", "Petal.Length")]), bg_X = iris @@ -59,28 +60,28 @@ additive_shap <- function(object, X, verbose = TRUE, ...) { if (any(attr(stats::terms(object), "order") > 1)) { stop("Additive SHAP not appropriate for models with interactions.") } - + txt <- "Exact additive SHAP via predict(..., type = 'terms')" if (verbose) { message(txt) } - + S <- stats::predict(object, newdata = X, type = "terms") rownames(S) <- NULL - + # Baseline value b <- as.vector(attr(S, "constant")) if (is.null(b)) { b <- 0 } - + # Which columns of X are used in each column of S? s_names <- colnames(S) cols_used <- lapply(s_names, function(z) all.vars(stats::reformulate(z))) if (any(lengths(cols_used) > 1L)) { stop("The formula contains terms with multiple features (not supported).") } - + # Collapse all columns in S using the same column in X and rename accordingly mapping <- split( s_names, factor(unlist(cols_used), levels = colnames(X)), drop = TRUE @@ -89,7 +90,7 @@ additive_shap <- function(object, X, verbose = TRUE, ...) { cbind, lapply(mapping, function(z) rowSums(S[, z, drop = FALSE], na.rm = TRUE)) ) - + structure( list( S = S, diff --git a/R/kernelshap.R b/R/kernelshap.R index e02c4f2..de65a5c 100644 --- a/R/kernelshap.R +++ b/R/kernelshap.R @@ -1,19 +1,26 @@ #' Kernel SHAP -#' -#' Efficient implementation of Kernel SHAP, see Lundberg and Lee (2017), and +#' +#' @description +#' Efficient implementation of Kernel SHAP, see Lundberg and Lee (2017), and #' Covert and Lee (2021), abbreviated by CL21. -#' For up to \eqn{p=8} features, the resulting Kernel SHAP values are exact regarding -#' the selected background data. For larger \eqn{p}, an almost exact -#' hybrid algorithm involving iterative sampling is used, see Details. -#' For up to eight features, however, we recomment to use [permshap()]. +#' For up to \eqn{p=8} features, the resulting Kernel SHAP values are exact regarding +#' the selected background data. For larger \eqn{p}, an almost exact +#' hybrid algorithm combining exact calculations and iterative sampling is used, +#' see Details. +#' +#' Note that (exact) Kernel SHAP is only an approximation of (exact) permutation SHAP. +#' Thus, for up to eight features, we recommend [permshap()]. For more features, +#' [permshap()] is slow compared the optimized hybrid strategy of our Kernel SHAP +#' implementation. #' -#' Pure iterative Kernel SHAP sampling as in Covert and Lee (2021) works like this: -#' -#' 1. A binary "on-off" vector \eqn{z} is drawn from \eqn{\{0, 1\}^p} -#' such that its sum follows the SHAP Kernel weight distribution +#' @details +#' The pure iterative Kernel SHAP sampling as in Covert and Lee (2021) works like this: +#' +#' 1. A binary "on-off" vector \eqn{z} is drawn from \eqn{\{0, 1\}^p} +#' such that its sum follows the SHAP Kernel weight distribution #' (normalized to the range \eqn{\{1, \dots, p-1\}}). -#' 2. For each \eqn{j} with \eqn{z_j = 1}, the \eqn{j}-th column of the -#' original background data is replaced by the corresponding feature value \eqn{x_j} +#' 2. For each \eqn{j} with \eqn{z_j = 1}, the \eqn{j}-th column of the +#' original background data is replaced by the corresponding feature value \eqn{x_j} #' of the observation to be explained. #' 3. The average prediction \eqn{v_z} on the data of Step 2 is calculated, and the #' average prediction \eqn{v_0} on the background data is subtracted. @@ -21,141 +28,141 @@ #' matrix \eqn{Z} (each row equals one of the \eqn{z}) and a vector \eqn{v} of #' shifted predictions. #' 5. \eqn{v} is regressed onto \eqn{Z} under the constraint that the sum of the -#' coefficients equals \eqn{v_1 - v_0}, where \eqn{v_1} is the prediction of the +#' coefficients equals \eqn{v_1 - v_0}, where \eqn{v_1} is the prediction of the #' observation to be explained. The resulting coefficients are the Kernel SHAP values. -#' +#' #' This is repeated multiple times until convergence, see CL21 for details. -#' -#' A drawback of this strategy is that many (at least 75%) of the \eqn{z} vectors will -#' have \eqn{\sum z \in \{1, p-1\}}, producing many duplicates. Similarly, at least 92% -#' of the mass will be used for the \eqn{p(p+1)} possible vectors with -#' \eqn{\sum z \in \{1, 2, p-2, p-1\}}. -#' This inefficiency can be fixed by a hybrid strategy, combining exact calculations +#' +#' A drawback of this strategy is that many (at least 75%) of the \eqn{z} vectors will +#' have \eqn{\sum z \in \{1, p-1\}}, producing many duplicates. Similarly, at least 92% +#' of the mass will be used for the \eqn{p(p+1)} possible vectors with +#' \eqn{\sum z \in \{1, 2, p-2, p-1\}}. +#' This inefficiency can be fixed by a hybrid strategy, combining exact calculations #' with sampling. -#' +#' #' The hybrid algorithm has two steps: -#' 1. Step 1 (exact part): There are \eqn{2p} different on-off vectors \eqn{z} with -#' \eqn{\sum z \in \{1, p-1\}}, covering a large proportion of the Kernel SHAP -#' distribution. The degree 1 hybrid will list those vectors and use them according -#' to their weights in the upcoming calculations. Depending on \eqn{p}, we can also go -#' a step further to a degree 2 hybrid by adding all \eqn{p(p-1)} vectors with -#' \eqn{\sum z \in \{2, p-2\}} to the process etc. The necessary predictions are +#' 1. Step 1 (exact part): There are \eqn{2p} different on-off vectors \eqn{z} with +#' \eqn{\sum z \in \{1, p-1\}}, covering a large proportion of the Kernel SHAP +#' distribution. The degree 1 hybrid will list those vectors and use them according +#' to their weights in the upcoming calculations. Depending on \eqn{p}, we can also go +#' a step further to a degree 2 hybrid by adding all \eqn{p(p-1)} vectors with +#' \eqn{\sum z \in \{2, p-2\}} to the process etc. The necessary predictions are #' obtained along with other calculations similar to those described in CL21. #' 2. Step 2 (sampling part): The remaining weight is filled by sampling vectors z -#' according to Kernel SHAP weights renormalized to the values not yet covered by Step 1. +#' according to Kernel SHAP weights renormalized to the values not yet covered by Step 1. #' Together with the results from Step 1 - correctly weighted - this now forms a -#' complete iteration as in CL21. The difference is that most mass is covered by exact -#' calculations. Afterwards, the algorithm iterates until convergence. -#' The output of Step 1 is reused in every iteration, leading to an extremely +#' complete iteration as in CL21. The difference is that most mass is covered by exact +#' calculations. Afterwards, the algorithm iterates until convergence. +#' The output of Step 1 is reused in every iteration, leading to an extremely #' efficient strategy. -#' -#' If \eqn{p} is sufficiently small, all possible \eqn{2^p-2} on-off vectors \eqn{z} can be -#' evaluated. In this case, no sampling is required and the algorithm returns exact -#' Kernel SHAP values with respect to the given background data. -#' Since [kernelshap()] calculates predictions on data with \eqn{MN} rows +#' +#' If \eqn{p} is sufficiently small, all possible \eqn{2^p-2} on-off vectors \eqn{z} can be +#' evaluated. In this case, no sampling is required and the algorithm returns exact +#' Kernel SHAP values with respect to the given background data. +#' Since [kernelshap()] calculates predictions on data with \eqn{MN} rows #' (\eqn{N} is the background data size and \eqn{M} the number of \eqn{z} vectors), \eqn{p} -#' should not be much higher than 10 for exact calculations. +#' should not be much higher than 10 for exact calculations. #' For similar reasons, degree 2 hybrids should not use \eqn{p} much larger than 40. -#' +#' #' @importFrom foreach %dopar% -#' +#' #' @param object Fitted model object. -#' @param X \eqn{(n \times p)} matrix or `data.frame` with rows to be explained. -#' The columns should only represent model features, not the response +#' @param X \eqn{(n \times p)} matrix or `data.frame` with rows to be explained. +#' The columns should only represent model features, not the response #' (but see `feature_names` on how to overrule this). -#' @param bg_X Background data used to integrate out "switched off" features, +#' @param bg_X Background data used to integrate out "switched off" features, #' often a subset of the training data (typically 50 to 500 rows). -#' In cases with a natural "off" value (like MNIST digits), +#' In cases with a natural "off" value (like MNIST digits), #' this can also be a single row with all values set to the off value. -#' If no `bg_X` is passed (the default) and if `X` is sufficiently large, +#' If no `bg_X` is passed (the default) and if `X` is sufficiently large, #' a random sample of `bg_n` rows from `X` serves as background data. #' @param pred_fun Prediction function of the form `function(object, X, ...)`, -#' providing \eqn{K \ge 1} predictions per row. Its first argument -#' represents the model `object`, its second argument a data structure like `X`. -#' Additional (named) arguments are passed via `...`. -#' The default, [stats::predict()], will work in most cases. -#' @param feature_names Optional vector of column names in `X` used to calculate +#' providing \eqn{K \ge 1} predictions per row. Its first argument +#' represents the model `object`, its second argument a data structure like `X`. +#' Additional (named) arguments are passed via `...`. +#' The default, [stats::predict()], will work in most cases. +#' @param feature_names Optional vector of column names in `X` used to calculate #' SHAP values. By default, this equals `colnames(X)`. Not supported if `X` #' is a matrix. #' @param bg_w Optional vector of case weights for each row of `bg_X`. #' If `bg_X = NULL`, must be of same length as `X`. Set to `NULL` for no weights. #' @param bg_n If `bg_X = NULL`: Size of background data to be sampled from `X`. #' @param exact If `TRUE`, the algorithm will produce exact Kernel SHAP values -#' with respect to the background data. In this case, the arguments `hybrid_degree`, +#' with respect to the background data. In this case, the arguments `hybrid_degree`, #' `m`, `paired_sampling`, `tol`, and `max_iter` are ignored. -#' The default is `TRUE` up to eight features, and `FALSE` otherwise. +#' The default is `TRUE` up to eight features, and `FALSE` otherwise. #' @param hybrid_degree Integer controlling the exactness of the hybrid strategy. For -#' \eqn{4 \le p \le 16}, the default is 2, otherwise it is 1. +#' \eqn{4 \le p \le 16}, the default is 2, otherwise it is 1. #' Ignored if `exact = TRUE`. #' - `0`: Pure sampling strategy not involving any exact part. It is strictly -#' worse than the hybrid strategy and should therefore only be used for +#' worse than the hybrid strategy and should therefore only be used for #' studying properties of the Kernel SHAP algorithm. #' - `1`: Uses all \eqn{2p} on-off vectors \eqn{z} with \eqn{\sum z \in \{1, p-1\}} -#' for the exact part, which covers at least 75% of the mass of the Kernel weight +#' for the exact part, which covers at least 75% of the mass of the Kernel weight #' distribution. The remaining mass is covered by random sampling. -#' - `2`: Uses all \eqn{p(p+1)} on-off vectors \eqn{z} with -#' \eqn{\sum z \in \{1, 2, p-2, p-1\}}. This covers at least 92% of the mass of the -#' Kernel weight distribution. The remaining mass is covered by sampling. +#' - `2`: Uses all \eqn{p(p+1)} on-off vectors \eqn{z} with +#' \eqn{\sum z \in \{1, 2, p-2, p-1\}}. This covers at least 92% of the mass of the +#' Kernel weight distribution. The remaining mass is covered by sampling. #' Convergence usually happens in the minimal possible number of iterations of two. -#' - `k>2`: Uses all on-off vectors with +#' - `k>2`: Uses all on-off vectors with #' \eqn{\sum z \in \{1, \dots, k, p-k, \dots, p-1\}}. #' @param paired_sampling Logical flag indicating whether to do the sampling in a paired -#' manner. This means that with every on-off vector \eqn{z}, also \eqn{1-z} is -#' considered. CL21 shows its superiority compared to standard sampling, therefore the +#' manner. This means that with every on-off vector \eqn{z}, also \eqn{1-z} is +#' considered. CL21 shows its superiority compared to standard sampling, therefore the #' default (`TRUE`) should usually not be changed except for studying properties #' of Kernel SHAP algorithms. Ignored if `exact = TRUE`. -#' @param m Even number of on-off vectors sampled during one iteration. -#' The default is \eqn{2p}, except when `hybrid_degree == 0`. +#' @param m Even number of on-off vectors sampled during one iteration. +#' The default is \eqn{2p}, except when `hybrid_degree == 0`. #' Then it is set to \eqn{8p}. Ignored if `exact = TRUE`. #' @param tol Tolerance determining when to stop. Following CL21, the algorithm keeps -#' iterating until \eqn{\textrm{max}(\sigma_n)/(\textrm{max}(\beta_n) - \textrm{min}(\beta_n)) < \textrm{tol}}, -#' where the \eqn{\beta_n} are the SHAP values of a given observation, -#' and \eqn{\sigma_n} their standard errors. -#' For multidimensional predictions, the criterion must be satisfied for each -#' dimension separately. The stopping criterion uses the fact that standard errors +#' iterating until \eqn{\textrm{max}(\sigma_n)/(\textrm{max}(\beta_n) - \textrm{min}(\beta_n)) < \textrm{tol}}, +#' where the \eqn{\beta_n} are the SHAP values of a given observation, +#' and \eqn{\sigma_n} their standard errors. +#' For multidimensional predictions, the criterion must be satisfied for each +#' dimension separately. The stopping criterion uses the fact that standard errors #' and SHAP values are all on the same scale. Ignored if `exact = TRUE`. -#' @param max_iter If the stopping criterion (see `tol`) is not reached after +#' @param max_iter If the stopping criterion (see `tol`) is not reached after #' `max_iter` iterations, the algorithm stops. Ignored if `exact = TRUE`. #' @param parallel If `TRUE`, use parallel [foreach::foreach()] to loop over rows -#' to be explained. Must register backend beforehand, e.g., via 'doFuture' package, +#' to be explained. Must register backend beforehand, e.g., via 'doFuture' package, #' see README for an example. Parallelization automatically disables the progress bar. -#' @param parallel_args Named list of arguments passed to [foreach::foreach()]. -#' Ideally, this is `NULL` (default). Only relevant if `parallel = TRUE`. -#' Example on Windows: if `object` is a GAM fitted with package 'mgcv', +#' @param parallel_args Named list of arguments passed to [foreach::foreach()]. +#' Ideally, this is `NULL` (default). Only relevant if `parallel = TRUE`. +#' Example on Windows: if `object` is a GAM fitted with package 'mgcv', #' then one might need to set `parallel_args = list(.packages = "mgcv")`. #' @param verbose Set to `FALSE` to suppress messages and the progress bar. #' @param survival Should cumulative hazards ("chf", default) or survival #' probabilities ("prob") per time be predicted? Only in `ranger()` survival models. #' @param ... Additional arguments passed to `pred_fun(object, X, ...)`. -#' @returns +#' @returns #' An object of class "kernelshap" with the following components: -#' - `S`: \eqn{(n \times p)} matrix with SHAP values or, if the model output has +#' - `S`: \eqn{(n \times p)} matrix with SHAP values or, if the model output has #' dimension \eqn{K > 1}, a list of \eqn{K} such matrices. #' - `X`: Same as input argument `X`. -#' - `baseline`: Vector of length K representing the average prediction on the +#' - `baseline`: Vector of length K representing the average prediction on the #' background data. #' - `bg_X`: The background data. #' - `bg_w`: The background case weights. #' - `SE`: Standard errors corresponding to `S` (and organized like `S`). -#' - `n_iter`: Integer vector of length n providing the number of iterations +#' - `n_iter`: Integer vector of length n providing the number of iterations #' per row of `X`. #' - `converged`: Logical vector of length n indicating convergence per row of `X`. -#' - `m`: Integer providing the effective number of sampled on-off vectors used +#' - `m`: Integer providing the effective number of sampled on-off vectors used #' per iteration. -#' - `m_exact`: Integer providing the effective number of exact on-off vectors used +#' - `m_exact`: Integer providing the effective number of exact on-off vectors used #' per iteration. -#' - `prop_exact`: Proportion of the Kernel SHAP weight distribution covered by +#' - `prop_exact`: Proportion of the Kernel SHAP weight distribution covered by #' exact calculations. #' - `exact`: Logical flag indicating whether calculations are exact or not. #' - `txt`: Summary text. #' - `predictions`: \eqn{(n \times K)} matrix with predictions of `X`. #' - `algorithm`: "kernelshap". #' @references -#' 1. Scott M. Lundberg and Su-In Lee. A unified approach to interpreting model -#' predictions. Proceedings of the 31st International Conference on Neural +#' 1. Scott M. Lundberg and Su-In Lee. A unified approach to interpreting model +#' predictions. Proceedings of the 31st International Conference on Neural #' Information Processing Systems, 2017. -#' 2. Ian Covert and Su-In Lee. Improving KernelSHAP: Practical Shapley Value -#' Estimation Using Linear Regression. Proceedings of The 24th International +#' 2. Ian Covert and Su-In Lee. Improving KernelSHAP: Practical Shapley Value +#' Estimation Using Linear Regression. Proceedings of The 24th International #' Conference on Artificial Intelligence and Statistics, PMLR 130:3457-3465, 2021. #' @export #' @examples @@ -221,12 +228,12 @@ kernelshap.default <- function( bg_w <- prep_bg$bg_w bg_n <- nrow(bg_X) n <- nrow(X) - + # Calculate v1 and v0 bg_preds <- align_pred(pred_fun(object, bg_X, ...)) v0 <- wcolMeans(bg_preds, bg_w) # Average pred of bg data: 1 x K v1 <- align_pred(pred_fun(object, X, ...)) # Predictions on X: n x K - + # For p = 1, exact Shapley values are returned if (p == 1L) { out <- case_p1( @@ -234,18 +241,18 @@ kernelshap.default <- function( ) return(out) } - + txt <- summarize_strategy(p, exact = exact, deg = hybrid_degree) if (verbose) { message(txt) } - + # Drop unnecessary columns in bg_X. If X is matrix, also column order is relevant # In what follows, predictions will never be applied directly to bg_X anymore if (!identical(colnames(bg_X), feature_names)) { bg_X <- bg_X[, feature_names, drop = FALSE] } - + # Precalculations that are identical for each row to be explained if (exact || hybrid_degree >= 1L) { if (exact) { @@ -266,21 +273,21 @@ kernelshap.default <- function( if (!exact) { precalc[["bg_X_m"]] <- rep_rows(bg_X, rep.int(seq_len(bg_n), m)) } - + if (max(m, m_exact) * bg_n > 2e5) { warning_burden(max(m, m_exact), bg_n = bg_n) } - + # Apply Kernel SHAP to each row of X if (isTRUE(parallel)) { parallel_args <- c(list(i = seq_len(n)), parallel_args) res <- do.call(foreach::foreach, parallel_args) %dopar% kernelshap_one( - x = X[i, , drop = FALSE], - v1 = v1[i, , drop = FALSE], + x = X[i, , drop = FALSE], + v1 = v1[i, , drop = FALSE], object = object, pred_fun = pred_fun, feature_names = feature_names, - bg_w = bg_w, + bg_w = bg_w, exact = exact, deg = hybrid_degree, paired = paired_sampling, @@ -293,17 +300,17 @@ kernelshap.default <- function( ) } else { if (verbose && n >= 2L) { - pb <- utils::txtProgressBar(max = n, style = 3) + pb <- utils::txtProgressBar(max = n, style = 3) } res <- vector("list", n) for (i in seq_len(n)) { res[[i]] <- kernelshap_one( - x = X[i, , drop = FALSE], - v1 = v1[i, , drop = FALSE], + x = X[i, , drop = FALSE], + v1 = v1[i, , drop = FALSE], object = object, pred_fun = pred_fun, feature_names = feature_names, - bg_w = bg_w, + bg_w = bg_w, exact = exact, deg = hybrid_degree, paired = paired_sampling, @@ -379,7 +386,7 @@ kernelshap.ranger <- function( } kernelshap.default( - object = object, + object = object, X = X, bg_X = bg_X, pred_fun = pred_fun, diff --git a/R/permshap.R b/R/permshap.R index 64acccd..df3c469 100644 --- a/R/permshap.R +++ b/R/permshap.R @@ -2,7 +2,7 @@ #' #' Exact permutation SHAP algorithm with respect to a background dataset, #' see Strumbelj and Kononenko. The function works for up to 14 features. -#' For eight or more features, we recomment to switch to [kernelshap()]. +#' For more than eight features, we recommend [kernelshap()] due to its higher speed. #' #' @inheritParams kernelshap #' @returns @@ -16,12 +16,12 @@ #' - `bg_w`: The background case weights. #' - `m_exact`: Integer providing the effective number of exact on-off vectors used. #' - `exact`: Logical flag indicating whether calculations are exact or not -#' (currently `TRUE`). +#' (currently always `TRUE`). #' - `txt`: Summary text. #' - `predictions`: \eqn{(n \times K)} matrix with predictions of `X`. #' - `algorithm`: "permshap". #' @references -#' 1. Erik Strumbelj and Igor Kononenko. Explaining prediction models and individual +#' 1. Erik Strumbelj and Igor Kononenko. Explaining prediction models and individual #' predictions with feature contributions. Knowledge and Information Systems 41, 2014. #' @export #' @examples @@ -80,7 +80,7 @@ permshap.default <- function( if (verbose) { message(txt) } - + basic_checks(X = X, feature_names = feature_names, pred_fun = pred_fun) prep_bg <- prepare_bg(X = X, bg_X = bg_X, bg_n = bg_n, bg_w = bg_w, verbose = verbose) bg_X <- prep_bg$bg_X @@ -92,32 +92,32 @@ permshap.default <- function( bg_preds <- align_pred(pred_fun(object, bg_X, ...)) v0 <- wcolMeans(bg_preds, w = bg_w) # Average pred of bg data: 1 x K v1 <- align_pred(pred_fun(object, X, ...)) # Predictions on X: n x K - + # Drop unnecessary columns in bg_X. If X is matrix, also column order is relevant # Predictions will never be applied directly to bg_X anymore if (!identical(colnames(bg_X), feature_names)) { bg_X <- bg_X[, feature_names, drop = FALSE] } - + # Precalculations that are identical for each row to be explained Z <- exact_Z(p, feature_names = feature_names, keep_extremes = TRUE) m_exact <- nrow(Z) - 2L # We won't evaluate vz for first and last row precalc <- list( Z = Z, - Z_code = rowpaste(Z), + Z_code = rowpaste(Z), bg_X_rep = rep_rows(bg_X, rep.int(seq_len(bg_n), m_exact)) ) - + if (m_exact * bg_n > 2e5) { warning_burden(m_exact, bg_n = bg_n) } - + # Apply permutation SHAP to each row of X if (isTRUE(parallel)) { parallel_args <- c(list(i = seq_len(n)), parallel_args) res <- do.call(foreach::foreach, parallel_args) %dopar% permshap_one( x = X[i, , drop = FALSE], - v1 = v1[i, , drop = FALSE], + v1 = v1[i, , drop = FALSE], object = object, pred_fun = pred_fun, bg_w = bg_w, @@ -133,7 +133,7 @@ permshap.default <- function( for (i in seq_len(n)) { res[[i]] <- permshap_one( x = X[i, , drop = FALSE], - v1 = v1[i, , drop = FALSE], + v1 = v1[i, , drop = FALSE], object = object, pred_fun = pred_fun, bg_w = bg_w, diff --git a/README.md b/README.md index e13d455..b7165fe 100644 --- a/README.md +++ b/README.md @@ -15,19 +15,18 @@ The package contains three functions to crunch SHAP values: -- `permshap()`: Exact permutation SHAP algorithm of [1]. Recommended for models with up to 8 features. -- `kernelshap()`: Kernel SHAP algorithm of [2] and [3]. Recommended for models with more than 8 features. -- `additive_shap()`: For *additive models* fitted via `lm()`, `glm()`, `mgcv::gam()`, `mgcv::bam()`, `gam::gam()`, `survival::coxph()`, or `survival::survreg()`. Exponentially faster than the model-agnostic options above, and recommended if possible. +- **`permshap()`**: Exact permutation SHAP algorithm of [1]. Recommended for models with up to 8 features. +- **`kernelshap()`**: Kernel SHAP algorithm of [2] and [3]. Recommended for models with more than 8 features. +- **`additive_shap()`**: For *additive models* fitted via `lm()`, `glm()`, `mgcv::gam()`, `mgcv::bam()`, `gam::gam()`, `survival::coxph()`, or `survival::survreg()`. Exponentially faster than the model-agnostic options above, and recommended if possible. -To explain your model, select an explanation dataset `X` (up to 1000 rows from the training data) and apply the recommended function. Use {shapviz} to visualize the resulting SHAP values. +To explain your model, select an explanation dataset `X` (up to 1000 rows from the training data, feature columns only) and apply the recommended function. Use {shapviz} to visualize the resulting SHAP values. -**Remarks for `permshap()` and `kernelshap()`** +**Remarks to `permshap()` and `kernelshap()`** -- `X` should only contain feature columns. - Both algorithms need a representative background data `bg_X` to calculate marginal means (up to 500 rows from the training data). In cases with a natural "off" value (like MNIST digits), this can also be a single row with all values set to the off value. If unspecified, 200 rows are randomly sampled from `X`. -- By changing the defaults in `kernelshap()`, the iterative pure sampling approach of [3] can be enforced. -- `permshap()` vs. `kernelshap()`: For models with interactions of order up to two, exact Kernel SHAP agrees with exact permutation SHAP. -- `additive_shap()` vs. the model-agnostic explainers: The results would agree if the full training data would be used as background data. +- Exact Kernel SHAP is an approximation to exact permutation SHAP. Since exact calculations are usually sufficiently fast for up to eight features, we recommend `permshap()` in this case. With more features, `kernelshap()` switches to a comparably fast, almost exact algorithm. That is why we recommend `kernelshap()` in this case. +- For models with interactions of order up to two, SHAP values of exact permutation SHAP and exact Kernel SHAP agree. +- `permshap()` and `kernelshap()` give the same results as `additive_shap` as long as the full training data would be used as background data. ## Installation diff --git a/man/additive_shap.Rd b/man/additive_shap.Rd index bb4ed43..2b20188 100644 --- a/man/additive_shap.Rd +++ b/man/additive_shap.Rd @@ -7,12 +7,12 @@ additive_shap(object, X, verbose = TRUE, ...) } \arguments{ -\item{object}{Fitted model object.} +\item{object}{Fitted additive model.} -\item{X}{Dataframe with rows to be explained. Will be used like +\item{X}{Dataframe with rows to be explained. Passed to \code{predict(object, newdata = X, type = "terms")}.} -\item{verbose}{Set to \code{FALSE} to suppress messages and the progress bar.} +\item{verbose}{Set to \code{FALSE} to suppress messages.} \item{...}{Currently unused.} } @@ -43,7 +43,7 @@ works for models fitted via } \details{ The SHAP values are extracted via \code{predict(object, newdata = X, type = "terms")}, -a logic heavily inspired by \code{fastshap:::explain.lm(..., exact = TRUE)}. +a logic adopted from \code{fastshap:::explain.lm(..., exact = TRUE)}. Models with interactions (specified via \code{:} or \code{*}), or with terms of multiple features like \code{log(x1/x2)} are not supported. diff --git a/man/kernelshap.Rd b/man/kernelshap.Rd index a55691a..97fbbf4 100644 --- a/man/kernelshap.Rd +++ b/man/kernelshap.Rd @@ -169,11 +169,16 @@ Efficient implementation of Kernel SHAP, see Lundberg and Lee (2017), and Covert and Lee (2021), abbreviated by CL21. For up to \eqn{p=8} features, the resulting Kernel SHAP values are exact regarding the selected background data. For larger \eqn{p}, an almost exact -hybrid algorithm involving iterative sampling is used, see Details. -For up to eight features, however, we recomment to use \code{\link[=permshap]{permshap()}}. +hybrid algorithm combining exact calculations and iterative sampling is used, +see Details. + +Note that (exact) Kernel SHAP is only an approximation of (exact) permutation SHAP. +Thus, for up to eight features, we recommend \code{\link[=permshap]{permshap()}}. For more features, +\code{\link[=permshap]{permshap()}} is slow compared the optimized hybrid strategy of our Kernel SHAP +implementation. } \details{ -Pure iterative Kernel SHAP sampling as in Covert and Lee (2021) works like this: +The pure iterative Kernel SHAP sampling as in Covert and Lee (2021) works like this: \enumerate{ \item A binary "on-off" vector \eqn{z} is drawn from \eqn{\{0, 1\}^p} such that its sum follows the SHAP Kernel weight distribution diff --git a/man/permshap.Rd b/man/permshap.Rd index 23d1c26..247ce71 100644 --- a/man/permshap.Rd +++ b/man/permshap.Rd @@ -94,7 +94,7 @@ background data. \item \code{bg_w}: The background case weights. \item \code{m_exact}: Integer providing the effective number of exact on-off vectors used. \item \code{exact}: Logical flag indicating whether calculations are exact or not -(currently \code{TRUE}). +(currently always \code{TRUE}). \item \code{txt}: Summary text. \item \code{predictions}: \eqn{(n \times K)} matrix with predictions of \code{X}. \item \code{algorithm}: "permshap". @@ -103,7 +103,7 @@ background data. \description{ Exact permutation SHAP algorithm with respect to a background dataset, see Strumbelj and Kononenko. The function works for up to 14 features. -For eight or more features, we recomment to switch to \code{\link[=kernelshap]{kernelshap()}}. +For more than eight features, we recommend \code{\link[=kernelshap]{kernelshap()}} due to its higher speed. } \section{Methods (by class)}{ \itemize{