Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inference: Model type propagation through exceptions #51754

Merged
merged 6 commits into from
Nov 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions base/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -479,13 +479,13 @@ eval(Core, quote
end)

function CodeInstance(
mi::MethodInstance, @nospecialize(rettype), @nospecialize(inferred_const),
mi::MethodInstance, @nospecialize(rettype), @nospecialize(exctype), @nospecialize(inferred_const),
@nospecialize(inferred), const_flags::Int32, min_world::UInt, max_world::UInt,
ipo_effects::UInt32, effects::UInt32, @nospecialize(analysis_results),
relocatability::UInt8)
return ccall(:jl_new_codeinst, Ref{CodeInstance},
(Any, Any, Any, Any, Int32, UInt, UInt, UInt32, UInt32, Any, UInt8),
mi, rettype, inferred_const, inferred, const_flags, min_world, max_world,
(Any, Any, Any, Any, Any, Int32, UInt, UInt, UInt32, UInt32, Any, UInt8),
mi, rettype, exctype, inferred_const, inferred, const_flags, min_world, max_world,
ipo_effects, effects, analysis_results,
relocatability)
end
Expand Down
309 changes: 202 additions & 107 deletions base/compiler/abstractinterpretation.jl

Large diffs are not rendered by default.

62 changes: 62 additions & 0 deletions base/compiler/effects.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,68 @@ function Effects(effects::Effects = _EFFECTS_UNKNOWN;
nonoverlayed)
end

function is_better_effects(new::Effects, old::Effects)
any_improved = false
if new.consistent == ALWAYS_TRUE
any_improved |= old.consistent != ALWAYS_TRUE
else
if !iszero(new.consistent & CONSISTENT_IF_NOTRETURNED)
old.consistent == ALWAYS_TRUE && return false
any_improved |= iszero(old.consistent & CONSISTENT_IF_NOTRETURNED)
elseif !iszero(new.consistent & CONSISTENT_IF_INACCESSIBLEMEMONLY)
old.consistent == ALWAYS_TRUE && return false
any_improved |= iszero(old.consistent & CONSISTENT_IF_INACCESSIBLEMEMONLY)
else
return false
end
end
if new.effect_free == ALWAYS_TRUE
any_improved |= old.consistent != ALWAYS_TRUE
elseif new.effect_free == EFFECT_FREE_IF_INACCESSIBLEMEMONLY
old.effect_free == ALWAYS_TRUE && return false
any_improved |= old.effect_free != EFFECT_FREE_IF_INACCESSIBLEMEMONLY
elseif new.effect_free != old.effect_free
return false
end
if new.nothrow
any_improved |= !old.nothrow
elseif new.nothrow != old.nothrow
return false
end
if new.terminates
any_improved |= !old.terminates
elseif new.terminates != old.terminates
return false
end
if new.notaskstate
any_improved |= !old.notaskstate
elseif new.notaskstate != old.notaskstate
return false
end
if new.inaccessiblememonly == ALWAYS_TRUE
any_improved |= old.inaccessiblememonly != ALWAYS_TRUE
elseif new.inaccessiblememonly == INACCESSIBLEMEM_OR_ARGMEMONLY
old.inaccessiblememonly == ALWAYS_TRUE && return false
any_improved |= old.inaccessiblememonly != INACCESSIBLEMEM_OR_ARGMEMONLY
elseif new.inaccessiblememonly != old.inaccessiblememonly
return false
end
if new.noub == ALWAYS_TRUE
any_improved |= old.noub != ALWAYS_TRUE
elseif new.noub == NOUB_IF_NOINBOUNDS
old.noub == ALWAYS_TRUE && return false
any_improved |= old.noub != NOUB_IF_NOINBOUNDS
elseif new.noub != old.noub
return false
end
if new.nonoverlayed
any_improved |= !old.nonoverlayed
elseif new.nonoverlayed != old.nonoverlayed
return false
end
return any_improved
end

