Skip to content

Commit

Permalink
Updated docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
mayer79 committed Sep 8, 2024
1 parent 6105947 commit 3277da8
Show file tree
Hide file tree
Showing 8 changed files with 153 additions and 140 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## Documentation

- More compact README.
- Updated function description.

# kernelshap 0.7.0

Expand Down
27 changes: 14 additions & 13 deletions R/additive_shap.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,20 @@
#' - `gam::gam()`,
#' - [survival::coxph()], and
#' - [survival::survreg()].
#'
#'
#' The SHAP values are extracted via `predict(object, newdata = X, type = "terms")`,
#' a logic heavily inspired by `fastshap:::explain.lm(..., exact = TRUE)`.
#' a logic adopted from `fastshap:::explain.lm(..., exact = TRUE)`.
#' Models with interactions (specified via `:` or `*`), or with terms of
#' multiple features like `log(x1/x2)` are not supported.
#'
#'
#' Note that the SHAP values obtained by [additive_shap()] are expected to
#' match those of [permshap()] and [kernelshap()] as long as their background
#' data equals the full training data (which is typically not feasible).
#'
#' @inheritParams kernelshap
#' @param X Dataframe with rows to be explained. Will be used like
#' @param object Fitted additive model.
#' @param X Dataframe with rows to be explained. Passed to
#' `predict(object, newdata = X, type = "terms")`.
#' @param verbose Set to `FALSE` to suppress messages.
#' @param ... Currently unused.
#' @returns
#' An object of class "kernelshap" with the following components:
Expand All @@ -38,15 +39,15 @@
#' fit <- lm(Sepal.Length ~ ., data = iris)
#' s <- additive_shap(fit, head(iris))
#' s
#'
#'
#' # MODEL TWO: More complicated (but not very clever) formula
#' fit <- lm(
#' Sepal.Length ~ poly(Sepal.Width, 2) + log(Petal.Length) + log(Sepal.Width),
#' data = iris
#' )
#' s_add <- additive_shap(fit, head(iris))
#' s_add
#'
#'
#' # Equals kernelshap()/permshap() when background data is full training data
#' s_kernel <- kernelshap(
#' fit, head(iris[c("Sepal.Width", "Petal.Length")]), bg_X = iris
Expand All @@ -59,28 +60,28 @@ additive_shap <- function(object, X, verbose = TRUE, ...) {
if (any(attr(stats::terms(object), "order") > 1)) {
stop("Additive SHAP not appropriate for models with interactions.")
}

txt <- "Exact additive SHAP via predict(..., type = 'terms')"
if (verbose) {
message(txt)
}

S <- stats::predict(object, newdata = X, type = "terms")
rownames(S) <- NULL

# Baseline value
b <- as.vector(attr(S, "constant"))
if (is.null(b)) {
b <- 0
}

# Which columns of X are used in each column of S?
s_names <- colnames(S)
cols_used <- lapply(s_names, function(z) all.vars(stats::reformulate(z)))
if (any(lengths(cols_used) > 1L)) {
stop("The formula contains terms with multiple features (not supported).")
}

# Collapse all columns in S using the same column in X and rename accordingly
mapping <- split(
s_names, factor(unlist(cols_used), levels = colnames(X)), drop = TRUE
Expand All @@ -89,7 +90,7 @@ additive_shap <- function(object, X, verbose = TRUE, ...) {
cbind,
lapply(mapping, function(z) rowSums(S[, z, drop = FALSE], na.rm = TRUE))
)

structure(
list(
S = S,
Expand Down
203 changes: 105 additions & 98 deletions R/kernelshap.R

Large diffs are not rendered by default.

22 changes: 11 additions & 11 deletions R/permshap.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#'
#' Exact permutation SHAP algorithm with respect to a background dataset,
#' see Strumbelj and Kononenko. The function works for up to 14 features.
#' For eight or more features, we recomment to switch to [kernelshap()].
#' For more than eight features, we recommend [kernelshap()] due to its higher speed.
#'
#' @inheritParams kernelshap
#' @returns
Expand All @@ -16,12 +16,12 @@
#' - `bg_w`: The background case weights.
#' - `m_exact`: Integer providing the effective number of exact on-off vectors used.
#' - `exact`: Logical flag indicating whether calculations are exact or not
#' (currently `TRUE`).
#' (currently always `TRUE`).
#' - `txt`: Summary text.
#' - `predictions`: \eqn{(n \times K)} matrix with predictions of `X`.
#' - `algorithm`: "permshap".
#' @references
#' 1. Erik Strumbelj and Igor Kononenko. Explaining prediction models and individual
#' 1. Erik Strumbelj and Igor Kononenko. Explaining prediction models and individual
#' predictions with feature contributions. Knowledge and Information Systems 41, 2014.
#' @export
#' @examples
Expand Down Expand Up @@ -80,7 +80,7 @@ permshap.default <- function(
if (verbose) {
message(txt)
}

basic_checks(X = X, feature_names = feature_names, pred_fun = pred_fun)
prep_bg <- prepare_bg(X = X, bg_X = bg_X, bg_n = bg_n, bg_w = bg_w, verbose = verbose)
bg_X <- prep_bg$bg_X
Expand All @@ -92,32 +92,32 @@ permshap.default <- function(
bg_preds <- align_pred(pred_fun(object, bg_X, ...))
v0 <- wcolMeans(bg_preds, w = bg_w) # Average pred of bg data: 1 x K
v1 <- align_pred(pred_fun(object, X, ...)) # Predictions on X: n x K

# Drop unnecessary columns in bg_X. If X is matrix, also column order is relevant
# Predictions will never be applied directly to bg_X anymore
if (!identical(colnames(bg_X), feature_names)) {
bg_X <- bg_X[, feature_names, drop = FALSE]
}

# Precalculations that are identical for each row to be explained
Z <- exact_Z(p, feature_names = feature_names, keep_extremes = TRUE)
m_exact <- nrow(Z) - 2L # We won't evaluate vz for first and last row
precalc <- list(
Z = Z,
Z_code = rowpaste(Z),
Z_code = rowpaste(Z),
bg_X_rep = rep_rows(bg_X, rep.int(seq_len(bg_n), m_exact))
)

if (m_exact * bg_n > 2e5) {
warning_burden(m_exact, bg_n = bg_n)
}

# Apply permutation SHAP to each row of X
if (isTRUE(parallel)) {
parallel_args <- c(list(i = seq_len(n)), parallel_args)
res <- do.call(foreach::foreach, parallel_args) %dopar% permshap_one(
x = X[i, , drop = FALSE],
v1 = v1[i, , drop = FALSE],
v1 = v1[i, , drop = FALSE],
object = object,
pred_fun = pred_fun,
bg_w = bg_w,
Expand All @@ -133,7 +133,7 @@ permshap.default <- function(
for (i in seq_len(n)) {
res[[i]] <- permshap_one(
x = X[i, , drop = FALSE],
v1 = v1[i, , drop = FALSE],
v1 = v1[i, , drop = FALSE],
object = object,
pred_fun = pred_fun,
bg_w = bg_w,
Expand Down
17 changes: 8 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,18 @@

The package contains three functions to crunch SHAP values:

- `permshap()`: Exact permutation SHAP algorithm of [1]. Recommended for models with up to 8 features.
- `kernelshap()`: Kernel SHAP algorithm of [2] and [3]. Recommended for models with more than 8 features.
- `additive_shap()`: For *additive models* fitted via `lm()`, `glm()`, `mgcv::gam()`, `mgcv::bam()`, `gam::gam()`, `survival::coxph()`, or `survival::survreg()`. Exponentially faster than the model-agnostic options above, and recommended if possible.
- **`permshap()`**: Exact permutation SHAP algorithm of [1]. Recommended for models with up to 8 features.
- **`kernelshap()`**: Kernel SHAP algorithm of [2] and [3]. Recommended for models with more than 8 features.
- **`additive_shap()`**: For *additive models* fitted via `lm()`, `glm()`, `mgcv::gam()`, `mgcv::bam()`, `gam::gam()`, `survival::coxph()`, or `survival::survreg()`. Exponentially faster than the model-agnostic options above, and recommended if possible.

To explain your model, select an explanation dataset `X` (up to 1000 rows from the training data) and apply the recommended function. Use {shapviz} to visualize the resulting SHAP values.
To explain your model, select an explanation dataset `X` (up to 1000 rows from the training data, feature columns only) and apply the recommended function. Use {shapviz} to visualize the resulting SHAP values.

**Remarks for `permshap()` and `kernelshap()`**
**Remarks to `permshap()` and `kernelshap()`**

- `X` should only contain feature columns.
- Both algorithms need a representative background data `bg_X` to calculate marginal means (up to 500 rows from the training data). In cases with a natural "off" value (like MNIST digits), this can also be a single row with all values set to the off value. If unspecified, 200 rows are randomly sampled from `X`.
- By changing the defaults in `kernelshap()`, the iterative pure sampling approach of [3] can be enforced.
- `permshap()` vs. `kernelshap()`: For models with interactions of order up to two, exact Kernel SHAP agrees with exact permutation SHAP.
- `additive_shap()` vs. the model-agnostic explainers: The results would agree if the full training data would be used as background data.
- Exact Kernel SHAP is an approximation to exact permutation SHAP. Since exact calculations are usually sufficiently fast for up to eight features, we recommend `permshap()` in this case. With more features, `kernelshap()` switches to a comparably fast, almost exact algorithm. That is why we recommend `kernelshap()` in this case.
- For models with interactions of order up to two, SHAP values of exact permutation SHAP and exact Kernel SHAP agree.
- `permshap()` and `kernelshap()` give the same results as `additive_shap` as long as the full training data would be used as background data.

## Installation

Expand Down
8 changes: 4 additions & 4 deletions man/additive_shap.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 8 additions & 3 deletions man/kernelshap.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/permshap.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 3277da8

Please sign in to comment.