From 6bd0489658ed2cc7166fefc2645fd35b4b949418 Mon Sep 17 00:00:00 2001 From: ErnstRoell Date: Thu, 2 May 2024 17:08:04 +0200 Subject: [PATCH] Added bunch of experiment notebooks --- experiment.ipynb | 203 +++++++ experiment_betti_TAG.ipynb | 396 ++++++++++++ experiment_betti_TransformerConv.ipynb | 396 ++++++++++++ experiment_betti_gnn.ipynb | 396 ++++++++++++ experiment_betti_mlp.ipynb | 567 ++++++++++++++++++ experiment_gnn.ipynb | 358 ----------- experiment_mlp.ipynb | 461 -------------- experiment_name_TAG.ipynb | 356 +++++++++++ experiment_name_Transformer.ipynb | 356 +++++++++++ experiment_name_gnn.ipynb | 384 ++++++++++++ experiment_name_mlp.ipynb | 356 +++++++++++ experiment_orientability_TAG.ipynb | 344 +++++++++++ ...riment_orientability_TransformerConv.ipynb | 358 +++++++++++ experiment_orientability_gnn.ipynb | 342 +++++++++++ experiment_orientability_mlp.ipynb | 387 ++++++++++++ 15 files changed, 4841 insertions(+), 819 deletions(-) create mode 100644 experiment.ipynb create mode 100644 experiment_betti_TAG.ipynb create mode 100644 experiment_betti_TransformerConv.ipynb create mode 100644 experiment_betti_gnn.ipynb create mode 100644 experiment_betti_mlp.ipynb delete mode 100644 experiment_gnn.ipynb delete mode 100644 experiment_mlp.ipynb create mode 100644 experiment_name_TAG.ipynb create mode 100644 experiment_name_Transformer.ipynb create mode 100644 experiment_name_gnn.ipynb create mode 100644 experiment_name_mlp.ipynb create mode 100644 experiment_orientability_TAG.ipynb create mode 100644 experiment_orientability_TransformerConv.ipynb create mode 100644 experiment_orientability_gnn.ipynb create mode 100644 experiment_orientability_mlp.ipynb diff --git a/experiment.ipynb b/experiment.ipynb new file mode 100644 index 0000000..5c56c9b --- /dev/null +++ b/experiment.ipynb @@ -0,0 +1,203 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "ExecuteTime": { + "end_time": "2024-04-25T17:47:32.798261Z", + "start_time": "2024-04-25T17:47:19.035622Z" + } + }, + "outputs": [], + "source": [ + "from torch_geometric.loader import DataLoader, ImbalancedSampler\n", + "from torch_geometric.transforms import FaceToEdge, OneHotDegree\n", + "import torchvision.transforms as transforms\n", + "import torch.nn.functional as F\n", + "\n", + "from mantra.simplicial import SimplicialDataset\n", + "from mantra.transforms import (\n", + " TriangulationToFaceTransform,\n", + " OrientableToClassTransform,\n", + " DegreeTransform,\n", + ")\n", + "from validation.validate_homology import validate_betti_numbers\n", + "\n", + "import torch" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Percentage: 0.27, 0.73\n" + ] + } + ], + "source": [ + "class NameToClass: \n", + " def __init__(self):\n", + " self.class_dict = {'Klein bottle': 0, '': 1, 'RP^2': 2, 'T^2': 3, 'S^2': 4}\n", + " \n", + " def __call__(self,data):\n", + " data.y = F.one_hot(torch.tensor(self.class_dict[data.name]),num_classes=5)\n", + " return data\n", + "\n", + "tr = transforms.Compose(\n", + " [\n", + " TriangulationToFaceTransform(),\n", + " FaceToEdge(remove_faces=False),\n", + " DegreeTransform(),\n", + " OrientableToClassTransform(),\n", + " NameToClass()\n", + " ]\n", + ")\n", + "\n", + "dataset = SimplicialDataset(root=\"./data\", transform=tr)\n", + "\n", + "print(\n", + " f\"Percentage: {sum(dataset.orientable) / len(dataset):.2f}, {(len(dataset) - sum(dataset.orientable)) / len(dataset):.2f}\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "C:\\Users\\ernst\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch_geometric\\data\\storage.py:450: UserWarning: Unable to accurately infer 'num_nodes' from the attribute set '{'betti_numbers', 'name', 'torsion_coefficients', 'orientable', 'n_vertices', 'genus', 'dimension', 'face'}'. Please explicitly set 'num_nodes' as an attribute of 'data' to suppress this warning\n", + " warnings.warn(\n" + ] + }, + { + "data": { + "text/plain": [ + "tensor([0, 0, 0, 0, 1])" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data = dataset[0]\n", + "data.y" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "ename": "AttributeError", + "evalue": "'NoneType' object has no attribute 'name'", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mAttributeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[1;32mIn[19], line 5\u001b[0m\n\u001b[0;32m 3\u001b[0m cnt \u001b[38;5;241m=\u001b[39m Counter()\n\u001b[0;32m 4\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m data \u001b[38;5;129;01min\u001b[39;00m dataset:\n\u001b[1;32m----> 5\u001b[0m cnt[\u001b[43mdata\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mname\u001b[49m] \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[0;32m 7\u001b[0m cnt\n", + "\u001b[1;31mAttributeError\u001b[0m: 'NoneType' object has no attribute 'name'" + ] + } + ], + "source": [ + "from collections import Counter\n", + "# Tally occurrences of words in a list\n", + "cnt = Counter()\n", + "for data in dataset:\n", + " cnt[data.name] += 1\n", + "\n", + "cnt" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of training graphs: 562\n", + "Number of test graphs: 150\n" + ] + } + ], + "source": [ + "dataset = dataset.shuffle()\n", + "\n", + "train_dataset = dataset[:-150]\n", + "test_dataset = dataset[-150:]\n", + "\n", + "print(f\"Number of training graphs: {len(train_dataset)}\")\n", + "print(f\"Number of test graphs: {len(test_dataset)}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([0, 0, 0, 1, 1, 0, 0, 1, 0, 1])" + ] + }, + "execution_count": 59, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_loader = DataLoader(train_dataset,batch_size=10)#,sampler=ImbalancedSampler(train_dataset))\n", + "test_loader = DataLoader(test_dataset,batch_size=10)\n", + "\n", + "\n", + "for batch in train_loader:\n", + " break\n", + "\n", + "batch.y" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/experiment_betti_TAG.ipynb b/experiment_betti_TAG.ipynb new file mode 100644 index 0000000..6b9e5bf --- /dev/null +++ b/experiment_betti_TAG.ipynb @@ -0,0 +1,396 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "ExecuteTime": { + "end_time": "2024-04-25T17:47:32.798261Z", + "start_time": "2024-04-25T17:47:19.035622Z" + } + }, + "outputs": [], + "source": [ + "from torch_geometric.loader import DataLoader, ImbalancedSampler\n", + "from torch_geometric.transforms import FaceToEdge, OneHotDegree\n", + "import torchvision.transforms as transforms\n", + "\n", + "from mantra.simplicial import SimplicialDataset\n", + "from mantra.transforms import (\n", + " TriangulationToFaceTransform,\n", + " OrientableToClassTransform,\n", + " DegreeTransform,\n", + ")\n", + "from validation.validate_homology import validate_betti_numbers\n", + "\n", + "import torch" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Dataset: SimplicialDataset(712):\n", + "====================\n", + "Number of graphs: 712\n", + "Number of features: 9\n", + "Number of classes: 2\n", + "\n", + "Data(dimension=[1], n_vertices=[1], torsion_coefficients=[3], betti_numbers=[3], orientable=[1], genus=[1], name='S^2', face=[3, 4], edge_index=[2, 12], x=[4, 9], y=[1])\n", + "=============================================================\n", + "Number of nodes: 4\n", + "Number of edges: 12\n", + "Average node degree: 3.00\n", + "Has isolated nodes: False\n", + "Has self-loops: False\n", + "Is undirected: True\n", + "=============================================================\n", + "Number of orientable Manifolds: 193\n", + "Number of non-orientable Manifolds: 519\n", + "Percentage: 0.27, 0.73\n" + ] + } + ], + "source": [ + "tr = transforms.Compose(\n", + " [\n", + " TriangulationToFaceTransform(),\n", + " FaceToEdge(remove_faces=False),\n", + " DegreeTransform(),\n", + " OrientableToClassTransform(),\n", + " OneHotDegree(max_degree=8,cat=False)\n", + " ]\n", + ")\n", + "\n", + "dataset = SimplicialDataset(root=\"./data\", transform=tr)\n", + "\n", + "\n", + "print()\n", + "print(f\"Dataset: {dataset}:\")\n", + "print(\"====================\")\n", + "print(f\"Number of graphs: {len(dataset)}\")\n", + "print(f\"Number of features: {dataset.num_features}\")\n", + "print(f\"Number of classes: {dataset.num_classes}\")\n", + "\n", + "data = dataset[0] # Get the first graph object.\n", + "\n", + "print()\n", + "print(data)\n", + "print(\"=============================================================\")\n", + "\n", + "# Gather some statistics about the first graph.\n", + "print(f\"Number of nodes: {len(data.x)}\")\n", + "print(f\"Number of edges: {data.num_edges}\")\n", + "print(f\"Average node degree: {data.num_edges / len(data.x):.2f}\")\n", + "print(f\"Has isolated nodes: {data.has_isolated_nodes()}\")\n", + "print(f\"Has self-loops: {data.has_self_loops()}\")\n", + "print(f\"Is undirected: {data.is_undirected()}\")\n", + "\n", + "print(\"=============================================================\")\n", + "print(f\"Number of orientable Manifolds: {sum(dataset.orientable)}\")\n", + "print(\n", + " f\"Number of non-orientable Manifolds: {len(dataset) - sum(dataset.orientable)}\"\n", + ")\n", + "print(\n", + " f\"Percentage: {sum(dataset.orientable) / len(dataset):.2f}, {(len(dataset) - sum(dataset.orientable)) / len(dataset):.2f}\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[1, 0, 1]" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data.betti_numbers" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of training graphs: 562\n", + "Number of test graphs: 150\n" + ] + } + ], + "source": [ + "dataset = dataset.shuffle()\n", + "\n", + "train_dataset = dataset[:-150]\n", + "test_dataset = dataset[-150:]\n", + "\n", + "print(f\"Number of training graphs: {len(train_dataset)}\")\n", + "print(f\"Number of test graphs: {len(test_dataset)}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([1, 0, 0, 0, 0, 1, 0, 0, 0, 0])" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_loader = DataLoader(train_dataset,batch_size=10)#,sampler=ImbalancedSampler(train_dataset))\n", + "test_loader = DataLoader(test_dataset,batch_size=10)\n", + "\n", + "\n", + "for batch in train_loader:\n", + " break\n", + "\n", + "batch.y" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[-0.0592, -0.0860, 0.0107],\n", + " [-0.0745, -0.0838, -0.0963],\n", + " [-0.1234, -0.1174, 0.0508],\n", + " [-0.2945, -0.0624, 0.3497],\n", + " [-0.3387, -0.0551, -0.0335],\n", + " [-0.1517, -0.1775, -0.2044],\n", + " [-0.1485, -0.0703, -0.0104],\n", + " [-0.1371, -0.1050, -0.0641],\n", + " [-0.0389, -0.2190, 0.0169],\n", + " [ 0.0023, -0.2771, -0.1588]], grad_fn=)\n" + ] + } + ], + "source": [ + "from operator import concat\n", + "from torch.nn import Linear\n", + "import torch.nn.functional as F\n", + "from torch_geometric.nn import GCNConv, TAGConv,TransformerConv\n", + "from torch_geometric.nn import global_mean_pool\n", + "\n", + "\n", + "class GCN(torch.nn.Module):\n", + " def __init__(self, hidden_channels):\n", + " super(GCN, self).__init__()\n", + " torch.manual_seed(12345)\n", + " self.conv1 = TAGConv(dataset.num_node_features, hidden_channels)\n", + " self.conv2 = TAGConv(hidden_channels, hidden_channels)\n", + " self.conv3 = TAGConv(hidden_channels, hidden_channels)\n", + " self.lin = Linear(hidden_channels, 3)\n", + "\n", + " def forward(self, batch):\n", + " # 1. Obtain node embeddings\n", + " x = self.conv1(batch.x, batch.edge_index)\n", + " x = x.relu()\n", + " x = self.conv2(x, batch.edge_index)\n", + " x = x.relu()\n", + " x = self.conv3(x, batch.edge_index)\n", + "\n", + " # 2. Readout layer\n", + " x = global_mean_pool(x, batch.batch) # [batch_size, hidden_channels]\n", + "\n", + " # 3. Apply a final classifier\n", + " x = F.dropout(x, p=0.5, training=self.training)\n", + " x = self.lin(x)\n", + "\n", + " return x\n", + "\n", + "\n", + "model = GCN(hidden_channels=64)\n", + "print(model(batch))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch: 001, Train Acc: 6.5337, Test Acc: 2.0014\n", + "Epoch: 002, Train Acc: 6.1711, Test Acc: 1.7780\n", + "Epoch: 003, Train Acc: 5.5158, Test Acc: 1.5298\n", + "Epoch: 004, Train Acc: 5.0714, Test Acc: 1.4182\n", + "Epoch: 005, Train Acc: 7.9687, Test Acc: 2.0118\n", + "Epoch: 006, Train Acc: 5.9958, Test Acc: 1.6171\n", + "Epoch: 007, Train Acc: 4.3281, Test Acc: 1.3416\n", + "Epoch: 008, Train Acc: 4.8488, Test Acc: 1.4623\n", + "Epoch: 009, Train Acc: 4.3541, Test Acc: 1.3209\n", + "Epoch: 010, Train Acc: 4.5998, Test Acc: 1.2991\n", + "Epoch: 011, Train Acc: 4.9841, Test Acc: 1.4021\n", + "Epoch: 012, Train Acc: 4.2027, Test Acc: 1.3458\n", + "Epoch: 013, Train Acc: 4.1259, Test Acc: 1.2443\n", + "Epoch: 014, Train Acc: 4.9470, Test Acc: 1.5142\n", + "Epoch: 015, Train Acc: 4.3742, Test Acc: 1.3832\n", + "Epoch: 016, Train Acc: 4.3983, Test Acc: 1.3189\n", + "Epoch: 017, Train Acc: 4.8632, Test Acc: 1.4811\n", + "Epoch: 018, Train Acc: 4.1164, Test Acc: 1.2966\n", + "Epoch: 019, Train Acc: 4.2895, Test Acc: 1.3772\n", + "Epoch: 020, Train Acc: 4.3439, Test Acc: 1.4044\n", + "Epoch: 021, Train Acc: 4.5521, Test Acc: 1.4183\n", + "Epoch: 022, Train Acc: 4.1032, Test Acc: 1.3653\n", + "Epoch: 023, Train Acc: 7.5456, Test Acc: 1.5576\n", + "Epoch: 024, Train Acc: 5.2212, Test Acc: 1.5423\n", + "Epoch: 025, Train Acc: 4.2803, Test Acc: 1.2952\n", + "Epoch: 026, Train Acc: 4.8403, Test Acc: 1.4016\n", + "Epoch: 027, Train Acc: 5.2088, Test Acc: 1.4992\n", + "Epoch: 028, Train Acc: 4.1632, Test Acc: 1.2780\n", + "Epoch: 029, Train Acc: 4.1269, Test Acc: 1.3689\n", + "Epoch: 030, Train Acc: 4.8322, Test Acc: 1.4433\n", + "Epoch: 031, Train Acc: 4.7486, Test Acc: 1.4451\n", + "Epoch: 032, Train Acc: 3.9849, Test Acc: 1.2616\n", + "Epoch: 033, Train Acc: 3.9689, Test Acc: 1.3034\n", + "Epoch: 034, Train Acc: 4.3163, Test Acc: 1.3400\n", + "Epoch: 035, Train Acc: 4.1944, Test Acc: 1.3585\n", + "Epoch: 036, Train Acc: 4.7218, Test Acc: 1.5358\n", + "Epoch: 037, Train Acc: 4.0876, Test Acc: 1.3294\n", + "Epoch: 038, Train Acc: 4.2680, Test Acc: 1.4031\n", + "Epoch: 039, Train Acc: 4.0934, Test Acc: 1.3362\n" + ] + } + ], + "source": [ + "model = GCN(hidden_channels=64)\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n", + "criterion = torch.nn.MSELoss()\n", + "\n", + "\n", + "def train():\n", + " model.train()\n", + "\n", + " for data in train_loader: # Iterate in batches over the training dataset.\n", + " out = model(data) # Perform a single forward pass.\n", + " loss = criterion(out, torch.tensor(data.betti_numbers,dtype=torch.float)) # Compute the loss.\n", + " loss.backward() # Derive gradients.\n", + " optimizer.step() # Update parameters based on gradients.\n", + " optimizer.zero_grad() # Clear gradients.\n", + "\n", + "\n", + "def test(loader):\n", + " model.eval()\n", + "\n", + " losses = 0\n", + " for data in loader: # Iterate in batches over the training/test dataset.\n", + " out = model(data)\n", + " \n", + " losses += criterion(out, torch.tensor(data.betti_numbers,dtype=torch.float))\n", + " return losses\n", + "\n", + "\n", + "for epoch in range(1, 40):\n", + " train()\n", + " train_acc = test(train_loader)\n", + " test_acc = test(test_loader)\n", + " print(\n", + " f\"Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}\"\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "predicted incorrect 391\n", + "predicted incorrect 59\n", + "percentage correct 0.8688888888888889\n" + ] + } + ], + "source": [ + "model.eval()\n", + "losses = 0\n", + "y_hat = []\n", + "y = []\n", + "for data in test_loader: # Iterate in batches over the training/test dataset.\n", + " out = model(data)\n", + " y_hat.append(out.round().long())\n", + " y.append(torch.tensor(data.betti_numbers))\n", + "\n", + "y_hat = torch.vstack(y_hat)\n", + "y = torch.vstack(y)\n", + "\n", + "\n", + "incorrect = torch.count_nonzero(y-y_hat!=0).item()\n", + "correct = torch.count_nonzero(y-y_hat==0).item()\n", + "\n", + "print(\"predicted incorrect\",correct)\n", + "print(\"predicted incorrect\",incorrect)\n", + "print(\"percentage correct\", correct / (correct + incorrect))\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/experiment_betti_TransformerConv.ipynb b/experiment_betti_TransformerConv.ipynb new file mode 100644 index 0000000..d818be7 --- /dev/null +++ b/experiment_betti_TransformerConv.ipynb @@ -0,0 +1,396 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "ExecuteTime": { + "end_time": "2024-04-25T17:47:32.798261Z", + "start_time": "2024-04-25T17:47:19.035622Z" + } + }, + "outputs": [], + "source": [ + "from torch_geometric.loader import DataLoader, ImbalancedSampler\n", + "from torch_geometric.transforms import FaceToEdge, OneHotDegree\n", + "import torchvision.transforms as transforms\n", + "\n", + "from mantra.simplicial import SimplicialDataset\n", + "from mantra.transforms import (\n", + " TriangulationToFaceTransform,\n", + " OrientableToClassTransform,\n", + " DegreeTransform,\n", + ")\n", + "from validation.validate_homology import validate_betti_numbers\n", + "\n", + "import torch" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Dataset: SimplicialDataset(712):\n", + "====================\n", + "Number of graphs: 712\n", + "Number of features: 9\n", + "Number of classes: 2\n", + "\n", + "Data(dimension=[1], n_vertices=[1], torsion_coefficients=[3], betti_numbers=[3], orientable=[1], genus=[1], name='S^2', face=[3, 4], edge_index=[2, 12], x=[4, 9], y=[1])\n", + "=============================================================\n", + "Number of nodes: 4\n", + "Number of edges: 12\n", + "Average node degree: 3.00\n", + "Has isolated nodes: False\n", + "Has self-loops: False\n", + "Is undirected: True\n", + "=============================================================\n", + "Number of orientable Manifolds: 193\n", + "Number of non-orientable Manifolds: 519\n", + "Percentage: 0.27, 0.73\n" + ] + } + ], + "source": [ + "tr = transforms.Compose(\n", + " [\n", + " TriangulationToFaceTransform(),\n", + " FaceToEdge(remove_faces=False),\n", + " DegreeTransform(),\n", + " OrientableToClassTransform(),\n", + " OneHotDegree(max_degree=8,cat=False)\n", + " ]\n", + ")\n", + "\n", + "dataset = SimplicialDataset(root=\"./data\", transform=tr)\n", + "\n", + "\n", + "print()\n", + "print(f\"Dataset: {dataset}:\")\n", + "print(\"====================\")\n", + "print(f\"Number of graphs: {len(dataset)}\")\n", + "print(f\"Number of features: {dataset.num_features}\")\n", + "print(f\"Number of classes: {dataset.num_classes}\")\n", + "\n", + "data = dataset[0] # Get the first graph object.\n", + "\n", + "print()\n", + "print(data)\n", + "print(\"=============================================================\")\n", + "\n", + "# Gather some statistics about the first graph.\n", + "print(f\"Number of nodes: {len(data.x)}\")\n", + "print(f\"Number of edges: {data.num_edges}\")\n", + "print(f\"Average node degree: {data.num_edges / len(data.x):.2f}\")\n", + "print(f\"Has isolated nodes: {data.has_isolated_nodes()}\")\n", + "print(f\"Has self-loops: {data.has_self_loops()}\")\n", + "print(f\"Is undirected: {data.is_undirected()}\")\n", + "\n", + "print(\"=============================================================\")\n", + "print(f\"Number of orientable Manifolds: {sum(dataset.orientable)}\")\n", + "print(\n", + " f\"Number of non-orientable Manifolds: {len(dataset) - sum(dataset.orientable)}\"\n", + ")\n", + "print(\n", + " f\"Percentage: {sum(dataset.orientable) / len(dataset):.2f}, {(len(dataset) - sum(dataset.orientable)) / len(dataset):.2f}\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[1, 0, 1]" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data.betti_numbers" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of training graphs: 562\n", + "Number of test graphs: 150\n" + ] + } + ], + "source": [ + "dataset = dataset.shuffle()\n", + "\n", + "train_dataset = dataset[:-150]\n", + "test_dataset = dataset[-150:]\n", + "\n", + "print(f\"Number of training graphs: {len(train_dataset)}\")\n", + "print(f\"Number of test graphs: {len(test_dataset)}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([0, 0, 0, 1, 0, 0, 1, 0, 0, 0])" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_loader = DataLoader(train_dataset,batch_size=10)#,sampler=ImbalancedSampler(train_dataset))\n", + "test_loader = DataLoader(test_dataset,batch_size=10)\n", + "\n", + "\n", + "for batch in train_loader:\n", + " break\n", + "\n", + "batch.y" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[ 0.0788, -0.0744, -0.0329],\n", + " [ 0.1500, -0.0435, -0.1150],\n", + " [ 0.1908, 0.1661, -0.0821],\n", + " [ 0.1692, 0.0347, -0.2018],\n", + " [-0.1191, -0.0769, 0.0184],\n", + " [ 0.2564, 0.0817, -0.2506],\n", + " [-0.0386, -0.1109, -0.0483],\n", + " [ 0.2014, -0.0658, 0.0422],\n", + " [ 0.2608, -0.0205, -0.1961],\n", + " [ 0.1721, 0.1232, -0.2513]], grad_fn=)\n" + ] + } + ], + "source": [ + "from operator import concat\n", + "from torch.nn import Linear\n", + "import torch.nn.functional as F\n", + "from torch_geometric.nn import GCNConv, TAGConv,TransformerConv\n", + "from torch_geometric.nn import global_mean_pool\n", + "\n", + "\n", + "class GCN(torch.nn.Module):\n", + " def __init__(self, hidden_channels):\n", + " super(GCN, self).__init__()\n", + " torch.manual_seed(12345)\n", + " self.conv1 = TransformerConv(dataset.num_node_features, hidden_channels)\n", + " self.conv2 = TransformerConv(hidden_channels, hidden_channels)\n", + " self.conv3 = TransformerConv(hidden_channels, hidden_channels)\n", + " self.lin = Linear(hidden_channels, 3)\n", + "\n", + " def forward(self, batch):\n", + " # 1. Obtain node embeddings\n", + " x = self.conv1(batch.x, batch.edge_index)\n", + " x = x.relu()\n", + " x = self.conv2(x, batch.edge_index)\n", + " x = x.relu()\n", + " x = self.conv3(x, batch.edge_index)\n", + "\n", + " # 2. Readout layer\n", + " x = global_mean_pool(x, batch.batch) # [batch_size, hidden_channels]\n", + "\n", + " # 3. Apply a final classifier\n", + " x = F.dropout(x, p=0.5, training=self.training)\n", + " x = self.lin(x)\n", + "\n", + " return x\n", + "\n", + "\n", + "model = GCN(hidden_channels=64)\n", + "print(model(batch))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch: 001, Train Acc: 14.3012, Test Acc: 3.9595\n", + "Epoch: 002, Train Acc: 7.9209, Test Acc: 2.2424\n", + "Epoch: 003, Train Acc: 7.3249, Test Acc: 2.1451\n", + "Epoch: 004, Train Acc: 7.0863, Test Acc: 2.0972\n", + "Epoch: 005, Train Acc: 6.5136, Test Acc: 1.9514\n", + "Epoch: 006, Train Acc: 5.7232, Test Acc: 1.6731\n", + "Epoch: 007, Train Acc: 5.6288, Test Acc: 1.6629\n", + "Epoch: 008, Train Acc: 5.3735, Test Acc: 1.6019\n", + "Epoch: 009, Train Acc: 5.5434, Test Acc: 1.6216\n", + "Epoch: 010, Train Acc: 5.0546, Test Acc: 1.4855\n", + "Epoch: 011, Train Acc: 5.0977, Test Acc: 1.4528\n", + "Epoch: 012, Train Acc: 5.1413, Test Acc: 1.5387\n", + "Epoch: 013, Train Acc: 5.0282, Test Acc: 1.4862\n", + "Epoch: 014, Train Acc: 4.5230, Test Acc: 1.3564\n", + "Epoch: 015, Train Acc: 4.5798, Test Acc: 1.3817\n", + "Epoch: 016, Train Acc: 4.7982, Test Acc: 1.4428\n", + "Epoch: 017, Train Acc: 5.1819, Test Acc: 1.5367\n", + "Epoch: 018, Train Acc: 5.0257, Test Acc: 1.5190\n", + "Epoch: 019, Train Acc: 4.4559, Test Acc: 1.3445\n", + "Epoch: 020, Train Acc: 4.2240, Test Acc: 1.2602\n", + "Epoch: 021, Train Acc: 4.5640, Test Acc: 1.3807\n", + "Epoch: 022, Train Acc: 4.7071, Test Acc: 1.4230\n", + "Epoch: 023, Train Acc: 4.5142, Test Acc: 1.3686\n", + "Epoch: 024, Train Acc: 4.2264, Test Acc: 1.2772\n", + "Epoch: 025, Train Acc: 4.0160, Test Acc: 1.2058\n", + "Epoch: 026, Train Acc: 4.4284, Test Acc: 1.3742\n", + "Epoch: 027, Train Acc: 4.1022, Test Acc: 1.2546\n", + "Epoch: 028, Train Acc: 4.1525, Test Acc: 1.2840\n", + "Epoch: 029, Train Acc: 4.2535, Test Acc: 1.3084\n", + "Epoch: 030, Train Acc: 4.3463, Test Acc: 1.3431\n", + "Epoch: 031, Train Acc: 3.9852, Test Acc: 1.2140\n", + "Epoch: 032, Train Acc: 4.1442, Test Acc: 1.2774\n", + "Epoch: 033, Train Acc: 4.0151, Test Acc: 1.2465\n", + "Epoch: 034, Train Acc: 4.9950, Test Acc: 1.5060\n", + "Epoch: 035, Train Acc: 3.9615, Test Acc: 1.2502\n", + "Epoch: 036, Train Acc: 4.5018, Test Acc: 1.4480\n", + "Epoch: 037, Train Acc: 3.8346, Test Acc: 1.1762\n", + "Epoch: 038, Train Acc: 3.7178, Test Acc: 1.1769\n", + "Epoch: 039, Train Acc: 3.8979, Test Acc: 1.2257\n" + ] + } + ], + "source": [ + "model = GCN(hidden_channels=64)\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n", + "criterion = torch.nn.MSELoss()\n", + "\n", + "\n", + "def train():\n", + " model.train()\n", + "\n", + " for data in train_loader: # Iterate in batches over the training dataset.\n", + " out = model(data) # Perform a single forward pass.\n", + " loss = criterion(out, torch.tensor(data.betti_numbers,dtype=torch.float)) # Compute the loss.\n", + " loss.backward() # Derive gradients.\n", + " optimizer.step() # Update parameters based on gradients.\n", + " optimizer.zero_grad() # Clear gradients.\n", + "\n", + "\n", + "def test(loader):\n", + " model.eval()\n", + "\n", + " losses = 0\n", + " for data in loader: # Iterate in batches over the training/test dataset.\n", + " out = model(data)\n", + " \n", + " losses += criterion(out, torch.tensor(data.betti_numbers,dtype=torch.float))\n", + " return losses\n", + "\n", + "\n", + "for epoch in range(1, 40):\n", + " train()\n", + " train_acc = test(train_loader)\n", + " test_acc = test(test_loader)\n", + " print(\n", + " f\"Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}\"\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "predicted incorrect 395\n", + "predicted incorrect 55\n", + "percentage correct 0.8777777777777778\n" + ] + } + ], + "source": [ + "model.eval()\n", + "losses = 0\n", + "y_hat = []\n", + "y = []\n", + "for data in test_loader: # Iterate in batches over the training/test dataset.\n", + " out = model(data)\n", + " y_hat.append(out.round().long())\n", + " y.append(torch.tensor(data.betti_numbers))\n", + "\n", + "y_hat = torch.vstack(y_hat)\n", + "y = torch.vstack(y)\n", + "\n", + "\n", + "incorrect = torch.count_nonzero(y-y_hat!=0).item()\n", + "correct = torch.count_nonzero(y-y_hat==0).item()\n", + "\n", + "print(\"predicted incorrect\",correct)\n", + "print(\"predicted incorrect\",incorrect)\n", + "print(\"percentage correct\", correct / (correct + incorrect))\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/experiment_betti_gnn.ipynb b/experiment_betti_gnn.ipynb new file mode 100644 index 0000000..4b4b2f1 --- /dev/null +++ b/experiment_betti_gnn.ipynb @@ -0,0 +1,396 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 55, + "metadata": { + "ExecuteTime": { + "end_time": "2024-04-25T17:47:32.798261Z", + "start_time": "2024-04-25T17:47:19.035622Z" + } + }, + "outputs": [], + "source": [ + "from torch_geometric.loader import DataLoader, ImbalancedSampler\n", + "from torch_geometric.transforms import FaceToEdge, OneHotDegree\n", + "import torchvision.transforms as transforms\n", + "\n", + "from mantra.simplicial import SimplicialDataset\n", + "from mantra.transforms import (\n", + " TriangulationToFaceTransform,\n", + " OrientableToClassTransform,\n", + " DegreeTransform,\n", + ")\n", + "from validation.validate_homology import validate_betti_numbers\n", + "\n", + "import torch" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, + { + "cell_type": "code", + "execution_count": 56, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Dataset: SimplicialDataset(712):\n", + "====================\n", + "Number of graphs: 712\n", + "Number of features: 9\n", + "Number of classes: 2\n", + "\n", + "Data(dimension=[1], n_vertices=[1], torsion_coefficients=[3], betti_numbers=[3], orientable=[1], genus=[1], name='S^2', face=[3, 4], edge_index=[2, 12], x=[4, 9], y=[1])\n", + "=============================================================\n", + "Number of nodes: 4\n", + "Number of edges: 12\n", + "Average node degree: 3.00\n", + "Has isolated nodes: False\n", + "Has self-loops: False\n", + "Is undirected: True\n", + "=============================================================\n", + "Number of orientable Manifolds: 193\n", + "Number of non-orientable Manifolds: 519\n", + "Percentage: 0.27, 0.73\n" + ] + } + ], + "source": [ + "tr = transforms.Compose(\n", + " [\n", + " TriangulationToFaceTransform(),\n", + " FaceToEdge(remove_faces=False),\n", + " DegreeTransform(),\n", + " OrientableToClassTransform(),\n", + " OneHotDegree(max_degree=8,cat=False)\n", + " ]\n", + ")\n", + "\n", + "dataset = SimplicialDataset(root=\"./data\", transform=tr)\n", + "\n", + "\n", + "print()\n", + "print(f\"Dataset: {dataset}:\")\n", + "print(\"====================\")\n", + "print(f\"Number of graphs: {len(dataset)}\")\n", + "print(f\"Number of features: {dataset.num_features}\")\n", + "print(f\"Number of classes: {dataset.num_classes}\")\n", + "\n", + "data = dataset[0] # Get the first graph object.\n", + "\n", + "print()\n", + "print(data)\n", + "print(\"=============================================================\")\n", + "\n", + "# Gather some statistics about the first graph.\n", + "print(f\"Number of nodes: {len(data.x)}\")\n", + "print(f\"Number of edges: {data.num_edges}\")\n", + "print(f\"Average node degree: {data.num_edges / len(data.x):.2f}\")\n", + "print(f\"Has isolated nodes: {data.has_isolated_nodes()}\")\n", + "print(f\"Has self-loops: {data.has_self_loops()}\")\n", + "print(f\"Is undirected: {data.is_undirected()}\")\n", + "\n", + "print(\"=============================================================\")\n", + "print(f\"Number of orientable Manifolds: {sum(dataset.orientable)}\")\n", + "print(\n", + " f\"Number of non-orientable Manifolds: {len(dataset) - sum(dataset.orientable)}\"\n", + ")\n", + "print(\n", + " f\"Percentage: {sum(dataset.orientable) / len(dataset):.2f}, {(len(dataset) - sum(dataset.orientable)) / len(dataset):.2f}\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[1, 0, 1]" + ] + }, + "execution_count": 57, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data.betti_numbers" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of training graphs: 562\n", + "Number of test graphs: 150\n" + ] + } + ], + "source": [ + "dataset = dataset.shuffle()\n", + "\n", + "train_dataset = dataset[:-150]\n", + "test_dataset = dataset[-150:]\n", + "\n", + "print(f\"Number of training graphs: {len(train_dataset)}\")\n", + "print(f\"Number of test graphs: {len(test_dataset)}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([0, 0, 0, 1, 1, 0, 0, 1, 0, 1])" + ] + }, + "execution_count": 59, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_loader = DataLoader(train_dataset,batch_size=10)#,sampler=ImbalancedSampler(train_dataset))\n", + "test_loader = DataLoader(test_dataset,batch_size=10)\n", + "\n", + "\n", + "for batch in train_loader:\n", + " break\n", + "\n", + "batch.y" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[-4.5819e-02, -3.9331e-02, -4.2941e-02],\n", + " [-3.7291e-02, -6.8504e-02, -5.4509e-03],\n", + " [-2.7116e-02, -2.7440e-02, 2.0817e-05],\n", + " [-3.9345e-02, -6.8072e-02, -8.1359e-02],\n", + " [-2.8689e-02, -4.1856e-02, -2.5476e-02],\n", + " [-8.1056e-03, -4.5064e-02, 2.9946e-02],\n", + " [-3.0017e-02, -4.4348e-02, -1.5939e-02],\n", + " [-3.4824e-03, -5.2788e-02, -1.7973e-02],\n", + " [-1.5153e-02, -6.5248e-02, -6.7883e-02],\n", + " [-6.2596e-02, -4.7315e-02, -9.9255e-03]], grad_fn=)\n" + ] + } + ], + "source": [ + "from operator import concat\n", + "from torch.nn import Linear\n", + "import torch.nn.functional as F\n", + "from torch_geometric.nn import GCNConv, TAGConv,TransformerConv\n", + "from torch_geometric.nn import global_mean_pool\n", + "\n", + "\n", + "class GCN(torch.nn.Module):\n", + " def __init__(self, hidden_channels):\n", + " super(GCN, self).__init__()\n", + " torch.manual_seed(12345)\n", + " self.conv1 = GCNConv(dataset.num_node_features, hidden_channels)\n", + " self.conv2 = GCNConv(hidden_channels, hidden_channels)\n", + " self.conv3 = GCNConv(hidden_channels, hidden_channels)\n", + " self.lin = Linear(hidden_channels, 3)\n", + "\n", + " def forward(self, batch):\n", + " # 1. Obtain node embeddings\n", + " x = self.conv1(batch.x, batch.edge_index)\n", + " x = x.relu()\n", + " x = self.conv2(x, batch.edge_index)\n", + " x = x.relu()\n", + " x = self.conv3(x, batch.edge_index)\n", + "\n", + " # 2. Readout layer\n", + " x = global_mean_pool(x, batch.batch) # [batch_size, hidden_channels]\n", + "\n", + " # 3. Apply a final classifier\n", + " x = F.dropout(x, p=0.5, training=self.training)\n", + " x = self.lin(x)\n", + "\n", + " return x\n", + "\n", + "\n", + "model = GCN(hidden_channels=64)\n", + "print(model(batch))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch: 001, Train Acc: 8.8607, Test Acc: 1.9755\n", + "Epoch: 002, Train Acc: 6.2946, Test Acc: 1.3799\n", + "Epoch: 003, Train Acc: 7.3765, Test Acc: 1.8066\n", + "Epoch: 004, Train Acc: 7.7206, Test Acc: 1.9605\n", + "Epoch: 005, Train Acc: 5.7995, Test Acc: 1.4090\n", + "Epoch: 006, Train Acc: 6.5125, Test Acc: 1.5056\n", + "Epoch: 007, Train Acc: 6.4604, Test Acc: 1.6126\n", + "Epoch: 008, Train Acc: 6.6463, Test Acc: 1.9684\n", + "Epoch: 009, Train Acc: 4.9964, Test Acc: 1.2566\n", + "Epoch: 010, Train Acc: 5.0923, Test Acc: 1.2564\n", + "Epoch: 011, Train Acc: 4.7138, Test Acc: 1.1211\n", + "Epoch: 012, Train Acc: 5.6023, Test Acc: 1.5135\n", + "Epoch: 013, Train Acc: 4.6682, Test Acc: 1.1419\n", + "Epoch: 014, Train Acc: 4.7652, Test Acc: 1.2088\n", + "Epoch: 015, Train Acc: 5.6908, Test Acc: 1.3544\n", + "Epoch: 016, Train Acc: 5.1775, Test Acc: 1.2503\n", + "Epoch: 017, Train Acc: 4.4483, Test Acc: 1.0854\n", + "Epoch: 018, Train Acc: 4.8420, Test Acc: 1.2799\n", + "Epoch: 019, Train Acc: 5.3053, Test Acc: 1.5795\n", + "Epoch: 020, Train Acc: 4.7974, Test Acc: 1.1136\n", + "Epoch: 021, Train Acc: 4.6213, Test Acc: 1.1321\n", + "Epoch: 022, Train Acc: 4.9669, Test Acc: 1.1988\n", + "Epoch: 023, Train Acc: 4.5986, Test Acc: 1.1202\n", + "Epoch: 024, Train Acc: 4.6800, Test Acc: 1.1055\n", + "Epoch: 025, Train Acc: 4.9282, Test Acc: 1.2884\n", + "Epoch: 026, Train Acc: 4.5849, Test Acc: 1.1057\n", + "Epoch: 027, Train Acc: 4.6809, Test Acc: 1.2394\n", + "Epoch: 028, Train Acc: 4.6340, Test Acc: 1.2130\n", + "Epoch: 029, Train Acc: 4.5447, Test Acc: 1.0757\n", + "Epoch: 030, Train Acc: 4.6329, Test Acc: 1.1134\n", + "Epoch: 031, Train Acc: 5.0366, Test Acc: 1.2723\n", + "Epoch: 032, Train Acc: 4.3747, Test Acc: 1.1413\n", + "Epoch: 033, Train Acc: 4.6353, Test Acc: 1.0597\n", + "Epoch: 034, Train Acc: 4.6297, Test Acc: 1.2276\n", + "Epoch: 035, Train Acc: 4.8424, Test Acc: 1.2621\n", + "Epoch: 036, Train Acc: 5.4038, Test Acc: 1.3956\n", + "Epoch: 037, Train Acc: 4.7174, Test Acc: 1.1145\n", + "Epoch: 038, Train Acc: 4.5051, Test Acc: 1.0113\n", + "Epoch: 039, Train Acc: 4.6326, Test Acc: 1.2423\n" + ] + } + ], + "source": [ + "model = GCN(hidden_channels=64)\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n", + "criterion = torch.nn.MSELoss()\n", + "\n", + "\n", + "def train():\n", + " model.train()\n", + "\n", + " for data in train_loader: # Iterate in batches over the training dataset.\n", + " out = model(data) # Perform a single forward pass.\n", + " loss = criterion(out, torch.tensor(data.betti_numbers,dtype=torch.float)) # Compute the loss.\n", + " loss.backward() # Derive gradients.\n", + " optimizer.step() # Update parameters based on gradients.\n", + " optimizer.zero_grad() # Clear gradients.\n", + "\n", + "\n", + "def test(loader):\n", + " model.eval()\n", + "\n", + " losses = 0\n", + " for data in loader: # Iterate in batches over the training/test dataset.\n", + " out = model(data)\n", + " \n", + " losses += criterion(out, torch.tensor(data.betti_numbers,dtype=torch.float))\n", + " return losses\n", + "\n", + "\n", + "for epoch in range(1, 40):\n", + " train()\n", + " train_acc = test(train_loader)\n", + " test_acc = test(test_loader)\n", + " print(\n", + " f\"Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}\"\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "predicted incorrect 411\n", + "predicted incorrect 39\n", + "percentage correct 0.9133333333333333\n" + ] + } + ], + "source": [ + "model.eval()\n", + "losses = 0\n", + "y_hat = []\n", + "y = []\n", + "for data in test_loader: # Iterate in batches over the training/test dataset.\n", + " out = model(data)\n", + " y_hat.append(out.round().long())\n", + " y.append(torch.tensor(data.betti_numbers))\n", + "\n", + "y_hat = torch.vstack(y_hat)\n", + "y = torch.vstack(y)\n", + "\n", + "\n", + "incorrect = torch.count_nonzero(y-y_hat!=0).item()\n", + "correct = torch.count_nonzero(y-y_hat==0).item()\n", + "\n", + "print(\"predicted incorrect\",correct)\n", + "print(\"predicted incorrect\",incorrect)\n", + "print(\"percentage correct\", correct / (correct + incorrect))\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/experiment_betti_mlp.ipynb b/experiment_betti_mlp.ipynb new file mode 100644 index 0000000..eb66e3b --- /dev/null +++ b/experiment_betti_mlp.ipynb @@ -0,0 +1,567 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "ExecuteTime": { + "end_time": "2024-04-25T17:47:32.798261Z", + "start_time": "2024-04-25T17:47:19.035622Z" + } + }, + "outputs": [], + "source": [ + "from torch_geometric.loader import DataLoader, ImbalancedSampler\n", + "from torch_geometric.transforms import FaceToEdge, OneHotDegree\n", + "import torchvision.transforms as transforms\n", + "\n", + "from mantra.simplicial import SimplicialDataset\n", + "from mantra.transforms import (\n", + " TriangulationToFaceTransform,\n", + " OrientableToClassTransform,\n", + " DegreeTransform,\n", + ")\n", + "from validation.validate_homology import validate_betti_numbers\n", + "\n", + "import torch" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Dataset: SimplicialDataset(712):\n", + "====================\n", + "Number of graphs: 712\n", + "Number of features: 9\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "C:\\Users\\ernst\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch_geometric\\data\\storage.py:450: UserWarning: Unable to accurately infer 'num_nodes' from the attribute set '{'betti_numbers', 'dimension', 'n_vertices', 'name', 'torsion_coefficients', 'face', 'genus', 'orientable'}'. Please explicitly set 'num_nodes' as an attribute of 'data' to suppress this warning\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of classes: 2\n", + "\n", + "Data(dimension=[1], n_vertices=[1], torsion_coefficients=[3], betti_numbers=[3], orientable=[1], genus=[1], name='S^2', face=[3, 4], edge_index=[2, 12], x=[4, 9], y=[1])\n", + "=============================================================\n", + "Number of nodes: 4\n", + "Number of edges: 12\n", + "Average node degree: 3.00\n", + "Has isolated nodes: False\n", + "Has self-loops: False\n", + "Is undirected: True\n", + "=============================================================\n", + "Number of orientable Manifolds: 193\n", + "Number of non-orientable Manifolds: 519\n", + "Percentage: 0.27, 0.73\n" + ] + } + ], + "source": [ + "tr = transforms.Compose(\n", + " [\n", + " TriangulationToFaceTransform(),\n", + " FaceToEdge(remove_faces=False),\n", + " DegreeTransform(),\n", + " OrientableToClassTransform(),\n", + " OneHotDegree(max_degree=8,cat=False)\n", + " ]\n", + ")\n", + "\n", + "dataset = SimplicialDataset(root=\"./data\", transform=tr)\n", + "\n", + "\n", + "print()\n", + "print(f\"Dataset: {dataset}:\")\n", + "print(\"====================\")\n", + "print(f\"Number of graphs: {len(dataset)}\")\n", + "print(f\"Number of features: {dataset.num_features}\")\n", + "print(f\"Number of classes: {dataset.num_classes}\")\n", + "\n", + "data = dataset[0] # Get the first graph object.\n", + "\n", + "print()\n", + "print(data)\n", + "print(\"=============================================================\")\n", + "\n", + "# Gather some statistics about the first graph.\n", + "print(f\"Number of nodes: {len(data.x)}\")\n", + "print(f\"Number of edges: {data.num_edges}\")\n", + "print(f\"Average node degree: {data.num_edges / len(data.x):.2f}\")\n", + "print(f\"Has isolated nodes: {data.has_isolated_nodes()}\")\n", + "print(f\"Has self-loops: {data.has_self_loops()}\")\n", + "print(f\"Is undirected: {data.is_undirected()}\")\n", + "\n", + "print(\"=============================================================\")\n", + "print(f\"Number of orientable Manifolds: {sum(dataset.orientable)}\")\n", + "print(\n", + " f\"Number of non-orientable Manifolds: {len(dataset) - sum(dataset.orientable)}\"\n", + ")\n", + "print(\n", + " f\"Percentage: {sum(dataset.orientable) / len(dataset):.2f}, {(len(dataset) - sum(dataset.orientable)) / len(dataset):.2f}\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[1, 0, 1]" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data.betti_numbers" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of training graphs: 562\n", + "Number of test graphs: 150\n" + ] + } + ], + "source": [ + "dataset = dataset.shuffle()\n", + "\n", + "train_dataset = dataset[:-150]\n", + "test_dataset = dataset[-150:]\n", + "\n", + "print(f\"Number of training graphs: {len(train_dataset)}\")\n", + "print(f\"Number of test graphs: {len(test_dataset)}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_loader = DataLoader(train_dataset,batch_size=10)#,sampler=ImbalancedSampler(train_dataset))\n", + "test_loader = DataLoader(test_dataset,batch_size=10)\n", + "\n", + "\n", + "for batch in train_loader:\n", + " break\n", + "\n", + "batch.y" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[-1.2394, -0.3322, 0.7674],\n", + " [-1.1591, -0.6497, 0.7996],\n", + " [-1.0816, -0.3613, 0.6788],\n", + " [-1.1017, -0.5386, 0.8528],\n", + " [-1.2463, -0.3314, 0.7317],\n", + " [-1.2013, -0.5507, 0.7876],\n", + " [-1.0896, -0.5982, 0.8156],\n", + " [-1.1058, -0.6580, 0.8609],\n", + " [-1.1820, -0.3744, 0.8095],\n", + " [-1.2314, -0.4554, 0.7725]], grad_fn=>)\n" + ] + } + ], + "source": [ + "from operator import concat\n", + "from torch.nn import Linear\n", + "import torch.nn.functional as F\n", + "from torch_geometric.nn import GCNConv, TAGConv,TransformerConv\n", + "from torch_geometric.nn import global_mean_pool\n", + "from torch_scatter import segment_coo\n", + "import torch.nn as nn\n", + "\n", + "\n", + "\n", + "class PermInvariant(torch.nn.Module):\n", + " def __init__(self, hidden_channels):\n", + " super().__init__()\n", + " # torch.manual_seed(12345)\n", + " self.classification = nn.Sequential( \n", + " nn.Linear(dataset.num_node_features,hidden_channels),\n", + " nn.ReLU(),\n", + " nn.Linear(hidden_channels,hidden_channels),\n", + " nn.ReLU(),\n", + " nn.Linear(hidden_channels,3)\n", + " )\n", + "\n", + " def forward(self, batch):\n", + " x = self.classification(batch.x)\n", + " # print(batch.x)\n", + " # print(x)\n", + " return segment_coo(x,batch.batch,reduce=\"sum\")\n", + "\n", + "\n", + "model = PermInvariant(hidden_channels=64)\n", + "print(model(batch))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch: 001, Train Acc: 6.9451, Test Acc: 1.6976\n", + "Epoch: 002, Train Acc: 6.9317, Test Acc: 1.6716\n", + "Epoch: 003, Train Acc: 6.9236, Test Acc: 1.6698\n", + "Epoch: 004, Train Acc: 6.9116, Test Acc: 1.6674\n", + "Epoch: 005, Train Acc: 6.9051, Test Acc: 1.6663\n", + "Epoch: 006, Train Acc: 6.9046, Test Acc: 1.6669\n", + "Epoch: 007, Train Acc: 6.9090, Test Acc: 1.6681\n", + "Epoch: 008, Train Acc: 6.9065, Test Acc: 1.6686\n", + "Epoch: 009, Train Acc: 6.9074, Test Acc: 1.6681\n", + "Epoch: 010, Train Acc: 6.9103, Test Acc: 1.6705\n", + "Epoch: 011, Train Acc: 6.8380, Test Acc: 1.6461\n", + "Epoch: 012, Train Acc: 6.9187, Test Acc: 1.6731\n", + "Epoch: 013, Train Acc: 6.9225, Test Acc: 1.6763\n", + "Epoch: 014, Train Acc: 6.9648, Test Acc: 1.6940\n", + "Epoch: 015, Train Acc: 6.9375, Test Acc: 1.6826\n", + "Epoch: 016, Train Acc: 6.9334, Test Acc: 1.6820\n", + "Epoch: 017, Train Acc: 6.9338, Test Acc: 1.6821\n", + "Epoch: 018, Train Acc: 6.9380, Test Acc: 1.6837\n", + "Epoch: 019, Train Acc: 6.9421, Test Acc: 1.6853\n", + "Epoch: 020, Train Acc: 6.9448, Test Acc: 1.6870\n", + "Epoch: 021, Train Acc: 6.9383, Test Acc: 1.6866\n", + "Epoch: 022, Train Acc: 6.9410, Test Acc: 1.6878\n", + "Epoch: 023, Train Acc: 6.9417, Test Acc: 1.6884\n", + "Epoch: 024, Train Acc: 6.9426, Test Acc: 1.6892\n", + "Epoch: 025, Train Acc: 6.9417, Test Acc: 1.6881\n", + "Epoch: 026, Train Acc: 6.9443, Test Acc: 1.6881\n", + "Epoch: 027, Train Acc: 6.9441, Test Acc: 1.6909\n", + "Epoch: 028, Train Acc: 6.9453, Test Acc: 1.6914\n", + "Epoch: 029, Train Acc: 6.9464, Test Acc: 1.6924\n", + "Epoch: 030, Train Acc: 6.9480, Test Acc: 1.6934\n", + "Epoch: 031, Train Acc: 6.9466, Test Acc: 1.6930\n", + "Epoch: 032, Train Acc: 6.9478, Test Acc: 1.6928\n", + "Epoch: 033, Train Acc: 6.9485, Test Acc: 1.6938\n", + "Epoch: 034, Train Acc: 6.9478, Test Acc: 1.6942\n", + "Epoch: 035, Train Acc: 6.9534, Test Acc: 1.6966\n", + "Epoch: 036, Train Acc: 6.9501, Test Acc: 1.6955\n", + "Epoch: 037, Train Acc: 6.9508, Test Acc: 1.6963\n", + "Epoch: 038, Train Acc: 6.9621, Test Acc: 1.6935\n", + "Epoch: 039, Train Acc: 6.9372, Test Acc: 1.6911\n" + ] + } + ], + "source": [ + "model = PermInvariant(hidden_channels=64)\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n", + "criterion = torch.nn.MSELoss()\n", + "\n", + "\n", + "def train():\n", + " model.train()\n", + "\n", + " for data in train_loader: # Iterate in batches over the training dataset.\n", + " out = model(data) # Perform a single forward pass.\n", + " loss = criterion(out, torch.tensor(data.betti_numbers,dtype=torch.float)) # Compute the loss.\n", + " loss.backward() # Derive gradients.\n", + " optimizer.step() # Update parameters based on gradients.\n", + " optimizer.zero_grad() # Clear gradients.\n", + "\n", + "\n", + "def test(loader):\n", + " model.eval()\n", + "\n", + " losses = 0\n", + " for data in loader: # Iterate in batches over the training/test dataset.\n", + " out = model(data)\n", + " \n", + " losses += criterion(out, torch.tensor(data.betti_numbers,dtype=torch.float))\n", + " return losses\n", + "\n", + "\n", + "for epoch in range(1, 40):\n", + " train()\n", + " train_acc = test(train_loader)\n", + " test_acc = test(test_loader)\n", + " print(\n", + " f\"Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}\"\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "predicted incorrect 371\n", + "predicted incorrect 79\n", + "percentage correct 0.8244444444444444\n" + ] + } + ], + "source": [ + "model.eval()\n", + "losses = 0\n", + "y_hat = []\n", + "y = []\n", + "for data in test_loader: # Iterate in batches over the training/test dataset.\n", + " out = model(data)\n", + " y_hat.append(out.round().long())\n", + " y.append(torch.tensor(data.betti_numbers))\n", + "\n", + "y_hat = torch.vstack(y_hat)\n", + "y = torch.vstack(y)\n", + "\n", + "\n", + "incorrect = torch.count_nonzero(y-y_hat!=0).item()\n", + "correct = torch.count_nonzero(y-y_hat==0).item()\n", + "\n", + "print(\"predicted incorrect\",correct)\n", + "print(\"predicted incorrect\",incorrect)\n", + "print(\"percentage correct\", correct / (correct + incorrect))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[1, 2, 0],\n", + " [1, 2, 1],\n", + " [1, 1, 0],\n", + " [1, 0, 0],\n", + " [1, 0, 0],\n", + " [1, 0, 0],\n", + " [1, 2, 0],\n", + " [1, 1, 0],\n", + " [1, 1, 0],\n", + " [1, 1, 0],\n", + " [1, 0, 0],\n", + " [1, 0, 0],\n", + " [1, 2, 0],\n", + " [1, 1, 0],\n", + " [1, 2, 0],\n", + " [1, 0, 1],\n", + " [1, 0, 0],\n", + " [1, 2, 1],\n", + " [1, 1, 0],\n", + " [1, 0, 0],\n", + " [1, 0, 1],\n", + " [1, 2, 1],\n", + " [1, 2, 1],\n", + " [1, 1, 0],\n", + " [1, 0, 0],\n", + " [1, 0, 0],\n", + " [1, 1, 0],\n", + " [1, 2, 1],\n", + " [1, 2, 0],\n", + " [1, 0, 0],\n", + " [1, 1, 0],\n", + " [1, 0, 0],\n", + " [1, 3, 0],\n", + " [1, 1, 0],\n", + " [1, 0, 0],\n", + " [1, 0, 0],\n", + " [1, 0, 1],\n", + " [1, 2, 1],\n", + " [1, 2, 0],\n", + " [1, 0, 0],\n", + " [1, 0, 0],\n", + " [1, 0, 1],\n", + " [1, 2, 1],\n", + " [1, 2, 0],\n", + " [1, 2, 1],\n", + " [1, 1, 0],\n", + " [1, 2, 1],\n", + " [1, 1, 0],\n", + " [1, 2, 0],\n", + " [1, 0, 1],\n", + " [1, 0, 0],\n", + " [1, 1, 0],\n", + " [1, 0, 0],\n", + " [1, 1, 0],\n", + " [1, 2, 1],\n", + " [1, 1, 0],\n", + " [1, 2, 0],\n", + " [1, 1, 0],\n", + " [1, 2, 0],\n", + " [1, 2, 1],\n", + " [1, 1, 0],\n", + " [1, 1, 0],\n", + " [1, 0, 0],\n", + " [1, 1, 0],\n", + " [1, 3, 0],\n", + " [1, 3, 0],\n", + " [1, 0, 0],\n", + " [1, 1, 0],\n", + " [1, 1, 0],\n", + " [1, 1, 0],\n", + " [1, 1, 0],\n", + " [1, 2, 1],\n", + " [1, 1, 0],\n", + " [1, 2, 0],\n", + " [1, 1, 0],\n", + " [1, 0, 1],\n", + " [1, 0, 1],\n", + " [1, 0, 1],\n", + " [1, 2, 0],\n", + " [1, 2, 0],\n", + " [1, 0, 1],\n", + " [1, 1, 0],\n", + " [1, 1, 0],\n", + " [1, 2, 0],\n", + " [1, 0, 0],\n", + " [1, 2, 1],\n", + " [1, 0, 1],\n", + " [1, 2, 0],\n", + " [1, 0, 1],\n", + " [1, 2, 0],\n", + " [1, 0, 0],\n", + " [1, 2, 1],\n", + " [1, 0, 1],\n", + " [1, 1, 0],\n", + " [1, 0, 0],\n", + " [1, 2, 0],\n", + " [1, 0, 0],\n", + " [1, 0, 0],\n", + " [1, 2, 1],\n", + " [1, 1, 0],\n", + " [1, 2, 1],\n", + " [1, 3, 0],\n", + " [1, 1, 0],\n", + " [1, 2, 1],\n", + " [1, 0, 1],\n", + " [1, 2, 0],\n", + " [1, 0, 0],\n", + " [1, 2, 0],\n", + " [1, 3, 0],\n", + " [1, 2, 0],\n", + " [1, 1, 0],\n", + " [1, 2, 0],\n", + " [1, 2, 1],\n", + " [1, 0, 0],\n", + " [1, 1, 0],\n", + " [1, 2, 0],\n", + " [1, 2, 0],\n", + " [1, 0, 1],\n", + " [1, 1, 0],\n", + " [1, 1, 0],\n", + " [1, 0, 0],\n", + " [1, 1, 0],\n", + " [1, 1, 0],\n", + " [1, 0, 0],\n", + " [1, 2, 0],\n", + " [1, 0, 1],\n", + " [1, 1, 0],\n", + " [1, 0, 0],\n", + " [1, 1, 0],\n", + " [1, 1, 0],\n", + " [1, 1, 0],\n", + " [1, 0, 0],\n", + " [1, 1, 0],\n", + " [1, 2, 1],\n", + " [1, 2, 1],\n", + " [1, 1, 0],\n", + " [1, 0, 1],\n", + " [1, 1, 0],\n", + " [1, 2, 1],\n", + " [1, 2, 0],\n", + " [1, 2, 0],\n", + " [1, 2, 0],\n", + " [1, 2, 1],\n", + " [1, 2, 1],\n", + " [1, 1, 0],\n", + " [1, 1, 0],\n", + " [1, 1, 0],\n", + " [1, 1, 0],\n", + " [1, 0, 0],\n", + " [1, 0, 0]])" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "y" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/experiment_gnn.ipynb b/experiment_gnn.ipynb deleted file mode 100644 index cf81c16..0000000 --- a/experiment_gnn.ipynb +++ /dev/null @@ -1,358 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "ExecuteTime": { - "end_time": "2024-04-25T17:47:32.798261Z", - "start_time": "2024-04-25T17:47:19.035622Z" - } - }, - "outputs": [], - "source": [ - "from torch_geometric.loader import DataLoader\n", - "from torch_geometric.transforms import FaceToEdge,OneHotDegree\n", - "import torchvision.transforms as transforms\n", - "\n", - "from mantra.simplicial import SimplicialDataset\n", - "from mantra.transforms import (\n", - " TriangulationToFaceTransform,\n", - " OrientableToClassTransform,\n", - " DegreeTransform,\n", - ")\n", - "from validation.validate_homology import validate_betti_numbers\n", - "\n", - "import torch" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Dataset: SimplicialDataset(712):\n", - "====================\n", - "Number of graphs: 712\n", - "Number of features: 10\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "C:\\Users\\ernst\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch_geometric\\data\\storage.py:450: UserWarning: Unable to accurately infer 'num_nodes' from the attribute set '{'n_vertices', 'torsion_coefficients', 'face', 'betti_numbers', 'name', 'genus', 'orientable', 'dimension'}'. Please explicitly set 'num_nodes' as an attribute of 'data' to suppress this warning\n", - " warnings.warn(\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Number of classes: 2\n", - "\n", - "Data(dimension=[1], n_vertices=[1], torsion_coefficients=[3], betti_numbers=[3], orientable=[1], genus=[1], name='S^2', face=[3, 4], edge_index=[2, 12], x=[4, 10], y=[1])\n", - "=============================================================\n", - "Number of nodes: 4\n", - "Number of edges: 12\n", - "Average node degree: 3.00\n", - "Has isolated nodes: False\n", - "Has self-loops: False\n", - "Is undirected: True\n", - "=============================================================\n", - "Number of orientable Manifolds: 193\n", - "Number of non-orientable Manifolds: 519\n", - "Percentage: 0.27, 0.73\n" - ] - } - ], - "source": [ - "tr = transforms.Compose(\n", - " [\n", - " TriangulationToFaceTransform(),\n", - " FaceToEdge(remove_faces=False),\n", - " DegreeTransform(),\n", - " OrientableToClassTransform(),\n", - " OneHotDegree(max_degree=8),\n", - " ]\n", - ")\n", - "\n", - "dataset = SimplicialDataset(root=\"./data\", transform=tr)\n", - "\n", - "\n", - "print()\n", - "print(f\"Dataset: {dataset}:\")\n", - "print(\"====================\")\n", - "print(f\"Number of graphs: {len(dataset)}\")\n", - "print(f\"Number of features: {dataset.num_features}\")\n", - "print(f\"Number of classes: {dataset.num_classes}\")\n", - "\n", - "data = dataset[0] # Get the first graph object.\n", - "\n", - "print()\n", - "print(data)\n", - "print(\"=============================================================\")\n", - "\n", - "# Gather some statistics about the first graph.\n", - "print(f\"Number of nodes: {len(data.x)}\")\n", - "print(f\"Number of edges: {data.num_edges}\")\n", - "print(f\"Average node degree: {data.num_edges / len(data.x):.2f}\")\n", - "print(f\"Has isolated nodes: {data.has_isolated_nodes()}\")\n", - "print(f\"Has self-loops: {data.has_self_loops()}\")\n", - "print(f\"Is undirected: {data.is_undirected()}\")\n", - "\n", - "print(\"=============================================================\")\n", - "print(f\"Number of orientable Manifolds: {sum(dataset.orientable)}\")\n", - "print(\n", - " f\"Number of non-orientable Manifolds: {len(dataset) - sum(dataset.orientable)}\"\n", - ")\n", - "print(\n", - " f\"Percentage: {sum(dataset.orientable) / len(dataset):.2f}, {(len(dataset) - sum(dataset.orientable)) / len(dataset):.2f}\"\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Number of training graphs: 562\n", - "Number of test graphs: 150\n" - ] - } - ], - "source": [ - "dataset = dataset.shuffle()\n", - "\n", - "train_dataset = dataset[:-150]\n", - "test_dataset = dataset[-150:]\n", - "\n", - "print(f\"Number of training graphs: {len(train_dataset)}\")\n", - "print(f\"Number of test graphs: {len(test_dataset)}\")" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [], - "source": [ - "train_loader = DataLoader(train_dataset)\n", - "test_loader = DataLoader(test_dataset)" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([0])\n" - ] - } - ], - "source": [ - "print(dataset[0].y)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "GCN(\n", - " (conv1): GCNConv(10, 64)\n", - " (conv2): GCNConv(64, 64)\n", - " (conv3): GCNConv(64, 64)\n", - " (lin): Linear(in_features=64, out_features=2, bias=True)\n", - ")\n" - ] - } - ], - "source": [ - "from torch.nn import Linear\n", - "import torch.nn.functional as F\n", - "from torch_geometric.nn import GCNConv\n", - "from torch_geometric.nn import global_mean_pool\n", - "\n", - "\n", - "class GCN(torch.nn.Module):\n", - " def __init__(self, hidden_channels):\n", - " super(GCN, self).__init__()\n", - " torch.manual_seed(12345)\n", - " self.conv1 = GCNConv(dataset.num_node_features, hidden_channels)\n", - " self.conv2 = GCNConv(hidden_channels, hidden_channels)\n", - " self.conv3 = GCNConv(hidden_channels, hidden_channels)\n", - " self.lin = Linear(hidden_channels, dataset.num_classes)\n", - "\n", - " def forward(self, x, edge_index, batch):\n", - " # 1. Obtain node embeddings\n", - " x = self.conv1(x, edge_index)\n", - " x = x.relu()\n", - " x = self.conv2(x, edge_index)\n", - " x = x.relu()\n", - " x = self.conv3(x, edge_index)\n", - "\n", - " # 2. Readout layer\n", - " x = global_mean_pool(x, batch) # [batch_size, hidden_channels]\n", - "\n", - " # 3. Apply a final classifier\n", - " x = F.dropout(x, p=0.5, training=self.training)\n", - " x = self.lin(x)\n", - "\n", - " return x\n", - "\n", - "\n", - "model = GCN(hidden_channels=64)\n", - "print(model)" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch: 001, Train Acc: 0.7135, Test Acc: 0.7867\n", - "Epoch: 002, Train Acc: 0.7135, Test Acc: 0.7867\n", - "Epoch: 003, Train Acc: 0.6673, Test Acc: 0.6800\n", - "Epoch: 004, Train Acc: 0.8203, Test Acc: 0.8400\n", - "Epoch: 005, Train Acc: 0.8132, Test Acc: 0.8400\n", - "Epoch: 006, Train Acc: 0.8256, Test Acc: 0.8533\n", - "Epoch: 007, Train Acc: 0.8256, Test Acc: 0.8533\n", - "Epoch: 008, Train Acc: 0.8238, Test Acc: 0.8467\n", - "Epoch: 009, Train Acc: 0.8256, Test Acc: 0.8533\n", - "Epoch: 010, Train Acc: 0.8238, Test Acc: 0.8533\n", - "Epoch: 011, Train Acc: 0.8203, Test Acc: 0.8400\n", - "Epoch: 012, Train Acc: 0.8238, Test Acc: 0.8533\n", - "Epoch: 013, Train Acc: 0.8256, Test Acc: 0.8533\n", - "Epoch: 014, Train Acc: 0.8256, Test Acc: 0.8533\n" - ] - }, - { - "ename": "KeyboardInterrupt", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", - "Cell \u001b[1;32mIn[15], line 36\u001b[0m\n\u001b[0;32m 34\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m171\u001b[39m):\n\u001b[0;32m 35\u001b[0m train()\n\u001b[1;32m---> 36\u001b[0m train_acc \u001b[38;5;241m=\u001b[39m \u001b[43mtest\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrain_loader\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 37\u001b[0m test_acc \u001b[38;5;241m=\u001b[39m test(test_loader)\n\u001b[0;32m 38\u001b[0m \u001b[38;5;28mprint\u001b[39m(\n\u001b[0;32m 39\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mEpoch: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepoch\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m03d\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, Train Acc: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtrain_acc\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m.4f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, Test Acc: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtest_acc\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m.4f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 40\u001b[0m )\n", - "Cell \u001b[1;32mIn[15], line 24\u001b[0m, in \u001b[0;36mtest\u001b[1;34m(loader)\u001b[0m\n\u001b[0;32m 22\u001b[0m correct \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[0;32m 23\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m data \u001b[38;5;129;01min\u001b[39;00m loader: \u001b[38;5;66;03m# Iterate in batches over the training/test dataset.\u001b[39;00m\n\u001b[1;32m---> 24\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43medge_index\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbatch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 25\u001b[0m pred \u001b[38;5;241m=\u001b[39m out\u001b[38;5;241m.\u001b[39margmax(dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m) \u001b[38;5;66;03m# Use the class with highest probability.\u001b[39;00m\n\u001b[0;32m 26\u001b[0m correct \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;28mint\u001b[39m(\n\u001b[0;32m 27\u001b[0m (pred \u001b[38;5;241m==\u001b[39m data\u001b[38;5;241m.\u001b[39my)\u001b[38;5;241m.\u001b[39msum()\n\u001b[0;32m 28\u001b[0m ) \u001b[38;5;66;03m# Check against ground-truth labels.\u001b[39;00m\n", - "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch\\nn\\modules\\module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n", - "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch\\nn\\modules\\module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", - "Cell \u001b[1;32mIn[14], line 22\u001b[0m, in \u001b[0;36mGCN.forward\u001b[1;34m(self, x, edge_index, batch)\u001b[0m\n\u001b[0;32m 20\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconv2(x, edge_index)\n\u001b[0;32m 21\u001b[0m x \u001b[38;5;241m=\u001b[39m x\u001b[38;5;241m.\u001b[39mrelu()\n\u001b[1;32m---> 22\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconv3\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43medge_index\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 24\u001b[0m \u001b[38;5;66;03m# 2. Readout layer\u001b[39;00m\n\u001b[0;32m 25\u001b[0m x \u001b[38;5;241m=\u001b[39m global_mean_pool(x, batch) \u001b[38;5;66;03m# [batch_size, hidden_channels]\u001b[39;00m\n", - "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch\\nn\\modules\\module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n", - "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch\\nn\\modules\\module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", - "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch_geometric\\nn\\conv\\gcn_conv.py:263\u001b[0m, in \u001b[0;36mGCNConv.forward\u001b[1;34m(self, x, edge_index, edge_weight)\u001b[0m\n\u001b[0;32m 260\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlin(x)\n\u001b[0;32m 262\u001b[0m \u001b[38;5;66;03m# propagate_type: (x: Tensor, edge_weight: OptTensor)\u001b[39;00m\n\u001b[1;32m--> 263\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpropagate\u001b[49m\u001b[43m(\u001b[49m\u001b[43medge_index\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43medge_weight\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43medge_weight\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 265\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbias \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m 266\u001b[0m out \u001b[38;5;241m=\u001b[39m out \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbias\n", - "File \u001b[1;32m~\\AppData\\Local\\Temp\\torch_geometric.nn.conv.gcn_conv_GCNConv_propagate_7evef4br.py:245\u001b[0m, in \u001b[0;36mpropagate\u001b[1;34m(self, edge_index, x, edge_weight, size)\u001b[0m\n\u001b[0;32m 236\u001b[0m kwargs \u001b[38;5;241m=\u001b[39m CollectArgs(\n\u001b[0;32m 237\u001b[0m x_j\u001b[38;5;241m=\u001b[39mkwargs\u001b[38;5;241m.\u001b[39mx_j,\n\u001b[0;32m 238\u001b[0m edge_weight\u001b[38;5;241m=\u001b[39mkwargs\u001b[38;5;241m.\u001b[39medge_weight,\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 241\u001b[0m dim_size\u001b[38;5;241m=\u001b[39mhook_kwargs[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mdim_size\u001b[39m\u001b[38;5;124m'\u001b[39m],\n\u001b[0;32m 242\u001b[0m )\n\u001b[0;32m 243\u001b[0m \u001b[38;5;66;03m# End Aggregate Forward Pre Hook #######################################\u001b[39;00m\n\u001b[1;32m--> 245\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43maggregate\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 246\u001b[0m \u001b[43m \u001b[49m\u001b[43mout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 247\u001b[0m \u001b[43m \u001b[49m\u001b[43mindex\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mindex\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 248\u001b[0m \u001b[43m \u001b[49m\u001b[43mptr\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mptr\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 249\u001b[0m \u001b[43m \u001b[49m\u001b[43mdim_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdim_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 250\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 252\u001b[0m \u001b[38;5;66;03m# Begin Aggregate Forward Hook #########################################\u001b[39;00m\n\u001b[0;32m 253\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mjit\u001b[38;5;241m.\u001b[39mis_scripting() \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_compiling():\n", - "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch_geometric\\nn\\conv\\message_passing.py:651\u001b[0m, in \u001b[0;36mMessagePassing.aggregate\u001b[1;34m(self, inputs, index, ptr, dim_size)\u001b[0m\n\u001b[0;32m 634\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21maggregate\u001b[39m(\n\u001b[0;32m 635\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[0;32m 636\u001b[0m inputs: Tensor,\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 639\u001b[0m dim_size: Optional[\u001b[38;5;28mint\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[0;32m 640\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m 641\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"Aggregates messages from neighbors as\u001b[39;00m\n\u001b[0;32m 642\u001b[0m \u001b[38;5;124;03m :math:`\\bigoplus_{j \\in \\mathcal{N}(i)}`.\u001b[39;00m\n\u001b[0;32m 643\u001b[0m \n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 649\u001b[0m \u001b[38;5;124;03m as specified in :meth:`__init__` by the :obj:`aggr` argument.\u001b[39;00m\n\u001b[0;32m 650\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[1;32m--> 651\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43maggr_module\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mindex\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mptr\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mptr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdim_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdim_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 652\u001b[0m \u001b[43m \u001b[49m\u001b[43mdim\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnode_dim\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch_geometric\\experimental.py:117\u001b[0m, in \u001b[0;36mdisable_dynamic_shapes..decorator..wrapper\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m 114\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[0;32m 115\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrapper\u001b[39m(\u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Any:\n\u001b[0;32m 116\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_experimental_mode_enabled(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mdisable_dynamic_shapes\u001b[39m\u001b[38;5;124m'\u001b[39m):\n\u001b[1;32m--> 117\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m func(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 119\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m required_arg \u001b[38;5;129;01min\u001b[39;00m required_args:\n\u001b[0;32m 120\u001b[0m index \u001b[38;5;241m=\u001b[39m required_args_pos[required_arg]\n", - "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch_geometric\\nn\\aggr\\base.py:133\u001b[0m, in \u001b[0;36mAggregation.__call__\u001b[1;34m(self, x, index, ptr, dim_size, dim, **kwargs)\u001b[0m\n\u001b[0;32m 130\u001b[0m dim_size \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mint\u001b[39m(index\u001b[38;5;241m.\u001b[39mmax()) \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m index\u001b[38;5;241m.\u001b[39mnumel() \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;241m0\u001b[39m\n\u001b[0;32m 132\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m--> 133\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28msuper\u001b[39m()\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__call__\u001b[39m(x, index\u001b[38;5;241m=\u001b[39mindex, ptr\u001b[38;5;241m=\u001b[39mptr, dim_size\u001b[38;5;241m=\u001b[39mdim_size,\n\u001b[0;32m 134\u001b[0m dim\u001b[38;5;241m=\u001b[39mdim, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 135\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (\u001b[38;5;167;01mIndexError\u001b[39;00m, \u001b[38;5;167;01mRuntimeError\u001b[39;00m) \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[0;32m 136\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m index \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", - "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch\\nn\\modules\\module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n", - "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch\\nn\\modules\\module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", - "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch_geometric\\nn\\aggr\\basic.py:22\u001b[0m, in \u001b[0;36mSumAggregation.forward\u001b[1;34m(self, x, index, ptr, dim_size, dim)\u001b[0m\n\u001b[0;32m 19\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x: Tensor, index: Optional[Tensor] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[0;32m 20\u001b[0m ptr: Optional[Tensor] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m, dim_size: Optional[\u001b[38;5;28mint\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[0;32m 21\u001b[0m dim: \u001b[38;5;28mint\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[1;32m---> 22\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreduce\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mindex\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mptr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdim_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdim\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreduce\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43msum\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n", - "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch_geometric\\nn\\aggr\\base.py:187\u001b[0m, in \u001b[0;36mAggregation.reduce\u001b[1;34m(self, x, index, ptr, dim_size, dim, reduce)\u001b[0m\n\u001b[0;32m 184\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m index \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m 185\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mAggregation requires \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mindex\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m to be specified\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m--> 187\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mscatter\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mindex\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdim\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdim_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreduce\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch_geometric\\utils\\_scatter.py:75\u001b[0m, in \u001b[0;36mscatter\u001b[1;34m(src, index, dim, dim_size, reduce)\u001b[0m\n\u001b[0;32m 73\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m reduce \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124msum\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m reduce \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124madd\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[0;32m 74\u001b[0m index \u001b[38;5;241m=\u001b[39m broadcast(index, src, dim)\n\u001b[1;32m---> 75\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43msrc\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnew_zeros\u001b[49m\u001b[43m(\u001b[49m\u001b[43msize\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscatter_add_\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdim\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mindex\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msrc\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 77\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m reduce \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmean\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[0;32m 78\u001b[0m count \u001b[38;5;241m=\u001b[39m src\u001b[38;5;241m.\u001b[39mnew_zeros(dim_size)\n", - "\u001b[1;31mKeyboardInterrupt\u001b[0m: " - ] - } - ], - "source": [ - "model = GCN(hidden_channels=64)\n", - "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n", - "criterion = torch.nn.CrossEntropyLoss()\n", - "\n", - "\n", - "def train():\n", - " model.train()\n", - "\n", - " for data in train_loader: # Iterate in batches over the training dataset.\n", - " out = model(\n", - " data.x, data.edge_index, data.batch\n", - " ) # Perform a single forward pass.\n", - " loss = criterion(out, data.y) # Compute the loss.\n", - " loss.backward() # Derive gradients.\n", - " optimizer.step() # Update parameters based on gradients.\n", - " optimizer.zero_grad() # Clear gradients.\n", - "\n", - "\n", - "def test(loader):\n", - " model.eval()\n", - "\n", - " correct = 0\n", - " for data in loader: # Iterate in batches over the training/test dataset.\n", - " out = model(data.x, data.edge_index, data.batch)\n", - " pred = out.argmax(dim=1) # Use the class with highest probability.\n", - " correct += int(\n", - " (pred == data.y).sum()\n", - " ) # Check against ground-truth labels.\n", - " return correct / len(\n", - " loader.dataset\n", - " ) # Derive ratio of correct predictions.\n", - "\n", - "\n", - "for epoch in range(1, 171):\n", - " train()\n", - " train_acc = test(train_loader)\n", - " test_acc = test(test_loader)\n", - " print(\n", - " f\"Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}\"\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.11" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/experiment_mlp.ipynb b/experiment_mlp.ipynb deleted file mode 100644 index f3dfd64..0000000 --- a/experiment_mlp.ipynb +++ /dev/null @@ -1,461 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "ExecuteTime": { - "end_time": "2024-04-25T17:47:32.798261Z", - "start_time": "2024-04-25T17:47:19.035622Z" - } - }, - "outputs": [], - "source": [ - "from torch_geometric.loader import DataLoader, ImbalancedSampler\n", - "from torch_geometric.transforms import FaceToEdge, OneHotDegree\n", - "import torchvision.transforms as transforms\n", - "\n", - "from mantra.simplicial import SimplicialDataset\n", - "from mantra.transforms import (\n", - " TriangulationToFaceTransform,\n", - " OrientableToClassTransform,\n", - " DegreeTransform,\n", - ")\n", - "from validation.validate_homology import validate_betti_numbers\n", - "\n", - "import torch" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Dataset: SimplicialDataset(712):\n", - "====================\n", - "Number of graphs: 712\n", - "Number of features: 9\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "c:\\Users\\Ernst\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\torch_geometric\\data\\storage.py:450: UserWarning: Unable to accurately infer 'num_nodes' from the attribute set '{'name', 'face', 'orientable', 'torsion_coefficients', 'dimension', 'betti_numbers', 'genus', 'n_vertices'}'. Please explicitly set 'num_nodes' as an attribute of 'data' to suppress this warning\n", - " warnings.warn(\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Number of classes: 2\n", - "\n", - "Data(dimension=[1], n_vertices=[1], torsion_coefficients=[3], betti_numbers=[3], orientable=[1], genus=[1], name='S^2', face=[3, 4], edge_index=[2, 12], x=[4, 9], y=[1])\n", - "=============================================================\n", - "Number of nodes: 4\n", - "Number of edges: 12\n", - "Average node degree: 3.00\n", - "Has isolated nodes: False\n", - "Has self-loops: False\n", - "Is undirected: True\n", - "=============================================================\n", - "Number of orientable Manifolds: 193\n", - "Number of non-orientable Manifolds: 519\n", - "Percentage: 0.27, 0.73\n" - ] - } - ], - "source": [ - "tr = transforms.Compose(\n", - " [\n", - " TriangulationToFaceTransform(),\n", - " FaceToEdge(remove_faces=False),\n", - " DegreeTransform(),\n", - " OrientableToClassTransform(),\n", - " OneHotDegree(max_degree=8,cat=False)\n", - " ]\n", - ")\n", - "\n", - "dataset = SimplicialDataset(root=\"./data\", transform=tr)\n", - "\n", - "\n", - "print()\n", - "print(f\"Dataset: {dataset}:\")\n", - "print(\"====================\")\n", - "print(f\"Number of graphs: {len(dataset)}\")\n", - "print(f\"Number of features: {dataset.num_features}\")\n", - "print(f\"Number of classes: {dataset.num_classes}\")\n", - "\n", - "data = dataset[0] # Get the first graph object.\n", - "\n", - "print()\n", - "print(data)\n", - "print(\"=============================================================\")\n", - "\n", - "# Gather some statistics about the first graph.\n", - "print(f\"Number of nodes: {len(data.x)}\")\n", - "print(f\"Number of edges: {data.num_edges}\")\n", - "print(f\"Average node degree: {data.num_edges / len(data.x):.2f}\")\n", - "print(f\"Has isolated nodes: {data.has_isolated_nodes()}\")\n", - "print(f\"Has self-loops: {data.has_self_loops()}\")\n", - "print(f\"Is undirected: {data.is_undirected()}\")\n", - "\n", - "print(\"=============================================================\")\n", - "print(f\"Number of orientable Manifolds: {sum(dataset.orientable)}\")\n", - "print(\n", - " f\"Number of non-orientable Manifolds: {len(dataset) - sum(dataset.orientable)}\"\n", - ")\n", - "print(\n", - " f\"Percentage: {sum(dataset.orientable) / len(dataset):.2f}, {(len(dataset) - sum(dataset.orientable)) / len(dataset):.2f}\"\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "c:\\Users\\Ernst\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\torch_geometric\\data\\storage.py:450: UserWarning: Unable to accurately infer 'num_nodes' from the attribute set '{'name', 'face', 'orientable', 'torsion_coefficients', 'dimension', 'betti_numbers', 'genus', 'n_vertices'}'. Please explicitly set 'num_nodes' as an attribute of 'data' to suppress this warning\n", - " warnings.warn(\n" - ] - }, - { - "data": { - "text/plain": [ - "tensor([ True, True, True, False, True, True, True, False, True, False,\n", - " False, True, True, True, True, True, True, False, True, False,\n", - " False, False, True, False, False, False, False, False, False, False,\n", - " False, False, False, False, True, True, True, True, True, False,\n", - " False, False, True, True, True, False, True, True, True, True,\n", - " False, False, True, False, True, True, True, True, True, True,\n", - " True, False, True, True, True, True, False, False, False, False,\n", - " True, True, True, False, False, False, False, False, False, False,\n", - " True, False, False, False, False, False, False, False, False, False,\n", - " False, False, False, False, False, False, False, False, True, True,\n", - " True, False, False, False, False, False, False, True, False, False,\n", - " True, True, True, False, False, True, False, False, False, False,\n", - " False, False, False, False, False, True, True, False, False, False,\n", - " True, False, False, False, False, False, False, True, False, False,\n", - " False, False, False, True, True, False, True, True, True, True,\n", - " False, False, False, False, False, False, False, False, False, False,\n", - " False, True, True, True, False, False, False, False, False, False,\n", - " False, False, True, True, True, False, False, True, True, True,\n", - " True, True, False, False, True, False, False, False, False, False,\n", - " False, False, False, False, False, False, False, False, False, False,\n", - " False, False, False, False, False, False, False, False, False, False,\n", - " False, False, False, False, False, False, False, False, False, False,\n", - " False, False, False, False, False, False, True, True, False, False,\n", - " False, False, False, False, False, False, False, False, False, True,\n", - " False, False, False, False, False, False, False, False, False, False,\n", - " False, False, False, False, False, False, False, False, False, False,\n", - " False, False, False, False, False, False, False, False, False, False,\n", - " False, False, False, False, False, False, False, False, False, False,\n", - " False, False, False, False, False, False, False, False, False, False,\n", - " False, False, False, False, False, False, False, False, False, False,\n", - " False, False, False, False, False, False, False, False, False, False,\n", - " False, False, False, False, False, False, False, False, False, False,\n", - " False, False, False, False, False, False, False, False, False, False,\n", - " True, True, False, False, False, False, False, False, False, False,\n", - " False, False, False, False, False, False, False, False, False, False,\n", - " True, True, True, True, True, False, False, False, False, False,\n", - " False, False, False, False, False, False, True, True, True, True,\n", - " True, False, False, False, True, False, False, False, True, False,\n", - " False, False, True, False, False, False, False, False, False, False,\n", - " False, True, True, True, False, False, False, False, False, False,\n", - " False, False, False, False, False, False, False, False, False, False,\n", - " False, False, False, False, False, False, False, False, False, False,\n", - " False, False, False, False, False, False, False, False, False, False,\n", - " False, False, True, False, False, True, False, False, True, True,\n", - " True, False, False, False, False, False, False, True, True, False,\n", - " False, True, False, False, False, False, False, False, False, False,\n", - " False, True, True, False, False, False, False, False, False, False,\n", - " False, False, False, False, False, True, True, True, True, True,\n", - " True, False, False, False, False, False, True, True, True, True,\n", - " True, False, False, False, True, True, True, True, True, True,\n", - " True, True, False, False, True, False, False, True, True, True,\n", - " False, True, True, True, True, False, False, False, False, True,\n", - " False, False, False, False, False, False, False, False, True, True,\n", - " True, True, True, True, True, True, False, False, False, False,\n", - " False, False, False, False, False, False, False, False, False, False,\n", - " False, False, False, False, False, False, False, False, False, False,\n", - " False, False, False, False, False, False, False, False, False, True,\n", - " True, False, False, False, False, False, False, True, True, False,\n", - " False, False, False, False, False, False, True, True, False, False,\n", - " False, False, False, False, False, False, False, False, False, False,\n", - " False, False, False, False, False, False, False, False, False, False,\n", - " False, False, True, True, True, False, False, False, False, False,\n", - " True, True, True, False, False, True, True, True, True, True,\n", - " False, False, False, True, True, True, True, True, True, True,\n", - " True, False, False, False, False, False, True, False, True, False,\n", - " True, False, False, False, True, True, True, False, False, False,\n", - " False, False, False, False, False, True, False, False, False, True,\n", - " True, False, False, False, False, True, False, False, False, False,\n", - " False, False, False, False, False, False, True, True, True, True,\n", - " True, True, True, True, True, True, True, False, True, False,\n", - " False, False, False, False, False, True, True, True, True, False,\n", - " True, True])" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "torch.tensor([data.y for data in dataset])" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Number of training graphs: 562\n", - "Number of test graphs: 150\n" - ] - } - ], - "source": [ - "dataset = dataset.shuffle()\n", - "\n", - "train_dataset = dataset[:-150]\n", - "test_dataset = dataset[-150:]\n", - "\n", - "print(f\"Number of training graphs: {len(train_dataset)}\")\n", - "print(f\"Number of test graphs: {len(test_dataset)}\")\n" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "c:\\Users\\Ernst\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\torch_geometric\\data\\storage.py:450: UserWarning: Unable to accurately infer 'num_nodes' from the attribute set '{'betti_numbers', 'face', 'n_vertices', 'name', 'dimension', 'orientable', 'torsion_coefficients', 'genus'}'. Please explicitly set 'num_nodes' as an attribute of 'data' to suppress this warning\n", - " warnings.warn(\n" - ] - }, - { - "data": { - "text/plain": [ - "tensor([1, 1, 0, 0, 1, 1, 0, 0, 0, 1])" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "train_loader = DataLoader(train_dataset,batch_size=10)#,sampler=ImbalancedSampler(train_dataset))\n", - "test_loader = DataLoader(test_dataset,batch_size=10)\n", - "\n", - "\n", - "for batch in train_loader:\n", - " break\n", - "\n", - "batch.y" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [], - "source": [ - "from torch.nn import Linear\n", - "import torch.nn.functional as F\n", - "import torch.nn as nn\n", - "from torch_geometric.nn import GCNConv\n", - "from torch_geometric.nn import global_mean_pool\n", - "from torch_scatter import segment_coo\n", - "\n", - "class PermInvariant(torch.nn.Module):\n", - " def __init__(self, hidden_channels):\n", - " super().__init__()\n", - " # torch.manual_seed(12345)\n", - " self.classification = nn.Sequential( \n", - " nn.Linear(9,hidden_channels),\n", - " nn.ReLU(),\n", - " nn.Linear(hidden_channels,hidden_channels),\n", - " nn.ReLU(),\n", - " nn.Linear(hidden_channels,2),\n", - " nn.ReLU()\n", - " )\n", - "\n", - " def forward(self, batch):\n", - " x = self.classification(batch.x)\n", - " # print(batch.x)\n", - " # print(x)\n", - " return segment_coo(x,batch.batch,reduce=\"sum\")\n", - "\n", - "\n", - "model = PermInvariant(hidden_channels=64)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch: 001, Train Acc: 0.7313, Test Acc: 0.7200\n", - "Epoch: 002, Train Acc: 0.7384, Test Acc: 0.7333\n", - "Epoch: 003, Train Acc: 0.7811, Test Acc: 0.7600\n", - "Epoch: 004, Train Acc: 0.8025, Test Acc: 0.7800\n", - "Epoch: 005, Train Acc: 0.7972, Test Acc: 0.7733\n", - "Epoch: 006, Train Acc: 0.7972, Test Acc: 0.7733\n", - "Epoch: 007, Train Acc: 0.7900, Test Acc: 0.7667\n", - "Epoch: 008, Train Acc: 0.7918, Test Acc: 0.7667\n", - "Epoch: 009, Train Acc: 0.7865, Test Acc: 0.7667\n", - "Epoch: 010, Train Acc: 0.7794, Test Acc: 0.7600\n", - "Epoch: 011, Train Acc: 0.7794, Test Acc: 0.7600\n", - "Epoch: 012, Train Acc: 0.7811, Test Acc: 0.7600\n", - "Epoch: 013, Train Acc: 0.7811, Test Acc: 0.7600\n", - "Epoch: 014, Train Acc: 0.7794, Test Acc: 0.7600\n", - "Epoch: 015, Train Acc: 0.7811, Test Acc: 0.7600\n", - "Epoch: 016, Train Acc: 0.7811, Test Acc: 0.7600\n", - "Epoch: 017, Train Acc: 0.7811, Test Acc: 0.7600\n", - "Epoch: 018, Train Acc: 0.7811, Test Acc: 0.7600\n", - "Epoch: 019, Train Acc: 0.7811, Test Acc: 0.7600\n", - "Epoch: 020, Train Acc: 0.7811, Test Acc: 0.7600\n", - "Epoch: 021, Train Acc: 0.7811, Test Acc: 0.7600\n", - "Epoch: 022, Train Acc: 0.7811, Test Acc: 0.7600\n", - "Epoch: 023, Train Acc: 0.7811, Test Acc: 0.7600\n", - "Epoch: 024, Train Acc: 0.7811, Test Acc: 0.7600\n", - "Epoch: 025, Train Acc: 0.7811, Test Acc: 0.7600\n", - "Epoch: 026, Train Acc: 0.7811, Test Acc: 0.7600\n", - "Epoch: 027, Train Acc: 0.7811, Test Acc: 0.7600\n", - "Epoch: 028, Train Acc: 0.7811, Test Acc: 0.7600\n", - "Epoch: 029, Train Acc: 0.7811, Test Acc: 0.7600\n", - "Epoch: 030, Train Acc: 0.7829, Test Acc: 0.7667\n", - "Epoch: 031, Train Acc: 0.7829, Test Acc: 0.7667\n", - "Epoch: 032, Train Acc: 0.7811, Test Acc: 0.7600\n", - "Epoch: 033, Train Acc: 0.7829, Test Acc: 0.7667\n", - "Epoch: 034, Train Acc: 0.7829, Test Acc: 0.7667\n", - "Epoch: 035, Train Acc: 0.7829, Test Acc: 0.7667\n", - "Epoch: 036, Train Acc: 0.7829, Test Acc: 0.7667\n", - "Epoch: 037, Train Acc: 0.7829, Test Acc: 0.7667\n", - "Epoch: 038, Train Acc: 0.7829, Test Acc: 0.7667\n", - "Epoch: 039, Train Acc: 0.7829, Test Acc: 0.7667\n" - ] - } - ], - "source": [ - "model = PermInvariant(hidden_channels=64)\n", - "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n", - "criterion = torch.nn.CrossEntropyLoss()\n", - "\n", - "\n", - "def train():\n", - " model.train()\n", - "\n", - " for data in train_loader: # Iterate in batches over the training dataset.\n", - " out = model(data\n", - " ) # Perform a single forward pass.\n", - " loss = criterion(out, data.y) # Compute the loss.\n", - " loss.backward() # Derive gradients.\n", - " optimizer.step() # Update parameters based on gradients.\n", - " optimizer.zero_grad() # Clear gradients.\n", - "\n", - "\n", - "def test(loader):\n", - " model.eval()\n", - "\n", - " correct = 0\n", - " for data in loader: # Iterate in batches over the training/test dataset.\n", - " out = model(data)\n", - " pred = out.argmax(dim=1) # Use the class with highest probability.\n", - " correct += int(\n", - " (pred == data.y).sum()\n", - " ) # Check against ground-truth labels.\n", - " return correct / len(\n", - " loader.dataset\n", - " ) # Derive ratio of correct predictions.\n", - "\n", - "\n", - "for epoch in range(1, 40):\n", - " train()\n", - " train_acc = test(train_loader)\n", - " test_acc = test(test_loader)\n", - " print(\n", - " f\"Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}\"\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([0.9635, 0.8106, 0.9215, 0.9906, 1.5490, 1.0627, 1.0325, 1.2307, 1.2846,\n", - " 2.0083], grad_fn=)" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "model(batch)[:,0]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.11" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/experiment_name_TAG.ipynb b/experiment_name_TAG.ipynb new file mode 100644 index 0000000..509b830 --- /dev/null +++ b/experiment_name_TAG.ipynb @@ -0,0 +1,356 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "ExecuteTime": { + "end_time": "2024-04-25T17:47:32.798261Z", + "start_time": "2024-04-25T17:47:19.035622Z" + } + }, + "outputs": [], + "source": [ + "from torch_geometric.loader import DataLoader, ImbalancedSampler\n", + "from torch_geometric.transforms import FaceToEdge, OneHotDegree\n", + "import torchvision.transforms as transforms\n", + "\n", + "from mantra.simplicial import SimplicialDataset\n", + "from mantra.transforms import (\n", + " TriangulationToFaceTransform,\n", + " OrientableToClassTransform,\n", + " DegreeTransform,\n", + ")\n", + "from validation.validate_homology import validate_betti_numbers\n", + "\n", + "import torch" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "class NameToClass: \n", + " def __init__(self):\n", + " self.class_dict = {'Klein bottle': 0, '': 1, 'RP^2': 2, 'T^2': 3, 'S^2': 4}\n", + " \n", + " def __call__(self,data):\n", + " # data.y = F.one_hot(torch.tensor(self.class_dict[data.name]),num_classes=5)\n", + " data.y = torch.tensor(self.class_dict[data.name])\n", + " return data\n", + "\n", + "tr = transforms.Compose(\n", + " [\n", + " TriangulationToFaceTransform(),\n", + " FaceToEdge(remove_faces=False),\n", + " DegreeTransform(),\n", + " OrientableToClassTransform(),\n", + " NameToClass()\n", + " ]\n", + ")\n", + "\n", + "dataset = SimplicialDataset(root=\"./data\", transform=tr)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Dataset: SimplicialDataset(712):\n", + "====================\n", + "Number of graphs: 712\n", + "Number of features: 1\n", + "\n", + "Data(dimension=[1], n_vertices=[1], torsion_coefficients=[3], betti_numbers=[3], orientable=[1], genus=[1], name='S^2', face=[3, 4], edge_index=[2, 12], x=[4, 1], y=4)\n", + "=============================================================\n", + "Number of nodes: 4\n", + "Number of edges: 12\n", + "Average node degree: 3.00\n", + "Has isolated nodes: False\n", + "Has self-loops: False\n", + "Is undirected: True\n", + "=============================================================\n", + "Number of orientable Manifolds: 193\n", + "Number of non-orientable Manifolds: 519\n", + "Percentage: 0.27, 0.73\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "C:\\Users\\ernst\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch_geometric\\data\\storage.py:450: UserWarning: Unable to accurately infer 'num_nodes' from the attribute set '{'n_vertices', 'betti_numbers', 'name', 'dimension', 'face', 'torsion_coefficients', 'orientable', 'genus'}'. Please explicitly set 'num_nodes' as an attribute of 'data' to suppress this warning\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "\n", + "\n", + "print()\n", + "print(f\"Dataset: {dataset}:\")\n", + "print(\"====================\")\n", + "print(f\"Number of graphs: {len(dataset)}\")\n", + "print(f\"Number of features: {dataset.num_features}\")\n", + "# print(f\"Number of classes: {dataset.num_classes}\")\n", + "\n", + "data = dataset[0] # Get the first graph object.\n", + "\n", + "print()\n", + "print(data)\n", + "print(\"=============================================================\")\n", + "\n", + "# Gather some statistics about the first graph.\n", + "print(f\"Number of nodes: {len(data.x)}\")\n", + "print(f\"Number of edges: {data.num_edges}\")\n", + "print(f\"Average node degree: {data.num_edges / len(data.x):.2f}\")\n", + "print(f\"Has isolated nodes: {data.has_isolated_nodes()}\")\n", + "print(f\"Has self-loops: {data.has_self_loops()}\")\n", + "print(f\"Is undirected: {data.is_undirected()}\")\n", + "\n", + "print(\"=============================================================\")\n", + "print(f\"Number of orientable Manifolds: {sum(dataset.orientable)}\")\n", + "print(\n", + " f\"Number of non-orientable Manifolds: {len(dataset) - sum(dataset.orientable)}\"\n", + ")\n", + "print(\n", + " f\"Percentage: {sum(dataset.orientable) / len(dataset):.2f}, {(len(dataset) - sum(dataset.orientable)) / len(dataset):.2f}\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of training graphs: 562\n", + "Number of test graphs: 150\n" + ] + } + ], + "source": [ + "dataset = dataset.shuffle()\n", + "\n", + "train_dataset = dataset[:-150]\n", + "test_dataset = dataset[-150:]\n", + "\n", + "print(f\"Number of training graphs: {len(train_dataset)}\")\n", + "print(f\"Number of test graphs: {len(test_dataset)}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([3, 0, 0, 1, 1, 0, 4, 2, 1, 2])" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_loader = DataLoader(train_dataset,batch_size=10)#,sampler=ImbalancedSampler(train_dataset))\n", + "test_loader = DataLoader(test_dataset,batch_size=10)\n", + "\n", + "\n", + "for batch in train_loader:\n", + " break\n", + "\n", + "batch.y" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "GCN(\n", + " (conv1): TAGConv(1, 64, K=3)\n", + " (conv2): TAGConv(64, 64, K=3)\n", + " (conv3): TAGConv(64, 64, K=3)\n", + " (lin): Linear(in_features=64, out_features=5, bias=True)\n", + ")\n" + ] + } + ], + "source": [ + "from torch.nn import Linear\n", + "import torch.nn.functional as F\n", + "from torch_geometric.nn import GCNConv, TAGConv\n", + "from torch_geometric.nn import global_mean_pool\n", + "\n", + "\n", + "class GCN(torch.nn.Module):\n", + " def __init__(self, hidden_channels):\n", + " super(GCN, self).__init__()\n", + " torch.manual_seed(12345)\n", + " self.conv1 = TAGConv(dataset.num_node_features, hidden_channels)\n", + " self.conv2 = TAGConv(hidden_channels, hidden_channels)\n", + " self.conv3 = TAGConv(hidden_channels, hidden_channels)\n", + " self.lin = Linear(hidden_channels, 5)\n", + "\n", + " def forward(self, batch):\n", + " # 1. Obtain node embeddings\n", + " x = self.conv1(batch.x, batch.edge_index)\n", + " x = x.relu()\n", + " x = self.conv2(x, batch.edge_index)\n", + " x = x.relu()\n", + " x = self.conv3(x, batch.edge_index)\n", + "\n", + " # 2. Readout layer\n", + " x = global_mean_pool(x, batch.batch) # [batch_size, hidden_channels]\n", + "\n", + " # 3. Apply a final classifier\n", + " x = F.dropout(x, p=0.5, training=self.training)\n", + " x = self.lin(x)\n", + "\n", + " return x\n", + "\n", + "\n", + "model = GCN(hidden_channels=64)\n", + "print(model)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch: 001, Train Acc: 0.2847, Test Acc: 0.2200\n", + "Epoch: 002, Train Acc: 0.2847, Test Acc: 0.2200\n", + "Epoch: 003, Train Acc: 0.2847, Test Acc: 0.2200\n", + "Epoch: 004, Train Acc: 0.2847, Test Acc: 0.2200\n", + "Epoch: 005, Train Acc: 0.2865, Test Acc: 0.2267\n", + "Epoch: 006, Train Acc: 0.3399, Test Acc: 0.2733\n", + "Epoch: 007, Train Acc: 0.2972, Test Acc: 0.2467\n", + "Epoch: 008, Train Acc: 0.2847, Test Acc: 0.2200\n", + "Epoch: 009, Train Acc: 0.2972, Test Acc: 0.2467\n", + "Epoch: 010, Train Acc: 0.2349, Test Acc: 0.2800\n", + "Epoch: 011, Train Acc: 0.4235, Test Acc: 0.4067\n", + "Epoch: 012, Train Acc: 0.3541, Test Acc: 0.3667\n", + "Epoch: 013, Train Acc: 0.4448, Test Acc: 0.4267\n", + "Epoch: 014, Train Acc: 0.5231, Test Acc: 0.5067\n", + "Epoch: 015, Train Acc: 0.3559, Test Acc: 0.3733\n", + "Epoch: 016, Train Acc: 0.2616, Test Acc: 0.2867\n", + "Epoch: 017, Train Acc: 0.2687, Test Acc: 0.2867\n", + "Epoch: 018, Train Acc: 0.3096, Test Acc: 0.3333\n", + "Epoch: 019, Train Acc: 0.4235, Test Acc: 0.4467\n", + "Epoch: 020, Train Acc: 0.5836, Test Acc: 0.5533\n", + "Epoch: 021, Train Acc: 0.4199, Test Acc: 0.4533\n", + "Epoch: 022, Train Acc: 0.4484, Test Acc: 0.4800\n", + "Epoch: 023, Train Acc: 0.5587, Test Acc: 0.5200\n", + "Epoch: 024, Train Acc: 0.5996, Test Acc: 0.5600\n", + "Epoch: 025, Train Acc: 0.7100, Test Acc: 0.6667\n", + "Epoch: 026, Train Acc: 0.6993, Test Acc: 0.6467\n", + "Epoch: 027, Train Acc: 0.7135, Test Acc: 0.6600\n", + "Epoch: 028, Train Acc: 0.7206, Test Acc: 0.6867\n", + "Epoch: 029, Train Acc: 0.7295, Test Acc: 0.7000\n", + "Epoch: 030, Train Acc: 0.7064, Test Acc: 0.6733\n", + "Epoch: 031, Train Acc: 0.7278, Test Acc: 0.6933\n", + "Epoch: 032, Train Acc: 0.7295, Test Acc: 0.6933\n", + "Epoch: 033, Train Acc: 0.8363, Test Acc: 0.8067\n", + "Epoch: 034, Train Acc: 0.8078, Test Acc: 0.7733\n", + "Epoch: 035, Train Acc: 0.8363, Test Acc: 0.8067\n", + "Epoch: 036, Train Acc: 0.8345, Test Acc: 0.8000\n", + "Epoch: 037, Train Acc: 0.8292, Test Acc: 0.7933\n", + "Epoch: 038, Train Acc: 0.8256, Test Acc: 0.7933\n", + "Epoch: 039, Train Acc: 0.8363, Test Acc: 0.8067\n" + ] + } + ], + "source": [ + "model = GCN(hidden_channels=64)\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n", + "criterion = torch.nn.CrossEntropyLoss()\n", + "\n", + "\n", + "def train():\n", + " model.train()\n", + "\n", + " for data in train_loader: # Iterate in batches over the training dataset.\n", + " out = model(data\n", + " ) # Perform a single forward pass.\n", + " loss = criterion(out, data.y) # Compute the loss.\n", + " loss.backward() # Derive gradients.\n", + " optimizer.step() # Update parameters based on gradients.\n", + " optimizer.zero_grad() # Clear gradients.\n", + "\n", + "\n", + "def test(loader):\n", + " model.eval()\n", + "\n", + " correct = 0\n", + " for data in loader: # Iterate in batches over the training/test dataset.\n", + " out = model(data)\n", + " pred = out.argmax(dim=1) # Use the class with highest probability.\n", + " correct += int(\n", + " (pred == data.y).sum()\n", + " ) # Check against ground-truth labels.\n", + " return correct / len(\n", + " loader.dataset\n", + " ) # Derive ratio of correct predictions.\n", + "\n", + "\n", + "for epoch in range(1, 40):\n", + " train()\n", + " train_acc = test(train_loader)\n", + " test_acc = test(test_loader)\n", + " print(\n", + " f\"Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}\"\n", + " )" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/experiment_name_Transformer.ipynb b/experiment_name_Transformer.ipynb new file mode 100644 index 0000000..80afc5e --- /dev/null +++ b/experiment_name_Transformer.ipynb @@ -0,0 +1,356 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "ExecuteTime": { + "end_time": "2024-04-25T17:47:32.798261Z", + "start_time": "2024-04-25T17:47:19.035622Z" + } + }, + "outputs": [], + "source": [ + "from torch_geometric.loader import DataLoader, ImbalancedSampler\n", + "from torch_geometric.transforms import FaceToEdge, OneHotDegree\n", + "import torchvision.transforms as transforms\n", + "\n", + "from mantra.simplicial import SimplicialDataset\n", + "from mantra.transforms import (\n", + " TriangulationToFaceTransform,\n", + " OrientableToClassTransform,\n", + " DegreeTransform,\n", + ")\n", + "from validation.validate_homology import validate_betti_numbers\n", + "\n", + "import torch" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "class NameToClass: \n", + " def __init__(self):\n", + " self.class_dict = {'Klein bottle': 0, '': 1, 'RP^2': 2, 'T^2': 3, 'S^2': 4}\n", + " \n", + " def __call__(self,data):\n", + " # data.y = F.one_hot(torch.tensor(self.class_dict[data.name]),num_classes=5)\n", + " data.y = torch.tensor(self.class_dict[data.name])\n", + " return data\n", + "\n", + "tr = transforms.Compose(\n", + " [\n", + " TriangulationToFaceTransform(),\n", + " FaceToEdge(remove_faces=False),\n", + " DegreeTransform(),\n", + " OrientableToClassTransform(),\n", + " NameToClass()\n", + " ]\n", + ")\n", + "\n", + "dataset = SimplicialDataset(root=\"./data\", transform=tr)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Dataset: SimplicialDataset(712):\n", + "====================\n", + "Number of graphs: 712\n", + "Number of features: 1\n", + "\n", + "Data(dimension=[1], n_vertices=[1], torsion_coefficients=[3], betti_numbers=[3], orientable=[1], genus=[1], name='S^2', face=[3, 4], edge_index=[2, 12], x=[4, 1], y=4)\n", + "=============================================================\n", + "Number of nodes: 4\n", + "Number of edges: 12\n", + "Average node degree: 3.00\n", + "Has isolated nodes: False\n", + "Has self-loops: False\n", + "Is undirected: True\n", + "=============================================================\n", + "Number of orientable Manifolds: 193\n", + "Number of non-orientable Manifolds: 519\n", + "Percentage: 0.27, 0.73\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "C:\\Users\\ernst\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch_geometric\\data\\storage.py:450: UserWarning: Unable to accurately infer 'num_nodes' from the attribute set '{'torsion_coefficients', 'n_vertices', 'betti_numbers', 'dimension', 'orientable', 'genus', 'face', 'name'}'. Please explicitly set 'num_nodes' as an attribute of 'data' to suppress this warning\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "\n", + "\n", + "print()\n", + "print(f\"Dataset: {dataset}:\")\n", + "print(\"====================\")\n", + "print(f\"Number of graphs: {len(dataset)}\")\n", + "print(f\"Number of features: {dataset.num_features}\")\n", + "# print(f\"Number of classes: {dataset.num_classes}\")\n", + "\n", + "data = dataset[0] # Get the first graph object.\n", + "\n", + "print()\n", + "print(data)\n", + "print(\"=============================================================\")\n", + "\n", + "# Gather some statistics about the first graph.\n", + "print(f\"Number of nodes: {len(data.x)}\")\n", + "print(f\"Number of edges: {data.num_edges}\")\n", + "print(f\"Average node degree: {data.num_edges / len(data.x):.2f}\")\n", + "print(f\"Has isolated nodes: {data.has_isolated_nodes()}\")\n", + "print(f\"Has self-loops: {data.has_self_loops()}\")\n", + "print(f\"Is undirected: {data.is_undirected()}\")\n", + "\n", + "print(\"=============================================================\")\n", + "print(f\"Number of orientable Manifolds: {sum(dataset.orientable)}\")\n", + "print(\n", + " f\"Number of non-orientable Manifolds: {len(dataset) - sum(dataset.orientable)}\"\n", + ")\n", + "print(\n", + " f\"Percentage: {sum(dataset.orientable) / len(dataset):.2f}, {(len(dataset) - sum(dataset.orientable)) / len(dataset):.2f}\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of training graphs: 562\n", + "Number of test graphs: 150\n" + ] + } + ], + "source": [ + "dataset = dataset.shuffle()\n", + "\n", + "train_dataset = dataset[:-150]\n", + "test_dataset = dataset[-150:]\n", + "\n", + "print(f\"Number of training graphs: {len(train_dataset)}\")\n", + "print(f\"Number of test graphs: {len(test_dataset)}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([4, 4, 0, 3, 4, 2, 2, 1, 4, 2])" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_loader = DataLoader(train_dataset,batch_size=10)#,sampler=ImbalancedSampler(train_dataset))\n", + "test_loader = DataLoader(test_dataset,batch_size=10)\n", + "\n", + "\n", + "for batch in train_loader:\n", + " break\n", + "\n", + "batch.y" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "GCN(\n", + " (conv1): TransformerConv(1, 64, heads=1)\n", + " (conv2): TransformerConv(64, 64, heads=1)\n", + " (conv3): TransformerConv(64, 64, heads=1)\n", + " (lin): Linear(in_features=64, out_features=5, bias=True)\n", + ")\n" + ] + } + ], + "source": [ + "from torch.nn import Linear\n", + "import torch.nn.functional as F\n", + "from torch_geometric.nn import GCNConv, TAGConv,TransformerConv\n", + "from torch_geometric.nn import global_mean_pool\n", + "\n", + "\n", + "class GCN(torch.nn.Module):\n", + " def __init__(self, hidden_channels):\n", + " super(GCN, self).__init__()\n", + " torch.manual_seed(12345)\n", + " self.conv1 = TransformerConv(dataset.num_node_features, hidden_channels)\n", + " self.conv2 = TransformerConv(hidden_channels, hidden_channels)\n", + " self.conv3 = TransformerConv(hidden_channels, hidden_channels)\n", + " self.lin = Linear(hidden_channels, 5)\n", + "\n", + " def forward(self, batch):\n", + " # 1. Obtain node embeddings\n", + " x = self.conv1(batch.x, batch.edge_index)\n", + " x = x.relu()\n", + " x = self.conv2(x, batch.edge_index)\n", + " x = x.relu()\n", + " x = self.conv3(x, batch.edge_index)\n", + "\n", + " # 2. Readout layer\n", + " x = global_mean_pool(x, batch.batch) # [batch_size, hidden_channels]\n", + "\n", + " # 3. Apply a final classifier\n", + " x = F.dropout(x, p=0.5, training=self.training)\n", + " x = self.lin(x)\n", + "\n", + " return x\n", + "\n", + "\n", + "model = GCN(hidden_channels=64)\n", + "print(model)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch: 001, Train Acc: 0.3327, Test Acc: 0.2933\n", + "Epoch: 002, Train Acc: 0.4751, Test Acc: 0.4533\n", + "Epoch: 003, Train Acc: 0.5089, Test Acc: 0.4933\n", + "Epoch: 004, Train Acc: 0.5125, Test Acc: 0.4600\n", + "Epoch: 005, Train Acc: 0.5996, Test Acc: 0.5667\n", + "Epoch: 006, Train Acc: 0.6637, Test Acc: 0.6200\n", + "Epoch: 007, Train Acc: 0.7046, Test Acc: 0.6067\n", + "Epoch: 008, Train Acc: 0.8167, Test Acc: 0.7867\n", + "Epoch: 009, Train Acc: 0.8345, Test Acc: 0.8000\n", + "Epoch: 010, Train Acc: 0.8310, Test Acc: 0.8000\n", + "Epoch: 011, Train Acc: 0.8185, Test Acc: 0.7600\n", + "Epoch: 012, Train Acc: 0.8114, Test Acc: 0.7667\n", + "Epoch: 013, Train Acc: 0.8363, Test Acc: 0.7933\n", + "Epoch: 014, Train Acc: 0.8363, Test Acc: 0.8000\n", + "Epoch: 015, Train Acc: 0.8363, Test Acc: 0.7933\n", + "Epoch: 016, Train Acc: 0.8381, Test Acc: 0.8000\n", + "Epoch: 017, Train Acc: 0.8399, Test Acc: 0.8000\n", + "Epoch: 018, Train Acc: 0.8381, Test Acc: 0.8000\n", + "Epoch: 019, Train Acc: 0.8399, Test Acc: 0.8000\n", + "Epoch: 020, Train Acc: 0.8256, Test Acc: 0.7933\n", + "Epoch: 021, Train Acc: 0.8381, Test Acc: 0.8000\n", + "Epoch: 022, Train Acc: 0.8399, Test Acc: 0.8000\n", + "Epoch: 023, Train Acc: 0.8381, Test Acc: 0.8000\n", + "Epoch: 024, Train Acc: 0.8399, Test Acc: 0.8000\n", + "Epoch: 025, Train Acc: 0.8381, Test Acc: 0.8000\n", + "Epoch: 026, Train Acc: 0.8399, Test Acc: 0.8000\n", + "Epoch: 027, Train Acc: 0.8381, Test Acc: 0.7933\n", + "Epoch: 028, Train Acc: 0.8399, Test Acc: 0.8000\n", + "Epoch: 029, Train Acc: 0.8399, Test Acc: 0.8000\n", + "Epoch: 030, Train Acc: 0.8399, Test Acc: 0.8000\n", + "Epoch: 031, Train Acc: 0.8381, Test Acc: 0.7933\n", + "Epoch: 032, Train Acc: 0.8399, Test Acc: 0.8000\n", + "Epoch: 033, Train Acc: 0.8399, Test Acc: 0.8000\n", + "Epoch: 034, Train Acc: 0.8399, Test Acc: 0.8000\n", + "Epoch: 035, Train Acc: 0.8399, Test Acc: 0.8000\n", + "Epoch: 036, Train Acc: 0.8399, Test Acc: 0.8000\n", + "Epoch: 037, Train Acc: 0.8399, Test Acc: 0.8000\n", + "Epoch: 038, Train Acc: 0.8399, Test Acc: 0.8000\n", + "Epoch: 039, Train Acc: 0.8399, Test Acc: 0.8000\n" + ] + } + ], + "source": [ + "model = GCN(hidden_channels=64)\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n", + "criterion = torch.nn.CrossEntropyLoss()\n", + "\n", + "\n", + "def train():\n", + " model.train()\n", + "\n", + " for data in train_loader: # Iterate in batches over the training dataset.\n", + " out = model(data\n", + " ) # Perform a single forward pass.\n", + " loss = criterion(out, data.y) # Compute the loss.\n", + " loss.backward() # Derive gradients.\n", + " optimizer.step() # Update parameters based on gradients.\n", + " optimizer.zero_grad() # Clear gradients.\n", + "\n", + "\n", + "def test(loader):\n", + " model.eval()\n", + "\n", + " correct = 0\n", + " for data in loader: # Iterate in batches over the training/test dataset.\n", + " out = model(data)\n", + " pred = out.argmax(dim=1) # Use the class with highest probability.\n", + " correct += int(\n", + " (pred == data.y).sum()\n", + " ) # Check against ground-truth labels.\n", + " return correct / len(\n", + " loader.dataset\n", + " ) # Derive ratio of correct predictions.\n", + "\n", + "\n", + "for epoch in range(1, 40):\n", + " train()\n", + " train_acc = test(train_loader)\n", + " test_acc = test(test_loader)\n", + " print(\n", + " f\"Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}\"\n", + " )" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/experiment_name_gnn.ipynb b/experiment_name_gnn.ipynb new file mode 100644 index 0000000..7144b65 --- /dev/null +++ b/experiment_name_gnn.ipynb @@ -0,0 +1,384 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 34, + "metadata": { + "ExecuteTime": { + "end_time": "2024-04-25T17:47:32.798261Z", + "start_time": "2024-04-25T17:47:19.035622Z" + } + }, + "outputs": [], + "source": [ + "from torch_geometric.loader import DataLoader\n", + "from torch_geometric.transforms import FaceToEdge,OneHotDegree\n", + "import torchvision.transforms as transforms\n", + "import torch.nn.functional as F\n", + "\n", + "from mantra.simplicial import SimplicialDataset\n", + "from mantra.transforms import (\n", + " TriangulationToFaceTransform,\n", + " OrientableToClassTransform,\n", + " DegreeTransform,\n", + ")\n", + "from validation.validate_homology import validate_betti_numbers\n", + "\n", + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [], + "source": [ + "class NameToClass: \n", + " def __init__(self):\n", + " self.class_dict = {'Klein bottle': 0, '': 1, 'RP^2': 2, 'T^2': 3, 'S^2': 4}\n", + " \n", + " def __call__(self,data):\n", + " # data.y = F.one_hot(torch.tensor(self.class_dict[data.name]),num_classes=5)\n", + " data.y = torch.tensor(self.class_dict[data.name])\n", + " return data\n", + "\n", + "tr = transforms.Compose(\n", + " [\n", + " TriangulationToFaceTransform(),\n", + " FaceToEdge(remove_faces=False),\n", + " DegreeTransform(),\n", + " OrientableToClassTransform(),\n", + " NameToClass()\n", + " ]\n", + ")\n", + "\n", + "dataset = SimplicialDataset(root=\"./data\", transform=tr)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(4)" + ] + }, + "execution_count": 36, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data = dataset[0] # Get the first graph object.\n", + "data.y" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Percentage: 0.27, 0.73\n", + "\n", + "Dataset: SimplicialDataset(712):\n", + "====================\n", + "Number of graphs: 712\n", + "Number of features: 1\n", + "\n", + "Data(dimension=[1], n_vertices=[1], torsion_coefficients=[3], betti_numbers=[3], orientable=[1], genus=[1], name='S^2', face=[3, 4], edge_index=[2, 12], x=[4, 1], y=4)\n", + "=============================================================\n", + "Number of nodes: 4\n", + "Number of edges: 12\n", + "Average node degree: 3.00\n", + "Has isolated nodes: False\n", + "Has self-loops: False\n", + "Is undirected: True\n", + "=============================================================\n", + "Number of orientable Manifolds: 193\n", + "Number of non-orientable Manifolds: 519\n", + "Percentage: 0.27, 0.73\n" + ] + } + ], + "source": [ + "\n", + "print(\n", + " f\"Percentage: {sum(dataset.orientable) / len(dataset):.2f}, {(len(dataset) - sum(dataset.orientable)) / len(dataset):.2f}\"\n", + ")\n", + "\n", + "\n", + "print()\n", + "print(f\"Dataset: {dataset}:\")\n", + "print(\"====================\")\n", + "print(f\"Number of graphs: {len(dataset)}\")\n", + "print(f\"Number of features: {dataset.num_features}\")\n", + "# print(f\"Number of classes: {dataset.num_classes}\")\n", + "\n", + "\n", + "print()\n", + "print(data)\n", + "print(\"=============================================================\")\n", + "\n", + "# Gather some statistics about the first graph.\n", + "print(f\"Number of nodes: {len(data.x)}\")\n", + "print(f\"Number of edges: {data.num_edges}\")\n", + "print(f\"Average node degree: {data.num_edges / len(data.x):.2f}\")\n", + "print(f\"Has isolated nodes: {data.has_isolated_nodes()}\")\n", + "print(f\"Has self-loops: {data.has_self_loops()}\")\n", + "print(f\"Is undirected: {data.is_undirected()}\")\n", + "\n", + "print(\"=============================================================\")\n", + "print(f\"Number of orientable Manifolds: {sum(dataset.orientable)}\")\n", + "print(\n", + " f\"Number of non-orientable Manifolds: {len(dataset) - sum(dataset.orientable)}\"\n", + ")\n", + "print(\n", + " f\"Percentage: {sum(dataset.orientable) / len(dataset):.2f}, {(len(dataset) - sum(dataset.orientable)) / len(dataset):.2f}\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of training graphs: 562\n", + "Number of test graphs: 150\n" + ] + } + ], + "source": [ + "dataset = dataset.shuffle()\n", + "\n", + "train_dataset = dataset[:-150]\n", + "test_dataset = dataset[-150:]\n", + "\n", + "print(f\"Number of training graphs: {len(train_dataset)}\")\n", + "print(f\"Number of test graphs: {len(test_dataset)}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [], + "source": [ + "train_loader = DataLoader(train_dataset,batch_size=10)\n", + "test_loader = DataLoader(test_dataset,batch_size=10)" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([0, 1, 2, 1, 1, 4, 4, 0, 3, 3])" + ] + }, + "execution_count": 40, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "for batch in train_loader:\n", + " break\n", + "batch.y " + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[-0.3993, -0.0838, -0.4797, 0.1843, 0.1530],\n", + " [ 0.2563, -0.4330, -1.0375, 0.6955, -0.1932],\n", + " [-0.4545, -0.3014, 0.0632, 0.2620, -0.0970],\n", + " [-0.5419, 0.1279, -0.4979, 0.2693, 0.0587],\n", + " [ 0.3741, -0.0875, -0.0427, 0.6432, 0.7742],\n", + " [-0.0616, -0.1327, -0.1828, 0.2206, 0.0559],\n", + " [ 0.4510, 0.2749, 0.0296, 0.2863, 0.4205],\n", + " [ 0.2428, 0.1032, -0.1836, 0.9172, 0.7433],\n", + " [ 0.2160, 0.6167, -0.4848, 0.4869, 0.3303],\n", + " [ 0.0614, 0.1672, 0.0395, 0.5327, 0.7759]],\n", + " grad_fn=)\n" + ] + } + ], + "source": [ + "from torch.nn import Linear\n", + "import torch.nn.functional as F\n", + "from torch_geometric.nn import GCNConv\n", + "from torch_geometric.nn import global_mean_pool\n", + "\n", + "\n", + "class GCN(torch.nn.Module):\n", + " def __init__(self, hidden_channels):\n", + " super(GCN, self).__init__()\n", + " torch.manual_seed(12345)\n", + " self.conv1 = GCNConv(dataset.num_node_features, hidden_channels)\n", + " self.conv2 = GCNConv(hidden_channels, hidden_channels)\n", + " self.conv3 = GCNConv(hidden_channels, hidden_channels)\n", + " self.lin = Linear(hidden_channels, 5)\n", + "\n", + " def forward(self, x, edge_index, batch):\n", + " # 1. Obtain node embeddings\n", + " x = self.conv1(x, edge_index)\n", + " x = x.relu()\n", + " x = self.conv2(x, edge_index)\n", + " x = x.relu()\n", + " x = self.conv3(x, edge_index)\n", + "\n", + " # 2. Readout layer\n", + " x = global_mean_pool(x, batch) # [batch_size, hidden_channels]\n", + "\n", + " # 3. Apply a final classifier\n", + " x = F.dropout(x, p=0.5, training=self.training)\n", + " x = self.lin(x)\n", + "\n", + " return x\n", + "\n", + "for batch in train_loader:\n", + " break\n", + "\n", + "model = GCN(hidden_channels=64)\n", + "print(model(batch.x,batch.edge_index,batch.batch))" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch: 001, Train Acc: 0.2758, Test Acc: 0.2533\n", + "Epoch: 002, Train Acc: 0.5125, Test Acc: 0.5133\n", + "Epoch: 003, Train Acc: 0.8238, Test Acc: 0.8467\n", + "Epoch: 004, Train Acc: 0.8256, Test Acc: 0.8467\n", + "Epoch: 005, Train Acc: 0.7491, Test Acc: 0.7867\n", + "Epoch: 006, Train Acc: 0.8274, Test Acc: 0.8467\n", + "Epoch: 007, Train Acc: 0.8256, Test Acc: 0.8467\n", + "Epoch: 008, Train Acc: 0.8256, Test Acc: 0.8467\n", + "Epoch: 009, Train Acc: 0.8256, Test Acc: 0.8467\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[1;32mIn[42], line 35\u001b[0m\n\u001b[0;32m 29\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m correct \u001b[38;5;241m/\u001b[39m \u001b[38;5;28mlen\u001b[39m(\n\u001b[0;32m 30\u001b[0m loader\u001b[38;5;241m.\u001b[39mdataset\n\u001b[0;32m 31\u001b[0m ) \u001b[38;5;66;03m# Derive ratio of correct predictions.\u001b[39;00m\n\u001b[0;32m 34\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m40\u001b[39m):\n\u001b[1;32m---> 35\u001b[0m \u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 36\u001b[0m train_acc \u001b[38;5;241m=\u001b[39m test(train_loader)\n\u001b[0;32m 37\u001b[0m test_acc \u001b[38;5;241m=\u001b[39m test(test_loader)\n", + "Cell \u001b[1;32mIn[42], line 10\u001b[0m, in \u001b[0;36mtrain\u001b[1;34m()\u001b[0m\n\u001b[0;32m 7\u001b[0m model\u001b[38;5;241m.\u001b[39mtrain()\n\u001b[0;32m 9\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m data \u001b[38;5;129;01min\u001b[39;00m train_loader: \u001b[38;5;66;03m# Iterate in batches over the training dataset.\u001b[39;00m\n\u001b[1;32m---> 10\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 11\u001b[0m \u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43medge_index\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbatch\u001b[49m\n\u001b[0;32m 12\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# Perform a single forward pass.\u001b[39;00m\n\u001b[0;32m 13\u001b[0m loss \u001b[38;5;241m=\u001b[39m criterion(out, data\u001b[38;5;241m.\u001b[39my) \u001b[38;5;66;03m# Compute the loss.\u001b[39;00m\n\u001b[0;32m 14\u001b[0m loss\u001b[38;5;241m.\u001b[39mbackward() \u001b[38;5;66;03m# Derive gradients.\u001b[39;00m\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch\\nn\\modules\\module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch\\nn\\modules\\module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "Cell \u001b[1;32mIn[41], line 22\u001b[0m, in \u001b[0;36mGCN.forward\u001b[1;34m(self, x, edge_index, batch)\u001b[0m\n\u001b[0;32m 20\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconv2(x, edge_index)\n\u001b[0;32m 21\u001b[0m x \u001b[38;5;241m=\u001b[39m x\u001b[38;5;241m.\u001b[39mrelu()\n\u001b[1;32m---> 22\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconv3\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43medge_index\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 24\u001b[0m \u001b[38;5;66;03m# 2. Readout layer\u001b[39;00m\n\u001b[0;32m 25\u001b[0m x \u001b[38;5;241m=\u001b[39m global_mean_pool(x, batch) \u001b[38;5;66;03m# [batch_size, hidden_channels]\u001b[39;00m\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch\\nn\\modules\\module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch\\nn\\modules\\module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch_geometric\\nn\\conv\\gcn_conv.py:241\u001b[0m, in \u001b[0;36mGCNConv.forward\u001b[1;34m(self, x, edge_index, edge_weight)\u001b[0m\n\u001b[0;32m 239\u001b[0m cache \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_cached_edge_index\n\u001b[0;32m 240\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m cache \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m--> 241\u001b[0m edge_index, edge_weight \u001b[38;5;241m=\u001b[39m \u001b[43mgcn_norm\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# yapf: disable\u001b[39;49;00m\n\u001b[0;32m 242\u001b[0m \u001b[43m \u001b[49m\u001b[43medge_index\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43medge_weight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msize\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnode_dim\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 243\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mimproved\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madd_self_loops\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mflow\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdtype\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 244\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcached:\n\u001b[0;32m 245\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_cached_edge_index \u001b[38;5;241m=\u001b[39m (edge_index, edge_weight)\n", + "\u001b[1;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "model = GCN(hidden_channels=64)\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n", + "criterion = torch.nn.CrossEntropyLoss()\n", + "\n", + "\n", + "def train():\n", + " model.train()\n", + "\n", + " for data in train_loader: # Iterate in batches over the training dataset.\n", + " out = model(\n", + " data.x, data.edge_index, data.batch\n", + " ) # Perform a single forward pass.\n", + " loss = criterion(out, data.y) # Compute the loss.\n", + " loss.backward() # Derive gradients.\n", + " optimizer.step() # Update parameters based on gradients.\n", + " optimizer.zero_grad() # Clear gradients.\n", + "\n", + "\n", + "def test(loader):\n", + " model.eval()\n", + "\n", + " correct = 0\n", + " for data in loader: # Iterate in batches over the training/test dataset.\n", + " out = model(data.x, data.edge_index, data.batch)\n", + " pred = out.argmax(dim=1) # Use the class with highest probability.\n", + " correct += int(\n", + " (pred == data.y).sum()\n", + " ) # Check against ground-truth labels.\n", + " return correct / len(\n", + " loader.dataset\n", + " ) # Derive ratio of correct predictions.\n", + "\n", + "\n", + "for epoch in range(1, 40):\n", + " train()\n", + " train_acc = test(train_loader)\n", + " test_acc = test(test_loader)\n", + " print(\n", + " f\"Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}\"\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/experiment_name_mlp.ipynb b/experiment_name_mlp.ipynb new file mode 100644 index 0000000..5e78fb4 --- /dev/null +++ b/experiment_name_mlp.ipynb @@ -0,0 +1,356 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "ExecuteTime": { + "end_time": "2024-04-25T17:47:32.798261Z", + "start_time": "2024-04-25T17:47:19.035622Z" + } + }, + "outputs": [], + "source": [ + "from torch_geometric.loader import DataLoader, ImbalancedSampler\n", + "from torch_geometric.transforms import FaceToEdge, OneHotDegree\n", + "import torchvision.transforms as transforms\n", + "\n", + "from mantra.simplicial import SimplicialDataset\n", + "from mantra.transforms import (\n", + " TriangulationToFaceTransform,\n", + " OrientableToClassTransform,\n", + " DegreeTransform,\n", + ")\n", + "from validation.validate_homology import validate_betti_numbers\n", + "\n", + "import torch" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "class NameToClass: \n", + " def __init__(self):\n", + " self.class_dict = {'Klein bottle': 0, '': 1, 'RP^2': 2, 'T^2': 3, 'S^2': 4}\n", + " \n", + " def __call__(self,data):\n", + " # data.y = F.one_hot(torch.tensor(self.class_dict[data.name]),num_classes=5)\n", + " data.y = torch.tensor(self.class_dict[data.name])\n", + " return data\n", + "\n", + "tr = transforms.Compose(\n", + " [\n", + " TriangulationToFaceTransform(),\n", + " FaceToEdge(remove_faces=False),\n", + " DegreeTransform(),\n", + " OrientableToClassTransform(),\n", + " NameToClass()\n", + " ]\n", + ")\n", + "\n", + "dataset = SimplicialDataset(root=\"./data\", transform=tr)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Dataset: SimplicialDataset(712):\n", + "====================\n", + "Number of graphs: 712\n", + "Number of features: 1\n", + "\n", + "Data(dimension=[1], n_vertices=[1], torsion_coefficients=[3], betti_numbers=[3], orientable=[1], genus=[1], name='S^2', face=[3, 4], edge_index=[2, 12], x=[4, 1], y=4)\n", + "=============================================================\n", + "Number of nodes: 4\n", + "Number of edges: 12\n", + "Average node degree: 3.00\n", + "Has isolated nodes: False\n", + "Has self-loops: False\n", + "Is undirected: True\n", + "=============================================================\n", + "Number of orientable Manifolds: 193\n", + "Number of non-orientable Manifolds: 519\n", + "Percentage: 0.27, 0.73\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "C:\\Users\\ernst\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch_geometric\\data\\storage.py:450: UserWarning: Unable to accurately infer 'num_nodes' from the attribute set '{'torsion_coefficients', 'n_vertices', 'genus', 'orientable', 'name', 'dimension', 'betti_numbers', 'face'}'. Please explicitly set 'num_nodes' as an attribute of 'data' to suppress this warning\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "\n", + "\n", + "print()\n", + "print(f\"Dataset: {dataset}:\")\n", + "print(\"====================\")\n", + "print(f\"Number of graphs: {len(dataset)}\")\n", + "print(f\"Number of features: {dataset.num_features}\")\n", + "# print(f\"Number of classes: {dataset.num_classes}\")\n", + "\n", + "data = dataset[0] # Get the first graph object.\n", + "\n", + "print()\n", + "print(data)\n", + "print(\"=============================================================\")\n", + "\n", + "# Gather some statistics about the first graph.\n", + "print(f\"Number of nodes: {len(data.x)}\")\n", + "print(f\"Number of edges: {data.num_edges}\")\n", + "print(f\"Average node degree: {data.num_edges / len(data.x):.2f}\")\n", + "print(f\"Has isolated nodes: {data.has_isolated_nodes()}\")\n", + "print(f\"Has self-loops: {data.has_self_loops()}\")\n", + "print(f\"Is undirected: {data.is_undirected()}\")\n", + "\n", + "print(\"=============================================================\")\n", + "print(f\"Number of orientable Manifolds: {sum(dataset.orientable)}\")\n", + "print(\n", + " f\"Number of non-orientable Manifolds: {len(dataset) - sum(dataset.orientable)}\"\n", + ")\n", + "print(\n", + " f\"Percentage: {sum(dataset.orientable) / len(dataset):.2f}, {(len(dataset) - sum(dataset.orientable)) / len(dataset):.2f}\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of training graphs: 562\n", + "Number of test graphs: 150\n" + ] + } + ], + "source": [ + "dataset = dataset.shuffle()\n", + "\n", + "train_dataset = dataset[:-150]\n", + "test_dataset = dataset[-150:]\n", + "\n", + "print(f\"Number of training graphs: {len(train_dataset)}\")\n", + "print(f\"Number of test graphs: {len(test_dataset)}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([2, 0, 2, 0, 4, 2, 2, 0, 0, 0])" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_loader = DataLoader(train_dataset,batch_size=10)#,sampler=ImbalancedSampler(train_dataset))\n", + "test_loader = DataLoader(test_dataset,batch_size=10)\n", + "\n", + "\n", + "for batch in train_loader:\n", + " break\n", + "\n", + "batch.y" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "PermInvariant(\n", + " (classification): Sequential(\n", + " (0): Linear(in_features=1, out_features=64, bias=True)\n", + " (1): ReLU()\n", + " (2): Linear(in_features=64, out_features=64, bias=True)\n", + " (3): ReLU()\n", + " (4): Linear(in_features=64, out_features=5, bias=True)\n", + " )\n", + ")\n" + ] + } + ], + "source": [ + "from torch.nn import Linear\n", + "import torch.nn.functional as F\n", + "from torch_geometric.nn import GCNConv, TAGConv,TransformerConv\n", + "from torch_geometric.nn import global_mean_pool\n", + "import torch.nn as nn\n", + "\n", + "from torch_scatter import segment_coo\n", + "\n", + "class PermInvariant(torch.nn.Module):\n", + " def __init__(self, hidden_channels):\n", + " super().__init__()\n", + " # torch.manual_seed(12345)\n", + " self.classification = nn.Sequential( \n", + " nn.Linear(dataset.num_node_features,hidden_channels),\n", + " nn.ReLU(),\n", + " nn.Linear(hidden_channels,hidden_channels),\n", + " nn.ReLU(),\n", + " nn.Linear(hidden_channels,5)\n", + " )\n", + "\n", + " def forward(self, batch):\n", + " x = self.classification(batch.x)\n", + " # print(batch.x)\n", + " # print(x)\n", + " return segment_coo(x,batch.batch,reduce=\"sum\")\n", + "\n", + "\n", + "model = PermInvariant(hidden_channels=64)\n", + "\n", + "\n", + "model = PermInvariant(hidden_channels=64)\n", + "print(model)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch: 001, Train Acc: 0.3470, Test Acc: 0.3933\n", + "Epoch: 002, Train Acc: 0.3523, Test Acc: 0.3867\n", + "Epoch: 003, Train Acc: 0.7580, Test Acc: 0.7867\n", + "Epoch: 004, Train Acc: 0.7562, Test Acc: 0.7800\n", + "Epoch: 005, Train Acc: 0.8274, Test Acc: 0.8467\n", + "Epoch: 006, Train Acc: 0.8274, Test Acc: 0.8467\n", + "Epoch: 007, Train Acc: 0.8274, Test Acc: 0.8467\n", + "Epoch: 008, Train Acc: 0.8274, Test Acc: 0.8467\n", + "Epoch: 009, Train Acc: 0.8274, Test Acc: 0.8467\n", + "Epoch: 010, Train Acc: 0.8274, Test Acc: 0.8467\n", + "Epoch: 011, Train Acc: 0.8274, Test Acc: 0.8467\n", + "Epoch: 012, Train Acc: 0.8274, Test Acc: 0.8467\n", + "Epoch: 013, Train Acc: 0.8274, Test Acc: 0.8467\n", + "Epoch: 014, Train Acc: 0.8274, Test Acc: 0.8467\n", + "Epoch: 015, Train Acc: 0.8274, Test Acc: 0.8467\n", + "Epoch: 016, Train Acc: 0.8274, Test Acc: 0.8467\n", + "Epoch: 017, Train Acc: 0.8274, Test Acc: 0.8467\n", + "Epoch: 018, Train Acc: 0.8274, Test Acc: 0.8467\n", + "Epoch: 019, Train Acc: 0.8274, Test Acc: 0.8467\n", + "Epoch: 020, Train Acc: 0.8274, Test Acc: 0.8467\n", + "Epoch: 021, Train Acc: 0.8274, Test Acc: 0.8467\n", + "Epoch: 022, Train Acc: 0.8274, Test Acc: 0.8467\n", + "Epoch: 023, Train Acc: 0.8256, Test Acc: 0.8467\n", + "Epoch: 024, Train Acc: 0.8274, Test Acc: 0.8467\n", + "Epoch: 025, Train Acc: 0.8274, Test Acc: 0.8467\n", + "Epoch: 026, Train Acc: 0.8274, Test Acc: 0.8467\n", + "Epoch: 027, Train Acc: 0.8274, Test Acc: 0.8467\n", + "Epoch: 028, Train Acc: 0.8274, Test Acc: 0.8467\n", + "Epoch: 029, Train Acc: 0.8274, Test Acc: 0.8467\n", + "Epoch: 030, Train Acc: 0.8274, Test Acc: 0.8467\n", + "Epoch: 031, Train Acc: 0.8274, Test Acc: 0.8467\n", + "Epoch: 032, Train Acc: 0.8274, Test Acc: 0.8467\n", + "Epoch: 033, Train Acc: 0.8274, Test Acc: 0.8467\n", + "Epoch: 034, Train Acc: 0.8274, Test Acc: 0.8467\n", + "Epoch: 035, Train Acc: 0.8274, Test Acc: 0.8467\n", + "Epoch: 036, Train Acc: 0.8274, Test Acc: 0.8467\n", + "Epoch: 037, Train Acc: 0.8274, Test Acc: 0.8467\n", + "Epoch: 038, Train Acc: 0.8274, Test Acc: 0.8467\n", + "Epoch: 039, Train Acc: 0.8274, Test Acc: 0.8467\n" + ] + } + ], + "source": [ + "model = PermInvariant(hidden_channels=64)\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n", + "criterion = torch.nn.CrossEntropyLoss()\n", + "\n", + "\n", + "def train():\n", + " model.train()\n", + "\n", + " for data in train_loader: # Iterate in batches over the training dataset.\n", + " out = model(data\n", + " ) # Perform a single forward pass.\n", + " loss = criterion(out, data.y) # Compute the loss.\n", + " loss.backward() # Derive gradients.\n", + " optimizer.step() # Update parameters based on gradients.\n", + " optimizer.zero_grad() # Clear gradients.\n", + "\n", + "\n", + "def test(loader):\n", + " model.eval()\n", + "\n", + " correct = 0\n", + " for data in loader: # Iterate in batches over the training/test dataset.\n", + " out = model(data)\n", + " pred = out.argmax(dim=1) # Use the class with highest probability.\n", + " correct += int(\n", + " (pred == data.y).sum()\n", + " ) # Check against ground-truth labels.\n", + " return correct / len(\n", + " loader.dataset\n", + " ) # Derive ratio of correct predictions.\n", + "\n", + "\n", + "for epoch in range(1, 40):\n", + " train()\n", + " train_acc = test(train_loader)\n", + " test_acc = test(test_loader)\n", + " print(\n", + " f\"Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}\"\n", + " )" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/experiment_orientability_TAG.ipynb b/experiment_orientability_TAG.ipynb new file mode 100644 index 0000000..af33432 --- /dev/null +++ b/experiment_orientability_TAG.ipynb @@ -0,0 +1,344 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "ExecuteTime": { + "end_time": "2024-04-25T17:47:32.798261Z", + "start_time": "2024-04-25T17:47:19.035622Z" + } + }, + "outputs": [], + "source": [ + "from torch_geometric.loader import DataLoader, ImbalancedSampler\n", + "from torch_geometric.transforms import FaceToEdge, OneHotDegree\n", + "import torchvision.transforms as transforms\n", + "\n", + "from mantra.simplicial import SimplicialDataset\n", + "from mantra.transforms import (\n", + " TriangulationToFaceTransform,\n", + " OrientableToClassTransform,\n", + " DegreeTransform,\n", + ")\n", + "from validation.validate_homology import validate_betti_numbers\n", + "\n", + "import torch" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Dataset: SimplicialDataset(712):\n", + "====================\n", + "Number of graphs: 712\n", + "Number of features: 10\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "C:\\Users\\ernst\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch_geometric\\data\\storage.py:450: UserWarning: Unable to accurately infer 'num_nodes' from the attribute set '{'torsion_coefficients', 'dimension', 'name', 'n_vertices', 'orientable', 'genus', 'betti_numbers', 'face'}'. Please explicitly set 'num_nodes' as an attribute of 'data' to suppress this warning\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of classes: 2\n", + "\n", + "Data(dimension=[1], n_vertices=[1], torsion_coefficients=[3], betti_numbers=[3], orientable=[1], genus=[1], name='S^2', face=[3, 4], edge_index=[2, 12], x=[4, 10], y=[1])\n", + "=============================================================\n", + "Number of nodes: 4\n", + "Number of edges: 12\n", + "Average node degree: 3.00\n", + "Has isolated nodes: False\n", + "Has self-loops: False\n", + "Is undirected: True\n", + "=============================================================\n", + "Number of orientable Manifolds: 193\n", + "Number of non-orientable Manifolds: 519\n", + "Percentage: 0.27, 0.73\n" + ] + } + ], + "source": [ + "tr = transforms.Compose(\n", + " [\n", + " TriangulationToFaceTransform(),\n", + " FaceToEdge(remove_faces=False),\n", + " DegreeTransform(),\n", + " OrientableToClassTransform(),\n", + " OneHotDegree(max_degree=9,cat=False)\n", + " ]\n", + ")\n", + "\n", + "dataset = SimplicialDataset(root=\"./data\", transform=tr)\n", + "\n", + "\n", + "print()\n", + "print(f\"Dataset: {dataset}:\")\n", + "print(\"====================\")\n", + "print(f\"Number of graphs: {len(dataset)}\")\n", + "print(f\"Number of features: {dataset.num_features}\")\n", + "print(f\"Number of classes: {dataset.num_classes}\")\n", + "\n", + "data = dataset[0] # Get the first graph object.\n", + "\n", + "print()\n", + "print(data)\n", + "print(\"=============================================================\")\n", + "\n", + "# Gather some statistics about the first graph.\n", + "print(f\"Number of nodes: {len(data.x)}\")\n", + "print(f\"Number of edges: {data.num_edges}\")\n", + "print(f\"Average node degree: {data.num_edges / len(data.x):.2f}\")\n", + "print(f\"Has isolated nodes: {data.has_isolated_nodes()}\")\n", + "print(f\"Has self-loops: {data.has_self_loops()}\")\n", + "print(f\"Is undirected: {data.is_undirected()}\")\n", + "\n", + "print(\"=============================================================\")\n", + "print(f\"Number of orientable Manifolds: {sum(dataset.orientable)}\")\n", + "print(\n", + " f\"Number of non-orientable Manifolds: {len(dataset) - sum(dataset.orientable)}\"\n", + ")\n", + "print(\n", + " f\"Percentage: {sum(dataset.orientable) / len(dataset):.2f}, {(len(dataset) - sum(dataset.orientable)) / len(dataset):.2f}\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of training graphs: 562\n", + "Number of test graphs: 150\n" + ] + } + ], + "source": [ + "dataset = dataset.shuffle()\n", + "\n", + "train_dataset = dataset[:-150]\n", + "test_dataset = dataset[-150:]\n", + "\n", + "print(f\"Number of training graphs: {len(train_dataset)}\")\n", + "print(f\"Number of test graphs: {len(test_dataset)}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([1, 0, 0, 0, 0, 0, 1, 1, 1, 1])" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_loader = DataLoader(train_dataset,batch_size=10)#,sampler=ImbalancedSampler(train_dataset))\n", + "test_loader = DataLoader(test_dataset,batch_size=10)\n", + "\n", + "\n", + "for batch in train_loader:\n", + " break\n", + "\n", + "batch.y" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "GCN(\n", + " (conv1): TAGConv(10, 64, K=3)\n", + " (conv2): TAGConv(64, 64, K=3)\n", + " (conv3): TAGConv(64, 64, K=3)\n", + " (lin): Linear(in_features=64, out_features=2, bias=True)\n", + ")\n" + ] + } + ], + "source": [ + "from torch.nn import Linear\n", + "import torch.nn.functional as F\n", + "from torch_geometric.nn import GCNConv, TAGConv\n", + "from torch_geometric.nn import global_mean_pool\n", + "\n", + "\n", + "class GCN(torch.nn.Module):\n", + " def __init__(self, hidden_channels):\n", + " super(GCN, self).__init__()\n", + " torch.manual_seed(12345)\n", + " self.conv1 = TAGConv(dataset.num_node_features, hidden_channels)\n", + " self.conv2 = TAGConv(hidden_channels, hidden_channels)\n", + " self.conv3 = TAGConv(hidden_channels, hidden_channels)\n", + " self.lin = Linear(hidden_channels, dataset.num_classes)\n", + "\n", + " def forward(self, batch):\n", + " # 1. Obtain node embeddings\n", + " x = self.conv1(batch.x, batch.edge_index)\n", + " x = x.relu()\n", + " x = self.conv2(x, batch.edge_index)\n", + " x = x.relu()\n", + " x = self.conv3(x, batch.edge_index)\n", + "\n", + " # 2. Readout layer\n", + " x = global_mean_pool(x, batch.batch) # [batch_size, hidden_channels]\n", + "\n", + " # 3. Apply a final classifier\n", + " x = F.dropout(x, p=0.5, training=self.training)\n", + " x = self.lin(x)\n", + "\n", + " return x\n", + "\n", + "\n", + "model = GCN(hidden_channels=64)\n", + "print(model)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch: 001, Train Acc: 0.7829, Test Acc: 0.8067\n", + "Epoch: 002, Train Acc: 0.8274, Test Acc: 0.8400\n", + "Epoch: 003, Train Acc: 0.8274, Test Acc: 0.8467\n", + "Epoch: 004, Train Acc: 0.8274, Test Acc: 0.8467\n", + "Epoch: 005, Train Acc: 0.8274, Test Acc: 0.8467\n", + "Epoch: 006, Train Acc: 0.8327, Test Acc: 0.8467\n", + "Epoch: 007, Train Acc: 0.8274, Test Acc: 0.8467\n", + "Epoch: 008, Train Acc: 0.8345, Test Acc: 0.8400\n", + "Epoch: 009, Train Acc: 0.8345, Test Acc: 0.8400\n" + ] + } + ], + "source": [ + "model = GCN(hidden_channels=64)\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n", + "criterion = torch.nn.CrossEntropyLoss()\n", + "\n", + "\n", + "def train():\n", + " model.train()\n", + "\n", + " for data in train_loader: # Iterate in batches over the training dataset.\n", + " out = model(data\n", + " ) # Perform a single forward pass.\n", + " loss = criterion(out, data.y) # Compute the loss.\n", + " loss.backward() # Derive gradients.\n", + " optimizer.step() # Update parameters based on gradients.\n", + " optimizer.zero_grad() # Clear gradients.\n", + "\n", + "\n", + "def test(loader):\n", + " model.eval()\n", + "\n", + " correct = 0\n", + " for data in loader: # Iterate in batches over the training/test dataset.\n", + " out = model(data)\n", + " pred = out.argmax(dim=1) # Use the class with highest probability.\n", + " correct += int(\n", + " (pred == data.y).sum()\n", + " ) # Check against ground-truth labels.\n", + " return correct / len(\n", + " loader.dataset\n", + " ) # Derive ratio of correct predictions.\n", + "\n", + "\n", + "for epoch in range(1, 10):\n", + " train()\n", + " train_acc = test(train_loader)\n", + " test_acc = test(test_loader)\n", + " print(\n", + " f\"Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}\"\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([ 0.3365, 2.6768, 3.0322, 0.2097, 0.4206, 2.7076, -3.6068, -3.1436,\n", + " -7.1983, 0.3847], grad_fn=)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model(batch)[:,0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/experiment_orientability_TransformerConv.ipynb b/experiment_orientability_TransformerConv.ipynb new file mode 100644 index 0000000..e99ecea --- /dev/null +++ b/experiment_orientability_TransformerConv.ipynb @@ -0,0 +1,358 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "ExecuteTime": { + "end_time": "2024-04-25T17:47:32.798261Z", + "start_time": "2024-04-25T17:47:19.035622Z" + } + }, + "outputs": [], + "source": [ + "from torch_geometric.loader import DataLoader, ImbalancedSampler\n", + "from torch_geometric.transforms import FaceToEdge, OneHotDegree\n", + "import torchvision.transforms as transforms\n", + "\n", + "from mantra.simplicial import SimplicialDataset\n", + "from mantra.transforms import (\n", + " TriangulationToFaceTransform,\n", + " OrientableToClassTransform,\n", + " DegreeTransform,\n", + ")\n", + "from validation.validate_homology import validate_betti_numbers\n", + "\n", + "import torch" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Dataset: SimplicialDataset(712):\n", + "====================\n", + "Number of graphs: 712\n", + "Number of features: 9\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "C:\\Users\\ernst\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch_geometric\\data\\storage.py:450: UserWarning: Unable to accurately infer 'num_nodes' from the attribute set '{'torsion_coefficients', 'betti_numbers', 'n_vertices', 'dimension', 'name', 'orientable', 'genus', 'face'}'. Please explicitly set 'num_nodes' as an attribute of 'data' to suppress this warning\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of classes: 2\n", + "\n", + "Data(dimension=[1], n_vertices=[1], torsion_coefficients=[3], betti_numbers=[3], orientable=[1], genus=[1], name='S^2', face=[3, 4], edge_index=[2, 12], x=[4, 9], y=[1])\n", + "=============================================================\n", + "Number of nodes: 4\n", + "Number of edges: 12\n", + "Average node degree: 3.00\n", + "Has isolated nodes: False\n", + "Has self-loops: False\n", + "Is undirected: True\n", + "=============================================================\n", + "Number of orientable Manifolds: 193\n", + "Number of non-orientable Manifolds: 519\n", + "Percentage: 0.27, 0.73\n" + ] + } + ], + "source": [ + "tr = transforms.Compose(\n", + " [\n", + " TriangulationToFaceTransform(),\n", + " FaceToEdge(remove_faces=False),\n", + " DegreeTransform(),\n", + " OrientableToClassTransform(),\n", + " OneHotDegree(max_degree=8,cat=False)\n", + " ]\n", + ")\n", + "\n", + "dataset = SimplicialDataset(root=\"./data\", transform=tr)\n", + "\n", + "\n", + "print()\n", + "print(f\"Dataset: {dataset}:\")\n", + "print(\"====================\")\n", + "print(f\"Number of graphs: {len(dataset)}\")\n", + "print(f\"Number of features: {dataset.num_features}\")\n", + "print(f\"Number of classes: {dataset.num_classes}\")\n", + "\n", + "data = dataset[0] # Get the first graph object.\n", + "\n", + "print()\n", + "print(data)\n", + "print(\"=============================================================\")\n", + "\n", + "# Gather some statistics about the first graph.\n", + "print(f\"Number of nodes: {len(data.x)}\")\n", + "print(f\"Number of edges: {data.num_edges}\")\n", + "print(f\"Average node degree: {data.num_edges / len(data.x):.2f}\")\n", + "print(f\"Has isolated nodes: {data.has_isolated_nodes()}\")\n", + "print(f\"Has self-loops: {data.has_self_loops()}\")\n", + "print(f\"Is undirected: {data.is_undirected()}\")\n", + "\n", + "print(\"=============================================================\")\n", + "print(f\"Number of orientable Manifolds: {sum(dataset.orientable)}\")\n", + "print(\n", + " f\"Number of non-orientable Manifolds: {len(dataset) - sum(dataset.orientable)}\"\n", + ")\n", + "print(\n", + " f\"Percentage: {sum(dataset.orientable) / len(dataset):.2f}, {(len(dataset) - sum(dataset.orientable)) / len(dataset):.2f}\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of training graphs: 562\n", + "Number of test graphs: 150\n" + ] + } + ], + "source": [ + "dataset = dataset.shuffle()\n", + "\n", + "train_dataset = dataset[:-150]\n", + "test_dataset = dataset[-150:]\n", + "\n", + "print(f\"Number of training graphs: {len(train_dataset)}\")\n", + "print(f\"Number of test graphs: {len(test_dataset)}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([0, 0, 0, 0, 0, 0, 0, 0, 1, 0])" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_loader = DataLoader(train_dataset,batch_size=10)#,sampler=ImbalancedSampler(train_dataset))\n", + "test_loader = DataLoader(test_dataset,batch_size=10)\n", + "\n", + "\n", + "for batch in train_loader:\n", + " break\n", + "\n", + "batch.y" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[ 0.4272, 0.1329],\n", + " [ 0.2431, 0.0324],\n", + " [ 0.2296, 0.1512],\n", + " [ 0.0760, -0.0063],\n", + " [ 0.4862, 0.0779],\n", + " [ 0.3347, 0.1233],\n", + " [ 0.1409, 0.1253],\n", + " [ 0.2714, 0.2711],\n", + " [ 0.0957, 0.2088],\n", + " [ 0.2210, 0.0088]], grad_fn=)\n" + ] + } + ], + "source": [ + "from operator import concat\n", + "from torch.nn import Linear\n", + "import torch.nn.functional as F\n", + "from torch_geometric.nn import GCNConv, TAGConv,TransformerConv\n", + "from torch_geometric.nn import global_mean_pool\n", + "\n", + "\n", + "class GCN(torch.nn.Module):\n", + " def __init__(self, hidden_channels):\n", + " super(GCN, self).__init__()\n", + " torch.manual_seed(12345)\n", + " self.conv1 = TransformerConv(dataset.num_node_features, hidden_channels)\n", + " self.conv2 = TransformerConv(hidden_channels, hidden_channels)\n", + " self.conv3 = TransformerConv(hidden_channels, hidden_channels,concat=False)\n", + " self.lin = Linear(hidden_channels, dataset.num_classes)\n", + "\n", + " def forward(self, batch):\n", + " # 1. Obtain node embeddings\n", + " x = self.conv1(batch.x, batch.edge_index)\n", + " x = x.relu()\n", + " x = self.conv2(x, batch.edge_index)\n", + " x = x.relu()\n", + " x = self.conv3(x, batch.edge_index)\n", + "\n", + " # 2. Readout layer\n", + " x = global_mean_pool(x, batch.batch) # [batch_size, hidden_channels]\n", + "\n", + " # 3. Apply a final classifier\n", + " x = F.dropout(x, p=0.5, training=self.training)\n", + " x = self.lin(x)\n", + "\n", + " return x\n", + "\n", + "\n", + "model = GCN(hidden_channels=64)\n", + "print(model(batch))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch: 001, Train Acc: 0.7278, Test Acc: 0.7333\n", + "Epoch: 002, Train Acc: 0.7509, Test Acc: 0.7600\n", + "Epoch: 003, Train Acc: 0.8060, Test Acc: 0.8133\n", + "Epoch: 004, Train Acc: 0.8256, Test Acc: 0.8267\n", + "Epoch: 005, Train Acc: 0.8310, Test Acc: 0.8333\n", + "Epoch: 006, Train Acc: 0.8310, Test Acc: 0.8333\n", + "Epoch: 007, Train Acc: 0.8310, Test Acc: 0.8333\n", + "Epoch: 008, Train Acc: 0.8310, Test Acc: 0.8333\n", + "Epoch: 009, Train Acc: 0.8310, Test Acc: 0.8333\n", + "Epoch: 010, Train Acc: 0.8310, Test Acc: 0.8267\n", + "Epoch: 011, Train Acc: 0.8310, Test Acc: 0.8267\n", + "Epoch: 012, Train Acc: 0.8310, Test Acc: 0.8333\n", + "Epoch: 013, Train Acc: 0.8310, Test Acc: 0.8200\n", + "Epoch: 014, Train Acc: 0.8381, Test Acc: 0.8200\n", + "Epoch: 015, Train Acc: 0.8452, Test Acc: 0.8133\n", + "Epoch: 016, Train Acc: 0.8488, Test Acc: 0.8400\n", + "Epoch: 017, Train Acc: 0.8505, Test Acc: 0.8400\n", + "Epoch: 018, Train Acc: 0.8488, Test Acc: 0.8400\n", + "Epoch: 019, Train Acc: 0.8470, Test Acc: 0.8400\n", + "Epoch: 020, Train Acc: 0.8470, Test Acc: 0.8400\n", + "Epoch: 021, Train Acc: 0.8523, Test Acc: 0.8467\n", + "Epoch: 022, Train Acc: 0.8559, Test Acc: 0.8533\n", + "Epoch: 023, Train Acc: 0.8559, Test Acc: 0.8400\n", + "Epoch: 024, Train Acc: 0.8577, Test Acc: 0.8467\n", + "Epoch: 025, Train Acc: 0.8665, Test Acc: 0.8533\n", + "Epoch: 026, Train Acc: 0.8737, Test Acc: 0.8467\n", + "Epoch: 027, Train Acc: 0.8701, Test Acc: 0.8533\n", + "Epoch: 028, Train Acc: 0.8701, Test Acc: 0.8600\n", + "Epoch: 029, Train Acc: 0.8737, Test Acc: 0.8533\n", + "Epoch: 030, Train Acc: 0.8719, Test Acc: 0.8533\n", + "Epoch: 031, Train Acc: 0.8737, Test Acc: 0.8533\n", + "Epoch: 032, Train Acc: 0.8665, Test Acc: 0.8467\n", + "Epoch: 033, Train Acc: 0.8772, Test Acc: 0.8533\n", + "Epoch: 034, Train Acc: 0.8754, Test Acc: 0.8400\n", + "Epoch: 035, Train Acc: 0.8772, Test Acc: 0.8600\n", + "Epoch: 036, Train Acc: 0.8737, Test Acc: 0.8467\n", + "Epoch: 037, Train Acc: 0.8665, Test Acc: 0.8533\n", + "Epoch: 038, Train Acc: 0.8790, Test Acc: 0.8400\n", + "Epoch: 039, Train Acc: 0.8754, Test Acc: 0.8533\n" + ] + } + ], + "source": [ + "model = GCN(hidden_channels=64)\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n", + "criterion = torch.nn.CrossEntropyLoss()\n", + "\n", + "\n", + "def train():\n", + " model.train()\n", + "\n", + " for data in train_loader: # Iterate in batches over the training dataset.\n", + " out = model(data\n", + " ) # Perform a single forward pass.\n", + " loss = criterion(out, data.y) # Compute the loss.\n", + " loss.backward() # Derive gradients.\n", + " optimizer.step() # Update parameters based on gradients.\n", + " optimizer.zero_grad() # Clear gradients.\n", + "\n", + "\n", + "def test(loader):\n", + " model.eval()\n", + "\n", + " correct = 0\n", + " for data in loader: # Iterate in batches over the training/test dataset.\n", + " out = model(data)\n", + " pred = out.argmax(dim=1) # Use the class with highest probability.\n", + " correct += int(\n", + " (pred == data.y).sum()\n", + " ) # Check against ground-truth labels.\n", + " return correct / len(\n", + " loader.dataset\n", + " ) # Derive ratio of correct predictions.\n", + "\n", + "\n", + "for epoch in range(1, 40):\n", + " train()\n", + " train_acc = test(train_loader)\n", + " test_acc = test(test_loader)\n", + " print(\n", + " f\"Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}\"\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/experiment_orientability_gnn.ipynb b/experiment_orientability_gnn.ipynb new file mode 100644 index 0000000..239bf67 --- /dev/null +++ b/experiment_orientability_gnn.ipynb @@ -0,0 +1,342 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "ExecuteTime": { + "end_time": "2024-04-25T17:47:32.798261Z", + "start_time": "2024-04-25T17:47:19.035622Z" + } + }, + "outputs": [], + "source": [ + "from torch_geometric.loader import DataLoader\n", + "from torch_geometric.transforms import FaceToEdge,OneHotDegree\n", + "import torchvision.transforms as transforms\n", + "\n", + "from mantra.simplicial import SimplicialDataset\n", + "from mantra.transforms import (\n", + " TriangulationToFaceTransform,\n", + " OrientableToClassTransform,\n", + " DegreeTransform,\n", + ")\n", + "from validation.validate_homology import validate_betti_numbers\n", + "\n", + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Dataset: SimplicialDataset(712):\n", + "====================\n", + "Number of graphs: 712\n", + "Number of features: 10\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "C:\\Users\\ernst\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch_geometric\\data\\storage.py:450: UserWarning: Unable to accurately infer 'num_nodes' from the attribute set '{'name', 'orientable', 'torsion_coefficients', 'n_vertices', 'face', 'dimension', 'genus', 'betti_numbers'}'. Please explicitly set 'num_nodes' as an attribute of 'data' to suppress this warning\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of classes: 2\n", + "\n", + "Data(dimension=[1], n_vertices=[1], torsion_coefficients=[3], betti_numbers=[3], orientable=[1], genus=[1], name='S^2', face=[3, 4], edge_index=[2, 12], x=[4, 10], y=[1])\n", + "=============================================================\n", + "Number of nodes: 4\n", + "Number of edges: 12\n", + "Average node degree: 3.00\n", + "Has isolated nodes: False\n", + "Has self-loops: False\n", + "Is undirected: True\n", + "=============================================================\n", + "Number of orientable Manifolds: 193\n", + "Number of non-orientable Manifolds: 519\n", + "Percentage: 0.27, 0.73\n" + ] + } + ], + "source": [ + "tr = transforms.Compose(\n", + " [\n", + " TriangulationToFaceTransform(),\n", + " FaceToEdge(remove_faces=False),\n", + " DegreeTransform(),\n", + " OrientableToClassTransform(),\n", + " OneHotDegree(max_degree=9,cat=False),\n", + " ]\n", + ")\n", + "\n", + "dataset = SimplicialDataset(root=\"./data\", transform=tr)\n", + "\n", + "\n", + "print()\n", + "print(f\"Dataset: {dataset}:\")\n", + "print(\"====================\")\n", + "print(f\"Number of graphs: {len(dataset)}\")\n", + "print(f\"Number of features: {dataset.num_features}\")\n", + "print(f\"Number of classes: {dataset.num_classes}\")\n", + "\n", + "data = dataset[0] # Get the first graph object.\n", + "\n", + "print()\n", + "print(data)\n", + "print(\"=============================================================\")\n", + "\n", + "# Gather some statistics about the first graph.\n", + "print(f\"Number of nodes: {len(data.x)}\")\n", + "print(f\"Number of edges: {data.num_edges}\")\n", + "print(f\"Average node degree: {data.num_edges / len(data.x):.2f}\")\n", + "print(f\"Has isolated nodes: {data.has_isolated_nodes()}\")\n", + "print(f\"Has self-loops: {data.has_self_loops()}\")\n", + "print(f\"Is undirected: {data.is_undirected()}\")\n", + "\n", + "print(\"=============================================================\")\n", + "print(f\"Number of orientable Manifolds: {sum(dataset.orientable)}\")\n", + "print(\n", + " f\"Number of non-orientable Manifolds: {len(dataset) - sum(dataset.orientable)}\"\n", + ")\n", + "print(\n", + " f\"Percentage: {sum(dataset.orientable) / len(dataset):.2f}, {(len(dataset) - sum(dataset.orientable)) / len(dataset):.2f}\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of training graphs: 562\n", + "Number of test graphs: 150\n" + ] + } + ], + "source": [ + "dataset = dataset.shuffle()\n", + "\n", + "train_dataset = dataset[:-150]\n", + "test_dataset = dataset[-150:]\n", + "\n", + "print(f\"Number of training graphs: {len(train_dataset)}\")\n", + "print(f\"Number of test graphs: {len(test_dataset)}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "train_loader = DataLoader(train_dataset)\n", + "test_loader = DataLoader(test_dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([1])\n" + ] + } + ], + "source": [ + "print(dataset[0].y)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "GCN(\n", + " (conv1): GCNConv(10, 64)\n", + " (conv2): GCNConv(64, 64)\n", + " (conv3): GCNConv(64, 64)\n", + " (lin): Linear(in_features=64, out_features=2, bias=True)\n", + ")\n" + ] + } + ], + "source": [ + "from torch.nn import Linear\n", + "import torch.nn.functional as F\n", + "from torch_geometric.nn import GCNConv\n", + "from torch_geometric.nn import global_mean_pool\n", + "\n", + "\n", + "class GCN(torch.nn.Module):\n", + " def __init__(self, hidden_channels):\n", + " super(GCN, self).__init__()\n", + " torch.manual_seed(12345)\n", + " self.conv1 = GCNConv(dataset.num_node_features, hidden_channels)\n", + " self.conv2 = GCNConv(hidden_channels, hidden_channels)\n", + " self.conv3 = GCNConv(hidden_channels, hidden_channels)\n", + " self.lin = Linear(hidden_channels, dataset.num_classes)\n", + "\n", + " def forward(self, x, edge_index, batch):\n", + " # 1. Obtain node embeddings\n", + " x = self.conv1(x, edge_index)\n", + " x = x.relu()\n", + " x = self.conv2(x, edge_index)\n", + " x = x.relu()\n", + " x = self.conv3(x, edge_index)\n", + "\n", + " # 2. Readout layer\n", + " x = global_mean_pool(x, batch) # [batch_size, hidden_channels]\n", + "\n", + " # 3. Apply a final classifier\n", + " x = F.dropout(x, p=0.5, training=self.training)\n", + " x = self.lin(x)\n", + "\n", + " return x\n", + "\n", + "\n", + "model = GCN(hidden_channels=64)\n", + "print(model)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch: 001, Train Acc: 0.7278, Test Acc: 0.7533\n", + "Epoch: 002, Train Acc: 0.7972, Test Acc: 0.8333\n", + "Epoch: 003, Train Acc: 0.8238, Test Acc: 0.8600\n", + "Epoch: 004, Train Acc: 0.8238, Test Acc: 0.8600\n", + "Epoch: 005, Train Acc: 0.8238, Test Acc: 0.8600\n", + "Epoch: 006, Train Acc: 0.8238, Test Acc: 0.8600\n", + "Epoch: 007, Train Acc: 0.8238, Test Acc: 0.8600\n", + "Epoch: 008, Train Acc: 0.8238, Test Acc: 0.8600\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[1;32mIn[7], line 36\u001b[0m\n\u001b[0;32m 34\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m171\u001b[39m):\n\u001b[0;32m 35\u001b[0m train()\n\u001b[1;32m---> 36\u001b[0m train_acc \u001b[38;5;241m=\u001b[39m \u001b[43mtest\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrain_loader\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 37\u001b[0m test_acc \u001b[38;5;241m=\u001b[39m test(test_loader)\n\u001b[0;32m 38\u001b[0m \u001b[38;5;28mprint\u001b[39m(\n\u001b[0;32m 39\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mEpoch: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepoch\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m03d\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, Train Acc: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtrain_acc\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m.4f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, Test Acc: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtest_acc\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m.4f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 40\u001b[0m )\n", + "Cell \u001b[1;32mIn[7], line 24\u001b[0m, in \u001b[0;36mtest\u001b[1;34m(loader)\u001b[0m\n\u001b[0;32m 22\u001b[0m correct \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[0;32m 23\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m data \u001b[38;5;129;01min\u001b[39;00m loader: \u001b[38;5;66;03m# Iterate in batches over the training/test dataset.\u001b[39;00m\n\u001b[1;32m---> 24\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43medge_index\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbatch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 25\u001b[0m pred \u001b[38;5;241m=\u001b[39m out\u001b[38;5;241m.\u001b[39margmax(dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m) \u001b[38;5;66;03m# Use the class with highest probability.\u001b[39;00m\n\u001b[0;32m 26\u001b[0m correct \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;28mint\u001b[39m(\n\u001b[0;32m 27\u001b[0m (pred \u001b[38;5;241m==\u001b[39m data\u001b[38;5;241m.\u001b[39my)\u001b[38;5;241m.\u001b[39msum()\n\u001b[0;32m 28\u001b[0m ) \u001b[38;5;66;03m# Check against ground-truth labels.\u001b[39;00m\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch\\nn\\modules\\module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch\\nn\\modules\\module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "Cell \u001b[1;32mIn[6], line 18\u001b[0m, in \u001b[0;36mGCN.forward\u001b[1;34m(self, x, edge_index, batch)\u001b[0m\n\u001b[0;32m 16\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x, edge_index, batch):\n\u001b[0;32m 17\u001b[0m \u001b[38;5;66;03m# 1. Obtain node embeddings\u001b[39;00m\n\u001b[1;32m---> 18\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconv1\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43medge_index\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 19\u001b[0m x \u001b[38;5;241m=\u001b[39m x\u001b[38;5;241m.\u001b[39mrelu()\n\u001b[0;32m 20\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconv2(x, edge_index)\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch\\nn\\modules\\module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch\\nn\\modules\\module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "\u001b[1;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "model = GCN(hidden_channels=64)\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n", + "criterion = torch.nn.CrossEntropyLoss()\n", + "\n", + "\n", + "def train():\n", + " model.train()\n", + "\n", + " for data in train_loader: # Iterate in batches over the training dataset.\n", + " out = model(\n", + " data.x, data.edge_index, data.batch\n", + " ) # Perform a single forward pass.\n", + " loss = criterion(out, data.y) # Compute the loss.\n", + " loss.backward() # Derive gradients.\n", + " optimizer.step() # Update parameters based on gradients.\n", + " optimizer.zero_grad() # Clear gradients.\n", + "\n", + "\n", + "def test(loader):\n", + " model.eval()\n", + "\n", + " correct = 0\n", + " for data in loader: # Iterate in batches over the training/test dataset.\n", + " out = model(data.x, data.edge_index, data.batch)\n", + " pred = out.argmax(dim=1) # Use the class with highest probability.\n", + " correct += int(\n", + " (pred == data.y).sum()\n", + " ) # Check against ground-truth labels.\n", + " return correct / len(\n", + " loader.dataset\n", + " ) # Derive ratio of correct predictions.\n", + "\n", + "\n", + "for epoch in range(1, 171):\n", + " train()\n", + " train_acc = test(train_loader)\n", + " test_acc = test(test_loader)\n", + " print(\n", + " f\"Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}\"\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/experiment_orientability_mlp.ipynb b/experiment_orientability_mlp.ipynb new file mode 100644 index 0000000..eda7c4d --- /dev/null +++ b/experiment_orientability_mlp.ipynb @@ -0,0 +1,387 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "ExecuteTime": { + "end_time": "2024-04-25T17:47:32.798261Z", + "start_time": "2024-04-25T17:47:19.035622Z" + } + }, + "outputs": [], + "source": [ + "from torch_geometric.loader import DataLoader, ImbalancedSampler\n", + "from torch_geometric.transforms import FaceToEdge, OneHotDegree\n", + "import torchvision.transforms as transforms\n", + "\n", + "from mantra.simplicial import SimplicialDataset\n", + "from mantra.transforms import (\n", + " TriangulationToFaceTransform,\n", + " OrientableToClassTransform,\n", + " DegreeTransform,\n", + ")\n", + "from validation.validate_homology import validate_betti_numbers\n", + "\n", + "import torch" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Dataset: SimplicialDataset(712):\n", + "====================\n", + "Number of graphs: 712\n", + "Number of features: 13\n", + "Number of classes: 2\n", + "\n", + "Data(dimension=[1], n_vertices=[1], torsion_coefficients=[3], betti_numbers=[3], orientable=[1], genus=[1], name='S^2', face=[3, 4], edge_index=[2, 12], x=[4, 13], y=[1])\n", + "=============================================================\n", + "Number of nodes: 4\n", + "Number of edges: 12\n", + "Average node degree: 3.00\n", + "Has isolated nodes: False\n", + "Has self-loops: False\n", + "Is undirected: True\n", + "=============================================================\n", + "Number of orientable Manifolds: 193\n", + "Number of non-orientable Manifolds: 519\n", + "Percentage: 0.27, 0.73\n" + ] + } + ], + "source": [ + "tr = transforms.Compose(\n", + " [\n", + " TriangulationToFaceTransform(),\n", + " FaceToEdge(remove_faces=False),\n", + " DegreeTransform(),\n", + " OrientableToClassTransform(),\n", + " OneHotDegree(max_degree=12,cat=False)\n", + " ]\n", + ")\n", + "\n", + "dataset = SimplicialDataset(root=\"./data\", transform=tr)\n", + "\n", + "\n", + "print()\n", + "print(f\"Dataset: {dataset}:\")\n", + "print(\"====================\")\n", + "print(f\"Number of graphs: {len(dataset)}\")\n", + "print(f\"Number of features: {dataset.num_features}\")\n", + "print(f\"Number of classes: {dataset.num_classes}\")\n", + "\n", + "data = dataset[0] # Get the first graph object.\n", + "\n", + "print()\n", + "print(data)\n", + "print(\"=============================================================\")\n", + "\n", + "# Gather some statistics about the first graph.\n", + "print(f\"Number of nodes: {len(data.x)}\")\n", + "print(f\"Number of edges: {data.num_edges}\")\n", + "print(f\"Average node degree: {data.num_edges / len(data.x):.2f}\")\n", + "print(f\"Has isolated nodes: {data.has_isolated_nodes()}\")\n", + "print(f\"Has self-loops: {data.has_self_loops()}\")\n", + "print(f\"Is undirected: {data.is_undirected()}\")\n", + "\n", + "print(\"=============================================================\")\n", + "print(f\"Number of orientable Manifolds: {sum(dataset.orientable)}\")\n", + "print(\n", + " f\"Number of non-orientable Manifolds: {len(dataset) - sum(dataset.orientable)}\"\n", + ")\n", + "print(\n", + " f\"Percentage: {sum(dataset.orientable) / len(dataset):.2f}, {(len(dataset) - sum(dataset.orientable)) / len(dataset):.2f}\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1,\n", + " 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1,\n", + " 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1,\n", + " 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1,\n", + " 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,\n", + " 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0,\n", + " 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0,\n", + " 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0,\n", + " 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0,\n", + " 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1])" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.tensor([data.y for data in dataset])" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of training graphs: 562\n", + "Number of test graphs: 150\n" + ] + } + ], + "source": [ + "dataset = dataset.shuffle()\n", + "\n", + "train_dataset = dataset[:-150]\n", + "test_dataset = dataset[-150:]\n", + "\n", + "print(f\"Number of training graphs: {len(train_dataset)}\")\n", + "print(f\"Number of test graphs: {len(test_dataset)}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([0, 0, 0, 0, 0, 0, 1, 0, 1, 0])" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_loader = DataLoader(train_dataset,batch_size=10)#,sampler=ImbalancedSampler(train_dataset))\n", + "test_loader = DataLoader(test_dataset,batch_size=10)\n", + "\n", + "\n", + "for batch in train_loader:\n", + " break\n", + "\n", + "batch.y" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "from torch.nn import Linear\n", + "import torch.nn.functional as F\n", + "import torch.nn as nn\n", + "from torch_geometric.nn import GCNConv\n", + "from torch_geometric.nn import global_mean_pool\n", + "from torch_scatter import segment_coo\n", + "\n", + "class PermInvariant(torch.nn.Module):\n", + " def __init__(self, hidden_channels):\n", + " super().__init__()\n", + " # torch.manual_seed(12345)\n", + " self.classification = nn.Sequential( \n", + " nn.Linear(dataset.num_node_features,hidden_channels),\n", + " nn.ReLU(),\n", + " nn.Linear(hidden_channels,hidden_channels),\n", + " nn.ReLU(),\n", + " nn.Linear(hidden_channels,2),\n", + " nn.ReLU()\n", + " )\n", + "\n", + " def forward(self, batch):\n", + " x = self.classification(batch.x)\n", + " # print(batch.x)\n", + " # print(x)\n", + " return segment_coo(x,batch.batch,reduce=\"sum\")\n", + "\n", + "\n", + "model = PermInvariant(hidden_channels=64)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch: 001, Train Acc: 0.7278, Test Acc: 0.7333\n", + "Epoch: 002, Train Acc: 0.7278, Test Acc: 0.7333\n", + "Epoch: 003, Train Acc: 0.7278, Test Acc: 0.7333\n", + "Epoch: 004, Train Acc: 0.7278, Test Acc: 0.7333\n", + "Epoch: 005, Train Acc: 0.7278, Test Acc: 0.7333\n", + "Epoch: 006, Train Acc: 0.7278, Test Acc: 0.7333\n", + "Epoch: 007, Train Acc: 0.7278, Test Acc: 0.7333\n", + "Epoch: 008, Train Acc: 0.7278, Test Acc: 0.7333\n", + "Epoch: 009, Train Acc: 0.7278, Test Acc: 0.7333\n", + "Epoch: 010, Train Acc: 0.7278, Test Acc: 0.7333\n", + "Epoch: 011, Train Acc: 0.7278, Test Acc: 0.7333\n", + "Epoch: 012, Train Acc: 0.7278, Test Acc: 0.7333\n", + "Epoch: 013, Train Acc: 0.7278, Test Acc: 0.7333\n", + "Epoch: 014, Train Acc: 0.7278, Test Acc: 0.7333\n", + "Epoch: 015, Train Acc: 0.7278, Test Acc: 0.7333\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[1;32mIn[14], line 34\u001b[0m\n\u001b[0;32m 28\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m correct \u001b[38;5;241m/\u001b[39m \u001b[38;5;28mlen\u001b[39m(\n\u001b[0;32m 29\u001b[0m loader\u001b[38;5;241m.\u001b[39mdataset\n\u001b[0;32m 30\u001b[0m ) \u001b[38;5;66;03m# Derive ratio of correct predictions.\u001b[39;00m\n\u001b[0;32m 33\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m40\u001b[39m):\n\u001b[1;32m---> 34\u001b[0m \u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 35\u001b[0m train_acc \u001b[38;5;241m=\u001b[39m test(train_loader)\n\u001b[0;32m 36\u001b[0m test_acc \u001b[38;5;241m=\u001b[39m test(test_loader)\n", + "Cell \u001b[1;32mIn[14], line 9\u001b[0m, in \u001b[0;36mtrain\u001b[1;34m()\u001b[0m\n\u001b[0;32m 6\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mtrain\u001b[39m():\n\u001b[0;32m 7\u001b[0m model\u001b[38;5;241m.\u001b[39mtrain()\n\u001b[1;32m----> 9\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m data \u001b[38;5;129;01min\u001b[39;00m train_loader: \u001b[38;5;66;03m# Iterate in batches over the training dataset.\u001b[39;00m\n\u001b[0;32m 10\u001b[0m out \u001b[38;5;241m=\u001b[39m model(data\n\u001b[0;32m 11\u001b[0m ) \u001b[38;5;66;03m# Perform a single forward pass.\u001b[39;00m\n\u001b[0;32m 12\u001b[0m loss \u001b[38;5;241m=\u001b[39m criterion(out, data\u001b[38;5;241m.\u001b[39my) \u001b[38;5;66;03m# Compute the loss.\u001b[39;00m\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch\\utils\\data\\dataloader.py:630\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 627\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sampler_iter \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m 628\u001b[0m \u001b[38;5;66;03m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[0;32m 629\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reset() \u001b[38;5;66;03m# type: ignore[call-arg]\u001b[39;00m\n\u001b[1;32m--> 630\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_next_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 631\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[0;32m 632\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataset_kind \u001b[38;5;241m==\u001b[39m _DatasetKind\u001b[38;5;241m.\u001b[39mIterable \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[0;32m 633\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[0;32m 634\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m>\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called:\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch\\utils\\data\\dataloader.py:674\u001b[0m, in \u001b[0;36m_SingleProcessDataLoaderIter._next_data\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 672\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_next_data\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[0;32m 673\u001b[0m index \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_next_index() \u001b[38;5;66;03m# may raise StopIteration\u001b[39;00m\n\u001b[1;32m--> 674\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_dataset_fetcher\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfetch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mindex\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# may raise StopIteration\u001b[39;00m\n\u001b[0;32m 675\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pin_memory:\n\u001b[0;32m 676\u001b[0m data \u001b[38;5;241m=\u001b[39m _utils\u001b[38;5;241m.\u001b[39mpin_memory\u001b[38;5;241m.\u001b[39mpin_memory(data, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pin_memory_device)\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch\\utils\\data\\_utils\\fetch.py:51\u001b[0m, in \u001b[0;36m_MapDatasetFetcher.fetch\u001b[1;34m(self, possibly_batched_index)\u001b[0m\n\u001b[0;32m 49\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset\u001b[38;5;241m.\u001b[39m__getitems__(possibly_batched_index)\n\u001b[0;32m 50\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m---> 51\u001b[0m data \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[idx] \u001b[38;5;28;01mfor\u001b[39;00m idx \u001b[38;5;129;01min\u001b[39;00m possibly_batched_index]\n\u001b[0;32m 52\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m 53\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[possibly_batched_index]\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch\\utils\\data\\_utils\\fetch.py:51\u001b[0m, in \u001b[0;36m\u001b[1;34m(.0)\u001b[0m\n\u001b[0;32m 49\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset\u001b[38;5;241m.\u001b[39m__getitems__(possibly_batched_index)\n\u001b[0;32m 50\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m---> 51\u001b[0m data \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdataset\u001b[49m\u001b[43m[\u001b[49m\u001b[43midx\u001b[49m\u001b[43m]\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m idx \u001b[38;5;129;01min\u001b[39;00m possibly_batched_index]\n\u001b[0;32m 52\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m 53\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[possibly_batched_index]\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch_geometric\\data\\dataset.py:290\u001b[0m, in \u001b[0;36mDataset.__getitem__\u001b[1;34m(self, idx)\u001b[0m\n\u001b[0;32m 285\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\u001b[38;5;28misinstance\u001b[39m(idx, (\u001b[38;5;28mint\u001b[39m, np\u001b[38;5;241m.\u001b[39minteger))\n\u001b[0;32m 286\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m (\u001b[38;5;28misinstance\u001b[39m(idx, Tensor) \u001b[38;5;129;01mand\u001b[39;00m idx\u001b[38;5;241m.\u001b[39mdim() \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m)\n\u001b[0;32m 287\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m (\u001b[38;5;28misinstance\u001b[39m(idx, np\u001b[38;5;241m.\u001b[39mndarray) \u001b[38;5;129;01mand\u001b[39;00m np\u001b[38;5;241m.\u001b[39misscalar(idx))):\n\u001b[0;32m 289\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mindices()[idx])\n\u001b[1;32m--> 290\u001b[0m data \u001b[38;5;241m=\u001b[39m data \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtransform \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtransform\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 291\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m data\n\u001b[0;32m 293\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torchvision\\transforms\\transforms.py:95\u001b[0m, in \u001b[0;36mCompose.__call__\u001b[1;34m(self, img)\u001b[0m\n\u001b[0;32m 93\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, img):\n\u001b[0;32m 94\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m t \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtransforms:\n\u001b[1;32m---> 95\u001b[0m img \u001b[38;5;241m=\u001b[39m \u001b[43mt\u001b[49m\u001b[43m(\u001b[49m\u001b[43mimg\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 96\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m img\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch_geometric\\transforms\\base_transform.py:32\u001b[0m, in \u001b[0;36mBaseTransform.__call__\u001b[1;34m(self, data)\u001b[0m\n\u001b[0;32m 30\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, data: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Any:\n\u001b[0;32m 31\u001b[0m \u001b[38;5;66;03m# Shallow-copy the data so that we prevent in-place data modification.\u001b[39;00m\n\u001b[1;32m---> 32\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcopy\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcopy\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch_geometric\\transforms\\face_to_edge.py:26\u001b[0m, in \u001b[0;36mFaceToEdge.forward\u001b[1;34m(self, data)\u001b[0m\n\u001b[0;32m 24\u001b[0m face \u001b[38;5;241m=\u001b[39m data\u001b[38;5;241m.\u001b[39mface\n\u001b[0;32m 25\u001b[0m edge_index \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mcat([face[:\u001b[38;5;241m2\u001b[39m], face[\u001b[38;5;241m1\u001b[39m:], face[::\u001b[38;5;241m2\u001b[39m]], dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m---> 26\u001b[0m edge_index \u001b[38;5;241m=\u001b[39m \u001b[43mto_undirected\u001b[49m\u001b[43m(\u001b[49m\u001b[43medge_index\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_nodes\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnum_nodes\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 28\u001b[0m data\u001b[38;5;241m.\u001b[39medge_index \u001b[38;5;241m=\u001b[39m edge_index\n\u001b[0;32m 29\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mremove_faces:\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch_geometric\\utils\\undirected.py:209\u001b[0m, in \u001b[0;36mto_undirected\u001b[1;34m(edge_index, edge_attr, num_nodes, reduce)\u001b[0m\n\u001b[0;32m 206\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(edge_attr, (\u001b[38;5;28mlist\u001b[39m, \u001b[38;5;28mtuple\u001b[39m)):\n\u001b[0;32m 207\u001b[0m edge_attr \u001b[38;5;241m=\u001b[39m [torch\u001b[38;5;241m.\u001b[39mcat([e, e], dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m) \u001b[38;5;28;01mfor\u001b[39;00m e \u001b[38;5;129;01min\u001b[39;00m edge_attr]\n\u001b[1;32m--> 209\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mcoalesce\u001b[49m\u001b[43m(\u001b[49m\u001b[43medge_index\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43medge_attr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_nodes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreduce\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch_geometric\\utils\\_coalesce.py:175\u001b[0m, in \u001b[0;36mcoalesce\u001b[1;34m(edge_index, edge_attr, num_nodes, reduce, is_sorted, sort_by_row)\u001b[0m\n\u001b[0;32m 173\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(edge_index, Tensor):\n\u001b[0;32m 174\u001b[0m edge_index \u001b[38;5;241m=\u001b[39m edge_index[:, mask]\n\u001b[1;32m--> 175\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mjit\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mis_scripting\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(edge_index, EdgeIndex):\n\u001b[0;32m 176\u001b[0m edge_index\u001b[38;5;241m.\u001b[39m_is_undirected \u001b[38;5;241m=\u001b[39m is_undirected\n\u001b[0;32m 177\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(edge_index, \u001b[38;5;28mtuple\u001b[39m):\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch\\_jit_internal.py:1109\u001b[0m, in \u001b[0;36mis_scripting\u001b[1;34m()\u001b[0m\n\u001b[0;32m 1105\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m2\u001b[39m, \u001b[38;5;241m7\u001b[39m):\n\u001b[0;32m 1106\u001b[0m \u001b[38;5;28mglobals\u001b[39m()[\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mBroadcastingList\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mi\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m BroadcastingList1\n\u001b[1;32m-> 1109\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mis_scripting\u001b[39m() \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28mbool\u001b[39m:\n\u001b[0;32m 1110\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m 1111\u001b[0m \u001b[38;5;124;03m Function that returns True when in compilation and False otherwise. This\u001b[39;00m\n\u001b[0;32m 1112\u001b[0m \u001b[38;5;124;03m is useful especially with the @unused decorator to leave code in your\u001b[39;00m\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 1126\u001b[0m \u001b[38;5;124;03m return unsupported_linear_op(x)\u001b[39;00m\n\u001b[0;32m 1127\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m 1128\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mFalse\u001b[39;00m\n", + "\u001b[1;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "model = PermInvariant(hidden_channels=64)\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)\n", + "criterion = torch.nn.CrossEntropyLoss()\n", + "\n", + "\n", + "def train():\n", + " model.train()\n", + "\n", + " for data in train_loader: # Iterate in batches over the training dataset.\n", + " out = model(data\n", + " ) # Perform a single forward pass.\n", + " loss = criterion(out, data.y) # Compute the loss.\n", + " loss.backward() # Derive gradients.\n", + " optimizer.step() # Update parameters based on gradients.\n", + " optimizer.zero_grad() # Clear gradients.\n", + "\n", + "\n", + "def test(loader):\n", + " model.eval()\n", + "\n", + " correct = 0\n", + " for data in loader: # Iterate in batches over the training/test dataset.\n", + " out = model(data)\n", + " pred = out.argmax(dim=1) # Use the class with highest probability.\n", + " correct += int(\n", + " (pred == data.y).sum()\n", + " ) # Check against ground-truth labels.\n", + " return correct / len(\n", + " loader.dataset\n", + " ) # Derive ratio of correct predictions.\n", + "\n", + "\n", + "for epoch in range(1, 40):\n", + " train()\n", + " train_acc = test(train_loader)\n", + " test_acc = test(test_loader)\n", + " print(\n", + " f\"Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}\"\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], grad_fn=)" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model(batch)[:,0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}