Skip to content

Commit

Permalink
fixes and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinuzziFrancesco committed Oct 29, 2024
1 parent 0d4dc2a commit 07b2106
Show file tree
Hide file tree
Showing 16 changed files with 258 additions and 61 deletions.
111 changes: 85 additions & 26 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,43 +41,102 @@ Pkg.add(url="https://github.com/MartinuzziFrancesco/RecurrentLayers.jl")
The workflow is identical to any recurrent Flux layer:

```julia
using Flux
using RecurrentLayers

input_size = 2
hidden_size = 5
output_size = 3
sequence_length = 100
epochs = 10
using Flux
using MLUtils: DataLoader
using Statistics
using Random

# Parameters
input_size = 1 # Each element in the sequence is a scalar
hidden_size = 64 # Size of the hidden state in MGU
num_classes = 2 # Binary classification
seq_length = 10 # Length of each sequence
batch_size = 16 # Batch size
num_epochs = 50 # Number of epochs for training
num_samples = 1000 # Number of samples in dataset

# Create dataset
function create_dataset(seq_length, num_samples)
data = randn(input_size, seq_length, num_samples)
labels = sum(data, dims=(1,2)) .>= 0
labels = Int.(labels)
return data, labels
end

# Generate training data
train_data, train_labels = create_dataset(seq_length, num_samples)
train_loader = DataLoader((train_data, train_labels), batchsize=batch_size, shuffle=true)

# Define the model
model = Chain(
MGU(input_size, hidden_size),
Dense(hidden_size, output_size)
RAN(input_size => hidden_size),
x -> x[:, end, :], # Extract the last hidden state
Dense(hidden_size, num_classes)
)

# dummy data
X = rand(Float32, input_size, sequence_length)
Y = rand(1:output_size)
# Adjust labels to 1-based indexing for Julia
function adjust_labels(labels)
return labels .+ 1
end

# Define the loss function
# Define the loss function
function loss_fn(batch_data, batch_labels)
# Adjust labels
batch_labels = adjust_labels(batch_labels)
# One-hot encode labels and remove any extra singleton dimensions
batch_labels_oh = dropdims(Flux.onehotbatch(batch_labels, 1:num_classes), dims=(2, 3))
# Forward pass
y_pred = model(batch_data)
# Compute loss
loss = Flux.logitcrossentropy(y_pred, batch_labels_oh)
return loss
end

# loss function
loss_fn(x, y) = Flux.mse(model(x), y)

# optimizer
opt = Adam()
# Define the optimizer
opt = Adam(0.01)

# training
for epoch in 1:epochs
# gradients
gs = gradient(Flux.params(model)) do
loss = loss_fn(X, Y)
return loss
# Training loop
for epoch in 1:num_epochs
total_loss = 0.0
for (batch_data, batch_labels) in train_loader
# Compute gradients and update parameters
grads = gradient(() -> loss_fn(batch_data, batch_labels), Flux.params(model))
Flux.Optimise.update!(opt, Flux.params(model), grads)

# Accumulate loss
total_loss += loss_fn(batch_data, batch_labels)
end
# update parameters
Flux.update!(opt, Flux.params(model), gs)
# loss at epoch
current_loss = loss_fn(X, Y)
println("Epoch $epoch, Loss: $(current_loss)")
avg_loss = total_loss / length(train_loader)
println("Epoch $epoch/$num_epochs, Loss: $(round(avg_loss, digits=4))")
end

# Generate test data
test_data, test_labels = create_dataset(seq_length, 200)
test_loader = DataLoader((test_data, test_labels), batchsize=batch_size, shuffle=false)

# Evaluation
correct = 0
total = 0
for (batch_data, batch_labels) in test_loader
# Adjust labels
batch_labels = adjust_labels(batch_labels)
# Forward pass
y_pred = model(batch_data)
# Decode predictions
predicted = Flux.onecold(y_pred, 1:num_classes)
# Flatten and compare
correct += sum(vec(predicted) .== vec(batch_labels))
total += length(batch_labels)
end

accuracy = 100 * correct / total
println("Test Accuracy: $(round(accuracy, digits=2))%")


```
## License

Expand Down
3 changes: 2 additions & 1 deletion src/RecurrentLayers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@ export MGUCell, LiGRUCell, IndRNNCell, RANCell, LightRUCell, RHNCell,
RHNCellUnit, NASCell
export MGU, LiGRU, IndRNN, RAN, LightRU, NAS

#TODO add double bias
include("mgu_cell.jl")
include("ligru_cell.jl")
include("indrnn_cell.jl")
include("ran_cell.jl")
include("lru_cell.jl")
include("lightru_cell.jl")
include("rhn_cell.jl")
include("nas_cell.jl")

Expand Down
4 changes: 2 additions & 2 deletions src/indrnn_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ function (indrnn::IndRNNCell)(inp::AbstractVecOrMat, state::AbstractVecOrMat)
return state
end

