-
-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support scalar numbers #33
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,11 +13,11 @@ struct Descent{T} | |
end | ||
Descent() = Descent(1f-1) | ||
|
||
init(o::Descent, x::AbstractArray) = nothing | ||
init(o::Descent, x::Union{AbstractArray,Number}) = nothing | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This may be too broad of a type. An alternative would be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think Zygote et. al. have |
||
|
||
function apply!(o::Descent, state, x, dx) | ||
η = convert(float(eltype(dx)), o.eta) | ||
|
||
return state, @.. dx * η | ||
end | ||
|
||
|
@@ -40,12 +40,12 @@ struct Momentum{T} | |
end | ||
Momentum(η = 1f-2, ρ = 9f-1) = Momentum{typeof(η)}(η, ρ) | ||
|
||
init(o::Momentum, x::AbstractArray) = zero(x) | ||
init(o::Momentum, x::Union{AbstractArray,Number}) = zero(x) | ||
|
||
function apply!(o::Momentum, state, x, dx) | ||
η, ρ, v = o.eta, o.rho, state | ||
v′ = @.. v = ρ * v - η * dx | ||
|
||
return v′, @.. -v′ | ||
end | ||
|
||
|
@@ -68,15 +68,15 @@ struct Nesterov{T} | |
end | ||
Nesterov(η = 1f-3, ρ = 9f-1) = Nesterov{typeof(η)}(η, ρ) | ||
|
||
init(o::Nesterov, x::AbstractArray) = zero(x) | ||
init(o::Nesterov, x::Union{AbstractArray,Number}) = zero(x) | ||
|
||
(o::Nesterov)(state, m, dm) = update(o, state, m, dm) | ||
|
||
function apply!(o::Nesterov, state, x, dx) | ||
η, ρ, v = o.eta, o.rho, state | ||
d = @.. ρ^2 * v - (1+ρ) * η * dx | ||
v′ = @.. v = ρ * v - η * dx | ||
|
||
return v′, @.. -d | ||
end | ||
|
||
|
@@ -103,13 +103,13 @@ struct RMSProp{T} | |
end | ||
RMSProp(η = 1f-3, ρ = 9f-1, ϵ = eps(typeof(η))) = RMSProp{typeof(η)}(η, ρ, ϵ) | ||
|
||
init(o::RMSProp, x::AbstractArray) = zero(x) | ||
init(o::RMSProp, x::Union{AbstractArray,Number}) = zero(x) | ||
|
||
function apply!(o::RMSProp, state, x, dx) | ||
η, ρ, ϵ, acc = o.eta, o.rho, o.epsilon, state | ||
acc′ = @.. acc = ρ * acc + (1 - ρ) * dx^2 | ||
dx′ = @.. dx * (η / (sqrt(acc) + ϵ)) | ||
|
||
return acc′, dx′ | ||
end | ||
|
||
|
@@ -135,7 +135,7 @@ struct ADAM{T} | |
end | ||
ADAM(η = 1f-3, β = (9f-1, 9.99f-1), ϵ = eps(typeof(η))) = ADAM{typeof(η)}(η, β, ϵ) | ||
|
||
init(o::ADAM, x::AbstractArray) = (zero(x), zero(x), o.beta) | ||
init(o::ADAM, x::Union{AbstractArray,Number}) = (zero(x), zero(x), o.beta) | ||
|
||
(o::ADAM)(state, m, dm) = update(o, state, m, dm) | ||
|
||
|
@@ -170,7 +170,7 @@ struct RADAM{T} | |
end | ||
RADAM(η = 1f-3, β = (9f-1, 9.99f-1), ϵ = eps(typeof(η))) = RADAM{typeof(η)}(η, β, ϵ) | ||
|
||
init(o::RADAM, x::AbstractArray) = (zero(x), zero(x), o.beta, 1) | ||
init(o::RADAM, x::Union{AbstractArray,Number}) = (zero(x), zero(x), o.beta, 1) | ||
|
||
(o::RADAM)(state, m, dm) = update(o, state, m, dm) | ||
|
||
|
@@ -213,7 +213,7 @@ struct AdaMax{T} | |
end | ||
AdaMax(η = 1f-3, β = (9f-1, 9.99f-1), ϵ = eps(typeof(η))) = AdaMax{typeof(η)}(η, β, ϵ) | ||
|
||
init(o::AdaMax, x::AbstractArray) = (zero(x), zero(x), o.beta) | ||
init(o::AdaMax, x::Union{AbstractArray,Number}) = (zero(x), zero(x), o.beta) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also on types: this could be narrower since |
||
|
||
(o::AdaMax)(state, m, dm) = update(o, state, m, dm) | ||
|
||
|
@@ -250,7 +250,7 @@ struct OADAM{T} | |
end | ||
OADAM(η = 1f-3, β = (5f-1, 9f-1), ϵ = eps(typeof(η))) = OADAM{typeof(η)}(η, β, ϵ) | ||
|
||
init(o::OADAM, x::AbstractArray) = (zero(x), zero(x), o.beta, zero(x)) | ||
init(o::OADAM, x::Union{AbstractArray,Number}) = (zero(x), zero(x), o.beta, zero(x)) | ||
|
||
(o::OADAM)(state, m, dm) = update(o, state, m, dm) | ||
|
||
|
@@ -287,6 +287,7 @@ struct ADAGrad{T} | |
end | ||
ADAGrad(η = 1f-1, ϵ = eps(typeof(η))) = ADAGrad{typeof(η)}(η, ϵ) | ||
|
||
init(o::ADAGrad, ::T) where T <: Number = convert(T, o.epsilon) | ||
init(o::ADAGrad, x::AbstractArray) = fill!(similar(x), o.epsilon) | ||
|
||
(o::ADAGrad)(state, m, dm) = update(o, state, m, dm) | ||
|
@@ -319,7 +320,7 @@ struct ADADelta{T} | |
end | ||
ADADelta(ρ = 9f-1, ϵ = eps(typeof(ρ))) = ADADelta{typeof(ρ)}(ρ, ϵ) | ||
|
||
init(o::ADADelta, x::AbstractArray) = (zero(x), zero(x)) | ||
init(o::ADADelta, x::Union{AbstractArray,Number}) = (zero(x), zero(x)) | ||
|
||
(o::ADADelta)(state, m, dm) = update(o, state, m, dm) | ||
|
||
|
@@ -332,7 +333,7 @@ function apply!(o::ADADelta, state, x, dx) | |
# or even out of the square roots | ||
dx′ = @.. dx * sqrt(Δacc + ϵ) / sqrt(acc + ϵ) | ||
Δacc′ = @.. Δacc = ρ * Δacc + (1 - ρ) * dx^2 | ||
|
||
return (acc′, Δacc′), dx′ | ||
end | ||
|
||
|
@@ -357,8 +358,8 @@ struct AMSGrad{T} | |
end | ||
AMSGrad(η = 1f-3, β = (9f-1, 9.99f-1), ϵ = eps(typeof(η))) = AMSGrad{typeof(η)}(η, β, ϵ) | ||
|
||
init(o::AMSGrad, x::AbstractArray) = | ||
(fill!(similar(x), o.epsilon), fill!(similar(x), o.epsilon), fill!(similar(x), o.epsilon)) | ||
init(o::AMSGrad, ::T) where T <: Number = ntuple(_ -> convert(T, o.epsilon), 3) | ||
init(o::AMSGrad, x::AbstractArray) = ntuple(_ -> fill!(similar(x), o.epsilon), 3) | ||
|
||
(o::AMSGrad)(state, m, dm) = update(o, state, m, dm) | ||
|
||
|
@@ -396,7 +397,7 @@ struct NADAM{T} | |
end | ||
NADAM(η = 1f-3, β = (9f-1, 9.99f-1), ϵ = eps(typeof(η))) = NADAM{typeof(η)}(η, β, ϵ) | ||
|
||
init(o::NADAM, x::AbstractArray) = (zero(x), zero(x), o.beta) | ||
init(o::NADAM, x::Union{AbstractArray,Number}) = (zero(x), zero(x), o.beta) | ||
|
||
(o::NADAM)(state, m, dm) = update(o, state, m, dm) | ||
|
||
|
@@ -407,7 +408,7 @@ function apply!(o::NADAM, state, x, dx) | |
|
||
mt′ = @.. mt = β[1] * mt + (1 - β[1]) * dx | ||
vt′ = @.. vt = β[2] * vt + (1 - β[2]) * dx^2 | ||
dx′ = @.. (β[1] * mt / (1 - β[1] * βt[1]) + (1 - β[1]) * dx / (1 - βt[1])) / | ||
dx′ = @.. (β[1] * mt / (1 - β[1] * βt[1]) + (1 - β[1]) * dx / (1 - βt[1])) / | ||
(sqrt(vt * β[2] / (1 - βt[2])) + ϵ) * η | ||
|
||
return (mt′, vt′, βt .* β), dx′ | ||
|
@@ -452,7 +453,7 @@ struct AdaBelief{T} | |
end | ||
AdaBelief(η = 1f-3, β = (9f-1, 9.99f-1), ϵ = eps(typeof(η))) = AdaBelief{typeof(η)}(η, β, ϵ) | ||
|
||
init(o::AdaBelief, x::AbstractArray) = (zero(x), zero(x)) | ||
init(o::AdaBelief, x::Union{AbstractArray,Number}) = (zero(x), zero(x)) | ||
|
||
(o::AdaBelief)(state, m, dm) = update(o, state, m, dm) | ||
|
||
|
@@ -463,7 +464,7 @@ function apply!(o::AdaBelief, state, x, dx) | |
mt′ = @.. mt = β[1] * mt + (1 - β[1]) * dx | ||
st′ = @.. st = β[2] * st + (1 - β[2]) * (dx - mt)^2 | ||
dx′ = @.. η * mt / (sqrt(st) + ϵ) | ||
|
||
return (mt′, st′), dx′ | ||
end | ||
|
||
|
@@ -480,7 +481,7 @@ struct WeightDecay{T} | |
end | ||
WeightDecay() = WeightDecay(5f-4) | ||
|
||
init(o::WeightDecay, x::AbstractArray) = nothing | ||
init(o::WeightDecay, x::Union{AbstractArray,Number}) = nothing | ||
|
||
(o::WeightDecay)(state, m, dm) = update(o, state, m, dm) | ||
|
||
|
@@ -502,7 +503,7 @@ struct ClipGrad{T<:Real} | |
end | ||
ClipGrad() = ClipGrad(10f0) | ||
|
||
init(o::ClipGrad, x::AbstractArray) = nothing | ||
init(o::ClipGrad, x::Union{AbstractArray,Number}) = nothing | ||
|
||
(o::ClipGrad)(state::Nothing, m, dm) = update(o, state, m, dm) | ||
|
||
|
@@ -531,7 +532,7 @@ struct ClipNorm{T<:Real} | |
end | ||
ClipNorm(ω = 10f0, p = 2; throw::Bool = true) = ClipNorm{typeof(ω)}(ω, p, throw) | ||
|
||
init(o::ClipNorm, x::AbstractArray) = nothing | ||
init(o::ClipNorm, x::Union{AbstractArray,Number}) = nothing | ||
|
||
(o::ClipNorm)(state::Nothing, m, dm) = update(o, state, m, dm) | ||
|
||
|
@@ -556,7 +557,7 @@ struct OptimiserChain{O} | |
end | ||
OptimiserChain(opts...) = OptimiserChain(opts) | ||
|
||
init(o::OptimiserChain, x::AbstractArray) = [init(opt, x) for opt in o.opts] | ||
init(o::OptimiserChain, x::Union{AbstractArray,Number}) = [init(opt, x) for opt in o.opts] | ||
|
||
(o::OptimiserChain)(state, m, dms...) = update(o, state, m, dms...) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe it's cleaner to say what's excluded:
Then when someone wants Unitful numbers, they may just work.
If this is going to be applied everywhere, not as an opt-in, then I think it ought to exclude integers. Even though AD mostly draws the line only at Bool. It's possible that arrays of integers should similarly be excluded from
isnumeric
? Then all cases can just be:There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe, though if in some distant future Flux is actually capable of quantization-aware training you might well see int8/int4 arrays. RE Unitful quantities, is there any reason to consider them any "safer" to optimize than integers?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the graphics people do some things with low-bit numbers, but wrap them up to disguise the integer inside --- presumably it would make sense to do something like that before trying to AD with Int8?
My thinking is that if you add
float(zero(x))
to them, which is what some trivial gradient could produce, nothing much will happen. Whereas if you do this to a size or an index, it'll break.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here's a thought: if the concern is about indices and bools, can we just exclude Int64, Int32 and Bool? Maybe the UInts down to 8 and Int16 as well. I've not seen anyone using those types as model params, and it would eliminate 99% of the confusion with indices + true/false.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After playing around for a bit with Unitful, I'm still not comfortable with giving carte blanche to all non-integer number types. Even with floats, it's not clear that a scalar param should be trainable by default.