diff --git a/pyproject.toml b/pyproject.toml index 34c9d6b58b..87895625b8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ dependencies = [ "flax", "jax", "jaxlib", + "jaxtyping", "lightning>=2.0", "ml-collections>=0.1.1", "mudata>=0.1.2", @@ -200,6 +201,8 @@ ignore = [ "D401", # We want docstrings to start immediately after the opening triple quote "D213", + # Ignore for jaxtyping + "F722", # Raising ValueError is sufficient in tests. "PT011", # We support np.random functions. diff --git a/src/scvi/module/_vae.py b/src/scvi/module/_vae.py index 920b65ca18..e5b71640ec 100644 --- a/src/scvi/module/_vae.py +++ b/src/scvi/module/_vae.py @@ -6,6 +6,7 @@ import numpy as np import torch +from torch.distributions import Distribution from torch.nn.functional import one_hot from scvi import REGISTRY_KEYS, settings @@ -21,6 +22,8 @@ from collections.abc import Callable from typing import Literal + from jaxtyping import Array, Float, Int + from torch import Tensor from torch.distributions import Distribution logger = logging.getLogger(__name__) @@ -164,9 +167,9 @@ def __init__( use_layer_norm: Literal["encoder", "decoder", "none", "both"] = "none", use_size_factor_key: bool = False, use_observed_lib_size: bool = True, - library_log_means: np.ndarray | None = None, - library_log_vars: np.ndarray | None = None, - var_activation: Callable[[torch.Tensor], torch.Tensor] = None, + library_log_means: Array[Float, "1 n_batch"] | None = None, + library_log_vars: Array[Float, "1 n_batch"] | None = None, + var_activation: Callable[[Tensor], Tensor] = None, extra_encoder_kwargs: dict | None = None, extra_decoder_kwargs: dict | None = None, batch_embedding_kwargs: dict | None = None, @@ -280,8 +283,8 @@ def __init__( def _get_inference_input( self, - tensors: dict[str, torch.Tensor | None], - ) -> dict[str, torch.Tensor | None]: + tensors: dict[str, Tensor | None], + ) -> dict[str, Tensor | None]: """Get input tensors for the inference process.""" from scvi.data._constants import ADATA_MINIFY_TYPE @@ -303,9 +306,9 @@ def _get_inference_input( def _get_generative_input( self, - tensors: dict[str, torch.Tensor], - inference_outputs: dict[str, torch.Tensor | Distribution | None], - ) -> dict[str, torch.Tensor | None]: + tensors: dict[str, Tensor], + inference_outputs: dict[str, Tensor | Distribution | None], + ) -> dict[str, Tensor | None]: """Get input tensors for the generative process.""" size_factor = tensors.get(REGISTRY_KEYS.SIZE_FACTOR_KEY, None) if size_factor is not None: @@ -323,8 +326,8 @@ def _get_generative_input( def _compute_local_library_params( self, - batch_index: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: + batch_index: Int[Tensor, "n 1"], + ) -> tuple[Float[Tensor, "n b"], Float[Tensor, "n b"]]: """Computes local library parameters. Compute two tensors of shape (batch_index.shape[0], 1) where each @@ -347,54 +350,50 @@ def _compute_local_library_params( @auto_move_data def _regular_inference( self, - x: torch.Tensor, - batch_index: torch.Tensor, - cont_covs: torch.Tensor | None = None, - cat_covs: torch.Tensor | None = None, + x: Float[Tensor, "n g"], + batch_index: Int[Tensor, "n 1"], + cont_covs: Float[Tensor, "n c"] | None = None, + cat_covs: Int[Tensor, "n k _"] | None = None, n_samples: int = 1, - ) -> dict[str, torch.Tensor | Distribution | None]: + ) -> dict[str, Tensor | Distribution | None]: """Run the regular inference process.""" - x_ = x - if self.use_observed_lib_size: - library = torch.log(x.sum(1)).unsqueeze(1) - if self.log_variational: - x_ = torch.log1p(x_) + x_: Float[Tensor, "n g"] = torch.log1p(x) if self.log_variational else x + encoder_input: Float[Tensor, "n g"] = x_ if cont_covs is not None and self.encode_covariates: - encoder_input = torch.cat((x_, cont_covs), dim=-1) - else: - encoder_input = x_ + encoder_input: Float[Tensor, "n g+c"] = torch.cat((x_, cont_covs), dim=-1) + + categorical_input = () if cat_covs is not None and self.encode_covariates: - categorical_input = torch.split(cat_covs, 1, dim=1) - else: - categorical_input = () + categorical_input: tuple[Int[Tensor, "n 1 _"]] = torch.split(cat_covs, 1, dim=1) if self.batch_representation == "embedding" and self.encode_covariates: - batch_rep = self.compute_embedding(REGISTRY_KEYS.BATCH_KEY, batch_index) - encoder_input = torch.cat([encoder_input, batch_rep], dim=-1) + batch_rep: Float[Tensor, "n b"] = self.compute_embedding( + REGISTRY_KEYS.BATCH_KEY, + batch_index, + ) + encoder_input: Float[Tensor, "n g+c+b"] = torch.cat([encoder_input, batch_rep], dim=-1) qz, z = self.z_encoder(encoder_input, *categorical_input) else: qz, z = self.z_encoder(encoder_input, batch_index, *categorical_input) - ql = None - if not self.use_observed_lib_size: - if self.batch_representation == "embedding": - ql, library_encoded = self.l_encoder(encoder_input, *categorical_input) - else: - ql, library_encoded = self.l_encoder( - encoder_input, batch_index, *categorical_input - ) - library = library_encoded + if self.use_observed_lib_size: + ql = None + library: Float[Tensor, "n 1"] = torch.log(x.sum(dim=-1, keepdim=True)) + elif self.batch_representation == "embedding": + ql, library = self.l_encoder(encoder_input, *categorical_input) + else: + ql, library = self.l_encoder(encoder_input, batch_index, *categorical_input) if n_samples > 1: - untran_z = qz.sample((n_samples,)) - z = self.z_encoder.z_transformation(untran_z) + z: Float[Tensor, "s n d"] = qz.sample((n_samples,)) + z = self.z_encoder.z_transformation(z) if self.use_observed_lib_size: - library = library.unsqueeze(0).expand( + library: Float[Tensor, "s n 1"] = library.unsqueeze(0).expand( (n_samples, library.size(0), library.size(1)) ) else: - library = ql.sample((n_samples,)) + library: Float[Tensor, "s n 1"] = ql.sample((n_samples,)) return { MODULE_KEYS.Z_KEY: z, @@ -406,11 +405,11 @@ def _regular_inference( @auto_move_data def _cached_inference( self, - qzm: torch.Tensor, - qzv: torch.Tensor, - observed_lib_size: torch.Tensor, + qzm: Float[Tensor, "n d"], + qzv: Float[Tensor, "n d"], + observed_lib_size: Float[Tensor, "n 1"], n_samples: int = 1, - ) -> dict[str, torch.Tensor | None]: + ) -> dict[str, Tensor | None]: """Run the cached inference process.""" from torch.distributions import Normal @@ -438,14 +437,14 @@ def _cached_inference( @auto_move_data def generative( self, - z: torch.Tensor, - library: torch.Tensor, - batch_index: torch.Tensor, - cont_covs: torch.Tensor | None = None, - cat_covs: torch.Tensor | None = None, - size_factor: torch.Tensor | None = None, - y: torch.Tensor | None = None, - transform_batch: torch.Tensor | None = None, + z: Tensor, + library: Tensor, + batch_index: Tensor, + cont_covs: Tensor | None = None, + cat_covs: Tensor | None = None, + size_factor: Tensor | None = None, + y: Tensor | None = None, + transform_batch: Tensor | None = None, ) -> dict[str, Distribution | None]: """Run the generative process.""" from torch.nn.functional import linear @@ -543,8 +542,8 @@ def generative( def loss( self, - tensors: dict[str, torch.Tensor], - inference_outputs: dict[str, torch.Tensor | Distribution | None], + tensors: dict[str, Tensor], + inference_outputs: dict[str, Tensor | Distribution | None], generative_outputs: dict[str, Distribution | None], kl_weight: float = 1.0, ) -> LossOutput: @@ -583,10 +582,10 @@ def loss( @torch.inference_mode() def sample( self, - tensors: dict[str, torch.Tensor], + tensors: dict[str, Tensor], n_samples: int = 1, max_poisson_rate: float = 1e8, - ) -> torch.Tensor: + ) -> Tensor: r"""Generate predictive samples from the posterior predictive distribution. The posterior predictive distribution is denoted as :math:`p(\hat{x} \mid x)`, where @@ -634,7 +633,7 @@ def sample( @auto_move_data def marginal_ll( self, - tensors: dict[str, torch.Tensor], + tensors: dict[str, Tensor], n_mc_samples: int, return_mean: bool = False, n_mc_samples_per_pass: int = 1,