Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add filter functions for explanations and shap_values #120

Open
wants to merge 31 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
5af9efb
basic implementation of filter_by_level function
iwan-tee Apr 14, 2024
5e52e86
basic implementation of filter_by_class function
iwan-tee Apr 14, 2024
5a8e9be
basic implementation of combine_filters function
iwan-tee Apr 14, 2024
cceb6d1
codestyling
iwan-tee Apr 14, 2024
ab1e3ab
helper function get_class_level added
iwan-tee Apr 14, 2024
719dd83
another helper functions added
iwan-tee Apr 14, 2024
343be04
some bugs fixed
iwan-tee Apr 14, 2024
af32acb
small plot_lcpl_explainer actualisation
iwan-tee Apr 14, 2024
3645e38
some small changes in plot_lcpl_explainer
iwan-tee Apr 14, 2024
d1cf3e0
another changes in plot_lcpl_explainer
iwan-tee Apr 14, 2024
7e6bcb3
functionality duplication removes + docstrings added (partially)
iwan-tee Apr 14, 2024
31e88b2
some actualization in plot_lcpl_explainer
iwan-tee Apr 14, 2024
849b30f
some refactoring
iwan-tee Apr 14, 2024
9877374
some refactoring
iwan-tee Apr 14, 2024
38a9060
some cases handled + tests written
iwan-tee Apr 15, 2024
159bd17
small changes
iwan-tee Apr 15, 2024
5faf54b
pydocstyle
iwan-tee Apr 15, 2024
c9e6aa2
documentation fixed
iwan-tee Apr 15, 2024
3d2bf3b
Algorithm overviem completed
iwan-tee Apr 15, 2024
a2e7700
plot_lcppn_explainer actualized
iwan-tee Apr 15, 2024
3b8239f
shap_multi_plot added
iwan-tee Apr 15, 2024
c06e5e6
part of the code substituted with newer methods
iwan-tee Apr 15, 2024
63c635f
pydocstyling
iwan-tee Apr 15, 2024
6953517
some cases handled and new tests added
iwan-tee Apr 15, 2024
e3a54ee
matplotlib requirement added
iwan-tee Apr 15, 2024
73f74f6
algorithm explaining updated
iwan-tee Apr 15, 2024
4ab46ad
small changes in indices selection
iwan-tee Apr 15, 2024
e431803
some skipiff added
iwan-tee Apr 16, 2024
bf64bfe
ray support added and used as a default (instead of joblib)
iwan-tee Apr 16, 2024
1047a2c
pydocs
iwan-tee Apr 16, 2024
c1dd746
pydocs
iwan-tee Apr 16, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ name = "pypi"
networkx = "*"
numpy = "*"
scikit-learn = "*"
matplotlib = "*"

[dev-packages]
pytest = "*"
Expand All @@ -20,4 +21,4 @@ sphinx-rtd-theme = "0.5.2"
[extras]
ray = "*"
shap = "0.44.1"
xarray = "*"
xarray = "*"
35 changes: 7 additions & 28 deletions docs/examples/plot_lcpl_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,35 +25,14 @@

# Define Explainer
explainer = Explainer(classifier, data=X_train, mode="tree")
explanations = explainer.explain(X_test.values)
print(explanations)

# Let's filter the Shapley values corresponding to the Covid (level 1)
# and 'Respiratory' (level 0)
# Now, our task is to see how feature importance may vary from level to level
# We are going to calculate shap_values for 'Respiratory', 'Covid' and plot what we calculated
# This can be done with a single method .shap_multi_plot, which additionally returns calculated explanations

covid_idx = classifier.predict(X_test)[:, 1] == "Covid"

shap_filter_covid = {"level": 1, "class": "Covid", "sample": covid_idx}
shap_filter_resp = {"level": 0, "class": "Respiratory", "sample": covid_idx}
shap_val_covid = explanations.sel(**shap_filter_covid)
shap_val_resp = explanations.sel(**shap_filter_resp)


# This code snippet demonstrates how to visually compare the mean absolute SHAP values for 'Covid' vs. 'Respiratory' diseases.

