Skip to content

Commit

Permalink
forward_demand: Implement bundle truncation
Browse files Browse the repository at this point in the history
Also make the canonical 1-bundle use the TaylorBundle type. The
types are isomorphic, but since `∂xⁿ{1}()` returns a TaylorBundle,
this helps type stability.
  • Loading branch information
Keno committed Jan 11, 2023
1 parent b293bbd commit dade754
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 19 deletions.
60 changes: 43 additions & 17 deletions src/codegen/forward_demand.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Core.Compiler: IRInterpretationState, construct_postdomtree, PiNode,
is_known_call, argextype, postdominates
is_known_call, argextype, postdominates, userefs

#=
function forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, to_diff::Vector{Pair{SSAValue, Int}}; custom_diff! = (args...)->nothing, diff_cache=Dict{SSAValue, SSAValue}())
Expand Down Expand Up @@ -93,12 +93,6 @@ function forward_diff_uncached!(ir::IRCode, interp, irsv::IRInterpretationState,
return Δtangent
else # general frule handling
info = inst[:info]
if !isa(info, FRuleCallInfo)
@show info
@show inst[:inst]
display(ir)
error()
end
if isexpr(stmt, :invoke)
args = stmt.args[2:end]
else
Expand Down Expand Up @@ -196,22 +190,50 @@ function forward_diff_no_inf!(ir::IRCode, interp, mi::MethodInstance, world, to_
forward_visit!(ir, ssa, order, ssa_orders, visit_custom!)
end

truncation_map = Dict{Pair{SSAValue, Int}, SSAValue}()

# Step 2: Transform
function maparg(arg, ssa, order)
if isa(arg, Argument)
if isa(arg, SSAValue)
if arg.id > length(ssa_orders)
# This is possible if the custom transform touched another statement.
# In that case just pass this through and assume the `transform!` did
# it correctly.
return arg
end
(argorder, _) = ssa_orders[arg.id]
if argorder != order
@assert order < argorder
return get!(truncation_map, arg=>order) do
# TODO: Other orders
@assert order == 0
insert_node!(ir, arg, NewInstruction(Expr(:call, primal, arg), Any), #=attach_after=#true)
end
end
return arg
elseif order == 0
return arg
elseif isa(arg, Argument)
# TODO: Should we remember whether the callbacks wanted the arg?
return transform!(ir, arg, order)
elseif isa(arg, SSAValue)
# TODO: Bundle truncation if necessary
return arg
elseif isa(arg, GlobalRef)
return insert_node!(ir, ssa, NewInstruction(Expr(:call, ZeroBundle{order}, arg), Any))
elseif isa(arg, QuoteNode)
return ZeroBundle{order}(arg.value)
end
@assert !isa(arg, Expr)
return insert_node!(ir, ssa, NewInstruction(Expr(:call, ZeroBundle{order}, arg), Any))
return ZeroBundle{order}(arg)
end

for (ssa, (order, custom)) in enumerate(ssa_orders)
if order == 0
# TODO: Bundle truncation?
inst = ir[SSAValue(ssa)]
stmt = inst[:inst]
urs = userefs(stmt)
for ur in urs
ur[] = maparg(ur[], SSAValue(ssa), order)
end
inst[:inst] = urs[]
continue
end
if custom
Expand All @@ -222,12 +244,16 @@ function forward_diff_no_inf!(ir::IRCode, interp, mi::MethodInstance, world, to_
if isexpr(stmt, :invoke)
inst[:inst] = Expr(:call, ∂☆{order}(), map(arg->maparg(arg, SSAValue(ssa), order), stmt.args[2:end])...)
inst[:type] = Any
elseif !isa(stmt, Expr)
inst[:inst] = maparg(stmt, ssa, order)
elseif isexpr(stmt, :call)
inst[:inst] = Expr(:call, ∂☆{order}(), map(arg->maparg(arg, SSAValue(ssa), order), stmt.args)...)
inst[:type] = Any
else
@show stmt
error()
urs = userefs(stmt)
for ur in urs
ur[] = maparg(ur[], SSAValue(ssa), order)
end
inst[:inst] = urs[]
inst[:type] = Any
end
end
end
Expand Down
2 changes: 1 addition & 1 deletion src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ raise_pd(f::PrimeDerivativeFwd{N,T}) where {N,T} = PrimeDerivativeFwd{plus1(N),T

function (f::PrimeDerivativeFwd{1})(x)
z = ∂☆¹(ZeroBundle{1}(getfield(f, :f)), ∂x(x))
z.tangent.partials[1]
z[TaylorTangentIndex(1)]
end

function (f::PrimeDerivativeFwd{N})(x) where N
Expand Down
2 changes: 1 addition & 1 deletion src/stage1/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ end
struct ∂☆internal{N}; end
struct ∂☆shuffle{N}; end

shuffle_base(r) = ExplicitTangentBundle{1}(r[1], (r[2],))
shuffle_base(r) = TaylorBundle{1}(r[1], (r[2],))

function (::∂☆internal{1})(args::AbstractTangentBundle{1}...)
r = my_frule(args...)
Expand Down
7 changes: 7 additions & 0 deletions src/tangent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,13 @@ function TaylorBundle{N}(primal, coeffs) where {N}
TaylorBundle{N, Core.Typeof(primal)}(primal, coeffs)
end

function Base.show(io::IO, x::TaylorBundle{1})
print(io, x.primal)
print(io, " + ")
x = x.tangent
print(io, x.coeffs[1], " ∂₁")
end

Base.getindex(tb::TaylorBundle, tti::TaylorTangentIndex) = tb.tangent.coeffs[tti.i]
function Base.getindex(tb::TaylorBundle, tti::CanonicalTangentIndex)
tb.tangent.coeffs[count_ones(tti.i)]
Expand Down

0 comments on commit dade754

Please sign in to comment.