-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
441 additions
and
146 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,7 @@ build-backend = "hatchling.build" | |
|
||
[project] | ||
name = "scanvi-explainer" | ||
version = "0.1.0" | ||
version = "0.2.0" | ||
description = "Exapliner of scANVI using SHAP" | ||
authors = [ | ||
{name = "Martin Proks", email = "[email protected]"}, | ||
|
@@ -15,7 +15,8 @@ license = {file = "LICENSE"} | |
requires-python = ">=3.10" | ||
keywords = ["shap", "scanvi", "explainer", "interpretability"] | ||
dependencies = [ | ||
"ruff>=0.6.4", | ||
"anndata", | ||
"rich", | ||
"shap>=0.41.0", | ||
] | ||
classifiers = [ | ||
|
@@ -35,7 +36,7 @@ Issues = "https://github.com/brickmanlab/scanvi-explainer/issues" | |
Changelog = "https://github.com/brickmanlab/scanvi-explainer/blob/master/CHANGELOG.md" | ||
|
||
[project.optional-dependencies] | ||
dev = ["ruff", "huggingface_hub"] | ||
dev = ["ruff", "scvi-tools", "huggingface_hub"] | ||
doc = [ | ||
"setuptools", | ||
"sphinx", | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,11 @@ | ||
import importlib.util | ||
|
||
from .scanvi_deep import SCANVIDeep | ||
|
||
try: | ||
import torch | ||
except ImportError: | ||
if not importlib.util.find_spec("torch"): | ||
raise ImportError("Missing torch package! Run pip install torch") | ||
|
||
try: | ||
from scvi.model import SCANVI | ||
except ImportError: | ||
raise ImportError("Missing scvi-tools package! Run pip install scvi-tools") | ||
if not importlib.util.find_spec("scvi"): | ||
raise ImportError("Missing torch package! Run pip install torch") | ||
|
||
__all__ = [ | ||
"SCANVIDeep" | ||
] | ||
__all__ = ["SCANVIDeep", "utils", "plots"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import matplotlib.pyplot as plt | ||
import numpy.typing as npt | ||
import pandas as pd | ||
import seaborn as sns | ||
from scvi import REGISTRY_KEYS | ||
|
||
from .scanvi_deep import SCANVIDeep | ||
|
||
|
||
def feature_plot( | ||
explainer: SCANVIDeep, | ||
shap_values: npt.NDArray, | ||
subset: bool = False, | ||
top_n: int = 10, | ||
) -> None: | ||
"""Prints feature contribution (absolute mean SHAP value) for each cell type (top 10). | ||
Parameters | ||
---------- | ||
explainer : SCANVIDeep | ||
SCANVIDeep explainer | ||
shap_values : npt.NDArray | ||
Expected SHAP values | ||
subset : bool, optional | ||
When set to true, calculate contribution by subsetting for test cells which belong to that | ||
particual classifier. | ||
When set to false, be generic and return contributing features even when testing set has | ||
different cell types. | ||
""" | ||
|
||
groupby = explainer.labels_key | ||
classes = explainer.adata.obs[groupby].cat.categories | ||
features = explainer.adata.var_names | ||
|
||
nrows = classes.size // 2 + classes.size % 2 | ||
fig, ax = plt.subplots(nrows, 2, sharex=False, figsize=[20, 40]) | ||
|
||
for idx, ct in enumerate(classes): | ||
|
||
shaps = pd.DataFrame(shap_values[idx], columns=features) | ||
|
||
if subset: | ||
shaps[groupby] = explainer.test[REGISTRY_KEYS.LABELS_KEY] | ||
shaps = shaps[shaps[groupby] == idx].iloc[:, :-1] | ||
|
||
tmp_avg = ( | ||
shaps.mean(axis=0) | ||
.sort_values(ascending=False) | ||
.reset_index() | ||
.rename(columns={"index": "feature", 0: "weight"}) | ||
) | ||
positive = ( | ||
tmp_avg.query("weight > 0") | ||
.head(top_n // 2) | ||
.assign(contribution="positive") | ||
) | ||
negative = ( | ||
tmp_avg.query("weight < 0") | ||
.tail(top_n // 2) | ||
.assign(contribution="negative") | ||
) | ||
|
||
avg = pd.concat([positive, negative]) | ||
title = f"Mean(SHAP value average importance for: {ct}" | ||
|
||
else: | ||
avg = ( | ||
shaps.abs() | ||
.mean(axis=0) | ||
.sort_values(ascending=False) | ||
.reset_index() | ||
.rename(columns={"index": "feature", 0: "weight"}) | ||
.query("weight > 0") | ||
.head(10) | ||
) | ||
title = f"Mean(|SHAP value|) average importance for: {ct}" | ||
|
||
sns.barplot( | ||
x="weight", | ||
y="feature", | ||
hue="contribution", | ||
palette=["red", "blue"], | ||
data=avg, | ||
ax=ax[idx // 2, idx % 2], | ||
) | ||
ax[idx // 2, idx % 2].set_title(title) | ||
ax[idx // 2, idx % 2].legend(title="Contribution", loc="lower right") | ||
fig.tight_layout() |
Oops, something went wrong.