Skip to content

Commit

Permalink
Merge branch 'main' into poissonmultivi
Browse files Browse the repository at this point in the history
  • Loading branch information
ori-kron-wis authored Dec 1, 2024
2 parents 88c30d6 + 5ab3372 commit 227da02
Show file tree
Hide file tree
Showing 14 changed files with 272 additions and 302 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ to [Semantic Versioning]. Full commit history is available in the
validation set, if available. {pr}`3036`.
- Add `batch_key` and `labels_key` to `scvi.external.SCAR.setup_anndata`.
- Implemented variance of ZINB distribution. {pr}`3044`.
- Support for minified mode while retaining counts to skip the encoder.
- New Trainingplan argument `update_only_decoder` to use stored latent codes and skip training of
the encoder.
- Refactored code for minified models.
- Add {class}`scvi.external.METHYLVI` for modeling methylation data from single-cell
bisulfite sequencing (scBS-seq) experiments {pr}`2834`.

Expand All @@ -37,6 +41,8 @@ to [Semantic Versioning]. Full commit history is available in the
to correctly compute the maxmimum log-density across in-sample cells rather than the
aggregated posterior log-density {pr}`3007`.
- Fix references to `scvi.external` in `scvi.external.SCAR.setup_anndata`.
- Fix gimVI to append mini batches first into CPU during get_imputed and get_latent operations {pr}`3058`.
-

#### Changed

Expand Down
1 change: 1 addition & 0 deletions src/scvi/data/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

class _ADATA_MINIFY_TYPE_NT(NamedTuple):
LATENT_POSTERIOR: str = "latent_posterior_parameters"
LATENT_POSTERIOR_WITH_COUNTS: str = "latent_posterior_parameters_with_counts"


ADATA_MINIFY_TYPE = _ADATA_MINIFY_TYPE_NT()
Expand Down
10 changes: 8 additions & 2 deletions src/scvi/external/gimvi/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,9 +313,11 @@ def get_latent_representation(
self.module.sample_from_posterior_z(
sample_batch, mode, deterministic=deterministic
)
.cpu()
.detach()
)

latent = torch.cat(latent).cpu().detach().numpy()
latent = torch.cat(latent).numpy()
latents.append(latent)

return latents
Expand Down Expand Up @@ -372,6 +374,8 @@ def get_imputed_values(
deterministic=deterministic,
decode_mode=decode_mode,
)
.cpu()
.detach()
)
else:
imputed_value.append(
Expand All @@ -383,9 +387,11 @@ def get_imputed_values(
deterministic=deterministic,
decode_mode=decode_mode,
)
.cpu()
.detach()
)

imputed_value = torch.cat(imputed_value).cpu().detach().numpy()
imputed_value = torch.cat(imputed_value).numpy()
imputed_values.append(imputed_value)

return imputed_values
Expand Down
103 changes: 8 additions & 95 deletions src/scvi/model/_scanvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,10 @@
import numpy as np
import pandas as pd
import torch
from anndata import AnnData

from scvi import REGISTRY_KEYS, settings
from scvi.data import AnnDataManager
from scvi.data._constants import (
_ADATA_MINIFY_TYPE_UNS_KEY,
_SETUP_ARGS_KEY,
ADATA_MINIFY_TYPE,
)
Expand All @@ -25,12 +23,9 @@
LayerField,
NumericalJointObsField,
NumericalObsField,
ObsmField,
StringUnsField,
)
from scvi.dataloaders import SemiSupervisedDataSplitter
from scvi.model._utils import _init_library_size, get_max_epochs_heuristic
from scvi.model.utils import get_minified_adata_scrna
from scvi.module import SCANVAE
from scvi.train import SemiSupervisedTrainingPlan, TrainRunner
from scvi.train._callbacks import SubSampleLabels
Expand All @@ -45,17 +40,8 @@

from anndata import AnnData

from scvi._types import MinifiedDataType
from scvi.data.fields import (
BaseAnnDataField,
)

from ._scvi import SCVI

_SCANVI_LATENT_QZM = "_scanvi_latent_qzm"
_SCANVI_LATENT_QZV = "_scanvi_latent_qzv"
_SCANVI_OBSERVED_LIB_SIZE = "_scanvi_observed_lib_size"

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -115,6 +101,8 @@ class SCANVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseMinifiedModeModelClass):

_module_cls = SCANVAE
_training_plan_cls = SemiSupervisedTrainingPlan
_LATENT_QZM = "scanvi_latent_qzm"
_LATENT_QZV = "scanvi_latent_qzv"

