Skip to content

Commit

Permalink
Merge pull request #56 from MartinuzziFrancesco/fm/lem
Browse files Browse the repository at this point in the history
[CELL] Long Expressive Memory (LEM)
  • Loading branch information
MartinuzziFrancesco authored Feb 8, 2025
2 parents 79f4b6f + ba0d3f7 commit 54b0227
Show file tree
Hide file tree
Showing 10 changed files with 196 additions and 10 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "RecurrentLayers"
uuid = "78449bcf-6750-4b78-9e82-63d4a1ccdf8c"
authors = ["Francesco Martinuzzi"]
version = "0.2.6"
version = "0.2.7"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand Down
1 change: 1 addition & 0 deletions docs/src/api/cells.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ PeepholeLSTMCell
FastRNNCell
FastGRNNCell
FSRNNCell
LEMCell
```
1 change: 1 addition & 0 deletions docs/src/api/layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ PeepholeLSTM
FastRNN
FastGRNN
FSRNN
LEM
```
9 changes: 5 additions & 4 deletions src/RecurrentLayers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ using NNlib: fast_act

export MGUCell, LiGRUCell, IndRNNCell, RANCell, LightRUCell, RHNCell,
RHNCellUnit, NASCell, MUT1Cell, MUT2Cell, MUT3Cell, SCRNCell, PeepholeLSTMCell,
FastRNNCell, FastGRNNCell, FSRNNCell
FastRNNCell, FastGRNNCell, FSRNNCell, LEMCell
export MGU, LiGRU, IndRNN, RAN, LightRU, NAS, RHN, MUT1, MUT2, MUT3,
SCRN, PeepholeLSTM, FastRNN, FastGRNN, FSRNN
SCRN, PeepholeLSTM, FastRNN, FastGRNN, FSRNN, LEM
export StackedRNN

@compat(public, (initialstates))
Expand All @@ -30,16 +30,17 @@ include("cells/scrn_cell.jl")
include("cells/peepholelstm_cell.jl")
include("cells/fastrnn_cell.jl")
include("cells/fsrnn_cell.jl")
include("cells/lem_cell.jl")

include("wrappers/stackedrnn.jl")

### fallbacks for functors ###
rlayers = (:FastRNN, :FastGRNN, :IndRNN, :LightRU, :LiGRU, :MGU, :MUT1,
:MUT2, :MUT3, :NAS, :PeepholeLSTM, :RAN, :SCRN, :FSRNN)
:MUT2, :MUT3, :NAS, :PeepholeLSTM, :RAN, :SCRN, :FSRNN, :LEM)

rcells = (:FastRNNCell, :FastGRNNCell, :IndRNNCell, :LightRUCell, :LiGRUCell,
:MGUCell, :MUT1Cell, :MUT2Cell, :MUT3Cell, :NASCell, :PeepholeLSTMCell,
:RANCell, :SCRNCell, :FSRNNCell)
:RANCell, :SCRNCell, :FSRNNCell, :LEMCell)

for (rlayer, rcell) in zip(rlayers, rcells)
@eval begin
Expand Down
169 changes: 169 additions & 0 deletions src/cells/lem_cell.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
#https://arxiv.org/pdf/2110.04744
@doc raw"""
LEMCell(input_size => hidden_size, [dt];
init_kernel = glorot_uniform, init_recurrent_kernel = glorot_uniform,
bias = true)
[Long expressive memory unit](https://arxiv.org/pdf/2110.04744).
See [`LEM`](@ref) for a layer that processes entire sequences.
# Arguments
- `input_size => hidden_size`: input and inner dimension of the layer
- `dt`: timestep. Defaul is 1.0
# Keyword arguments
- `init_kernel`: initializer for the input to hidden weights
- `init_recurrent_kernel`: initializer for the hidden to hidden weights
- `bias`: include a bias or not. Default is `true`
# Equations
```math
\begin{aligned}
\boldsymbol{\Delta t_n} &= \Delta \hat{t} \hat{\sigma}
(W_1 y_{n-1} + V_1 u_n + b_1) \\
\overline{\boldsymbol{\Delta t_n}} &= \Delta \hat{t}
\hat{\sigma} (W_2 y_{n-1} + V_2 u_n + b_2) \\
z_n &= (1 - \boldsymbol{\Delta t_n}) \odot z_{n-1} +
\boldsymbol{\Delta t_n} \odot \sigma (W_z y_{n-1} + V_z u_n + b_z) \\
y_n &= (1 - \boldsymbol{\Delta t_n}) \odot y_{n-1} +
\boldsymbol{\Delta t_n} \odot \sigma (W_y z_n + V_y u_n + b_y)
\end{aligned}
```
# Forward
lemcell(inp, state)
lemcell(inp)
## Arguments
- `inp`: The input to the lemcell. It should be a vector of size `input_size`
or a matrix of size `input_size x batch_size`.
- `state`: The hidden state of the LEMCell. It should be a vector of size
`hidden_size` or a matrix of size `hidden_size x batch_size`.
If not provided, it is assumed to be a vector of zeros,
initialized by [`Flux.initialstates`](@extref).
## Returns
- A tuple `(output, state)`, where both elements are given by the updated state
`new_state`, a tensor of size `hidden_size` or `hidden_size x batch_size`.
"""
struct LEMCell{I, H, Z, V, D} <: AbstractDoubleRecurrentCell
Wi::I
Wh::H
Wz::Z
bias::V
dt::D
end

