diff --git a/experiment_gnn.ipynb b/experiment_gnn.ipynb new file mode 100644 index 0000000..4dac831 --- /dev/null +++ b/experiment_gnn.ipynb @@ -0,0 +1,409 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "from torch_geometric.loader import DataLoader\n", + "from torch_geometric.utils import degree\n", + "from torch_geometric.transforms import FaceToEdge\n", + "import torchvision.transforms as transforms\n", + "\n", + "from mantra.simplicial import SimplicialDataset\n", + "from validation.validate_homology import validate_betti_numbers\n", + "\n", + "import torch\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 81, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Dataset: SimplicialDataset(712):\n", + "====================\n", + "Number of graphs: 712\n", + "Number of features: 1\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "C:\\Users\\ernst\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch_geometric\\data\\storage.py:327: UserWarning: Unable to accurately infer 'num_nodes' from the attribute set '{'dimension', 'torsion_coefficients', 'name', 'n_vertices', 'betti_numbers', 'orientable', 'face', 'genus'}'. Please explicitly set 'num_nodes' as an attribute of 'data' to suppress this warning\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of classes: 2\n", + "\n", + "Data(dimension=[1], n_vertices=[1], torsion_coefficients=[3], betti_numbers=[3], orientable=[1], genus=[1], name='S^2', face=[3, 4], edge_index=[2, 12], x=[5, 1], y=[1])\n", + "=============================================================\n", + "Number of nodes: 5\n", + "Number of edges: 12\n", + "Average node degree: 2.40\n", + "Has isolated nodes: True\n", + "Has self-loops: False\n", + "Is undirected: True\n", + "=============================================================\n", + "Number of orientable Manifolds: 193\n", + "Number of non-orientable Manifolds: 519\n", + "Percentage: 0.27, 0.73\n" + ] + } + ], + "source": [ + "class DegreeTransform(object):\n", + " def __call__(self, data):\n", + " deg = degree(data.edge_index[0], dtype=torch.float)\n", + " data.x = deg.view(-1, 1)\n", + " return data\n", + "\n", + "class TriangulationToFaceTransform(object):\n", + " def __call__(self, data):\n", + " data.face = torch.tensor(data.triangulation).T - 1\n", + " data.triangulation = None\n", + " return data\n", + "\n", + "class OrientableToClassTransform(object):\n", + " def __call__(self, data):\n", + " data.y = data.orientable.long()\n", + " return data\n", + "\n", + "\n", + "\n", + "tr = transforms.Compose(\n", + " [TriangulationToFaceTransform(),FaceToEdge(remove_faces=False),DegreeTransform(),OrientableToClassTransform()]\n", + " )\n", + "\n", + "dataset = SimplicialDataset(root=\"./data\",transform=tr)\n", + "\n", + "\n", + "print()\n", + "print(f'Dataset: {dataset}:')\n", + "print('====================')\n", + "print(f'Number of graphs: {len(dataset)}')\n", + "print(f'Number of features: {dataset.num_features}')\n", + "print(f'Number of classes: {dataset.num_classes}')\n", + "\n", + "data = dataset[0] # Get the first graph object.\n", + "\n", + "print()\n", + "print(data)\n", + "print('=============================================================')\n", + "\n", + "# Gather some statistics about the first graph.\n", + "print(f'Number of nodes: {len(data.x)}')\n", + "print(f'Number of edges: {data.num_edges}')\n", + "print(f'Average node degree: {data.num_edges / len(data.x):.2f}')\n", + "print(f'Has isolated nodes: {data.has_isolated_nodes()}')\n", + "print(f'Has self-loops: {data.has_self_loops()}')\n", + "print(f'Is undirected: {data.is_undirected()}')\n", + "\n", + "print('=============================================================')\n", + "print(f'Number of orientable Manifolds: {sum(dataset.orientable)}')\n", + "print(f'Number of non-orientable Manifolds: {len(dataset) - sum(dataset.orientable)}')\n", + "print(f'Percentage: {sum(dataset.orientable) / len(dataset):.2f}, {(len(dataset) - sum(dataset.orientable)) / len(dataset):.2f}')" + ] + }, + { + "cell_type": "code", + "execution_count": 68, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of training graphs: 150\n", + "Number of test graphs: 562\n" + ] + } + ], + "source": [ + "\n", + "\n", + "dataset = dataset.shuffle()\n", + "\n", + "train_dataset = dataset[:150]\n", + "test_dataset = dataset[150:]\n", + "\n", + "print(f'Number of training graphs: {len(train_dataset)}')\n", + "print(f'Number of test graphs: {len(test_dataset)}')" + ] + }, + { + "cell_type": "code", + "execution_count": 75, + "metadata": {}, + "outputs": [], + "source": [ + "train_loader = DataLoader(train_dataset)\n", + "test_loader = DataLoader(test_dataset)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 76, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "GCN(\n", + " (conv1): GCNConv(1, 64)\n", + " (conv2): GCNConv(64, 64)\n", + " (conv3): GCNConv(64, 64)\n", + " (lin): Linear(in_features=64, out_features=2, bias=True)\n", + ")\n" + ] + } + ], + "source": [ + "from torch.nn import Linear\n", + "import torch.nn.functional as F\n", + "from torch_geometric.nn import GCNConv\n", + "from torch_geometric.nn import global_mean_pool\n", + "\n", + "\n", + "class GCN(torch.nn.Module):\n", + " def __init__(self, hidden_channels):\n", + " super(GCN, self).__init__()\n", + " torch.manual_seed(12345)\n", + " self.conv1 = GCNConv(dataset.num_node_features, hidden_channels)\n", + " self.conv2 = GCNConv(hidden_channels, hidden_channels)\n", + " self.conv3 = GCNConv(hidden_channels, hidden_channels)\n", + " self.lin = Linear(hidden_channels, dataset.num_classes)\n", + "\n", + " def forward(self, x, edge_index, batch):\n", + " # 1. Obtain node embeddings \n", + " x = self.conv1(x, edge_index)\n", + " x = x.relu()\n", + " x = self.conv2(x, edge_index)\n", + " x = x.relu()\n", + " x = self.conv3(x, edge_index)\n", + "\n", + " # 2. Readout layer\n", + " x = global_mean_pool(x, batch) # [batch_size, hidden_channels]\n", + "\n", + " # 3. Apply a final classifier\n", + " x = F.dropout(x, p=0.5, training=self.training)\n", + " x = self.lin(x)\n", + " \n", + " return x\n", + "\n", + "model = GCN(hidden_channels=64)\n", + "print(model)" + ] + }, + { + "cell_type": "code", + "execution_count": 77, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch: 001, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 002, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 003, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 004, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 005, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 006, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 007, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 008, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 009, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 010, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 011, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 012, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 013, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 014, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 015, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 016, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 017, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 018, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 019, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 020, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 021, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 022, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 023, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 024, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 025, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 026, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 027, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 028, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 029, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 030, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 031, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 032, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 033, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 034, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 035, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 036, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 037, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 038, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 039, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 040, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 041, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 042, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 043, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 044, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 045, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 046, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 047, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 048, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 049, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 050, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 051, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 052, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 053, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 054, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 055, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 056, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 057, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 058, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 059, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 060, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 061, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 062, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 063, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 064, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 065, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 066, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 067, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 068, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 069, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 070, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 071, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 072, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 073, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 074, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 075, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 076, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 077, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 078, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 079, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 080, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 081, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 082, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 083, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 084, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 085, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 086, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 087, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 088, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 089, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 090, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 091, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 092, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 093, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 094, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 095, Train Acc: 0.7333, Test Acc: 0.7278\n", + "Epoch: 096, Train Acc: 0.7333, Test Acc: 0.7278\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[1;32mIn[77], line 29\u001b[0m\n\u001b[0;32m 27\u001b[0m train()\n\u001b[0;32m 28\u001b[0m train_acc \u001b[38;5;241m=\u001b[39m test(train_loader)\n\u001b[1;32m---> 29\u001b[0m test_acc \u001b[38;5;241m=\u001b[39m \u001b[43mtest\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtest_loader\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 30\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mEpoch: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepoch\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m03d\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, Train Acc: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtrain_acc\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m.4f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, Test Acc: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtest_acc\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m.4f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m)\n", + "Cell \u001b[1;32mIn[77], line 20\u001b[0m, in \u001b[0;36mtest\u001b[1;34m(loader)\u001b[0m\n\u001b[0;32m 18\u001b[0m correct \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[0;32m 19\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m data \u001b[38;5;129;01min\u001b[39;00m loader: \u001b[38;5;66;03m# Iterate in batches over the training/test dataset.\u001b[39;00m\n\u001b[1;32m---> 20\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43medge_index\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbatch\u001b[49m\u001b[43m)\u001b[49m \n\u001b[0;32m 21\u001b[0m pred \u001b[38;5;241m=\u001b[39m out\u001b[38;5;241m.\u001b[39margmax(dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m) \u001b[38;5;66;03m# Use the class with highest probability.\u001b[39;00m\n\u001b[0;32m 22\u001b[0m correct \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;28mint\u001b[39m((pred \u001b[38;5;241m==\u001b[39m data\u001b[38;5;241m.\u001b[39my)\u001b[38;5;241m.\u001b[39msum()) \u001b[38;5;66;03m# Check against ground-truth labels.\u001b[39;00m\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch\\nn\\modules\\module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch\\nn\\modules\\module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "Cell \u001b[1;32mIn[76], line 18\u001b[0m, in \u001b[0;36mGCN.forward\u001b[1;34m(self, x, edge_index, batch)\u001b[0m\n\u001b[0;32m 16\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x, edge_index, batch):\n\u001b[0;32m 17\u001b[0m \u001b[38;5;66;03m# 1. Obtain node embeddings \u001b[39;00m\n\u001b[1;32m---> 18\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconv1\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43medge_index\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 19\u001b[0m x \u001b[38;5;241m=\u001b[39m x\u001b[38;5;241m.\u001b[39mrelu()\n\u001b[0;32m 20\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconv2(x, edge_index)\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch\\nn\\modules\\module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch\\nn\\modules\\module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch_geometric\\nn\\conv\\gcn_conv.py:222\u001b[0m, in \u001b[0;36mGCNConv.forward\u001b[1;34m(self, x, edge_index, edge_weight)\u001b[0m\n\u001b[0;32m 220\u001b[0m cache \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_cached_edge_index\n\u001b[0;32m 221\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m cache \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m--> 222\u001b[0m edge_index, edge_weight \u001b[38;5;241m=\u001b[39m \u001b[43mgcn_norm\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# yapf: disable\u001b[39;49;00m\n\u001b[0;32m 223\u001b[0m \u001b[43m \u001b[49m\u001b[43medge_index\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43medge_weight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msize\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnode_dim\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 224\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mimproved\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madd_self_loops\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mflow\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdtype\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 225\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcached:\n\u001b[0;32m 226\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_cached_edge_index \u001b[38;5;241m=\u001b[39m (edge_index, edge_weight)\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch_geometric\\nn\\conv\\gcn_conv.py:91\u001b[0m, in \u001b[0;36mgcn_norm\u001b[1;34m(edge_index, edge_weight, num_nodes, improved, add_self_loops, flow, dtype)\u001b[0m\n\u001b[0;32m 88\u001b[0m num_nodes \u001b[38;5;241m=\u001b[39m maybe_num_nodes(edge_index, num_nodes)\n\u001b[0;32m 90\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m add_self_loops:\n\u001b[1;32m---> 91\u001b[0m edge_index, edge_weight \u001b[38;5;241m=\u001b[39m \u001b[43madd_remaining_self_loops\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 92\u001b[0m \u001b[43m \u001b[49m\u001b[43medge_index\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43medge_weight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfill_value\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_nodes\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 94\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m edge_weight \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m 95\u001b[0m edge_weight \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mones((edge_index\u001b[38;5;241m.\u001b[39msize(\u001b[38;5;241m1\u001b[39m), ), dtype\u001b[38;5;241m=\u001b[39mdtype,\n\u001b[0;32m 96\u001b[0m device\u001b[38;5;241m=\u001b[39medge_index\u001b[38;5;241m.\u001b[39mdevice)\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch_geometric\\utils\\loop.py:342\u001b[0m, in \u001b[0;36madd_remaining_self_loops\u001b[1;34m(edge_index, edge_attr, fill_value, num_nodes)\u001b[0m\n\u001b[0;32m 339\u001b[0m N \u001b[38;5;241m=\u001b[39m maybe_num_nodes(edge_index, num_nodes)\n\u001b[0;32m 340\u001b[0m mask \u001b[38;5;241m=\u001b[39m edge_index[\u001b[38;5;241m0\u001b[39m] \u001b[38;5;241m!=\u001b[39m edge_index[\u001b[38;5;241m1\u001b[39m]\n\u001b[1;32m--> 342\u001b[0m loop_index \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43marange\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mN\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlong\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43medge_index\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 343\u001b[0m loop_index \u001b[38;5;241m=\u001b[39m loop_index\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m0\u001b[39m)\u001b[38;5;241m.\u001b[39mrepeat(\u001b[38;5;241m2\u001b[39m, \u001b[38;5;241m1\u001b[39m)\n\u001b[0;32m 345\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m edge_attr \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", + "\u001b[1;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "model = GCN(hidden_channels=64)\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n", + "criterion = torch.nn.CrossEntropyLoss()\n", + "\n", + "def train():\n", + " model.train()\n", + "\n", + " for data in train_loader: # Iterate in batches over the training dataset.\n", + " out = model(data.x, data.edge_index, data.batch) # Perform a single forward pass.\n", + " loss = criterion(out, data.y) # Compute the loss.\n", + " loss.backward() # Derive gradients.\n", + " optimizer.step() # Update parameters based on gradients.\n", + " optimizer.zero_grad() # Clear gradients.\n", + "\n", + "def test(loader):\n", + " model.eval()\n", + "\n", + " correct = 0\n", + " for data in loader: # Iterate in batches over the training/test dataset.\n", + " out = model(data.x, data.edge_index, data.batch) \n", + " pred = out.argmax(dim=1) # Use the class with highest probability.\n", + " correct += int((pred == data.y).sum()) # Check against ground-truth labels.\n", + " return correct / len(loader.dataset) # Derive ratio of correct predictions.\n", + "\n", + "\n", + "for epoch in range(1, 171):\n", + " train()\n", + " train_acc = test(train_loader)\n", + " test_acc = test(test_loader)\n", + " print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/experiment_simplex2vec.ipynb b/experiment_simplex2vec.ipynb new file mode 100644 index 0000000..5ee9b6d --- /dev/null +++ b/experiment_simplex2vec.ipynb @@ -0,0 +1,444 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import numpy as np\n", + "import gudhi\n", + "\n", + "from torch_geometric.loader import DataLoader\n", + "from torch_geometric.utils import degree\n", + "from torch_geometric.transforms import FaceToEdge, ToUndirected\n", + "import torchvision.transforms as transforms\n", + "\n", + "from mantra.simplicial import SimplicialDataset\n", + "from validation.validate_homology import validate_betti_numbers\n", + "from torch.nn import Linear\n", + "import torch.nn.functional as F\n", + "from torch_geometric.nn import GCNConv, GATConv\n", + "from torch_geometric.nn import global_mean_pool\n", + "\n", + "\n", + "import k_simplex2vec as ks2v\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "int(True)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Data(dimension=[1], n_vertices=[1], torsion_coefficients=[3], betti_numbers=[3], orientable=[1], genus=[1], name='S^2', face=[3, 4], edge_index=[2, 12], x=[4, 1], y=[1], edge_attr=[12, 20], num_nodes=4)\n" + ] + } + ], + "source": [ + "class DegreeTransform(object):\n", + " def __call__(self, data):\n", + " deg = degree(data.edge_index[0], dtype=torch.float)\n", + " data.x = deg.view(-1, 1)\n", + " return data\n", + "\n", + "class TriangulationToFaceTransform(object):\n", + " def __call__(self, data):\n", + " data.face = torch.tensor(data.triangulation).T - 1\n", + " data.triangulation = None\n", + " return data\n", + "\n", + "class OrientableToClassTransform(object):\n", + " def __call__(self, data):\n", + " data.y = int(data.orientable)\n", + " return data\n", + "\n", + "class SetNumNodes:\n", + " def __call__(self,data):\n", + " data.num_nodes = data.n_vertices\n", + " return data\n", + "\n", + "\n", + "class Simplex2Vec:\n", + " def __call__(self,data):\n", + " st = gudhi.SimplexTree()\n", + "\n", + " ei = [[edge[0],edge[1]] for edge in data.edge_index.T.tolist() if edge[0] < edge[1]]\n", + " data.edge_index = torch.tensor(ei).T\n", + " # Say hi to bad programming\n", + " for edge in ei:\n", + " st.insert(edge)\n", + " st.expansion(3)\n", + "\n", + " p1 = ks2v.assemble(cplx =st, k= 1, scheme = \"uniform\", laziness =None)\n", + " P1 = p1.toarray()\n", + "\n", + " Simplices = list()\n", + " for simplex in st.get_filtration():\n", + " if simplex[1]!= np.inf:\n", + " Simplices.append(simplex[0])\n", + " else: \n", + " break \n", + "\n", + " ## Perform random walks on the edges\n", + " L = 20\n", + " N = 40\n", + " Walks = ks2v.RandomWalks(walk_length=L, number_walks=N, P=P1,seed = 3)\n", + " # to save the walks in a text file \n", + " ks2v.save_random_walks(Walks,'RandomWalks_Edges.txt')\n", + "\n", + " ## Embed the edges \n", + " Emb = ks2v.Embedding(Walks = Walks, emb_dim = 20 , epochs = 5 ,filename ='k-simplex2vec_Edge_embedding.model')\n", + " data.edge_attr = torch.tensor(Emb.wv.vectors)\n", + " toundirected = ToUndirected()\n", + " data = toundirected(data)\n", + " return data\n", + "\n", + "\n", + "\n", + "tr = transforms.Compose(\n", + " [\n", + " TriangulationToFaceTransform(),\n", + " SetNumNodes(),\n", + " FaceToEdge(remove_faces=False),\n", + " DegreeTransform(),\n", + " OrientableToClassTransform(),\n", + " Simplex2Vec(),\n", + " ]\n", + " )\n", + "\n", + "dataset = SimplicialDataset(root=\"./data\",pre_transform=tr)\n", + "\n", + "\n", + "print(dataset[0])\n", + "\n", + "# print()\n", + "# print(f'Dataset: {dataset}:')\n", + "# print('====================')\n", + "# print(f'Number of graphs: {len(dataset)}')\n", + "# print(f'Number of features: {dataset.num_features}')\n", + "# print(f'Number of classes: {dataset.num_classes}')\n", + "\n", + "# data = dataset[0] # Get the first graph object.\n", + "\n", + "# print()\n", + "# print(data)\n", + "# print('=============================================================')\n", + "\n", + "# # Gather some statistics about the first graph.\n", + "# print(f'Number of nodes: {len(data.x)}')\n", + "# print(f'Number of edges: {data.num_edges}')\n", + "# print(f'Average node degree: {data.num_edges / len(data.x):.2f}')\n", + "# print(f'Has isolated nodes: {data.has_isolated_nodes()}')\n", + "# print(f'Has self-loops: {data.has_self_loops()}')\n", + "# print(f'Is undirected: {data.is_undirected()}')\n", + "\n", + "# print('=============================================================')\n", + "# print(f'Number of orientable Manifolds: {sum(dataset.orientable)}')\n", + "# print(f'Number of non-orientable Manifolds: {len(dataset) - sum(dataset.orientable)}')\n", + "# print(f'Percentage: {sum(dataset.orientable) / len(dataset):.2f}, {(len(dataset) - sum(dataset.orientable)) / len(dataset):.2f}')" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[0, 0, 0, 1],\n", + " [1, 1, 2, 2],\n", + " [2, 3, 3, 3]])" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset[0].face" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of training graphs: 150\n", + "Number of test graphs: 562\n" + ] + } + ], + "source": [ + "\n", + "dataset = dataset.shuffle()\n", + "\n", + "train_dataset = dataset[:150]\n", + "test_dataset = dataset[150:]\n", + "\n", + "print(f'Number of training graphs: {len(train_dataset)}')\n", + "print(f'Number of test graphs: {len(test_dataset)}')\n", + "\n", + "train_loader = DataLoader(train_dataset)\n", + "test_loader = DataLoader(test_dataset)\n", + "\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([9, 1])\n", + "tensor(0)\n" + ] + } + ], + "source": [ + "for batch in train_loader:\n", + " break\n", + "\n", + "print(batch.x.shape)\n", + "print(batch.edge_index.min())\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "GCN(\n", + " (conv1): GATConv(1, 64, heads=1)\n", + " (conv2): GATConv(64, 64, heads=1)\n", + " (conv3): GATConv(64, 64, heads=1)\n", + " (lin): Linear(in_features=64, out_features=2, bias=True)\n", + ")\n" + ] + } + ], + "source": [ + "class GCN(torch.nn.Module):\n", + " def __init__(self, hidden_channels):\n", + " super(GCN, self).__init__()\n", + " torch.manual_seed(12345)\n", + " self.conv1 = GATConv(1, hidden_channels)\n", + " self.conv2 = GATConv(hidden_channels, hidden_channels)\n", + " self.conv3 = GATConv(hidden_channels, hidden_channels)\n", + " self.lin = Linear(hidden_channels, 2)\n", + "\n", + " def forward(self, x, edge_index, edge_attr, batch):\n", + " # 1. Obtain node embeddings \n", + " x = self.conv1(x, edge_index,edge_attr)\n", + " x = x.relu()\n", + " x = self.conv2(x, edge_index,edge_attr)\n", + " x = x.relu()\n", + " x = self.conv3(x, edge_index,edge_attr)\n", + "\n", + " # 2. Readout layer\n", + " x = global_mean_pool(x, batch) # [batch_size, hidden_channels]\n", + "\n", + " # 3. Apply a final classifier\n", + " x = F.dropout(x, p=0.5, training=self.training)\n", + " x = self.lin(x)\n", + " \n", + " return x\n", + "\n", + "model = GCN(hidden_channels=64)\n", + "print(model)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch: 001, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 002, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 003, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 004, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 005, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 006, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 007, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 008, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 009, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 010, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 011, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 012, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 013, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 014, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 015, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 016, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 017, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 018, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 019, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 020, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 021, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 022, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 023, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 024, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 025, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 026, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 027, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 028, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 029, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 030, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 031, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 032, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 033, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 034, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 035, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 036, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 037, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 038, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 039, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 040, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 041, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 042, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 043, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 044, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 045, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 046, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 047, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 048, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 049, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 050, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 051, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 052, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 053, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 054, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 055, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 056, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 057, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 058, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 059, Train Acc: 0.6467, Test Acc: 0.7509\n", + "Epoch: 060, Train Acc: 0.6467, Test Acc: 0.7509\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[1;32mIn[10], line 29\u001b[0m\n\u001b[0;32m 27\u001b[0m train()\n\u001b[0;32m 28\u001b[0m train_acc \u001b[38;5;241m=\u001b[39m test(train_loader)\n\u001b[1;32m---> 29\u001b[0m test_acc \u001b[38;5;241m=\u001b[39m \u001b[43mtest\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtest_loader\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 30\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mEpoch: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepoch\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m03d\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, Train Acc: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtrain_acc\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m.4f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, Test Acc: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtest_acc\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m.4f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m)\n", + "Cell \u001b[1;32mIn[10], line 20\u001b[0m, in \u001b[0;36mtest\u001b[1;34m(loader)\u001b[0m\n\u001b[0;32m 18\u001b[0m correct \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[0;32m 19\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m data \u001b[38;5;129;01min\u001b[39;00m loader: \u001b[38;5;66;03m# Iterate in batches over the training/test dataset.\u001b[39;00m\n\u001b[1;32m---> 20\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43medge_index\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43medge_attr\u001b[49m\u001b[43m,\u001b[49m\u001b[43mdata\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbatch\u001b[49m\u001b[43m)\u001b[49m \n\u001b[0;32m 21\u001b[0m pred \u001b[38;5;241m=\u001b[39m out\u001b[38;5;241m.\u001b[39margmax(dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m) \u001b[38;5;66;03m# Use the class with highest probability.\u001b[39;00m\n\u001b[0;32m 22\u001b[0m correct \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;28mint\u001b[39m((pred \u001b[38;5;241m==\u001b[39m data\u001b[38;5;241m.\u001b[39my)\u001b[38;5;241m.\u001b[39msum()) \u001b[38;5;66;03m# Check against ground-truth labels.\u001b[39;00m\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch\\nn\\modules\\module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch\\nn\\modules\\module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "Cell \u001b[1;32mIn[7], line 14\u001b[0m, in \u001b[0;36mGCN.forward\u001b[1;34m(self, x, edge_index, edge_attr, batch)\u001b[0m\n\u001b[0;32m 12\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconv1(x, edge_index,edge_attr)\n\u001b[0;32m 13\u001b[0m x \u001b[38;5;241m=\u001b[39m x\u001b[38;5;241m.\u001b[39mrelu()\n\u001b[1;32m---> 14\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconv2\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43medge_index\u001b[49m\u001b[43m,\u001b[49m\u001b[43medge_attr\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 15\u001b[0m x \u001b[38;5;241m=\u001b[39m x\u001b[38;5;241m.\u001b[39mrelu()\n\u001b[0;32m 16\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconv3(x, edge_index,edge_attr)\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch\\nn\\modules\\module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch\\nn\\modules\\module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch_geometric\\nn\\conv\\gat_conv.py:239\u001b[0m, in \u001b[0;36mGATConv.forward\u001b[1;34m(self, x, edge_index, edge_attr, size, return_attention_weights)\u001b[0m\n\u001b[0;32m 236\u001b[0m num_nodes \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mmin\u001b[39m(size) \u001b[38;5;28;01mif\u001b[39;00m size \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m num_nodes\n\u001b[0;32m 237\u001b[0m edge_index, edge_attr \u001b[38;5;241m=\u001b[39m remove_self_loops(\n\u001b[0;32m 238\u001b[0m edge_index, edge_attr)\n\u001b[1;32m--> 239\u001b[0m edge_index, edge_attr \u001b[38;5;241m=\u001b[39m \u001b[43madd_self_loops\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 240\u001b[0m \u001b[43m \u001b[49m\u001b[43medge_index\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43medge_attr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfill_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfill_value\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 241\u001b[0m \u001b[43m \u001b[49m\u001b[43mnum_nodes\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnum_nodes\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 242\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(edge_index, SparseTensor):\n\u001b[0;32m 243\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39medge_dim \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", + "File \u001b[1;32m~\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torch_geometric\\utils\\loop.py:263\u001b[0m, in \u001b[0;36madd_self_loops\u001b[1;34m(edge_index, edge_attr, fill_value, num_nodes)\u001b[0m\n\u001b[0;32m 261\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(fill_value, \u001b[38;5;28mstr\u001b[39m):\n\u001b[0;32m 262\u001b[0m col \u001b[38;5;241m=\u001b[39m edge_index[\u001b[38;5;241m0\u001b[39m] \u001b[38;5;28;01mif\u001b[39;00m is_sparse \u001b[38;5;28;01melse\u001b[39;00m edge_index[\u001b[38;5;241m1\u001b[39m]\n\u001b[1;32m--> 263\u001b[0m loop_attr \u001b[38;5;241m=\u001b[39m \u001b[43mscatter\u001b[49m\u001b[43m(\u001b[49m\u001b[43medge_attr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcol\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mN\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfill_value\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 264\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m 265\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mNo valid \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mfill_value\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m provided\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", + "\u001b[1;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "model = GCN(hidden_channels=64)\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)\n", + "criterion = torch.nn.CrossEntropyLoss()\n", + "\n", + "def train():\n", + " model.train()\n", + "\n", + " for data in train_loader: # Iterate in batches over the training dataset.\n", + " out = model(data.x, data.edge_index, data.edge_attr, data.batch) # Perform a single forward pass.\n", + " loss = criterion(out, data.y) # Compute the loss.\n", + " loss.backward() # Derive gradients.\n", + " optimizer.step() # Update parameters based on gradients.\n", + " optimizer.zero_grad() # Clear gradients.\n", + "\n", + "def test(loader):\n", + " model.eval()\n", + "\n", + " correct = 0\n", + " for data in loader: # Iterate in batches over the training/test dataset.\n", + " out = model(data.x, data.edge_index, data.edge_attr,data.batch) \n", + " pred = out.argmax(dim=1) # Use the class with highest probability.\n", + " correct += int((pred == data.y).sum()) # Check against ground-truth labels.\n", + " return correct / len(loader.dataset) # Derive ratio of correct predictions.\n", + "\n", + "\n", + "for epoch in range(1, 171):\n", + " train()\n", + " train_acc = test(train_loader)\n", + " test_acc = test(test_loader)\n", + " print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}