From 6a7e9ccd34b01421e861c5a03ddd7fe74b3b0a2e Mon Sep 17 00:00:00 2001 From: Keno Fischer Date: Thu, 17 Mar 2022 14:21:35 +0000 Subject: [PATCH] WIP: Allow inlining method matches with unmatched type parameters Currently we do not allow inlining any methods that have unmatched type parameters. The original reason for this restriction is that I didn't really know what to put for an inlined :static_parameter, so I just had inlining bail. As a result, in code like: ``` f(x) = Val{x}() ``` the call to `Val{x}()` would not be inlined unless `x` was known through constant propagation. This PR attempts to remidy that. A new builtin is added that computes the static parameters for a given method/argument list. Additionally, sroa gains the ability to simplify and fold this builtin. As a result, inlining can insert an expression that computes the correct values for the inlinees static parameters. The change benchmarks favorably: Before: ``` julia> function foo() for i = 1:10000 Base.donotdelete(Val{i}()) end end foo (generic function with 1 method) julia> @time foo() 0.375567 seconds (4.24 M allocations: 274.440 MiB, 14.67% gc time, 72.96% compilation time) julia> @time foo() 0.012387 seconds (9.49 k allocations: 148.266 KiB) ``` After: ``` julia> function foo() for i = 1:10000 Base.donotdelete(Val{i}()) end end foo (generic function with 1 method) julia> @time foo() 0.003058 seconds (29.47 k allocations: 1.546 MiB) julia> @time foo() 0.001200 seconds (9.49 k allocations: 148.266 KiB) ``` Note that this particular benchmark could also be fixed by #44654, but this change is more general. There is a potential downside, which is that we remove a specialization barrier here. We already do that in the case when all type parameters are matched, so it's not eggregious. However, there is anectodal knowledge in the community that extra type parameters force specialization. Some of that is due to the handling of type parameters in the specialization code, but some of it might also be due to inlining's prior refusal to perform this inlining. We'll have to keep an eye out for any regressions. --- base/compiler/ssair/inlining.jl | 69 ++++++++++++++++----------- base/compiler/ssair/ir.jl | 79 ++++++++++++++++++++++--------- base/compiler/ssair/passes.jl | 82 +++++++++++++++++++++++++++++++++ base/essentials.jl | 8 +--- src/builtin_proto.h | 4 ++ src/builtins.c | 31 +++++++++++++ src/staticdata.c | 2 +- 7 files changed, 217 insertions(+), 58 deletions(-) diff --git a/base/compiler/ssair/inlining.jl b/base/compiler/ssair/inlining.jl index ab93375db4d0e..650acabd8683c 100644 --- a/base/compiler/ssair/inlining.jl +++ b/base/compiler/ssair/inlining.jl @@ -355,11 +355,17 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector boundscheck = :off end end + if !validate_sparams(sparam_vals) + sparam_vals = insert_node_here!(compact, + effect_free(NewInstruction(Expr(:call, Core._compute_sparams, item.mi.def, argexprs...), SimpleVector, topline))) + end # If the iterator already moved on to the next basic block, # temporarily re-open in again. local return_value sig = def.sig # Special case inlining that maintains the current basic block if there's only one BB in the target + new_new_offset = length(compact.new_new_nodes) + late_fixup_offset = length(compact.late_fixup) if spec.linear_inline_eligible #compact[idx] = nothing inline_compact = IncrementalCompact(compact, spec.ir, compact.result_idx) @@ -368,7 +374,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector # face of rename_arguments! mutating in place - should figure out # something better eventually. inline_compact[idx′] = nothing - stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck, compact) + stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck, inline_compact) if isa(stmt′, ReturnNode) val = stmt′.val isa(val, SSAValue) && (compact.used_ssas[val.id] += 1) @@ -380,7 +386,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector end inline_compact[idx′] = stmt′ end - just_fixup!(inline_compact) + just_fixup!(inline_compact, new_new_offset, late_fixup_offset) compact.result_idx = inline_compact.result_idx else bb_offset, post_bb_id = popfirst!(todo_bbs) @@ -394,7 +400,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector inline_compact = IncrementalCompact(compact, spec.ir, compact.result_idx) for ((_, idx′), stmt′) in inline_compact inline_compact[idx′] = nothing - stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck, compact) + stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck, inline_compact) if isa(stmt′, ReturnNode) if isdefined(stmt′, :val) val = stmt′.val @@ -425,7 +431,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector end inline_compact[idx′] = stmt′ end - just_fixup!(inline_compact) + just_fixup!(inline_compact, new_new_offset, late_fixup_offset) compact.result_idx = inline_compact.result_idx compact.active_result_bb = inline_compact.active_result_bb for i = 1:length(pn.values) @@ -840,8 +846,7 @@ function analyze_method!(match::MethodMatch, argtypes::Vector{Any}, end end - # Bail out if any static parameters are left as TypeVar - validate_sparams(match.sparams) || return nothing + #validate_sparams(match.sparams) || return nothing et = state.et @@ -1048,7 +1053,7 @@ function inline_invoke!( argtypes = invoke_rewrite(sig.argtypes) if isa(result, InferenceResult) (; mi) = item = InliningTodo(result, argtypes) - validate_sparams(mi.sparam_vals) || return nothing + #validate_sparams(mi.sparam_vals) || return nothing if argtypes_to_type(argtypes) <: mi.def.sig state.mi_cache !== nothing && (item = resolve_todo(item, state, flag)) handle_single_case!(ir, idx, stmt, item, todo, state.params, true) @@ -1324,7 +1329,7 @@ function handle_inf_result!( (; mi) = item = InliningTodo(result, argtypes) spec_types = mi.specTypes allow_abstract || isdispatchtuple(spec_types) || return false - validate_sparams(mi.sparam_vals) || return false + #validate_sparams(mi.sparam_vals) || return false state.mi_cache !== nothing && (item = resolve_todo(item, state, flag)) item === nothing && return false push!(cases, InliningCase(spec_types, item)) @@ -1360,7 +1365,6 @@ function handle_const_opaque_closure_call!( sig::Signature, state::InliningState, todo::Vector{Pair{Int, Any}}) item = InliningTodo(result, sig.argtypes) isdispatchtuple(item.mi.specTypes) || return - validate_sparams(item.mi.sparam_vals) || return state.mi_cache !== nothing && (item = resolve_todo(item, state, flag)) handle_single_case!(ir, idx, stmt, item, todo, state.params) return nothing @@ -1539,15 +1543,16 @@ function late_inline_special_case!( end function ssa_substitute!(idx::Int, @nospecialize(val), arg_replacements::Vector{Any}, - @nospecialize(spsig), spvals::SimpleVector, + @nospecialize(spsig), spvals::Union{SimpleVector, SSAValue}, linetable_offset::Int32, boundscheck::Symbol, compact::IncrementalCompact) compact.result[idx][:flag] &= ~IR_FLAG_INBOUNDS compact.result[idx][:line] += linetable_offset - return ssa_substitute_op!(val, arg_replacements, spsig, spvals, boundscheck) + return ssa_substitute_op!(val, arg_replacements, spsig, spvals, boundscheck, compact, idx) end function ssa_substitute_op!(@nospecialize(val), arg_replacements::Vector{Any}, - @nospecialize(spsig), spvals::SimpleVector, boundscheck::Symbol) + @nospecialize(spsig), spvals::Union{SimpleVector, SSAValue}, boundscheck::Symbol, + compact::IncrementalCompact, idx::Int) if isa(val, Argument) return arg_replacements[val.n] end @@ -1555,22 +1560,32 @@ function ssa_substitute_op!(@nospecialize(val), arg_replacements::Vector{Any}, e = val::Expr head = e.head if head === :static_parameter - return quoted(spvals[e.args[1]::Int]) + if isa(spvals, SimpleVector) + return quoted(spvals[e.args[1]::Int]) + else + ret = insert_node!(compact, SSAValue(idx), + effect_free(NewInstruction(Expr(:call, Core._svec_ref, false, spvals, e.args[1]), Any))) + return ret + end elseif head === :cfunction - @assert !isa(spsig, UnionAll) || !isempty(spvals) - e.args[3] = ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), e.args[3], spsig, spvals) - e.args[4] = svec(Any[ - ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), argt, spsig, spvals) - for argt in e.args[4]::SimpleVector ]...) + if isa(spvals, SimpleVector) + @assert !isa(spsig, UnionAll) || !isempty(spvals) + e.args[3] = ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), e.args[3], spsig, spvals) + e.args[4] = svec(Any[ + ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), argt, spsig, spvals) + for argt in e.args[4]::SimpleVector ]...) + end elseif head === :foreigncall - @assert !isa(spsig, UnionAll) || !isempty(spvals) - for i = 1:length(e.args) - if i == 2 - e.args[2] = ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), e.args[2], spsig, spvals) - elseif i == 3 - e.args[3] = svec(Any[ - ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), argt, spsig, spvals) - for argt in e.args[3]::SimpleVector ]...) + if isa(spvals, SimpleVector) + @assert !isa(spsig, UnionAll) || !isempty(spvals) + for i = 1:length(e.args) + if i == 2 + e.args[2] = ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), e.args[2], spsig, spvals) + elseif i == 3 + e.args[3] = svec(Any[ + ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), argt, spsig, spvals) + for argt in e.args[3]::SimpleVector ]...) + end end end elseif head === :boundscheck @@ -1585,7 +1600,7 @@ function ssa_substitute_op!(@nospecialize(val), arg_replacements::Vector{Any}, end urs = userefs(val) for op in urs - op[] = ssa_substitute_op!(op[], arg_replacements, spsig, spvals, boundscheck) + op[] = ssa_substitute_op!(op[], arg_replacements, spsig, spvals, boundscheck, compact, idx) end return urs[] end diff --git a/base/compiler/ssair/ir.jl b/base/compiler/ssair/ir.jl index b3f01e8a4a415..4f94f33c5e86a 100644 --- a/base/compiler/ssair/ir.jl +++ b/base/compiler/ssair/ir.jl @@ -635,8 +635,8 @@ mutable struct IncrementalCompact pending_perm = Int[] return new(code, parent.result, parent.result_bbs, ssa_rename, bb_rename, bb_rename, parent.used_ssas, - late_fixup, perm, 1, - new_new_nodes, pending_nodes, pending_perm, + parent.late_fixup, perm, 1, + parent.new_new_nodes, pending_nodes, pending_perm, 1, result_offset, parent.active_result_bb, false, false, false) end end @@ -1356,8 +1356,15 @@ function maybe_erase_unused!( return false end -function fixup_phinode_values!(compact::IncrementalCompact, old_values::Vector{Any}) +struct FixedNode + node::Any + needs_fixup::Bool + FixedNode(@nospecialize(node), needs_fixup::Bool) = new(node, needs_fixup) +end + +function fixup_phinode_values!(compact::IncrementalCompact, old_values::Vector{Any}, reify_new_nodes::Bool) values = Vector{Any}(undef, length(old_values)) + needs_fixup = false for i = 1:length(old_values) isassigned(old_values, i) || continue val = old_values[i] @@ -1367,28 +1374,43 @@ function fixup_phinode_values!(compact::IncrementalCompact, old_values::Vector{A compact.used_ssas[val.id] += 1 end elseif isa(val, NewSSAValue) - val = SSAValue(length(compact.result) + val.id) + if reify_new_nodes + val = SSAValue(length(compact.result) + val.id) + else + needs_fixup = true + end end values[i] = val end - values + FixedNode(values, needs_fixup) end -function fixup_node(compact::IncrementalCompact, @nospecialize(stmt)) +function fixup_node(compact::IncrementalCompact, @nospecialize(stmt), reify_new_nodes::Bool) if isa(stmt, PhiNode) - return PhiNode(stmt.edges, fixup_phinode_values!(compact, stmt.values)) + (;node, needs_fixup) = fixup_phinode_values!(compact, stmt.values, reify_new_nodes) + return FixedNode(PhiNode(stmt.edges, node), needs_fixup) elseif isa(stmt, PhiCNode) - return PhiCNode(fixup_phinode_values!(compact, stmt.values)) + (;node, needs_fixup) = fixup_phinode_values!(compact, stmt.values, reify_new_nodes) + return FixedNode(PhiCNode(node), needs_fixup) elseif isa(stmt, NewSSAValue) - return SSAValue(length(compact.result) + stmt.id) + if reify_new_nodes + return FixedNode(SSAValue(length(compact.result) + stmt.id), false) + else + return FixedNode(stmt, true) + end elseif isa(stmt, OldSSAValue) - return compact.ssa_rename[stmt.id] + return FixedNode(compact.ssa_rename[stmt.id], false) else urs = userefs(stmt) + needs_fixup = false for ur in urs val = ur[] if isa(val, NewSSAValue) - val = SSAValue(length(compact.result) + val.id) + if reify_new_nodes + val = SSAValue(length(compact.result) + val.id) + else + needs_fixup = true + end elseif isa(val, OldSSAValue) val = compact.ssa_rename[val.id] end @@ -1400,22 +1422,33 @@ function fixup_node(compact::IncrementalCompact, @nospecialize(stmt)) end ur[] = val end - return urs[] + return FixedNode(urs[], needs_fixup) end end -function just_fixup!(compact::IncrementalCompact) - for idx in compact.late_fixup +function just_fixup!(compact::IncrementalCompact, new_new_nodes_offset::Union{Int, Nothing} = nothing, late_fixup_offset::Union{Int, Nothing}=nothing) + off = late_fixup_offset === nothing ? 1 : (late_fixup_offset+1) + set_off = off + for i in off:length(compact.late_fixup) + idx = compact.late_fixup[i] stmt = compact.result[idx][:inst] - new_stmt = fixup_node(compact, stmt) - (stmt === new_stmt) || (compact.result[idx][:inst] = new_stmt) - end - for idx in 1:length(compact.new_new_nodes) - node = compact.new_new_nodes.stmts[idx] - stmt = node[:inst] - new_stmt = fixup_node(compact, stmt) - if new_stmt !== stmt - node[:inst] = new_stmt + (;node, needs_fixup) = fixup_node(compact, stmt, late_fixup_offset === nothing) + (stmt === node) || (compact.result[idx][:inst] = node) + if needs_fixup + compact.late_fixup[set_off] = idx + set_off += 1 + end + end + if late_fixup_offset !== nothing + resize!(compact.late_fixup, set_off-1) + end + off = new_new_nodes_offset === nothing ? 1 : (new_new_nodes_offset+1) + for idx in off:length(compact.new_new_nodes) + new_node = compact.new_new_nodes.stmts[idx] + stmt = new_node[:inst] + (;node) = fixup_node(compact, stmt, late_fixup_offset === nothing) + if node !== stmt + new_node[:inst] = node end end end diff --git a/base/compiler/ssair/passes.jl b/base/compiler/ssair/passes.jl index 7aeb303bc03a2..2410ce71bb28c 100644 --- a/base/compiler/ssair/passes.jl +++ b/base/compiler/ssair/passes.jl @@ -652,6 +652,86 @@ function perform_lifting!(compact::IncrementalCompact, return stmt_val # N.B. should never happen end +function lift_svec_ref!(compact, idx, stmt) + if length(stmt.args) != 4 + return + end + + vec = stmt.args[3] + val = stmt.args[4] + valT = argextype(val, compact) + (isa(valT, Const) && isa(valT.val, Int)) || return + valI = valT.val + (1 <= valI) || return + + if isa(vec, SimpleVector) + if valI <= length(val) + compact[idx] = vec[valI] + end + return + end + + if isa(vec, SSAValue) + # TODO: We could do the whole lifing machinery here, but really all + # we want to do is clean this up when it got inserted by inlining, + # which always + def = compact[vec] + if is_known_call(def, Core.svec, compact) + nargs = length(def.args) + if valI <= nargs-1 + compact[idx] = def.args[valI+1] + end + return + elseif is_known_call(def, Core._compute_sparams, compact) + m = argextype(def.args[2], compact) + isa(m, Const) || return + m = m.val + isa(m, Method) || return + # For now, just pattern match the benchmark case + # TODO: More general structural analysis of the intersection + length(def.args) == 3 || return + sig = m.sig + isa(sig, UnionAll) || return + tvar = sig.var + sig = sig.body + isa(sig, DataType) || return + sig.name === Tuple.name + length(sig.parameters) == 1 || return + + arg = sig.parameters[1] + isa(arg, DataType) || return + arg.name === typename(Type) || return + arg = arg.parameters[1] + + isa(arg, DataType) || return + + rarg = def.args[3] + isa(rarg, SSAValue) || return + argdef = compact[rarg] + + is_known_call(argdef, Core.apply_type, compact) || return + length(argdef.args) == 3 || return + + applyT = argextype(argdef.args[2], compact) + isa(applyT, Const) || return + applyT = applyT.val + + isa(applyT, UnionAll) || return + applyTvar = applyT.var + applyTbody = applyT.body + + isa(applyTbody, DataType) || return + applyTbody.name == arg.name || return + length(applyTbody.parameters) == length(arg.parameters) == 1 || return + applyTbody.parameters[1] === applyTvar || return + arg.parameters[1] === tvar || return + + compact[idx] = argdef.args[3] + return + end + end +end + # NOTE we use `IdSet{Int}` instead of `BitSet` for in these passes since they work on IR after inlining, # which can be very large sometimes, and program counters in question are often very sparse const SPCSet = IdSet{Int} @@ -746,6 +826,8 @@ function sroa_pass!(ir::IRCode) else # TODO: This isn't the best place to put these if is_known_call(stmt, typeassert, compact) canonicalize_typeassert!(compact, idx, stmt) + elseif is_known_call(stmt, Core._svec_ref, compact) + lift_svec_ref!(compact, idx, stmt) elseif is_known_call(stmt, (===), compact) lift_comparison!(===, compact, idx, stmt, lifting_cache) elseif is_known_call(stmt, isa, compact) diff --git a/base/essentials.jl b/base/essentials.jl index b837b556ed910..d8852c4c36368 100644 --- a/base/essentials.jl +++ b/base/essentials.jl @@ -605,13 +605,7 @@ end # SimpleVector -function getindex(v::SimpleVector, i::Int) - @boundscheck if !(1 <= i <= length(v)) - throw(BoundsError(v,i)) - end - return ccall(:jl_svec_ref, Any, (Any, Int), v, i - 1) -end - +@eval getindex(v::SimpleVector, i::Int) = Core._svec_ref($(Expr(:boundscheck)), v, i) function length(v::SimpleVector) return ccall(:jl_svec_len, Int, (Any,), v) end diff --git a/src/builtin_proto.h b/src/builtin_proto.h index 46adef8444aa9..eb23ef09bee3a 100644 --- a/src/builtin_proto.h +++ b/src/builtin_proto.h @@ -56,6 +56,8 @@ DECLARE_BUILTIN(typeof); DECLARE_BUILTIN(_typevar); DECLARE_BUILTIN(donotdelete); DECLARE_BUILTIN(getglobal); +DECLARE_BUILTIN(_compute_sparams); +DECLARE_BUILTIN(_svec_ref); JL_CALLABLE(jl_f_invoke_kwsorter); #ifdef DEFINE_BUILTIN_GLOBALS @@ -72,6 +74,8 @@ JL_CALLABLE(jl_f_get_binding_type); JL_CALLABLE(jl_f_set_binding_type); JL_CALLABLE(jl_f_donotdelete); JL_CALLABLE(jl_f_setglobal); +JL_CALLABLE(jl_f__compute_sparams); +JL_CALLABLE(jl_f__svec_ref); #ifdef __cplusplus } diff --git a/src/builtins.c b/src/builtins.c index f81069424d784..7669098f94d61 100644 --- a/src/builtins.c +++ b/src/builtins.c @@ -1586,6 +1586,35 @@ JL_CALLABLE(jl_f_donotdelete) return jl_nothing; } +JL_CALLABLE(jl_f__compute_sparams) +{ + JL_NARGSV(_compute_sparams, 1); + jl_method_t *m = (jl_method_t*)args[0]; + JL_TYPECHK(_compute_sparams, method, (jl_value_t*)m); + jl_datatype_t *tt = jl_inst_arg_tuple_type(args[1], &args[2], nargs-1, 1); + jl_svec_t *env = jl_emptysvec; + JL_GC_PUSH2(&env, &tt); + jl_type_intersection_env((jl_value_t*)tt, m->sig, &env); + JL_GC_POP(); + return (jl_value_t*)env; +} + +JL_CALLABLE(jl_f__svec_ref) +{ + JL_NARGS(_svec_ref, 3, 3); + jl_value_t *b = args[0]; + jl_svec_t *s = (jl_svec_t*)args[1]; + jl_value_t *i = (jl_value_t*)args[2]; + JL_TYPECHK(_svec_ref, bool, b); + JL_TYPECHK(_svec_ref, simplevector, (jl_value_t*)s); + JL_TYPECHK(_svec_ref, long, i); + ssize_t idx = jl_unbox_long(i); + if (idx < 1 || idx > jl_svec_len(s)) { + jl_bounds_error_int(s, i); + } + return jl_svec_ref(s, jl_unbox_long(i)-1); +} + static int equiv_field_types(jl_value_t *old, jl_value_t *ft) { size_t nf = jl_svec_len(ft); @@ -1956,6 +1985,8 @@ void jl_init_primitives(void) JL_GC_DISABLED jl_builtin__typebody = add_builtin_func("_typebody!", jl_f__typebody); add_builtin_func("_equiv_typedef", jl_f__equiv_typedef); jl_builtin_donotdelete = add_builtin_func("donotdelete", jl_f_donotdelete); + add_builtin_func("_compute_sparams", jl_f__compute_sparams); + add_builtin_func("_svec_ref", jl_f__svec_ref); // builtin types add_builtin("Any", (jl_value_t*)jl_any_type); diff --git a/src/staticdata.c b/src/staticdata.c index e72f29257ce19..0d5ccb00b0d79 100644 --- a/src/staticdata.c +++ b/src/staticdata.c @@ -311,7 +311,7 @@ static const jl_fptr_args_t id_to_fptrs[] = { &jl_f_ifelse, &jl_f__structtype, &jl_f__abstracttype, &jl_f__primitivetype, &jl_f__typebody, &jl_f__setsuper, &jl_f__equiv_typedef, &jl_f_get_binding_type, &jl_f_set_binding_type, &jl_f_opaque_closure_call, &jl_f_donotdelete, - &jl_f_getglobal, &jl_f_setglobal, + &jl_f_getglobal, &jl_f_setglobal, &jl_f__compute_sparams, &jl_f__svec_ref, NULL }; typedef struct {