Skip to content

Commit

Permalink
feat: new api changes
Browse files Browse the repository at this point in the history
  • Loading branch information
matq007 committed Sep 13, 2024
1 parent 6fc15e8 commit d47f318
Show file tree
Hide file tree
Showing 6 changed files with 441 additions and 146 deletions.
7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "scanvi-explainer"
version = "0.1.0"
version = "0.2.0"
description = "Exapliner of scANVI using SHAP"
authors = [
{name = "Martin Proks", email = "[email protected]"},
Expand All @@ -15,7 +15,8 @@ license = {file = "LICENSE"}
requires-python = ">=3.10"
keywords = ["shap", "scanvi", "explainer", "interpretability"]
dependencies = [
"ruff>=0.6.4",
"anndata",
"rich",
"shap>=0.41.0",
]
classifiers = [
Expand All @@ -35,7 +36,7 @@ Issues = "https://github.com/brickmanlab/scanvi-explainer/issues"
Changelog = "https://github.com/brickmanlab/scanvi-explainer/blob/master/CHANGELOG.md"

[project.optional-dependencies]
dev = ["ruff", "huggingface_hub"]
dev = ["ruff", "scvi-tools", "huggingface_hub"]
doc = [
"setuptools",
"sphinx",
Expand Down
16 changes: 6 additions & 10 deletions src/scanvi_explainer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
import importlib.util

from .scanvi_deep import SCANVIDeep

try:
import torch
except ImportError:
if not importlib.util.find_spec("torch"):
raise ImportError("Missing torch package! Run pip install torch")

try:
from scvi.model import SCANVI
except ImportError:
raise ImportError("Missing scvi-tools package! Run pip install scvi-tools")
if not importlib.util.find_spec("scvi"):
raise ImportError("Missing torch package! Run pip install torch")

__all__ = [
"SCANVIDeep"
]
__all__ = ["SCANVIDeep", "utils", "plots"]
88 changes: 88 additions & 0 deletions src/scanvi_explainer/plots.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import matplotlib.pyplot as plt
import numpy.typing as npt
import pandas as pd
import seaborn as sns
from scvi import REGISTRY_KEYS

from .scanvi_deep import SCANVIDeep


def feature_plot(
explainer: SCANVIDeep,
shap_values: npt.NDArray,
subset: bool = False,
top_n: int = 10,
) -> None:
"""Prints feature contribution (absolute mean SHAP value) for each cell type (top 10).
Parameters
----------
explainer : SCANVIDeep
SCANVIDeep explainer
shap_values : npt.NDArray
Expected SHAP values
subset : bool, optional
When set to true, calculate contribution by subsetting for test cells which belong to that
particual classifier.
When set to false, be generic and return contributing features even when testing set has
different cell types.
"""

groupby = explainer.labels_key
classes = explainer.adata.obs[groupby].cat.categories
features = explainer.adata.var_names

nrows = classes.size // 2 + classes.size % 2
fig, ax = plt.subplots(nrows, 2, sharex=False, figsize=[20, 40])

for idx, ct in enumerate(classes):

shaps = pd.DataFrame(shap_values[idx], columns=features)

if subset:
shaps[groupby] = explainer.test[REGISTRY_KEYS.LABELS_KEY]
shaps = shaps[shaps[groupby] == idx].iloc[:, :-1]

tmp_avg = (
shaps.mean(axis=0)
.sort_values(ascending=False)
.reset_index()
.rename(columns={"index": "feature", 0: "weight"})
)
positive = (
tmp_avg.query("weight > 0")
.head(top_n // 2)
.assign(contribution="positive")
)
negative = (
tmp_avg.query("weight < 0")
.tail(top_n // 2)
.assign(contribution="negative")
)

avg = pd.concat([positive, negative])
title = f"Mean(SHAP value average importance for: {ct}"

else:
avg = (
shaps.abs()
.mean(axis=0)
.sort_values(ascending=False)
.reset_index()
.rename(columns={"index": "feature", 0: "weight"})
.query("weight > 0")
.head(10)
)
title = f"Mean(|SHAP value|) average importance for: {ct}"

sns.barplot(
x="weight",
y="feature",
hue="contribution",
palette=["red", "blue"],
data=avg,
ax=ax[idx // 2, idx % 2],
)
ax[idx // 2, idx % 2].set_title(title)
ax[idx // 2, idx % 2].legend(title="Contribution", loc="lower right")
fig.tight_layout()
Loading

0 comments on commit d47f318

Please sign in to comment.