From 5e8cadf9b98d6ca67d48255a1ab15a9902961716 Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Mon, 16 Dec 2024 11:10:38 +0100 Subject: [PATCH 1/3] small fixes and abstraction --- src/RecurrentLayers.jl | 5 +++++ src/fastrnn_cell.jl | 9 ++------- src/indrnn_cell.jl | 7 +------ src/lightru_cell.jl | 9 ++------- src/ligru_cell.jl | 15 ++++----------- src/mgu_cell.jl | 9 ++------- src/mut_cell.jl | 21 +++------------------ src/nas_cell.jl | 9 ++------- src/peepholelstm_cell.jl | 9 ++------- src/ran_cell.jl | 10 ++-------- src/rhn_cell.jl | 2 +- src/scrn_cell.jl | 7 +------ 12 files changed, 27 insertions(+), 85 deletions(-) diff --git a/src/RecurrentLayers.jl b/src/RecurrentLayers.jl index 4172e48..10f5aa7 100644 --- a/src/RecurrentLayers.jl +++ b/src/RecurrentLayers.jl @@ -35,6 +35,11 @@ function (rlayer::AbstractRecurrentLayer)(inp::AbstractVecOrMat) return rlayer(inp, state) end +function (rlayer::AbstractRecurrentLayer)(inp, state) + @assert ndims(inp) == 2 || ndims(inp) == 3 + return scan(rlayer.cell, inp, state) +end + export MGUCell, LiGRUCell, IndRNNCell, RANCell, LightRUCell, RHNCell, RHNCellUnit, NASCell, MUT1Cell, MUT2Cell, MUT3Cell, SCRNCell, PeepholeLSTMCell, FastRNNCell, FastGRNNCell diff --git a/src/fastrnn_cell.jl b/src/fastrnn_cell.jl index 67e3ca2..6d6f513 100644 --- a/src/fastrnn_cell.jl +++ b/src/fastrnn_cell.jl @@ -88,7 +88,7 @@ struct FastRNN{M} <: AbstractRecurrentLayer cell::M end -Flux.@layer :expand FastRNN +Flux.@layer :noexpand FastRNN @doc raw""" FastRNN((input_size => hidden_size), [activation]; kwargs...) @@ -234,7 +234,7 @@ struct FastGRNN{M} <: AbstractRecurrentLayer cell::M end -Flux.@layer :expand FastGRNN +Flux.@layer :noexpand FastGRNN @doc raw""" FastGRNN((input_size => hidden_size), [activation]; kwargs...) @@ -279,9 +279,4 @@ function FastGRNN((input_size, hidden_size)::Pair, activation = tanh_fast; kwargs...) cell = FastGRNNCell(input_size => hidden_size, activation; kwargs...) return FastGRNN(cell) -end - -function (fastgrnn::FastGRNN)(inp, state) - @assert ndims(inp) == 2 || ndims(inp) == 3 - return scan(fastgrnn.call, inp, state) end \ No newline at end of file diff --git a/src/indrnn_cell.jl b/src/indrnn_cell.jl index 94e5b06..ecd1738 100644 --- a/src/indrnn_cell.jl +++ b/src/indrnn_cell.jl @@ -74,7 +74,7 @@ struct IndRNN{M} <: AbstractRecurrentLayer cell::M end -Flux.@layer :expand IndRNN +Flux.@layer :noexpand IndRNN @doc raw""" IndRNN((input_size, hidden_size)::Pair, σ = tanh, σ=relu; @@ -113,9 +113,4 @@ See [`IndRNNCell`](@ref) for a layer that processes a single sequence. function IndRNN((input_size, hidden_size)::Pair, σ = tanh; kwargs...) cell = IndRNNCell(input_size, hidden_size, σ; kwargs...) return IndRNN(cell) -end - -function (indrnn::IndRNN)(inp, state) - @assert ndims(inp) == 2 || ndims(inp) == 3 - return scan(indrnn.cell, inp, state) end \ No newline at end of file diff --git a/src/lightru_cell.jl b/src/lightru_cell.jl index 37de018..5eb4a87 100644 --- a/src/lightru_cell.jl +++ b/src/lightru_cell.jl @@ -82,7 +82,7 @@ struct LightRU{M} <: AbstractRecurrentLayer cell::M end -Flux.@layer :expand LightRU +Flux.@layer :noexpand LightRU @doc raw""" LightRU((input_size => hidden_size)::Pair; kwargs...) @@ -124,9 +124,4 @@ h_t &= (1 - f_t) \odot h_{t-1} + f_t \odot \tilde{h}_t. function LightRU((input_size, hidden_size)::Pair; kwargs...) cell = LightRUCell(input_size => hidden_size; kwargs...) return LightRU(cell) -end - -function (lightru::LightRU)(inp, state) - @assert ndims(inp) == 2 || ndims(inp) == 3 - return scan(lightru.cell, inp, state) -end +end \ No newline at end of file diff --git a/src/ligru_cell.jl b/src/ligru_cell.jl index 259a901..e8e30c1 100644 --- a/src/ligru_cell.jl +++ b/src/ligru_cell.jl @@ -75,12 +75,14 @@ function (ligru::LiGRUCell)(inp::AbstractVecOrMat, state) return new_state, new_state end +Base.show(io::IO, ligru::LiGRUCell) = + print(io, "LiGRUCell(", size(ligru.Wi, 2), " => ", size(ligru.Wi, 1) ÷ 2, ")") struct LiGRU{M} <: AbstractRecurrentLayer cell::M end -Flux.@layer :expand LiGRU +Flux.@layer :noexpand LiGRU @doc raw""" LiGRU((input_size => hidden_size)::Pair; kwargs...) @@ -124,13 +126,4 @@ h_t &= z_t \odot h_{t-1} + (1 - z_t) \odot \tilde{h}_t function LiGRU((input_size, hidden_size)::Pair; kwargs...) cell = LiGRUCell(input_size => hidden_size; kwargs...) return LiGRU(cell) -end - -function (ligru::LiGRU)(inp, state) - @assert ndims(inp) == 2 || ndims(inp) == 3 - return scan(ligru.cell, inp, state) -end - - -Base.show(io::IO, ligru::LiGRUCell) = - print(io, "LiGRUCell(", size(ligru.Wi, 2), " => ", size(ligru.Wi, 1) ÷ 2, ")") +end \ No newline at end of file diff --git a/src/mgu_cell.jl b/src/mgu_cell.jl index 1a55e87..ce78015 100644 --- a/src/mgu_cell.jl +++ b/src/mgu_cell.jl @@ -81,7 +81,7 @@ struct MGU{M} <: AbstractRecurrentLayer cell::M end -Flux.@layer :expand MGU +Flux.@layer :noexpand MGU @doc raw""" MGU((input_size => hidden_size)::Pair; kwargs...) @@ -123,9 +123,4 @@ h_t &= (1 - f_t) \odot h_{t-1} + f_t \odot \tilde{h}_t function MGU((input_size, hidden_size)::Pair; kwargs...) cell = MGUCell(input_size => hidden_size; kwargs...) return MGU(cell) -end - -function (mgu::MGU)(inp, state) - @assert ndims(inp) == 2 || ndims(inp) == 3 - return scan(mgu.cell, inp, state) -end +end \ No newline at end of file diff --git a/src/mut_cell.jl b/src/mut_cell.jl index 2820474..0c9c60b 100644 --- a/src/mut_cell.jl +++ b/src/mut_cell.jl @@ -84,7 +84,7 @@ struct MUT1{M} <: AbstractRecurrentLayer cell::M end -Flux.@layer :expand MUT1 +Flux.@layer :noexpand MUT1 @doc raw""" MUT1((input_size => hidden_size); kwargs...) @@ -129,11 +129,6 @@ function MUT1((input_size, hidden_size)::Pair; kwargs...) return MUT1(cell) end -function (mut::MUT1)(inp, state) - @assert ndims(inp) == 2 || ndims(inp) == 3 - return scan(mut.cell, inp, state) -end - struct MUT2Cell{I, H, V} <: AbstractRecurrentCell Wi::I @@ -220,7 +215,7 @@ struct MUT2{M} <: AbstractRecurrentLayer cell::M end -Flux.@layer :expand MUT2 +Flux.@layer :noexpand MUT2 @doc raw""" MUT2Cell((input_size => hidden_size); kwargs...) @@ -264,11 +259,6 @@ function MUT2((input_size, hidden_size)::Pair; kwargs...) cell = MUT2Cell(input_size => hidden_size; kwargs...) return MUT2(cell) end - -function (mut::MUT2)(inp, state) - @assert ndims(inp) == 2 || ndims(inp) == 3 - return scan(mut.cell, inp, state) -end struct MUT3Cell{I, H, V} <: AbstractRecurrentCell @@ -354,7 +344,7 @@ struct MUT3{M} <: AbstractRecurrentLayer cell::M end -Flux.@layer :expand MUT3 +Flux.@layer :noexpand MUT3 @doc raw""" MUT3((input_size => hidden_size); kwargs...) @@ -397,9 +387,4 @@ h_{t+1} &= \tanh(U_h (r \odot h_t) + W_h x_t + b_h) \odot z \\ function MUT3((input_size, hidden_size)::Pair; kwargs...) cell = MUT3Cell(input_size => hidden_size; kwargs...) return MUT3(cell) -end - -function (mut::MUT3)(inp, state) - @assert ndims(inp) == 2 || ndims(inp) == 3 - return scan(mut.cell, inp, state) end \ No newline at end of file diff --git a/src/nas_cell.jl b/src/nas_cell.jl index e27f92d..868fda0 100644 --- a/src/nas_cell.jl +++ b/src/nas_cell.jl @@ -148,7 +148,7 @@ struct NAS{M} <: AbstractRecurrentLayer cell::M end -Flux.@layer :expand NAS +Flux.@layer :noexpand NAS @doc raw""" NAS((input_size => hidden_size)::Pair; kwargs...) @@ -211,9 +211,4 @@ h_{\text{new}} &= \tanh(c_{\text{new}} \cdot l_5) function NAS((input_size, hidden_size)::Pair; kwargs...) cell = NASCell(input_size => hidden_size; kwargs...) return NAS(cell) -end - -function (nas::NAS)(inp, state) - @assert ndims(inp) == 2 || ndims(inp) == 3 - return scan(nas.cell, inp, state) -end +end \ No newline at end of file diff --git a/src/peepholelstm_cell.jl b/src/peepholelstm_cell.jl index 001180b..ac71e1a 100644 --- a/src/peepholelstm_cell.jl +++ b/src/peepholelstm_cell.jl @@ -86,7 +86,7 @@ struct PeepholeLSTM{M} <: AbstractRecurrentLayer cell::M end -Flux.@layer :expand PeepholeLSTM +Flux.@layer :noexpand PeepholeLSTM @doc raw""" PeepholeLSTM((input_size => hidden_size)::Pair; kwargs...) @@ -130,9 +130,4 @@ h_t &= o_t \odot \sigma_h(c_t). function PeepholeLSTM((input_size, hidden_size)::Pair; kwargs...) cell = PeepholeLSTM(input_size => hidden_size; kwargs...) return PeepholeLSTM(cell) -end - -function (lstm::PeepholeLSTM)(inp, state) - @assert ndims(inp) == 2 || ndims(inp) == 3 - return scan(lstm.cell, inp, state) -end +end \ No newline at end of file diff --git a/src/ran_cell.jl b/src/ran_cell.jl index c94bb7f..c2f393f 100644 --- a/src/ran_cell.jl +++ b/src/ran_cell.jl @@ -88,7 +88,7 @@ struct RAN{M} <: AbstractRecurrentLayer cell::M end -Flux.@layer :expand RAN +Flux.@layer :noexpand RAN @doc raw""" RAN(input_size => hidden_size; kwargs...) @@ -137,10 +137,4 @@ h_t &= g(c_t) function RAN((input_size, hidden_size)::Pair; kwargs...) cell = RANCell(input_size => hidden_size; kwargs...) return RAN(cell) -end - -function (ran::RAN)(inp, state) - @assert ndims(inp) == 2 || ndims(inp) == 3 - return scan(ran.cell, inp, state) -end - +end \ No newline at end of file diff --git a/src/rhn_cell.jl b/src/rhn_cell.jl index 3df2bd9..6a5866b 100644 --- a/src/rhn_cell.jl +++ b/src/rhn_cell.jl @@ -138,7 +138,7 @@ struct RHN{M} cell::M end -Flux.@layer :expand RHN +Flux.@layer :noexpand RHN @doc raw""" RHN((input_size => hidden_size)::Pair depth=3; kwargs...) diff --git a/src/scrn_cell.jl b/src/scrn_cell.jl index 117ba1c..4b0b6e5 100644 --- a/src/scrn_cell.jl +++ b/src/scrn_cell.jl @@ -92,7 +92,7 @@ struct SCRN{M} <: AbstractRecurrentLayer cell::M end -Flux.@layer :expand SCRN +Flux.@layer :noexpand SCRN @doc raw""" SCRN((input_size => hidden_size)::Pair; @@ -139,9 +139,4 @@ y_t &= f(U_y h_t + W_y s_t) function SCRN((input_size, hidden_size)::Pair; kwargs...) cell = SCRNCell(input_size => hidden_size; kwargs...) return SCRN(cell) -end - -function (scrn::SCRN)(inp, state) - @assert ndims(inp) == 2 || ndims(inp) == 3 - return scan(scrn.cell, inp, state) end \ No newline at end of file From 487426c389939b8a576edcde6efcfdad38a2c8ab Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Mon, 16 Dec 2024 11:11:17 +0100 Subject: [PATCH 2/3] up version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index fa5a441..7a3d20d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "RecurrentLayers" uuid = "78449bcf-6750-4b78-9e82-63d4a1ccdf8c" authors = ["Francesco Martinuzzi"] -version = "0.2.0" +version = "0.2.1" [deps] Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" From 3c4eff584e0a8c936f1cbe3d39661314aa4f588b Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Mon, 16 Dec 2024 11:35:13 +0100 Subject: [PATCH 3/3] typing --- src/RecurrentLayers.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/RecurrentLayers.jl b/src/RecurrentLayers.jl index 10f5aa7..dec1ab4 100644 --- a/src/RecurrentLayers.jl +++ b/src/RecurrentLayers.jl @@ -35,7 +35,7 @@ function (rlayer::AbstractRecurrentLayer)(inp::AbstractVecOrMat) return rlayer(inp, state) end -function (rlayer::AbstractRecurrentLayer)(inp, state) +function (rlayer::AbstractRecurrentLayer)(inp::AbstractArray, state::AbstractVecOrMat) @assert ndims(inp) == 2 || ndims(inp) == 3 return scan(rlayer.cell, inp, state) end @@ -47,7 +47,6 @@ export MGU, LiGRU, IndRNN, RAN, LightRU, NAS, RHN, MUT1, MUT2, MUT3, SCRN, PeepholeLSTM, FastRNN, FastGRNN -#TODO add double bias include("mgu_cell.jl") include("ligru_cell.jl") include("indrnn_cell.jl")