Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Allow inlining method matches with unmatched type parameters #44656

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 42 additions & 27 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -355,11 +355,17 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
boundscheck = :off
end
end
if !validate_sparams(sparam_vals)
sparam_vals = insert_node_here!(compact,
effect_free(NewInstruction(Expr(:call, Core._compute_sparams, item.mi.def, argexprs...), SimpleVector, topline)))
end
# If the iterator already moved on to the next basic block,
# temporarily re-open in again.
local return_value
sig = def.sig
# Special case inlining that maintains the current basic block if there's only one BB in the target
new_new_offset = length(compact.new_new_nodes)
late_fixup_offset = length(compact.late_fixup)
if spec.linear_inline_eligible
#compact[idx] = nothing
inline_compact = IncrementalCompact(compact, spec.ir, compact.result_idx)
Expand All @@ -368,7 +374,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
# face of rename_arguments! mutating in place - should figure out
# something better eventually.
inline_compact[idx′] = nothing
stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck, compact)
stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck, inline_compact)
if isa(stmt′, ReturnNode)
val = stmt′.val
isa(val, SSAValue) && (compact.used_ssas[val.id] += 1)
Expand All @@ -380,7 +386,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
end
inline_compact[idx′] = stmt′
end
just_fixup!(inline_compact)
just_fixup!(inline_compact, new_new_offset, late_fixup_offset)
compact.result_idx = inline_compact.result_idx
else
bb_offset, post_bb_id = popfirst!(todo_bbs)
Expand All @@ -394,7 +400,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
inline_compact = IncrementalCompact(compact, spec.ir, compact.result_idx)
for ((_, idx′), stmt′) in inline_compact
inline_compact[idx′] = nothing
stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck, compact)
stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck, inline_compact)
if isa(stmt′, ReturnNode)
if isdefined(stmt′, :val)
val = stmt′.val
Expand Down Expand Up @@ -425,7 +431,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
end
inline_compact[idx′] = stmt′
end
just_fixup!(inline_compact)
just_fixup!(inline_compact, new_new_offset, late_fixup_offset)
compact.result_idx = inline_compact.result_idx
compact.active_result_bb = inline_compact.active_result_bb
for i = 1:length(pn.values)
Expand Down Expand Up @@ -840,8 +846,7 @@ function analyze_method!(match::MethodMatch, argtypes::Vector{Any},
end
end

# Bail out if any static parameters are left as TypeVar
validate_sparams(match.sparams) || return nothing
#validate_sparams(match.sparams) || return nothing

et = state.et

Expand Down Expand Up @@ -1048,7 +1053,7 @@ function inline_invoke!(
argtypes = invoke_rewrite(sig.argtypes)
if isa(result, InferenceResult)
(; mi) = item = InliningTodo(result, argtypes)
validate_sparams(mi.sparam_vals) || return nothing
#validate_sparams(mi.sparam_vals) || return nothing
if argtypes_to_type(argtypes) <: mi.def.sig
state.mi_cache !== nothing && (item = resolve_todo(item, state, flag))
handle_single_case!(ir, idx, stmt, item, todo, state.params, true)
Expand Down Expand Up @@ -1324,7 +1329,7 @@ function handle_inf_result!(
(; mi) = item = InliningTodo(result, argtypes)
spec_types = mi.specTypes
allow_abstract || isdispatchtuple(spec_types) || return false
validate_sparams(mi.sparam_vals) || return false
#validate_sparams(mi.sparam_vals) || return false
state.mi_cache !== nothing && (item = resolve_todo(item, state, flag))
item === nothing && return false
push!(cases, InliningCase(spec_types, item))
Expand Down Expand Up @@ -1360,7 +1365,6 @@ function handle_const_opaque_closure_call!(
sig::Signature, state::InliningState, todo::Vector{Pair{Int, Any}})
item = InliningTodo(result, sig.argtypes)
isdispatchtuple(item.mi.specTypes) || return
validate_sparams(item.mi.sparam_vals) || return
state.mi_cache !== nothing && (item = resolve_todo(item, state, flag))
handle_single_case!(ir, idx, stmt, item, todo, state.params)
return nothing
Expand Down Expand Up @@ -1539,38 +1543,49 @@ function late_inline_special_case!(
end

function ssa_substitute!(idx::Int, @nospecialize(val), arg_replacements::Vector{Any},
@nospecialize(spsig), spvals::SimpleVector,
@nospecialize(spsig), spvals::Union{SimpleVector, SSAValue},
linetable_offset::Int32, boundscheck::Symbol, compact::IncrementalCompact)
compact.result[idx][:flag] &= ~IR_FLAG_INBOUNDS
compact.result[idx][:line] += linetable_offset
return ssa_substitute_op!(val, arg_replacements, spsig, spvals, boundscheck)
return ssa_substitute_op!(val, arg_replacements, spsig, spvals, boundscheck, compact, idx)
end

function ssa_substitute_op!(@nospecialize(val), arg_replacements::Vector{Any},
@nospecialize(spsig), spvals::SimpleVector, boundscheck::Symbol)
@nospecialize(spsig), spvals::Union{SimpleVector, SSAValue}, boundscheck::Symbol,
compact::IncrementalCompact, idx::Int)
if isa(val, Argument)
return arg_replacements[val.n]
end
if isa(val, Expr)
e = val::Expr
head = e.head
if head === :static_parameter
return quoted(spvals[e.args[1]::Int])
if isa(spvals, SimpleVector)
return quoted(spvals[e.args[1]::Int])
else
ret = insert_node!(compact, SSAValue(idx),
effect_free(NewInstruction(Expr(:call, Core._svec_ref, false, spvals, e.args[1]), Any)))
return ret
end
elseif head === :cfunction
@assert !isa(spsig, UnionAll) || !isempty(spvals)
e.args[3] = ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), e.args[3], spsig, spvals)
e.args[4] = svec(Any[
ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), argt, spsig, spvals)
for argt in e.args[4]::SimpleVector ]...)
if isa(spvals, SimpleVector)
@assert !isa(spsig, UnionAll) || !isempty(spvals)
e.args[3] = ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), e.args[3], spsig, spvals)
e.args[4] = svec(Any[
ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), argt, spsig, spvals)
for argt in e.args[4]::SimpleVector ]...)
end
elseif head === :foreigncall
@assert !isa(spsig, UnionAll) || !isempty(spvals)
for i = 1:length(e.args)
if i == 2
e.args[2] = ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), e.args[2], spsig, spvals)
elseif i == 3
e.args[3] = svec(Any[
ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), argt, spsig, spvals)
for argt in e.args[3]::SimpleVector ]...)
if isa(spvals, SimpleVector)
@assert !isa(spsig, UnionAll) || !isempty(spvals)
for i = 1:length(e.args)
if i == 2
e.args[2] = ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), e.args[2], spsig, spvals)
elseif i == 3
e.args[3] = svec(Any[
ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), argt, spsig, spvals)
for argt in e.args[3]::SimpleVector ]...)
end
end
end
elseif head === :boundscheck
Expand All @@ -1585,7 +1600,7 @@ function ssa_substitute_op!(@nospecialize(val), arg_replacements::Vector{Any},
end
urs = userefs(val)
for op in urs
op[] = ssa_substitute_op!(op[], arg_replacements, spsig, spvals, boundscheck)
op[] = ssa_substitute_op!(op[], arg_replacements, spsig, spvals, boundscheck, compact, idx)
end
return urs[]
end
79 changes: 56 additions & 23 deletions base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -635,8 +635,8 @@ mutable struct IncrementalCompact
pending_perm = Int[]
return new(code, parent.result,
parent.result_bbs, ssa_rename, bb_rename, bb_rename, parent.used_ssas,
late_fixup, perm, 1,
new_new_nodes, pending_nodes, pending_perm,
parent.late_fixup, perm, 1,
parent.new_new_nodes, pending_nodes, pending_perm,
1, result_offset, parent.active_result_bb, false, false, false)
end
end
Expand Down Expand Up @@ -1356,8 +1356,15 @@ function maybe_erase_unused!(
return false
end

function fixup_phinode_values!(compact::IncrementalCompact, old_values::Vector{Any})
struct FixedNode
node::Any
needs_fixup::Bool
FixedNode(@nospecialize(node), needs_fixup::Bool) = new(node, needs_fixup)
end

function fixup_phinode_values!(compact::IncrementalCompact, old_values::Vector{Any}, reify_new_nodes::Bool)
values = Vector{Any}(undef, length(old_values))
needs_fixup = false
for i = 1:length(old_values)
isassigned(old_values, i) || continue
val = old_values[i]
Expand All @@ -1367,28 +1374,43 @@ function fixup_phinode_values!(compact::IncrementalCompact, old_values::Vector{A
compact.used_ssas[val.id] += 1
end
elseif isa(val, NewSSAValue)
val = SSAValue(length(compact.result) + val.id)
if reify_new_nodes
val = SSAValue(length(compact.result) + val.id)
else
needs_fixup = true
end
end
values[i] = val
end
values
FixedNode(values, needs_fixup)
end

function fixup_node(compact::IncrementalCompact, @nospecialize(stmt))
function fixup_node(compact::IncrementalCompact, @nospecialize(stmt), reify_new_nodes::Bool)
if isa(stmt, PhiNode)
return PhiNode(stmt.edges, fixup_phinode_values!(compact, stmt.values))
(;node, needs_fixup) = fixup_phinode_values!(compact, stmt.values, reify_new_nodes)
return FixedNode(PhiNode(stmt.edges, node), needs_fixup)
elseif isa(stmt, PhiCNode)
return PhiCNode(fixup_phinode_values!(compact, stmt.values))
(;node, needs_fixup) = fixup_phinode_values!(compact, stmt.values, reify_new_nodes)
return FixedNode(PhiCNode(node), needs_fixup)
elseif isa(stmt, NewSSAValue)
return SSAValue(length(compact.result) + stmt.id)
if reify_new_nodes
return FixedNode(SSAValue(length(compact.result) + stmt.id), false)
else
return FixedNode(stmt, true)
end
elseif isa(stmt, OldSSAValue)
return compact.ssa_rename[stmt.id]
return FixedNode(compact.ssa_rename[stmt.id], false)
else
urs = userefs(stmt)
needs_fixup = false
for ur in urs
val = ur[]
if isa(val, NewSSAValue)
val = SSAValue(length(compact.result) + val.id)
if reify_new_nodes
val = SSAValue(length(compact.result) + val.id)
else
needs_fixup = true
end
elseif isa(val, OldSSAValue)
val = compact.ssa_rename[val.id]
end
Expand All @@ -1400,22 +1422,33 @@ function fixup_node(compact::IncrementalCompact, @nospecialize(stmt))
end
ur[] = val
end
return urs[]
return FixedNode(urs[], needs_fixup)
end
end

function just_fixup!(compact::IncrementalCompact)
for idx in compact.late_fixup
function just_fixup!(compact::IncrementalCompact, new_new_nodes_offset::Union{Int, Nothing} = nothing, late_fixup_offset::Union{Int, Nothing}=nothing)
off = late_fixup_offset === nothing ? 1 : (late_fixup_offset+1)
set_off = off
for i in off:length(compact.late_fixup)
idx = compact.late_fixup[i]
stmt = compact.result[idx][:inst]
new_stmt = fixup_node(compact, stmt)
(stmt === new_stmt) || (compact.result[idx][:inst] = new_stmt)
end
for idx in 1:length(compact.new_new_nodes)
node = compact.new_new_nodes.stmts[idx]
stmt = node[:inst]
new_stmt = fixup_node(compact, stmt)
if new_stmt !== stmt
node[:inst] = new_stmt
(;node, needs_fixup) = fixup_node(compact, stmt, late_fixup_offset === nothing)
(stmt === node) || (compact.result[idx][:inst] = node)
if needs_fixup
compact.late_fixup[set_off] = idx
set_off += 1
end
end
if late_fixup_offset !== nothing
resize!(compact.late_fixup, set_off-1)
end
off = new_new_nodes_offset === nothing ? 1 : (new_new_nodes_offset+1)
for idx in off:length(compact.new_new_nodes)
new_node = compact.new_new_nodes.stmts[idx]
stmt = new_node[:inst]
(;node) = fixup_node(compact, stmt, late_fixup_offset === nothing)
if node !== stmt
new_node[:inst] = node
end
end
end
Expand Down
82 changes: 82 additions & 0 deletions base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,86 @@ function perform_lifting!(compact::IncrementalCompact,
return stmt_val # N.B. should never happen
end

function lift_svec_ref!(compact, idx, stmt)
if length(stmt.args) != 4
return
end

vec = stmt.args[3]
val = stmt.args[4]
valT = argextype(val, compact)
(isa(valT, Const) && isa(valT.val, Int)) || return
valI = valT.val
(1 <= valI) || return

if isa(vec, SimpleVector)
if valI <= length(val)
compact[idx] = vec[valI]
end
return
end

if isa(vec, SSAValue)
# TODO: We could do the whole lifing machinery here, but really all
# we want to do is clean this up when it got inserted by inlining,
# which always
def = compact[vec]
if is_known_call(def, Core.svec, compact)
nargs = length(def.args)
if valI <= nargs-1
compact[idx] = def.args[valI+1]
end
return
elseif is_known_call(def, Core._compute_sparams, compact)
m = argextype(def.args[2], compact)
isa(m, Const) || return
m = m.val
isa(m, Method) || return
# For now, just pattern match the benchmark case
# TODO: More general structural analysis of the intersection
length(def.args) == 3 || return
sig = m.sig
isa(sig, UnionAll) || return
tvar = sig.var
sig = sig.body
isa(sig, DataType) || return
sig.name === Tuple.name
length(sig.parameters) == 1 || return

arg = sig.parameters[1]
isa(arg, DataType) || return
arg.name === typename(Type) || return
arg = arg.parameters[1]

isa(arg, DataType) || return

rarg = def.args[3]
isa(rarg, SSAValue) || return
argdef = compact[rarg]

is_known_call(argdef, Core.apply_type, compact) || return
length(argdef.args) == 3 || return

applyT = argextype(argdef.args[2], compact)
isa(applyT, Const) || return
applyT = applyT.val

isa(applyT, UnionAll) || return
applyTvar = applyT.var
applyTbody = applyT.body

isa(applyTbody, DataType) || return
applyTbody.name == arg.name || return
length(applyTbody.parameters) == length(arg.parameters) == 1 || return
applyTbody.parameters[1] === applyTvar || return
arg.parameters[1] === tvar || return

compact[idx] = argdef.args[3]
return
end
end
end

# NOTE we use `IdSet{Int}` instead of `BitSet` for in these passes since they work on IR after inlining,
# which can be very large sometimes, and program counters in question are often very sparse
const SPCSet = IdSet{Int}
Expand Down Expand Up @@ -746,6 +826,8 @@ function sroa_pass!(ir::IRCode)
else # TODO: This isn't the best place to put these
if is_known_call(stmt, typeassert, compact)
canonicalize_typeassert!(compact, idx, stmt)
elseif is_known_call(stmt, Core._svec_ref, compact)
lift_svec_ref!(compact, idx, stmt)
elseif is_known_call(stmt, (===), compact)
lift_comparison!(===, compact, idx, stmt, lifting_cache)
elseif is_known_call(stmt, isa, compact)
Expand Down
Loading