Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support for minification in totalVI #3061

Merged
merged 14 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ to [Semantic Versioning]. Full commit history is available in the
- Refactored code for minified models. {pr}`2883`.
- Add {class}`scvi.external.METHYLVI` for modeling methylation data from single-cell
bisulfite sequencing (scBS-seq) experiments {pr}`2834`.
- Add MuData Minification option to {class}`~scvi.model.TOTALVI` {pr}`3061`.

#### Fixed

Expand Down
4 changes: 3 additions & 1 deletion src/scvi/data/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,10 +311,12 @@ def _get_adata_minify_type(adata: AnnData) -> MinifiedDataType | None:
return adata.uns.get(_constants._ADATA_MINIFY_TYPE_UNS_KEY, None)


def _is_minified(adata: AnnData | str) -> bool:
def _is_minified(adata: AnnOrMuData | str) -> bool:
uns_key = _constants._ADATA_MINIFY_TYPE_UNS_KEY
if isinstance(adata, AnnData):
return adata.uns.get(uns_key, None) is not None
elif isinstance(adata, MuData):
return adata.uns.get(uns_key, None) is not None
elif isinstance(adata, str):
with h5py.File(adata) as fp:
return uns_key in read_elem(fp["uns"]).keys()
Expand Down
15 changes: 11 additions & 4 deletions src/scvi/model/_scanvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ class SCANVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseMinifiedModeModelClass):

_module_cls = SCANVAE
_training_plan_cls = SemiSupervisedTrainingPlan
_LATENT_QZM = "scanvi_latent_qzm"
_LATENT_QZV = "scanvi_latent_qzv"
_LATENT_QZM_KEY = "scanvi_latent_qzm"
_LATENT_QZV_KEY = "scanvi_latent_qzv"

