From 6dc3af6561af6f0cdae247909c1a8ea18ff642be Mon Sep 17 00:00:00 2001 From: Nicolas Kuske Date: Sun, 19 May 2024 14:01:11 +0200 Subject: [PATCH 1/2] swap order of valid and train arguments for class DomainDataModule here the order of validation and training where mixed up. This error goes all the way through to the trainer. I.e., training data and validation data get swapped! Maybe adopt the code with keyword arguments so that the order does not matter? To double check the length of valid and train data used in the trainer, add in class DomainDataModule: def get_train_length(self): return len(self.train_dataset) def get_val_length(self): return len(self.val_dataset) Then add before the unimodal module the following class: from lightning.pytorch.callbacks import Callback class DatasetLengthLogger(Callback): def on_train_end(self, trainer, pl_module): train_loader = trainer.datamodule.train_dataloader() val_loader = trainer.datamodule.val_dataloader() train_length = len(train_loader.dataset) val_length = len(val_loader.dataset) print(f"Training Data Length: {train_length}") print(f"Validation Data Length: {val_length}") Finally, in def train module, before calling the trainer, add dataset_length_logger = DatasetLengthLogger() And in trainer, add dataset_length_logger to the callbacks: callbacks=[ ModelCheckpoint( dirpath="checkpoints", filename=module_name, monitor="val_loss", mode="min", save_top_k=1, ), dataset_length_logger ], Result: Training Data Length 128 Validation Data Length 256 --- docs/shimmer_basics.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/shimmer_basics.md b/docs/shimmer_basics.md index 1bba0524..1af9b629 100644 --- a/docs/shimmer_basics.md +++ b/docs/shimmer_basics.md @@ -65,8 +65,8 @@ from torch.utils.data import DataLoader, TensorDataset class DomainDataModule(LightningDataModule): def __init__( self, - val_dataset: torch.Tensor, train_dataset: torch.Tensor, + val_dataset: torch.Tensor, batch_size: int, ) -> None: super().__init__() From 77d96ccffd811cb034003cc484a27e4dc565f90e Mon Sep 17 00:00:00 2001 From: bdvllrs Date: Thu, 19 Sep 2024 09:52:24 +0000 Subject: [PATCH 2/2] also change for GWDataModule --- docs/shimmer_basics.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/shimmer_basics.md b/docs/shimmer_basics.md index 1af9b629..a5d7281a 100644 --- a/docs/shimmer_basics.md +++ b/docs/shimmer_basics.md @@ -334,8 +334,8 @@ from shimmer import RepeatedDataset class GWDataModule(LightningDataModule): def __init__( self, - val_datasets: dict[frozenset[str], DomainDataset], train_datasets: dict[frozenset[str], DomainDataset], + val_datasets: dict[frozenset[str], DomainDataset], batch_size: int, ) -> None: super().__init__()