Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hotfix LSTM ouput #2547

Merged
merged 4 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ CUDA = "5"
ChainRulesCore = "1.12"
Compat = "4.10.0"
Enzyme = "0.13"
Functors = "0.5"
EnzymeCore = "0.7.7, 0.8.4"
Functors = "0.5"
MLDataDevices = "1.4.2"
MLUtils = "0.4"
MPI = "0.20.19"
Expand Down
99 changes: 39 additions & 60 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,20 @@
out_from_state(state) = state
out_from_state(state::Tuple) = state[1]

function scan(cell, x, state0)
state = state0
y = []
for x_t in eachslice(x, dims = 2)
state = cell(x_t, state)
out = out_from_state(state)
y = vcat(y, [out])
end
return stack(y, dims = 2)
end


# Vanilla RNN

# Vanilla RNN
@doc raw"""
RNNCell(in => out, σ = tanh; init_kernel = glorot_uniform,
init_recurrent_kernel = glorot_uniform, bias = true)
Expand Down Expand Up @@ -215,13 +229,7 @@
@assert ndims(x) == 2 || ndims(x) == 3
# [x] = [in, L] or [in, L, B]
# [h] = [out] or [out, B]
y = []
for x_t in eachslice(x, dims = 2)
h = m.cell(x_t, h)
# y = [y..., h]
y = vcat(y, [h])
end
return stack(y, dims = 2)
return scan(m.cell, x, h)
end


Expand Down Expand Up @@ -297,22 +305,20 @@
end

function LSTMCell(
(in, out)::Pair;
init_kernel = glorot_uniform,
init_recurrent_kernel = glorot_uniform,
bias = true,
)
(in, out)::Pair;
init_kernel = glorot_uniform,
init_recurrent_kernel = glorot_uniform,
bias = true,
)

Wi = init_kernel(out * 4, in)
Wh = init_recurrent_kernel(out * 4, out)
b = create_bias(Wi, bias, out * 4)
cell = LSTMCell(Wi, Wh, b)
return cell
end

function (lstm::LSTMCell)(x::AbstractVecOrMat)
state, cstate = initialstates(lstm)
return lstm(x, (state, cstate))
end
(lstm::LSTMCell)(x::AbstractVecOrMat) = lstm(x, initialstates(lstm))

Check warning on line 321 in src/layers/recurrent.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/recurrent.jl#L321

Added line #L321 was not covered by tests

function (m::LSTMCell)(x::AbstractVecOrMat, (h, c))
_size_check(m, x, 1 => size(m.Wi, 2))
Expand Down Expand Up @@ -368,15 +374,14 @@
They should be vectors of size `out` or matrices of size `out x batch_size`.
If not provided, they are assumed to be vectors of zeros, initialized by [`initialstates`](@ref).

Returns a tuple `(h′, c′)` containing all new hidden states `h_t` and cell states `c_t`
in tensors of size `out x len` or `out x len x batch_size`.
Returns all new hidden states `h_t` as an array of size `out x len` or `out x len x batch_size`.

# Examples

```julia
struct Model
lstm::LSTM
h0::AbstractVector
h0::AbstractVector # trainable initial hidden state
c0::AbstractVector
end

Expand All @@ -387,7 +392,7 @@
d_in, d_out, len, batch_size = 2, 3, 4, 5
x = rand(Float32, (d_in, len, batch_size))
model = Model(LSTM(d_in => d_out), zeros(Float32, d_out), zeros(Float32, d_out))
h, c = model(x)
h = model(x)
size(h) # out x len x batch_size
```
"""
Expand All @@ -404,21 +409,11 @@
return LSTM(cell)
end

function (lstm::LSTM)(x::AbstractArray)
state, cstate = initialstates(lstm)
return lstm(x, (state, cstate))
end
(lstm::LSTM)(x::AbstractArray) = lstm(x, initialstates(lstm))

Check warning on line 412 in src/layers/recurrent.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/recurrent.jl#L412

Added line #L412 was not covered by tests

