From 820c1a52370d8e37d36d285ef059a62e789b1850 Mon Sep 17 00:00:00 2001 From: ashishpatel16 Date: Tue, 26 Mar 2024 21:12:49 +0100 Subject: [PATCH] updated plot_lcppn_explainer to use platypus dataset --- docs/examples/plot_lcppn_explainer.py | 56 +++++++++++++++++---------- 1 file changed, 36 insertions(+), 20 deletions(-) diff --git a/docs/examples/plot_lcppn_explainer.py b/docs/examples/plot_lcppn_explainer.py index e0fee33f..9f7d00f4 100644 --- a/docs/examples/plot_lcppn_explainer.py +++ b/docs/examples/plot_lcppn_explainer.py @@ -6,29 +6,29 @@ A minimalist example showing how to use HiClass Explainer to obtain SHAP values of LCPPN model. A detailed summary of the Explainer class has been given at Algorithms Overview Section for :ref:`Hierarchical Explainability`. +SHAP values are calculated based on a synthetic platypus diseases dataset that can be downloaded `here `_. """ -import numpy as np from sklearn.ensemble import RandomForestClassifier from hiclass import LocalClassifierPerParentNode, Explainer +import requests +import pandas as pd +import shap + +# Download training data +url = "https://gist.githubusercontent.com/ashishpatel16/9306f8ed3ed101e7ddcb519776bcbd80/raw/3f225c3f80dd8cbb1b6252f6c372a054ec968705/platypus_diseases.csv" +path = "platypus_diseases.csv" +response = requests.get(url) +with open(path, "wb") as file: + file.write(response.content) + +# Load training data into pandas dataframe +training_data = pd.read_csv(path).fillna(" ") # Define data -X_train = np.array( - [ - [40.7, 1.0, 1.0, 2.0, 5.0, 2.0, 1.0, 5.0, 34.3], - [39.2, 0.0, 2.0, 4.0, 1.0, 3.0, 1.0, 2.0, 34.1], - [40.6, 0.0, 3.0, 1.0, 4.0, 5.0, 0.0, 6.0, 27.7], - [36.5, 0.0, 3.0, 1.0, 2.0, 2.0, 0.0, 2.0, 39.9], - ] -) -X_test = np.array([[35.5, 0.0, 1.0, 1.0, 3.0, 3.0, 0.0, 2.0, 37.5]]) -Y_train = np.array( - [ - ["Gastrointestinal", "Norovirus", ""], - ["Respiratory", "Covid", ""], - ["Allergy", "External", "Bee Allergy"], - ["Respiratory", "Cold", ""], - ] -) +X_train = training_data.drop(["label"], axis=1) +X_test = X_train[:100] # Use first 100 samples as test set +Y_train = training_data["label"] +Y_train = [eval(my) for my in Y_train] # Use random forest classifiers for every node rfc = RandomForestClassifier() @@ -36,10 +36,26 @@ local_classifier=rfc, replace_classifiers=False ) -# Train local classifier per node +# Train local classifier per parent node classifier.fit(X_train, Y_train) # Define Explainer explainer = Explainer(classifier, data=X_train, mode="tree") -explanations = explainer.explain(X_test) +explanations = explainer.explain(X_test.values) print(explanations) + +# Filter samples which only predicted "Respiratory" at first level +respiratory_idx = classifier.predict(X_test)[:, 0] == "Respiratory" + +# Specify additional filters to obtain only level 0 +shap_filter = {"level": 0, "class": "Respiratory", "sample": respiratory_idx} + +# Use .sel() method to apply the filter and obtain filtered results +shap_val_respiratory = explanations.sel(shap_filter) + +# Plot feature importance on test set +shap.plots.violin( + shap_val_respiratory.shap_values, + feature_names=X_train.columns.values, + plot_size=(13, 8), +)