Skip to content

Commit

Permalink
Removed pdoc, added example notebook.
Browse files Browse the repository at this point in the history
  • Loading branch information
ErnstRoell committed Oct 21, 2024
1 parent 073041c commit a24f66b
Show file tree
Hide file tree
Showing 6 changed files with 410 additions and 31 deletions.
1 change: 0 additions & 1 deletion .github/workflows/docs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ jobs:
pip install torch==2.2.0
pip install numpy
pip install git+https://github.com/Lezcano/geotorch/
pip install pdoc
# ADJUST THIS: install all dependencies (including pdoc)
- name: Install mantra
Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration

extensions = ["sphinx.ext.autodoc", "sphinx_rtd_theme", "myst_parser"]
extensions = ["sphinx.ext.autodoc", "sphinx_rtd_theme", "myst_parser","nbsphinx"]

myst_enable_extensions = [
"amsmath",
Expand Down
12 changes: 10 additions & 2 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,18 @@

.. toctree::
:hidden:
:caption: Modules

:caption: Modules
datasets

.. toctree::
:hidden:
:caption: Examples:

notebooks/train_gnn.ipynb



.. toctree::
:hidden:
:caption: Licence
Expand Down
224 changes: 224 additions & 0 deletions docs/source/notebooks/train_gnn.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Training a GNN on the Mantra Dataset\n",
"\n",
"In this tutorial, we provide an example use-case for the mantra dataset. We show \n",
"how to train a GNN to predict the orientability based on random node features. \n",
"\n",
"The `torch-geometric` interface for the MANTRA dataset can be installed with \n",
"pip via the command \n",
"```{python}\n",
"pip install mantra\n",
"```\n",
"\n",
"As a preprocessing step we apply three transforms to the base dataset.\n",
"Since the dataset does not have intrinsic coordinates attached to the vertices, \n",
"we first have to create a transform that generates random node features.\n",
"Each manifold in MANTRA comes as a list of triples, where the integers in each \n",
"triple are vertex id's. The starting id in each manifold is $1$ and has to be \n",
"converted to a torch-geometric compliant $0$-based index.\n",
"GNN's are typically trained on graphs and the FaceToEdge transform converts our\n",
"manifold to a graph. \n",
"\n",
"For each of the transforms we use a single class and are succesively applied to\n",
"form the final transformed dataset. "
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# Load all required packages. \n",
"\n",
"import torch \n",
"import torch.nn.functional as F\n",
"from torch import nn\n",
"from torch.utils.data import random_split\n",
"\n",
"from torchvision.transforms import Compose\n",
"\n",
"from torch_geometric.loader import DataLoader\n",
"from torch_geometric.transforms import Compose, FaceToEdge\n",
"\n",
"from torch_geometric.nn import GCNConv, global_mean_pool\n",
"\n",
"# Load the mantra dataset\n",
"from mantra.datasets import ManifoldTriangulations\n",
"\n",
"class NodeIndex: \n",
" def __call__(self,data):\n",
" '''\n",
" In the base dataset, the vertex start index is 1 and is provided as a\n",
" list. The transform converts the list to a tensor and changes the start\n",
" index to 0, in compliance with torch-geometric. \n",
" '''\n",
" data.face = torch.tensor(data.triangulation ).T- 1\n",
" return data\n",
"\n",
"\n",
"class RandomNodeFeatures: \n",
" def __call__(self,data):\n",
" \"\"\"\n",
" We create an 8-dimensional vector with random numbers for each vertex. \n",
" Often the coordinates of the graph or triangulation are tightly coupled \n",
" with the structure of the graph, an assumtion we hope to tackle.\n",
" \"\"\"\n",
" data.x = torch.rand(size=(data.face.max()+1,8))\n",
" return data\n",
"\n",
"\n",
"# Instantiate the dataset. Following the `torch-geometric` API, we download the \n",
"# dataset into the root directory. \n",
"dataset = ManifoldTriangulations(root=\"./data\", manifold=\"2\", version=\"latest\",\n",
" transform=Compose([\n",
" NodeIndex(),\n",
" RandomNodeFeatures(),\n",
" FaceToEdge(remove_faces=True),\n",
" ]\n",
" )\n",
" )\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"train_dataset, test_dataset = random_split(\n",
" dataset,\n",
" [0.8,0.2\n",
" ],\n",
" ) # type: ignore\n",
"\n",
"train_dataloader = DataLoader(train_dataset,batch_size=32)\n",
"test_dataloader = DataLoader(test_dataset,batch_size=32)\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"class GCN(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
"\n",
" self.conv_input = GCNConv(\n",
" 8, 16\n",
" )\n",
" self.final_linear = nn.Linear(\n",
" 16, 1\n",
" )\n",
"\n",
" def forward(self, batch):\n",
" x, edge_index, batch = batch.x, batch.edge_index, batch.batch\n",
" \n",
" # 1. Obtain node embeddings\n",
" x = self.conv_input(x, edge_index)\n",
" # 2. Readout layer\n",
" x = global_mean_pool(x, batch) # [batch_size, hidden_channels]\n",
" # 3. Apply a final classifier\n",
" x = F.dropout(x, p=0.5, training=self.training)\n",
" x = self.final_linear(x)\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 0, 0.2743515074253082\n",
"Epoch 1, 0.24504387378692627\n",
"Epoch 2, 0.2461807280778885\n",
"Epoch 3, 0.24599741399288177\n",
"Epoch 4, 0.2461780607700348\n",
"Epoch 5, 0.24923910200595856\n",
"Epoch 6, 0.24623213708400726\n",
"Epoch 7, 0.24637295305728912\n",
"Epoch 8, 0.24762295186519623\n",
"Epoch 9, 0.24508829414844513\n"
]
}
],
"source": [
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"model = GCN().to(device)\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)\n",
"loss_fn = nn.BCEWithLogitsLoss()\n",
"\n",
"model.train()\n",
"for epoch in range(10):\n",
" for batch in train_dataloader: \n",
" batch.orientable = batch.orientable.to(torch.float)\n",
" batch.to(device)\n",
" optimizer.zero_grad()\n",
" out = model(batch)\n",
" loss = loss_fn(out.squeeze(), batch.orientable)\n",
" loss.backward()\n",
" optimizer.step()\n",
" print(f\"Epoch {epoch}, {loss.item()}\")\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy: 0.0825\n"
]
}
],
"source": [
"correct = 0\n",
"total = 0\n",
"model.eval()\n",
"for testbatch in test_dataloader: \n",
" testbatch.to(device)\n",
" pred = model(testbatch)\n",
" correct += ((pred.squeeze() < 0) == testbatch.orientable).sum()\n",
" total += len(testbatch)\n",
"\n",
"acc = int(correct) / int(total)\n",
"print(f'Accuracy: {acc:.4f}')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading

0 comments on commit a24f66b

Please sign in to comment.