Skip to content

Commit

Permalink
Don't insert truncation within phi blocks (#162)
Browse files Browse the repository at this point in the history
* Don't insert truncation within phi blocks

* more efficient search

* clearer test

* change test to check all ading phi nodes

* Start search at start
  • Loading branch information
oxinabox authored Jun 15, 2023
1 parent 10cef62 commit 022865e
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 14 deletions.
6 changes: 4 additions & 2 deletions src/codegen/forward_demand.jl
Original file line number Diff line number Diff line change
Expand Up @@ -240,10 +240,12 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}};
if argorder != order
@assert order < argorder
return get!(truncation_map, arg=>order) do
# identify where to insert. Must be after phi blocks
pos = SSAValue(find_end_of_phi_block(ir, arg.id))
if order == 0
insert_node!(ir, arg, NewInstruction(Expr(:call, primal, arg), Any), #=attach_after=#true)
insert_node!(ir, pos, NewInstruction(Expr(:call, primal, arg), Any), #=attach_after=#true)
else
insert_node!(ir, arg, NewInstruction(Expr(:call, truncate, arg, Val{order}()), Any), #=attach_after=#true)
insert_node!(ir, pos, NewInstruction(Expr(:call, truncate, arg, Val{order}()), Any), #=attach_after=#true)
end
end
end
Expand Down
26 changes: 26 additions & 0 deletions src/stage1/compiler_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,29 @@ Base.lastindex(x::Core.Compiler.InstructionStream) =
if isdefined(Core.Compiler, :CallInfo)
Base.convert(::Type{Core.Compiler.CallInfo}, ::Nothing) = Core.Compiler.NoCallInfo()
end


"""
find_end_of_phi_block(ir, start_search_idx)
Finds the last index within the same basic block, on or after the `start_search_idx` which is not within a phi block.
A phi-block is a run on PhiNodes or nothings that must be the first statements within the basic block.
If `start_search_idx` is not within a phi block to begin with, then just returns `start_search_idx`
"""
function find_end_of_phi_block(ir, start_search_idx::Int)
# Short-cut for early exit:
stmt = ir.stmts[start_search_idx][:inst]
stmt !== nothing && !isa(stmt, PhiNode) && return start_search_idx

# Actually going to have to go digging throught the IR to out if were are in a phi block
# TODO: this is not so efficient. maybe preconstruct CFG then use block_for_inst?
bb=CC.block_for_inst(ir.cfg, start_search_idx)
end_search_idx=ir.cfg.blocks[bb].stmts[end]
for idx in (start_search_idx):(end_search_idx-1)
stmt = ir.stmts[idx+1][:inst]
# next statment is no longer in a phi block, so safe to insert
stmt !== nothing && !isa(stmt, PhiNode) && return idx
end
return end_search_idx
end
50 changes: 38 additions & 12 deletions test/stage2_fwd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,23 @@ module stage2_fwd
g(x) = Diffractor.∂☆{1}()(Diffractor.ZeroBundle{1}(f), Diffractor.TaylorBundle{1}(x, (1.0,)))
Diffractor.∂☆{1}()(Diffractor.ZeroBundle{1}(g), Diffractor.TaylorBundle{1}(10f0, (1.0,)))
end
end


module forward_diff_no_inf # todo: move this to a seperate file
using Diffractor, Test
# this is needed as transform! is *always* called on Arguments regardless of what visit_custom says
identity_transform!(ir, ssa::Core.SSAValue, order) = ir[ssa]
function identity_transform!(ir, arg::Core.Argument, order)
return Core.Compiler.insert_node!(ir, Core.SSAValue(1), Core.Compiler.NewInstruction(Expr(:call, Diffractor.ZeroBundle{1}, arg), Any))
end

@testset "Constructors in forward_diff_no_inf!" begin
struct Bar148
v
end
foo_148(x) = Bar148(x)

# this is needed as transform! is *always* called on Arguments regardless of what visit_custom says
identity_transform!(ir, ssa::Core.SSAValue, order) = ir[ssa]
function identity_transform!(ir, arg::Core.Argument, order)
return Core.Compiler.insert_node!(ir, Core.SSAValue(1), Core.Compiler.NewInstruction(Expr(:call, Diffractor.ZeroBundle{1}, arg), Any))
end

ir = first(only(Base.code_ircode(foo_148, Tuple{Float64})))
Diffractor.forward_diff_no_inf!(ir, [Core.SSAValue(1) => 1]; transform! = identity_transform!)
ir2 = Core.Compiler.compact!(ir)
Expand All @@ -67,17 +71,39 @@ module stage2_fwd
@eval global _coeff::Float64=24.5
plus_a_global(x) = x + _coeff

# this is needed as transform! is *always* called on Arguments regardless of what visit_custom says
identity_transform!(ir, ssa::Core.SSAValue, order) = ir[ssa]
function identity_transform!(ir, arg::Core.Argument, order)
return Core.Compiler.insert_node!(ir, Core.SSAValue(1), Core.Compiler.NewInstruction(Expr(:call, Diffractor.ZeroBundle{1}, arg), Any))
end

ir = first(only(Base.code_ircode(plus_a_global, Tuple{Float64})))
Diffractor.forward_diff_no_inf!(ir, [Core.SSAValue(1) => 1]; transform! = identity_transform!)
ir2 = Core.Compiler.compact!(ir)
Core.Compiler.verify_ir(ir2) # This would error if we were not handling nonconst globals correctly
f = Core.OpaqueClosure(ir2; do_compile=false)
@test f(3.5) == 28.0
end

@testset "runs of phi nodes" begin
function phi_run(x::Float64)
a = 2.0
b = 2.0
if (@noinline rand()) < 0 # this branch will never actually be taken
a = -100.0
b = 200.0
end
return x - a + b
end

input_ir = first(only(Base.code_ircode(phi_run, Tuple{Float64})))
ir = copy(input_ir)
#Workout where to diff to trigger error
diff_ssa = Core.SSAValue[]
for idx in 1:length(ir.stmts)
if ir.stmts[idx][:inst] isa Core.PhiNode
push!(diff_ssa, Core.SSAValue(idx))
end
end

Diffractor.forward_diff_no_inf!(ir, diff_ssa .=> 1; transform! = identity_transform!)
ir2 = Core.Compiler.compact!(ir)
Core.Compiler.verify_ir(ir2) # This would error if we were not handling nonconst phi nodes correctly (after https://github.com/JuliaLang/julia/pull/50158)
f = Core.OpaqueClosure(ir2; do_compile=false)
@test f(3.5) == 3.5 # this will segfault if we are not handling phi nodes correctly
end
end

0 comments on commit 022865e

Please sign in to comment.