Skip to content

Commit

Permalink
lattice overhaul step 2: convert existing extended lattice wrappers t…
Browse files Browse the repository at this point in the history
…o `LatticeElement` attributes

- pack `PartialStruct` into `LatticeElement.fields`
- pack `Conditional`/`InterConditional` into `LatticeElement.conditional`
- pack `Const` into `LatticeElement.constant`
- pack `PartialTypeVar` into `LatticeElement.partialtypevar`
- pack `LimitedAccuracy` into `LatticeElement.causes`
- pack `PartialOpaque` into `LatticeElement.partialopaque`
- pack `MaybeUndef` into `LatticeElement.maybeundef`
- merge `LatticeElement.partialopaque` and `LatticeElement.partialopaque`
  There is not much value in keeping them separate, since a variable
  usually doesn't have these "special" attributes at the same time.
- wrap `Vararg` in `LatticeElement.special::Vararg`
- add HACK to allow `DelayedTyp` to sneak in `LatticeElement` system
  And now we can eliminate `AbstractLattice`, and our inference code
  works with `LatticeElement` (mostly).
- define `SSAValueType(s)` / `Argtypes` aliases
  • Loading branch information
aviatesk committed Dec 17, 2021
1 parent 0df985b commit 0b60f1a
Show file tree
Hide file tree
Showing 29 changed files with 1,405 additions and 1,082 deletions.
9 changes: 2 additions & 7 deletions base/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -421,10 +421,7 @@ eval(Core, :(CodeInstance(mi::MethodInstance, @nospecialize(rettype), @nospecial
min_world::UInt, max_world::UInt) =
ccall(:jl_new_codeinst, Ref{CodeInstance}, (Any, Any, Any, Any, Int32, UInt, UInt),
mi, rettype, inferred_const, inferred, const_flags, min_world, max_world)))
eval(Core, :(Const(@nospecialize(v)) = $(Expr(:new, :Const, :v))))
eval(Core, :(PartialStruct(@nospecialize(typ), fields::Array{Any, 1}) = $(Expr(:new, :PartialStruct, :typ, :fields))))
eval(Core, :(PartialOpaque(@nospecialize(typ), @nospecialize(env), isva::Bool, parent::MethodInstance, source::Method) = $(Expr(:new, :PartialOpaque, :typ, :env, :isva, :parent, :source))))
eval(Core, :(InterConditional(slot::Int, @nospecialize(vtype), @nospecialize(elsetype)) = $(Expr(:new, :InterConditional, :slot, :vtype, :elsetype))))
eval(Core, :(MethodMatch(@nospecialize(spec_types), sparams::SimpleVector, method::Method, fully_covers::Bool) =
$(Expr(:new, :MethodMatch, :spec_types, :sparams, :method, :fully_covers))))

Expand Down Expand Up @@ -496,13 +493,11 @@ Symbol(s::Symbol) = s
module IR
export CodeInfo, MethodInstance, CodeInstance, GotoNode, GotoIfNot, ReturnNode,
NewvarNode, SSAValue, Slot, SlotNumber, TypedSlot, Argument,
PiNode, PhiNode, PhiCNode, UpsilonNode, LineInfoNode,
Const, PartialStruct
PiNode, PhiNode, PhiCNode, UpsilonNode, LineInfoNode

import Core: CodeInfo, MethodInstance, CodeInstance, GotoNode, GotoIfNot, ReturnNode,
NewvarNode, SSAValue, Slot, SlotNumber, TypedSlot, Argument,
PiNode, PhiNode, PhiCNode, UpsilonNode, LineInfoNode,
Const, PartialStruct
PiNode, PhiNode, PhiCNode, UpsilonNode, LineInfoNode

end

Expand Down
462 changes: 238 additions & 224 deletions base/compiler/abstractinterpretation.jl

Large diffs are not rendered by default.

94 changes: 64 additions & 30 deletions base/compiler/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,35 +112,11 @@ something(x::Any, y...) = x
# compiler #
############

# TODO remove me in the future, this is just to check the coverage of the overhaul
import Core: Const, PartialStruct, InterConditional, PartialOpaque, TypeofVararg
abstract type _AbstractLattice end
const AbstractLattice = Union{
Const, PartialStruct, InterConditional, PartialOpaque, TypeofVararg,
_AbstractLattice}
const Argtypes = Vector{AbstractLattice}

macro latticeop(mode, def)
@assert is_function_def(def)
sig, body = def.args
if mode === :args || mode === :op
nospecs = Symbol[]
for arg in sig.args
if isexpr(arg, :macrocall) && arg.args[1] === Symbol("@nospecialize")
push!(nospecs, arg.args[3])
end
end
idx = findfirst(x->!isa(x, LineNumberNode), body.args)
for var in nospecs
insert!(body.args, idx, Expr(:(=), var, Expr(:(::), var, :AbstractLattice)))
end
end
if mode === :ret || mode === :op
sig = Expr(:(::), sig, :AbstractLattice)
end
return esc(Expr(def.head, sig, body))
end
anymap(f::Function, a::Vector{AbstractLattice}) = Any[ f(a[i]) for i in 1:length(a) ]
include("compiler/typelattice.jl")

