Skip to content

Commit

Permalink
Merge pull request #101 from ModelOriented/align_pred
Browse files Browse the repository at this point in the history
Align pred
  • Loading branch information
mayer79 authored Sep 12, 2023
2 parents 1ab40a4 + 5ad8fe1 commit 4851434
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 25 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
## Maintenance

- Added explanation of sampling Kernel SHAP to help file.
- Internal code optimizations.

# kernelshap 0.3.7

Expand Down
6 changes: 2 additions & 4 deletions R/kernelshap.R
Original file line number Diff line number Diff line change
Expand Up @@ -219,10 +219,8 @@ kernelshap.default <- function(object, X, bg_X, pred_fun = stats::predict,
}

# Calculate v1 and v0
v1 <- check_pred(pred_fun(object, X, ...), n = n) # Predictions on X: n x K
bg_preds <- check_pred(
pred_fun(object, bg_X[, colnames(X), drop = FALSE], ...), n = bg_n
)
v1 <- align_pred(pred_fun(object, X, ...)) # Predictions on X: n x K
bg_preds <- align_pred(pred_fun(object, bg_X[, colnames(X), drop = FALSE], ...))
v0 <- weighted_colMeans(bg_preds, bg_w) # Average pred of bg data: 1 x K

# For p = 1, exact Shapley values are returned
Expand Down
32 changes: 13 additions & 19 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ get_vz <- function(X, bg, Z, object, pred_fun, feature_names, w, ...) {
X[[nm]][s] <- bg[[nm]][s]
}
}
preds <- check_pred(pred_fun(object, X, ...), n = nrow(X))
preds <- align_pred(pred_fun(object, X, ...))

# Aggregate
if (is.null(w)) {
Expand Down Expand Up @@ -193,29 +193,23 @@ reorganize_list <- function(alist, nms) {
lapply(out, as.matrix)
}

# Checks and reshapes predictions to (n x K) matrix
check_pred <- function(x, n) {
if (
!is.vector(x) &&
!is.matrix(x) &&
!is.data.frame(x) &&
!(is.array(x) && length(dim(x)) <= 2L)
) {
stop("Predictions must be a vector, matrix, data.frame, or <=2D array")
}
if (is.data.frame(x) || is.array(x)) {
#' Aligns Predictions
#'
#' Turns predictions into matrix. Originally implemented in {hstats}.
#'
#' @noRd
#' @keywords internal
#'
#' @param x Object representing model predictions.
#' @returns Like `x`, but converted to matrix.
align_pred <- function(x) {
if (!is.matrix(x)) {
x <- as.matrix(x)
}
if (!is.numeric(x)) {
stop("Predictions must be numeric")
}
if (is.matrix(x) && nrow(x) == n) {
return(x)
}
if (length(x) == n) {
return(matrix(x, nrow = n))
}
stop("Predictions must be a length n vector or a matrix/data.frame/array with n rows.")
x
}

# Helper function in print() and summary()
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/test-kernelshap.R
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ preds <- unname(predict(fit, iris))
test_that("Special case p = 1 works", {
s <- kernelshap(fit, iris[1:5, x, drop = FALSE], bg_X = iris, verbose = FALSE)
expect_equal(s$baseline, mean(iris$Sepal.Length))
expect_equal(rowSums(s$S) + s$baseline, preds[1:5])
expect_equal(unname(rowSums(s$S)) + s$baseline, preds[1:5])
expect_equal(s$SE[1L], 0)
})

Expand Down Expand Up @@ -204,7 +204,7 @@ test_that("Special case p = 1 works with case weights", {
)

expect_equal(s$baseline, weighted.mean(iris$Sepal.Length, iris$Petal.Length))
expect_equal(rowSums(s$S) + s$baseline, preds[1:5])
expect_equal(unname(rowSums(s$S)) + s$baseline, preds[1:5])
})

fit <- lm(
Expand Down

0 comments on commit 4851434

Please sign in to comment.