From 07b2106afaca4e3f86cdd5bcccab7b2465884d70 Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Tue, 29 Oct 2024 15:28:01 +0100 Subject: [PATCH] fixes and tests --- README.md | 111 ++++++++++++++++++++------- src/RecurrentLayers.jl | 3 +- src/indrnn_cell.jl | 4 +- src/{lru_cell.jl => lightru_cell.jl} | 4 +- src/ligru_cell.jl | 18 +++-- src/mgu_cell.jl | 9 ++- src/nas_cell.jl | 8 +- src/ran_cell.jl | 19 ++--- src/rhn_cell.jl | 4 + src/sru_cell.jl | 9 ++- test/indrnn_cell.jl | 22 ++++++ test/lightru_cell.jl | 22 ++++++ test/ligru_cell.jl | 22 ++++++ test/mgu_cell.jl | 22 ++++++ test/ran_cell.jl | 22 ++++++ test/runtests.jl | 20 +++++ 16 files changed, 258 insertions(+), 61 deletions(-) rename src/{lru_cell.jl => lightru_cell.jl} (96%) create mode 100644 test/indrnn_cell.jl create mode 100644 test/lightru_cell.jl create mode 100644 test/ligru_cell.jl create mode 100644 test/mgu_cell.jl create mode 100644 test/ran_cell.jl diff --git a/README.md b/README.md index f544522..24fc77e 100644 --- a/README.md +++ b/README.md @@ -41,43 +41,102 @@ Pkg.add(url="https://github.com/MartinuzziFrancesco/RecurrentLayers.jl") The workflow is identical to any recurrent Flux layer: ```julia -using Flux using RecurrentLayers -input_size = 2 -hidden_size = 5 -output_size = 3 -sequence_length = 100 -epochs = 10 +using Flux +using MLUtils: DataLoader +using Statistics +using Random + +# Parameters +input_size = 1 # Each element in the sequence is a scalar +hidden_size = 64 # Size of the hidden state in MGU +num_classes = 2 # Binary classification +seq_length = 10 # Length of each sequence +batch_size = 16 # Batch size +num_epochs = 50 # Number of epochs for training +num_samples = 1000 # Number of samples in dataset + +# Create dataset +function create_dataset(seq_length, num_samples) + data = randn(input_size, seq_length, num_samples) + labels = sum(data, dims=(1,2)) .>= 0 + labels = Int.(labels) + return data, labels +end + +# Generate training data +train_data, train_labels = create_dataset(seq_length, num_samples) +train_loader = DataLoader((train_data, train_labels), batchsize=batch_size, shuffle=true) +# Define the model model = Chain( - MGU(input_size, hidden_size), - Dense(hidden_size, output_size) + RAN(input_size => hidden_size), + x -> x[:, end, :], # Extract the last hidden state + Dense(hidden_size, num_classes) ) -# dummy data -X = rand(Float32, input_size, sequence_length) -Y = rand(1:output_size) +# Adjust labels to 1-based indexing for Julia +function adjust_labels(labels) + return labels .+ 1 +end + +# Define the loss function +# Define the loss function +function loss_fn(batch_data, batch_labels) + # Adjust labels + batch_labels = adjust_labels(batch_labels) + # One-hot encode labels and remove any extra singleton dimensions + batch_labels_oh = dropdims(Flux.onehotbatch(batch_labels, 1:num_classes), dims=(2, 3)) + # Forward pass + y_pred = model(batch_data) + # Compute loss + loss = Flux.logitcrossentropy(y_pred, batch_labels_oh) + return loss +end -# loss function -loss_fn(x, y) = Flux.mse(model(x), y) -# optimizer -opt = Adam() +# Define the optimizer +opt = Adam(0.01) -# training -for epoch in 1:epochs - # gradients - gs = gradient(Flux.params(model)) do - loss = loss_fn(X, Y) - return loss +# Training loop +for epoch in 1:num_epochs + total_loss = 0.0 + for (batch_data, batch_labels) in train_loader + # Compute gradients and update parameters + grads = gradient(() -> loss_fn(batch_data, batch_labels), Flux.params(model)) + Flux.Optimise.update!(opt, Flux.params(model), grads) + + # Accumulate loss + total_loss += loss_fn(batch_data, batch_labels) end - # update parameters - Flux.update!(opt, Flux.params(model), gs) - # loss at epoch - current_loss = loss_fn(X, Y) - println("Epoch $epoch, Loss: $(current_loss)") + avg_loss = total_loss / length(train_loader) + println("Epoch $epoch/$num_epochs, Loss: $(round(avg_loss, digits=4))") +end + +# Generate test data +test_data, test_labels = create_dataset(seq_length, 200) +test_loader = DataLoader((test_data, test_labels), batchsize=batch_size, shuffle=false) + +# Evaluation +correct = 0 +total = 0 +for (batch_data, batch_labels) in test_loader + # Adjust labels + batch_labels = adjust_labels(batch_labels) + # Forward pass + y_pred = model(batch_data) + # Decode predictions + predicted = Flux.onecold(y_pred, 1:num_classes) + # Flatten and compare + correct += sum(vec(predicted) .== vec(batch_labels)) + total += length(batch_labels) end + +accuracy = 100 * correct / total +println("Test Accuracy: $(round(accuracy, digits=2))%") + + ``` ## License diff --git a/src/RecurrentLayers.jl b/src/RecurrentLayers.jl index fede6cf..c1ebe95 100644 --- a/src/RecurrentLayers.jl +++ b/src/RecurrentLayers.jl @@ -8,11 +8,12 @@ export MGUCell, LiGRUCell, IndRNNCell, RANCell, LightRUCell, RHNCell, RHNCellUnit, NASCell export MGU, LiGRU, IndRNN, RAN, LightRU, NAS +#TODO add double bias include("mgu_cell.jl") include("ligru_cell.jl") include("indrnn_cell.jl") include("ran_cell.jl") -include("lru_cell.jl") +include("lightru_cell.jl") include("rhn_cell.jl") include("nas_cell.jl") diff --git a/src/indrnn_cell.jl b/src/indrnn_cell.jl index b7f05aa..05c5c30 100644 --- a/src/indrnn_cell.jl +++ b/src/indrnn_cell.jl @@ -27,8 +27,8 @@ function (indrnn::IndRNNCell)(inp::AbstractVecOrMat, state::AbstractVecOrMat) return state end -function Base.show(io::IO, m::IndRNNCell) - print(io, "IndRNNCell(", size(m.Wi, 2), " => ", size(indrnn.Wi, 1)) +function Base.show(io::IO, indrnn::IndRNNCell) + print(io, "IndRNNCell(", size(indrnn.Wi, 2), " => ", size(indrnn.Wi, 1)) print(io, ", ", indrnn.σ) print(io, ")") end diff --git a/src/lru_cell.jl b/src/lightru_cell.jl similarity index 96% rename from src/lru_cell.jl rename to src/lightru_cell.jl index 4401a28..71df40a 100644 --- a/src/lru_cell.jl +++ b/src/lightru_cell.jl @@ -5,6 +5,8 @@ struct LightRUCell{I,H,V} bias::V end +Flux.@layer LightRUCell + function LightRUCell((in, out)::Pair, σ=tanh; init = glorot_uniform, bias = true) Wi = init(2 * out, in) Wh = init(out, out) @@ -43,7 +45,7 @@ struct LightRU{M} cell::M end -Flux.@layer :expand LRU +Flux.@layer :expand LightRU function LightRU((in, out)::Pair; init = glorot_uniform, bias = true) cell = LightRUCell(in => out; init, bias) diff --git a/src/ligru_cell.jl b/src/ligru_cell.jl index 9e29126..de4726d 100644 --- a/src/ligru_cell.jl +++ b/src/ligru_cell.jl @@ -5,6 +5,8 @@ struct LiGRUCell{I, H, V} bias::V end +Flux.@layer LiGRUCell + function LiGRUCell((in, out)::Pair; init = glorot_uniform, bias = true) @@ -18,17 +20,21 @@ end LiGRUCell(in, out; kwargs...) = LiGRUCell(in => out; kwargs...) +function (ligru::LiGRUCell)(inp::AbstractVecOrMat) + state = zeros_like(inp, size(ligru.Wh, 2)) + return ligru(inp, state) +end + function (ligru::LiGRUCell)(inp::AbstractVecOrMat, state) _size_check(ligru, inp, 1 => size(ligru.Wi,2)) - Wi, Wh, b = ligru.Wi, ligru.Wh, ligru.b + Wi, Wh, b = ligru.Wi, ligru.Wh, ligru.bias #split gxs = chunk(Wi * inp, 2, dims=1) - ghs = chunk(Wh * state, 2, dims=1) - bs = chunk(b, 2, dims=1) + ghs = chunk(Wh * state .+ b, 2, dims=1) #compute - forget_gate = @. sigmoid_fast(gxs[1] + ghs[1] + bs[1]) - candidate_hidden = @. tanh_fast(gxs[2] + ghs[2] + bs[2]) - new_state = forget_gate .* hidden .+ (1 .- forget_gate) .* candidate_hidden + forget_gate = @. sigmoid_fast(gxs[1] + ghs[1]) + candidate_hidden = @. tanh_fast(gxs[2] + ghs[2]) + new_state = forget_gate .* state .+ (1 .- forget_gate) .* candidate_hidden return new_state end diff --git a/src/mgu_cell.jl b/src/mgu_cell.jl index d8a063a..fe8f313 100644 --- a/src/mgu_cell.jl +++ b/src/mgu_cell.jl @@ -5,6 +5,8 @@ struct MGUCell{I, H, V} bias::V end +Flux.@layer MGUCell + function MGUCell((in, out)::Pair; init = glorot_uniform, bias = true) @@ -27,12 +29,11 @@ function (mgu::MGUCell)(inp::AbstractVecOrMat, state) _size_check(mgu, inp, 1 => size(mgu.Wi,2)) Wi, Wh, b = mgu.Wi, mgu.Wh, mgu.bias #split - gxs = chunk(Wi * inp, 2, dims=1) - bs = chunk(b, 2, dims=1) + gxs = chunk(Wi * inp .+ b, 2, dims=1) ghs = chunk(Wh, 2, dims=1) - forget_gate = sigmoid_fast.(gxs[1] .+ ghs[1]*state .+ bs[1]) - candidate_state = tanh_fast.(gxs[2] .+ ghs[2]*(forget_gate.*state) .+ bs[2]) + forget_gate = sigmoid_fast.(gxs[1] .+ ghs[1]*state) + candidate_state = tanh_fast.(gxs[2] .+ ghs[2]*(forget_gate.*state)) new_state = forget_gate .* state .+ (1 .- forget_gate) .* candidate_state return new_state end diff --git a/src/nas_cell.jl b/src/nas_cell.jl index 5f8bf59..339545c 100644 --- a/src/nas_cell.jl +++ b/src/nas_cell.jl @@ -29,6 +29,8 @@ struct NASCell{I,H,V} bias::V end +Flux.@layer NASCell + function NASCell((in, out)::Pair; init = glorot_uniform, bias = true) Wi = init(8 * out, in) Wh = init(8 * out, out) @@ -47,10 +49,8 @@ function (nas::NASCell)(inp::AbstractVecOrMat, (state, c_state)) Wi, Wh, b = nas.Wi, nas.Wh, nas.bias #matmul and split - inputs = Wi * inp - recurrents = Wh * state .+ b - im = chunk(inputs, 8; dims=1) - mm = chunk(recurrents, 8; dims=1) + im = chunk(Wi * inp, 8; dims=1) + mm = chunk(Wh * state .+ b, 8; dims=1) #first layer layer1_1 = sigmoid_fast(im[1] .+ mm[1]) diff --git a/src/ran_cell.jl b/src/ran_cell.jl index fc338d7..a9bd0c9 100644 --- a/src/ran_cell.jl +++ b/src/ran_cell.jl @@ -5,6 +5,8 @@ struct RANCell{I,H,V} bias::V end +Flux.@layer RANCell + """ RANCell(in => out; init = glorot_uniform, bias = true) @@ -13,14 +15,6 @@ The `RANCell`, introduced in [this paper](https://arxiv.org/pdf/1705.07393), is a recurrent cell layer which provides additional memory through the use of gates. -The forward pass consists of: -```math -\tilde{c}_t = W_{cx} x_t \\ -i_t = \sigma\left(W_{ih} h_{t-1} + W_{ix} x_t + b_i\right) \\ -f_t = \sigma\left(W_{fh} h_{t-1} + W_{fx} x_t + b_f\right) \\ -c_t = i_t \circ \tilde{c}_t + f_t \circ c_{t-1} \\ -h_t = tanh(c_t) -``` and returns both h_t anf c_t. See [`RAN`](@ref) for a layer that processes entire sequences. @@ -77,13 +71,12 @@ function (ran::RANCell)(inp::AbstractVecOrMat, (state, c_state)) Wi, Wh, b = ran.Wi, ran.Wh, ran.bias #split - gxs = chunk(Wi * inp, 3, dims=1) - bs = chunk(b, 2, dims=1) - ghs = chunk(Wh * state, 2, dims=1) + gxs = chunk(Wi * inp, 3; dims=1) + ghs = chunk(Wh * state .+ b, 2; dims=1) #compute - input_gate = @. sigmoid_fast(gxs[2] + ghs[1] + bs[1]) - forget_gate = @. sigmoid_fast(gxs[3] + ghs[2] + bs[2]) + input_gate = @. sigmoid_fast(gxs[2] + ghs[1]) + forget_gate = @. sigmoid_fast(gxs[3] + ghs[2]) candidate_state = @. input_gate * gxs[1] + forget_gate * c_state new_state = tanh_fast(candidate_state) return new_state, candidate_state diff --git a/src/rhn_cell.jl b/src/rhn_cell.jl index 1db66d0..9403b36 100644 --- a/src/rhn_cell.jl +++ b/src/rhn_cell.jl @@ -6,6 +6,8 @@ struct RHNCellUnit{I,V} bias::V end +Flux.@layer RHNCellUnit + function RHNCellUnit((in, out)::Pair; init = glorot_uniform, bias = true) weight = init(3 * out, in) b = create_bias(weight, bias, size(weight, 1)) @@ -37,6 +39,8 @@ struct RHNCell{C} couple_carry::Bool end +Flux.@layer RHNCell + function RHNCell((in, out), depth=3; couple_carry::Bool = true, cell_kwargs...) diff --git a/src/sru_cell.jl b/src/sru_cell.jl index fe04575..ae56c20 100644 --- a/src/sru_cell.jl +++ b/src/sru_cell.jl @@ -6,6 +6,8 @@ struct SRUCell{I,H,B,V} bias::V end +Flux.@layer SRUCell + function SRUCell((in, out)::Pair, σ=tanh; init = glorot_uniform, bias = true) Wi = init(2 * out, in) Wh = init(2 * out, out) @@ -29,13 +31,12 @@ function (sru::SRUCell)(inp::AbstractVecOrMat, (state, c_state)) #split gxs = chunk(Wi * inp, 3, dims=1) - ghs = chunk(Wh * state, 2, dims=1) - bs = chunk(b, 2, dims=1) + ghs = chunk(Wh * state .+ b, 2, dims=1) vs = chunk(v, 2, dims=1) #compute - input_gate = @. sigmoid_fast(gxs[2] + ghs[1] + bs[1]) - forget_gate = @. sigmoid_fast(gxs[3] + ghs[2] + bs[2]) + input_gate = @. sigmoid_fast(gxs[2] + ghs[1]) + forget_gate = @. sigmoid_fast(gxs[3] + ghs[2]) candidate_state = @. input_gate * gxs[1] + forget_gate * c_state new_state = tanh_fast(candidate_state) return new_state, candidate_state diff --git a/test/indrnn_cell.jl b/test/indrnn_cell.jl new file mode 100644 index 0000000..c467846 --- /dev/null +++ b/test/indrnn_cell.jl @@ -0,0 +1,22 @@ +using Test +using RecurrentLayers +using Flux + +@testset "IndRNNCell" begin + @testset "Sizes and parameters" begin + indrnn = IndRNNCell(3 => 5) + @test length(Flux.trainables(indrnn)) == 3 + + inp = rand(Float32, 3) + @test indrnn(inp) == indrnn(inp, zeros(Float32, 5)) + + indrnn = IndRNNCell(3 => 5; bias=false) + @test length(Flux.trainables(indrnn)) == 2 + + inp = rand(Float32, 3) + @test indrnn(inp) == indrnn(inp, zeros(Float32, 5)) + end +end + +#@testset "IndRNN" begin +#end \ No newline at end of file diff --git a/test/lightru_cell.jl b/test/lightru_cell.jl new file mode 100644 index 0000000..0248632 --- /dev/null +++ b/test/lightru_cell.jl @@ -0,0 +1,22 @@ +using Test +using RecurrentLayers +using Flux + +@testset "LightRUCell" begin + @testset "Sizes and parameters" begin + lightru = LightRUCell(3 => 5) + @test length(Flux.trainables(lightru)) == 3 + + inp = rand(Float32, 3) + @test lightru(inp) == lightru(inp, zeros(Float32, 5)) + + lightru = LightRUCell(3 => 5; bias=false) + @test length(Flux.trainables(lightru)) == 2 + + inp = rand(Float32, 3) + @test lightru(inp) == lightru(inp, zeros(Float32, 5)) + end +end + +#@testset "LightRU" begin +#end \ No newline at end of file diff --git a/test/ligru_cell.jl b/test/ligru_cell.jl new file mode 100644 index 0000000..787f4e9 --- /dev/null +++ b/test/ligru_cell.jl @@ -0,0 +1,22 @@ +using Test +using RecurrentLayers +using Flux + +@testset "LiGRUCell" begin + @testset "Sizes and parameters" begin + ligru = LiGRUCell(3 => 5) + @test length(Flux.trainables(ligru)) == 3 + + inp = rand(Float32, 3) + @test ligru(inp) == ligru(inp, zeros(Float32, 5)) + + ligru = LiGRUCell(3 => 5; bias=false) + @test length(Flux.trainables(ligru)) == 2 + + inp = rand(Float32, 3) + @test ligru(inp) == ligru(inp, zeros(Float32, 5)) + end +end + +#@testset "LiGRU" begin +#end \ No newline at end of file diff --git a/test/mgu_cell.jl b/test/mgu_cell.jl new file mode 100644 index 0000000..dee7779 --- /dev/null +++ b/test/mgu_cell.jl @@ -0,0 +1,22 @@ +using Test +using RecurrentLayers +using Flux + +@testset "MGUCell" begin + @testset "Sizes and parameters" begin + mgu = MGUCell(3 => 5) + @test length(Flux.trainables(mgu)) == 3 + + inp = rand(Float32, 3) + @test mgu(inp) == mgu(inp, zeros(Float32, 5)) + + mgu = MGUCell(3 => 5; bias=false) + @test length(Flux.trainables(mgu)) == 2 + + inp = rand(Float32, 3) + @test mgu(inp) == mgu(inp, zeros(Float32, 5)) + end +end + +#@testset "MGU" begin +#end \ No newline at end of file diff --git a/test/ran_cell.jl b/test/ran_cell.jl new file mode 100644 index 0000000..dbde3b4 --- /dev/null +++ b/test/ran_cell.jl @@ -0,0 +1,22 @@ +using Test +using RecurrentLayers +using Flux + +@testset "RANCell" begin + @testset "Sizes and parameters" begin + rancell = RANCell(3 => 5) + @test length(Flux.trainables(rancell)) == 3 + + inp = rand(Float32, 3) + @test rancell(inp) == rancell(inp, (zeros(Float32, 5), zeros(Float32, 5))) + + rancell = RANCell(3 => 5; bias=false) + @test length(Flux.trainables(rancell)) == 2 + + inp = rand(Float32, 3) + @test rancell(inp) == rancell(inp, (zeros(Float32, 5), zeros(Float32, 5))) + end +end + +#@testset "LightRU" begin +#end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 0408f83..59fdd9c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,3 +4,23 @@ using Test @safetestset "Quality Assurance" begin include("qa.jl") end + +@safetestset "Minimal gated unit" begin + include("mgu_cell.jl") +end + +@safetestset "Independently recurrent neural network" begin + include("indrnn_cell.jl") +end + +@safetestset "Light gated recurrent unit" begin + include("ligru_cell.jl") +end + +@safetestset "Light recurrent unit" begin + include("lightru_cell.jl") +end + +@safetestset "Recurrent additive network" begin + include("ran_cell.jl") +end \ No newline at end of file