Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support scalar numbers #33

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ end
# default all rules to first order calls
apply!(o, state, x, dx, dxs...) = apply!(o, state, x, dx)

isnumeric(x::AbstractFloat) = true
isnumeric(x::Complex{<:AbstractFloat}) = true
Comment on lines +42 to +43
Copy link
Member

@mcabbott mcabbott Jan 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it's cleaner to say what's excluded:

Suggested change
isnumeric(x::AbstractFloat) = true
isnumeric(x::Complex{<:AbstractFloat}) = true
isnumeric(x::Number) = true
isnumeric(x::Integer) = false

Then when someone wants Unitful numbers, they may just work.

If this is going to be applied everywhere, not as an opt-in, then I think it ought to exclude integers. Even though AD mostly draws the line only at Bool. It's possible that arrays of integers should similarly be excluded from isnumeric? Then all cases can just be:

isnumeric(x::Numeric) = isleaf(x)
isnumeric(x::Numeric{<:Integer}) = false
isnumeric(_) = false

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe, though if in some distant future Flux is actually capable of quantization-aware training you might well see int8/int4 arrays. RE Unitful quantities, is there any reason to consider them any "safer" to optimize than integers?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the graphics people do some things with low-bit numbers, but wrap them up to disguise the integer inside --- presumably it would make sense to do something like that before trying to AD with Int8?

julia> g = Gray{ColorTypes.N0f8}(128/255)
Gray{N0f8}(0.502)

julia> 0 < g < 1
true

julia> g isa Integer
false

julia> dump(g)
Gray{FixedPointNumbers.N0f8}
  val: FixedPointNumbers.N0f8
    i: UInt8 0x80

RE Unitful quantities, is there any reason to consider them any "safer" to optimize than integers?

My thinking is that if you add float(zero(x)) to them, which is what some trivial gradient could produce, nothing much will happen. Whereas if you do this to a size or an index, it'll break.

Copy link
Member Author

@ToucheSir ToucheSir Feb 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's a thought: if the concern is about indices and bools, can we just exclude Int64, Int32 and Bool? Maybe the UInts down to 8 and Int16 as well. I've not seen anyone using those types as model params, and it would eliminate 99% of the confusion with indices + true/false.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After playing around for a bit with Unitful, I'm still not comfortable with giving carte blanche to all non-integer number types. Even with floats, it's not clear that a scalar param should be trainable by default.

isnumeric(x::AbstractArray{<:Number}) = isleaf(x) # isleaf to allow for e.g. transposed shared weights
isnumeric(x::AbstractArray{<:Bool}) = false # convention of ChainRules is that Bool is non-differentiable
isnumeric(x) = false
Expand Down
49 changes: 25 additions & 24 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ struct Descent{T}
end
Descent() = Descent(1f-1)

init(o::Descent, x::AbstractArray) = nothing
init(o::Descent, x::Union{AbstractArray,Number}) = nothing
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may be too broad of a type. An alternative would be Union{AbstractArray,AbstractFloat,Complex{<:AbstractFloat}}, with type aliases as appropriate.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think init is called unconditionally by setup. So there's no need to filter out integers at this stage, that's already happened or not, it's just to throw an error if you really can't handle something.

Zygote et. al. have Numeric{T<:Number} = Union{T, AbstractArray{T}}, and then you can do Numeric{<:Real} etc?


function apply!(o::Descent, state, x, dx)
η = convert(float(eltype(dx)), o.eta)

return state, @.. dx * η
end

Expand All @@ -40,12 +40,12 @@ struct Momentum{T}
end
Momentum(η = 1f-2, ρ = 9f-1) = Momentum{typeof(η)}(η, ρ)

init(o::Momentum, x::AbstractArray) = zero(x)
init(o::Momentum, x::Union{AbstractArray,Number}) = zero(x)

function apply!(o::Momentum, state, x, dx)
η, ρ, v = o.eta, o.rho, state
v′ = @.. v = ρ * v - η * dx

return v′, @.. -v′
end

Expand All @@ -68,15 +68,15 @@ struct Nesterov{T}
end
Nesterov(η = 1f-3, ρ = 9f-1) = Nesterov{typeof(η)}(η, ρ)

init(o::Nesterov, x::AbstractArray) = zero(x)
init(o::Nesterov, x::Union{AbstractArray,Number}) = zero(x)

(o::Nesterov)(state, m, dm) = update(o, state, m, dm)

