Skip to content

Commit

Permalink
Refactoring to execute experiments with only one method.
Browse files Browse the repository at this point in the history
  • Loading branch information
rballeba committed May 6, 2024
1 parent b6b6b47 commit 0636e95
Show file tree
Hide file tree
Showing 13 changed files with 317 additions and 121 deletions.
31 changes: 9 additions & 22 deletions experiments/betti_numbers/graphs/GCN.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
import lightning as L
import torch
import wandb
from torch.utils.data import Subset
from torch_geometric.loader import DataLoader
from torch_geometric.transforms import FaceToEdge, OneHotDegree
from torchvision import transforms

from experiments.experiment_utils import get_wandb_logger
from experiments.experiment_utils import perform_experiment
from experiments.lightning_modules.GraphCommonModuleBettiNumbers import (
GraphCommonModuleBettiNumbers,
)
Expand Down Expand Up @@ -72,21 +68,12 @@ def single_experiment_betti_numbers_gnn():
num_hidden_layers=num_hidden_layers,
learning_rate=learning_rate,
)
train_ds = Subset(dataset, dataset.train_betti_numbers_indices)
test_ds = Subset(dataset, dataset.test_betti_numbers_indices)
train_dl = DataLoader(
train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers
perform_experiment(
task="betti_numbers",
model=model,
model_name="GCN",
dataset=dataset,
batch_size=batch_size,
max_epochs=max_epochs,
num_workers=num_workers,
)
test_dl = DataLoader(
test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers
)
logger = get_wandb_logger(task_name="betti_numbers", model_name="GCN")
trainer = L.Trainer(
max_epochs=max_epochs, log_every_n_steps=1, logger=logger
)
trainer.fit(
model,
train_dl,
test_dl,
)
wandb.finish()
103 changes: 103 additions & 0 deletions experiments/betti_numbers/others/MLPConstantShape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import torch
from torch import nn
from torch_geometric.transforms import OneHotDegree, FaceToEdge
from torchvision import transforms

from experiments.experiment_utils import perform_experiment
from experiments.lightning_modules.BaseModuleBettiNumbers import (
BaseBettiNumbersModule,
)
from mantra.simplicial import SimplicialDataset
from mantra.transforms import DegreeTransform, TriangulationToFaceTransform
from models.others.MLPConstantShape import MLPConstantShape


class MLPModule(BaseBettiNumbersModule):
def __init__(
self,
num_input_neurons,
num_hidden_neurons,
num_hidden_layers,
num_out_neurons,
learning_rate,
):
super().__init__()
self.base_model = MLPConstantShape(
num_input_neurons=num_input_neurons,
num_hidden_neurons=num_hidden_neurons,
num_hidden_layers=num_hidden_layers,
num_out_neurons=num_out_neurons,
)
self.learning_rate = learning_rate

def configure_optimizers(self):
optimizer = torch.optim.Adam(
self.base_model.parameters(), lr=self.learning_rate
)
return optimizer

def forward(self, x, batch):
x = self.base_model(x, batch)
return x

def general_step(self, batch, batch_idx, step: str):
x_hat = self(batch.x, batch.batch)
y = torch.tensor(
batch.betti_numbers, device=x_hat.device, dtype=x_hat.dtype
)
batch_len = len(y)
loss = nn.functional.mse_loss(x_hat, y)
self.log(
f"{step}_loss",
loss,
prog_bar=True,
batch_size=batch_len,
on_step=False,
on_epoch=True,
)
self.log_scores(x_hat, y, batch_len, step)
return loss


def load_dataset_with_transformations():
tr = transforms.Compose(
[
TriangulationToFaceTransform(),
FaceToEdge(remove_faces=False),
DegreeTransform(),
OneHotDegree(max_degree=8, cat=False),
]
)
dataset = SimplicialDataset(root="./data", transform=tr)
return dataset


def single_experiment_betti_numbers_mlp_constant_shape():
# ===============================
# Training parameters
# ===============================
num_hidden_neurons = 64
num_hidden_layers = 3
num_out_neurons = 3
batch_size = 32
learning_rate = 0.1
num_workers = 0
max_epochs = 100
# ===============================
dataset = load_dataset_with_transformations()
model = MLPModule(
num_input_neurons=dataset.num_features,
num_hidden_neurons=num_hidden_neurons,
num_hidden_layers=num_hidden_layers,
num_out_neurons=num_out_neurons,
learning_rate=learning_rate,
)
perform_experiment(
task="betti_numbers",
model=model,
model_name="MLPConstantShape",
dataset=dataset,
batch_size=batch_size,
num_workers=num_workers,
max_epochs=max_epochs,
)
Empty file.
46 changes: 46 additions & 0 deletions experiments/experiment_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from typing import Literal

import lightning as L
import wandb
from lightning.pytorch.loggers import WandbLogger
from torch.utils.data import Subset
from torch_geometric.loader import DataLoader


def get_wandb_logger(
Expand All @@ -15,3 +19,45 @@ def get_wandb_logger(
if model_name is not None:
wandb_logger.experiment.config["model_name"] = model_name
return wandb_logger


def perform_experiment(
task: Literal["orientability", "betti_numbers", "name"],
model,
model_name,
dataset,
batch_size,
num_workers,
max_epochs,
data_loader_class=DataLoader,
accelerator="auto",
):
if task == "orientability":
train_indices = dataset.train_orientability_indices
test_indices = dataset.test_orientability_indices
elif task == "betti_numbers":
train_indices = dataset.train_betti_numbers_indices
test_indices = dataset.test_betti_numbers_indices
elif task == "name":
train_indices = dataset.train_name_indices
test_indices = dataset.test_name_indices
else:
raise ValueError(f"Task {task} not recognized")
logger = get_wandb_logger(task_name=task, model_name=model_name)
train_ds = Subset(dataset, train_indices)
test_ds = Subset(dataset, test_indices)
train_dl = data_loader_class(
train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers
)
test_dl = data_loader_class(
test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers
)
trainer = L.Trainer(
max_epochs=max_epochs, log_every_n_steps=1, logger=logger
)
trainer.fit(
model,
train_dl,
test_dl,
)
wandb.finish()
1 change: 0 additions & 1 deletion experiments/lightning_modules/GraphCommonModuleName.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import torch
from torch import nn

from experiments.lightning_modules.BaseModelClassification import (
Expand Down
31 changes: 9 additions & 22 deletions experiments/name/graphs/GCN.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import lightning as L
import torch
import wandb
from torch_geometric.transforms import FaceToEdge
from torchvision import transforms

from experiments.experiment_utils import get_wandb_logger
from experiments.experiment_utils import perform_experiment
from experiments.lightning_modules.GraphCommonModuleName import (
GraphCommonModuleName,
)
Expand All @@ -15,8 +13,6 @@
TriangulationToFaceTransform,
)
from models.graphs.GCN import GCNetwork
from torch.utils.data import Subset
from torch_geometric.loader import DataLoader


class GCNModule(GraphCommonModuleName):
Expand Down Expand Up @@ -76,21 +72,12 @@ def single_experiment_name_gnn():
num_hidden_layers=num_hidden_layers,
learning_rate=learning_rate,
)
train_ds = Subset(dataset, dataset.train_betti_numbers_indices)
test_ds = Subset(dataset, dataset.test_betti_numbers_indices)
train_dl = DataLoader(
train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers
perform_experiment(
task="name",
model=model,
model_name="GCN",
dataset=dataset,
batch_size=batch_size,
max_epochs=max_epochs,
num_workers=num_workers,
)
test_dl = DataLoader(
test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers
)
logger = get_wandb_logger(task_name="betti_numbers", model_name="GCN")
trainer = L.Trainer(
max_epochs=max_epochs, log_every_n_steps=1, logger=logger
)
trainer.fit(
model,
train_dl,
test_dl,
)
wandb.finish()
30 changes: 10 additions & 20 deletions experiments/orientability/graphs/GATSimplex2Vec.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import lightning as L
import torch
import torch
import torchvision.transforms as transforms
import wandb
from torch.utils.data import Subset
from torch_geometric.data import DataLoader
from torch_geometric.transforms import FaceToEdge

from experiments.experiment_utils import get_wandb_logger
from experiments.experiment_utils import perform_experiment
from experiments.lightning_modules.GraphCommonModuleOrientability import (
GraphCommonModuleOrientability,
)
Expand Down Expand Up @@ -82,19 +79,12 @@ def single_experiment_orientability_gat_simplex2vec():
num_hidden_layers=num_hidden_layers,
learning_rate=learning_rate,
)
train_ds = Subset(dataset, dataset.train_orientability_indices)
test_ds = Subset(dataset, dataset.test_orientability_indices)
train_dl = DataLoader(
train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers
)
test_dl = DataLoader(
test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers
)
logger = get_wandb_logger(
task_name="orientability", model_name="GATSimplex2Vec"
)
trainer = L.Trainer(
max_epochs=max_epochs, log_every_n_steps=1, logger=logger
perform_experiment(
task="orientability",
model=model,
model_name="GATSimplex2Vec",
dataset=dataset,
batch_size=batch_size,
max_epochs=max_epochs,
num_workers=num_workers,
)
trainer.fit(model, train_dl, test_dl)
wandb.finish()
33 changes: 10 additions & 23 deletions experiments/orientability/graphs/GCN.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import lightning as L
import torch
import torch
import torchvision.transforms as transforms
import wandb
from torch.utils.data import Subset
from torch_geometric.loader import DataLoader
from torch_geometric.transforms import FaceToEdge

from experiments.experiment_utils import get_wandb_logger
from experiments.experiment_utils import perform_experiment
from experiments.lightning_modules.GraphCommonModuleOrientability import (
GraphCommonModuleOrientability,
)
Expand Down Expand Up @@ -76,22 +73,12 @@ def single_experiment_orientability_gnn():
num_hidden_layers=num_hidden_layers,
learning_rate=learning_rate,
)
train_ds = Subset(dataset, dataset.train_orientability_indices)
test_ds = Subset(dataset, dataset.test_orientability_indices)
train_dl = DataLoader(
train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers
)
test_dl = DataLoader(
test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers
)
logger = get_wandb_logger(task_name="orientability", model_name="GCN")
trainer = L.Trainer(
max_epochs=max_epochs, log_every_n_steps=1, logger=logger
)

trainer.fit(
model,
train_dl,
test_dl,
perform_experiment(
task="orientability",
model=model,
model_name="GCN",
dataset=dataset,
batch_size=batch_size,
num_workers=num_workers,
max_epochs=max_epochs,
)
wandb.finish()
29 changes: 9 additions & 20 deletions experiments/orientability/simplicial_complexes/SCNN.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
from typing import Literal

import lightning as L
import torch
import torchvision.transforms as transforms
import wandb
from torch import nn
from torch.utils.data import Subset

from experiments.experiment_utils import get_wandb_logger
from experiments.experiment_utils import perform_experiment
from experiments.lightning_modules.BaseModelClassification import (
BaseClassificationModule,
)
Expand Down Expand Up @@ -156,22 +153,14 @@ def single_experiment_orientability_scnn():
n_layers=num_layers,
learning_rate=learning_rate,
)
train_ds = Subset(dataset, dataset.train_orientability_indices)
test_ds = Subset(dataset, dataset.test_orientability_indices)
train_dl = SimplicialDataLoader(
train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers
)
test_dl = SimplicialDataLoader(
test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers
)
logger = get_wandb_logger(task_name="orientability", model_name="SCNN")
# Use CPU acceleration: SCCNN does not support GPU acceleration because it creates matrices not placed in the
# device of the network.
trainer = L.Trainer(
perform_experiment(
task="orientability",
model=model,
model_name="SCNN",
dataset=dataset,
batch_size=batch_size,
num_workers=num_workers,
max_epochs=max_epochs,
data_loader_class=SimplicialDataLoader,
accelerator="cpu",
log_every_n_steps=1,
logger=logger,
)
trainer.fit(model, train_dl, test_dl)
wandb.finish()
Loading

0 comments on commit 0636e95

Please sign in to comment.