Skip to content

Commit

Permalink
up WeightInitializers to 0.1.6, rm SparseArrays
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinuzziFrancesco committed Feb 27, 2024
1 parent 2afa8a1 commit fbb0ceb
Show file tree
Hide file tree
Showing 8 changed files with 16 additions and 74 deletions.
4 changes: 1 addition & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"

Expand All @@ -34,10 +33,9 @@ Optim = "1"
PartialFunctions = "1.2"
Random = "1.10"
SafeTestsets = "0.1"
SparseArrays = "1.10"
Statistics = "1.10"
Test = "1"
WeightInitializers = "0.1.5"
WeightInitializers = "0.1.6"
julia = "1.10"

[extras]
Expand Down
15 changes: 7 additions & 8 deletions src/ReservoirComputing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,13 @@ using NNlib
using Optim
using PartialFunctions
using Random
using SparseArrays
using Statistics
using WeightInitializers

export NLADefault, NLAT1, NLAT2, NLAT3
export StandardStates, ExtendedStates, PaddedStates, PaddedExtendedStates
export StandardRidge, LinearModel
export scaled_rand, weighted_init, sparse_init, informed_init, minimal_init
export scaled_rand, weighted_init, informed_init, minimal_init
export rand_sparse, delay_line, delay_line_backward, cycle_jumps, simple_cycle, pseudo_svd
export RNN, MRNN, GRU, GRUParams, FullyGated, Minimal
export ESN, train
Expand Down Expand Up @@ -76,26 +75,26 @@ end
#fallbacks for initializers
for initializer in (:rand_sparse, :delay_line, :delay_line_backward, :cycle_jumps,
:simple_cycle, :pseudo_svd,
:scaled_rand, :weighted_init, :sparse_init, :informed_init, :minimal_init)
:scaled_rand, :weighted_init, :informed_init, :minimal_init)
NType = ifelse(initializer === :rand_sparse, Real, Number)
@eval function ($initializer)(dims::Integer...; kwargs...)
return $initializer(_default_rng(), Float32, dims...; kwargs...)
return $initializer(WeightInitializers._default_rng(), Float32, dims...; kwargs...)
end
@eval function ($initializer)(rng::AbstractRNG, dims::Integer...; kwargs...)
return $initializer(rng, Float32, dims...; kwargs...)
end
@eval function ($initializer)(::Type{T},
dims::Integer...; kwargs...) where {T <: $NType}
return $initializer(_default_rng(), T, dims...; kwargs...)
return $initializer(WeightInitializers._default_rng(), T, dims...; kwargs...)
end
@eval function ($initializer)(rng::AbstractRNG; kwargs...)
return __partial_apply($initializer, (rng, (; kwargs...)))
return WeightInitializers.__partial_apply($initializer, (rng, (; kwargs...)))
end
@eval function ($initializer)(rng::AbstractRNG,
::Type{T}; kwargs...) where {T <: $NType}
return __partial_apply($initializer, ((rng, T), (; kwargs...)))
return WeightInitializers.__partial_apply($initializer, ((rng, T), (; kwargs...)))
end
@eval ($initializer)(; kwargs...) = __partial_apply($initializer, (; kwargs...))
@eval ($initializer)(; kwargs...) = WeightInitializers.__partial_apply($initializer, (; kwargs...))
end

