From a1066bab03dc95b3b0916ae7bec6bc0c344fc31e Mon Sep 17 00:00:00 2001
From: MartinuzziFrancesco <martinuzzi.francesco@gmail.com>
Date: Sat, 13 Jan 2024 19:18:40 +0100
Subject: [PATCH 1/6] initial work for initializers

---
 Project.toml              |   1 +
 src/esn/esn_reservoirs.jl | 423 +++++++-------------------------------
 test/runtests.jl          |  36 +++-
 3 files changed, 97 insertions(+), 363 deletions(-)

diff --git a/Project.toml b/Project.toml
index b3207f6e..b003aecd 100644
--- a/Project.toml
+++ b/Project.toml
@@ -13,6 +13,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
 MLJLinearModels = "6ee0df7b-362f-4a72-a706-9e79364fb692"
 NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
 Optim = "429524aa-4258-5aef-a3af-852621145aeb"
+PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
 SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
 Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
 
diff --git a/src/esn/esn_reservoirs.jl b/src/esn/esn_reservoirs.jl
index e014e4e7..047fb0bf 100644
--- a/src/esn/esn_reservoirs.jl
+++ b/src/esn/esn_reservoirs.jl
@@ -1,393 +1,108 @@
-abstract type AbstractReservoir end
-
-function get_ressize(reservoir::AbstractReservoir)
-    return reservoir.res_size
-end
-
 function get_ressize(reservoir)
     return size(reservoir, 1)
 end
 
-struct RandSparseReservoir{T, C} <: AbstractReservoir
-    res_size::Int
-    radius::T
-    sparsity::C
-end
-
-"""
-    RandSparseReservoir(res_size, radius, sparsity)
-    RandSparseReservoir(res_size; radius=1.0, sparsity=0.1)
-
-
-Returns a random sparse reservoir initializer, which generates a matrix of size `res_size x res_size` with the specified `sparsity` and scaled spectral radius according to `radius`. This type of reservoir initializer is commonly used in Echo State Networks (ESNs) for capturing complex temporal dependencies.
-
-# Arguments
-- `res_size`: The size of the reservoir matrix.
-- `radius`: The desired spectral radius of the reservoir. By default, it is set to 1.0.
-- `sparsity`: The sparsity level of the reservoir matrix, controlling the fraction of zero elements. By default, it is set to 0.1.
-
-# Returns
-A RandSparseReservoir object that can be used as a reservoir initializer in ESN construction.
-
-# References
-This type of reservoir initialization is a common choice in ESN construction for its ability to capture temporal dependencies in data. However, there is no specific reference associated with this function.
-"""
-function RandSparseReservoir(res_size; radius = 1.0, sparsity = 0.1)
-    return RandSparseReservoir(res_size, radius, sparsity)
-end
-
 """
-    create_reservoir(reservoir::AbstractReservoir, res_size)
-    create_reservoir(reservoir, args...)
+    rand_sparse(rng::AbstractRNG, ::Type{T}, dims::Integer...; radius=1.0, sparsity=0.1)
 
-Given an `AbstractReservoir` constructor and the size of the reservoir (`res_size`), this function returns the corresponding reservoir matrix. Alternatively, it accepts a pre-generated matrix.
+Create and return a random sparse reservoir matrix for use in Echo State Networks (ESNs). The matrix will be of size specified by `dims`, with specified `sparsity` and scaled spectral radius according to `radius`.
 
 # Arguments
-- `reservoir`: An `AbstractReservoir` object or constructor.
-- `res_size`: The size of the reservoir matrix.
-- `matrix_type`: The type of the resulting matrix. By default, it is set to `Matrix{Float64}`.
+- `rng`: An instance of `AbstractRNG` for random number generation.
+- `T`: The data type for the elements of the matrix.
+- `dims`: Dimensions of the reservoir matrix.
+- `radius`: The desired spectral radius of the reservoir. Defaults to 1.0.
+- `sparsity`: The sparsity level of the reservoir matrix, controlling the fraction of zero elements. Defaults to 0.1.
 
 # Returns
-A matrix representing the reservoir, generated based on the properties of the specified `reservoir` object or constructor.
+A matrix representing the random sparse reservoir.
 
 # References
-The choice of reservoir initialization is crucial in Echo State Networks (ESNs) for achieving effective temporal modeling. Specific references for reservoir initialization methods may vary based on the type of reservoir used, but the practice of initializing reservoirs for ESNs is widely documented in the ESN literature.
-"""
-function create_reservoir(reservoir::RandSparseReservoir,
-        res_size;
-        matrix_type = Matrix{Float64})
-    reservoir_matrix = Matrix(sprand(res_size, res_size, reservoir.sparsity))
+This type of reservoir initialization is commonly used in ESNs for capturing temporal dependencies in data.
+"""
+function rand_sparse(rng::AbstractRNG,
+        ::Type{T},
+        dims::Integer...;
+        radius = 1.0,
+        sparsity = 0.1) where {T <: Number}
+    reservoir_matrix = Matrix{T}(sprand(rng, dims..., sparsity))
     reservoir_matrix = 2.0 .* (reservoir_matrix .- 0.5)
     replace!(reservoir_matrix, -1.0 => 0.0)
     rho_w = maximum(abs.(eigvals(reservoir_matrix)))
-    reservoir_matrix .*= reservoir.radius / rho_w
-    #TODO: change to explicit if
-    Inf in unique(reservoir_matrix) || -Inf in unique(reservoir_matrix) ?
-    error("Sparsity too low for size of the matrix.
-          Increase res_size or increase sparsity") : nothing
-    return Adapt.adapt(matrix_type, reservoir_matrix)
-end
-
-function create_reservoir(reservoir, args...; kwargs...)
-    return reservoir
-end
-
-#=
-function create_reservoir(res_size, reservoir::RandReservoir)
-    sparsity = degree/res_size
-    W = Matrix(sprand(Float64, res_size, res_size, sparsity))
-    W = 2.0 .*(W.-0.5)
-    replace!(W, -1.0=>0.0)
-    rho_w = maximum(abs.(eigvals(W)))
-    W .*= radius/rho_w
-    W
-end
-=#
-
-struct PseudoSVDReservoir{T, C} <: AbstractReservoir
-    res_size::Int
-    max_value::T
-    sparsity::C
-    sorted::Bool
-    reverse_sort::Bool
-end
-
-function PseudoSVDReservoir(res_size;
-        max_value = 1.0,
-        sparsity = 0.1,
-        sorted = true,
-        reverse_sort = false)
-    return PseudoSVDReservoir(res_size, max_value, sparsity, sorted, reverse_sort)
+    reservoir_matrix .*= radius / rho_w
+    if Inf in unique(reservoir_matrix) || -Inf in unique(reservoir_matrix)
+        error("Sparsity too low for size of the matrix. Increase res_size or increase sparsity")
+    end
+    return reservoir_matrix
 end
 
 """
-    PseudoSVDReservoir(max_value, sparsity, sorted, reverse_sort)
-    PseudoSVDReservoir(max_value, sparsity; sorted=true, reverse_sort=false)
+    delay_line(rng::AbstractRNG, ::Type{T}, dims::Integer...; weight=0.1) where {T <: Number}
 
-Returns an initializer to build a sparse reservoir matrix with the given `sparsity` by using a pseudo-SVD approach as described in [^yang].
+Create and return a delay line reservoir matrix for use in Echo State Networks (ESNs). A delay line reservoir is a deterministic structure where each unit is connected only to its immediate predecessor with a specified weight. This method is particularly useful for tasks that require specific temporal processing.
 
 # Arguments
-- `res_size`: The size of the reservoir matrix.
-- `max_value`: The maximum absolute value of elements in the matrix.
-- `sparsity`: The desired sparsity level of the reservoir matrix.
-- `sorted`: A boolean indicating whether to sort the singular values before creating the diagonal matrix. By default, it is set to `true`.
-- `reverse_sort`: A boolean indicating whether to reverse the sorted singular values. By default, it is set to `false`.
+- `rng`: An instance of `AbstractRNG` for random number generation. This argument is not used in the current implementation but is included for consistency with other initialization functions.
+- `T`: The data type for the elements of the matrix.
+- `dims`: Dimensions of the reservoir matrix. Typically, this should be a tuple of two equal integers representing a square matrix.
+- `weight`: The weight determines the absolute value of all connections in the reservoir. Defaults to 0.1.
 
 # Returns
-A PseudoSVDReservoir object that can be used as a reservoir initializer in ESN construction.
-
-# References
-This reservoir initialization method, based on a pseudo-SVD approach, is inspired by the work in [^yang], which focuses on designing polynomial echo state networks for time series prediction.
-
-[^yang]: Yang, Cuili, et al. "_Design of polynomial echo state networks for time series prediction._" Neurocomputing 290 (2018): 148-160.
-"""
-function PseudoSVDReservoir(res_size, max_value, sparsity; sorted = true,
-        reverse_sort = false)
-    return PseudoSVDReservoir(res_size, max_value, sparsity, sorted, reverse_sort)
-end
+A delay line reservoir matrix with dimensions specified by `dims`. The matrix is initialized such that each element in the `i+1`th row and `i`th column is set to `weight`, and all other elements are zeros.
 
-function create_reservoir(reservoir::PseudoSVDReservoir,
-        res_size;
-        matrix_type = Matrix{Float64})
-    sorted = reservoir.sorted
-    reverse_sort = reservoir.reverse_sort
-    reservoir_matrix = create_diag(res_size, reservoir.max_value, sorted = sorted,
-        reverse_sort = reverse_sort)
-    tmp_sparsity = get_sparsity(reservoir_matrix, res_size)
+# Example
+```julia
+reservoir = delay_line(Float64, 100, 100; weight=0.2)
+```
 
-    while tmp_sparsity <= reservoir.sparsity
-        reservoir_matrix *= create_qmatrix(res_size, rand(1:res_size), rand(1:res_size),
-            rand() * 2 - 1)
-        tmp_sparsity = get_sparsity(reservoir_matrix, res_size)
+# References
+This type of reservoir initialization is described in:
+Rodan, Ali, and Peter Tino. "Minimum complexity echo state network." IEEE Transactions on Neural Networks 22.1 (2010): 131-144.
+"""
+function delay_line(rng::AbstractRNG,
+        ::Type{T},
+        dims::Integer...;
+        weight = 0.1) where {T <: Number}
+    reservoir_matrix = zeros(T, dims...)
+    @assert length(dims) == 2 && dims[1] == dims[2],
+    "The dimensions must define a square matrix (e.g., (100, 100))"
+
+    for i in 1:(dims[1] - 1)
+        reservoir_matrix[i + 1, i] = weight
     end
 
-    return Adapt.adapt(matrix_type, reservoir_matrix)
+    return reservoir_matrix
 end
 
-function create_diag(dim, max_value; sorted = true, reverse_sort = false)
-    diagonal_matrix = zeros(dim, dim)
-    if sorted == true
-        if reverse_sort == true
-            diagonal_values = sort(rand(dim) .* max_value, rev = true)
-            diagonal_values[1] = max_value
-        else
-            diagonal_values = sort(rand(dim) .* max_value)
-            diagonal_values[end] = max_value
-        end
-    else
-        diagonal_values = rand(dim) .* max_value
+for initializer in (:rand_sparse, :delay_line)
+    NType = ifelse(initializer === :rand_sparse, Real, Number)
+    @eval function ($initializer)(dims::Integer...; kwargs...)
+        return $initializer(_default_rng(), Float32, dims...; kwargs...)
     end
-
-    for i in 1:dim
-        diagonal_matrix[i, i] = diagonal_values[i]
+    @eval function ($initializer)(rng::AbstractRNG, dims::Integer...; kwargs...)
+        return $initializer(rng, Float32, dims...; kwargs...)
     end
-
-    return diagonal_matrix
-end
-
-function create_qmatrix(dim, coord_i, coord_j, theta)
-    qmatrix = zeros(dim, dim)
-
-    for i in 1:dim
-        qmatrix[i, i] = 1.0
+    @eval function ($initializer)(::Type{T},
+            dims::Integer...; kwargs...) where {T <: $NType}
+        return $initializer(_default_rng(), T, dims...; kwargs...)
     end
-
-    qmatrix[coord_i, coord_i] = cos(theta)
-    qmatrix[coord_j, coord_j] = cos(theta)
-    qmatrix[coord_i, coord_j] = -sin(theta)
-    qmatrix[coord_j, coord_i] = sin(theta)
-    return qmatrix
-end
-
-function get_sparsity(M, dim)
-    return size(M[M .!= 0], 1) / (dim * dim - size(M[M .!= 0], 1)) #nonzero/zero elements
-end
-
-#from "minimum complexity echo state network" Rodan
-# Delay Line Reservoir
-
-struct DelayLineReservoir{T} <: AbstractReservoir
-    res_size::Int
-    weight::T
-end
-
-"""
-    DelayLineReservoir(res_size, weight)
-    DelayLineReservoir(res_size; weight=0.1)
-
-Returns a Delay Line Reservoir matrix constructor to obtain a deterministic reservoir as
-described in [^Rodan2010].
-
-# Arguments
-- `res_size::Int`: The size of the reservoir.
-- `weight::T`: The weight determines the absolute value of all the connections in the reservoir.
-
-# Returns
-A `DelayLineReservoir` object.
-
-# References
-[^Rodan2010]: Rodan, Ali, and Peter Tino. "Minimum complexity echo state network."
-IEEE transactions on neural networks 22.1 (2010): 131-144.
-"""
-function DelayLineReservoir(res_size; weight = 0.1)
-    return DelayLineReservoir(res_size, weight)
-end
-
-function create_reservoir(reservoir::DelayLineReservoir,
-        res_size;
-        matrix_type = Matrix{Float64})
-    reservoir_matrix = zeros(res_size, res_size)
-
-    for i in 1:(res_size - 1)
-        reservoir_matrix[i + 1, i] = reservoir.weight
+    @eval function ($initializer)(rng::AbstractRNG; kwargs...)
+        return __partial_apply($initializer, (rng, (; kwargs...)))
     end
