Skip to content

Commit

Permalink
algorithm explaining updated
Browse files Browse the repository at this point in the history
  • Loading branch information
iwan-tee committed Apr 15, 2024
1 parent e3a54ee commit 73f74f6
Showing 1 changed file with 19 additions and 17 deletions.
36 changes: 19 additions & 17 deletions docs/source/algorithms/explainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -111,39 +111,41 @@ Code sample
lcppn.fit(x_train, y_train)
explainer = Explainer(lcppn, data=x_train, mode="tree")
# One of the possible ways to get explanations
explanations = explainer.explain(x_test)
++++++++++++++++++++++++++
Filtering and Manipulation
++++++++++++++++++++++++++

The Explanation object returned by the Explainer is built using the :literal:`xarray.Dataset` data structure, that enables the application of any xarray dataset operation. For example, filtering specific values can be quickly done. To illustrate the filtering operation, suppose we have SHAP values stored in the Explanation object named :literal:`explanation`.
When you work with the `Explanation` object generated by the `Explainer`, you're leveraging the power of the `xarray.Dataset`. This structure is not just robust but also flexible, allowing for comprehensive dataset operations—especially filtering.

**Practical Example: Filtering SHAP Values**

A common use case is to extract SHAP values for only the predicted nodes. In Local classifier per parent node approach, each node except the leaf nodes represents a classifier. Hence, to find the SHAP values, we can pass the prediction until the penultimate element to obtain the SHAP values.
To achieve this, we can use xarray's :literal:`.sel()` method:
Consider a scenario where you need to focus only on SHAP values corresponding to predicted nodes. In the context of our `LocalClassifierPerParentNode` model, each node—except for the leaf nodes—acts as a classifier. This setup is particularly useful when you're looking to isolate SHAP values up to the penultimate node in your predictions. Here’s how you can do this efficiently using the `sel()` method from xarray:

.. code-block:: python
# Creating a mask for selecting SHAP values for predicted classes
mask = {'class': lcppn.predict(x_test).flatten()[:-1]}
x = explanations.sel(mask).shap_values
selected_shap_values = explanations.sel(mask).shap_values
Also, we developed some helper functions built in the Explainer class as its methods to simplify standard explanation manipulation and filtering such as filtering explanations by level, or filtering explanations by class and returning its Shapley values. A basic example below is a continuation of the example from the beginning of this section:
**Advanced Visualization: Multi-Plot SHAP Values**

.. code-block:: python
For an even deeper analysis, you might want to visualize the SHAP values. The `shap_multi_plot()` method not only filters the data but also provides a visual representation of the SHAP values for specified classes. Below is an example that illustrates how to plot SHAP values for the classes "Covid" and "Respiratory":

predictions = lcppn.predict(x_test)
# Get the correcponding samples
covid_idx = explainer.get_sample_indices(predictions, 'Covid')
# Filter the shap values
shap_values_covid = explainer.filter_by_class(explanations, 'Covid', covid_idx)
print(shap_values_covid)
.. code-block:: python
# Filter explanations by level
level = 1
explanations_level_1 = explainer.filter_by_level(explanations, level)
print(explanations_level_1)
# Generating and plotting explanations for specific classes
explanations = explainer.shap_multi_plot(
class_names=["Covid", "Respiratory"],
features=x_test,
pred_class="Covid",
# Feature names specifiaction possible if x_train is a dataframe with specified columns_names
feature_names=x_train.columns.values
)
Expand Down

0 comments on commit 73f74f6

Please sign in to comment.