Skip to content

Commit

Permalink
Make normalization more AD friendly
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 5, 2022
1 parent 3d6c75c commit 73a790b
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 52 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Lux"
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.4.20"
version = "0.4.21"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
30 changes: 22 additions & 8 deletions src/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,23 @@ ChainRulesCore.@non_differentiable glorot_uniform(::Any...)
ChainRulesCore.@non_differentiable check_use_cuda()
ChainRulesCore.@non_differentiable istraining(::Any)
ChainRulesCore.@non_differentiable _get_norm_except_dims(::Any, ::Any)
ChainRulesCore.@non_differentiable _affine(::Any)
ChainRulesCore.@non_differentiable _track_stats(::Any)
ChainRulesCore.@non_differentiable _copy_autodiff_barrier(::Any)

# NNlib Functions
function ChainRulesCore.rrule(::typeof(batchnorm), g::CuArray{T}, b::CuArray{T},
x::Union{CuArray{T, 4}, CuArray{T, 5}}, running_mean,
running_var, momentum; kwargs...) where {T <: CUDNNFloat}
y = batchnorm(g, b, x, running_mean, running_var, momentum; kwargs...)
function batchnorm_pullback(dy)
dg, db, dx = ∇batchnorm(g, b, x, dy, running_mean, running_var, momentum; kwargs...)
return NoTangent(), dg, db, dx, NoTangent(), NoTangent(), NoTangent()
function ChainRulesCore.rrule(::typeof(_batchnorm), g::CuArray{T}, b::CuArray{T},
x::Union{CuArray{T, 2}, CuArray{T, 4}, CuArray{T, 5}},
running_mean, running_var, momentum, epsilon,
training) where {T <: CUDNNFloat}
y = _batchnorm(g, b, x, running_mean, running_var, momentum, epsilon, training)
function _batchnorm_pullback(dy)
dg, db, dx = ∇batchnorm(g, b, x, unthunk(dy), running_mean, running_var, momentum;
eps=epsilon, training=training)
return NoTangent(), dg, db, dx, NoTangent(), NoTangent(), NoTangent(), NoTangent(),
NoTangent()
end
return y, batchnorm_pullback
return y, _batchnorm_pullback
end