function apply!(o::Nesterov, state, x, dx)
η, ρ, v = o.eta, o.rho, state
d = @.. ρ^2 * v - (1+ρ) * η * dx
v′ = @.. v = ρ * v - η * dx

return v′, @.. -d
end

Expand All @@ -103,13 +103,13 @@ struct RMSProp{T}
end
RMSProp(η = 1f-3, ρ = 9f-1, ϵ = eps(typeof(η))) = RMSProp{typeof(η)}(η, ρ, ϵ)

init(o::RMSProp, x::AbstractArray) = zero(x)
init(o::RMSProp, x::Union{AbstractArray,Number}) = zero(x)

function apply!(o::RMSProp, state, x, dx)
η, ρ, ϵ, acc = o.eta, o.rho, o.epsilon, state
acc′ = @.. acc = ρ * acc + (1 - ρ) * dx^2
dx′ = @.. dx * (η / (sqrt(acc) + ϵ))

return acc′, dx′
end

Expand All @@ -135,7 +135,7 @@ struct ADAM{T}
end
ADAM(η = 1f-3, β = (9f-1, 9.99f-1), ϵ = eps(typeof(η))) = ADAM{typeof(η)}(η, β, ϵ)

init(o::ADAM, x::AbstractArray) = (zero(x), zero(x), o.beta)
init(o::ADAM, x::Union{AbstractArray,Number}) = (zero(x), zero(x), o.beta)

(o::ADAM)(state, m, dm) = update(o, state, m, dm)

Expand Down Expand Up @@ -170,7 +170,7 @@ struct RADAM{T}
end
RADAM(η = 1f-3, β = (9f-1, 9.99f-1), ϵ = eps(typeof(η))) = RADAM{typeof(η)}(η, β, ϵ)

init(o::RADAM, x::AbstractArray) = (zero(x), zero(x), o.beta, 1)
init(o::RADAM, x::Union{AbstractArray,Number}) = (zero(x), zero(x), o.beta, 1)

(o::RADAM)(state, m, dm) = update(o, state, m, dm)

Expand Down Expand Up @@ -213,7 +213,7 @@ struct AdaMax{T}
end
AdaMax(η = 1f-3, β = (9f-1, 9.99f-1), ϵ = eps(typeof(η))) = AdaMax{typeof(η)}(η, β, ϵ)

init(o::AdaMax, x::AbstractArray) = (zero(x), zero(x), o.beta)
init(o::AdaMax, x::Union{AbstractArray,Number}) = (zero(x), zero(x), o.beta)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also on types: this could be narrower since max does not support complex numbers. Should it be handled on a case-by-case basis?


(o::AdaMax)(state, m, dm) = update(o, state, m, dm)

Expand Down Expand Up @@ -250,7 +250,7 @@ struct OADAM{T}
end
OADAM(η = 1f-3, β = (5f-1, 9f-1), ϵ = eps(typeof(η))) = OADAM{typeof(η)}(η, β, ϵ)

init(o::OADAM, x::AbstractArray) = (zero(x), zero(x), o.beta, zero(x))
init(o::OADAM, x::Union{AbstractArray,Number}) = (zero(x), zero(x), o.beta, zero(x))

(o::OADAM)(state, m, dm) = update(o, state, m, dm)

Expand Down Expand Up @@ -287,6 +287,7 @@ struct ADAGrad{T}
end
ADAGrad(η = 1f-1, ϵ = eps(typeof(η))) = ADAGrad{typeof(η)}(η, ϵ)

init(o::ADAGrad, ::T) where T <: Number = convert(T, o.epsilon)
init(o::ADAGrad, x::AbstractArray) = fill!(similar(x), o.epsilon)

(o::ADAGrad)(state, m, dm) = update(o, state, m, dm)
Expand Down Expand Up @@ -319,7 +320,7 @@ struct ADADelta{T}
end
ADADelta(ρ = 9f-1, ϵ = eps(typeof(ρ))) = ADADelta{typeof(ρ)}(ρ, ϵ)

init(o::ADADelta, x::AbstractArray) = (zero(x), zero(x))
init(o::ADADelta, x::Union{AbstractArray,Number}) = (zero(x), zero(x))

(o::ADADelta)(state, m, dm) = update(o, state, m, dm)

Expand All @@ -332,7 +333,7 @@ function apply!(o::ADADelta, state, x, dx)
# or even out of the square roots
dx′ = @.. dx * sqrt(Δacc + ϵ) / sqrt(acc + ϵ)
Δacc′ = @.. Δacc = ρ * Δacc + (1 - ρ) * dx^2

