diff --git a/src/nichepca/utils/__init__.py b/src/nichepca/utils/__init__.py index b2d3d0b..2f2a155 100644 --- a/src/nichepca/utils/__init__.py +++ b/src/nichepca/utils/__init__.py @@ -1 +1 @@ -from ._helper import check_for_raw_counts, to_numpy, to_torch +from ._helper import check_for_raw_counts, normalize_per_sample, to_numpy, to_torch diff --git a/src/nichepca/utils/_helper.py b/src/nichepca/utils/_helper.py index f1a63a5..7329874 100644 --- a/src/nichepca/utils/_helper.py +++ b/src/nichepca/utils/_helper.py @@ -4,6 +4,7 @@ from warnings import warn import numpy as np +import scanpy as sc import scipy.sparse as sp import torch @@ -81,3 +82,32 @@ def check_for_raw_counts(adata: AnnData): UserWarning, stacklevel=1, ) + + +def normalize_per_sample(adata, sample_key, **kwargs): + """ + Normalize the per-sample counts in the `adata` object based on the given `sample_key`. + + Parameters + ---------- + adata : AnnData + The annotated data object. + sample_key : str + The key in `adata.obs` that identifies distinct samples. + kwargs : dict, optional + Additional keyword arguments to be passed to `sc.pp.normalize_total`. + + Returns + ------- + None + """ + if kwargs.get("target_sum", None) is not None: + # if target sum is provided, samples make no difference + sc.pp.normalize_total(adata, **kwargs) + else: + adata.X = adata.X.astype(np.float32) + for sample in adata.obs[sample_key].unique(): + mask = adata.obs[sample_key] == sample + sub_ad = adata[mask].copy() + sc.pp.normalize_total(sub_ad, **kwargs) + adata.X[mask.values] = sub_ad.X diff --git a/tests/test_utils.py b/tests/test_utils.py index 93f8a7b..86b90e8 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,6 +2,7 @@ import numpy as np import pytest +import scanpy as sc import torch from utils import generate_dummy_adata @@ -49,3 +50,38 @@ def test_check_for_raw_counts(): # Check for the specific warning with pytest.warns(UserWarning): npc.utils.check_for_raw_counts(adata) + + +def test_normalize_per_sample(): + sample_key = "sample" + + target_sum = 1e4 + + adata_1 = generate_dummy_adata() + npc.utils.normalize_per_sample( + adata_1, target_sum=target_sum, sample_key=sample_key + ) + + adata_2 = generate_dummy_adata() + sc.pp.normalize_total(adata_2, target_sum=target_sum) + + assert np.all(adata_1.X.toarray() == adata_2.X.toarray()) + + # second test without fixed target sum + target_sum = None + + adata_1 = generate_dummy_adata() + npc.utils.normalize_per_sample( + adata_1, target_sum=target_sum, sample_key=sample_key + ) + + adata_2 = generate_dummy_adata() + adata_2.X = adata_2.X.astype(np.float32).toarray() + + for sample in adata_2.obs[sample_key].unique(): + mask = adata_2.obs[sample_key] == sample + sub_ad = adata_2[mask].copy() + sc.pp.normalize_total(sub_ad) + adata_2.X[mask.values] = sub_ad.X + + assert np.all(adata_1.X.astype(np.float32).toarray() == adata_2.X)