-
-
Notifications
You must be signed in to change notification settings - Fork 39
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #205 from SciML/fm/tr
Changes to training
- Loading branch information
Showing
11 changed files
with
115 additions
and
110 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
module RCLIBSVMExt | ||
using ReservoirComputing | ||
using LIBSVM | ||
|
||
function ReservoirComputing.train(svr::LIBSVM.AbstractSVR, states, target) | ||
out_size = size(target, 1) | ||
output_matrix = [] | ||
|
||
if out_size == 1 | ||
output_matrix = LIBSVM.fit!(svr, states', vec(target)) | ||
else | ||
for i in 1:out_size | ||
push!(output_matrix, LIBSVM.fit!(svr, states', target[i, :])) | ||
end | ||
end | ||
|
||
return OutputLayer(svr, output_matrix, out_size, target[:, end]) | ||
end | ||
|
||
function ReservoirComputing.get_prediction( | ||
training_method::LIBSVM.AbstractSVR, output_layer, x) | ||
out = zeros(output_layer.out_size) | ||
|
||
for i in 1:(output_layer.out_size) | ||
x_new = reshape(x, 1, length(x)) | ||
out[i] = LIBSVM.predict(output_layer.output_matrix[i], x_new)[1] | ||
end | ||
|
||
return out | ||
end | ||
|
||
end #module |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
module RCMLJLinearModelsExt | ||
using ReservoirComputing | ||
using MLJLinearModels | ||
|
||
function ReservoirComputing.train(regressor::MLJLinearModels.GeneralizedLinearRegression, | ||
states::AbstractArray{T}, | ||
target::AbstractArray{T}; | ||
kwargs...) where {T <: Number} | ||
out_size = size(target, 1) | ||
output_layer = similar(target, size(target, 1), size(states, 1)) | ||
|
||
if regressor.fit_intercept | ||
throw(ArgumentError("fit_intercept=true is not yet supported. | ||
Please add fit_intercept=false to the MLJ regressor")) | ||
end | ||
|
||
for i in axes(target, 1) | ||
output_layer[i, :] = MLJLinearModels.fit(regressor, states', | ||
target[i, :]; kwargs...) | ||
end | ||
|
||
return OutputLayer(regressor, output_layer, out_size, target[:, end]) | ||
end | ||
|
||
end #module |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,64 +1,22 @@ | ||
struct StandardRidge{T} <: AbstractLinearModel | ||
regularization_coeff::T | ||
struct StandardRidge | ||
reg::Number | ||
end | ||
|
||
""" | ||
StandardRidge(regularization_coeff) | ||
StandardRidge(;regularization_coeff=0.0) | ||
Ridge regression training for all the models in the library. The | ||
`regularization_coeff` is the regularization, it can be passed as an arg or kwarg. | ||
""" | ||
function StandardRidge(; regularization_coeff = 0.0) | ||
return StandardRidge(regularization_coeff) | ||
end | ||
|
||
#default training - OLS | ||
function _train(states, target_data, sr::StandardRidge = StandardRidge(0.0)) | ||
output_layer = ((states * states' + sr.regularization_coeff * I) \ | ||
(states * target_data'))' | ||
#output_layer = (target_data*states')*inv(states*states'+sr.regularization_coeff*I) | ||
return OutputLayer(sr, output_layer, size(target_data, 1), target_data[:, end]) | ||
function StandardRidge(::Type{T}, reg) where {T <: Number} | ||
return StandardRidge(T.(reg)) | ||
end | ||
|
||
#mlj interface | ||
struct LinearModel{T, S, K} <: AbstractLinearModel | ||
regression::T | ||
solver::S | ||
regression_kwargs::K | ||
function StandardRidge() | ||
return StandardRidge(0.0) | ||
end | ||
|
||
""" | ||
LinearModel(;regression=LinearRegression, | ||
solver=Analytical(), | ||
regression_kwargs=(;)) | ||
Linear regression training based on | ||
[MLJLinearModels](https://juliaai.github.io/MLJLinearModels.jl/stable/) for all the | ||
models in the library. All the parameters have to be passed into `regression_kwargs`, | ||
apart from the solver choice. MLJLinearModels.jl needs to be called in order | ||
to use these models. | ||
""" | ||
function LinearModel(; regression = LinearRegression, | ||
solver = Analytical(), | ||
regression_kwargs = (;)) | ||
return LinearModel(regression, solver, regression_kwargs) | ||
end | ||
|
||
function LinearModel(regression; | ||
solver = Analytical(), | ||
regression_kwargs = (;)) | ||
return LinearModel(regression, solver, regression_kwargs) | ||
end | ||
|
||
function _train(states, target_data, linear::LinearModel) | ||
out_size = size(target_data, 1) | ||
output_layer = zeros(size(target_data, 1), size(states, 1)) | ||
for i in 1:size(target_data, 1) | ||
regressor = linear.regression(; fit_intercept = false, linear.regression_kwargs...) | ||
output_layer[i, :] = MLJLinearModels.fit(regressor, states', | ||
target_data[i, :], solver = linear.solver) | ||
end | ||
|
||
return OutputLayer(linear, output_layer, out_size, target_data[:, end]) | ||
function train(sr::StandardRidge, | ||
states::AbstractArray{T}, | ||
target_data::AbstractArray{T}) where {T <: Number} | ||
#A = states * states' + sr.reg * I | ||
#b = states * target_data | ||
#output_layer = (A \ b)' | ||
output_layer = Matrix(((states * states' + sr.reg * I) \ | ||
(states * target_data'))') | ||
return OutputLayer(sr, output_layer, size(target_data, 1), target_data[:, end]) | ||
end |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters