Skip to content

Commit

Permalink
Merge pull request #205 from SciML/fm/tr
Browse files Browse the repository at this point in the history
Changes to training
  • Loading branch information
MartinuzziFrancesco authored Mar 1, 2024
2 parents 0a1bccc + 9133fb9 commit 21948f7
Show file tree
Hide file tree
Showing 11 changed files with 115 additions and 110 deletions.
12 changes: 9 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,22 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CellularAutomata = "878138dc-5b27-11ea-1a71-cb95d38d6b29"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
LIBSVM = "b1bec4e5-fd48-53fe-b0cb-9723c09d164b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLJLinearModels = "6ee0df7b-362f-4a72-a706-9e79364fb692"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"

[weakdeps]
LIBSVM = "b1bec4e5-fd48-53fe-b0cb-9723c09d164b"
MLJLinearModels = "6ee0df7b-362f-4a72-a706-9e79364fb692"

[extensions]
RCMLJLinearModelsExt = "MLJLinearModels"
RCLIBSVMExt = "LIBSVM"

[compat]
Adapt = "3.3.3, 4"
Aqua = "0.8"
Expand Down Expand Up @@ -46,4 +52,4 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "Test", "SafeTestsets", "Random", "DifferentialEquations"]
test = ["Aqua", "Test", "SafeTestsets", "Random", "DifferentialEquations", "MLJLinearModels", "LIBSVM"]
32 changes: 32 additions & 0 deletions ext/RCLIBSVMExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
module RCLIBSVMExt
using ReservoirComputing
using LIBSVM

function ReservoirComputing.train(svr::LIBSVM.AbstractSVR, states, target)
out_size = size(target, 1)
output_matrix = []

