-
Notifications
You must be signed in to change notification settings - Fork 22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add filter functions for explanations and shap_values #120
base: main
Are you sure you want to change the base?
Changes from 26 commits
5af9efb
5e52e86
5a8e9be
cceb6d1
ab1e3ab
719dd83
343be04
af32acb
3645e38
d1cf3e0
7e6bcb3
31e88b2
849b30f
9877374
38a9060
159bd17
5faf54b
c9e6aa2
3d2bf3b
a2e7700
3b8239f
c06e5e6
63c635f
6953517
e3a54ee
73f74f6
4ab46ad
e431803
bf64bfe
1047a2c
c1dd746
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,23 +25,26 @@ | |
# Train local classifier per parent node | ||
classifier.fit(X_train, Y_train) | ||
|
||
# Get predictions | ||
predictions = classifier.predict(X_test) | ||
|
||
# Define Explainer | ||
explainer = Explainer(classifier, data=X_train.values, mode="tree") | ||
explanations = explainer.explain(X_test.values) | ||
print(explanations) | ||
|
||
# Filter samples which only predicted "Respiratory" at first level | ||
respiratory_idx = classifier.predict(X_test)[:, 0] == "Respiratory" | ||
|
||
# Specify additional filters to obtain only level 0 | ||
shap_filter = {"level": 0, "class": "Respiratory", "sample": respiratory_idx} | ||
# Filter samples which only predicted "Respiratory" | ||
respiratory_idx = explainer.get_sample_indices(predictions, "Respiratory") | ||
|
||
# Use .sel() method to apply the filter and obtain filtered results | ||
shap_val_respiratory = explanations.sel(shap_filter) | ||
shap_val_respiratory = explainer.filter_by_class( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Alternatively, I guess you can probably call the method There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. settled in this option for now |
||
explanations, class_name="Respiratory", sample_indices=respiratory_idx | ||
) | ||
|
||
|
||
# Plot feature importance on test set | ||
shap.plots.violin( | ||
shap_val_respiratory.shap_values, | ||
shap_val_respiratory, | ||
feature_names=X_train.columns.values, | ||
plot_size=(13, 8), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you think about making it more pandas-like? For example
respiratory_idx = predictions == "Respiratory"
?