const Argtypes = Vector{LatticeElement}
const EMPTY_SLOTTYPES = Argtypes()
anymap(f::Function, a::Argtypes) = Any[ f(a[i]) for i in 1:length(a) ]

include("compiler/cicache.jl")
include("compiler/types.jl")
Expand All @@ -153,14 +129,72 @@ include("compiler/inferencestate.jl")

include("compiler/typeutils.jl")
include("compiler/typelimits.jl")
include("compiler/typelattice.jl")
include("compiler/tfuncs.jl")
include("compiler/stmtinfo.jl")

include("compiler/abstractinterpretation.jl")
include("compiler/typeinfer.jl")
include("compiler/optimize.jl") # TODO: break this up further + extract utilities

# function show(io::IO, xs::Vector)
# print(io, eltype(xs), '[')
# show_itr(io, xs)
# print(io, ']')
# end
# function show(io::IO, xs::Tuple)
# print(io, '(')
# show_itr(io, xs)
# print(io, ')')
# end
# function show_itr(io::IO, xs)
# n = length(xs)
# for i in 1:n
# show(io, xs[i])
# i == n || print(io, ", ")
# end
# end
# function show(io::IO, typ′::LatticeElement)
# function name(x)
# if isLimitedAccuracy(typ′)
# return (nameof(x), '′',)
# else
# return (nameof(x),)
# end
# end
# typ = ignorelimited(typ′)
# if isConditional(typ)
# show(io, conditional(typ))
# elseif isConst(typ)
# print(io, name(Const)..., '(', constant(typ), ')')
# elseif isPartialStruct(typ)
# print(io, name(PartialStruct)..., '(', widenconst(typ), ", [")
# n = length(partialfields(typ))
# for i in 1:n
# show(io, partialfields(typ)[i])
# i == n || print(io, ", ")
# end
# print(io, "])")
# elseif isPartialTypeVar(typ)
# print(io, name(PartialTypeVar)..., '(')
# show(io, typ.partialtypevar.tv)
# print(io, ')')
# else
# print(io, name(NativeType)..., '(', widenconst(typ), ')')
# end
# end
# function show(io::IO, typ::ConditionalInfo)
# if typ === __NULL_CONDITIONAL__
# return print(io, "__NULL_CONDITIONAL__")
# end
# print(io, nameof(Conditional), '(')
# show(io, typ.var)
# print(io, ", ")
# show(io, typ.vtype)
# print(io, ", ")
# show(io, typ.elsetype)
# print(io, ')')
# end

include("compiler/bootstrap.jl")
ccall(:jl_set_typeinf_func, Cvoid, (Any,), typeinf_ext_toplevel)

Expand Down
51 changes: 26 additions & 25 deletions base/compiler/inferenceresult.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