#general
Expand Down
2 changes: 1 addition & 1 deletion src/esn/deepesn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ function DeepESN(train_data,
nla_type = NLADefault(),
states_type = StandardStates(),
washout::Int = 0,
rng = _default_rng(),
rng = WeightInitializers._default_rng(),
T = Float64,
matrix_type = typeof(train_data))
if states_type isa AbstractPaddedStates
Expand Down
9 changes: 1 addition & 8 deletions src/esn/esn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ function ESN(train_data,
nla_type = NLADefault(),
states_type = StandardStates(),
washout = 0,
rng = _default_rng(),
rng = WeightInitializers._default_rng(),
T = Float32,
matrix_type = typeof(train_data))
if states_type isa AbstractPaddedStates
Expand Down Expand Up @@ -120,13 +120,6 @@ trained_esn = train(esn, target_data)
# Train the ESN using a custom training method
trained_esn = train(esn, target_data, training_method = StandardRidge(1.0))
```
# Notes
- When using a `Hybrid` variation, the function extends the state matrix with data from the
physical model included in the `variation`.
- The training is handled by a lower-level `_train` function which takes the new state matrix
and performs the actual training using the specified `training_method`.
"""
function train(esn::AbstractEchoStateNetwork,
target_data,
Expand Down
37 changes: 0 additions & 37 deletions src/esn/esn_input_layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,43 +77,6 @@ function weighted_init(rng::AbstractRNG,
return layer_matrix
end

# TODO: @MartinuzziFrancesco remove when pr gets into WeightInitializers
"""
sparse_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; scaling=T(0.1), sparsity=T(0.1)) where {T <: Number}
Create and return a sparse layer matrix for use in neural network models.
The matrix will be of size specified by `dims`, with the specified `sparsity` and `scaling`.
# Arguments
- `rng`: An instance of `AbstractRNG` for random number generation.
- `T`: The data type for the elements of the matrix.
- `dims`: Dimensions of the resulting sparse layer matrix.
- `scaling`: The scaling factor for the sparse layer matrix. Defaults to 0.1.
- `sparsity`: The sparsity level of the sparse layer matrix, controlling the fraction of zero elements. Defaults to 0.1.
# Returns
A sparse layer matrix.
# Example
```julia
rng = Random.default_rng()
input_layer = sparse_init(rng, Float64, (3, 300); scaling = 0.2, sparsity = 0.1)
```
"""
function sparse_init(rng::AbstractRNG, ::Type{T}, dims::Integer...;
scaling = T(0.1), sparsity = T(0.1)) where {T <: Number}
res_size, in_size = dims
layer_matrix = Matrix(sprand(rng, T, res_size, in_size, sparsity))
layer_matrix = T.(2.0) .* (layer_matrix .- T.(0.5))
replace!(layer_matrix, T(-1.0) => T(0.0))
layer_matrix = scaling .* layer_matrix

return layer_matrix
end

"""
informed_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; scaling=T(0.1), model_in_size, gamma=T(0.5)) where {T <: Number}
Expand Down
20 changes: 5 additions & 15 deletions src/esn/esn_reservoirs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@ function rand_sparse(rng::AbstractRNG,
::Type{T},
dims::Integer...;
radius = T(1.0),
sparsity = T(0.1)) where {T <: Number}
reservoir_matrix = Matrix{T}(sprand(rng, dims..., sparsity))
reservoir_matrix = T(2.0) .* (reservoir_matrix .- T(0.5))
replace!(reservoir_matrix, T(-1.0) => T(0.0))
sparsity = T(0.1),
std = T(1.0)) where {T <: Number}

lcl_sparsity = T(1)-sparsity #consistency with current implementations
reservoir_matrix = sparse_init(rng, T, dims...; sparsity=lcl_sparsity, std=std)
rho_w = maximum(abs.(eigvals(reservoir_matrix)))
reservoir_matrix .*= radius / rho_w
if Inf in unique(reservoir_matrix) || -Inf in unique(reservoir_matrix)
Expand Down Expand Up @@ -299,14 +300,3 @@ end
function get_sparsity(M, dim)
return size(M[M .!= 0], 1) / (dim * dim - size(M[M .!= 0], 1)) #nonzero/zero elements
end

# from WeightInitializers.jl, TODO @MartinuzziFrancesco consider importing package
function _default_rng()
@static if VERSION >= v"1.7"
return Xoshiro(1234)
else
return MersenneTwister(1234)
end
end

__partial_apply(fn, inp) = fn$inp
2 changes: 1 addition & 1 deletion src/esn/hybridesn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ function HybridESN(model,
nla_type = NLADefault(),
states_type = StandardStates(),
washout = 0,
rng = _default_rng(),
rng = WeightInitializers._default_rng(),
T = Float32,
matrix_type = typeof(train_data))
train_data = vcat(train_data, model.model_data[:, 1:(end - 1)])
Expand Down
1 change: 0 additions & 1 deletion test/esn/test_inits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ reservoir_inits = [
input_inits = [
scaled_rand,
weighted_init,
sparse_init,
minimal_init,
minimal_init(; sampling_type = :irrational)
]
Expand Down

0 comments on commit fbb0ceb

Please sign in to comment.