Skip to content

Commit

Permalink
updated plot_lcppn_explainer to use platypus dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
ashishpatel16 committed Mar 26, 2024
1 parent ae48bf0 commit 820c1a5
Showing 1 changed file with 36 additions and 20 deletions.
56 changes: 36 additions & 20 deletions docs/examples/plot_lcppn_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,40 +6,56 @@
A minimalist example showing how to use HiClass Explainer to obtain SHAP values of LCPPN model.
A detailed summary of the Explainer class has been given at Algorithms Overview Section for :ref:`Hierarchical Explainability`.
SHAP values are calculated based on a synthetic platypus diseases dataset that can be downloaded `here <https://gist.githubusercontent.com/ashishpatel16/9306f8ed3ed101e7ddcb519776bcbd80/raw/3f225c3f80dd8cbb1b6252f6c372a054ec968705/platypus_diseases.csv>`_.
"""
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from hiclass import LocalClassifierPerParentNode, Explainer
import requests
import pandas as pd
import shap

# Download training data
url = "https://gist.githubusercontent.com/ashishpatel16/9306f8ed3ed101e7ddcb519776bcbd80/raw/3f225c3f80dd8cbb1b6252f6c372a054ec968705/platypus_diseases.csv"
path = "platypus_diseases.csv"
response = requests.get(url)
with open(path, "wb") as file:
file.write(response.content)

# Load training data into pandas dataframe
training_data = pd.read_csv(path).fillna(" ")

# Define data
X_train = np.array(
[
[40.7, 1.0, 1.0, 2.0, 5.0, 2.0, 1.0, 5.0, 34.3],
[39.2, 0.0, 2.0, 4.0, 1.0, 3.0, 1.0, 2.0, 34.1],
[40.6, 0.0, 3.0, 1.0, 4.0, 5.0, 0.0, 6.0, 27.7],
[36.5, 0.0, 3.0, 1.0, 2.0, 2.0, 0.0, 2.0, 39.9],
]
)
X_test = np.array([[35.5, 0.0, 1.0, 1.0, 3.0, 3.0, 0.0, 2.0, 37.5]])
Y_train = np.array(
[
["Gastrointestinal", "Norovirus", ""],
["Respiratory", "Covid", ""],
["Allergy", "External", "Bee Allergy"],
["Respiratory", "Cold", ""],
]
)
X_train = training_data.drop(["label"], axis=1)
X_test = X_train[:100] # Use first 100 samples as test set
Y_train = training_data["label"]
Y_train = [eval(my) for my in Y_train]

# Use random forest classifiers for every node
rfc = RandomForestClassifier()
classifier = LocalClassifierPerParentNode(
local_classifier=rfc, replace_classifiers=False
)

# Train local classifier per node
# Train local classifier per parent node
classifier.fit(X_train, Y_train)

# Define Explainer
explainer = Explainer(classifier, data=X_train, mode="tree")
explanations = explainer.explain(X_test)
explanations = explainer.explain(X_test.values)
print(explanations)

# Filter samples which only predicted "Respiratory" at first level
respiratory_idx = classifier.predict(X_test)[:, 0] == "Respiratory"

# Specify additional filters to obtain only level 0
shap_filter = {"level": 0, "class": "Respiratory", "sample": respiratory_idx}

# Use .sel() method to apply the filter and obtain filtered results
shap_val_respiratory = explanations.sel(shap_filter)

# Plot feature importance on test set
shap.plots.violin(
shap_val_respiratory.shap_values,
feature_names=X_train.columns.values,
plot_size=(13, 8),
)

0 comments on commit 820c1a5

Please sign in to comment.