Skip to content

Commit

Permalink
Adding return state option to recurrent layers (#2557)
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinuzziFrancesco authored Jan 9, 2025
1 parent 4eb4454 commit d79811a
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 33 deletions.
145 changes: 113 additions & 32 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ function scan(cell, x, state)
yt, state = cell(x_t, state)
y = vcat(y, [yt])
end
return stack(y, dims = 2)
return stack(y, dims = 2), state
end

"""
Expand Down Expand Up @@ -58,16 +58,27 @@ julia> x = rand(Float32, 2, 3, 4); # in x len x batch_size
julia> y = rnn(x); # out x len x batch_size
```
"""
struct Recurrence{M}
struct Recurrence{S,M}
cell::M
end

@layer Recurrence

initialstates(rnn::Recurrence) = initialstates(rnn.cell)

function Recurrence(cell; return_state = false)
return Recurrence{return_state, typeof(cell)}(cell)
end

(rnn::Recurrence)(x::AbstractArray) = rnn(x, initialstates(rnn))
(rnn::Recurrence)(x::AbstractArray, state) = scan(rnn.cell, x, state)

function (rnn::Recurrence{false})(x::AbstractArray, state)
first(scan(rnn.cell, x, state))
end

function (rnn::Recurrence{true})(x::AbstractArray, state)
scan(rnn.cell, x, state)
end

# Vanilla RNN
@doc raw"""
Expand Down Expand Up @@ -193,8 +204,8 @@ function Base.show(io::IO, m::RNNCell)
end

@doc raw"""
RNN(in => out, σ = tanh; init_kernel = glorot_uniform,
init_recurrent_kernel = glorot_uniform, bias = true)
RNN(in => out, σ = tanh; return_state = false,
init_kernel = glorot_uniform, init_recurrent_kernel = glorot_uniform, bias = true)
The most basic recurrent layer. Essentially acts as a `Dense` layer, but with the
output fed back into the input each time step.
Expand All @@ -212,6 +223,7 @@ See [`RNNCell`](@ref) for a layer that processes a single time step.
- `in => out`: The input and output dimensions of the layer.
- `σ`: The non-linearity to apply to the output. Default is `tanh`.
- `return_state`: Option to return the last state together with the output. Default is `false`.
- `init_kernel`: The initialization function to use for the input to hidden connection weights. Default is `glorot_uniform`.
- `init_recurrent_kernel`: The initialization function to use for the hidden to hidden connection weights. Default is `glorot_uniform`.
- `bias`: Whether to include a bias term initialized to zero. Default is `true`.
Expand All @@ -227,7 +239,8 @@ The arguments of the forward pass are:
If given, it is a vector of size `out` or a matrix of size `out x batch_size`.
If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref).
Returns all new hidden states `h_t` as an array of size `out x len x batch_size`.
Returns all new hidden states `h_t` as an array of size `out x len x batch_size`. When `return_state = true` it returns
a tuple of the hidden stats `h_t` and the last state of the iteration.
# Examples
Expand Down Expand Up @@ -260,26 +273,43 @@ Flux.@layer Model
model = Model(RNN(32 => 64), zeros(Float32, 64))
```
"""
struct RNN{M}
struct RNN{S,M}
cell::M
end

@layer :noexpand RNN

initialstates(rnn::RNN) = initialstates(rnn.cell)

function RNN((in, out)::Pair, σ = tanh; cell_kwargs...)
function RNN((in, out)::Pair, σ = tanh; return_state = false, cell_kwargs...)
cell = RNNCell(in => out, σ; cell_kwargs...)
return RNN(cell)
return RNN{return_state, typeof(cell)}(cell)
end

function RNN(cell::RNNCell; return_state::Bool=false)
RNN{return_state, typeof(cell)}(cell)
end

(rnn::RNN)(x::AbstractArray) = rnn(x, initialstates(rnn))

function (m::RNN)(x::AbstractArray, h)
function (rnn::RNN{false})(x::AbstractArray, h)
@assert ndims(x) == 2 || ndims(x) == 3
# [x] = [in, L] or [in, L, B]
# [h] = [out] or [out, B]
return scan(m.cell, x, h)
return first(scan(rnn.cell, x, h))
end

function (rnn::RNN{true})(x::AbstractArray, h)
@assert ndims(x) == 2 || ndims(x) == 3
# [x] = [in, L] or [in, L, B]
# [h] = [out] or [out, B]
return scan(rnn.cell, x, h)
end

function Functors.functor(rnn::RNN{S}) where {S}
params = (cell = rnn.cell,)
reconstruct = p -> RNN{S, typeof(p.cell)}(p.cell)
return params, reconstruct
end

function Base.show(io::IO, m::RNN)
Expand Down Expand Up @@ -391,7 +421,7 @@ Base.show(io::IO, m::LSTMCell) =


@doc raw"""
LSTM(in => out; init_kernel = glorot_uniform,
LSTM(in => out; return_state = false, init_kernel = glorot_uniform,
init_recurrent_kernel = glorot_uniform, bias = true)
[Long Short Term Memory](https://www.researchgate.net/publication/13853244_Long_Short-term_Memory)
Expand All @@ -415,6 +445,7 @@ See [`LSTMCell`](@ref) for a layer that processes a single time step.
# Arguments
- `in => out`: The input and output dimensions of the layer.
- `return_state`: Option to return the last state together with the output. Default is `false`.
- `init_kernel`: The initialization function to use for the input to hidden connection weights. Default is `glorot_uniform`.
- `init_recurrent_kernel`: The initialization function to use for the hidden to hidden connection weights. Default is `glorot_uniform`.
- `bias`: Whether to include a bias term initialized to zero. Default is `true`.
Expand All @@ -430,7 +461,8 @@ The arguments of the forward pass are:
They should be vectors of size `out` or matrices of size `out x batch_size`.
If not provided, they are assumed to be vectors of zeros, initialized by [`initialstates`](@ref).
Returns all new hidden states `h_t` as an array of size `out x len` or `out x len x batch_size`.
Returns all new hidden states `h_t` as an array of size `out x len` or `out x len x batch_size`. When `return_state = true` it returns
a tuple of the hidden stats `h_t` and the last state of the iteration.
# Examples
Expand All @@ -452,24 +484,39 @@ h = model(x)
size(h) # out x len x batch_size
```
"""
struct LSTM{M}
struct LSTM{S,M}
cell::M
end

@layer :noexpand LSTM

initialstates(lstm::LSTM) = initialstates(lstm.cell)

function LSTM((in, out)::Pair; cell_kwargs...)
function LSTM((in, out)::Pair; return_state = false, cell_kwargs...)
cell = LSTMCell(in => out; cell_kwargs...)
return LSTM(cell)
return LSTM{return_state, typeof(cell)}(cell)
end

function LSTM(cell::LSTMCell; return_state::Bool=false)
LSTM{return_state, typeof(cell)}(cell)
end

(lstm::LSTM)(x::AbstractArray) = lstm(x, initialstates(lstm))

function (m::LSTM)(x::AbstractArray, state0)
function (lstm::LSTM{false})(x::AbstractArray, state0)
@assert ndims(x) == 2 || ndims(x) == 3
return scan(m.cell, x, state0)
return first(scan(lstm.cell, x, state0))
end

function (lstm::LSTM{true})(x::AbstractArray, state0)
@assert ndims(x) == 2 || ndims(x) == 3
return scan(lstm.cell, x, state0)
end

function Functors.functor(lstm::LSTM{S}) where {S}
params = (cell = lstm.cell,)
reconstruct = p -> LSTM{S, typeof(p.cell)}(p.cell)
return params, reconstruct
end

function Base.show(io::IO, m::LSTM)
Expand Down Expand Up @@ -578,7 +625,7 @@ Base.show(io::IO, m::GRUCell) =
print(io, "GRUCell(", size(m.Wi, 2), " => ", size(m.Wi, 1) ÷ 3, ")")

@doc raw"""
GRU(in => out; init_kernel = glorot_uniform,
GRU(in => out; return_state = false, init_kernel = glorot_uniform,
init_recurrent_kernel = glorot_uniform, bias = true)
[Gated Recurrent Unit](https://arxiv.org/abs/1406.1078v1) layer. Behaves like an
Expand All @@ -599,6 +646,7 @@ See [`GRUCell`](@ref) for a layer that processes a single time step.
# Arguments
- `in => out`: The input and output dimensions of the layer.
- `return_state`: Option to return the last state together with the output. Default is `false`.
- `init_kernel`: The initialization function to use for the input to hidden connection weights. Default is `glorot_uniform`.
- `init_recurrent_kernel`: The initialization function to use for the hidden to hidden connection weights. Default is `glorot_uniform`.
- `bias`: Whether to include a bias term initialized to zero. Default is `true`.
Expand All @@ -613,7 +661,8 @@ The arguments of the forward pass are:
- `h`: The initial hidden state of the GRU. It should be a vector of size `out` or a matrix of size `out x batch_size`.
If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref).
Returns all new hidden states `h_t` as an array of size `out x len x batch_size`.
Returns all new hidden states `h_t` as an array of size `out x len x batch_size`. When `return_state = true` it returns
a tuple of the hidden stats `h_t` and the last state of the iteration.
# Examples
Expand All @@ -625,24 +674,39 @@ h0 = zeros(Float32, d_out)
h = gru(x, h0) # out x len x batch_size
```
"""
struct GRU{M}
struct GRU{S,M}
cell::M
end

@layer :noexpand GRU

initialstates(gru::GRU) = initialstates(gru.cell)

function GRU((in, out)::Pair; cell_kwargs...)
function GRU((in, out)::Pair; return_state = false, cell_kwargs...)
cell = GRUCell(in => out; cell_kwargs...)
return GRU(cell)
return GRU{return_state, typeof(cell)}(cell)
end

function GRU(cell::GRUCell; return_state::Bool=false)
GRU{return_state, typeof(cell)}(cell)
end

(gru::GRU)(x::AbstractArray) = gru(x, initialstates(gru))

function (m::GRU)(x::AbstractArray, h)
function (gru::GRU{false})(x::AbstractArray, h)
@assert ndims(x) == 2 || ndims(x) == 3
return first(scan(gru.cell, x, h))
end

function (gru::GRU{true})(x::AbstractArray, h)
@assert ndims(x) == 2 || ndims(x) == 3
return scan(m.cell, x, h)
return scan(gru.cell, x, h)
end

function Functors.functor(gru::GRU{S}) where {S}
params = (cell = gru.cell,)
reconstruct = p -> GRU{S, typeof(p.cell)}(p.cell)
return params, reconstruct
end

function Base.show(io::IO, m::GRU)
Expand Down Expand Up @@ -739,7 +803,7 @@ Base.show(io::IO, m::GRUv3Cell) =


@doc raw"""
GRUv3(in => out; init_kernel = glorot_uniform,
GRUv3(in => out; return_state = false, init_kernel = glorot_uniform,
init_recurrent_kernel = glorot_uniform, bias = true)
[Gated Recurrent Unit](https://arxiv.org/abs/1406.1078v3) layer. Behaves like an
Expand All @@ -764,6 +828,7 @@ but only a less popular variant.
# Arguments
- `in => out`: The input and output dimensions of the layer.
- `return_state`: Option to return the last state together with the output. Default is `false`.
- `init_kernel`: The initialization function to use for the input to hidden connection weights. Default is `glorot_uniform`.
- `init_recurrent_kernel`: The initialization function to use for the hidden to hidden connection weights. Default is `glorot_uniform`.
- `bias`: Whether to include a bias term initialized to zero. Default is `true`.
Expand All @@ -778,7 +843,8 @@ The arguments of the forward pass are:
- `h`: The initial hidden state of the GRU. It should be a vector of size `out` or a matrix of size `out x batch_size`.
If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref).
Returns all new hidden states `h_t` as an array of size `out x len x batch_size`.
Returns all new hidden states `h_t` as an array of size `out x len x batch_size`. When `return_state = true` it returns
a tuple of the hidden stats `h_t` and the last state of the iteration.
# Examples
Expand All @@ -790,24 +856,39 @@ h0 = zeros(Float32, d_out)
h = gruv3(x, h0) # out x len x batch_size
```
"""
struct GRUv3{M}
struct GRUv3{S,M}
cell::M
end

@layer :noexpand GRUv3

initialstates(gru::GRUv3) = initialstates(gru.cell)

function GRUv3((in, out)::Pair; cell_kwargs...)
function GRUv3((in, out)::Pair; return_state = false, cell_kwargs...)
cell = GRUv3Cell(in => out; cell_kwargs...)
return GRUv3(cell)
return GRUv3{return_state, typeof(cell)}(cell)
end

function GRUv3(cell::GRUv3Cell; return_state::Bool=false)
GRUv3{return_state, typeof(cell)}(cell)
end

(gru::GRUv3)(x::AbstractArray) = gru(x, initialstates(gru))

function (m::GRUv3)(x::AbstractArray, h)
function (gru::GRUv3{false})(x::AbstractArray, h)
@assert ndims(x) == 2 || ndims(x) == 3
return scan(m.cell, x, h)
return first(scan(gru.cell, x, h))
end

function (gru::GRUv3{true})(x::AbstractArray, h)
@assert ndims(x) == 2 || ndims(x) == 3
return scan(gru.cell, x, h)
end

function Functors.functor(gru::GRUv3{S}) where {S}
params = (cell = gru.cell,)
reconstruct = p -> GRUv3{S, typeof(p.cell)}(p.cell)
return params, reconstruct
end

function Base.show(io::IO, m::GRUv3)
Expand Down
Loading

0 comments on commit d79811a

Please sign in to comment.