Skip to content

Commit

Permalink
Merge pull request #93 from SciML/esnfitted
Browse files Browse the repository at this point in the history
Added ESNfitted function
  • Loading branch information
MartinuzziFrancesco authored May 19, 2021
2 parents a229fa0 + 6f7647e commit 6fc79b2
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ReservoirComputing"
uuid = "7c2d2b1e-3dd4-11ea-355a-8f6a8116e294"
authors = ["Francesco Martinuzzi"]
version = "0.6.2"
version = "0.6.3"

[deps]
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Expand Down
2 changes: 1 addition & 1 deletion src/ReservoirComputing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ include("esn_reservoirs.jl")
export init_reservoir_givendeg, init_reservoir_givensp, pseudoSVD, DLR, DLRB, SCR, CRJ

include("echostatenetwork.jl")
export ESN, ESNpredict, ESNpredict_h_steps
export ESN, ESNpredict, ESNpredict_h_steps, ESNfitted

include("dafesn.jl")
export dafESN, dafESNpredict, dafESNpredict_h_steps
Expand Down
60 changes: 60 additions & 0 deletions src/echostatenetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,63 @@ function ESNpredict_h_steps(esn::AbstractLeakyESN,
end
return output
end

"""
ESNfitted(esn::AbstractLeakyESN, W_out::Matrix; autonomous=false)
Return the prediction for the training data using the trained output layer. The autonomous trigger can be used to have have it return an autonomous prediction starting from the first point if true, or a point by point prediction if false.
"""

function ESNfitted(esn::AbstractLeakyESN, W_out::Matrix; autonomous=false)
train_len = size(esn.train_data, 2)
output = zeros(Float64, esn.in_size, train_len)
x = zeros(size(esn.states, 1))

if autonomous
out = esn.train_data[:,1]
return _fitted!(output, esn, x, train_len, W_out, out)
else
return _fitted!(output, esn, x, train_len, W_out, esn.train_data)
end
end

function _fitted!(output, esn, state, train_len, W_out, vector::Vector)
if esn.extended_states == false
for i=1:train_len
state = leaky_fixed_rnn(esn.activation, esn.alpha, esn.W, esn.W_in, state, vector)
x_new = nla(esn.nla_type, state)
vector = (W_out*x_new)
output[:, i] = vector
end
elseif esn.extended_states == true
for i=1:train_len
state = vcat(leaky_fixed_rnn(esn.activation, esn.alpha, esn.W, esn.W_in, state[1:esn.res_size], vector), vector)
x_new = nla(esn.nla_type, state)
vector = (W_out*x_new)
output[:, i] = vector
end
end
return output
end

function _fitted!(output, esn, state, train_len, W_out, vector::Matrix)
if esn.extended_states == false
for i=1:train_len
state = leaky_fixed_rnn(esn.activation, esn.alpha, esn.W, esn.W_in, state, vector[:,i])
x_new = nla(esn.nla_type, state)
out = (W_out*x_new)
output[:, i] = out
end
elseif esn.extended_states == true
for i=1:train_len
state = vcat(leaky_fixed_rnn(esn.activation, esn.alpha, esn.W, esn.W_in, state[1:esn.res_size], vector[:,i]), vector[:,i])
x_new = nla(esn.nla_type, state)
out = (W_out*x_new)
output[:, i] = out
end
end
return output
end



7 changes: 7 additions & 0 deletions test/extras/test_extended_states.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ output = ESNpredict(esn, predict_len, W_out)
output = ESNpredict_h_steps(esn, predict_len, h_steps, test, W_out)
@test size(output) == (out_size, predict_len)

#test esnfitted
fit1 = ESNfitted(esn, W_out; autonomous=false)
@test size(fit1) == size(train)

fit2 = ESNfitted(esn, W_out; autonomous=true)
@test size(fit1) == size(train)

#test esgp
mean = MeanZero()
kernel = Lin(1.0)
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ using SafeTestsets
@time @safetestset "reca gol predict" begin include("training/test_recagol.jl") end
@time @safetestset "RMM constructors" begin include("constructors/test_rmm_constructors.jl") end
@time @safetestset "GRUESN constructors" begin include("constructors/test_gruesn_constructors.jl") end
@time @safetestset "ESN fitted" begin include("training/test_esnfitted.jl") end
41 changes: 41 additions & 0 deletions test/training/test_esnfitted.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
using ReservoirComputing
using MLJLinearModels
#model parameters
const approx_res_size = 30
const radius = 1.2
const activation = tanh
const degree = 6
const sigma = 0.1
const beta = 0.0
const alpha = 1.0
const nla_type = NLADefault()
const in_size = 3
const out_size = 3
const extended_states = false
const delta = 0.5


const train_len = 50
const predict_len = 12
data = ones(Float64, in_size, 100)
train = data[:, 1:1+train_len-1]
test = data[:, train_len:train_len+predict_len-1]

#constructor 1
esn = ESN(approx_res_size,
train,
degree,
radius,
activation = activation,
sigma = sigma,
alpha = alpha,
nla_type = nla_type,
extended_states = extended_states)

W_out = ESNtrain(esn, beta)

fit1 = ESNfitted(esn, W_out; autonomous=false)
@test size(fit1) == size(train)

fit2 = ESNfitted(esn, W_out; autonomous=true)
@test size(fit1) == size(train)

0 comments on commit 6fc79b2

Please sign in to comment.