function Base.show(io::IO, m::IndRNNCell)
print(io, "IndRNNCell(", size(m.Wi, 2), " => ", size(indrnn.Wi, 1))
function Base.show(io::IO, indrnn::IndRNNCell)
print(io, "IndRNNCell(", size(indrnn.Wi, 2), " => ", size(indrnn.Wi, 1))
print(io, ", ", indrnn.σ)
print(io, ")")
end
Expand Down
4 changes: 3 additions & 1 deletion src/lru_cell.jl → src/lightru_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ struct LightRUCell{I,H,V}
bias::V
end

Flux.@layer LightRUCell

function LightRUCell((in, out)::Pair, σ=tanh; init = glorot_uniform, bias = true)
Wi = init(2 * out, in)
Wh = init(out, out)
Expand Down Expand Up @@ -43,7 +45,7 @@ struct LightRU{M}
cell::M
end

Flux.@layer :expand LRU
Flux.@layer :expand LightRU

function LightRU((in, out)::Pair; init = glorot_uniform, bias = true)
cell = LightRUCell(in => out; init, bias)
Expand Down
18 changes: 12 additions & 6 deletions src/ligru_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ struct LiGRUCell{I, H, V}
bias::V
end

Flux.@layer LiGRUCell

function LiGRUCell((in, out)::Pair;
init = glorot_uniform,
bias = true)
Expand All @@ -18,17 +20,21 @@ end

LiGRUCell(in, out; kwargs...) = LiGRUCell(in => out; kwargs...)

function (ligru::LiGRUCell)(inp::AbstractVecOrMat)
state = zeros_like(inp, size(ligru.Wh, 2))
return ligru(inp, state)
end

function (ligru::LiGRUCell)(inp::AbstractVecOrMat, state)
_size_check(ligru, inp, 1 => size(ligru.Wi,2))
Wi, Wh, b = ligru.Wi, ligru.Wh, ligru.b
Wi, Wh, b = ligru.Wi, ligru.Wh, ligru.bias
#split
gxs = chunk(Wi * inp, 2, dims=1)
ghs = chunk(Wh * state, 2, dims=1)
bs = chunk(b, 2, dims=1)
ghs = chunk(Wh * state .+ b, 2, dims=1)
#compute
forget_gate = @. sigmoid_fast(gxs[1] + ghs[1] + bs[1])
candidate_hidden = @. tanh_fast(gxs[2] + ghs[2] + bs[2])
new_state = forget_gate .* hidden .+ (1 .- forget_gate) .* candidate_hidden
forget_gate = @. sigmoid_fast(gxs[1] + ghs[1])
candidate_hidden = @. tanh_fast(gxs[2] + ghs[2])
new_state = forget_gate .* state .+ (1 .- forget_gate) .* candidate_hidden
return new_state
end

Expand Down
9 changes: 5 additions & 4 deletions src/mgu_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ struct MGUCell{I, H, V}
bias::V
end

Flux.@layer MGUCell

function MGUCell((in, out)::Pair;
init = glorot_uniform,
bias = true)
Expand All @@ -27,12 +29,11 @@ function (mgu::MGUCell)(inp::AbstractVecOrMat, state)
_size_check(mgu, inp, 1 => size(mgu.Wi,2))
Wi, Wh, b = mgu.Wi, mgu.Wh, mgu.bias
#split
gxs = chunk(Wi * inp, 2, dims=1)
bs = chunk(b, 2, dims=1)
gxs = chunk(Wi * inp .+ b, 2, dims=1)
ghs = chunk(Wh, 2, dims=1)

forget_gate = sigmoid_fast.(gxs[1] .+ ghs[1]*state .+ bs[1])
candidate_state = tanh_fast.(gxs[2] .+ ghs[2]*(forget_gate.*state) .+ bs[2])
forget_gate = sigmoid_fast.(gxs[1] .+ ghs[1]*state)
candidate_state = tanh_fast.(gxs[2] .+ ghs[2]*(forget_gate.*state))
new_state = forget_gate .* state .+ (1 .- forget_gate) .* candidate_state
return new_state
end
Expand Down
8 changes: 4 additions & 4 deletions src/nas_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ struct NASCell{I,H,V}
bias::V
end

Flux.@layer NASCell

function NASCell((in, out)::Pair; init = glorot_uniform, bias = true)
Wi = init(8 * out, in)
Wh = init(8 * out, out)
Expand All @@ -47,10 +49,8 @@ function (nas::NASCell)(inp::AbstractVecOrMat, (state, c_state))
Wi, Wh, b = nas.Wi, nas.Wh, nas.bias

#matmul and split
inputs = Wi * inp
recurrents = Wh * state .+ b
im = chunk(inputs, 8; dims=1)
mm = chunk(recurrents, 8; dims=1)
im = chunk(Wi * inp, 8; dims=1)
mm = chunk(Wh * state .+ b, 8; dims=1)

#first layer
layer1_1 = sigmoid_fast(im[1] .+ mm[1])
Expand Down
19 changes: 6 additions & 13 deletions src/ran_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ struct RANCell{I,H,V}
bias::V
end

Flux.@layer RANCell