-
-    return Adapt.adapt(matrix_type, reservoir_matrix)
-end
-
-#from "minimum complexity echo state network" Rodan
-# Delay Line Reservoir with backward connections
-struct DelayLineBackwardReservoir{T} <: AbstractReservoir
-    res_size::Int
-    weight::T
-    fb_weight::T
-end
-
-"""
-    DelayLineBackwardReservoir(res_size, weight, fb_weight)
-    DelayLineBackwardReservoir(res_size; weight=0.1, fb_weight=0.2)
-
-Returns a Delay Line Reservoir constructor to create a matrix with backward connections
-as described in [^Rodan2010]. The `weight` and `fb_weight` can be passed as either arguments or
-keyword arguments, and they determine the absolute values of the connections in the reservoir.
-
-# Arguments
-- `res_size::Int`: The size of the reservoir.
-- `weight::T`: The weight determines the absolute value of forward connections in the reservoir.
-- `fb_weight::T`: The `fb_weight` determines the absolute value of backward connections in the reservoir.
-
-# Returns
-A `DelayLineBackwardReservoir` object.
-
-# References
-[^Rodan2010]: Rodan, Ali, and Peter Tino. "Minimum complexity echo state network."
-IEEE transactions on neural networks 22.1 (2010): 131-144.
-"""
-function DelayLineBackwardReservoir(res_size; weight = 0.1, fb_weight = 0.2)
-    return DelayLineBackwardReservoir(res_size, weight, fb_weight)
-end
-
-function create_reservoir(reservoir::DelayLineBackwardReservoir,
-        res_size;
-        matrix_type = Matrix{Float64})
-    reservoir_matrix = zeros(res_size, res_size)
-
-    for i in 1:(res_size - 1)
-        reservoir_matrix[i + 1, i] = reservoir.weight
-        reservoir_matrix[i, i + 1] = reservoir.fb_weight
-    end
-
-    return Adapt.adapt(matrix_type, reservoir_matrix)
-end
-
-#from "minimum complexity echo state network" Rodan
-# Simple cycle reservoir
-struct SimpleCycleReservoir{T} <: AbstractReservoir
-    res_size::Int
-    weight::T
-end
-
-"""
-    SimpleCycleReservoir(res_size, weight)
-    SimpleCycleReservoir(res_size; weight=0.1)
-
-Returns a Simple Cycle Reservoir constructor to build a reservoir matrix as
-described in [^Rodan2010]. The `weight` can be passed as an argument or a keyword argument, and it determines the
-absolute value of all the connections in the reservoir.
-
-# Arguments
-- `res_size::Int`: The size of the reservoir.
-- `weight::T`: The weight determines the absolute value of connections in the reservoir.
-
-# Returns
-A `SimpleCycleReservoir` object.
-
-# References
-[^Rodan2010]: Rodan, Ali, and Peter Tino. "Minimum complexity echo state network."
-IEEE transactions on neural networks 22.1 (2010): 131-144.
-"""
-function SimpleCycleReservoir(res_size; weight = 0.1)
-    return SimpleCycleReservoir(res_size, weight)
-end
-
-function create_reservoir(reservoir::SimpleCycleReservoir,
-        res_size;
-        matrix_type = Matrix{Float64})
-    reservoir_matrix = zeros(Float64, res_size, res_size)
-
-    for i in 1:(res_size - 1)
-        reservoir_matrix[i + 1, i] = reservoir.weight
+    @eval function ($initializer)(rng::AbstractRNG,
+            ::Type{T}; kwargs...) where {T <: $NType}
+        return __partial_apply($initializer, ((rng, T), (; kwargs...)))
     end
-
-    reservoir_matrix[1, res_size] = reservoir.weight
-    return Adapt.adapt(matrix_type, reservoir_matrix)
-end
-
-#from "simple deterministically constructed cycle reservoirs with regular jumps" by Rodan and Tino
-# Cycle Reservoir with Jumps
-struct CycleJumpsReservoir{T} <: AbstractReservoir
-    res_size::Int
-    cycle_weight::T
-    jump_weight::T
-    jump_size::Int
-end
-
-"""
-    CycleJumpsReservoir(res_size; cycle_weight=0.1, jump_weight=0.1, jump_size=3)
-    CycleJumpsReservoir(res_size, cycle_weight, jump_weight, jump_size)
-
-Return a Cycle Reservoir with Jumps constructor to create a reservoir matrix as described
-in [^Rodan2012]. The `cycle_weight`, `jump_weight`, and `jump_size` can be passed as arguments or keyword arguments, and they
-determine the absolute values of connections in the reservoir. The `jump_size` determines the jumps between `jump_weight`s.
-
-# Arguments
-- `res_size::Int`: The size of the reservoir.
-- `cycle_weight::T`: The weight of cycle connections.
-- `jump_weight::T`: The weight of jump connections.
-- `jump_size::Int`: The number of steps between jump connections.
-
-# Returns
-A `CycleJumpsReservoir` object.
-
-# References
-[^Rodan2012]: Rodan, Ali, and Peter Tiňo. "Simple deterministically constructed cycle reservoirs
-with regular jumps." Neural computation 24.7 (2012): 1822-1852.
-"""
-function CycleJumpsReservoir(res_size; cycle_weight = 0.1, jump_weight = 0.1, jump_size = 3)
-    return CycleJumpsReservoir(res_size, cycle_weight, jump_weight, jump_size)
+    @eval ($initializer)(; kwargs...) = __partial_apply($initializer, (; kwargs...))
 end
 
-function create_reservoir(reservoir::CycleJumpsReservoir,
-        res_size;
-        matrix_type = Matrix{Float64})
-    reservoir_matrix = zeros(res_size, res_size)
-
-    for i in 1:(res_size - 1)
-        reservoir_matrix[i + 1, i] = reservoir.cycle_weight
-    end
-
-    reservoir_matrix[1, res_size] = reservoir.cycle_weight
-
-    for i in 1:(reservoir.jump_size):(res_size - reservoir.jump_size)
-        tmp = (i + reservoir.jump_size) % res_size
-        if tmp == 0
-            tmp = res_size
-        end
-        reservoir_matrix[i, tmp] = reservoir.jump_weight
-        reservoir_matrix[tmp, i] = reservoir.jump_weight
+# from WeightInitializers.jl, TODO @MartinuzziFrancesco consider importing package
+function _default_rng()
+    @static if VERSION >= v"1.7"
+        return Xoshiro(1234)
+    else
+        return MersenneTwister(1234)
     end
-
-    return Adapt.adapt(matrix_type, reservoir_matrix)
 end
 
-"""
-    NullReservoir()
-
-Return a constructor for a matrix of zeros with dimensions `res_size x res_size`.
-
-# Arguments
-- None
-
-# Returns
-A `NullReservoir` object.
-
-# References
-- None
-"""
-struct NullReservoir <: AbstractReservoir end
-
-function create_reservoir(reservoir::NullReservoir,
-        res_size;
-        matrix_type = Matrix{Float64})
-    return Adapt.adapt(matrix_type, zeros(res_size, res_size))
-end
+__partial_apply(fn, inp) = fn$inp
diff --git a/test/runtests.jl b/test/runtests.jl
index b1b28ad1..2d114e99 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -2,19 +2,37 @@ using SafeTestsets
 using Test
 
 @testset "Common Utilities" begin
-    @safetestset "Quality Assurance" begin include("qa.jl") end
-    @safetestset "States" begin include("test_states.jl") end
+    @safetestset "Quality Assurance" begin
+        include("qa.jl")
+    end
+    @safetestset "States" begin
+        include("test_states.jl")
+    end
 end
 
 @testset "Echo State Networks" begin
-    @safetestset "ESN Input Layers" begin include("esn/test_input_layers.jl") end
-    @safetestset "ESN Reservoirs" begin include("esn/test_reservoirs.jl") end
-    @safetestset "ESN States" begin include("esn/test_states.jl") end
-    @safetestset "ESN Train and Predict" begin include("esn/test_train.jl") end
-    @safetestset "ESN Drivers" begin include("esn/test_drivers.jl") end
-    @safetestset "Hybrid ESN" begin include("esn/test_hybrid.jl") end
+    @safetestset "ESN Input Layers" begin
+        include("esn/test_input_layers.jl")
+    end
+    @safetestset "ESN Reservoirs" begin
+        include("esn/test_reservoirs.jl")
+    end
+    @safetestset "ESN States" begin
+        include("esn/test_states.jl")
+    end
+    @safetestset "ESN Train and Predict" begin
+        include("esn/test_train.jl")
+    end
+    @safetestset "ESN Drivers" begin
+        include("esn/test_drivers.jl")
+    end
+    @safetestset "Hybrid ESN" begin
+        include("esn/test_hybrid.jl")
+    end
 end
 
 @testset "CA based Reservoirs" begin
-    @safetestset "RECA" begin include("reca/test_predictive.jl") end
+    @safetestset "RECA" begin
+        include("reca/test_predictive.jl")
+    end
 end

From 7530157e5417be4fb6c2d6695976a9c8f2166fab Mon Sep 17 00:00:00 2001
From: MartinuzziFrancesco <martinuzzi.francesco@gmail.com>
Date: Sat, 13 Jan 2024 19:19:53 +0100
Subject: [PATCH 2/6] exports

---
 src/ReservoirComputing.jl | 4 +---
 1 file changed, 1 insertion(+), 3 deletions(-)

diff --git a/src/ReservoirComputing.jl b/src/ReservoirComputing.jl
index 4c24427a..84b0d145 100644
--- a/src/ReservoirComputing.jl
+++ b/src/ReservoirComputing.jl
@@ -18,9 +18,7 @@ export StandardRidge, LinearModel
 export AbstractLayer, create_layer
 export WeightedLayer, DenseLayer, SparseLayer, MinimumLayer, InformedLayer, NullLayer
 export BernoulliSample, IrrationalSample
-export AbstractReservoir, create_reservoir
-export RandSparseReservoir, PseudoSVDReservoir, DelayLineReservoir
-export DelayLineBackwardReservoir, SimpleCycleReservoir, CycleJumpsReservoir, NullReservoir
+export rand_sparse, delay_line
 export RNN, MRNN, GRU, GRUParams, FullyGated, Minimal
 export ESN, Default, Hybrid, train
 export RECA, train

From ba9ec649002d195c25c18d4d52281fa78fd2ef7a Mon Sep 17 00:00:00 2001
From: MartinuzziFrancesco <martinuzzi.francesco@gmail.com>
Date: Thu, 18 Jan 2024 18:17:21 +0100
Subject: [PATCH 3/6] fixing types and starting tests

---
 Project.toml                |  2 +
 src/ReservoirComputing.jl   |  2 +
 src/esn/esn_reservoirs.jl   | 10 ++--
 test/esn/test_reservoirs.jl | 99 +++++++++++--------------------------
 test/utils.jl               |  5 ++
 5 files changed, 42 insertions(+), 76 deletions(-)
 create mode 100644 test/utils.jl

diff --git a/Project.toml b/Project.toml
index b003aecd..fb57c37a 100644
--- a/Project.toml
+++ b/Project.toml
@@ -14,6 +14,7 @@ MLJLinearModels = "6ee0df7b-362f-4a72-a706-9e79364fb692"
 NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
 Optim = "429524aa-4258-5aef-a3af-852621145aeb"
 PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
+Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
 SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
 Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
 
@@ -29,6 +30,7 @@ LinearAlgebra = "1.10"
 MLJLinearModels = "0.9.2"
 NNlib = "0.8.4, 0.9"
 Optim = "1"
+PartialFunctions = "1.2"
 Random = "1"
 SafeTestsets = "0.1"
 SparseArrays = "1.10"
diff --git a/src/ReservoirComputing.jl b/src/ReservoirComputing.jl
index 84b0d145..c826b3f3 100644
--- a/src/ReservoirComputing.jl
+++ b/src/ReservoirComputing.jl
@@ -9,6 +9,8 @@ using LinearAlgebra
 using MLJLinearModels
 using NNlib
 using Optim
+using PartialFunctions
+using Random
 using SparseArrays
 using Statistics
 
diff --git a/src/esn/esn_reservoirs.jl b/src/esn/esn_reservoirs.jl
index 047fb0bf..fa57109b 100644
--- a/src/esn/esn_reservoirs.jl
+++ b/src/esn/esn_reservoirs.jl
@@ -23,11 +23,11 @@ This type of reservoir initialization is commonly used in ESNs for capturing tem
 function rand_sparse(rng::AbstractRNG,
         ::Type{T},
         dims::Integer...;
-        radius = 1.0,
-        sparsity = 0.1) where {T <: Number}
+        radius = T(1.0),
+        sparsity = T(0.1)) where {T <: Number}
     reservoir_matrix = Matrix{T}(sprand(rng, dims..., sparsity))
