Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explainer API for Local Classifier per parent node #minor #106

Merged
merged 68 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from 58 commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
62c218d
added initial implementation of explainer api for lcppn
Jan 9, 2024
ea1fff8
fixed lints
Jan 10, 2024
c4d75c5
fixed lints
Jan 14, 2024
299af62
added an _explain_lcppn implementation and some tests provided
iwan-tee Jan 14, 2024
0ea8956
modified docstrings
Jan 14, 2024
1efd946
explainer for lcpn implemented + tests added and some cases fixed
iwan-tee Jan 14, 2024
7dcb52f
Merge branch 'explainer_api_lcpn' into explainer_api
iwan-tee Jan 14, 2024
1de360c
tests added + some bugs fixed
iwan-tee Jan 15, 2024
933b1f6
base
iwan-tee Jan 15, 2024
c57abed
basic implementation
iwan-tee Jan 18, 2024
a829ce0
LCPL explanator implementation + test
iwan-tee Jan 23, 2024
33f2cbc
added tests for hierarchy without roots
Jan 26, 2024
c06d8a7
check on root node added
iwan-tee Jan 26, 2024
c597fce
minor updates
Jan 26, 2024
b79e5f4
codestyling
iwan-tee Jan 26, 2024
8a643f1
codestyling
iwan-tee Jan 26, 2024
606c1eb
Merge branch 'explainer_master' into explainer_api_lcpl
ashishpatel16 Jan 26, 2024
ca6c654
Update Explainer.py
ashishpatel16 Jan 26, 2024
9936dc3
Merge pull request #1 from ashishpatel16/explainer_api_lcpl
ashishpatel16 Jan 26, 2024
82573be
Merge pull request #2 from ashishpatel16/explainer_api_lcpn
ashishpatel16 Jan 26, 2024
d53e8d9
added support for xarray for lcppn
Jan 29, 2024
2449928
Merge branch 'explainer_master' into explainer_api
ashishpatel16 Jan 29, 2024
759489f
Update Explainer.py
ashishpatel16 Jan 29, 2024
0771c08
Update Explainer.py
ashishpatel16 Jan 29, 2024
4eb6f5c
fixed errors with classifier with single class
Jan 30, 2024
3955521
updated test cases and removed cached explainers
Feb 1, 2024
7c2f4d2
removed cached explainers
Feb 1, 2024
b12bdc3
modified predict proba to return dict
Feb 1, 2024
986b61c
Merge branch 'main' into explainer_api
ashishpatel16 Feb 2, 2024
eb11c0e
updated get_predict_proba to return only traversed prediction probabi…
Feb 3, 2024
8c700e4
updated fork
Feb 3, 2024
53a90a0
separate test file for explainer
Feb 3, 2024
5e74762
Update Explainer.py
ashishpatel16 Feb 5, 2024
b1f3656
_get_traversed_nodes edited
iwan-tee Feb 6, 2024
2a12087
fixed lints
Feb 12, 2024
84f6e39
fixed conflicts
Feb 12, 2024
b09f8da
refactored and cleaned up code
Feb 12, 2024
aecdd96
updated test cases and isolated lcppn code
Feb 12, 2024
9658c4a
Merge branch 'main' into lcppn_explainer
ashishpatel16 Feb 13, 2024
9a73b6c
added support for lcpn
Feb 16, 2024
c5b5a68
Merge branch 'main' into lcpn_explainer
ashishpatel16 Mar 14, 2024
dc99b44
updated explainer and tests, added docstrings
ashishpatel16 Mar 15, 2024
06acdca
updated readthedocs
ashishpatel16 Mar 15, 2024
7da5779
updated README with Explainer example
ashishpatel16 Mar 15, 2024
139ad11
fixed imports
ashishpatel16 Mar 16, 2024
707d51e
removed unecessary files
ashishpatel16 Mar 20, 2024
33e1548
added tests, updated dependencies in setup.py and docs/requirements.txt
ashishpatel16 Mar 22, 2024
2a19d40
fixed lints
ashishpatel16 Mar 22, 2024
53288f2
isolated lcpn code and removed lcppn code from explainer
ashishpatel16 Mar 22, 2024
253a8de
fixed shap version
ashishpatel16 Mar 24, 2024
59ba63c
merged lcpn_epxlainer
ashishpatel16 Mar 25, 2024
c2378bb
separated code for lcppn and added tests
ashishpatel16 Mar 25, 2024
10bad49
Update plot_lcppn_explainer.py
ashishpatel16 Mar 25, 2024
81817c3
removed get_predict_proba() method from LocalClassifierPerParentNode
ashishpatel16 Mar 25, 2024
9ba6f28
removed redundant dependencies from pipfile
ashishpatel16 Mar 25, 2024
04a9ed8
Merge remote-tracking branch 'origin/lcppn_explainer' into lcppn_expl…
ashishpatel16 Mar 25, 2024
1b90da8
used masking approach to calculate traversed nodes
ashishpatel16 Mar 25, 2024
d914fd4
handled cases for imbalanced hierarchy
ashishpatel16 Mar 25, 2024
17c96ef
removed hiclass separator from output
ashishpatel16 Mar 26, 2024
fde3040
Update tests/test_LocalClassifierPerParentNode.py
ashishpatel16 Mar 26, 2024
01bf32e
Update hiclass/Explainer.py
ashishpatel16 Mar 26, 2024
ada4e88
refactored _get_traversed_nodes, will be three distinct methods for …
ashishpatel16 Mar 26, 2024
8dbf6e5
Merge remote-tracking branch 'origin/lcppn_explainer' into lcppn_expl…
ashishpatel16 Mar 26, 2024
a0d3f59
fixed xarray dependency version
ashishpatel16 Mar 26, 2024
ae48bf0
updated documentation and fixed typos
ashishpatel16 Mar 26, 2024
820c1a5
updated plot_lcppn_explainer to use platypus dataset
ashishpatel16 Mar 26, 2024
af7e83a
updated README
ashishpatel16 Mar 27, 2024
3c93a98
updated url for platypus dataset
ashishpatel16 Mar 27, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@ sphinx-rtd-theme = "0.5.2"