function (m::LSTM)(x::AbstractArray, (h, c))
function (m::LSTM)(x::AbstractArray, state0)
@assert ndims(x) == 2 || ndims(x) == 3
h′ = []
c′ = []
for x_t in eachslice(x, dims = 2)
h, c = m.cell(x_t, (h, c))
h′ = vcat(h′, [h])
c′ = vcat(c′, [c])
end
return stack(h′, dims = 2), stack(c′, dims = 2)
return scan(m.cell, x, state0)
end

# GRU
Expand Down Expand Up @@ -485,11 +480,12 @@
initialstates(gru::GRUCell) = zeros_like(gru.Wh, size(gru.Wh, 2))

function GRUCell(
(in, out)::Pair;
init_kernel = glorot_uniform,
init_recurrent_kernel = glorot_uniform,
bias = true,
)
(in, out)::Pair;
init_kernel = glorot_uniform,
init_recurrent_kernel = glorot_uniform,
bias = true,
)

Wi = init_kernel(out * 3, in)
Wh = init_recurrent_kernel(out * 3, out)
b = create_bias(Wi, bias, size(Wi, 1))
Expand Down Expand Up @@ -581,20 +577,11 @@
return GRU(cell)
end

function (gru::GRU)(x::AbstractArray)
state = initialstates(gru)
return gru(x, state)
end
(gru::GRU)(x::AbstractArray) = gru(x, initialstates(gru))

Check warning on line 580 in src/layers/recurrent.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/recurrent.jl#L580

Added line #L580 was not covered by tests

function (m::GRU)(x::AbstractArray, h)
@assert ndims(x) == 2 || ndims(x) == 3
h′ = []
# [x] = [in, L] or [in, L, B]
for x_t in eachslice(x, dims = 2)
h = m.cell(x_t, h)
h′ = vcat(h′, [h])
end
return stack(h′, dims = 2)
return scan(m.cell, x, h)
end

# GRU v3
Expand Down Expand Up @@ -750,17 +737,9 @@
return GRUv3(cell)
end

function (gru::GRUv3)(x::AbstractArray)
state = initialstates(gru)
return gru(x, state)
end
(gru::GRUv3)(x::AbstractArray) = gru(x, initialstates(gru))

Check warning on line 740 in src/layers/recurrent.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/recurrent.jl#L740

Added line #L740 was not covered by tests

function (m::GRUv3)(x::AbstractArray, h)
@assert ndims(x) == 2 || ndims(x) == 3
h′ = []
for x_t in eachslice(x, dims = 2)
h = m.cell(x_t, h)
h′ = vcat(h′, [h])
end
return stack(h′, dims = 2)
return scan(m.cell, x, h)
end
100 changes: 40 additions & 60 deletions test/ext_common/recurrent_gpu_ad.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,28 @@

@testset "RNNCell GPU AD" begin
function loss(r, x, h)
y = []
for x_t in x
h = r(x_t, h)
y = vcat(y, [h])
end
# return mean(h)
y = stack(y, dims=2) # [D, L] or [D, L, B]
return mean(y)
out_from_state(state::Tuple) = state[1]
out_from_state(state) = state

function recurrent_cell_loss(cell, seq, state)
out = []
for xt in seq
state = cell(xt, state)
yt = out_from_state(state)
out = vcat(out, [yt])
end
return mean(stack(out, dims = 2))
end

@testset "RNNCell GPU AD" begin
d_in, d_out, len, batch_size = 2, 3, 4, 5
r = RNNCell(d_in => d_out)
x = [randn(Float32, d_in, batch_size) for _ in 1:len]
h = zeros(Float32, d_out)
# Single Step
@test test_gradients(r, x[1], h; test_gpu=true, compare_finite_diff=false) broken = :rnncell_single ∈ BROKEN_TESTS
@test test_gradients(r, x[1], h; test_gpu=true,
compare_finite_diff=false) broken = :rnncell_single ∈ BROKEN_TESTS
# Multiple Steps
@test test_gradients(r, x, h; test_gpu=true, compare_finite_diff=false, loss) broken = :rnncell_multiple ∈ BROKEN_TESTS
@test test_gradients(r, x, h; test_gpu=true,
compare_finite_diff=false,
loss=recurrent_cell_loss) broken = :rnncell_multiple ∈ BROKEN_TESTS
end

