forked from DrewWham/PSU_Stat_380
-
Notifications
You must be signed in to change notification settings - Fork 0
/
shap.R
100 lines (85 loc) · 3.89 KB
/
shap.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# Note: The functions shap.score.rank, shap_long_hd and plot.shap.summary were
# originally published at https://liuyanguu.github.io/post/2018/10/14/shap-visualization-for-xgboost/
# All the credits to the author.
## functions for plot
# return matrix of shap score and mean ranked score list
shap.score.rank <- function(xgb_model = xgb_mod, shap_approx = TRUE,
X_train = mydata$train_mm){
require(xgboost)
require(data.table)
shap_contrib <- predict(xgb_model, X_train,
predcontrib = TRUE, approxcontrib = shap_approx)
shap_contrib <- as.data.table(shap_contrib)
shap_contrib[,BIAS:=NULL]
cat('make SHAP score by decreasing order\n\n')
mean_shap_score <- colMeans(abs(shap_contrib))[order(colMeans(abs(shap_contrib)), decreasing = T)]
return(list(shap_score = shap_contrib,
mean_shap_score = (mean_shap_score)))
}
# a function to standardize feature values into same range
std1 <- function(x){
return ((x - min(x, na.rm = T))/(max(x, na.rm = T) - min(x, na.rm = T)))
}
# prep shap data
shap.prep <- function(shap = shap_result, X_train = mydata$train_mm, top_n){
require(ggforce)
# descending order
if (missing(top_n)) top_n <- dim(X_train)[2] # by default, use all features
if (!top_n%in%c(1:dim(X_train)[2])) stop('supply correct top_n')
require(data.table)
shap_score_sub <- as.data.table(shap$shap_score)
shap_score_sub <- shap_score_sub[, names(shap$mean_shap_score)[1:top_n], with = F]
shap_score_long <- melt.data.table(shap_score_sub, measure.vars = colnames(shap_score_sub))
# feature values: the values in the original dataset
fv_sub <- as.data.table(X_train)[, names(shap$mean_shap_score)[1:top_n], with = F]
# standardize feature values
fv_sub_long <- melt.data.table(fv_sub, measure.vars = colnames(fv_sub))
fv_sub_long[, stdfvalue := std1(value), by = "variable"]
# SHAP value: value
# raw feature value: rfvalue;
# standarized: stdfvalue
names(fv_sub_long) <- c("variable", "rfvalue", "stdfvalue" )
shap_long2 <- cbind(shap_score_long, fv_sub_long[,c('rfvalue','stdfvalue')])
shap_long2[, mean_value := mean(abs(value)), by = variable]
setkey(shap_long2, variable)
return(shap_long2)
}
plot.shap.summary <- function(data_long){
x_bound <- max(abs(data_long$value))
require('ggforce') # for `geom_sina`
plot1 <- ggplot(data = data_long)+
coord_flip() +
# sina plot:
geom_sina(aes(x = variable, y = value, color = stdfvalue)) +
# print the mean absolute value:
geom_text(data = unique(data_long[, c("variable", "mean_value"), with = F]),
aes(x = variable, y=-Inf, label = sprintf("%.3f", mean_value)),
size = 3, alpha = 0.7,
hjust = -0.2,
fontface = "bold") + # bold
# # add a "SHAP" bar notation
# annotate("text", x = -Inf, y = -Inf, vjust = -0.2, hjust = 0, size = 3,
# label = expression(group("|", bar(SHAP), "|"))) +
scale_color_gradient(low="#FFCC33", high="#6600CC",
breaks=c(0,1), labels=c("Low","High")) +
theme_bw() +
theme(axis.line.y = element_blank(), axis.ticks.y = element_blank(), # remove axis line
legend.position="bottom") +
geom_hline(yintercept = 0) + # the vertical line
scale_y_continuous(limits = c(-x_bound, x_bound)) +
# reverse the order of features
scale_x_discrete(limits = rev(levels(data_long$variable))
) +
labs(y = "SHAP value (impact on model output)", x = "", color = "Feature value")
return(plot1)
}
var_importance <- function(shap_result, top_n=10)
{
var_importance=tibble(var=names(shap_result$mean_shap_score), importance=shap_result$mean_shap_score)
var_importance=var_importance[1:top_n,]
ggplot(var_importance, aes(x=reorder(var,importance), y=importance)) +
geom_bar(stat = "identity") +
coord_flip() +
theme_light() +
theme(axis.title.y=element_blank())
}