@layer LEMCell

function LEMCell((input_size, hidden_size)::Pair{<:Int, <:Int}, dt::Number=1.0f0;
init_kernel=glorot_uniform, init_recurrent_kernel=glorot_uniform,
bias::Bool=true)
Wi = init_kernel(hidden_size * 4, input_size)
Wh = init_recurrent_kernel(hidden_size * 3, hidden_size)
Wz = init_recurrent_kernel(hidden_size, hidden_size)
b = create_bias(Wi, bias, size(Wi, 1))

return LEMCell(Wi, Wh, Wz, b, eltype(Wi)(dt))
end

function (lem::LEMCell)(inp::AbstractVecOrMat, (state, z_state))
_size_check(lem, inp, 1 => size(lem.Wi, 2))
Wi, Wh, Wz, b = lem.Wi, lem.Wh, lem.Wz, lem.bias
T = eltype(Wi)
#split
gxs = chunk(Wi * inp .+ b, 4; dims=1)
ghs = chunk(Wh * state, 3; dims=1)

msdt_bar = lem.dt .* sigmoid_fast.(gxs[1] .+ ghs[1])
ms_dt = lem.dt .* sigmoid_fast.(gxs[2] .+ ghs[2])
new_zstate = (T(1.0f0) .- ms_dt) .* z_state .+ ms_dt .* tanh_fast(gxs[3] .+ ghs[3])
new_state = (T(1.0f0) .- msdt_bar) .* state .+
msdt_bar .* tanh_fast(gxs[4] .+ Wz * z_state)
return new_state, (new_state, new_zstate)
end

function Base.show(io::IO, lem::LEMCell)
print(io, "LEMCell(", size(lem.Wi, 2), " => ", size(lem.Wi, 1) ÷ 4, ")")
end

@doc raw"""
LEM(input_size => hidden_size, [dt];
return_state=false, init_kernel = glorot_uniform,
init_recurrent_kernel = glorot_uniform, bias = true)
[Long expressive memory network](https://arxiv.org/pdf/2110.04744).
See [`LEMCell`](@ref) for a layer that processes a single sequence.
# Arguments
- `input_size => hidden_size`: input and inner dimension of the layer
- `dt`: timestep. Defaul is 1.0
# Keyword arguments
- `init_kernel`: initializer for the input to hidden weights
- `init_recurrent_kernel`: initializer for the hidden to hidden weights
- `bias`: include a bias or not. Default is `true`
- `return_state`: Option to return the last state together with the output.
Default is `false`.
# Equations
```math
\begin{aligned}
\boldsymbol{\Delta t_n} &= \Delta \hat{t} \hat{\sigma}
(W_1 y_{n-1} + V_1 u_n + b_1) \\
\overline{\boldsymbol{\Delta t_n}} &= \Delta \hat{t}
\hat{\sigma} (W_2 y_{n-1} + V_2 u_n + b_2) \\
z_n &= (1 - \boldsymbol{\Delta t_n}) \odot z_{n-1} +
\boldsymbol{\Delta t_n} \odot \sigma (W_z y_{n-1} + V_z u_n + b_z) \\
y_n &= (1 - \boldsymbol{\Delta t_n}) \odot y_{n-1} +
\boldsymbol{\Delta t_n} \odot \sigma (W_y z_n + V_y u_n + b_y)
\end{aligned}
```
# Forward
LEM(inp, (state, zstate))
LEM(inp)
## Arguments
- `inp`: The input to the LEM. It should be a vector of size `input_size x len`
or a matrix of size `input_size x len x batch_size`.
- `(state, cstate)`: A tuple containing the hidden and cell states of the LEM.
They should be vectors of size `hidden_size` or matrices of size
`hidden_size x batch_size`. If not provided, they are assumed to be vectors of zeros,
initialized by [`Flux.initialstates`](@extref).
## Returns
- New hidden states `new_states` as an array of size `hidden_size x len x batch_size`.
When `return_state = true` it returns a tuple of the hidden stats `new_states` and
the last state of the iteration.
"""
struct LEM{S, M} <: AbstractRecurrentLayer{S}
cell::M
end

