Skip to content

Commit

Permalink
add RuleConfig
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Sep 9, 2021
1 parent b95eb68 commit ad02eee
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 2 deletions.
2 changes: 2 additions & 0 deletions src/runtime.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using ChainRulesCore

struct DiffractorRuleConfig <: RuleConfig{Union{HasReverseMode,HasForwardsMode}} end

@Base.aggressive_constprop accum(a, b) = a + b
@Base.aggressive_constprop accum(a::Tuple, b::Tuple) = map(accum, a, b)
@Base.aggressive_constprop @generated function accum(x::NamedTuple, y::NamedTuple)
Expand Down
4 changes: 3 additions & 1 deletion src/stage1/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ first_partial(x::CompositeBundle) = map(first_partial, getfield(x, :tup))

# TODO: Which version do we want in ChainRules?
function my_frule(args::ATB{1}...)
frule(map(first_partial, args), map(primal, args)...)
frule(DiffractorRuleConfig(), map(first_partial, args), map(primal, args)...)
end

# Fast path for some hot cases
Expand Down Expand Up @@ -118,6 +118,8 @@ function (::∂☆internal{1})(args::AbstractTangentBundle{1}...)
end
end

ChainRulesCore.frule_via_ad(::DiffractorRuleConfig, args...) = ∂☆internal{1}()(args...)

function (::∂☆shuffle{N})(args::AbstractTangentBundle{N}...) where {N}
∂☆p = ∂☆{minus1(N)}()
∂☆p(ZeroBundle{minus1(N)}(my_frule), map(shuffle_down, args)...)
Expand Down
6 changes: 5 additions & 1 deletion src/stage1/generated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ function (::∂⃖{N})(f::T, args...) where {T, N}
if N == 1
# Base case (inlined to avoid ambiguities with manually specified
# higher order rules)
z = rrule(f, args...)
z = rrule(DiffractorRuleConfig(), f, args...)
if z === nothing
return ∂⃖recurse{1}()(f, args...)
end
Expand All @@ -225,6 +225,10 @@ function (::∂⃖{N})(f::T, args...) where {T, N}
end
end

function ChainRulesCore.rrule_via_ad(::DiffractorRuleConfig, f::T, args...) where {T}
∂⃖{1}()(f, args...)
end

@Base.pure function (::∂⃖{1})(::typeof(Core.apply_type), head, args...)
return rrule(Core.apply_type, head, args...)
end
Expand Down

0 comments on commit ad02eee

Please sign in to comment.