Skip to content

Commit

Permalink
Bugfix keep_samp_for_vS with iterative approach (#417)
Browse files Browse the repository at this point in the history
  • Loading branch information
martinju authored Nov 14, 2024
1 parent 2b3f118 commit 7f9b44f
Show file tree
Hide file tree
Showing 34 changed files with 99 additions and 7 deletions.
51 changes: 44 additions & 7 deletions R/compute_vS.R
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ compute_MCint <- function(dt, pred_cols = "p_hat") {
#' @keywords internal
append_vS_list <- function(vS_list, internal) {
iter <- length(internal$iter_list)
keep_samp_for_vS <- internal$parameters$output_args$keep_samp_for_vS

# Adds v_S output above to any vS_list already computed
if (iter > 1) {
Expand All @@ -249,17 +250,53 @@ append_vS_list <- function(vS_list, internal) {
prev_vS_list_new <- list()

# Applies the mapper to update the prev_vS_list ot the new id_coalition numbering
for (k in seq_along(prev_vS_list)) {
prev_vS_list_new[[k]] <- merge(prev_vS_list[[k]],
id_coalitions_mapper[, .(id_coalition, id_coalition_new)],
by = "id_coalition"
)
prev_vS_list_new[[k]][, id_coalition := id_coalition_new]
prev_vS_list_new[[k]][, id_coalition_new := NULL]
if (isFALSE(keep_samp_for_vS)) {
for (k in seq_along(prev_vS_list)) {
this_vS <- prev_vS_list[[k]]

this_vS_new <- merge(this_vS,
id_coalitions_mapper[, .(id_coalition, id_coalition_new)],
by = "id_coalition"
)

this_vS_new[, id_coalition := id_coalition_new]
this_vS_new[, id_coalition_new := NULL]


prev_vS_list_new[[k]] <- this_vS_new
}
} else {
for (k in seq_along(prev_vS_list)) {
this_vS <- prev_vS_list[[k]]$dt_vS
this_samp_for_vS <- prev_vS_list[[k]]$dt_samp_for_vS


this_vS_new <- merge(this_vS,
id_coalitions_mapper[, .(id_coalition, id_coalition_new)],
by = "id_coalition"
)

this_vS_new[, id_coalition := id_coalition_new]
this_vS_new[, id_coalition_new := NULL]

this_samp_for_vS_new <- merge(this_samp_for_vS,
id_coalitions_mapper[, .(id_coalition, id_coalition_new)],
by = "id_coalition"
)

this_samp_for_vS_new[, id_coalition := id_coalition_new]
this_samp_for_vS_new[, id_coalition_new := NULL]


prev_vS_list_new[[k]] <- list(dt_vS = this_vS_new, dt_samp_for_vS = this_samp_for_vS_new)
}
}
names(prev_vS_list_new) <- names(prev_vS_list)

# Merge the new vS_list with the old vS_list
vS_list <- c(prev_vS_list_new, vS_list)
}


return(vS_list)
}
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -982,3 +982,38 @@
2: 2 42.44 4.919 -4.878 -11.9086 -0.8405 -1.1714
3: 3 42.44 7.447 -25.748 0.0324 -0.1976 0.8978

# output_lm_numeric_independence_keep_samp_for_vS

Code
(out <- code)
Message
Success with message:
max_n_coalitions is NULL or larger than or 2^n_features = 32,
and is therefore set to 2^n_features = 32.
* Model class: <lm>
* Approach: independence
* Iterative estimation: TRUE
* Number of feature-wise Shapley values: 5
* Number of observations to explain: 3
-- iterative computation started --
-- Iteration 1 -----------------------------------------------------------------
i Using 5 of 32 coalitions, 5 new.
-- Iteration 2 -----------------------------------------------------------------
i Using 10 of 32 coalitions, 4 new.
-- Iteration 3 -----------------------------------------------------------------
i Using 12 of 32 coalitions, 2 new.
-- Iteration 4 -----------------------------------------------------------------
i Using 16 of 32 coalitions, 4 new.
Output
explain_id none Solar.R Wind Temp Month Day
<int> <num> <num> <num> <num> <num> <num>
1: 1 42.44 -4.541 8.330 17.491 -5.585 -3.093
2: 2 42.44 2.246 -3.285 -5.258 -5.585 -1.997
3: 3 42.44 3.704 -18.549 -1.467 -2.545 1.289

Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
File renamed without changes.
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -310,3 +310,23 @@ test_that("output_verbose_1_3_4_5", {
"output_verbose_1_3_4_5"
)
})


# Just checking that internal$output$dt_samp_for_vS works for iterative
test_that("output_lm_numeric_independence_keep_samp_for_vS", {
expect_snapshot_rds(
(out <- explain(
testing = TRUE,
model = model_lm_numeric,
x_explain = x_explain_numeric,
x_train = x_train_numeric,
approach = "independence",
phi0 = p0,
output_args = list(keep_samp_for_vS = TRUE),
iterative = TRUE
)),
"output_lm_numeric_independence_keep_samp_for_vS"
)

expect_false(is.null(out$internal$output$dt_samp_for_vS))
})
File renamed without changes.

0 comments on commit 7f9b44f

Please sign in to comment.