Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explainer API for Local Classifier per parent node #minor #106

Merged
merged 68 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from 66 commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
62c218d
added initial implementation of explainer api for lcppn
Jan 9, 2024
ea1fff8
fixed lints
Jan 10, 2024
c4d75c5
fixed lints
Jan 14, 2024
299af62
added an _explain_lcppn implementation and some tests provided
iwan-tee Jan 14, 2024
0ea8956
modified docstrings
Jan 14, 2024
1efd946
explainer for lcpn implemented + tests added and some cases fixed
iwan-tee Jan 14, 2024
7dcb52f
Merge branch 'explainer_api_lcpn' into explainer_api
iwan-tee Jan 14, 2024
1de360c
tests added + some bugs fixed
iwan-tee Jan 15, 2024
933b1f6
base
iwan-tee Jan 15, 2024
c57abed
basic implementation
iwan-tee Jan 18, 2024
a829ce0
LCPL explanator implementation + test
iwan-tee Jan 23, 2024
33f2cbc
added tests for hierarchy without roots
Jan 26, 2024
c06d8a7
check on root node added
iwan-tee Jan 26, 2024
c597fce
minor updates
Jan 26, 2024
b79e5f4
codestyling
iwan-tee Jan 26, 2024
8a643f1
codestyling
iwan-tee Jan 26, 2024
606c1eb
Merge branch 'explainer_master' into explainer_api_lcpl
ashishpatel16 Jan 26, 2024
ca6c654
Update Explainer.py
ashishpatel16 Jan 26, 2024
9936dc3
Merge pull request #1 from ashishpatel16/explainer_api_lcpl
ashishpatel16 Jan 26, 2024
82573be
Merge pull request #2 from ashishpatel16/explainer_api_lcpn
ashishpatel16 Jan 26, 2024
d53e8d9
added support for xarray for lcppn
Jan 29, 2024
2449928
Merge branch 'explainer_master' into explainer_api
ashishpatel16 Jan 29, 2024
759489f
Update Explainer.py
ashishpatel16 Jan 29, 2024
0771c08
Update Explainer.py
ashishpatel16 Jan 29, 2024
4eb6f5c
fixed errors with classifier with single class
Jan 30, 2024
3955521
updated test cases and removed cached explainers
Feb 1, 2024
7c2f4d2
removed cached explainers
Feb 1, 2024
b12bdc3
modified predict proba to return dict
Feb 1, 2024
986b61c
Merge branch 'main' into explainer_api
ashishpatel16 Feb 2, 2024
eb11c0e
updated get_predict_proba to return only traversed prediction probabi…
Feb 3, 2024
8c700e4
updated fork
Feb 3, 2024
53a90a0
separate test file for explainer
Feb 3, 2024
5e74762
Update Explainer.py
ashishpatel16 Feb 5, 2024
b1f3656
_get_traversed_nodes edited
iwan-tee Feb 6, 2024
2a12087
fixed lints
Feb 12, 2024
84f6e39
fixed conflicts
Feb 12, 2024
b09f8da
refactored and cleaned up code
Feb 12, 2024
aecdd96
updated test cases and isolated lcppn code
Feb 12, 2024
9658c4a
Merge branch 'main' into lcppn_explainer
ashishpatel16 Feb 13, 2024
9a73b6c
added support for lcpn
Feb 16, 2024
c5b5a68
Merge branch 'main' into lcpn_explainer
ashishpatel16 Mar 14, 2024
dc99b44
updated explainer and tests, added docstrings
ashishpatel16 Mar 15, 2024
06acdca
updated readthedocs
ashishpatel16 Mar 15, 2024
7da5779
updated README with Explainer example
ashishpatel16 Mar 15, 2024
139ad11
fixed imports
ashishpatel16 Mar 16, 2024
707d51e
removed unecessary files
ashishpatel16 Mar 20, 2024
33e1548
added tests, updated dependencies in setup.py and docs/requirements.txt
ashishpatel16 Mar 22, 2024
2a19d40
fixed lints
ashishpatel16 Mar 22, 2024
53288f2
isolated lcpn code and removed lcppn code from explainer
ashishpatel16 Mar 22, 2024
253a8de
fixed shap version
ashishpatel16 Mar 24, 2024
59ba63c
merged lcpn_epxlainer
ashishpatel16 Mar 25, 2024
c2378bb
separated code for lcppn and added tests
ashishpatel16 Mar 25, 2024
10bad49
Update plot_lcppn_explainer.py
ashishpatel16 Mar 25, 2024
81817c3
removed get_predict_proba() method from LocalClassifierPerParentNode
ashishpatel16 Mar 25, 2024
9ba6f28
removed redundant dependencies from pipfile
ashishpatel16 Mar 25, 2024
04a9ed8
Merge remote-tracking branch 'origin/lcppn_explainer' into lcppn_expl…
ashishpatel16 Mar 25, 2024
1b90da8
used masking approach to calculate traversed nodes
ashishpatel16 Mar 25, 2024
d914fd4
handled cases for imbalanced hierarchy
ashishpatel16 Mar 25, 2024
17c96ef
removed hiclass separator from output
ashishpatel16 Mar 26, 2024
fde3040
Update tests/test_LocalClassifierPerParentNode.py
ashishpatel16 Mar 26, 2024
01bf32e
Update hiclass/Explainer.py
ashishpatel16 Mar 26, 2024
ada4e88
refactored _get_traversed_nodes, will be three distinct methods for …
ashishpatel16 Mar 26, 2024
8dbf6e5
Merge remote-tracking branch 'origin/lcppn_explainer' into lcppn_expl…
ashishpatel16 Mar 26, 2024
a0d3f59
fixed xarray dependency version
ashishpatel16 Mar 26, 2024
ae48bf0
updated documentation and fixed typos
ashishpatel16 Mar 26, 2024
820c1a5
updated plot_lcppn_explainer to use platypus dataset
ashishpatel16 Mar 26, 2024
af7e83a
updated README
ashishpatel16 Mar 27, 2024
3c93a98
updated url for platypus dataset
ashishpatel16 Mar 27, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@ sphinx-rtd-theme = "0.5.2"