def __init__(
self,
Expand Down Expand Up @@ -223,25 +211,26 @@ def from_scvi_model(
)
del scanvi_kwargs[k]

if scvi_model.minified_data_type is not None:
if scvi_model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR:
raise ValueError(
"We cannot use the given scvi model to initialize scanvi because it has a "
"minified adata."
"We cannot use the given scVI model to initialize scANVI because it has "
"minified adata. Keep counts when minifying model using "
"minified_data_type='latent_posterior_parameters_with_counts'."
)

if adata is None:
adata = scvi_model.adata
else:
if _is_minified(adata):
raise ValueError("Please provide a non-minified `adata` to initialize scanvi.")
raise ValueError("Please provide a non-minified `adata` to initialize scANVI.")
# validate new anndata against old model
scvi_model._validate_anndata(adata)

scvi_setup_args = deepcopy(scvi_model.adata_manager.registry[_SETUP_ARGS_KEY])
scvi_labels_key = scvi_setup_args["labels_key"]
if labels_key is None and scvi_labels_key is None:
raise ValueError(
"A `labels_key` is necessary as the SCVI model was initialized without one."
"A `labels_key` is necessary as the scVI model was initialized without one."
)
if scvi_labels_key is None:
scvi_setup_args.update({"labels_key": labels_key})
Expand Down Expand Up @@ -485,79 +474,3 @@ def setup_anndata(
adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
adata_manager.register_fields(adata, **kwargs)
cls.register_manager(adata_manager)

@staticmethod
def _get_fields_for_adata_minification(
minified_data_type: MinifiedDataType,
) -> list[BaseAnnDataField]:
"""Return the fields required for adata minification of the given minified_data_type."""
if minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR:
fields = [
ObsmField(
REGISTRY_KEYS.LATENT_QZM_KEY,
_SCANVI_LATENT_QZM,
),
ObsmField(
REGISTRY_KEYS.LATENT_QZV_KEY,
_SCANVI_LATENT_QZV,
),
NumericalObsField(
REGISTRY_KEYS.OBSERVED_LIB_SIZE,
_SCANVI_OBSERVED_LIB_SIZE,
),
]
else:
raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}")
fields.append(
StringUnsField(
REGISTRY_KEYS.MINIFY_TYPE_KEY,
_ADATA_MINIFY_TYPE_UNS_KEY,
),
)
return fields

def minify_adata(
self,
minified_data_type: MinifiedDataType = ADATA_MINIFY_TYPE.LATENT_POSTERIOR,
use_latent_qzm_key: str = "X_latent_qzm",
use_latent_qzv_key: str = "X_latent_qzv",
):
"""Minifies the model's adata.
Minifies the adata, and registers new anndata fields: latent qzm, latent qzv, adata uns
containing minified-adata type, and library size.
This also sets the appropriate property on the module to indicate that the adata is
minified.
Parameters
----------
minified_data_type
How to minify the data. Currently only supports `latent_posterior_parameters`.
If minified_data_type == `latent_posterior_parameters`:
* the original count data is removed (`adata.X`, adata.raw, and any layers)
* the parameters of the latent representation of the original data is stored
* everything else is left untouched
use_latent_qzm_key
Key to use in `adata.obsm` where the latent qzm params are stored
use_latent_qzv_key
Key to use in `adata.obsm` where the latent qzv params are stored
Notes
-----
The modification is not done inplace -- instead the model is assigned a new (minified)
version of the adata.
"""
if minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR:
raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}")

if self.module.use_observed_lib_size is False:
raise ValueError("Cannot minify the data if `use_observed_lib_size` is False")

minified_adata = get_minified_adata_scrna(self.adata, minified_data_type)
minified_adata.obsm[_SCANVI_LATENT_QZM] = self.adata.obsm[use_latent_qzm_key]
minified_adata.obsm[_SCANVI_LATENT_QZV] = self.adata.obsm[use_latent_qzv_key]
counts = self.adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY)
minified_adata.obs[_SCANVI_OBSERVED_LIB_SIZE] = np.squeeze(np.asarray(counts.sum(axis=1)))
self._update_adata_and_manager_post_minification(minified_adata, minified_data_type)
self.module.minified_data_type = minified_data_type
95 changes: 2 additions & 93 deletions src/scvi/model/_scvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,18 @@
import warnings
from typing import TYPE_CHECKING

import numpy as np

