From 136addf3d70141f5f49dd434431328c1237e840d Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Wed, 8 Jan 2025 16:26:18 +0100 Subject: [PATCH] format --- src/esn/esn.jl | 6 ++---- src/esn/esn_reservoir_drivers.jl | 8 +------- src/predict.jl | 6 +++--- 3 files changed, 6 insertions(+), 14 deletions(-) diff --git a/src/esn/esn.jl b/src/esn/esn.jl index 7c3010d..6e88e02 100644 --- a/src/esn/esn.jl +++ b/src/esn/esn.jl @@ -50,7 +50,6 @@ julia> train_data = rand(Float32, 10, 100) # 10 features, 100 time steps julia> esn = ESN(train_data, 10, 300; washout=10) ESN(10 => 300) - ``` """ function ESN(train_data, @@ -95,8 +94,9 @@ function (esn::AbstractEchoStateNetwork)(prediction::AbstractPrediction, kwargs...) end -Base.show(io::IO, esn::ESN) = +function Base.show(io::IO, esn::ESN) print(io, "ESN(", size(esn.train_data, 1), " => ", size(esn.reservoir_matrix, 1), ")") +end #training dispatch on esn """ @@ -110,7 +110,6 @@ Trains an Echo State Network (ESN) using the provided target data and a specifie - `target_data`: Supervised training data for the ESN. - `training_method`: The method for training the ESN (default: `StandardRidge(0.0)`). - # Example ```jldoctest @@ -132,7 +131,6 @@ ESN(10 => 300) julia> output_layer = train(esn, rand(Float32, 3, 90)) OutputLayer successfully trained with output size: 3 - ``` """ function train(esn::AbstractEchoStateNetwork, diff --git a/src/esn/esn_reservoir_drivers.jl b/src/esn/esn_reservoir_drivers.jl index 7601a09..edbed91 100644 --- a/src/esn/esn_reservoir_drivers.jl +++ b/src/esn/esn_reservoir_drivers.jl @@ -20,7 +20,6 @@ specified reservoir driver. and reservoir nodes. - `bias_vector`: The bias vector to be added at each time step during the reservoir update. - """ function create_states(reservoir_driver::AbstractReservoirDriver, train_data, @@ -108,8 +107,6 @@ echo state networks (`ESN`). Defaults to `tanh_fast`. - `leaky_coefficient`: The leaky coefficient used in the RNN. Defaults to 1.0. - - """ function RNN(; activation_function=NNlib.fast_act(tanh), leaky_coefficient=1.0) RNN(activation_function, leaky_coefficient) @@ -185,7 +182,6 @@ This function creates an MRNN object with the specified activation functions, leaky coefficient, and scaling factors, which can be used as a reservoir driver in the ESN. - [^Lun2015]: Lun, Shu-Xian, et al. "_A novel model of leaky integrator echo state network for time-series prediction._" Neurocomputing 159 (2015): 58-66. @@ -234,10 +230,9 @@ end Returns a Fully Gated Recurrent Unit (FullyGated) initializer for the Echo State Network (ESN). -Returns the standard gated recurrent unit [^Cho2014] as a driver for the +Returns the standard gated recurrent unit [^Cho2014] as a driver for the echo state network (`ESN`). - [^Cho2014]: Cho, Kyunghyun, et al. "_Learning phrase representations using RNN encoder-decoder for statistical machine translation._" @@ -281,7 +276,6 @@ This driver is based on the GRU architecture [^Cho2014]. - `variant`: The GRU variant to use. By default, it uses the "FullyGated" variant. - [^Cho2014]: Cho, Kyunghyun, et al. "_Learning phrase representations using RNN encoder-decoder for statistical machine translation._" arXiv preprint arXiv:1406.1078 (2014). diff --git a/src/predict.jl b/src/predict.jl index 41a226c..4ba275e 100644 --- a/src/predict.jl +++ b/src/predict.jl @@ -9,8 +9,9 @@ struct OutputLayer{T, I, S, L} <: AbstractOutputLayer last_value::L end -Base.show(io::IO, ol::OutputLayer) = +function Base.show(io::IO, ol::OutputLayer) print(io, "OutputLayer successfully trained with output size: ", ol.out_size) +end #prediction types """ @@ -58,14 +59,13 @@ of input features (`prediction_data`). The `Predictive` prediction method uses the provided input data (`prediction_data`) to produce corresponding labels or outputs based -on the learned relationships in the model. +on the learned relationships in the model. """ function Predictive(prediction_data) prediction_len = size(prediction_data, 2) Predictive(prediction_data, prediction_len) end - function obtain_prediction(rc::AbstractReservoirComputer, prediction::Generative, x,