[extras]
ray = "*"
shap = "*"
shap = "0.44.1"
xarray = "*"
81 changes: 78 additions & 3 deletions Pipfile.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

32 changes: 32 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,38 @@ pipeline.fit(X_train, Y_train)
predictions = pipeline.predict(X_test)
```

## Explaining Hierarchical Classifiers
ashishpatel16 marked this conversation as resolved.
Show resolved Hide resolved
Hierarchical classifiers can provide additional insights when combined with explainability methods such as SHAP values. Below is a simple example to demonstrate how to calculate hierarchical SHAP values:
```python
from hiclass import LocalClassifierPerParentNode, Explainer
from sklearn.ensemble import RandomForestClassifier
import numpy as np

# Define data
X_train = np.array([[1], [2], [3], [4]])
X_test = np.array([[4], [3], [2], [1]])
Y_train = np.array([
['Animal', 'Mammal', 'Sheep'],
['Animal', 'Mammal', 'Cow'],
['Animal', 'Reptile', 'Snake'],
['Animal', 'Reptile', 'Lizard'],
])

# Use random forest classifiers for every node
rf = RandomForestClassifier()
classifier = LocalClassifierPerParentNode(local_classifier=rf, replace_classifiers=False)

# Train local classifier per node
classifier.fit(X_train, Y_train)

# Predict
predictions = classifier.predict(X_test)

