diff --git a/R/check_convergence.R b/R/check_convergence.R index 40c9608f..4b071066 100644 --- a/R/check_convergence.R +++ b/R/check_convergence.R @@ -20,7 +20,7 @@ check_convergence <- function(internal) { n_sampled_coalitions <- internal$iter_list[[iter]]$n_coalitions - 2 # Subtract the zero and full predictions - max_sd <- dt_shapley_sd[, max(.SD), .SDcols = -1, by = .I]$V1 # Max per prediction + max_sd <- dt_shapley_sd[, max(.SD, na.rm = TRUE), .SDcols = -1, by = .I]$V1 # Max per prediction max_sd0 <- max_sd * sqrt(n_sampled_coalitions) # Scales UP the sd as it scales at this rate dt_shapley_est0 <- copy(dt_shapley_est) diff --git a/R/compute_estimates.R b/R/compute_estimates.R index edc31583..7cb4a6e3 100644 --- a/R/compute_estimates.R +++ b/R/compute_estimates.R @@ -40,7 +40,7 @@ compute_estimates <- function(internal, vS_list) { cli::cli_progress_step("Boostrapping Shapley value sds") } - dt_shapley_sd <- bootstrap_shapley_new(internal, n_boot_samps = n_boot_samps, processed_vS_list$dt_vS) + dt_shapley_sd <- bootstrap_shapley_outer(internal, n_boot_samps = n_boot_samps, processed_vS_list$dt_vS) internal$timing_list$compute_bootstrap <- Sys.time() } else { @@ -261,25 +261,46 @@ bootstrap_shapley <- function(internal, dt_vS, n_boot_samps = 100, seed = 123) { return(dt_kshap_boot_sd) } -bootstrap_shapley_new <- function(internal, dt_vS, n_boot_samps = 100, seed = 123) { +bootstrap_shapley_outer <- function (internal, dt_vS, n_boot_samps = 100, seed = 123) { iter <- length(internal$iter_list) type <- internal$parameters$type + is_groupwise <- internal$parameters$is_groupwise - X <- internal$iter_list[[iter]]$X + result <- list() + if (type == "forecast") { + horizon <- internal$parameters$horizon + n_explain <- internal$parameters$n_explain + for (i in seq_along(internal$objects$X_list)) { + X <- internal$objects$X_list[[i]] + if (is_groupwise) { + n_shapley_values <- length(internal$data$shap_names) + shap_names <- internal$data$shap_names + } else { + n_shapley_values <- length(internal$parameters$horizon_features[[i]]) + shap_names <- internal$parameters$horizon_features[[i]] + } + dt_cols <- c(1, seq_len(n_explain) + (i - 1) * n_explain + 1) + dt_vS_this <- dt_vS[, ..dt_cols] + result[[i]] <- bootstrap_shapley_new(X, n_shapley_values, shap_names, internal, dt_vS_this, n_boot_samps, seed) + } + result <- rbindlist(result, fill = TRUE) + } else { + X <- internal$iter_list[[iter]]$X + n_shapley_values <- internal$parameters$n_shapley_values + shap_names <- internal$parameters$shap_names + result <- bootstrap_shapley_new(X, n_shapley_values, shap_names, internal, dt_vS, n_boot_samps, seed) + } + return(result) +} - set.seed(seed) +bootstrap_shapley_new <- function(X, n_shapley_values, shap_names, internal, dt_vS, n_boot_samps = 100, seed = 123) { + type <- internal$parameters$type - is_groupwise <- internal$parameters$is_groupwise + set.seed(seed) n_explain <- internal$parameters$n_explain - if (internal$parameters$type == "forecast") { - n_explain <- n_explain * internal$parameters$horizon - } paired_shap_sampling <- internal$parameters$paired_shap_sampling shapley_reweight <- internal$parameters$shapley_reweighting - shap_names <- internal$parameters$shap_names - n_shapley_values <- internal$parameters$n_shapley_values - X_org <- copy(X) @@ -305,7 +326,6 @@ bootstrap_shapley_new <- function(internal, dt_vS, n_boot_samps = 100, seed = 12 X_boot00[, boot_id := rep(seq(n_boot_samps), times = n_coalitions_boot / 2)] - X_boot00_paired <- copy(X_boot00[, .(coalitions, boot_id)]) X_boot00_paired[, coalitions := lapply(coalitions, function(x) seq(n_shapley_values)[-x])] X_boot00_paired[, coalitions_tmp := sapply(coalitions, paste, collapse = " ")] @@ -325,14 +345,7 @@ bootstrap_shapley_new <- function(internal, dt_vS, n_boot_samps = 100, seed = 12 X_boot[, sample_freq := .N / n_coalitions_boot, by = .(id_coalition, boot_id)] X_boot <- unique(X_boot, by = c("id_coalition", "boot_id")) X_boot[, shapley_weight := sample_freq] - if (type == "forecast") { - # Filter out everything which represents empty (i.e. no changed features) and everything which is all (i.e. all features are changed, will be multiple to filter out for horizon > 1). - full_ids <- internal$objects$id_coalition_mapper_dt$id_coalition[internal$objects$id_coalition_mapper_dt$full] - X_boot[coalition_size == 0 | id_coalition %in% full_ids, shapley_weight := X_org[1, shapley_weight]] - } else { - X_boot[coalition_size %in% c(0, n_shapley_values), shapley_weight := X_org[1, shapley_weight]] - } - + X_boot[coalition_size %in% c(0, n_shapley_values), shapley_weight := X_org[1, shapley_weight]] } else { X_boot0 <- X_samp[ sample.int( diff --git a/R/shapley_setup.R b/R/shapley_setup.R index 088bfb0b..fa27cf9b 100644 --- a/R/shapley_setup.R +++ b/R/shapley_setup.R @@ -735,7 +735,7 @@ shapley_setup_forecast <- function(internal) { W <- NULL # Included for consistency. Necessary weights are in W_list instead coalition_map <- X[, .(id_coalition, - coalitions_str = sapply(coalitions, paste, collapse = " ") + coalitions_str = sapply(features, paste, collapse = " ") )] ## Get feature matrix --------- @@ -763,6 +763,7 @@ shapley_setup_forecast <- function(internal) { internal$iter_list[[iter]]$W <- W internal$iter_list[[iter]]$S <- S internal$objects$id_coalition_mapper_dt <- id_coalition_mapper_dt + internal$objects$X_list <- X_list internal$iter_list[[iter]]$coalition_map <- coalition_map internal$iter_list[[iter]]$S_batch <- create_S_batch(internal) diff --git a/tests/testthat/helper-ar-arima.R b/tests/testthat/helper-ar-arima.R index 47944e87..ed210d73 100644 --- a/tests/testthat/helper-ar-arima.R +++ b/tests/testthat/helper-ar-arima.R @@ -1,8 +1,8 @@ options(digits = 5) # To avoid round off errors when printing output on different systems - - data <- data.table::as.data.table(airquality) +data[, Solar.R := ifelse(is.na(Solar.R), mean(Solar.R, na.rm = TRUE), Solar.R)] +data[, Ozone := ifelse(is.na(Ozone), mean(Ozone, na.rm = TRUE), Ozone)] model_ar_temp <- ar(data$Temp, order = 2) model_ar_temp$n.ahead <- 3 @@ -10,6 +10,7 @@ model_ar_temp$n.ahead <- 3 p0_ar <- rep(mean(data$Temp), 3) model_arima_temp <- arima(data$Temp[1:150], c(2, 1, 0), xreg = data$Wind[1:150]) +model_arima_temp2 <- arima(data$Temp[1:150], c(2, 1, 0), xreg = data[1:150, c("Wind", "Solar.R", "Ozone")]) model_arima_temp_noxreg <- arima(data$Temp[1:150], c(2, 1, 0)) diff --git a/tests/testthat/test-forecast-output.R b/tests/testthat/test-forecast-output.R index 28f05d31..b7b251f1 100644 --- a/tests/testthat/test-forecast-output.R +++ b/tests/testthat/test-forecast-output.R @@ -39,6 +39,52 @@ test_that("forecast_output_arima_numeric", { ) }) +test_that("forecast_output_arima_numeric_adaptive", { + expect_snapshot_rds( + explain_forecast( + testing = TRUE, + model = model_arima_temp, + y = data[1:150, "Temp"], + xreg = data[, "Wind"], + train_idx = 3:148, + explain_idx = 149:150, + explain_y_lags = 3, + explain_xreg_lags = 3, + horizon = 3, + approach = "empirical", + prediction_zero = p0_ar, + group_lags = FALSE, + max_n_coalitions = 150, + adaptive = TRUE, + adaptive_arguments = list(initial_n_coalitions = 10) + ), + "forecast_output_arima_numeric" + ) +}) + +test_that("forecast_output_arima_numeric_adaptive_groups", { + expect_snapshot_rds( + explain_forecast( + testing = TRUE, + model = model_arima_temp2, + y = data[1:150, "Temp"], + xreg = data[, c("Wind", "Solar.R", "Ozone")], + train_idx = 3:148, + explain_idx = 149:150, + explain_y_lags = 3, + explain_xreg_lags = c(3, 3, 3), + horizon = 3, + approach = "empirical", + prediction_zero = p0_ar, + group_lags = TRUE, + max_n_coalitions = 150, + adaptive = TRUE, + adaptive_arguments = list(initial_n_coalitions = 10, convergence_tolerance = 7e-3) + ), + "forecast_output_arima_numeric" + ) +}) + test_that("forecast_output_arima_numeric_no_xreg", { expect_snapshot_rds( explain_forecast(