Skip to content

Commit

Permalink
Update python examples ++ (#416)
Browse files Browse the repository at this point in the history
  • Loading branch information
martinju authored Nov 14, 2024
1 parent 4c7e223 commit 2b3f118
Show file tree
Hide file tree
Showing 15 changed files with 219 additions and 179 deletions.
15 changes: 9 additions & 6 deletions inst/code_paper/code_sec_3.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ library(xgboost)
library(data.table)
library(shapr)

path <- "inst/code_paper/"
path0 <- "https://raw.githubusercontent.com/NorskRegnesentral/shapr/refs/heads/"
path <- paste0(path0,"master/inst/code_paper/")
x_explain <- fread(paste0(path, "x_explain.csv"))
x_train <- fread(paste0(path, "x_train.csv"))
y_train <- unlist(fread(paste0(path, "y_train.csv")))
model <- readRDS(paste0(path, "model.rds"))
model <- readRDS(file(paste0(path, "model.rds")))


# We compute the SHAP values for the test data.
Expand Down Expand Up @@ -51,8 +52,7 @@ exp_20_ctree$MSEv$MSEv
#<num> <num>
# 1: 1224818 101680.4

exp_20_ctree

print(exp_20_ctree)
### Continued estimation

exp_iter_ctree <- explain(model = model,
Expand All @@ -71,7 +71,7 @@ library(ggplot2)

plot(exp_iter_ctree, plot_type = "scatter",scatter_features = c("atemp","windspeed"))

ggplot2::ggsave("inst/code_paper/scatter_ctree.pdf",width = 7, height = 4)
ggplot2::ggsave("inst/code_paper/scatter_ctree.pdf",width = 7, height = 3)

### Grouping

Expand Down Expand Up @@ -125,7 +125,10 @@ exp_g_reg_tuned$MSEv$MSEv

# Plot the best one

plot(exp_group_reg_sep_xgb_tuned,index_x_explain = 6,plot_type="waterfall")
exp_g_reg_tuned$shapley_values_est[6,]
x_explain[6,]

plot(exp_g_reg_tuned,index_x_explain = 6,plot_type="waterfall")

ggplot2::ggsave("inst/code_paper/waterfall_group.pdf",width = 7, height = 4)

Expand Down
3 changes: 2 additions & 1 deletion inst/code_paper/code_sec_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import pandas as pd
from shaprpy import explain

path = "inst/code_paper/"
path0 = "https://raw.githubusercontent.com/NorskRegnesentral/shapr/refs/heads/"
path = path0 + "master/inst/code_paper/"

# Read data
x_train = pd.read_csv(path + "x_train.csv")
Expand Down
12 changes: 7 additions & 5 deletions inst/code_paper/code_sec_6.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ library(xgboost)
library(data.table)
library(shapr)

path <- "inst/code_paper/"
path0 <- "https://raw.githubusercontent.com/NorskRegnesentral/shapr/refs/heads/"
path <- paste0(path0,"master/inst/code_paper/")
x_full <- fread(paste0(path, "x_full.csv"))

data_fit <- x_full[seq_len(729), ]

model_ar <- ar(x_full$temp, order = 2)
model_ar <- ar(data_fit$temp, order = 2)

phi0_ar <- rep(mean(x_full$temp), 3)
phi0_ar <- rep(mean(data_fit$temp), 3)

explain_forecast(
model = model_ar,
Expand All @@ -30,8 +32,8 @@ phi0_arimax <- rep(mean(data_fit$temp), 2)

explain_forecast(
model = model_arimax,
y = data_fit[, "temp"],
xreg = bike[, "windspeed"],
y = x_full[, "temp"],
xreg = x_full[, "windspeed"],
train_idx = 2:728,
explain_idx = 729,
explain_y_lags = 2,
Expand Down
Binary file modified inst/code_paper/scatter_ctree.pdf
Binary file not shown.
4 changes: 2 additions & 2 deletions python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,14 @@ model = RandomForestRegressor()
model.fit(dfx_train, dfy_train.values.flatten())

## Shapr
df_shapley, pred_explain, internal, timing = explain(
explanation = explain(
model = model,
x_train = dfx_train,
x_explain = dfx_test,
approach = 'empirical',
phi0 = dfy_train.mean().item(),
)
print(df_shapley)
print(explanation["shapley_values_est"])
```

`shaprpy` knows how to explain predictions from models from `sklearn`, `keras` and `xgboost`.
Expand Down
35 changes: 17 additions & 18 deletions python/examples/keras_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,34 +25,33 @@
epochs=10,
validation_data=(dfx_test, dfy_test))
## Shapr
df_shapley, pred_explain, internal, timing, MSEv = explain(
explanation = explain(
model = model,
x_train = dfx_train,
x_explain = dfx_test,
approach = 'empirical',
phi0 = dfy_train.mean().item(),
phi0 = dfy_train.mean().item()
)
print(df_shapley)
print(explanation["shapley_values_est"])

"""
none sepal length (cm) sepal width (cm) petal length (cm) \
1 0.494737 0.042263 0.037911 0.059232
2 0.494737 0.034217 0.029183 0.045027
3 0.494737 0.045776 0.031752 0.058278
4 0.494737 0.014977 0.032691 0.014280
5 0.494737 0.022742 0.025851 0.027427
petal width (cm)
1 0.058412
2 0.053639
3 0.070650
4 0.018697
5 0.026814
explain_id none sepal length (cm) sepal width (cm) \
1 1 0.494737 0.041518 0.037129
2 2 0.494737 0.033541 0.028414
3 3 0.494737 0.045033 0.031092
4 4 0.494737 0.014281 0.031831
5 5 0.494737 0.022155 0.025154
petal length (cm) petal width (cm)
1 0.058252 0.057664
2 0.044242 0.052839
3 0.057368 0.069891
4 0.013667 0.018016
5 0.026672 0.026181
"""

# Look at the (overall) MSEv
MSEv["MSEv"]
explanation["MSEv"]["MSEv"]

"""
MSEv MSEv_sd
Expand Down
52 changes: 35 additions & 17 deletions python/examples/pytorch_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,33 +36,51 @@ def forward(self, x):
optim.zero_grad()

## Shapr
df_shapley, pred_explain, internal, timing, MSEv = explain(
explanation = explain(
model = model,
x_train = dfx_train,
x_explain = dfx_test,
approach = 'empirical',
predict_model = lambda m, x: m(torch.from_numpy(x.values).float()).cpu().detach().numpy(),
phi0 = dfy_train.mean().item(),
)
print(df_shapley)
print(explanation["shapley_values_est"])
"""
none MedInc HouseAge AveRooms AveBedrms Population AveOccup \
1 2.205947 2.313935 5.774470 5.425240 4.194669 1.712164 3.546001
2 2.205947 4.477620 5.467266 2.904239 3.046492 1.484807 5.631292
3 2.205946 4.028013 1.168401 5.229893 1.719724 2.134012 3.426378
4 2.205948 4.230376 8.639265 1.138520 3.776463 3.786978 4.253034
5 2.205947 3.923747 1.483737 1.113199 4.963213 -3.645875 4.950775
explain_id none MedInc HouseAge AveRooms AveBedrms Population \
1 1 2.205951 3.531437 7.746453 6.985043 5.454877 3.287326
2 2 2.205951 6.004403 7.041080 4.254553 4.118677 3.162567
3 3 2.205950 5.497648 1.538680 6.750968 2.806428 3.687014
4 4 2.205951 5.761901 11.378609 2.112351 5.013451 5.754630
5 5 2.205951 5.325281 2.585713 2.224409 6.418153 -2.848570
Latitude Longitude
1 1.102239 2.906469
2 4.966465 2.178510
3 3.503413 2.909760
4 3.413727 3.795563
5 3.011126 4.016985
AveOccup Latitude Longitude
1 4.774873 2.273699 4.314784
2 7.386783 6.473623 3.318631
3 5.193341 4.875864 4.290797
4 5.866562 4.564957 5.139962
5 6.428984 4.280456 5.509226
"""

MSEv["MSEv"]
print(explanation["shapley_values_sd"])

"""
explain_id none MedInc HouseAge AveRooms AveBedrms \
1 1 3.523652e-08 0.122568 0.124885 0.163694 0.134910
2 2 3.501778e-08 0.125286 0.113064 0.123057 0.129869
3 3 1.805247e-08 0.098208 0.095959 0.115399 0.102265
4 4 3.227380e-08 0.110442 0.118524 0.124688 0.101476
5 5 3.650380e-08 0.125538 0.130427 0.136797 0.131515
Population AveOccup Latitude Longitude
1 0.133510 0.149141 0.132394 0.121605
2 0.113429 0.124539 0.122773 0.100871
3 0.092633 0.110790 0.090657 0.090542
4 0.114721 0.122266 0.103081 0.105613
5 0.113853 0.139291 0.135377 0.132476
"""

explanation["MSEv"]["MSEv"]
"""
MSEv MSEv_sd
1 27.046126 7.253933
MSEv MSEv_sd
1 33.143896 7.986808
"""
55 changes: 11 additions & 44 deletions python/examples/regression_paradigm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
x_train=dfx_train,
x_explain=dfx_test,
approach='empirical',
iterative = False,
phi0=dfy_train.mean().item()
)

Expand All @@ -38,8 +39,6 @@
x_explain=dfx_test,
approach='regression_separate',
phi0=dfy_train.mean().item(),
verbose=2,
n_batches=1,
regression_model='parsnip::linear_reg()'
)

Expand All @@ -50,8 +49,6 @@
x_explain=dfx_test,
approach='regression_separate',
phi0=dfy_train.mean().item(),
verbose=2,
n_batches=1,
regression_model='parsnip::linear_reg()',
regression_recipe_func='''function(regression_recipe) {
return(recipes::step_ns(regression_recipe, recipes::all_numeric_predictors(), deg_free = 3))
Expand All @@ -65,8 +62,6 @@
x_explain=dfx_test,
approach='regression_separate',
phi0=dfy_train.mean().item(),
verbose=2,
n_batches=1,
regression_model='parsnip::linear_reg()',
regression_recipe_func='''function(regression_recipe) {
return(recipes::step_ns(regression_recipe, recipes::all_numeric_predictors(), deg_free = 3))
Expand All @@ -80,8 +75,6 @@
x_explain=dfx_test,
approach='regression_separate',
phi0=dfy_train.mean().item(),
verbose=2,
n_batches=1,
regression_model="parsnip::decision_tree(tree_depth = hardhat::tune(), engine = 'rpart', mode = 'regression')",
regression_tune_values='dials::grid_regular(dials::tree_depth(), levels = 4)',
regression_vfold_cv_para={'v': 5}
Expand All @@ -94,8 +87,6 @@
x_explain=dfx_test,
approach='regression_separate',
phi0=dfy_train.mean().item(),
verbose=2,
n_batches=1,
regression_model="parsnip::boost_tree(engine = 'xgboost', mode = 'regression')"
)

Expand All @@ -106,8 +97,6 @@
x_explain=dfx_test,
approach='regression_separate',
phi0=dfy_train.mean().item(),
verbose=2,
n_batches=1,
regression_model="parsnip::boost_tree(trees = hardhat::tune(), engine = 'xgboost', mode = 'regression')",
regression_tune_values='expand.grid(trees = c(10, 15, 25, 50, 100, 500))',
regression_vfold_cv_para={'v': 5}
Expand All @@ -121,8 +110,6 @@
x_explain=dfx_test,
approach='regression_surrogate',
phi0=dfy_train.mean().item(),
verbose=2,
n_batches=1,
regression_model='parsnip::linear_reg()'
)

Expand All @@ -133,8 +120,6 @@
x_explain=dfx_test,
approach='regression_surrogate',
phi0=dfy_train.mean().item(),
verbose=2,
n_batches=1,
regression_model="parsnip::rand_forest(engine = 'ranger', mode = 'regression')"
)

Expand All @@ -145,8 +130,6 @@
x_explain=dfx_test,
approach='regression_surrogate',
phi0=dfy_train.mean().item(),
verbose=2,
n_batches=1,
regression_model="""parsnip::rand_forest(
mtry = hardhat::tune(), trees = hardhat::tune(), engine = 'ranger', mode = 'regression'
)""",
Expand All @@ -161,34 +144,18 @@
# Print the MSEv evaluation criterion scores
print("Method", "MSEv", "Elapsed time (seconds)")
for i, (method, explanation) in enumerate(explanation_list.items()):
print(method, round(explanation[4]["MSEv"]["MSEv"].iloc[0], 3), round(explanation[3]["total_time_secs"], 3))
print(method, round(explanation["MSEv"]["MSEv"].iloc[0].iloc[0], 3), round(explanation["timing"]["total_time_secs"][0], 3))



"""
Method MSEv Time
empirical 0.826 1.096
sep_lm 1.623 12.093
sep_pca 1.626 16.435
sep_splines 1.626 15.072
sep_tree_cv 1.436 275.002
sep_xgboost 0.769 13.870
sep_xgboost_cv 0.802 312.758
sur_lm 1.772 0.548
sur_rf 0.886 41.250
"""

explanation_list["sep_xgboost"][0]
explanation_list["sep_xgboost"]["shapley_values_est"]

"""
none MedInc HouseAge AveRooms AveBedrms Population AveOccup \
1 2.205937 -0.496421 0.195272 -0.077923 0.010124 -0.219369 -0.316029
2 2.205938 -0.163246 0.014565 -0.415945 -0.114073 0.084315 0.144754
3 2.205938 0.574157 0.258926 0.090818 -0.665126 0.354005 0.869530
4 2.205938 0.311416 -0.105142 0.211300 0.031939 -0.180331 -0.059839
5 2.205938 0.077537 -0.150997 -0.117875 0.087118 -0.085118 0.414764
Latitude Longitude
1 -0.434240 -0.361774
2 -0.483618 -0.324016
3 0.276002 0.957242
4 0.028560 0.049815
5 -0.242943 0.006815
explain_id none MedInc HouseAge AveRooms AveBedrms Population AveOccup Latitude Longitude
1 1 2.205937 -0.498764 0.193443 -0.073068 0.005078 -0.216733 -0.313781 -0.433844 -0.362689
2 2 2.205938 -0.160032 0.014564 -0.417670 -0.117127 0.084102 0.151612 -0.486576 -0.326138
3 3 2.205938 0.585638 0.239399 0.103826 -0.656533 0.349671 0.859701 0.275356 0.958495
4 4 2.205938 0.311038 -0.114403 0.206639 0.041748 -0.178090 -0.061004 0.036681 0.045110
5 5 2.205938 0.079439 -0.156861 -0.118913 0.093746 -0.097861 0.433192 -0.239588 -0.003852
"""
Loading

0 comments on commit 2b3f118

Please sign in to comment.