diff --git a/experiments/experiment_utils.py b/experiments/experiment_utils.py new file mode 100644 index 0000000..40dfbb4 --- /dev/null +++ b/experiments/experiment_utils.py @@ -0,0 +1,17 @@ +from typing import Literal + +from lightning.pytorch.loggers import WandbLogger + + +def get_wandb_logger( + task_name: Literal["orientability", "betti_numbers", "name"], + save_dir="./lightning_logs", + model_name: str = None, +): + wandb_logger = WandbLogger( + project="MANTRA", entity="aidos-labs", save_dir=save_dir + ) + wandb_logger.experiment.config["task"] = task_name + if model_name is not None: + wandb_logger.experiment.config["model_name"] = model_name + return wandb_logger diff --git a/experiments/orientability/graphs/GATSimplex2Vec.py b/experiments/orientability/graphs/GATSimplex2Vec.py index eb3ac68..4ddd15b 100644 --- a/experiments/orientability/graphs/GATSimplex2Vec.py +++ b/experiments/orientability/graphs/GATSimplex2Vec.py @@ -1,11 +1,12 @@ import lightning as L import torch import torchvision.transforms as transforms -from lightning.pytorch.loggers import CSVLogger +import wandb from torch.utils.data import Subset from torch_geometric.data import DataLoader from torch_geometric.transforms import FaceToEdge +from experiments.experiment_utils import get_wandb_logger from experiments.lightning_modules.GraphCommonModuleOrientability import ( GraphCommonModuleOrientability, ) @@ -89,8 +90,11 @@ def single_experiment_orientability_gat_simplex2vec(): test_dl = DataLoader( test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers ) - logger = CSVLogger(name="GATSimplex2Vec", save_dir="./lightning_logs") + logger = get_wandb_logger( + task_name="orientability", model_name="GATSimplex2Vec" + ) trainer = L.Trainer( max_epochs=max_epochs, log_every_n_steps=1, logger=logger ) trainer.fit(model, train_dl, test_dl) + wandb.finish() diff --git a/experiments/orientability/graphs/GCN.py b/experiments/orientability/graphs/GCN.py index f17db8d..461035d 100644 --- a/experiments/orientability/graphs/GCN.py +++ b/experiments/orientability/graphs/GCN.py @@ -1,11 +1,12 @@ import lightning as L import torch import torchvision.transforms as transforms -from lightning.pytorch.loggers import CSVLogger +import wandb from torch.utils.data import Subset from torch_geometric.loader import DataLoader from torch_geometric.transforms import FaceToEdge +from experiments.experiment_utils import get_wandb_logger from experiments.lightning_modules.GraphCommonModuleOrientability import ( GraphCommonModuleOrientability, ) @@ -83,7 +84,7 @@ def single_experiment_orientability_gnn(): test_dl = DataLoader( test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers ) - logger = CSVLogger(name="GCN", save_dir="./lightning_logs") + logger = get_wandb_logger(task_name="orientability", model_name="GCN") trainer = L.Trainer( max_epochs=max_epochs, log_every_n_steps=1, logger=logger ) @@ -93,3 +94,4 @@ def single_experiment_orientability_gnn(): train_dl, test_dl, ) + wandb.finish() diff --git a/experiments/orientability/simplicial_complexes/SCNN.py b/experiments/orientability/simplicial_complexes/SCNN.py index f176bbf..f68faa8 100644 --- a/experiments/orientability/simplicial_complexes/SCNN.py +++ b/experiments/orientability/simplicial_complexes/SCNN.py @@ -3,10 +3,11 @@ import lightning as L import torch import torchvision.transforms as transforms -from lightning.pytorch.loggers import CSVLogger +import wandb from torch import nn from torch.utils.data import Subset +from experiments.experiment_utils import get_wandb_logger from experiments.lightning_modules.BaseModuleOrientability import ( BaseOrientabilityModule, ) @@ -163,7 +164,7 @@ def single_experiment_orientability_scnn(): test_dl = SimplicialDataLoader( test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers ) - logger = CSVLogger(name="SCNN_rank_1", save_dir="./lightning_logs") + logger = get_wandb_logger(task_name="orientability", model_name="SCNN") # Use CPU acceleration: SCCNN does not support GPU acceleration because it creates matrices not placed in the # device of the network. trainer = L.Trainer( @@ -173,3 +174,4 @@ def single_experiment_orientability_scnn(): logger=logger, ) trainer.fit(model, train_dl, test_dl) + wandb.finish() diff --git a/main.py b/main.py index c2011fe..1fd09cb 100644 --- a/main.py +++ b/main.py @@ -1,8 +1,14 @@ from experiments.orientability.graphs.GATSimplex2Vec import ( single_experiment_orientability_gat_simplex2vec, ) +from experiments.orientability.graphs.GCN import ( + single_experiment_orientability_gnn, +) +from experiments.orientability.simplicial_complexes.SCNN import ( + single_experiment_orientability_scnn, +) if __name__ == "__main__": - # single_experiment_orientability_gnn() - # single_experiment_orientability_scnn() + single_experiment_orientability_gnn() + single_experiment_orientability_scnn() single_experiment_orientability_gat_simplex2vec() diff --git a/pyproject.toml b/pyproject.toml index 1bea603..927f443 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,8 @@ name = "mantra" version = "0.0.1" dependencies = [ "gudhi", - "lightning" + "lightning", + "wandb", ] requires-python = ">=3.8" authors = [