Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom dataloader registry support #2932

Open
wants to merge 96 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
96 commits
Select commit Hold shift + click to select a range
7088e4b
copying CZI custom dataloader into our repo
ori-kron-wis Jul 28, 2024
cc72b05
added some fixes to the custom dataloader stuff
ori-kron-wis Jul 30, 2024
46048e3
Some suggestions
canergen Jul 30, 2024
14f343d
Changes to datamodule pipeline
canergen Jul 31, 2024
17282cd
Fixed attr_dict
canergen Jul 31, 2024
a4143f5
added some fixes based on custom data loader test
ori-kron-wis Aug 1, 2024
69abc47
Changes to dataloader
canergen Aug 6, 2024
dc21a3d
copying CZI custom dataloader into our repo
ori-kron-wis Jul 28, 2024
a1098b3
added some fixes to the custom dataloader stuff
ori-kron-wis Jul 30, 2024
b07216b
Some suggestions
canergen Jul 30, 2024
a578af1
Changes to datamodule pipeline
canergen Jul 31, 2024
42434ec
Fixed attr_dict
canergen Jul 31, 2024
3d0c890
added some fixes based on custom data loader test
ori-kron-wis Aug 1, 2024
eff5b1e
Changes to dataloader
canergen Aug 6, 2024
cbdc26e
Merge remote-tracking branch 'origin/ori-2907-custom-dataloader-regis…
ori-kron-wis Aug 7, 2024
18d65a6
add changes to tests and some merging with main following custom data…
ori-kron-wis Aug 7, 2024
4fe3ee1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 7, 2024
1110966
just put the cutom dataloder2 test under remarks so hook tests will r…
ori-kron-wis Aug 7, 2024
7972bdc
fixes
ori-kron-wis Aug 7, 2024
2d86c43
additional external models fixes once there is a registry
ori-kron-wis Aug 7, 2024
3c44d86
fixed a few failed tests
ori-kron-wis Aug 11, 2024
c0889d8
fix archesmixin init and added new custom dataloader test and github …
ori-kron-wis Aug 11, 2024
8fe043c
fix again for from __future__ import annotations
ori-kron-wis Aug 11, 2024
d8cf0f6
fix for run custom dataloader in github action
ori-kron-wis Aug 11, 2024
c41e8b2
rollback
ori-kron-wis Aug 11, 2024
6ec5d4d
added label to the new githubaction for custom dataloader
ori-kron-wis Aug 11, 2024
6bce317
fix for github action for custom dataloaders
ori-kron-wis Aug 12, 2024
1f4ae9d
another fix to custom dataloder test and github action
ori-kron-wis Aug 12, 2024
de1f30b
another fix to custom dataloder test and github action
ori-kron-wis Aug 12, 2024
49fa01e
another fix to custom dataloder test and github action
ori-kron-wis Aug 12, 2024
e33a935
another fix to custom dataloder test and github action
ori-kron-wis Aug 12, 2024
48627d9
another fix to custom dataloder test and github action
ori-kron-wis Aug 12, 2024
609094d
another fix to custom dataloder test and github action
ori-kron-wis Aug 12, 2024
8cf3517
another fix to custom dataloder test and github action
ori-kron-wis Aug 12, 2024
ba5a028
another fix to custom dataloder test and github action
ori-kron-wis Aug 12, 2024
a7dc3fe
another fix to custom dataloder test and github action
ori-kron-wis Aug 12, 2024
f3ff0f8
another fix to custom dataloder test and github action
ori-kron-wis Aug 12, 2024
083c76e
Merge branch 'main' into ori-2907-custom-dataloader-registry
ori-kron-wis Sep 9, 2024
70bba69
Merge branch 'main' into ori-2907-custom-dataloader-registry
ori-kron-wis Sep 15, 2024
8c75662
Merge branch 'main' into ori-2907-custom-dataloader-registry
ori-kron-wis Sep 16, 2024
b6eb2f1
Returned REGISTRY_KEYS for import, after was drop in recent merges
ori-kron-wis Sep 16, 2024
2979ea2
It is ok to drop it after scarches categorial covariates fix
ori-kron-wis Sep 16, 2024
67e9b34
Merge branch 'main' into ori-2907-custom-dataloader-registry
ori-kron-wis Sep 17, 2024
11fe33a
Merge branch 'main' into ori-2907-custom-dataloader-registry
ori-kron-wis Sep 17, 2024
4a648ff
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 17, 2024
e3831cb
moved to type checking blocks beucase of ruff updates
ori-kron-wis Sep 17, 2024
e1837bd
Merge branch 'main' into ori-2907-custom-dataloader-registry
ori-kron-wis Sep 26, 2024
bf4d3bf
Merge remote-tracking branch 'origin/main' into ori-2907-custom-datal…
ori-kron-wis Oct 7, 2024
2cc8ff9
updated for CZI custom dataloader test and backend
ori-kron-wis Oct 9, 2024
e62dc3a
Merge branch 'main' into ori-2907-custom-dataloader-registry
ori-kron-wis Oct 9, 2024
41fd877
added cellxgene-census folder as well for debug (will not be merged)
ori-kron-wis Oct 9, 2024
10ada9c
added cellxgene-census packge to run test
ori-kron-wis Oct 9, 2024
dd3649c
added torchdata packge to run test
ori-kron-wis Oct 9, 2024
c6acb5a
fixed the test workwflow
ori-kron-wis Oct 9, 2024
b35c6eb
adding the lamindb as well
ori-kron-wis Oct 10, 2024
1801604
fix the c.dataloders test
ori-kron-wis Oct 10, 2024
ed77a65
fix the c.dataloders test
ori-kron-wis Oct 10, 2024
fc831d5
fix the c.dataloders test
ori-kron-wis Oct 10, 2024
7400621
fix the c.dataloders test
ori-kron-wis Oct 10, 2024
47376ca
fix the c.dataloders test
ori-kron-wis Oct 10, 2024
f94f7fa
removed redundat functions in code base
ori-kron-wis Oct 13, 2024
962f043
Added scanvi support, including CZI datamodule fix for it
ori-kron-wis Oct 15, 2024
5c21d71
Merge remote-tracking branch 'origin/main' into ori-2907-custom-datal…
ori-kron-wis Oct 20, 2024
a8aeffe
updates from main
ori-kron-wis Dec 25, 2024
1283616
more updates from main
ori-kron-wis Dec 25, 2024
624ee72
Merge branch 'main' into ori-2907-custom-dataloader-registry
ori-kron-wis Dec 25, 2024
6d4f368
Merge remote-tracking branch 'origin/ori-2907-custom-dataloader-regis…
ori-kron-wis Dec 25, 2024
8ab01a4
updated related to tests
ori-kron-wis Dec 25, 2024
31e1d44
updated related to tests
ori-kron-wis Dec 25, 2024
93666fa
Running DataLoader MappedCollection
canergen Dec 30, 2024
1d1d6d3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 30, 2024
7695a8a
Fixed LaminDB dataloader
canergen Dec 31, 2024
e4d732a
Merge branch 'ori-2907-custom-dataloader-registry' of https://github.…
canergen Dec 31, 2024
a651442
LaminDB dataloader test.
canergen Dec 31, 2024
9767b8c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 31, 2024
719e740
Merge branch 'main' into ori-2907-custom-dataloader-registry
ori-kron-wis Dec 31, 2024
1a4c796
Merge remote-tracking branch 'origin/main' into ori-2907-custom-datal…
ori-kron-wis Jan 8, 2025
5666558
Changes for MappedCollection.
canergen Jan 8, 2025
c740dd2
Merge branch 'ori-2907-custom-dataloader-registry' of https://github.…
canergen Jan 8, 2025
61f2e27
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 8, 2025
874935b
Add other notebook for testing new dataloader
canergen Jan 9, 2025
f2c63bd
Merge branch 'ori-2907-custom-dataloader-registry' of https://github.…
canergen Jan 9, 2025
35d45c8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 9, 2025
38c670f
Merge remote-tracking branch 'origin/main' into ori-2907-custom-datal…
ori-kron-wis Jan 16, 2025
c93fc97
updates to test script
ori-kron-wis Jan 16, 2025
5045fc3
remove old test nb
ori-kron-wis Jan 16, 2025
55775f9
update test
ori-kron-wis Jan 16, 2025
7ccdf8d
update test
ori-kron-wis Jan 16, 2025
f88dc50
updated czi cdl
ori-kron-wis Jan 16, 2025
1f3ea11
updated czi cdl
ori-kron-wis Jan 16, 2025
d0ec46f
Merge remote-tracking branch 'origin/main' into ori-2907-custom-datal…
ori-kron-wis Jan 20, 2025
e304922
merge with main + updates
ori-kron-wis Feb 9, 2025
5ccd1ed
more updates
ori-kron-wis Feb 9, 2025
96a09d8
more updates
ori-kron-wis Feb 9, 2025
601d86f
more updates
ori-kron-wis Feb 9, 2025
2485bb6
pyproject update
ori-kron-wis Feb 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
added some fixes to the custom dataloader stuff
  • Loading branch information