from scvi import REGISTRY_KEYS, settings
from scvi.data import AnnDataManager
from scvi.data._constants import _ADATA_MINIFY_TYPE_UNS_KEY, ADATA_MINIFY_TYPE
from scvi.data._utils import _get_adata_minify_type
from scvi.data.fields import (
CategoricalJointObsField,
CategoricalObsField,
LayerField,
NumericalJointObsField,
NumericalObsField,
ObsmField,
StringUnsField,
)
from scvi.model._utils import _init_library_size
from scvi.model.base import EmbeddingMixin, UnsupervisedTrainingMixin
from scvi.model.utils import get_minified_adata_scrna
from scvi.module import VAE
from scvi.utils import setup_anndata_dsp

Expand All @@ -32,15 +26,6 @@

from anndata import AnnData

from scvi._types import MinifiedDataType
from scvi.data.fields import (
BaseAnnDataField,
)

_SCVI_LATENT_QZM = "_scvi_latent_qzm"
_SCVI_LATENT_QZV = "_scvi_latent_qzv"
_SCVI_OBSERVED_LIB_SIZE = "_scvi_observed_lib_size"

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -115,6 +100,8 @@ class SCVI(
"""

_module_cls = VAE
_SCVI_LATENT_QZM = "scvi_latent_qzm"
_SCVI_LATENT_QZV = "scvi_latent_qzv"

def __init__(
self,
Expand Down Expand Up @@ -231,81 +218,3 @@ def setup_anndata(
adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
adata_manager.register_fields(adata, **kwargs)
cls.register_manager(adata_manager)

@staticmethod
def _get_fields_for_adata_minification(
minified_data_type: MinifiedDataType,
) -> list[BaseAnnDataField]:
"""Return the fields required for adata minification of the given minified_data_type."""
if minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR:
fields = [
ObsmField(
REGISTRY_KEYS.LATENT_QZM_KEY,
_SCVI_LATENT_QZM,
),
ObsmField(
REGISTRY_KEYS.LATENT_QZV_KEY,
_SCVI_LATENT_QZV,
),
NumericalObsField(
REGISTRY_KEYS.OBSERVED_LIB_SIZE,
_SCVI_OBSERVED_LIB_SIZE,
),
]
else:
raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}")
fields.append(
StringUnsField(
REGISTRY_KEYS.MINIFY_TYPE_KEY,
_ADATA_MINIFY_TYPE_UNS_KEY,
),
)
return fields

def minify_adata(
self,
minified_data_type: MinifiedDataType = ADATA_MINIFY_TYPE.LATENT_POSTERIOR,
use_latent_qzm_key: str = "X_latent_qzm",
use_latent_qzv_key: str = "X_latent_qzv",
) -> None:
"""Minifies the model's adata.
Minifies the adata, and registers new anndata fields: latent qzm, latent qzv, adata uns
containing minified-adata type, and library size.
This also sets the appropriate property on the module to indicate that the adata is
minified.
Parameters
----------
minified_data_type
How to minify the data. Currently only supports `latent_posterior_parameters`.
If minified_data_type == `latent_posterior_parameters`:
* the original count data is removed (`adata.X`, adata.raw, and any layers)
* the parameters of the latent representation of the original data is stored
* everything else is left untouched
use_latent_qzm_key
Key to use in `adata.obsm` where the latent qzm params are stored
use_latent_qzv_key
Key to use in `adata.obsm` where the latent qzv params are stored
Notes
-----
The modification is not done inplace -- instead the model is assigned a new (minified)
version of the adata.
"""
# TODO(adamgayoso): Add support for a scenario where we want to cache the latent posterior
# without removing the original counts.
if minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR:
raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}")

if self.module.use_observed_lib_size is False:
raise ValueError("Cannot minify the data if `use_observed_lib_size` is False")

minified_adata = get_minified_adata_scrna(self.adata, minified_data_type)
minified_adata.obsm[_SCVI_LATENT_QZM] = self.adata.obsm[use_latent_qzm_key]
minified_adata.obsm[_SCVI_LATENT_QZV] = self.adata.obsm[use_latent_qzv_key]
counts = self.adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY)
minified_adata.obs[_SCVI_OBSERVED_LIB_SIZE] = np.squeeze(np.asarray(counts.sum(axis=1)))
self._update_adata_and_manager_post_minification(minified_adata, minified_data_type)
self.module.minified_data_type = minified_data_type
Loading

0 comments on commit 227da02

Please sign in to comment.