diff --git a/docs/user_guide/index.md b/docs/user_guide/index.md index 02ccb667c1..f2085dd599 100644 --- a/docs/user_guide/index.md +++ b/docs/user_guide/index.md @@ -82,6 +82,9 @@ scvi-tools is composed of models that can perform one or many analysis tasks. In * - :doc:`/user_guide/models/methylvi` - Dimensionality reduction, removal of unwanted variation, integration across replicates, donors, and technologies, differential methylation, imputation, normalization of other cell- and sample-level confounding factors - :cite:p:`Weinberger2023a` + * - :doc:`/user_guide/models/methylanvi` + - MethylVI tasks along with cell type label transfer from reference, seed labeling + - :cite:p:`Weinberger2023a` ``` ## Multimodal analysis diff --git a/docs/user_guide/models/methylanvi.md b/docs/user_guide/models/methylanvi.md new file mode 100644 index 0000000000..e7986fea6e --- /dev/null +++ b/docs/user_guide/models/methylanvi.md @@ -0,0 +1,145 @@ +# MethylANVI + +**MethylANVI** [^ref1] (Python class {class}`~scvi.external.METHYLANVI`) is a semi-supervised generative model of scBS-seq data. +Similar to how scANVI extends scVI, MethylANVI can be treated as an extension of MethylVI that can leverage cell type annotations +for a subset of the cells present in the data sets to infer the states of the rest of the cells + +The advantages of MethylANVI are: + +- Comprehensive in capabilities. +- Scalable to very large datasets (>1 million cells). + +The limitations of MethylANVI include: + +- Effectively requires a GPU for fast inference. +- Latent space is not interpretable, unlike that of a linear method. +- May not scale to very large number of cell types. + +```{topic} Tutorials: + +- Work in progress. +``` + +## Preliminaries + +MethylANVI takes as input scBS-seq count matrices representing methylation measurements aggregated over pre-defined +regions of interest (e.g. gene bodies, known regulatory regions, etc.). Depending on the system being investigated, +such measurements may be separated based on methylation context (e.g. CpG methylation versus non-CpG methylation). +For each context, MethylANVI accepts two count matrices as input $Y^{C}_{mc}$ and $Y^{C}_{cov}$. Here $C$ refers to +an arbitrary methylation context, and each of these matrices has data from $N$ cells and $M$ genomic regions. +Each entry in $Y_{cov}$ represents the _total_ number of cytosines profiled at a given region in a cell, while the +entries in $Y_{mc}$ denote the number of _methylated_ cytosines in a region for a cell. + +In addition to methylation measurements, MethylANVI takes as input a vector of partially observed cell-type labels $\mathbf{l}$, +where $L$ denotes the total number of cell types. Additionally, a vector of categorical covariates $S$, representing batch, +donor, etc, is an optional input to the model. + +## Generative process + +MethylANVI posits that the observed number of methylated cytosines in context $C$ for cell $i$ in region $j$, +$y^{C}_{ij}$, is generated by the following process: + +```{math} +:nowrap: true + +\begin{align} + l_i &\sim \text{Categorical}(1/L, \ldots, 1/L) \\ + u_i &\sim \mathcal{N}(0, I_d) \\ + z_{i} &\sim \mathcal{N}(f_z^{\mu}(u_i, l_i), f_z^{\sigma}(u_i, l_i)) \\ + \mu^{C}_{ij} &= f_{\theta^{C}}(z_{i}, s_i)_j \\ + p^{C}_{ijk} &\sim \text{Beta}(\mu^{C}_{ij}, \gamma^{C}_j) \\ + y^{C}_{ijk} &\sim \text{Ber}(p^{C}_{ijk}) \\ + y^{C}_{ij} &= \sum_{k}y_{ijk} +\end{align} +``` + +Equivalently, we can express this process more compactly as + +```{math} +:nowrap: true + +\begin{align} + l_i &\sim \text{Categorical}(1/L, \ldots, 1/L) \\ + u_i &\sim \mathcal{N}(0, I_d) \\ + z_{i} &\sim \mathcal{N}(f_z^{\mu}(u_i, l_i), f_z^{\sigma}(u_i, l_i)) \\ + z_{i} &\sim \mathcal{N}(0, I_d) \\ + \mu^{C}_{ij} &= f_{\theta^{C}}(z_{i}, s_i)_j \\ + y^{C}_{ij} &\sim \text{BetaBinomial}(n^{C}_{ij}, \mu^{C}_{ij}, \gamma^{C}_{j}) +\end{align} +``` + +We assume no prior knowledge on the distribution of cell types in the data (i.e., we place a uniform prior on the +distribution of cell type labels). Within-cell-type variations $u_i$ are assumed to follow a fixed standard normal distribution, +while the distribution over the cell-type-aware latent variables $z_i$ depend on the learnable neural networks $f_z^{\mu}$ and +$f_z^{\sigma}$. The variables $z_i$ summarize a cell's state as a low-dimensional vector, and have a similar interpretation +as with MethylVI. However, by incorporating cell type labels into the model, MethylANVI may learn a better structured +latent space compared to MethylVI. + +The remainder of the model closely follows MethylVI. In particular, observed methylated cytosine counts are assumed +to follow a beta-binomial distribution conditioned on a cell's underlying state $z_i$ as well as batch covariates $s_i$. + +In addition to the variables defined for {doc}`/user_guide/models/methylvi`, we have the following variables for MethylANVI: + +```{eval-rst} +.. list-table:: + :widths: 20 90 15 + :header-rows: 1 + + * - Latent variable + - Description + - Code variable (if different) + * - :math:`l_i \in \Delta^{L-1}` + - Cell type label + - ``y`` + * - :math:`z_i \in \mathbb{R}^d` + - Latent cell state + - ``z_1`` + * - :math:`u_i \in \mathbb{R}^{d}` + - Latent cell-type specific state + - ``z_2`` +``` + +## Inference + +MethylANVI posits the following factorized distribution for posterior inference + +:nowrap: true + +\begin{align} + q_\phi(z_i, u_i, c_i \mid y_i, n_i, s_i) + = + q_\phi(z_i \mid y_i, n_i, s_i) + q_\phi(c_i \mid z_i) + q_\phi(u_i \mid c_i, z_i) +\end{align} + +Each of the individual variational distributions in our factorized expression is parameterized by neural +networks. Here $q_\phi(z_i \mid y_i, n_i, s_i)$ and $q_\phi(u_i \mid c_i, z_i)$ follow Gaussian distributions, while +$q_\phi(c_i \mid z_i)$ represents a Categorical distribution over cell types. Notably, $q_\phi(c_i \mid z_i)$ can be +leveraged post-training to predict cell types for an unlabeled cell. For this classification procedure, under the hood +we use as input the mean of the variational distribution $q_\phi(z_i \mid y_i, n_i, s_i)$. + +## Training details + +MethylANVI optimizes two evidence lower bounds (ELBOs) on the log evidence, with the two bounds corresponding to labeled +and unlabeled cells. These bounds largely mirror those of scANVI, with appropriate substitutions made to account for scBS-seq +observations. We refer the reader to the {doc}`/user_guide/models/scanvi` documentation for further details. + +## Tasks + +MethylANVI can perform the same tasks as MethylVI (see {doc}`/user_guide/models/methylvi`). In addition, MethylANVI can +do the following: + +### Cell type label prediction + +For cell type label prediction, MethylANVI returns the distribution $q_{\phi}(l_i \mid z_i)$ in the following +function: + +``` +>>> mdata.obs["methylanvi_prediction"] = model.predict() +``` + +[^ref1]: + Ethan Weinberger and Su-In Lee (2021), + _A deep generative model of single-cell methylomic data_, + [OpenReview](https://openreview.net/forum?id=Mg2DM0F3AY). diff --git a/src/scvi/data/fields/__init__.py b/src/scvi/data/fields/__init__.py index 9aa2bf394e..140bc4a834 100644 --- a/src/scvi/data/fields/__init__.py +++ b/src/scvi/data/fields/__init__.py @@ -26,7 +26,7 @@ from ._layer_field import LayerField, MuDataLayerField from ._mudata import BaseMuDataWrapperClass, MuDataWrapper from ._protein import MuDataProteinLayerField, ProteinObsmField -from ._scanvi import LabelsWithUnlabeledObsField +from ._scanvi import LabelsWithUnlabeledObsField, MuDataLabelsWithUnlabeledObsField from ._uns_field import StringUnsField __all__ = [ @@ -59,5 +59,6 @@ "MuDataCategoricalJointVarField", "ProteinObsmField", "LabelsWithUnlabeledObsField", + "MuDataLabelsWithUnlabeledObsField", "StringUnsField", ] diff --git a/src/scvi/data/fields/_scanvi.py b/src/scvi/data/fields/_scanvi.py index de264179fe..6e1d443b5a 100644 --- a/src/scvi/data/fields/_scanvi.py +++ b/src/scvi/data/fields/_scanvi.py @@ -8,6 +8,7 @@ from scvi.data._utils import _make_column_categorical, _set_data_in_registry from ._dataframe_field import CategoricalObsField +from ._mudata import MuDataWrapper class LabelsWithUnlabeledObsField(CategoricalObsField): @@ -107,3 +108,6 @@ def transfer_field( ) mapping = transfer_state_registry[self.CATEGORICAL_MAPPING_KEY] return self._remap_unlabeled_to_final_category(adata_target, mapping) + + +MuDataLabelsWithUnlabeledObsField = MuDataWrapper(LabelsWithUnlabeledObsField) diff --git a/src/scvi/dataloaders/_data_splitting.py b/src/scvi/dataloaders/_data_splitting.py index 9ea0146acb..5504fad503 100644 --- a/src/scvi/dataloaders/_data_splitting.py +++ b/src/scvi/dataloaders/_data_splitting.py @@ -410,6 +410,7 @@ def __init__( adata_manager.adata, adata_manager.data_registry.labels.attr_name, labels_state_registry.original_key, + mod_key=getattr(self.adata_manager.data_registry.labels, "mod_key", None), ).ravel() self.unlabeled_category = labels_state_registry.unlabeled_category self._unlabeled_indices = np.argwhere(labels == self.unlabeled_category).ravel() diff --git a/src/scvi/dataloaders/_semi_dataloader.py b/src/scvi/dataloaders/_semi_dataloader.py index 545f3c6d9b..4df1b591e1 100644 --- a/src/scvi/dataloaders/_semi_dataloader.py +++ b/src/scvi/dataloaders/_semi_dataloader.py @@ -59,6 +59,7 @@ def __init__( adata_manager.adata, adata_manager.data_registry.labels.attr_name, labels_state_registry.original_key, + mod_key=getattr(adata_manager.data_registry.labels, "mod_key", None), ).ravel() # save a nested list of the indices per labeled category diff --git a/src/scvi/external/__init__.py b/src/scvi/external/__init__.py index 8e9e449f43..a172fc2715 100644 --- a/src/scvi/external/__init__.py +++ b/src/scvi/external/__init__.py @@ -2,7 +2,7 @@ from .contrastivevi import ContrastiveVI from .decipher import Decipher from .gimvi import GIMVI -from .methylvi import METHYLVI +from .methylvi import METHYLANVI, METHYLVI from .mrvi import MRVI from .poissonvi import POISSONVI from .resolvi import RESOLVI @@ -28,5 +28,6 @@ "VELOVI", "MRVI", "METHYLVI", + "METHYLANVI", "RESOLVI", ] diff --git a/src/scvi/external/methylvi/__init__.py b/src/scvi/external/methylvi/__init__.py index 7ed81dc07f..87442bea9d 100644 --- a/src/scvi/external/methylvi/__init__.py +++ b/src/scvi/external/methylvi/__init__.py @@ -1,6 +1,7 @@ from ._base_components import DecoderMETHYLVI from ._constants import METHYLVI_REGISTRY_KEYS -from ._model import METHYLVI as METHYLVI -from ._module import METHYLVAE +from ._methylanvi_model import METHYLANVI as METHYLANVI +from ._methylvi_model import METHYLVI as METHYLVI +from ._methylvi_module import METHYLVAE -__all__ = ["METHYLVI_REGISTRY_KEYS", "DecoderMETHYLVI", "METHYLVAE", "METHYLVI"] +__all__ = ["METHYLVI_REGISTRY_KEYS", "DecoderMETHYLVI", "METHYLVAE", "METHYLVI", "METHYLANVI"] diff --git a/src/scvi/external/methylvi/_base_components.py b/src/scvi/external/methylvi/_base_components.py index c8690054b2..e618fabd9a 100644 --- a/src/scvi/external/methylvi/_base_components.py +++ b/src/scvi/external/methylvi/_base_components.py @@ -1,10 +1,414 @@ -from collections.abc import Iterable +from __future__ import annotations + +import logging +import warnings +from collections import defaultdict +from functools import partial +from typing import TYPE_CHECKING -import torch from torch import nn +from torch.distributions import Binomial +from scvi.distributions import BetaBinomial +from scvi.external.methylvi._utils import METHYLVI_REGISTRY_KEYS from scvi.nn import FCLayers +if TYPE_CHECKING: + from collections.abc import Iterable, Sequence + from typing import Literal + + from mudata import MuData + + from scvi._types import Number + +import numpy as np +import pandas as pd +import torch + +from scvi import settings +from scvi.model.base._de_core import ( + _de_core, +) + +from ._utils import scmc_raw_counts_properties + +logger = logging.getLogger(__name__) + + +class BSSeqMixin: + """General purpose methods for BS-seq analysis.""" + + @torch.inference_mode() + def get_normalized_methylation( + self, + mdata: MuData | None = None, + indices: Sequence[int] | None = None, + region_list: Sequence[str] | None = None, + n_samples: int = 1, + n_samples_overall: int = None, + batch_size: int | None = None, + return_mean: bool = True, + return_numpy: bool | None = None, + context: str | None = None, + **importance_weighting_kwargs, + ) -> (np.ndarray | pd.DataFrame) | dict[str, np.ndarray | pd.DataFrame]: + r"""Returns the normalized (decoded) methylation. + + This is denoted as :math:`\mu_n` in the methylVI paper. + + Parameters + ---------- + mdata + MuData object with equivalent structure to initial Mudata. + If `None`, defaults to the MuData object used to initialize the model. + indices + Indices of cells in mdata to use. If `None`, all cells are used. + region_list + Return frequencies of expression for a subset of regions. + This can save memory when working with large datasets and few regions are + of interest. + n_samples + Number of posterior samples to use for estimation. + n_samples_overall + Number of posterior samples to use for estimation. Overrides `n_samples`. + batch_size + Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. + return_mean + Whether to return the mean of the samples. + return_numpy + Return a :class:`~numpy.ndarray` instead of a :class:`~pandas.DataFrame`. + DataFrame includes region names as columns. If either `n_samples=1` or + `return_mean=True`, defaults to `False`. Otherwise, it defaults to `True`. + context + If not `None`, returns normalized methylation levels for the specified + methylation context. Otherwise, a dictionary with contexts as keys and normalized + methylation levels as values is returned. + + Returns + ------- + If `n_samples` is provided and `return_mean` is False, + this method returns a 3d tensor of shape (n_samples, n_cells, n_regions). + If `n_samples` is provided and `return_mean` is True, it returns a 2d tensor + of shape (n_cells, n_regions). + In this case, return type is :class:`~pandas.DataFrame` unless `return_numpy` is True. + Otherwise, the method expects `n_samples_overall` to be provided and returns a 2d tensor + of shape (n_samples_overall, n_regions). + + If model was set up using a MuData object, a dictionary is returned with keys + corresponding to individual methylation contexts with values determined as + described above. + """ + mdata = self._validate_anndata(mdata) + + if context is not None and context not in self.contexts: + raise ValueError( + f"{context} is not a valid methylation context for this model. " + f"Valid contexts are {self.contexts}." + ) + + if indices is None: + indices = np.arange(mdata.n_obs) + if n_samples_overall is not None: + assert n_samples == 1 # default value + n_samples = n_samples_overall // len(indices) + 1 + scdl = self._make_data_loader(adata=mdata, indices=indices, batch_size=batch_size) + + region_mask = slice(None) if region_list is None else mdata.var_names.isin(region_list) + + if n_samples > 1 and return_mean is False: + if return_numpy is False: + warnings.warn( + "`return_numpy` must be `True` if `n_samples > 1` and `return_mean` " + "is`False`, returning an `np.ndarray`.", + UserWarning, + stacklevel=settings.warnings_stacklevel, + ) + return_numpy = True + + exprs = defaultdict(list) + + for tensors in scdl: + inference_kwargs = {"n_samples": n_samples} + inference_outputs, generative_outputs = self.module.forward( + tensors=tensors, + inference_kwargs=inference_kwargs, + generative_kwargs={}, + compute_loss=False, + ) + + for ctxt in self.contexts: + exp_ = generative_outputs["px_mu"][ctxt] + exp_ = exp_[..., region_mask] + exprs[ctxt].append(exp_.cpu()) + + cell_axis = 1 if n_samples > 1 else 0 + + for ctxt in self.contexts: + exprs[ctxt] = np.concatenate(exprs[ctxt], axis=cell_axis) + + if n_samples_overall is not None: + # Converts the 3d tensor to a 2d tensor + for ctxt in self.contexts: + exprs[ctxt] = exprs[ctxt].reshape(-1, exprs[ctxt].shape[-1]) + n_samples_ = exprs[ctxt].shape[0] + ind_ = np.random.choice(n_samples_, n_samples_overall, replace=True) + exprs[ctxt] = exprs[ctxt][ind_] + return_numpy = True + + elif n_samples > 1 and return_mean: + for ctxt in self.contexts: + exprs[ctxt] = exprs[ctxt].mean(0) + + if return_numpy is None or return_numpy is False: + exprs_dfs = {} + for ctxt in self.contexts: + exprs_dfs[ctxt] = pd.DataFrame( + exprs[ctxt], + columns=mdata[ctxt].var_names[region_mask], + index=mdata[ctxt].obs_names[indices], + ) + exprs_ = exprs_dfs + else: + exprs_ = exprs + + if context is not None: + return exprs_[context] + else: + return exprs_ + + @torch.inference_mode() + def get_specific_normalized_methylation( + self, + mdata: MuData | None = None, + context: str = None, + indices: Sequence[int] | None = None, + transform_batch: Sequence[Number | str] | None = None, + region_list: Sequence[str] | None = None, + n_samples: int = 1, + n_samples_overall: int = None, + weights: Literal["uniform", "importance"] | None = None, + batch_size: int | None = None, + return_mean: bool = True, + return_numpy: bool | None = None, + **importance_weighting_kwargs, + ) -> (np.ndarray | pd.DataFrame) | dict[str, np.ndarray | pd.DataFrame]: + r"""Convenience function to obtain normalized methylation values for a single context. + + Parameters + ---------- + mdata + MuData object with equivalent structure to initial MuData. If `None`, defaults to the + MuData object used to initialize the model. + context + Methylation context for which to obtain normalized methylation levels. + indices + Indices of cells in mdata to use. If `None`, all cells are used. + transform_batch + Batch to condition on. + If transform_batch is: + + - None, then real observed batch is used. + - int, then batch transform_batch is used. + region_list + Return frequencies of expression for a subset of regions. + This can save memory when working with large datasets and few regions are + of interest. + n_samples + Number of posterior samples to use for estimation. + n_samples_overall + Number of posterior samples to use for estimation. Overrides `n_samples`. + weights + Weights to use for sampling. If `None`, defaults to `"uniform"`. + batch_size + Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. + return_mean + Whether to return the mean of the samples. + return_numpy + Return a :class:`~numpy.ndarray` instead of a :class:`~pandas.DataFrame`. + DataFrame includes region names as columns. If either `n_samples=1` or + `return_mean=True`, defaults to `False`. Otherwise, it defaults to `True`. + importance_weighting_kwargs + Keyword arguments passed into + :meth:`~scvi.model.base.RNASeqMixin._get_importance_weights`. + + Returns + ------- + If `n_samples` is provided and `return_mean` is False, + this method returns a 3d tensor of shape (n_samples, n_cells, n_regions). + If `n_samples` is provided and `return_mean` is True, it returns a 2d tensor + of shape (n_cells, n_regions). + In this case, return type is :class:`~pandas.DataFrame` unless `return_numpy` is True. + Otherwise, the method expects `n_samples_overall` to be provided and returns a 2d tensor + of shape (n_samples_overall, n_regions). + """ + exprs = self.get_normalized_methylation( + mdata=mdata, + indices=indices, + transform_batch=transform_batch, + region_list=region_list, + n_samples=n_samples, + n_samples_overall=n_samples_overall, + weights=weights, + batch_size=batch_size, + return_mean=return_mean, + return_numpy=return_numpy, + **importance_weighting_kwargs, + ) + return exprs[context] + + def differential_methylation( + self, + mdata: MuData | None = None, + groupby: str | None = None, + group1: Iterable[str] | None = None, + group2: str | None = None, + idx1: Sequence[int] | Sequence[bool] | str | None = None, + idx2: Sequence[int] | Sequence[bool] | str | None = None, + mode: Literal["vanilla", "change"] = "vanilla", + delta: float = 0.05, + batch_size: int | None = None, + all_stats: bool = True, + batch_correction: bool = False, + batchid1: Iterable[str] | None = None, + batchid2: Iterable[str] | None = None, + fdr_target: float = 0.05, + silent: bool = False, + two_sided: bool = True, + **kwargs, + ) -> dict[str, pd.DataFrame] | pd.DataFrame: + r"""\. + + A unified method for differential methylation analysis. + + Implements `"vanilla"` DE :cite:p:`Lopez18`. and `"change"` mode DE :cite:p:`Boyeau19`. + + Parameters + ---------- + %(de_mdata)s + %(de_modality)s + %(de_groupby)s + %(de_group1)s + %(de_group2)s + %(de_idx1)s + %(de_idx2)s + %(de_mode)s + %(de_delta)s + %(de_batch_size)s + %(de_all_stats)s + %(de_batch_correction)s + %(de_batchid1)s + %(de_batchid2)s + %(de_fdr_target)s + %(de_silent)s + two_sided + Whether to perform a two-sided test, or a one-sided test. + **kwargs + Keyword args for :meth:`scvi.model.base.DifferentialComputation.get_bayes_factors` + + Returns + ------- + Differential methylation DataFrame with the following columns: + proba_de + the probability of the region being differentially methylated + is_de_fdr + whether the region passes a multiple hypothesis correction procedure + with the target_fdr threshold + bayes_factor + Bayes Factor indicating the level of significance of the analysis + effect_size + the effect size, computed as (accessibility in population 2) - + (accessibility in population 1) + emp_effect + the empirical effect, based on observed detection rates instead of the estimated + accessibility scores from the methylVI model + scale1 + the estimated methylation level in population 1 + scale2 + the estimated methylation level in population 2 + emp_mean1 + the empirical (observed) methylation level in population 1 + emp_mean2 + the empirical (observed) methylation level in population 2 + + """ + mdata = self._validate_anndata(mdata) + + def change_fn(a, b): + return a - b + + if two_sided: + + def m1_domain_fn(samples): + return np.abs(samples) >= delta + + else: + + def m1_domain_fn(samples): + return samples >= delta + + result = {} + for context in self.contexts: + col_names = mdata[context].var_names + model_fn = partial( + self.get_specific_normalized_methylation, + batch_size=batch_size, + context=context, + ) + all_stats_fn = partial(scmc_raw_counts_properties, context=context) + + result[context] = _de_core( + adata_manager=self.get_anndata_manager(mdata, required=True), + model_fn=model_fn, + representation_fn=None, + groupby=groupby, + group1=group1, + group2=group2, + idx1=idx1, + idx2=idx2, + all_stats=all_stats, + all_stats_fn=all_stats_fn, + col_names=col_names, + mode=mode, + batchid1=batchid1, + batchid2=batchid2, + delta=delta, + batch_correction=batch_correction, + fdr=fdr_target, + silent=silent, + change_fn=change_fn, + m1_domain_fn=m1_domain_fn, + **kwargs, + ) + + return result + + +class BSSeqModuleMixin: + """Shared methods for BS-seq VAE modules.""" + + data_input_keys = [METHYLVI_REGISTRY_KEYS.MC_KEY, METHYLVI_REGISTRY_KEYS.COV_KEY] + + def _compute_minibatch_reconstruction_loss(self, minibatch_size, tensors, generative_outputs): + reconst_loss = torch.zeros(minibatch_size).to(self.device) + + for context in self.contexts: + px_mu = generative_outputs["px_mu"][context] + px_gamma = generative_outputs["px_gamma"][context] + mc = tensors[f"{context}_{METHYLVI_REGISTRY_KEYS.MC_KEY}"] + cov = tensors[f"{context}_{METHYLVI_REGISTRY_KEYS.COV_KEY}"] + + if self.dispersion == "region": + px_gamma = torch.sigmoid(self.px_gamma[context]) + + if self.likelihood == "binomial": + dist = Binomial(probs=px_mu, total_count=cov) + elif self.likelihood == "betabinomial": + dist = BetaBinomial(mu=px_mu, gamma=px_gamma, total_count=cov) + + reconst_loss += -dist.log_prob(mc).sum(dim=-1) + + return reconst_loss + class DecoderMETHYLVI(nn.Module): """Decodes data from latent space of ``n_input`` dimensions into ``n_output`` dimensions. diff --git a/src/scvi/external/methylvi/_methylanvi_model.py b/src/scvi/external/methylvi/_methylanvi_model.py new file mode 100644 index 0000000000..64a5e6fd0a --- /dev/null +++ b/src/scvi/external/methylvi/_methylanvi_model.py @@ -0,0 +1,276 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Iterable + from typing import Literal + + from anndata import AnnData + from mudata import MuData + +import numpy as np + +from scvi import REGISTRY_KEYS +from scvi.data import AnnDataManager, fields +from scvi.data._constants import _SETUP_ARGS_KEY +from scvi.external.methylvi._base_components import BSSeqMixin +from scvi.external.methylvi._utils import _context_cov_key, _context_mc_key +from scvi.model.base import ( + ArchesMixin, + BaseModelClass, + SemisupervisedTrainingMixin, + VAEMixin, +) +from scvi.train import SemiSupervisedTrainingPlan +from scvi.utils import setup_anndata_dsp + +from ._methylanvi_module import METHYLANVAE + +logger = logging.getLogger(__name__) + + +class METHYLANVI(VAEMixin, SemisupervisedTrainingMixin, BSSeqMixin, ArchesMixin, BaseModelClass): + """Methylation annotation using variational inference :cite:p:`Weinberger23`. + + Inspired from M1 + M2 model, as described in (https://arxiv.org/pdf/1406.5298.pdf). + + Parameters + ---------- + mdata + MuData object registered via :meth:`~scvi.external.methylvi.METHYLVI.setup_mudata`. + n_hidden + Number of nodes per hidden layer. + n_latent + Dimensionality of the latent space. + n_layers + Number of hidden layers used for encoder and decoder NNs. + dropout_rate + Dropout rate for neural networks. + likelihood + One of + * ``'betabinomial'`` - BetaBinomial distribution + * ``'binomial'`` - Binomial distribution + dispersion + One of the following + * ``'region'`` - dispersion parameter of BetaBinomial is constant per region across cells + * ``'region-cell'`` - dispersion can differ for every region in every cell + linear_classifier + If ``True``, uses a single linear layer for classification instead of a + multi-layer perceptron. + **model_kwargs + Keyword args for :class:`~scvi.module.SCANVAE` + + Examples + -------- + >>> mdata = mudata.read_h5mu(path_to_mudata) + >>> scvi.external.methylvi.METHYLANVI.setup_mudata( + ... mdata, labels_key="labels", unlabeled_category="Unknown" + ... ) + >>> vae = scvi.external.methylvi.METHYLANVI(mdata) + >>> vae.train() + >>> mdata.obsm["X_scVI"] = vae.get_latent_representation() + >>> mdata.obs["pred_label"] = vae.predict() + + """ + + _module_cls = METHYLANVAE + _training_plan_cls = SemiSupervisedTrainingPlan + + def __init__( + self, + mdata: MuData, + n_hidden: int = 128, + n_latent: int = 10, + n_layers: int = 1, + dropout_rate: float = 0.1, + likelihood: Literal["betabinomial", "binomial"] = "betabinomial", + dispersion: Literal["region", "region-cell"] = "region", + linear_classifier: bool = False, + **model_kwargs, + ): + super().__init__(mdata) + methylanvae_model_kwargs = dict(model_kwargs) + + self._set_indices_and_labels() + + # ignores unlabeled category + n_labels = self.summary_stats.n_labels - 1 + n_cats_per_cov = ( + self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).n_cats_per_key + if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry + else None + ) + + n_batch = self.summary_stats.n_batch + + self.contexts = self.get_anndata_manager(mdata, required=True).registry[_SETUP_ARGS_KEY][ + "methylation_contexts" + ] + self.num_features_per_context = [mdata[context].shape[1] for context in self.contexts] + + n_input = np.sum(self.num_features_per_context) + + self.module = self._module_cls( + n_input=n_input, + n_batch=n_batch, + n_cats_per_cov=n_cats_per_cov, + n_labels=n_labels, + n_hidden=n_hidden, + n_latent=n_latent, + n_layers=n_layers, + dropout_rate=dropout_rate, + dispersion=dispersion, + likelihood=likelihood, + linear_classifier=linear_classifier, + contexts=self.contexts, + num_features_per_context=self.num_features_per_context, + **methylanvae_model_kwargs, + ) + + self.unsupervised_history_ = None + self.semisupervised_history_ = None + + self._model_summary_string = ( + f"MethylANVI Model with the following params: \nunlabeled_category: " + f"{self.unlabeled_category_}, n_hidden: {n_hidden}, n_latent: {n_latent}" + f", n_layers: {n_layers}, dropout_rate: {dropout_rate}, dispersion: " + f"{dispersion}, likelihood: {likelihood}" + ) + self.init_params_ = self._get_init_params(locals()) + self.was_pretrained = False + self.n_labels = n_labels + + @classmethod + @setup_anndata_dsp.dedent + def setup_anndata( + cls, + adata: AnnData, + **kwargs, + ) -> AnnData | None: + """ + %(summary)s. + + Parameters + ---------- + %(param_adata)s + + Returns + ------- + %(returns)s + """ + raise NotImplementedError("METHYLANVI must be used with a MuData object.") + + @classmethod + @setup_anndata_dsp.dedent + def setup_mudata( + cls, + mdata: MuData, + mc_layer: str, + cov_layer: str, + labels_key: str, + unlabeled_category: str, + methylation_contexts: Iterable[str], + batch_key: str | None = None, + categorical_covariate_keys: Iterable[str] | None = None, + modalities=None, + **kwargs, + ): + """%(summary_mdata)s. + + Parameters + ---------- + %(param_mdata)s + mc_layer + Layer containing methylated cytosine counts for each set of methylation features. + cov_layer + Layer containing total coverage counts for each set of methylation features. + labels_key + Obs field in `mdata` object containing cell type labels + unlabeled_category + Value of `mdata.obs[labels_key]` representing an unknown cell type label + methylation_contexts + List of modality fields in `mdata` object representing different methylation contexts. + Each context must be equipped with a layer containing the number of methylated counts + (specified by `mc_layer`) and total number of counts (specified by `cov_layer`) for + each genomic region feature. + %(param_batch_key)s + %(param_categorical_covariate_keys)s + %(param_modalities)s + + Examples + -------- + METHYLANVI.setup_mudata( + mdata, + mc_layer="mc", + cov_layer="cov", + labels_key="CellType", + unlabeled_category="Unknown", + methylation_contexts=["mCG", "mCH"], + categorical_covariate_keys=["Platform"], + modalities={ + "categorical_covariate_keys": "mCG" + }, + ) + + """ + if modalities is None: + modalities = {} + setup_method_args = METHYLANVI._get_setup_method_args(**locals()) + + if methylation_contexts is None: + raise ValueError("Methylation contexts cannot be None.") + + modalities_ = cls._create_modalities_attr_dict(modalities, setup_method_args) + + batch_field = fields.MuDataCategoricalObsField( + REGISTRY_KEYS.BATCH_KEY, + batch_key, + mod_key=modalities_.batch_key, + ) + + cat_cov_field = fields.MuDataCategoricalJointObsField( + REGISTRY_KEYS.CAT_COVS_KEY, + categorical_covariate_keys, + mod_key=modalities_.categorical_covariate_keys, + ) + + cell_type_field = fields.MuDataLabelsWithUnlabeledObsField( + REGISTRY_KEYS.LABELS_KEY, + labels_key, + unlabeled_category, + mod_key=modalities_.labels_key, + ) + + mc_fields = [] + cov_fields = [] + + for context in methylation_contexts: + mc_fields.append( + fields.MuDataLayerField( + _context_mc_key(context), + mc_layer, + mod_key=context, + is_count_data=True, + mod_required=True, + ) + ) + + cov_fields.append( + fields.MuDataLayerField( + _context_cov_key(context), + cov_layer, + mod_key=context, + is_count_data=True, + mod_required=True, + ) + ) + + mudata_fields = ( + mc_fields + cov_fields + [batch_field] + [cat_cov_field] + [cell_type_field] + ) + adata_manager = AnnDataManager(fields=mudata_fields, setup_method_args=setup_method_args) + adata_manager.register_fields(mdata, **kwargs) + + cls.register_manager(adata_manager) diff --git a/src/scvi/external/methylvi/_methylanvi_module.py b/src/scvi/external/methylvi/_methylanvi_module.py new file mode 100644 index 0000000000..99f00f9562 --- /dev/null +++ b/src/scvi/external/methylvi/_methylanvi_module.py @@ -0,0 +1,361 @@ +"""PyTorch module for methylVI for single cell methylation data.""" + +from collections.abc import Iterable, Sequence +from typing import Literal + +import numpy as np +import torch +from torch.distributions import Categorical, Normal +from torch.distributions import kl_divergence as kl +from torch.nn import functional as F + +from scvi import REGISTRY_KEYS +from scvi.external.methylvi._base_components import BSSeqModuleMixin +from scvi.external.methylvi._methylvi_module import METHYLVAE +from scvi.module._classifier import Classifier +from scvi.module._utils import broadcast_labels +from scvi.module.base import LossOutput, auto_move_data +from scvi.nn import Decoder, Encoder + + +class METHYLANVAE(METHYLVAE, BSSeqModuleMixin): + """Methylation annotation using variational inference. + + This is an implementation of the MethylANVI model described in :cite:p:`Weinberger2023a`. + + Parameters + ---------- + n_input + Number of input genes + n_batch + Number of batches + n_labels + Number of labels + n_hidden + Number of nodes per hidden layer + n_latent + Dimensionality of the latent space + n_layers + Number of hidden layers used for encoder and decoder NNs + n_continuous_cov + Number of continuous covarites + n_cats_per_cov + Number of categories for each extra categorical covariate + dropout_rate + Dropout rate for neural networks + likelihood + One of + * ``'betabinomial'`` - BetaBinomial distribution + * ``'binomial'`` - Binomial distribution + dispersion + One of the following + * ``'region'`` - dispersion parameter of BetaBinomial is constant per region across cells + * ``'region-cell'`` - dispersion can differ for every region in every cell + log_variational + Log(data+1) prior to encoding for numerical stability. Not normalization. + y_prior + If None, initialized to uniform probability over cell types + labels_groups + Label group designations + use_labels_groups + Whether to use the label groups + linear_classifier + If `True`, uses a single linear layer for classification instead of a + multi-layer perceptron. + classifier_parameters + Keyword arguments passed into :class:`~scvi.module.Classifier`. + use_batch_norm + Whether to use batch norm in layers + use_layer_norm + Whether to use layer norm in layers + linear_classifier + If ``True``, uses a single linear layer for classification instead of a + multi-layer perceptron. + **model_kwargs + Keyword args for :class:`~scvi.external.methylvi.METHYLVAE` + """ + + def __init__( + self, + n_input: int, + contexts: Iterable[str], + num_features_per_context: Iterable[int], + n_batch: int = 0, + n_cats_per_cov: Iterable[int] | None = None, + n_labels: int = 0, + n_hidden: int = 128, + n_latent: int = 10, + n_layers: int = 1, + dropout_rate: float = 0.1, + likelihood: Literal["betabinomial", "binomial"] = "betabinomial", + dispersion: Literal["region", "region-cell"] = "region", + y_prior=None, + labels_groups: Sequence[int] = None, + use_labels_groups: bool = False, + linear_classifier: bool = False, + classifier_parameters: dict | None = None, + use_batch_norm: Literal["encoder", "decoder", "none", "both"] = "both", + use_layer_norm: Literal["encoder", "decoder", "none", "both"] = "none", + **model_kwargs, + ): + super().__init__( + n_input=n_input, + n_hidden=n_hidden, + n_latent=n_latent, + n_layers=n_layers, + n_batch=n_batch, + n_cats_per_cov=n_cats_per_cov, + contexts=contexts, + num_features_per_context=num_features_per_context, + likelihood=likelihood, + dispersion=dispersion, + **model_kwargs, + ) + + classifier_parameters = classifier_parameters or {} + use_batch_norm_encoder = use_batch_norm == "encoder" or use_batch_norm == "both" + use_batch_norm_decoder = use_batch_norm == "decoder" or use_batch_norm == "both" + use_layer_norm_encoder = use_layer_norm == "encoder" or use_layer_norm == "both" + use_layer_norm_decoder = use_layer_norm == "decoder" or use_layer_norm == "both" + + self.n_labels = n_labels + # Classifier takes n_latent as input + cls_parameters = { + "n_layers": 0 if linear_classifier else n_layers, + "n_hidden": 0 if linear_classifier else n_hidden, + "dropout_rate": dropout_rate, + "logits": True, + } + cls_parameters.update(classifier_parameters) + self.classifier = Classifier( + n_latent, + n_labels=n_labels, + use_batch_norm=use_batch_norm_encoder, + use_layer_norm=use_layer_norm_encoder, + **cls_parameters, + ) + + self.encoder_z2_z1 = Encoder( + n_latent, + n_latent, + n_cat_list=[self.n_labels], + n_layers=n_layers, + n_hidden=n_hidden, + dropout_rate=dropout_rate, + use_batch_norm=use_batch_norm_encoder, + use_layer_norm=use_layer_norm_encoder, + return_dist=True, + ) + + self.decoder_z1_z2 = Decoder( + n_latent, + n_latent, + n_cat_list=[self.n_labels], + n_layers=n_layers, + n_hidden=n_hidden, + use_batch_norm=use_batch_norm_decoder, + use_layer_norm=use_layer_norm_decoder, + ) + + self.y_prior = torch.nn.Parameter( + y_prior if y_prior is not None else (1 / n_labels) * torch.ones(1, n_labels), + requires_grad=False, + ) + self.use_labels_groups = use_labels_groups + self.labels_groups = np.array(labels_groups) if labels_groups is not None else None + if self.use_labels_groups: + if labels_groups is None: + raise ValueError("Specify label groups") + unique_groups = np.unique(self.labels_groups) + self.n_groups = len(unique_groups) + if not (unique_groups == np.arange(self.n_groups)).all(): + raise ValueError() + self.classifier_groups = Classifier( + n_latent, n_hidden, self.n_groups, n_layers, dropout_rate + ) + self.groups_index = torch.nn.ParameterList( + [ + torch.nn.Parameter( + torch.tensor( + (self.labels_groups == i).astype(np.uint8), + dtype=torch.uint8, + ), + requires_grad=False, + ) + for i in range(self.n_groups) + ] + ) + + @auto_move_data + def classify( + self, + mc: torch.Tensor, + cov: torch.Tensor, + batch_index: torch.Tensor | None = None, + cont_covs=None, + cat_covs=None, + use_posterior_mean: bool = True, + ) -> torch.Tensor: + """Forward pass through the encoder and classifier. + + Parameters + ---------- + x + Tensor of shape ``(n_obs, n_vars)``. + batch_index + Tensor of shape ``(n_obs,)`` denoting batch indices. + cont_covs + Tensor of shape ``(n_obs, n_continuous_covariates)``. + cat_covs + Tensor of shape ``(n_obs, n_categorical_covariates)``. + use_posterior_mean + Whether to use the posterior mean of the latent distribution for + classification. + + Returns + ------- + Tensor of shape ``(n_obs, n_labels)`` denoting logit scores per label. + Before v1.1, this method by default returned probabilities per label, + see #2301 for more details. + """ + # log the inputs to the variational distribution for numerical stability + mc_ = torch.log(1 + mc) + cov_ = torch.log(1 + cov) + + # get variational parameters via the encoder networks + # we input both the methylated reads (mc) and coverage (cov) + encoder_input = torch.cat((mc_, cov_), dim=-1) + if cont_covs is not None and self.encode_covariates: + encoder_input = torch.cat((encoder_input, cont_covs), dim=-1) + if cat_covs is not None and self.encode_covariates: + categorical_input = torch.split(cat_covs, 1, dim=1) + else: + categorical_input = () + + qz, z = self.z_encoder(encoder_input, batch_index, *categorical_input) + z = qz.loc if use_posterior_mean else z + + if self.use_labels_groups: + w_g = self.classifier_groups(z) + unw_y = self.classifier(z) + w_y = torch.zeros_like(unw_y) + for i, group_index in enumerate(self.groups_index): + unw_y_g = unw_y[:, group_index] + w_y[:, group_index] = unw_y_g / (unw_y_g.sum(dim=-1, keepdim=True) + 1e-8) + w_y[:, group_index] *= w_g[:, [i]] + else: + w_y = self.classifier(z) + return w_y + + @auto_move_data + def classification_loss(self, labelled_dataset): + """Computes scANVI-style classification loss.""" + inference_inputs = self._get_inference_input(labelled_dataset) # (n_obs, n_vars) + data_inputs = {key: inference_inputs[key] for key in self.data_input_keys} + y = labelled_dataset[REGISTRY_KEYS.LABELS_KEY] # (n_obs, 1) + batch_idx = labelled_dataset[REGISTRY_KEYS.BATCH_KEY] + cat_covs = inference_inputs["cat_covs"] + + logits = self.classify( + **data_inputs, + batch_index=batch_idx, + cat_covs=cat_covs, + ) # (n_obs, n_labels) + ce_loss = F.cross_entropy( + logits, + y.view(-1).long(), + ) + return ce_loss, y, logits + + def loss( + self, + tensors, + inference_outputs, + generative_outputs, + feed_labels=False, + kl_weight=1, + labelled_tensors=None, + classification_ratio=None, + ): + """Compute the loss.""" + qz1 = inference_outputs["qz"] + z1 = inference_outputs["z"] + + if feed_labels: + y = tensors[REGISTRY_KEYS.LABELS_KEY] + else: + y = None + is_labelled = False if y is None else True + + # Enumerate choices of label + ys, z1s = broadcast_labels(z1, n_broadcast=self.n_labels) + qz2, z2 = self.encoder_z2_z1(z1s, ys) + pz1_m, pz1_v = self.decoder_z1_z2(z2, ys) + + minibatch_size = qz1.loc.size()[0] + reconst_loss = self._compute_minibatch_reconstruction_loss( + minibatch_size=minibatch_size, + tensors=tensors, + generative_outputs=generative_outputs, + ) + + # KL Divergence + mean = torch.zeros_like(qz2.loc) + scale = torch.ones_like(qz2.scale) + + kl_divergence_z2 = kl(qz2, Normal(mean, scale)).sum(dim=1) + loss_z1_unweight = -Normal(pz1_m, torch.sqrt(pz1_v)).log_prob(z1s).sum(dim=-1) + loss_z1_weight = qz1.log_prob(z1).sum(dim=-1) + + if is_labelled: + loss = reconst_loss + loss_z1_weight + loss_z1_unweight + kl_locals = { + "kl_divergence_z2": kl_divergence_z2, + } + if labelled_tensors is not None: + ce_loss, true_labels, logits = self.classification_loss(labelled_tensors) + loss += ce_loss * classification_ratio + return LossOutput( + loss=loss, + reconstruction_loss=reconst_loss, + kl_local=kl_locals, + classification_loss=ce_loss, + true_labels=true_labels, + logits=logits, + extra_metrics={ + "n_labelled_tensors": labelled_tensors[REGISTRY_KEYS.X_KEY].shape[0], + }, + ) + return LossOutput( + loss=loss, + reconstruction_loss=reconst_loss, + kl_local=kl_locals, + ) + + probs = F.softmax(self.classifier(z1), dim=-1) + + reconst_loss += loss_z1_weight + ( + (loss_z1_unweight).view(self.n_labels, -1).t() * probs + ).sum(dim=1) + + kl_divergence = (kl_divergence_z2.view(self.n_labels, -1).t() * probs).sum(dim=1) + kl_divergence += kl( + Categorical(probs=probs), + Categorical(probs=self.y_prior.repeat(probs.size(0), 1)), + ) + + loss = torch.mean(reconst_loss + kl_divergence * kl_weight) + + if labelled_tensors is not None: + ce_loss, true_labels, logits = self.classification_loss(labelled_tensors) + + loss += ce_loss * classification_ratio + return LossOutput( + loss=loss, + reconstruction_loss=reconst_loss, + kl_local=kl_divergence, + classification_loss=ce_loss, + true_labels=true_labels, + logits=logits, + ) + return LossOutput(loss=loss, reconstruction_loss=reconst_loss, kl_local=kl_divergence) diff --git a/src/scvi/external/methylvi/_methylvi_model.py b/src/scvi/external/methylvi/_methylvi_model.py new file mode 100644 index 0000000000..2e44b78666 --- /dev/null +++ b/src/scvi/external/methylvi/_methylvi_model.py @@ -0,0 +1,268 @@ +from __future__ import annotations + +import logging +from collections import defaultdict +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Iterable + + from anndata import AnnData + from mudata import MuData + +import numpy as np +import sparse +import torch + +from scvi import REGISTRY_KEYS +from scvi.data import AnnDataManager, fields +from scvi.data._constants import _SETUP_ARGS_KEY +from scvi.external.methylvi._base_components import BSSeqMixin +from scvi.external.methylvi._utils import _context_cov_key, _context_mc_key +from scvi.model.base import ( + ArchesMixin, + BaseModelClass, + UnsupervisedTrainingMixin, + VAEMixin, +) +from scvi.utils import setup_anndata_dsp + +from ._methylvi_module import METHYLVAE + +logger = logging.getLogger(__name__) + + +class METHYLVI(VAEMixin, BSSeqMixin, UnsupervisedTrainingMixin, ArchesMixin, BaseModelClass): + """ + Model class for methylVI :cite:p:`Weinberger2023a` + + Parameters + ---------- + mdata + MuData object that has been registered via :meth:`~scvi.external.METHYLVI.setup_mudata`. + n_hidden + Number of nodes per hidden layer. + n_latent + Dimensionality of the latent space. + n_layers + Number of hidden layers used for encoder and decoder NNs. + **model_kwargs + Keyword args for :class:`~scvi.external.methylvi.METHYLVAE` + + Examples + -------- + >>> mdata = mudata.read_h5mu(path_to_mudata) + >>> MethylVI.setup_mudata(mdata, batch_key="batch") + >>> vae = MethylVI(mdata) + >>> vae.train() + >>> mdata.obsm["X_methylVI"] = vae.get_latent_representation() + """ + + def __init__( + self, + mdata: MuData, + n_hidden: int = 128, + n_latent: int = 10, + n_layers: int = 1, + **model_kwargs, + ): + super().__init__(mdata) + + n_batch = self.summary_stats.n_batch + n_cats_per_cov = ( + self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY)[ + fields.CategoricalJointObsField.N_CATS_PER_KEY + ] + if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry + else None + ) + + self.contexts = self.get_anndata_manager(mdata, required=True).registry[_SETUP_ARGS_KEY][ + "methylation_contexts" + ] + self.num_features_per_context = [mdata[context].shape[1] for context in self.contexts] + + n_input = np.sum(self.num_features_per_context) + + self.module = METHYLVAE( + n_input=n_input, + n_hidden=n_hidden, + n_latent=n_latent, + n_layers=n_layers, + n_batch=n_batch, + n_cats_per_cov=n_cats_per_cov, + contexts=self.contexts, + num_features_per_context=self.num_features_per_context, + **model_kwargs, + ) + self._model_summary_string = ( + "Overwrite this attribute to get an informative representation for your model" + ) + # necessary line to get params that will be used for saving/loading + self.init_params_ = self._get_init_params(locals()) + + logger.info("The model has been initialized") + + @classmethod + @setup_anndata_dsp.dedent + def setup_anndata( + cls, + adata: AnnData, + **kwargs, + ) -> AnnData | None: + """ + %(summary)s. + + Parameters + ---------- + %(param_adata)s + + Returns + ------- + %(returns)s + """ + raise NotImplementedError("METHYLVI must be used with a MuData object.") + + @classmethod + @setup_anndata_dsp.dedent + def setup_mudata( + cls, + mdata: MuData, + mc_layer: str, + cov_layer: str, + methylation_contexts: Iterable[str], + batch_key: str | None = None, + categorical_covariate_keys: list[str] | None = None, + modalities=None, + **kwargs, + ): + """%(summary_mdata)s. + + Parameters + ---------- + %(param_mdata)s + mc_layer + Layer containing methylated cytosine counts for each set of methylation features. + cov_layer + Layer containing total coverage counts for each set of methylation features. + methylation_contexts + List of modality fields in `mdata` object representing different methylation contexts. + Each context must be equipped with a layer containing the number of methylated counts + (specified by `mc_layer`) and total number of counts (specified by `cov_layer`) for + each genomic region feature. + %(param_batch_key)s + %(param_categorical_covariate_keys)s + %(param_modalities)s + + Examples + -------- + MethylVI.setup_mudata( + mdata, + mc_layer="mc", + cov_layer="cov", + batch_key="Platform", + methylation_modalities=['mCG', 'mCH'], + modalities={ + "batch_key": "mCG" + }, + ) + + """ + if modalities is None: + modalities = {} + setup_method_args = METHYLVI._get_setup_method_args(**locals()) + + if methylation_contexts is None: + raise ValueError("Methylation contexts cannot be None.") + + modalities_ = cls._create_modalities_attr_dict(modalities, setup_method_args) + + batch_field = fields.MuDataCategoricalObsField( + REGISTRY_KEYS.BATCH_KEY, + batch_key, + mod_key=modalities_.batch_key, + ) + + cat_cov_field = fields.MuDataCategoricalJointObsField( + REGISTRY_KEYS.CAT_COVS_KEY, + categorical_covariate_keys, + mod_key=modalities_.categorical_covariate_keys, + ) + + mc_fields = [] + cov_fields = [] + + for context in methylation_contexts: + mc_fields.append( + fields.MuDataLayerField( + _context_mc_key(context), + mc_layer, + mod_key=context, + is_count_data=True, + mod_required=True, + ) + ) + + cov_fields.append( + fields.MuDataLayerField( + _context_cov_key(context), + cov_layer, + mod_key=context, + is_count_data=True, + mod_required=True, + ) + ) + + mudata_fields = mc_fields + cov_fields + [batch_field] + [cat_cov_field] + adata_manager = AnnDataManager(fields=mudata_fields, setup_method_args=setup_method_args) + adata_manager.register_fields(mdata, **kwargs) + + cls.register_manager(adata_manager) + + @torch.inference_mode() + def posterior_predictive_sample( + self, + mdata: MuData | None = None, + n_samples: int = 1, + batch_size: int | None = None, + ) -> dict[str, sparse.GCXS] | sparse.GCXS: + r""" + Generate observation samples from the posterior predictive distribution. + + The posterior predictive distribution is written as :math:`p(\hat{x} \mid x)`. + + Parameters + ---------- + mdata + MuData object with equivalent structure to initial MuData. If `None`, defaults to the + MuData object used to initialize the model. + n_samples + Number of samples for each cell. + batch_size + Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. + + Returns + ------- + x_new : :py:class:`torch.Tensor` + tensor with shape (n_cells, n_regions, n_samples) + """ + mdata = self._validate_anndata(mdata) + + scdl = self._make_data_loader(adata=mdata, batch_size=batch_size) + + x_new = defaultdict(list) + for tensors in scdl: + samples = self.module.sample( + tensors, + n_samples=n_samples, + ) + + for context in self.contexts: + x_new[context].append(sparse.GCXS.from_numpy(samples[context].numpy())) + + for context in self.contexts: + x_new[context] = sparse.concatenate( + x_new[context] + ) # Shape (n_cells, n_regions, n_samples) + + return x_new diff --git a/src/scvi/external/methylvi/_module.py b/src/scvi/external/methylvi/_methylvi_module.py similarity index 91% rename from src/scvi/external/methylvi/_module.py rename to src/scvi/external/methylvi/_methylvi_module.py index bff6f4c771..24d5592fcc 100644 --- a/src/scvi/external/methylvi/_module.py +++ b/src/scvi/external/methylvi/_methylvi_module.py @@ -11,6 +11,7 @@ from scvi import REGISTRY_KEYS from scvi.distributions import BetaBinomial from scvi.external.methylvi import METHYLVI_REGISTRY_KEYS, DecoderMETHYLVI +from scvi.external.methylvi._base_components import BSSeqModuleMixin from scvi.external.methylvi._utils import _context_cov_key, _context_mc_key from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data from scvi.nn import Encoder @@ -18,7 +19,7 @@ TensorDict = dict[str, torch.Tensor] -class METHYLVAE(BaseModuleClass): +class METHYLVAE(BaseModuleClass, BSSeqModuleMixin): """PyTorch module for methylVI. Parameters @@ -206,24 +207,11 @@ def loss( weighted_kl_local = kl_weight * kl_local_for_warmup minibatch_size = qz.loc.size()[0] - reconst_loss = torch.zeros(minibatch_size).to(self.device) - - for context in self.contexts: - px_mu = generative_outputs["px_mu"][context] - px_gamma = generative_outputs["px_gamma"][context] - mc = tensors[f"{context}_{METHYLVI_REGISTRY_KEYS.MC_KEY}"] - cov = tensors[f"{context}_{METHYLVI_REGISTRY_KEYS.COV_KEY}"] - - if self.dispersion == "region": - px_gamma = torch.sigmoid(self.px_gamma[context]) - - if self.likelihood == "binomial": - dist = Binomial(probs=px_mu, total_count=cov) - elif self.likelihood == "betabinomial": - dist = BetaBinomial(mu=px_mu, gamma=px_gamma, total_count=cov) - - reconst_loss += -dist.log_prob(mc).sum(dim=-1) - + reconst_loss = self._compute_minibatch_reconstruction_loss( + minibatch_size=minibatch_size, + tensors=tensors, + generative_outputs=generative_outputs, + ) loss = torch.mean(reconst_loss + weighted_kl_local) kl_local = {"kl_divergence_z": kl_divergence_z} diff --git a/src/scvi/external/methylvi/_model.py b/src/scvi/external/methylvi/_model.py deleted file mode 100644 index 494a81658a..0000000000 --- a/src/scvi/external/methylvi/_model.py +++ /dev/null @@ -1,623 +0,0 @@ -from __future__ import annotations - -import logging -import warnings -from collections import defaultdict -from functools import partial -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from collections.abc import Iterable, Sequence - from typing import Literal - - from anndata import AnnData - from mudata import MuData - - from scvi._types import Number - -import numpy as np -import pandas as pd -import sparse -import torch - -from scvi import REGISTRY_KEYS, settings -from scvi.data import AnnDataManager, fields -from scvi.data._constants import _SETUP_ARGS_KEY -from scvi.external.methylvi._utils import _context_cov_key, _context_mc_key -from scvi.model.base import ( - ArchesMixin, - BaseModelClass, - UnsupervisedTrainingMixin, - VAEMixin, -) -from scvi.model.base._de_core import ( - _de_core, -) -from scvi.utils import setup_anndata_dsp - -from ._module import METHYLVAE -from ._utils import scmc_raw_counts_properties - -logger = logging.getLogger(__name__) - - -class METHYLVI(VAEMixin, UnsupervisedTrainingMixin, ArchesMixin, BaseModelClass): - """ - Model class for methylVI :cite:p:`Weinberger2023a` - - Parameters - ---------- - mdata - MuData object that has been registered via :meth:`~scvi.external.METHYLVI.setup_mudata`. - n_hidden - Number of nodes per hidden layer. - n_latent - Dimensionality of the latent space. - n_layers - Number of hidden layers used for encoder and decoder NNs. - **model_kwargs - Keyword args for :class:`~scvi.external.methylvi.METHYLVAE` - - Examples - -------- - >>> mdata = mudata.read_h5mu(path_to_mudata) - >>> MethylVI.setup_mudata(mdata, batch_key="batch") - >>> vae = MethylVI(mdata) - >>> vae.train() - >>> mdata.obsm["X_methylVI"] = vae.get_latent_representation() - """ - - def __init__( - self, - mdata: MuData, - n_hidden: int = 128, - n_latent: int = 10, - n_layers: int = 1, - **model_kwargs, - ): - super().__init__(mdata) - - n_batch = self.summary_stats.n_batch - n_cats_per_cov = ( - self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY)[ - fields.CategoricalJointObsField.N_CATS_PER_KEY - ] - if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry - else None - ) - - self.contexts = self.get_anndata_manager(mdata, required=True).registry[_SETUP_ARGS_KEY][ - "methylation_contexts" - ] - self.num_features_per_context = [mdata[context].shape[1] for context in self.contexts] - - n_input = np.sum(self.num_features_per_context) - - self.module = METHYLVAE( - n_input=n_input, - n_hidden=n_hidden, - n_latent=n_latent, - n_layers=n_layers, - n_batch=n_batch, - n_cats_per_cov=n_cats_per_cov, - contexts=self.contexts, - num_features_per_context=self.num_features_per_context, - **model_kwargs, - ) - self._model_summary_string = ( - "Overwrite this attribute to get an informative representation for your model" - ) - # necessary line to get params that will be used for saving/loading - self.init_params_ = self._get_init_params(locals()) - - logger.info("The model has been initialized") - - @classmethod - @setup_anndata_dsp.dedent - def setup_anndata( - cls, - adata: AnnData, - **kwargs, - ) -> AnnData | None: - """ - %(summary)s. - - Parameters - ---------- - %(param_adata)s - - Returns - ------- - %(returns)s - """ - raise NotImplementedError("METHYLVI must be used with a MuData object.") - - @classmethod - @setup_anndata_dsp.dedent - def setup_mudata( - cls, - mdata: MuData, - mc_layer: str, - cov_layer: str, - methylation_contexts: Iterable[str], - batch_key: str | None = None, - categorical_covariate_keys: list[str] | None = None, - modalities=None, - **kwargs, - ): - """%(summary_mdata)s. - - Parameters - ---------- - %(param_mdata)s - mc_layer - Layer containing methylated cytosine counts for each set of methylation features. - cov_layer - Layer containing total coverage counts for each set of methylation features. - methylation_contexts - List of modality fields in `mdata` object representing different methylation contexts. - Each context must be equipped with a layer containing the number of methylated counts - (specified by `mc_layer`) and total number of counts (specified by `cov_layer`) for - each genomic region feature. - %(param_batch_key)s - %(param_cat_cov_keys)s - %(param_modalities)s - - Examples - -------- - MethylVI.setup_mudata( - mdata, - mc_layer="mc", - cov_layer="cov", - batch_key="Platform", - methylation_modalities=['mCG', 'mCH'], - modalities={ - "batch_key": "mCG" - }, - ) - - """ - if modalities is None: - modalities = {} - setup_method_args = METHYLVI._get_setup_method_args(**locals()) - - if methylation_contexts is None: - raise ValueError("Methylation contexts cannot be None.") - - modalities_ = cls._create_modalities_attr_dict(modalities, setup_method_args) - - batch_field = fields.MuDataCategoricalObsField( - REGISTRY_KEYS.BATCH_KEY, - batch_key, - mod_key=modalities_.batch_key, - ) - - cat_cov_field = fields.MuDataCategoricalJointObsField( - REGISTRY_KEYS.CAT_COVS_KEY, - categorical_covariate_keys, - mod_key=modalities_.categorical_covariate_keys, - ) - - mc_fields = [] - cov_fields = [] - - for context in methylation_contexts: - mc_fields.append( - fields.MuDataLayerField( - _context_mc_key(context), - mc_layer, - mod_key=context, - is_count_data=True, - mod_required=True, - ) - ) - - cov_fields.append( - fields.MuDataLayerField( - _context_cov_key(context), - cov_layer, - mod_key=context, - is_count_data=True, - mod_required=True, - ) - ) - - mudata_fields = mc_fields + cov_fields + [batch_field] + [cat_cov_field] - adata_manager = AnnDataManager(fields=mudata_fields, setup_method_args=setup_method_args) - adata_manager.register_fields(mdata, **kwargs) - - cls.register_manager(adata_manager) - - @torch.inference_mode() - def posterior_predictive_sample( - self, - mdata: MuData | None = None, - n_samples: int = 1, - batch_size: int | None = None, - ) -> dict[str, sparse.GCXS] | sparse.GCXS: - r""" - Generate observation samples from the posterior predictive distribution. - - The posterior predictive distribution is written as :math:`p(\hat{x} \mid x)`. - - Parameters - ---------- - mdata - MuData object with equivalent structure to initial MuData. If `None`, defaults to the - MuData object used to initialize the model. - n_samples - Number of samples for each cell. - batch_size - Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. - - Returns - ------- - x_new : :py:class:`torch.Tensor` - tensor with shape (n_cells, n_regions, n_samples) - """ - mdata = self._validate_anndata(mdata) - - scdl = self._make_data_loader(adata=mdata, batch_size=batch_size) - - x_new = defaultdict(list) - for tensors in scdl: - samples = self.module.sample( - tensors, - n_samples=n_samples, - ) - - for context in self.contexts: - x_new[context].append(sparse.GCXS.from_numpy(samples[context].numpy())) - - for context in self.contexts: - x_new[context] = sparse.concatenate( - x_new[context] - ) # Shape (n_cells, n_regions, n_samples) - - return x_new - - @torch.inference_mode() - def get_normalized_methylation( - self, - mdata: MuData | None = None, - indices: Sequence[int] | None = None, - region_list: Sequence[str] | None = None, - n_samples: int = 1, - n_samples_overall: int = None, - batch_size: int | None = None, - return_mean: bool = True, - return_numpy: bool | None = None, - context: str | None = None, - **importance_weighting_kwargs, - ) -> (np.ndarray | pd.DataFrame) | dict[str, np.ndarray | pd.DataFrame]: - r"""Returns the normalized (decoded) methylation. - - This is denoted as :math:`\mu_n` in the methylVI paper. - - Parameters - ---------- - mdata - MuData object with equivalent structure to initial Mudata. - If `None`, defaults to the MuData object used to initialize the model. - indices - Indices of cells in mdata to use. If `None`, all cells are used. - region_list - Return frequencies of expression for a subset of regions. - This can save memory when working with large datasets and few regions are - of interest. - n_samples - Number of posterior samples to use for estimation. - n_samples_overall - Number of posterior samples to use for estimation. Overrides `n_samples`. - batch_size - Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. - return_mean - Whether to return the mean of the samples. - return_numpy - Return a :class:`~numpy.ndarray` instead of a :class:`~pandas.DataFrame`. - DataFrame includes region names as columns. If either `n_samples=1` or - `return_mean=True`, defaults to `False`. Otherwise, it defaults to `True`. - context - If not `None`, returns normalized methylation levels for the specified - methylation context. Otherwise, a dictionary with contexts as keys and normalized - methylation levels as values is returned. - - Returns - ------- - If `n_samples` is provided and `return_mean` is False, - this method returns a 3d tensor of shape (n_samples, n_cells, n_regions). - If `n_samples` is provided and `return_mean` is True, it returns a 2d tensor - of shape (n_cells, n_regions). - In this case, return type is :class:`~pandas.DataFrame` unless `return_numpy` is True. - Otherwise, the method expects `n_samples_overall` to be provided and returns a 2d tensor - of shape (n_samples_overall, n_regions). - - If model was set up using a MuData object, a dictionary is returned with keys - corresponding to individual methylation contexts with values determined as - described above. - """ - mdata = self._validate_anndata(mdata) - - if context is not None and context not in self.contexts: - raise ValueError( - f"{context} is not a valid methylation context for this model. " - f"Valid contexts are {self.contexts}." - ) - - if indices is None: - indices = np.arange(mdata.n_obs) - if n_samples_overall is not None: - assert n_samples == 1 # default value - n_samples = n_samples_overall // len(indices) + 1 - scdl = self._make_data_loader(adata=mdata, indices=indices, batch_size=batch_size) - - region_mask = slice(None) if region_list is None else mdata.var_names.isin(region_list) - - if n_samples > 1 and return_mean is False: - if return_numpy is False: - warnings.warn( - "`return_numpy` must be `True` if `n_samples > 1` and `return_mean` " - "is`False`, returning an `np.ndarray`.", - UserWarning, - stacklevel=settings.warnings_stacklevel, - ) - return_numpy = True - - exprs = defaultdict(list) - - for tensors in scdl: - inference_kwargs = {"n_samples": n_samples} - inference_outputs, generative_outputs = self.module.forward( - tensors=tensors, - inference_kwargs=inference_kwargs, - generative_kwargs={}, - compute_loss=False, - ) - - for ctxt in self.contexts: - exp_ = generative_outputs["px_mu"][ctxt] - exp_ = exp_[..., region_mask] - exprs[ctxt].append(exp_.cpu()) - - cell_axis = 1 if n_samples > 1 else 0 - - for ctxt in self.contexts: - exprs[ctxt] = np.concatenate(exprs[ctxt], axis=cell_axis) - - if n_samples_overall is not None: - # Converts the 3d tensor to a 2d tensor - for ctxt in self.contexts: - exprs[ctxt] = exprs[ctxt].reshape(-1, exprs[ctxt].shape[-1]) - n_samples_ = exprs[ctxt].shape[0] - ind_ = np.random.choice(n_samples_, n_samples_overall, replace=True) - exprs[ctxt] = exprs[ctxt][ind_] - return_numpy = True - - elif n_samples > 1 and return_mean: - for ctxt in self.contexts: - exprs[ctxt] = exprs[ctxt].mean(0) - - if return_numpy is None or return_numpy is False: - exprs_dfs = {} - for ctxt in self.contexts: - exprs_dfs[ctxt] = pd.DataFrame( - exprs[ctxt], - columns=mdata[ctxt].var_names[region_mask], - index=mdata[ctxt].obs_names[indices], - ) - exprs_ = exprs_dfs - else: - exprs_ = exprs - - if context is not None: - return exprs_[context] - else: - return exprs_ - - @torch.inference_mode() - def get_specific_normalized_methylation( - self, - mdata: MuData | None = None, - context: str = None, - indices: Sequence[int] | None = None, - transform_batch: Sequence[Number | str] | None = None, - region_list: Sequence[str] | None = None, - n_samples: int = 1, - n_samples_overall: int = None, - weights: Literal["uniform", "importance"] | None = None, - batch_size: int | None = None, - return_mean: bool = True, - return_numpy: bool | None = None, - **importance_weighting_kwargs, - ) -> (np.ndarray | pd.DataFrame) | dict[str, np.ndarray | pd.DataFrame]: - r"""Convenience function to obtain normalized methylation values for a single context. - - Only applicable to MuData models. - - Parameters - ---------- - mdata - MuData object with equivalent structure to initial MuData. If `None`, defaults to the - MuData object used to initialize the model. - context - Methylation context for which to obtain normalized methylation levels. - indices - Indices of cells in mdata to use. If `None`, all cells are used. - transform_batch - Batch to condition on. - If transform_batch is: - - - None, then real observed batch is used. - - int, then batch transform_batch is used. - region_list - Return frequencies of expression for a subset of regions. - This can save memory when working with large datasets and few regions are - of interest. - n_samples - Number of posterior samples to use for estimation. - n_samples_overall - Number of posterior samples to use for estimation. Overrides `n_samples`. - weights - Weights to use for sampling. If `None`, defaults to `"uniform"`. - batch_size - Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. - return_mean - Whether to return the mean of the samples. - return_numpy - Return a :class:`~numpy.ndarray` instead of a :class:`~pandas.DataFrame`. - DataFrame includes region names as columns. If either `n_samples=1` or - `return_mean=True`, defaults to `False`. Otherwise, it defaults to `True`. - importance_weighting_kwargs - Keyword arguments passed into - :meth:`~scvi.model.base.RNASeqMixin._get_importance_weights`. - - Returns - ------- - If `n_samples` is provided and `return_mean` is False, - this method returns a 3d tensor of shape (n_samples, n_cells, n_regions). - If `n_samples` is provided and `return_mean` is True, it returns a 2d tensor - of shape (n_cells, n_regions). - In this case, return type is :class:`~pandas.DataFrame` unless `return_numpy` is True. - Otherwise, the method expects `n_samples_overall` to be provided and returns a 2d tensor - of shape (n_samples_overall, n_regions). - """ - exprs = self.get_normalized_methylation( - mdata=mdata, - indices=indices, - transform_batch=transform_batch, - region_list=region_list, - n_samples=n_samples, - n_samples_overall=n_samples_overall, - weights=weights, - batch_size=batch_size, - return_mean=return_mean, - return_numpy=return_numpy, - **importance_weighting_kwargs, - ) - return exprs[context] - - def differential_methylation( - self, - mdata: MuData | None = None, - groupby: str | None = None, - group1: Iterable[str] | None = None, - group2: str | None = None, - idx1: Sequence[int] | Sequence[bool] | str | None = None, - idx2: Sequence[int] | Sequence[bool] | str | None = None, - mode: Literal["vanilla", "change"] = "vanilla", - delta: float = 0.05, - batch_size: int | None = None, - all_stats: bool = True, - batch_correction: bool = False, - batchid1: Iterable[str] | None = None, - batchid2: Iterable[str] | None = None, - fdr_target: float = 0.05, - silent: bool = False, - two_sided: bool = True, - **kwargs, - ) -> dict[str, pd.DataFrame] | pd.DataFrame: - r"""\. - - A unified method for differential methylation analysis. - - Implements `"vanilla"` DE :cite:p:`Lopez18`. and `"change"` mode DE :cite:p:`Boyeau19`. - - Parameters - ---------- - %(de_mdata)s - %(de_modality)s - %(de_groupby)s - %(de_group1)s - %(de_group2)s - %(de_idx1)s - %(de_idx2)s - %(de_mode)s - %(de_delta)s - %(de_batch_size)s - %(de_all_stats)s - %(de_batch_correction)s - %(de_batchid1)s - %(de_batchid2)s - %(de_fdr_target)s - %(de_silent)s - two_sided - Whether to perform a two-sided test, or a one-sided test. - **kwargs - Keyword args for :meth:`scvi.model.base.DifferentialComputation.get_bayes_factors` - - Returns - ------- - Differential methylation DataFrame with the following columns: - proba_de - the probability of the region being differentially methylated - is_de_fdr - whether the region passes a multiple hypothesis correction procedure - with the target_fdr threshold - bayes_factor - Bayes Factor indicating the level of significance of the analysis - effect_size - the effect size, computed as (accessibility in population 2) - - (accessibility in population 1) - emp_effect - the empirical effect, based on observed detection rates instead of the estimated - accessibility scores from the methylVI model - scale1 - the estimated methylation level in population 1 - scale2 - the estimated methylation level in population 2 - emp_mean1 - the empirical (observed) methylation level in population 1 - emp_mean2 - the empirical (observed) methylation level in population 2 - - """ - mdata = self._validate_anndata(mdata) - - def change_fn(a, b): - return a - b - - if two_sided: - - def m1_domain_fn(samples): - return np.abs(samples) >= delta - - else: - - def m1_domain_fn(samples): - return samples >= delta - - result = {} - for context in self.contexts: - col_names = mdata[context].var_names - model_fn = partial( - self.get_specific_normalized_methylation, - batch_size=batch_size, - context=context, - ) - all_stats_fn = partial(scmc_raw_counts_properties, context=context) - - result[context] = _de_core( - adata_manager=self.get_anndata_manager(mdata, required=True), - model_fn=model_fn, - representation_fn=None, - groupby=groupby, - group1=group1, - group2=group2, - idx1=idx1, - idx2=idx2, - all_stats=all_stats, - all_stats_fn=all_stats_fn, - col_names=col_names, - mode=mode, - batchid1=batchid1, - batchid2=batchid2, - delta=delta, - batch_correction=batch_correction, - fdr=fdr_target, - silent=silent, - change_fn=change_fn, - m1_domain_fn=m1_domain_fn, - **kwargs, - ) - - return result diff --git a/src/scvi/external/methylvi/_utils.py b/src/scvi/external/methylvi/_utils.py index d880958570..e271b46726 100644 --- a/src/scvi/external/methylvi/_utils.py +++ b/src/scvi/external/methylvi/_utils.py @@ -4,7 +4,7 @@ from scvi.data import AnnDataManager from scvi.data._constants import _SETUP_ARGS_KEY -from scvi.external.methylvi import METHYLVI_REGISTRY_KEYS +from scvi.external.methylvi._constants import METHYLVI_REGISTRY_KEYS logger = logging.getLogger(__name__) diff --git a/src/scvi/model/base/__init__.py b/src/scvi/model/base/__init__.py index 4b38494caf..97f10675d6 100644 --- a/src/scvi/model/base/__init__.py +++ b/src/scvi/model/base/__init__.py @@ -14,7 +14,7 @@ PyroSviTrainMixin, ) from ._rnamixin import RNASeqMixin -from ._training_mixin import UnsupervisedTrainingMixin +from ._training_mixin import SemisupervisedTrainingMixin, UnsupervisedTrainingMixin from ._vaemixin import VAEMixin __all__ = [ @@ -32,4 +32,5 @@ "BaseMinifiedModeModelClass", "BaseMudataMinifiedModeModelClass", "EmbeddingMixin", + "SemisupervisedTrainingMixin", ] diff --git a/src/scvi/model/base/_training_mixin.py b/src/scvi/model/base/_training_mixin.py index ebace98445..00cd47cc5a 100644 --- a/src/scvi/model/base/_training_mixin.py +++ b/src/scvi/model/base/_training_mixin.py @@ -1,15 +1,30 @@ from __future__ import annotations +import logging from typing import TYPE_CHECKING -from scvi.dataloaders import DataSplitter +import numpy as np +import pandas as pd +import torch + +from scvi import REGISTRY_KEYS +from scvi.data._utils import get_anndata_attribute +from scvi.dataloaders import DataSplitter, SemiSupervisedDataSplitter from scvi.model._utils import get_max_epochs_heuristic, use_distributed_sampler -from scvi.train import TrainingPlan, TrainRunner +from scvi.train import SemiSupervisedTrainingPlan, TrainingPlan, TrainRunner +from scvi.train._callbacks import SubSampleLabels from scvi.utils._docstrings import devices_dsp if TYPE_CHECKING: + from collections.abc import Sequence + from lightning import LightningDataModule + from scvi._types import AnnOrMuData + + +logger = logging.getLogger(__name__) + class UnsupervisedTrainingMixin: """General purpose unsupervised train method.""" @@ -143,3 +158,195 @@ def train( **trainer_kwargs, ) return runner() + + +class SemisupervisedTrainingMixin: + _training_plan_cls = SemiSupervisedTrainingPlan + + def _set_indices_and_labels(self): + """Set indices for labeled and unlabeled cells.""" + labels_state_registry = self.adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY) + self.original_label_key = labels_state_registry.original_key + self.unlabeled_category_ = labels_state_registry.unlabeled_category + + labels = get_anndata_attribute( + self.adata, + self.adata_manager.data_registry.labels.attr_name, + self.original_label_key, + mod_key=getattr(self.adata_manager.data_registry.labels, "mod_key", None), + ).ravel() + self._label_mapping = labels_state_registry.categorical_mapping + + # set unlabeled and labeled indices + self._unlabeled_indices = np.argwhere(labels == self.unlabeled_category_).ravel() + self._labeled_indices = np.argwhere(labels != self.unlabeled_category_).ravel() + self._code_to_label = dict(enumerate(self._label_mapping)) + + def predict( + self, + adata: AnnOrMuData | None = None, + indices: Sequence[int] | None = None, + soft: bool = False, + batch_size: int | None = None, + use_posterior_mean: bool = True, + ) -> np.ndarray | pd.DataFrame: + """Return cell label predictions. + + Parameters + ---------- + adata + AnnData or MuData object that has been registered via corresponding setup + method in model class. + indices + Return probabilities for each class label. + soft + If True, returns per class probabilities + batch_size + Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. + use_posterior_mean + If ``True``, uses the mean of the posterior distribution to predict celltype + labels. Otherwise, uses a sample from the posterior distribution - this + means that the predictions will be stochastic. + """ + adata = self._validate_anndata(adata) + + if indices is None: + indices = np.arange(adata.n_obs) + + scdl = self._make_data_loader( + adata=adata, + indices=indices, + batch_size=batch_size, + ) + y_pred = [] + for _, tensors in enumerate(scdl): + inference_inputs = self.module._get_inference_input(tensors) # (n_obs, n_vars) + data_inputs = {key: inference_inputs[key] for key in self.module.data_input_keys} + + batch = tensors[REGISTRY_KEYS.BATCH_KEY] + + cont_key = REGISTRY_KEYS.CONT_COVS_KEY + cont_covs = tensors[cont_key] if cont_key in tensors.keys() else None + + cat_key = REGISTRY_KEYS.CAT_COVS_KEY + cat_covs = tensors[cat_key] if cat_key in tensors.keys() else None + + pred = self.module.classify( + **data_inputs, + batch_index=batch, + cat_covs=cat_covs, + cont_covs=cont_covs, + use_posterior_mean=use_posterior_mean, + ) + if self.module.classifier.logits: + pred = torch.nn.functional.softmax(pred, dim=-1) + if not soft: + pred = pred.argmax(dim=1) + y_pred.append(pred.detach().cpu()) + + y_pred = torch.cat(y_pred).numpy() + if not soft: + predictions = [self._code_to_label[p] for p in y_pred] + return np.array(predictions) + else: + n_labels = len(pred[0]) + pred = pd.DataFrame( + y_pred, + columns=self._label_mapping[:n_labels], + index=adata.obs_names[indices], + ) + return pred + + @devices_dsp.dedent + def train( + self, + max_epochs: int | None = None, + n_samples_per_label: float | None = None, + check_val_every_n_epoch: int | None = None, + train_size: float = 0.9, + validation_size: float | None = None, + shuffle_set_split: bool = True, + batch_size: int = 128, + accelerator: str = "auto", + devices: int | list[int] | str = "auto", + datasplitter_kwargs: dict | None = None, + plan_kwargs: dict | None = None, + **trainer_kwargs, + ): + """Train the model. + + Parameters + ---------- + max_epochs + Number of passes through the dataset for semisupervised training. + n_samples_per_label + Number of subsamples for each label class to sample per epoch. By default, there + is no label subsampling. + check_val_every_n_epoch + Frequency with which metrics are computed on the data for validation set for both + the unsupervised and semisupervised trainers. If you'd like a different frequency for + the semisupervised trainer, set check_val_every_n_epoch in semisupervised_train_kwargs. + train_size + Size of training set in the range [0.0, 1.0]. + validation_size + Size of the test set. If `None`, defaults to 1 - `train_size`. If + `train_size + validation_size < 1`, the remaining cells belong to a test set. + shuffle_set_split + Whether to shuffle indices before splitting. If `False`, the val, train, + and test set are split in the sequential order of the data according to + `validation_size` and `train_size` percentages. + batch_size + Minibatch size to use during training. + %(param_accelerator)s + %(param_devices)s + datasplitter_kwargs + Additional keyword arguments passed into + :class:`~scvi.dataloaders.SemiSupervisedDataSplitter`. + plan_kwargs + Keyword args for :class:`~scvi.train.SemiSupervisedTrainingPlan`. Keyword + arguments passed to `train()` will overwrite values present in `plan_kwargs`, + when appropriate. + **trainer_kwargs + Other keyword args for :class:`~scvi.train.Trainer`. + """ + if max_epochs is None: + max_epochs = get_max_epochs_heuristic(self.adata.n_obs) + + if self.was_pretrained: + max_epochs = int(np.min([10, np.max([2, round(max_epochs / 3.0)])])) + + logger.info(f"Training for {max_epochs} epochs.") + + plan_kwargs = {} if plan_kwargs is None else plan_kwargs + datasplitter_kwargs = datasplitter_kwargs or {} + + # if we have labeled cells, we want to subsample labels each epoch + sampler_callback = [SubSampleLabels()] if len(self._labeled_indices) != 0 else [] + + data_splitter = SemiSupervisedDataSplitter( + adata_manager=self.adata_manager, + train_size=train_size, + validation_size=validation_size, + shuffle_set_split=shuffle_set_split, + n_samples_per_label=n_samples_per_label, + batch_size=batch_size, + **datasplitter_kwargs, + ) + training_plan = self._training_plan_cls(self.module, self.n_labels, **plan_kwargs) + + if "callbacks" in trainer_kwargs.keys(): + trainer_kwargs["callbacks"] + [sampler_callback] + else: + trainer_kwargs["callbacks"] = sampler_callback + + runner = TrainRunner( + self, + training_plan=training_plan, + data_splitter=data_splitter, + max_epochs=max_epochs, + accelerator=accelerator, + devices=devices, + check_val_every_n_epoch=check_val_every_n_epoch, + **trainer_kwargs, + ) + return runner() diff --git a/tests/external/methylvi/test_methylanvi.py b/tests/external/methylvi/test_methylanvi.py new file mode 100644 index 0000000000..083832dc8d --- /dev/null +++ b/tests/external/methylvi/test_methylanvi.py @@ -0,0 +1,40 @@ +import pytest +from mudata import MuData + +from scvi.data import synthetic_iid +from scvi.external import METHYLANVI + + +def test_methylanvi(): + adata1 = synthetic_iid() + adata1.layers["mc"] = adata1.X + adata1.layers["cov"] = adata1.layers["mc"] + 10 + + adata2 = synthetic_iid() + adata2.layers["mc"] = adata2.X + adata2.layers["cov"] = adata2.layers["mc"] + 10 + + mdata = MuData({"mod1": adata1, "mod2": adata2}) + + METHYLANVI.setup_mudata( + mdata, + mc_layer="mc", + cov_layer="cov", + labels_key="labels", + unlabeled_category="unknown", + methylation_contexts=["mod1", "mod2"], + batch_key="batch", + modalities={"batch_key": "mod1", "labels_key": "mod1"}, + ) + vae = METHYLANVI( + mdata, + ) + vae.train(3) + vae.get_elbo(indices=vae.validation_indices) + vae.get_normalized_methylation() # Retrieve methylation for all contexts + vae.get_normalized_methylation(context="mod1") # Retrieve for specific context + with pytest.raises(ValueError): # Should fail when invalid context selected + vae.get_normalized_methylation(context="mod3") + vae.get_latent_representation() + vae.differential_methylation(groupby="mod1:labels", group1="label_1") + vae.predict()