Skip to content

Commit

Permalink
changed all defaults to Float32
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinuzziFrancesco committed Jan 6, 2025
1 parent f0849f3 commit 7c50ee2
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 16 deletions.
8 changes: 3 additions & 5 deletions src/esn/deepesn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ temporal features.
- `input_layer`: A function or an array of functions to initialize the input
matrices for each layer. Default is `scaled_rand` for each layer.
- `bias`: A function or an array of functions to initialize the bias vectors
for each layer. Default is `zeros64` for each layer.
for each layer. Default is `zeros32` for each layer.
- `reservoir`: A function or an array of functions to initialize the reservoir
matrices for each layer. Default is `rand_sparse` for each layer.
- `reservoir_driver`: The driving system for the reservoir.
Expand All @@ -50,8 +50,6 @@ temporal features.
Default is 0.
- `rng`: Random number generator used for initializing weights. Default is the package's
default random number generator.
- `T`: The data type for the matrices (e.g., `Float64`). Influences computational
efficiency and precision.
- `matrix_type`: The type of matrix used for storing the training data.
Default is inferred from `train_data`.
Expand All @@ -74,21 +72,21 @@ function DeepESN(train_data,
res_size::Int;
depth::Int=2,
input_layer=fill(scaled_rand, depth),
bias=fill(zeros64, depth),
bias=fill(zeros32, depth),
reservoir=fill(rand_sparse, depth),
reservoir_driver=RNN(),
nla_type=NLADefault(),
states_type=StandardStates(),
washout::Int=0,
rng=Utils.default_rng(),
T=Float64,
matrix_type=typeof(train_data))
if states_type isa AbstractPaddedStates
in_size = size(train_data, 1) + 1
train_data = vcat(Adapt.adapt(matrix_type, ones(1, size(train_data, 2))),
train_data)
end

T = eltype(train_data)
reservoir_matrix = [reservoir[i](rng, T, res_size, res_size) for i in 1:depth]
input_matrix = [i == 1 ? input_layer[i](rng, T, res_size, in_size) :
input_layer[i](rng, T, res_size, res_size) for i in 1:depth]
Expand Down
4 changes: 2 additions & 2 deletions src/esn/esn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,20 +49,20 @@ function ESN(train_data,
res_size::Int;
input_layer=scaled_rand,
reservoir=rand_sparse,
bias=zeros64,
bias=zeros32,
reservoir_driver=RNN(),
nla_type=NLADefault(),
states_type=StandardStates(),
washout=0,
rng=Utils.default_rng(),
T=Float32,
matrix_type=typeof(train_data))
if states_type isa AbstractPaddedStates
in_size = size(train_data, 1) + 1
train_data = vcat(Adapt.adapt(matrix_type, ones(1, size(train_data, 2))),
train_data)
end

T = eltype(train_data)
reservoir_matrix = reservoir(rng, T, res_size, res_size)
input_matrix = input_layer(rng, T, res_size, in_size)
bias_vector = bias(rng, res_size)
Expand Down
2 changes: 1 addition & 1 deletion src/esn/hybridesn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ function HybridESN(model,
res_size::Int;
input_layer=scaled_rand,
reservoir=rand_sparse,
bias=zeros64,
bias=zeros32,
reservoir_driver=RNN(),
nla_type=NLADefault(),
states_type=StandardStates(),
Expand Down
21 changes: 13 additions & 8 deletions test/esn/deepesn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,19 @@ const train_len = 400
const predict_len = 100
const input_data = reduce(hcat, data[1:(train_len - 1)])
const target_data = reduce(hcat, data[2:train_len])
const test = reduce(hcat, data[(train_len + 1):(train_len + predict_len)])
const test_data = reduce(hcat, data[(train_len + 1):(train_len + predict_len)])
const reg = 10e-6
#test_types = [Float64, Float32, Float16]

Random.seed!(77)
res = rand_sparse(; radius=1.2, sparsity=0.1)
esn = DeepESN(input_data, 1, res_size)
test_types = [Float64, Float32, Float16]
zeros_types = [zeros64, zeros32, zeros16]

output_layer = train(esn, target_data)
output = esn(Generative(length(test)), output_layer)
@test mean(abs.(test .- output)) ./ mean(abs.(test)) < 0.22
for (tidx,t) in enumerate(test_types)
Random.seed!(77)
res = rand_sparse(; radius=1.2, sparsity=0.1)
esn = DeepESN(t.(input_data), 1, res_size;
bias=fill(zeros_types[tidx], 2))

output_layer = train(esn, t.(target_data))
output = esn(Generative(length(test_data)), output_layer)
@test eltype(output) == t
end

0 comments on commit 7c50ee2

Please sign in to comment.