Skip to content

Commit

Permalink
* Add tests for adaptive in explain_forecast.
Browse files Browse the repository at this point in the history
* Fix bootstrapping for multiple horizons in explain_forecast.
  • Loading branch information
jonlachmann committed Oct 17, 2024
1 parent 68a65b7 commit a0ea854
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 24 deletions.
2 changes: 1 addition & 1 deletion R/check_convergence.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ check_convergence <- function(internal) {

n_sampled_coalitions <- internal$iter_list[[iter]]$n_coalitions - 2 # Subtract the zero and full predictions

max_sd <- dt_shapley_sd[, max(.SD), .SDcols = -1, by = .I]$V1 # Max per prediction
max_sd <- dt_shapley_sd[, max(.SD, na.rm = TRUE), .SDcols = -1, by = .I]$V1 # Max per prediction
max_sd0 <- max_sd * sqrt(n_sampled_coalitions) # Scales UP the sd as it scales at this rate

dt_shapley_est0 <- copy(dt_shapley_est)
Expand Down
53 changes: 33 additions & 20 deletions R/compute_estimates.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ compute_estimates <- function(internal, vS_list) {
cli::cli_progress_step("Boostrapping Shapley value sds")
}

dt_shapley_sd <- bootstrap_shapley_new(internal, n_boot_samps = n_boot_samps, processed_vS_list$dt_vS)
dt_shapley_sd <- bootstrap_shapley_outer(internal, n_boot_samps = n_boot_samps, processed_vS_list$dt_vS)

internal$timing_list$compute_bootstrap <- Sys.time()
} else {
Expand Down Expand Up @@ -261,25 +261,46 @@ bootstrap_shapley <- function(internal, dt_vS, n_boot_samps = 100, seed = 123) {
return(dt_kshap_boot_sd)
}

bootstrap_shapley_new <- function(internal, dt_vS, n_boot_samps = 100, seed = 123) {
bootstrap_shapley_outer <- function (internal, dt_vS, n_boot_samps = 100, seed = 123) {
iter <- length(internal$iter_list)
type <- internal$parameters$type
is_groupwise <- internal$parameters$is_groupwise

X <- internal$iter_list[[iter]]$X
result <- list()
if (type == "forecast") {
horizon <- internal$parameters$horizon
n_explain <- internal$parameters$n_explain
for (i in seq_along(internal$objects$X_list)) {
X <- internal$objects$X_list[[i]]
if (is_groupwise) {
n_shapley_values <- length(internal$data$shap_names)
shap_names <- internal$data$shap_names
} else {
n_shapley_values <- length(internal$parameters$horizon_features[[i]])
shap_names <- internal$parameters$horizon_features[[i]]
}
dt_cols <- c(1, seq_len(n_explain) + (i - 1) * n_explain + 1)
dt_vS_this <- dt_vS[, ..dt_cols]
result[[i]] <- bootstrap_shapley_new(X, n_shapley_values, shap_names, internal, dt_vS_this, n_boot_samps, seed)
}
result <- rbindlist(result, fill = TRUE)
} else {
X <- internal$iter_list[[iter]]$X
n_shapley_values <- internal$parameters$n_shapley_values
shap_names <- internal$parameters$shap_names
result <- bootstrap_shapley_new(X, n_shapley_values, shap_names, internal, dt_vS, n_boot_samps, seed)
}
return(result)
}

set.seed(seed)
bootstrap_shapley_new <- function(X, n_shapley_values, shap_names, internal, dt_vS, n_boot_samps = 100, seed = 123) {
type <- internal$parameters$type

is_groupwise <- internal$parameters$is_groupwise
set.seed(seed)

n_explain <- internal$parameters$n_explain
if (internal$parameters$type == "forecast") {
n_explain <- n_explain * internal$parameters$horizon
}
paired_shap_sampling <- internal$parameters$paired_shap_sampling
shapley_reweight <- internal$parameters$shapley_reweighting
shap_names <- internal$parameters$shap_names
n_shapley_values <- internal$parameters$n_shapley_values


X_org <- copy(X)

Expand All @@ -305,7 +326,6 @@ bootstrap_shapley_new <- function(internal, dt_vS, n_boot_samps = 100, seed = 12

X_boot00[, boot_id := rep(seq(n_boot_samps), times = n_coalitions_boot / 2)]


X_boot00_paired <- copy(X_boot00[, .(coalitions, boot_id)])
X_boot00_paired[, coalitions := lapply(coalitions, function(x) seq(n_shapley_values)[-x])]
X_boot00_paired[, coalitions_tmp := sapply(coalitions, paste, collapse = " ")]
Expand All @@ -325,14 +345,7 @@ bootstrap_shapley_new <- function(internal, dt_vS, n_boot_samps = 100, seed = 12
X_boot[, sample_freq := .N / n_coalitions_boot, by = .(id_coalition, boot_id)]
X_boot <- unique(X_boot, by = c("id_coalition", "boot_id"))
X_boot[, shapley_weight := sample_freq]
if (type == "forecast") {
# Filter out everything which represents empty (i.e. no changed features) and everything which is all (i.e. all features are changed, will be multiple to filter out for horizon > 1).
full_ids <- internal$objects$id_coalition_mapper_dt$id_coalition[internal$objects$id_coalition_mapper_dt$full]
X_boot[coalition_size == 0 | id_coalition %in% full_ids, shapley_weight := X_org[1, shapley_weight]]
} else {
X_boot[coalition_size %in% c(0, n_shapley_values), shapley_weight := X_org[1, shapley_weight]]
}

X_boot[coalition_size %in% c(0, n_shapley_values), shapley_weight := X_org[1, shapley_weight]]
} else {
X_boot0 <- X_samp[
sample.int(
Expand Down
3 changes: 2 additions & 1 deletion R/shapley_setup.R
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,7 @@ shapley_setup_forecast <- function(internal) {
W <- NULL # Included for consistency. Necessary weights are in W_list instead

coalition_map <- X[, .(id_coalition,
coalitions_str = sapply(coalitions, paste, collapse = " ")
coalitions_str = sapply(features, paste, collapse = " ")
)]

## Get feature matrix ---------
Expand Down Expand Up @@ -763,6 +763,7 @@ shapley_setup_forecast <- function(internal) {
internal$iter_list[[iter]]$W <- W
internal$iter_list[[iter]]$S <- S
internal$objects$id_coalition_mapper_dt <- id_coalition_mapper_dt
internal$objects$X_list <- X_list
internal$iter_list[[iter]]$coalition_map <- coalition_map
internal$iter_list[[iter]]$S_batch <- create_S_batch(internal)

Expand Down
5 changes: 3 additions & 2 deletions tests/testthat/helper-ar-arima.R
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
options(digits = 5) # To avoid round off errors when printing output on different systems



data <- data.table::as.data.table(airquality)
data[, Solar.R := ifelse(is.na(Solar.R), mean(Solar.R, na.rm = TRUE), Solar.R)]
data[, Ozone := ifelse(is.na(Ozone), mean(Ozone, na.rm = TRUE), Ozone)]

model_ar_temp <- ar(data$Temp, order = 2)
model_ar_temp$n.ahead <- 3

p0_ar <- rep(mean(data$Temp), 3)

model_arima_temp <- arima(data$Temp[1:150], c(2, 1, 0), xreg = data$Wind[1:150])
model_arima_temp2 <- arima(data$Temp[1:150], c(2, 1, 0), xreg = data[1:150, c("Wind", "Solar.R", "Ozone")])

model_arima_temp_noxreg <- arima(data$Temp[1:150], c(2, 1, 0))

Expand Down
46 changes: 46 additions & 0 deletions tests/testthat/test-forecast-output.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,52 @@ test_that("forecast_output_arima_numeric", {
)
})

test_that("forecast_output_arima_numeric_adaptive", {
expect_snapshot_rds(
explain_forecast(
testing = TRUE,
model = model_arima_temp,
y = data[1:150, "Temp"],
xreg = data[, "Wind"],
train_idx = 3:148,
explain_idx = 149:150,
explain_y_lags = 3,
explain_xreg_lags = 3,
horizon = 3,
approach = "empirical",
prediction_zero = p0_ar,
group_lags = FALSE,
max_n_coalitions = 150,
adaptive = TRUE,
adaptive_arguments = list(initial_n_coalitions = 10)
),
"forecast_output_arima_numeric"
)
})

test_that("forecast_output_arima_numeric_adaptive_groups", {
expect_snapshot_rds(
explain_forecast(
testing = TRUE,
model = model_arima_temp2,
y = data[1:150, "Temp"],
xreg = data[, c("Wind", "Solar.R", "Ozone")],
train_idx = 3:148,
explain_idx = 149:150,
explain_y_lags = 3,
explain_xreg_lags = c(3, 3, 3),
horizon = 3,
approach = "empirical",
prediction_zero = p0_ar,
group_lags = TRUE,
max_n_coalitions = 150,
adaptive = TRUE,
adaptive_arguments = list(initial_n_coalitions = 10, convergence_tolerance = 7e-3)
),
"forecast_output_arima_numeric"
)
})

test_that("forecast_output_arima_numeric_no_xreg", {
expect_snapshot_rds(
explain_forecast(
Expand Down

0 comments on commit a0ea854

Please sign in to comment.