return (acc′, Δacc′), dx′
end

Expand All @@ -357,8 +358,8 @@ struct AMSGrad{T}
end
AMSGrad(η = 1f-3, β = (9f-1, 9.99f-1), ϵ = eps(typeof(η))) = AMSGrad{typeof(η)}(η, β, ϵ)

init(o::AMSGrad, x::AbstractArray) =
(fill!(similar(x), o.epsilon), fill!(similar(x), o.epsilon), fill!(similar(x), o.epsilon))
init(o::AMSGrad, ::T) where T <: Number = ntuple(_ -> convert(T, o.epsilon), 3)
init(o::AMSGrad, x::AbstractArray) = ntuple(_ -> fill!(similar(x), o.epsilon), 3)

(o::AMSGrad)(state, m, dm) = update(o, state, m, dm)

Expand Down Expand Up @@ -396,7 +397,7 @@ struct NADAM{T}
end
NADAM(η = 1f-3, β = (9f-1, 9.99f-1), ϵ = eps(typeof(η))) = NADAM{typeof(η)}(η, β, ϵ)

init(o::NADAM, x::AbstractArray) = (zero(x), zero(x), o.beta)
init(o::NADAM, x::Union{AbstractArray,Number}) = (zero(x), zero(x), o.beta)

(o::NADAM)(state, m, dm) = update(o, state, m, dm)

Expand All @@ -407,7 +408,7 @@ function apply!(o::NADAM, state, x, dx)

mt′ = @.. mt = β[1] * mt + (1 - β[1]) * dx
vt′ = @.. vt = β[2] * vt + (1 - β[2]) * dx^2
dx′ = @.. (β[1] * mt / (1 - β[1] * βt[1]) + (1 - β[1]) * dx / (1 - βt[1])) /
dx′ = @.. (β[1] * mt / (1 - β[1] * βt[1]) + (1 - β[1]) * dx / (1 - βt[1])) /
(sqrt(vt * β[2] / (1 - βt[2])) + ϵ) * η

return (mt′, vt′, βt .* β), dx′
Expand Down Expand Up @@ -452,7 +453,7 @@ struct AdaBelief{T}
end
AdaBelief(η = 1f-3, β = (9f-1, 9.99f-1), ϵ = eps(typeof(η))) = AdaBelief{typeof(η)}(η, β, ϵ)

init(o::AdaBelief, x::AbstractArray) = (zero(x), zero(x))
init(o::AdaBelief, x::Union{AbstractArray,Number}) = (zero(x), zero(x))

(o::AdaBelief)(state, m, dm) = update(o, state, m, dm)

Expand All @@ -463,7 +464,7 @@ function apply!(o::AdaBelief, state, x, dx)
mt′ = @.. mt = β[1] * mt + (1 - β[1]) * dx
st′ = @.. st = β[2] * st + (1 - β[2]) * (dx - mt)^2
dx′ = @.. η * mt / (sqrt(st) + ϵ)

return (mt′, st′), dx′
end

Expand All @@ -480,7 +481,7 @@ struct WeightDecay{T}
end
WeightDecay() = WeightDecay(5f-4)

init(o::WeightDecay, x::AbstractArray) = nothing
init(o::WeightDecay, x::Union{AbstractArray,Number}) = nothing

(o::WeightDecay)(state, m, dm) = update(o, state, m, dm)

Expand All @@ -502,7 +503,7 @@ struct ClipGrad{T<:Real}
end
ClipGrad() = ClipGrad(10f0)

init(o::ClipGrad, x::AbstractArray) = nothing
init(o::ClipGrad, x::Union{AbstractArray,Number}) = nothing

(o::ClipGrad)(state::Nothing, m, dm) = update(o, state, m, dm)

Expand Down Expand Up @@ -531,7 +532,7 @@ struct ClipNorm{T<:Real}
end
ClipNorm(ω = 10f0, p = 2; throw::Bool = true) = ClipNorm{typeof(ω)}(ω, p, throw)

init(o::ClipNorm, x::AbstractArray) = nothing
init(o::ClipNorm, x::Union{AbstractArray,Number}) = nothing

(o::ClipNorm)(state::Nothing, m, dm) = update(o, state, m, dm)

Expand All @@ -556,7 +557,7 @@ struct OptimiserChain{O}
end
OptimiserChain(opts...) = OptimiserChain(opts)

init(o::OptimiserChain, x::AbstractArray) = [init(opt, x) for opt in o.opts]
init(o::OptimiserChain, x::Union{AbstractArray,Number}) = [init(opt, x) for opt in o.opts]

