Skip to content

Commit

Permalink
Refactor lattice code to expose layering and enable easy extension
Browse files Browse the repository at this point in the history
There's been two threads of work involving the compiler's notion of
the inference lattice. One is that the lattice has gotten to complicated
and with too many internal constraints that are not manifest in the
type system. #42596 attempted to address this, but it's quite disruptive
as it changes the lattice types and all the signatures of the lattice
operations, which are used quite extensively throughout the ecosystem
(despite being internal), so that change is quite disruptive (and
something we'd ideally only make the ecosystem do once).

The other thread of work is that people would like to experiment with
a variety of extended lattices outside of base (either to prototype
potential additions to the lattice in base or to do custom abstract
interpretation over the julia code). At the moment, the lattice is
quite closely interwoven with the rest of the abstract interpreter.
In response to this request in #40992, I had proposed a `CustomLattice`
element with callbacks, but this doesn't compose particularly well,
is cumbersome and imposes overhead on some of the hottest parts of
the compiler, so it's a bit of a tough sell to merge into Base.

In this PR, I'd like to propose a refactoring that is relatively
non-invasive to non-Base users, but I think would allow easier
experimentation with changes to the lattice for these two use
cases. In essence, we're splitting the lattice into a ladder of
5 different lattices, each containing the previous lattice as a
sub-lattice. These 5 lattices are:

- JLTypeLattice (Anything that's a `Type`)
- ConstsLattice ( + `Const`, `PartialTypeVar`)
- PartialsLattice ( + `PartialStruct` )
- ConditionalsLattice ( + `Conditional` )
- InferenceLattice ( + `LimitedAccuracy`, `MaybeUndef` )

The idea is that where a lattice element contains another lattice
element (e.g. in `PartialStruct` or `Conditional`), the element
contained may only be from a wider lattice. In this PR, this
is not enforced by the type system. This is quite deliberate, as
I want to retain the types and object layouts of the lattice elements,
but of course a future #42596-like change could add such type
enforcement.

Of particular note is that the `PartialsLattice` and `ConditionalsLattice`
is parameterized and additional layers may be added in the stack.
For example, in #40992, I had proposed a lattice element that refines
`Int` and tracks symbolic expressions. In this setup, this could
be accomplished by adding an appropriate lattice in between the
`ConstsLattice` and the `PartialsLattice` (of course, additional
hooks would be required to make the tfuncs work, but that is
outside the scope of this PR).

I don't think this is a full solution, but I think it'll help us
play with some of these extended lattice options over the next
6-12 months in the packages that want to do this sort of thing.
Presumably once we know what all the potential lattice extensions
look like, we will want to take another look at this (likely
together with whatever solution we come up with for the
AbstractInterpreter composability problem and a rebase of #42596).

WIP because I didn't bother updating and plumbing through the lattice
in all the call sites yet, but that's mostly mechanical, so if we
like this direction, I will make that change and hope to merge this
in short order (because otherwise it'll accumulate massive merge
conflicts).
  • Loading branch information
Keno committed Aug 31, 2022
1 parent 682ae8a commit 59f14a1
Show file tree
Hide file tree
Showing 13 changed files with 374 additions and 158 deletions.
110 changes: 61 additions & 49 deletions base/compiler/abstractinterpretation.jl

Large diffs are not rendered by default.

160 changes: 160 additions & 0 deletions base/compiler/abstractlattice.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
abstract type AbstractLattice; end
function widen end

"""
struct JLTypeLattice
A singleton type representing the lattice of Julia types, without any inference
extensions.
"""
struct JLTypeLattice <: AbstractLattice; end
widen(::JLTypeLattice) = error("Type lattice is the least-precise lattice available")
is_valid_lattice(::JLTypeLattice, @nospecialize(elem)) = isa(elem, Type)

"""
struct ConstsLattice
A lattice extending `JLTypeLattice` and adjoining `Const` and `PartialTypeVar`.
"""
struct ConstsLattice <: AbstractLattice; end
widen(::ConstsLattice) = JLTypeLattice()
is_valid_lattice(lattice::ConstsLattice, @nospecialize(elem)) =
is_valid_lattice(widen(lattice), elem) || isa(elem, Const) || isa(elem, PartialTypeVar)

"""
struct PartialsLattice{L}
A lattice extending lattice `L` and adjoining `PartialStruct` and `PartialOpaque`.
"""
struct PartialsLattice{L <: AbstractLattice} <: AbstractLattice
parent::L
end
widen(L::PartialsLattice) = L.parent
is_valid_lattice(lattice::PartialsLattice, @nospecialize(elem)) =
is_valid_lattice(widen(lattice), elem) ||
isa(elem, PartialStruct) || isa(elem, PartialOpaque)

"""
struct ConditionalsLattice{L}
A lattice extending lattice `L` and adjoining `Conditional`.
"""
struct ConditionalsLattice{L <: AbstractLattice} <: AbstractLattice
parent::L
end
widen(L::ConditionalsLattice) = L.parent
is_valid_lattice(lattice::ConditionalsLattice, @nospecialize(elem)) =
is_valid_lattice(widen(lattice), elem) || isa(elem, Conditional)

struct InterConditionalsLattice{L <: AbstractLattice} <: AbstractLattice
parent::L
end
widen(L::InterConditionalsLattice) = L.parent
is_valid_lattice(lattice::InterConditionalsLattice, @nospecialize(elem)) =
is_valid_lattice(widen(lattice), elem) || isa(elem, InterConditional)

const AnyConditionalsLattice{L} = Union{ConditionalsLattice{L}, InterConditionalsLattice{L}}
const BaseInferenceLattice = typeof(ConditionalsLattice(PartialsLattice(ConstsLattice())))
const IPOResultLattice = typeof(InterConditionalsLattice(PartialsLattice(ConstsLattice())))

"""
struct OptimizerLattice
The lattice used by the optimizer. Extends
`BaseInferenceLattice` with `MaybeUndef`.
"""
struct OptimizerLattice <: AbstractLattice; end
widen(L::OptimizerLattice) = BaseInferenceLattice.instance
is_valid_lattice(lattice::OptimizerLattice, @nospecialize(elem)) =
is_valid_lattice(widen(lattice), elem) || isa(elem, MaybeUndef)

"""
struct InferenceLattice{L}
The full lattice used for abstract interpration during inference. Takes
a base lattice and adjoins `LimitedAccuracy`.
"""
struct InferenceLattice{L} <: AbstractLattice
parent::L
end
widen(L::InferenceLattice) = L.parent
is_valid_lattice(lattice::InferenceLattice, @nospecialize(elem)) =
is_valid_lattice(widen(lattice), elem) || isa(elem, LimitedAccuracy)

"""
tmeet(lattice, a, b::Type)
Compute the lattice meet of lattice elements `a` and `b` over the lattice
`lattice`. If `lattice` is `JLTypeLattice`, this is equiavalent to type
intersection. Note that currently `b` is restricted to being a type (interpreted
as a lattice element in the JLTypeLattice sub-lattice of `lattice`).
"""
function tmeet end

"""
tmerge(lattice, a, b)
Compute a lattice join of elements `a` and `b` over the lattice `lattice`.
Note that the computed element need not be the least upper bound of `a` and
`b`, but rather, we impose some heuristic limits on the complexity of the
joined element, ideally without losing too much precision in common cases and
remaining mostly associative and commutative.
"""
function tmerge end

"""
⊑(lattice, a, b)
Compute the lattice ordering (i.e. less-than-or-equal) relationship between
lattice elements `a` and `b` over the lattice `lattice`. If `lattice` is
`JLTypeLattice`, this is equiavalent to subtyping.
"""
function end

function (::JLTypeLattice, @nospecialize(a::Type), @nospecialize(b::Type))
a <: b
end


"""
⊏(lattice, a, b) -> Bool
The strict partial order over the type inference lattice.
This is defined as the irreflexive kernel of `⊑`.
"""
(lattice::AbstractLattice, @nospecialize(a), @nospecialize(b)) = (lattice, a, b) && !(lattice, b, a)

"""
⋤(lattice, a, b) -> Bool
This order could be used as a slightly more efficient version of the strict order `⊏`,
where we can safely assume `a ⊑ b` holds.
"""
(lattice::AbstractLattice, @nospecialize(a), @nospecialize(b)) = !(lattice, b, a)

"""
is_lattice_equal(a, b) -> Bool
Check if two lattice elements are partial order equivalent.
This is basically `a ⊑ b && b ⊑ a` but (optionally) with extra performance optimizations.
"""
function is_lattice_equal(lattice::AbstractLattice, @nospecialize(a), @nospecialize(b))
a === b && return true
(lattice, a, b) && (lattice, b, a)
end

# Curried versions
(lattice::AbstractLattice) = (@nospecialize(a), @nospecialize(b)) -> (lattice, a, b)
(lattice::AbstractLattice) = (@nospecialize(a), @nospecialize(b)) -> (lattice, a, b)
(lattice::AbstractLattice) = (@nospecialize(a), @nospecialize(b)) -> (lattice, a, b)

# Fallbacks for external packages using these methods
const fallback_lattice = InferenceLattice(BaseInferenceLattice.instance)
const fallback_ipo_lattice = InferenceLattice(IPOResultLattice.instance)

(@nospecialize(a), @nospecialize(b)) = (fallback_lattice, a, b)
tmeet(@nospecialize(a), @nospecialize(b)) = tmeet(fallback_lattice, a, b)
tmerge(@nospecialize(a), @nospecialize(b)) = tmerge(fallback_lattice, a, b)
(@nospecialize(a), @nospecialize(b)) = (fallback_lattice, a, b)
(@nospecialize(a), @nospecialize(b)) = (fallback_lattice, a, b)
is_lattice_equal(@nospecialize(a), @nospecialize(b)) = is_lattice_equal(fallback_lattice, a, b)
1 change: 1 addition & 0 deletions base/compiler/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ include("compiler/ssair/ir.jl")
include("compiler/inferenceresult.jl")
include("compiler/inferencestate.jl")

include("compiler/abstractlattice.jl")
include("compiler/typeutils.jl")
include("compiler/typelimits.jl")
include("compiler/typelattice.jl")
Expand Down
8 changes: 5 additions & 3 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@ Returns a tuple of (effect_free_and_nothrow, nothrow) for a given statement.
"""
function stmt_effect_flags(@nospecialize(stmt), @nospecialize(rt), src::Union{IRCode,IncrementalCompact})
# TODO: We're duplicating analysis from inference here.
lattice = OptimizerLattice()
= (lattice)
isa(stmt, PiNode) && return (true, true)
isa(stmt, PhiNode) && return (true, true)
isa(stmt, ReturnNode) && return (false, true)
Expand Down Expand Up @@ -248,7 +250,7 @@ function stmt_effect_flags(@nospecialize(stmt), @nospecialize(rt), src::Union{IR
return (total, total)
end
rt === Bottom && return (false, false)
nothrow = _builtin_nothrow(f, Any[argextype(args[i], src) for i = 2:length(args)], rt)
nothrow = _builtin_nothrow(f, Any[argextype(args[i], src) for i = 2:length(args)], rt, lattice)
nothrow || return (false, false)
return (contains_is(_EFFECT_FREE_BUILTINS, f), nothrow)
elseif head === :new
Expand Down Expand Up @@ -277,11 +279,11 @@ function stmt_effect_flags(@nospecialize(stmt), @nospecialize(rt), src::Union{IR
typ = argextype(args[1], src)
typ, isexact = instanceof_tfunc(typ)
isexact || return (false, false)
typ Tuple || return (false, false)
typ Tuple || return (false, false)
rt_lb = argextype(args[2], src)
rt_ub = argextype(args[3], src)
source = argextype(args[4], src)
if !(rt_lb Type && rt_ub Type && source Method)
if !(rt_lb Type && rt_ub Type && source Method)
return (false, false)
end
return (true, true)
Expand Down
24 changes: 13 additions & 11 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ function CFGInliningState(ir::IRCode)
)
end

(@nospecialize(a), @nospecialize(b)) = (OptimizerLattice(), a, b)

# Tells the inliner that we're now inlining into block `block`, meaning
# all previous blocks have been processed and can be added to the new cfg
function inline_into_block!(state::CFGInliningState, block::Int)
Expand Down Expand Up @@ -1080,8 +1082,8 @@ function inline_apply!(
nonempty_idx = 0
for i = (arg_start + 1):length(argtypes)
ti = argtypes[i]
ti Tuple{} && continue
if ti Tuple && nonempty_idx == 0
ti Tuple{} && continue
if ti Tuple && nonempty_idx == 0
nonempty_idx = i
continue
end
Expand Down Expand Up @@ -1123,9 +1125,9 @@ end
# TODO: this test is wrong if we start to handle Unions of function types later
is_builtin(s::Signature) =
isa(s.f, IntrinsicFunction) ||
s.ft IntrinsicFunction ||
s.ft IntrinsicFunction ||
isa(s.f, Builtin) ||
s.ft Builtin
s.ft Builtin

function inline_invoke!(
ir::IRCode, idx::Int, stmt::Expr, info::InvokeCallInfo, flag::UInt8,
Expand Down Expand Up @@ -1165,7 +1167,7 @@ function narrow_opaque_closure!(ir::IRCode, stmt::Expr, @nospecialize(info), sta
ub, exact = instanceof_tfunc(ubt)
exact || return
# Narrow opaque closure type
newT = widenconst(tmeet(tmerge(lb, info.unspec.rt), ub))
newT = widenconst(tmeet(tmerge(OptimizerLattice(), lb, info.unspec.rt), ub))
if newT != ub
# N.B.: Narrowing the ub requires a backdge on the mi whose type
# information we're using, since a change in that function may
Expand Down Expand Up @@ -1222,7 +1224,7 @@ function process_simple!(ir::IRCode, idx::Int, state::InliningState, todo::Vecto
ir.stmts[idx][:inst] = earlyres.val
return nothing
end
if (sig.f === modifyfield! || sig.ft typeof(modifyfield!)) && 5 <= length(stmt.args) <= 6
if (sig.f === modifyfield! || sig.ft typeof(modifyfield!)) && 5 <= length(stmt.args) <= 6
let info = ir.stmts[idx][:info]
info isa MethodResultPure && (info = info.info)
info isa ConstCallInfo && (info = info.call)
Expand All @@ -1240,7 +1242,7 @@ function process_simple!(ir::IRCode, idx::Int, state::InliningState, todo::Vecto
end

if check_effect_free!(ir, idx, stmt, rt)
if sig.f === typeassert || sig.ft typeof(typeassert)
if sig.f === typeassert || sig.ft typeof(typeassert)
# typeassert is a no-op if effect free
ir.stmts[idx][:inst] = stmt.args[2]
return nothing
Expand Down Expand Up @@ -1637,7 +1639,7 @@ function early_inline_special_case(
elseif ispuretopfunction(f) || contains_is(_PURE_BUILTINS, f)
return SomeCase(quoted(val))
elseif contains_is(_EFFECT_FREE_BUILTINS, f)
if _builtin_nothrow(f, argtypes[2:end], type)
if _builtin_nothrow(f, argtypes[2:end], type, OptimizerLattice())
return SomeCase(quoted(val))
end
elseif f === Core.get_binding_type
Expand Down Expand Up @@ -1683,17 +1685,17 @@ function late_inline_special_case!(
elseif length(argtypes) == 3 && istopfunction(f, :(>:))
# special-case inliner for issupertype
# that works, even though inference generally avoids inferring the `>:` Method
if isa(type, Const) && _builtin_nothrow(<:, Any[argtypes[3], argtypes[2]], type)
if isa(type, Const) && _builtin_nothrow(<:, Any[argtypes[3], argtypes[2]], type, OptimizerLattice())
return SomeCase(quoted(type.val))
end
subtype_call = Expr(:call, GlobalRef(Core, :(<:)), stmt.args[3], stmt.args[2])
return SomeCase(subtype_call)
elseif f === TypeVar && 2 <= length(argtypes) <= 4 && (argtypes[2] Symbol)
elseif f === TypeVar && 2 <= length(argtypes) <= 4 && (argtypes[2] Symbol)
typevar_call = Expr(:call, GlobalRef(Core, :_typevar), stmt.args[2],
length(stmt.args) < 4 ? Bottom : stmt.args[3],
length(stmt.args) == 2 ? Any : stmt.args[end])
return SomeCase(typevar_call)
elseif f === UnionAll && length(argtypes) == 3 && (argtypes[2] TypeVar)
elseif f === UnionAll && length(argtypes) == 3 && (argtypes[2] TypeVar)
unionall_call = Expr(:foreigncall, QuoteNode(:jl_type_unionall), Any, svec(Any, Any),
0, QuoteNode(:ccall), stmt.args[2], stmt.args[3])
return SomeCase(unionall_call)
Expand Down
16 changes: 9 additions & 7 deletions base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospe
# path, with a different type constraint. We may have
# to redo some work here with the wider typeconstraint
push!(worklist_defs, new_def)
push!(worklist_constraints, tmerge(new_constraint, visited_constraints[new_def]))
push!(worklist_constraints, tmerge(OptimizerLattice(), new_constraint, visited_constraints[new_def]))
end
continue
end
Expand Down Expand Up @@ -348,7 +348,7 @@ function is_getfield_captures(@nospecialize(def), compact::IncrementalCompact)
isa(which, Const) || return false
which.val === :captures || return false
oc = argextype(def.args[2], compact)
return oc Core.OpaqueClosure
return (OptimizerLattice(), oc, Core.OpaqueClosure)
end

struct LiftedValue
Expand Down Expand Up @@ -528,13 +528,15 @@ function lift_comparison!(::typeof(===), compact::IncrementalCompact,
lift_comparison_leaves!(egal_tfunc, compact, val, cmp, lifting_cache, idx)
end

isa_tfunc_opt(@nospecialize(v), @nospecialize(t)) = isa_tfunc(OptimizerLattice(), v, t)

function lift_comparison!(::typeof(isa), compact::IncrementalCompact,
idx::Int, stmt::Expr, lifting_cache::IdDict{Pair{AnySSAValue, Any}, AnySSAValue})
args = stmt.args
length(args) == 3 || return
cmp = argextype(args[3], compact)
val = args[2]
lift_comparison_leaves!(isa_tfunc, compact, val, cmp, lifting_cache, idx)
lift_comparison_leaves!(isa_tfunc_opt, compact, val, cmp, lifting_cache, idx)
end

function lift_comparison!(::typeof(isdefined), compact::IncrementalCompact,
Expand Down Expand Up @@ -1446,15 +1448,15 @@ function adce_pass!(ir::IRCode)
r = searchsorted(unionphis, val.id; by = first)
if !isempty(r)
unionphi = unionphis[first(r)]
t = tmerge(unionphi[2], stmt.typ)
t = tmerge(OptimizerLattice(), unionphi[2], stmt.typ)
unionphis[first(r)] = Pair{Int,Any}(unionphi[1], t)
end
end
else
if is_known_call(stmt, typeassert, compact) && length(stmt.args) == 3
# nullify safe `typeassert` calls
ty, isexact = instanceof_tfunc(argextype(stmt.args[3], compact))
if isexact && argextype(stmt.args[2], compact) ty
if isexact && (OptimizerLattice(), argextype(stmt.args[2], compact), ty)
compact[idx] = nothing
continue
end
Expand Down Expand Up @@ -1483,7 +1485,7 @@ function adce_pass!(ir::IRCode)
if !isempty(r)
unionphi = unionphis[first(r)]
unionphis[first(r)] = Pair{Int,Any}(unionphi[1],
tmerge(unionphi[2], inst[:type]))
tmerge(OptimizerLattice(), unionphi[2], inst[:type]))
end
end
end
Expand All @@ -1499,7 +1501,7 @@ function adce_pass!(ir::IRCode)
continue
elseif t === Any
continue
elseif compact.result[phi][:type] t
elseif (OptimizerLattice(), compact.result[phi][:type], t)
continue
end
to_drop = Int[]
Expand Down
Loading

0 comments on commit 59f14a1

Please sign in to comment.