diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index c6eee8d..c543db3 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -20,7 +20,6 @@ jobs: fail-fast: false matrix: version: - - '1.3' - '1.6' # LTS - '1' - 'nightly' diff --git a/Project.toml b/Project.toml index 08da26d..cdf0a18 100644 --- a/Project.toml +++ b/Project.toml @@ -1,16 +1,18 @@ name = "Tracker" uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" -version = "0.2.20" +version = "0.2.21" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" @@ -18,16 +20,18 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] -Adapt = "1, 2, 3" +Adapt = "3" DiffRules = "1.4" +Functors = "0.3.0" ForwardDiff = "0.10" LogExpFunctions = "0.3" MacroTools = "0.5" -NNlib = "0.7.18, 0.8" # 0.7.18 is the last version which supports Julia 1.3 -NaNMath = "0.3, 1" -Requires = "0.5, 1.0" -SpecialFunctions = "0.10, 1, 2" -julia = "1.3" +NNlib = "0.8" +NaNMath = "1" +Optimisers = "0.2.9" +Requires = "1.0" +SpecialFunctions = "1, 2" +julia = "1.6" [extras] PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" diff --git a/src/Tracker.jl b/src/Tracker.jl index b269fc4..354fdea 100644 --- a/src/Tracker.jl +++ b/src/Tracker.jl @@ -13,7 +13,7 @@ import Printf import Base: == export TrackedArray, TrackedVector, TrackedMatrix, Params, gradient, - jacobian, hessian, param, back! + jacobian, hessian, param, back!, withgradient tracker(x) = nothing @@ -70,10 +70,10 @@ end include("idset.jl") include("params.jl") -include("back.jl") -include("numeric.jl") include("lib/real.jl") include("lib/array.jl") +include("back.jl") +include("numeric.jl") include("forward.jl") @init @require PDMats="90014a1f-27ba-587c-ab20-58faa44d9150" include("lib/pdmats.jl") diff --git a/src/back.jl b/src/back.jl index e638e1a..b7f07db 100644 --- a/src/back.jl +++ b/src/back.jl @@ -178,3 +178,71 @@ function jacobian(f, x::AbstractVector) end hessian(f, x) = jacobian(x -> gradient(f, x, nest=true)[1], x) + +using Functors: fmap, functor +using Optimisers: _trainable, isnumeric + +""" + withgradient(f, xs...) + +This computes the value `f(xs...)` and the gradient with respect to `xs`. +However, it differs from `gradient` in several other respects: +* It will recurse into `xs` using `fmap`, and thus like Zygote's "explicit mode" it + returns a tree-like gradient matching the shape of a Flux model. + This recursion obeys restrictions imposed by `Optimisers.trainable`, if defined. +* Only objects satisfying `Optimisers.isnumeric` are regarded as parameters, + thus in particular integers are ignored. +* Returns plain arrays, not tracked. Uses `nothing` as a strong zero gradient, like Zygote. + +# Examples +``` +julia> nt = (vec = [1.0, 2.0], mat = [4.0;;], fun = sin); + +julia> withgradient(nt, 2) do x, p + sum(abs2, x.vec) ^ p + end +(val = 25.0, grad = ((vec = [20.0, 40.0], mat = [0.0;;], fun = nothing), nothing)) + +julia> using Flux + +julia> model = Chain(Dense(2 => 1, tanh), Dense(1 => 1, bias=false)); + +julia> withgradient(model, rand(Float32, 2)) do m, x + sum(abs2, m(x)) + end +(val = 0.035716165f0, grad = ((layers = ((weight = Float32[-0.4241869 -0.16741231], bias = Float32[-0.5529184], σ = nothing), (weight = Float32[-0.04804218;;], bias = nothing, σ = nothing)),), Float32[0.12706584, -0.08858479])) +``` +""" +function withgradient(f, xs...) + pxs = fmap(param, xs; exclude = isnumeric, walk = _trainable_walk) + l = f(pxs...) + losscheck(l) + l isa TrackedReal || return (val = l, grad = nothing) + @interrupts back!(l) + (val = data(l), grad = rec_grad(pxs)) +end + +function _trainable_walk(f, x) + func, re = functor(x) + isempty(func) && return x + done = map(f, _trainable(x)) # recurse only into trainable fields, this contains `nothing` elsewhere + map(func, merge(func, done)) do n, t + isnothing(t) ? n : t + end |> re # reconstruct the whole thing +end +_trainable_walk(f, x::Tuple) = map(f, x) + +# Easier to write the recursion to extract the gradients without using fmap: +rec_grad(x::TrackedArray) = grad(x) +rec_grad(x::TrackedReal) = grad(x) +rec_grad(x::AbstractArray{<:Number}) = nothing +rec_grad(x::Number) = nothing + +rec_grad(x::Union{Tuple,NamedTuple,AbstractArray}) = map(rec_grad, x) +rec_grad(::Tuple{}) = nothing +rec_grad(::NamedTuple{(), Tuple{}}) = nothing +function rec_grad(x::T) where {T} + F = fieldnames(T) + isempty(F) && return nothing + map(f -> rec_grad(getfield(x, f)), NamedTuple{F}(F)) +end diff --git a/test/runtests.jl b/test/runtests.jl index d9790fc..b2ce9d6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -17,4 +17,20 @@ using Tracker: jacobian @test J ≈ A.data end +using Optimisers, Functors +struct TwoThirds a; b; c; end # evil test from Optimisers.jl +@eval Functors.@functor TwoThirds (a, c) +Optimisers.trainable(x::TwoThirds) = (a = x.a,) + +@testset "withgradient" begin + nt = (vec = [1.0, 2.0], mat = [4.0;;], fun = sin); + @test withgradient((x, p) -> sum(abs2, x.vec) ^ p, nt, 2) == (val = 25.0, grad = ((vec = [20.0, 40.0], mat = [0.0;;], fun = nothing), nothing)) + + @test withgradient(x -> sum(x.v), (v = [1, 2], w = [3.0])) == (val = 3, grad = nothing) + + m = TwoThirds([1.0], [2.0], [3.0]) # only the first should be tracked, but all should survive + g = withgradient(m -> only(m.a::AbstractVector + m.b::Vector + m.c::Vector), m) + @test g == (val = 6.0, grad = ((a = [1.0], b = nothing, c = nothing),)) end + +end # overall @testset