# Feature names for the X-axis
feature_names = X_train.columns.values

# SHAP values for 'Covid'
shap_values_covid = shap_val_covid.shap_values.values

# SHAP values for 'Respiratory'
shap_values_resp = shap_val_resp.shap_values.values

shap.summary_plot(
[shap_values_covid, shap_values_resp],
features=X_test.iloc[covid_idx],
feature_names=X_train.columns.values,
plot_type="bar",
explanations = explainer.shap_multi_plot(
class_names=["Covid", "Respiratory"],
features=X_test.values,
pred_class="Respiratory",
features_names=X_train.columns.values,
)
17 changes: 10 additions & 7 deletions docs/examples/plot_lcppn_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,26 @@
# Train local classifier per parent node
classifier.fit(X_train, Y_train)

# Get predictions
predictions = classifier.predict(X_test)

# Define Explainer
explainer = Explainer(classifier, data=X_train.values, mode="tree")
explanations = explainer.explain(X_test.values)
print(explanations)

# Filter samples which only predicted "Respiratory" at first level
respiratory_idx = classifier.predict(X_test)[:, 0] == "Respiratory"

# Specify additional filters to obtain only level 0
shap_filter = {"level": 0, "class": "Respiratory", "sample": respiratory_idx}
# Filter samples which only predicted "Respiratory"
respiratory_idx = explainer.get_sample_indices(predictions, "Respiratory")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about making it more pandas-like? For example respiratory_idx = predictions == "Respiratory"?


# Use .sel() method to apply the filter and obtain filtered results
shap_val_respiratory = explanations.sel(shap_filter)
shap_val_respiratory = explainer.filter_by_class(
Copy link
Collaborator

@mirand863 mirand863 Apr 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, I guess you can probably call the method get_sample_indices inside this other method filter_by_class, simplifying it for the user

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

settled in this option for now

explanations, class_name="Respiratory", sample_indices=respiratory_idx
)


# Plot feature importance on test set
shap.plots.violin(
shap_val_respiratory.shap_values,
shap_val_respiratory,
feature_names=X_train.columns.values,
plot_size=(13, 8),
)
29 changes: 25 additions & 4 deletions docs/source/algorithms/explainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -111,21 +111,42 @@ Code sample

lcppn.fit(x_train, y_train)
explainer = Explainer(lcppn, data=x_train, mode="tree")

# One of the possible ways to get explanations
explanations = explainer.explain(x_test)


++++++++++++++++++++++++++
Filtering and Manipulation
++++++++++++++++++++++++++

The Explanation object returned by the Explainer is built using the :literal:`xarray.Dataset` data structure, that enables the application of any xarray dataset operation. For example, filtering specific values can be quickly done. To illustrate the filtering operation, suppose we have SHAP values stored in the Explanation object named :literal:`explanation`.
When you work with the `Explanation` object generated by the `Explainer`, you're leveraging the power of the `xarray.Dataset`. This structure is not just robust but also flexible, allowing for comprehensive dataset operations—especially filtering.

**Practical Example: Filtering SHAP Values**

A common use case is to extract SHAP values for only the predicted nodes. In Local classifier per parent node approach, each node except the leaf nodes represents a classifier. Hence, to find the SHAP values, we can pass the prediction until the penultimate element to obtain the SHAP values.
To achieve this, we can use xarray's :literal:`.sel()` method:
Consider a scenario where you need to focus only on SHAP values corresponding to predicted nodes. In the context of our `LocalClassifierPerParentNode` model, each node—except for the leaf nodes—acts as a classifier. This setup is particularly useful when you're looking to isolate SHAP values up to the penultimate node in your predictions. Here’s how you can do this efficiently using the `sel()` method from xarray:

.. code-block:: python

# Creating a mask for selecting SHAP values for predicted classes
mask = {'class': lcppn.predict(x_test).flatten()[:-1]}
x = explanations.sel(mask).shap_values
selected_shap_values = explanations.sel(mask).shap_values

**Advanced Visualization: Multi-Plot SHAP Values**

For an even deeper analysis, you might want to visualize the SHAP values. The `shap_multi_plot()` method not only filters the data but also provides a visual representation of the SHAP values for specified classes. Below is an example that illustrates how to plot SHAP values for the classes "Covid" and "Respiratory":

.. code-block:: python

# Generating and plotting explanations for specific classes
explanations = explainer.shap_multi_plot(
class_names=["Covid", "Respiratory"],
features=x_test,
pred_class="Covid",
# Feature names specifiaction possible if x_train is a dataframe with specified columns_names
feature_names=x_train.columns.values
)



More advanced usage and capabilities can be found at the `Xarray.Dataset <https://docs.xarray.dev/en/stable/generated/xarray.Dataset.html>`_ documentation.
225 changes: 225 additions & 0 deletions hiclass/Explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,3 +365,228 @@ def _calculate_shap_values(self, X):
datasets.append(local_dataset)
sample_explanation = xr.concat(datasets, dim="level")
return sample_explanation

def filter_by_level(self, explanations, level):
"""
Return the explanations filtered by the given level.

Parameters
__________
explanations : xarray.DataArray
The explanations to filter
level : int
level in the hierarchy to filter

Returns
_______
filtered_explanations : xarray.Dataset
Explanations filtered by the given level

Examples
--------
>>> from sklearn.ensemble import RandomForestClassifier
>>> import numpy as np
>>> from hiclass import LocalClassifierPerParentNode, Explainer
>>> rfc = RandomForestClassifier()
>>> lcppn = LocalClassifierPerParentNode(local_classifier=rfc, replace_classifiers=False)
>>> x_train = np.array([[1, 3], [2, 5]])
>>> y_train = np.array([[1, 2], [3, 4]])
>>> x_test = np.array([[4, 6]])
>>> lcppn.fit(x_train, y_train)
>>> explainer = Explainer(lcppn, data=x_train, mode="tree")
>>> explanations = explainer.explain(x_test)
>>> explanations_level_1 = explainer.filter_by_level(explanations, level=1)
<xarray.Dataset>
Dimensions: (class: 3, sample: 1, feature: 2)
Coordinates:
* class (class) <U1 12B '1' '3' '4'
level int64 8B 1
Dimensions without coordinates: sample, feature
Data variables:
node (sample) <U13 52B '3'
predicted_class (sample) <U1 4B '4'
predict_proba (sample, class) float64 24B nan nan 1.0
classes (sample, class) object 24B nan nan '4'
shap_values (class, sample, feature) float64 48B nan nan ... 0.0 0.0
"""
filter_by_level = {"level": level}
filtered_explanations = explanations.sel(**filter_by_level)
return filtered_explanations

def filter_by_class(self, explanations, class_name, sample_indices=None):
"""
Filter SHAP values based on a specified class and optionally by sample indices.

This function filters the provided explanations data array to return SHAP values
for a specific class, determined both by its name and the level it appears in
the hierarchy, with an option to further filter by specific sample indices.

As long as class can belong to one level only, the function also provides filtration
by level which is computed based on the class location in the hierarchy.
Parameters
----------
explanations : xarray.Dataset
A dataset of explanations, where dimensions include 'class', 'level', and 'sample'.
class_name : str or int
The name of the class to filter the explanations by.
sample_indices : list of boolean, optional
A list of boolean indices specifying which samples to include in the filter.
If None, no sample-based filtering is applied.

Returns
-------
numpy.ndarray
An array of SHAP values filtered according to the specified class and optionally
by the provided sample indices.

Examples
________
>>> from sklearn.ensemble import RandomForestClassifier
>>> import numpy as np
>>> from hiclass import LocalClassifierPerParentNode, Explainer
>>> rfc = RandomForestClassifier()
>>> lcppn = LocalClassifierPerParentNode(local_classifier=rfc, replace_classifiers=False)
>>> x_train = np.array([[1, 3], [2, 5]])
>>> y_train = np.array([[1, 2], [3, 4]])
>>> x_test = np.array([[4, 6]])
>>> lcppn.fit(x_train, y_train)
>>> predictions = lcppn.predict(x_test)
>>> explainer = Explainer(lcppn, data=x_train, mode="tree")
>>> explanations = explainer.explain(x_test)
>>> filtered_shap = explainer.filter_by_class(explanations, level=3)
>>> print(filtered_shap)
[['3' '4']]
[[0.1 0.105]]
"""
# Ensure that explanations are provided and have the expected structure
if not isinstance(explanations, xr.Dataset):
raise ValueError("Explanations should be an xarray.Dataset!")

# Converting class_name to the string format
class_name = str(class_name)

if class_name == "":
raise ValueError("Empty class!")

# Define level
level = self.get_class_level(str(class_name))

# Handling with LocalClassifierPerNode case
if isinstance(self.hierarchical_model, LocalClassifierPerNode):
class_name = f"{class_name}_1"

# Shap filter
shap_filter = {"class": class_name, "level": level}
if sample_indices is not None:
shap_filter["sample"] = sample_indices

# Select the SHAP values according to the filter and handle possible errors
try:
filtered_explanations = explanations.sel(**shap_filter)
except KeyError as e:
raise KeyError(
f"Class name {class_name} with level {level} not found."
) from e

# Return the selected SHAP values as a NumPy array
return filtered_explanations.shap_values.values

def get_class_level(self, class_name):
"""
Return level of the class in the hierarchy.

Parameters
__________
class_name : int or str
Name of the class

Returns
_______
class_level : int
Level of the class in the hierarchy


"""
# Set the classifier
classifier = self.hierarchical_model

# Converting class_name to the string formatn
class_name = str(class_name)

# Iterating through the nodes of hierarchy
for node_ in classifier.hierarchy_.nodes:
if class_name in node_.split(classifier.separator_):
node_classes = node_.split(classifier.separator_)
return node_classes.index(class_name)

raise ValueError(f"Class '{class_name}' not found!")

def get_sample_indices(self, predictions, class_name):
"""
Return indices of predictions corresponding to the certain class.

Parameters
__________
predictions: array-like
Array of predictions of the hierarchical classificator
class_name: str
Name of class

Returns
_______
sample_indices: boolean array of indices
"""
class_level = self.get_class_level(class_name)
return predictions[:, class_level] == class_name

def shap_multi_plot(self, class_names, features, pred_class, features_names=None):
"""
Plot shap_values for multi-class case on a bar and return explanations.

"Lazy" function which does not require any additional actions from the user
apart from classifier fitting and explainer initialization.

Parameters
----------
class_names : list of str
A list of class names to calculate and visualize the Shapley values for.
features: array-like
Matrix of feature values with shape (# features) or (# samples x # features).
Typically, this would be the test set features (X_test).
pred_class : int or str
The class label that the classifier's predictions must match for a sample to be
included in the subset of data used for SHAP value calculation. If not provided,
no filtering is applied, and all samples are considered.
features_names : list, optional
A list of feature names to include in the bar plot for the shap_values.

Returns
-------
explanations: xarray.Dataset3
Whole explanations of data in features provided.
"""
classifier = self.hierarchical_model
predictions = classifier.predict(features)

if pred_class is not None and not any(pred_class in row for row in predictions):
raise ValueError(
f"The specified class '{pred_class}' was not found in the predictions."
)

explanations = self.explain(features)
sample_idx = self.get_sample_indices(predictions, pred_class)
shap_array = []
for class_name in class_names:
shap_val = self.filter_by_class(
explanations, class_name=class_name, sample_indices=sample_idx
)
shap_array.append(shap_val)

shap.summary_plot(
shap_array,
features=features[sample_idx],
feature_names=features_names,
plot_type="bar",
class_names=class_names,
)
return explanations
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
KEYWORDS = ["hierarchical classification"]
DACS_SOFTWARE = "https://gitlab.com/dacs-hpi"
# What packages are required for this module to be executed?
REQUIRED = ["networkx", "numpy", "scikit-learn", "scipy<1.13"]
REQUIRED = ["networkx", "numpy", "scikit-learn", "scipy<1.13", "matplotlib"]

# What packages are optional?
# 'fancy feature': ['django'],}
Expand Down
Loading
Loading