Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

style: add initial tensor shape typing to vae #2868

Closed
wants to merge 9 commits into from
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ dependencies = [
"flax",
"jax",
"jaxlib",
"jaxtyping",
"lightning>=2.0",
"ml-collections>=0.1.1",
"mudata>=0.1.2",
Expand Down Expand Up @@ -200,6 +201,8 @@ ignore = [
"D401",
# We want docstrings to start immediately after the opening triple quote
"D213",
# Ignore for jaxtyping
"F722",
# Raising ValueError is sufficient in tests.
"PT011",
# We support np.random functions.
Expand Down
115 changes: 57 additions & 58 deletions src/scvi/module/_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import numpy as np
import torch
from torch.distributions import Distribution
from torch.nn.functional import one_hot

from scvi import REGISTRY_KEYS, settings
Expand All @@ -21,6 +22,8 @@
from collections.abc import Callable
from typing import Literal

from jaxtyping import Array, Float, Int
from torch import Tensor
from torch.distributions import Distribution

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -164,9 +167,9 @@ def __init__(
use_layer_norm: Literal["encoder", "decoder", "none", "both"] = "none",
use_size_factor_key: bool = False,
use_observed_lib_size: bool = True,
library_log_means: np.ndarray | None = None,
library_log_vars: np.ndarray | None = None,
var_activation: Callable[[torch.Tensor], torch.Tensor] = None,
library_log_means: Array[Float, "1 n_batch"] | None = None,
library_log_vars: Array[Float, "1 n_batch"] | None = None,
var_activation: Callable[[Tensor], Tensor] = None,
extra_encoder_kwargs: dict | None = None,
extra_decoder_kwargs: dict | None = None,
batch_embedding_kwargs: dict | None = None,
Expand Down Expand Up @@ -280,8 +283,8 @@ def __init__(

def _get_inference_input(
self,
tensors: dict[str, torch.Tensor | None],
) -> dict[str, torch.Tensor | None]:
tensors: dict[str, Tensor | None],
) -> dict[str, Tensor | None]:
"""Get input tensors for the inference process."""
from scvi.data._constants import ADATA_MINIFY_TYPE

Expand All @@ -303,9 +306,9 @@ def _get_inference_input(

def _get_generative_input(
self,
tensors: dict[str, torch.Tensor],
inference_outputs: dict[str, torch.Tensor | Distribution | None],
) -> dict[str, torch.Tensor | None]:
tensors: dict[str, Tensor],
inference_outputs: dict[str, Tensor | Distribution | None],
) -> dict[str, Tensor | None]:
"""Get input tensors for the generative process."""
size_factor = tensors.get(REGISTRY_KEYS.SIZE_FACTOR_KEY, None)
if size_factor is not None:
Expand All @@ -323,8 +326,8 @@ def _get_generative_input(

def _compute_local_library_params(
self,
batch_index: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
batch_index: Int[Tensor, "n 1"],
) -> tuple[Float[Tensor, "n b"], Float[Tensor, "n b"]]:
"""Computes local library parameters.

Compute two tensors of shape (batch_index.shape[0], 1) where each
Expand All @@ -347,54 +350,50 @@ def _compute_local_library_params(
@auto_move_data
def _regular_inference(
self,
x: torch.Tensor,
batch_index: torch.Tensor,
cont_covs: torch.Tensor | None = None,
cat_covs: torch.Tensor | None = None,
x: Float[Tensor, "n g"],
batch_index: Int[Tensor, "n 1"],
cont_covs: Float[Tensor, "n c"] | None = None,
cat_covs: Int[Tensor, "n k _"] | None = None,
n_samples: int = 1,
) -> dict[str, torch.Tensor | Distribution | None]:
) -> dict[str, Tensor | Distribution | None]:
"""Run the regular inference process."""
x_ = x
if self.use_observed_lib_size:
library = torch.log(x.sum(1)).unsqueeze(1)
if self.log_variational:
x_ = torch.log1p(x_)
x_: Float[Tensor, "n g"] = torch.log1p(x) if self.log_variational else x

encoder_input: Float[Tensor, "n g"] = x_
if cont_covs is not None and self.encode_covariates:
encoder_input = torch.cat((x_, cont_covs), dim=-1)
else:
encoder_input = x_
encoder_input: Float[Tensor, "n g+c"] = torch.cat((x_, cont_covs), dim=-1)

categorical_input = ()
if cat_covs is not None and self.encode_covariates:
categorical_input = torch.split(cat_covs, 1, dim=1)
else:
categorical_input = ()
categorical_input: tuple[Int[Tensor, "n 1 _"]] = torch.split(cat_covs, 1, dim=1)

if self.batch_representation == "embedding" and self.encode_covariates:
batch_rep = self.compute_embedding(REGISTRY_KEYS.BATCH_KEY, batch_index)
encoder_input = torch.cat([encoder_input, batch_rep], dim=-1)
batch_rep: Float[Tensor, "n b"] = self.compute_embedding(
REGISTRY_KEYS.BATCH_KEY,
batch_index,
)
encoder_input: Float[Tensor, "n g+c+b"] = torch.cat([encoder_input, batch_rep], dim=-1)
qz, z = self.z_encoder(encoder_input, *categorical_input)
else:
qz, z = self.z_encoder(encoder_input, batch_index, *categorical_input)

ql = None
if not self.use_observed_lib_size:
if self.batch_representation == "embedding":
ql, library_encoded = self.l_encoder(encoder_input, *categorical_input)
else:
ql, library_encoded = self.l_encoder(
encoder_input, batch_index, *categorical_input
)
library = library_encoded
if self.use_observed_lib_size:
ql = None
library: Float[Tensor, "n 1"] = torch.log(x.sum(dim=-1, keepdim=True))
elif self.batch_representation == "embedding":
ql, library = self.l_encoder(encoder_input, *categorical_input)
else:
ql, library = self.l_encoder(encoder_input, batch_index, *categorical_input)

if n_samples > 1:
untran_z = qz.sample((n_samples,))
z = self.z_encoder.z_transformation(untran_z)
z: Float[Tensor, "s n d"] = qz.sample((n_samples,))
z = self.z_encoder.z_transformation(z)
if self.use_observed_lib_size:
library = library.unsqueeze(0).expand(
library: Float[Tensor, "s n 1"] = library.unsqueeze(0).expand(
(n_samples, library.size(0), library.size(1))
)
else:
library = ql.sample((n_samples,))
library: Float[Tensor, "s n 1"] = ql.sample((n_samples,))

return {
MODULE_KEYS.Z_KEY: z,
Expand All @@ -406,11 +405,11 @@ def _regular_inference(
@auto_move_data
def _cached_inference(
self,
qzm: torch.Tensor,
qzv: torch.Tensor,
observed_lib_size: torch.Tensor,
qzm: Float[Tensor, "n d"],
qzv: Float[Tensor, "n d"],
observed_lib_size: Float[Tensor, "n 1"],
n_samples: int = 1,
) -> dict[str, torch.Tensor | None]:
) -> dict[str, Tensor | None]:
"""Run the cached inference process."""
from torch.distributions import Normal

Expand Down Expand Up @@ -438,14 +437,14 @@ def _cached_inference(
@auto_move_data
def generative(
self,
z: torch.Tensor,
library: torch.Tensor,
batch_index: torch.Tensor,
cont_covs: torch.Tensor | None = None,
cat_covs: torch.Tensor | None = None,
size_factor: torch.Tensor | None = None,
y: torch.Tensor | None = None,
transform_batch: torch.Tensor | None = None,
z: Tensor,
library: Tensor,
batch_index: Tensor,
cont_covs: Tensor | None = None,
cat_covs: Tensor | None = None,
size_factor: Tensor | None = None,
y: Tensor | None = None,
transform_batch: Tensor | None = None,
) -> dict[str, Distribution | None]:
"""Run the generative process."""
from torch.nn.functional import linear
Expand Down Expand Up @@ -543,8 +542,8 @@ def generative(

def loss(
self,
tensors: dict[str, torch.Tensor],
inference_outputs: dict[str, torch.Tensor | Distribution | None],
tensors: dict[str, Tensor],
inference_outputs: dict[str, Tensor | Distribution | None],
generative_outputs: dict[str, Distribution | None],
kl_weight: float = 1.0,
) -> LossOutput:
Expand Down Expand Up @@ -583,10 +582,10 @@ def loss(
@torch.inference_mode()
def sample(
self,
tensors: dict[str, torch.Tensor],
tensors: dict[str, Tensor],
n_samples: int = 1,
max_poisson_rate: float = 1e8,
) -> torch.Tensor:
) -> Tensor:
r"""Generate predictive samples from the posterior predictive distribution.

The posterior predictive distribution is denoted as :math:`p(\hat{x} \mid x)`, where
Expand Down Expand Up @@ -634,7 +633,7 @@ def sample(
@auto_move_data
def marginal_ll(
self,
tensors: dict[str, torch.Tensor],
tensors: dict[str, Tensor],
n_mc_samples: int,
return_mean: bool = False,
n_mc_samples_per_pass: int = 1,
Expand Down