@testset "RNN GPU AD" begin
Expand All @@ -40,21 +44,6 @@ end
end

@testset "LSTMCell" begin

function loss(r, x, hc)
h, c = hc
h′ = []
c′ = []
for x_t in x
h, c = r(x_t, (h, c))
h′ = vcat(h′, [h])
c′ = [c′..., c]
end
hnew = stack(h′, dims=2)
cnew = stack(c′, dims=2)
return mean(hnew) + mean(cnew)
end

d_in, d_out, len, batch_size = 2, 3, 4, 5
cell = LSTMCell(d_in => d_out)
x = [randn(Float32, d_in, batch_size) for _ in 1:len]
Expand All @@ -64,7 +53,9 @@ end
@test test_gradients(cell, x[1], (h, c); test_gpu=true, compare_finite_diff=false,
loss = (m, x, (h, c)) -> mean(m(x, (h, c))[1])) broken = :lstmcell_single ∈ BROKEN_TESTS
# Multiple Steps
@test test_gradients(cell, x, (h, c); test_gpu=true, compare_finite_diff=false, loss) broken = :lstmcell_multiple ∈ BROKEN_TESTS
@test test_gradients(cell, x, (h, c); test_gpu=true,
compare_finite_diff = false,
loss = recurrent_cell_loss) broken = :lstmcell_multiple ∈ BROKEN_TESTS
end

@testset "LSTM" begin
Expand All @@ -81,30 +72,22 @@ end
d_in, d_out, len, batch_size = 2, 3, 4, 5
model = ModelLSTM(LSTM(d_in => d_out), zeros(Float32, d_out), zeros(Float32, d_out))
x_nobatch = randn(Float32, d_in, len)
@test test_gradients(model, x_nobatch; test_gpu=true, compare_finite_diff=false,
loss = (m, x) -> mean(m(x)[1])) broken = :lstm_nobatch ∈ BROKEN_TESTS
@test test_gradients(model, x_nobatch; test_gpu=true,
compare_finite_diff=false) broken = :lstm_nobatch ∈ BROKEN_TESTS
x = randn(Float32, d_in, len, batch_size)
@test test_gradients(model, x; test_gpu=true, compare_finite_diff=false,
loss = (m, x) -> mean(m(x)[1])) broken = :lstm ∈ BROKEN_TESTS
@test test_gradients(model, x; test_gpu=true,
compare_finite_diff=false) broken = :lstm ∈ BROKEN_TESTS
end

@testset "GRUCell" begin
function loss(r, x, h)
y = []
for x_t in x
h = r(x_t, h)
y = vcat(y, [h])
end
y = stack(y, dims=2) # [D, L] or [D, L, B]
return mean(y)
end

d_in, d_out, len, batch_size = 2, 3, 4, 5
r = GRUCell(d_in => d_out)
x = [randn(Float32, d_in, batch_size) for _ in 1:len]
h = zeros(Float32, d_out)
@test test_gradients(r, x[1], h; test_gpu=true, compare_finite_diff=false) broken = :grucell_single ∈ BROKEN_TESTS
@test test_gradients(r, x, h; test_gpu=true, compare_finite_diff=false, loss) broken = :grucell_multiple ∈ BROKEN_TESTS
@test test_gradients(r, x, h; test_gpu=true,
compare_finite_diff = false,
loss = recurrent_cell_loss) broken = :grucell_multiple ∈ BROKEN_TESTS
end

@testset "GRU GPU AD" begin
Expand All @@ -120,28 +103,23 @@ end
d_in, d_out, len, batch_size = 2, 3, 4, 5
model = ModelGRU(GRU(d_in => d_out), zeros(Float32, d_out))
x_nobatch = randn(Float32, d_in, len)
@test test_gradients(model, x_nobatch; test_gpu=true, compare_finite_diff=false) broken = :gru_nobatch ∈ BROKEN_TESTS
@test test_gradients(model, x_nobatch; test_gpu=true,
compare_finite_diff=false) broken = :gru_nobatch ∈ BROKEN_TESTS
x = randn(Float32, d_in, len, batch_size)
@test test_gradients(model, x; test_gpu=true, compare_finite_diff=false) broken = :gru ∈ BROKEN_TESTS
@test test_gradients(model, x; test_gpu=true,
compare_finite_diff=false) broken = :gru ∈ BROKEN_TESTS
end

