From 12660308cfe926103f922240a443b3342899ddf1 Mon Sep 17 00:00:00 2001 From: dschaub95 Date: Fri, 27 Sep 2024 17:57:04 +0200 Subject: [PATCH 1/8] set radius expected dtype to float for distance graph --- src/nichepca/graph_construction/_spatial_graph.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/nichepca/graph_construction/_spatial_graph.py b/src/nichepca/graph_construction/_spatial_graph.py index 1e02673..72f69fd 100644 --- a/src/nichepca/graph_construction/_spatial_graph.py +++ b/src/nichepca/graph_construction/_spatial_graph.py @@ -209,7 +209,7 @@ def knn_graph( def distance_graph( adata: AnnData, - radius: int = 50, + radius: float = 50, obsm_key: str = "spatial", remove_self_loops: bool = False, p: int = 2, @@ -223,7 +223,7 @@ def distance_graph( ---------- adata : AnnData Annotated data object. - radius : int, default 50 + radius : float, default 50 Radius for the distance threshold. obsm_key : str, default "spatial" Key in `obsm` attribute where the spatial data is stored. From 883c4bca3cac46d148f53ac17d2469d65fa0d3b7 Mon Sep 17 00:00:00 2001 From: dschaub95 Date: Sat, 28 Sep 2024 17:20:41 +0200 Subject: [PATCH 2/8] generate random cell type labels --- tests/utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/utils.py b/tests/utils.py index aa00dbf..4a40c2c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -4,7 +4,7 @@ from sklearn.cluster import KMeans -def generate_dummy_adata(n_cells=100, n_genes=50, n_samples=2, seed=0): +def generate_dummy_adata(n_cells=100, n_genes=50, n_samples=2, n_celltypes=5, seed=0): random_state = np.random.RandomState(seed) X = random_state.randint(0, 400, size=(n_cells, n_genes)) @@ -23,4 +23,8 @@ def generate_dummy_adata(n_cells=100, n_genes=50, n_samples=2, seed=0): samples = kmeans.fit_predict(coords) adata.obs["sample"] = [str(s) for s in samples] + # create artificial cell type column + adata.obs["cell_type"] = random_state.randint(0, n_celltypes, size=n_cells) + adata.obs["cell_type"] = adata.obs["cell_type"].astype(str).astype("category") + return adata From b8d5b0f775c9a31db11173ebab847d03b7632800 Mon Sep 17 00:00:00 2001 From: dschaub95 Date: Sat, 28 Sep 2024 17:21:11 +0200 Subject: [PATCH 3/8] refactor nichepca function and generalize functionality --- src/nichepca/workflows/__init__.py | 2 +- src/nichepca/workflows/_nichepca.py | 165 ++++++++++++++++++++-------- 2 files changed, 121 insertions(+), 46 deletions(-) diff --git a/src/nichepca/workflows/__init__.py b/src/nichepca/workflows/__init__.py index c80f759..f4e1b54 100644 --- a/src/nichepca/workflows/__init__.py +++ b/src/nichepca/workflows/__init__.py @@ -1 +1 @@ -from ._nichepca import run_nichepca +from ._nichepca import nichepca diff --git a/src/nichepca/workflows/_nichepca.py b/src/nichepca/workflows/_nichepca.py index b1a2a68..844ecd0 100644 --- a/src/nichepca/workflows/_nichepca.py +++ b/src/nichepca/workflows/_nichepca.py @@ -2,87 +2,162 @@ from typing import TYPE_CHECKING +import numpy as np +import pandas as pd import scanpy as sc from nichepca.graph_construction import ( construct_multi_sample_graph, - distance_graph, - knn_graph, + resolve_graph_constructor, ) from nichepca.nhood_embedding import aggregate -from nichepca.utils import check_for_raw_counts +from nichepca.utils import check_for_raw_counts, normalize_per_sample if TYPE_CHECKING: from anndata import AnnData -def run_nichepca( +def nichepca( adata: AnnData, - knn: int = None, - radius: float = None, - sample_key: str = None, + knn: int | None = None, + radius: float | None = None, + delaunay: bool = False, n_comps: int = 30, - max_iter_harmony: int = 50, + obs_key: str | None = None, + sample_key: str | None = None, + pipeline: tuple | list = ("norm", "log1p", "agg", "pca"), norm_per_sample: bool = True, + backend: str = "pyg", + aggr: str = "mean", + allow_harmony: bool = True, + max_iter_harmony: int = 50, **kwargs, ): """ - Run the NichePCA workflow. + Run the general NichePCA workflow. Parameters ---------- adata : AnnData - Annotated data object. - knn : int - Number of nearest neighbors for the kNN graph. - sample_key : str, optional - Key in `adata.obs` that identifies distinct samples. If provided, harmony will be used to - integrate the data. - radius : float, optional - The radius of the neighborhood graph. + The input AnnData object. + knn : int | None, optional + Number of nearest neighbors to use for graph construction. + radius : float | None, optional + Radius for graph construction. + delaunay : bool, optional + Whether to use Delaunay triangulation for graph construction. n_comps : int, optional Number of principal components to compute. - max_iter_harmony : int, optional - Maximum number of iterations for harmony. + obs_key : str | None, optional + Observation key to use for generating a new AnnData object. + sample_key : str | None, optional + Sample key to use for multi-sample graph construction. + pipeline : tuple | list, optional + Pipeline of steps to perform. Must include 'agg'. norm_per_sample : bool, optional - Whether to normalize the data per sample. - kwargs : dict, optional - Additional keyword arguments for the graph construction. + Whether to normalize per sample. + backend : str, optional + Backend to use for aggregation. + aggr : str, optional + Aggregation method to use. + allow_harmony : bool, optional + Whether to allow Harmony integration. + max_iter_harmony : int, optional + Maximum number of iterations for Harmony. + **kwargs : dict + Additional keyword arguments. Returns ------- None """ - check_for_raw_counts(adata) + # we always need to use agg + assert "agg" in pipeline, "aggregation must be part of the pipeline" + # assert that the pca is behind norm and log1p + if "pca" in pipeline: + pca_after_norm = np.argmax(np.array(pipeline) == "pca") > np.argmax( + np.array(pipeline) == "norm" + ) + pca_after_log1p = np.argmax(np.array(pipeline) == "pca") > np.argmax( + np.array(pipeline) == "log1p" + ) + assert ( + pca_after_norm and pca_after_log1p + ), "pca must be executed after norm and log1p" + + # perform sanity check in case we are normalizing the data + if "norm" or "log1p" in pipeline and obs_key is None: + check_for_raw_counts(adata) + + # extract any additional kwargs that are not directed to the graph construction + target_sum = kwargs.pop("target_sum", None) + # construct the (multi-sample) graph if sample_key is not None: construct_multi_sample_graph( - adata, sample_key=sample_key, knn=knn, radius=radius, **kwargs + adata, + sample_key=sample_key, + knn=knn, + radius=radius, + delaunay=delaunay, + **kwargs, ) else: - if knn is not None: - knn_graph(adata, knn, **kwargs) - elif radius is not None: - distance_graph(adata, radius, **kwargs) - else: - raise ValueError("Either knn or radius must be provided.") + resolve_graph_constructor(radius, knn, delaunay)(adata, **kwargs) - if norm_per_sample and sample_key is not None: - for sample in adata.obs[sample_key].unique(): - mask = adata.obs[sample_key] == sample - sub_ad = adata[mask].copy() - sc.pp.normalize_total(sub_ad) - adata[mask].X = sub_ad.X + # if an obs_key is provided generate a new AnnData + if obs_key is not None: + df = pd.get_dummies(adata.obs[obs_key], dtype=np.int8) + ad_tmp = sc.AnnData( + X=df.values, + obs=adata.obs, + var=pd.DataFrame(index=df.columns), + uns=adata.uns, + ) + # remove normalization steps + pipeline = [p for p in pipeline if p not in ["norm", "log1p"]] + print(f"obs_key provided, running pipeline: {'->'.join(pipeline)}") else: - sc.pp.normalize_total(adata) + ad_tmp = adata + print(f"Running pipeline: {'->'.join(pipeline)}") - sc.pp.log1p(adata) - - aggregate(adata) + for fn in pipeline: + if fn == "norm": + if norm_per_sample and sample_key is not None: + normalize_per_sample( + ad_tmp, sample_key=sample_key, target_sum=target_sum + ) + else: + sc.pp.normalize_total(ad_tmp, target_sum=target_sum) + elif fn == "log1p": + sc.pp.log1p(ad_tmp) + elif fn == "agg": + aggregate(ad_tmp, backend=backend, aggr=aggr) + elif fn == "pca": + sc.tl.pca(ad_tmp, n_comps=n_comps) + # run harmony if sample_key is provided and obs key is None + if sample_key is not None and obs_key is None and allow_harmony: + sc.external.pp.harmony_integrate( + ad_tmp, key=sample_key, max_iter_harmony=max_iter_harmony + ) + else: + raise ValueError(f"Unknown step in the pipeline: {fn}") - sc.tl.pca(adata, n_comps=n_comps) + # extract the results and remove old keys + if "X_pca_harmony" in ad_tmp.obsm: + X_npca = ad_tmp.obsm["X_pca_harmony"].copy() + del ad_tmp.obsm["X_pca_harmony"] + else: + X_npca = ad_tmp.obsm["X_pca"].copy() + del ad_tmp.obsm["X_pca"] - if sample_key is not None: - sc.external.pp.harmony_integrate( - adata, key=sample_key, max_iter_harmony=max_iter_harmony - ) + # store the results + adata.obsm["X_npca"] = X_npca + adata.uns["npca"] = ad_tmp.uns["pca"].copy() + adata.uns["npca"]["PCs"] = pd.DataFrame( + data=ad_tmp.varm["PCs"], + index=ad_tmp.var_names, + columns=[f"PC{i}" for i in range(n_comps)], + ) + del ad_tmp.varm["PCs"] + del ad_tmp.uns["pca"] From 8ff658a0839db020b737216fa524b35e36b96906 Mon Sep 17 00:00:00 2001 From: dschaub95 Date: Sat, 28 Sep 2024 17:21:47 +0200 Subject: [PATCH 4/8] extend nichepca tests --- tests/test_workflows.py | 64 ++++++++++++++++++++++++++++++++++++----- 1 file changed, 57 insertions(+), 7 deletions(-) diff --git a/tests/test_workflows.py b/tests/test_workflows.py index bb769c4..d654b92 100644 --- a/tests/test_workflows.py +++ b/tests/test_workflows.py @@ -1,4 +1,5 @@ import numpy as np +import pandas as pd import scanpy as sc from utils import generate_dummy_adata @@ -7,7 +8,7 @@ def test_nichepca_single(): adata_1 = generate_dummy_adata() - npc.wf.run_nichepca(adata_1, knn=10, n_comps=30) + npc.wf.nichepca(adata_1, knn=10, n_comps=30) adata_2 = generate_dummy_adata() sc.pp.normalize_total(adata_2) @@ -16,12 +17,61 @@ def test_nichepca_single(): npc.ne.aggregate(adata_2) sc.tl.pca(adata_2, n_comps=30) - assert np.all(adata_1.obsm["X_pca"] == adata_2.obsm["X_pca"]) + assert np.all(adata_1.obsm["X_npca"] == adata_2.obsm["X_pca"]) + # test with obs_key + obs_key = "cell_type" + n_celltypes = 5 + adata_1 = generate_dummy_adata(n_celltypes=n_celltypes) + npc.wf.nichepca(adata_1, knn=10, n_comps=n_celltypes - 1, obs_key=obs_key) -def test_nichepca_multi(): - adata = generate_dummy_adata() - npc.wf.run_nichepca(adata, knn=10, sample_key="sample") + adata_2 = generate_dummy_adata(n_celltypes=n_celltypes) + npc.gc.knn_graph(adata_2, knn=10) + df = pd.get_dummies(adata_2.obs[obs_key], dtype=np.int8) + ad_tmp = sc.AnnData( + X=df.values, + obs=adata_2.obs, + var=pd.DataFrame(index=df.columns), + uns=adata_2.uns, + ) + npc.ne.aggregate(ad_tmp) + sc.tl.pca(ad_tmp, n_comps=n_celltypes - 1) + + assert np.all(adata_1.obsm["X_npca"] == ad_tmp.obsm["X_pca"]) + + +def test_nichepca_multi_sample(): + adata_1 = generate_dummy_adata() + npc.wf.nichepca(adata_1, knn=10, n_comps=30, sample_key="sample") + + adata_2 = generate_dummy_adata() + npc.gc.construct_multi_sample_graph(adata_2, knn=10, sample_key="sample") + npc.utils.normalize_per_sample(adata_2, sample_key="sample") + sc.pp.log1p(adata_2) + npc.ne.aggregate(adata_2) + sc.tl.pca(adata_2, n_comps=30) + sc.external.pp.harmony_integrate(adata_2, key="sample", max_iter_harmony=50) + + assert np.all(adata_1.obsm["X_npca"] == adata_2.obsm["X_pca_harmony"]) + + # test with obs_key + obs_key = "cell_type" + n_celltypes = 5 + adata_1 = generate_dummy_adata(n_celltypes=n_celltypes) + npc.wf.nichepca( + adata_1, knn=10, n_comps=n_celltypes - 1, obs_key=obs_key, sample_key="sample" + ) + + adata_2 = generate_dummy_adata(n_celltypes=n_celltypes) + npc.gc.construct_multi_sample_graph(adata_2, knn=10, sample_key="sample") + df = pd.get_dummies(adata_2.obs[obs_key], dtype=np.int8) + ad_tmp = sc.AnnData( + X=df.values, + obs=adata_2.obs, + var=pd.DataFrame(index=df.columns), + uns=adata_2.uns, + ) + npc.ne.aggregate(ad_tmp) + sc.tl.pca(ad_tmp, n_comps=n_celltypes - 1) - assert "X_pca" in adata.obsm.keys() - assert "X_pca_harmony" in adata.obsm.keys() + assert np.all(adata_1.obsm["X_npca"] == ad_tmp.obsm["X_pca"]) From b8c3b08227d2904b93c372b49d36a71d3b3cf09c Mon Sep 17 00:00:00 2001 From: dschaub95 Date: Sun, 29 Sep 2024 18:29:35 +0200 Subject: [PATCH 5/8] update readme --- README.md | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 3aa05c6..6c19b71 100644 --- a/README.md +++ b/README.md @@ -15,17 +15,36 @@ Package for PCA-based spatial domain identification in single-cell spatial trans - [API documentation][link-api]. --> -Given an AnnData object `adata`, you can run nichepca as follows: +Given an AnnData object `adata`, you can run nichepca starting from raw counts as follows: ```python import scanpy as sc import nichepca as npc -npc.wf.run_nichepca(adata, knn=5) -sc.pp.neighbors(adata) +npc.wf.nichepca(adata, knn=25) +sc.pp.neighbors(adata, use_rep="X_npca") sc.tl.leiden(adata, resolution=0.5) ``` +If you have multiple samples in `adata.obs['sample']`, you can provide the key `sample` to `npc.wf.nichepca`: + +```python +npc.wf.nichepca(adata, knn=25, sample_key="sample") +``` + +If you have cell type labels in `adata.obs['cell_type']`, you can directly provide them to `nichepca` as follows: + +```python +npc.wf.nichepca(adata, knn=25, obs_key='cell_type') +``` + +The `nichepca` functiopn also allows to customize the original `("norm", "log1p", "agg", "pca")` pipeline, e.g., without median normalization: +```python +npc.wf.nichepca(adata, knn=25, pipeline=["log1p", "agg", "pca"]) +``` + +We found that higher number of neighbors e.g., `knn=25` lead to better results in brain tissue, while `knn=10` works well for kidney data. We recommend to qualitatively optimize these parameters on a small subset of your data. + ## Installation You need to have Python 3.10 or newer installed on your system. If you don't have From d86058cc6d2e9bd8fdc6536a6cbf7c4781acd930 Mon Sep 17 00:00:00 2001 From: dschaub95 Date: Sun, 29 Sep 2024 18:36:55 +0200 Subject: [PATCH 6/8] update readme --- README.md | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 6c19b71..05e9181 100644 --- a/README.md +++ b/README.md @@ -42,9 +42,16 @@ The `nichepca` functiopn also allows to customize the original `("norm", "log1p" ```python npc.wf.nichepca(adata, knn=25, pipeline=["log1p", "agg", "pca"]) ``` - -We found that higher number of neighbors e.g., `knn=25` lead to better results in brain tissue, while `knn=10` works well for kidney data. We recommend to qualitatively optimize these parameters on a small subset of your data. - +or with `"pca"` before `"agg"`: +```python +npc.wf.nichepca(adata, knn=25, pipeline=["norm", "log1p", "pca", "agg"]) +``` +or without `"pca"` at all: +```python +npc.wf.nichepca(adata, knn=25, pipeline=["norm", "log1p", "agg"]) +``` +## Setting parameters +We found that higher number of neighbors e.g., `knn=25` lead to better results in brain tissue, while `knn=10` works well for kidney data. We recommend to qualitatively optimize these parameters on a small subset of your data. The number of PCs (`n_comps=30` by default) seems to have negligible effect on the results. ## Installation You need to have Python 3.10 or newer installed on your system. If you don't have From 052169e532af16d9d1bf71167ad08a692a42318e Mon Sep 17 00:00:00 2001 From: dschaub95 Date: Mon, 30 Sep 2024 11:04:10 +0200 Subject: [PATCH 7/8] fix bug when running pca before agg and add obsm_key option --- src/nichepca/workflows/_nichepca.py | 52 ++++++++++++++++++++--------- 1 file changed, 36 insertions(+), 16 deletions(-) diff --git a/src/nichepca/workflows/_nichepca.py b/src/nichepca/workflows/_nichepca.py index 844ecd0..a9d92e7 100644 --- a/src/nichepca/workflows/_nichepca.py +++ b/src/nichepca/workflows/_nichepca.py @@ -24,6 +24,7 @@ def nichepca( delaunay: bool = False, n_comps: int = 30, obs_key: str | None = None, + obsm_key: str | None = None, sample_key: str | None = None, pipeline: tuple | list = ("norm", "log1p", "agg", "pca"), norm_per_sample: bool = True, @@ -50,6 +51,8 @@ def nichepca( Number of principal components to compute. obs_key : str | None, optional Observation key to use for generating a new AnnData object. + obsm_key : str | None, optional + Observation matrix key to use as input. sample_key : str | None, optional Sample key to use for multi-sample graph construction. pipeline : tuple | list, optional @@ -86,7 +89,7 @@ def nichepca( ), "pca must be executed after norm and log1p" # perform sanity check in case we are normalizing the data - if "norm" or "log1p" in pipeline and obs_key is None: + if "norm" or "log1p" in pipeline and obs_key is None and obsm_key is None: check_for_raw_counts(adata) # extract any additional kwargs that are not directed to the graph construction @@ -108,19 +111,27 @@ def nichepca( # if an obs_key is provided generate a new AnnData if obs_key is not None: df = pd.get_dummies(adata.obs[obs_key], dtype=np.int8) - ad_tmp = sc.AnnData( - X=df.values, - obs=adata.obs, - var=pd.DataFrame(index=df.columns), - uns=adata.uns, - ) + X = df.values + var = pd.DataFrame(index=df.columns) # remove normalization steps pipeline = [p for p in pipeline if p not in ["norm", "log1p"]] print(f"obs_key provided, running pipeline: {'->'.join(pipeline)}") + elif obsm_key is not None: + X = adata.obsm[obsm_key] + var = adata.var[[]] else: - ad_tmp = adata + X = adata.X + var = adata.var[[]] print(f"Running pipeline: {'->'.join(pipeline)}") + # create intermediate AnnData + ad_tmp = sc.AnnData( + X=X, + obs=adata.obs, + var=var, + uns=adata.uns, + ) + for fn in pipeline: if fn == "norm": if norm_per_sample and sample_key is not None: @@ -132,7 +143,20 @@ def nichepca( elif fn == "log1p": sc.pp.log1p(ad_tmp) elif fn == "agg": - aggregate(ad_tmp, backend=backend, aggr=aggr) + # if pca is executed before agg, we need to aggregate the pca results + if "X_pca_harmony" in ad_tmp.obsm: + obsm_key_agg = "X_pca_harmony" + elif "X_pca" in ad_tmp.obsm: + obsm_key_agg = "X_pca" + else: + obsm_key_agg = None + aggregate( + ad_tmp, + backend=backend, + aggr=aggr, + obsm_key=obsm_key_agg, + suffix="", + ) elif fn == "pca": sc.tl.pca(ad_tmp, n_comps=n_comps) # run harmony if sample_key is provided and obs key is None @@ -145,19 +169,15 @@ def nichepca( # extract the results and remove old keys if "X_pca_harmony" in ad_tmp.obsm: - X_npca = ad_tmp.obsm["X_pca_harmony"].copy() - del ad_tmp.obsm["X_pca_harmony"] + X_npca = ad_tmp.obsm["X_pca_harmony"] else: - X_npca = ad_tmp.obsm["X_pca"].copy() - del ad_tmp.obsm["X_pca"] + X_npca = ad_tmp.obsm["X_pca"] # store the results adata.obsm["X_npca"] = X_npca - adata.uns["npca"] = ad_tmp.uns["pca"].copy() + adata.uns["npca"] = ad_tmp.uns["pca"] adata.uns["npca"]["PCs"] = pd.DataFrame( data=ad_tmp.varm["PCs"], index=ad_tmp.var_names, columns=[f"PC{i}" for i in range(n_comps)], ) - del ad_tmp.varm["PCs"] - del ad_tmp.uns["pca"] From bef5eeadfe0fad21734ec7ebf0af2fd6e7b76fda Mon Sep 17 00:00:00 2001 From: dschaub95 Date: Mon, 30 Sep 2024 11:04:37 +0200 Subject: [PATCH 8/8] test pca before agg case --- tests/test_workflows.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/test_workflows.py b/tests/test_workflows.py index d654b92..645153b 100644 --- a/tests/test_workflows.py +++ b/tests/test_workflows.py @@ -39,6 +39,21 @@ def test_nichepca_single(): assert np.all(adata_1.obsm["X_npca"] == ad_tmp.obsm["X_pca"]) + # test with pca before agg + adata_1 = generate_dummy_adata() + npc.wf.nichepca( + adata_1, knn=10, n_comps=30, pipeline=["norm", "log1p", "pca", "agg"] + ) + + adata_2 = generate_dummy_adata() + npc.gc.knn_graph(adata_2, knn=10) + sc.pp.normalize_total(adata_2) + sc.pp.log1p(adata_2) + sc.pp.pca(adata_2, n_comps=30) + npc.ne.aggregate(adata_2, obsm_key="X_pca", suffix="") + + assert np.all(adata_1.obsm["X_npca"] == adata_2.obsm["X_pca"]) + def test_nichepca_multi_sample(): adata_1 = generate_dummy_adata()