Skip to content

Commit

Permalink
Bump LIBSVM version to v0.7
Browse files Browse the repository at this point in the history
The svmtrain and svmpredict functions slightly changed in 0.7 for custom
kernels, therefore some code had to be changed.
  • Loading branch information
simonschoelly authored Jun 24, 2021
1 parent eed1aa4 commit 3aa1e6d
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 15 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ ThreadsX = "ac1d9e8a-700a-412c-b207-f0111f4b6c0d"
[compat]
DataStructures = "0.18.9"
KernelFunctions = "0.10.5"
LIBSVM = "0.6"
LIBSVM = "0.7"
LightGraphs = "1.3"
SimpleValueGraphs = "0.3"
ThreadsX = "0.1"
Expand Down
2 changes: 1 addition & 1 deletion src/GraphKernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ include("integrations/LIBSVM.jl")
Simple k-fold cross validation implementation for quick testing during development.
"""
function k_fold_cross_validation(kernel::AbstractGraphKernel, graphs; k_folds=5, class_key=1, kwargs...)
function k_fold_cross_validation(kernel::KernelFunctions.Kernel, graphs::AbstractVector{<:AbstractGraph}; k_folds=5, class_key=1, kwargs...)

n = length(graphs)
indices = randperm(MersenneTwister(123), n)
Expand Down
18 changes: 5 additions & 13 deletions src/integrations/LIBSVM.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@

struct GraphSVMModel
svm::LIBSVM.SVM
kernel::AbstractGraphKernel
kernel::KernelFunctions.Kernel
graphs::AbstractVector
end


function svmtrain(graphs::AbstractVector{<:AbstractGraph}, labels, kernel::AbstractGraphKernel; kwargs...)

n = length(graphs)

X = vcat(transpose(1:n), kernelmatrix(kernel, graphs))
function svmtrain(graphs::AbstractVector{<:AbstractGraph}, labels, kernel::KernelFunctions.Kernel; kwargs...)

X = kernelmatrix(kernel, graphs)
svm = svmtrain(X, labels, kernel=Kernel.Precomputed; kwargs...)

return GraphSVMModel(svm, kernel, graphs)
Expand All @@ -22,12 +19,7 @@ function svmpredict(model::GraphSVMModel, unpredicted_graphs::AbstractVector{<:A
graphs = model.graphs
kernel = model.kernel

m = length(graphs)
n = length(unpredicted_graphs)

X = Matrix{Float64}(undef, m + 1, n)
X[1, :] = 1:n
X[2:end, :] = kernelmatrix(kernel, graphs, unpredicted_graphs)

# TODO might only be necessary to do the calculations for support vectors
X = kernelmatrix(kernel, graphs, unpredicted_graphs)
return svmpredict(model.svm, X)[1] # for simplicity return only the labels for now
end

0 comments on commit 3aa1e6d

Please sign in to comment.