From 2a43506cd86b6fabdc993913938a822c1aac7b01 Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Fri, 7 Feb 2025 19:01:00 +0100 Subject: [PATCH 1/3] first work on lem --- src/cells/lem_cell.jl | 82 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 src/cells/lem_cell.jl diff --git a/src/cells/lem_cell.jl b/src/cells/lem_cell.jl new file mode 100644 index 0000000..8e096c7 --- /dev/null +++ b/src/cells/lem_cell.jl @@ -0,0 +1,82 @@ +#https://arxiv.org/pdf/2110.04744 +@doc raw""" + LEMCell(input_size => hidden_size, [dt]; + init_kernel = glorot_uniform, + init_recurrent_kernel = glorot_uniform, + bias = true) + +[Long expressive memory unit](https://arxiv.org/pdf/2110.04744). +See [`LEM`](@ref) for a layer that processes entire sequences. + +# Arguments + +- `input_size => hidden_size`: input and inner dimension of the layer +- `dt`: timestep. Defaul is 1.0 + +# Keyword arguments + +- `init_kernel`: initializer for the input to hidden weights +- `init_recurrent_kernel`: initializer for the hidden to hidden weights +- `bias`: include a bias or not. Default is `true` + +# Equations +```math +\begin{aligned} + +\end{aligned} +``` + +# Forward + + lemcell(inp, state) + lemcell(inp) + +## Arguments +- `inp`: The input to the lemcell. It should be a vector of size `input_size` + or a matrix of size `input_size x batch_size`. +- `state`: The hidden state of the LEMCell. It should be a vector of size + `hidden_size` or a matrix of size `hidden_size x batch_size`. + If not provided, it is assumed to be a vector of zeros, + initialized by [`Flux.initialstates`](@extref). + +## Returns +- A tuple `(output, state)`, where both elements are given by the updated state + `new_state`, a tensor of size `hidden_size` or `hidden_size x batch_size`. +""" +struct LEMCell{I, H, V, D} <: AbstractRecurrentCell + Wi::I + Wh::H + bias::V + dt::D +end + +@layer LEMCell + +function LEMCell((input_size, hidden_size)::Pair{<:Int, <:Int}, dt::Number=1.0; + init_kernel=glorot_uniform, init_recurrent_kernel=glorot_uniform, + bias::Bool=true) + Wi = init_kernel(hidden_size * 4, input_size) + Wh = init_recurrent_kernel(hidden_size * 3, hidden_size) + Wz = init_recurrent_kernel(hidden_size, hidden_size) + b = create_bias(Wi, bias, size(Wi, 1)) + + return LEMCell(Wi, Wh, b) +end + +function (lem::LEMCell)(inp::AbstractVecOrMat, (state, z_state)) + _size_check(lem, inp, 1 => size(lem.Wi, 2)) + Wi, Wh, b = lem.Wi, lem.Wh, lem.bias + #split + gxs = chunk(Wi * inp .+ b, 4; dims=1) + ghs = chunk(Wh * state, 3; dims=1) + + msdt_bar = lem.dt .* sigmoid_fast.(gxs[1] .+ ghs[1]) + ms_dt = lem.dt .* sigmoid_fast.(gxs[2] .+ ghs[2]) + new_zstate = (1.0 .- ms_dt) .* z_state .+ ms_dt .* tanh_fast(gxs[3] .+ ghs[3]) + new_state = (1.0 .- msdt_bar) .* state .+ msdt_bar .* tanh_fast(gxs[4] .+ Wz*z_state) + return new_zstate, (new_state, new_zstate) +end + +function Base.show(io::IO, lem::LEMCell) + print(io, "LEMCell(", size(lem.Wi, 2), " => ", size(lem.Wi, 1) ÷ 2, ")") +end \ No newline at end of file From f4162049dc8869d9e45c15b9be66a2d73399c21f Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Sat, 8 Feb 2025 16:04:47 +0100 Subject: [PATCH 2/3] tests and docs for lem --- docs/src/api/cells.md | 1 + docs/src/api/layers.md | 1 + src/RecurrentLayers.jl | 9 +-- src/cells/lem_cell.jl | 111 +++++++++++++++++++++++++++++---- src/cells/peepholelstm_cell.jl | 4 +- src/generics.jl | 2 +- test/test_cells.jl | 14 +++++ test/test_layers.jl | 2 +- test/test_wrappers.jl | 2 +- 9 files changed, 125 insertions(+), 21 deletions(-) diff --git a/docs/src/api/cells.md b/docs/src/api/cells.md index c1a96c1..04d5643 100644 --- a/docs/src/api/cells.md +++ b/docs/src/api/cells.md @@ -17,4 +17,5 @@ PeepholeLSTMCell FastRNNCell FastGRNNCell FSRNNCell +LEMCell ``` \ No newline at end of file diff --git a/docs/src/api/layers.md b/docs/src/api/layers.md index eac0ad1..306f31e 100644 --- a/docs/src/api/layers.md +++ b/docs/src/api/layers.md @@ -16,4 +16,5 @@ PeepholeLSTM FastRNN FastGRNN FSRNN +LEM ``` \ No newline at end of file diff --git a/src/RecurrentLayers.jl b/src/RecurrentLayers.jl index c70c792..1140686 100644 --- a/src/RecurrentLayers.jl +++ b/src/RecurrentLayers.jl @@ -9,9 +9,9 @@ using NNlib: fast_act export MGUCell, LiGRUCell, IndRNNCell, RANCell, LightRUCell, RHNCell, RHNCellUnit, NASCell, MUT1Cell, MUT2Cell, MUT3Cell, SCRNCell, PeepholeLSTMCell, - FastRNNCell, FastGRNNCell, FSRNNCell + FastRNNCell, FastGRNNCell, FSRNNCell, LEMCell export MGU, LiGRU, IndRNN, RAN, LightRU, NAS, RHN, MUT1, MUT2, MUT3, - SCRN, PeepholeLSTM, FastRNN, FastGRNN, FSRNN + SCRN, PeepholeLSTM, FastRNN, FastGRNN, FSRNN, LEM export StackedRNN @compat(public, (initialstates)) @@ -30,16 +30,17 @@ include("cells/scrn_cell.jl") include("cells/peepholelstm_cell.jl") include("cells/fastrnn_cell.jl") include("cells/fsrnn_cell.jl") +include("cells/lem_cell.jl") include("wrappers/stackedrnn.jl") ### fallbacks for functors ### rlayers = (:FastRNN, :FastGRNN, :IndRNN, :LightRU, :LiGRU, :MGU, :MUT1, - :MUT2, :MUT3, :NAS, :PeepholeLSTM, :RAN, :SCRN, :FSRNN) + :MUT2, :MUT3, :NAS, :PeepholeLSTM, :RAN, :SCRN, :FSRNN, :LEM) rcells = (:FastRNNCell, :FastGRNNCell, :IndRNNCell, :LightRUCell, :LiGRUCell, :MGUCell, :MUT1Cell, :MUT2Cell, :MUT3Cell, :NASCell, :PeepholeLSTMCell, - :RANCell, :SCRNCell, :FSRNNCell) + :RANCell, :SCRNCell, :FSRNNCell, :LEMCell) for (rlayer, rcell) in zip(rlayers, rcells) @eval begin diff --git a/src/cells/lem_cell.jl b/src/cells/lem_cell.jl index 8e096c7..99e8479 100644 --- a/src/cells/lem_cell.jl +++ b/src/cells/lem_cell.jl @@ -1,8 +1,7 @@ #https://arxiv.org/pdf/2110.04744 @doc raw""" LEMCell(input_size => hidden_size, [dt]; - init_kernel = glorot_uniform, - init_recurrent_kernel = glorot_uniform, + init_kernel = glorot_uniform, init_recurrent_kernel = glorot_uniform, bias = true) [Long expressive memory unit](https://arxiv.org/pdf/2110.04744). @@ -22,7 +21,14 @@ See [`LEM`](@ref) for a layer that processes entire sequences. # Equations ```math \begin{aligned} - +\boldsymbol{\Delta t_n} &= \Delta \hat{t} \hat{\sigma} + (W_1 y_{n-1} + V_1 u_n + b_1) \\ +\overline{\boldsymbol{\Delta t_n}} &= \Delta \hat{t} + \hat{\sigma} (W_2 y_{n-1} + V_2 u_n + b_2) \\ +z_n &= (1 - \boldsymbol{\Delta t_n}) \odot z_{n-1} + + \boldsymbol{\Delta t_n} \odot \sigma (W_z y_{n-1} + V_z u_n + b_z) \\ +y_n &= (1 - \boldsymbol{\Delta t_n}) \odot y_{n-1} + + \boldsymbol{\Delta t_n} \odot \sigma (W_y z_n + V_y u_n + b_y) \end{aligned} ``` @@ -43,16 +49,17 @@ See [`LEM`](@ref) for a layer that processes entire sequences. - A tuple `(output, state)`, where both elements are given by the updated state `new_state`, a tensor of size `hidden_size` or `hidden_size x batch_size`. """ -struct LEMCell{I, H, V, D} <: AbstractRecurrentCell +struct LEMCell{I, H, Z, V, D} <: AbstractDoubleRecurrentCell Wi::I Wh::H + Wz::Z bias::V dt::D end @layer LEMCell -function LEMCell((input_size, hidden_size)::Pair{<:Int, <:Int}, dt::Number=1.0; +function LEMCell((input_size, hidden_size)::Pair{<:Int, <:Int}, dt::Number=1.0f0; init_kernel=glorot_uniform, init_recurrent_kernel=glorot_uniform, bias::Bool=true) Wi = init_kernel(hidden_size * 4, input_size) @@ -60,23 +67,103 @@ function LEMCell((input_size, hidden_size)::Pair{<:Int, <:Int}, dt::Number=1.0; Wz = init_recurrent_kernel(hidden_size, hidden_size) b = create_bias(Wi, bias, size(Wi, 1)) - return LEMCell(Wi, Wh, b) + return LEMCell(Wi, Wh, Wz, b, eltype(Wi)(dt)) end function (lem::LEMCell)(inp::AbstractVecOrMat, (state, z_state)) _size_check(lem, inp, 1 => size(lem.Wi, 2)) - Wi, Wh, b = lem.Wi, lem.Wh, lem.bias + Wi, Wh, Wz, b = lem.Wi, lem.Wh, lem.Wz, lem.bias + T = eltype(Wi) #split gxs = chunk(Wi * inp .+ b, 4; dims=1) ghs = chunk(Wh * state, 3; dims=1) msdt_bar = lem.dt .* sigmoid_fast.(gxs[1] .+ ghs[1]) ms_dt = lem.dt .* sigmoid_fast.(gxs[2] .+ ghs[2]) - new_zstate = (1.0 .- ms_dt) .* z_state .+ ms_dt .* tanh_fast(gxs[3] .+ ghs[3]) - new_state = (1.0 .- msdt_bar) .* state .+ msdt_bar .* tanh_fast(gxs[4] .+ Wz*z_state) - return new_zstate, (new_state, new_zstate) + new_zstate = (T(1.0f0) .- ms_dt) .* z_state .+ ms_dt .* tanh_fast(gxs[3] .+ ghs[3]) + new_state = (T(1.0f0) .- msdt_bar) .* state .+ + msdt_bar .* tanh_fast(gxs[4] .+ Wz * z_state) + return new_state, (new_state, new_zstate) end function Base.show(io::IO, lem::LEMCell) - print(io, "LEMCell(", size(lem.Wi, 2), " => ", size(lem.Wi, 1) ÷ 2, ")") -end \ No newline at end of file + print(io, "LEMCell(", size(lem.Wi, 2), " => ", size(lem.Wi, 1) ÷ 4, ")") +end + +@doc raw""" + LEM(input_size => hidden_size, [dt]; + return_state=false, init_kernel = glorot_uniform, + init_recurrent_kernel = glorot_uniform, bias = true) + +[Long expressive memory network](https://arxiv.org/pdf/2110.04744). +See [`LEMCell`](@ref) for a layer that processes a single sequence. + +# Arguments + +- `input_size => hidden_size`: input and inner dimension of the layer +- `dt`: timestep. Defaul is 1.0 + +# Keyword arguments + +- `init_kernel`: initializer for the input to hidden weights +- `init_recurrent_kernel`: initializer for the hidden to hidden weights +- `bias`: include a bias or not. Default is `true` +- `return_state`: Option to return the last state together with the output. + Default is `false`. + +# Equations + +```math +\begin{aligned} +\boldsymbol{\Delta t_n} &= \Delta \hat{t} \hat{\sigma} + (W_1 y_{n-1} + V_1 u_n + b_1) \\ +\overline{\boldsymbol{\Delta t_n}} &= \Delta \hat{t} + \hat{\sigma} (W_2 y_{n-1} + V_2 u_n + b_2) \\ +z_n &= (1 - \boldsymbol{\Delta t_n}) \odot z_{n-1} + + \boldsymbol{\Delta t_n} \odot \sigma (W_z y_{n-1} + V_z u_n + b_z) \\ +y_n &= (1 - \boldsymbol{\Delta t_n}) \odot y_{n-1} + + \boldsymbol{\Delta t_n} \odot \sigma (W_y z_n + V_y u_n + b_y) +\end{aligned} +``` + +# Forward + + LEM(inp, (state, zstate)) + LEM(inp) + +## Arguments +- `inp`: The input to the LEM. It should be a vector of size `input_size x len` + or a matrix of size `input_size x len x batch_size`. +- `(state, cstate)`: A tuple containing the hidden and cell states of the LEM. + They should be vectors of size `hidden_size` or matrices of size + `hidden_size x batch_size`. If not provided, they are assumed to be vectors of zeros, + initialized by [`Flux.initialstates`](@extref). + +## Returns +- New hidden states `new_states` as an array of size `hidden_size x len x batch_size`. + When `return_state = true` it returns a tuple of the hidden stats `new_states` and + the last state of the iteration. +""" +struct LEM{S, M} <: AbstractRecurrentLayer{S} + cell::M +end + +@layer :noexpand LEM + +function LEM((input_size, hidden_size)::Pair{<:Int, <:Int}, dt::Number=1.0; + return_state::Bool=false, kwargs...) + cell = LEMCell(input_size => hidden_size, dt; kwargs...) + return LEM{return_state, typeof(cell)}(cell) +end + +function functor(rnn::LEM{S}) where {S} + params = (cell=rnn.cell,) + reconstruct = p -> LEM{S, typeof(p.cell)}(p.cell) + return params, reconstruct +end + +function Base.show(io::IO, lem::LEM) + print(io, "LEM(", size(lem.cell.Wi, 2), + " => ", size(lem.cell.Wi, 1) ÷ 4) + print(io, ")") +end diff --git a/src/cells/peepholelstm_cell.jl b/src/cells/peepholelstm_cell.jl index f378711..ffc9118 100644 --- a/src/cells/peepholelstm_cell.jl +++ b/src/cells/peepholelstm_cell.jl @@ -73,7 +73,7 @@ function (lstm::PeepholeLSTMCell)(inp::AbstractVecOrMat, (state, c_state)) input, forget, cell, output = chunk(g, 4; dims=1) new_cstate = @. sigmoid_fast(forget) * c_state + sigmoid_fast(input) * tanh_fast(cell) new_state = @. sigmoid_fast(output) * tanh_fast(new_cstate) - return new_cstate, (new_state, new_cstate) + return new_state, (new_state, new_cstate) end function Base.show(io::IO, lstm::PeepholeLSTMCell) @@ -150,6 +150,6 @@ end function Base.show(io::IO, peepholelstm::PeepholeLSTM) print(io, "PeepholeLSTM(", size(peepholelstm.cell.Wi, 2), - " => ", size(peepholelstm.cell.Wi, 1)) + " => ", size(peepholelstm.cell.Wi, 1) ÷ 4) print(io, ")") end diff --git a/src/generics.jl b/src/generics.jl index 5ce499a..fbf7147 100644 --- a/src/generics.jl +++ b/src/generics.jl @@ -22,7 +22,7 @@ function initialstates(rlayer::AbstractRecurrentLayer) return initialstates(rlayer.cell) end -function (rlayer::AbstractRecurrentLayer)(inp::AbstractVecOrMat) +function (rlayer::AbstractRecurrentLayer)(inp::AbstractArray) state = initialstates(rlayer) return rlayer(inp, state) end diff --git a/test/test_cells.jl b/test/test_cells.jl index 4488ac9..b97471e 100644 --- a/test/test_cells.jl +++ b/test/test_cells.jl @@ -67,3 +67,17 @@ end inp = rand(Float32, 3) @test rnncell(inp) == rnncell(inp, zeros(Float32, 5)) end + +@testset "LEMCell" begin + rnncell = LEMCell(3 => 5) + @test length(Flux.trainables(rnncell)) == 4 + + inp = rand(Float32, 3) + @test rnncell(inp) == rnncell(inp, (zeros(Float32, 5), zeros(Float32, 5))) + + rnncell = LEMCell(3 => 5; bias=false) + @test length(Flux.trainables(rnncell)) == 3 + + inp = rand(Float32, 3) + @test rnncell(inp) == rnncell(inp, (zeros(Float32, 5), zeros(Float32, 5))) +end diff --git a/test/test_layers.jl b/test/test_layers.jl index c74ac9d..9c07fae 100644 --- a/test/test_layers.jl +++ b/test/test_layers.jl @@ -2,7 +2,7 @@ using RecurrentLayers, Flux, Test import Flux: initialstates layers = [MGU, LiGRU, RAN, LightRU, NAS, MUT1, MUT2, MUT3, - SCRN, PeepholeLSTM, FastRNN, FastGRNN] + SCRN, PeepholeLSTM, FastRNN, FastGRNN, LEM] #IndRNN handles internal states diffrently #RHN should be checked more for consistency for initialstates diff --git a/test/test_wrappers.jl b/test/test_wrappers.jl index b83d958..ca960b4 100644 --- a/test/test_wrappers.jl +++ b/test/test_wrappers.jl @@ -1,7 +1,7 @@ using RecurrentLayers, Flux, Test layers = [RNN, GRU, GRUv3, LSTM, MGU, LiGRU, RAN, LightRU, NAS, MUT1, MUT2, MUT3, - SCRN, PeepholeLSTM, FastRNN, FastGRNN] + SCRN, PeepholeLSTM, FastRNN, FastGRNN, LEM] @testset "Sizes for StackedRNN with layer: $layer" for layer in layers wrap = StackedRNN(layer, 2 => 4) From ba0d3f795aa52fd513135edccf906c7483d0cfc2 Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Sat, 8 Feb 2025 16:23:14 +0100 Subject: [PATCH 3/3] up version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 23c5d08..7f798b8 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "RecurrentLayers" uuid = "78449bcf-6750-4b78-9e82-63d4a1ccdf8c" authors = ["Francesco Martinuzzi"] -version = "0.2.6" +version = "0.2.7" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"