Skip to content

Commit

Permalink
Merge pull request #102 from ModelOriented/code_optim
Browse files Browse the repository at this point in the history
Code optimization
  • Loading branch information
mayer79 authored Sep 12, 2023
2 parents 4851434 + be3c755 commit 14c430b
Show file tree
Hide file tree
Showing 8 changed files with 181 additions and 113 deletions.
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
## Maintenance

- Added explanation of sampling Kernel SHAP to help file.
- Internal code optimizations.
- In internal calculations, use explicit `feature_names` as dimnames (https://github.com/ModelOriented/kernelshap/issues/96)

# kernelshap 0.3.7

Expand Down
71 changes: 46 additions & 25 deletions R/exact.R
Original file line number Diff line number Diff line change
@@ -1,40 +1,59 @@
# Functions required only for handling exact cases
# Functions required only for handling (partly) exact cases

# Provides fixed input for the exact case:
# - Z: Matrix with all 2^p-2 on-off vectors z
# - w: Vector with row weights of Z ensuring that the distribution of sum(z) matches
# the SHAP kernel distribution
# - A: Exact matrix A = Z'wZ
input_exact <- function(p) {
Z <- exact_Z(p)
input_exact <- function(p, feature_names) {
Z <- exact_Z(p, feature_names = feature_names)
# Each Kernel weight(j) is divided by the number of vectors z having sum(z) = j
w <- kernel_weights(p) / choose(p, 1:(p - 1L))
list(Z = Z, w = w[rowSums(Z)], A = exact_A(p))
list(Z = Z, w = w[rowSums(Z)], A = exact_A(p, feature_names = feature_names))
}

# Calculates exact A. Notice the difference to the off-diagnonals in the Supplement of
# Covert and Lee (2021). Credits to David Watson for figuring out the correct formula,
# see our discussions in https://github.com/ModelOriented/kernelshap/issues/22
exact_A <- function(p) {
#' Exact Matrix A
#'
#' Internal function that calculates exact A.
#' Notice the difference to the off-diagnonals in the Supplement of
#' Covert and Lee (2021). Credits to David Watson for figuring out the correct formula,
#' see our discussions in https://github.com/ModelOriented/kernelshap/issues/22
#'
#' @noRd
#' @keywords internal
#'
#' @param p Number of features.
#' @param feature_names Feature names.
#' @returns A (p x p) matrix.
exact_A <- function(p, feature_names) {
S <- 1:(p - 1L)
c_pr <- S * (S - 1) / p / (p - 1)
off_diag <- sum(kernel_weights(p) * c_pr)
A <- matrix(off_diag, nrow = p, ncol = p)
A <- matrix(
off_diag, nrow = p, ncol = p, dimnames = list(feature_names, feature_names)
)
diag(A) <- 0.5
A
}

# Creates (2^p-2) x p matrix with all on-off vectors z of length p
# Instead of calculating this object, we could evaluate it for different p <= p_max
# and store it as a list in the package.
exact_Z <- function(p) {
#' All on-off Vectors
#'
#' Internal function that creates matrix of all on-off vectors of length `p`.
#'
#' @noRd
#' @keywords internal
#'
#' @param p Number of features.
#' @param feature_names Feature names.
#' @returns An integer ((2^p - 2) x p) matrix of all on-off vectors of length `p`.
exact_Z <- function(p, feature_names) {
Z <- as.matrix(do.call(expand.grid, replicate(p, 0:1, simplify = FALSE)))
dimnames(Z) <- NULL
colnames(Z) <- feature_names
Z[2:(nrow(Z) - 1L), , drop = FALSE]
}

# List all length p vectors z with sum(z) in {k, p - k}
partly_exact_Z <- function(p, k) {
partly_exact_Z <- function(p, k, feature_names) {
if (k < 1L) {
stop("k must be at least 1")
}
Expand All @@ -48,17 +67,18 @@ partly_exact_Z <- function(p, k) {
utils::combn(seq_len(p), k, FUN = function(z) {x <- numeric(p); x[z] <- 1; x})
)
}
if (p == 2L * k) {
return(Z)
if (p != 2L * k) {
Z <- rbind(Z, 1 - Z)
}
return(rbind(Z, 1 - Z))
colnames(Z) <- feature_names
Z
}

# Create Z, w, A for vectors z with sum(z) in {k, p-k} for k in {1, ..., deg}.
# The total weights do not sum to one, except in the special (exact) case deg=p-deg.
# (The remaining weight will be added via input_sampling(p, deg=deg)).
# Note that for a given k, the weights are constant.
input_partly_exact <- function(p, deg) {
input_partly_exact <- function(p, deg, feature_names) {
if (deg < 1L) {
stop("deg must be at least 1")
}
Expand All @@ -70,7 +90,7 @@ input_partly_exact <- function(p, deg) {
Z <- w <- vector("list", deg)

for (k in seq_len(deg)) {
Z[[k]] <- partly_exact_Z(p, k = k)
Z[[k]] <- partly_exact_Z(p, k = k, feature_names = feature_names)
n <- nrow(Z[[k]])
w_tot <- kw[k] * (2 - (p == 2L * k))
w[[k]] <- rep(w_tot / n, n)
Expand All @@ -82,20 +102,21 @@ input_partly_exact <- function(p, deg) {
}

# Case p = 1 returns exact Shapley values
case_p1 <- function(n, nms, v0, v1, X, verbose) {
case_p1 <- function(n, feature_names, v0, v1, X, verbose) {
txt <- "Exact Shapley values (p = 1)"
if (verbose) {
message(txt)
}
S <- v1 - v0[rep(1L, n), , drop = FALSE]
SE <- matrix(numeric(n), dimnames = list(NULL, nms))
S <- v1 - v0[rep(1L, n), , drop = FALSE] # (n x K)
SE <- matrix(numeric(n), dimnames = list(NULL, feature_names)) # (n x 1)
if (ncol(v1) > 1L) {
SE <- replicate(ncol(v1), SE, simplify = FALSE)
S <- lapply(
asplit(S, MARGIN = 2L), function(M) as.matrix(M, dimnames = list(NULL, nms))
asplit(S, MARGIN = 2L), function(M)
as.matrix(M, dimnames = list(NULL, feature_names))
)
} else {
colnames(S) <- nms
colnames(S) <- feature_names
}
out <- list(
S = S,
Expand Down
14 changes: 10 additions & 4 deletions R/kernelshap.R
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,9 @@ kernelshap.default <- function(object, X, bg_X, pred_fun = stats::predict,
# For p = 1, exact Shapley values are returned
if (p == 1L) {
return(
case_p1(n = n, nms = feature_names, v0 = v0, v1 = v1, X = X, verbose = verbose)
case_p1(
n = n, feature_names = feature_names, v0 = v0, v1 = v1, X = X, verbose = verbose
)
)
}

Expand All @@ -238,7 +240,11 @@ kernelshap.default <- function(object, X, bg_X, pred_fun = stats::predict,

# Precalculations for the real Kernel SHAP
if (exact || hybrid_degree >= 1L) {
precalc <- if (exact) input_exact(p) else input_partly_exact(p, hybrid_degree)
if (exact) {
precalc <- input_exact(p, feature_names = feature_names)
} else {
precalc <- input_partly_exact(p, deg = hybrid_degree, feature_names = feature_names)
}
m_exact <- nrow(precalc[["Z"]])
prop_exact <- sum(precalc[["w"]])
precalc[["bg_X_exact"]] <- bg_X[rep(seq_len(bg_n), times = m_exact), , drop = FALSE]
Expand Down Expand Up @@ -317,10 +323,10 @@ kernelshap.default <- function(object, X, bg_X, pred_fun = stats::predict,
warning("\nNon-convergence for ", sum(!converged), " rows.")
}
out <- list(
S = reorganize_list(lapply(res, `[[`, "beta"), nms = feature_names),
S = reorganize_list(lapply(res, `[[`, "beta")),
X = X,
baseline = as.vector(v0),
SE = reorganize_list(lapply(res, `[[`, "sigma"), nms = feature_names),
SE = reorganize_list(lapply(res, `[[`, "sigma")),
n_iter = vapply(res, `[[`, "n_iter", FUN.VALUE = integer(1L)),
converged = converged,
m = m,
Expand Down
10 changes: 6 additions & 4 deletions R/sampling.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Draw m binary vectors z of length p with sum(z) distributed according
# to Kernel SHAP weights -> (m x p) matrix.
# The argument S can be used to restrict the range of sum(z).
sample_Z <- function(p, m, S = 1:(p - 1L)) {
sample_Z <- function(p, m, feature_names, S = 1:(p - 1L)) {
# First draw s = sum(z) according to Kernel weights (renormalized to sum 1)
probs <- kernel_weights(p, S = S)
N <- S[sample.int(length(S), m, replace = TRUE, prob = probs)]
Expand All @@ -22,6 +22,7 @@ sample_Z <- function(p, m, S = 1:(p - 1L)) {
dim(out) <- c(p, m)
ord <- order(col(out), sample.int(m * p))
out[] <- out[ord]
rownames(out) <- feature_names
t(out)
}

Expand All @@ -46,17 +47,18 @@ conv_crit <- function(sig, bet) {
#
# If deg > 0, vectors z with sum(z) restricted to [deg+1, p-deg-1] are sampled.
# This case is used in combination with input_partly_hybrid(). Consequently, sum(w) < 1.
input_sampling <- function(p, m, deg, paired) {
input_sampling <- function(p, m, deg, paired, feature_names) {
if (p < 2L * deg + 2L) {
stop("p must be >=2*deg + 2")
}
S <- (deg + 1L):(p - deg - 1L)
Z <- sample_Z(m = if (paired) m / 2 else m, p = p, S = S)
Z <- sample_Z(
p = p, m = if (paired) m / 2 else m, feature_names = feature_names, S = S
)
if (paired) {
Z <- rbind(Z, 1 - Z)
}
w_total <- if (deg == 0L) 1 else 1 - 2 * sum(kernel_weights(p)[seq_len(deg)])
w <- w_total / m
list(Z = Z, w = rep(w, m), A = crossprod(Z) * w)
}

Loading

0 comments on commit 14c430b

Please sign in to comment.