Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/v3.0.1 #5

Merged
merged 3 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,17 @@ and this project adheres to [Semantic Versioning][].
[keep a changelog]: https://keepachangelog.com/en/1.0.0/
[semantic versioning]: https://semver.org/spec/v2.0.0.html

## Version 0.3.1

### Added

- `SCANVIBoostrapper.save` for saving bootstrap results into feather format [0ebd757]
- New parameter `gene_symbols` for `plots.feature_plot` and `SCANVIBoostrapper.feature_plot` [81349a7]

### Changed

- Empty subplots are removed from the figure [81349a7]

## Version 0.3.0

### Added
Expand Down
355 changes: 332 additions & 23 deletions docs/notebooks/Example.ipynb

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "scanvi-explainer"
version = "0.3.0"
version = "0.3.1"
description = "Exapliner of scANVI using SHAP"
authors = [
{name = "Martin Proks", email = "[email protected]"},
Expand All @@ -17,6 +17,7 @@ keywords = ["shap", "scanvi", "explainer", "interpretability"]
dependencies = [
"anndata",
"pandas<=2.0", # required by scvi-tool for now
"pyarrow",
"rich",
"seaborn",
"scvi-tools",
Expand Down
55 changes: 45 additions & 10 deletions src/scanvi_explainer/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib.figure import Figure
from scvi import REGISTRY_KEYS

from .scanvi_deep import SCANVIDeep
Expand All @@ -12,7 +13,11 @@ def feature_plot(
shap_values: np.ndarray,
subset: bool = False,
top_n: int = 10,
) -> None:
gene_symbols: None | str = None,
n_cols: int = 2,
figsize: tuple[int, int] = (20, 20),
return_fig: bool = False,
) -> Figure | None:
"""Prints feature contribution (absolute mean SHAP value) for each cell type (top 10).

Parameters
Expand All @@ -26,14 +31,33 @@ def feature_plot(
particual classifier.
When set to false, be generic and return contributing features even when testing set has
different cell types.
top_n: int
Subset for top N number of features
gene_symbols: None | str = None
Column name in `var` for gene symbols
n_cols: int
Number of columns in Figure
figsize : tuple[int, int]
Figure size, by default [20, 20]
return_fig : bool
Flag to return figure object, by default False
"""

if gene_symbols and gene_symbols not in explainer.adata.var.columns:
raise ValueError(
"Specified gene_symbol not present in the 'var' of model's adata!"
)

groupby = explainer.labels_key
classes = explainer.adata.obs[groupby].cat.categories
features = explainer.adata.var_names
features = (
explainer.adata.var[gene_symbols].values
if gene_symbols
else explainer.adata.var_names
)

nrows = classes.size // 2 + classes.size % 2
fig, ax = plt.subplots(nrows, 2, sharex=False, figsize=[20, 40])
nrows = round(classes.size / n_cols)
fig, ax = plt.subplots(nrows, n_cols, sharex=False, figsize=figsize)

for idx, ct in enumerate(classes):
shaps = pd.DataFrame(shap_values[idx], columns=features)
Expand All @@ -60,7 +84,7 @@ def feature_plot(
)

avg = pd.concat([positive, negative])
title = f"Average SHAP value importance for: {ct}"
title = f"Mean(|SHAP value|) importance for: {ct}"

else:
avg = (
Expand All @@ -72,16 +96,27 @@ def feature_plot(
.query("weight > 0")
.head(10)
)
title = f"Mean(|SHAP value|) average importance for: {ct}"
title = f"Mean(|SHAP value|) importance for: {ct}"

sns.barplot(
x="weight",
y="feature",
hue="contribution",
palette=["red", "blue"],
data=avg,
ax=ax[idx // 2, idx % 2],
ax=ax[idx // n_cols, idx % n_cols],
)
ax[idx // 2, idx % 2].set_title(title)
ax[idx // 2, idx % 2].legend(title="Contribution", loc="lower right")
fig.tight_layout()

ax[idx // n_cols, idx % n_cols].set_title(title)
ax[idx // n_cols, idx % n_cols].legend(title="Contribution", loc="lower right")

# clean axes which are empty
# from: https://stackoverflow.com/a/76269136
_ = [fig.delaxes(ax_) for ax_ in ax.flatten() if not ax_.has_data()]

fig.tight_layout()

if return_fig:
return fig

return None
63 changes: 59 additions & 4 deletions src/scanvi_explainer/scanvi_bootstrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def estimate(

@staticmethod
def _filter(
shap_values: list, features: list[str], top_n: int = 10
shap_values: list, features: list[str], top_n: None | int = 10
) -> pd.DataFrame:
"""Helper function for filtering top positive only SHAP features.

Expand All @@ -106,7 +106,8 @@ def _filter(
features : list[str]
Features (genes)
top_n : int
Number of top features to subset, by default 10
Number of top features to subset, by default 10. If None is specified, the filter does
not apply.

Returns
-------
Expand All @@ -125,6 +126,7 @@ def feature_plot(
n_features: int = 10,
metric_fn: Callable[..., ArrayLike] = np.mean,
kind: Literal["boxplot", "barplot"] = "boxplot",
gene_symbols: None | str = None,
n_cols: int = 3,
figsize: tuple[int, int] = (20, 20),
return_fig: bool = False,
Expand All @@ -141,6 +143,8 @@ def feature_plot(
Statistical measurement of each boostrap, by default np.mean
kind : Literal[&quot;boxplot&quot;, &quot;barplot&quot;]
Type of plot, by default "boxplot"
gene_symbols: None | str = None
Column name in `var` for gene symbols
n_cols : int
Number of columns for subplots, by default 3
figsize : tuple[int, int]
Expand All @@ -162,11 +166,20 @@ def feature_plot(
if kind not in ["boxplot", "barplot"]:
raise ValueError(f"Specified {kind} not supported!")

features = self.model.adata.var_names
if gene_symbols and gene_symbols not in self.model.adata.var.columns:
raise ValueError(
"Specified gene_symbol not present in the 'var' of model's adata!"
)

features = (
self.model.adata.var[gene_symbols]
if gene_symbols
else self.model.adata.var_names
)
sample_stat = self.estimate(metric_fn, shap_values)
labels = self.model.adata.obs[get_labels_key(self.model)].cat.categories

n_rows = labels.size // n_cols + labels.size % n_cols
n_rows = round(labels.size / n_cols)
fig, ax = plt.subplots(n_rows, n_cols, figsize=figsize)
for idx, label in enumerate(labels):
data = self._filter(sample_stat[idx], features, n_features).T
Expand All @@ -193,6 +206,10 @@ def feature_plot(
data=data,
).set(title=label)

# clean axes which are empty
# from: https://stackoverflow.com/a/76269136
_ = [fig.delaxes(ax_) for ax_ in ax.flatten() if not ax_.has_data()]

fig.suptitle(
f"Top {n_features} SHAP values based on boostrapping (n={self.NUM_OF_BOOTSRAPS})"
)
Expand All @@ -204,3 +221,41 @@ def feature_plot(
return fig

return None

def save(self, shap_values: list[np.ndarray], filename: str):
"""Save results to feather format.

Parameters
----------
shap_values : list[np.ndarray]
SHAP values
filename : str
Path to filename containing .feather extension

Examples
--------
>>> lvae = scvi.model.SCANVI.load("...")
>>> bootstrapper = SCANVIBoostrapper(lvae, n_bootstraps=10)
>>> shap_values = bootstrapper.run(train_size=0.8, batch_size=64)
>>> bootstrapper.save(shap_values, "./bootstrapped_shaps.feather")
"""

features = self.model.adata.var_names
sample_stat = self.estimate(np.mean, shap_values)
labels = self.model.adata.obs[get_labels_key(self.model)].cat.categories

res = []
for idx, label in enumerate(labels):
data = self._filter(sample_stat[idx], features, top_n=None)
data.columns = [f"n_{c}" for c in data.columns]
data["label"] = label

res.append(data)

if not filename.endswith(".feather"):
filename += ".feather"

try:
pd.concat(res).reset_index().to_feather(filename)
except IOError:
raise