Skip to content

Commit

Permalink
Merge pull request #49 from JuliaDiff/sds/fix_constprop
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott authored Sep 10, 2021
2 parents 2059a35 + 9709486 commit 0a98045
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 320 deletions.
274 changes: 0 additions & 274 deletions Manifest.toml

This file was deleted.

14 changes: 7 additions & 7 deletions src/runtime.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
using ChainRulesCore

@Base.aggressive_constprop accum(a, b) = a + b
@Base.aggressive_constprop accum(a::Tuple, b::Tuple) = map(accum, a, b)
@Base.aggressive_constprop @generated function accum(x::NamedTuple, y::NamedTuple)
@Base.constprop :aggressive accum(a, b) = a + b
@Base.constprop :aggressive accum(a::Tuple, b::Tuple) = map(accum, a, b)
@Base.constprop :aggressive @generated function accum(x::NamedTuple, y::NamedTuple)
fnames = union(fieldnames(x), fieldnames(y))
gradx(f) = f in fieldnames(x) ? :(getfield(x, $(quot(f)))) : :(ZeroTangent())
grady(f) = f in fieldnames(y) ? :(getfield(y, $(quot(f)))) : :(ZeroTangent())
Expr(:tuple, [:($f=accum($(gradx(f)), $(grady(f)))) for f in fnames]...)
end
@Base.aggressive_constprop accum(a, b, c, args...) = accum(accum(a, b), c, args...)
@Base.aggressive_constprop accum(a::NoTangent, b) = b
@Base.aggressive_constprop accum(a, b::NoTangent) = a
@Base.aggressive_constprop accum(a::NoTangent, b::NoTangent) = NoTangent()
@Base.constprop :aggressive accum(a, b, c, args...) = accum(accum(a, b), c, args...)
@Base.constprop :aggressive accum(a::NoTangent, b) = b
@Base.constprop :aggressive accum(a, b::NoTangent) = a
@Base.constprop :aggressive accum(a::NoTangent, b::NoTangent) = NoTangent()
14 changes: 7 additions & 7 deletions src/stage1/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,37 +134,37 @@ end
(::∂☆{N})(args::AbstractTangentBundle{N}...) where {N} = ∂☆internal{N}()(args...)

# Special case rules for performance
@Base.aggressive_constprop function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::TangentBundle{N}, s::AbstractTangentBundle{N}) where {N}
@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::TangentBundle{N}, s::AbstractTangentBundle{N}) where {N}
s = primal(s)
TangentBundle{N}(getfield(primal(x), s),
map(x->lifted_getfield(x, s), x.partials))
end

@Base.aggressive_constprop function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::TaylorBundle{N}, s::AbstractTangentBundle{N}) where {N}
@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::TaylorBundle{N}, s::AbstractTangentBundle{N}) where {N}
s = primal(s)
TaylorBundle{N}(getfield(primal(x), s),
map(y->lifted_getfield(y, s), x.coeffs))
end

@Base.aggressive_constprop function (::∂☆{N})(::ATB{N, typeof(getfield)}, x::CompositeBundle{N}, s::AbstractTangentBundle{N, Int}) where {N}
@Base.constprop :aggressive function (::∂☆{N})(::ATB{N, typeof(getfield)}, x::CompositeBundle{N}, s::AbstractTangentBundle{N, Int}) where {N}
x.tup[primal(s)]
end

@Base.aggressive_constprop function (::∂☆{N})(::ATB{N, typeof(getfield)}, x::CompositeBundle{N, B}, s::AbstractTangentBundle{N, Symbol}) where {N, B}
@Base.constprop :aggressive function (::∂☆{N})(::ATB{N, typeof(getfield)}, x::CompositeBundle{N, B}, s::AbstractTangentBundle{N, Symbol}) where {N, B}
x.tup[Base.fieldindex(B, primal(s))]
end

@Base.aggressive_constprop function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::ATB{N}, s::ATB{N}, inbounds::ATB{N}) where {N}
@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::ATB{N}, s::ATB{N}, inbounds::ATB{N}) where {N}
s = primal(s)
TangentBundle{N}(getfield(primal(x), s, primal(inbounds)),
map(x->lifted_getfield(x, s), x.partials))
end

@Base.aggressive_constprop function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::UniformBundle{N, <:Any, U}, s::AbstractTangentBundle{N}) where {N, U}
@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::UniformBundle{N, <:Any, U}, s::AbstractTangentBundle{N}) where {N, U}
UniformBundle{N,<:Any,U}(getfield(primal(x), primal(s)), x.partial)
end

@Base.aggressive_constprop function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::UniformBundle{N, <:Any, U}, s::AbstractTangentBundle{N}, inbounds::AbstractTangentBundle{N}) where {N, U}
@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::UniformBundle{N, <:Any, U}, s::AbstractTangentBundle{N}, inbounds::AbstractTangentBundle{N}) where {N, U}
UniformBundle{N,<:Any,U}(getfield(primal(x), primal(s), primal(inbounds)), x.partial)
end

Expand Down
Loading

0 comments on commit 0a98045

Please sign in to comment.