diff --git a/src/codegen/forward_demand.jl b/src/codegen/forward_demand.jl index baa61a9e..acd00bd8 100644 --- a/src/codegen/forward_demand.jl +++ b/src/codegen/forward_demand.jl @@ -240,10 +240,12 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}}; if argorder != order @assert order < argorder return get!(truncation_map, arg=>order) do + # identify where to insert. Must be after phi blocks + pos = SSAValue(find_end_of_phi_block(ir, arg.id)) if order == 0 - insert_node!(ir, arg, NewInstruction(Expr(:call, primal, arg), Any), #=attach_after=#true) + insert_node!(ir, pos, NewInstruction(Expr(:call, primal, arg), Any), #=attach_after=#true) else - insert_node!(ir, arg, NewInstruction(Expr(:call, truncate, arg, Val{order}()), Any), #=attach_after=#true) + insert_node!(ir, pos, NewInstruction(Expr(:call, truncate, arg, Val{order}()), Any), #=attach_after=#true) end end end diff --git a/src/stage1/compiler_utils.jl b/src/stage1/compiler_utils.jl index b237632b..7286792f 100644 --- a/src/stage1/compiler_utils.jl +++ b/src/stage1/compiler_utils.jl @@ -64,3 +64,29 @@ Base.lastindex(x::Core.Compiler.InstructionStream) = if isdefined(Core.Compiler, :CallInfo) Base.convert(::Type{Core.Compiler.CallInfo}, ::Nothing) = Core.Compiler.NoCallInfo() end + + +""" + find_end_of_phi_block(ir, start_search_idx) + +Finds the last index within the same basic block, on or after the `start_search_idx` which is not within a phi block. +A phi-block is a run on PhiNodes or nothings that must be the first statements within the basic block. + +If `start_search_idx` is not within a phi block to begin with, then just returns `start_search_idx` +""" +function find_end_of_phi_block(ir, start_search_idx::Int) + # Short-cut for early exit: + stmt = ir.stmts[start_search_idx][:inst] + stmt !== nothing && !isa(stmt, PhiNode) && return start_search_idx + + # Actually going to have to go digging throught the IR to out if were are in a phi block + # TODO: this is not so efficient. maybe preconstruct CFG then use block_for_inst? + bb=CC.block_for_inst(ir.cfg, start_search_idx) + end_search_idx=ir.cfg.blocks[bb].stmts[end] + for idx in (start_search_idx):(end_search_idx-1) + stmt = ir.stmts[idx+1][:inst] + # next statment is no longer in a phi block, so safe to insert + stmt !== nothing && !isa(stmt, PhiNode) && return idx + end + return end_search_idx +end \ No newline at end of file diff --git a/test/stage2_fwd.jl b/test/stage2_fwd.jl index 996ea93c..4f507cc5 100644 --- a/test/stage2_fwd.jl +++ b/test/stage2_fwd.jl @@ -43,19 +43,23 @@ module stage2_fwd g(x) = Diffractor.∂☆{1}()(Diffractor.ZeroBundle{1}(f), Diffractor.TaylorBundle{1}(x, (1.0,))) Diffractor.∂☆{1}()(Diffractor.ZeroBundle{1}(g), Diffractor.TaylorBundle{1}(10f0, (1.0,))) end +end + +module forward_diff_no_inf # todo: move this to a seperate file + using Diffractor, Test + # this is needed as transform! is *always* called on Arguments regardless of what visit_custom says + identity_transform!(ir, ssa::Core.SSAValue, order) = ir[ssa] + function identity_transform!(ir, arg::Core.Argument, order) + return Core.Compiler.insert_node!(ir, Core.SSAValue(1), Core.Compiler.NewInstruction(Expr(:call, Diffractor.ZeroBundle{1}, arg), Any)) + end + @testset "Constructors in forward_diff_no_inf!" begin struct Bar148 v end foo_148(x) = Bar148(x) - # this is needed as transform! is *always* called on Arguments regardless of what visit_custom says - identity_transform!(ir, ssa::Core.SSAValue, order) = ir[ssa] - function identity_transform!(ir, arg::Core.Argument, order) - return Core.Compiler.insert_node!(ir, Core.SSAValue(1), Core.Compiler.NewInstruction(Expr(:call, Diffractor.ZeroBundle{1}, arg), Any)) - end - ir = first(only(Base.code_ircode(foo_148, Tuple{Float64}))) Diffractor.forward_diff_no_inf!(ir, [Core.SSAValue(1) => 1]; transform! = identity_transform!) ir2 = Core.Compiler.compact!(ir) @@ -67,12 +71,6 @@ module stage2_fwd @eval global _coeff::Float64=24.5 plus_a_global(x) = x + _coeff - # this is needed as transform! is *always* called on Arguments regardless of what visit_custom says - identity_transform!(ir, ssa::Core.SSAValue, order) = ir[ssa] - function identity_transform!(ir, arg::Core.Argument, order) - return Core.Compiler.insert_node!(ir, Core.SSAValue(1), Core.Compiler.NewInstruction(Expr(:call, Diffractor.ZeroBundle{1}, arg), Any)) - end - ir = first(only(Base.code_ircode(plus_a_global, Tuple{Float64}))) Diffractor.forward_diff_no_inf!(ir, [Core.SSAValue(1) => 1]; transform! = identity_transform!) ir2 = Core.Compiler.compact!(ir) @@ -80,4 +78,32 @@ module stage2_fwd f = Core.OpaqueClosure(ir2; do_compile=false) @test f(3.5) == 28.0 end + + @testset "runs of phi nodes" begin + function phi_run(x::Float64) + a = 2.0 + b = 2.0 + if (@noinline rand()) < 0 # this branch will never actually be taken + a = -100.0 + b = 200.0 + end + return x - a + b + end + + input_ir = first(only(Base.code_ircode(phi_run, Tuple{Float64}))) + ir = copy(input_ir) + #Workout where to diff to trigger error + diff_ssa = Core.SSAValue[] + for idx in 1:length(ir.stmts) + if ir.stmts[idx][:inst] isa Core.PhiNode + push!(diff_ssa, Core.SSAValue(idx)) + end + end + + Diffractor.forward_diff_no_inf!(ir, diff_ssa .=> 1; transform! = identity_transform!) + ir2 = Core.Compiler.compact!(ir) + Core.Compiler.verify_ir(ir2) # This would error if we were not handling nonconst phi nodes correctly (after https://github.com/JuliaLang/julia/pull/50158) + f = Core.OpaqueClosure(ir2; do_compile=false) + @test f(3.5) == 3.5 # this will segfault if we are not handling phi nodes correctly + end end