diff --git a/CHANGELOG.md b/CHANGELOG.md index 551da1c88c..609e9cb74c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ to [Semantic Versioning]. Full commit history is available in the - Refactored code for minified models. {pr}`2883`. - Add {class}`scvi.external.METHYLVI` for modeling methylation data from single-cell bisulfite sequencing (scBS-seq) experiments {pr}`2834`. +- Add MuData Minification option to {class}`~scvi.model.TOTALVI` {pr}`3061`. #### Fixed diff --git a/src/scvi/data/_utils.py b/src/scvi/data/_utils.py index fc6228a29f..3fff82a13b 100644 --- a/src/scvi/data/_utils.py +++ b/src/scvi/data/_utils.py @@ -311,10 +311,12 @@ def _get_adata_minify_type(adata: AnnData) -> MinifiedDataType | None: return adata.uns.get(_constants._ADATA_MINIFY_TYPE_UNS_KEY, None) -def _is_minified(adata: AnnData | str) -> bool: +def _is_minified(adata: AnnOrMuData | str) -> bool: uns_key = _constants._ADATA_MINIFY_TYPE_UNS_KEY if isinstance(adata, AnnData): return adata.uns.get(uns_key, None) is not None + elif isinstance(adata, MuData): + return adata.uns.get(uns_key, None) is not None elif isinstance(adata, str): with h5py.File(adata) as fp: return uns_key in read_elem(fp["uns"]).keys() diff --git a/src/scvi/model/_scanvi.py b/src/scvi/model/_scanvi.py index 084d83be1f..87a14dbd54 100644 --- a/src/scvi/model/_scanvi.py +++ b/src/scvi/model/_scanvi.py @@ -101,8 +101,8 @@ class SCANVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseMinifiedModeModelClass): _module_cls = SCANVAE _training_plan_cls = SemiSupervisedTrainingPlan - _LATENT_QZM = "scanvi_latent_qzm" - _LATENT_QZV = "scanvi_latent_qzv" + _LATENT_QZM_KEY = "scanvi_latent_qzm" + _LATENT_QZV_KEY = "scanvi_latent_qzv" def __init__( self, @@ -132,7 +132,10 @@ def __init__( n_batch = self.summary_stats.n_batch use_size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry library_log_means, library_log_vars = None, None - if not use_size_factor_key and self.minified_data_type is None: + if ( + not use_size_factor_key + and self.minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR + ): library_log_means, library_log_vars = _init_library_size(self.adata_manager, n_batch) self.module = self._module_cls( @@ -237,6 +240,7 @@ def from_scvi_model( cls.setup_anndata( adata, unlabeled_category=unlabeled_category, + use_minified=False, **scvi_setup_args, ) scanvi_model = cls(adata, **non_kwargs, **kwargs, **scanvi_kwargs) @@ -443,6 +447,7 @@ def setup_anndata( size_factor_key: str | None = None, categorical_covariate_keys: list[str] | None = None, continuous_covariate_keys: list[str] | None = None, + use_minified: bool = True, **kwargs, ): """%(summary)s. @@ -457,6 +462,8 @@ def setup_anndata( %(param_size_factor_key)s %(param_cat_cov_keys)s %(param_cont_cov_keys)s + use_minified + If True, will register the minified version of the adata if possible. """ setup_method_args = cls._get_setup_method_args(**locals()) anndata_fields = [ @@ -469,7 +476,7 @@ def setup_anndata( ] # register new fields if the adata is minified adata_minify_type = _get_adata_minify_type(adata) - if adata_minify_type is not None: + if adata_minify_type is not None and use_minified: anndata_fields += cls._get_fields_for_adata_minification(adata_minify_type) adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) diff --git a/src/scvi/model/_scvi.py b/src/scvi/model/_scvi.py index 36ad6f2c39..1eb23aa138 100644 --- a/src/scvi/model/_scvi.py +++ b/src/scvi/model/_scvi.py @@ -6,6 +6,7 @@ from scvi import REGISTRY_KEYS, settings from scvi.data import AnnDataManager +from scvi.data._constants import ADATA_MINIFY_TYPE from scvi.data._utils import _get_adata_minify_type from scvi.data.fields import ( CategoricalJointObsField, @@ -100,8 +101,8 @@ class SCVI( """ _module_cls = VAE - _SCVI_LATENT_QZM = "scvi_latent_qzm" - _SCVI_LATENT_QZV = "scvi_latent_qzv" + _LATENT_QZM_KEY = "scvi_latent_qzm" + _LATENT_QZV_KEY = "scvi_latent_qzv" def __init__( self, @@ -151,7 +152,10 @@ def __init__( n_batch = self.summary_stats.n_batch use_size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry library_log_means, library_log_vars = None, None - if not use_size_factor_key and self.minified_data_type is None: + if ( + not use_size_factor_key + and self.minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR + ): library_log_means, library_log_vars = _init_library_size( self.adata_manager, n_batch ) diff --git a/src/scvi/model/_totalvi.py b/src/scvi/model/_totalvi.py index a50c56e3ee..5a62e8c8f3 100644 --- a/src/scvi/model/_totalvi.py +++ b/src/scvi/model/_totalvi.py @@ -12,7 +12,8 @@ from scvi import REGISTRY_KEYS, settings from scvi.data import AnnDataManager, fields -from scvi.data._utils import _check_nonnegative_integers +from scvi.data._constants import ADATA_MINIFY_TYPE +from scvi.data._utils import _check_nonnegative_integers, _get_adata_minify_type from scvi.dataloaders import DataSplitter from scvi.model._utils import ( _get_batch_code_from_category, @@ -26,7 +27,13 @@ from scvi.train import AdversarialTrainingPlan, TrainRunner from scvi.utils._docstrings import de_dsp, devices_dsp, setup_anndata_dsp -from .base import ArchesMixin, BaseModelClass, RNASeqMixin, VAEMixin +from .base import ( + ArchesMixin, + BaseMinifiedModeModelClass, + BaseMudataMinifiedModeModelClass, + RNASeqMixin, + VAEMixin, +) if TYPE_CHECKING: from collections.abc import Iterable, Sequence @@ -35,18 +42,25 @@ from anndata import AnnData from mudata import MuData - from scvi._types import Number + from scvi._types import AnnOrMuData, Number logger = logging.getLogger(__name__) -class TOTALVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseModelClass): +class TOTALVI( + RNASeqMixin, + VAEMixin, + ArchesMixin, + BaseMinifiedModeModelClass, + BaseMudataMinifiedModeModelClass, +): """total Variational Inference :cite:p:`GayosoSteier21`. Parameters ---------- adata - AnnData object that has been registered via :meth:`~scvi.model.TOTALVI.setup_anndata`. + AnnOrMuData object that has been registered via :meth:`~scvi.model.TOTALVI.setup_anndata` + or :meth:`~scvi.model.TOTALVI.setup_mudata`. n_latent Dimensionality of the latent space. gene_dispersion @@ -84,13 +98,12 @@ class TOTALVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseModelClass): Examples -------- - >>> adata = anndata.read_h5ad(path_to_anndata) - >>> scvi.model.TOTALVI.setup_anndata( - adata, batch_key="batch", protein_expression_obsm_key="protein_expression" - ) - >>> vae = scvi.model.TOTALVI(adata) + >>> mdata = mudata.read_h5mu(path_to_mudata) + >>> scvi.model.TOTALVI.setup_mudata( + mdata, modalities={"rna_layer": "rna", "protein_layer": "prot"} + >>> vae = scvi.model.TOTALVI(mdata) >>> vae.train() - >>> adata.obsm["X_totalVI"] = vae.get_latent_representation() + >>> mdata.obsm["X_totalVI"] = vae.get_latent_representation() Notes ----- @@ -102,13 +115,15 @@ class TOTALVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseModelClass): """ _module_cls = TOTALVAE + _LATENT_QZM_KEY = "totalvi_latent_qzm" + _LATENT_QZV_KEY = "totalvi_latent_qzv" _data_splitter_cls = DataSplitter _training_plan_cls = AdversarialTrainingPlan _train_runner_cls = TrainRunner def __init__( self, - adata: AnnData, + adata: AnnOrMuData, n_latent: int = 20, gene_dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"] = "gene", protein_dispersion: Literal["protein", "protein-batch", "protein-label"] = "protein", @@ -129,10 +144,10 @@ def __init__( batch_mask = self.protein_state_registry.protein_batch_mask msg = ( "Some proteins have all 0 counts in some batches. " - + "These proteins will be treated as missing measurements; however, " - + "this can occur due to experimental design/biology. " - + "Reinitialize the model with `override_missing_proteins=True`," - + "to override this behavior." + "These proteins will be treated as missing measurements; however, " + "this can occur due to experimental design/biology. " + "Reinitialize the model with `override_missing_proteins=True`," + "to override this behavior." ) warnings.warn(msg, UserWarning, stacklevel=settings.warnings_stacklevel) self._use_adversarial_classifier = True @@ -145,7 +160,7 @@ def __init__( if empirical_protein_background_prior is not None else (self.summary_stats.n_proteins > 10) ) - if emp_prior: + if emp_prior and self.minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR: prior_mean, prior_scale = self._get_totalvi_protein_priors(adata) else: prior_mean, prior_scale = None, None @@ -161,7 +176,10 @@ def __init__( n_batch = self.summary_stats.n_batch use_size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry library_log_means, library_log_vars = None, None - if not use_size_factor_key: + if ( + not use_size_factor_key + and self.minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR + ): library_log_means, library_log_vars = _init_library_size(self.adata_manager, n_batch) self.module = self._module_cls( @@ -1004,6 +1022,7 @@ def get_feature_correlation_matrix( batch_size=batch_size, rna_size_factor=rna_size_factor, transform_batch=b, + indices=indices, ) flattened = np.zeros((denoised_data.shape[0] * n_samples, denoised_data.shape[1])) for i in range(n_samples): @@ -1214,6 +1233,12 @@ def setup_anndata( ------- %(returns)s """ + warnings.warn( + "We recommend using setup_mudata for multi-modal data." + "It does not influence model performance", + DeprecationWarning, + stacklevel=settings.warnings_stacklevel, + ) setup_method_args = cls._get_setup_method_args(**locals()) batch_field = fields.CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key) anndata_fields = [ @@ -1275,7 +1300,7 @@ def setup_mudata( -------- >>> mdata = muon.read_10x_h5("pbmc_10k_protein_v3_filtered_feature_bc_matrix.h5") >>> scvi.model.TOTALVI.setup_mudata( - mdata, modalities={"rna_layer": "rna": "protein_layer": "prot"} + mdata, modalities={"rna_layer": "rna", "protein_layer": "prot"} ) >>> vae = scvi.model.TOTALVI(mdata) """ @@ -1330,6 +1355,9 @@ def setup_mudata( mod_required=True, ), ] + mdata_minify_type = _get_adata_minify_type(mdata) + if mdata_minify_type is not None: + mudata_fields += cls._get_fields_for_mudata_minification(mdata_minify_type) adata_manager = AnnDataManager(fields=mudata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(mdata, **kwargs) cls.register_manager(adata_manager) diff --git a/src/scvi/model/_utils.py b/src/scvi/model/_utils.py index 0759253006..77093b5323 100644 --- a/src/scvi/model/_utils.py +++ b/src/scvi/model/_utils.py @@ -12,6 +12,7 @@ from lightning.pytorch.trainer.connectors.accelerator_connector import ( _AcceleratorConnector, ) +from scipy.sparse import issparse from scvi import REGISTRY_KEYS, settings from scvi._types import Number @@ -244,6 +245,8 @@ def cite_seq_raw_counts_properties( nan = np.array([np.nan] * adata_manager.summary_stats.n_proteins) protein_exp = adata_manager.get_from_registry(REGISTRY_KEYS.PROTEIN_EXP_KEY) + if issparse(protein_exp): + protein_exp = protein_exp.toarray() mean1_pro = np.asarray(protein_exp[idx1].mean(0)) mean2_pro = np.asarray(protein_exp[idx2].mean(0)) nonz1_pro = np.asarray((protein_exp[idx1] > 0).mean(0)) diff --git a/src/scvi/model/base/__init__.py b/src/scvi/model/base/__init__.py index e8573f8d53..4b38494caf 100644 --- a/src/scvi/model/base/__init__.py +++ b/src/scvi/model/base/__init__.py @@ -1,5 +1,9 @@ from ._archesmixin import ArchesMixin -from ._base_model import BaseMinifiedModeModelClass, BaseModelClass +from ._base_model import ( + BaseMinifiedModeModelClass, + BaseModelClass, + BaseMudataMinifiedModeModelClass, +) from ._differential import DifferentialComputation from ._embedding_mixin import EmbeddingMixin from ._jaxmixin import JaxTrainingMixin @@ -26,5 +30,6 @@ "DifferentialComputation", "JaxTrainingMixin", "BaseMinifiedModeModelClass", + "BaseMudataMinifiedModeModelClass", "EmbeddingMixin", ] diff --git a/src/scvi/model/base/_base_model.py b/src/scvi/model/base/_base_model.py index 64b0594e29..aaafca9dcc 100644 --- a/src/scvi/model/base/_base_model.py +++ b/src/scvi/model/base/_base_model.py @@ -35,7 +35,7 @@ _load_saved_files, _validate_var_names, ) -from scvi.model.utils import get_minified_adata_scrna +from scvi.model.utils import get_minified_adata_scrna, get_minified_mudata from scvi.utils import attrdict, setup_anndata_dsp from scvi.utils._docstrings import devices_dsp @@ -90,8 +90,6 @@ class BaseModelClass(metaclass=BaseModelMetaClass): 1. :doc:`/tutorials/notebooks/dev/model_user_guide` """ - _LATENT_QZM_KEY = "latent_qzm" - _LATENT_QZV_KEY = "latent_qzv" _OBSERVED_LIB_SIZE_KEY = "observed_lib_size" _data_loader_cls = AnnDataLoader @@ -343,7 +341,8 @@ def get_anndata_manager( if _SCVI_UUID_KEY not in adata.uns: if required: raise ValueError( - f"Please set up your AnnData with {cls.__name__}.setup_anndata first." + f"Please set up your AnnData with {cls.__name__}.setup_anndata'" + "or {cls.__name__}.setup_mudata first." ) return None @@ -358,7 +357,7 @@ def get_anndata_manager( elif adata_id not in cls._per_instance_manager_store[self.id]: if required: raise AssertionError( - "Please call ``self._validate_anndata`` on this AnnData object." + "Please call ``self._validate_anndata`` on this AnnData or MuData object." ) return None @@ -893,11 +892,11 @@ def view_anndata_setup( class BaseMinifiedModeModelClass(BaseModelClass): - """Base class for models that can handle minified data.""" + """Abstract base class for scvi-tools models that can handle minified data.""" @property def minified_data_type(self) -> MinifiedDataType | None: - """Type of minified data associated with this model.""" + """The type of minified data associated with this model, if applicable.""" return ( self.adata_manager.get_from_registry(REGISTRY_KEYS.MINIFY_TYPE_KEY) if REGISTRY_KEYS.MINIFY_TYPE_KEY in self.adata_manager.data_registry @@ -1019,3 +1018,128 @@ def summary_string(self): hasattr(self, "minified_data_type") and self.minified_data_type is not None ) return summary_string + + +class BaseMudataMinifiedModeModelClass(BaseModelClass): + """Abstract base class for scvi-tools models that can handle minified data.""" + + @property + def minified_data_type(self) -> MinifiedDataType | None: + """The type of minified data associated with this model, if applicable.""" + return ( + self.adata_manager.get_from_registry(REGISTRY_KEYS.MINIFY_TYPE_KEY) + if REGISTRY_KEYS.MINIFY_TYPE_KEY in self.adata_manager.data_registry + else None + ) + + def minify_mudata( + self, + minified_data_type: MinifiedDataType = ADATA_MINIFY_TYPE.LATENT_POSTERIOR, + use_latent_qzm_key: str = "X_latent_qzm", + use_latent_qzv_key: str = "X_latent_qzv", + ) -> None: + """Minify the model's :attr:`~scvi.model.base.BaseModelClass.adata`. + + Minifies the :class:`~mudata.MuData` object associated with the model according to the + method specified by ``minified_data_type`` and registers the new fields with the model's + :class:`~scvi.data.AnnDataManager`. This also sets the ``minified_data_type`` attribute + of the underlying :class:`~scvi.module.base.BaseModuleClass` instance. + + Parameters + ---------- + minified_data_type + Method for minifying the data. One of the following: + + - ``"latent_posterior_parameters"``: Store the latent posterior mean and variance in + :attr:`~mudata.MuData.obsm` using the keys ``use_latent_qzm_key`` and + ``use_latent_qzv_key``. + - ``"latent_posterior_parameters_with_counts"``: Store the latent posterior mean and + variance in :attr:`~mudata.MuData.obsm` using the keys ``use_latent_qzm_key`` and + ``use_latent_qzv_key``, and the raw count data in :attr:`~mudata.MuData.X`. + use_latent_qzm_key + Key to use for storing the latent posterior mean in :attr:`~mudata.MuData.obsm` when + ``minified_data_type`` is ``"latent_posterior"``. + use_latent_qzv_key + Key to use for storing the latent posterior variance in :attr:`~mudata.MuData.obsm` + when ``minified_data_type`` is ``"latent_posterior"``. + + Notes + ----- + The modification is not done inplace -- instead the model is assigned a new (minified) + version of the :class:`~mudata.MuData`. + """ + if self.adata_manager._registry["setup_method_name"] != "setup_mudata": + raise ValueError( + f"MuData must be registered with {self.__name__}.setup_mudata to use this method." + ) + if minified_data_type not in ADATA_MINIFY_TYPE: + raise NotImplementedError( + f"Minification method {minified_data_type} is not supported." + ) + elif not getattr(self.module, "use_observed_lib_size", True): + raise ValueError( + "Minification is not supported for models that do not use observed library size." + ) + + keep_count_data = minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR_WITH_COUNTS + mini_adata = get_minified_mudata( + adata_manager=self.adata_manager, + keep_count_data=keep_count_data, + ) + del mini_adata.uns[_SCVI_UUID_KEY] + mini_adata.uns[_ADATA_MINIFY_TYPE_UNS_KEY] = minified_data_type + mini_adata.obsm[self._LATENT_QZM_KEY] = self.adata.obsm[use_latent_qzm_key] + mini_adata.obsm[self._LATENT_QZV_KEY] = self.adata.obsm[use_latent_qzv_key] + mini_adata.obs[self._OBSERVED_LIB_SIZE_KEY] = np.squeeze( + np.asarray(self.adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY).sum(axis=-1)) + ) + self._update_mudata_and_manager_post_minification( + mini_adata, + minified_data_type, + ) + self.module.minified_data_type = minified_data_type + + @classmethod + def _get_fields_for_mudata_minification( + cls, + minified_data_type: MinifiedDataType, + ): + """Return the fields required for minification of the given type.""" + if minified_data_type not in ADATA_MINIFY_TYPE: + raise NotImplementedError( + f"Minification method {minified_data_type} is not supported." + ) + + mini_fields = [ + fields.ObsmField(REGISTRY_KEYS.LATENT_QZM_KEY, cls._LATENT_QZM_KEY), + fields.ObsmField(REGISTRY_KEYS.LATENT_QZV_KEY, cls._LATENT_QZV_KEY), + fields.NumericalObsField(REGISTRY_KEYS.OBSERVED_LIB_SIZE, cls._OBSERVED_LIB_SIZE_KEY), + fields.StringUnsField(REGISTRY_KEYS.MINIFY_TYPE_KEY, _ADATA_MINIFY_TYPE_UNS_KEY), + ] + + return mini_fields + + def _update_mudata_and_manager_post_minification( + self, minified_adata: AnnOrMuData, minified_data_type: MinifiedDataType + ): + """Update the mudata and manager inplace after creating a minified adata.""" + # Register this new adata with the model, creating a new manager in the cache + self._validate_anndata(minified_adata) + new_adata_manager = self.get_anndata_manager(minified_adata, required=True) + # This inplace edits the manager + new_adata_manager.register_new_fields( + self._get_fields_for_mudata_minification(minified_data_type), + ) + new_adata_manager.registry["setup_method_name"] = "setup_mudata" + # We set the adata attribute of the model as this will update self.registry_ + # and self.adata_manager with the new adata manager + self.adata = minified_adata + + @property + def summary_string(self): + """Summary string of the model.""" + summary_string = super().summary_string + summary_string += "\nModel's adata is minified?: {}".format( + hasattr(self, "minified_data_type") and self.minified_data_type is not None + ) + return summary_string diff --git a/src/scvi/model/base/_de_core.py b/src/scvi/model/base/_de_core.py index bef79ad900..13a3738803 100644 --- a/src/scvi/model/base/_de_core.py +++ b/src/scvi/model/base/_de_core.py @@ -1,10 +1,13 @@ import logging +import warnings from collections.abc import Iterable as IterableClass import anndata import numpy as np import pandas as pd +from scvi import settings +from scvi.data._constants import _ADATA_MINIFY_TYPE_UNS_KEY, ADATA_MINIFY_TYPE from scvi.utils import track from ._differential import DifferentialComputation @@ -81,6 +84,15 @@ def _de_core( **kwargs, ): """Internal function for DE interface.""" + if ( + adata_manager.adata.uns.get(_ADATA_MINIFY_TYPE_UNS_KEY, None) + == ADATA_MINIFY_TYPE.LATENT_POSTERIOR + ): + warnings.warn( + "Count statistics make no sense for minified model. Consider disabling all_stats.", + UserWarning, + stacklevel=settings.warnings_stacklevel, + ) adata = adata_manager.adata if group1 is None and idx1 is None: group1 = adata.obs[groupby].astype("category").cat.categories.tolist() diff --git a/src/scvi/model/utils/__init__.py b/src/scvi/model/utils/__init__.py index 003b763e5e..0ee147802d 100644 --- a/src/scvi/model/utils/__init__.py +++ b/src/scvi/model/utils/__init__.py @@ -1,4 +1,4 @@ from ._mde import mde -from ._minification import get_minified_adata_scrna +from ._minification import get_minified_adata_scrna, get_minified_mudata -__all__ = ["mde", "get_minified_adata_scrna"] +__all__ = ["mde", "get_minified_adata_scrna", "get_minified_mudata"] diff --git a/src/scvi/model/utils/_minification.py b/src/scvi/model/utils/_minification.py index 77fb737ba1..5ba553062a 100644 --- a/src/scvi/model/utils/_minification.py +++ b/src/scvi/model/utils/_minification.py @@ -2,12 +2,14 @@ from typing import TYPE_CHECKING -from anndata import AnnData from scipy.sparse import csr_matrix from scvi import REGISTRY_KEYS if TYPE_CHECKING: + from anndata import AnnData + from mudata import MuData + from scvi.data import AnnDataManager @@ -15,22 +17,51 @@ def get_minified_adata_scrna( adata_manager: AnnDataManager, keep_count_data: bool = False, ) -> AnnData: - """Get a minified version of an :class:`~anndata.AnnData` or :class:`~mudata.MuData` object.""" + """Returns a minified AnnData. + + Parameters + ---------- + adata_manager + Manager with original AnnData, of which we want to create a minified version. + keep_count_data + If True, the count data is kept in the minified data. If False, the count data is removed. + """ + adata = adata_manager.adata.copy() if keep_count_data: - return adata_manager.adata.copy() + pass else: + del adata.raw counts = adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY) all_zeros = csr_matrix(counts.shape) X = all_zeros layers = {layer: all_zeros.copy() for layer in adata_manager.adata.layers} - return AnnData( - X=X, - layers=layers, - obs=adata_manager.adata.obs.copy(), - var=adata_manager.adata.var.copy(), - uns=adata_manager.adata.uns.copy(), - obsm=adata_manager.adata.obsm.copy(), - varm=adata_manager.adata.varm.copy(), - obsp=adata_manager.adata.obsp.copy(), - varp=adata_manager.adata.varp.copy(), - ) + adata.X = X + adata.layers = layers + return adata + + +def get_minified_mudata( + adata_manager: AnnDataManager, + keep_count_data: bool = False, +) -> MuData: + """Returns a minified MuData that works for most multi modality models (MULTIVI, TOTALVI). + + Parameters + ---------- + adata_manager + Manager with original MuData, of which we want to create a minified version. + keep_count_data + If True, the count data is kept in the minified data. If False, the count data is removed. + """ + mdata = adata_manager.adata.copy() + if keep_count_data: + pass + else: + for modality in mdata.mod_names: + del mdata[modality].raw + all_zeros = csr_matrix(mdata[modality].X.shape) + mdata[modality].X = all_zeros.copy() + if len(mdata[modality].layers) > 0: + layers = {layer: all_zeros.copy() for layer in mdata[modality].layers} + mdata[modality].layers = layers + return mdata diff --git a/src/scvi/module/_totalvae.py b/src/scvi/module/_totalvae.py index d3fb5488da..643e6ba614 100644 --- a/src/scvi/module/_totalvae.py +++ b/src/scvi/module/_totalvae.py @@ -12,13 +12,15 @@ from scvi import REGISTRY_KEYS from scvi.data import _constants +from scvi.data._constants import ADATA_MINIFY_TYPE from scvi.distributions import ( NegativeBinomial, NegativeBinomialMixture, ZeroInflatedNegativeBinomial, ) from scvi.model.base import BaseModelClass -from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data +from scvi.module._constants import MODULE_KEYS +from scvi.module.base import BaseMinifiedModeModuleClass, LossOutput, auto_move_data from scvi.nn import DecoderTOTALVI, EncoderTOTALVI from scvi.nn._utils import ExpActivation @@ -26,7 +28,7 @@ # VAE model -class TOTALVAE(BaseModuleClass): +class TOTALVAE(BaseMinifiedModeModuleClass): """Total variational inference for CITE-seq data. Implements the totalVI model of :cite:p:`GayosoSteier21`. @@ -324,25 +326,37 @@ def get_reconstruction_loss( return reconst_loss_gene, reconst_loss_protein - def _get_inference_input(self, tensors): - x = tensors[REGISTRY_KEYS.X_KEY] - y = tensors[REGISTRY_KEYS.PROTEIN_EXP_KEY] - batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] - - cont_key = REGISTRY_KEYS.CONT_COVS_KEY - cont_covs = tensors[cont_key] if cont_key in tensors.keys() else None - - cat_key = REGISTRY_KEYS.CAT_COVS_KEY - cat_covs = tensors[cat_key] if cat_key in tensors.keys() else None - - input_dict = { - "x": x, - "y": y, - "batch_index": batch_index, - "cat_covs": cat_covs, - "cont_covs": cont_covs, - } - return input_dict + def _get_inference_input( + self, + tensors, + full_forward_pass: bool = False, + ) -> dict[str, torch.Tensor | None]: + """Get input tensors for the inference process.""" + if full_forward_pass or self.minified_data_type is None: + loader = "full_data" + elif self.minified_data_type in [ + ADATA_MINIFY_TYPE.LATENT_POSTERIOR, + ADATA_MINIFY_TYPE.LATENT_POSTERIOR_WITH_COUNTS, + ]: + loader = "minified_data" + else: + raise NotImplementedError(f"Unknown minified-data type: {self.minified_data_type}") + + if loader == "full_data": + return { + MODULE_KEYS.X_KEY: tensors[REGISTRY_KEYS.X_KEY], + MODULE_KEYS.Y_KEY: tensors[REGISTRY_KEYS.PROTEIN_EXP_KEY], + MODULE_KEYS.BATCH_INDEX_KEY: tensors[REGISTRY_KEYS.BATCH_KEY], + MODULE_KEYS.CONT_COVS_KEY: tensors.get(REGISTRY_KEYS.CONT_COVS_KEY, None), + MODULE_KEYS.CAT_COVS_KEY: tensors.get(REGISTRY_KEYS.CAT_COVS_KEY, None), + } + else: + return { + MODULE_KEYS.QZM_KEY: tensors[REGISTRY_KEYS.LATENT_QZM_KEY], + MODULE_KEYS.QZV_KEY: tensors[REGISTRY_KEYS.LATENT_QZV_KEY], + REGISTRY_KEYS.OBSERVED_LIB_SIZE: tensors[REGISTRY_KEYS.OBSERVED_LIB_SIZE], + MODULE_KEYS.BATCH_INDEX_KEY: tensors[REGISTRY_KEYS.BATCH_KEY], + } def _get_generative_input(self, tensors, inference_outputs): z = inference_outputs["z"] @@ -433,7 +447,45 @@ def generative( } @auto_move_data - def inference( + def _cached_inference( + self, + qzm: torch.Tensor, + qzv: torch.Tensor, + batch_index: torch.Tensor, + observed_lib_size: torch.Tensor, + n_samples: int = 1, + ) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]: + """Run the cached inference process.""" + library = observed_lib_size + qz = Normal(qzm, qzv) + untran_z = qz.sample() if n_samples == 1 else qz.sample((n_samples,)) + z = self.encoder.z_transformation(untran_z) + library = torch.log(observed_lib_size) + if n_samples > 1: + library = library.unsqueeze(0).expand((n_samples, library.size(0), library.size(1))) + + if self.n_batch > 0: + py_back_alpha_prior = F.linear( + one_hot(batch_index.squeeze(-1), self.n_batch).float(), self.background_pro_alpha + ) + py_back_beta_prior = F.linear( + one_hot(batch_index.squeeze(-1), self.n_batch).float(), + torch.exp(self.background_pro_log_beta), + ) + else: + py_back_alpha_prior = self.background_pro_alpha + py_back_beta_prior = torch.exp(self.background_pro_log_beta) + self.back_mean_prior = Normal(py_back_alpha_prior, py_back_beta_prior) + + return { + MODULE_KEYS.Z_KEY: z, + MODULE_KEYS.QZ_KEY: qz, + MODULE_KEYS.QL_KEY: None, + "library_gene": observed_lib_size, + } + + @auto_move_data + def _regular_inference( self, x: torch.Tensor, y: torch.Tensor, @@ -515,24 +567,6 @@ def inference( else: library_gene = self.encoder.l_transformation(untran_l) - # Background regularization - if self.gene_dispersion == "gene-label": - # px_r gets transposed - last dimension is nb genes - px_r = F.linear(one_hot(label.squeeze(-1), self.n_labels).float(), self.px_r) - elif self.gene_dispersion == "gene-batch": - px_r = F.linear(one_hot(batch_index.squeeze(-1), self.n_batch).float(), self.px_r) - elif self.gene_dispersion == "gene": - px_r = self.px_r - px_r = torch.exp(px_r) - - if self.protein_dispersion == "protein-label": - # py_r gets transposed - last dimension is n_proteins - py_r = F.linear(one_hot(label.squeeze(-1), self.n_labels).float(), self.py_r) - elif self.protein_dispersion == "protein-batch": - py_r = F.linear(one_hot(batch_index.squeeze(-1), self.n_batch).float(), self.py_r) - elif self.protein_dispersion == "protein": - py_r = self.py_r - py_r = torch.exp(py_r) if self.n_batch > 0: py_back_alpha_prior = F.linear( one_hot(batch_index.squeeze(-1), self.n_batch).float(), self.background_pro_alpha @@ -547,10 +581,9 @@ def inference( self.back_mean_prior = Normal(py_back_alpha_prior, py_back_beta_prior) return { - "qz": qz, - "z": z, - "untran_z": untran_z, - "ql": ql, + MODULE_KEYS.Z_KEY: z, + MODULE_KEYS.QZ_KEY: qz, + MODULE_KEYS.QL_KEY: ql, "library_gene": library_gene, "untran_l": untran_l, } @@ -656,7 +689,7 @@ def sample(self, tensors, n_samples=1): inference_kwargs = {"n_samples": n_samples} with torch.inference_mode(): ( - inference_outputs, + _, generative_outputs, ) = self.forward( tensors, @@ -691,12 +724,11 @@ def marginal_ll(self, tensors, n_mc_samples, return_mean: bool = True): # Distribution parameters and sampled variables inference_outputs, generative_outputs, losses = self.forward(tensors) # outputs = self.module.inference(x, y, batch_index, labels) - qz = inference_outputs["qz"] - ql = inference_outputs["ql"] + qz = inference_outputs[MODULE_KEYS.QZ_KEY] + ql = inference_outputs[MODULE_KEYS.QL_KEY] + z = inference_outputs[MODULE_KEYS.Z_KEY] py_ = generative_outputs["py_"] - log_library = inference_outputs["untran_l"] # really need not softmax transformed random variable - z = inference_outputs["untran_z"] log_pro_back_mean = generative_outputs["log_pro_back_mean"] # Reconstruction Loss @@ -708,6 +740,7 @@ def marginal_ll(self, tensors, n_mc_samples, return_mean: bool = True): log_prob_sum = torch.zeros(qz.loc.shape[0]).to(self.device) if not self.use_observed_lib_size: + log_library = inference_outputs["untran_l"] n_batch = self.library_log_means.shape[1] local_library_log_means = F.linear( one_hot(batch_index.squeeze(-1), n_batch).float(), self.library_log_means diff --git a/tests/model/test_models_with_mudata_minified_data.py b/tests/model/test_models_with_mudata_minified_data.py new file mode 100644 index 0000000000..e4d543b033 --- /dev/null +++ b/tests/model/test_models_with_mudata_minified_data.py @@ -0,0 +1,400 @@ +import numpy as np +import pytest +from mudata import MuData + +import scvi +from scvi.data import synthetic_iid +from scvi.data._constants import ADATA_MINIFY_TYPE +from scvi.data._utils import _is_minified +from scvi.model import TOTALVI + +OBSERVED_LIB_SIZE = "observed_lib_size" + + +def prep_model(cls=TOTALVI, use_size_factor=False): + # create a synthetic dataset + adata = synthetic_iid() + adata_counts = adata.X + if use_size_factor: + adata.obs["size_factor"] = np.random.randint(1, 5, size=(adata.shape[0],)) + adata.var["n_counts"] = np.squeeze(np.asarray(np.sum(adata_counts, axis=0))) + adata.varm["my_varm"] = np.random.negative_binomial(5, 0.3, size=(adata.shape[1], 3)) + adata.layers["my_layer"] = np.ones_like(adata.X) + adata_before_setup = adata.copy() + + # run setup_anndata + setup_kwargs = { + "batch_key": "batch", + "protein_expression_obsm_key": "protein_expression", + "protein_names_uns_key": "protein_names", + } + if use_size_factor: + setup_kwargs["size_factor_key"] = "size_factor" + cls.setup_anndata( + adata, + **setup_kwargs, + ) + + # create and train the model + if cls == TOTALVI: + model = cls(adata, n_latent=5) + else: + model = cls(adata, n_latent=5, n_genes=50, n_regions=50) + model.train(1, check_val_every_n_epoch=1, train_size=0.5) + + # get the adata lib size + adata_lib_size = np.squeeze(np.asarray(adata_counts.sum(axis=1))) + assert ( + np.min(adata_lib_size) > 0 + ) # make sure it's not all zeros and there are no negative values + + return model, adata, adata_lib_size, adata_before_setup + + +def run_test_for_model_with_minified_adata( + cls=TOTALVI, + n_samples: int = 1, + give_mean: bool = False, + use_size_factor=False, +): + model, adata, adata_lib_size, _ = prep_model(cls, use_size_factor) + + qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) + model.adata.obsm["X_latent_qzm"] = qzm + model.adata.obsm["X_latent_qzv"] = qzv + adata_orig = adata.copy() + + model.minify_adata() + assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR + assert model.adata_manager.registry is model.registry_ + + # make sure the original adata we set up the model with was not changed + assert adata is not model.adata + assert _is_minified(adata) is False + + assert adata_orig.layers.keys() == model.adata.layers.keys() + orig_obs_df = adata_orig.obs + obs_keys = OBSERVED_LIB_SIZE + orig_obs_df[obs_keys] = adata_lib_size + assert model.adata.obs.equals(orig_obs_df) + assert model.adata.var_names.equals(adata_orig.var_names) + assert model.adata.var.equals(adata_orig.var) + assert model.adata.varm.keys() == adata_orig.varm.keys() + np.testing.assert_array_equal(model.adata.varm["my_varm"], adata_orig.varm["my_varm"]) + + +def prep_model_mudata(cls=TOTALVI, use_size_factor=False, layer=None): + # create a synthetic dataset + mdata = synthetic_iid(return_mudata=True) + if use_size_factor: + mdata.obs["size_factor_rna"] = mdata["rna"].X.sum(1) + mdata.obs["size_factor_atac"] = (mdata["accessibility"].X.sum(1) + 1) / ( + np.max(mdata["accessibility"].X.sum(1)) + 1.01 + ) + if layer is not None: + for mod in mdata.mod_names: + mdata[mod].layers[layer] = mdata[mod].X.copy() + mdata[mod].X = np.zeros_like(mdata[mod].X) + mdata.var["n_counts"] = np.squeeze( + np.concatenate( + [ + np.asarray(np.sum(mdata["rna"].X, axis=0)), + np.asarray(np.sum(mdata["protein_expression"].X, axis=0)), + np.asarray(np.sum(mdata["accessibility"].X, axis=0)), + ] + ) + ) + mdata.varm["my_varm"] = np.random.negative_binomial(5, 0.3, size=(mdata.shape[1], 3)) + mdata_before_setup = mdata.copy() + + # run setup_anndata + setup_kwargs = { + "batch_key": "batch", + } + + if use_size_factor: + setup_kwargs["size_factor_key"] = "size_factor_rna" + + # create and train the model + if cls == TOTALVI: + mdata = MuData({"rna": mdata["rna"], "protein_expression": mdata["protein_expression"]}) + mdata.obs = mdata_before_setup.obs + cls.setup_mudata( + mdata, + modalities={"rna_layer": "rna", "protein_layer": "protein_expression"}, + **setup_kwargs, + ) + model = cls(mdata, n_latent=5) + else: + raise ValueError("Bad Model name as input to test") + model.train(1, check_val_every_n_epoch=1, train_size=0.5) + + # get the mdata lib size + mdata_lib_size = np.squeeze(np.asarray(mdata["rna"].X.sum(axis=1))) + assert ( + np.min(mdata_lib_size) > 0 + ) # make sure it's not all zeros and there are no negative values + + return model, mdata, mdata_lib_size, mdata_before_setup + + +def run_test_for_model_with_minified_mudata( + cls=TOTALVI, + use_size_factor=False, +): + model, mdata, mdata_lib_size, _ = prep_model_mudata(cls, use_size_factor) + + qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) + model.adata.obsm["X_latent_qzm"] = qzm + model.adata.obsm["X_latent_qzv"] = qzv + + mdata_orig = mdata.copy() + + model.minify_mudata() + assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR + assert model.adata_manager.registry is model.registry_ + + # make sure the original mdata we set up the model with was not changed + assert mdata is not model.adata + assert _is_minified(mdata) is False + assert _is_minified(model.adata) is True + + assert mdata_orig["rna"].layers.keys() == model.adata["rna"].layers.keys() + orig_obs_df = mdata_orig.obs + obs_keys = OBSERVED_LIB_SIZE + orig_obs_df[obs_keys] = mdata_lib_size + assert model.adata.obs.equals(orig_obs_df) + assert model.adata.var_names.equals(mdata_orig.var_names) + assert model.adata.var.equals(mdata_orig.var) + assert model.adata.varm.keys() == mdata_orig.varm.keys() + np.testing.assert_array_equal(model.adata.varm["my_varm"], mdata_orig.varm["my_varm"]) + + +def assert_approx_equal(a, b): + # Allclose because on GPU, the values are not exactly the same + # as some values are moved to cpu during data minification + np.testing.assert_allclose(a, b, rtol=3e-1, atol=5e-1) + + +@pytest.mark.parametrize("cls", [TOTALVI]) +@pytest.mark.parametrize("use_size_factor", [True]) +def test_with_minified_adata(cls, use_size_factor: bool): + run_test_for_model_with_minified_adata(cls=cls, use_size_factor=use_size_factor) + + +@pytest.mark.parametrize("cls", [TOTALVI]) +def test_with_minified_mdata_get_normalized_expression(cls): + model, mdata, _, _ = prep_model_mudata(cls=cls, use_size_factor=True) + + scvi.settings.seed = 1 + qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) + model.adata.obsm["X_latent_qzm"] = qzm + model.adata.obsm["X_latent_qzv"] = qzv + + scvi.settings.seed = 1 + exprs_orig = model.get_normalized_expression(n_samples=500) + + model.minify_mudata() + assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR + + scvi.settings.seed = 1 + exprs_new = model.get_normalized_expression(n_samples=500) + + if type(exprs_new) is tuple: + for ii in range(len(exprs_new)): + assert exprs_new[ii].shape == mdata[mdata.mod_names[ii]].shape + for ii in range(len(exprs_new)): + assert_approx_equal(exprs_new[ii], exprs_orig[ii]) + else: + assert exprs_new.shape == exprs_orig.shape + assert_approx_equal(exprs_new, exprs_orig) + + +def test_totalvi_downstream_with_minified_mdata(): + model, mdata, _, _ = prep_model_mudata(cls=TOTALVI, use_size_factor=True) + # non-default gene list and n_samples > 1 + gl = mdata.var_names[:5].to_list() + + scvi.settings.seed = 1 + qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) + model.adata.obsm["X_latent_qzm"] = qzm + model.adata.obsm["X_latent_qzv"] = qzv + + scvi.settings.seed = 1 + + model.minify_mudata() + assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR + + scvi.settings.seed = 1 + assert model.get_normalized_expression(gene_list=gl, library_size="latent") + assert model.get_normalized_expression(gene_list=gl, library_size=1) + sample = model.posterior_predictive_sample() + assert sample.shape == mdata.shape + corr = model.get_feature_correlation_matrix() + assert corr.shape == (mdata.n_vars, mdata.n_vars) + fore = model.get_protein_foreground_probability() + assert fore.shape == (mdata.n_obs, mdata["protein_expression"].n_vars) + model.differential_expression(groupby="labels") + + +def test_totalvi_downstream_with_minified_mdata_keep_counts(): + model, mdata, _, _ = prep_model_mudata(cls=TOTALVI, use_size_factor=True) + + # non-default gene list and n_samples > 1 + gl = mdata.var_names[:5].to_list() + + scvi.settings.seed = 1 + qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) + model.adata.obsm["X_latent_qzm"] = qzm + model.adata.obsm["X_latent_qzv"] = qzv + + scvi.settings.seed = 1 + + model.minify_mudata(minified_data_type="latent_posterior_parameters_with_counts") + assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR_WITH_COUNTS + + scvi.settings.seed = 1 + assert model.get_normalized_expression(gene_list=gl, library_size="latent") + assert model.get_normalized_expression(gene_list=gl, library_size=1) + sample = model.posterior_predictive_sample() + assert sample.shape == mdata.shape + corr = model.get_feature_correlation_matrix() + assert corr.shape == (mdata.n_vars, mdata.n_vars) + fore = model.get_protein_foreground_probability() + assert fore.shape == (mdata.n_obs, mdata["protein_expression"].n_vars) + model.differential_expression(groupby="labels") + + +@pytest.mark.parametrize("cls", [TOTALVI]) +def test_validate_unsupported_if_minified(cls): + model, _, _, _ = prep_model_mudata(cls=cls, use_size_factor=True) + + qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) + model.adata.obsm["X_latent_qzm"] = qzm + model.adata.obsm["X_latent_qzv"] = qzv + + model.minify_mudata() + assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR + + common_err_msg = "The {} function currently does not support minified data." + + with pytest.raises(ValueError) as e: + model.get_elbo() + assert str(e.value) == common_err_msg.format("VAEMixin.get_elbo") + + with pytest.raises(ValueError) as e: + model.get_reconstruction_error() + assert str(e.value) == common_err_msg.format("VAEMixin.get_reconstruction_error") + + with pytest.raises(ValueError) as e: + model.get_marginal_ll() + assert str(e.value) == common_err_msg.format("VAEMixin.get_marginal_ll") + + +@pytest.mark.parametrize("cls", [TOTALVI]) +def test_with_minified_mdata_save_then_load(cls, save_path): + # create a model and minify its mdata, then save it and its mdata. + # Load it back up using the same (minified) mdata. Validate that the + # loaded model has the minified_data_type attribute set as expected. + model, mdata, _, _ = prep_model_mudata(cls=cls, use_size_factor=True) + + qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) + model.adata.obsm["X_latent_qzm"] = qzm + model.adata.obsm["X_latent_qzv"] = qzv + + model.minify_mudata() + assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR + + model.save(save_path, overwrite=True, save_anndata=True) + model.view_setup_args(save_path) + # load saved model with saved (minified) mdata + loaded_model = cls.load(save_path, adata=mdata) + + assert loaded_model.minified_data_type is None + + +@pytest.mark.parametrize("cls", [TOTALVI]) +def test_with_minified_mdata_save_then_load_with_non_minified_mdata(cls, save_path): + # create a model and minify its mdata, then save it and its mdata. + # Load it back up using a non-minified mdata. Validate that the + # loaded model does not has the minified_data_type attribute set. + model, mdata, _, mdata_before_setup = prep_model_mudata(cls=cls, use_size_factor=True) + + qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) + model.adata.obsm["X_latent_qzm"] = qzm + model.adata.obsm["X_latent_qzv"] = qzv + + model.minify_mudata() + assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR + + model.save(save_path, overwrite=True, save_anndata=False, legacy_mudata_format=True) + # load saved model with a non-minified mdata + loaded_model = cls.load(save_path, adata=mdata_before_setup) + + assert loaded_model.minified_data_type is None + + +@pytest.mark.parametrize("cls", [TOTALVI]) +def test_save_then_load_with_minified_mdata(cls, save_path): + # create a model, then save it and its mdata (non-minified). + # Load it back up using a minified mdata. Validate that this + # fails, as expected because we don't have a way to validate + # whether the minified-mdata was set up correctly + model, _, _, _ = prep_model_mudata(cls=cls, use_size_factor=True) + + qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) + model.adata.obsm["X_latent_qzm"] = qzm + model.adata.obsm["X_latent_qzv"] = qzv + + model.save(save_path, overwrite=True, save_anndata=False, legacy_mudata_format=True) + + model.minify_mudata() + assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR + + # loading this model with a minified mdata is not allowed because + # we don't have a way to validate whether the minified-mdata was + # set up correctly + with pytest.raises(KeyError): + cls.load(save_path, adata=model.adata) + + +@pytest.mark.parametrize("cls", [TOTALVI]) +def test_with_minified_mdata_get_latent_representation(cls): + model, _, _, _ = prep_model_mudata(cls=cls, use_size_factor=True) + + qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) + model.adata.obsm["X_latent_qzm"] = qzm + model.adata.obsm["X_latent_qzv"] = qzv + + latent_repr_orig = model.get_latent_representation() + + model.minify_mudata() + assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR + + latent_repr_new = model.get_latent_representation() + + assert_approx_equal(latent_repr_new, latent_repr_orig) + + +@pytest.mark.parametrize("cls", [TOTALVI]) +def test_with_minified_mdata_get_feature_correlation_matrix(cls): + model, _, _, _ = prep_model_mudata(cls=cls, use_size_factor=True) + + qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) + model.adata.obsm["X_latent_qzm"] = qzm + model.adata.obsm["X_latent_qzv"] = qzv + + fcm_orig = model.get_feature_correlation_matrix( + correlation_type="spearman", + transform_batch=["batch_0", "batch_1"], + ) + + model.minify_mudata() + assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR + + fcm_new = model.get_feature_correlation_matrix( + correlation_type="spearman", + transform_batch=["batch_0", "batch_1"], + ) + + assert_approx_equal(fcm_new, fcm_orig)