Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restructure explain() for iterative estimation with convergence detection, verbose arguments ++ #396

Merged
merged 177 commits into from
Oct 6, 2024
Merged
Show file tree
Hide file tree
Changes from 148 commits
Commits
Show all changes
177 commits
Select commit Hold shift + click to select a range
6dcb2e3
future testing added
martinju May 6, 2024
b677900
init covergence
martinju May 30, 2024
0847946
man
martinju May 30, 2024
f23a9a2
more work. seems to work all right now
martinju May 31, 2024
58c9fe2
starting to put the non-adaptive method into the same loop. Not done!
martinju May 31, 2024
d1a48bb
make tests for non-adaptive approach work out
martinju Jun 4, 2024
818d6e1
tests still pass on this
martinju Jun 4, 2024
7ac6df5
work on the way to iter list
martinju Jun 4, 2024
419c024
more work
martinju Jun 4, 2024
6f57b9b
where i am now. something is off, crashing in compute_vS
martinju Jun 4, 2024
09eb42b
move to iter_list and more cleanup
martinju Jun 7, 2024
56dd131
div + .Rd
martinju Jun 7, 2024
1201f03
cleanup
martinju Jun 7, 2024
9d82430
.
martinju Jun 7, 2024
2857ae9
[skip actions] rename and fix create_s_batch
martinju Jun 11, 2024
25d3960
remove all iteration relevant internals outside of iter_list
martinju Jun 11, 2024
f8d3a0e
[skip actions] forgot reverting. Non-adaptive shap value tests works
martinju Jun 11, 2024
769f64e
real data sim for NSM2024 (to be deleted)
martinju Aug 1, 2024
ec2b51e
various minor edits
martinju Aug 2, 2024
bed8eff
fix basic error in forecast with paried_shap_sampling
martinju Aug 2, 2024
aa67898
+adaptive
martinju Aug 2, 2024
d140325
create_S_batch_new
martinju Aug 2, 2024
ed4a79d
get old functions back to make forecast work
martinju Aug 2, 2024
21f9f55
adding approach to X
martinju Aug 2, 2024
95a2111
man ++
martinju Aug 2, 2024
8b08767
Revert "create_S_batch_new"
martinju Aug 2, 2024
b64cdf1
add sample_freq=NA to exact method
martinju Aug 2, 2024
917d0da
shapley values are correct at this stage
martinju Aug 2, 2024
962fbd5
make setup tests pass
martinju Aug 2, 2024
34dbcfe
.
martinju Aug 2, 2024
dfc8f72
accept regression snaps
martinju Aug 2, 2024
def449e
fix regression_separates req to call shapley_setup before setup_approach
martinju Aug 2, 2024
60ae3a5
all tests good, except two vaeacs
martinju Aug 2, 2024
30a6fc1
set seed again before shapley_setup to fix failing vaeac test
martinju Aug 2, 2024
73780af
fixed last failing test
martinju Aug 2, 2024
829db34
fixing checks NSE + documentation (not properly)
martinju Aug 2, 2024
175dcbd
exporting weight_matrix
martinju Aug 2, 2024
8fac649
upgraded all packages locally. Might be an issue with fit_times()
martinju Aug 3, 2024
5ca7153
deletes fit_times for regression_surrogate to pass tests
martinju Aug 3, 2024
345cb2c
styler
martinju Aug 3, 2024
151bba7
lintr
martinju Aug 3, 2024
361d8f9
.
martinju Aug 3, 2024
12ca661
apply name changes to test files
martinju Aug 5, 2024
a5c666b
rename regular output name
martinju Aug 5, 2024
86fed31
adding setup adaptive ++
martinju Aug 5, 2024
410c05d
update regular tests
martinju Aug 5, 2024
ae29313
bugfix, improve printing and init adaptive tests
martinju Aug 6, 2024
34c0905
update test files
martinju Aug 6, 2024
970d08d
div
martinju Aug 6, 2024
4e319e7
remove timing arg and add hidden testing arg
martinju Aug 6, 2024
a81a22b
fixing broken testing objects after updates
martinju Aug 6, 2024
8a4f0db
update tests with testing = TRUE, and remove timing = FALSE
martinju Aug 6, 2024
bb4e385
rds files
martinju Aug 7, 2024
643e5f0
styler
martinju Aug 7, 2024
24ae4d4
[skip actions] .
martinju Aug 7, 2024
d0a1ad6
move functions to appropriate files
martinju Aug 7, 2024
854fee7
[skip actions] doc + temporary and hiddenly adding unique_sampling
martinju Aug 7, 2024
17c94ef
add timing + experiment with improved bootstrapping code
martinju Aug 8, 2024
c7f3e2b
[skip actions] fix non-unique sampling
martinju Aug 8, 2024
f95e7ef
init moving to max_n_combinations
martinju Aug 9, 2024
5c5f436
add feature_samples to iter_list in setup for convenience
martinju Aug 9, 2024
21c43dc
simplifying explain view + improve max_n_combinations sets and checks
martinju Aug 9, 2024
2e7e864
man
martinju Aug 9, 2024
699bde0
Merge commit '2e7e86450686f61d6f4c7f63ac87c5857ff0094e' into convergence
martinju Aug 9, 2024
c9b679e
.
martinju Aug 9, 2024
95dda97
[skip actions] remaining stuff of max_n_combinations. Works, i think
martinju Aug 9, 2024
01b017e
[skip actions] remaining stuff of max_n_combinations. Works, i think
martinju Aug 9, 2024
d296e87
new bootstrap introduced with tests
martinju Aug 9, 2024
6dbcaff
making tests work
martinju Aug 9, 2024
303e323
tests OK
martinju Aug 10, 2024
fb6d050
some more ok tests. Forecast dont work as of now
martinju Aug 10, 2024
db6c221
apply the feature_combination stuff also to groups
martinju Aug 10, 2024
63281a6
Not 100% sure this actually works as it should
martinju Aug 10, 2024
1d6fb63
add and fix group tests
martinju Aug 12, 2024
5bc4efe
new
martinju Aug 12, 2024
6086d2a
adaptive OK
martinju Aug 16, 2024
11df7de
all tests pass
martinju Aug 16, 2024
ca58ce4
styler
martinju Aug 16, 2024
6b4931d
man
martinju Aug 16, 2024
803a181
fix checks
martinju Aug 16, 2024
14f360e
Disable rcpp approx solve warnings
martinju Aug 16, 2024
38be193
temporary fix forecast (not adaptive yet)
martinju Aug 16, 2024
45a657e
tests
martinju Aug 16, 2024
28acb00
combinations -> coalitions and merging all features/groups-code
martinju Aug 16, 2024
886f93b
add features to coalition table for both features and groups
martinju Aug 19, 2024
7ba0408
bugfix groups + some plot test updates
martinju Aug 19, 2024
48d6b81
more fixing
martinju Aug 19, 2024
b9bd1fe
adaptive tests
martinju Aug 19, 2024
8b0bd2c
remaining tests
martinju Aug 19, 2024
8bdb6be
forcast tests (something is up with forecast grouping, though)
martinju Aug 20, 2024
3de91ff
[skip actions] style
martinju Aug 20, 2024
7d747b8
adding reweighting strategy on all cond
martinju Aug 21, 2024
8f46e38
[skip actions] add reweighting strategies + non-unique paired sampling
martinju Aug 23, 2024
9daf94e
n_samples -> n_MC_samples
martinju Aug 23, 2024
b4125d2
tests OK
martinju Aug 23, 2024
cd7b0a8
fix iterative with paired sampling
martinju Sep 5, 2024
e3373d3
update tests after bootstrap change + .Rprofile for smoother testing
martinju Sep 5, 2024
fd7734f
add intermediate saving
martinju Sep 6, 2024
9a35c14
working version of continue training
martinju Sep 6, 2024
a064b9c
moves prev_shapr_object handling to setup and add validity test
martinju Sep 6, 2024
e46822a
working
martinju Sep 6, 2024
55434c1
[skip actions] Working
martinju Sep 6, 2024
f4e7931
man + testthat
martinju Sep 9, 2024
7f18f91
new adaptive output testfiles
martinju Sep 9, 2024
681d197
Fix cutting of coalition list per horizon in ```shapley_setup_forecas…
jonlachmann Sep 10, 2024
ccc39af
Merge remote-tracking branch 'jonlachmann/convergence' into convergence
martinju Sep 11, 2024
23016dc
update OK forecast test files
martinju Sep 11, 2024
db7c15d
extra forecast test file update
martinju Sep 11, 2024
a4d02fb
add max_batch_size og min_n_batches
martinju Sep 19, 2024
c04571c
apply the new n_batches settings in practice
martinju Sep 19, 2024
2118826
remove all traces of n_batches in the older code
martinju Sep 26, 2024
5bee4b9
adaptive tests ok
martinju Sep 26, 2024
56fa77b
more tests
martinju Sep 26, 2024
cba572c
adding checks for adaptive argument formats
martinju Sep 27, 2024
1e481e2
update tests
martinju Sep 27, 2024
c780b70
regression
martinju Sep 27, 2024
33301c3
temporary disabling the forecast tests
martinju Sep 27, 2024
2e0e09d
new test files
martinju Sep 27, 2024
cbd01f4
move reweighting and set new defaults
martinju Sep 27, 2024
cc737aa
moving towards new defaults
martinju Sep 27, 2024
e76d068
adpative-output at least OK
martinju Sep 27, 2024
0e4ba15
[skip tests] new test files
martinju Sep 27, 2024
99864b9
[skip actions] other tests ok
martinju Sep 30, 2024
882ed16
Merge remote-tracking branch 'origin/convergence' into convergence
martinju Sep 30, 2024
75ae9a4
[skip actions] .
martinju Sep 30, 2024
92f45c1
[skip actions] documenting explain
martinju Sep 30, 2024
f7f5a49
more documentation
martinju Sep 30, 2024
9095c64
.
martinju Sep 30, 2024
d6bc603
[skip actions] Slight restructure + update of main vignette
martinju Sep 30, 2024
1789783
checks for the adaptive argument
martinju Oct 1, 2024
1a5d034
NSE warnings
martinju Oct 1, 2024
35cd82f
test updates
martinju Oct 1, 2024
59e1c23
man and zzz
martinju Oct 1, 2024
dadbce2
man + tests
martinju Oct 1, 2024
231c862
tmp
martinju Oct 1, 2024
08a9e87
deal with cont estimation for non-adaptive
martinju Oct 1, 2024
3627112
first vignette
martinju Oct 1, 2024
21a2c29
vaeac also need X in setup
martinju Oct 1, 2024
7cb617b
vaeac vignette works
martinju Oct 1, 2024
55f5219
[skip actions] init update of regression vignette
martinju Oct 1, 2024
f56eb49
+ regression
martinju Oct 1, 2024
f093fa0
fix docs
martinju Oct 1, 2024
3f222b3
style
martinju Oct 2, 2024
61336bf
linting
martinju Oct 2, 2024
a02c9c9
fix man
martinju Oct 2, 2024
0d96221
fix vignette
martinju Oct 2, 2024
91de6ec
remove (>= 3.0.0) for testthat for tesitng
martinju Oct 2, 2024
fde5a3e
Merge branch 'verbose' into convergence
martinju Oct 2, 2024
54aa94b
Merge branch 'convergence' into verbose
martinju Oct 2, 2024
92936a4
replacing the old verbose syntax in vaeac and regresion
martinju Oct 2, 2024
971637b
move everything to string verbose
martinju Oct 2, 2024
bb2092f
playing around with cli progress
martinju Oct 2, 2024
26c0cff
more testing
martinju Oct 3, 2024
851a2b6
more work
martinju Oct 3, 2024
06c878a
working OK for now
martinju Oct 3, 2024
d581017
more work
martinju Oct 3, 2024
b8f4331
separate regression done
martinju Oct 3, 2024
4108811
also regression_surrogate done
martinju Oct 3, 2024
8f89af3
fixed no testthat package?
martinju Oct 4, 2024
941b95e
consider myself done for now
martinju Oct 4, 2024
31a9203
vignettes
martinju Oct 4, 2024
996a261
testfile updates
martinju Oct 4, 2024
2676501
styler
martinju Oct 4, 2024
758725f
lint and some checks
martinju Oct 4, 2024
16ff5b8
Merge branch 'verbose' into convergence
martinju Oct 4, 2024
e91963e
hoping to avoid the missing testthat package error on GHA
martinju Oct 4, 2024
d14ba46
annabelles comments
martinju Oct 5, 2024
3bd61ec
update main vignette
martinju Oct 5, 2024
ca936e0
rerun regression vignette
martinju Oct 5, 2024
eae46cc
styler
martinju Oct 5, 2024
0da771c
clean out unused messages
martinju Oct 5, 2024
a4b387e
minor bugfixing
martinju Oct 5, 2024
6ad3846
man
martinju Oct 5, 2024
13fd5c1
tests
martinju Oct 5, 2024
396c523
styler
martinju Oct 5, 2024
681f8d5
Started on NEWS
martinju Oct 6, 2024
d919a6b
Merge branch 'shapr-1.0.0' into convergence
martinju Oct 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
6 changes: 4 additions & 2 deletions .Rprofile
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
testthat::set_max_fails(Inf)

#' Helper function for package development
#'
#' This is a manual extension of [testthat::snapshot_review()] which works for the \code{.rds} files used in
Expand All @@ -7,7 +9,7 @@
#' @param ... Additional arguments passed to [waldo::compare()]
#' Gives the relative path to the test files to review
#'
snapshot_review_man <- function(path, tolerance = NULL, ...) {
snapshot_review_man <- function(path, tolerance = 10^(-5), max_diffs = 200, ...) {
changed <- testthat:::snapshot_meta(path)
these_rds <- (tools::file_ext(changed$name) == "rds")
if (any(these_rds)) {
Expand All @@ -16,7 +18,7 @@ snapshot_review_man <- function(path, tolerance = NULL, ...) {
new <- readRDS(changed[i, "new"])

cat(paste0("Difference for check ", changed[i, "name"], " in test ", changed[i, "test"], "\n"))
print(waldo::compare(old, new, max_diffs = 50, tolerance = tolerance, ...))
print(waldo::compare(old, new, max_diffs = max_diffs, tolerance = tolerance, ...))
browser()
}
}
Expand Down
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Encoding: UTF-8
LazyData: true
ByteCompile: true
Language: en-US
RoxygenNote: 7.3.1
RoxygenNote: 7.3.2
Depends: R (>= 3.5.0)
Imports:
stats,
Expand All @@ -40,7 +40,7 @@ Suggests:
ranger,
xgboost,
mgcv,
testthat (>= 3.0.0),
testthat,
knitr,
rmarkdown,
roxygen2,
Expand Down
18 changes: 16 additions & 2 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,22 @@ S3method(setup_approach,regression_separate)
S3method(setup_approach,regression_surrogate)
S3method(setup_approach,timeseries)
S3method(setup_approach,vaeac)
export(additional_regression_setup)
export(aicc_full_single_cpp)
export(check_convergence)
export(coalition_matrix_cpp)
export(compute_estimates)
export(compute_shapley_new)
export(compute_time)
export(compute_vS)
export(compute_vS_forecast)
export(correction_matrix_cpp)
export(create_coalition_table)
export(explain)
export(explain_forecast)
export(feature_combinations)
export(feature_matrix_cpp)
export(finalize_explanation)
export(finalize_explanation_forecast)
export(get_adaptive_arguments_default)
export(get_cov_mat)
export(get_data_specs)
export(get_model_specs)
Expand All @@ -75,17 +82,23 @@ export(predict_model)
export(prepare_data)
export(prepare_data_copula_cpp)
export(prepare_data_gaussian_cpp)
export(prepare_next_iteration)
export(print_iter)
export(regression.train_model)
export(rss_cpp)
export(save_results)
export(setup)
export(setup_approach)
export(setup_computation)
export(shapley_setup)
export(testing_cleanup)
export(vaeac_get_evaluation_criteria)
export(vaeac_get_extra_para_default)
export(vaeac_plot_eval_crit)
export(vaeac_plot_imputed_ggpairs)
export(vaeac_train_model)
export(vaeac_train_model_continue)
export(weight_matrix)
export(weight_matrix_cpp)
importFrom(Rcpp,sourceCpp)
importFrom(data.table,":=")
Expand All @@ -110,6 +123,7 @@ importFrom(stats,as.formula)
importFrom(stats,contrasts)
importFrom(stats,embed)
importFrom(stats,formula)
importFrom(stats,median)
importFrom(stats,model.frame)
importFrom(stats,model.matrix)
importFrom(stats,predict)
Expand Down
40 changes: 20 additions & 20 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -110,15 +110,15 @@ inv_gaussian_transform_cpp <- function(z, x) {

#' Generate (Gaussian) Copula MC samples
#'
#' @param MC_samples_mat arma::mat. Matrix of dimension (`n_samples`, `n_features`) containing samples from the
#' @param MC_samples_mat arma::mat. Matrix of dimension (`n_MC_samples`, `n_features`) containing samples from the
#' univariate standard normal.
#' @param x_explain_mat arma::mat. Matrix of dimension (`n_explain`, `n_features`) containing the observations
#' to explain on the original scale.
#' @param x_explain_gaussian_mat arma::mat. Matrix of dimension (`n_explain`, `n_features`) containing the
#' observations to explain after being transformed using the Gaussian transform, i.e., the samples have been
#' transformed to a standardized normal distribution.
#' @param x_train_mat arma::mat. Matrix of dimension (`n_train`, `n_features`) containing the training observations.
#' @param S arma::mat. Matrix of dimension (`n_combinations`, `n_features`) containing binary representations of
#' @param S arma::mat. Matrix of dimension (`n_coalitions`, `n_features`) containing binary representations of
#' the used coalitions. S cannot contain the empty or grand coalition, i.e., a row containing only zeros or ones.
#' This is not a problem internally in shapr as the empty and grand coalitions treated differently.
#' @param mu arma::vec. Vector of length `n_features` containing the mean of each feature after being transformed
Expand All @@ -127,8 +127,8 @@ inv_gaussian_transform_cpp <- function(z, x) {
#' between all pairs of features after being transformed using the Gaussian transform, i.e., the samples have been
#' transformed to a standardized normal distribution.
#'
#' @return An arma::cube/3D array of dimension (`n_samples`, `n_explain` * `n_coalitions`, `n_features`), where
#' the columns (_,j,_) are matrices of dimension (`n_samples`, `n_features`) containing the conditional Gaussian
#' @return An arma::cube/3D array of dimension (`n_MC_samples`, `n_explain` * `n_coalitions`, `n_features`), where
#' the columns (_,j,_) are matrices of dimension (`n_MC_samples`, `n_features`) containing the conditional Gaussian
#' copula MC samples for each explicand and coalition on the original scale.
#'
#' @export
Expand All @@ -140,19 +140,19 @@ prepare_data_copula_cpp <- function(MC_samples_mat, x_explain_mat, x_explain_gau

#' Generate Gaussian MC samples
#'
#' @param MC_samples_mat arma::mat. Matrix of dimension (`n_samples`, `n_features`) containing samples from the
#' @param MC_samples_mat arma::mat. Matrix of dimension (`n_MC_samples`, `n_features`) containing samples from the
#' univariate standard normal.
#' @param x_explain_mat arma::mat. Matrix of dimension (`n_explain`, `n_features`) containing the observations
#' to explain.
#' @param S arma::mat. Matrix of dimension (`n_combinations`, `n_features`) containing binary representations of
#' @param S arma::mat. Matrix of dimension (`n_coalitions`, `n_features`) containing binary representations of
#' the used coalitions. S cannot contain the empty or grand coalition, i.e., a row containing only zeros or ones.
#' This is not a problem internally in shapr as the empty and grand coalitions treated differently.
#' @param mu arma::vec. Vector of length `n_features` containing the mean of each feature.
#' @param cov_mat arma::mat. Matrix of dimension (`n_features`, `n_features`) containing the pairwise covariance
#' between all pairs of features.
#'
#' @return An arma::cube/3D array of dimension (`n_samples`, `n_explain` * `n_coalitions`, `n_features`), where
#' the columns (_,j,_) are matrices of dimension (`n_samples`, `n_features`) containing the conditional Gaussian
#' @return An arma::cube/3D array of dimension (`n_MC_samples`, `n_explain` * `n_coalitions`, `n_features`), where
#' the columns (_,j,_) are matrices of dimension (`n_MC_samples`, `n_features`) containing the conditional Gaussian
#' MC samples for each explicand and coalition.
#'
#' @export
Expand Down Expand Up @@ -199,7 +199,7 @@ sample_features_cpp <- function(m, n_features) {
#'
#' @param xtest Numeric matrix. Represents a single test observation.
#'
#' @param S Integer matrix of dimension \code{n_combinations x m}, where \code{n_combinations} equals
#' @param S Integer matrix of dimension \code{n_coalitions x m}, where \code{n_coalitions} equals
#' the total number of sampled/non-sampled feature combinations and \code{m} equals
#' the total number of unique features. Note that \code{m = ncol(xtrain)}. See details
#' for more information.
Expand Down Expand Up @@ -228,34 +228,34 @@ observation_impute_cpp <- function(index_xtrain, index_s, xtrain, xtest, S) {

#' Calculate weight matrix
#'
#' @param subsets List. Each of the elements equals an integer
#' @param coalitions List. Each of the elements equals an integer
#' vector representing a valid combination of features/feature groups.
#' @param m Integer. Number of features/feature groups
#' @param n Integer. Number of combinations
#' @param w Numeric vector of length \code{n}, i.e. \code{w[i]} equals
#' the Shapley weight of feature/feature group combination \code{i}, represented by
#' \code{subsets[[i]]}.
#' \code{coalitions[[i]]}.
#'
#' @export
#' @keywords internal
#'
#' @return Matrix of dimension n x m + 1
#' @author Nikolai Sellereite
weight_matrix_cpp <- function(subsets, m, n, w) {
.Call(`_shapr_weight_matrix_cpp`, subsets, m, n, w)
#' @author Nikolai Sellereite, Martin Jullum
weight_matrix_cpp <- function(coalitions, m, n, w) {
.Call(`_shapr_weight_matrix_cpp`, coalitions, m, n, w)
}

#' Get feature matrix
#' Get coalition matrix
#'
#' @param features List
#' @param m Positive integer. Total number of features
#' @param coalitions List
#' @param m Positive integer. Total number of coalitions
#'
#' @export
#' @keywords internal
#'
#' @return Matrix
#' @author Nikolai Sellereite
feature_matrix_cpp <- function(features, m) {
.Call(`_shapr_feature_matrix_cpp`, features, m)
#' @author Nikolai Sellereite, Martin Jullum
coalition_matrix_cpp <- function(coalitions, m) {
.Call(`_shapr_coalition_matrix_cpp`, coalitions, m)
}

34 changes: 26 additions & 8 deletions R/approach.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,29 @@
setup_approach <- function(internal, ...) {
approach <- internal$parameters$approach

this_class <- ""
iter <- length(internal$iter_list)
X <- internal$iter_list[[iter]]$X

if (length(approach) > 1) {
class(this_class) <- "combined"
needs_X <- c("regression_surrogate", "vaeac")

run_now <- (isFALSE(any(needs_X %in% approach)) && isTRUE(is.null(X))) ||
(isTRUE(any(needs_X %in% approach)) && isFALSE(is.null(X)))

if (isFALSE(run_now)) { # Do nothing
return(internal)
} else {
class(this_class) <- approach
}
this_class <- ""

UseMethod("setup_approach", this_class)
if (length(approach) > 1) {
class(this_class) <- "combined"
} else {
class(this_class) <- approach
}

UseMethod("setup_approach", this_class)

internal$timing_list$setup_approach <- Sys.time()
}
}

#' @inheritParams default_doc
Expand Down Expand Up @@ -49,6 +63,10 @@ setup_approach.combined <- function(internal, ...) {
#' @export
#' @keywords internal
prepare_data <- function(internal, index_features = NULL, ...) {
iter <- length(internal$iter_list)

X <- internal$iter_list[[iter]]$X

# Extract the used approach(es)
approach <- internal$parameters$approach

Expand All @@ -57,9 +75,9 @@ prepare_data <- function(internal, index_features = NULL, ...) {

# Check if the user provided one or several approaches.
if (length(approach) > 1) {
# Picks the relevant approach from the internal$objects$X table which list the unique approach of the batch
# Picks the relevant approach from the X table which list the unique approach of the batch
# matches by index_features
class(this_class) <- internal$objects$X[id_combination == index_features[1], approach]
class(this_class) <- X[id_coalition == index_features[1], approach]
} else {
# Only one approach for all coalitions sizes
class(this_class) <- approach
Expand Down
39 changes: 21 additions & 18 deletions R/approach_categorical.R
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,11 @@ prepare_data.categorical <- function(internal, index_features = NULL, ...) {

joint_probability_dt <- internal$parameters$categorical.joint_prob_dt

X <- internal$objects$X
S <- internal$objects$S
iter <- length(internal$iter_list)

X <- internal$iter_list[[iter]]$X
S <- internal$iter_list[[iter]]$S


if (is.null(index_features)) { # 2,3
features <- X$features # list of [1], [2], [2, 3]
Expand All @@ -106,9 +109,9 @@ prepare_data.categorical <- function(internal, index_features = NULL, ...) {
}
feature_names <- internal$parameters$feature_names

# 3 id columns: id, id_combination, and id_all
# 3 id columns: id, id_coalition, and id_all
# id: for each x_explain observation
# id_combination: the rows of the S matrix
# id_coalition: the rows of the S matrix
# id_all: identifies the unique combinations of feature values from
# the training data (not necessarily the ones in the explain data)

Expand All @@ -118,9 +121,9 @@ prepare_data.categorical <- function(internal, index_features = NULL, ...) {

S_dt <- data.table::data.table(S)
S_dt[S_dt == 0] <- NA
S_dt[, id_combination := seq_len(nrow(S_dt))]
S_dt[, id_coalition := seq_len(nrow(S_dt))]

data.table::setnames(S_dt, c(feature_conditioned, "id_combination"))
data.table::setnames(S_dt, c(feature_conditioned, "id_coalition"))

# (1) Compute marginal probabilities

Expand Down Expand Up @@ -153,21 +156,21 @@ prepare_data.categorical <- function(internal, index_features = NULL, ...) {

cond_dt <- j_S_all_feat[marg_dt, on = feature_conditioned]
cond_dt[, cond_prob := joint_prob / marg_prob]
cond_dt[id_combination == 1, marg_prob := 0]
cond_dt[id_combination == 1, cond_prob := 1]
cond_dt[id_coalition == 1, marg_prob := 0]
cond_dt[id_coalition == 1, cond_prob := 1]

# check marginal probabilities
cond_dt_unique <- unique(cond_dt, by = feature_conditioned)
check <- cond_dt_unique[id_combination != 1][, .(sum_prob = sum(marg_prob)),
by = "id_combination"
check <- cond_dt_unique[id_coalition != 1][, .(sum_prob = sum(marg_prob)),
by = "id_coalition"
][["sum_prob"]]
if (!all(round(check) == 1)) {
print("Warning - not all marginal probabilities sum to 1. There could be a problem
with the joint probabilities. Consider checking.")
}

# make x_explain
data.table::setkeyv(cond_dt, c("id_combination", "id_all"))
data.table::setkeyv(cond_dt, c("id_coalition", "id_all"))
x_explain_with_id <- data.table::copy(x_explain)[, id := .I]
dt_just_explain <- cond_dt[x_explain_with_id, on = feature_names]

Expand All @@ -178,22 +181,22 @@ prepare_data.categorical <- function(internal, index_features = NULL, ...) {
dt <- cond_dt[dt_explain_just_conditioned, on = feature_conditioned, allow.cartesian = TRUE]

# check conditional probabilities
check <- dt[id_combination != 1][, .(sum_prob = sum(cond_prob)),
by = c("id_combination", "id")
check <- dt[id_coalition != 1][, .(sum_prob = sum(cond_prob)),
by = c("id_coalition", "id")
][["sum_prob"]]
if (!all(round(check) == 1)) {
print("Warning - not all conditional probabilities sum to 1. There could be a problem
with the joint probabilities. Consider checking.")
}

setnames(dt, "cond_prob", "w")
data.table::setkeyv(dt, c("id_combination", "id"))
data.table::setkeyv(dt, c("id_coalition", "id"))

# here we merge so that we only return the combintations found in our actual explain data
# this merge does not change the number of rows in dt
# dt <- merge(dt, x$X[, .(id_combination, n_features)], by = "id_combination")
# dt <- merge(dt, x$X[, .(id_coalition, n_features)], by = "id_coalition")
# dt[n_features %in% c(0, ncol(x_explain)), w := 1.0]
dt[id_combination %in% c(1, 2^ncol(x_explain)), w := 1.0]
ret_col <- c("id_combination", "id", feature_names, "w")
return(dt[id_combination %in% index_features, mget(ret_col)])
dt[id_coalition %in% c(1, 2^ncol(x_explain)), w := 1.0]
ret_col <- c("id_coalition", "id", feature_names, "w")
return(dt[id_coalition %in% index_features, mget(ret_col)])
}
Loading
Loading