-    reservoir_matrix = 2.0 .* (reservoir_matrix .- 0.5)
-    replace!(reservoir_matrix, -1.0 => 0.0)
+    reservoir_matrix = T(2.0) .* (reservoir_matrix .- T(0.5))
+    replace!(reservoir_matrix, T(-1.0) => T(0.0))
     rho_w = maximum(abs.(eigvals(reservoir_matrix)))
     reservoir_matrix .*= radius / rho_w
     if Inf in unique(reservoir_matrix) || -Inf in unique(reservoir_matrix)
@@ -62,7 +62,7 @@ Rodan, Ali, and Peter Tino. "Minimum complexity echo state network." IEEE Transa
 function delay_line(rng::AbstractRNG,
         ::Type{T},
         dims::Integer...;
-        weight = 0.1) where {T <: Number}
+        weight = T(0.1)) where {T <: Number}
     reservoir_matrix = zeros(T, dims...)
     @assert length(dims) == 2 && dims[1] == dims[2],
     "The dimensions must define a square matrix (e.g., (100, 100))"
diff --git a/test/esn/test_reservoirs.jl b/test/esn/test_reservoirs.jl
index ac751712..df58bd3c 100644
--- a/test/esn/test_reservoirs.jl
+++ b/test/esn/test_reservoirs.jl
@@ -1,79 +1,36 @@
 using ReservoirComputing
+using LinearAlgebra
+using Random
+include("../utils.jl")
 
 const res_size = 20
 const radius = 1.0
 const sparsity = 0.1
 const weight = 0.2
 const jump_size = 3
+const rng = Random.default_rng()
+
+dtypes = [Float16, Float32, Float64]
+reservoir_inits = [rand_sparse]
+
+@testset "Sizes and types" begin
+    for init in reservoir_inits
+        for dt in dtypes
+            #sizes
+            @test size(init(res_size, res_size)) == (res_size, res_size)
+            @test size(init(rng, res_size, res_size)) == (res_size, res_size)
+            #types
+            @test eltype(init(dt, res_size, res_size)) == dt
+            @test eltype(init(rng, dt, res_size, res_size)) == dt
+            #closure
+            cl = init(rng)
+            @test cl(dt, res_size, res_size) isa AbstractArray{dt}
+        end
+    end
+end
+
+@testset "rand_sparse" begin
+    sp = rand_sparse(res_size, res_size)
+    @test check_radius(sp, radius)
+end
 
-#testing RandSparseReservoir implicit and esplicit constructors
-reservoir_constructor = RandSparseReservoir(res_size, radius, sparsity)
-reservoir_matrix = create_reservoir(reservoir_constructor, res_size)
-@test size(reservoir_matrix) == (res_size, res_size)
-
-reservoir_constructor = RandSparseReservoir(res_size, radius = radius, sparsity = sparsity)
-reservoir_matrix = create_reservoir(reservoir_constructor, res_size)
-@test size(reservoir_matrix) == (res_size, res_size)
-
-#testing PseudoSVDReservoir implicit and esplicit constructors
-reservoir_constructor = PseudoSVDReservoir(res_size, radius, sparsity)
-reservoir_matrix = create_reservoir(reservoir_constructor, res_size)
-@test size(reservoir_matrix) == (res_size, res_size)
-@test maximum(reservoir_matrix) <= radius
-
-reservoir_constructor = PseudoSVDReservoir(res_size, max_value = radius,
-    sparsity = sparsity)
-reservoir_matrix = create_reservoir(reservoir_constructor, res_size)
-@test size(reservoir_matrix) == (res_size, res_size)
-@test maximum(reservoir_matrix) <= radius
-
-#testing DelayLineReservoir implicit and esplicit constructors
-reservoir_constructor = DelayLineReservoir(res_size, weight)
-reservoir_matrix = create_reservoir(reservoir_constructor, res_size)
-@test size(reservoir_matrix) == (res_size, res_size)
-@test maximum(reservoir_matrix) == weight
-
-reservoir_constructor = DelayLineReservoir(res_size, weight = weight)
-reservoir_matrix = create_reservoir(reservoir_constructor, res_size)
-@test size(reservoir_matrix) == (res_size, res_size)
-@test maximum(reservoir_matrix) == weight
-
-#testing DelayLineReservoir implicit and esplicit constructors
-reservoir_constructor = DelayLineBackwardReservoir(res_size, weight, weight)
-reservoir_matrix = create_reservoir(reservoir_constructor, res_size)
-@test size(reservoir_matrix) == (res_size, res_size)
-@test maximum(reservoir_matrix) == weight
-
-reservoir_constructor = DelayLineBackwardReservoir(res_size, weight = weight,
-    fb_weight = weight)
-reservoir_matrix = create_reservoir(reservoir_constructor, res_size)
-@test size(reservoir_matrix) == (res_size, res_size)
-@test maximum(reservoir_matrix) == weight
-
-#testing SimpleCycleReservoir implicit and esplicit constructors
-reservoir_constructor = SimpleCycleReservoir(res_size, weight)
-reservoir_matrix = create_reservoir(reservoir_constructor, res_size)
-@test size(reservoir_matrix) == (res_size, res_size)
-@test maximum(reservoir_matrix) == weight
-
-reservoir_constructor = SimpleCycleReservoir(res_size, weight = weight)
-reservoir_matrix = create_reservoir(reservoir_constructor, res_size)
-@test size(reservoir_matrix) == (res_size, res_size)
-@test maximum(reservoir_matrix) == weight
-
-#testing CycleJumpsReservoir implicit and esplicit constructors
-reservoir_constructor = CycleJumpsReservoir(res_size, weight, weight, jump_size)
-reservoir_matrix = create_reservoir(reservoir_constructor, res_size)
-@test size(reservoir_matrix) == (res_size, res_size)
-@test maximum(reservoir_matrix) == weight
-
-reservoir_constructor = CycleJumpsReservoir(res_size, cycle_weight = weight,
-    jump_weight = weight, jump_size = jump_size)
-reservoir_matrix = create_reservoir(reservoir_constructor, res_size)
-@test size(reservoir_matrix) == (res_size, res_size)
-@test maximum(reservoir_matrix) == weight
-
-#testing NullReservoir constructors
-reservoir_constructor = NullReservoir()
-reservoir_matrix = create_reservoir(reservoir_constructor, res_size)
-@test size(reservoir_matrix) == (res_size, res_size)
diff --git a/test/utils.jl b/test/utils.jl
new file mode 100644
index 00000000..9ef6f360
--- /dev/null
+++ b/test/utils.jl
@@ -0,0 +1,5 @@
+function check_radius(matrix, target_radius; tolerance=1e-5)
+    eigenvalues = eigvals(matrix)
+    spectral_radius = maximum(abs.(eigenvalues))
+    return isapprox(spectral_radius, target_radius, atol=tolerance)
+end
\ No newline at end of file

From 3a5a62202623a14f2cac040f83bf357cfdbc55e9 Mon Sep 17 00:00:00 2001
From: MartinuzziFrancesco <martinuzzi.francesco@gmail.com>
Date: Sat, 20 Jan 2024 18:10:11 +0100
Subject: [PATCH 4/6] start of input_layers, start of streamline to new api

---
 Project.toml                |   1 +
 src/ReservoirComputing.jl   |  27 ++-
 src/esn/echostatenetwork.jl |  37 ++--
 src/esn/esn_input_layers.jl | 381 ++----------------------------------
 src/esn/esn_reservoirs.jl   |  29 +--
 test/esn/test_reservoirs.jl |   7 +-
 6 files changed, 71 insertions(+), 411 deletions(-)

diff --git a/Project.toml b/Project.toml
index fb57c37a..e0d6bb2b 100644
--- a/Project.toml
+++ b/Project.toml
@@ -17,6 +17,7 @@ PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
 Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
 SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
 Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
+WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"
 
 [compat]
 Adapt = "3.3.3, 4"
diff --git a/src/ReservoirComputing.jl b/src/ReservoirComputing.jl
index c826b3f3..d1e47609 100644
--- a/src/ReservoirComputing.jl
+++ b/src/ReservoirComputing.jl
@@ -13,13 +13,13 @@ using PartialFunctions
 using Random
 using SparseArrays
 using Statistics
+using WeightInitializers
 
 export NLADefault, NLAT1, NLAT2, NLAT3
 export StandardStates, ExtendedStates, PaddedStates, PaddedExtendedStates
 export StandardRidge, LinearModel
 export AbstractLayer, create_layer
-export WeightedLayer, DenseLayer, SparseLayer, MinimumLayer, InformedLayer, NullLayer
-export BernoulliSample, IrrationalSample
+export scaled_rand
 export rand_sparse, delay_line
 export RNN, MRNN, GRU, GRUParams, FullyGated, Minimal
 export ESN, Default, Hybrid, train
@@ -72,6 +72,29 @@ function Predictive(prediction_data)
     Predictive(prediction_data, prediction_len)
 end
 
+#fallbacks for initializers
+for initializer in (:rand_sparse, :delay_line, :scaled_rand)
+    NType = ifelse(initializer === :rand_sparse, Real, Number)
+    @eval function ($initializer)(dims::Integer...; kwargs...)
+        return $initializer(_default_rng(), Float32, dims...; kwargs...)
+    end
+    @eval function ($initializer)(rng::AbstractRNG, dims::Integer...; kwargs...)
+        return $initializer(rng, Float32, dims...; kwargs...)
+    end
+    @eval function ($initializer)(::Type{T},
+            dims::Integer...; kwargs...) where {T <: $NType}
+        return $initializer(_default_rng(), T, dims...; kwargs...)
+    end
+    @eval function ($initializer)(rng::AbstractRNG; kwargs...)
+        return __partial_apply($initializer, (rng, (; kwargs...)))
+    end
+    @eval function ($initializer)(rng::AbstractRNG,
+            ::Type{T}; kwargs...) where {T <: $NType}
+        return __partial_apply($initializer, ((rng, T), (; kwargs...)))
+    end
+    @eval ($initializer)(; kwargs...) = __partial_apply($initializer, (; kwargs...))
+end
+
 #general
 include("states.jl")
 include("predict.jl")