ori-kron-wis committed Jul 30, 2024
commit cc72b05f27f75f349b5946aeaca5e30a10bdcb21
49 changes: 49 additions & 0 deletions src/scvi/data/_manager.py
Original file line number Diff line number Diff line change
@@ -192,6 +192,55 @@ def register_fields(
self._assign_uuid()
self._assign_most_recent_manager_uuid()

def register_data_module_fields(
self,
datamodule,
source_registry: dict | None = None,
**transfer_kwargs,
):
"""Registers each field associated with this instance with the AnnData object.

Either registers or transfers the setup from `source_setup_dict` if passed in.
Sets ``self.adata``.

Parameters
----------
adata
AnnData object to be registered.
source_registry
Registry created after registering an AnnData using an
:class:`~scvi.data.AnnDataManager` object.
transfer_kwargs
Additional keywords which modify transfer behavior. Only applicable if
``source_registry`` is set.
"""
if self.adata is not None:
raise AssertionError("Existing AnnData object registered with this Manager instance.")

if source_registry is None and transfer_kwargs:
raise TypeError(
f"register_fields() got unexpected keyword arguments {transfer_kwargs} passed "
"without a source_registry."
)

self._validate_anndata_object(datamodule)

for field in self.fields:
self._add_field(
field=field,
adata=datamodule,
source_registry=source_registry,
**transfer_kwargs,
)

# Save arguments for register_fields.
self._source_registry = deepcopy(source_registry)
self._transfer_kwargs = deepcopy(transfer_kwargs)

self.adata = datamodule
self._assign_uuid()
self._assign_most_recent_manager_uuid()

def _add_field(
self,
field: AnnDataField,
66 changes: 65 additions & 1 deletion src/scvi/model/_scvi.py
Original file line number Diff line number Diff line change
@@ -140,8 +140,10 @@ def __init__(
f"gene_likelihood: {gene_likelihood}, latent_distribution: {latent_distribution}."
)

# in the next part we need to construct the same module no mather the way
# dataloader was given
if self._module_init_on_train:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove module_init_on_train. This is deprecated code with the new dataloader implementation.

# Here we need to adjust given the new custom data loader
# Here we need to adjust given the new custom data loader like CZI case
self.module = None
warnings.warn(
"Model was initialized without `adata`. The module will be initialized when "
@@ -225,6 +227,68 @@ def setup_anndata(
adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
adata_manager.register_fields(adata, **kwargs)
cls.register_manager(adata_manager)
# adata_manager.get_state_registry(SCVI.REGISTRY_KEYS.X_KEY).to_dict()
# adata_manager.registry[_constants._FIELD_REGISTRIES_KEY]
# pprint(adata_manager.registry)

@classmethod
@setup_anndata_dsp.dedent
def setup_datamodule(
cls,
datamodule,
layer: str | None = None,
batch_key: str | None = None,
labels_key: str | None = None,
size_factor_key: str | None = None,
categorical_covariate_keys: list[str] | None = None,
continuous_covariate_keys: list[str] | None = None,
**kwargs,
):
"""%(summary)s.

Parameters
----------
%(param_adata)s
%(param_layer)s
%(param_batch_key)s
%(param_labels_key)s
%(param_size_factor_key)s
%(param_cat_cov_keys)s
%(param_cont_cov_keys)s
"""
setup_method_args = cls._get_setup_method_args(**locals())
anndata_fields = [
LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True),
CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key),
CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key),
NumericalObsField(REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key, required=False),
CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys),
NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys),
]
# register new fields if the adata is minified
# adata_minify_type = _get_adata_minify_type(adata)
# if adata_minify_type is not None:
# anndata_fields += cls._get_fields_for_adata_minification(adata_minify_type)
adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
adata_manager.registry["setup_method_name"] = "setup_datamodule"
adata_manager.registry["setup_args"]["layer"] = datamodule.datapipe.layer_name
adata_manager.registry["setup_args"]["batch_key"] = datamodule.batch_keys
adata_manager.registry["setup_args"]["labels_key"]
adata_manager.registry["setup_args"]["batch_key"]
adata_manager.registry["setup_args"]["batch_key"]
adata_manager.registry["setup_args"]["batch_key"]
# datamodule._datapipe.obs_column_names
# datamodule._datapipe.obs_encoders
# adata_manager.register_fields(adata, **kwargs)
# how to etract the information we need from the datamodule
adata_manager.register_data_module_fields(
datamodule, **kwargs
) # here we need a new function for data module

