Skip to content

Commit

Permalink
feat: Support for minification in totalVI (#3061)
Browse files Browse the repository at this point in the history
Add support of minification in totalVI

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ori Kronfeld <[email protected]>
  • Loading branch information
3 people authored Dec 4, 2024
1 parent 2bb56ee commit 337ec87
Show file tree
Hide file tree
Showing 13 changed files with 750 additions and 100 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ to [Semantic Versioning]. Full commit history is available in the
- Refactored code for minified models. {pr}`2883`.
- Add {class}`scvi.external.METHYLVI` for modeling methylation data from single-cell
bisulfite sequencing (scBS-seq) experiments {pr}`2834`.
- Add MuData Minification option to {class}`~scvi.model.TOTALVI` {pr}`3061`.

#### Fixed

Expand Down
4 changes: 3 additions & 1 deletion src/scvi/data/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,10 +311,12 @@ def _get_adata_minify_type(adata: AnnData) -> MinifiedDataType | None:
return adata.uns.get(_constants._ADATA_MINIFY_TYPE_UNS_KEY, None)


def _is_minified(adata: AnnData | str) -> bool:
def _is_minified(adata: AnnOrMuData | str) -> bool:
uns_key = _constants._ADATA_MINIFY_TYPE_UNS_KEY
if isinstance(adata, AnnData):
return adata.uns.get(uns_key, None) is not None
elif isinstance(adata, MuData):
return adata.uns.get(uns_key, None) is not None
elif isinstance(adata, str):
with h5py.File(adata) as fp:
return uns_key in read_elem(fp["uns"]).keys()
Expand Down
15 changes: 11 additions & 4 deletions src/scvi/model/_scanvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ class SCANVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseMinifiedModeModelClass):

_module_cls = SCANVAE
_training_plan_cls = SemiSupervisedTrainingPlan
_LATENT_QZM = "scanvi_latent_qzm"
_LATENT_QZV = "scanvi_latent_qzv"
_LATENT_QZM_KEY = "scanvi_latent_qzm"
_LATENT_QZV_KEY = "scanvi_latent_qzv"

def __init__(
self,
Expand Down Expand Up @@ -132,7 +132,10 @@ def __init__(
n_batch = self.summary_stats.n_batch
use_size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry
library_log_means, library_log_vars = None, None
if not use_size_factor_key and self.minified_data_type is None:
if (
not use_size_factor_key
and self.minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR
):
library_log_means, library_log_vars = _init_library_size(self.adata_manager, n_batch)

self.module = self._module_cls(
Expand Down Expand Up @@ -237,6 +240,7 @@ def from_scvi_model(
cls.setup_anndata(
adata,
unlabeled_category=unlabeled_category,
use_minified=False,
**scvi_setup_args,
)
scanvi_model = cls(adata, **non_kwargs, **kwargs, **scanvi_kwargs)
Expand Down Expand Up @@ -443,6 +447,7 @@ def setup_anndata(
size_factor_key: str | None = None,
categorical_covariate_keys: list[str] | None = None,
continuous_covariate_keys: list[str] | None = None,
use_minified: bool = True,
**kwargs,
):
"""%(summary)s.
Expand All @@ -457,6 +462,8 @@ def setup_anndata(
%(param_size_factor_key)s
%(param_cat_cov_keys)s
%(param_cont_cov_keys)s
use_minified
If True, will register the minified version of the adata if possible.
"""
setup_method_args = cls._get_setup_method_args(**locals())
anndata_fields = [
Expand All @@ -469,7 +476,7 @@ def setup_anndata(
]
# register new fields if the adata is minified
adata_minify_type = _get_adata_minify_type(adata)
if adata_minify_type is not None:
if adata_minify_type is not None and use_minified:
anndata_fields += cls._get_fields_for_adata_minification(adata_minify_type)
adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
adata_manager.register_fields(adata, **kwargs)
Expand Down
10 changes: 7 additions & 3 deletions src/scvi/model/_scvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from scvi import REGISTRY_KEYS, settings
from scvi.data import AnnDataManager
from scvi.data._constants import ADATA_MINIFY_TYPE
from scvi.data._utils import _get_adata_minify_type
from scvi.data.fields import (
CategoricalJointObsField,
Expand Down Expand Up @@ -100,8 +101,8 @@ class SCVI(
"""

_module_cls = VAE
_SCVI_LATENT_QZM = "scvi_latent_qzm"
_SCVI_LATENT_QZV = "scvi_latent_qzv"
_LATENT_QZM_KEY = "scvi_latent_qzm"
_LATENT_QZV_KEY = "scvi_latent_qzv"

def __init__(
self,
Expand Down Expand Up @@ -151,7 +152,10 @@ def __init__(
n_batch = self.summary_stats.n_batch
use_size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry
library_log_means, library_log_vars = None, None
if not use_size_factor_key and self.minified_data_type is None:
if (
not use_size_factor_key
and self.minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR
):
library_log_means, library_log_vars = _init_library_size(
self.adata_manager, n_batch
)
Expand Down
66 changes: 47 additions & 19 deletions src/scvi/model/_totalvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

from scvi import REGISTRY_KEYS, settings
from scvi.data import AnnDataManager, fields
from scvi.data._utils import _check_nonnegative_integers
from scvi.data._constants import ADATA_MINIFY_TYPE
from scvi.data._utils import _check_nonnegative_integers, _get_adata_minify_type
from scvi.dataloaders import DataSplitter
from scvi.model._utils import (
_get_batch_code_from_category,
Expand All @@ -26,7 +27,13 @@
from scvi.train import AdversarialTrainingPlan, TrainRunner
from scvi.utils._docstrings import de_dsp, devices_dsp, setup_anndata_dsp

from .base import ArchesMixin, BaseModelClass, RNASeqMixin, VAEMixin
from .base import (
ArchesMixin,
BaseMinifiedModeModelClass,
BaseMudataMinifiedModeModelClass,
RNASeqMixin,
VAEMixin,
)

if TYPE_CHECKING:
from collections.abc import Iterable, Sequence
Expand All @@ -35,18 +42,25 @@
from anndata import AnnData
from mudata import MuData

from scvi._types import Number
from scvi._types import AnnOrMuData, Number

logger = logging.getLogger(__name__)


class TOTALVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseModelClass):
class TOTALVI(
RNASeqMixin,
VAEMixin,
ArchesMixin,
BaseMinifiedModeModelClass,
BaseMudataMinifiedModeModelClass,
):
"""total Variational Inference :cite:p:`GayosoSteier21`.
Parameters
----------
adata
AnnData object that has been registered via :meth:`~scvi.model.TOTALVI.setup_anndata`.
AnnOrMuData object that has been registered via :meth:`~scvi.model.TOTALVI.setup_anndata`
or :meth:`~scvi.model.TOTALVI.setup_mudata`.
n_latent
Dimensionality of the latent space.
gene_dispersion
Expand Down Expand Up @@ -84,13 +98,12 @@ class TOTALVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseModelClass):
Examples
--------
>>> adata = anndata.read_h5ad(path_to_anndata)
>>> scvi.model.TOTALVI.setup_anndata(
adata, batch_key="batch", protein_expression_obsm_key="protein_expression"
)
>>> vae = scvi.model.TOTALVI(adata)
>>> mdata = mudata.read_h5mu(path_to_mudata)
>>> scvi.model.TOTALVI.setup_mudata(
mdata, modalities={"rna_layer": "rna", "protein_layer": "prot"}
>>> vae = scvi.model.TOTALVI(mdata)
>>> vae.train()
>>> adata.obsm["X_totalVI"] = vae.get_latent_representation()
>>> mdata.obsm["X_totalVI"] = vae.get_latent_representation()
Notes
-----
Expand All @@ -102,13 +115,15 @@ class TOTALVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseModelClass):
"""

_module_cls = TOTALVAE
_LATENT_QZM_KEY = "totalvi_latent_qzm"
_LATENT_QZV_KEY = "totalvi_latent_qzv"
_data_splitter_cls = DataSplitter
_training_plan_cls = AdversarialTrainingPlan
_train_runner_cls = TrainRunner

def __init__(
self,
adata: AnnData,
adata: AnnOrMuData,
n_latent: int = 20,
gene_dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"] = "gene",
protein_dispersion: Literal["protein", "protein-batch", "protein-label"] = "protein",
Expand All @@ -129,10 +144,10 @@ def __init__(
batch_mask = self.protein_state_registry.protein_batch_mask
msg = (
"Some proteins have all 0 counts in some batches. "
+ "These proteins will be treated as missing measurements; however, "
+ "this can occur due to experimental design/biology. "
+ "Reinitialize the model with `override_missing_proteins=True`,"
+ "to override this behavior."
"These proteins will be treated as missing measurements; however, "
"this can occur due to experimental design/biology. "
"Reinitialize the model with `override_missing_proteins=True`,"
"to override this behavior."
)
warnings.warn(msg, UserWarning, stacklevel=settings.warnings_stacklevel)
self._use_adversarial_classifier = True
Expand All @@ -145,7 +160,7 @@ def __init__(
if empirical_protein_background_prior is not None
else (self.summary_stats.n_proteins > 10)
)
if emp_prior:
if emp_prior and self.minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR:
prior_mean, prior_scale = self._get_totalvi_protein_priors(adata)
else:
prior_mean, prior_scale = None, None
Expand All @@ -161,7 +176,10 @@ def __init__(
n_batch = self.summary_stats.n_batch
use_size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry
library_log_means, library_log_vars = None, None
if not use_size_factor_key:
if (
not use_size_factor_key
and self.minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR
):
library_log_means, library_log_vars = _init_library_size(self.adata_manager, n_batch)

self.module = self._module_cls(
Expand Down Expand Up @@ -1004,6 +1022,7 @@ def get_feature_correlation_matrix(
batch_size=batch_size,
rna_size_factor=rna_size_factor,
transform_batch=b,
indices=indices,
)
flattened = np.zeros((denoised_data.shape[0] * n_samples, denoised_data.shape[1]))
for i in range(n_samples):
Expand Down Expand Up @@ -1214,6 +1233,12 @@ def setup_anndata(
-------
%(returns)s
"""
warnings.warn(
"We recommend using setup_mudata for multi-modal data."
"It does not influence model performance",
DeprecationWarning,
stacklevel=settings.warnings_stacklevel,
)
setup_method_args = cls._get_setup_method_args(**locals())
batch_field = fields.CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key)
anndata_fields = [
Expand Down Expand Up @@ -1275,7 +1300,7 @@ def setup_mudata(
--------
>>> mdata = muon.read_10x_h5("pbmc_10k_protein_v3_filtered_feature_bc_matrix.h5")
>>> scvi.model.TOTALVI.setup_mudata(
mdata, modalities={"rna_layer": "rna": "protein_layer": "prot"}
mdata, modalities={"rna_layer": "rna", "protein_layer": "prot"}
)
>>> vae = scvi.model.TOTALVI(mdata)
"""
Expand Down Expand Up @@ -1330,6 +1355,9 @@ def setup_mudata(
mod_required=True,
),
]
mdata_minify_type = _get_adata_minify_type(mdata)
if mdata_minify_type is not None:
mudata_fields += cls._get_fields_for_mudata_minification(mdata_minify_type)
adata_manager = AnnDataManager(fields=mudata_fields, setup_method_args=setup_method_args)
adata_manager.register_fields(mdata, **kwargs)
cls.register_manager(adata_manager)
3 changes: 3 additions & 0 deletions src/scvi/model/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from lightning.pytorch.trainer.connectors.accelerator_connector import (
_AcceleratorConnector,
)
from scipy.sparse import issparse

from scvi import REGISTRY_KEYS, settings
from scvi._types import Number
Expand Down Expand Up @@ -244,6 +245,8 @@ def cite_seq_raw_counts_properties(

nan = np.array([np.nan] * adata_manager.summary_stats.n_proteins)
protein_exp = adata_manager.get_from_registry(REGISTRY_KEYS.PROTEIN_EXP_KEY)
if issparse(protein_exp):
protein_exp = protein_exp.toarray()
mean1_pro = np.asarray(protein_exp[idx1].mean(0))
mean2_pro = np.asarray(protein_exp[idx2].mean(0))
nonz1_pro = np.asarray((protein_exp[idx1] > 0).mean(0))
Expand Down
7 changes: 6 additions & 1 deletion src/scvi/model/base/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from ._archesmixin import ArchesMixin
from ._base_model import BaseMinifiedModeModelClass, BaseModelClass
from ._base_model import (
BaseMinifiedModeModelClass,
BaseModelClass,
BaseMudataMinifiedModeModelClass,
)
from ._differential import DifferentialComputation
from ._embedding_mixin import EmbeddingMixin
from ._jaxmixin import JaxTrainingMixin
Expand All @@ -26,5 +30,6 @@
"DifferentialComputation",
"JaxTrainingMixin",
"BaseMinifiedModeModelClass",
"BaseMudataMinifiedModeModelClass",
"EmbeddingMixin",
]
Loading

0 comments on commit 337ec87

Please sign in to comment.