diff --git a/src/esn/echostatenetwork.jl b/src/esn/echostatenetwork.jl
index 42fab481..adbf85a4 100644
--- a/src/esn/echostatenetwork.jl
+++ b/src/esn/echostatenetwork.jl
@@ -90,33 +90,30 @@ train_data = rand(10, 100)  # 10 features, 100 time steps
 esn = ESN(train_data, reservoir=RandSparseReservoir(200), washout=10)
 ```
 """
-function ESN(train_data;
-        variation = Default(),
-        input_layer = DenseLayer(),
-        reservoir = RandSparseReservoir(100),
-        bias = NullLayer(),
-        reservoir_driver = RNN(),
-        nla_type = NLADefault(),
-        states_type = StandardStates(),
-        washout = 0,
-        matrix_type = typeof(train_data))
-    if variation isa Hybrid
-        train_data = vcat(train_data, variation.model_data[:, 1:(end - 1)])
-    end
+function ESN(
+    train_data,
+    in_size,
+    res_size;
+    input_layer = scaled_rand,
+    reservoir = rand_sparse,
+    bias = zeros64,
+    reservoir_driver = RNN(),
+    nla_type = NLADefault(),
+    states_type = StandardStates(),
+    washout = 0,
+    rng = _default_rng(),
+    matrix_type = typeof(train_data)
+) where {T <: Number}
 
     if states_type isa AbstractPaddedStates
         in_size = size(train_data, 1) + 1
         train_data = vcat(Adapt.adapt(matrix_type, ones(1, size(train_data, 2))),
             train_data)
-    else
-        in_size = size(train_data, 1)
     end
 
-    input_matrix, reservoir_matrix, bias_vector, res_size = obtain_layers(in_size,
-        input_layer,
-        reservoir, bias;
-        matrix_type = matrix_type)
-
+    reservoir_matrix = reservoir(rng, T, res_size, res_size)
+    input_matrix = input_layer(rng, T, res_size, in_size)
+    bias_vector = bias(rng, T, res_size)
     inner_res_driver = reservoir_driver_params(reservoir_driver, res_size, in_size)
     states = create_states(inner_res_driver, train_data, washout, reservoir_matrix,
         input_matrix, bias_vector)
diff --git a/src/esn/esn_input_layers.jl b/src/esn/esn_input_layers.jl
index e7bb950c..42b82282 100644
--- a/src/esn/esn_input_layers.jl
+++ b/src/esn/esn_input_layers.jl
@@ -1,371 +1,32 @@
-abstract type AbstractLayer end
-
-struct WeightedLayer{T} <: AbstractLayer
-    scaling::T
-end
-
-"""
-    WeightedInput(scaling)
-    WeightedInput(;scaling=0.1)
-
-Creates a `WeightedInput` layer initializer for Echo State Networks.
-This initializer generates a weighted input matrix with random non-zero
-elements distributed uniformly within the range [-`scaling`, `scaling`],
-following the approach in [^Lu].
-
-# Parameters
-- `scaling`: The scaling factor for the weight distribution (default: 0.1).
-
-# Returns
-- A `WeightedInput` instance to be used for initializing the input layer of an ESN.
-
-Reference:
-[^Lu]: Lu, Zhixin, et al.
-    "Reservoir observers: Model-free inference of unmeasured variables in chaotic systems."
-    Chaos: An Interdisciplinary Journal of Nonlinear Science 27.4 (2017): 041102.
-"""
-function WeightedLayer(; scaling = 0.1)
-    return WeightedLayer(scaling)
-end
-
-function create_layer(input_layer::WeightedLayer,
-        approx_res_size,
-        in_size;
-        matrix_type = Matrix{Float64})
-    scaling = input_layer.scaling
-    res_size = Int(floor(approx_res_size / in_size) * in_size)
-    layer_matrix = zeros(res_size, in_size)
-    q = floor(Int, res_size / in_size)
-
-    for i in 1:in_size
-        layer_matrix[((i - 1) * q + 1):((i) * q), i] = rand(Uniform(-scaling, scaling), 1,
-            q)
-    end
-
-    return Adapt.adapt(matrix_type, layer_matrix)
-end
-
-function create_layer(layer, args...; kwargs...)
-    return layer
-end
-
-"""
-    DenseLayer(scaling)
-    DenseLayer(;scaling=0.1)
-
-Creates a `DenseLayer` initializer for Echo State Networks, generating a fully connected input layer.
-The layer is initialized with random weights uniformly distributed within [-`scaling`, `scaling`].
-This scaling factor can be provided either as an argument or a keyword argument.
-The `DenseLayer` is the default input layer in `ESN` construction.
-
-# Parameters
-- `scaling`: The scaling factor for weight distribution (default: 0.1).
-
-# Returns
-- A `DenseLayer` instance for initializing the ESN's input layer.
-"""
-struct DenseLayer{T} <: AbstractLayer
-    scaling::T
-end
-
-function DenseLayer(; scaling = 0.1)
-    return DenseLayer(scaling)
-end
-
-"""
-    create_layer(input_layer::AbstractLayer, res_size, in_size)
-
-Generates a matrix layer of size `res_size` x `in_size`, constructed according to the specifications of the `input_layer`.
-
-# Parameters
-- `input_layer`: An instance of `AbstractLayer` determining the layer construction.
-- `res_size`: The number of rows (reservoir size) for the layer.
-- `in_size`: The number of columns (input size) for the layer.
-
-# Returns
-- A matrix representing the constructed layer.
-"""
-function create_layer(input_layer::DenseLayer,
-        res_size,
-        in_size;
-        matrix_type = Matrix{Float64})
-    scaling = input_layer.scaling
-    layer_matrix = rand(Uniform(-scaling, scaling), res_size, in_size)
-    return Adapt.adapt(matrix_type, layer_matrix)
-end
-
-"""
-    SparseLayer(scaling, sparsity)
-    SparseLayer(scaling; sparsity=0.1)
-    SparseLayer(;scaling=0.1, sparsity=0.1)
-
-Creates a `SparseLayer` initializer for Echo State Networks, generating a sparse input layer.
-The layer is initialized with weights distributed within [-`scaling`, `scaling`]
-and a specified `sparsity` level. Both `scaling` and `sparsity` can be set as arguments or keyword arguments.
-
-# Parameters
-- `scaling`: Scaling factor for weight distribution (default: 0.1).
-- `sparsity`: Sparsity level of the layer (default: 0.1).
-
-# Returns
-- A `SparseLayer` instance for initializing ESN's input layer with sparse connections.
-"""
-struct SparseLayer{T} <: AbstractLayer
-    scaling::T
-    sparsity::T
-end
-
-function SparseLayer(; scaling = 0.1, sparsity = 0.1)
-    return SparseLayer(scaling, sparsity)
-end
-
-function SparseLayer(scaling_arg; scaling = scaling_arg, sparsity = 0.1)
-    return SparseLayer(scaling, sparsity)
-end
-
-function create_layer(input_layer::SparseLayer,
-        res_size,
-        in_size;
-        matrix_type = Matrix{Float64})
-    layer_matrix = Matrix(sprand(res_size, in_size, input_layer.sparsity))
-    layer_matrix = 2.0 .* (layer_matrix .- 0.5)
-    replace!(layer_matrix, -1.0 => 0.0)
-    layer_matrix = input_layer.scaling .* layer_matrix
-    return Adapt.adapt(matrix_type, layer_matrix)
-end
-
-#from "minimum complexity echo state network" Rodan
-#and "simple deterministically constructed cycle reservoirs with regular jumps"
-#by Rodan and Tino
-struct BernoulliSample{T}
-    p::T
-end
-
-"""
-    BernoulliSample(p)
-    BernoulliSample(;p=0.5)
-
-Creates a `BernoulliSample` constructor for the `MinimumLayer`.
-It uses a Bernoulli distribution to determine the sign of weights in the input layer.
-The parameter `p` sets the probability of a weight being positive, as per the `Distributions` package.
-This method of sign weight determination for input layers is based on the approach in [^Rodan].
-
-# Parameters
-- `p`: Probability of a positive weight (default: 0.5).
-
-# Returns
-- A `BernoulliSample` instance for generating sign weights in `MinimumLayer`.
-
-Reference:
-[^Rodan]: Rodan, Ali, and Peter Tino.
-    "Minimum complexity echo state network." 
-    IEEE Transactions on Neural Networks 22.1 (2010): 131-144.
-"""
-function BernoulliSample(; p = 0.5)
-    return BernoulliSample(p)
-end
-
-struct IrrationalSample{K}
-    irrational::Irrational
-    start::K
-end
-
-"""
-    IrrationalSample(irrational, start)
-    IrrationalSample(;irrational=pi, start=1)
-
-Creates an `IrrationalSample` constructor for the `MinimumLayer`.
-It determines the sign of weights in the input layer based on the decimal expansion of an `irrational` number.
-The `start` parameter sets the starting point in the decimal sequence.
-The signs are assigned based on the thresholding of each decimal digit against 4.5, as described in [^Rodan].
-
-# Parameters
-- `irrational`: An irrational number for weight sign determination (default: π).
-- `start`: Starting index in the decimal expansion (default: 1).
-
-# Returns
-- An `IrrationalSample` instance for generating sign weights in `MinimumLayer`.
-
-Reference:
-[^Rodan]: Rodan, Ali, and Peter Tiňo.
-    "Simple deterministically constructed cycle reservoirs with regular jumps."
-    Neural Computation 24.7 (2012): 1822-1852.
-"""
-function IrrationalSample(; irrational = pi, start = 1)
-    return IrrationalSample(irrational, start)
-end
-
-struct MinimumLayer{T, K} <: AbstractLayer
-    weight::T
-    sampling::K
-end
-
-"""
-    MinimumLayer(weight, sampling)
-    MinimumLayer(weight; sampling=BernoulliSample(0.5))
-    MinimumLayer(;weight=0.1, sampling=BernoulliSample(0.5))
-
-Creates a `MinimumLayer` initializer for Echo State Networks, generating a fully connected input layer.
-This layer has a uniform absolute weight value (`weight`) with the sign of each
-weight determined by the `sampling` method. This approach, as detailed in [^Rodan1] and [^Rodan2],
-allows for controlled weight distribution in the layer.
-
-# Parameters
-- `weight`: Absolute value of weights in the layer.
-- `sampling`: Method for determining the sign of weights (default: `BernoulliSample(0.5)`).
-
-# Returns
-- A `MinimumLayer` instance for initializing the ESN's input layer.
-
-References:
-[^Rodan1]: Rodan, Ali, and Peter Tino.
-    "Minimum complexity echo state network."
-    IEEE Transactions on Neural Networks 22.1 (2010): 131-144.
-[^Rodan2]: Rodan, Ali, and Peter Tiňo.
-    "Simple deterministically constructed cycle reservoirs with regular jumps."
-    Neural Computation 24.7 (2012): 1822-1852.
-"""
-function MinimumLayer(weight; sampling = BernoulliSample(0.5))
-    return MinimumLayer(weight, sampling)
-end
-
-function MinimumLayer(; weight = 0.1, sampling = BernoulliSample(0.5))
-    return MinimumLayer(weight, sampling)
-end
-
-function create_layer(input_layer::MinimumLayer,
-        res_size,
-        in_size;
-        matrix_type = Matrix{Float64})
-    sampling = input_layer.sampling
-    weight = input_layer.weight
-    layer_matrix = create_minimum_input(sampling, res_size, in_size, weight)
-    return Adapt.adapt(matrix_type, layer_matrix)
-end
-
-function create_minimum_input(sampling::BernoulliSample, res_size, in_size, weight)
-    p = sampling.p
-    input_matrix = zeros(res_size, in_size)
-    for i in 1:res_size
-        for j in 1:in_size
-            rand(Bernoulli(p)) ? input_matrix[i, j] = weight : input_matrix[i, j] = -weight
-        end
-    end
-
-    return input_matrix
-end
-
-function create_minimum_input(sampling::IrrationalSample, res_size, in_size, weight)
-    setprecision(BigFloat, Int(ceil(log2(10) * (res_size * in_size + sampling.start + 1))))
-    ir_string = string(BigFloat(sampling.irrational)) |> collect
-    deleteat!(ir_string, findall(x -> x == '.', ir_string))
-    ir_array = zeros(length(ir_string))
-    input_matrix = zeros(res_size, in_size)
-
-    for i in 1:length(ir_string)
-        ir_array[i] = parse(Int, ir_string[i])
-    end
-
-    co = sampling.start
-    counter = 1
-
-    for i in 1:res_size
-        for j in 1:in_size
-            ir_array[counter] < 5 ? input_matrix[i, j] = -weight :
-            input_matrix[i, j] = weight
-            counter += 1
-        end
-    end
-
-    return input_matrix
-end
-
-struct InformedLayer{T, K, M} <: AbstractLayer
-    scaling::T
-    gamma::K
-    model_in_size::M
-end
-
 """
-    InformedLayer(model_in_size; scaling=0.1, gamma=0.5)
+    scaled_rand(rng::AbstractRNG, ::Type{T}, dims::Integer...; scaling=T(0.1)) where {T <: Number}
 
-Creates an `InformedLayer` initializer for Echo State Networks (ESNs) that generates
-a weighted input layer matrix. The matrix contains random non-zero elements drawn from
-the range [-```scaling```, ```scaling```]. This initializer ensures that a fraction (`gamma`)
-of reservoir nodes are exclusively connected to the raw inputs, while the rest are
-connected to the outputs of a prior knowledge model, as described in [^Pathak].
+Create and return a matrix with random values, uniformly distributed within a range defined by `scaling`. This function is useful for initializing matrices, such as the layers of a neural network, with scaled random values.
 
 # Arguments
-- `model_in_size`: The size of the prior knowledge model's output,
-    which determines the number of columns in the input layer matrix.
-
-# Keyword Arguments
-- `scaling`: The absolute value of the weights (default: 0.1).
-- `gamma`: The fraction of reservoir nodes connected exclusively to raw inputs (default: 0.5).
+- `rng`: An instance of `AbstractRNG` for random number generation.
+- `T`: The data type for the elements of the matrix.
+- `dims`: Dimensions of the matrix. It must be a 2-element tuple specifying the number of rows and columns (e.g., `(res_size, in_size)`).
+- `scaling`: A scaling factor to define the range of the uniform distribution. The matrix elements will be randomly chosen from the range `[-scaling, scaling]`. Defaults to `T(0.1)`.
 
 # Returns
-- An `InformedLayer` instance for initializing the ESN's input layer matrix.
-
-Reference:
-[^Pathak]: Jaideep Pathak et al. 
-    "Hybrid Forecasting of Chaotic Processes: Using Machine Learning in Conjunction with a Knowledge-Based Model" (2018).
-"""
-function InformedLayer(model_in_size; scaling = 0.1, gamma = 0.5)
-    return InformedLayer(scaling, gamma, model_in_size)
-end
-
-function create_layer(input_layer::InformedLayer,
-        res_size,
-        in_size;
-        matrix_type = Matrix{Float64})
-    scaling = input_layer.scaling
-    state_size = in_size - input_layer.model_in_size
-
-    if state_size <= 0
-        throw(DimensionMismatch("in_size must be greater than model_in_size"))
-    end
-
-    input_matrix = zeros(res_size, in_size)
-    #Vector used to find res nodes not yet connected
-    zero_connections = zeros(in_size)
-    #Num of res nodes allotted for raw states
-    num_for_state = floor(Int, res_size * input_layer.gamma)
-    #Num of res nodes allotted for prior model input
-    num_for_model = floor(Int, (res_size * (1 - input_layer.gamma)))
-
-    for i in 1:num_for_state
-        #find res nodes with no connections
-        idxs = findall(Bool[zero_connections == input_matrix[i, :]
-                            for i in 1:size(input_matrix, 1)])
-        random_row_idx = idxs[rand(1:end)]
-        random_clm_idx = range(1, state_size, step = 1)[rand(1:end)]
-        input_matrix[random_row_idx, random_clm_idx] = rand(Uniform(-scaling, scaling))
-    end
-
-    for i in 1:num_for_model
-        idxs = findall(Bool[zero_connections == input_matrix[i, :]
-                            for i in 1:size(input_matrix, 1)])
-        random_row_idx = idxs[rand(1:end)]
-        random_clm_idx = range(state_size + 1, in_size, step = 1)[rand(1:end)]
-        input_matrix[random_row_idx, random_clm_idx] = rand(Uniform(-scaling, scaling))
-    end
-
-    return Adapt.adapt(matrix_type, input_matrix)
-end
+A matrix of type with dimensions specified by `dims`. Each element of the matrix is a random number uniformly distributed between `-scaling` and `scaling`.
 
+# Example
+```julia
+rng = Random.default_rng()
+matrix = scaled_rand(rng, Float64, (100, 50); scaling=0.2)
 """
-    NullLayer()
+function scaled_rand(
+    rng::AbstractRNG,
+    ::Type{T},
+    dims::Integer...;
+    scaling=T(0.1)
+) where {T <: Number}
 
-Creates a `NullLayer` initializer for Echo State Networks (ESNs) that generates a vector of zeros.
-
-# Returns
-- A `NullLayer` instance for initializing the ESN's input layer matrix.
-"""
-struct NullLayer <: AbstractLayer end
+    @assert length(dims) == 2, "The dimensions must define a matrix (e.g., (res_size, in_size))"
 
