Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added tests for layers #29

Merged
merged 6 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
name = "RecurrentLayers"
uuid = "78449bcf-6750-4b78-9e82-63d4a1ccdf8c"
authors = ["Francesco Martinuzzi"]
version = "0.2.1"
version = "0.2.2"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"

[compat]
Compat = "4.16.0"
Flux = "0.16"
julia = "1.10"

Expand Down
6 changes: 5 additions & 1 deletion src/RecurrentLayers.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module RecurrentLayers

using Flux
using Compat: @compat #for @compat public
import Flux: _size_check, _match_eltype, chunk, create_bias, zeros_like
import Flux: glorot_uniform
#TODO add interlinks to initialstates in docstrings https://juliadocs.org/DocumenterInterLinks.jl/stable/
Expand Down Expand Up @@ -35,7 +36,9 @@ function (rlayer::AbstractRecurrentLayer)(inp::AbstractVecOrMat)
return rlayer(inp, state)
end

function (rlayer::AbstractRecurrentLayer)(inp::AbstractArray, state::AbstractVecOrMat)
function (rlayer::AbstractRecurrentLayer)(
inp::AbstractArray,
state::Union{AbstractVecOrMat, Tuple{AbstractVecOrMat, AbstractVecOrMat}})
@assert ndims(inp) == 2 || ndims(inp) == 3
return scan(rlayer.cell, inp, state)
end
Expand All @@ -46,6 +49,7 @@ FastRNNCell, FastGRNNCell
export MGU, LiGRU, IndRNN, RAN, LightRU, NAS, RHN, MUT1, MUT2, MUT3,
SCRN, PeepholeLSTM, FastRNN, FastGRNN

@compat(public, (initialstates))

include("mgu_cell.jl")
include("ligru_cell.jl")
Expand Down
7 changes: 1 addition & 6 deletions src/fastrnn_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,6 @@ function FastRNN((input_size, hidden_size)::Pair, activation = tanh_fast;
cell = FastRNNCell(input_size => hidden_size, activation; kwargs...)
return FastRNN(cell)
end

function (fastrnn::FastRNN)(inp, state)
@assert ndims(inp) == 2 || ndims(inp) == 3
return scan(fastrnn.cell, inp, state)
end


struct FastGRNNCell{I, H, V, A, B, F} <: AbstractRecurrentCell
Expand Down Expand Up @@ -221,7 +216,7 @@ function (fastgrnn::FastGRNNCell)(inp::AbstractVecOrMat, state)
# perform computations
gate = fastgrnn.activation.(partial_gate .+ bz)
candidate_state = tanh_fast.(partial_gate .+ bh)
new_state = (zeta .* (ones(size(gate)) .- gate) .+ nu) .* candidate_state .+ gate .* state
new_state = (zeta .* (ones(Float32, size(gate)) .- gate) .+ nu) .* candidate_state .+ gate .* state

return new_state, new_state
end
Expand Down
2 changes: 1 addition & 1 deletion src/indrnn_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,6 @@ See [`IndRNNCell`](@ref) for a layer that processes a single sequence.
- New hidden states `new_states` as an array of size `hidden_size x len x batch_size`.
"""
function IndRNN((input_size, hidden_size)::Pair, σ = tanh; kwargs...)
cell = IndRNNCell(input_size, hidden_size, σ; kwargs...)
cell = IndRNNCell(input_size => hidden_size, σ; kwargs...)
return IndRNN(cell)
end
7 changes: 3 additions & 4 deletions src/peepholelstm_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@ function PeepholeLSTMCell(
Wi = init_kernel(hidden_size * 4, input_size)
Wh = init_recurrent_kernel(hidden_size * 4, hidden_size)
b = create_bias(Wi, bias, hidden_size * 4)
cell = PeepholeLSTMCell(Wi, Wh, b)
return cell
return PeepholeLSTMCell(Wi, Wh, b)
end

function (lstm::PeepholeLSTMCell)(inp::AbstractVecOrMat,
Expand All @@ -74,7 +73,7 @@ function (lstm::PeepholeLSTMCell)(inp::AbstractVecOrMat,
input, forget, cell, output = chunk(g, 4; dims = 1)
new_cstate = @. sigmoid_fast(forget) * c_state + sigmoid_fast(input) * tanh_fast(cell)
new_state = @. sigmoid_fast(output) * tanh_fast(new_cstate)
return new_state, (new_state, new_cstate)
return new_cstate, (new_state, new_cstate)
end

Base.show(io::IO, lstm::PeepholeLSTMCell) =
Expand Down Expand Up @@ -128,6 +127,6 @@ h_t &= o_t \odot \sigma_h(c_t).
- New hidden states `new_states` as an array of size `hidden_size x len x batch_size`.
"""
function PeepholeLSTM((input_size, hidden_size)::Pair; kwargs...)
cell = PeepholeLSTM(input_size => hidden_size; kwargs...)
cell = PeepholeLSTMCell(input_size => hidden_size; kwargs...)
return PeepholeLSTM(cell)
end
4 changes: 2 additions & 2 deletions src/scrn_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ function SCRNCell((input_size, hidden_size)::Pair;
init_kernel = glorot_uniform,
init_recurrent_kernel = glorot_uniform,
bias::Bool = true,
alpha = 0.0)
alpha = 0.f0)

Wi = init_kernel(2 * hidden_size, input_size)
Wh = init_recurrent_kernel(2 * hidden_size, hidden_size)
Expand All @@ -78,7 +78,7 @@ function (scrn::SCRNCell)(inp::AbstractVecOrMat, (state, c_state))
gcs = chunk(Wc * c_state .+ b, 2; dims=1)

#compute
context_layer = (1 .- scrn.alpha) .* gxs[1] .+ scrn.alpha .* c_state
context_layer = (1.f0 .- scrn.alpha) .* gxs[1] .+ scrn.alpha .* c_state
hidden_layer = sigmoid_fast(gxs[2] .+ ghs[1] * state .+ gcs[1])
new_state = tanh_fast(ghs[2] * hidden_layer .+ gcs[2])
return new_state, (new_state, context_layer)
Expand Down
6 changes: 5 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ using Test
include("qa.jl")
end

@safetestset "Sizes and parameters" begin
@safetestset "Cells" begin
include("test_cells.jl")
end

@safetestset "Layers" begin
include("test_layers.jl")
end
34 changes: 34 additions & 0 deletions test/test_layers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
using RecurrentLayers
using Flux
using Test

import Flux: initialstates

layers = [MGU, LiGRU, RAN, LightRU, NAS, MUT1, MUT2, MUT3,
SCRN, PeepholeLSTM, FastRNN, FastGRNN]
#IndRNN handles internal states diffrently
#RHN should be checked more for consistency for initialstates

@testset "Sizes for layer: $layer" for layer in layers
rlayer = layer(2 => 4)

# initial states is zero
state = initialstates(rlayer)
if state isa AbstractArray
@test state ≈ zeros(Float32, 4)
else
@test state[1] ≈ zeros(Float32, 4)
@test state[2] ≈ zeros(Float32, 4)
end

inp = rand(Float32, 2, 3, 1)
output = rlayer(inp, state)
@test output isa Array{Float32, 3}
@test size(output) == (4, 3, 1)

inp = rand(Float32, 2, 3)
output = rlayer(inp, state)
@test output isa Array{Float32, 2}
@test size(output) == (4, 3)

end
Loading