"""
RANCell(in => out; init = glorot_uniform, bias = true)
Expand All @@ -13,14 +15,6 @@ The `RANCell`, introduced in [this paper](https://arxiv.org/pdf/1705.07393),
is a recurrent cell layer which provides additional memory through the
use of gates.
The forward pass consists of:
```math
\tilde{c}_t = W_{cx} x_t \\
i_t = \sigma\left(W_{ih} h_{t-1} + W_{ix} x_t + b_i\right) \\
f_t = \sigma\left(W_{fh} h_{t-1} + W_{fx} x_t + b_f\right) \\
c_t = i_t \circ \tilde{c}_t + f_t \circ c_{t-1} \\
h_t = tanh(c_t)
```
and returns both h_t anf c_t.
See [`RAN`](@ref) for a layer that processes entire sequences.
Expand Down Expand Up @@ -77,13 +71,12 @@ function (ran::RANCell)(inp::AbstractVecOrMat, (state, c_state))
Wi, Wh, b = ran.Wi, ran.Wh, ran.bias

#split
gxs = chunk(Wi * inp, 3, dims=1)
bs = chunk(b, 2, dims=1)
ghs = chunk(Wh * state, 2, dims=1)
gxs = chunk(Wi * inp, 3; dims=1)
ghs = chunk(Wh * state .+ b, 2; dims=1)

#compute
input_gate = @. sigmoid_fast(gxs[2] + ghs[1] + bs[1])
forget_gate = @. sigmoid_fast(gxs[3] + ghs[2] + bs[2])
input_gate = @. sigmoid_fast(gxs[2] + ghs[1])
forget_gate = @. sigmoid_fast(gxs[3] + ghs[2])
candidate_state = @. input_gate * gxs[1] + forget_gate * c_state
new_state = tanh_fast(candidate_state)
return new_state, candidate_state
Expand Down
4 changes: 4 additions & 0 deletions src/rhn_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ struct RHNCellUnit{I,V}
bias::V
end

Flux.@layer RHNCellUnit

function RHNCellUnit((in, out)::Pair; init = glorot_uniform, bias = true)
weight = init(3 * out, in)
b = create_bias(weight, bias, size(weight, 1))
Expand Down Expand Up @@ -37,6 +39,8 @@ struct RHNCell{C}
couple_carry::Bool
end

Flux.@layer RHNCell

function RHNCell((in, out), depth=3;
couple_carry::Bool = true,
cell_kwargs...)
Expand Down
9 changes: 5 additions & 4 deletions src/sru_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ struct SRUCell{I,H,B,V}
bias::V
end

Flux.@layer SRUCell

function SRUCell((in, out)::Pair, σ=tanh; init = glorot_uniform, bias = true)
Wi = init(2 * out, in)
Wh = init(2 * out, out)
Expand All @@ -29,13 +31,12 @@ function (sru::SRUCell)(inp::AbstractVecOrMat, (state, c_state))

#split
gxs = chunk(Wi * inp, 3, dims=1)
ghs = chunk(Wh * state, 2, dims=1)
bs = chunk(b, 2, dims=1)
ghs = chunk(Wh * state .+ b, 2, dims=1)
vs = chunk(v, 2, dims=1)

#compute
input_gate = @. sigmoid_fast(gxs[2] + ghs[1] + bs[1])
forget_gate = @. sigmoid_fast(gxs[3] + ghs[2] + bs[2])
input_gate = @. sigmoid_fast(gxs[2] + ghs[1])
forget_gate = @. sigmoid_fast(gxs[3] + ghs[2])
candidate_state = @. input_gate * gxs[1] + forget_gate * c_state
new_state = tanh_fast(candidate_state)
return new_state, candidate_state
Expand Down
22 changes: 22 additions & 0 deletions test/indrnn_cell.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
using Test
using RecurrentLayers
using Flux

@testset "IndRNNCell" begin
@testset "Sizes and parameters" begin
indrnn = IndRNNCell(3 => 5)
@test length(Flux.trainables(indrnn)) == 3

inp = rand(Float32, 3)
@test indrnn(inp) == indrnn(inp, zeros(Float32, 5))

indrnn = IndRNNCell(3 => 5; bias=false)
@test length(Flux.trainables(indrnn)) == 2

inp = rand(Float32, 3)
@test indrnn(inp) == indrnn(inp, zeros(Float32, 5))
end
end

#@testset "IndRNN" begin
#end
22 changes: 22 additions & 0 deletions test/lightru_cell.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
using Test
using RecurrentLayers
using Flux

@testset "LightRUCell" begin
@testset "Sizes and parameters" begin
lightru = LightRUCell(3 => 5)
@test length(Flux.trainables(lightru)) == 3

inp = rand(Float32, 3)
@test lightru(inp) == lightru(inp, zeros(Float32, 5))

lightru = LightRUCell(3 => 5; bias=false)
@test length(Flux.trainables(lightru)) == 2

inp = rand(Float32, 3)
@test lightru(inp) == lightru(inp, zeros(Float32, 5))
end
end

#@testset "LightRU" begin
#end
Loading

0 comments on commit 07b2106

Please sign in to comment.