diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index b88b426aaa4549..1dec4763c4ea48 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -52,6 +52,7 @@ end function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), arginfo::ArgInfo, @nospecialize(atype), sv::InferenceState, max_methods::Int) + ⊑ᵢ = ⊑(typeinf_lattice(interp)) if !should_infer_this_call(sv) add_remark!(interp, sv, "Skipped call in throw block") nonoverlayed = false @@ -132,7 +133,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), f, this_arginfo, match, sv) const_result = nothing if const_call_result !== nothing - if const_call_result.rt ⊑ rt + if const_call_result.rt ⊑ᵢ rt rt = const_call_result.rt (; effects, const_result) = const_call_result end @@ -179,7 +180,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), this_const_rt = widenwrappedconditional(const_call_result.rt) # return type of const-prop' inference can be wider than that of non const-prop' inference # e.g. in cases when there are cycles but cached result is still accurate - if this_const_rt ⊑ this_rt + if this_const_rt ⊑ᵢ this_rt this_conditional = this_const_conditional this_rt = this_const_rt (; effects, const_result) = const_call_result @@ -191,8 +192,8 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), end @assert !(this_conditional isa Conditional) "invalid lattice element returned from inter-procedural context" seen += 1 - rettype = tmerge(rettype, this_rt) - if this_conditional !== Bottom && is_lattice_bool(rettype) && fargs !== nothing + rettype = tmerge(ipo_lattice(interp), rettype, this_rt) + if this_conditional !== Bottom && is_lattice_bool(ipo_lattice(interp), rettype) && fargs !== nothing if conditionals === nothing conditionals = Any[Bottom for _ in 1:length(argtypes)], Any[Bottom for _ in 1:length(argtypes)] @@ -223,7 +224,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), all_effects = Effects(all_effects; nothrow=false) end - rettype = from_interprocedural!(rettype, sv, arginfo, conditionals) + rettype = from_interprocedural!(ipo_lattice(interp), rettype, sv, arginfo, conditionals) if call_result_unused(sv) && !(rettype === Bottom) add_remark!(interp, sv, "Call result type was widened because the return value is unused") @@ -346,7 +347,7 @@ function find_matching_methods(argtypes::Vector{Any}, @nospecialize(atype), meth end """ - from_interprocedural!(rt, sv::InferenceState, arginfo::ArgInfo, maybecondinfo) -> newrt + from_interprocedural!(ipo_lattice::AbstractLattice, rt, sv::InferenceState, arginfo::ArgInfo, maybecondinfo) -> newrt Converts inter-procedural return type `rt` into a local lattice element `newrt`, that is appropriate in the context of current local analysis frame `sv`, especially: @@ -365,13 +366,13 @@ In such cases `maybecondinfo` should be either of: When we deal with multiple `MethodMatch`es, it's better to precompute `maybecondinfo` by `tmerge`ing argument signature type of each method call. """ -function from_interprocedural!(@nospecialize(rt), sv::InferenceState, arginfo::ArgInfo, @nospecialize(maybecondinfo)) +function from_interprocedural!(ipo_lattice::AbstractLattice, @nospecialize(rt), sv::InferenceState, arginfo::ArgInfo, @nospecialize(maybecondinfo)) rt = collect_limitations!(rt, sv) - if is_lattice_bool(rt) + if is_lattice_bool(ipo_lattice, rt) if maybecondinfo === nothing rt = widenconditional(rt) else - rt = from_interconditional(rt, sv, arginfo, maybecondinfo) + rt = from_interconditional(ipo_lattice, rt, sv, arginfo, maybecondinfo) end end @assert !(rt isa InterConditional) "invalid lattice element returned from inter-procedural context" @@ -386,7 +387,9 @@ function collect_limitations!(@nospecialize(typ), sv::InferenceState) return typ end -function from_interconditional(@nospecialize(typ), sv::InferenceState, (; fargs, argtypes)::ArgInfo, @nospecialize(maybecondinfo)) +function from_interconditional(ipo_lattice::AbstractLattice, @nospecialize(typ), + sv::InferenceState, (; fargs, argtypes)::ArgInfo, @nospecialize(maybecondinfo)) + lattice = widen(ipo_lattice) fargs === nothing && return widenconditional(typ) slot = 0 thentype = elsetype = Any @@ -412,21 +415,21 @@ function from_interconditional(@nospecialize(typ), sv::InferenceState, (; fargs, end if condval === false thentype = Bottom - elseif new_thentype ⊑ thentype + elseif ⊑(lattice, new_thentype, thentype) thentype = new_thentype else - thentype = tmeet(thentype, widenconst(new_thentype)) + thentype = tmeet(lattice, thentype, widenconst(new_thentype)) end if condval === true elsetype = Bottom - elseif new_elsetype ⊑ elsetype + elseif ⊑(lattice, new_elsetype, elsetype) elsetype = new_elsetype else - elsetype = tmeet(elsetype, widenconst(new_elsetype)) + elsetype = tmeet(lattice, elsetype, widenconst(new_elsetype)) end - if (slot > 0 || condval !== false) && thentype ⋤ old + if (slot > 0 || condval !== false) && ⋤(lattice, thentype, old) slot = id - elseif (slot > 0 || condval !== true) && elsetype ⋤ old + elseif (slot > 0 || condval !== true) && ⋤(lattice, elsetype, old) slot = id else # reset: no new useful information for this slot thentype = elsetype = Any @@ -1038,8 +1041,9 @@ function force_const_prop(interp::AbstractInterpreter, @nospecialize(f), method: end function const_prop_function_heuristic( - _::AbstractInterpreter, @nospecialize(f), (; argtypes)::ArgInfo, + interp::AbstractInterpreter, @nospecialize(f), (; argtypes)::ArgInfo, nargs::Int, all_overridden::Bool, still_nothrow::Bool, _::InferenceState) + ⊑ᵢ = ⊑(typeinf_lattice(interp)) if nargs > 1 if istopfunction(f, :getindex) || istopfunction(f, :setindex!) arrty = argtypes[2] @@ -1050,12 +1054,12 @@ function const_prop_function_heuristic( if !still_nothrow || ismutabletype(arrty) return false end - elseif arrty ⊑ Array + elseif arrty ⊑ᵢ Array return false end elseif istopfunction(f, :iterate) itrty = argtypes[2] - if itrty ⊑ Array + if itrty ⊑ᵢ Array return false end end @@ -1283,7 +1287,7 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n nstatetype = getfield_tfunc(stateordonet, Const(2)) # If there's no new information in this statetype, don't bother continuing, # the iterator won't be finite. - if nstatetype ⊑ statetype + if ⊑(typeinf_lattice(interp), nstatetype, statetype) return Any[Bottom], nothing end valtype = getfield_tfunc(stateordonet, Const(1)) @@ -1453,6 +1457,8 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, (; fargs sv::InferenceState, max_methods::Int) @nospecialize f la = length(argtypes) + lattice = typeinf_lattice(interp) + ⊑ᵢ = ⊑(lattice) if f === Core.ifelse && fargs isa Vector{Any} && la == 4 cnd = argtypes[2] if isa(cnd, Conditional) @@ -1467,12 +1473,12 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, (; fargs a = ssa_def_slot(fargs[3], sv) b = ssa_def_slot(fargs[4], sv) if isa(a, SlotNumber) && cnd.slot == slot_id(a) - tx = (cnd.thentype ⊑ tx ? cnd.thentype : tmeet(tx, widenconst(cnd.thentype))) + tx = (cnd.thentype ⊑ᵢ tx ? cnd.thentype : tmeet(lattice, tx, widenconst(cnd.thentype))) end if isa(b, SlotNumber) && cnd.slot == slot_id(b) - ty = (cnd.elsetype ⊑ ty ? cnd.elsetype : tmeet(ty, widenconst(cnd.elsetype))) + ty = (cnd.elsetype ⊑ᵢ ty ? cnd.elsetype : tmeet(lattice, ty, widenconst(cnd.elsetype))) end - return tmerge(tx, ty) + return tmerge(lattice, tx, ty) end end end @@ -1652,12 +1658,12 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn overlayed ? nothing : singleton_type(ft′), arginfo, match, sv) const_result = nothing if const_call_result !== nothing - if const_call_result.rt ⊑ rt + if ⊑(typeinf_lattice(interp), const_call_result.rt, rt) (; rt, effects, const_result) = const_call_result end end effects = Effects(effects; nonoverlayed=!overlayed) - return CallMeta(from_interprocedural!(rt, sv, arginfo, sig), effects, InvokeCallInfo(match, const_result)) + return CallMeta(from_interprocedural!(ipo_lattice(interp), rt, sv, arginfo, sig), effects, InvokeCallInfo(match, const_result)) end function invoke_rewrite(xs::Vector{Any}) @@ -1805,15 +1811,17 @@ function abstract_call_opaque_closure(interp::AbstractInterpreter, end end info = OpaqueClosureCallInfo(match, const_result) + ipo = ipo_lattice(interp) + ⊑ₚ = ⊑(ipo) if check # analyze implicit type asserts on argument and return type ftt = closure.typ (aty, rty) = (unwrap_unionall(ftt)::DataType).parameters rty = rewrap_unionall(rty isa TypeVar ? rty.lb : rty, ftt) - if !(rt ⊑ rty && tuple_tfunc(arginfo.argtypes[2:end]) ⊑ rewrap_unionall(aty, ftt)) + if !(rt ⊑ₚ rty && tuple_tfunc(arginfo.argtypes[2:end]) ⊑ₚ rewrap_unionall(aty, ftt)) effects = Effects(effects; nothrow=false) end end - rt = from_interprocedural!(rt, sv, arginfo, match.spec_types) + rt = from_interprocedural!(ipo, rt, sv, arginfo, match.spec_types) return CallMeta(rt, effects, info) end @@ -1969,6 +1977,7 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e), end e = e::Expr ehead = e.head + ⊑ᵢ = ⊑(typeinf_lattice(interp)) if ehead === :call ea = e.args argtypes = collect_argtypes(interp, ea, vtypes, sv) @@ -1994,8 +2003,8 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e), for i = 1:nargs at = widenconditional(abstract_eval_value(interp, e.args[i+1], vtypes, sv)) ft = fieldtype(t, i) - nothrow && (nothrow = at ⊑ ft) - at = tmeet(at, ft) + nothrow && (nothrow = at ⊑ᵢ ft) + at = tmeet(typeinf_lattice(interp), at, ft) at === Bottom && @goto always_throw if ismutable && !isconst(t, i) ats[i] = ft # can't constrain this field (as it may be modified later) @@ -2048,8 +2057,8 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e), let t = t, at = at; _all(i->getfield(at.val::Tuple, i) isa fieldtype(t, i), 1:n); end nothrow = isexact && isconcretedispatch(t) t = Const(ccall(:jl_new_structt, Any, (Any, Any), t, at.val)) - elseif isa(at, PartialStruct) && at ⊑ Tuple && n == length(at.fields::Vector{Any}) && - let t = t, at = at; _all(i->(at.fields::Vector{Any})[i] ⊑ fieldtype(t, i), 1:n); end + elseif isa(at, PartialStruct) && at ⊑ᵢ Tuple && n == length(at.fields::Vector{Any}) && + let t = t, at = at; _all(i->(at.fields::Vector{Any})[i] ⊑ᵢ fieldtype(t, i), 1:n); end nothrow = isexact && isconcretedispatch(t) t = PartialStruct(t, at.fields::Vector{Any}) end @@ -2215,8 +2224,11 @@ function abstract_eval_ssavalue(s::SSAValue, ssavaluetypes::Vector{Any}) return typ end -function widenreturn(@nospecialize(rt), @nospecialize(bestguess), nargs::Int, slottypes::Vector{Any}, changes::VarTable) - if !(bestguess ⊑ Bool) || bestguess === Bool +function widenreturn(ipo_lattice::AbstractLattice, @nospecialize(rt), @nospecialize(bestguess), nargs::Int, slottypes::Vector{Any}, changes::VarTable) + ⊑ₚ = ⊑(ipo_lattice) + inner_lattice = widen(ipo_lattice) + ⊑ᵢ = ⊑(inner_lattice) + if !(bestguess ⊑ₚ Bool) || bestguess === Bool # give up inter-procedural constraint back-propagation # when tmerge would widen the result anyways (as an optimization) rt = widenconditional(rt) @@ -2225,8 +2237,8 @@ function widenreturn(@nospecialize(rt), @nospecialize(bestguess), nargs::Int, sl id = rt.slot if 1 ≤ id ≤ nargs old_id_type = widenconditional(slottypes[id]) # same as `(states[1]::VarTable)[id].typ` - if (!(rt.thentype ⊑ old_id_type) || old_id_type ⊑ rt.thentype) && - (!(rt.elsetype ⊑ old_id_type) || old_id_type ⊑ rt.elsetype) + if (!(rt.thentype ⊑ᵢ old_id_type) || old_id_type ⊑ᵢ rt.thentype) && + (!(rt.elsetype ⊑ᵢ old_id_type) || old_id_type ⊑ᵢ rt.elsetype) # discard this `Conditional` since it imposes # no new constraint on the argument type # (the caller will recreate it if needed) @@ -2241,7 +2253,7 @@ function widenreturn(@nospecialize(rt), @nospecialize(bestguess), nargs::Int, sl end if isa(rt, Conditional) rt = InterConditional(rt.slot, rt.thentype, rt.elsetype) - elseif is_lattice_bool(rt) + elseif is_lattice_bool(inner_lattice, rt) if isa(bestguess, InterConditional) # if the bestguess so far is already `Conditional`, try to convert # this `rt` into `Conditional` on the slot to avoid overapproximation @@ -2263,10 +2275,10 @@ function widenreturn(@nospecialize(rt), @nospecialize(bestguess), nargs::Int, sl # and is valid and good inter-procedurally isa(rt, Conditional) && return InterConditional(rt) isa(rt, InterConditional) && return rt - return widenreturn_noconditional(rt) + return widenreturn_noconditional(widen(ipo_lattice), rt) end -function widenreturn_noconditional(@nospecialize(rt)) +function widenreturn_noconditional(inner_lattice::AbstractLattice, @nospecialize(rt)) isa(rt, Const) && return rt isa(rt, Type) && return rt if isa(rt, PartialStruct) @@ -2274,11 +2286,11 @@ function widenreturn_noconditional(@nospecialize(rt)) local anyrefine = false for i in 1:length(fields) a = fields[i] - a = isvarargtype(a) ? a : widenreturn_noconditional(a) + a = isvarargtype(a) ? a : widenreturn_noconditional(inner_lattice, a) if !anyrefine # TODO: consider adding && const_prop_profitable(a) here? anyrefine = has_const_info(a) || - a ⊏ fieldtype(rt.typ, i) + ⊏(inner_lattice, a, fieldtype(rt.typ, i)) end fields[i] = a end @@ -2350,7 +2362,7 @@ end end end -function update_bbstate!(frame::InferenceState, bb::Int, vartable::VarTable) +function update_bbstate!(lattice::AbstractLattice, frame::InferenceState, bb::Int, vartable::VarTable) bbtable = frame.bb_vartables[bb] if bbtable === nothing # if a basic block hasn't been analyzed yet, @@ -2358,7 +2370,7 @@ function update_bbstate!(frame::InferenceState, bb::Int, vartable::VarTable) frame.bb_vartables[bb] = copy(vartable) return true else - return stupdate!(bbtable, vartable) + return stupdate!(lattice, bbtable, vartable) end end @@ -2455,13 +2467,13 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) else false_vartable = currstate end - changed = update_bbstate!(frame, falsebb, false_vartable) + changed = update_bbstate!(typeinf_lattice(interp), frame, falsebb, false_vartable) then_change = conditional_change(currstate, condt.thentype, condt.slot) if then_change !== nothing stoverwrite1!(currstate, then_change) end else - changed = update_bbstate!(frame, falsebb, currstate) + changed = update_bbstate!(typeinf_lattice(interp), frame, falsebb, currstate) end if changed handle_control_backedge!(frame, currpc, stmt.dest) @@ -2473,7 +2485,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) elseif isa(stmt, ReturnNode) bestguess = frame.bestguess rt = abstract_eval_value(interp, stmt.val, currstate, frame) - rt = widenreturn(rt, bestguess, nargs, slottypes, currstate) + rt = widenreturn(ipo_lattice(interp), rt, bestguess, nargs, slottypes, currstate) # narrow representation of bestguess slightly to prepare for tmerge with rt if rt isa InterConditional && bestguess isa Const let slot_id = rt.slot @@ -2493,9 +2505,9 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) if !isempty(frame.limitations) rt = LimitedAccuracy(rt, copy(frame.limitations)) end - if tchanged(rt, bestguess) + if tchanged(ipo_lattice(interp), rt, bestguess) # new (wider) return type for frame - bestguess = tmerge(bestguess, rt) + bestguess = tmerge(ipo_lattice(interp), bestguess, rt) # TODO: if bestguess isa InterConditional && !interesting(bestguess); bestguess = widenconditional(bestguess); end frame.bestguess = bestguess for (caller, caller_pc) in frame.cycle_backedges @@ -2511,7 +2523,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) # Propagate entry info to exception handler l = stmt.args[1]::Int catchbb = block_for_inst(frame.cfg, l) - if update_bbstate!(frame, catchbb, currstate) + if update_bbstate!(typeinf_lattice(interp), frame, catchbb, currstate) push!(W, catchbb) end ssavaluetypes[currpc] = Any @@ -2536,7 +2548,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) # propagate new type info to exception handler # the handling for Expr(:enter) propagates all changes from before the try/catch # so this only needs to propagate any changes - if stupdate1!(states[exceptbb]::VarTable, changes) + if stupdate1!(typeinf_lattice(interp), states[exceptbb]::VarTable, changes) push!(W, exceptbb) end cur_hand = frame.handler_at[cur_hand] @@ -2561,7 +2573,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) # Case 2: Directly branch to a different BB begin @label branch - if update_bbstate!(frame, nextbb, currstate) + if update_bbstate!(typeinf_lattice(interp), frame, nextbb, currstate) push!(W, nextbb) end end diff --git a/base/compiler/abstractlattice.jl b/base/compiler/abstractlattice.jl new file mode 100644 index 00000000000000..613e7a9202eb91 --- /dev/null +++ b/base/compiler/abstractlattice.jl @@ -0,0 +1,163 @@ +abstract type AbstractLattice; end +function widen end + +""" + struct JLTypeLattice + +A singleton type representing the lattice of Julia types, without any inference +extensions. +""" +struct JLTypeLattice <: AbstractLattice; end +widen(::JLTypeLattice) = error("Type lattice is the least-precise lattice available") +is_valid_lattice(::JLTypeLattice, @nospecialize(elem)) = isa(elem, Type) + +""" + struct ConstsLattice + +A lattice extending `JLTypeLattice` and adjoining `Const` and `PartialTypeVar`. +""" +struct ConstsLattice <: AbstractLattice; end +widen(::ConstsLattice) = JLTypeLattice() +is_valid_lattice(lattice::ConstsLattice, @nospecialize(elem)) = + is_valid_lattice(widen(lattice), elem) || isa(elem, Const) || isa(elem, PartialTypeVar) + +""" + struct PartialsLattice{L} + +A lattice extending lattice `L` and adjoining `PartialStruct` and `PartialOpaque`. +""" +struct PartialsLattice{L <: AbstractLattice} <: AbstractLattice + parent::L +end +widen(L::PartialsLattice) = L.parent +is_valid_lattice(lattice::PartialsLattice, @nospecialize(elem)) = + is_valid_lattice(widen(lattice), elem) || + isa(elem, PartialStruct) || isa(elem, PartialOpaque) + +""" + struct ConditionalsLattice{L} + +A lattice extending lattice `L` and adjoining `Conditional`. +""" +struct ConditionalsLattice{L <: AbstractLattice} <: AbstractLattice + parent::L +end +widen(L::ConditionalsLattice) = L.parent +is_valid_lattice(lattice::ConditionalsLattice, @nospecialize(elem)) = + is_valid_lattice(widen(lattice), elem) || isa(elem, Conditional) + +struct InterConditionalsLattice{L <: AbstractLattice} <: AbstractLattice + parent::L +end +widen(L::InterConditionalsLattice) = L.parent +is_valid_lattice(lattice::InterConditionalsLattice, @nospecialize(elem)) = + is_valid_lattice(widen(lattice), elem) || isa(elem, InterConditional) + +const AnyConditionalsLattice{L} = Union{ConditionalsLattice{L}, InterConditionalsLattice{L}} +const BaseInferenceLattice = typeof(ConditionalsLattice(PartialsLattice(ConstsLattice()))) +const IPOResultLattice = typeof(InterConditionalsLattice(PartialsLattice(ConstsLattice()))) + +""" + struct OptimizerLattice + +The lattice used by the optimizer. Extends +`BaseInferenceLattice` with `MaybeUndef`. +""" +struct OptimizerLattice <: AbstractLattice; end +widen(L::OptimizerLattice) = BaseInferenceLattice.instance +is_valid_lattice(lattice::OptimizerLattice, @nospecialize(elem)) = + is_valid_lattice(widen(lattice), elem) || isa(elem, MaybeUndef) + +""" + struct InferenceLattice{L} + +The full lattice used for abstract interpration during inference. Takes +a base lattice and adjoins `LimitedAccuracy`. +""" +struct InferenceLattice{L} <: AbstractLattice + parent::L +end +widen(L::InferenceLattice) = L.parent +is_valid_lattice(lattice::InferenceLattice, @nospecialize(elem)) = + is_valid_lattice(widen(lattice), elem) || isa(elem, LimitedAccuracy) + +""" + tmeet(lattice, a, b::Type) + +Compute the lattice meet of lattice elements `a` and `b` over the lattice +`lattice`. If `lattice` is `JLTypeLattice`, this is equiavalent to type +intersection. Note that currently `b` is restricted to being a type (interpreted +as a lattice element in the JLTypeLattice sub-lattice of `lattice`). +""" +function tmeet end + +function tmeet(::JLTypeLattice, @nospecialize(a::Type), @nospecialize(b::Type)) + ti = typeintersect(a, b) + valid_as_lattice(ti) || return Bottom + return ti +end + +""" + tmerge(lattice, a, b) + +Compute a lattice join of elements `a` and `b` over the lattice `lattice`. +Note that the computed element need not be the least upper bound of `a` and +`b`, but rather, we impose some heuristic limits on the complexity of the +joined element, ideally without losing too much precision in common cases and +remaining mostly associative and commutative. +""" +function tmerge end + +""" + ⊑(lattice, a, b) + +Compute the lattice ordering (i.e. less-than-or-equal) relationship between +lattice elements `a` and `b` over the lattice `lattice`. If `lattice` is +`JLTypeLattice`, this is equiavalent to subtyping. +""" +function ⊑ end + +⊑(::JLTypeLattice, @nospecialize(a::Type), @nospecialize(b::Type)) = a <: b + +""" + ⊏(lattice, a, b) -> Bool + +The strict partial order over the type inference lattice. +This is defined as the irreflexive kernel of `⊑`. +""" +⊏(lattice::AbstractLattice, @nospecialize(a), @nospecialize(b)) = ⊑(lattice, a, b) && !⊑(lattice, b, a) + +""" + ⋤(lattice, a, b) -> Bool + +This order could be used as a slightly more efficient version of the strict order `⊏`, +where we can safely assume `a ⊑ b` holds. +""" +⋤(lattice::AbstractLattice, @nospecialize(a), @nospecialize(b)) = !⊑(lattice, b, a) + +""" + is_lattice_equal(a, b) -> Bool + +Check if two lattice elements are partial order equivalent. +This is basically `a ⊑ b && b ⊑ a` but (optionally) with extra performance optimizations. +""" +function is_lattice_equal(lattice::AbstractLattice, @nospecialize(a), @nospecialize(b)) + a === b && return true + ⊑(lattice, a, b) && ⊑(lattice, b, a) +end + +# Curried versions +⊑(lattice::AbstractLattice) = (@nospecialize(a), @nospecialize(b)) -> ⊑(lattice, a, b) +⊏(lattice::AbstractLattice) = (@nospecialize(a), @nospecialize(b)) -> ⊏(lattice, a, b) +⋤(lattice::AbstractLattice) = (@nospecialize(a), @nospecialize(b)) -> ⋤(lattice, a, b) + +# Fallbacks for external packages using these methods +const fallback_lattice = InferenceLattice(BaseInferenceLattice.instance) +const fallback_ipo_lattice = InferenceLattice(IPOResultLattice.instance) + +⊑(@nospecialize(a), @nospecialize(b)) = ⊑(fallback_lattice, a, b) +tmeet(@nospecialize(a), @nospecialize(b)) = tmeet(fallback_lattice, a, b) +tmerge(@nospecialize(a), @nospecialize(b)) = tmerge(fallback_lattice, a, b) +⊏(@nospecialize(a), @nospecialize(b)) = ⊏(fallback_lattice, a, b) +⋤(@nospecialize(a), @nospecialize(b)) = ⋤(fallback_lattice, a, b) +is_lattice_equal(@nospecialize(a), @nospecialize(b)) = is_lattice_equal(fallback_lattice, a, b) diff --git a/base/compiler/compiler.jl b/base/compiler/compiler.jl index 5b3a83c325499d..c37d67e65cc76f 100644 --- a/base/compiler/compiler.jl +++ b/base/compiler/compiler.jl @@ -156,6 +156,7 @@ include("compiler/ssair/ir.jl") include("compiler/inferenceresult.jl") include("compiler/inferencestate.jl") +include("compiler/abstractlattice.jl") include("compiler/typeutils.jl") include("compiler/typelimits.jl") include("compiler/typelattice.jl") diff --git a/base/compiler/optimize.jl b/base/compiler/optimize.jl index f19fbd014a04ea..41d72a773b8a56 100644 --- a/base/compiler/optimize.jl +++ b/base/compiler/optimize.jl @@ -211,6 +211,8 @@ Returns a tuple of (effect_free_and_nothrow, nothrow) for a given statement. """ function stmt_effect_flags(@nospecialize(stmt), @nospecialize(rt), src::Union{IRCode,IncrementalCompact}) # TODO: We're duplicating analysis from inference here. + lattice = OptimizerLattice() + ⊑ₒ = ⊑(lattice) isa(stmt, PiNode) && return (true, true) isa(stmt, PhiNode) && return (true, true) isa(stmt, ReturnNode) && return (false, true) @@ -248,7 +250,7 @@ function stmt_effect_flags(@nospecialize(stmt), @nospecialize(rt), src::Union{IR return (total, total) end rt === Bottom && return (false, false) - nothrow = _builtin_nothrow(f, Any[argextype(args[i], src) for i = 2:length(args)], rt) + nothrow = _builtin_nothrow(f, Any[argextype(args[i], src) for i = 2:length(args)], rt, lattice) nothrow || return (false, false) return (contains_is(_EFFECT_FREE_BUILTINS, f), nothrow) elseif head === :new @@ -277,11 +279,11 @@ function stmt_effect_flags(@nospecialize(stmt), @nospecialize(rt), src::Union{IR typ = argextype(args[1], src) typ, isexact = instanceof_tfunc(typ) isexact || return (false, false) - typ ⊑ Tuple || return (false, false) + typ ⊑ₒ Tuple || return (false, false) rt_lb = argextype(args[2], src) rt_ub = argextype(args[3], src) source = argextype(args[4], src) - if !(rt_lb ⊑ Type && rt_ub ⊑ Type && source ⊑ Method) + if !(rt_lb ⊑ₒ Type && rt_ub ⊑ₒ Type && source ⊑ₒ Method) return (false, false) end return (true, true) diff --git a/base/compiler/ssair/inlining.jl b/base/compiler/ssair/inlining.jl index 84864404f008e3..80074c458a39d9 100644 --- a/base/compiler/ssair/inlining.jl +++ b/base/compiler/ssair/inlining.jl @@ -114,6 +114,8 @@ function CFGInliningState(ir::IRCode) ) end +⊑ₒ(@nospecialize(a), @nospecialize(b)) = ⊑(OptimizerLattice(), a, b) + # Tells the inliner that we're now inlining into block `block`, meaning # all previous blocks have been processed and can be added to the new cfg function inline_into_block!(state::CFGInliningState, block::Int) @@ -1080,8 +1082,8 @@ function inline_apply!( nonempty_idx = 0 for i = (arg_start + 1):length(argtypes) ti = argtypes[i] - ti ⊑ Tuple{} && continue - if ti ⊑ Tuple && nonempty_idx == 0 + ti ⊑ₒ Tuple{} && continue + if ti ⊑ₒ Tuple && nonempty_idx == 0 nonempty_idx = i continue end @@ -1123,9 +1125,9 @@ end # TODO: this test is wrong if we start to handle Unions of function types later is_builtin(s::Signature) = isa(s.f, IntrinsicFunction) || - s.ft ⊑ IntrinsicFunction || + s.ft ⊑ₒ IntrinsicFunction || isa(s.f, Builtin) || - s.ft ⊑ Builtin + s.ft ⊑ₒ Builtin function inline_invoke!( ir::IRCode, idx::Int, stmt::Expr, info::InvokeCallInfo, flag::UInt8, @@ -1165,7 +1167,7 @@ function narrow_opaque_closure!(ir::IRCode, stmt::Expr, @nospecialize(info), sta ub, exact = instanceof_tfunc(ubt) exact || return # Narrow opaque closure type - newT = widenconst(tmeet(tmerge(lb, info.unspec.rt), ub)) + newT = widenconst(tmeet(OptimizerLattice(), tmerge(OptimizerLattice(), lb, info.unspec.rt), ub)) if newT != ub # N.B.: Narrowing the ub requires a backdge on the mi whose type # information we're using, since a change in that function may @@ -1222,7 +1224,7 @@ function process_simple!(ir::IRCode, idx::Int, state::InliningState, todo::Vecto ir.stmts[idx][:inst] = earlyres.val return nothing end - if (sig.f === modifyfield! || sig.ft ⊑ typeof(modifyfield!)) && 5 <= length(stmt.args) <= 6 + if (sig.f === modifyfield! || sig.ft ⊑ₒ typeof(modifyfield!)) && 5 <= length(stmt.args) <= 6 let info = ir.stmts[idx][:info] info isa MethodResultPure && (info = info.info) info isa ConstCallInfo && (info = info.call) @@ -1240,7 +1242,7 @@ function process_simple!(ir::IRCode, idx::Int, state::InliningState, todo::Vecto end if check_effect_free!(ir, idx, stmt, rt) - if sig.f === typeassert || sig.ft ⊑ typeof(typeassert) + if sig.f === typeassert || sig.ft ⊑ₒ typeof(typeassert) # typeassert is a no-op if effect free ir.stmts[idx][:inst] = stmt.args[2] return nothing @@ -1637,7 +1639,7 @@ function early_inline_special_case( elseif ispuretopfunction(f) || contains_is(_PURE_BUILTINS, f) return SomeCase(quoted(val)) elseif contains_is(_EFFECT_FREE_BUILTINS, f) - if _builtin_nothrow(f, argtypes[2:end], type) + if _builtin_nothrow(f, argtypes[2:end], type, OptimizerLattice()) return SomeCase(quoted(val)) end elseif f === Core.get_binding_type @@ -1683,17 +1685,17 @@ function late_inline_special_case!( elseif length(argtypes) == 3 && istopfunction(f, :(>:)) # special-case inliner for issupertype # that works, even though inference generally avoids inferring the `>:` Method - if isa(type, Const) && _builtin_nothrow(<:, Any[argtypes[3], argtypes[2]], type) + if isa(type, Const) && _builtin_nothrow(<:, Any[argtypes[3], argtypes[2]], type, OptimizerLattice()) return SomeCase(quoted(type.val)) end subtype_call = Expr(:call, GlobalRef(Core, :(<:)), stmt.args[3], stmt.args[2]) return SomeCase(subtype_call) - elseif f === TypeVar && 2 <= length(argtypes) <= 4 && (argtypes[2] ⊑ Symbol) + elseif f === TypeVar && 2 <= length(argtypes) <= 4 && (argtypes[2] ⊑ₒ Symbol) typevar_call = Expr(:call, GlobalRef(Core, :_typevar), stmt.args[2], length(stmt.args) < 4 ? Bottom : stmt.args[3], length(stmt.args) == 2 ? Any : stmt.args[end]) return SomeCase(typevar_call) - elseif f === UnionAll && length(argtypes) == 3 && (argtypes[2] ⊑ TypeVar) + elseif f === UnionAll && length(argtypes) == 3 && (argtypes[2] ⊑ₒ TypeVar) unionall_call = Expr(:foreigncall, QuoteNode(:jl_type_unionall), Any, svec(Any, Any), 0, QuoteNode(:ccall), stmt.args[2], stmt.args[3]) return SomeCase(unionall_call) diff --git a/base/compiler/ssair/passes.jl b/base/compiler/ssair/passes.jl index 27b145d82c2101..14a1f04c7a1c57 100644 --- a/base/compiler/ssair/passes.jl +++ b/base/compiler/ssair/passes.jl @@ -289,7 +289,7 @@ function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospe # path, with a different type constraint. We may have # to redo some work here with the wider typeconstraint push!(worklist_defs, new_def) - push!(worklist_constraints, tmerge(new_constraint, visited_constraints[new_def])) + push!(worklist_constraints, tmerge(OptimizerLattice(), new_constraint, visited_constraints[new_def])) end continue end @@ -348,7 +348,7 @@ function is_getfield_captures(@nospecialize(def), compact::IncrementalCompact) isa(which, Const) || return false which.val === :captures || return false oc = argextype(def.args[2], compact) - return oc ⊑ Core.OpaqueClosure + return ⊑(OptimizerLattice(), oc, Core.OpaqueClosure) end struct LiftedValue @@ -528,13 +528,15 @@ function lift_comparison!(::typeof(===), compact::IncrementalCompact, lift_comparison_leaves!(egal_tfunc, compact, val, cmp, lifting_cache, idx) end +isa_tfunc_opt(@nospecialize(v), @nospecialize(t)) = isa_tfunc(OptimizerLattice(), v, t) + function lift_comparison!(::typeof(isa), compact::IncrementalCompact, idx::Int, stmt::Expr, lifting_cache::IdDict{Pair{AnySSAValue, Any}, AnySSAValue}) args = stmt.args length(args) == 3 || return cmp = argextype(args[3], compact) val = args[2] - lift_comparison_leaves!(isa_tfunc, compact, val, cmp, lifting_cache, idx) + lift_comparison_leaves!(isa_tfunc_opt, compact, val, cmp, lifting_cache, idx) end function lift_comparison!(::typeof(isdefined), compact::IncrementalCompact, @@ -1446,7 +1448,7 @@ function adce_pass!(ir::IRCode) r = searchsorted(unionphis, val.id; by = first) if !isempty(r) unionphi = unionphis[first(r)] - t = tmerge(unionphi[2], stmt.typ) + t = tmerge(OptimizerLattice(), unionphi[2], stmt.typ) unionphis[first(r)] = Pair{Int,Any}(unionphi[1], t) end end @@ -1454,7 +1456,7 @@ function adce_pass!(ir::IRCode) if is_known_call(stmt, typeassert, compact) && length(stmt.args) == 3 # nullify safe `typeassert` calls ty, isexact = instanceof_tfunc(argextype(stmt.args[3], compact)) - if isexact && argextype(stmt.args[2], compact) ⊑ ty + if isexact && ⊑(OptimizerLattice(), argextype(stmt.args[2], compact), ty) compact[idx] = nothing continue end @@ -1483,7 +1485,7 @@ function adce_pass!(ir::IRCode) if !isempty(r) unionphi = unionphis[first(r)] unionphis[first(r)] = Pair{Int,Any}(unionphi[1], - tmerge(unionphi[2], inst[:type])) + tmerge(OptimizerLattice(), unionphi[2], inst[:type])) end end end @@ -1499,7 +1501,7 @@ function adce_pass!(ir::IRCode) continue elseif t === Any continue - elseif compact.result[phi][:type] ⊑ t + elseif ⊑(OptimizerLattice(), compact.result[phi][:type], t) continue end to_drop = Int[] diff --git a/base/compiler/ssair/slot2ssa.jl b/base/compiler/ssair/slot2ssa.jl index 29a3b093c9abf4..e7298833f7b5e6 100644 --- a/base/compiler/ssair/slot2ssa.jl +++ b/base/compiler/ssair/slot2ssa.jl @@ -576,7 +576,7 @@ function recompute_type(node::Union{PhiNode, PhiCNode}, ci::CodeInfo, ir::IRCode while isa(typ, DelayedTyp) typ = types(ir)[new_to_regular(typ.phi::NewSSAValue, nstmts)] end - new_typ = tmerge(new_typ, was_maybe_undef ? MaybeUndef(typ) : typ) + new_typ = tmerge(OptimizerLattice.instance, new_typ, was_maybe_undef ? MaybeUndef(typ) : typ) end return new_typ end @@ -586,6 +586,8 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, domtree::DomTree, code = ir.stmts.inst cfg = ir.cfg catch_entry_blocks = Tuple{Int, Int}[] + lattice = OptimizerLattice.instance + ⊑ₒ = ⊑(lattice) for idx in 1:length(code) stmt = code[idx] if isexpr(stmt, :enter) @@ -719,7 +721,7 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, domtree::DomTree, if isa(typ, DelayedTyp) push!(type_refine_phi, ssaval.id) end - new_typ = isa(typ, DelayedTyp) ? Union{} : tmerge(old_entry[:type], typ) + new_typ = isa(typ, DelayedTyp) ? Union{} : tmerge(lattice, old_entry[:type], typ) old_entry[:type] = new_typ old_entry[:inst] = node incoming_vals[slot] = ssaval @@ -853,7 +855,7 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, domtree::DomTree, while isa(typ, DelayedTyp) typ = types(ir)[new_to_regular(typ.phi::NewSSAValue, nstmts)] end - new_typ = tmerge(new_typ, typ) + new_typ = tmerge(lattice, new_typ, typ) end node[:type] = new_typ end @@ -867,7 +869,7 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, domtree::DomTree, for new_idx in type_refine_phi node = new_nodes.stmts[new_idx] new_typ = recompute_type(node[:inst]::Union{PhiNode,PhiCNode}, ci, ir, ir.sptypes, slottypes, nstmts) - if !(node[:type] ⊑ new_typ) || !(new_typ ⊑ node[:type]) + if !(node[:type] ⊑ₒ new_typ) || !(new_typ ⊑ₒ node[:type]) node[:type] = new_typ changed = true end diff --git a/base/compiler/ssair/verify.jl b/base/compiler/ssair/verify.jl index ec43a0e142699b..fcbed694bc2dbc 100644 --- a/base/compiler/ssair/verify.jl +++ b/base/compiler/ssair/verify.jl @@ -72,7 +72,8 @@ function count_int(val::Int, arr::Vector{Int}) n end -function verify_ir(ir::IRCode, print::Bool=true, allow_frontend_forms::Bool=false) +function verify_ir(ir::IRCode, print::Bool=true, allow_frontend_forms::Bool=false, + lattice = OptimizerLattice()) # For now require compact IR # @assert isempty(ir.new_nodes) # Verify CFG @@ -182,7 +183,7 @@ function verify_ir(ir::IRCode, print::Bool=true, allow_frontend_forms::Bool=fals val = stmt.values[i] phiT = ir.stmts[idx][:type] if isa(val, SSAValue) - if !(types(ir)[val] ⊑ phiT) + if !⊑(lattice, types(ir)[val], phiT) #@verify_error """ # PhiNode $idx, has operand $(val.id), whose type is not a sub lattice element. # PhiNode type was $phiT diff --git a/base/compiler/tfuncs.jl b/base/compiler/tfuncs.jl index 5616b50ef39135..3412e4e8dd31cb 100644 --- a/base/compiler/tfuncs.jl +++ b/base/compiler/tfuncs.jl @@ -254,6 +254,7 @@ function isdefined_nothrow(argtypes::Array{Any, 1}) return a2 ⊑ Symbol || a2 ⊑ Int end end + isdefined_tfunc(arg1, sym, order) = (@nospecialize; isdefined_tfunc(arg1, sym)) function isdefined_tfunc(@nospecialize(arg1), @nospecialize(sym)) if isa(arg1, Const) @@ -307,11 +308,14 @@ function isdefined_tfunc(@nospecialize(arg1), @nospecialize(sym)) end end elseif isa(a1, Union) - return tmerge(isdefined_tfunc(rewrap_unionall(a1.a, arg1t), sym), + # Results can only be `Const` or `Bool` + return tmerge(fallback_lattice, + isdefined_tfunc(rewrap_unionall(a1.a, arg1t), sym), isdefined_tfunc(rewrap_unionall(a1.b, arg1t), sym)) end return Bool end + add_tfunc(isdefined, 2, 3, isdefined_tfunc, 1) function sizeof_nothrow(@nospecialize(x)) @@ -653,7 +657,7 @@ function typeassert_tfunc(@nospecialize(v), @nospecialize(t)) end add_tfunc(typeassert, 2, 2, typeassert_tfunc, 4) -function isa_tfunc(@nospecialize(v), @nospecialize(tt)) +function isa_tfunc(𝕃::AbstractLattice, @nospecialize(v), @nospecialize(tt)) t, isexact = instanceof_tfunc(tt) if t === Bottom # check if t could be equivalent to typeof(Bottom), since that's valid in `isa`, but the set of `v` is empty @@ -662,7 +666,7 @@ function isa_tfunc(@nospecialize(v), @nospecialize(tt)) return Const(false) end if !has_free_typevars(t) - if v ⊑ t + if ⊑(𝕃, v, t) if isexact && isnotbrokensubtype(v, t) return Const(true) end @@ -686,6 +690,7 @@ function isa_tfunc(@nospecialize(v), @nospecialize(tt)) # TODO: handle non-leaftype(t) by testing against lower and upper bounds return Bool end +isa_tfunc(@nospecialize(v), @nospecialize(t)) = isa_tfunc(fallback_lattice, v, t) add_tfunc(isa, 2, 2, isa_tfunc, 1) function subtype_tfunc(@nospecialize(a), @nospecialize(b)) @@ -1764,7 +1769,8 @@ function arrayset_typecheck(@nospecialize(arytype), @nospecialize(elmtype)) end # Query whether the given builtin is guaranteed not to throw given the argtypes -function _builtin_nothrow(@nospecialize(f), argtypes::Array{Any,1}, @nospecialize(rt)) +function _builtin_nothrow(@nospecialize(f), argtypes::Array{Any,1}, @nospecialize(rt), lattice::AbstractLattice) + ⊑ₗ = ⊑(lattice) if f === arrayset array_builtin_common_nothrow(argtypes, 4) || return false # Additionally check element type compatibility @@ -1775,7 +1781,7 @@ function _builtin_nothrow(@nospecialize(f), argtypes::Array{Any,1}, @nospecializ return arraysize_nothrow(argtypes) elseif f === Core._expr length(argtypes) >= 1 || return false - return argtypes[1] ⊑ Symbol + return argtypes[1] ⊑ₗ Symbol elseif f === Core._typevar length(argtypes) == 3 || return false return typevar_nothrow(argtypes[1], argtypes[2], argtypes[3]) @@ -1809,12 +1815,12 @@ function _builtin_nothrow(@nospecialize(f), argtypes::Array{Any,1}, @nospecializ return isa(rt, Const) elseif f === Core.ifelse length(argtypes) == 3 || return false - return argtypes[1] ⊑ Bool + return argtypes[1] ⊑ₗ Bool elseif f === typeassert length(argtypes) == 2 || return false a3 = argtypes[2] - if (isType(a3) && !has_free_typevars(a3) && argtypes[1] ⊑ a3.parameters[1]) || - (isa(a3, Const) && isa(a3.val, Type) && argtypes[1] ⊑ a3.val) + if (isType(a3) && !has_free_typevars(a3) && argtypes[1] ⊑ₗ a3.parameters[1]) || + (isa(a3, Const) && isa(a3.val, Type) && argtypes[1] ⊑ₗ a3.val) return true end return false @@ -2023,7 +2029,7 @@ end function builtin_nothrow(@nospecialize(f), argtypes::Vector{Any}, @nospecialize(rt)) rt === Bottom && return false contains_is(_PURE_BUILTINS, f) && return true - return _builtin_nothrow(f, argtypes, rt) + return _builtin_nothrow(f, argtypes, rt, fallback_lattice) end function builtin_tfunction(interp::AbstractInterpreter, @nospecialize(f), argtypes::Vector{Any}, diff --git a/base/compiler/typelattice.jl b/base/compiler/typelattice.jl index b391840ed5ed4d..c83d4248a465dc 100644 --- a/base/compiler/typelattice.jl +++ b/base/compiler/typelattice.jl @@ -187,7 +187,7 @@ end is_same_conditionals(a::C, b::C) where C<:AnyConditional = a.slot == b.slot -is_lattice_bool(@nospecialize(typ)) = typ !== Bottom && typ ⊑ Bool +is_lattice_bool(lattice::AbstractLattice, @nospecialize(typ)) = typ !== Bottom && ⊑(lattice, typ, Bool) maybe_extract_const_bool(c::Const) = (val = c.val; isa(val, Bool)) ? val : nothing function maybe_extract_const_bool(c::AnyConditional) @@ -206,12 +206,9 @@ ignorelimited(typ::LimitedAccuracy) = typ.typ # lattice order # ============= -""" - a ⊑ b -> Bool - -The non-strict partial order over the type inference lattice. -""" -@nospecialize(a) ⊑ @nospecialize(b) = begin +function ⊑(lattice::InferenceLattice, @nospecialize(a), @nospecialize(b)) + @assert is_valid_lattice(lattice, a) + @assert is_valid_lattice(lattice, b) if isa(b, LimitedAccuracy) if !isa(a, LimitedAccuracy) return false @@ -222,17 +219,28 @@ The non-strict partial order over the type inference lattice. b = b.typ end isa(a, LimitedAccuracy) && (a = a.typ) + return ⊑(widen(lattice), a, b) +end + +function ⊑(lattice::OptimizerLattice, @nospecialize(a), @nospecialize(b)) + @assert is_valid_lattice(lattice, a) + @assert is_valid_lattice(lattice, b) if isa(a, MaybeUndef) && !isa(b, MaybeUndef) return false end isa(a, MaybeUndef) && (a = a.typ) isa(b, MaybeUndef) && (b = b.typ) + return ⊑(widen(lattice), a, b) +end + +function ⊑(lattice::AnyConditionalsLattice, @nospecialize(a), @nospecialize(b)) + @assert is_valid_lattice(lattice, a) + @assert is_valid_lattice(lattice, b) + # Fast paths for common cases b === Any && return true a === Any && return false a === Union{} && return true b === Union{} && return false - @assert !isa(a, TypeVar) "invalid lattice item" - @assert !isa(b, TypeVar) "invalid lattice item" if isa(a, AnyConditional) if isa(b, AnyConditional) return issubconditional(a, b) @@ -243,6 +251,12 @@ The non-strict partial order over the type inference lattice. elseif isa(b, AnyConditional) return false end + ⊑(widen(lattice), a, b) +end + +function ⊑(lattice::PartialsLattice, @nospecialize(a), @nospecialize(b)) + @assert is_valid_lattice(lattice, a) + @assert is_valid_lattice(lattice, b) if isa(a, PartialStruct) if isa(b, PartialStruct) if !(length(a.fields) == length(b.fields) && a.typ <: b.typ) @@ -262,7 +276,7 @@ The non-strict partial order over the type inference lattice. continue end end - ⊑(af, bf) || return false + ⊑(lattice, af, bf) || return false end return true end @@ -283,7 +297,7 @@ The non-strict partial order over the type inference lattice. if i == nf bf = unwrapva(bf) end - ⊑(Const(getfield(a.val, i)), bf) || return false + ⊑(lattice, Const(getfield(a.val, i)), bf) || return false end return true end @@ -293,12 +307,16 @@ The non-strict partial order over the type inference lattice. if isa(b, PartialOpaque) (a.parent === b.parent && a.source === b.source) || return false return (widenconst(a) <: widenconst(b)) && - ⊑(a.env, b.env) + ⊑(lattice, a.env, b.env) end - return widenconst(a) ⊑ b + return ⊑(widen(lattice), widenconst(a), b) elseif isa(b, PartialOpaque) return false end + ⊑(widen(lattice), a, b) +end + +function ⊑(lattice::ConstsLattice, @nospecialize(a), @nospecialize(b)) if isa(a, Const) if isa(b, Const) return a.val === b.val @@ -315,84 +333,82 @@ The non-strict partial order over the type inference lattice. elseif isa(a, PartialTypeVar) return b === TypeVar || a === b elseif isa(b, PartialTypeVar) - return a === b - else - return a <: b + return false end + return ⊑(widen(lattice), a, b) end -""" - a ⊏ b -> Bool - -The strict partial order over the type inference lattice. -This is defined as the irreflexive kernel of `⊑`. -""" -@nospecialize(a) ⊏ @nospecialize(b) = a ⊑ b && !⊑(b, a) - -""" - a ⋤ b -> Bool +function is_lattice_equal(lattice::InferenceLattice, @nospecialize(a), @nospecialize(b)) + if isa(a, LimitedAccuracy) || isa(b, LimitedAccuracy) + # TODO: Unwrap these and recurse to is_lattice_equal + return ⊑(lattice, a, b) && ⊑(lattice, b, a) + end + return is_lattice_equal(widen(lattice), a, b) +end -This order could be used as a slightly more efficient version of the strict order `⊏`, -where we can safely assume `a ⊑ b` holds. -""" -@nospecialize(a) ⋤ @nospecialize(b) = !⊑(b, a) +function is_lattice_equal(lattice::OptimizerLattice, @nospecialize(a), @nospecialize(b)) + if isa(a, MaybeUndef) || isa(b, MaybeUndef) + # TODO: Unwrap these and recurse to is_lattice_equal + return ⊑(lattice, a, b) && ⊑(lattice, b, a) + end + return is_lattice_equal(widen(lattice), a, b) +end -""" - is_lattice_equal(a, b) -> Bool +function is_lattice_equal(lattice::AnyConditionalsLattice, @nospecialize(a), @nospecialize(b)) + if isa(a, AnyConditional) || isa(b, AnyConditional) + # TODO: Unwrap these and recurse to is_lattice_equal + return ⊑(lattice, a, b) && ⊑(lattice, b, a) + end + return is_lattice_equal(widen(lattice), a, b) +end -Check if two lattice elements are partial order equivalent. -This is basically `a ⊑ b && b ⊑ a` but with extra performance optimizations. -""" -function is_lattice_equal(@nospecialize(a), @nospecialize(b)) - a === b && return true +function is_lattice_equal(lattice::PartialsLattice, @nospecialize(a), @nospecialize(b)) if isa(a, PartialStruct) isa(b, PartialStruct) || return false length(a.fields) == length(b.fields) || return false widenconst(a) == widenconst(b) || return false + a.fields === b.fields && return true # fast path for i in 1:length(a.fields) is_lattice_equal(a.fields[i], b.fields[i]) || return false end return true end isa(b, PartialStruct) && return false + if isa(a, PartialOpaque) + isa(b, PartialOpaque) || return false + widenconst(a) == widenconst(b) || return false + a.source === b.source || return false + a.parent === b.parent || return false + return is_lattice_equal(lattice, a.env, b.env) + end + isa(b, PartialOpaque) && return false + return is_lattice_equal(widen(lattice), a, b) +end + +function is_lattice_equal(lattice::ConstsLattice, @nospecialize(a), @nospecialize(b)) + a === b && return true if a isa Const if issingletontype(b) return a.val === b.instance end + # N.B. Assumes a === b checked above return false end if b isa Const if issingletontype(a) return a.instance === b.val end + # N.B. Assumes a === b checked above return false end - if isa(a, PartialOpaque) - isa(b, PartialOpaque) || return false - widenconst(a) == widenconst(b) || return false - a.source === b.source || return false - a.parent === b.parent || return false - return is_lattice_equal(a.env, b.env) - end - return a ⊑ b && b ⊑ a + return is_lattice_equal(widen(lattice), a, b) end # lattice operations # ================== -""" - tmeet(v, t::Type) -> x - -Computes typeintersect over the extended inference lattice, as precisely as we can, -where `v` is in the extended lattice, and `t` is a `Type`. -""" -function tmeet(@nospecialize(v), @nospecialize(t::Type)) - if isa(v, Const) - if !has_free_typevars(t) && !isa(v.val, t) - return Bottom - end - return v - elseif isa(v, PartialStruct) +function tmeet(lattice::PartialsLattice, @nospecialize(v), @nospecialize(t::Type)) + if isa(v, PartialStruct) has_free_typevars(t) && return v widev = widenconst(v) if widev <: t @@ -407,7 +423,7 @@ function tmeet(@nospecialize(v), @nospecialize(t::Type)) if isvarargtype(vfi) new_fields[i] = vfi else - new_fields[i] = tmeet(vfi, widenconst(getfield_tfunc(t, Const(i)))) + new_fields[i] = tmeet(lattice, vfi, widenconst(getfield_tfunc(t, Const(i)))) if new_fields[i] === Bottom return Bottom end @@ -423,15 +439,46 @@ function tmeet(@nospecialize(v), @nospecialize(t::Type)) ti = typeintersect(widev, t) valid_as_lattice(ti) || return Bottom return PartialOpaque(ti, v.env, v.parent, v.source) - elseif isa(v, Conditional) + end + return tmeet(widen(lattice), v, t) +end + +function tmeet(lattice::ConstsLattice, @nospecialize(v), @nospecialize(t::Type)) + if isa(v, Const) + if !has_free_typevars(t) && !isa(v.val, t) + return Bottom + end + return v + end + tmeet(widen(lattice), widenconst(v), t) +end + +function tmeet(lattice::ConditionalsLattice, @nospecialize(v), @nospecialize(t::Type)) + if isa(v, Conditional) if !(Bool <: t) return Bottom end return v end - ti = typeintersect(widenconst(v), t) - valid_as_lattice(ti) || return Bottom - return ti + tmeet(widen(lattice), v, t) +end + +function tmeet(lattice::InferenceLattice, @nospecialize(v), @nospecialize(t::Type)) + # TODO: This can probably happen and should be handled + @assert !isa(v, LimitedAccuracy) + tmeet(widen(lattice), v, t) +end + +function tmeet(lattice::InterConditionalsLattice, @nospecialize(v), @nospecialize(t::Type)) + # TODO: This can probably happen and should be handled + @assert !isa(v, AnyConditional) + tmeet(widen(lattice), v, t) +end + +function tmeet(lattice::OptimizerLattice, @nospecialize(v), @nospecialize(t::Type)) + # TODO: This can probably happen and should be handled + @assert !isa(v, MaybeUndef) + tmeet(widen(lattice), v, t) end """ @@ -454,19 +501,22 @@ widenconst(::LimitedAccuracy) = error("unhandled LimitedAccuracy") # state management # #################### -issubstate(a::VarState, b::VarState) = (a.typ ⊑ b.typ && a.undef <= b.undef) +issubstate(lattice::AbstractLattice, a::VarState, b::VarState) = + ⊑(lattice, a.typ, b.typ) && a.undef <= b.undef -function smerge(sa::Union{NotFound,VarState}, sb::Union{NotFound,VarState}) +function smerge(lattice::AbstractLattice, sa::Union{NotFound,VarState}, sb::Union{NotFound,VarState}) sa === sb && return sa sa === NOT_FOUND && return sb sb === NOT_FOUND && return sa - issubstate(sa, sb) && return sb - issubstate(sb, sa) && return sa - return VarState(tmerge(sa.typ, sb.typ), sa.undef | sb.undef) + issubstate(lattice, sa, sb) && return sb + issubstate(lattice, sb, sa) && return sa + return VarState(tmerge(lattice, sa.typ, sb.typ), sa.undef | sb.undef) end -@inline tchanged(@nospecialize(n), @nospecialize(o)) = o === NOT_FOUND || (n !== NOT_FOUND && !(n ⊑ o)) -@inline schanged(@nospecialize(n), @nospecialize(o)) = (n !== o) && (o === NOT_FOUND || (n !== NOT_FOUND && !issubstate(n::VarState, o::VarState))) +@inline tchanged(lattice::AbstractLattice, @nospecialize(n), @nospecialize(o)) = + o === NOT_FOUND || (n !== NOT_FOUND && !⊑(lattice, n, o)) +@inline schanged(lattice::AbstractLattice, @nospecialize(n), @nospecialize(o)) = + (n !== o) && (o === NOT_FOUND || (n !== NOT_FOUND && !issubstate(lattice, n::VarState, o::VarState))) # remove any lattice elements that wrap the reassigned slot object from the vartable function invalidate_slotwrapper(vt::VarState, changeid::Int, ignore_conditional::Bool) @@ -478,7 +528,7 @@ function invalidate_slotwrapper(vt::VarState, changeid::Int, ignore_conditional: return nothing end -function stupdate!(state::VarTable, changes::StateUpdate) +function stupdate!(lattice::AbstractLattice, state::VarTable, changes::StateUpdate) changed = false changeid = slot_id(changes.var) for i = 1:length(state) @@ -492,28 +542,28 @@ function stupdate!(state::VarTable, changes::StateUpdate) newtype = invalidated end oldtype = state[i] - if schanged(newtype, oldtype) - state[i] = smerge(oldtype, newtype) + if schanged(lattice, newtype, oldtype) + state[i] = smerge(lattice, oldtype, newtype) changed = true end end return changed end -function stupdate!(state::VarTable, changes::VarTable) +function stupdate!(lattice::AbstractLattice, state::VarTable, changes::VarTable) changed = false for i = 1:length(state) newtype = changes[i] oldtype = state[i] - if schanged(newtype, oldtype) - state[i] = smerge(oldtype, newtype) + if schanged(lattice, newtype, oldtype) + state[i] = smerge(lattice, oldtype, newtype) changed = true end end return changed end -function stupdate1!(state::VarTable, change::StateUpdate) +function stupdate1!(lattice::AbstractLattice, state::VarTable, change::StateUpdate) changeid = slot_id(change.var) for i = 1:length(state) invalidated = invalidate_slotwrapper(state[i], changeid, change.conditional) @@ -524,8 +574,8 @@ function stupdate1!(state::VarTable, change::StateUpdate) # and update the type of it newtype = change.vtype oldtype = state[changeid] - if schanged(newtype, oldtype) - state[changeid] = smerge(oldtype, newtype) + if schanged(lattice, newtype, oldtype) + state[changeid] = smerge(lattice, oldtype, newtype) return true end return false diff --git a/base/compiler/typelimits.jl b/base/compiler/typelimits.jl index d44619fa508df6..58355a5b82da6a 100644 --- a/base/compiler/typelimits.jl +++ b/base/compiler/typelimits.jl @@ -349,20 +349,37 @@ function issimplertype(@nospecialize(typea), @nospecialize(typeb)) return true end -# pick a wider type that contains both typea and typeb, -# with some limits on how "large" it can get, -# but without losing too much precision in common cases -# and also trying to be mostly associative and commutative -function tmerge(@nospecialize(typea), @nospecialize(typeb)) +@inline function tmerge_fast_path(lattice::AbstractLattice, @nospecialize(typea), @nospecialize(typeb)) + # Fast paths typea === Union{} && return typeb typeb === Union{} && return typea typea === typeb && return typea - suba = typea ⊑ typeb + suba = ⊑(lattice, typea, typeb) suba && issimplertype(typeb, typea) && return typeb - subb = typeb ⊑ typea + subb = ⊑(lattice, typeb, typea) suba && subb && return typea subb && issimplertype(typea, typeb) && return typea + return nothing +end + + +function tmerge(lattice::OptimizerLattice, @nospecialize(typea), @nospecialize(typeb)) + r = tmerge_fast_path(lattice, typea, typeb) + r !== nothing && return r + + # type-lattice for MaybeUndef wrapper + if isa(typea, MaybeUndef) || isa(typeb, MaybeUndef) + return MaybeUndef(tmerge( + isa(typea, MaybeUndef) ? typea.typ : typea, + isa(typeb, MaybeUndef) ? typeb.typ : typeb)) + end + return tmerge(widen(lattice), typea, typeb) +end + +function tmerge(lattice::InferenceLattice, @nospecialize(typea), @nospecialize(typeb)) + r = tmerge_fast_path(lattice, typea, typeb) + r !== nothing && return r # type-lattice for LimitedAccuracy wrapper # the merge create a slightly narrower type than needed, but we can't @@ -383,13 +400,10 @@ function tmerge(@nospecialize(typea), @nospecialize(typeb)) return LimitedAccuracy(tmerge(typea, typeb.typ), typeb.causes) end - # type-lattice for MaybeUndef wrapper - if isa(typea, MaybeUndef) || isa(typeb, MaybeUndef) - return MaybeUndef(tmerge( - isa(typea, MaybeUndef) ? typea.typ : typea, - isa(typeb, MaybeUndef) ? typeb.typ : typeb)) - end + return tmerge(widen(lattice), typea, typeb) +end +function tmerge(lattice::ConditionalsLattice, @nospecialize(typea), @nospecialize(typeb)) # type-lattice for Conditional wrapper (NOTE never be merged with InterConditional) if isa(typea, Conditional) && isa(typeb, Const) if typeb.val === true @@ -419,6 +433,10 @@ function tmerge(@nospecialize(typea), @nospecialize(typeb)) end return Bool end + tmerge(widen(lattice), typea, typeb) +end + +function tmerge(lattice::InterConditionalsLattice, @nospecialize(typea), @nospecialize(typeb)) # type-lattice for InterConditional wrapper (NOTE never be merged with Conditional) if isa(typea, InterConditional) && isa(typeb, Const) if typeb.val === true @@ -448,7 +466,10 @@ function tmerge(@nospecialize(typea), @nospecialize(typeb)) end return Bool end + tmerge(widen(lattice), typea, typeb) +end +function tmerge(latice::PartialsLattice, @nospecialize(typea), @nospecialize(typeb)) # type-lattice for Const and PartialStruct wrappers if ((isa(typea, PartialStruct) || isa(typea, Const)) && (isa(typeb, PartialStruct) || isa(typeb, Const))) @@ -513,10 +534,12 @@ function tmerge(@nospecialize(typea), @nospecialize(typeb)) # no special type-inference lattice, join the types typea, typeb = widenconst(typea), widenconst(typeb) - if !isa(typea, Type) || !isa(typeb, Type) - # XXX: this should never happen - return Any - end + @assert (isa(typea, Type) && isa(typeb, Type)) + + return tmerge(JLTypeLattice(), typea, typeb) +end + +function tmerge(::JLTypeLattice, @nospecialize(typea), @nospecialize(typeb)) typea == typeb && return typea # it's always ok to form a Union of two concrete types if (isconcretetype(typea) || isType(typea)) && (isconcretetype(typeb) || isType(typeb)) diff --git a/base/compiler/types.jl b/base/compiler/types.jl index 3b968f6f844779..870546f105f4a0 100644 --- a/base/compiler/types.jl +++ b/base/compiler/types.jl @@ -273,3 +273,7 @@ to the call site signature. """ infer_compilation_signature(::AbstractInterpreter) = false infer_compilation_signature(::NativeInterpreter) = true + +typeinf_lattice(::AbstractInterpreter) = InferenceLattice(BaseInferenceLattice.instance) +ipo_lattice(::AbstractInterpreter) = InferenceLattice(IPOResultLattice.instance) +optimizer_lattice(::AbstractInterpreter) = OptimizerLattice.instance diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index 8dbb4b4359b655..19436df4ff67b2 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -163,7 +163,7 @@ tmerge_test(Tuple{}, Tuple{Complex, Vararg{Union{ComplexF32, ComplexF64}}}, @test Core.Compiler.tmerge(Vector{Int}, Core.Compiler.tmerge(Vector{String}, Union{Vector{Bool}, Vector{Symbol}})) == Vector @test Core.Compiler.tmerge(Base.BitIntegerType, Union{}) === Base.BitIntegerType @test Core.Compiler.tmerge(Union{}, Base.BitIntegerType) === Base.BitIntegerType -@test Core.Compiler.tmerge(Core.Compiler.InterConditional(1, Int, Union{}), Core.Compiler.InterConditional(2, String, Union{})) === Core.Compiler.Const(true) +@test Core.Compiler.tmerge(Core.Compiler.fallback_ipo_lattice, Core.Compiler.InterConditional(1, Int, Union{}), Core.Compiler.InterConditional(2, String, Union{})) === Core.Compiler.Const(true) struct SomethingBits x::Base.BitIntegerType