diff --git a/Project.toml b/Project.toml index 3ba2c62..f609f25 100644 --- a/Project.toml +++ b/Project.toml @@ -18,7 +18,7 @@ ThreadsX = "ac1d9e8a-700a-412c-b207-f0111f4b6c0d" [compat] DataStructures = "0.18.9" KernelFunctions = "0.10.5" -LIBSVM = "0.6" +LIBSVM = "0.7" LightGraphs = "1.3" SimpleValueGraphs = "0.3" ThreadsX = "0.1" diff --git a/src/GraphKernels.jl b/src/GraphKernels.jl index 4d4ed5c..ee65e01 100644 --- a/src/GraphKernels.jl +++ b/src/GraphKernels.jl @@ -55,7 +55,7 @@ include("integrations/LIBSVM.jl") Simple k-fold cross validation implementation for quick testing during development. """ -function k_fold_cross_validation(kernel::AbstractGraphKernel, graphs; k_folds=5, class_key=1, kwargs...) +function k_fold_cross_validation(kernel::KernelFunctions.Kernel, graphs::AbstractVector{<:AbstractGraph}; k_folds=5, class_key=1, kwargs...) n = length(graphs) indices = randperm(MersenneTwister(123), n) diff --git a/src/integrations/LIBSVM.jl b/src/integrations/LIBSVM.jl index 6bf02df..de44db4 100644 --- a/src/integrations/LIBSVM.jl +++ b/src/integrations/LIBSVM.jl @@ -1,17 +1,14 @@ struct GraphSVMModel svm::LIBSVM.SVM - kernel::AbstractGraphKernel + kernel::KernelFunctions.Kernel graphs::AbstractVector end -function svmtrain(graphs::AbstractVector{<:AbstractGraph}, labels, kernel::AbstractGraphKernel; kwargs...) - - n = length(graphs) - - X = vcat(transpose(1:n), kernelmatrix(kernel, graphs)) +function svmtrain(graphs::AbstractVector{<:AbstractGraph}, labels, kernel::KernelFunctions.Kernel; kwargs...) + X = kernelmatrix(kernel, graphs) svm = svmtrain(X, labels, kernel=Kernel.Precomputed; kwargs...) return GraphSVMModel(svm, kernel, graphs) @@ -22,12 +19,7 @@ function svmpredict(model::GraphSVMModel, unpredicted_graphs::AbstractVector{<:A graphs = model.graphs kernel = model.kernel - m = length(graphs) - n = length(unpredicted_graphs) - - X = Matrix{Float64}(undef, m + 1, n) - X[1, :] = 1:n - X[2:end, :] = kernelmatrix(kernel, graphs, unpredicted_graphs) - + # TODO might only be necessary to do the calculations for support vectors + X = kernelmatrix(kernel, graphs, unpredicted_graphs) return svmpredict(model.svm, X)[1] # for simplicity return only the labels for now end