if out_size == 1
output_matrix = LIBSVM.fit!(svr, states', vec(target))
else
for i in 1:out_size
push!(output_matrix, LIBSVM.fit!(svr, states', target[i, :]))
end
end

return OutputLayer(svr, output_matrix, out_size, target[:, end])
end

function ReservoirComputing.get_prediction(
training_method::LIBSVM.AbstractSVR, output_layer, x)
out = zeros(output_layer.out_size)

for i in 1:(output_layer.out_size)
x_new = reshape(x, 1, length(x))
out[i] = LIBSVM.predict(output_layer.output_matrix[i], x_new)[1]
end

return out
end

end #module
25 changes: 25 additions & 0 deletions ext/RCMLJLinearModelsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
module RCMLJLinearModelsExt
using ReservoirComputing
using MLJLinearModels

function ReservoirComputing.train(regressor::MLJLinearModels.GeneralizedLinearRegression,
states::AbstractArray{T},
target::AbstractArray{T};
kwargs...) where {T <: Number}
out_size = size(target, 1)
output_layer = similar(target, size(target, 1), size(states, 1))

if regressor.fit_intercept
throw(ArgumentError("fit_intercept=true is not yet supported.
Please add fit_intercept=false to the MLJ regressor"))
end

for i in axes(target, 1)
output_layer[i, :] = MLJLinearModels.fit(regressor, states',
target[i, :]; kwargs...)
end

return OutputLayer(regressor, output_layer, out_size, target[:, end])
end

end #module
15 changes: 7 additions & 8 deletions src/ReservoirComputing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@ using Adapt
using CellularAutomata
using Distances
using Distributions
using LIBSVM
using LinearAlgebra
using MLJLinearModels
using NNlib
using Optim
using PartialFunctions
Expand All @@ -16,7 +14,7 @@ using WeightInitializers

export NLADefault, NLAT1, NLAT2, NLAT3
export StandardStates, ExtendedStates, PaddedStates, PaddedExtendedStates
export StandardRidge, LinearModel
export StandardRidge
export scaled_rand, weighted_init, informed_init, minimal_init
export rand_sparse, delay_line, delay_line_backward, cycle_jumps, simple_cycle, pseudo_svd
export RNN, MRNN, GRU, GRUParams, FullyGated, Minimal
Expand All @@ -31,11 +29,7 @@ export Generative, Predictive, OutputLayer
abstract type AbstractReservoirComputer end
abstract type AbstractOutputLayer end
abstract type AbstractPrediction end
#training methods
abstract type AbstractLinearModel end
abstract type AbstractSupportVector end
#should probably move some of these
abstract type AbstractVariation end
abstract type AbstractGRUVariant end

#general output layer struct
Expand Down Expand Up @@ -104,7 +98,6 @@ include("predict.jl")

#general training
include("train/linear_regression.jl")
include("train/supportvector_regression.jl")

#esn
include("esn/esn_input_layers.jl")
Expand All @@ -119,4 +112,10 @@ include("esn/esn_predict.jl")
include("reca/reca.jl")
include("reca/reca_input_encodings.jl")

# Julia < 1.9 support
if !isdefined(Base, :get_extension)
include("../ext/RCMLJLinearModelsExt.jl")
include("../ext/RCLIBSVMExt.jl")
end

end #module
5 changes: 3 additions & 2 deletions src/esn/esn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,11 @@ trained_esn = train(esn, target_data, training_method = StandardRidge(1.0))
"""
function train(esn::AbstractEchoStateNetwork,
target_data,
training_method = StandardRidge(0.0))
training_method = StandardRidge();
kwargs...)
states_new = esn.states_type(esn.nla_type, esn.states, esn.train_data[:, 1:end])

return _train(states_new, target_data, training_method)
return train(training_method, states_new, target_data; kwargs...)
end

#function pad_esnstate(variation::Hybrid, states_type, x_pad, x, model_prediction_data)
Expand Down
5 changes: 3 additions & 2 deletions src/esn/hybridesn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,10 @@ end

function train(hesn::HybridESN,
target_data,
training_method = StandardRidge(0.0))
training_method = StandardRidge();
kwargs...)
states = vcat(hesn.states, hesn.model.model_data[:, 2:end])
states_new = hesn.states_type(hesn.nla_type, states, hesn.train_data[:, 1:end])

return _train(states_new, target_data, training_method)
return train(training_method, states_new, target_data; kwargs...)
end
14 changes: 1 addition & 13 deletions src/predict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,10 @@ function obtain_prediction(rc::AbstractReservoirComputer,
end

#linear models
function get_prediction(training_method::AbstractLinearModel, output_layer, x)
function get_prediction(training_method, output_layer, x)
return output_layer.output_matrix * x
end

#support vector regression
function get_prediction(training_method::LIBSVM.AbstractSVR, output_layer, x)
out = zeros(output_layer.out_size)

for i in 1:(output_layer.out_size)
x_new = reshape(x, 1, length(x))
out[i] = LIBSVM.predict(output_layer.output_matrix[i], x_new)[1]
end

return out
end

#single matrix for other training methods
function output_storing(training_method, out_size, prediction_len, storing_type)
return Adapt.adapt(storing_type, zeros(out_size, prediction_len))
Expand Down
4 changes: 2 additions & 2 deletions src/reca/reca.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ function RECA(train_data,
end

#training dispatch
function train(reca::AbstractReca, target_data, training_method = StandardRidge(0.0))
function train(reca::AbstractReca, target_data, training_method = StandardRidge; kwargs...)
states_new = reca.states_type(reca.nla_type, reca.states, reca.train_data)
return _train(states_new, target_data, training_method)
return train(training_method, states_new, target_data; kwargs...)
end

#predict dispatch
Expand Down
72 changes: 15 additions & 57 deletions src/train/linear_regression.jl
Original file line number Diff line number Diff line change
@@ -1,64 +1,22 @@
struct StandardRidge{T} <: AbstractLinearModel
regularization_coeff::T
struct StandardRidge
reg::Number
end

"""
StandardRidge(regularization_coeff)
StandardRidge(;regularization_coeff=0.0)
Ridge regression training for all the models in the library. The
`regularization_coeff` is the regularization, it can be passed as an arg or kwarg.
"""
function StandardRidge(; regularization_coeff = 0.0)
return StandardRidge(regularization_coeff)
end

#default training - OLS
function _train(states, target_data, sr::StandardRidge = StandardRidge(0.0))
output_layer = ((states * states' + sr.regularization_coeff * I) \
(states * target_data'))'
#output_layer = (target_data*states')*inv(states*states'+sr.regularization_coeff*I)
return OutputLayer(sr, output_layer, size(target_data, 1), target_data[:, end])
function StandardRidge(::Type{T}, reg) where {T <: Number}
return StandardRidge(T.(reg))
end

#mlj interface
struct LinearModel{T, S, K} <: AbstractLinearModel
regression::T
solver::S
regression_kwargs::K
function StandardRidge()
return StandardRidge(0.0)
end

"""
LinearModel(;regression=LinearRegression,
solver=Analytical(),
regression_kwargs=(;))
Linear regression training based on
[MLJLinearModels](https://juliaai.github.io/MLJLinearModels.jl/stable/) for all the
models in the library. All the parameters have to be passed into `regression_kwargs`,
apart from the solver choice. MLJLinearModels.jl needs to be called in order
to use these models.
"""
function LinearModel(; regression = LinearRegression,
solver = Analytical(),
regression_kwargs = (;))
return LinearModel(regression, solver, regression_kwargs)
end

function LinearModel(regression;
solver = Analytical(),
regression_kwargs = (;))
return LinearModel(regression, solver, regression_kwargs)
end

function _train(states, target_data, linear::LinearModel)
out_size = size(target_data, 1)
output_layer = zeros(size(target_data, 1), size(states, 1))
for i in 1:size(target_data, 1)
regressor = linear.regression(; fit_intercept = false, linear.regression_kwargs...)
output_layer[i, :] = MLJLinearModels.fit(regressor, states',
target_data[i, :], solver = linear.solver)
end

return OutputLayer(linear, output_layer, out_size, target_data[:, end])
function train(sr::StandardRidge,
states::AbstractArray{T},
target_data::AbstractArray{T}) where {T <: Number}
#A = states * states' + sr.reg * I
#b = states * target_data
#output_layer = (A \ b)'
output_layer = Matrix(((states * states' + sr.reg * I) \
(states * target_data'))')
return OutputLayer(sr, output_layer, size(target_data, 1), target_data[:, end])
end
11 changes: 0 additions & 11 deletions src/train/supportvector_regression.jl

This file was deleted.

30 changes: 18 additions & 12 deletions test/esn/test_train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,24 @@ const reg = 10e-6
Random.seed!(77)
res = rand_sparse(; radius = 1.2, sparsity = 0.1)
esn = ESN(input_data, 1, res_size;
reservoir = rand_sparse)

training_methods = [
StandardRidge(regularization_coeff = reg),
LinearModel(RidgeRegression, regression_kwargs = (; lambda = reg)),
LinearModel(regression = RidgeRegression, regression_kwargs = (; lambda = reg)),
EpsilonSVR()
]
reservoir = res)
# different models that implement a train dispatch
# TODO add classification
linear_training = [StandardRidge(0.0), LinearRegression(; fit_intercept = false),
RidgeRegression(; fit_intercept = false), LassoRegression(; fit_intercept = false),
ElasticNetRegression(; fit_intercept = false), HuberRegression(; fit_intercept = false),
QuantileRegression(; fit_intercept = false), LADRegression(; fit_intercept = false)]
svm_training = [EpsilonSVR(), NuSVR()]

# TODO check types
@testset "Training Algo Tests: $ta" for ta in training_methods
output_layer = train(esn, target_data, ta)
output = esn(Predictive(input_data), output_layer)
@test mean(abs.(target_data .- output)) ./ mean(abs.(target_data)) < 0.22
@testset "Linear training: $lt" for lt in linear_training
output_layer = train(esn, target_data, lt)
@test output_layer isa OutputLayer
@test output_layer.output_matrix isa AbstractArray
end

@testset "SVM training: $st" for st in svm_training
output_layer = train(esn, target_data, st)
@test output_layer isa OutputLayer
@test output_layer.output_matrix isa typeof(st)
end

0 comments on commit 21948f7

Please sign in to comment.