-function create_layer(input_layer::NullLayer,
-        res_size,
-        in_size;
-        matrix_type = Matrix{Float64})
-    return Adapt.adapt(matrix_type, zeros(res_size, in_size))
+    res_size, in_size = dims
+    layer_matrix = rand(rng, Uniform(-scaling, scaling), res_size, in_size)
+    return layer_matrix
 end
diff --git a/src/esn/esn_reservoirs.jl b/src/esn/esn_reservoirs.jl
index fa57109b..ab2eaf42 100644
--- a/src/esn/esn_reservoirs.jl
+++ b/src/esn/esn_reservoirs.jl
@@ -1,7 +1,3 @@
-function get_ressize(reservoir)
-    return size(reservoir, 1)
-end
-
 """
     rand_sparse(rng::AbstractRNG, ::Type{T}, dims::Integer...; radius=1.0, sparsity=0.1)
 
@@ -64,8 +60,7 @@ function delay_line(rng::AbstractRNG,
         dims::Integer...;
         weight = T(0.1)) where {T <: Number}
     reservoir_matrix = zeros(T, dims...)
-    @assert length(dims) == 2 && dims[1] == dims[2],
-    "The dimensions must define a square matrix (e.g., (100, 100))"
+    @assert length(dims) == 2 && dims[1] == dims[2] "The dimensions must define a square matrix (e.g., (100, 100))"
 
     for i in 1:(dims[1] - 1)
         reservoir_matrix[i + 1, i] = weight
@@ -74,28 +69,6 @@ function delay_line(rng::AbstractRNG,
     return reservoir_matrix
 end
 
-for initializer in (:rand_sparse, :delay_line)
-    NType = ifelse(initializer === :rand_sparse, Real, Number)
-    @eval function ($initializer)(dims::Integer...; kwargs...)
-        return $initializer(_default_rng(), Float32, dims...; kwargs...)
-    end
-    @eval function ($initializer)(rng::AbstractRNG, dims::Integer...; kwargs...)
-        return $initializer(rng, Float32, dims...; kwargs...)
-    end
-    @eval function ($initializer)(::Type{T},
-            dims::Integer...; kwargs...) where {T <: $NType}
-        return $initializer(_default_rng(), T, dims...; kwargs...)
-    end
-    @eval function ($initializer)(rng::AbstractRNG; kwargs...)
-        return __partial_apply($initializer, (rng, (; kwargs...)))
-    end
-    @eval function ($initializer)(rng::AbstractRNG,
-            ::Type{T}; kwargs...) where {T <: $NType}
-        return __partial_apply($initializer, ((rng, T), (; kwargs...)))
-    end
-    @eval ($initializer)(; kwargs...) = __partial_apply($initializer, (; kwargs...))
-end
-
 # from WeightInitializers.jl, TODO @MartinuzziFrancesco consider importing package
 function _default_rng()
     @static if VERSION >= v"1.7"
diff --git a/test/esn/test_reservoirs.jl b/test/esn/test_reservoirs.jl
index df58bd3c..debd9be0 100644
--- a/test/esn/test_reservoirs.jl
+++ b/test/esn/test_reservoirs.jl
@@ -11,7 +11,7 @@ const jump_size = 3
 const rng = Random.default_rng()
 
 dtypes = [Float16, Float32, Float64]
-reservoir_inits = [rand_sparse]
+reservoir_inits = [rand_sparse, delay_line]
 
 @testset "Sizes and types" begin
     for init in reservoir_inits
@@ -34,3 +34,8 @@ end
     @test check_radius(sp, radius)
 end
 
+@testset "delay_line" begin
+    dl = delay_line(res_size, res_size)
+    @test unique(dl) == Float32.([0.0, 0.1])
+end
+

From ab3337cad9bdd669e35d52e57bb3933bc088cabb Mon Sep 17 00:00:00 2001
From: MartinuzziFrancesco <martinuzzi.francesco@gmail.com>
Date: Sun, 21 Jan 2024 16:41:52 +0100
Subject: [PATCH 5/6] made ESN work with new initilaizers, started separation
 of different models

---
 README.md                   |   9 +-
 src/ReservoirComputing.jl   |  11 +-
 src/esn/deepesn.jl          |  84 ++++++++++++
 src/esn/echostatenetwork.jl | 266 ------------------------------------
 src/esn/esn.jl              | 143 +++++++++++++++++++
 src/esn/esn_input_layers.jl |  40 +++++-
 src/esn/esn_predict.jl      |  35 +----
 src/esn/hybridesn.jl        |  82 +++++++++++
 8 files changed, 366 insertions(+), 304 deletions(-)
 create mode 100644 src/esn/deepesn.jl
 delete mode 100644 src/esn/echostatenetwork.jl
 create mode 100644 src/esn/esn.jl
 create mode 100644 src/esn/hybridesn.jl

diff --git a/README.md b/README.md
index 9172725e..60f12963 100644
--- a/README.md
+++ b/README.md
@@ -51,14 +51,15 @@ test = data[:, (shift + train_len):(shift + train_len + predict_len - 1)]
 Now that we have the data we can initialize the ESN with the chosen parameters. Given that this is a quick example we are going to change the least amount of possible parameters. For more detailed examples and explanations of the functions please refer to the documentation.
 
 ```julia
+input_size = 3
 res_size = 300
-esn = ESN(input_data;
-    reservoir = RandSparseReservoir(res_size, radius = 1.2, sparsity = 6 / res_size),
-    input_layer = WeightedLayer(),
+esn = ESN(input_data, input_size, res_size;
+    reservoir = rand_sparse(;radius = 1.2, sparsity = 6 / res_size),
+    input_layer = weighted_init,
     nla_type = NLAT2())
 ```
 
-The echo state network can now be trained and tested. If not specified, the training will always be Ordinary Least Squares regression. The full range of training methods is detailed in the documentation.
+The echo state network can now be trained and tested. If not specified, the training will always be ordinary least squares regression. The full range of training methods is detailed in the documentation.
 
 ```julia
 output_layer = train(esn, target_data)
diff --git a/src/ReservoirComputing.jl b/src/ReservoirComputing.jl
index d1e47609..f8668b54 100644
--- a/src/ReservoirComputing.jl
+++ b/src/ReservoirComputing.jl
@@ -19,10 +19,11 @@ export NLADefault, NLAT1, NLAT2, NLAT3
 export StandardStates, ExtendedStates, PaddedStates, PaddedExtendedStates
 export StandardRidge, LinearModel
 export AbstractLayer, create_layer
-export scaled_rand
+export scaled_rand, weighted_init
 export rand_sparse, delay_line
 export RNN, MRNN, GRU, GRUParams, FullyGated, Minimal
-export ESN, Default, Hybrid, train
+export ESN, train
+export DeepESN, HybridESN
 export RECA, train
 export RandomMapping, RandomMaps
 export Generative, Predictive, OutputLayer
@@ -73,7 +74,7 @@ function Predictive(prediction_data)
 end
 
 #fallbacks for initializers
-for initializer in (:rand_sparse, :delay_line, :scaled_rand)
+for initializer in (:rand_sparse, :delay_line, :scaled_rand, :weighted_init)
     NType = ifelse(initializer === :rand_sparse, Real, Number)
     @eval function ($initializer)(dims::Integer...; kwargs...)
         return $initializer(_default_rng(), Float32, dims...; kwargs...)
@@ -107,7 +108,9 @@ include("train/supportvector_regression.jl")
 include("esn/esn_input_layers.jl")
 include("esn/esn_reservoirs.jl")
 include("esn/esn_reservoir_drivers.jl")
-include("esn/echostatenetwork.jl")
+include("esn/esn.jl")
+include("esn/deepesn.jl")
+include("esn/hybridesn.jl")
 include("esn/esn_predict.jl")
 
 #reca
diff --git a/src/esn/deepesn.jl b/src/esn/deepesn.jl
new file mode 100644
index 00000000..4ab05f39
--- /dev/null
+++ b/src/esn/deepesn.jl
@@ -0,0 +1,84 @@
+struct DeepESN{I, S, V, N, T, O, M, B, ST, W, IS} <: AbstractEchoStateNetwork
+    res_size::I
+    train_data::S
+    variation::V
+    nla_type::N
+    input_matrix::T
+    reservoir_driver::O
+    reservoir_matrix::M
+    bias_vector::B
+    states_type::ST
+    washout::W
+    states::IS
+end
+
+function DeepESN(
+    train_data,
+    in_size::Int,
+    res_size::AbstractArray;
+    input_layer = scaled_rand,
+    reservoir = rand_sparse,
+    bias = zeros64,
+    reservoir_driver = RNN(),
+    nla_type = NLADefault(),
+    states_type = StandardStates(),
+    washout = 0,
+    rng = _default_rng(),
+    T=Float64,
+    matrix_type = typeof(train_data)
+)
+
+    if states_type isa AbstractPaddedStates
+        in_size = size(train_data, 1) + 1
+        train_data = vcat(Adapt.adapt(matrix_type, ones(1, size(train_data, 2))),
+            train_data)
+    end
+
+    reservoir_matrix = reservoir(rng, T, res_size, res_size)
+    input_matrix = input_layer(rng, T, res_size, in_size)
+    bias_vector = bias(rng, T, res_size)
+    inner_res_driver = reservoir_driver_params(reservoir_driver, res_size, in_size)
+    states = create_states(inner_res_driver, train_data, washout, reservoir_matrix,
+        input_matrix, bias_vector)
+    train_data = train_data[:, (washout + 1):end]
+
+    ESN(sum(res_size), train_data, variation, nla_type, input_matrix,
+        inner_res_driver, reservoir_matrix, bias_vector, states_type, washout,
+        states)
+end
+
+function obtain_layers(in_size,
+    input_layer,
+    reservoir::Vector,
+    bias;
+    matrix_type = Matrix{Float64})
+esn_depth = length(reservoir)
+input_res_sizes = [get_ressize(reservoir[i]) for i in 1:esn_depth]
+in_sizes = zeros(Int, esn_depth)
+in_sizes[2:end] = input_res_sizes[1:(end - 1)]
+in_sizes[1] = in_size
+
+if input_layer isa Array
+    input_matrix = [create_layer(input_layer[j], input_res_sizes[j], in_sizes[j],
+        matrix_type = matrix_type) for j in 1:esn_depth]
+else
+    _input_layer = fill(input_layer, esn_depth)
+    input_matrix = [create_layer(_input_layer[k], input_res_sizes[k], in_sizes[k],
+        matrix_type = matrix_type) for k in 1:esn_depth]
+end
+
+res_sizes = [get_ressize(input_matrix[j]) for j in 1:esn_depth]
+reservoir_matrix = [create_reservoir(reservoir[k], res_sizes[k],
+    matrix_type = matrix_type) for k in 1:esn_depth]
+
+if bias isa Array
+    bias_vector = [create_layer(bias[j], res_sizes[j], 1, matrix_type = matrix_type)
+                   for j in 1:esn_depth]
+else
+    _bias = fill(bias, esn_depth)
+    bias_vector = [create_layer(_bias[k], res_sizes[k], 1, matrix_type = matrix_type)
+                   for k in 1:esn_depth]
+end
+
+return input_matrix, reservoir_matrix, bias_vector, res_sizes
+end
\ No newline at end of file
diff --git a/src/esn/echostatenetwork.jl b/src/esn/echostatenetwork.jl
deleted file mode 100644
index adbf85a4..00000000
--- a/src/esn/echostatenetwork.jl
+++ /dev/null
@@ -1,266 +0,0 @@
-abstract type AbstractEchoStateNetwork <: AbstractReservoirComputer end
-struct ESN{I, S, V, N, T, O, M, B, ST, W, IS} <: AbstractEchoStateNetwork
-    res_size::I
-    train_data::S
-    variation::V
-    nla_type::N
-    input_matrix::T
-    reservoir_driver::O
-    reservoir_matrix::M
-    bias_vector::B
-    states_type::ST
-    washout::W
-    states::IS
-end
-
-"""
-    Default()
-
-The `Default` struct specifies the use of the standard model in Echo State Networks (ESNs).
-It requires no parameters and is used when no specific variations or customizations of the ESN model are needed.
-This struct is ideal for straightforward applications where the default ESN settings are sufficient.
-"""
-struct Default <: AbstractVariation end
-struct Hybrid{T, K, O, I, S, D} <: AbstractVariation
-    prior_model::T
-    u0::K
-    tspan::O
-    dt::I
-    datasize::S
-    model_data::D
-end
-
-"""
-    Hybrid(prior_model, u0, tspan, datasize)
-
-Constructs a `Hybrid` variation of Echo State Networks (ESNs) integrating a knowledge-based model
-(`prior_model`) with ESNs for advanced training and prediction in chaotic systems. 
-
-# Parameters
-- `prior_model`: A knowledge-based model function for integration with ESNs.
-- `u0`: Initial conditions for the model.
-- `tspan`: Time span as a tuple, indicating the duration for model operation.
-- `datasize`: The size of the data to be processed.
-
-# Returns
-- A `Hybrid` struct instance representing the combined ESN and knowledge-based model.
-
-This method is effective for chaotic processes as highlighted in [^Pathak].
-
-Reference:
-[^Pathak]: Jaideep Pathak et al.
-    "Hybrid Forecasting of Chaotic Processes:
-    Using Machine Learning in Conjunction with a Knowledge-Based Model" (2018).
-"""
-function Hybrid(prior_model, u0, tspan, datasize)
-    trange = collect(range(tspan[1], tspan[2], length = datasize))
-    dt = trange[2] - trange[1]
-    tsteps = push!(trange, dt + trange[end])
-    tspan_new = (tspan[1], dt + tspan[2])
-    model_data = prior_model(u0, tspan_new, tsteps)
-    return Hybrid(prior_model, u0, tspan, dt, datasize, model_data)
-end
-
-"""
-    ESN(train_data; kwargs...) -> ESN
-
-Creates an Echo State Network (ESN) using specified parameters and training data, suitable for various machine learning tasks.
-
-# Parameters
-- `train_data`: Matrix of training data (columns as time steps, rows as features).
-- `variation`: Variation of ESN (default: `Default()`).
-- `input_layer`: Input layer of ESN (default: `DenseLayer()`).
-- `reservoir`: Reservoir of the ESN (default: `RandSparseReservoir(100)`).
-- `bias`: Bias vector for each time step (default: `NullLayer()`).
-- `reservoir_driver`: Mechanism for evolving reservoir states (default: `RNN()`).
-- `nla_type`: Non-linear activation type (default: `NLADefault()`).
-- `states_type`: Format for storing states (default: `StandardStates()`).
-- `washout`: Initial time steps to discard (default: `0`).
-- `matrix_type`: Type of matrices used internally (default: type of `train_data`).
-
-# Returns
-- An initialized ESN instance with specified parameters.
-
-# Examples
-```julia
-using ReservoirComputing
-
-train_data = rand(10, 100)  # 10 features, 100 time steps
-
-esn = ESN(train_data, reservoir=RandSparseReservoir(200), washout=10)
-```
-"""
-function ESN(
-    train_data,
-    in_size,
-    res_size;
-    input_layer = scaled_rand,
-    reservoir = rand_sparse,
-    bias = zeros64,
-    reservoir_driver = RNN(),
-    nla_type = NLADefault(),
-    states_type = StandardStates(),
-    washout = 0,
-    rng = _default_rng(),
-    matrix_type = typeof(train_data)
-) where {T <: Number}
-
-    if states_type isa AbstractPaddedStates
-        in_size = size(train_data, 1) + 1
-        train_data = vcat(Adapt.adapt(matrix_type, ones(1, size(train_data, 2))),
-            train_data)
-    end
-
-    reservoir_matrix = reservoir(rng, T, res_size, res_size)
-    input_matrix = input_layer(rng, T, res_size, in_size)
-    bias_vector = bias(rng, T, res_size)
-    inner_res_driver = reservoir_driver_params(reservoir_driver, res_size, in_size)
-    states = create_states(inner_res_driver, train_data, washout, reservoir_matrix,
-        input_matrix, bias_vector)
-    train_data = train_data[:, (washout + 1):end]
-
-    ESN(sum(res_size), train_data, variation, nla_type, input_matrix,
-        inner_res_driver, reservoir_matrix, bias_vector, states_type, washout,
-        states)
-end
-
-#shallow esn construction
-function obtain_layers(in_size,
-        input_layer,
-        reservoir,
-        bias;
-        matrix_type = Matrix{Float64})
-    input_res_size = get_ressize(reservoir)
-    input_matrix = create_layer(input_layer, input_res_size, in_size,
-        matrix_type = matrix_type)
-    res_size = size(input_matrix, 1) #WeightedInput actually changes the res size
-    reservoir_matrix = create_reservoir(reservoir, res_size, matrix_type = matrix_type)
-    @assert size(reservoir_matrix, 1) == res_size
-    bias_vector = create_layer(bias, res_size, 1, matrix_type = matrix_type)
-    return input_matrix, reservoir_matrix, bias_vector, res_size
-end
-
-#deep esn construction
-#there is a bug going on with WeightedLayer in this construction.
-#it works for eny other though
-function obtain_layers(in_size,
-        input_layer,
-        reservoir::Vector,
-        bias;
-        matrix_type = Matrix{Float64})
-    esn_depth = length(reservoir)
-    input_res_sizes = [get_ressize(reservoir[i]) for i in 1:esn_depth]
-    in_sizes = zeros(Int, esn_depth)
-    in_sizes[2:end] = input_res_sizes[1:(end - 1)]
-    in_sizes[1] = in_size
-
-    if input_layer isa Array
-        input_matrix = [create_layer(input_layer[j], input_res_sizes[j], in_sizes[j],
-            matrix_type = matrix_type) for j in 1:esn_depth]
-    else
-        _input_layer = fill(input_layer, esn_depth)
-        input_matrix = [create_layer(_input_layer[k], input_res_sizes[k], in_sizes[k],
-            matrix_type = matrix_type) for k in 1:esn_depth]
-    end
-
-    res_sizes = [get_ressize(input_matrix[j]) for j in 1:esn_depth]
-    reservoir_matrix = [create_reservoir(reservoir[k], res_sizes[k],
-        matrix_type = matrix_type) for k in 1:esn_depth]
-
-    if bias isa Array
-        bias_vector = [create_layer(bias[j], res_sizes[j], 1, matrix_type = matrix_type)
-                       for j in 1:esn_depth]
-    else
-        _bias = fill(bias, esn_depth)
-        bias_vector = [create_layer(_bias[k], res_sizes[k], 1, matrix_type = matrix_type)
-                       for k in 1:esn_depth]
-    end
-
-    return input_matrix, reservoir_matrix, bias_vector, res_sizes
-end
-
-function (esn::ESN)(prediction::AbstractPrediction,
-        output_layer::AbstractOutputLayer;
-        last_state = esn.states[:, [end]],
-        kwargs...)
-    variation = esn.variation
-    pred_len = prediction.prediction_len
-
-    if variation isa Hybrid
-        model = variation.prior_model
-        predict_tsteps = [variation.tspan[2] + variation.dt]
-        [append!(predict_tsteps, predict_tsteps[end] + variation.dt) for i in 1:pred_len]
-        tspan_new = (variation.tspan[2] + variation.dt, predict_tsteps[end])
-        u0 = variation.model_data[:, end]
-        model_pred_data = model(u0, tspan_new, predict_tsteps)[:, 2:end]
-        return obtain_esn_prediction(esn, prediction, last_state, output_layer,
-            model_pred_data;
-            kwargs...)
-    else
-        return obtain_esn_prediction(esn, prediction, last_state, output_layer;
-            kwargs...)
-    end
-end
-
-#training dispatch on esn
-"""
-    train(esn::AbstractEchoStateNetwork, target_data, training_method = StandardRidge(0.0))
-
-Trains an Echo State Network (ESN) using the provided target data and a specified training method.
-
-# Parameters
-- `esn::AbstractEchoStateNetwork`: The ESN instance to be trained.
-- `target_data`: Supervised training data for the ESN.
-- `training_method`: The method for training the ESN (default: `StandardRidge(0.0)`).
-
-# Returns
-- The trained ESN model. Its type and structure depend on `training_method` and the ESN's implementation.
-
-
-# Returns
-The trained ESN model. The exact type and structure of the return value depends on the
-`training_method` and the specific ESN implementation.
-
-```julia
-using ReservoirComputing
-
-# Initialize an ESN instance and target data
-esn = ESN(train_data, reservoir=RandSparseReservoir(200), washout=10)
-target_data = rand(size(train_data, 2))
-
-# Train the ESN using the default training method
-trained_esn = train(esn, target_data)
-
-# Train the ESN using a custom training method
-trained_esn = train(esn, target_data, training_method=StandardRidge(1.0))
-```
-
-# Notes
-- When using a `Hybrid` variation, the function extends the state matrix with data from the
-    physical model included in the `variation`.
-- The training is handled by a lower-level `_train` function which takes the new state matrix
-    and performs the actual training using the specified `training_method`.
-"""
-function train(esn::AbstractEchoStateNetwork,
-        target_data,
-        training_method = StandardRidge(0.0))
-    variation = esn.variation
-
-    if esn.variation isa Hybrid
-        states = vcat(esn.states, esn.variation.model_data[:, 2:end])
-    else
-        states = esn.states
-    end
-    states_new = esn.states_type(esn.nla_type, states, esn.train_data[:, 1:end])
-
-    return _train(states_new, target_data, training_method)
-end
-
-function pad_esnstate(variation::Hybrid, states_type, x_pad, x, model_prediction_data)
-    x_tmp = vcat(x, model_prediction_data)
-    x_pad = pad_state!(states_type, x_pad, x_tmp)
-end
-
-function pad_esnstate!(variation, states_type, x_pad, x, args...)
-    x_pad = pad_state!(states_type, x_pad, x)
-end
diff --git a/src/esn/esn.jl b/src/esn/esn.jl
new file mode 100644
index 00000000..3592ed8d
--- /dev/null
+++ b/src/esn/esn.jl
@@ -0,0 +1,143 @@
+abstract type AbstractEchoStateNetwork <: AbstractReservoirComputer end
+struct ESN{I, S, N, T, O, M, B, ST, W, IS} <: AbstractEchoStateNetwork
+    res_size::I
+    train_data::S
+    nla_type::N
+    input_matrix::T
+    reservoir_driver::O
+    reservoir_matrix::M
+    bias_vector::B
+    states_type::ST
+    washout::W
+    states::IS
+end
+
+"""
+    ESN(train_data; kwargs...) -> ESN
+
+Creates an Echo State Network (ESN) using specified parameters and training data, suitable for various machine learning tasks.
+
+# Parameters
+- `train_data`: Matrix of training data (columns as time steps, rows as features).
+- `variation`: Variation of ESN (default: `Default()`).
+- `input_layer`: Input layer of ESN (default: `DenseLayer()`).
+- `reservoir`: Reservoir of the ESN (default: `RandSparseReservoir(100)`).
+- `bias`: Bias vector for each time step (default: `NullLayer()`).
+- `reservoir_driver`: Mechanism for evolving reservoir states (default: `RNN()`).
+- `nla_type`: Non-linear activation type (default: `NLADefault()`).
+- `states_type`: Format for storing states (default: `StandardStates()`).
+- `washout`: Initial time steps to discard (default: `0`).
+- `matrix_type`: Type of matrices used internally (default: type of `train_data`).
+
+# Returns
+- An initialized ESN instance with specified parameters.
+
+# Examples
+```julia
+using ReservoirComputing
+
+train_data = rand(10, 100)  # 10 features, 100 time steps
+
+esn = ESN(train_data, reservoir=RandSparseReservoir(200), washout=10)
+```
+"""
+function ESN(
+    train_data,
+    in_size::Int,
+    res_size::Int;
+    input_layer = scaled_rand,
+    reservoir = rand_sparse,
+    bias = zeros64,
+    reservoir_driver = RNN(),
+    nla_type = NLADefault(),
+    states_type = StandardStates(),
+    washout = 0,
+    rng = _default_rng(),
+    T = Float32,
+    matrix_type = typeof(train_data)
+)
+
+    if states_type isa AbstractPaddedStates
+        in_size = size(train_data, 1) + 1
+        train_data = vcat(Adapt.adapt(matrix_type, ones(1, size(train_data, 2))),
+            train_data)
+    end
+
+    reservoir_matrix = reservoir(rng, T, res_size, res_size)
+    input_matrix = input_layer(rng, T, in_size, res_size)
+    bias_vector = bias(rng, res_size)
+    inner_res_driver = reservoir_driver_params(reservoir_driver, res_size, in_size)
+    states = create_states(inner_res_driver, train_data, washout, reservoir_matrix,
+        input_matrix, bias_vector)
+    train_data = train_data[:, (washout + 1):end]
+
+    ESN(res_size, train_data, nla_type, input_matrix,
+        inner_res_driver, reservoir_matrix, bias_vector, states_type, washout,
+        states)
+end
+
+function (esn::ESN)(prediction::AbstractPrediction,
+        output_layer::AbstractOutputLayer;
+        last_state = esn.states[:, [end]],
+        kwargs...)
+    pred_len = prediction.prediction_len
+
+    return obtain_esn_prediction(esn, prediction, last_state, output_layer;
+        kwargs...)
+end
+
+#training dispatch on esn
+"""
+    train(esn::AbstractEchoStateNetwork, target_data, training_method = StandardRidge(0.0))
+
+Trains an Echo State Network (ESN) using the provided target data and a specified training method.
+
+# Parameters
+- `esn::AbstractEchoStateNetwork`: The ESN instance to be trained.
+- `target_data`: Supervised training data for the ESN.
+- `training_method`: The method for training the ESN (default: `StandardRidge(0.0)`).
+
+# Returns
+- The trained ESN model. Its type and structure depend on `training_method` and the ESN's implementation.
+
+
+# Returns
+The trained ESN model. The exact type and structure of the return value depends on the
+`training_method` and the specific ESN implementation.
+
+```julia
+using ReservoirComputing
+
+# Initialize an ESN instance and target data
+esn = ESN(train_data, reservoir=RandSparseReservoir(200), washout=10)
+target_data = rand(size(train_data, 2))
+
+# Train the ESN using the default training method
+trained_esn = train(esn, target_data)
+
+# Train the ESN using a custom training method
+trained_esn = train(esn, target_data, training_method=StandardRidge(1.0))
+```
+
+# Notes
+- When using a `Hybrid` variation, the function extends the state matrix with data from the
+    physical model included in the `variation`.
+- The training is handled by a lower-level `_train` function which takes the new state matrix
+    and performs the actual training using the specified `training_method`.
+"""
+function train(esn::ESN,
+        target_data,
+        training_method = StandardRidge(0.0))
+    states_new = esn.states_type(esn.nla_type, esn.states, esn.train_data[:, 1:end])
+
+    return _train(states_new, target_data, training_method)
+end
+
+#function pad_esnstate(variation::Hybrid, states_type, x_pad, x, model_prediction_data)
+#    x_tmp = vcat(x, model_prediction_data)
+#    x_pad = pad_state!(states_type, x_pad, x_tmp)
+#end
+
+function pad_esnstate!(variation, states_type, x_pad, x, args...)
+    x_pad = pad_state!(states_type, x_pad, x)
+end
diff --git a/src/esn/esn_input_layers.jl b/src/esn/esn_input_layers.jl
index 42b82282..e79df1d5 100644
--- a/src/esn/esn_input_layers.jl
+++ b/src/esn/esn_input_layers.jl
@@ -24,9 +24,45 @@ function scaled_rand(
     scaling=T(0.1)
 ) where {T <: Number}
 
-    @assert length(dims) == 2, "The dimensions must define a matrix (e.g., (res_size, in_size))"
-
     res_size, in_size = dims
     layer_matrix = rand(rng, Uniform(-scaling, scaling), res_size, in_size)
     return layer_matrix
 end
+
+"""
+    weighted_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; scaling=T(0.1)) where {T <: Number}
+
+Create and return a matrix representing a weighted input layer for Echo State Networks (ESNs). This initializer generates a weighted input matrix with random non-zero elements distributed uniformly within the range [-`scaling`, `scaling`], inspired by the approach in [^Lu].
+
+# Arguments
+- `rng`: An instance of `AbstractRNG` for random number generation.
+- `T`: The data type for the elements of the matrix.
+- `dims`: A 2-element tuple specifying the approximate reservoir size and input size (e.g., `(approx_res_size, in_size)`).
+- `scaling`: The scaling factor for the weight distribution. Defaults to `T(0.1)`.
+
+# Returns
+A matrix representing the weighted input layer as defined in [^Lu2017]. The matrix dimensions will be adjusted to ensure each input unit connects to an equal number of reservoir units.
+
+# Example
+```julia
+rng = Random.default_rng()
+input_layer = weighted_init(rng, Float64, (3, 300); scaling=0.2)
+```
+# References
+[^Lu2017]: Lu, Zhixin, et al.
+    "Reservoir observers: Model-free inference of unmeasured variables in chaotic systems."
+    Chaos: An Interdisciplinary Journal of Nonlinear Science 27.4 (2017): 041102.
+"""
+function weighted_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; scaling=T(0.1)) where {T <: Number}
+
+    in_size, approx_res_size = dims
+    res_size = Int(floor(approx_res_size / in_size) * in_size)
+    layer_matrix = zeros(T, res_size, in_size)
+    q = floor(Int, res_size / in_size)
+
+    for i in 1:in_size
+        layer_matrix[((i - 1) * q + 1):((i) * q), i] = rand(rng, Uniform(-scaling, scaling), q)
+    end
+
+    return layer_matrix
+end
diff --git a/src/esn/esn_predict.jl b/src/esn/esn_predict.jl
index daa6fc34..5955c762 100644
--- a/src/esn/esn_predict.jl
+++ b/src/esn/esn_predict.jl
@@ -13,7 +13,7 @@ function obtain_esn_prediction(esn,
     out = initial_conditions
     states = similar(esn.states, size(esn.states, 1), prediction_len)
 
-    out_pad = allocate_outpad(esn.variation, esn.states_type, out)
+    out_pad = allocate_outpad(esn, esn.states_type, out)
     tmp_array = allocate_tmp(esn.reservoir_driver, typeof(esn.states), esn.res_size)
     x_new = esn.states_type(esn.nla_type, x, out_pad)
 
@@ -43,7 +43,7 @@ function obtain_esn_prediction(esn,
     out = initial_conditions
     states = similar(esn.states, size(esn.states, 1), prediction_len)
 
-    out_pad = allocate_outpad(esn.variation, esn.states_type, out)
+    out_pad = allocate_outpad(esn, esn.states_type, out)
     tmp_array = allocate_tmp(esn.reservoir_driver, typeof(esn.states), esn.res_size)
     x_new = esn.states_type(esn.nla_type, x, out_pad)
 
@@ -60,20 +60,6 @@ end
 
 #prediction dispatch on esn 
 function next_state_prediction!(esn::ESN, x, x_new, out, out_pad, i, tmp_array, args...)
-    return _variation_prediction!(esn.variation, esn, x, x_new, out, out_pad, i, tmp_array,
-        args...)
-end
-
-#dispatch the prediction on the esn variation
-function _variation_prediction!(variation,
-        esn,
-        x,
-        x_new,
-        out,
-        out_pad,
-        i,
-        tmp_array,
-        args...)
     out_pad = pad_state!(esn.states_type, out_pad, out)
     xv = @view x[1:(esn.res_size)]
     x = next_state!(x, esn.reservoir_driver, xv, out_pad,
@@ -82,15 +68,8 @@ function _variation_prediction!(variation,
     return x, x_new
 end
 
-function _variation_prediction!(variation::Hybrid,
-        esn,
-        x,
-        x_new,
-        out,
-        out_pad,
-        i,
-        tmp_array,
-        model_prediction_data)
+#TODO fixme @MatrinuzziFra
+function next_state_prediction!(hesn::HybridESN, x, x_new, out, out_pad, i, tmp_array, args...)
     out_tmp = vcat(out, model_prediction_data[:, i])
     out_pad = pad_state!(esn.states_type, out_pad, out_tmp)
     x = next_state!(x, esn.reservoir_driver, x[1:(esn.res_size)], out_pad,
@@ -100,12 +79,12 @@ function _variation_prediction!(variation::Hybrid,
     return x, x_new
 end
 
-function allocate_outpad(variation, states_type, out)
+function allocate_outpad(ens::ESN, states_type, out)
     return allocate_singlepadding(states_type, out)
 end
 
-function allocate_outpad(variation::Hybrid, states_type, out)
-    pad_length = length(out) + size(variation.model_data[:, 1], 1)
+function allocate_outpad(hesn::HybridESN, states_type, out)
+    pad_length = length(out) + size(hesn.model.model_data[:, 1], 1)
     out_tmp = Adapt.adapt(typeof(out), zeros(pad_length))
     return allocate_singlepadding(states_type, out_tmp)
 end
diff --git a/src/esn/hybridesn.jl b/src/esn/hybridesn.jl
new file mode 100644
index 00000000..d1cfdac9
--- /dev/null
+++ b/src/esn/hybridesn.jl
@@ -0,0 +1,82 @@
+struct HybridESN{I, S, V, N, T, O, M, B, ST, W, IS} <: AbstractEchoStateNetwork
+    res_size::I
+    train_data::S
+    model::V
+    nla_type::N
+    input_matrix::T
+    reservoir_driver::O
+    reservoir_matrix::M
+    bias_vector::B
+    states_type::ST
+    washout::W
+    states::IS
+end
+
+struct KnowledgeModel{T, K, O, I, S, D}
+    prior_model::T
+    u0::K
+    tspan::O
+    dt::I
+    datasize::S
+    model_data::D
+end
+
+"""
+    Hybrid(prior_model, u0, tspan, datasize)
+
+Constructs a `Hybrid` variation of Echo State Networks (ESNs) integrating a knowledge-based model
+(`prior_model`) with ESNs for advanced training and prediction in chaotic systems. 
+
+# Parameters
+- `prior_model`: A knowledge-based model function for integration with ESNs.
+- `u0`: Initial conditions for the model.
+- `tspan`: Time span as a tuple, indicating the duration for model operation.
+- `datasize`: The size of the data to be processed.
+
+# Returns
+- A `Hybrid` struct instance representing the combined ESN and knowledge-based model.
+
+This method is effective for chaotic processes as highlighted in [^Pathak].
+
+Reference:
+[^Pathak]: Jaideep Pathak et al.
+    "Hybrid Forecasting of Chaotic Processes:
+    Using Machine Learning in Conjunction with a Knowledge-Based Model" (2018).
+"""
+function KnowledgeModel(prior_model, u0, tspan, datasize)
+    trange = collect(range(tspan[1], tspan[2], length = datasize))
+    dt = trange[2] - trange[1]
+    tsteps = push!(trange, dt + trange[end])
+    tspan_new = (tspan[1], dt + tspan[2])
+    model_data = prior_model(u0, tspan_new, tsteps)
+    return Hybrid(prior_model, u0, tspan, dt, datasize, model_data)
+end
+
+function (hesn::HybridESN)(prediction::AbstractPrediction,
+    output_layer::AbstractOutputLayer;
+    last_state = esn.states[:, [end]],
+    kwargs...)
+
+    pred_len = prediction.prediction_len
+
+    model = variation.prior_model
+    predict_tsteps = [variation.tspan[2] + variation.dt]
+    [append!(predict_tsteps, predict_tsteps[end] + variation.dt) for i in 1:pred_len]
+    tspan_new = (variation.tspan[2] + variation.dt, predict_tsteps[end])
+    u0 = variation.model_data[:, end]
+    model_pred_data = model(u0, tspan_new, predict_tsteps)[:, 2:end]
+
+    return obtain_esn_prediction(esn, prediction, last_state, output_layer,
+        model_pred_data;
+        kwargs...)
+end
+
+function train(hesn::HybridESN,
+    target_data,
+    training_method = StandardRidge(0.0))
+
+    states = vcat(esn.states, esn.variation.model_data[:, 2:end])
+    states_new = esn.states_type(esn.nla_type, states, esn.train_data[:, 1:end])
+
+    return _train(states_new, target_data, training_method)
+end
\ No newline at end of file

From ab444232597caf8e36590309f47dfd6d2ae23e57 Mon Sep 17 00:00:00 2001
From: MartinuzziFrancesco <martinuzzi.francesco@gmail.com>
Date: Sun, 21 Jan 2024 17:41:41 +0100
Subject: [PATCH 6/6] HybridESN working, modified docs and readme to follow
 changes

---
 README.md                              |  3 ++
 docs/src/esn_tutorials/hybrid.md       | 19 +++++---
 docs/src/esn_tutorials/lorenz_basic.md | 14 +++---
 src/ReservoirComputing.jl              |  3 +-
 src/esn/esn.jl                         |  6 +--
 src/esn/esn_predict.jl                 | 10 ++---
 src/esn/hybridesn.jl                   | 62 +++++++++++++++++++++-----
 7 files changed, 84 insertions(+), 33 deletions(-)

diff --git a/README.md b/README.md
index 60f12963..2b123d66 100644
--- a/README.md
+++ b/README.md
@@ -104,3 +104,6 @@ If you use this library in your work, please cite:
   url     = {http://jmlr.org/papers/v23/22-0611.html}
 }
 ```
+## Acknowledgements
+
+This project was possible thanks to initial funding through the [Google summer of code](https://summerofcode.withgoogle.com/) 2020 program. Francesco M. further acknowledges [ScaDS.AI](https://scads.ai/) and [RSC4Earth](https://rsc4earth.de/) for supporting the current progress on the library.
diff --git a/docs/src/esn_tutorials/hybrid.md b/docs/src/esn_tutorials/hybrid.md
index bf274f01..5682e9db 100644
--- a/docs/src/esn_tutorials/hybrid.md
+++ b/docs/src/esn_tutorials/hybrid.md
@@ -1,6 +1,6 @@
 # Hybrid Echo State Networks
 
-Following the idea of giving physical information to machine learning models, the hybrid echo state networks [^1] try to achieve this results by feeding model data into the ESN. In this example, it is explained how to create and leverage such models in ReservoirComputing.jl. The full script for this example is available [here](https://github.com/MartinuzziFrancesco/reservoir-computing-examples/blob/main/hybrid/hybrid.jl). This example was run on Julia v1.7.2.
+Following the idea of giving physical information to machine learning models, the hybrid echo state networks [^1] try to achieve this results by feeding model data into the ESN. In this example, it is explained how to create and leverage such models in ReservoirComputing.jl.
 
 ## Generating the data
 
@@ -47,17 +47,22 @@ function prior_model_data_generator(u0, tspan, tsteps, model = lorenz)
 end
 ```
 
-Given the initial condition, time span, and time steps, this function returns the data for the chosen model. Now, using the `Hybrid` method, it is possible to input all this information to the model.
+Given the initial condition, time span, and time steps, this function returns the data for the chosen model. Now, using the `KnowledgeModel` method, it is possible to input all this information to `HybridESN`.
 
 ```@example hybrid
 using ReservoirComputing, Random
 Random.seed!(42)
 
-hybrid = Hybrid(prior_model_data_generator, u0, tspan_train, train_len)
+km = KnowledgeModel(prior_model_data_generator, u0, tspan_train, train_len)
 
-esn = ESN(input_data,
-    reservoir = RandSparseReservoir(300),
-    variation = hybrid)
+in_size = 3
+res_size = 300
+hesn = HybridESN(
+    km,
+    input_data,
+    in_size,
+    res_size;
+    reservoir = rand_sparse)
 ```
 
 ## Training and Prediction
@@ -65,7 +70,7 @@ esn = ESN(input_data,
 The training and prediction of the Hybrid ESN can proceed as usual:
 
 ```@example hybrid
-output_layer = train(esn, target_data, StandardRidge(0.3))
+output_layer = train(hesn, target_data, StandardRidge(0.3))
 output = esn(Generative(predict_len), output_layer)
 ```
 
diff --git a/docs/src/esn_tutorials/lorenz_basic.md b/docs/src/esn_tutorials/lorenz_basic.md
index f820a36d..364c2c4b 100644
--- a/docs/src/esn_tutorials/lorenz_basic.md
+++ b/docs/src/esn_tutorials/lorenz_basic.md
@@ -1,6 +1,6 @@
 # Lorenz System Forecasting
 
-This example expands on the readme Lorenz system forecasting to better showcase how to use methods and functions provided in the library for Echo State Networks. Here the prediction method used is `Generative`, for a more detailed explanation of the differences between `Generative` and `Predictive` please refer to the other examples given in the documentation. The full script for this example is available [here](https://github.com/MartinuzziFrancesco/reservoir-computing-examples/blob/main/lorenz_basic/lorenz_basic.jl). This example was run on Julia v1.7.2.
+This example expands on the readme Lorenz system forecasting to better showcase how to use methods and functions provided in the library for Echo State Networks. Here the prediction method used is `Generative`, for a more detailed explanation of the differences between `Generative` and `Predictive` please refer to the other examples given in the documentation.
 
 ## Generating the data
 
@@ -46,15 +46,15 @@ using ReservoirComputing
 
 #define ESN parameters
 res_size = 300
+in_size = 3
 res_radius = 1.2
 res_sparsity = 6 / 300
 input_scaling = 0.1
 
 #build ESN struct
-esn = ESN(input_data;
-    variation = Default(),
-    reservoir = RandSparseReservoir(res_size, radius = res_radius, sparsity = res_sparsity),
-    input_layer = WeightedLayer(scaling = input_scaling),
+esn = ESN(input_data, in_size, res_size;
+    reservoir = rand_sparse(;radius = res_radius, sparsity = res_sparsity),
+    input_layer = weighted_init(;scaling = input_scaling),
     reservoir_driver = RNN(),
     nla_type = NLADefault(),
     states_type = StandardStates())
@@ -62,9 +62,9 @@ esn = ESN(input_data;
 
 Most of the parameters chosen here mirror the default ones, so a direct call is not necessary. The readme example is identical to this one, except for the explicit call. Going line by line to see what is happening, starting from `res_size`: this value determines the dimensions of the reservoir matrix. In this case, a size of 300 has been chosen, so the reservoir matrix will be 300 x 300. This is not always the case, since some input layer constructions can modify the dimensions of the reservoir, but in that case, everything is taken care of internally.
 
-The `res_radius` determines the scaling of the spectral radius of the reservoir matrix; a proper scaling is necessary to assure the Echo State Property. The default value in the `RandSparseReservoir()` method is 1.0 in accordance with the most commonly followed guidelines found in the literature (see [^2] and references therein). The `sparsity` of the reservoir matrix in this case is obtained by choosing a degree of connections and dividing that by the reservoir size. Of course, it is also possible to simply choose any value between 0.0 and 1.0 to test behaviors for different sparsity values. In this example, the call to the parameters inside `RandSparseReservoir()` was done explicitly to showcase the meaning of each of them, but it is also possible to simply pass the values directly, like so `RandSparseReservoir(1.2, 6/300)`.
+The `res_radius` determines the scaling of the spectral radius of the reservoir matrix; a proper scaling is necessary to assure the Echo State Property. The default value in the `rand_sparse` method is 1.0 in accordance with the most commonly followed guidelines found in the literature (see [^2] and references therein). The `sparsity` of the reservoir matrix in this case is obtained by choosing a degree of connections and dividing that by the reservoir size. Of course, it is also possible to simply choose any value between 0.0 and 1.0 to test behaviors for different sparsity values.
 
-The value of `input_scaling` determines the upper and lower bounds of the uniform distribution of the weights in the `WeightedLayer()`. Like before, this value can be passed either as an argument or as a keyword argument `WeightedLayer(0.1)`. The value of 0.1 represents the default. The default input layer is the `DenseLayer`, a fully connected layer. The details of the weighted version can be found in [^3], for this example, this version returns the best results.
+The value of `input_scaling` determines the upper and lower bounds of the uniform distribution of the weights in the `weighted_init`. The value of 0.1 represents the default. The default input layer is the `scaled_rand`, a dense matrix. The details of the weighted version can be found in [^3], for this example, this version returns the best results.
 
 The reservoir driver represents the dynamics of the reservoir. In the standard ESN definition, these dynamics are obtained through a Recurrent Neural Network (RNN), and this is reflected by calling the `RNN` driver for the `ESN` struct. This option is set as the default, and unless there is the need to change parameters, it is not needed. The full equation is the following:
 
diff --git a/src/ReservoirComputing.jl b/src/ReservoirComputing.jl
index f8668b54..aa38de0c 100644
--- a/src/ReservoirComputing.jl
+++ b/src/ReservoirComputing.jl
@@ -23,7 +23,8 @@ export scaled_rand, weighted_init
 export rand_sparse, delay_line
 export RNN, MRNN, GRU, GRUParams, FullyGated, Minimal
 export ESN, train
-export DeepESN, HybridESN
+export HybridESN, KnowledgeModel
+export DeepESN
 export RECA, train
 export RandomMapping, RandomMaps
 export Generative, Predictive, OutputLayer
diff --git a/src/esn/esn.jl b/src/esn/esn.jl
index 3592ed8d..2beb552f 100644
--- a/src/esn/esn.jl
+++ b/src/esn/esn.jl
@@ -138,6 +138,6 @@ end
 #    x_pad = pad_state!(states_type, x_pad, x_tmp)
 #end
 
-function pad_esnstate!(variation, states_type, x_pad, x, args...)
-    x_pad = pad_state!(states_type, x_pad, x)
-end
+#function pad_esnstate!(variation, states_type, x_pad, x, args...)
+#    x_pad = pad_state!(states_type, x_pad, x)
+#end
diff --git a/src/esn/esn_predict.jl b/src/esn/esn_predict.jl
index 5955c762..1b4cd462 100644
--- a/src/esn/esn_predict.jl
+++ b/src/esn/esn_predict.jl
@@ -69,13 +69,13 @@ function next_state_prediction!(esn::ESN, x, x_new, out, out_pad, i, tmp_array,
 end
 
 #TODO fixme @MatrinuzziFra
-function next_state_prediction!(hesn::HybridESN, x, x_new, out, out_pad, i, tmp_array, args...)
+function next_state_prediction!(hesn::HybridESN, x, x_new, out, out_pad, i, tmp_array, model_prediction_data)
     out_tmp = vcat(out, model_prediction_data[:, i])
-    out_pad = pad_state!(esn.states_type, out_pad, out_tmp)
-    x = next_state!(x, esn.reservoir_driver, x[1:(esn.res_size)], out_pad,
-        esn.reservoir_matrix, esn.input_matrix, esn.bias_vector, tmp_array)
+    out_pad = pad_state!(hesn.states_type, out_pad, out_tmp)
+    x = next_state!(x, hesn.reservoir_driver, x[1:(hesn.res_size)], out_pad,
+    hesn.reservoir_matrix, hesn.input_matrix, hesn.bias_vector, tmp_array)
     x_tmp = vcat(x, model_prediction_data[:, i])
-    x_new = esn.states_type(esn.nla_type, x_tmp, out_pad)
+    x_new = hesn.states_type(hesn.nla_type, x_tmp, out_pad)
     return x, x_new
 end
 
diff --git a/src/esn/hybridesn.jl b/src/esn/hybridesn.jl
index d1cfdac9..f29028fc 100644
--- a/src/esn/hybridesn.jl
+++ b/src/esn/hybridesn.jl
@@ -49,24 +49,66 @@ function KnowledgeModel(prior_model, u0, tspan, datasize)
     tsteps = push!(trange, dt + trange[end])
     tspan_new = (tspan[1], dt + tspan[2])
     model_data = prior_model(u0, tspan_new, tsteps)
-    return Hybrid(prior_model, u0, tspan, dt, datasize, model_data)
+    return KnowledgeModel(prior_model, u0, tspan, dt, datasize, model_data)
+end
+
+function HybridESN(
+    model,
+    train_data,
+    in_size::Int,
+    res_size::Int;
+    input_layer = scaled_rand,
+    reservoir = rand_sparse,
+    bias = zeros64,
+    reservoir_driver = RNN(),
+    nla_type = NLADefault(),
+    states_type = StandardStates(),
+    washout = 0,
+    rng = _default_rng(),
+    T = Float32,
+    matrix_type = typeof(train_data)
+)
+
+    train_data = vcat(train_data, model.model_data[:, 1:(end - 1)])
+
+    if states_type isa AbstractPaddedStates
+        in_size = size(train_data, 1) + 1
+        train_data = vcat(Adapt.adapt(matrix_type, ones(1, size(train_data, 2))),
+            train_data)
+    else
+        in_size = size(train_data, 1)
+    end
+
+    reservoir_matrix = reservoir(rng, T, res_size, res_size)
+    #different from ESN, why?
+    input_matrix = input_layer(rng, T, res_size, in_size)
+    bias_vector = bias(rng, res_size)
+    inner_res_driver = reservoir_driver_params(reservoir_driver, res_size, in_size)
+    states = create_states(inner_res_driver, train_data, washout, reservoir_matrix,
+        input_matrix, bias_vector)
+    train_data = train_data[:, (washout + 1):end]
+
+    HybridESN(res_size, train_data, model, nla_type, input_matrix,
+        inner_res_driver, reservoir_matrix, bias_vector, states_type, washout,
+        states)
 end
 
 function (hesn::HybridESN)(prediction::AbstractPrediction,
     output_layer::AbstractOutputLayer;
-    last_state = esn.states[:, [end]],
+    last_state = hesn.states[:, [end]],
     kwargs...)
 
+    km = hesn.model
     pred_len = prediction.prediction_len
 
-    model = variation.prior_model
-    predict_tsteps = [variation.tspan[2] + variation.dt]
-    [append!(predict_tsteps, predict_tsteps[end] + variation.dt) for i in 1:pred_len]
-    tspan_new = (variation.tspan[2] + variation.dt, predict_tsteps[end])
-    u0 = variation.model_data[:, end]
+    model = km.prior_model
+    predict_tsteps = [km.tspan[2] + km.dt]
+    [append!(predict_tsteps, predict_tsteps[end] + km.dt) for i in 1:pred_len]
+    tspan_new = (km.tspan[2] + km.dt, predict_tsteps[end])
+    u0 = km.model_data[:, end]
     model_pred_data = model(u0, tspan_new, predict_tsteps)[:, 2:end]
 
-    return obtain_esn_prediction(esn, prediction, last_state, output_layer,
+    return obtain_esn_prediction(hesn, prediction, last_state, output_layer,
         model_pred_data;
         kwargs...)
 end
@@ -75,8 +117,8 @@ function train(hesn::HybridESN,
     target_data,
     training_method = StandardRidge(0.0))
 
-    states = vcat(esn.states, esn.variation.model_data[:, 2:end])
-    states_new = esn.states_type(esn.nla_type, states, esn.train_data[:, 1:end])
+    states = vcat(hesn.states, hesn.model.model_data[:, 2:end])
+    states_new = hesn.states_type(hesn.nla_type, states, hesn.train_data[:, 1:end])
 
     return _train(states_new, target_data, training_method)
 end
\ No newline at end of file