-
-
Notifications
You must be signed in to change notification settings - Fork 39
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
195f61b
commit 7fd059d
Showing
5 changed files
with
76 additions
and
56 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,23 +1,24 @@ | ||
using ReservoirComputing, Random | ||
using ReservoirComputing | ||
|
||
const res_size = 20 | ||
const ts = 0.0:0.1:50.0 | ||
const data = sin.(ts) | ||
const train_len = 400 | ||
const input_data = reduce(hcat, data[1:(train_len - 1)]) | ||
const target_data = reduce(hcat, data[2:train_len]) | ||
const predict_len = 100 | ||
const test = reduce(hcat, data[(train_len + 1):(train_len + predict_len)]) | ||
const training_method = StandardRidge(10e-6) | ||
states = [1, 2, 3, 4, 5, 6, 7, 8, 9] | ||
nla1_states = [1, 2, 9, 4, 25, 6, 49, 8, 81] | ||
nla2_states = [1, 2, 2, 4, 12, 6, 30, 8, 9] | ||
nla3_states = [1, 2, 8, 4, 24, 6, 48, 8, 9] | ||
|
||
nlas = [NLADefault(), NLAT1(), NLAT2(), NLAT3()] | ||
|
||
for n in nlas | ||
Random.seed!(77) | ||
esn = ESN(input_data; | ||
reservoir = RandSparseReservoir(res_size, 1.2, 0.1), | ||
nla_type = n) | ||
output_layer = train(esn, target_data, training_method) | ||
output = esn(Generative(predict_len), output_layer) | ||
@test maximum(abs.(test .- output)) ./ maximum(abs.(test)) < 0.1 | ||
test_types = [Float64, Float32, Float16] | ||
|
||
for tt in test_types | ||
# test default | ||
nla_states = ReservoirComputing.nla(NLADefault(), tt.(states)) | ||
@test nla_states == tt.(states) | ||
# test NLAT1 | ||
nla_states = ReservoirComputing.nla(NLAT1(), tt.(states)) | ||
@test nla_states = tt.(nla1_states) | ||
# test nlat2 | ||
nla_states = ReservoirComputing.nla(NLAT2(), tt.(states)) | ||
@test nla_states = tt.(nla2_states) | ||
# test nlat3 | ||
nla_states = ReservoirComputing.nla(NLAT3(), tt.(states)) | ||
@test nla_states = tt.(nla3_states) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,22 +1,50 @@ | ||
using ReservoirComputing, Random | ||
|
||
const res_size = 20 | ||
const ts = 0.0:0.1:50.0 | ||
const data = sin.(ts) | ||
const train_len = 400 | ||
const input_data = reduce(hcat, data[1:(train_len - 1)]) | ||
const target_data = reduce(hcat, data[2:train_len]) | ||
const predict_len = 100 | ||
const test_data = reduce(hcat, data[(train_len + 1):(train_len + predict_len)]) | ||
const training_method = StandardRidge(10e-6) | ||
using ReservoirComputing | ||
|
||
test_types = [Float64, Float32, Float16] | ||
states = [1, 2, 3, 4, 5, 6, 7, 8, 9] | ||
in_data = fill(1, 3) | ||
|
||
states_types = [StandardStates, ExtendedStates, PaddedStates, PaddedExtendedStates] | ||
|
||
for t in states_types | ||
Random.seed!(77) | ||
esn = ESN(input_data; | ||
reservoir = RandSparseReservoir(res_size, 1.2, 0.1)) | ||
output_layer = train(esn, target_data, training_method) | ||
output = esn(Generative(predict_len), output_layer) | ||
@test maximum(abs.(test_data .- output)) ./ maximum(abs.(test_data)) < 0.1 | ||
# testing extension and padding | ||
for tt in test_types | ||
st_states = StandardStates()(NLADefault(), tt.(states), tt.(in_data)) | ||
@test length(st_states) == length(states) | ||
@test typeof(st_states) == typeof(tt.(states)) | ||
|
||
st_states = ExtendedStates()(NLADefault(), tt.(states), tt.(in_data)) | ||
@test length(st_states) == length(states) + length(in_data) | ||
@test typeof(st_states) == typeof(tt.(states)) | ||
|
||
st_states = PaddedStates()(NLADefault(), tt.(states), tt.(in_data)) | ||
@test length(st_states) == length(states) + 1 | ||
@test typeof(st_states[1]) == typeof(tt.(states)[1]) | ||
|
||
st_states = PaddedExtendedStates()(NLADefault(), tt.(states), tt.(in_data)) | ||
@test length(st_states) == length(states) + length(in_data) + 1 | ||
@test typeof(st_states[1]) == typeof(tt.(states)[1]) | ||
end | ||
|
||
|
||
|
||
## testing non linear algos | ||
nla1_states = [1, 2, 9, 4, 25, 6, 49, 8, 81] | ||
nla2_states = [1, 2, 2, 4, 12, 6, 30, 8, 9] | ||
nla3_states = [1, 2, 8, 4, 24, 6, 48, 8, 9] | ||
|
||
|
||
|
||
for tt in test_types | ||
# test default | ||
nla_states = ReservoirComputing.nla(NLADefault(), tt.(states)) | ||
@test nla_states == tt.(states) | ||
# test NLAT1 | ||
nla_states = ReservoirComputing.nla(NLAT1(), tt.(states)) | ||
@test nla_states == tt.(nla1_states) | ||
# test nlat2 | ||
nla_states = ReservoirComputing.nla(NLAT2(), tt.(states)) | ||
@test nla_states == tt.(nla2_states) | ||
# test nlat3 | ||
nla_states = ReservoirComputing.nla(NLAT3(), tt.(states)) | ||
@test nla_states == tt.(nla3_states) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters