Skip to content

Commit

Permalink
Fix explain_forecast and implement adaptive within that framework. (#405
Browse files Browse the repository at this point in the history
)
  • Loading branch information
jonlachmann authored Oct 18, 2024
1 parent 2a3aff9 commit 31b2d6f
Show file tree
Hide file tree
Showing 35 changed files with 1,030 additions and 695 deletions.
2 changes: 0 additions & 2 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,11 @@ export(compute_estimates)
export(compute_shapley_new)
export(compute_time)
export(compute_vS)
export(compute_vS_forecast)
export(correction_matrix_cpp)
export(create_coalition_table)
export(explain)
export(explain_forecast)
export(finalize_explanation)
export(finalize_explanation_forecast)
export(get_adaptive_arguments_default)
export(get_cov_mat)
export(get_data_specs)
Expand Down
6 changes: 3 additions & 3 deletions R/check_convergence.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ check_convergence <- function(internal) {

n_sampled_coalitions <- internal$iter_list[[iter]]$n_coalitions - 2 # Subtract the zero and full predictions

max_sd <- dt_shapley_sd[, max(.SD), .SDcols = -1, by = .I]$V1 # Max per prediction
max_sd <- dt_shapley_sd[, max(.SD, na.rm = TRUE), .SDcols = -1, by = .I]$V1 # Max per prediction
max_sd0 <- max_sd * sqrt(n_sampled_coalitions) # Scales UP the sd as it scales at this rate

dt_shapley_est0 <- copy(dt_shapley_est)
Expand All @@ -33,8 +33,8 @@ check_convergence <- function(internal) {
} else {
converged_exact <- FALSE
if (!is.null(convergence_tolerance)) {
dt_shapley_est0[, maxval := max(.SD), .SDcols = -c(1, 2), by = .I]
dt_shapley_est0[, minval := min(.SD), .SDcols = -c(1, 2), by = .I]
dt_shapley_est0[, maxval := max(.SD, na.rm = TRUE), .SDcols = -c(1, 2), by = .I]
dt_shapley_est0[, minval := min(.SD, na.rm = TRUE), .SDcols = -c(1, 2), by = .I]
dt_shapley_est0[, max_sd0 := max_sd0]
dt_shapley_est0[, req_samples := (max_sd0 / ((maxval - minval) * convergence_tolerance))^2]
dt_shapley_est0[, conv_measure := max_sd0 / ((maxval - minval) * sqrt(n_sampled_coalitions))]
Expand Down
66 changes: 50 additions & 16 deletions R/compute_estimates.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#' @keywords internal
compute_estimates <- function(internal, vS_list) {
verbose <- internal$parameters$verbose
cli_id <- internal$parameter$cli_id
type <- internal$parameters$type

internal$timing_list$compute_vS <- Sys.time()

Expand Down Expand Up @@ -40,7 +40,7 @@ compute_estimates <- function(internal, vS_list) {
cli::cli_progress_step("Boostrapping Shapley value sds")
}

dt_shapley_sd <- bootstrap_shapley_new(internal, n_boot_samps = n_boot_samps, processed_vS_list$dt_vS)
dt_shapley_sd <- bootstrap_shapley(internal, n_boot_samps = n_boot_samps, processed_vS_list$dt_vS)

internal$timing_list$compute_bootstrap <- Sys.time()
} else {
Expand All @@ -50,10 +50,12 @@ compute_estimates <- function(internal, vS_list) {


# Adding explain_id to the output dt
dt_shapley_est[, explain_id := .I]
setcolorder(dt_shapley_est, "explain_id")
dt_shapley_sd[, explain_id := .I]
setcolorder(dt_shapley_sd, "explain_id")
if (type != "forecast") {
dt_shapley_est[, explain_id := .I]
setcolorder(dt_shapley_est, "explain_id")
dt_shapley_sd[, explain_id := .I]
setcolorder(dt_shapley_sd, "explain_id")
}


internal$iter_list[[iter]]$dt_shapley_est <- dt_shapley_est
Expand Down Expand Up @@ -137,9 +139,10 @@ compute_shapley_new <- function(internal, dt_vS) {

# If multiple horizons with explain_forecast are used, we only distribute value to those used at each horizon
if (type == "forecast") {
id_coalition_mapper_dt <- internal$objects$id_coalition_mapper_dt
id_coalition_mapper_dt <- internal$iter_list[[iter]]$id_coalition_mapper_dt
horizon <- internal$parameters$horizon
cols_per_horizon <- internal$objects$cols_per_horizon
shap_names <- internal$parameters$shap_names
W_list <- internal$objects$W_list

kshap_list <- list()
Expand Down Expand Up @@ -260,21 +263,47 @@ bootstrap_shapley <- function(internal, dt_vS, n_boot_samps = 100, seed = 123) {
return(dt_kshap_boot_sd)
}

bootstrap_shapley_new <- function(internal, dt_vS, n_boot_samps = 100, seed = 123) {
bootstrap_shapley <- function(internal, dt_vS, n_boot_samps = 100, seed = 123) {
iter <- length(internal$iter_list)
type <- internal$parameters$type
is_groupwise <- internal$parameters$is_groupwise
X_list <- internal$iter_list[[iter]]$X_list

X <- internal$iter_list[[iter]]$X
result <- list()
if (type == "forecast") {
n_explain <- internal$parameters$n_explain
for (i in seq_along(X_list)) {
X <- X_list[[i]]
if (is_groupwise) {
n_shapley_values <- length(internal$data$shap_names)
shap_names <- internal$data$shap_names
} else {
n_shapley_values <- length(internal$parameters$horizon_features[[i]])
shap_names <- internal$parameters$horizon_features[[i]]
}
dt_cols <- c(1, seq_len(n_explain) + (i - 1) * n_explain + 1)
dt_vS_this <- dt_vS[, ..dt_cols]
result[[i]] <- bootstrap_shapley_inner(X, n_shapley_values, shap_names, internal, dt_vS_this, n_boot_samps, seed)
}
result <- rbindlist(result, fill = TRUE)
} else {
X <- internal$iter_list[[iter]]$X
n_shapley_values <- internal$parameters$n_shapley_values
shap_names <- internal$parameters$shap_names
result <- bootstrap_shapley_inner(X, n_shapley_values, shap_names, internal, dt_vS, n_boot_samps, seed)
}
return(result)
}

set.seed(seed)
bootstrap_shapley_inner <- function(X, n_shapley_values, shap_names, internal, dt_vS, n_boot_samps = 100, seed = 123) {
type <- internal$parameters$type
iter <- length(internal$iter_list)

is_groupwise <- internal$parameters$is_groupwise
set.seed(seed)

n_explain <- internal$parameters$n_explain
paired_shap_sampling <- internal$parameters$paired_shap_sampling
shapley_reweight <- internal$parameters$shapley_reweighting
shap_names <- internal$parameters$shap_names
n_shapley_values <- internal$parameters$n_shapley_values


X_org <- copy(X)

Expand All @@ -300,7 +329,6 @@ bootstrap_shapley_new <- function(internal, dt_vS, n_boot_samps = 100, seed = 12

X_boot00[, boot_id := rep(seq(n_boot_samps), times = n_coalitions_boot / 2)]


X_boot00_paired <- copy(X_boot00[, .(coalitions, boot_id)])
X_boot00_paired[, coalitions := lapply(coalitions, function(x) seq(n_shapley_values)[-x])]
X_boot00_paired[, coalitions_tmp := sapply(coalitions, paste, collapse = " ")]
Expand Down Expand Up @@ -338,7 +366,13 @@ bootstrap_shapley_new <- function(internal, dt_vS, n_boot_samps = 100, seed = 12
X_boot[, sample_freq := .N / n_coalitions_boot, by = .(id_coalition, boot_id)]
X_boot <- unique(X_boot, by = c("id_coalition", "boot_id"))
X_boot[, shapley_weight := sample_freq]
X_boot[coalition_size %in% c(0, n_shapley_values), shapley_weight := X_org[1, shapley_weight]]
if (type == "forecast") {
id_coalition_mapper_dt <- internal$iter_list[[iter]]$id_coalition_mapper_dt
full_ids <- id_coalition_mapper_dt$id_coalition[id_coalition_mapper_dt$full]
X_boot[coalition_size == 0 | id_coalition %in% full_ids, shapley_weight := X_org[1, shapley_weight]]
} else {
X_boot[coalition_size %in% c(0, n_shapley_values), shapley_weight := X_org[1, shapley_weight]]
}
}

for (i in seq_len(n_boot_samps)) {
Expand Down
39 changes: 2 additions & 37 deletions R/compute_vS.R
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,8 @@ compute_preds <- function(
if (type == "forecast") {
dt[, (pred_cols) := predict_model(
x = model,
newdata = .SD[, 1:n_endo],
newreg = .SD[, -(1:n_endo)],
newdata = .SD[, .SD, .SDcols = seq_len(n_endo)],
newreg = .SD[, .SD, .SDcols = seq_len(length(feature_names) - n_endo) + n_endo],
horizon = horizon,
explain_idx = explain_idx[id],
explain_lags = explain_lags,
Expand All @@ -263,38 +263,3 @@ compute_MCint <- function(dt, pred_cols = "p_hat") {

return(dt_mat)
}


#' Computes `v(S)` for all features subsets `S`.
#'
#' @inheritParams default_doc
#' @inheritParams explain
#'
#' @param method Character
#' Indicates whether the lappy method (default) or loop method should be used.
#'
#' @export
compute_vS_forecast <- function(internal, model, predict_model, method = "future") {
# old function used only for forecast temporary
S_batch <- internal$objects$S_batch



if (method == "future") {
ret <- future_compute_vS_batch(S_batch = S_batch, internal = internal, model = model, predict_model = predict_model)
} else {
# Doing the same as above without future without progressbar or paralellization
ret <- list()
for (i in seq_along(S_batch)) {
S <- S_batch[[i]]
ret[[i]] <- batch_compute_vS(
S = S,
internal = internal,
model = model,
predict_model = predict_model
)
}
}

return(ret)
}
Loading

0 comments on commit 31b2d6f

Please sign in to comment.