[extras]
ray = "*"
shap = "*"
shap = "0.44.1"
xarray = "*"
81 changes: 78 additions & 3 deletions Pipfile.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ HiClass is an open-source Python library for hierarchical classification compati
- **[Hierarchical metrics](https://hiclass.readthedocs.io/en/latest/api/utilities.html#hierarchical-metrics):** HiClass supports the computation of hierarchical precision, recall and f-score, which are more appropriate for hierarchical data than traditional metrics.
- **[Compatible with pickle](https://hiclass.readthedocs.io/en/latest/auto_examples/plot_model_persistence.html):** Easily store trained models on disk for future use.
- **[BERT sklearn](https://hiclass.readthedocs.io/en/latest/auto_examples/plot_bert.html):** Compatible with the library [BERT sklearn](https://github.com/charles9n/bert-sklearn).
- **[Hierarchical Explanability]():** HiClass allows explaining hierarchical models using the [SHAP](https://github.com/shap/shap) package.
ashishpatel16 marked this conversation as resolved.
Show resolved Hide resolved

**Any feature missing on this list?** Search our [issue tracker](https://github.com/scikit-learn-contrib/hiclass/issues) to see if someone has already requested it and add a comment to it explaining your use-case. Otherwise, please open a new issue describing the requested feature and possible use-case scenario. We prioritize our roadmap based on user feedback, so we would love to hear from you.

Expand Down Expand Up @@ -113,7 +114,7 @@ pip install hiclass"[<extra_name>]"
Replace <extra_name> with one of the following options:

- ray: Installs the ray package, which is required for parallel processing support.
- xai: Installs the shap and xarray packages, which are required for explaining Hiclass predictions.
- xai: Installs the shap and xarray packages, which are required for explaining Hiclass' predictions.

### Option 2: Conda

Expand Down Expand Up @@ -199,6 +200,9 @@ pipeline.fit(X_train, Y_train)
predictions = pipeline.predict(X_test)
```

## Explaining Hierarchical Classifiers
ashishpatel16 marked this conversation as resolved.
Show resolved Hide resolved
Hierarchical classifiers can provide additional insights when combined with explainability methods. HiClass allows explaining hierarchical models using SHAP values. Different hierarchical models yield different insights. More information on explaining [Local classifier per parent node](https://colab.research.google.com/drive/1rVlYuRU_uO1jw5sD6qo2HoCpCz6E6z5J?usp=sharing), [Local classifier per node](), and [Local classifier per level]() is available on [Read the Docs](https://hiclass.readthedocs.io/en/latest/algorithms/explainer.html).

## Step-by-step walk-through

A step-by-step walk-through is available on our documentation hosted on [Read the Docs](https://hiclass.readthedocs.io/en/latest/index.html).
Expand Down
61 changes: 61 additions & 0 deletions docs/examples/plot_lcppn_explainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# -*- coding: utf-8 -*-
"""
============================================
Explaining Local Classifier Per Parent Node
============================================

A minimalist example showing how to use HiClass Explainer to obtain SHAP values of LCPPN model.
A detailed summary of the Explainer class has been given at Algorithms Overview Section for :ref:`Hierarchical Explainability`.
SHAP values are calculated based on a synthetic platypus diseases dataset that can be downloaded `here <https://gist.githubusercontent.com/ashishpatel16/9306f8ed3ed101e7ddcb519776bcbd80/raw/3f225c3f80dd8cbb1b6252f6c372a054ec968705/platypus_diseases.csv>`_.
"""
from sklearn.ensemble import RandomForestClassifier
from hiclass import LocalClassifierPerParentNode, Explainer
import requests
import pandas as pd
import shap

# Download training data
url = "https://gist.githubusercontent.com/ashishpatel16/9306f8ed3ed101e7ddcb519776bcbd80/raw/3f225c3f80dd8cbb1b6252f6c372a054ec968705/platypus_diseases.csv"
path = "platypus_diseases.csv"
response = requests.get(url)
with open(path, "wb") as file:
file.write(response.content)

# Load training data into pandas dataframe
training_data = pd.read_csv(path).fillna(" ")

# Define data
X_train = training_data.drop(["label"], axis=1)
X_test = X_train[:100] # Use first 100 samples as test set
Y_train = training_data["label"]
Y_train = [eval(my) for my in Y_train]

# Use random forest classifiers for every node
rfc = RandomForestClassifier()
classifier = LocalClassifierPerParentNode(
local_classifier=rfc, replace_classifiers=False
)

# Train local classifier per parent node
classifier.fit(X_train, Y_train)

# Define Explainer
explainer = Explainer(classifier, data=X_train, mode="tree")
explanations = explainer.explain(X_test.values)
print(explanations)

# Filter samples which only predicted "Respiratory" at first level
respiratory_idx = classifier.predict(X_test)[:, 0] == "Respiratory"

# Specify additional filters to obtain only level 0
shap_filter = {"level": 0, "class": "Respiratory", "sample": respiratory_idx}

# Use .sel() method to apply the filter and obtain filtered results
shap_val_respiratory = explanations.sel(shap_filter)

# Plot feature importance on test set
shap.plots.violin(
shap_val_respiratory.shap_values,
feature_names=X_train.columns.values,
plot_size=(13, 8),
)
2 changes: 2 additions & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,5 @@ pandas==1.4.2
ray==1.13.0
numpy
git+https://github.com/charles9n/bert-sklearn.git@master
shap==0.44.1
xarray==2023.1.0
Binary file added docs/source/algorithms/explainer-indexing.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Loading