From 6495ed1f2a38ffd793d8d356530ff2a0c5fa5796 Mon Sep 17 00:00:00 2001 From: Keno Fischer Date: Sat, 31 Dec 2022 20:25:20 +0000 Subject: [PATCH 1/2] Hookup demand-driven forward mode to the Diffractor runtime Tests currently depend on https://github.com/JuliaLang/julia/pull/48045 and https://github.com/JuliaLang/julia/pull/48059, so we should either get those merged first, or mark them here as broken. --- src/codegen/forward_demand.jl | 137 +++++++++++++++++++++++++++++----- src/higher_fwd_rules.jl | 17 +++++ src/stage1/forward.jl | 8 +- src/stage2/forward.jl | 38 +++++++--- src/stage2/interpreter.jl | 79 ++++++++++++++------ src/stage2/lattice.jl | 4 + test/stage2_fwd.jl | 17 ++++- 7 files changed, 246 insertions(+), 54 deletions(-) diff --git a/src/codegen/forward_demand.jl b/src/codegen/forward_demand.jl index 031115ae..87c727e8 100644 --- a/src/codegen/forward_demand.jl +++ b/src/codegen/forward_demand.jl @@ -1,12 +1,13 @@ using Core.Compiler: IRInterpretationState, construct_postdomtree, PiNode, is_known_call, argextype, postdominates -function forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, pantelides::Vector{SSAValue}; custom_diff! = (args...)->nothing, diff_cache=Dict{SSAValue, SSAValue}()) +#= +function forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, to_diff::Vector{Pair{SSAValue, Int}}; custom_diff! = (args...)->nothing, diff_cache=Dict{SSAValue, SSAValue}()) Δs = SSAValue[] rets = findall(@nospecialize(x)->isa(x, ReturnNode) && isdefined(x, :val), ir.stmts.inst) postdomtree = construct_postdomtree(ir.cfg.blocks) - for ssa in pantelides - Δssa = forward_diff!(ir, interp, irsv, ssa; custom_diff!, diff_cache) + for (ssa, order) in to_diff + Δssa = forward_diff!(ir, interp, irsv, ssa, order; custom_diff!, diff_cache) Δblock = block_for_inst(ir, Δssa.id) for idx in rets retblock = block_for_inst(ir, idx) @@ -18,31 +19,24 @@ function forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, pantelid end return (ir, Δs) end +=# -function diff_unassigned_variable!(ir, ssa) - return insert_node!(ir, ssa, NewInstruction( - Expr(:call, GlobalRef(Intrinsics, :state_ddt), ssa), Float64), #=attach_after=#true) -end - -function forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, ssa::SSAValue; custom_diff!, diff_cache) +function forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, ssa::SSAValue, order::Int; custom_diff!, diff_cache) if haskey(diff_cache, ssa) return diff_cache[ssa] end inst = ir[ssa] stmt = inst[:inst] - if isa(stmt, SSAValue) - return forward_diff!(ir, interp, irsv, stmt; custom_diff!, diff_cache) - end - Δssa = forward_diff_uncached!(ir, interp, irsv, ssa, inst; custom_diff!, diff_cache) + Δssa = forward_diff_uncached!(ir, interp, irsv, ssa, inst, order::Int; custom_diff!, diff_cache) @assert Δssa !== nothing if isa(Δssa, SSAValue) diff_cache[ssa] = Δssa end return Δssa end -forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, val::Union{Integer, AbstractFloat}; custom_diff!, diff_cache) = zero(val) -forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, @nospecialize(arg); custom_diff!, diff_cache) = ChainRulesCore.NoTangent() -function forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, arg::Argument; custom_diff!, diff_cache) +forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, val::Union{Integer, AbstractFloat}, order::Int; custom_diff!, diff_cache) = zero(val) +forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, @nospecialize(arg), order::Int; custom_diff!, diff_cache) = ChainRulesCore.NoTangent() +function forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, arg::Argument, order::Int; custom_diff!, diff_cache) recurse(x) = forward_diff!(ir, interp, irsv, x; custom_diff!, diff_cache) val = custom_diff!(ir, SSAValue(0), arg, recurse) if val !== nothing @@ -51,13 +45,15 @@ function forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, arg::Arg return ChainRulesCore.NoTangent() end -function forward_diff_uncached!(ir::IRCode, interp, irsv::IRInterpretationState, ssa::SSAValue, inst::Core.Compiler.Instruction; custom_diff!, diff_cache) +function forward_diff_uncached!(ir::IRCode, interp, irsv::IRInterpretationState, ssa::SSAValue, inst::Core.Compiler.Instruction, order::Int; custom_diff!, diff_cache) stmt = inst[:inst] - recurse(x) = forward_diff!(ir, interp, irsv, x; custom_diff!, diff_cache) + recurse(x) = forward_diff!(ir, interp, irsv, x, order; custom_diff!, diff_cache) if (val = custom_diff!(ir, ssa, stmt, recurse)) !== nothing return val elseif isa(stmt, PiNode) return recurse(stmt.val) + elseif isa(stmt, SSAValue) + return recurse(stmt) elseif isa(stmt, PhiNode) Δphi = PhiNode(copy(stmt.edges), similar(stmt.values)) T = Union{} @@ -152,3 +148,108 @@ function forward_diff_uncached!(ir::IRCode, interp, irsv::IRInterpretationState, return Δssa end end + +function forward_visit!(ir::IRCode, ssa::SSAValue, order::Int, ssa_orders::Vector{Pair{Int, Bool}}, visit_custom!) + if ssa_orders[ssa.id][1] >= order + return + end + ssa_orders[ssa.id] = order => ssa_orders[ssa.id][2] + inst = ir[ssa] + stmt = inst[:inst] + recurse(@nospecialize(val)) = forward_visit!(ir, val, order, ssa_orders, visit_custom!) + if visit_custom!(ir, stmt, order, recurse) + ssa_orders[ssa.id] = order => true + return + elseif isa(stmt, PiNode) + return recurse(stmt.val) + elseif isa(stmt, PhiNode) + for i = 1:length(stmt.values) + isassigned(stmt.values, i) || continue + recurse(stmt.values[i]) + end + return + elseif isexpr(stmt, :new) || isexpr(stmt, :invoke) + foreach(recurse, stmt.args[2:end]) + elseif isexpr(stmt, :call) + foreach(recurse, stmt.args) + elseif isa(stmt, SSAValue) + recurse(stmt) + elseif !isa(stmt, Expr) + return + else + @show stmt + error() + end +end +forward_visit!(ir::IRCode, _, order::Int, ssa_orders::Vector{Pair{Int, Bool}}, visit_custom!) = nothing +function forward_visit!(ir::IRCode, a::Argument, order::Int, ssa_orders::Vector{Pair{Int, Bool}}, visit_custom!) + recurse(@nospecialize(val)) = forward_visit!(ir, val, order, ssa_orders, visit_custom!) + return visit_custom!(ir, a, order, recurse) +end + + +function forward_diff_no_inf!(ir::IRCode, interp, mi::MethodInstance, world, to_diff::Vector{Pair{SSAValue, Int}}; + visit_custom! = (args...)->false, transform! = (args...)->error()) + # Step 1: For each SSAValue in the IR, keep track of the differentiation order needed + ssa_orders = [0=>false for i = 1:length(ir.stmts)] + for (ssa, order) in to_diff + forward_visit!(ir, ssa, order, ssa_orders, visit_custom!) + end + + # Step 2: Transform + function maparg(arg, ssa, order) + if isa(arg, Argument) + # TODO: Should we remember whether the callbacks wanted the arg? + return transform!(ir, arg, order) + elseif isa(arg, SSAValue) + # TODO: Bundle truncation if necessary + return arg + end + @assert !isa(arg, Expr) + return insert_node!(ir, ssa, NewInstruction(Expr(:call, ZeroBundle{order}, arg), Any)) + end + + for (ssa, (order, custom)) in enumerate(ssa_orders) + if order == 0 + # TODO: Bundle truncation? + continue + end + if custom + transform!(ir, SSAValue(ssa), order) + else + inst = ir[SSAValue(ssa)] + stmt = inst[:inst] + if isexpr(stmt, :invoke) + inst[:inst] = Expr(:call, ∂☆{order}(), map(arg->maparg(arg, SSAValue(ssa), order), stmt.args[2:end])...) + inst[:type] = Any + elseif !isa(stmt, Expr) + inst[:inst] = maparg(stmt, ssa, order) + inst[:type] = Any + else + @show stmt + error() + end + end + end + +end + +function forward_diff!(ir::IRCode, interp, mi::MethodInstance, world, to_diff::Vector{Pair{SSAValue, Int}}; kwargs...) + forward_diff_no_inf!(ir, interp, mi, world, to_diff; kwargs...) + + # Step 3: Re-inference + ir = compact!(ir) + + extra_reprocess = CC.BitSet() + for i = 1:length(ir.stmts) + if ir[SSAValue(i)][:type] == Any + CC.push!(extra_reprocess, i) + end + end + + interp′ = enable_reinference(interp) + irsv = IRInterpretationState(interp′, ir, mi, world, ir.argtypes[1:mi.def.nargs]) + rt = CC._ir_abstract_constant_propagation(interp′, irsv; extra_reprocess) + + return ir +end diff --git a/src/higher_fwd_rules.jl b/src/higher_fwd_rules.jl index aeb7f7e5..8bb91a52 100644 --- a/src/higher_fwd_rules.jl +++ b/src/higher_fwd_rules.jl @@ -28,6 +28,23 @@ for f in (sin, cos, exp) end end +# TODO: It's a bit embarassing that we need to write these out, but currently the +# compiler is not strong enough to automatically lift the frule. Let's hope we +# can delete these in the near future. +function (∂☆ₙ::∂☆{N})(fb::ZeroBundle{N, typeof(+)}, a::TaylorBundle{N}, b::TaylorBundle{N}) where {N} + TaylorBundle{N}(primal(a) + primal(b), + map(+, a.tangent.coeffs, b.tangent.coeffs)) +end + +function (∂☆ₙ::∂☆{N})(fb::ZeroBundle{N, typeof(+)}, a::TaylorBundle{N}, b::ZeroBundle{N}) where {N} + TaylorBundle{N}(primal(a) + primal(b), a.tangent.coeffs) +end + +function (∂☆ₙ::∂☆{N})(fb::ZeroBundle{N, typeof(-)}, a::TaylorBundle{N}, b::TaylorBundle{N}) where {N} + TaylorBundle{N}(primal(a) - primal(b), + map(-, a.tangent.coeffs, b.tangent.coeffs)) +end + function (::Diffractor.∂☆new{N})(B::ATB{N, Type{T}}, args::ATB{N}...) where {N, T<:SArray} error("Should have intercepted the constructor") end diff --git a/src/stage1/forward.jl b/src/stage1/forward.jl index dd1da09d..3a90ffaf 100644 --- a/src/stage1/forward.jl +++ b/src/stage1/forward.jl @@ -205,13 +205,15 @@ struct FwdIterate{N, T<:AbstractTangentBundle{N}} end function (f::FwdIterate)(arg::ATB{N}) where {N} r = ∂☆{N}()(f.f, arg) - primal(r) === nothing && return nothing + # `primal(r) === nothing` would work, but doesn't create `Conditional` in inference + isa(r, ATB{N, Nothing}) && return nothing (∂☆{N}()(ZeroBundle{N}(getindex), r, ZeroBundle{N}(1)), primal(∂☆{N}()(ZeroBundle{N}(getindex), r, ZeroBundle{N}(2)))) end -function (f::FwdIterate)(arg::ATB{N}, st) where {N} +@Base.constprop :aggressive function (f::FwdIterate)(arg::ATB{N}, st) where {N} r = ∂☆{N}()(f.f, arg, ZeroBundle{N}(st)) - primal(r) === nothing && return nothing + # `primal(r) === nothing` would work, but doesn't create `Conditional` in inference + isa(r, ATB{N, Nothing}) && return nothing (∂☆{N}()(ZeroBundle{N}(getindex), r, ZeroBundle{N}(1)), primal(∂☆{N}()(ZeroBundle{N}(getindex), r, ZeroBundle{N}(2)))) end diff --git a/src/stage2/forward.jl b/src/stage2/forward.jl index a87de612..8d03ff8f 100644 --- a/src/stage2/forward.jl +++ b/src/stage2/forward.jl @@ -2,7 +2,7 @@ using .CC: compact! # Engineering entry point for the 2nd-order forward AD functionality. This is # unlikely to be the actual interface. For now, it is used for testing. -function dontuse_nth_order_forward_stage2(tt::Type) +function dontuse_nth_order_forward_stage2(tt::Type, order::Int=1) interp = ADInterpreter(; forward=true, backward=false) match = Base._which(tt) frame = Core.Compiler.typeinf_frame(interp, match.method, match.spec_types, match.sparams, #=run_optimizer=#true) @@ -10,26 +10,44 @@ function dontuse_nth_order_forward_stage2(tt::Type) ir = copy((interp.opt[0][frame.linfo].inferred).ir::IRCode) # Find all Return Nodes - vals = SSAValue[] + vals = Pair{SSAValue, Int}[] for i = 1:length(ir.stmts) if isa(ir[SSAValue(i)][:inst], ReturnNode) - push!(vals, SSAValue(i)) + push!(vals, SSAValue(i)=>order) end end - function custom_diff!(ir, ssa, stmt, recurse) + function visit_custom!(ir::IRCode, @nospecialize(stmt), order, recurse) if isa(stmt, ReturnNode) - r = recurse(stmt.val) - ir[ssa][:inst] = ReturnNode(r) - return ssa + recurse(stmt.val) + return true elseif isa(stmt, Argument) - return 1.0 + return true + else + return false end - return nothing end + function transform!(ir::IRCode, ssa::SSAValue, _) + inst = ir[ssa] + stmt = inst[:inst] + if isa(stmt, ReturnNode) + nr = insert_node!(ir, ssa, NewInstruction(Expr(:call, getindex, stmt.val, TaylorTangentIndex(order)), Any)) + inst[:inst] = ReturnNode(nr) + else + error() + end + end + + function transform!(ir::IRCode, arg::Argument, _) + return insert_node!(ir, SSAValue(1), NewInstruction(Expr(:call, ∂xⁿ{order}(), arg), typeof(∂xⁿ{order}()(1.0)))) + end + + irsv = CC.IRInterpretationState(interp, ir, frame.linfo, CC.get_world_counter(interp), ir.argtypes[1:frame.linfo.def.nargs]) - forward_diff!(ir, interp, irsv, vals; custom_diff!) + ir = forward_diff!(ir, interp, frame.linfo, CC.get_world_counter(interp), vals; visit_custom!, transform!) + + display(ir) ir = compact!(ir) return OpaqueClosure(ir) diff --git a/src/stage2/interpreter.jl b/src/stage2/interpreter.jl index 29182107..ecf7f335 100644 --- a/src/stage2/interpreter.jl +++ b/src/stage2/interpreter.jl @@ -32,12 +32,13 @@ end =# using Core.Compiler: AbstractInterpreter, NativeInterpreter, InferenceState, - InferenceResult, CodeInstance, WorldRange + InferenceResult, CodeInstance, WorldRange, ArgInfo, StmtInfo struct ADInterpreter <: AbstractInterpreter # Modes settings forward::Bool backward::Bool + reinference::Bool # This cache is stratified by AD nesting level. Depending on the # nesting level of the derivative, The AD primitives may behave @@ -52,19 +53,25 @@ struct ADInterpreter <: AbstractInterpreter native_interpreter::NativeInterpreter current_level::Int - msgs::Vector{Tuple{Int, MethodInstance, Int, String}} + remarks::OffsetVector{Dict{Union{MethodInstance,InferenceResult}, Cthulhu.PC2Remarks}} end -change_level(interp::ADInterpreter, new_level::Int) = ADInterpreter(interp.opt, interp.unopt, interp.transformed, interp.native_interpreter, new_level, interp.msgs) +change_level(interp::ADInterpreter, new_level::Int) = ADInterpreter(interp.opt, interp.unopt, interp.transformed, interp.native_interpreter, new_level, interp.remarks) raise_level(interp::ADInterpreter) = change_level(interp, interp.current_level + 1) lower_level(interp::ADInterpreter) = change_level(interp, interp.current_level - 1) -disable_forward(interp::ADInterpreter) = ADInterpreter(false, interp.backward, interp.opt, interp.unopt, interp.transformed, interp.native_interpreter, interp.current_level, interp.msgs) +disable_forward(interp::ADInterpreter) = ADInterpreter(false, interp.backward, interp.reinference, interp.opt, interp.unopt, interp.transformed, interp.native_interpreter, interp.current_level, interp.remarks) +disable_reinference(interp::ADInterpreter) = ADInterpreter(interp.forward, interp.backward, false, interp.opt, interp.unopt, interp.transformed, interp.native_interpreter, interp.current_level, interp.remarks) +enable_reinference(interp::ADInterpreter) = ADInterpreter(interp.forward, interp.backward, true, interp.opt, interp.unopt, interp.transformed, interp.native_interpreter, interp.current_level, interp.remarks) -Cthulhu.get_optimized_codeinst(interp::ADInterpreter, curs::ADCursor) = (curs.transformed ? interp.transformed : interp.opt)[curs.level][curs.mi] +function Cthulhu.get_optimized_codeinst(interp::ADInterpreter, curs::ADCursor) + @show curs + (curs.transformed ? interp.transformed : interp.opt)[curs.level][curs.mi] +end Cthulhu.AbstractCursor(interp::ADInterpreter, mi::MethodInstance) = ADCursor(0, mi, false) + # This is a lie, but let's clean this up later -Cthulhu.can_descend(interp::ADInterpreter, mi::MethodInstance, optimize::Bool) = true +Cthulhu.can_descend(interp::ADInterpreter, @nospecialize(key), optimize::Bool) = true function Cthulhu.lookup(interp::ADInterpreter, curs::ADCursor, optimize::Bool; allow_no_src::Bool=false) if !optimize @@ -187,7 +194,7 @@ function Cthulhu.navigate(curs::ADCursor, callsite::Cthulhu.Callsite) return ADCursor(curs.level, Cthulhu.get_mi(callsite)) end -function Cthulhu.process_info(interp::ADInterpreter, @nospecialize(info), argtypes::Cthulhu.ArgTypes, @nospecialize(rt), optimize::Bool) +function Cthulhu.process_info(interp::ADInterpreter, @nospecialize(info::Core.Compiler.CallInfo), argtypes::Cthulhu.ArgTypes, @nospecialize(rt), optimize::Bool) if isa(info, RecurseInfo) newargtypes = argtypes[2:end] callinfos = Cthulhu.process_info(interp, info.info, newargtypes, Cthulhu.unwrapType(widenconst(rt)), optimize) @@ -215,17 +222,17 @@ function Cthulhu.process_info(interp::ADInterpreter, @nospecialize(info), argtyp elseif isa(info, CompClosInfo) return Any[CompClosCallInfo(rt)] end - return invoke(Cthulhu.process_info, Tuple{AbstractInterpreter, Any, Cthulhu.ArgTypes, Any, Bool}, + return invoke(Cthulhu.process_info, Tuple{AbstractInterpreter, Core.Compiler.CallInfo, Cthulhu.ArgTypes, Any, Bool}, interp, info, argtypes, rt, optimize) end -ADInterpreter(;forward = false, backward=true) = ADInterpreter(forward, backward, +ADInterpreter(;forward = false, backward=true, reinference=false) = ADInterpreter(forward, backward, reinference, OffsetVector([Dict{MethodInstance, CodeInstance}(), Dict{MethodInstance, CodeInstance}()], 0:1), OffsetVector([Dict{MethodInstance, Cthulhu.InferredSource}(), Dict{MethodInstance, Cthulhu.InferredSource}()], 0:1), OffsetVector([Dict{MethodInstance, CodeInstance}(), Dict{MethodInstance, CodeInstance}()], 0:1), NativeInterpreter(), 0, - Vector{Tuple{Int, MethodInstance, Int, String}}() + OffsetVector([Dict{Union{MethodInstance,InferenceResult}, Cthulhu.PC2Remarks}()], 0:0) ) ADInterpreter(fg::ADGraph, level) = @@ -246,7 +253,7 @@ end function Core.Compiler.code_cache(ei::ADInterpreter) while ei.current_level > lastindex(ei.opt) - push!(ei.opt, Dict{MethodInstance, Any}())`` + push!(ei.opt, Dict{MethodInstance, Any}()) end ei.opt[ei.current_level] end @@ -264,10 +271,14 @@ function Core.Compiler.get(view::CodeInfoView, mi::MethodInstance, default) return r::CodeInfo end -function Core.Compiler.add_remark!(ei::ADInterpreter, sv::InferenceState, msg) - push!(ei.msgs, (ei.current_level, sv.linfo, sv.currpc, msg)) +function CC.add_remark!(interp::ADInterpreter, sv::InferenceState, msg) + key = CC.any(sv.result.overridden_by_const) ? sv.result : sv.linfo + push!(get!(Cthulhu.PC2Remarks, interp.remarks[interp.current_level], key), sv.currpc=>msg) end +# TODO: `get_remarks` should get a cursor? +Cthulhu.get_remarks(interp::ADInterpreter, key::Union{MethodInstance,InferenceResult}) = get(interp.remarks[interp.current_level], key, nothing) + #= function Core.Compiler.const_prop_heuristic(interp::AbstractInterpreter, method::Method, mi::MethodInstance) return true @@ -290,24 +301,48 @@ function Core.Compiler.transform_result_for_cache(interp::ADInterpreter, return Cthulhu.create_cthulhu_source(result.src, result.ipo_effects) end -#@static if isdefined(Compiler, :is_stmt_inline) -function Core.Compiler.inlining_policy( - interp::ADInterpreter, @nospecialize(src), stmt_flag::UInt8, - mi::MethodInstance, argtypes::Vector{Any}) +function CC.inlining_policy(interp::ADInterpreter, + @nospecialize(src), @nospecialize(info::CC.CallInfo), stmt_flag::UInt8, mi::MethodInstance, argtypes::Vector{Any}) + # Disallow inlining things away that have an frule + if isa(info, FRuleCallInfo) + return nothing + end + if isdefined(CC, :SemiConcreteResult) && isa(src, CC.SemiConcreteResult) + return src + end @assert isa(src, Cthulhu.OptimizedSource) || isnothing(src) if isa(src, Cthulhu.OptimizedSource) - if Core.Compiler.is_stmt_inline(stmt_flag) || src.isinlineable + if CC.is_stmt_inline(stmt_flag) || src.isinlineable return src.ir end else # the default inlining policy may try additional effor to find the source in a local cache - return @invoke Core.Compiler.inlining_policy( - interp::AbstractInterpreter, nothing, stmt_flag::UInt8, - mi::MethodInstance, argtypes::Vector{Any}) + return @invoke CC.inlining_policy(interp::AbstractInterpreter, + nothing, info::CC.CallInfo, stmt_flag::UInt8, mi::MethodInstance, argtypes::Vector{Any}) end return nothing end -#end # @static if isdefined(Compiler, :is_stmt_inline) + +function dummy() end +const dummym = first(methods(dummy)) + +function CC.abstract_call_gf_by_type(interp::ADInterpreter, @nospecialize(f), + arginfo::ArgInfo, si::StmtInfo, @nospecialize(atype), + sv::IRCode, max_methods::Int) + + if interp.reinference + # Create a dummy inference state to serve as the root + # TODO: This is terrible - how can we refactor this to do better? + mi = CC.specialize_method(dummym, Tuple{typeof(dummy)}, Core.svec()) + result = InferenceResult(mi) + interp′ = disable_forward(disable_reinference(interp)) + sv′ = InferenceState(result, :no, interp′) + r = abstract_call_gf_by_type(interp′, f, arginfo, si, atype, sv′, -1) + return r + end + + return CallMeta(Any, CC.Effects(), CC.NoCallInfo()) +end #= function Core.Compiler.optimize(interp::ADInterpreter, opt::OptimizationState, diff --git a/src/stage2/lattice.jl b/src/stage2/lattice.jl index a6c11c08..431106bd 100644 --- a/src/stage2/lattice.jl +++ b/src/stage2/lattice.jl @@ -76,6 +76,10 @@ function Base.show(io::IO, info::FRuleCallInfo) print(io, "FRuleCallInfo(", typeof(info.info), ", ", typeof(info.frule_call.info), ")") end +function Cthulhu.process_info(interp::AbstractInterpreter, info::FRuleCallInfo, argtypes::Cthulhu.ArgTypes, @nospecialize(rt), optimize::Bool) + return Cthulhu.process_info(interp, info.info, argtypes, rt, optimize) +end + # Helpers tuple_type_fields(rt) = isa(rt, PartialStruct) ? rt.fields : widenconst(rt).parameters diff --git a/test/stage2_fwd.jl b/test/stage2_fwd.jl index 6aebc5de..56b60bc2 100644 --- a/test/stage2_fwd.jl +++ b/test/stage2_fwd.jl @@ -1,7 +1,22 @@ module stage2_fwd - using Diffractor, Test + using Diffractor, Test, ChainRulesCore mysin(x) = sin(x) let sin′ = Diffractor.dontuse_nth_order_forward_stage2(Tuple{typeof(mysin), Float64}) @test sin′(1.0) == cos(1.0) end + let sin′′ = Diffractor.dontuse_nth_order_forward_stage2(Tuple{typeof(mysin), Float64}, 2) + @test isa(sin′′, Core.OpaqueClosure{Tuple{Float64}, Float64}) + @test sin′′(1.0) == -sin(1.0) + end + + myminus(a, b) = a - b + @ChainRulesCore.scalar_rule myminus(x, y) (true, -1) + + self_minus(a) = myminus(a, a) + let self_minus′′ = Diffractor.dontuse_nth_order_forward_stage2(Tuple{typeof(self_minus), Float64}, 2) + # TODO: The IR for this currently contains Union{Diffractor.TangentBundle{2, Float64, Diffractor.ExplicitTangent{Tuple{Float64, Float64, Float64}}}, Diffractor.TangentBundle{2, Float64, Diffractor.TaylorTangent{Tuple{Float64, Float64}}}} + # We should have Diffractor be able to prove uniformity + @test isa(self_minus′′, Core.OpaqueClosure{Tuple{Float64}, Float64}) + @test sin′′(1.0) == -sin(1.0) + end end From 064fc724f6c9b71f6a0d14ce1e00ffcc729c4ce7 Mon Sep 17 00:00:00 2001 From: Keno Fischer Date: Thu, 5 Jan 2023 20:34:26 +0000 Subject: [PATCH 2/2] Mark test as broken --- src/stage2/forward.jl | 2 -- test/stage2_fwd.jl | 4 ++-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/stage2/forward.jl b/src/stage2/forward.jl index 8d03ff8f..a4074a88 100644 --- a/src/stage2/forward.jl +++ b/src/stage2/forward.jl @@ -47,8 +47,6 @@ function dontuse_nth_order_forward_stage2(tt::Type, order::Int=1) irsv = CC.IRInterpretationState(interp, ir, frame.linfo, CC.get_world_counter(interp), ir.argtypes[1:frame.linfo.def.nargs]) ir = forward_diff!(ir, interp, frame.linfo, CC.get_world_counter(interp), vals; visit_custom!, transform!) - display(ir) - ir = compact!(ir) return OpaqueClosure(ir) end diff --git a/test/stage2_fwd.jl b/test/stage2_fwd.jl index 56b60bc2..06f21bd0 100644 --- a/test/stage2_fwd.jl +++ b/test/stage2_fwd.jl @@ -16,7 +16,7 @@ module stage2_fwd let self_minus′′ = Diffractor.dontuse_nth_order_forward_stage2(Tuple{typeof(self_minus), Float64}, 2) # TODO: The IR for this currently contains Union{Diffractor.TangentBundle{2, Float64, Diffractor.ExplicitTangent{Tuple{Float64, Float64, Float64}}}, Diffractor.TangentBundle{2, Float64, Diffractor.TaylorTangent{Tuple{Float64, Float64}}}} # We should have Diffractor be able to prove uniformity - @test isa(self_minus′′, Core.OpaqueClosure{Tuple{Float64}, Float64}) - @test sin′′(1.0) == -sin(1.0) + @test_broken isa(self_minus′′, Core.OpaqueClosure{Tuple{Float64}, Float64}) + @test self_minus′′(1.0) == 0. end end