Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport PR #3061 on branch 1.2.x (feat: Support for minification in totalVI) #3077

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -21,6 +21,7 @@ to [Semantic Versioning]. Full commit history is available in the
- Refactored code for minified models. {pr}`2883`.
- Add {class}`scvi.external.METHYLVI` for modeling methylation data from single-cell
bisulfite sequencing (scBS-seq) experiments {pr}`2834`.
- Add MuData Minification option to {class}`~scvi.model.TOTALVI` {pr}`3061`.

#### Fixed

4 changes: 3 additions & 1 deletion src/scvi/data/_utils.py
Original file line number Diff line number Diff line change
@@ -311,10 +311,12 @@ def _get_adata_minify_type(adata: AnnData) -> MinifiedDataType | None:
return adata.uns.get(_constants._ADATA_MINIFY_TYPE_UNS_KEY, None)


def _is_minified(adata: AnnData | str) -> bool:
def _is_minified(adata: AnnOrMuData | str) -> bool:
uns_key = _constants._ADATA_MINIFY_TYPE_UNS_KEY
if isinstance(adata, AnnData):
return adata.uns.get(uns_key, None) is not None
elif isinstance(adata, MuData):
return adata.uns.get(uns_key, None) is not None
elif isinstance(adata, str):
with h5py.File(adata) as fp:
return uns_key in read_elem(fp["uns"]).keys()
15 changes: 11 additions & 4 deletions src/scvi/model/_scanvi.py
Original file line number Diff line number Diff line change
@@ -101,8 +101,8 @@ class SCANVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseMinifiedModeModelClass):

_module_cls = SCANVAE
_training_plan_cls = SemiSupervisedTrainingPlan
_LATENT_QZM = "scanvi_latent_qzm"
_LATENT_QZV = "scanvi_latent_qzv"
_LATENT_QZM_KEY = "scanvi_latent_qzm"
_LATENT_QZV_KEY = "scanvi_latent_qzv"

