Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/jon-fixes-rebase' into jon-fixes…
Browse files Browse the repository at this point in the history
…-rebase

# Conflicts:
#	R/compute_estimates.R
  • Loading branch information
jonlachmann committed Oct 17, 2024
2 parents a0ea854 + d46d501 commit 5ed90bb
Show file tree
Hide file tree
Showing 16 changed files with 558 additions and 293 deletions.
13 changes: 7 additions & 6 deletions R/compute_estimates.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#' @keywords internal
compute_estimates <- function(internal, vS_list) {
verbose <- internal$parameters$verbose
cli_id <- internal$parameter$cli_id
type <- internal$parameters$type

internal$timing_list$compute_vS <- Sys.time()

Expand Down Expand Up @@ -50,10 +50,12 @@ compute_estimates <- function(internal, vS_list) {


# Adding explain_id to the output dt
dt_shapley_est[, explain_id := .I]
setcolorder(dt_shapley_est, "explain_id")
dt_shapley_sd[, explain_id := .I]
setcolorder(dt_shapley_sd, "explain_id")
if (type != "forecast") {
dt_shapley_est[, explain_id := .I]
setcolorder(dt_shapley_est, "explain_id")
dt_shapley_sd[, explain_id := .I]
setcolorder(dt_shapley_sd, "explain_id")
}


internal$iter_list[[iter]]$dt_shapley_est <- dt_shapley_est
Expand Down Expand Up @@ -268,7 +270,6 @@ bootstrap_shapley_outer <- function (internal, dt_vS, n_boot_samps = 100, seed =

result <- list()
if (type == "forecast") {
horizon <- internal$parameters$horizon
n_explain <- internal$parameters$n_explain
for (i in seq_along(internal$objects$X_list)) {
X <- internal$objects$X_list[[i]]
Expand Down
7 changes: 6 additions & 1 deletion R/explain_forecast.R
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ explain_forecast <- function(model,
prediction_zero,
max_n_coalitions = NULL,
adaptive = NULL,
adaptive_arguments = list(),
shapley_reweighting = "on_all_cond",
group_lags = TRUE,
group = NULL,
n_MC_samples = 1e3,
Expand Down Expand Up @@ -134,6 +136,9 @@ explain_forecast <- function(model,
type = "forecast",
horizon = horizon,
adaptive = adaptive,
adaptive_arguments = adaptive_arguments,
shapley_reweighting = shapley_reweighting,
init_time = init_time,
y = y,
xreg = xreg,
train_idx = train_idx,
Expand Down Expand Up @@ -397,7 +402,7 @@ lag_data <- function(x, lags) {
reg_forecast_setup <- function(x, horizon, group) {
fcast <- matrix(NA, nrow(x) - horizon + 1, 0)
names <- character()
horizon_group <- lapply(seq_len(horizon), function (i) names(group)[!(names(group) %in% colnames(x))])
horizon_group <- lapply(seq_len(horizon), function(i) names(group)[!(names(group) %in% colnames(x))])
for (i in seq_len(ncol(x))) {
names_i <- paste0(colnames(x)[i], ".F", seq_len(horizon))
names <- c(names, names_i)
Expand Down
3 changes: 2 additions & 1 deletion R/get_predict_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ test_predict_model <- function(x_test, predict_model, model, internal) {
tmp <- tryCatch(predict_model(
x = model,
newdata = x_test[, .SD, .SDcols = seq_len(internal$data$n_endo), drop = FALSE],
newreg = x_test[, .SD, .SDcols = seq_len(ncol(x_test) - internal$data$n_endo) + internal$data$n_endo, drop = FALSE],
newreg = x_test[, .SD, .SDcols = seq_len(ncol(x_test) - internal$data$n_endo) + internal$data$n_endo,
drop = FALSE],
horizon = internal$parameters$horizon,
explain_idx = rep(internal$parameters$explain_idx[1], 2),
y = internal$data$y,
Expand Down
7 changes: 5 additions & 2 deletions R/setup.R
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ setup <- function(x_train,
internal$parameters$group_lags <- group_lags

# TODO: Consider handling this parameter update somewhere else (like in get_extra_parameters?)

} else {
internal$data <- get_data(x_train, x_explain)
}
Expand Down Expand Up @@ -487,7 +486,8 @@ get_extra_parameters <- function(internal, type) {
if (internal$parameters$group_lags) {
internal$parameters$group <- internal$data$group
}
internal$parameters$horizon_features <- lapply(internal$data$horizon_group, function (x) as.character(unlist(internal$data$group[x])))
internal$parameters$horizon_features <- lapply(internal$data$horizon_group,
function(x) as.character(unlist(internal$data$group[x])))
}

# get number of features and observations to explain
Expand Down Expand Up @@ -528,6 +528,9 @@ get_extra_parameters <- function(internal, type) {
} else {
internal$parameters$shap_names <- internal$parameters$group_names
}
} else {
# For normal explain
internal$parameters$shap_names <- internal$parameters$group_names
}

internal$parameters$n_groups <- length(group)
Expand Down
15 changes: 8 additions & 7 deletions R/shapley_setup.R
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,10 @@ create_S_batch <- function(internal, seed = NULL) {

coalition_map <- internal$iter_list[[iter]]$coalition_map

if (type == "forecast") {
id_coalition_mapper_dt <- internal$objects$id_coalition_mapper_dt
full_ids <- id_coalition_mapper_dt$id_coalition[id_coalition_mapper_dt$full]
}

X0 <- copy(internal$iter_list[[iter]]$X)

Expand All @@ -551,7 +555,6 @@ create_S_batch <- function(internal, seed = NULL) {

if (length(approach0) > 1) {
if (type == "forecast") {
full_ids <- internal$objects$id_coalition_mapper_dt$id_coalition[internal$objects$id_coalition_mapper_dt$full]
X0[!(coalition_size == 0 | id_coalition %in% full_ids), approach := approach0[coalition_size]]
} else {
X0[!(coalition_size %in% c(0, n_shapley_values)), approach := approach0[coalition_size]]
Expand Down Expand Up @@ -604,7 +607,6 @@ create_S_batch <- function(internal, seed = NULL) {
}
} else {
if (type == "forecast") {
full_ids <- internal$objects$id_coalition_mapper_dt$id_coalition[internal$objects$id_coalition_mapper_dt$full]
X0[!(coalition_size == 0 | id_coalition %in% full_ids), approach := approach0]
} else {
X0[!(coalition_size %in% c(0, n_shapley_values)), approach := approach0]
Expand All @@ -615,7 +617,6 @@ create_S_batch <- function(internal, seed = NULL) {
data.table::setorder(X0, randomorder)
data.table::setorder(X0, shapley_weight)
if (type == "forecast") {
full_ids <- internal$objects$id_coalition_mapper_dt$id_coalition[internal$objects$id_coalition_mapper_dt$full]
X0[!(coalition_size == 0 | id_coalition %in% full_ids), batch := ceiling(.I / .N * n_batches)]
} else {
X0[!(coalition_size %in% c(0, n_shapley_values)), batch := ceiling(.I / .N * n_batches)]
Expand All @@ -625,7 +626,6 @@ create_S_batch <- function(internal, seed = NULL) {
# Assigning batch 1 (which always is the smallest) to the full prediction.
X0[, randomorder := NULL]
if (type == "forecast") {
full_ids <- internal$objects$id_coalition_mapper_dt$id_coalition[internal$objects$id_coalition_mapper_dt$full]
X0[id_coalition %in% full_ids, batch := 1]
} else {
X0[id_coalition == max(id_coalition), batch := 1]
Expand Down Expand Up @@ -686,10 +686,11 @@ shapley_setup_forecast <- function(internal) {
# Apply create_coalition_table, weigth_matrix and coalition_matrix_cpp to each of the different horizons
for (i in seq_along(horizon_features)) {
if (is_groupwise && !is.null(horizon_group)) {
this_coal_feature_list <- coal_feature_list[sapply(names(coal_feature_list), function (x) x %in% horizon_group[[i]])]
this_coal_feature_list <- coal_feature_list[sapply(names(coal_feature_list),
function(x) x %in% horizon_group[[i]])]
} else {
this_coal_feature_list <- lapply(coal_feature_list, function(x) x[x %in% horizon_features[[i]]])
this_coal_feature_list <- this_coal_feature_list[sapply(this_coal_feature_list, function (x) length(x) != 0)]
this_coal_feature_list <- this_coal_feature_list[sapply(this_coal_feature_list, function(x) length(x) != 0)]
}

n_this_featcomb <- length(this_coal_feature_list)
Expand All @@ -704,7 +705,7 @@ shapley_setup_forecast <- function(internal) {
prev_coal_samples = prev_coal_samples,
coal_feature_list = this_coal_feature_list,
approach0 = approach,
shapley_reweighting = FALSE
shapley_reweighting = shapley_reweighting
)

W_list[[i]] <- weight_matrix(
Expand Down
19 changes: 19 additions & 0 deletions man/explain_forecast.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 5ed90bb

Please sign in to comment.