@latticeop args function is_argtype_match(@nospecialize(given_argtype),
@nospecialize(cache_argtype),
function is_argtype_match(given_argtype::LatticeElement,
cache_argtype::LatticeElement,
overridden_by_const::Bool)
if is_forwardable_argtype(given_argtype)
return is_lattice_equal(given_argtype, cache_argtype)
Expand All @@ -10,10 +10,10 @@
end

function is_forwardable_argtype(@nospecialize x)
return isa(x, Const) ||
isa(x, Conditional) ||
isa(x, PartialStruct) ||
isa(x, PartialOpaque)
return isConst(x) ||
isConditional(x) ||
isPartialStruct(x) ||
isPartialOpaque(x)
end

# In theory, there could be a `cache` containing a matching `InferenceResult`
Expand All @@ -26,13 +26,13 @@ function matching_cache_argtypes(
@assert isa(linfo.def, Method) # ensure the next line works
nargs::Int = linfo.def.nargs
cache_argtypes, overridden_by_const = matching_cache_argtypes(linfo, nothing, va_override)
given_argtypes = Vector{AbstractLattice}(undef, length(argtypes))
given_argtypes = Vector{LatticeElement}(undef, length(argtypes))
local condargs = nothing
for i in 1:length(argtypes)
argtype = argtypes[i]
# forward `Conditional` if it conveys a constraint on any other argument
if isa(argtype, Conditional) && fargs !== nothing
cnd = argtype
if isConditional(argtype) && fargs !== nothing
cnd = conditional(argtype)
slotid = find_constrained_arg(cnd, fargs, sv)
if slotid !== nothing
# using union-split signature, we may be able to narrow down `Conditional`
Expand All @@ -48,26 +48,26 @@ function matching_cache_argtypes(
condargs = Tuple{Int,Int}[]
end
push!(condargs, (slotid, i))
given_argtypes[i] = Conditional(SlotNumber(slotid), vtype, elsetype)
given_argtypes[i] = Conditional(slotid, vtype, elsetype)
end
continue
end
end
given_argtypes[i] = widenconditional(argtype)
end
isva = va_override || linfo.def.isva
if isva || isvarargtype(unwraptype(given_argtypes[end]))
isva_given_argtypes = Vector{Any}(undef, nargs)
if isva || isVararg(given_argtypes[end])
isva_given_argtypes = Vector{LatticeElement}(undef, nargs)
for i = 1:(nargs - isva)
isva_given_argtypes[i] = argtype_by_index(given_argtypes, i)
end
if isva
if length(given_argtypes) < nargs && isvarargtype(unwraptype(given_argtypes[end]))
if length(given_argtypes) < nargs && isVararg(given_argtypes[end])
last = length(given_argtypes)
else
last = nargs
end
isva_given_argtypes[nargs] = TypeLattice(tuple_tfunc(anymap(unwraptype, given_argtypes[last:end])))
isva_given_argtypes[nargs] = LatticeElement(tuple_tfunc(given_argtypes[last:end]))
# invalidate `Conditional` imposed on varargs
if condargs !== nothing
for (slotid, i) in condargs
Expand Down Expand Up @@ -101,7 +101,7 @@ function most_general_argtypes(method::Union{Method, Nothing}, @nospecialize(spe
# For opaque closure, the closure environment is processed elsewhere
nargs -= 1
end
cache_argtypes = Vector{AbstractLattice}(undef, nargs)
cache_argtypes = Vector{LatticeElement}(undef, nargs)
# First, if we're dealing with a varargs method, then we set the last element of `args`
# to the appropriate `Tuple` type or `PartialStruct` instance.
if !toplevel && isva
Expand All @@ -110,23 +110,24 @@ function most_general_argtypes(method::Union{Method, Nothing}, @nospecialize(spe
linfo_argtypes = Any[Any for i = 1:nargs]
linfo_argtypes[end] = Vararg{Any}
end
vargtype = Tuple
vargtype = NativeType(Tuple)
else
linfo_argtypes_length = length(linfo_argtypes)
if nargs > linfo_argtypes_length
va = linfo_argtypes[linfo_argtypes_length]
if isvarargtype(va)
new_va = rewrap_unionall(unconstrain_vararg_length(va), specTypes)
vargtype = Tuple{new_va}
vargtype = NativeType(Tuple{new_va})
else
vargtype = Tuple{}
vargtype = NativeType(Tuple{})
end
else
vargtype_elements = Any[]
vargtype_elements = LatticeElement[]
for i in nargs:linfo_argtypes_length
p = linfo_argtypes[i]
p = unwraptv(isvarargtype(p) ? unconstrain_vararg_length(p) : p)
push!(vargtype_elements, elim_free_typevars(rewrap_unionall(p, specTypes)))
p = elim_free_typevars(rewrap_unionall(p, specTypes))
push!(vargtype_elements, isvarargtype(p) ? mkVararg(p) : NativeType(p))
end
for i in 1:length(vargtype_elements)
atyp = vargtype_elements[i]
Expand All @@ -140,7 +141,7 @@ function most_general_argtypes(method::Union{Method, Nothing}, @nospecialize(spe
vargtype = tuple_tfunc(vargtype_elements)
end
end
cache_argtypes[nargs] = TypeLattice(vargtype)
cache_argtypes[nargs] = vargtype
nargs -= 1
end
# Now, we propagate type info from `linfo_argtypes` into `cache_argtypes`, improving some
Expand Down Expand Up @@ -168,10 +169,10 @@ function most_general_argtypes(method::Union{Method, Nothing}, @nospecialize(spe
atyp = elim_free_typevars(rewrap_unionall(atyp, specTypes))
end
i == n && (lastatype = atyp)
cache_argtypes[i] = TypeLattice(atyp)
cache_argtypes[i] = LatticeElement(atyp)
end
for i = (tail_index + 1):nargs
cache_argtypes[i] = TypeLattice(lastatype)
cache_argtypes[i] = LatticeElement(lastatype)
end
else
@assert nargs == 0 "invalid specialization of method" # wrong number of arguments
Expand Down Expand Up @@ -199,7 +200,7 @@ function matching_cache_argtypes(linfo::MethodInstance, ::Nothing, va_override::
return cache_argtypes, falses(length(cache_argtypes))
end

function cache_lookup(linfo::MethodInstance, given_argtypes::Vector{AbstractLattice}, cache::Vector{InferenceResult})
function cache_lookup(linfo::MethodInstance, given_argtypes::Argtypes, cache::Vector{InferenceResult})
method = linfo.def::Method
nargs::Int = method.nargs
method.isva && (nargs -= 1)
Expand All @@ -218,7 +219,7 @@ function cache_lookup(linfo::MethodInstance, given_argtypes::Vector{AbstractLatt
end
end
if method.isva && cache_match
cache_match = is_argtype_match(TypeLattice(tuple_tfunc(anymap(unwraptype, given_argtypes[(nargs + 1):end]))),
cache_match = is_argtype_match(LatticeElement(tuple_tfunc(anymap(unwraptype, given_argtypes[(nargs + 1):end]))),
cache_argtypes[end],
cache_overridden_by_const[end])
end
Expand Down
42 changes: 13 additions & 29 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,12 @@

const LineNum = Int

# The type of a variable load is either a value or an UndefVarError
# (only used in abstractinterpret, doesn't appear in optimize)
struct VarState
typ::AbstractLattice
undef::Bool
@latticeop args VarState(@nospecialize(typ), undef::Bool) = new(typ, undef)
end

"""
const VarTable = Vector{VarState}
The extended lattice that maps local variables to inferred type represented as `AbstractLattice`.
Each index corresponds to the `id` of `SlotNumber` which identifies each local variable.
Note that `InferenceState` will maintain multiple `VarTable`s at each SSA statement
to enable flow-sensitive analysis.
"""
const VarTable = Vector{VarState}

mutable struct InferenceState
params::InferenceParams
result::InferenceResult # remember where to put the result
linfo::MethodInstance
sptypes::Vector{AbstractLattice} # types of static parameter
slottypes::Vector{AbstractLattice}
sptypes::Argtypes # types of static parameter
slottypes::Argtypes
mod::Module
currpc::LineNum
pclimitations::IdSet{InferenceState} # causes of precision restrictions (LimitedAccuracy) on currpc ssavalue
Expand All @@ -40,7 +22,7 @@ mutable struct InferenceState
stmt_edges::Vector{Union{Nothing, Vector{Any}}}
stmt_info::Vector{Any}
# return type
bestguess::AbstractLattice
bestguess::LatticeElement
# current active instruction pointers
ip::BitSet
pc´´::LineNum
Expand Down Expand Up @@ -79,7 +61,9 @@ mutable struct InferenceState
sp = sptypes_from_meth_instance(linfo::MethodInstance)

nssavalues = src.ssavaluetypes::Int
src.ssavaluetypes = AbstractLattice[ NOT_FOUND for i = 1:nssavalues ]
# NOTE we can't initialize `src.ssavaluetypes` as `Argtypes` to avoid
# an allocation within `ir_to_codeinf!(src)` where we widen all ssavaluetypes to native Julia types
src.ssavaluetypes = Any[ NOT_FOUND for i = 1:nssavalues ]
stmt_info = Any[ nothing for i = 1:length(code) ]

n = length(code)
Expand All @@ -91,9 +75,9 @@ mutable struct InferenceState
argtypes = result.argtypes
nargs = length(argtypes)
s_argtypes = VarTable(undef, nslots)
slottypes = Vector{AbstractLattice}(undef, nslots)
slottypes = Vector{LatticeElement}(undef, nslots)
for i in 1:nslots
at = (i > nargs) ?: TypeLattice(argtypes[i])
at = (i > nargs) ?: LatticeElement(argtypes[i])
s_argtypes[i] = VarState(at, i > nargs)
slottypes[i] = at
end
Expand Down Expand Up @@ -316,9 +300,9 @@ function sptypes_from_meth_instance(linfo::MethodInstance)
else
ty = Const(v)
end
sp[i] = TypeLattice(ty)
sp[i] = LatticeElement(ty)
end
return collect(AbstractLattice, sp)
return collect(LatticeElement, sp)
end

_topmod(sv::InferenceState) = _topmod(sv.mod)
Expand All @@ -332,9 +316,9 @@ end

update_valid_age!(edge::InferenceState, sv::InferenceState) = update_valid_age!(sv, edge.valid_worlds)

@latticeop args function record_ssa_assign(ssa_id::Int, @nospecialize(new), frame::InferenceState)
ssavaluetypes = frame.src.ssavaluetypes::Vector{AbstractLattice}
old = ssavaluetypes[ssa_id]
function record_ssa_assign(ssa_id::Int, new::LatticeElement, frame::InferenceState)
ssavaluetypes = frame.src.ssavaluetypes::SSAValueTypes
old = ssavaluetypes[ssa_id]::SSAValueType
if old === NOT_FOUND || !(new old)
# typically, we expect that old ⊑ new (that output information only
# gets less precise with worse input information), but to actually
Expand Down
Loading

0 comments on commit 0b60f1a

Please sign in to comment.