# Explain
explainer = Explainer(classifier, data=X_train, mode="tree")
explanations = explainer.explain(X_test)
```

ashishpatel16 marked this conversation as resolved.
Show resolved Hide resolved
## Step-by-step walk-through

A step-by-step walk-through is available on our documentation hosted on [Read the Docs](https://hiclass.readthedocs.io/en/latest/index.html).
Expand Down
45 changes: 45 additions & 0 deletions docs/examples/plot_lcppn_explainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# -*- coding: utf-8 -*-
"""
============================================
Explaining Local Classifier Per Parent Node
============================================

A minimalist example showing how to use HiClass Explainer to obtain SHAP values of LCPPN model.
A detailed summary of the Explainer class has been given at Algorithms Overview Section for :ref:`Hierarchical Explainability`.
"""
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from hiclass import LocalClassifierPerParentNode, Explainer

# Define data
X_train = np.array(
[
[40.7, 1.0, 1.0, 2.0, 5.0, 2.0, 1.0, 5.0, 34.3],
[39.2, 0.0, 2.0, 4.0, 1.0, 3.0, 1.0, 2.0, 34.1],
[40.6, 0.0, 3.0, 1.0, 4.0, 5.0, 0.0, 6.0, 27.7],
[36.5, 0.0, 3.0, 1.0, 2.0, 2.0, 0.0, 2.0, 39.9],
]
)
X_test = np.array([[35.5, 0.0, 1.0, 1.0, 3.0, 3.0, 0.0, 2.0, 37.5]])
Y_train = np.array(
[
["Gastrointestinal", "Norovirus", ""],
["Respiratory", "Covid", ""],
["Allergy", "External", "Bee Allergy"],
["Respiratory", "Cold", ""],
]
)
mirand863 marked this conversation as resolved.
Show resolved Hide resolved

# Use random forest classifiers for every node
rfc = RandomForestClassifier()
classifier = LocalClassifierPerParentNode(
local_classifier=rfc, replace_classifiers=False
)

# Train local classifier per node
classifier.fit(X_train, Y_train)

# Define Explainer
explainer = Explainer(classifier, data=X_train, mode="tree")
explanations = explainer.explain(X_test)
print(explanations)
2 changes: 2 additions & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,5 @@ pandas==1.4.2
ray==1.13.0
numpy
git+https://github.com/charles9n/bert-sklearn.git@master
shap==0.44.1
xarray
ashishpatel16 marked this conversation as resolved.
Show resolved Hide resolved
Binary file added docs/source/algorithms/explainer-indexing.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
135 changes: 135 additions & 0 deletions docs/source/algorithms/explainer.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
.. _explainer-overview:

===========================
Hierarchical Explainability
===========================
HiClass also provides support for eXplainable AI (XAI) using SHAP values. This section demonstrates the Explainer class along with examples and design principles.

++++++++++++++++++++++++++
Motivation
++++++++++++++++++++++++++

Explainability in machine learning refers to the ability to understand and interpret how a model arrives at a particular decision. Several explainability methods are available in literature which have found applications in various machine learning applications.

SHAP values is one such method that provides a unified measure of feature importance that considers the contribution of each feature to the model prediction. These values are based on cooperative game theory and provide a fair way to distribute the credit for the prediction among the features.

Integrating explainability methods to Hierarchical classifiers can yield promising results depending on the application domain. Hierarchical explainability extends the concept of SHAP values to hierarchical classification models.

++++++++++++++++++++++++++
Dataset overview
++++++++++++++++++++++++++
For the remainder of this section, we will utilize a synthetically generated dataset representing platypus diseases. This tabular dataset is created for visualizing and testing the essence of explainability using SHAP on hierarchical models. The diagram below illustrates the hierarchical structure of the dataset. With 9 symptoms as features—fever, diarrhea, stomach pain, skin rash, cough, sniffles, shortness of breath, headache, and body size—the objective is to predict the disease based on these feature values.

.. figure:: ../algorithms/platypus_diseases_hierarchy.svg
:align: center
:width: 100%

Hierarchical structure of the synthetic dataset representing platypus diseases.

++++++++++++++++++++++++++
Background
++++++++++++++++++++++++++
This section introduces two main concepts: hierarchical classification and SHAP values. Hierarchical classification leverages the hierarchical structure of data, breaking down the classification task into manageable sub-tasks using models organized in a DAG structure.

SHAP values, adapted from game theory, show the impact of features on model predictions, thus aiding model interpretation. The SHAP library offers practical implementation of these methods, supporting various machine learning algorithms for explanation generation.

To demonstrate how SHAP values provide insights into model prediction, consider the following sample from the platypus disease dataset.

.. code-block:: python

test_sample = np.array([[35.5, 0. , 1. , 1. , 3. , 3. , 0. , 2. , 37.5]])
sample_target = np.array([['Respiratory', 'Cold', '']])

We can calculate SHAP values using the SHAP python package and visualize them. SHAP values tell us how much each symptom "contributes" to the model's decision about which disease a platypus might have. The following diagram illustrates how SHAP values can be visualized using the :literal:`shap.force_plot`


.. figure:: ../algorithms/shap_explanation.png
:align: center
:width: 100%

Force plot illustrating the influence of symptoms on predicting platypus diseases using SHAP values. Each bar represents a symptom, with its length indicating the magnitude of impact on disease prediction.


++++++++++++++++++++++++++
API Design
++++++++++++++++++++++++++

Designing an API for hierarchical classifiers and SHAP value computation presents numerous challenges including complex data structures, difficulties accessing correct shap values corresponding to a classifier, and slow computation. We addressed these issues by using xarray dataset for organization, filtering, and storage of SHAP values efficiency. We also utilized parallelization using joblib for speed. These enhancements ensure a streamlined and user-friendly experience for users dealing with hierarchical classifiers and SHAP values.

.. figure:: ../algorithms/explainer-indexing.png
:align: center
:width: 75%

Pictorial representation of dimensions along which indexing of hierarchical SHAP values are required.

The Explainer class takes a fitted HiClass model, training data, and some named parameters as input. After creating an instance of the Explainer, the explain method can be called by providing the samples for which SHAP values need to be calculated.

.. code-block:: python

explainer = Explainer(fitted_hiclass_model, data=training_data)

The Explainer returns an Xarray.Dataset object which allows users to intuitively access, filter, slice, and plot SHAP values. This Explanation dataset can also be used interactively within the Jupyter notebook environment. The Explanation object along with its respective attributes are depicted in the following UML diagram.

.. figure:: ../algorithms/hiclass-uml.png
:align: center
:width: 100%

UML diagram showing relationship between HiClass Explainer and the returned Explanation object.

The Explanation object can be obtained calling the explain method of Explainer.

.. code-block:: python

explanations = explainer.explain(sample_data)


++++++++++++++++++++++++++
Code sample
++++++++++++++++++++++++++

.. code-block:: python

from sklearn.ensemble import RandomForestClassifier
import numpy as np
from hiclass import LocalClassifierPerParentNode, Explainer

rfc = RandomForestClassifier()
lcppn = LocalClassifierPerParentNode(local_classifier=rfc, replace_classifiers=False)

x_train = np.array([
[40.7, 1. , 1. , 2. , 5. , 2. , 1. , 5. , 34.3],
[39.2, 0. , 2. , 4. , 1. , 3. , 1. , 2. , 34.1],
[40.6, 0. , 3. , 1. , 4. , 5. , 0. , 6. , 27.7],
[36.5, 0. , 3. , 1. , 2. , 2. , 0. , 2. , 39.9],
])
y_train = np.array([
['Gastrointestinal', 'Norovirus', ''],
['Respiratory', 'Covid', ''],
['Allergy', 'External', 'Bee Allergy'],
['Respiratory', 'Cold', ''],
])

x_test = np.array([[35.5, 0. , 1. , 1. , 3. , 3. , 0. , 2. , 37.5]])

lcppn.fit(x_train, y_train)
explainer = Explainer(lcppn, data=x_train, mode="tree")
explanations = explainer.explain(x_test)


++++++++++++++++++++++++++
Filtering and Manipulation
++++++++++++++++++++++++++

The explainer explanation object in SHAP is built using the xarray dataset, enabling the application of any xarray dataset operation. For example, filtering specific values can be easily done. To illustrate, suppose we have SHAP values stored in the Explanation object named :literal:`explanation`.

A common use case is to extract SHAP values for only the predicted nodes. In Local Classifier per parent node approach, each node except the leaf nodes represent a classifier. Hence, to find the SHAP values we can pass the prediction until the penultimate element to obtain the SHAP values.
To achieve this, we can use xarray's .sel() method:

.. code-block:: python

mask = {'class': lcppn.predict(x_test).flatten()[:-1]}
x = explanations.sel(mask).shap_values

More advanced usage and capabilities can be found at the `Xarray.Dataset <https://docs.xarray.dev/en/stable/generated/xarray.Dataset.html>`_ documentation.


Binary file added docs/source/algorithms/hiclass-uml.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions docs/source/algorithms/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ HiClass provides implementations for the most popular machine learning models fo
local_classifier_per_level
multi_label
metrics
explainer
1 change: 1 addition & 0 deletions docs/source/algorithms/platypus_diseases_hierarchy.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/algorithms/shap_explanation.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
10 changes: 10 additions & 0 deletions docs/source/api/explainer_api.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
.. _explainer_api:

Explainer
========================

Explainer
-----------------------
.. autoclass:: Explainer.Explainer
:members:
:special-members: __init__
1 change: 1 addition & 0 deletions docs/source/api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ This is done in order to provide a complete list of the callable functions for e

classifiers
utilities
explainer_api
Loading
Loading