diff --git a/Project.toml b/Project.toml index 6926c5fb4..0aac87933 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Lux" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" authors = ["Avik Pal and contributors"] -version = "0.4.20" +version = "0.4.21" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/autodiff.jl b/src/autodiff.jl index e815b51f8..ccead668d 100644 --- a/src/autodiff.jl +++ b/src/autodiff.jl @@ -10,17 +10,23 @@ ChainRulesCore.@non_differentiable glorot_uniform(::Any...) ChainRulesCore.@non_differentiable check_use_cuda() ChainRulesCore.@non_differentiable istraining(::Any) ChainRulesCore.@non_differentiable _get_norm_except_dims(::Any, ::Any) +ChainRulesCore.@non_differentiable _affine(::Any) +ChainRulesCore.@non_differentiable _track_stats(::Any) +ChainRulesCore.@non_differentiable _copy_autodiff_barrier(::Any) # NNlib Functions -function ChainRulesCore.rrule(::typeof(batchnorm), g::CuArray{T}, b::CuArray{T}, - x::Union{CuArray{T, 4}, CuArray{T, 5}}, running_mean, - running_var, momentum; kwargs...) where {T <: CUDNNFloat} - y = batchnorm(g, b, x, running_mean, running_var, momentum; kwargs...) - function batchnorm_pullback(dy) - dg, db, dx = ∇batchnorm(g, b, x, dy, running_mean, running_var, momentum; kwargs...) - return NoTangent(), dg, db, dx, NoTangent(), NoTangent(), NoTangent() +function ChainRulesCore.rrule(::typeof(_batchnorm), g::CuArray{T}, b::CuArray{T}, + x::Union{CuArray{T, 2}, CuArray{T, 4}, CuArray{T, 5}}, + running_mean, running_var, momentum, epsilon, + training) where {T <: CUDNNFloat} + y = _batchnorm(g, b, x, running_mean, running_var, momentum, epsilon, training) + function _batchnorm_pullback(dy) + dg, db, dx = ∇batchnorm(g, b, x, unthunk(dy), running_mean, running_var, momentum; + eps=epsilon, training=training) + return NoTangent(), dg, db, dx, NoTangent(), NoTangent(), NoTangent(), NoTangent(), + NoTangent() end - return y, batchnorm_pullback + return y, _batchnorm_pullback end function ChainRulesCore.rrule(::typeof(dropout), rng::AbstractRNG, x::AbstractArray{T, N}, @@ -59,6 +65,9 @@ function ChainRulesCore.rrule(::typeof(merge), nt1::NamedTuple{F1}, dnt2 = NamedTuple((f2 => getproperty(dy, f2) for f2 in F2)) return (NoTangent(), dnt1, dnt2) end + function merge_pullback(dy::Union{NoTangent, ZeroTangent}) + return (NoTangent(), NoTangent(), NoTangent()) + end return y, merge_pullback end @@ -89,6 +98,11 @@ function ChainRulesCore.rrule(::typeof(collect), v::Vector) return y, collect_pullback end +function ChainRulesCore.rrule(::typeof(copy), x) + copy_pullback(dy) = (NoTangent(), dy) + return copy(x), copy_pullback +end + # Zygote Fixes function Zygote.accum(x::ComponentArray, ys::ComponentArray...) return ComponentArray(Zygote.accum(getdata(x), getdata.(ys)...), getaxes(x)) diff --git a/src/layers/normalize.jl b/src/layers/normalize.jl index 0874e7edc..09afb6d32 100644 --- a/src/layers/normalize.jl +++ b/src/layers/normalize.jl @@ -1,5 +1,8 @@ abstract type AbstractNormalizationLayer{affine, track_stats} <: AbstractExplicitLayer end +@inline _affine(l::AbstractNormalizationLayer{A, T}) where {A, T} = A +@inline _track_stats(l::AbstractNormalizationLayer{A, T}) where {A, T} = T + """ BatchNorm(chs::Integer, activation=identity; init_bias=zeros32, init_scale=ones32, affine=true, track_stats=true, epsilon=1f-5, momentum=0.1f0) @@ -93,28 +96,25 @@ function BatchNorm(chs::Int, activation=identity; init_bias=zeros32, init_scale= chs, init_bias, init_scale) end -function initialparameters(rng::AbstractRNG, l::BatchNorm{affine}) where {affine} - if affine +function initialparameters(rng::AbstractRNG, l::BatchNorm) + if _affine(l) return (scale=l.init_scale(rng, l.chs), bias=l.init_bias(rng, l.chs)) else return (scale=nothing, bias=nothing) end end -function initialstates(rng::AbstractRNG, - l::BatchNorm{affine, track_stats}) where {affine, track_stats} - return if track_stats - (running_mean=zeros32(rng, l.chs), running_var=ones32(rng, l.chs), - training=Val(true)) +function initialstates(rng::AbstractRNG, l::BatchNorm) + if _track_stats(l) + return (running_mean=zeros32(rng, l.chs), running_var=ones32(rng, l.chs), + training=Val(true)) else - (running_mean=nothing, running_var=nothing, training=Val(true)) + return (running_mean=nothing, running_var=nothing, training=Val(true)) end end -parameterlength(l::BatchNorm{affine}) where {affine} = affine ? (l.chs * 2) : 0 -function statelength(l::BatchNorm{affine, track_stats}) where {affine, track_stats} - return (track_stats ? 2 * l.chs : 0) + 1 -end +parameterlength(l::BatchNorm) = _affine(l) ? (l.chs * 2) : 0 +statelength(l::BatchNorm) = (_track_stats(l) ? 2 * l.chs : 0) + 1 function (BN::BatchNorm)(x::AbstractArray{T, N}, ps, st::NamedTuple) where {T, N} x_normalized, xmean, xvar = normalization(x, st.running_mean, st.running_var, ps.scale, @@ -127,21 +127,21 @@ function (BN::BatchNorm)(x::AbstractArray{T, N}, ps, st::NamedTuple) where {T, N return x_normalized, st end -function (BN::BatchNorm{affine, track_stats})(x::Union{CuArray{T, 2}, CuArray{T, 4}, - CuArray{T, 5}}, ps, - st::NamedTuple) where { - T <: - Union{Float32, Float64 - }, affine, - track_stats} +function _batchnorm(scale, bias, x, running_mean, running_var, momentum, epsilon, training) + return batchnorm(scale, bias, x, running_mean, running_var, momentum; eps=epsilon, + training=training) +end + +function (BN::BatchNorm)(x::Union{CuArray{T, 2}, CuArray{T, 4}, CuArray{T, 5}}, ps, + st::NamedTuple) where {T <: Union{Float32, Float64}} # NNlibCUDA silently updates running_mean and running_var so copying them if istraining(st) - running_mean2 = track_stats ? copy(st.running_mean) : nothing - running_var2 = track_stats ? copy(st.running_var) : nothing + running_mean2 = _track_stats(BN) ? _copy_autodiff_barrier(st.running_mean) : nothing + running_var2 = _track_stats(BN) ? _copy_autodiff_barrier(st.running_var) : nothing else - if track_stats - running_mean2 = copy(st.running_mean) - running_var2 = copy(st.running_var) + if _track_stats(BN) + running_mean2 = _copy_autodiff_barrier(st.running_mean) + running_var2 = _copy_autodiff_barrier(st.running_var) else N = ndims(x) reduce_dims = collect([1:(N - 2); N]) @@ -149,20 +149,20 @@ function (BN::BatchNorm{affine, track_stats})(x::Union{CuArray{T, 2}, CuArray{T, running_var2 = var(x; mean=running_mean2, dims=reduce_dims, corrected=false) end end - res = BN.activation.(batchnorm(affine ? ps.scale : nothing, affine ? ps.bias : nothing, - x, running_mean2, running_var2, BN.momentum; - eps=BN.epsilon, training=istraining(st))) - if track_stats + res = BN.activation.(_batchnorm(_affine(BN) ? ps.scale : nothing, + _affine(BN) ? ps.bias : nothing, x, running_mean2, + running_var2, BN.momentum, BN.epsilon, istraining(st))) + if _track_stats(BN) st = merge(st, (running_mean=running_mean2, running_var=running_var2)) end return res, st end -function Base.show(io::IO, l::BatchNorm{affine, track_stats}) where {affine, track_stats} +function Base.show(io::IO, l::BatchNorm) print(io, "BatchNorm($(l.chs)") (l.activation == identity) || print(io, ", $(l.activation)") - print(io, ", affine=$(affine)") - print(io, ", track_stats=$(track_stats)") + print(io, ", affine=$(_affine(l))") + print(io, ", track_stats=$(_track_stats(l))") return print(io, ")") end @@ -281,17 +281,16 @@ function GroupNorm(chs::Integer, groups::Integer, activation=identity; init_bias groups) end -function initialparameters(rng::AbstractRNG, l::GroupNorm{affine}) where {affine} - if affine +function initialparameters(rng::AbstractRNG, l::GroupNorm) + if _affine(l) return (scale=l.init_scale(rng, l.chs), bias=l.init_bias(rng, l.chs)) else return (scale=nothing, bias=nothing) end end -function initialstates(rng::AbstractRNG, - l::GroupNorm{affine, track_stats}) where {affine, track_stats} - return if track_stats +function initialstates(rng::AbstractRNG, l::GroupNorm) + return if _track_stats(l) (running_mean=zeros32(rng, l.groups), running_var=ones32(rng, l.groups), training=Val(true)) else @@ -299,11 +298,9 @@ function initialstates(rng::AbstractRNG, end end -parameterlength(l::GroupNorm{affine}) where {affine} = affine ? (l.chs * 2) : 0 +parameterlength(l::GroupNorm) = _affine(l) ? (l.chs * 2) : 0 -function statelength(l::GroupNorm{affine, track_stats}) where {affine, track_stats} - return (track_stats ? 2 * l.groups : 0) + 1 -end +statelength(l::GroupNorm) = (_track_stats(l) ? 2 * l.groups : 0) + 1 function (GN::GroupNorm)(x::AbstractArray{T, N}, ps, st::NamedTuple) where {T, N} sz = size(x) @@ -318,11 +315,11 @@ function (GN::GroupNorm)(x::AbstractArray{T, N}, ps, st::NamedTuple) where {T, N return reshape(x_normalized, sz), st end -function Base.show(io::IO, l::GroupNorm{affine, track_stats}) where {affine, track_stats} +function Base.show(io::IO, l::GroupNorm) print(io, "GroupNorm($(l.chs), $(l.groups)") (l.activation == identity) || print(io, ", $(l.activation)") - print(io, ", affine=$(affine)") - print(io, ", track_stats=$(track_stats)") + print(io, ", affine=$(_affine(l))") + print(io, ", track_stats=$(_track_stats(l))") return print(io, ")") end diff --git a/src/utils.jl b/src/utils.jl index da11f6262..02b3f0892 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -189,3 +189,6 @@ Split up `x` into `N` equally sized chunks (along dimension `1`). # Val utilities get_known(::Val{T}) where {T} = T + +# Copy and don't allow gradient propagation +_copy_autodiff_barrier(x) = copy(x)