@layer :noexpand LEM

function LEM((input_size, hidden_size)::Pair{<:Int, <:Int}, dt::Number=1.0;
return_state::Bool=false, kwargs...)
cell = LEMCell(input_size => hidden_size, dt; kwargs...)
return LEM{return_state, typeof(cell)}(cell)
end

function functor(rnn::LEM{S}) where {S}
params = (cell=rnn.cell,)
reconstruct = p -> LEM{S, typeof(p.cell)}(p.cell)
return params, reconstruct
end

function Base.show(io::IO, lem::LEM)
print(io, "LEM(", size(lem.cell.Wi, 2),
" => ", size(lem.cell.Wi, 1) ÷ 4)
print(io, ")")
end
4 changes: 2 additions & 2 deletions src/cells/peepholelstm_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ function (lstm::PeepholeLSTMCell)(inp::AbstractVecOrMat, (state, c_state))
input, forget, cell, output = chunk(g, 4; dims=1)
new_cstate = @. sigmoid_fast(forget) * c_state + sigmoid_fast(input) * tanh_fast(cell)
new_state = @. sigmoid_fast(output) * tanh_fast(new_cstate)
return new_cstate, (new_state, new_cstate)
return new_state, (new_state, new_cstate)
end

function Base.show(io::IO, lstm::PeepholeLSTMCell)
Expand Down Expand Up @@ -150,6 +150,6 @@ end

function Base.show(io::IO, peepholelstm::PeepholeLSTM)
print(io, "PeepholeLSTM(", size(peepholelstm.cell.Wi, 2),
" => ", size(peepholelstm.cell.Wi, 1))
" => ", size(peepholelstm.cell.Wi, 1) ÷ 4)
print(io, ")")
end
2 changes: 1 addition & 1 deletion src/generics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ function initialstates(rlayer::AbstractRecurrentLayer)
return initialstates(rlayer.cell)
end

function (rlayer::AbstractRecurrentLayer)(inp::AbstractVecOrMat)
function (rlayer::AbstractRecurrentLayer)(inp::AbstractArray)
state = initialstates(rlayer)
return rlayer(inp, state)
end
Expand Down
14 changes: 14 additions & 0 deletions test/test_cells.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,17 @@ end
inp = rand(Float32, 3)
@test rnncell(inp) == rnncell(inp, zeros(Float32, 5))
end

@testset "LEMCell" begin
rnncell = LEMCell(3 => 5)
@test length(Flux.trainables(rnncell)) == 4

inp = rand(Float32, 3)
@test rnncell(inp) == rnncell(inp, (zeros(Float32, 5), zeros(Float32, 5)))

rnncell = LEMCell(3 => 5; bias=false)
@test length(Flux.trainables(rnncell)) == 3

inp = rand(Float32, 3)
@test rnncell(inp) == rnncell(inp, (zeros(Float32, 5), zeros(Float32, 5)))
end
2 changes: 1 addition & 1 deletion test/test_layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using RecurrentLayers, Flux, Test
import Flux: initialstates

layers = [MGU, LiGRU, RAN, LightRU, NAS, MUT1, MUT2, MUT3,
SCRN, PeepholeLSTM, FastRNN, FastGRNN]
SCRN, PeepholeLSTM, FastRNN, FastGRNN, LEM]
#IndRNN handles internal states diffrently
#RHN should be checked more for consistency for initialstates

Expand Down
2 changes: 1 addition & 1 deletion test/test_wrappers.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using RecurrentLayers, Flux, Test

layers = [RNN, GRU, GRUv3, LSTM, MGU, LiGRU, RAN, LightRU, NAS, MUT1, MUT2, MUT3,
SCRN, PeepholeLSTM, FastRNN, FastGRNN]
SCRN, PeepholeLSTM, FastRNN, FastGRNN, LEM]

@testset "Sizes for StackedRNN with layer: $layer" for layer in layers
wrap = StackedRNN(layer, 2 => 4)
Expand Down

2 comments on commit 54b0227

@MartinuzziFrancesco
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/124598

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.7 -m "<description of version>" 54b022723951f45f2750bc83d1549f15769cc9a8
git push origin v0.2.7

Please sign in to comment.