diff --git a/GCN/GCN.fsproj b/GCN/GCN.fsproj index 95024f2..b2ce6e9 100644 --- a/GCN/GCN.fsproj +++ b/GCN/GCN.fsproj @@ -6,8 +6,8 @@ - + diff --git a/GCN/GCNModel.fs b/GCN/GCNModel.fs index 6918688..16d674d 100644 --- a/GCN/GCNModel.fs +++ b/GCN/GCNModel.fs @@ -5,6 +5,7 @@ open type TorchSharp.NN.Modules open TorchSharp.Fun let inline (!>) (x:^a) : ^b = ((^a or ^b) : (static member op_Implicit : ^a -> ^b) x) +///single graph convolutional layer let gcnLayer in_features out_features hasBias (adj:TorchTensor) = let weight = Parameter(randName(),Float32Tensor.empty([|in_features; out_features|],requiresGrad=true)) let bias = if hasBias then Parameter(randName(),Float32Tensor.empty([|out_features|],requiresGrad=true)) |> Some else None @@ -23,7 +24,7 @@ let gcnLayer in_features out_features hasBias (adj:TorchTensor) = let create nfeat nhid nclass dropout adj = let gc1 = gcnLayer nfeat nhid true adj let gc2 = gcnLayer nhid nclass true adj - let drp = if dropout > 0.0 then Dropout(dropout) |> M else Model.nop + let drp = Dropout(dropout) fwd3 gc1 gc2 drp (fun t g1 g2 drp -> use t = gc1.forward(t) diff --git a/GCN/Program.fs b/GCN/Program.fs index 6e401a1..52f6439 100644 --- a/GCN/Program.fs +++ b/GCN/Program.fs @@ -57,9 +57,6 @@ module Defs = [] let main args = let runParms = Defs.parse args - try - Train.run runParms - with ex -> - printfn "%s" ex.Message + Train.run runParms System.Console.ReadLine() |> ignore 0 diff --git a/GCN/Train.fs b/GCN/Train.fs index 7c09737..38700e0 100644 --- a/GCN/Train.fs +++ b/GCN/Train.fs @@ -22,34 +22,38 @@ let run (datafolder,no_cuda,fastmode,epochs,dropout,lr,hidden,seed,weight_decay) let loss = NN.Functions.nll_loss() if cuda then - model.Module.cuda() |> ignore + model.Module.cuda() |> ignore let optimizer = NN.Optimizer.Adam(model.Module.parameters(), learningRate = lr, weight_decay=weight_decay) let train epoch = let t = DateTime.Now - model.Module.Train() - let parms = model.Module.parameters() + model.Module.Train() optimizer.zero_grad() let output = model.forward(features) let loss_train = loss.Invoke(output.[ idx_train], labels.[idx_train]) - let ls = float loss_train let acc_train = Utils.accuracy(output.[idx_train], labels.[idx_train]) - printfn $"training - loss: {ls}, acc: {acc_train}" loss_train.backward() optimizer.step() - if fastmode then - model.Module.Eval() - let y' = model.forward(features) - let loss_val = loss.Invoke(y'.[idx_val], labels.[idx_val]) - let acc_val = Utils.accuracy(y'.[idx_val], labels.[idx_val]) - printfn - $""" - Epoch: {epoch}, loss_train: {float loss_train}, - acc_train: {acc_train}, loss_val: {float loss_val}, - acc_val: {acc_val}, time: {t} - """ + let parms = model.Module.parameters() + let data = parms |> Array.map TorchSharp.Fun.Tensor.getData + let i = 1 + + let loss_val,acc_val = + if not fastmode then + model.Module.Eval() + let y' = model.forward(features) + let loss_val = loss.Invoke(y'.[idx_val], labels.[idx_val]) + let acc_val = Utils.accuracy(y'.[idx_val], labels.[idx_val]) + loss_val,acc_val + else + let loss_val = loss.Invoke(output.[idx_val], labels.[idx_val]) + let acc_val = Utils.accuracy(output.[idx_val], labels.[idx_val]) + loss_val,acc_val + + printf $"Epoch: {epoch}, loss_train: %0.4f{float loss_train}, acc_train: %0.4f{acc_train}, " + printfn $"loss_val: %0.4f{float loss_val}, acc_val: %0.4f{acc_val}" let test() = model.Module.Eval() @@ -62,10 +66,9 @@ let run (datafolder,no_cuda,fastmode,epochs,dropout,lr,hidden,seed,weight_decay) let t_total = DateTime.Now for i in 1 .. epochs-1 do - printfn $"epoch {i}" train i printfn "Optimization done" - printfn $"Time elapsed: {(DateTime.Now - t_total).TotalMinutes} minutes" + printfn $"Time elapsed: %0.2f{(DateTime.Now - t_total).TotalSeconds} seconds" test() diff --git a/GCN/Utils.fs b/GCN/Utils.fs index 8adac97..a7a3fe1 100644 --- a/GCN/Utils.fs +++ b/GCN/Utils.fs @@ -30,13 +30,13 @@ let normalize (m:Matrix) = let sparse_mx_to_torch_sparse_tensor (m:Matrix) = let coo = m.EnumerateIndexed(Zeros.AllowSkip) - let rows = coo |> Seq.map (fun (r,c,v) -> int64 r) + let rows = coo |> Seq.map (fun (r,c,v) -> int64 r) let cols = coo |> Seq.map (fun (r,c,v) -> int64 c) + let vals = coo |> Seq.map (fun (r,c,v) -> v) let idxs = Seq.append rows cols |> Seq.toArray - let idx1 = idxs |> Int64Tensor.from |> fun x -> x.view(2L,-1L) - let vals = coo |> Seq.map(fun (r,c,v) -> v) |> Seq.toArray |> Float32Tensor.from - let t = Float32Tensor.sparse(idx1,vals,[|int64 m.RowCount; int64 m.ColumnCount|]) - let dt = TorchSharp.Fun.Tensor.getData(t.to_dense()) + let idxT = idxs |> Int64Tensor.from |> fun x -> x.view(2L, idxs.Length / 2 |> int64) + let valsT = vals |> Seq.toArray |> Float32Tensor.from + let t = Float32Tensor.sparse(idxT,valsT,[|int64 m.RowCount; int64 m.ColumnCount|]) t let accuracy(output:TorchTensor, labels:TorchTensor) = @@ -64,27 +64,26 @@ let loadData (dataFolder:string) dataset = Label = xs.[xs.Length-1] |}) - let edges = + let idx_map = dataFeatures |> Seq.mapi (fun i x-> x.Id,i) |> Map.ofSeq + + let edges_unordered = edgesFile |> File.ReadLines |> Seq.map (fun x->x.Split('\t')) |> Seq.map (fun xs -> xs.[0],xs.[1]) |> Seq.toArray - let edgeIdx = - edges - |> Seq.collect (fun (a,b)->[a;b]) - |> Seq.distinct - |> Seq.mapi (fun i x->x,i) - |> dict + let edges = + edges_unordered + |> Array.map (fun (a,b) -> idx_map.[a],idx_map.[b]) - let ftrs = Matrix.Build.DenseOfRows(dataFeatures |> Seq.map (fun x->Array.toSeq x.Features)) + let ftrs = Matrix.Build.SparseOfRowArrays(dataFeatures |> Seq.map (fun x-> x.Features) |> Seq.toArray) let graph = Matrix.Build.SparseFromCoordinateFormat ( - edgeIdx.Count, edgeIdx.Count, edges.Length, //rows,cols,num vals - edges |> Array.map (fun x -> edgeIdx.[fst x]), //hot row idx - edges |> Array.map (fun x -> edgeIdx.[snd x]), //hot col idx + idx_map.Count, idx_map.Count, edges.Length, //rows,cols,num vals + edges |> Array.map fst, //hot row idx + edges |> Array.map snd, //hot col idx edges |> Array.map (fun _ -> 1.0f) //values ) diff --git a/GCN/scripts/gcn.fsx b/GCN/scripts/gcn.fsx index 32d3715..91692f7 100644 --- a/GCN/scripts/gcn.fsx +++ b/GCN/scripts/gcn.fsx @@ -1,100 +1,24 @@ #load "packages.fsx" -open System -open System.IO -open MathNet.Numerics.LinearAlgebra +#load "../Utils.fs" +open TorchSharp.Fun -let dataFolder = @"C:\Users\fwaris\Downloads\pygcn-master\data\cora" -let contentFile = $"{dataFolder}/cora.content" -let citesFile = $"{dataFolder}/cora.cites" -let yourself x = x +let datafolder = @"C:\s\Repos\gcn\data\cora" +let adj, features, labels, idx_train, idx_val, idx_test = Utils.loadData datafolder None -let dataCntnt = - contentFile - |> File.ReadLines - |> Seq.map(fun x -> x.Split('\t')) - |> Seq.map(fun xs -> - {| - Id = xs.[0] - Features = xs.[1 .. xs.Length-2] |> Array.map float32 - Label = xs.[xs.Length-1] - |}) +let v1 = adj.[0L,50L] |> float -let dataCites = - citesFile - |> File.ReadLines - |> Seq.map (fun x->x.Split('\t')) - |> Seq.map (fun xs -> xs.[0],xs.[1]) - |> Seq.toArray +let idx = adj.SparseIndices |> Tensor.getData +let rc = idx |> Array.chunkBySize (idx.Length/2) +let vals = adj.SparseValues |> Tensor.getData -let citationIdx = dataCites |> Seq.collect (fun (a,b)->[a;b]) |> Seq.distinct |> Seq.mapi (fun i x->x,i) |> dict +let i = 500 +let r,c = rc.[0].[i],rc.[1].[i] +let vx = adj.[r,c] |> float -let ftrs = Matrix.Build.DenseOfRows(dataCntnt |> Seq.map (fun x->Array.toSeq x.Features)) +let df = features |> Tensor.getData |> Array.chunkBySize (int features.shape.[1]) -let graph = Matrix.Build.SparseFromCoordinateFormat - ( - dataCites.Length, dataCites.Length, dataCites.Length, - dataCites |> Array.map (fun x -> citationIdx.[fst x]), - dataCites |> Array.map (fun x -> citationIdx.[snd x]), - dataCites |> Array.map (fun _ -> 1.0f) - ) +let f1 = features.[1L,12L] |> float -let normalize (m:Matrix) = - let rowsum = m.RowSums() - let r_inv = rowsum.PointwisePower(-1.0f) - let r_inv = r_inv.Map(fun x-> if Single.IsInfinity x then 0.0f else x) - let r_mat_inv = Matrix.Build.SparseOfDiagonalVector(r_inv) - let mx = r_mat_inv.Multiply(m) - mx - -let graph_n = Matrix.Build.SparseIdentity(graph.RowCount) + graph |> normalize -let ftrs_n = normalize ftrs - -open TorchSharp.Tensor -let sparse_mx_to_torch_sparse_tensor (m:Matrix) = - let coo = m.EnumerateIndexed(Zeros.AllowSkip) - let rows = coo |> Seq.map (fun (r,c,v) -> int64 r) - let cols = coo |> Seq.map (fun (r,c,v) -> int64 c) - let idxs = Seq.append rows cols |> Seq.toArray - let idx1 = idxs |> Int64Tensor.from |> fun x -> x.view(2L,-1L) - let vals = coo |> Seq.map(fun (r,c,v) -> v) |> Seq.toArray |> Float32Tensor.from - Float32Tensor.sparse(idx1,vals,[|int64 m.RowCount; int64 m.ColumnCount|]) - -let adj = sparse_mx_to_torch_sparse_tensor(graph_n) - -module GCNModel = - open TorchSharp.Tensor - open TorchSharp.NN - open type TorchSharp.NN.Modules - open TorchSharp.Fun - let inline (!>) (x:^a) : ^b = ((^a or ^b) : (static member op_Implicit : ^a -> ^b) x) - - let gcnLayer in_features out_features hasBias (adj:TorchTensor) = - let weight = Parameter(randName(),Float32Tensor.empty([|in_features; out_features|])) - let bias = if hasBias then Parameter(randName(),Float32Tensor.empty([|out_features|])) |> Some else None - let parms = [| yield weight; if hasBias then yield bias.Value|] - Init.kaiming_uniform(weight.Tensor) |> ignore - - Model.create(parms,fun wts t -> - let support = t.mm(wts.[0]) - let output = adj.mm(support) - if hasBias then - output.add(wts.[1]) - else - output) - - let create nfeat nhid nclass dropout adj = - let gc1 = gcnLayer nfeat nhid true adj - let gc2 = gcnLayer nhid nclass true adj - let relu = ReLU() - let logm = LogSoftmax(1L) - let drp = if dropout then Dropout() |> M else Model.nop - fwd3 gc1 gc2 drp (fun t g1 g2 drp -> - use t = gc1.forward(t) - use t = relu.forward(t) - use t = drp.forward(t) - use t = gc2.forward(t) - let t = logm.forward(t) - t) diff --git a/GCN/scripts/packages.fsx b/GCN/scripts/packages.fsx index 538069b..1e50fca 100644 --- a/GCN/scripts/packages.fsx +++ b/GCN/scripts/packages.fsx @@ -15,8 +15,7 @@ //#r "nuget: libtorch-cuda-11.1-win-x64, 1.8.0.7" System.Runtime.InteropServices.NativeLibrary.Load(@"D:\s\libtorch\lib\torch_cuda.dll") -#load @"..\MLUtils.fs" -#load @"..\MathUtils.fs" + #load @"..\TorchSharp.Fun.fs" #I @"C:\Program Files\dotnet\shared\Microsoft.WindowsDesktop.App\5.0.4" #r "System.Windows.Forms" diff --git a/README.md b/README.md index 173c63e..2a7c612 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,78 @@ -### Design thoughts +# Graph Convolutional Networks in TorchSharp.Fun -#### What is avaialble +TorchSharp.Fun is thin functional wrapper in F# over TorchSharp (a .Net binding of PyTorch). -- V1 clusters -- V2 clusters +## TorchSharp.Fun Example -#### Model -- each micro cluster is a graph -- extract the micrograph -- note time to wait for root alarm -- +Below is a simple sequential model. It is a composition over standard TorchSharp 'modules'. The compostion is performed with the '->>' operator. + +```F# +let model = + Linear(10L,5L) + ->> Dropout(0.5) + ->> Linear(5L,1L) + ->> RelU() +``` + +## GCN Model + +The Graph Convolutional Network (GCN) model presented in this repo is based on the work of Thomas Kipf, [Graph Convolutional Networks](http://tkipf.github.io/graph-convolutional-networks/) (2016). + +It is a port of the [Pytorch GCN model](http://github.com/tkipf/pygcn). + +## TorchSharp.Fun + +The code for TorchSharp.Fun is included in the repo. At this stage it is expected to undergo considerable churn and therefore is not released as an independent package. + +## Training the model + +The data for the model included is however two changes to source are required to train the model. Both are in Program.fs file. These are: + +- Path to libtorch native library - [download link](https://pytorch.org/) +- Path to the data folder + +It is recommend to use Visual Studio code with F# / Ionide plug-in - just start the project after making the above changes. + +## Why TorchSharp.Fun? + +A function-compostional approach to deep learning models arose when I could not easily create a deep ResNet model with 'standard' TorchSharp. + +An alternative F# library was also tried. The library supports an elegant API; it was easy to create a deep ResNet model. Unfortunately at its current stage of development, the training performance for deep models is not on par with that of basic TorchSharp. + +TorchSharp.Fun is a very thin wrapper over TorchSharp does not suffer any noticable performance hits when compared with TorchSharp (or PyTorch for that matter). + +Below is an example of a 30 layer ResNet regression model: + +```F# +module Resnet = + let RESNET_DIM = 50L + let RESNET_DEPTH = 30 + let FTR_DIM = 340L + + let act() = SELU() + + let resnetCell (input: Model) = + let cell = + act() + ->> Linear(RESNET_DIM, RESNET_DIM) + ->> act() + ->> Linear(RESNET_DIM, RESNET_DIM) + + let join = + fwd2 input cell (fun ``input tensor`` inputModel cellModel -> + use t1 = inputModel.forward (``input tensor``) + use t2 = cellModel.forward (t1) + t1 + t2) + + join ->> act() + + let create() = + let emb = Linear(FTR_DIM, RESNET_DIM) |> M + let rsLayers = + (emb, [ 1 .. RESNET_DEPTH ]) + ||> List.fold (fun emb _ -> resnetCell emb) + rsLayers + ->> Linear(RESNET_DIM,10L) + ->> Linear(10L, 1L) + +```