def __init__(
self,
@@ -132,7 +132,10 @@ def __init__(
n_batch = self.summary_stats.n_batch
use_size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry
library_log_means, library_log_vars = None, None
if not use_size_factor_key and self.minified_data_type is None:
if (
not use_size_factor_key
and self.minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR
):
library_log_means, library_log_vars = _init_library_size(self.adata_manager, n_batch)

self.module = self._module_cls(
@@ -237,6 +240,7 @@ def from_scvi_model(
cls.setup_anndata(
adata,
unlabeled_category=unlabeled_category,
use_minified=False,
**scvi_setup_args,
)
scanvi_model = cls(adata, **non_kwargs, **kwargs, **scanvi_kwargs)
@@ -443,6 +447,7 @@ def setup_anndata(
size_factor_key: str | None = None,
categorical_covariate_keys: list[str] | None = None,
continuous_covariate_keys: list[str] | None = None,
use_minified: bool = True,
**kwargs,
):
"""%(summary)s.
@@ -457,6 +462,8 @@ def setup_anndata(
%(param_size_factor_key)s
%(param_cat_cov_keys)s
%(param_cont_cov_keys)s
use_minified
If True, will register the minified version of the adata if possible.
"""
setup_method_args = cls._get_setup_method_args(**locals())
anndata_fields = [
@@ -469,7 +476,7 @@ def setup_anndata(
]
# register new fields if the adata is minified
adata_minify_type = _get_adata_minify_type(adata)
if adata_minify_type is not None:
if adata_minify_type is not None and use_minified:
anndata_fields += cls._get_fields_for_adata_minification(adata_minify_type)
adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
adata_manager.register_fields(adata, **kwargs)
10 changes: 7 additions & 3 deletions src/scvi/model/_scvi.py
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@

from scvi import REGISTRY_KEYS, settings
from scvi.data import AnnDataManager
from scvi.data._constants import ADATA_MINIFY_TYPE
from scvi.data._utils import _get_adata_minify_type
from scvi.data.fields import (
CategoricalJointObsField,
@@ -100,8 +101,8 @@ class SCVI(
"""

_module_cls = VAE
_SCVI_LATENT_QZM = "scvi_latent_qzm"
_SCVI_LATENT_QZV = "scvi_latent_qzv"
_LATENT_QZM_KEY = "scvi_latent_qzm"
_LATENT_QZV_KEY = "scvi_latent_qzv"

def __init__(
self,
@@ -151,7 +152,10 @@ def __init__(
n_batch = self.summary_stats.n_batch
use_size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry
library_log_means, library_log_vars = None, None
if not use_size_factor_key and self.minified_data_type is None:
if (
not use_size_factor_key
and self.minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR
):
library_log_means, library_log_vars = _init_library_size(
self.adata_manager, n_batch
)
66 changes: 47 additions & 19 deletions src/scvi/model/_totalvi.py
Original file line number Diff line number Diff line change
@@ -12,7 +12,8 @@

from scvi import REGISTRY_KEYS, settings
from scvi.data import AnnDataManager, fields
from scvi.data._utils import _check_nonnegative_integers
from scvi.data._constants import ADATA_MINIFY_TYPE
from scvi.data._utils import _check_nonnegative_integers, _get_adata_minify_type
from scvi.dataloaders import DataSplitter
from scvi.model._utils import (
_get_batch_code_from_category,
@@ -26,7 +27,13 @@
from scvi.train import AdversarialTrainingPlan, TrainRunner
from scvi.utils._docstrings import de_dsp, devices_dsp, setup_anndata_dsp

from .base import ArchesMixin, BaseModelClass, RNASeqMixin, VAEMixin
from .base import (
ArchesMixin,
BaseMinifiedModeModelClass,
BaseMudataMinifiedModeModelClass,
RNASeqMixin,
VAEMixin,
)

if TYPE_CHECKING:
from collections.abc import Iterable, Sequence
@@ -35,18 +42,25 @@
from anndata import AnnData
from mudata import MuData

from scvi._types import Number
from scvi._types import AnnOrMuData, Number

logger = logging.getLogger(__name__)


class TOTALVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseModelClass):
class TOTALVI(
RNASeqMixin,
VAEMixin,
ArchesMixin,
BaseMinifiedModeModelClass,
BaseMudataMinifiedModeModelClass,
):
"""total Variational Inference :cite:p:`GayosoSteier21`.

Parameters
----------
adata
AnnData object that has been registered via :meth:`~scvi.model.TOTALVI.setup_anndata`.
AnnOrMuData object that has been registered via :meth:`~scvi.model.TOTALVI.setup_anndata`
or :meth:`~scvi.model.TOTALVI.setup_mudata`.
n_latent
Dimensionality of the latent space.
gene_dispersion
@@ -84,13 +98,12 @@ class TOTALVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseModelClass):

Examples
--------
>>> adata = anndata.read_h5ad(path_to_anndata)
>>> scvi.model.TOTALVI.setup_anndata(
adata, batch_key="batch", protein_expression_obsm_key="protein_expression"
)
>>> vae = scvi.model.TOTALVI(adata)
>>> mdata = mudata.read_h5mu(path_to_mudata)
>>> scvi.model.TOTALVI.setup_mudata(
mdata, modalities={"rna_layer": "rna", "protein_layer": "prot"}
>>> vae = scvi.model.TOTALVI(mdata)
>>> vae.train()
>>> adata.obsm["X_totalVI"] = vae.get_latent_representation()
>>> mdata.obsm["X_totalVI"] = vae.get_latent_representation()

Notes
-----
@@ -102,13 +115,15 @@ class TOTALVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseModelClass):
"""

_module_cls = TOTALVAE
_LATENT_QZM_KEY = "totalvi_latent_qzm"
_LATENT_QZV_KEY = "totalvi_latent_qzv"
_data_splitter_cls = DataSplitter
_training_plan_cls = AdversarialTrainingPlan
_train_runner_cls = TrainRunner

def __init__(
self,
adata: AnnData,
adata: AnnOrMuData,
n_latent: int = 20,
gene_dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"] = "gene",
protein_dispersion: Literal["protein", "protein-batch", "protein-label"] = "protein",
@@ -129,10 +144,10 @@ def __init__(
batch_mask = self.protein_state_registry.protein_batch_mask
msg = (
"Some proteins have all 0 counts in some batches. "
+ "These proteins will be treated as missing measurements; however, "
+ "this can occur due to experimental design/biology. "
+ "Reinitialize the model with `override_missing_proteins=True`,"
+ "to override this behavior."
"These proteins will be treated as missing measurements; however, "
"this can occur due to experimental design/biology. "
"Reinitialize the model with `override_missing_proteins=True`,"
"to override this behavior."
)
warnings.warn(msg, UserWarning, stacklevel=settings.warnings_stacklevel)
self._use_adversarial_classifier = True
@@ -145,7 +160,7 @@ def __init__(
if empirical_protein_background_prior is not None
else (self.summary_stats.n_proteins > 10)
)
if emp_prior:
if emp_prior and self.minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR:
prior_mean, prior_scale = self._get_totalvi_protein_priors(adata)
else:
prior_mean, prior_scale = None, None
@@ -161,7 +176,10 @@ def __init__(
n_batch = self.summary_stats.n_batch
use_size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry
library_log_means, library_log_vars = None, None
if not use_size_factor_key:
if (
not use_size_factor_key
and self.minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR
):
library_log_means, library_log_vars = _init_library_size(self.adata_manager, n_batch)

self.module = self._module_cls(
@@ -1004,6 +1022,7 @@ def get_feature_correlation_matrix(
batch_size=batch_size,
rna_size_factor=rna_size_factor,
transform_batch=b,
indices=indices,
)
flattened = np.zeros((denoised_data.shape[0] * n_samples, denoised_data.shape[1]))
for i in range(n_samples):
@@ -1214,6 +1233,12 @@ def setup_anndata(
-------
%(returns)s
"""
warnings.warn(
"We recommend using setup_mudata for multi-modal data."
"It does not influence model performance",
DeprecationWarning,
stacklevel=settings.warnings_stacklevel,
)
setup_method_args = cls._get_setup_method_args(**locals())
batch_field = fields.CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key)
anndata_fields = [
@@ -1275,7 +1300,7 @@ def setup_mudata(
--------
>>> mdata = muon.read_10x_h5("pbmc_10k_protein_v3_filtered_feature_bc_matrix.h5")
>>> scvi.model.TOTALVI.setup_mudata(
mdata, modalities={"rna_layer": "rna": "protein_layer": "prot"}
mdata, modalities={"rna_layer": "rna", "protein_layer": "prot"}
)
>>> vae = scvi.model.TOTALVI(mdata)
"""
@@ -1330,6 +1355,9 @@ def setup_mudata(
mod_required=True,
),
]
mdata_minify_type = _get_adata_minify_type(mdata)
if mdata_minify_type is not None:
mudata_fields += cls._get_fields_for_mudata_minification(mdata_minify_type)
adata_manager = AnnDataManager(fields=mudata_fields, setup_method_args=setup_method_args)
adata_manager.register_fields(mdata, **kwargs)
cls.register_manager(adata_manager)
3 changes: 3 additions & 0 deletions src/scvi/model/_utils.py
Original file line number Diff line number Diff line change
@@ -12,6 +12,7 @@
from lightning.pytorch.trainer.connectors.accelerator_connector import (
_AcceleratorConnector,
)
from scipy.sparse import issparse

from scvi import REGISTRY_KEYS, settings
from scvi._types import Number
@@ -244,6 +245,8 @@ def cite_seq_raw_counts_properties(

nan = np.array([np.nan] * adata_manager.summary_stats.n_proteins)
protein_exp = adata_manager.get_from_registry(REGISTRY_KEYS.PROTEIN_EXP_KEY)
if issparse(protein_exp):
protein_exp = protein_exp.toarray()
mean1_pro = np.asarray(protein_exp[idx1].mean(0))
mean2_pro = np.asarray(protein_exp[idx2].mean(0))
nonz1_pro = np.asarray((protein_exp[idx1] > 0).mean(0))
7 changes: 6 additions & 1 deletion src/scvi/model/base/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from ._archesmixin import ArchesMixin
from ._base_model import BaseMinifiedModeModelClass, BaseModelClass
from ._base_model import (
BaseMinifiedModeModelClass,
BaseModelClass,
BaseMudataMinifiedModeModelClass,
)
from ._differential import DifferentialComputation
from ._embedding_mixin import EmbeddingMixin
from ._jaxmixin import JaxTrainingMixin
@@ -26,5 +30,6 @@
"DifferentialComputation",
"JaxTrainingMixin",
"BaseMinifiedModeModelClass",
"BaseMudataMinifiedModeModelClass",
"EmbeddingMixin",
]
Loading