cls.register_manager(adata_manager)
# adata_manager.get_state_registry(SCVI.REGISTRY_KEYS.X_KEY).to_dict()
# adata_manager.registry[_constants._FIELD_REGISTRIES_KEY]
# pprint(adata_manager.registry)

@staticmethod
def _get_fields_for_adata_minification(
17 changes: 17 additions & 0 deletions src/scvi/model/base/_base_model.py
Original file line number Diff line number Diff line change
@@ -815,6 +815,23 @@ def setup_anndata(
on a model-specific instance of :class:`~scvi.data.AnnDataManager`.
"""

@classmethod
@abstractmethod
@setup_anndata_dsp.dedent
def setup_datamodule(
cls,
datamodule,
*args,
**kwargs,
):
"""%(summary)s.

Each model class deriving from this class provides parameters to this method
according to its needs. To operate correctly with the model initialization,
the implementation must call :meth:`~scvi.model.base.BaseModelClass.register_manager`
on a model-specific instance of :class:`~scvi.data.AnnDataManager`.
"""

@staticmethod
def view_setup_args(dir_path: str, prefix: str | None = None) -> None:
"""Print args used to setup a saved model.
4 changes: 4 additions & 0 deletions src/scvi/model/base/_training_mixin.py
Original file line number Diff line number Diff line change
@@ -102,6 +102,7 @@ def train(
)

if datamodule is None:
# In the general case we enter here
datasplitter_kwargs = datasplitter_kwargs or {}
datamodule = self._data_splitter_cls(
self.adata_manager,
@@ -114,6 +115,7 @@ def train(
**datasplitter_kwargs,
)
elif self.module is None:
# in CZI case we enter here
self.module = self._module_cls(
datamodule.n_vars,
n_batch=datamodule.n_batch,
@@ -122,6 +124,8 @@ def train(
n_cats_per_cov=getattr(datamodule, "n_cats_per_cov", None),
**self._module_kwargs,
)
# after either of the cases we should be here with the same self.module
# and same datamodule

plan_kwargs = plan_kwargs or {}
training_plan = self._training_plan_cls(self.module, **plan_kwargs)
63 changes: 63 additions & 0 deletions tests/dataloaders/test_custom_dataloader.py
Original file line number Diff line number Diff line change
@@ -1 +1,64 @@
from __future__ import annotations

import os

import numpy as np
import scanpy as sc

import scvi
from scvi.data import _constants, synthetic_iid
from scvi.model import SCVI

# We will now create the SCVI model object:
# Its parameters:
n_layers = 1
n_latent = 10
batch_size = 1024
train_size = 0.9
max_epochs = 1


# COMAPRE TO THE ORIGINAL METHOD!!! - use the same data!!!
# We first create a registry using the orignal way of anndata in order to compare and add
# what is missing
adata = synthetic_iid()
adata.obs["size_factor"] = np.random.randint(1, 5, size=(adata.shape[0],))
SCVI.setup_anndata(
adata,
batch_key="batch",
labels_key="labels",
size_factor_key="size_factor",
)
#
model_orig = SCVI(adata, n_latent=n_latent)
model_orig.train(1, check_val_every_n_epoch=1, train_size=0.5)

# Saving the model
save_dir = "/Users/orikr/runs/290724/" # tempfile.TemporaryDirectory()
model_dir = os.path.join(save_dir, "scvi_orig_model")
model_orig.save(model_dir, overwrite=True)

# Loading the model (just as a compariosn)
model_orig_loaded = scvi.model.SCVI.load(model_dir, adata=adata)

# Obtaining model outputs
SCVI_LATENT_KEY = "X_scVI"
latent = model_orig.get_latent_representation()
adata.obsm[SCVI_LATENT_KEY] = latent
# latent.shape

# You can see all necessary entries and the structure at
adata_manager = model_orig.adata_manager
model_orig.view_anndata_setup(hide_state_registries=True)
# adata_manager.get_state_registry(SCVI.REGISTRY_KEYS.X_KEY).to_dict()
adata_manager.registry[_constants._FIELD_REGISTRIES_KEY]

# Plot UMAP and save the figure for later check
sc.pp.neighbors(adata, use_rep="scvi", key_added="scvi")
sc.tl.umap(adata, neighbors_key="scvi")
sc.pl.umap(adata, color="dataset_id", title="SCVI")

# Now return and add all the registry stuff that we will need

# Now add the missing stuff from the current CZI implemenation in order for us to have the exact
# same steps like the original way (except than setup_anndata)
Loading