Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix explain_forecast and implement adaptive within that framework. #405

Merged
merged 31 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
9fb231d
Set shapley_reweighting to FALSE in shapley_setup_forecast for debugg…
jonlachmann Oct 1, 2024
f594996
* explain_forecast now passes tests which compare horizons.
jonlachmann Oct 3, 2024
3d16f14
* Disable compute_sd for forecasts for now. Fix it in the future.
jonlachmann Oct 3, 2024
e225b2d
* Stop using an old function specifically for explain_forecast.
jonlachmann Oct 3, 2024
2f50c4a
* Standard deviation working for explain_forecast
jonlachmann Oct 7, 2024
7002a52
* Set adaptive to false for now for explain_forecast
jonlachmann Oct 7, 2024
8bb20d5
* Adaptive working for forecast.
jonlachmann Oct 7, 2024
1e0ee0a
* Rename tests
jonlachmann Oct 8, 2024
2c8f537
man + run tests on GHA
martinju Oct 13, 2024
bcda193
Sync zzz with 1.0.0
jonlachmann Oct 15, 2024
64d0310
Ref man from 1.0.0
jonlachmann Oct 15, 2024
68a65b7
Merge branch 'shapr-1.0.0' into jon-fixes-rebase
jonlachmann Oct 15, 2024
bedc245
get shap_names back
martinju Oct 15, 2024
0880237
adding adaptive_arguments and shapley_reweighting
martinju Oct 16, 2024
a48da25
adding latest forecast testfiles
martinju Oct 16, 2024
5d30298
remove n_batches
martinju Oct 16, 2024
a608dd6
not adding explain_id for forecast (explain_idx is already there)
martinju Oct 16, 2024
79b6132
some forecast tests OK
martinju Oct 16, 2024
79853ee
style, lint and simplify full_ids code
martinju Oct 16, 2024
d46d501
fixes during meeting.
martinju Oct 16, 2024
a0ea854
* Add tests for adaptive in explain_forecast.
jonlachmann Oct 17, 2024
5ed90bb
Merge remote-tracking branch 'origin/jon-fixes-rebase' into jon-fixes…
jonlachmann Oct 17, 2024
9249abc
moving some iterative objects to iter_list
martinju Oct 17, 2024
6eaf071
tests
martinju Oct 17, 2024
92ce297
rename boostrap functions, style and lint
martinju Oct 17, 2024
d76ca47
fix group_names once an for all
martinju Oct 17, 2024
7d99836
+
martinju Oct 17, 2024
e9cb76d
* Rename data in forecast tests to avoid conflicts.
jonlachmann Oct 18, 2024
f03d20f
update tests after Jons fix
martinju Oct 18, 2024
421a4de
Robustifies the print function to pass tests
martinju Oct 18, 2024
0a71b3f
style + lint
martinju Oct 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,11 @@ export(compute_estimates)
export(compute_shapley_new)
export(compute_time)
export(compute_vS)
export(compute_vS_forecast)
export(correction_matrix_cpp)
export(create_coalition_table)
export(explain)
export(explain_forecast)
export(finalize_explanation)
export(finalize_explanation_forecast)
export(get_adaptive_arguments_default)
export(get_cov_mat)
export(get_data_specs)
Expand Down
6 changes: 3 additions & 3 deletions R/check_convergence.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ check_convergence <- function(internal) {

n_sampled_coalitions <- internal$iter_list[[iter]]$n_coalitions - 2 # Subtract the zero and full predictions

max_sd <- dt_shapley_sd[, max(.SD), .SDcols = -1, by = .I]$V1 # Max per prediction
max_sd <- dt_shapley_sd[, max(.SD, na.rm = TRUE), .SDcols = -1, by = .I]$V1 # Max per prediction
max_sd0 <- max_sd * sqrt(n_sampled_coalitions) # Scales UP the sd as it scales at this rate

dt_shapley_est0 <- copy(dt_shapley_est)
Expand All @@ -33,8 +33,8 @@ check_convergence <- function(internal) {
} else {
converged_exact <- FALSE
if (!is.null(convergence_tolerance)) {
dt_shapley_est0[, maxval := max(.SD), .SDcols = -c(1, 2), by = .I]
dt_shapley_est0[, minval := min(.SD), .SDcols = -c(1, 2), by = .I]
dt_shapley_est0[, maxval := max(.SD, na.rm = TRUE), .SDcols = -c(1, 2), by = .I]
dt_shapley_est0[, minval := min(.SD, na.rm = TRUE), .SDcols = -c(1, 2), by = .I]
dt_shapley_est0[, max_sd0 := max_sd0]
dt_shapley_est0[, req_samples := (max_sd0 / ((maxval - minval) * convergence_tolerance))^2]
dt_shapley_est0[, conv_measure := max_sd0 / ((maxval - minval) * sqrt(n_sampled_coalitions))]
Expand Down
66 changes: 50 additions & 16 deletions R/compute_estimates.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#' @keywords internal
compute_estimates <- function(internal, vS_list) {
verbose <- internal$parameters$verbose
cli_id <- internal$parameter$cli_id
type <- internal$parameters$type

internal$timing_list$compute_vS <- Sys.time()

Expand Down Expand Up @@ -40,7 +40,7 @@ compute_estimates <- function(internal, vS_list) {
cli::cli_progress_step("Boostrapping Shapley value sds")
}

dt_shapley_sd <- bootstrap_shapley_new(internal, n_boot_samps = n_boot_samps, processed_vS_list$dt_vS)
dt_shapley_sd <- bootstrap_shapley(internal, n_boot_samps = n_boot_samps, processed_vS_list$dt_vS)

internal$timing_list$compute_bootstrap <- Sys.time()
} else {
Expand All @@ -50,10 +50,12 @@ compute_estimates <- function(internal, vS_list) {


# Adding explain_id to the output dt
dt_shapley_est[, explain_id := .I]
setcolorder(dt_shapley_est, "explain_id")
dt_shapley_sd[, explain_id := .I]
setcolorder(dt_shapley_sd, "explain_id")
if (type != "forecast") {
dt_shapley_est[, explain_id := .I]
setcolorder(dt_shapley_est, "explain_id")
dt_shapley_sd[, explain_id := .I]
setcolorder(dt_shapley_sd, "explain_id")
}


internal$iter_list[[iter]]$dt_shapley_est <- dt_shapley_est
Expand Down Expand Up @@ -137,9 +139,10 @@ compute_shapley_new <- function(internal, dt_vS) {

# If multiple horizons with explain_forecast are used, we only distribute value to those used at each horizon
if (type == "forecast") {
id_coalition_mapper_dt <- internal$objects$id_coalition_mapper_dt
id_coalition_mapper_dt <- internal$iter_list[[iter]]$id_coalition_mapper_dt
horizon <- internal$parameters$horizon
cols_per_horizon <- internal$objects$cols_per_horizon
shap_names <- internal$parameters$shap_names
W_list <- internal$objects$W_list

kshap_list <- list()
Expand Down Expand Up @@ -260,21 +263,47 @@ bootstrap_shapley <- function(internal, dt_vS, n_boot_samps = 100, seed = 123) {
return(dt_kshap_boot_sd)
}

bootstrap_shapley_new <- function(internal, dt_vS, n_boot_samps = 100, seed = 123) {
bootstrap_shapley <- function(internal, dt_vS, n_boot_samps = 100, seed = 123) {
iter <- length(internal$iter_list)
type <- internal$parameters$type
is_groupwise <- internal$parameters$is_groupwise
X_list <- internal$iter_list[[iter]]$X_list

X <- internal$iter_list[[iter]]$X
result <- list()
if (type == "forecast") {
n_explain <- internal$parameters$n_explain
for (i in seq_along(X_list)) {
X <- X_list[[i]]
if (is_groupwise) {
n_shapley_values <- length(internal$data$shap_names)
shap_names <- internal$data$shap_names
} else {
n_shapley_values <- length(internal$parameters$horizon_features[[i]])
shap_names <- internal$parameters$horizon_features[[i]]
}
dt_cols <- c(1, seq_len(n_explain) + (i - 1) * n_explain + 1)
dt_vS_this <- dt_vS[, ..dt_cols]
result[[i]] <- bootstrap_shapley_inner(X, n_shapley_values, shap_names, internal, dt_vS_this, n_boot_samps, seed)
}
result <- rbindlist(result, fill = TRUE)
} else {
X <- internal$iter_list[[iter]]$X
n_shapley_values <- internal$parameters$n_shapley_values
shap_names <- internal$parameters$shap_names
result <- bootstrap_shapley_inner(X, n_shapley_values, shap_names, internal, dt_vS, n_boot_samps, seed)
}
return(result)
}

set.seed(seed)
bootstrap_shapley_inner <- function(X, n_shapley_values, shap_names, internal, dt_vS, n_boot_samps = 100, seed = 123) {
type <- internal$parameters$type
iter <- length(internal$iter_list)

is_groupwise <- internal$parameters$is_groupwise
set.seed(seed)

n_explain <- internal$parameters$n_explain
paired_shap_sampling <- internal$parameters$paired_shap_sampling
shapley_reweight <- internal$parameters$shapley_reweighting
shap_names <- internal$parameters$shap_names
n_shapley_values <- internal$parameters$n_shapley_values


X_org <- copy(X)

Expand All @@ -300,7 +329,6 @@ bootstrap_shapley_new <- function(internal, dt_vS, n_boot_samps = 100, seed = 12

X_boot00[, boot_id := rep(seq(n_boot_samps), times = n_coalitions_boot / 2)]


X_boot00_paired <- copy(X_boot00[, .(coalitions, boot_id)])
X_boot00_paired[, coalitions := lapply(coalitions, function(x) seq(n_shapley_values)[-x])]
X_boot00_paired[, coalitions_tmp := sapply(coalitions, paste, collapse = " ")]
Expand Down Expand Up @@ -338,7 +366,13 @@ bootstrap_shapley_new <- function(internal, dt_vS, n_boot_samps = 100, seed = 12
X_boot[, sample_freq := .N / n_coalitions_boot, by = .(id_coalition, boot_id)]
X_boot <- unique(X_boot, by = c("id_coalition", "boot_id"))
X_boot[, shapley_weight := sample_freq]
X_boot[coalition_size %in% c(0, n_shapley_values), shapley_weight := X_org[1, shapley_weight]]
if (type == "forecast") {
id_coalition_mapper_dt <- internal$iter_list[[iter]]$id_coalition_mapper_dt
full_ids <- id_coalition_mapper_dt$id_coalition[id_coalition_mapper_dt$full]
X_boot[coalition_size == 0 | id_coalition %in% full_ids, shapley_weight := X_org[1, shapley_weight]]
} else {
X_boot[coalition_size %in% c(0, n_shapley_values), shapley_weight := X_org[1, shapley_weight]]
}
}

for (i in seq_len(n_boot_samps)) {
Expand Down
39 changes: 2 additions & 37 deletions R/compute_vS.R
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,8 @@ compute_preds <- function(
if (type == "forecast") {
dt[, (pred_cols) := predict_model(
x = model,
newdata = .SD[, 1:n_endo],
newreg = .SD[, -(1:n_endo)],
newdata = .SD[, .SD, .SDcols = seq_len(n_endo)],
newreg = .SD[, .SD, .SDcols = seq_len(length(feature_names) - n_endo) + n_endo],
horizon = horizon,
explain_idx = explain_idx[id],
explain_lags = explain_lags,
Expand All @@ -263,38 +263,3 @@ compute_MCint <- function(dt, pred_cols = "p_hat") {

return(dt_mat)
}


#' Computes `v(S)` for all features subsets `S`.
#'
#' @inheritParams default_doc
#' @inheritParams explain
#'
#' @param method Character
#' Indicates whether the lappy method (default) or loop method should be used.
#'
#' @export
compute_vS_forecast <- function(internal, model, predict_model, method = "future") {
# old function used only for forecast temporary
S_batch <- internal$objects$S_batch



if (method == "future") {
ret <- future_compute_vS_batch(S_batch = S_batch, internal = internal, model = model, predict_model = predict_model)
} else {
# Doing the same as above without future without progressbar or paralellization
ret <- list()
for (i in seq_along(S_batch)) {
S <- S_batch[[i]]
ret[[i]] <- batch_compute_vS(
S = S,
internal = internal,
model = model,
predict_model = predict_model
)
}
}

return(ret)
}
Loading
Loading