From 801996016fa4682485ab0628f284d0d5a05e411b Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Wed, 16 Oct 2024 20:46:51 +0200 Subject: [PATCH 1/4] changing layers to novel flux design --- src/RecurrentLayers.jl | 2 +- src/ligru_cell.jl | 78 ++++++++++++++++++++++++----------------- src/mgu_cell.jl | 79 ++++++++++++++++++++++++------------------ 3 files changed, 93 insertions(+), 66 deletions(-) diff --git a/src/RecurrentLayers.jl b/src/RecurrentLayers.jl index 15f1951..6b8bb0b 100644 --- a/src/RecurrentLayers.jl +++ b/src/RecurrentLayers.jl @@ -1,7 +1,7 @@ module RecurrentLayers using Flux -import Flux: _size_check, _match_eltype, multigate, reshape_cell_output +import Flux: _size_check, _match_eltype, chunk, create_bias, zeros_like import Flux: glorot_uniform export MGUCell, LiGRUCell diff --git a/src/ligru_cell.jl b/src/ligru_cell.jl index 96a2c00..b092c9f 100644 --- a/src/ligru_cell.jl +++ b/src/ligru_cell.jl @@ -1,49 +1,63 @@ -struct LiGRUCell{I, H, V, S, F1, F2} - Wf::I +struct LiGRUCell{I, H, V} + Wi::I Wh::H b::V - state0::S - activation_fn::F1 - gate_activation_fn::F2 end function LiGRUCell((in, out)::Pair; - init=glorot_uniform, - initb=zeros32, - init_state=zeros32, - activation_fn=tanh_fast, - gate_activation_fn=sigmoid_fast) + init = glorot_uniform, + bias = true) - Wf = init(out * 2, in) + Wi = init(out * 2, in) Wh = init(out * 2, out) - b = initb(out * 2) - state0 = init_state(out, 1) - return LiGRUCell(Wf, Wh, b, state0, activation_fn, gate_activation_fn) + b = create_bias(Wi, bias, size(Wi, 1)) + + return LiGRUCell(Wi, Wh, b) end LiGRUCell(in, out; kwargs...) = LiGRUCell(in => out; kwargs...) -function (ligru::LiGRUCell{I,H,V,<:AbstractMatrix{T},F1, F2})(hidden, inp::AbstractVecOrMat) where {I,H,V,T,F1,F2} - _size_check(ligru, inp, 1 => size(ligru.Wf,2)) - Wf, Wh, bias, o = ligru.Wf, ligru.Wh, ligru.b, size(hidden, 1) - inp_t = _match_eltype(ligru, T, inp) - gxs, ghs, bs = multigate(Wf*inp_t, o, Val(2)), multigate(Wh*(hidden), o, Val(2)), multigate(bias, o, Val(2)) - forget_gate = @. ligru.gate_activation_fn(gxs[1] + ghs[1] + bs[1]) - +function (ligru::LiGRUCell)(inp::AbstractVecOrMat, state) + _size_check(ligru, inp, 1 => size(ligru.Wi,2)) + Wi, Wh, b = ligru.Wi, ligru.Wh, ligru.b + #split + gxs = chunk(Wi * inp, 2, dims=1) + ghs = chunk(Wh * state, 2, dims=1) + bs = chunk(b, 2, dims=1) + #compute + forget_gate = @. sigmoid_fast(gxs[1] + ghs[1] + bs[1]) candidate_hidden = @. tanh_fast(gxs[2] + ghs[2] + bs[2]) - new_h = forget_gate .* hidden .+ (1 .- forget_gate) .* candidate_hidden - return new_h, reshape_cell_output(new_h, inp) + new_state = forget_gate .* hidden .+ (1 .- forget_gate) .* candidate_hidden + return new_state end -Flux.@layer LiGRUCell -Base.show(io::IO, ligru::LiGRUCell) = - print(io, "LiGRUCell(", size(ligru.Wf, 2), " => ", size(ligru.Wf, 1) ÷ 2, ")") +struct LiGRU{M} + cell::M +end + +Flux.@layer :expand LiGRU -function LiGRU(args...; kwargs...) - return Flux.Recur(LiGRUCell(args...; kwargs...)) -end +function LiGRU((in, out)::Pair; init = glorot_uniform, bias = true) + cell = LiGRUCell(in => out; init, bias) + return LiGRU(cell) +end + +function (ligru::LiGRU)(inp) + state = zeros_like(inp, size(ligru.cell.Wh, 2)) + return ligru(inp, state) +end + +function (ligru::LiGRU)(inp, state) + @assert ndims(inp) == 2 || ndims(inp) == 3 + new_state = [] + for inp_t in eachslice(inp, dims=2) + state = ligru.cell(inp_t, state) + new_state = vcat(new_state, [state]) + end + return stack(new_state, dims=2) +end -function Flux.Recur(ligru::LiGRUCell) - return Flux.Recur(ligru, ligru.state0) -end \ No newline at end of file + +Base.show(io::IO, ligru::LiGRUCell) = + print(io, "LiGRUCell(", size(ligru.Wi, 2), " => ", size(ligru.Wi, 1) ÷ 2, ")") diff --git a/src/mgu_cell.jl b/src/mgu_cell.jl index eecb4f0..5bd1353 100644 --- a/src/mgu_cell.jl +++ b/src/mgu_cell.jl @@ -1,50 +1,63 @@ # Define the MGU cell in Flux.jl -struct MGUCell{I, H, V, S, F1, F2} - Wf::I +struct MGUCell{I, H, V} + Wi::I Wh::H b::V - state0::S - activation_fn::F1 - gate_activation_fn::F2 end function MGUCell((in, out)::Pair; - init=glorot_uniform, - initb=zeros32, - init_state=zeros32, - activation_fn=tanh_fast, - gate_activation_fn=sigmoid_fast) + init = glorot_uniform, + bias = true) - Wf = init(out * 2, in) + Wi = init(out * 2, in) Wh = init(out * 2, out) - b = initb(out * 2) - state0 = init_state(out, 1) - return MGUCell(Wf, Wh, b, state0, activation_fn, gate_activation_fn) + b = create_bias(Wi, bias, size(Wi, 1)) + + return MGUCell(Wi, Wh, b) end MGUCell(in, out; kwargs...) = MGUCell(in => out; kwargs...) -function (mgu::MGUCell{I,H,V,<:AbstractMatrix{T},F1, F2})(hidden, inp::AbstractVecOrMat) where {I,H,V,T,F1,F2} - _size_check(mgu, inp, 1 => size(mgu.Wf,2)) - Wf, Wh, bias, o = mgu.Wf, mgu.Wh, mgu.b, size(hidden, 1) - inp_t = _match_eltype(mgu, T, inp) - gxs, ghs, bs = multigate(Wf*inp_t, o, Val(2)), multigate(Wh*(hidden), o, Val(2)), multigate(bias, o, Val(2)) - forget_gate = @. mgu.gate_activation_fn(gxs[1] + ghs[1] + bs[1]) - - candidate_hidden = @. tanh_fast(gxs[2] + forget_gate * (ghs[2]*hidden) + bs[2]) - new_h = forget_gate .* hidden .+ (1 .- forget_gate) .* candidate_hidden - return new_h, reshape_cell_output(new_h, inp) +function (mgu::MGUCell)(inp::AbstractVecOrMat, state) + _size_check(mgu, inp, 1 => size(mgu.Wi,2)) + Wi, Wh, b = mgu.Wi, mgu.Wh, mgu.b + #split + gxs = chunk(Wi * inp, 2, dims=1) + bs = chunk(b, 2, dims=1) + ghs = chunk(Wh, 2, dims=1) + + forget_gate = @. sigmoid_fast(gxs[1] + ghs[1]*state + bs[1]) + candidate_state = @. tanh_fast(gxs[2] + ghs[2]*(forget_gate*state) + bs[2]) + new_state = @. forget_gate .* state .+ (1 .- forget_gate) .* candidate_state + return new_state end -Flux.@layer MGUCell -Base.show(io::IO, l::MGUCell) = - print(io, "MGUCell(", size(l.Wf, 2), " => ", size(l.Wf, 1) ÷ 2, ")") +struct MGU{M} + cell::M +end + +Flux.@layer :expand MGU -function MGU(args...; kwargs...) - return Flux.Recur(MGUCell(args...; kwargs...)) -end +function MGU((in, out)::Pair; init = glorot_uniform, bias = true) + cell = MGUCell(in => out; init, bias) + return MGU(cell) +end + +function (mgu::MGU)(inp) + state = zeros_like(inp, size(mgu.cell.Wh, 2)) + return mgu(inp, state) +end + +function (mgu::MGU)(inp, state) + @assert ndims(inp) == 2 || ndims(inp) == 3 + new_state = [] + for inp_t in eachslice(inp, dims=2) + state = mgu.cell(inp_t, state) + new_state = vcat(new_state, [state]) + end + return stack(new_state, dims=2) +end -function Flux.Recur(mgu::MGUCell) - return Flux.Recur(mgu, mgu.state0) -end \ No newline at end of file +Base.show(io::IO, mgu::MGUCell) = + print(io, "MGUCell(", size(mgu.Wi, 2), " => ", size(mgu.Wi, 1) ÷ 2, ")") From a28c24f22722d0aa7e2e9fa0a3c5f8781ed530dd Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Fri, 18 Oct 2024 17:37:37 +0200 Subject: [PATCH 2/4] added IndRNNCell and IndRNN --- src/RecurrentLayers.jl | 5 ++-- src/indrnn_cell.jl | 58 ++++++++++++++++++++++++++++++++++++++++++ src/mgu_cell.jl | 6 ++--- 3 files changed, 64 insertions(+), 5 deletions(-) create mode 100644 src/indrnn_cell.jl diff --git a/src/RecurrentLayers.jl b/src/RecurrentLayers.jl index 6b8bb0b..ef373da 100644 --- a/src/RecurrentLayers.jl +++ b/src/RecurrentLayers.jl @@ -4,10 +4,11 @@ using Flux import Flux: _size_check, _match_eltype, chunk, create_bias, zeros_like import Flux: glorot_uniform -export MGUCell, LiGRUCell -export MGU, LiGRU +export MGUCell, LiGRUCell, IndRNNCell +export MGU, LiGRU, IndRNN include("mgu_cell.jl") include("ligru_cell.jl") +include("indrnn_cell.jl") end #module diff --git a/src/indrnn_cell.jl b/src/indrnn_cell.jl new file mode 100644 index 0000000..5e901fb --- /dev/null +++ b/src/indrnn_cell.jl @@ -0,0 +1,58 @@ +struct IndRNNCell{F,I,H,V} + σ::F + Wi::I + u::H + b::V +end + +Flux.@layer IndRNNCell + +function IndRNNCell((in, out)::Pair, σ=relu; init = glorot_uniform, bias = true) + Wi = init(out, in) + u = init(out) + b = create_bias(Wi, bias, size(Wi, 1)) + return IndRNNCell(σ, Wi, u, b) +end + +function (indrnn::IndRNNCell)(x::AbstractVecOrMat) + state = zeros_like(x, size(indrnn.u, 1)) + return indrnn(x, state) + +function (indrnn::IndRNNCell)(inp::AbstractVecOrMat, state::AbstractVecOrMat) + _size_check(indrnn, inp, 1 => size(indrnn.Wi, 2)) + σ = NNlib.fast_act(indrnn.σ, inp) + state = σ.(indrnn.Wi*inp .+ indrnn.u.*state .+ indrnn.b) + return state +end + +function Base.show(io::IO, m::IndRNNCell) + print(io, "IndRNNCell(", size(m.Wi, 2), " => ", size(indrnn.Wi, 1)) + print(io, ", ", indrnn.σ) + print(io, ")") +end + +struct IndRNN{M} + cell::M +end + +Flux.@layer :expand IndRNN + +function IndRNN((in, out)::Pair, σ = tanh; bias = true, init = glorot_uniform) + cell = IndRNNCell(in => out, σ; bias=bias, init=init) + return IndRNN(cell) +end + +function (indrnn::IndRNN)(inp) + state = zeros_like(inp, size(indrnn.cell.u, 1)) + return indrnn(inp, state) +end + +function (indrnn::IndRNN)(inp, state) + @assert ndims(inp) == 2 || ndims(inp) == 3 + new_state = [] + for inp_t in eachslice(inp, dims=2) + state = indrnn.cell(inp_t, state) + new_state = vcat(new_state, [state]) + end + return stack(new_state, dims=2) +end \ No newline at end of file diff --git a/src/mgu_cell.jl b/src/mgu_cell.jl index 5bd1353..9c7349e 100644 --- a/src/mgu_cell.jl +++ b/src/mgu_cell.jl @@ -32,6 +32,9 @@ function (mgu::MGUCell)(inp::AbstractVecOrMat, state) return new_state end +Base.show(io::IO, mgu::MGUCell) = + print(io, "MGUCell(", size(mgu.Wi, 2), " => ", size(mgu.Wi, 1) ÷ 2, ")") + struct MGU{M} cell::M @@ -58,6 +61,3 @@ function (mgu::MGU)(inp, state) end return stack(new_state, dims=2) end - -Base.show(io::IO, mgu::MGUCell) = - print(io, "MGUCell(", size(mgu.Wi, 2), " => ", size(mgu.Wi, 1) ÷ 2, ")") From 0f2b1f9abf32898fdb62c021519d096a703d9c05 Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Wed, 23 Oct 2024 20:59:37 +0200 Subject: [PATCH 3/4] added ran, pt1 --- src/RecurrentLayers.jl | 5 +-- src/indrnn_cell.jl | 1 + src/ligru_cell.jl | 2 +- src/mgu_cell.jl | 4 +-- src/ran_cell.jl | 74 ++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 81 insertions(+), 5 deletions(-) create mode 100644 src/ran_cell.jl diff --git a/src/RecurrentLayers.jl b/src/RecurrentLayers.jl index ef373da..628afe7 100644 --- a/src/RecurrentLayers.jl +++ b/src/RecurrentLayers.jl @@ -4,11 +4,12 @@ using Flux import Flux: _size_check, _match_eltype, chunk, create_bias, zeros_like import Flux: glorot_uniform -export MGUCell, LiGRUCell, IndRNNCell -export MGU, LiGRU, IndRNN +export MGUCell, LiGRUCell, IndRNNCell, RANCell +export MGU, LiGRU, IndRNN, RAN include("mgu_cell.jl") include("ligru_cell.jl") include("indrnn_cell.jl") +include("ran_cell.jl") end #module diff --git a/src/indrnn_cell.jl b/src/indrnn_cell.jl index 5e901fb..eef483f 100644 --- a/src/indrnn_cell.jl +++ b/src/indrnn_cell.jl @@ -17,6 +17,7 @@ end function (indrnn::IndRNNCell)(x::AbstractVecOrMat) state = zeros_like(x, size(indrnn.u, 1)) return indrnn(x, state) +end function (indrnn::IndRNNCell)(inp::AbstractVecOrMat, state::AbstractVecOrMat) _size_check(indrnn, inp, 1 => size(indrnn.Wi, 2)) diff --git a/src/ligru_cell.jl b/src/ligru_cell.jl index b092c9f..11aba26 100644 --- a/src/ligru_cell.jl +++ b/src/ligru_cell.jl @@ -1,7 +1,7 @@ struct LiGRUCell{I, H, V} Wi::I Wh::H - b::V + bias::V end function LiGRUCell((in, out)::Pair; diff --git a/src/mgu_cell.jl b/src/mgu_cell.jl index 9c7349e..de6e6a1 100644 --- a/src/mgu_cell.jl +++ b/src/mgu_cell.jl @@ -2,7 +2,7 @@ struct MGUCell{I, H, V} Wi::I Wh::H - b::V + bias::V end function MGUCell((in, out)::Pair; @@ -20,7 +20,7 @@ MGUCell(in, out; kwargs...) = MGUCell(in => out; kwargs...) function (mgu::MGUCell)(inp::AbstractVecOrMat, state) _size_check(mgu, inp, 1 => size(mgu.Wi,2)) - Wi, Wh, b = mgu.Wi, mgu.Wh, mgu.b + Wi, Wh, b = mgu.Wi, mgu.Wh, mgu.bias #split gxs = chunk(Wi * inp, 2, dims=1) bs = chunk(b, 2, dims=1) diff --git a/src/ran_cell.jl b/src/ran_cell.jl new file mode 100644 index 0000000..bae207c --- /dev/null +++ b/src/ran_cell.jl @@ -0,0 +1,74 @@ +#https://arxiv.org/pdf/1705.07393 +struct RANCell{I,H,V} + Wi::I + Wh::H + bias::V +end + +function RANCell((in, out)::Pair, σ=tanh; init = glorot_uniform, bias = true) + Wi = init(3 * out, in) + Wh = init(2 * out, out) + b = create_bias(Wi, bias, size(Wh, 1)) + + return RANCell(Wi, Wh, b) +end + +RANCell(in, out; kwargs...) = RANCell(in => out; kwargs...) + +function (ran::RANCell)(inp::AbstractVecOrMat) + state = zeros_like(inp, size(ran.Wh, 2)) + c_state = zeros_like(state) + return ran(inp, (state, c_state)) +end + +function (ran::RANCell)(inp::AbstractVecOrMat, (state, c_state)) + _size_check(ran, inp, 1 => size(ran.Wi,2)) + Wi, Wh, b = ran.Wi, ran.Wh, ran.bias + + #split + gxs = chunk(Wi * inp, 3, dims=1) + bs = chunk(b, 2, dims=1) + ghs = chunk(Wh * state, 2, dims=1) + + #compute + input_gate = @. sigmoid_fast(gxs[2] + ghs[1] + bs[1]) + forget_gate = @. sigmoid_fast(gxs[3] + ghs[2] + bs[2]) + candidate_state = @. input_gate * gxs[1] + forget_gate * c_state + new_state = tanh_fast(candidate_state) + return new_state, candidate_state +end + +Base.show(io::IO, ran::RANCell) = + print(io, "RANCell(", size(ran.Wi, 2), " => ", size(ran.Wi, 1)÷2, ")") + + +struct RAN{M} + cell::M +end + +Flux.@layer :expand RAN + +function RAN((in, out)::Pair; init = glorot_uniform, bias = true) + cell = RANCell(in => out; init, bias) + return RAN(cell) +end + +function (ran::RAN)(inp) + state = zeros_like(inp, size(ran.cell.Wh, 2)) + c_state = zeros_like(state) + return ran(inp, (state, c_state)) +end + +function (ran::RAN)(inp, (state, c_state)) + @assert ndims(inp) == 2 || ndims(inp) == 3 + new_state = [] + new_cstate = [] + for inp_t in eachslice(inp, dims=2) + state, c_state = ran.cell(inp_t, (state, c_state)) + new_state = vcat(new_state, [state]) + new_cstate = vcat(new_cstate, [c_state]) + end + return stack(new_state, dims=2), stack(new_cstate, dims=2) +end + + From 8558f1059c705ea6ca583dda1864b2999caeb8b3 Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Fri, 25 Oct 2024 17:22:09 +0200 Subject: [PATCH 4/4] added lru --- src/RecurrentLayers.jl | 3 +- src/indrnn_cell.jl | 1 + src/ligru_cell.jl | 1 + src/lru_cell.jl | 66 ++++++++++++++++++++++++++++++++++++++++++ src/mgu_cell.jl | 13 ++++++--- src/ran_cell.jl | 3 +- src/sru_cell.jl | 45 ++++++++++++++++++++++++++++ 7 files changed, 125 insertions(+), 7 deletions(-) create mode 100644 src/lru_cell.jl create mode 100644 src/sru_cell.jl diff --git a/src/RecurrentLayers.jl b/src/RecurrentLayers.jl index 628afe7..5165237 100644 --- a/src/RecurrentLayers.jl +++ b/src/RecurrentLayers.jl @@ -4,12 +4,13 @@ using Flux import Flux: _size_check, _match_eltype, chunk, create_bias, zeros_like import Flux: glorot_uniform -export MGUCell, LiGRUCell, IndRNNCell, RANCell +export MGUCell, LiGRUCell, IndRNNCell, RANCell, LRUCell export MGU, LiGRU, IndRNN, RAN include("mgu_cell.jl") include("ligru_cell.jl") include("indrnn_cell.jl") include("ran_cell.jl") +include("lru_cell.jl") end #module diff --git a/src/indrnn_cell.jl b/src/indrnn_cell.jl index eef483f..b7f05aa 100644 --- a/src/indrnn_cell.jl +++ b/src/indrnn_cell.jl @@ -1,3 +1,4 @@ +#https://arxiv.org/pdf/1803.04831 struct IndRNNCell{F,I,H,V} σ::F Wi::I diff --git a/src/ligru_cell.jl b/src/ligru_cell.jl index 11aba26..9e29126 100644 --- a/src/ligru_cell.jl +++ b/src/ligru_cell.jl @@ -1,3 +1,4 @@ +#https://arxiv.org/pdf/1803.10225 struct LiGRUCell{I, H, V} Wi::I Wh::H diff --git a/src/lru_cell.jl b/src/lru_cell.jl new file mode 100644 index 0000000..83db9d7 --- /dev/null +++ b/src/lru_cell.jl @@ -0,0 +1,66 @@ +#https://www.mdpi.com/2079-9292/13/16/3204 +struct LRUCell{I,H,V} + Wi::I + Wh::H + bias::V +end + +function LRUCell((in, out)::Pair, σ=tanh; init = glorot_uniform, bias = true) + Wi = init(2 * out, in) + Wh = init(out, out) + b = create_bias(Wi, bias, size(Wh, 1)) + + return LRUCell(Wi, Wh, b) +end + +LRUCell(in, out; kwargs...) = LRUCell(in => out; kwargs...) + +function (lru::LRUCell)(inp::AbstractVecOrMat) + state = zeros_like(inp, size(lru.Wh, 2)) + return lru(inp, state) +end + +function (lru::LRUCell)(inp::AbstractVecOrMat, state) + _size_check(lru, inp, 1 => size(lru.Wi,2)) + Wi, Wh, b = lru.Wi, lru.Wh, lru.bias + + #split + gxs = chunk(Wi * inp, 2, dims=1) + + #compute + candidate_state = @. tanh_fast(gxs[1]) + forget_gate = sigmoid_fast(gxs[2] .+ Wh * state .+ b) + new_state = @. (1 - forget_gate) * state + forget_gate * candidate_state + return new_state +end + +Base.show(io::IO, lru::LRUCell) = + print(io, "LRUCell(", size(lru.Wi, 2), " => ", size(lru.Wi, 1)÷2, ")") + + + +struct LRU{M} + cell::M +end + +Flux.@layer :expand LRU + +function LRU((in, out)::Pair; init = glorot_uniform, bias = true) + cell = LRUCell(in => out; init, bias) + return LRU(cell) +end + +function (lru::LRU)(inp) + state = zeros_like(inp, size(lru.cell.Wh, 2)) + return lru(inp, state) +end + +function (lru::LRU)(inp, state) + @assert ndims(inp) == 2 || ndims(inp) == 3 + new_state = [] + for inp_t in eachslice(inp, dims=2) + state = lru.cell(inp_t, state) + new_state = vcat(new_state, [state]) + end + return stack(new_state, dims=2) +end diff --git a/src/mgu_cell.jl b/src/mgu_cell.jl index de6e6a1..d8a063a 100644 --- a/src/mgu_cell.jl +++ b/src/mgu_cell.jl @@ -1,4 +1,4 @@ -# Define the MGU cell in Flux.jl +#https://arxiv.org/pdf/1603.09420 struct MGUCell{I, H, V} Wi::I Wh::H @@ -18,6 +18,11 @@ end MGUCell(in, out; kwargs...) = MGUCell(in => out; kwargs...) +function (mgu::MGUCell)(inp::AbstractVecOrMat) + state = zeros_like(inp, size(mgu.Wh, 2)) + return mgu(inp, state) +end + function (mgu::MGUCell)(inp::AbstractVecOrMat, state) _size_check(mgu, inp, 1 => size(mgu.Wi,2)) Wi, Wh, b = mgu.Wi, mgu.Wh, mgu.bias @@ -26,9 +31,9 @@ function (mgu::MGUCell)(inp::AbstractVecOrMat, state) bs = chunk(b, 2, dims=1) ghs = chunk(Wh, 2, dims=1) - forget_gate = @. sigmoid_fast(gxs[1] + ghs[1]*state + bs[1]) - candidate_state = @. tanh_fast(gxs[2] + ghs[2]*(forget_gate*state) + bs[2]) - new_state = @. forget_gate .* state .+ (1 .- forget_gate) .* candidate_state + forget_gate = sigmoid_fast.(gxs[1] .+ ghs[1]*state .+ bs[1]) + candidate_state = tanh_fast.(gxs[2] .+ ghs[2]*(forget_gate.*state) .+ bs[2]) + new_state = forget_gate .* state .+ (1 .- forget_gate) .* candidate_state return new_state end diff --git a/src/ran_cell.jl b/src/ran_cell.jl index bae207c..49277ff 100644 --- a/src/ran_cell.jl +++ b/src/ran_cell.jl @@ -39,7 +39,7 @@ function (ran::RANCell)(inp::AbstractVecOrMat, (state, c_state)) end Base.show(io::IO, ran::RANCell) = - print(io, "RANCell(", size(ran.Wi, 2), " => ", size(ran.Wi, 1)÷2, ")") + print(io, "RANCell(", size(ran.Wi, 2), " => ", size(ran.Wi, 1)÷3, ")") struct RAN{M} @@ -71,4 +71,3 @@ function (ran::RAN)(inp, (state, c_state)) return stack(new_state, dims=2), stack(new_cstate, dims=2) end - diff --git a/src/sru_cell.jl b/src/sru_cell.jl new file mode 100644 index 0000000..fe04575 --- /dev/null +++ b/src/sru_cell.jl @@ -0,0 +1,45 @@ +#https://arxiv.org/pdf/1709.02755 +struct SRUCell{I,H,B,V} + Wi::I + Wh::H + v::B + bias::V +end + +function SRUCell((in, out)::Pair, σ=tanh; init = glorot_uniform, bias = true) + Wi = init(2 * out, in) + Wh = init(2 * out, out) + v = init(2 * out) + b = create_bias(Wi, bias, size(Wh, 1)) + + return SRUCell(Wi, Wh, v, b) +end + +SRUCell(in, out; kwargs...) = SRUCell(in => out; kwargs...) + +function (sru::SRUCell)(inp::AbstractVecOrMat) + state = zeros_like(inp, size(sru.Wh, 2)) + c_state = zeros_like(state) + return sru(inp, (state, c_state)) +end + +function (sru::SRUCell)(inp::AbstractVecOrMat, (state, c_state)) + _size_check(sru, inp, 1 => size(sru.Wi,2)) + Wi, Wh, v, b = sru.Wi, sru.Wh, sru.v, sru.bias + + #split + gxs = chunk(Wi * inp, 3, dims=1) + ghs = chunk(Wh * state, 2, dims=1) + bs = chunk(b, 2, dims=1) + vs = chunk(v, 2, dims=1) + + #compute + input_gate = @. sigmoid_fast(gxs[2] + ghs[1] + bs[1]) + forget_gate = @. sigmoid_fast(gxs[3] + ghs[2] + bs[2]) + candidate_state = @. input_gate * gxs[1] + forget_gate * c_state + new_state = tanh_fast(candidate_state) + return new_state, candidate_state +end + +Base.show(io::IO, sru::SRUCell) = + print(io, "SRUCell(", size(sru.Wi, 2), " => ", size(sru.Wi, 1)÷2, ")") \ No newline at end of file