diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d72632cb51..92232aff37 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -31,7 +31,7 @@ repos: )$ - repo: https://github.com/igorshubovych/markdownlint-cli - rev: v0.42.0 + rev: v0.43.0 hooks: - id: markdownlint-fix exclude: | @@ -41,7 +41,7 @@ repos: )$ - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.7.4 + rev: v0.8.1 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] diff --git a/CHANGELOG.md b/CHANGELOG.md index bec14bc8a0..fd4017ef1d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,12 @@ to [Semantic Versioning]. Full commit history is available in the validation set, if available. {pr}`3036`. - Add `batch_key` and `labels_key` to `scvi.external.SCAR.setup_anndata`. - Implemented variance of ZINB distribution. {pr}`3044`. +- Support for minified mode while retaining counts to skip the encoder. +- New Trainingplan argument `update_only_decoder` to use stored latent codes and skip training of + the encoder. +- Refactored code for minified models. +- Add {class}`scvi.external.METHYLVI` for modeling methylation data from single-cell + bisulfite sequencing (scBS-seq) experiments {pr}`2834`. #### Fixed @@ -35,6 +41,8 @@ to [Semantic Versioning]. Full commit history is available in the to correctly compute the maxmimum log-density across in-sample cells rather than the aggregated posterior log-density {pr}`3007`. - Fix references to `scvi.external` in `scvi.external.SCAR.setup_anndata`. +- Fix gimVI to append mini batches first into CPU during get_imputed and get_latent operations {pr}`3058`. +- #### Changed @@ -88,8 +96,6 @@ to [Semantic Versioning]. Full commit history is available in the data {pr}`2756`. - Add support for reference mapping with {class}`mudata.MuData` models to {class}`scvi.model.base.ArchesMixin` {pr}`2578`. -- Add {class}`scvi.external.METHYLVI` for modeling methylation data from single-cell - bisulfite sequencing (scBS-seq) experiments {pr}`2834`. - Add argument `return_mean` to {meth}`scvi.model.base.VAEMixin.get_reconstruction_error` and {meth}`scvi.model.base.VAEMixin.get_elbo` to allow computation without averaging across cells {pr}`2362`. diff --git a/src/scvi/data/_built_in_data/_brain_large.py b/src/scvi/data/_built_in_data/_brain_large.py index fe75eb536c..21239e3b90 100644 --- a/src/scvi/data/_built_in_data/_brain_large.py +++ b/src/scvi/data/_built_in_data/_brain_large.py @@ -92,8 +92,8 @@ def _load_brainlarge_file( logger.info( f"loaded {i * loading_batch_size + n_cells_batch} / {n_cells_to_keep} cells" ) - logger.info("%d cells subsampled" % matrix.shape[0]) - logger.info("%d genes subsampled" % matrix.shape[1]) + logger.info(f"{matrix.shape[0]} cells subsampled") + logger.info(f"{matrix.shape[1]} genes subsampled") adata = anndata.AnnData(matrix) adata.obs["labels"] = np.zeros(matrix.shape[0]) adata.obs["batch"] = np.zeros(matrix.shape[0]) diff --git a/src/scvi/data/_constants.py b/src/scvi/data/_constants.py index 9efa664537..09ab91aeaa 100644 --- a/src/scvi/data/_constants.py +++ b/src/scvi/data/_constants.py @@ -37,6 +37,7 @@ class _ADATA_MINIFY_TYPE_NT(NamedTuple): LATENT_POSTERIOR: str = "latent_posterior_parameters" + LATENT_POSTERIOR_WITH_COUNTS: str = "latent_posterior_parameters_with_counts" ADATA_MINIFY_TYPE = _ADATA_MINIFY_TYPE_NT() diff --git a/src/scvi/external/gimvi/_model.py b/src/scvi/external/gimvi/_model.py index ac7b80f508..8bbde7326b 100644 --- a/src/scvi/external/gimvi/_model.py +++ b/src/scvi/external/gimvi/_model.py @@ -313,9 +313,11 @@ def get_latent_representation( self.module.sample_from_posterior_z( sample_batch, mode, deterministic=deterministic ) + .cpu() + .detach() ) - latent = torch.cat(latent).cpu().detach().numpy() + latent = torch.cat(latent).numpy() latents.append(latent) return latents @@ -372,6 +374,8 @@ def get_imputed_values( deterministic=deterministic, decode_mode=decode_mode, ) + .cpu() + .detach() ) else: imputed_value.append( @@ -383,9 +387,11 @@ def get_imputed_values( deterministic=deterministic, decode_mode=decode_mode, ) + .cpu() + .detach() ) - imputed_value = torch.cat(imputed_value).cpu().detach().numpy() + imputed_value = torch.cat(imputed_value).numpy() imputed_values.append(imputed_value) return imputed_values diff --git a/src/scvi/model/_scanvi.py b/src/scvi/model/_scanvi.py index b596e48afa..084d83be1f 100644 --- a/src/scvi/model/_scanvi.py +++ b/src/scvi/model/_scanvi.py @@ -8,12 +8,10 @@ import numpy as np import pandas as pd import torch -from anndata import AnnData from scvi import REGISTRY_KEYS, settings from scvi.data import AnnDataManager from scvi.data._constants import ( - _ADATA_MINIFY_TYPE_UNS_KEY, _SETUP_ARGS_KEY, ADATA_MINIFY_TYPE, ) @@ -25,12 +23,9 @@ LayerField, NumericalJointObsField, NumericalObsField, - ObsmField, - StringUnsField, ) from scvi.dataloaders import SemiSupervisedDataSplitter from scvi.model._utils import _init_library_size, get_max_epochs_heuristic -from scvi.model.utils import get_minified_adata_scrna from scvi.module import SCANVAE from scvi.train import SemiSupervisedTrainingPlan, TrainRunner from scvi.train._callbacks import SubSampleLabels @@ -45,17 +40,8 @@ from anndata import AnnData - from scvi._types import MinifiedDataType - from scvi.data.fields import ( - BaseAnnDataField, - ) - from ._scvi import SCVI -_SCANVI_LATENT_QZM = "_scanvi_latent_qzm" -_SCANVI_LATENT_QZV = "_scanvi_latent_qzv" -_SCANVI_OBSERVED_LIB_SIZE = "_scanvi_observed_lib_size" - logger = logging.getLogger(__name__) @@ -115,6 +101,8 @@ class SCANVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseMinifiedModeModelClass): _module_cls = SCANVAE _training_plan_cls = SemiSupervisedTrainingPlan + _LATENT_QZM = "scanvi_latent_qzm" + _LATENT_QZV = "scanvi_latent_qzv" def __init__( self, @@ -223,17 +211,18 @@ def from_scvi_model( ) del scanvi_kwargs[k] - if scvi_model.minified_data_type is not None: + if scvi_model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR: raise ValueError( - "We cannot use the given scvi model to initialize scanvi because it has a " - "minified adata." + "We cannot use the given scVI model to initialize scANVI because it has " + "minified adata. Keep counts when minifying model using " + "minified_data_type='latent_posterior_parameters_with_counts'." ) if adata is None: adata = scvi_model.adata else: if _is_minified(adata): - raise ValueError("Please provide a non-minified `adata` to initialize scanvi.") + raise ValueError("Please provide a non-minified `adata` to initialize scANVI.") # validate new anndata against old model scvi_model._validate_anndata(adata) @@ -241,7 +230,7 @@ def from_scvi_model( scvi_labels_key = scvi_setup_args["labels_key"] if labels_key is None and scvi_labels_key is None: raise ValueError( - "A `labels_key` is necessary as the SCVI model was initialized without one." + "A `labels_key` is necessary as the scVI model was initialized without one." ) if scvi_labels_key is None: scvi_setup_args.update({"labels_key": labels_key}) @@ -485,79 +474,3 @@ def setup_anndata( adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) - - @staticmethod - def _get_fields_for_adata_minification( - minified_data_type: MinifiedDataType, - ) -> list[BaseAnnDataField]: - """Return the fields required for adata minification of the given minified_data_type.""" - if minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR: - fields = [ - ObsmField( - REGISTRY_KEYS.LATENT_QZM_KEY, - _SCANVI_LATENT_QZM, - ), - ObsmField( - REGISTRY_KEYS.LATENT_QZV_KEY, - _SCANVI_LATENT_QZV, - ), - NumericalObsField( - REGISTRY_KEYS.OBSERVED_LIB_SIZE, - _SCANVI_OBSERVED_LIB_SIZE, - ), - ] - else: - raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}") - fields.append( - StringUnsField( - REGISTRY_KEYS.MINIFY_TYPE_KEY, - _ADATA_MINIFY_TYPE_UNS_KEY, - ), - ) - return fields - - def minify_adata( - self, - minified_data_type: MinifiedDataType = ADATA_MINIFY_TYPE.LATENT_POSTERIOR, - use_latent_qzm_key: str = "X_latent_qzm", - use_latent_qzv_key: str = "X_latent_qzv", - ): - """Minifies the model's adata. - - Minifies the adata, and registers new anndata fields: latent qzm, latent qzv, adata uns - containing minified-adata type, and library size. - This also sets the appropriate property on the module to indicate that the adata is - minified. - - Parameters - ---------- - minified_data_type - How to minify the data. Currently only supports `latent_posterior_parameters`. - If minified_data_type == `latent_posterior_parameters`: - - * the original count data is removed (`adata.X`, adata.raw, and any layers) - * the parameters of the latent representation of the original data is stored - * everything else is left untouched - use_latent_qzm_key - Key to use in `adata.obsm` where the latent qzm params are stored - use_latent_qzv_key - Key to use in `adata.obsm` where the latent qzv params are stored - - Notes - ----- - The modification is not done inplace -- instead the model is assigned a new (minified) - version of the adata. - """ - if minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR: - raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}") - - if self.module.use_observed_lib_size is False: - raise ValueError("Cannot minify the data if `use_observed_lib_size` is False") - - minified_adata = get_minified_adata_scrna(self.adata, minified_data_type) - minified_adata.obsm[_SCANVI_LATENT_QZM] = self.adata.obsm[use_latent_qzm_key] - minified_adata.obsm[_SCANVI_LATENT_QZV] = self.adata.obsm[use_latent_qzv_key] - counts = self.adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY) - minified_adata.obs[_SCANVI_OBSERVED_LIB_SIZE] = np.squeeze(np.asarray(counts.sum(axis=1))) - self._update_adata_and_manager_post_minification(minified_adata, minified_data_type) - self.module.minified_data_type = minified_data_type diff --git a/src/scvi/model/_scvi.py b/src/scvi/model/_scvi.py index ee6b7765e3..36ad6f2c39 100644 --- a/src/scvi/model/_scvi.py +++ b/src/scvi/model/_scvi.py @@ -4,11 +4,8 @@ import warnings from typing import TYPE_CHECKING -import numpy as np - from scvi import REGISTRY_KEYS, settings from scvi.data import AnnDataManager -from scvi.data._constants import _ADATA_MINIFY_TYPE_UNS_KEY, ADATA_MINIFY_TYPE from scvi.data._utils import _get_adata_minify_type from scvi.data.fields import ( CategoricalJointObsField, @@ -16,12 +13,9 @@ LayerField, NumericalJointObsField, NumericalObsField, - ObsmField, - StringUnsField, ) from scvi.model._utils import _init_library_size from scvi.model.base import EmbeddingMixin, UnsupervisedTrainingMixin -from scvi.model.utils import get_minified_adata_scrna from scvi.module import VAE from scvi.utils import setup_anndata_dsp @@ -32,15 +26,6 @@ from anndata import AnnData - from scvi._types import MinifiedDataType - from scvi.data.fields import ( - BaseAnnDataField, - ) - -_SCVI_LATENT_QZM = "_scvi_latent_qzm" -_SCVI_LATENT_QZV = "_scvi_latent_qzv" -_SCVI_OBSERVED_LIB_SIZE = "_scvi_observed_lib_size" - logger = logging.getLogger(__name__) @@ -115,6 +100,8 @@ class SCVI( """ _module_cls = VAE + _SCVI_LATENT_QZM = "scvi_latent_qzm" + _SCVI_LATENT_QZV = "scvi_latent_qzv" def __init__( self, @@ -231,81 +218,3 @@ def setup_anndata( adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) - - @staticmethod - def _get_fields_for_adata_minification( - minified_data_type: MinifiedDataType, - ) -> list[BaseAnnDataField]: - """Return the fields required for adata minification of the given minified_data_type.""" - if minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR: - fields = [ - ObsmField( - REGISTRY_KEYS.LATENT_QZM_KEY, - _SCVI_LATENT_QZM, - ), - ObsmField( - REGISTRY_KEYS.LATENT_QZV_KEY, - _SCVI_LATENT_QZV, - ), - NumericalObsField( - REGISTRY_KEYS.OBSERVED_LIB_SIZE, - _SCVI_OBSERVED_LIB_SIZE, - ), - ] - else: - raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}") - fields.append( - StringUnsField( - REGISTRY_KEYS.MINIFY_TYPE_KEY, - _ADATA_MINIFY_TYPE_UNS_KEY, - ), - ) - return fields - - def minify_adata( - self, - minified_data_type: MinifiedDataType = ADATA_MINIFY_TYPE.LATENT_POSTERIOR, - use_latent_qzm_key: str = "X_latent_qzm", - use_latent_qzv_key: str = "X_latent_qzv", - ) -> None: - """Minifies the model's adata. - - Minifies the adata, and registers new anndata fields: latent qzm, latent qzv, adata uns - containing minified-adata type, and library size. - This also sets the appropriate property on the module to indicate that the adata is - minified. - - Parameters - ---------- - minified_data_type - How to minify the data. Currently only supports `latent_posterior_parameters`. - If minified_data_type == `latent_posterior_parameters`: - - * the original count data is removed (`adata.X`, adata.raw, and any layers) - * the parameters of the latent representation of the original data is stored - * everything else is left untouched - use_latent_qzm_key - Key to use in `adata.obsm` where the latent qzm params are stored - use_latent_qzv_key - Key to use in `adata.obsm` where the latent qzv params are stored - - Notes - ----- - The modification is not done inplace -- instead the model is assigned a new (minified) - version of the adata. - """ - # TODO(adamgayoso): Add support for a scenario where we want to cache the latent posterior - # without removing the original counts. - if minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR: - raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}") - - if self.module.use_observed_lib_size is False: - raise ValueError("Cannot minify the data if `use_observed_lib_size` is False") - - minified_adata = get_minified_adata_scrna(self.adata, minified_data_type) - minified_adata.obsm[_SCVI_LATENT_QZM] = self.adata.obsm[use_latent_qzm_key] - minified_adata.obsm[_SCVI_LATENT_QZV] = self.adata.obsm[use_latent_qzv_key] - counts = self.adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY) - minified_adata.obs[_SCVI_OBSERVED_LIB_SIZE] = np.squeeze(np.asarray(counts.sum(axis=1))) - self._update_adata_and_manager_post_minification(minified_adata, minified_data_type) - self.module.minified_data_type = minified_data_type diff --git a/src/scvi/model/base/_base_model.py b/src/scvi/model/base/_base_model.py index fd47bf1926..64b0594e29 100644 --- a/src/scvi/model/base/_base_model.py +++ b/src/scvi/model/base/_base_model.py @@ -15,13 +15,15 @@ from mudata import MuData from scvi import REGISTRY_KEYS, settings -from scvi.data import AnnDataManager +from scvi.data import AnnDataManager, fields from scvi.data._compat import registry_from_setup_dict from scvi.data._constants import ( + _ADATA_MINIFY_TYPE_UNS_KEY, _MODEL_NAME_KEY, _SCVI_UUID_KEY, _SETUP_ARGS_KEY, _SETUP_METHOD_NAME, + ADATA_MINIFY_TYPE, ) from scvi.data._utils import _assign_adata_uuid, _check_if_view, _get_adata_minify_type from scvi.dataloaders import AnnDataLoader @@ -33,6 +35,7 @@ _load_saved_files, _validate_var_names, ) +from scvi.model.utils import get_minified_adata_scrna from scvi.utils import attrdict, setup_anndata_dsp from scvi.utils._docstrings import devices_dsp @@ -87,6 +90,9 @@ class BaseModelClass(metaclass=BaseModelMetaClass): 1. :doc:`/tutorials/notebooks/dev/model_user_guide` """ + _LATENT_QZM_KEY = "latent_qzm" + _LATENT_QZV_KEY = "latent_qzv" + _OBSERVED_LIB_SIZE_KEY = "observed_lib_size" _data_loader_cls = AnnDataLoader def __init__(self, adata: AnnOrMuData | None = None): @@ -887,53 +893,122 @@ def view_anndata_setup( class BaseMinifiedModeModelClass(BaseModelClass): - """Abstract base class for scvi-tools models that can handle minified data.""" + """Base class for models that can handle minified data.""" @property def minified_data_type(self) -> MinifiedDataType | None: - """The type of minified data associated with this model, if applicable.""" + """Type of minified data associated with this model.""" return ( self.adata_manager.get_from_registry(REGISTRY_KEYS.MINIFY_TYPE_KEY) if REGISTRY_KEYS.MINIFY_TYPE_KEY in self.adata_manager.data_registry else None ) - @abstractmethod def minify_adata( self, - *args, - **kwargs, - ): - """Minifies the model's adata. + minified_data_type: MinifiedDataType = ADATA_MINIFY_TYPE.LATENT_POSTERIOR, + use_latent_qzm_key: str = "X_latent_qzm", + use_latent_qzv_key: str = "X_latent_qzv", + ) -> None: + """Minify the model's :attr:`~scvi.model.base.BaseModelClass.adata`. + + Minifies the :class:`~anndata.AnnData` object associated with the model according to the + method specified by ``minified_data_type`` and registers the new fields with the model's + :class:`~scvi.data.AnnDataManager`. This also sets the ``minified_data_type`` attribute + of the underlying :class:`~scvi.module.base.BaseModuleClass` instance. - Minifies the adata, and registers new anndata fields as required (can be model-specific). - This also sets the appropriate property on the module to indicate that the adata is - minified. + Parameters + ---------- + minified_data_type + Method for minifying the data. One of the following: + + - ``"latent_posterior_parameters"``: Store the latent posterior mean and variance in + :attr:`~anndata.AnnData.obsm` using the keys ``use_latent_qzm_key`` and + ``use_latent_qzv_key``. + - ``"latent_posterior_parameters_with_counts"``: Store the latent posterior mean and + variance in :attr:`~anndata.AnnData.obsm` using the keys ``use_latent_qzm_key`` and + ``use_latent_qzv_key``, and the raw count data in :attr:`~anndata.AnnData.X`. + use_latent_qzm_key + Key to use for storing the latent posterior mean in :attr:`~anndata.AnnData.obsm` when + ``minified_data_type`` is ``"latent_posterior"``. + use_latent_qzv_key + Key to use for storing the latent posterior variance in :attr:`~anndata.AnnData.obsm` + when ``minified_data_type`` is ``"latent_posterior"``. Notes ----- The modification is not done inplace -- instead the model is assigned a new (minified) - version of the adata. + version of the :class:`~anndata.AnnData`. """ + if minified_data_type not in ADATA_MINIFY_TYPE: + raise NotImplementedError( + f"Minification method {minified_data_type} is not supported." + ) + elif not getattr(self.module, "use_observed_lib_size", True): + raise ValueError( + "Minification is not supported for models that do not use observed library size." + ) - @staticmethod - @abstractmethod - def _get_fields_for_adata_minification(minified_data_type: MinifiedDataType): - """Return the anndata fields required for adata minification of the given type.""" + keep_count_data = minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR_WITH_COUNTS + mini_adata = get_minified_adata_scrna( + adata_manager=self.adata_manager, + keep_count_data=keep_count_data, + ) + del mini_adata.uns[_SCVI_UUID_KEY] + mini_adata.uns[_ADATA_MINIFY_TYPE_UNS_KEY] = minified_data_type + mini_adata.obsm[self._LATENT_QZM_KEY] = self.adata.obsm[use_latent_qzm_key] + mini_adata.obsm[self._LATENT_QZV_KEY] = self.adata.obsm[use_latent_qzv_key] + mini_adata.obs[self._OBSERVED_LIB_SIZE_KEY] = np.squeeze( + np.asarray(self.adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY).sum(axis=-1)) + ) + self._update_adata_and_manager_post_minification( + mini_adata, + minified_data_type, + ) + self.module.minified_data_type = minified_data_type + + @classmethod + def _get_fields_for_adata_minification( + cls, + minified_data_type: MinifiedDataType, + ): + """Return the fields required for minification of the given type.""" + if minified_data_type not in ADATA_MINIFY_TYPE: + raise NotImplementedError( + f"Minification method {minified_data_type} is not supported." + ) + + mini_fields = [ + fields.ObsmField(REGISTRY_KEYS.LATENT_QZM_KEY, cls._LATENT_QZM_KEY), + fields.ObsmField(REGISTRY_KEYS.LATENT_QZV_KEY, cls._LATENT_QZV_KEY), + fields.NumericalObsField(REGISTRY_KEYS.OBSERVED_LIB_SIZE, cls._OBSERVED_LIB_SIZE_KEY), + fields.StringUnsField(REGISTRY_KEYS.MINIFY_TYPE_KEY, _ADATA_MINIFY_TYPE_UNS_KEY), + ] + + return mini_fields def _update_adata_and_manager_post_minification( - self, minified_adata: AnnOrMuData, minified_data_type: MinifiedDataType + self, + minified_adata: AnnOrMuData, + minified_data_type: MinifiedDataType, ): - """Update the anndata and manager inplace after creating a minified adata.""" - # Register this new adata with the model, creating a new manager in the cache + """Update the :class:`~anndata.AnnData` and :class:`~scvi.data.AnnDataManager` in-place. + + Parameters + ---------- + minified_adata + Minified version of :attr:`~scvi.model.base.BaseModelClass.adata`. + minified_data_type + Method used for minifying the data. + keep_count_data + If ``True``, the full count matrix is kept in the minified + :attr:`~scvi.model.base.BaseModelClass.adata`. + """ self._validate_anndata(minified_adata) new_adata_manager = self.get_anndata_manager(minified_adata, required=True) - # This inplace edits the manager new_adata_manager.register_new_fields( self._get_fields_for_adata_minification(minified_data_type) ) - # We set the adata attribute of the model as this will update self.registry_ - # and self.adata_manager with the new adata manager self.adata = minified_adata @property diff --git a/src/scvi/model/base/_log_likelihood.py b/src/scvi/model/base/_log_likelihood.py index 97df6f557b..ea0552f61c 100644 --- a/src/scvi/model/base/_log_likelihood.py +++ b/src/scvi/model/base/_log_likelihood.py @@ -1,5 +1,6 @@ from __future__ import annotations +from inspect import signature from typing import TYPE_CHECKING import torch @@ -48,8 +49,14 @@ def compute_elbo( The evidence lower bound (ELBO) of the data. """ elbo = [] + if "full_forward_pass" in signature(module._get_inference_input).parameters: + get_inference_input_kwargs = {"full_forward_pass": True} + else: + get_inference_input_kwargs = {} for tensors in dataloader: - _, _, losses = module(tensors, **kwargs) + _, _, losses = module( + tensors, **kwargs, get_inference_input_kwargs=get_inference_input_kwargs + ) if isinstance(losses.reconstruction_loss, dict): reconstruction_loss = torch.stack(list(losses.reconstruction_loss.values())).sum(dim=0) else: @@ -99,9 +106,18 @@ def compute_reconstruction_error( A dictionary of the reconstruction error of the data. """ # Iterate once over the data and computes the reconstruction error + if "full_forward_pass" in signature(module._get_inference_input).parameters: + get_inference_input_kwargs = {"full_forward_pass": True} + else: + get_inference_input_kwargs = {} log_lkl = {} for tensors in dataloader: - _, _, loss_output = module(tensors, loss_kwargs={"kl_weight": 1}, **kwargs) + _, _, loss_output = module( + tensors, + loss_kwargs={"kl_weight": 1}, + get_inference_input_kwargs=get_inference_input_kwargs, + **kwargs, + ) if not isinstance(loss_output.reconstruction_loss, dict): rec_loss_dict = {"reconstruction_loss": loss_output.reconstruction_loss} else: diff --git a/src/scvi/model/base/_vaemixin.py b/src/scvi/model/base/_vaemixin.py index 7de4fe43c9..1a2ea85bfa 100644 --- a/src/scvi/model/base/_vaemixin.py +++ b/src/scvi/model/base/_vaemixin.py @@ -29,7 +29,7 @@ def get_elbo( adata: AnnData | None = None, indices: Sequence[int] | None = None, batch_size: int | None = None, - dataloader: Iterator[dict[str, Tensor | None]] = None, + dataloader: Iterator[dict[str, Tensor | None]] | None = None, return_mean: bool = True, **kwargs, ) -> float: @@ -167,7 +167,7 @@ def get_reconstruction_error( adata: AnnData | None = None, indices: Sequence[int] | None = None, batch_size: int | None = None, - dataloader: Iterator[dict[str, Tensor | None]] = None, + dataloader: Iterator[dict[str, Tensor | None]] | None = None, return_mean: bool = True, **kwargs, ) -> dict[str, float]: diff --git a/src/scvi/model/utils/_minification.py b/src/scvi/model/utils/_minification.py index cf84687bc5..77fb737ba1 100644 --- a/src/scvi/model/utils/_minification.py +++ b/src/scvi/model/utils/_minification.py @@ -1,43 +1,36 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + from anndata import AnnData from scipy.sparse import csr_matrix -from scvi._types import MinifiedDataType -from scvi.data._constants import ( - _ADATA_MINIFY_TYPE_UNS_KEY, - _SCVI_UUID_KEY, - ADATA_MINIFY_TYPE, -) +from scvi import REGISTRY_KEYS + +if TYPE_CHECKING: + from scvi.data import AnnDataManager def get_minified_adata_scrna( - adata: AnnData, - minified_data_type: MinifiedDataType, + adata_manager: AnnDataManager, + keep_count_data: bool = False, ) -> AnnData: - """Returns a minified adata that works for most scrna models (such as SCVI, SCANVI). - - Parameters - ---------- - adata - Original adata, of which we to create a minified version. - minified_data_type - How to minify the data. - """ - if minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR: - raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}") - - all_zeros = csr_matrix(adata.X.shape) - layers = {layer: all_zeros.copy() for layer in adata.layers} - bdata = AnnData( - X=all_zeros, - layers=layers, - uns=adata.uns.copy(), - obs=adata.obs, - var=adata.var, - varm=adata.varm, - obsm=adata.obsm, - obsp=adata.obsp, - ) - # Remove scvi uuid key to make bdata fresh w.r.t. the model's manager - del bdata.uns[_SCVI_UUID_KEY] - bdata.uns[_ADATA_MINIFY_TYPE_UNS_KEY] = minified_data_type - return bdata + """Get a minified version of an :class:`~anndata.AnnData` or :class:`~mudata.MuData` object.""" + if keep_count_data: + return adata_manager.adata.copy() + else: + counts = adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY) + all_zeros = csr_matrix(counts.shape) + X = all_zeros + layers = {layer: all_zeros.copy() for layer in adata_manager.adata.layers} + return AnnData( + X=X, + layers=layers, + obs=adata_manager.adata.obs.copy(), + var=adata_manager.adata.var.copy(), + uns=adata_manager.adata.uns.copy(), + obsm=adata_manager.adata.obsm.copy(), + varm=adata_manager.adata.varm.copy(), + obsp=adata_manager.adata.obsp.copy(), + varp=adata_manager.adata.varp.copy(), + ) diff --git a/src/scvi/module/_vae.py b/src/scvi/module/_vae.py index 920b65ca18..adfc2934ef 100644 --- a/src/scvi/module/_vae.py +++ b/src/scvi/module/_vae.py @@ -9,6 +9,7 @@ from torch.nn.functional import one_hot from scvi import REGISTRY_KEYS, settings +from scvi.data._constants import ADATA_MINIFY_TYPE from scvi.module._constants import MODULE_KEYS from scvi.module.base import ( BaseMinifiedModeModuleClass, @@ -16,6 +17,7 @@ LossOutput, auto_move_data, ) +from scvi.utils import unsupported_if_adata_minified if TYPE_CHECKING: from collections.abc import Callable @@ -281,25 +283,32 @@ def __init__( def _get_inference_input( self, tensors: dict[str, torch.Tensor | None], + full_forward_pass: bool = False, ) -> dict[str, torch.Tensor | None]: """Get input tensors for the inference process.""" - from scvi.data._constants import ADATA_MINIFY_TYPE + if full_forward_pass or self.minified_data_type is None: + loader = "full_data" + elif self.minified_data_type in [ + ADATA_MINIFY_TYPE.LATENT_POSTERIOR, + ADATA_MINIFY_TYPE.LATENT_POSTERIOR_WITH_COUNTS, + ]: + loader = "minified_data" + else: + raise NotImplementedError(f"Unknown minified-data type: {self.minified_data_type}") - if self.minified_data_type is None: + if loader == "full_data": return { MODULE_KEYS.X_KEY: tensors[REGISTRY_KEYS.X_KEY], MODULE_KEYS.BATCH_INDEX_KEY: tensors[REGISTRY_KEYS.BATCH_KEY], MODULE_KEYS.CONT_COVS_KEY: tensors.get(REGISTRY_KEYS.CONT_COVS_KEY, None), MODULE_KEYS.CAT_COVS_KEY: tensors.get(REGISTRY_KEYS.CAT_COVS_KEY, None), } - elif self.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR: + else: return { MODULE_KEYS.QZM_KEY: tensors[REGISTRY_KEYS.LATENT_QZM_KEY], MODULE_KEYS.QZV_KEY: tensors[REGISTRY_KEYS.LATENT_QZV_KEY], REGISTRY_KEYS.OBSERVED_LIB_SIZE: tensors[REGISTRY_KEYS.OBSERVED_LIB_SIZE], } - else: - raise NotImplementedError(f"Unknown minified-data type: {self.minified_data_type}") def _get_generative_input( self, @@ -414,14 +423,9 @@ def _cached_inference( """Run the cached inference process.""" from torch.distributions import Normal - from scvi.data._constants import ADATA_MINIFY_TYPE - - if self.minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR: - raise NotImplementedError(f"Unknown minified-data type: {self.minified_data_type}") - - dist = Normal(qzm, qzv.sqrt()) + qz = Normal(qzm, qzv.sqrt()) # use dist.sample() rather than rsample because we aren't optimizing the z here - untran_z = dist.sample() if n_samples == 1 else dist.sample((n_samples,)) + untran_z = qz.sample() if n_samples == 1 else qz.sample((n_samples,)) z = self.z_encoder.z_transformation(untran_z) library = torch.log(observed_lib_size) if n_samples > 1: @@ -429,8 +433,7 @@ def _cached_inference( return { MODULE_KEYS.Z_KEY: z, - MODULE_KEYS.QZM_KEY: qzm, - MODULE_KEYS.QZV_KEY: qzv, + MODULE_KEYS.QZ_KEY: qz, MODULE_KEYS.QL_KEY: None, MODULE_KEYS.LIBRARY_KEY: library, } @@ -541,6 +544,7 @@ def generative( MODULE_KEYS.PZ_KEY: pz, } + @unsupported_if_adata_minified def loss( self, tensors: dict[str, torch.Tensor], @@ -670,7 +674,9 @@ def marginal_ll( for _ in range(n_passes): # Distribution parameters and sampled variables inference_outputs, _, losses = self.forward( - tensors, inference_kwargs={"n_samples": n_mc_samples_per_pass} + tensors, + inference_kwargs={"n_samples": n_mc_samples_per_pass}, + get_inference_input_kwargs={"full_forward_pass": True}, ) qz = inference_outputs[MODULE_KEYS.QZ_KEY] ql = inference_outputs[MODULE_KEYS.QL_KEY] diff --git a/src/scvi/module/base/_base_module.py b/src/scvi/module/base/_base_module.py index 39097c9039..46395e8752 100644 --- a/src/scvi/module/base/_base_module.py +++ b/src/scvi/module/base/_base_module.py @@ -13,7 +13,6 @@ from torch import nn from scvi import settings -from scvi.data._constants import ADATA_MINIFY_TYPE from scvi.utils._jax import device_selecting_PRNGKey from ._decorators import auto_move_data @@ -303,10 +302,7 @@ def inference(self, *args, **kwargs): Branches off to regular or cached inference depending on whether we have a minified adata that contains the latent posterior parameters. """ - if ( - self.minified_data_type is not None - and self.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR - ): + if "qzm" in kwargs.keys() and "qzv" in kwargs.keys(): return self._cached_inference(*args, **kwargs) else: return self._regular_inference(*args, **kwargs) @@ -743,6 +739,9 @@ def _generic_forward( loss_kwargs = _get_dict_if_none(loss_kwargs) get_inference_input_kwargs = _get_dict_if_none(get_inference_input_kwargs) get_generative_input_kwargs = _get_dict_if_none(get_generative_input_kwargs) + if not ("latent_qzm" in tensors.keys() and "latent_qzv" in tensors.keys()): + # Remove full_forward_pass if not minified model + get_inference_input_kwargs.pop("full_forward_pass", None) inference_inputs = module._get_inference_input(tensors, **get_inference_input_kwargs) inference_outputs = module.inference(**inference_inputs, **inference_kwargs) diff --git a/src/scvi/train/_trainingplans.py b/src/scvi/train/_trainingplans.py index 79aa4bf0e3..b1ab32e8f9 100644 --- a/src/scvi/train/_trainingplans.py +++ b/src/scvi/train/_trainingplans.py @@ -146,6 +146,7 @@ def __init__( optimizer: Literal["Adam", "AdamW", "Custom"] = "Adam", optimizer_creator: TorchOptimizerCreator | None = None, lr: float = 1e-3, + update_only_decoder: bool = False, weight_decay: float = 1e-6, eps: float = 0.01, n_steps_kl_warmup: int = None, @@ -180,6 +181,7 @@ def __init__( self.min_kl_weight = min_kl_weight self.max_kl_weight = max_kl_weight self.optimizer_creator = optimizer_creator + self.update_only_decoder = update_only_decoder if self.optimizer_name == "Custom" and self.optimizer_creator is None: raise ValueError("If optimizer is 'Custom', `optimizer_creator` must be provided.") @@ -275,7 +277,11 @@ def n_obs_validation(self, n_obs: int): def forward(self, *args, **kwargs): """Passthrough to the module's forward method.""" - return self.module(*args, **kwargs) + return self.module( + *args, + **kwargs, + get_inference_input_kwargs={"full_forward_pass": not self.update_only_decoder}, + ) @torch.inference_mode() def compute_and_log_metrics( diff --git a/src/scvi/utils/_decorators.py b/src/scvi/utils/_decorators.py index 156dc90e74..fe728ef33d 100644 --- a/src/scvi/utils/_decorators.py +++ b/src/scvi/utils/_decorators.py @@ -1,13 +1,15 @@ from collections.abc import Callable from functools import wraps +from scvi.data._constants import ADATA_MINIFY_TYPE + def unsupported_if_adata_minified(fn: Callable) -> Callable: """Decorator to raise an error if the model's `adata` is minified.""" @wraps(fn) def wrapper(self, *args, **kwargs): - if getattr(self, "minified_data_type", None) is not None: + if getattr(self, "minified_data_type", None) == ADATA_MINIFY_TYPE.LATENT_POSTERIOR: raise ValueError( f"The {fn.__qualname__} function currently does not support minified data." ) diff --git a/tests/model/test_models_with_minified_data.py b/tests/model/test_models_with_minified_data.py index 5f87ed804c..52c9013362 100644 --- a/tests/model/test_models_with_minified_data.py +++ b/tests/model/test_models_with_minified_data.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + import numpy as np import pytest @@ -6,51 +10,51 @@ from scvi.data._constants import _ADATA_MINIFY_TYPE_UNS_KEY, ADATA_MINIFY_TYPE from scvi.data._utils import _is_minified from scvi.model import SCANVI, SCVI +from scvi.model.base import BaseMinifiedModeModelClass + +if TYPE_CHECKING: + import numpy.typing as npt + from anndata import AnnData _SCVI_OBSERVED_LIB_SIZE = "_scvi_observed_lib_size" _SCANVI_OBSERVED_LIB_SIZE = "_scanvi_observed_lib_size" -def prep_model(cls=SCVI, layer=None, use_size_factor=False): - # create a synthetic dataset +def prep_model( + cls: BaseMinifiedModeModelClass = SCVI, + layer: str | None = None, + use_size_factor: bool = False, + n_latent: int = 5, +) -> tuple[BaseMinifiedModeModelClass, AnnData, npt.NDArray, AnnData]: adata = synthetic_iid() - adata_counts = adata.X + counts = adata.X if use_size_factor: adata.obs["size_factor"] = np.random.randint(1, 5, size=(adata.shape[0],)) if layer is not None: adata.layers[layer] = adata.X.copy() adata.X = np.zeros_like(adata.X) - adata.var["n_counts"] = np.squeeze(np.asarray(np.sum(adata_counts, axis=0))) - adata.varm["my_varm"] = np.random.negative_binomial(5, 0.3, size=(adata.shape[1], 3)) - adata.layers["my_layer"] = np.ones_like(adata.X) + adata_before_setup = adata.copy() - # run setup_anndata setup_kwargs = { "layer": layer, "batch_key": "batch", "labels_key": "labels", + "size_factor_key": "size_factor" if use_size_factor else None, } if cls == SCANVI: setup_kwargs["unlabeled_category"] = "unknown" - if use_size_factor: - setup_kwargs["size_factor_key"] = "size_factor" cls.setup_anndata( adata, **setup_kwargs, ) - # create and train the model - model = cls(adata, n_latent=5) + model = cls(adata, n_latent=n_latent) model.train(1, check_val_every_n_epoch=1, train_size=0.5) - # get the adata lib size - adata_lib_size = np.squeeze(np.asarray(adata_counts.sum(axis=1))) - assert ( - np.min(adata_lib_size) > 0 - ) # make sure it's not all zeros and there are no negative values + lib_size = np.squeeze(np.asarray(counts.sum(axis=-1))) - return model, adata, adata_lib_size, adata_before_setup + return model, adata, lib_size, adata_before_setup def assert_approx_equal(a, b): @@ -60,11 +64,11 @@ def assert_approx_equal(a, b): def run_test_for_model_with_minified_adata( - cls=SCVI, + cls: BaseMinifiedModeModelClass = SCVI, n_samples: int = 1, give_mean: bool = False, layer: str = None, - use_size_factor=False, + use_size_factor: bool = False, ): model, adata, adata_lib_size, _ = prep_model(cls, layer, use_size_factor) @@ -80,19 +84,14 @@ def run_test_for_model_with_minified_adata( assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR assert model.adata_manager.registry is model.registry_ - # make sure the original adata we set up the model with was not changed + assert not _is_minified(adata) assert adata is not model.adata - assert _is_minified(adata) is False - assert adata_orig.layers.keys() == model.adata.layers.keys() orig_obs_df = adata_orig.obs - obs_keys = _SCANVI_OBSERVED_LIB_SIZE if cls == SCANVI else _SCVI_OBSERVED_LIB_SIZE - orig_obs_df[obs_keys] = adata_lib_size + orig_obs_df[BaseMinifiedModeModelClass._OBSERVED_LIB_SIZE_KEY] = adata_lib_size assert model.adata.obs.equals(orig_obs_df) assert model.adata.var_names.equals(adata_orig.var_names) assert model.adata.var.equals(adata_orig.var) - assert model.adata.varm.keys() == adata_orig.varm.keys() - np.testing.assert_array_equal(model.adata.varm["my_varm"], adata_orig.varm["my_varm"]) scvi.settings.seed = 1 keys = ["mean", "dispersions", "dropout"] @@ -161,7 +160,9 @@ def test_scanvi_from_scvi(save_path): scvi.model.SCANVI.from_scvi_model(model, "label_0") msg = ( - "We cannot use the given scvi model to initialize scanvi because it has a minified adata." + "We cannot use the given scVI model to initialize scANVI because it has minified adata. " + "Keep counts when minifying model using minified_data_type=" + "'latent_posterior_parameters_with_counts'." ) assert str(e.value) == msg @@ -174,7 +175,7 @@ def test_scanvi_from_scvi(save_path): adata2.uns[_ADATA_MINIFY_TYPE_UNS_KEY] = ADATA_MINIFY_TYPE.LATENT_POSTERIOR with pytest.raises(ValueError) as e: scvi.model.SCANVI.from_scvi_model(loaded_model, "label_0", adata=adata2) - assert str(e.value) == "Please provide a non-minified `adata` to initialize scanvi." + assert str(e.value) == "Please provide a non-minified `adata` to initialize scANVI." scanvi_model = scvi.model.SCANVI.from_scvi_model(loaded_model, "label_0") scanvi_model.train(1) @@ -263,6 +264,43 @@ def test_validate_unsupported_if_minified(): model.get_latent_library_size() assert str(e.value) == common_err_msg.format("RNASeqMixin.get_latent_library_size") + with pytest.raises(ValueError) as e: + model.train() + assert str(e.value) == common_err_msg.format("VAE.loss") + + +def test_validate_supported_if_minified_keep_count(): + model, _, _, _ = prep_model() + model2, _, _, _ = prep_model() + + qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) + model.adata.obsm["X_latent_qzm"] = qzm + model.adata.obsm["X_latent_qzv"] = qzv + + model.minify_adata(minified_data_type="latent_posterior_parameters_with_counts") + assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR_WITH_COUNTS + assert model2.minified_data_type is None + + assert np.allclose(model2.get_elbo(), model.get_elbo(), rtol=5e-2) + assert np.allclose( + model2.get_reconstruction_error()["reconstruction_loss"], + model.get_reconstruction_error()["reconstruction_loss"], + rtol=5e-2, + ) + assert np.allclose(model2.get_marginal_ll(), model.get_marginal_ll(), rtol=5e-2) + + model.train(1, check_val_every_n_epoch=1, train_size=0.5) + model.train( + 1, check_val_every_n_epoch=1, train_size=0.5, plan_kwargs={"update_only_decoder": True} + ) + scanvi_model = scvi.model.SCANVI.from_scvi_model( + model, labels_key="labels", unlabeled_category="unknown" + ) + scanvi_model.train() + scanvi_model.train( + 1, check_val_every_n_epoch=1, train_size=0.5, plan_kwargs={"update_only_decoder": True} + ) + def test_scvi_with_minified_adata_save_then_load(save_path): # create a model and minify its adata, then save it and its adata.