(o::OptimiserChain)(state, m, dms...) = update(o, state, m, dms...)

Expand Down
93 changes: 60 additions & 33 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ using Optimisers: @..
g = ([25, 33],)
o = Descent(0.1)
s = Optimisers.state(o, m)

s2, m2 = Optimisers.update(o, s, m, g)
@test m[1] == 1:2 # not mutated
@test Optimisers.iswriteable(m[1])
Expand All @@ -23,42 +23,69 @@ using Optimisers: @..
@test m3[1] ≈ [1,2] .- 0.1 .* [25, 33]
end

@testset for o in (Descent(), ADAM(), Momentum(), Nesterov(), RMSProp(),
ADAGrad(), AdaMax(), ADADelta(), AMSGrad(), NADAM(),
ADAMW(), RADAM(), OADAM(), AdaBelief())
w′ = (α = rand(3, 3), β = rand(3, 3))

# Original example
w = (α = 5rand(3, 3), β = rand(3, 3))
st = Optimisers.state(o, w)
loss(x, y) = mean((x.α .* x.β .- y.α .* y.β) .^ 2)
@test loss(w, w′) > 1
for i = 1:10^4
gs = gradient(x -> loss(x, w′), w)
st, w = Optimisers.update(o, st, w, gs...)
end
lw = loss(w, w′)
if o isa ADADelta
@test_broken lw < 0.001
else
@test lw < 0.001
end
ALL_OPTS = (Descent(), ADAM(), Momentum(), Nesterov(), RMSProp(),
ADAGrad(), AdaMax(), ADADelta(), AMSGrad(), NADAM(),
ADAMW(), RADAM(), OADAM(), AdaBelief())
ALL_TRANSFORMS = (ALL_OPTS..., ClipGrad(), ClipNorm(), WeightDecay())

# Slightly harder variant
m = (α = randn(3), β = transpose(5rand(3,3)), γ = (rand(2), tanh)) # issue 28
st = Optimisers.state(o, m)
@test loss(m, w′) > 1
for i = 1:10^4
gs = gradient(x -> loss(x, w′), m)
st, m = o(st, m, gs...)
@testset "Training with gradients" begin
@testset for o in ALL_OPTS
w′ = (α = rand(3, 3), β = rand(3, 3))

# Original example
w = (α = 5rand(3, 3), β = rand(3, 3))
st = Optimisers.state(o, w)
loss(x, y) = mean((x.α .* x.β .- y.α .* y.β) .^ 2)
@test loss(w, w′) > 1
for i = 1:10^4
gs = gradient(x -> loss(x, w′), w)
st, w = Optimisers.update(o, st, w, gs...)
end
lw = loss(w, w′)
if o isa ADADelta
@test_broken lw < 0.001
else
@test lw < 0.001
end

# Slightly harder variant
m = (α = randn(3), β = transpose(5rand(3,3)), γ = (rand(2), tanh)) # issue 28
st = Optimisers.state(o, m)
@test loss(m, w′) > 1
for i = 1:10^4
gs = gradient(x -> loss(x, w′), m)
st, m = o(st, m, gs...)
end
lm = loss(m, w′)
if lm < 0.1
@test lm < 0.1
else
@test_broken lm < 0.1 # @test keyword broken doesn't exist on Julia 1.6
end
end
lm = loss(m, w′)
if lm < 0.1
@test lm < 0.1
else
@test_broken lm < 0.1 # @test keyword broken doesn't exist on Julia 1.6
end

@testset "Scalar Params" begin
m = (α = 3, β = 5f0, γ = 2.0+2.0im)
gs = (α = nothing, β = 1f0, γ = 1.0+1.0im)

# General compatibility
@testset for t in ALL_TRANSFORMS
m2, gs2 = m, gs
# These can't handle complex numbers
if t isa Union{AdaMax,AMSGrad,ClipGrad}
m2, gs2 = m[(:α, :β)], gs[(:α, :β)]
end
st = Optimisers.state(t, m2)
Optimisers.update(t, st, m2, gs2)
end

# End-to-end
o = Descent(0.1)
st = Optimisers.state(o, m)
_, m′ = Optimisers.update(o, st, m, gs)
@test m.β - 0.1gs.β ≈ m′.β
@test m.γ - 0.1gs.γ ≈ m′.γ
end

@testset "OptimiserChain with $pre" for pre in (WeightDecay(), ClipGrad(), ClipNorm())
Expand Down