Skip to content

Commit

Permalink
Adding wandb to the project configured for the aidos-lab organization…
Browse files Browse the repository at this point in the history
… and MANTRA project.
  • Loading branch information
rballeba committed May 4, 2024
1 parent 14e8a39 commit 609c648
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 9 deletions.
17 changes: 17 additions & 0 deletions experiments/experiment_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from typing import Literal

from lightning.pytorch.loggers import WandbLogger


def get_wandb_logger(
task_name: Literal["orientability", "betti_numbers", "name"],
save_dir="./lightning_logs",
model_name: str = None,
):
wandb_logger = WandbLogger(
project="MANTRA", entity="aidos-labs", save_dir=save_dir
)
wandb_logger.experiment.config["task"] = task_name
if model_name is not None:
wandb_logger.experiment.config["model_name"] = model_name
return wandb_logger
8 changes: 6 additions & 2 deletions experiments/orientability/graphs/GATSimplex2Vec.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import lightning as L
import torch
import torchvision.transforms as transforms
from lightning.pytorch.loggers import CSVLogger
import wandb
from torch.utils.data import Subset
from torch_geometric.data import DataLoader
from torch_geometric.transforms import FaceToEdge

from experiments.experiment_utils import get_wandb_logger
from experiments.lightning_modules.GraphCommonModuleOrientability import (
GraphCommonModuleOrientability,
)
Expand Down Expand Up @@ -89,8 +90,11 @@ def single_experiment_orientability_gat_simplex2vec():
test_dl = DataLoader(
test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers
)
logger = CSVLogger(name="GATSimplex2Vec", save_dir="./lightning_logs")
logger = get_wandb_logger(
task_name="orientability", model_name="GATSimplex2Vec"
)
trainer = L.Trainer(
max_epochs=max_epochs, log_every_n_steps=1, logger=logger
)
trainer.fit(model, train_dl, test_dl)
wandb.finish()
6 changes: 4 additions & 2 deletions experiments/orientability/graphs/GCN.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import lightning as L
import torch
import torchvision.transforms as transforms
from lightning.pytorch.loggers import CSVLogger
import wandb
from torch.utils.data import Subset
from torch_geometric.loader import DataLoader
from torch_geometric.transforms import FaceToEdge

from experiments.experiment_utils import get_wandb_logger
from experiments.lightning_modules.GraphCommonModuleOrientability import (
GraphCommonModuleOrientability,
)
Expand Down Expand Up @@ -83,7 +84,7 @@ def single_experiment_orientability_gnn():
test_dl = DataLoader(
test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers
)
logger = CSVLogger(name="GCN", save_dir="./lightning_logs")
logger = get_wandb_logger(task_name="orientability", model_name="GCN")
trainer = L.Trainer(
max_epochs=max_epochs, log_every_n_steps=1, logger=logger
)
Expand All @@ -93,3 +94,4 @@ def single_experiment_orientability_gnn():
train_dl,
test_dl,
)
wandb.finish()
6 changes: 4 additions & 2 deletions experiments/orientability/simplicial_complexes/SCNN.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import lightning as L
import torch
import torchvision.transforms as transforms
from lightning.pytorch.loggers import CSVLogger
import wandb
from torch import nn
from torch.utils.data import Subset

from experiments.experiment_utils import get_wandb_logger
from experiments.lightning_modules.BaseModuleOrientability import (
BaseOrientabilityModule,
)
Expand Down Expand Up @@ -163,7 +164,7 @@ def single_experiment_orientability_scnn():
test_dl = SimplicialDataLoader(
test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers
)
logger = CSVLogger(name="SCNN_rank_1", save_dir="./lightning_logs")
logger = get_wandb_logger(task_name="orientability", model_name="SCNN")
# Use CPU acceleration: SCCNN does not support GPU acceleration because it creates matrices not placed in the
# device of the network.
trainer = L.Trainer(
Expand All @@ -173,3 +174,4 @@ def single_experiment_orientability_scnn():
logger=logger,
)
trainer.fit(model, train_dl, test_dl)
wandb.finish()
10 changes: 8 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
from experiments.orientability.graphs.GATSimplex2Vec import (
single_experiment_orientability_gat_simplex2vec,
)
from experiments.orientability.graphs.GCN import (
single_experiment_orientability_gnn,
)
from experiments.orientability.simplicial_complexes.SCNN import (
single_experiment_orientability_scnn,
)

if __name__ == "__main__":
# single_experiment_orientability_gnn()
# single_experiment_orientability_scnn()
single_experiment_orientability_gnn()
single_experiment_orientability_scnn()
single_experiment_orientability_gat_simplex2vec()
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ name = "mantra"
version = "0.0.1"
dependencies = [
"gudhi",
"lightning"
"lightning",
"wandb",
]
requires-python = ">=3.8"
authors = [
Expand Down

0 comments on commit 609c648

Please sign in to comment.