function merge_effects(old::Effects, new::Effects)
return Effects(
merge_effectbits(old.consistent, new.consistent),
Expand Down
59 changes: 37 additions & 22 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,12 @@ const CACHE_MODE_GLOBAL = 0x01 << 0 # cached globally, optimization allowed
const CACHE_MODE_LOCAL = 0x01 << 1 # cached locally, optimization allowed
const CACHE_MODE_VOLATILE = 0x01 << 2 # not cached, optimization allowed

mutable struct TryCatchFrame
exct
const enter_idx::Int
TryCatchFrame(@nospecialize(exct), enter_idx::Int) = new(exct, enter_idx)
end

mutable struct InferenceState
#= information about this method instance =#
linfo::MethodInstance
Expand All @@ -218,7 +224,8 @@ mutable struct InferenceState
currbb::Int
currpc::Int
ip::BitSet#=TODO BoundedMinPrioritySet=# # current active instruction pointers
handler_at::Vector{Int} # current exception handler info
handlers::Vector{TryCatchFrame}
handler_at::Vector{Tuple{Int, Int}} # tuple of current (handler, exception stack) value at the pc
ssavalue_uses::Vector{BitSet} # ssavalue sparsity and restart info
# TODO: Could keep this sparsely by doing structural liveness analysis ahead of time.
bb_vartables::Vector{Union{Nothing,VarTable}} # nothing if not analyzed yet
Expand All @@ -239,6 +246,7 @@ mutable struct InferenceState
unreachable::BitSet # statements that were found to be statically unreachable
valid_worlds::WorldRange
bestguess #::Type
exc_bestguess
ipo_effects::Effects

#= flags =#
Expand Down Expand Up @@ -266,7 +274,7 @@ mutable struct InferenceState

currbb = currpc = 1
ip = BitSet(1) # TODO BitSetBoundedMinPrioritySet(1)
handler_at = compute_trycatch(code, BitSet())
handler_at, handlers = compute_trycatch(code, BitSet())
nssavalues = src.ssavaluetypes::Int
ssavalue_uses = find_ssavalue_uses(code, nssavalues)
nstmts = length(code)
Expand Down Expand Up @@ -296,6 +304,7 @@ mutable struct InferenceState

valid_worlds = WorldRange(src.min_world, src.max_world == typemax(UInt) ? get_world_counter() : src.max_world)
bestguess = Bottom
exc_bestguess = Bottom
ipo_effects = EFFECTS_TOTAL

insert_coverage = should_insert_coverage(mod, src)
Expand All @@ -315,9 +324,9 @@ mutable struct InferenceState

return new(
linfo, world, mod, sptypes, slottypes, src, cfg, method_info,
currbb, currpc, ip, handler_at, ssavalue_uses, bb_vartables, ssavaluetypes, stmt_edges, stmt_info,
currbb, currpc, ip, handlers, handler_at, ssavalue_uses, bb_vartables, ssavaluetypes, stmt_edges, stmt_info,
pclimitations, limitations, cycle_backedges, callers_in_cycle, dont_work_on_me, parent,
result, unreachable, valid_worlds, bestguess, ipo_effects,
result, unreachable, valid_worlds, bestguess, exc_bestguess, ipo_effects,
restrict_abstract_call_sites, cache_mode, insert_coverage,
interp)
end
Expand Down Expand Up @@ -347,16 +356,19 @@ function compute_trycatch(code::Vector{Any}, ip::BitSet)
empty!(ip)
ip.offset = 0 # for _bits_findnext
push!(ip, n + 1)
handler_at = fill(0, n)
handler_at = fill((0, 0), n)
handlers = TryCatchFrame[]

# start from all :enter statements and record the location of the try
for pc = 1:n
stmt = code[pc]
if isexpr(stmt, :enter)
l = stmt.args[1]::Int
handler_at[pc + 1] = pc
push!(handlers, TryCatchFrame(Bottom, pc))
handler_id = length(handlers)
handler_at[pc + 1] = (handler_id, 0)
push!(ip, pc + 1)
handler_at[l] = pc
handler_at[l] = (handler_id, handler_id)
push!(ip, l)
end
end
Expand All @@ -369,25 +381,26 @@ function compute_trycatch(code::Vector{Any}, ip::BitSet)
while true # inner loop optimizes the common case where it can run straight from pc to pc + 1
pc´ = pc + 1 # next program-counter (after executing instruction)
delete!(ip, pc)
cur_hand = handler_at[pc]
@assert cur_hand != 0 "unbalanced try/catch"
cur_stacks = handler_at[pc]
@assert cur_stacks != (0, 0) "unbalanced try/catch"
stmt = code[pc]
if isa(stmt, GotoNode)
pc´ = stmt.label
elseif isa(stmt, GotoIfNot)
l = stmt.dest::Int
if handler_at[l] != cur_hand
@assert handler_at[l] == 0 "unbalanced try/catch"
handler_at[l] = cur_hand
if handler_at[l] != cur_stacks
@assert handler_at[l][1] == 0 || handler_at[l][1] == cur_stacks[1] "unbalanced try/catch"
handler_at[l] = cur_stacks
push!(ip, l)
end
elseif isa(stmt, ReturnNode)
@assert !isdefined(stmt, :val) "unbalanced try/catch"
@assert !isdefined(stmt, :val) || cur_stacks[1] == 0 "unbalanced try/catch"
break
elseif isa(stmt, Expr)
head = stmt.head
if head === :enter
cur_hand = pc
# Already set above
cur_stacks = (handler_at[pc´][1], cur_stacks[2])
elseif head === :leave
l = 0
for j = 1:length(stmt.args)
Expand All @@ -403,19 +416,21 @@ function compute_trycatch(code::Vector{Any}, ip::BitSet)
end
l += 1
end
cur_hand = cur_stacks[1]
for i = 1:l
cur_hand = handler_at[cur_hand]
cur_hand = handler_at[handlers[cur_hand].enter_idx][1]
end
cur_hand == 0 && break
cur_stacks = (cur_hand, cur_stacks[2])
cur_stacks == (0, 0) && break
elseif head === :pop_exception
cur_stacks = (cur_stacks[1], handler_at[(stmt.args[1]::SSAValue).id][2])
cur_stacks == (0, 0) && break
end
end

pc´ > n && break # can't proceed with the fast-path fall-through
if handler_at[pc´] != cur_hand
if handler_at[pc´] != 0
@assert false "unbalanced try/catch"
end
handler_at[pc´] = cur_hand
if handler_at[pc´] != cur_stacks
handler_at[pc´] = cur_stacks
elseif !in(pc´, ip)
break # already visited
end
Expand All @@ -424,7 +439,7 @@ function compute_trycatch(code::Vector{Any}, ip::BitSet)
end

@assert first(ip) == n + 1
return handler_at
return handler_at, handlers
end

# check if coverage mode is enabled
Expand Down
6 changes: 4 additions & 2 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -925,8 +925,10 @@ function run_passes_ipo_safe(
# @timeit "verify 2" verify_ir(ir)
@pass "compact 2" ir = compact!(ir)
@pass "SROA" ir = sroa_pass!(ir, sv.inlining)
@pass "ADCE" ir = adce_pass!(ir, sv.inlining)
@pass "compact 3" ir = compact!(ir, true)
@pass "ADCE" (ir, made_changes) = adce_pass!(ir, sv.inlining)
if made_changes
@pass "compact 3" ir = compact!(ir, true)
end
if JLOptions().debug_level == 2
@timeit "verify 3" (verify_ir(ir, true, false, optimizer_lattice(sv.inlining.interp)); verify_linetable(ir.linetable))
end
Expand Down
36 changes: 27 additions & 9 deletions base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,7 @@ struct CFGTransformState
result_bbs::Vector{BasicBlock}
bb_rename_pred::Vector{Int}
bb_rename_succ::Vector{Int}
domtree::Union{Nothing, DomTree}
end

# N.B.: Takes ownership of the CFG array
Expand Down Expand Up @@ -622,11 +623,14 @@ function CFGTransformState!(blocks::Vector{BasicBlock}, allow_cfg_transforms::Bo
let blocks = blocks, bb_rename = bb_rename
result_bbs = BasicBlock[blocks[i] for i = 1:length(blocks) if bb_rename[i] != -1]
end
# TODO: This could be done by just renaming the domtree
domtree = construct_domtree(result_bbs)
else
bb_rename = Vector{Int}()
result_bbs = blocks
domtree = nothing
end
return CFGTransformState(allow_cfg_transforms, allow_cfg_transforms, result_bbs, bb_rename, bb_rename)
return CFGTransformState(allow_cfg_transforms, allow_cfg_transforms, result_bbs, bb_rename, bb_rename, domtree)
end

mutable struct IncrementalCompact
Expand Down Expand Up @@ -681,7 +685,7 @@ mutable struct IncrementalCompact
bb_rename = Vector{Int}()
pending_nodes = NewNodeStream()
pending_perm = Int[]
return new(code, parent.result, CFGTransformState(false, false, parent.cfg_transform.result_bbs, bb_rename, bb_rename),
return new(code, parent.result, CFGTransformState(false, false, parent.cfg_transform.result_bbs, bb_rename, bb_rename, nothing),
ssa_rename, parent.used_ssas,
parent.late_fixup, perm, 1,
parent.new_new_nodes, parent.new_new_used_ssas, pending_nodes, pending_perm,
Expand Down Expand Up @@ -942,6 +946,14 @@ function insert_node_here!(compact::IncrementalCompact, newinst::NewInstruction,
return inst
end

function delete_inst_here!(compact)
# Delete the statement, update refcounts etc
compact[SSAValue(compact.result_idx-1)] = nothing
# Pretend that we never compacted this statement in the first place
compact.result_idx -= 1
return nothing
end

function getindex(view::TypesView, v::OldSSAValue)
id = v.id
ir = view.ir.ir
Expand Down Expand Up @@ -1222,19 +1234,25 @@ end

# N.B.: from and to are non-renamed indices
function kill_edge!(compact::IncrementalCompact, active_bb::Int, from::Int, to::Int)
# Note: We recursively kill as many edges as are obviously dead. However, this
# may leave dead loops in the IR. We kill these later in a CFG cleanup pass (or
# worstcase during codegen).
(; bb_rename_pred, bb_rename_succ, result_bbs) = compact.cfg_transform
# Note: We recursively kill as many edges as are obviously dead.
(; bb_rename_pred, bb_rename_succ, result_bbs, domtree) = compact.cfg_transform
preds = result_bbs[bb_rename_succ[to]].preds
succs = result_bbs[bb_rename_pred[from]].succs
deleteat!(preds, findfirst(x::Int->x==bb_rename_pred[from], preds)::Int)
deleteat!(succs, findfirst(x::Int->x==bb_rename_succ[to], succs)::Int)
if domtree !== nothing
domtree_delete_edge!(domtree, result_bbs, bb_rename_pred[from], bb_rename_succ[to])
end
# Check if the block is now dead
if length(preds) == 0
for succ in copy(result_bbs[bb_rename_succ[to]].succs)
kill_edge!(compact, active_bb, to, findfirst(x::Int->x==succ, bb_rename_pred)::Int)
if length(preds) == 0 || (domtree !== nothing && bb_unreachable(domtree, bb_rename_succ[to]))
to_succs = result_bbs[bb_rename_succ[to]].succs
for succ in copy(to_succs)
new_succ = findfirst(x::Int->x==succ, bb_rename_pred)
new_succ === nothing && continue
kill_edge!(compact, active_bb, to, new_succ)
end
empty!(preds)
empty!(to_succs)
if to < active_bb
# Kill all statements in the block
stmts = result_bbs[bb_rename_succ[to]].stmts
Expand Down
Loading