function ChainRulesCore.rrule(::typeof(dropout), rng::AbstractRNG, x::AbstractArray{T, N},
Expand Down Expand Up @@ -59,6 +65,9 @@ function ChainRulesCore.rrule(::typeof(merge), nt1::NamedTuple{F1},
dnt2 = NamedTuple((f2 => getproperty(dy, f2) for f2 in F2))
return (NoTangent(), dnt1, dnt2)
end
function merge_pullback(dy::Union{NoTangent, ZeroTangent})
return (NoTangent(), NoTangent(), NoTangent())
end
return y, merge_pullback
end

Expand Down Expand Up @@ -89,6 +98,11 @@ function ChainRulesCore.rrule(::typeof(collect), v::Vector)
return y, collect_pullback
end

function ChainRulesCore.rrule(::typeof(copy), x)
copy_pullback(dy) = (NoTangent(), dy)
return copy(x), copy_pullback
end

# Zygote Fixes
function Zygote.accum(x::ComponentArray, ys::ComponentArray...)
return ComponentArray(Zygote.accum(getdata(x), getdata.(ys)...), getaxes(x))
Expand Down
83 changes: 40 additions & 43 deletions src/layers/normalize.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
abstract type AbstractNormalizationLayer{affine, track_stats} <: AbstractExplicitLayer end

@inline _affine(l::AbstractNormalizationLayer{A, T}) where {A, T} = A
@inline _track_stats(l::AbstractNormalizationLayer{A, T}) where {A, T} = T

"""
BatchNorm(chs::Integer, activation=identity; init_bias=zeros32, init_scale=ones32,
affine=true, track_stats=true, epsilon=1f-5, momentum=0.1f0)
Expand Down Expand Up @@ -93,28 +96,25 @@ function BatchNorm(chs::Int, activation=identity; init_bias=zeros32, init_scale=
chs, init_bias, init_scale)
end

function initialparameters(rng::AbstractRNG, l::BatchNorm{affine}) where {affine}
if affine
function initialparameters(rng::AbstractRNG, l::BatchNorm)
if _affine(l)
return (scale=l.init_scale(rng, l.chs), bias=l.init_bias(rng, l.chs))
else
return (scale=nothing, bias=nothing)
end
end

function initialstates(rng::AbstractRNG,
l::BatchNorm{affine, track_stats}) where {affine, track_stats}
return if track_stats
(running_mean=zeros32(rng, l.chs), running_var=ones32(rng, l.chs),
training=Val(true))
function initialstates(rng::AbstractRNG, l::BatchNorm)
if _track_stats(l)
return (running_mean=zeros32(rng, l.chs), running_var=ones32(rng, l.chs),
training=Val(true))
else
(running_mean=nothing, running_var=nothing, training=Val(true))
return (running_mean=nothing, running_var=nothing, training=Val(true))
end
end

parameterlength(l::BatchNorm{affine}) where {affine} = affine ? (l.chs * 2) : 0
function statelength(l::BatchNorm{affine, track_stats}) where {affine, track_stats}
return (track_stats ? 2 * l.chs : 0) + 1
end
parameterlength(l::BatchNorm) = _affine(l) ? (l.chs * 2) : 0
statelength(l::BatchNorm) = (_track_stats(l) ? 2 * l.chs : 0) + 1

function (BN::BatchNorm)(x::AbstractArray{T, N}, ps, st::NamedTuple) where {T, N}
x_normalized, xmean, xvar = normalization(x, st.running_mean, st.running_var, ps.scale,
Expand All @@ -127,42 +127,42 @@ function (BN::BatchNorm)(x::AbstractArray{T, N}, ps, st::NamedTuple) where {T, N
return x_normalized, st
end

function (BN::BatchNorm{affine, track_stats})(x::Union{CuArray{T, 2}, CuArray{T, 4},
CuArray{T, 5}}, ps,
st::NamedTuple) where {
T <:
Union{Float32, Float64
}, affine,
track_stats}
function _batchnorm(scale, bias, x, running_mean, running_var, momentum, epsilon, training)
return batchnorm(scale, bias, x, running_mean, running_var, momentum; eps=epsilon,
training=training)
end

function (BN::BatchNorm)(x::Union{CuArray{T, 2}, CuArray{T, 4}, CuArray{T, 5}}, ps,
st::NamedTuple) where {T <: Union{Float32, Float64}}
# NNlibCUDA silently updates running_mean and running_var so copying them
if istraining(st)
running_mean2 = track_stats ? copy(st.running_mean) : nothing
running_var2 = track_stats ? copy(st.running_var) : nothing
running_mean2 = _track_stats(BN) ? _copy_autodiff_barrier(st.running_mean) : nothing
running_var2 = _track_stats(BN) ? _copy_autodiff_barrier(st.running_var) : nothing
else
if track_stats
running_mean2 = copy(st.running_mean)
running_var2 = copy(st.running_var)
if _track_stats(BN)
running_mean2 = _copy_autodiff_barrier(st.running_mean)
running_var2 = _copy_autodiff_barrier(st.running_var)
else
N = ndims(x)
reduce_dims = collect([1:(N - 2); N])
running_mean2 = mean(x; dims=reduce_dims)
running_var2 = var(x; mean=running_mean2, dims=reduce_dims, corrected=false)
end
end
res = BN.activation.(batchnorm(affine ? ps.scale : nothing, affine ? ps.bias : nothing,
x, running_mean2, running_var2, BN.momentum;
eps=BN.epsilon, training=istraining(st)))
if track_stats
res = BN.activation.(_batchnorm(_affine(BN) ? ps.scale : nothing,
_affine(BN) ? ps.bias : nothing, x, running_mean2,
running_var2, BN.momentum, BN.epsilon, istraining(st)))
if _track_stats(BN)
st = merge(st, (running_mean=running_mean2, running_var=running_var2))
end
return res, st
end

function Base.show(io::IO, l::BatchNorm{affine, track_stats}) where {affine, track_stats}
function Base.show(io::IO, l::BatchNorm)
print(io, "BatchNorm($(l.chs)")
(l.activation == identity) || print(io, ", $(l.activation)")
print(io, ", affine=$(affine)")
print(io, ", track_stats=$(track_stats)")
print(io, ", affine=$(_affine(l))")
print(io, ", track_stats=$(_track_stats(l))")
return print(io, ")")
end

Expand Down Expand Up @@ -281,29 +281,26 @@ function GroupNorm(chs::Integer, groups::Integer, activation=identity; init_bias
groups)
end

function initialparameters(rng::AbstractRNG, l::GroupNorm{affine}) where {affine}
if affine
function initialparameters(rng::AbstractRNG, l::GroupNorm)
if _affine(l)
return (scale=l.init_scale(rng, l.chs), bias=l.init_bias(rng, l.chs))
else
return (scale=nothing, bias=nothing)
end
end

function initialstates(rng::AbstractRNG,
l::GroupNorm{affine, track_stats}) where {affine, track_stats}
return if track_stats
function initialstates(rng::AbstractRNG, l::GroupNorm)
return if _track_stats(l)
(running_mean=zeros32(rng, l.groups), running_var=ones32(rng, l.groups),
training=Val(true))
else
(running_mean=nothing, running_var=nothing, training=Val(true))
end
end

parameterlength(l::GroupNorm{affine}) where {affine} = affine ? (l.chs * 2) : 0
parameterlength(l::GroupNorm) = _affine(l) ? (l.chs * 2) : 0

function statelength(l::GroupNorm{affine, track_stats}) where {affine, track_stats}
return (track_stats ? 2 * l.groups : 0) + 1
end
statelength(l::GroupNorm) = (_track_stats(l) ? 2 * l.groups : 0) + 1

function (GN::GroupNorm)(x::AbstractArray{T, N}, ps, st::NamedTuple) where {T, N}
sz = size(x)
Expand All @@ -318,11 +315,11 @@ function (GN::GroupNorm)(x::AbstractArray{T, N}, ps, st::NamedTuple) where {T, N
return reshape(x_normalized, sz), st
end

function Base.show(io::IO, l::GroupNorm{affine, track_stats}) where {affine, track_stats}
function Base.show(io::IO, l::GroupNorm)
print(io, "GroupNorm($(l.chs), $(l.groups)")
(l.activation == identity) || print(io, ", $(l.activation)")
print(io, ", affine=$(affine)")
print(io, ", track_stats=$(track_stats)")
print(io, ", affine=$(_affine(l))")
print(io, ", track_stats=$(_track_stats(l))")
return print(io, ")")
end

Expand Down
3 changes: 3 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,6 @@ Split up `x` into `N` equally sized chunks (along dimension `1`).

# Val utilities
get_known(::Val{T}) where {T} = T

# Copy and don't allow gradient propagation
_copy_autodiff_barrier(x) = copy(x)

0 comments on commit 73a790b

Please sign in to comment.