From ebe6f70ccb454cb67f1d6a01b0c22545fb6f4583 Mon Sep 17 00:00:00 2001 From: Martin Kim Date: Fri, 28 Jun 2024 13:52:42 -0700 Subject: [PATCH 1/3] style: add initial tensor shape typing to vae --- pyproject.toml | 3 ++ src/scvi/module/_vae.py | 114 ++++++++++++++++++++-------------------- 2 files changed, 59 insertions(+), 58 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5c8e7d177c..b94a144434 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dependencies = [ "flax", "jax", "jaxlib", + "jaxtyping", "lightning>=2.0", "ml-collections>=0.1.1", "mudata>=0.1.2", @@ -200,6 +201,8 @@ ignore = [ "D203", # We want docstrings to start immediately after the opening triple quote "D213", + # Ignore for jaxtyping + "F722", ] [tool.ruff.lint.pydocstyle] diff --git a/src/scvi/module/_vae.py b/src/scvi/module/_vae.py index 26f61ba9d1..fb01c8d0a4 100644 --- a/src/scvi/module/_vae.py +++ b/src/scvi/module/_vae.py @@ -6,6 +6,8 @@ import numpy as np import torch +from jaxtyping import Array, Float, Int +from torch import Tensor from torch.distributions import Distribution from torch.nn.functional import one_hot @@ -159,9 +161,9 @@ def __init__( use_layer_norm: Literal["encoder", "decoder", "none", "both"] = "none", use_size_factor_key: bool = False, use_observed_lib_size: bool = True, - library_log_means: np.ndarray | None = None, - library_log_vars: np.ndarray | None = None, - var_activation: Callable[[torch.Tensor], torch.Tensor] = None, + library_log_means: Array[Float, "1 n_batch"] | None = None, + library_log_vars: Array[Float, "1 n_batch"] | None = None, + var_activation: Callable[[Tensor], Tensor] = None, extra_encoder_kwargs: dict | None = None, extra_decoder_kwargs: dict | None = None, batch_embedding_kwargs: dict | None = None, @@ -275,8 +277,8 @@ def __init__( def _get_inference_input( self, - tensors: dict[str, torch.Tensor | None], - ) -> dict[str, torch.Tensor | None]: + tensors: dict[str, Tensor | None], + ) -> dict[str, Tensor | None]: """Get input tensors for the inference process.""" from scvi.data._constants import ADATA_MINIFY_TYPE @@ -298,9 +300,9 @@ def _get_inference_input( def _get_generative_input( self, - tensors: dict[str, torch.Tensor], - inference_outputs: dict[str, torch.Tensor | Distribution | None], - ) -> dict[str, torch.Tensor | None]: + tensors: dict[str, Tensor], + inference_outputs: dict[str, Tensor | Distribution | None], + ) -> dict[str, Tensor | None]: """Get input tensors for the generative process.""" size_factor = tensors.get(REGISTRY_KEYS.SIZE_FACTOR_KEY, None) if size_factor is not None: @@ -318,8 +320,8 @@ def _get_generative_input( def _compute_local_library_params( self, - batch_index: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: + batch_index: Int[Tensor, "n 1"], + ) -> tuple[Float[Tensor, "n b"], Float[Tensor, "n b"]]: """Computes local library parameters. Compute two tensors of shape (batch_index.shape[0], 1) where each @@ -342,54 +344,50 @@ def _compute_local_library_params( @auto_move_data def _regular_inference( self, - x: torch.Tensor, - batch_index: torch.Tensor, - cont_covs: torch.Tensor | None = None, - cat_covs: torch.Tensor | None = None, + x: Float[Tensor, "n g"], + batch_index: Int[Tensor, "n 1"], + cont_covs: Float[Tensor, "n c"] | None = None, + cat_covs: Int[Tensor, "n k _"] | None = None, n_samples: int = 1, - ) -> dict[str, torch.Tensor | Distribution | None]: + ) -> dict[str, Tensor | Distribution | None]: """Run the regular inference process.""" - x_ = x - if self.use_observed_lib_size: - library = torch.log(x.sum(1)).unsqueeze(1) - if self.log_variational: - x_ = torch.log1p(x_) + x_: Float[Tensor, "n g"] = torch.log1p(x) if self.log_variational else x + encoder_input: Float[Tensor, "n g"] = x_ if cont_covs is not None and self.encode_covariates: - encoder_input = torch.cat((x_, cont_covs), dim=-1) - else: - encoder_input = x_ + encoder_input: Float[Tensor, "n g+c"] = torch.cat((x_, cont_covs), dim=-1) + + categorical_input = () if cat_covs is not None and self.encode_covariates: - categorical_input = torch.split(cat_covs, 1, dim=1) - else: - categorical_input = () + categorical_input: tuple[Int[Tensor, "n 1 _"]] = torch.split(cat_covs, 1, dim=1) if self.batch_representation == "embedding" and self.encode_covariates: - batch_rep = self.compute_embedding(REGISTRY_KEYS.BATCH_KEY, batch_index) - encoder_input = torch.cat([encoder_input, batch_rep], dim=-1) + batch_rep: Float[Tensor, "n b"] = self.compute_embedding( + REGISTRY_KEYS.BATCH_KEY, + batch_index, + ) + encoder_input: Float[Tensor, "n g+c+b"] = torch.cat([encoder_input, batch_rep], dim=-1) qz, z = self.z_encoder(encoder_input, *categorical_input) else: qz, z = self.z_encoder(encoder_input, batch_index, *categorical_input) - ql = None - if not self.use_observed_lib_size: - if self.batch_representation == "embedding": - ql, library_encoded = self.l_encoder(encoder_input, *categorical_input) - else: - ql, library_encoded = self.l_encoder( - encoder_input, batch_index, *categorical_input - ) - library = library_encoded + if self.use_observed_lib_size: + ql = None + library: Float[Tensor, "n 1"] = torch.log(x.sum(dim=-1, keepdim=True)) + elif self.batch_representation == "embedding": + ql, library = self.l_encoder(encoder_input, *categorical_input) + else: + ql, library = self.l_encoder(encoder_input, batch_index, *categorical_input) if n_samples > 1: - untran_z = qz.sample((n_samples,)) - z = self.z_encoder.z_transformation(untran_z) + z: Float[Tensor, "s n d"] = qz.sample((n_samples,)) + z = self.z_encoder.z_transformation(z) if self.use_observed_lib_size: - library = library.unsqueeze(0).expand( + library: Float[Tensor, "s n 1"] = library.unsqueeze(0).expand( (n_samples, library.size(0), library.size(1)) ) else: - library = ql.sample((n_samples,)) + library: Float[Tensor, "s n 1"] = ql.sample((n_samples,)) return { MODULE_KEYS.Z_KEY: z, @@ -401,11 +399,11 @@ def _regular_inference( @auto_move_data def _cached_inference( self, - qzm: torch.Tensor, - qzv: torch.Tensor, - observed_lib_size: torch.Tensor, + qzm: Float[Tensor, "n d"], + qzv: Float[Tensor, "n d"], + observed_lib_size: Float[Tensor, "n 1"], n_samples: int = 1, - ) -> dict[str, torch.Tensor | None]: + ) -> dict[str, Tensor | None]: """Run the cached inference process.""" from torch.distributions import Normal @@ -433,14 +431,14 @@ def _cached_inference( @auto_move_data def generative( self, - z: torch.Tensor, - library: torch.Tensor, - batch_index: torch.Tensor, - cont_covs: torch.Tensor | None = None, - cat_covs: torch.Tensor | None = None, - size_factor: torch.Tensor | None = None, - y: torch.Tensor | None = None, - transform_batch: torch.Tensor | None = None, + z: Tensor, + library: Tensor, + batch_index: Tensor, + cont_covs: Tensor | None = None, + cat_covs: Tensor | None = None, + size_factor: Tensor | None = None, + y: Tensor | None = None, + transform_batch: Tensor | None = None, ) -> dict[str, Distribution | None]: """Run the generative process.""" from torch.distributions import Normal @@ -534,8 +532,8 @@ def generative( def loss( self, - tensors: dict[str, torch.Tensor], - inference_outputs: dict[str, torch.Tensor | Distribution | None], + tensors: dict[str, Tensor], + inference_outputs: dict[str, Tensor | Distribution | None], generative_outputs: dict[str, Distribution | None], kl_weight: float = 1.0, ) -> LossOutput: @@ -574,10 +572,10 @@ def loss( @torch.inference_mode() def sample( self, - tensors: dict[str, torch.Tensor], + tensors: dict[str, Tensor], n_samples: int = 1, max_poisson_rate: float = 1e8, - ) -> torch.Tensor: + ) -> Tensor: r"""Generate predictive samples from the posterior predictive distribution. The posterior predictive distribution is denoted as :math:`p(\hat{x} \mid x)`, where @@ -625,7 +623,7 @@ def sample( @auto_move_data def marginal_ll( self, - tensors: dict[str, torch.Tensor], + tensors: dict[str, Tensor], n_mc_samples: int, return_mean: bool = False, n_mc_samples_per_pass: int = 1, From 6b3097402f22f881e03126a874711890e6d263af Mon Sep 17 00:00:00 2001 From: Ori Kronfeld Date: Thu, 26 Sep 2024 11:05:14 +0300 Subject: [PATCH 2/3] Update _vae.py fixed type checking ruff rules --- src/scvi/module/_vae.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/scvi/module/_vae.py b/src/scvi/module/_vae.py index f2eeedba87..7e12f5e757 100644 --- a/src/scvi/module/_vae.py +++ b/src/scvi/module/_vae.py @@ -6,8 +6,6 @@ import numpy as np import torch -from jaxtyping import Array, Float, Int -from torch import Tensor from torch.distributions import Distribution from torch.nn.functional import one_hot @@ -23,7 +21,8 @@ if TYPE_CHECKING: from collections.abc import Callable from typing import Literal - + from jaxtyping import Array, Float, Int + from torch import Tensor from torch.distributions import Distribution logger = logging.getLogger(__name__) From 993bad2c89e66b2fe24ab7982e20d2e3e6e308d2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 26 Sep 2024 08:05:25 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/scvi/module/_vae.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/scvi/module/_vae.py b/src/scvi/module/_vae.py index 7e12f5e757..e5b71640ec 100644 --- a/src/scvi/module/_vae.py +++ b/src/scvi/module/_vae.py @@ -21,6 +21,7 @@ if TYPE_CHECKING: from collections.abc import Callable from typing import Literal + from jaxtyping import Array, Float, Int from torch import Tensor from torch.distributions import Distribution