def __init__(
self,
Expand Down Expand Up @@ -132,7 +132,10 @@ def __init__(
n_batch = self.summary_stats.n_batch
use_size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry
library_log_means, library_log_vars = None, None
if not use_size_factor_key and self.minified_data_type is None:
if (
not use_size_factor_key
and self.minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR
):
library_log_means, library_log_vars = _init_library_size(self.adata_manager, n_batch)

self.module = self._module_cls(
Expand Down Expand Up @@ -237,6 +240,7 @@ def from_scvi_model(
cls.setup_anndata(
adata,
unlabeled_category=unlabeled_category,
use_minified=False,
**scvi_setup_args,
)
scanvi_model = cls(adata, **non_kwargs, **kwargs, **scanvi_kwargs)
Expand Down Expand Up @@ -443,6 +447,7 @@ def setup_anndata(
size_factor_key: str | None = None,
categorical_covariate_keys: list[str] | None = None,
continuous_covariate_keys: list[str] | None = None,
use_minified: bool = True,
**kwargs,
):
"""%(summary)s.
Expand All @@ -457,6 +462,8 @@ def setup_anndata(
%(param_size_factor_key)s
%(param_cat_cov_keys)s
%(param_cont_cov_keys)s
use_minified
If True, will register the minified version of the adata if possible.
"""
setup_method_args = cls._get_setup_method_args(**locals())
anndata_fields = [
Expand All @@ -469,7 +476,7 @@ def setup_anndata(
]
# register new fields if the adata is minified
adata_minify_type = _get_adata_minify_type(adata)
if adata_minify_type is not None:
if adata_minify_type is not None and use_minified:
canergen marked this conversation as resolved.
Show resolved Hide resolved
anndata_fields += cls._get_fields_for_adata_minification(adata_minify_type)
adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
adata_manager.register_fields(adata, **kwargs)
Expand Down
10 changes: 7 additions & 3 deletions src/scvi/model/_scvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from scvi import REGISTRY_KEYS, settings
from scvi.data import AnnDataManager
from scvi.data._constants import ADATA_MINIFY_TYPE
from scvi.data._utils import _get_adata_minify_type
from scvi.data.fields import (
CategoricalJointObsField,
Expand Down Expand Up @@ -100,8 +101,8 @@ class SCVI(
"""

_module_cls = VAE
_SCVI_LATENT_QZM = "scvi_latent_qzm"
_SCVI_LATENT_QZV = "scvi_latent_qzv"
_LATENT_QZM_KEY = "scvi_latent_qzm"
_LATENT_QZV_KEY = "scvi_latent_qzv"

def __init__(
self,
Expand Down Expand Up @@ -151,7 +152,10 @@ def __init__(
n_batch = self.summary_stats.n_batch
use_size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry
library_log_means, library_log_vars = None, None
if not use_size_factor_key and self.minified_data_type is None:
if (
not use_size_factor_key
and self.minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR
):
library_log_means, library_log_vars = _init_library_size(
self.adata_manager, n_batch
)
Expand Down
66 changes: 47 additions & 19 deletions src/scvi/model/_totalvi.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

didnt we want to have a warning for setup_anndata: "TOTALVI is supposed to work with MuData"

Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

from scvi import REGISTRY_KEYS, settings
from scvi.data import AnnDataManager, fields
from scvi.data._utils import _check_nonnegative_integers
from scvi.data._constants import ADATA_MINIFY_TYPE
from scvi.data._utils import _check_nonnegative_integers, _get_adata_minify_type
from scvi.dataloaders import DataSplitter
from scvi.model._utils import (
_get_batch_code_from_category,
Expand All @@ -26,7 +27,13 @@
from scvi.train import AdversarialTrainingPlan, TrainRunner
from scvi.utils._docstrings import de_dsp, devices_dsp, setup_anndata_dsp

from .base import ArchesMixin, BaseModelClass, RNASeqMixin, VAEMixin
from .base import (
ArchesMixin,
BaseMinifiedModeModelClass,
BaseMudataMinifiedModeModelClass,
RNASeqMixin,
VAEMixin,
)

if TYPE_CHECKING:
from collections.abc import Iterable, Sequence
Expand All @@ -35,18 +42,25 @@
from anndata import AnnData
from mudata import MuData

from scvi._types import Number
from scvi._types import AnnOrMuData, Number

logger = logging.getLogger(__name__)


class TOTALVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseModelClass):
class TOTALVI(
RNASeqMixin,
VAEMixin,
ArchesMixin,
BaseMinifiedModeModelClass,
BaseMudataMinifiedModeModelClass,
):
"""total Variational Inference :cite:p:`GayosoSteier21`.
Parameters
canergen marked this conversation as resolved.
Show resolved Hide resolved
----------
adata
AnnData object that has been registered via :meth:`~scvi.model.TOTALVI.setup_anndata`.
AnnOrMuData object that has been registered via :meth:`~scvi.model.TOTALVI.setup_anndata`
or :meth:`~scvi.model.TOTALVI.setup_mudata`.
n_latent
Dimensionality of the latent space.
gene_dispersion
Expand Down Expand Up @@ -84,13 +98,12 @@ class TOTALVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseModelClass):
Examples
--------
>>> adata = anndata.read_h5ad(path_to_anndata)
>>> scvi.model.TOTALVI.setup_anndata(
adata, batch_key="batch", protein_expression_obsm_key="protein_expression"
)
>>> vae = scvi.model.TOTALVI(adata)
>>> mdata = mudata.read_h5mu(path_to_mudata)
>>> scvi.model.TOTALVI.setup_mudata(
mdata, modalities={"rna_layer": "rna", "protein_layer": "prot"}
>>> vae = scvi.model.TOTALVI(mdata)
>>> vae.train()
>>> adata.obsm["X_totalVI"] = vae.get_latent_representation()
>>> mdata.obsm["X_totalVI"] = vae.get_latent_representation()
Notes
-----
Expand All @@ -102,13 +115,15 @@ class TOTALVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseModelClass):
"""

_module_cls = TOTALVAE
_LATENT_QZM_KEY = "totalvi_latent_qzm"
_LATENT_QZV_KEY = "totalvi_latent_qzv"
_data_splitter_cls = DataSplitter
_training_plan_cls = AdversarialTrainingPlan
_train_runner_cls = TrainRunner

def __init__(
self,
adata: AnnData,
adata: AnnOrMuData,
n_latent: int = 20,
gene_dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"] = "gene",
protein_dispersion: Literal["protein", "protein-batch", "protein-label"] = "protein",
Expand All @@ -129,10 +144,10 @@ def __init__(
batch_mask = self.protein_state_registry.protein_batch_mask
msg = (
"Some proteins have all 0 counts in some batches. "
+ "These proteins will be treated as missing measurements; however, "
+ "this can occur due to experimental design/biology. "
+ "Reinitialize the model with `override_missing_proteins=True`,"
+ "to override this behavior."
"These proteins will be treated as missing measurements; however, "
"this can occur due to experimental design/biology. "
"Reinitialize the model with `override_missing_proteins=True`,"
"to override this behavior."
)
warnings.warn(msg, UserWarning, stacklevel=settings.warnings_stacklevel)
self._use_adversarial_classifier = True
Expand All @@ -145,7 +160,7 @@ def __init__(
if empirical_protein_background_prior is not None
else (self.summary_stats.n_proteins > 10)
)
if emp_prior:
if emp_prior and self.minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR:
prior_mean, prior_scale = self._get_totalvi_protein_priors(adata)
else:
prior_mean, prior_scale = None, None
Expand All @@ -161,7 +176,10 @@ def __init__(
n_batch = self.summary_stats.n_batch
use_size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry
library_log_means, library_log_vars = None, None
if not use_size_factor_key:
if (
not use_size_factor_key
and self.minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR
):
library_log_means, library_log_vars = _init_library_size(self.adata_manager, n_batch)

self.module = self._module_cls(
Expand Down Expand Up @@ -1004,6 +1022,7 @@ def get_feature_correlation_matrix(
batch_size=batch_size,
rna_size_factor=rna_size_factor,
transform_batch=b,
indices=indices,
canergen marked this conversation as resolved.
Show resolved Hide resolved
)
flattened = np.zeros((denoised_data.shape[0] * n_samples, denoised_data.shape[1]))
for i in range(n_samples):
Expand Down Expand Up @@ -1214,6 +1233,12 @@ def setup_anndata(
-------
%(returns)s
"""
warnings.warn(
"We recommend using setup_mudata for multi-modal data."
"It does not influence model performance",
DeprecationWarning,
stacklevel=settings.warnings_stacklevel,
)
setup_method_args = cls._get_setup_method_args(**locals())
batch_field = fields.CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key)
anndata_fields = [
Expand Down Expand Up @@ -1275,7 +1300,7 @@ def setup_mudata(
--------
>>> mdata = muon.read_10x_h5("pbmc_10k_protein_v3_filtered_feature_bc_matrix.h5")
>>> scvi.model.TOTALVI.setup_mudata(
mdata, modalities={"rna_layer": "rna": "protein_layer": "prot"}
mdata, modalities={"rna_layer": "rna", "protein_layer": "prot"}
)
>>> vae = scvi.model.TOTALVI(mdata)
"""
Expand Down Expand Up @@ -1330,6 +1355,9 @@ def setup_mudata(
mod_required=True,
),
]
mdata_minify_type = _get_adata_minify_type(mdata)
if mdata_minify_type is not None:
mudata_fields += cls._get_fields_for_mudata_minification(mdata_minify_type)
adata_manager = AnnDataManager(fields=mudata_fields, setup_method_args=setup_method_args)
adata_manager.register_fields(mdata, **kwargs)
cls.register_manager(adata_manager)
3 changes: 3 additions & 0 deletions src/scvi/model/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from lightning.pytorch.trainer.connectors.accelerator_connector import (
_AcceleratorConnector,
)
from scipy.sparse import issparse

from scvi import REGISTRY_KEYS, settings
from scvi._types import Number
Expand Down Expand Up @@ -244,6 +245,8 @@ def cite_seq_raw_counts_properties(

nan = np.array([np.nan] * adata_manager.summary_stats.n_proteins)
protein_exp = adata_manager.get_from_registry(REGISTRY_KEYS.PROTEIN_EXP_KEY)
if issparse(protein_exp):
protein_exp = protein_exp.toarray()
mean1_pro = np.asarray(protein_exp[idx1].mean(0))
mean2_pro = np.asarray(protein_exp[idx2].mean(0))
nonz1_pro = np.asarray((protein_exp[idx1] > 0).mean(0))
Expand Down
7 changes: 6 additions & 1 deletion src/scvi/model/base/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from ._archesmixin import ArchesMixin
from ._base_model import BaseMinifiedModeModelClass, BaseModelClass
from ._base_model import (
BaseMinifiedModeModelClass,
BaseModelClass,
BaseMudataMinifiedModeModelClass,
)
from ._differential import DifferentialComputation
from ._embedding_mixin import EmbeddingMixin
from ._jaxmixin import JaxTrainingMixin
Expand All @@ -26,5 +30,6 @@
"DifferentialComputation",
"JaxTrainingMixin",
"BaseMinifiedModeModelClass",
"BaseMudataMinifiedModeModelClass",
"EmbeddingMixin",
]
Loading
Loading