diff --git a/notebooks/additional/scprint_overfit.ipynb b/notebooks/additional/scprint_overfit.ipynb index 16a7a9e..5fd0b16 100644 --- a/notebooks/additional/scprint_overfit.ipynb +++ b/notebooks/additional/scprint_overfit.ipynb @@ -959,8 +959,7 @@ } ], "source": [ - "trainer.fit(model, dat\n", - "amodule=datamodule)" + "trainer.fit(model, datamodule=datamodule)" ] }, { @@ -1255,7 +1254,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.0" + "version": "3.10.15" } }, "nbformat": 4, diff --git a/notebooks/additional/scprint_test.ipynb b/notebooks/scprint_train.ipynb similarity index 100% rename from notebooks/additional/scprint_test.ipynb rename to notebooks/scprint_train.ipynb diff --git a/pyproject.toml b/pyproject.toml index b69027f..dae89c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ dependencies = [ "leidenalg>=0.10.0", "django>=4.0.0", "scikit-misc>=0.5.0", - "scDataLoader>=1.1.3", + "scDataLoader>=1.6.5", "GRnnData>=1.1.4", "BenGRN>=1.2.4", "gseapy>=0.10.0", diff --git a/scprint/model/model.py b/scprint/model/model.py index 28b38bb..3d428e5 100644 --- a/scprint/model/model.py +++ b/scprint/model/model.py @@ -132,6 +132,7 @@ def __init__( self.fused_adam = False self.lr_reduce_patience = 1 self.lr_reduce_factor = 0.6 + self.test_every = 1 self.lr_reduce_monitor = "val_loss" self.name = "" self.lr = lr @@ -1103,7 +1104,7 @@ def on_validation_epoch_end(self): self.log_adata( gtclass=self.info, name="validation_part_" + str(self.counter) ) - if (self.current_epoch + 1) % 30 == 0: + if (self.current_epoch + 1) % self.test_every == 0: self.on_test_epoch_end() def test_step(self, *args, **kwargs): diff --git a/scprint/tasks/denoise.py b/scprint/tasks/denoise.py index 7d2750d..1eab0cd 100644 --- a/scprint/tasks/denoise.py +++ b/scprint/tasks/denoise.py @@ -243,7 +243,6 @@ def default_benchmark( """ adata = sc.read_h5ad(default_dataset) denoise = Denoiser( - model, batch_size=40, max_len=max_len, max_cells=10_000, @@ -251,9 +250,8 @@ def default_benchmark( num_workers=8, predict_depth_mult=10, downsample=0.7, - devices=1, ) - return denoise(adata)[0] + return denoise(model, adata)[0] def open_benchmark(model): diff --git a/scprint/tasks/grn.py b/scprint/tasks/grn.py index 857531c..701c688 100644 --- a/scprint/tasks/grn.py +++ b/scprint/tasks/grn.py @@ -504,7 +504,6 @@ def default_benchmark( max_cells=maxcells, doplot=False, batch_size=batch_size, - devices=1, ) grn = grn_inferer(model, adata) grn.varp["all"] = grn.varp["GRN"] @@ -634,7 +633,6 @@ def default_benchmark( doplot=False, num_workers=8, batch_size=batch_size, - devices=1, ) grn = grn_inferer(model, nadata) grn.varp["all"] = grn.varp["GRN"] @@ -695,29 +693,28 @@ def default_benchmark( adata.var["isTF"] = False adata.var.loc[adata.var.symbol.isin(grnutils.TF), "isTF"] = True for celltype in cell_types: - print(celltype) - grn_inferer = GNInfer( - layer=layers, - how="random expr", - preprocess="softmax", - head_agg="max", - filtration="none", - forward_mode="none", - num_workers=8, - num_genes=2200, - max_cells=maxcells, - doplot=False, - batch_size=batch_size, - devices=1, - ) - - grn = grn_inferer(model, adata[adata.X.sum(1) > 500], cell_type=celltype) - grn.var.index = make_index_unique(grn.var["symbol"].astype(str)) - metrics[celltype + "_scprint"] = BenGRN( - grn, doplot=False - ).scprint_benchmark() - del grn - gc.collect() + # print(celltype) + # grn_inferer = GNInfer( + # layer=layers, + # how="random expr", + # preprocess="softmax", + # head_agg="max", + # filtration="none", + # forward_mode="none", + # num_workers=8, + # num_genes=2200, + # max_cells=maxcells, + # doplot=False, + # batch_size=batch_size, + # ) + # + # grn = grn_inferer(model, adata[adata.X.sum(1) > 500], cell_type=celltype) + # grn.var.index = make_index_unique(grn.var["symbol"].astype(str)) + # metrics[celltype + "_scprint"] = BenGRN( + # grn, doplot=False + # ).scprint_benchmark() + # del grn + # gc.collect() grn_inferer = GNInfer( layer=layers, how="most var across", @@ -730,7 +727,6 @@ def default_benchmark( max_cells=maxcells, doplot=False, batch_size=batch_size, - devices=1, ) grn = grn_inferer(model, adata[adata.X.sum(1) > 500], cell_type=celltype) grn.var.index = make_index_unique(grn.var["symbol"].astype(str)) diff --git a/scprint/trainer/trainer.py b/scprint/trainer/trainer.py index 03704bc..1d0644a 100644 --- a/scprint/trainer/trainer.py +++ b/scprint/trainer/trainer.py @@ -21,6 +21,7 @@ def __init__( do_generate: bool = True, class_scale: float = 1.5, mask_ratio: List[float] = [], # 0.3 + test_every: int = 1, warmup_duration: int = 500, fused_adam: bool = False, adv_class_scale: float = 0.1, @@ -69,6 +70,7 @@ def __init__( optim (str): Optimizer to use during training. Defaults to "adamW". weight_decay (float): Weight decay to apply during optimization. Defaults to 0.01. name (str): Name of the training mode. Defaults to an empty string. should be an ID for the model + test_every (int): Number of epochs between testing. Defaults to 1. """ super().__init__() self.do_denoise = do_denoise @@ -100,6 +102,7 @@ def __init__( self.do_adv_batch = do_adv_batch self.run_full_forward = run_full_forward self.name = name + self.test_every = test_every def __repr__(self): return ( @@ -131,7 +134,8 @@ def __repr__(self): f"do_cls={self.do_cls}, " f"do_adv_batch={self.do_adv_batch}, " f"run_full_forward={self.run_full_forward}), " - f"name={self.name})" + f"name={self.name}, " + f"test_every={self.test_every})" ) def setup(self, trainer, model, stage=None): @@ -165,4 +169,5 @@ def setup(self, trainer, model, stage=None): model.optim = self.optim model.weight_decay = self.weight_decay model.name = self.name + model.test_every = self.test_every # model.configure_optimizers() diff --git a/tests/test_base.py b/tests/test_base.py index dfda97f..80999ac 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -5,12 +5,16 @@ import pytest import scanpy as sc import torch -from scdataloader import Preprocessor +from scdataloader import Preprocessor, DataModule from scdataloader.utils import populate_my_ontology from scprint import scPrint from scprint.base import NAME from scprint.tasks import Denoiser, Embedder, GNInfer +from scprint.trainer import TrainingMode + +import lamindb as ln +from lightning.pytorch import Trainer def test_base(): @@ -102,7 +106,93 @@ def test_base(): ) grn_adata = grn_inferer(model, adata) assert "GRN" in grn_adata.varp, "GRN inference failed" - # fit scprint + # make a collection + file = ln.Artifact(adata, description="test file") + file.save() + col = ln.Collection(file, name="test dataset") + col.save() + datamodule = DataModule( + collection_name="test dataset", + gene_embeddings=os.path.join(os.path.dirname(__file__), "test_emb.parquet"), + all_clss=[ + "sex_ontology_term_id", + "organism_ontology_term_id", + ], + hierarchical_clss=[], + organisms=["NCBITaxon:9606"], # , "NCBITaxon:10090"], + how="most expr", + max_len=200, + add_zero_genes=0, + # how much more you will see the most present vs less present category + weight_scaler=10, + clss_to_weight=["sex_ontology_term_id"], + clss_to_pred=[ + "sex_ontology_term_id", + "organism_ontology_term_id", + ], + batch_size=1, + num_workers=1, + # train_oversampling=2, + validation_split=0.1, + do_gene_pos=False, + test_split=0.1, + ) + _ = datamodule.setup() + model = scPrint( + genes=datamodule.genes, + d_model=64, + nhead=1, + nlayers=1, + # layers_cls = [d_model], + # labels = datamodule.labels, + # cls_hierarchy = datamodule.cls_hierarchy, + dropout=0, + transformer="normal", + precpt_gene_emb=os.path.join(os.path.dirname(__file__), "test_emb.parquet"), + mvc_decoder="inner product", + fused_dropout_add_ln=False, + checkpointing=False, + ) + trainingmode = TrainingMode( + do_denoise=True, + noise=[0.1], + do_cce=False, + do_ecs=False, + do_cls=True, + do_mvc=True, + mask_ratio=[], + warmup_duration=10, + lr_reduce_patience=10, + test_every=10_000, + ) + trainer = Trainer( + gradient_clip_val=500, + max_time={"minutes": 4}, + limit_val_batches=1, + callbacks=[trainingmode], + accumulate_grad_batches=1, + check_val_every_n_epoch=1, + overfit_batches=1, + max_epochs=20, + reload_dataloaders_every_n_epochs=100_000, + logger=None, + num_sanity_val_steps=0, + max_steps=100, + ) + initial_loss = None + for i in range(2): + trainer.fit(model, datamodule=datamodule) + trainer.fit_loop.max_epochs = 20 * ( + i + 2 + ) # Reset max_epochs for next iteration + current_loss = trainer.callback_metrics.get("train_loss") + if initial_loss is None: + initial_loss = current_loss + else: + assert ( + current_loss < initial_loss + ), f"Loss not decreasing: initial {initial_loss}, current {current_loss}" + initial_loss = current_loss # cli # get_Seq # sinkhorn diff --git a/tests/test_emb.parquet b/tests/test_emb.parquet new file mode 100644 index 0000000..505fe5a Binary files /dev/null and b/tests/test_emb.parquet differ diff --git a/uv.lock b/uv.lock index 11628dc..62001e9 100644 --- a/uv.lock +++ b/uv.lock @@ -5102,7 +5102,7 @@ wheels = [ [[package]] name = "scprint" -version = "1.6.1" +version = "1.6.2" source = { editable = "." } dependencies = [ { name = "anndata" },