From fdff08be436e032fd24fccad0f8b1b129a969f2a Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Tue, 30 May 2023 16:29:26 -0400 Subject: [PATCH] Give Diffractor a ForwardDiff inspired interface --- Manifest.toml | 4 +-- src/interface.jl | 79 +++++++---------------------------------- src/stage1/broadcast.jl | 7 ++-- src/tangent.jl | 15 ++++---- test/Project.toml | 5 +++ test/runtests.jl | 6 ++-- 6 files changed, 35 insertions(+), 81 deletions(-) diff --git a/Manifest.toml b/Manifest.toml index 7de02ba8..4510e7f1 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -126,9 +126,9 @@ uuid = "82899510-4779-5014-852e-03e436cf321d" version = "1.0.0" [[deps.JuliaSyntax]] -git-tree-sha1 = "3884259b6852ed89c7036c455551a556d8a3a124" +git-tree-sha1 = "3379908bd15b3ae86b24de22efbb1e6813864078" uuid = "70703baa-626e-46a2-a12c-08ffd08c73b4" -version = "0.4.1" +version = "0.4.3" [[deps.LibGit2]] deps = ["Base64", "NetworkOptions", "Printf", "SHA"] diff --git a/src/interface.jl b/src/interface.jl index 9b565ee3..f27d15a0 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -59,17 +59,14 @@ dx(x::Complex) = error("Tried to take the gradient of a complex-valued function. dx(x) = error("Cotangent space not defined for `$(typeof(x))`. Try a real-valued function.") """ - ∂x(x) + ∂xⁿ{N}(x) -For `x` in a one dimensional manifold, map x to the trivial, unital, 1st order -tangent bundle. It should hold that `∀x ⟨∂x(x), dx(x)⟩ = 1` +For `x` in a one dimensional manifold, map x to the trivial, unital, Nth order +tangent bundle. It should hold that `∀x ⟨∂ⁿ{1}x(x), dx(x)⟩ = 1` """ -∂x(x::Real) = ExplicitTangentBundle{1}(x, (one(x),)) -∂x(x) = error("Tangent space not defined for `$(typeof(x)).") - struct ∂xⁿ{N}; end -(::∂xⁿ{N})(x::Real) where {N} = TaylorBundle{N}(x, (one(x), (zero(x) for i = 1:(N-1))...,)) +(::∂xⁿ{N})(x::Real) where {N} = TaylorBundle{N}(x, ntuple(i->i==1 ? one(x) : zero(x), N)) (::∂xⁿ)(x) = error("Tangent space not defined for `$(typeof(x)).") function ChainRules.rrule(∂::∂xⁿ, x) @@ -130,8 +127,6 @@ function (::Type{∇})(f, x1, args...) unthunk.(∇(f)(x1, args...)) end -const gradient = ∇ - # Star Trek has their prime directive. We have the... abstract type AbstractPrimeDerivative{N, T}; end @@ -172,66 +167,18 @@ lower_pd(f::PrimeDerivativeFwd{N,T}) where {N,T} = (error(); PrimeDerivativeFwd{ raise_pd(f::PrimeDerivativeFwd{N,T}) where {N,T} = PrimeDerivativeFwd{N+1,T}(getfield(f, :f)) (f::PrimeDerivativeFwd{0})(x) = getfield(f, :f)(x) - -function (f::PrimeDerivativeFwd{1})(x) - z = ∂☆¹(ZeroBundle{1}(getfield(f, :f)), ∂x(x)) - z[TaylorTangentIndex(1)] -end - function (f::PrimeDerivativeFwd{N})(x) where N z = ∂☆{N}()(ZeroBundle{N}(getfield(f, :f)), ∂xⁿ{N}()(x)) z[TaylorTangentIndex(N)] end -# Polyalgorithm prime derivative -struct PrimeDerivative{N, T} - f::T -end - -function (f::PrimeDerivative{N, T})(x) where {N, T} - # For now, this is backwards mode, since that's more fully implemented - return PrimeDerivativeBack{N, T}(f.f)(x) -end - -""" - f' - -This is a convenience syntax for taking the derivative of a function f: ℝ -> ℝ. -In particular, for such a function f'(x) will be the first derivative of `f` -at `x` (and similar for `f''(x)` and second derivatives and so on.) - -Note that the syntax conflicts with the Base definition for the adjoint of a -matrix and thus is not enabled by default. To use it, add the following to the -top of your module: - -```julia -using Diffractor: var"'" -``` - -It is also available using the @∂ macro: -```julia -@∂ f'(x) -``` -""" -var"'"(f) = PrimeDerivativeBack(f) - -""" - @∂ - -Convenice macro for writing partial derivatives. E.g. The expression: - -```julia -@∂ f(∂x, ∂y) -``` - -Will compute the partial derivative ∂^2 f/∂x∂y at `(x, y)``. And similarly - -```julia -@∂ f(∂²x, ∂y) -``` - -will compute the derivative `∂^3 f/∂x^2 ∂y` at `(x,y)`. -""" -macro ∂(expr) - error("Write me") +derivative(f, x) = Diffractor.PrimeDerivativeFwd(f)(x) +function gradient(f, x::AbstractVector) + map(eachindex(x)) do i + derivative(ξ -> f(vcat(x[begin:i-1], ξ, x[i+1:end])), x[i]) + end end +gradient(f, x::AbstractArray) = reshape(gradient(v -> f(reshape(v, size(x))), vec(x)), size(x)) +gradient(f, xs...) = unthunk.(∇(f)(xs...)) +jacobian(f, x::AbstractArray) = reduce(hcat, vec.(gradient(f, x))) +hessian(f, x::AbstractArray) = jacobian(y -> gradient(f, y), float(x)) diff --git a/src/stage1/broadcast.jl b/src/stage1/broadcast.jl index e04bd60a..52c997d2 100644 --- a/src/stage1/broadcast.jl +++ b/src/stage1/broadcast.jl @@ -17,10 +17,11 @@ function (∂ₙ::∂☆{N})(zc::ZeroBundle{N, typeof(copy)}, FwdMap(n_getfield(∂ₙ, bc, :f)), ntuple(length(primal(args))) do i val = n_getfield(∂ₙ, args, i) - if ndims(primal(val)) == 0 - return Ref(∂ₙ(ZeroBundle{N}(getindex), val)) - else + p = primal(val) + if p isa AbstractArray && ndims(p) != 0 return unbundle(val) + else + return Ref(∂ₙ(ZeroBundle{N}(getindex), val)) end end)) if isa(r, AbstractArray) diff --git a/src/tangent.jl b/src/tangent.jl index f5505527..e12e4c1c 100644 --- a/src/tangent.jl +++ b/src/tangent.jl @@ -334,10 +334,6 @@ function unbundle(atb::TaylorBundle{Order, A}) where {Order, Dim, T, A<:Abstract StructArray{TaylorBundle{Order, T}}((atb.primal, atb.tangent.coeffs...)) end -function ChainRulesCore.rrule(::typeof(unbundle), atb::TaylorBundle) - unbundle(atb), Δ->throw(Δ) -end - function StructArrays.staticschema(::Type{<:TaylorBundle{N, B, T}}) where {N, B, T} Tuple{B, T.parameters...} end @@ -355,11 +351,11 @@ function StructArrays.createinstance(T::Type{<:TaylorBundle}, args...) T(first(args), Base.tail(args)) end -function unbundle(zb::ZeroBundle{N, A}) where {N,T,Dim,A<:AbstractArray{T, Dim}} - StructArray{ZeroBundle{N, T}}((zb.primal, fill(zb.tangent.val, size(zb.primal)...))) +function unbundle(u::UniformBundle{N, A}) where {N,T,Dim,A<:AbstractArray{T, Dim}} + StructArray{UniformBundle{N, T}}((u.primal, fill(u.tangent.val, size(u.primal)...))) end -function ChainRulesCore.rrule(::typeof(unbundle), atb::ZeroBundle) +function ChainRulesCore.rrule(::typeof(unbundle), atb::AbstractTangentBundle) unbundle(atb), Δ->throw(Δ) end @@ -383,6 +379,11 @@ function rebundle(A::AbstractArray{<:TaylorBundle{N}}) where {N} end) end +function rebundle(A::AbstractArray{<:UniformBundle{N}}) where {N} + @assert all(x->getfield(x, :tangent)==(first(A).tangent), A) + UniformBundle{N}(map(x->x.primal, A), first(A).tangent.val) +end + function ChainRulesCore.rrule(::typeof(rebundle), atb) rebundle(atb), Δ->throw(Δ) end diff --git a/test/Project.toml b/test/Project.toml index 6135bbbb..1ab6f008 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -2,11 +2,15 @@ ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" +DiffTests = "de460e47-3fe3-5279-bb4a-814414816d5d" +Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" @@ -16,6 +20,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" ChainRules = "1.44.5" ChainRulesCore = "1.15.3" Combinatorics = "1" +DiffTests = "0.1.1" StaticArrays = "1" StatsBase = "0.33" StructArrays = "0.6.12" diff --git a/test/runtests.jl b/test/runtests.jl index 9c8ea16c..67b2ae37 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,5 @@ using Diffractor -using Diffractor: var"'", ∂⃖, DiffractorRuleConfig +using Diffractor: ∂⃖, derivative, DiffractorRuleConfig using ChainRules using ChainRulesCore using ChainRulesCore: ZeroTangent, NoTangent, frule_via_ad, rrule_via_ad @@ -13,7 +13,7 @@ const bwd = Diffractor.PrimeDerivativeBack @testset verbose=true "Diffractor.jl" begin # overall testset, ensures all tests run -@testset "$file" for file in ("stage2_fwd.jl", "tangent.jl") +@testset "$file" for file in ("stage2_fwd.jl", "tangent.jl")#, "forwarddiff_tests.jl", ) include(file) end @@ -55,7 +55,7 @@ ChainRules.rrule(::typeof(my_tuple), args...) = args, Δ->Core.tuple(NoTangent() # Minimal 2-nd order forward smoke test @test Diffractor.∂☆{2}()(Diffractor.ZeroBundle{2}(sin), - Diffractor.ExplicitTangentBundle{2}(1.0, (1.0, 1.0, 0.0)))[Diffractor.CanonicalTangentIndex(1)] == sin'(1.0) + Diffractor.ExplicitTangentBundle{2}(1.0, (1.0, 1.0, 0.0)))[Diffractor.CanonicalTangentIndex(1)] == derivative(sin, 1.0) function simple_control_flow(b, x) if b