@testset "GRUv3Cell GPU AD" begin
function loss(r, x, h)
y = []
for x_t in x
h = r(x_t, h)
y = vcat(y, [h])
end
y = stack(y, dims=2) # [D, L] or [D, L, B]
return mean(y)
end

d_in, d_out, len, batch_size = 2, 3, 4, 5
r = GRUv3Cell(d_in => d_out)
x = [randn(Float32, d_in, batch_size) for _ in 1:len]
h = zeros(Float32, d_out)
@test test_gradients(r, x[1], h; test_gpu=true, compare_finite_diff=false) broken = :gruv3cell_single ∈ BROKEN_TESTS
@test test_gradients(r, x, h; test_gpu=true, compare_finite_diff=false, loss) broken = :gruv3cell_multiple ∈ BROKEN_TESTS
@test test_gradients(r, x[1], h; test_gpu=true,
compare_finite_diff=false) broken = :gruv3cell_single ∈ BROKEN_TESTS
@test test_gradients(r, x, h; test_gpu=true,
compare_finite_diff=false,
loss = recurrent_cell_loss) broken = :gruv3cell_multiple ∈ BROKEN_TESTS
end

@testset "GRUv3 GPU AD" begin
Expand All @@ -157,7 +135,9 @@ end
d_in, d_out, len, batch_size = 2, 3, 4, 5
model = ModelGRUv3(GRUv3(d_in => d_out), zeros(Float32, d_out))
x_nobatch = randn(Float32, d_in, len)
@test test_gradients(model, x_nobatch; test_gpu=true, compare_finite_diff=false) broken = :gruv3_nobatch ∈ BROKEN_TESTS
@test test_gradients(model, x_nobatch; test_gpu=true,
compare_finite_diff=false) broken = :gruv3_nobatch ∈ BROKEN_TESTS
x = randn(Float32, d_in, len, batch_size)
@test test_gradients(model, x; test_gpu=true, compare_finite_diff=false) broken = :gruv3 ∈ BROKEN_TESTS
@test test_gradients(model, x; test_gpu=true,
compare_finite_diff=false) broken = :gruv3 ∈ BROKEN_TESTS
end
20 changes: 7 additions & 13 deletions test/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,37 +156,31 @@ end
model = ModelLSTM(LSTM(2 => 4), zeros(Float32, 4), zeros(Float32, 4))

x = rand(Float32, 2, 3, 1)
h, c = model(x)
h = model(x)
@test h isa Array{Float32, 3}
@test size(h) == (4, 3, 1)
@test c isa Array{Float32, 3}
@test size(c) == (4, 3, 1)
test_gradients(model, x, loss = (m, x) -> mean(m(x)[1]))
test_gradients(model, x)

x = rand(Float32, 2, 3)
h, c = model(x)
h = model(x)
@test h isa Array{Float32, 2}
@test size(h) == (4, 3)
@test c isa Array{Float32, 2}
@test size(c) == (4, 3)
test_gradients(model, x, loss = (m, x) -> mean(m(x)[1]))

# test default initial states
lstm = model.lstm
h, c = lstm(x)
h = lstm(x)
@test h isa Array{Float32, 2}
@test size(h) == (4, 3)
@test c isa Array{Float32, 2}
@test size(c) == (4, 3)


# initial states are zero
h0, c0 = Flux.initialstates(lstm)
@test h0 ≈ zeros(Float32, 4)
@test c0 ≈ zeros(Float32, 4)

# no initial state same as zero initial state
h1, c1 = lstm(x, (zeros(Float32, 4), zeros(Float32, 4)))
h1 = lstm(x, (zeros(Float32, 4), zeros(Float32, 4)))
@test h ≈ h1
@test c ≈ c1
end

@testset "GRUCell" begin
Expand Down
Loading