-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Removed pdoc, added example notebook.
- Loading branch information
1 parent
073041c
commit a24f66b
Showing
6 changed files
with
410 additions
and
31 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,224 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Training a GNN on the Mantra Dataset\n", | ||
"\n", | ||
"In this tutorial, we provide an example use-case for the mantra dataset. We show \n", | ||
"how to train a GNN to predict the orientability based on random node features. \n", | ||
"\n", | ||
"The `torch-geometric` interface for the MANTRA dataset can be installed with \n", | ||
"pip via the command \n", | ||
"```{python}\n", | ||
"pip install mantra\n", | ||
"```\n", | ||
"\n", | ||
"As a preprocessing step we apply three transforms to the base dataset.\n", | ||
"Since the dataset does not have intrinsic coordinates attached to the vertices, \n", | ||
"we first have to create a transform that generates random node features.\n", | ||
"Each manifold in MANTRA comes as a list of triples, where the integers in each \n", | ||
"triple are vertex id's. The starting id in each manifold is $1$ and has to be \n", | ||
"converted to a torch-geometric compliant $0$-based index.\n", | ||
"GNN's are typically trained on graphs and the FaceToEdge transform converts our\n", | ||
"manifold to a graph. \n", | ||
"\n", | ||
"For each of the transforms we use a single class and are succesively applied to\n", | ||
"form the final transformed dataset. " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Load all required packages. \n", | ||
"\n", | ||
"import torch \n", | ||
"import torch.nn.functional as F\n", | ||
"from torch import nn\n", | ||
"from torch.utils.data import random_split\n", | ||
"\n", | ||
"from torchvision.transforms import Compose\n", | ||
"\n", | ||
"from torch_geometric.loader import DataLoader\n", | ||
"from torch_geometric.transforms import Compose, FaceToEdge\n", | ||
"\n", | ||
"from torch_geometric.nn import GCNConv, global_mean_pool\n", | ||
"\n", | ||
"# Load the mantra dataset\n", | ||
"from mantra.datasets import ManifoldTriangulations\n", | ||
"\n", | ||
"class NodeIndex: \n", | ||
" def __call__(self,data):\n", | ||
" '''\n", | ||
" In the base dataset, the vertex start index is 1 and is provided as a\n", | ||
" list. The transform converts the list to a tensor and changes the start\n", | ||
" index to 0, in compliance with torch-geometric. \n", | ||
" '''\n", | ||
" data.face = torch.tensor(data.triangulation ).T- 1\n", | ||
" return data\n", | ||
"\n", | ||
"\n", | ||
"class RandomNodeFeatures: \n", | ||
" def __call__(self,data):\n", | ||
" \"\"\"\n", | ||
" We create an 8-dimensional vector with random numbers for each vertex. \n", | ||
" Often the coordinates of the graph or triangulation are tightly coupled \n", | ||
" with the structure of the graph, an assumtion we hope to tackle.\n", | ||
" \"\"\"\n", | ||
" data.x = torch.rand(size=(data.face.max()+1,8))\n", | ||
" return data\n", | ||
"\n", | ||
"\n", | ||
"# Instantiate the dataset. Following the `torch-geometric` API, we download the \n", | ||
"# dataset into the root directory. \n", | ||
"dataset = ManifoldTriangulations(root=\"./data\", manifold=\"2\", version=\"latest\",\n", | ||
" transform=Compose([\n", | ||
" NodeIndex(),\n", | ||
" RandomNodeFeatures(),\n", | ||
" FaceToEdge(remove_faces=True),\n", | ||
" ]\n", | ||
" )\n", | ||
" )\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"train_dataset, test_dataset = random_split(\n", | ||
" dataset,\n", | ||
" [0.8,0.2\n", | ||
" ],\n", | ||
" ) # type: ignore\n", | ||
"\n", | ||
"train_dataloader = DataLoader(train_dataset,batch_size=32)\n", | ||
"test_dataloader = DataLoader(test_dataset,batch_size=32)\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"class GCN(nn.Module):\n", | ||
" def __init__(self):\n", | ||
" super().__init__()\n", | ||
"\n", | ||
" self.conv_input = GCNConv(\n", | ||
" 8, 16\n", | ||
" )\n", | ||
" self.final_linear = nn.Linear(\n", | ||
" 16, 1\n", | ||
" )\n", | ||
"\n", | ||
" def forward(self, batch):\n", | ||
" x, edge_index, batch = batch.x, batch.edge_index, batch.batch\n", | ||
" \n", | ||
" # 1. Obtain node embeddings\n", | ||
" x = self.conv_input(x, edge_index)\n", | ||
" # 2. Readout layer\n", | ||
" x = global_mean_pool(x, batch) # [batch_size, hidden_channels]\n", | ||
" # 3. Apply a final classifier\n", | ||
" x = F.dropout(x, p=0.5, training=self.training)\n", | ||
" x = self.final_linear(x)\n", | ||
" return x" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Epoch 0, 0.2743515074253082\n", | ||
"Epoch 1, 0.24504387378692627\n", | ||
"Epoch 2, 0.2461807280778885\n", | ||
"Epoch 3, 0.24599741399288177\n", | ||
"Epoch 4, 0.2461780607700348\n", | ||
"Epoch 5, 0.24923910200595856\n", | ||
"Epoch 6, 0.24623213708400726\n", | ||
"Epoch 7, 0.24637295305728912\n", | ||
"Epoch 8, 0.24762295186519623\n", | ||
"Epoch 9, 0.24508829414844513\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", | ||
"model = GCN().to(device)\n", | ||
"optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)\n", | ||
"loss_fn = nn.BCEWithLogitsLoss()\n", | ||
"\n", | ||
"model.train()\n", | ||
"for epoch in range(10):\n", | ||
" for batch in train_dataloader: \n", | ||
" batch.orientable = batch.orientable.to(torch.float)\n", | ||
" batch.to(device)\n", | ||
" optimizer.zero_grad()\n", | ||
" out = model(batch)\n", | ||
" loss = loss_fn(out.squeeze(), batch.orientable)\n", | ||
" loss.backward()\n", | ||
" optimizer.step()\n", | ||
" print(f\"Epoch {epoch}, {loss.item()}\")\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Accuracy: 0.0825\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"correct = 0\n", | ||
"total = 0\n", | ||
"model.eval()\n", | ||
"for testbatch in test_dataloader: \n", | ||
" testbatch.to(device)\n", | ||
" pred = model(testbatch)\n", | ||
" correct += ((pred.squeeze() < 0) == testbatch.orientable).sum()\n", | ||
" total += len(testbatch)\n", | ||
"\n", | ||
"acc = int(correct) / int(total)\n", | ||
"print(f'Accuracy: {acc:.4f}')" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.11" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.