# This file is a part of Julia. License is MIT: https://julialang.org/license

"""
    This struct keeps track of all uses of some mutable struct allocated
    in the current function. `uses` are all instances of `getfield` on the
    struct. `defs` are all instances of `setfield!` on the struct. The terminology
    refers to the uses/defs of the ``slot bundle'' that the mutable struct represents.

    In addition we keep track of all instances of a foreigncall preserve of this mutable
    struct. Somewhat counterintuitively, we don't actually need to make sure that the
    struct itself is live (or even allocated) at a ccall site. If there are no other places
    where the struct escapes (and thus e.g. where its address is taken), it need not be
    allocated. We do however, need to make sure to preserve any elements of this struct.
"""
struct SSADefUse
    uses::Vector{Int}
    defs::Vector{Int}
    ccall_preserve_uses::Vector{Int}
end
SSADefUse() = SSADefUse(Int[], Int[], Int[])

function try_compute_field_stmt(compact::IncrementalCompact, stmt::Expr)
    field = stmt.args[3]
    # fields are usually literals, handle them manually
    if isa(field, QuoteNode)
        field = field.value
    elseif isa(field, Int)
    # try to resolve other constants, e.g. global reference
    else
        field = compact_exprtype(compact, field)
        if isa(field, Const)
            field = field.val
        else
            return nothing
        end
    end
    isa(field, Union{Int, Symbol}) || return nothing
    return field
end

function try_compute_fieldidx_stmt(compact::IncrementalCompact, stmt::Expr, typ::DataType)
    field = try_compute_field_stmt(compact, stmt)
    return try_compute_fieldidx(typ, field)
end

function find_curblock(domtree::DomTree, allblocks::Vector{Int}, curblock::Int)
    # TODO: This can be much faster by looking at current level and only
    # searching for those blocks in a sorted order
    while !(curblock in allblocks)
        curblock = domtree.idoms_bb[curblock]
    end
    return curblock
end

function val_for_def_expr(ir::IRCode, def::Int, fidx::Int)
    ex = ir[SSAValue(def)]
    if isexpr(ex, :new)
        return ex.args[1+fidx]
    else
        @assert isa(ex, Expr)
        # The use is whatever the setfield was
        return ex.args[4]
    end
end

function compute_value_for_block(ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, du::SSADefUse, phinodes::IdDict{Int, SSAValue}, fidx::Int, curblock::Int)
    curblock = find_curblock(domtree, allblocks, curblock)
    def = 0
    for stmt in du.defs
        if block_for_inst(ir.cfg, stmt) == curblock
            def = max(def, stmt)
        end
    end
    def == 0 ? phinodes[curblock] : val_for_def_expr(ir, def, fidx)
end

function compute_value_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, du::SSADefUse, phinodes::IdDict{Int, SSAValue}, fidx::Int, use_idx::Int)
    def, stmtblock, curblock = find_def_for_use(ir, domtree, allblocks, du, use_idx)
    if def == 0
        if !haskey(phinodes, curblock)
            # If this happens, we need to search the predecessors for defs. Which
            # one doesn't matter - if it did, we'd have had a phinode
            return compute_value_for_block(ir, domtree, allblocks, du, phinodes, fidx, first(ir.cfg.blocks[stmtblock].preds))
        end
        # The use is the phinode
        return phinodes[curblock]
    else
        return val_for_def_expr(ir, def, fidx)
    end
end

# find the first dominating def for the given use
function find_def_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, du::SSADefUse, use_idx::Int)
    stmtblock = block_for_inst(ir.cfg, use_idx)
    curblock = find_curblock(domtree, allblocks, stmtblock)
    local def = 0
    for idx in du.defs
        if block_for_inst(ir.cfg, idx) == curblock
            if curblock != stmtblock
                # Find the last def in this block
                def = max(def, idx)
            else
                # Find the last def before our use
                def = max(def, idx >= use_idx ? 0 : idx)
            end
        end
    end
    return def, stmtblock, curblock
end

function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#),
                     callback = (@nospecialize(pi), @nospecialize(idx)) -> false)
    while true
        if isa(defssa, OldSSAValue)
            if already_inserted(compact, defssa)
                rename = compact.ssa_rename[defssa.id]
                if isa(rename, AnySSAValue)
                    defssa = rename
                    continue
                end
                return rename
            end
        end
        def = compact[defssa]
        if isa(def, PiNode)
            if callback(def, defssa)
                return defssa
            end
            def = def.val
            if isa(def, SSAValue)
                is_old(compact, defssa) && (def = OldSSAValue(def.id))
            else
                return def
            end
            defssa = def
        elseif isa(def, AnySSAValue)
            callback(def, defssa)
            if isa(def, SSAValue)
                is_old(compact, defssa) && (def = OldSSAValue(def.id))
            end
            defssa = def
        elseif isa(def, Union{PhiNode, PhiCNode, Expr, GlobalRef})
            return defssa
        else
            return def
        end
    end
end

function simple_walk_constraint(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#),
                                @nospecialize(typeconstraint) = types(compact)[defssa])
    callback = function (@nospecialize(pi), @nospecialize(idx))
        if isa(pi, PiNode)
            typeconstraint = typeintersect(typeconstraint, widenconst(pi.typ))
        end
        return false
    end
    def = simple_walk(compact, defssa, callback)
    return Pair{Any, Any}(def, typeconstraint)
end

"""
    walk_to_defs(compact, val, intermediaries)

Starting at `val` walk use-def chains to get all the leaves feeding into
this val (pruning those leaves rules out by path conditions).
"""
function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospecialize(typeconstraint), visited_phinodes::Vector{Any}=Any[])
    isa(defssa, AnySSAValue) || return Any[defssa]
    def = compact[defssa]
    isa(def, PhiNode) || return Any[defssa]
    # Step 2: Figure out what the struct is defined as
    ## Track definitions through PiNode/PhiNode
    found_def = false
    ## Track which PhiNodes, SSAValue intermediaries
    ## we forwarded through.
    visited = IdDict{Any, Any}()
    worklist_defs = Any[]
    worklist_constraints = Any[]
    leaves = Any[]
    push!(worklist_defs, defssa)
    push!(worklist_constraints, typeconstraint)
    while !isempty(worklist_defs)
        defssa = pop!(worklist_defs)
        typeconstraint = pop!(worklist_constraints)
        visited[defssa] = typeconstraint
        def = compact[defssa]
        if isa(def, PhiNode)
            push!(visited_phinodes, defssa)
            possible_predecessors = Int[]
            for n in 1:length(def.edges)
                isassigned(def.values, n) || continue
                val = def.values[n]
                if is_old(compact, defssa) && isa(val, SSAValue)
                    val = OldSSAValue(val.id)
                end
                edge_typ = widenconst(compact_exprtype(compact, val))
                hasintersect(edge_typ, typeconstraint) || continue
                push!(possible_predecessors, n)
            end
            for n in possible_predecessors
                pred = def.edges[n]
                val = def.values[n]
                if is_old(compact, defssa) && isa(val, SSAValue)
                    val = OldSSAValue(val.id)
                end
                if isa(val, AnySSAValue)
                    new_def, new_constraint = simple_walk_constraint(compact, val, typeconstraint)
                    if isa(new_def, AnySSAValue)
                        if !haskey(visited, new_def)
                            push!(worklist_defs, new_def)
                            push!(worklist_constraints, new_constraint)
                        elseif !(new_constraint <: visited[new_def])
                            # We have reached the same definition via a different
                            # path, with a different type constraint. We may have
                            # to redo some work here with the wider typeconstraint
                            push!(worklist_defs, new_def)
                            push!(worklist_constraints, tmerge(new_constraint, visited[new_def]))
                        end
                        continue
                    end
                    val = new_def
                end
                if def === val
                    # This shouldn't really ever happen, but
                    # patterns like this can occur in dead code,
                    # so bail out.
                    break
                else
                    push!(leaves, val)
                end
                continue
            end
        else
            push!(leaves, defssa)
        end
    end
    leaves
end

function process_immutable_preserve(new_preserves::Vector{Any}, compact::IncrementalCompact, def::Expr)
    for arg in (isexpr(def, :new) ? def.args : def.args[2:end])
        if !isbitstype(widenconst(compact_exprtype(compact, arg)))
            push!(new_preserves, arg)
        end
    end
end

function already_inserted(compact::IncrementalCompact, old::OldSSAValue)
    id = old.id
    if id < length(compact.ir.stmts)
        return id < compact.idx
    end
    id -= length(compact.ir.stmts)
    if id < length(compact.ir.new_nodes)
        error("")
    end
    id -= length(compact.ir.new_nodes)
    @assert id <= length(compact.pending_nodes)
    return !(id in compact.pending_perm)
end

function is_pending(compact::IncrementalCompact, old::OldSSAValue)
    return old.id > length(compact.ir.stmts) + length(compact.ir.new_nodes)
end

function is_getfield_captures(@nospecialize(def), compact::IncrementalCompact)
    isa(def, Expr) || return false
    length(def.args) >= 3 || return false
    is_known_call(def, getfield, compact) || return false
    which = compact_exprtype(compact, def.args[3])
    isa(which, Const) || return false
    which.val === :captures || return false
    oc = compact_exprtype(compact, def.args[2])
    return oc ⊑ Core.OpaqueClosure
end

# try to compute lifted values that can replace `getfield(x, field)` call
# where `x` is an immutable struct that are defined at any of `leaves`
function lift_leaves(compact::IncrementalCompact,
                     @nospecialize(result_t), field::Int, leaves::Vector{Any})
    # For every leaf, the lifted value
    lifted_leaves = IdDict{Any, Any}()
    maybe_undef = false
    for leaf in leaves
        leaf_key = leaf
        if isa(leaf, AnySSAValue)
            function lift_arg(ref::Core.Compiler.UseRef)
                lifted = ref[]
                if is_old(compact, leaf) && isa(lifted, SSAValue)
                    lifted = OldSSAValue(lifted.id)
                end
                if isa(lifted, GlobalRef) || isa(lifted, Expr)
                    lifted = insert_node!(compact, leaf, effect_free(NewInstruction(lifted, compact_exprtype(compact, lifted))))
                    ref[] = lifted
                    (isa(leaf, SSAValue) && (leaf.id < compact.result_idx)) && push!(compact.late_fixup, leaf.id)
                end
                lifted_leaves[leaf_key] = RefValue{Any}(lifted)
                nothing
            end
            function walk_leaf(@nospecialize(leaf))
                if isa(leaf, OldSSAValue) && already_inserted(compact, leaf)
                    leaf = compact.ssa_rename[leaf.id]
                    if isa(leaf, AnySSAValue)
                        leaf = simple_walk(compact, leaf)
                    end
                    if isa(leaf, AnySSAValue)
                        def = compact[leaf]
                    else
                        def = leaf
                    end
                elseif isa(leaf, AnySSAValue)
                    def = compact[leaf]
                else
                    def = leaf
                end
                return Pair{Any, Any}(def, leaf)
            end
            (def, leaf) = walk_leaf(leaf)
            if is_tuple_call(compact, def) && 1 <= field < length(def.args)
                lift_arg(UseRef(def, 1 + field))
                continue
            elseif isexpr(def, :new)
                typ = widenconst(types(compact)[leaf])
                if isa(typ, UnionAll)
                    typ = unwrap_unionall(typ)
                end
                (isa(typ, DataType) && !isabstracttype(typ)) || return nothing
                @assert !ismutabletype(typ)
                if length(def.args) < 1 + field
                    if field > fieldcount(typ)
                        return nothing
                    end
                    ftyp = fieldtype(typ, field)
                    if !isbitstype(ftyp)
                        # On this branch, this will be a guaranteed UndefRefError.
                        # We use the regular undef mechanic to lift this to a boolean slot
                        maybe_undef = true
                        lifted_leaves[leaf_key] = nothing
                        continue
                    end
                    return nothing
                    # Expand the Expr(:new) to include it's element Expr(:new) nodes up until the one we want
                    compact[leaf] = nothing
                    for i = (length(def.args) + 1):(1+field)
                        ftyp = fieldtype(typ, i - 1)
                        isbitstype(ftyp) || return nothing
                        ninst = effect_free(NewInstruction(Expr(:new, ftyp), result_t))
                        push!(def.args, insert_node!(compact, leaf, ninst))
                    end
                    compact[leaf] = def
                end
                lifted = def.args[1+field]
                if is_old(compact, leaf) && isa(lifted, SSAValue)
                    lifted = OldSSAValue(lifted.id)
                end
                if isa(lifted, GlobalRef) || isa(lifted, Expr)
                    lifted = insert_node!(compact, leaf, effect_free(NewInstruction(lifted, compact_exprtype(compact, lifted))))
                    def.args[1+field] = lifted
                    (isa(leaf, SSAValue) && (leaf.id < compact.result_idx)) && push!(compact.late_fixup, leaf.id)
                end
                lifted_leaves[leaf_key] = RefValue{Any}(lifted)
                continue
            elseif is_getfield_captures(def, compact)
                # Walk to new_opaque_closure
                ocleaf = def.args[2]
                if isa(ocleaf, AnySSAValue)
                    ocleaf = simple_walk(compact, ocleaf)
                end
                ocdef, _ = walk_leaf(ocleaf)
                if isexpr(ocdef, :new_opaque_closure) && isa(field, Int) && 1 <= field <= length(ocdef.args)-5
                    lift_arg(UseRef(ocdef, 5 + field))
                    continue
                end
                return nothing
            else
                typ = compact_exprtype(compact, leaf)
                if !isa(typ, Const)
                    # TODO: (disabled since #27126)
                    # If the leaf is an old ssa value, insert a getfield here
                    # We will revisit this getfield later when compaction gets
                    # to the appropriate point.
                    # N.B.: This can be a bit dangerous because it can lead to
                    # infinite loops if we accidentally insert a node just ahead
                    # of where we are
                    return nothing
                end
                leaf = typ.val
                # Fall through to below
            end
        elseif isa(leaf, QuoteNode)
            leaf = leaf.value
        elseif isa(leaf, GlobalRef)
            mod, name = leaf.mod, leaf.name
            if isdefined(mod, name) && isconst(mod, name)
                leaf = getfield(mod, name)
            else
                return nothing
            end
        elseif isa(leaf, Union{Argument, Expr})
            return nothing
        end
        ismutable(leaf) && return nothing
        isdefined(leaf, field) || return nothing
        val = getfield(leaf, field)
        is_inlineable_constant(val) || return nothing
        lifted_leaves[leaf_key] = RefValue{Any}(quoted(val))
    end
    return lifted_leaves, maybe_undef
end

make_MaybeUndef(@nospecialize(typ)) = isa(typ, MaybeUndef) ? typ : MaybeUndef(typ)

function lift_comparison!(compact::IncrementalCompact, idx::Int,
        @nospecialize(c1), @nospecialize(c2), stmt::Expr,
        lifting_cache::IdDict{Pair{AnySSAValue, Any}, AnySSAValue})
    if isa(c1, Const)
        cmp = c1
        typeconstraint = widenconst(c2)
        val = stmt.args[3]
    else
        cmp = c2::Const
        typeconstraint = widenconst(c1)
        val = stmt.args[2]
    end

    if isa(val, Union{OldSSAValue, SSAValue})
        val, typeconstraint = simple_walk_constraint(compact, val, typeconstraint)
    end

    visited_phinodes = Any[]
    leaves = walk_to_defs(compact, val, typeconstraint, visited_phinodes)

    # Let's check if we evaluate the comparison for each one of the leaves
    lifted_leaves = IdDict{Any, Any}()
    for leaf in leaves
        r = egal_tfunc(compact_exprtype(compact, leaf), cmp)
        if isa(r, Const)
            lifted_leaves[leaf] = RefValue{Any}(r.val)
        else
            # TODO: In some cases it might be profitable to hoist the ===
            # here.
            return
        end
    end

    lifted_val = perform_lifting!(compact, visited_phinodes, cmp, lifting_cache, Bool, lifted_leaves, val)
    @assert lifted_val !== nothing

    # global assertion_counter
    # assertion_counter::Int += 1
    # insert_node_here!(compact, Expr(:assert_egal, Symbol(string("assert_egal_", assertion_counter)), SSAValue(idx), lifted_val), nothing, 0, true)
    # return
    compact[idx] = lifted_val.x
end

struct LiftedPhi
    ssa::AnySSAValue
    node::PhiNode
    need_argupdate::Bool
end

function is_old(compact, @nospecialize(old_node_ssa))
    isa(old_node_ssa, OldSSAValue) &&
        !is_pending(compact, old_node_ssa) &&
        !already_inserted(compact, old_node_ssa)
end

function perform_lifting!(compact::IncrementalCompact,
        visited_phinodes::Vector{Any}, @nospecialize(cache_key),
        lifting_cache::IdDict{Pair{AnySSAValue, Any}, AnySSAValue},
        @nospecialize(result_t), lifted_leaves::IdDict{Any, Any}, @nospecialize(stmt_val))
    reverse_mapping = IdDict{Any, Any}(ssa => id for (id, ssa) in enumerate(visited_phinodes))

    # Insert PhiNodes
    lifted_phis = LiftedPhi[]
    for item in visited_phinodes
        if (item, cache_key) in keys(lifting_cache)
            ssa = lifting_cache[Pair{AnySSAValue, Any}(item, cache_key)]
            push!(lifted_phis, LiftedPhi(ssa, compact[ssa]::PhiNode, false))
            continue
        end
        n = PhiNode()
        ssa = insert_node!(compact, item, effect_free(NewInstruction(n, result_t)))
        lifting_cache[Pair{AnySSAValue, Any}(item, cache_key)] = ssa
        push!(lifted_phis, LiftedPhi(ssa, n, true))
    end

    # Fix up arguments
    for (old_node_ssa, lf) in zip(visited_phinodes, lifted_phis)
        old_node = compact[old_node_ssa]::PhiNode
        new_node = lf.node
        lf.need_argupdate || continue
        for i = 1:length(old_node.edges)
            edge = old_node.edges[i]
            isassigned(old_node.values, i) || continue
            val = old_node.values[i]
            orig_val = val
            if is_old(compact, old_node_ssa) && isa(val, SSAValue)
                val = OldSSAValue(val.id)
            end
            if isa(val, AnySSAValue)
                val = simple_walk(compact, val)
            end
            if val in keys(lifted_leaves)
                push!(new_node.edges, edge)
                lifted_val = lifted_leaves[val]
                if lifted_val === nothing
                    resize!(new_node.values, length(new_node.values)+1)
                    continue
                end
                lifted_val = lifted_val.x
                if isa(lifted_val, AnySSAValue)
                    callback = (@nospecialize(pi), @nospecialize(idx)) -> true
                    lifted_val = simple_walk(compact, lifted_val, callback)
                end
                push!(new_node.values, lifted_val)
            elseif isa(val, AnySSAValue) && val in keys(reverse_mapping)
                push!(new_node.edges, edge)
                push!(new_node.values, lifted_phis[reverse_mapping[val]].ssa)
            else
                # Probably ignored by path condition, skip this
            end
        end
    end

    for lf in lifted_phis
        count_added_node!(compact, lf.node)
    end

    # Fixup the stmt itself
    if isa(stmt_val, Union{SSAValue, OldSSAValue})
        stmt_val = simple_walk(compact, stmt_val)
    end

    if stmt_val in keys(lifted_leaves)
        stmt_val = lifted_leaves[stmt_val]
    elseif isa(stmt_val, AnySSAValue) && stmt_val in keys(reverse_mapping)
        stmt_val = RefValue{Any}(lifted_phis[reverse_mapping[stmt_val]].ssa)
    end

    return stmt_val
end

"""
    sroa_pass!(ir::IRCode) -> newir::IRCode

`getfield` elimination pass, a.k.a. Scalar Replacements of Aggregates optimization.

This pass is based on a local alias analysis that collects field information by def-use chain walking.
It looks for struct allocation sites ("definitions"), and `getfield` calls as well as
`:foreigncall`s that preserve the structs ("usages"). If "definitions" have enough information,
then this pass will replace corresponding usages with lifted values.
`mutable struct`s require additional cares and need to be handled separately from immutables.
For `mutable struct`s, `setfield!` calls account for "definitions" also, and the pass should
give up the lifting conservatively when there are any "intermediate usages" that may escape
the mutable struct (e.g. non-inlined generic function call that takes the mutable struct as
its argument).

In a case when all usages are fully eliminated, `struct` allocation may also be erased as
a result of dead code elimination.
"""
function sroa_pass!(ir::IRCode)
    compact = IncrementalCompact(ir)
    defuses = IdDict{Int, Tuple{IdSet{Int}, SSADefUse}}()
    lifting_cache = IdDict{Pair{AnySSAValue, Any}, AnySSAValue}()
    for ((_, idx), stmt) in compact
        isa(stmt, Expr) || continue
        result_t = compact_exprtype(compact, SSAValue(idx))
        is_getfield = is_setfield = false
        field_ordering = :unspecified
        # Step 1: Check whether the statement we're looking at is a getfield/setfield!
        if is_known_call(stmt, setfield!, compact)
            is_setfield = true
            4 <= length(stmt.args) <= 5 || continue
            if length(stmt.args) == 5
                field_ordering = compact_exprtype(compact, stmt.args[5])
            end
        elseif is_known_call(stmt, getfield, compact)
            is_getfield = true
            3 <= length(stmt.args) <= 5 || continue
            if length(stmt.args) == 5
                field_ordering = compact_exprtype(compact, stmt.args[5])
            elseif length(stmt.args) == 4
                field_ordering = compact_exprtype(compact, stmt.args[4])
                widenconst(field_ordering) === Bool && (field_ordering = :unspecified)
            end
        elseif is_known_call(stmt, isa, compact)
            # TODO
            continue
        elseif is_known_call(stmt, typeassert, compact) && length(stmt.args) == 3
            # Canonicalize
            #   X = typeassert(Y, T)::S
            # into
            #   typeassert(Y, T)
            #   X = PiNode(Y, S)
            # N.B.: Inference may have a more precise type for `S`, than
            #       just T, but from here on out, there's no problem with
            #       using just using that.
            # so subsequent analysis only has to deal with the latter
            # form. TODO: This isn't the best place to put this.
            # Also, we should probably have a version of typeassert
            # that's defined not to return its value to make life easier
            # for the backend.
            pi = insert_node_here!(compact,
                NewInstruction(
                    PiNode(stmt.args[2], compact.result[idx][:type]),
                    compact.result[idx][:type],
                    compact.result[idx][:line]), true)
            compact.ssa_rename[compact.idx-1] = pi
            continue
        elseif is_known_call(stmt, (===), compact) && length(stmt.args) == 3
            c1 = compact_exprtype(compact, stmt.args[2])
            c2 = compact_exprtype(compact, stmt.args[3])
            if !(isa(c1, Const) || isa(c2, Const))
                continue
            end
            (isa(c1, Const) && isa(c2, Const)) && continue
            lift_comparison!(compact, idx, c1, c2, stmt, lifting_cache)
            continue
        elseif isexpr(stmt, :foreigncall)
            nccallargs = length(stmt.args[3]::SimpleVector)
            new_preserves = Any[]
            old_preserves = stmt.args[(6+nccallargs):end]
            for (pidx, preserved_arg) in enumerate(old_preserves)
                isa(preserved_arg, SSAValue) || continue
                let intermediaries = IdSet{Int}()
                    callback = function (@nospecialize(pi), @nospecialize(ssa))
                        push!(intermediaries, ssa.id)
                        return false
                    end
                    def = simple_walk(compact, preserved_arg, callback)
                    isa(def, SSAValue) || continue
                    defidx = def.id
                    def = compact[defidx]
                    if is_tuple_call(compact, def)
                        process_immutable_preserve(new_preserves, compact, def)
                        old_preserves[pidx] = nothing
                        continue
                    elseif isexpr(def, :new)
                        typ = widenconst(compact_exprtype(compact, SSAValue(defidx)))
                        if isa(typ, UnionAll)
                            typ = unwrap_unionall(typ)
                        end
                        if typ isa DataType && !ismutabletype(typ)
                            process_immutable_preserve(new_preserves, compact, def)
                            old_preserves[pidx] = nothing
                            continue
                        end
                    else
                        continue
                    end
                    mid, defuse = get!(defuses, defidx, (IdSet{Int}(), SSADefUse()))
                    push!(defuse.ccall_preserve_uses, idx)
                    union!(mid, intermediaries)
                end
                continue
            end
            if !isempty(new_preserves)
                old_preserves = filter(ssa->ssa !== nothing, old_preserves)
                new_expr = Expr(:foreigncall, stmt.args[1:(6+nccallargs-1)]...,
                    old_preserves..., new_preserves...)
                compact[idx] = new_expr
            end
            continue
        else
            continue
        end
        field = try_compute_field_stmt(compact, stmt)
        field === nothing && continue

        struct_typ = unwrap_unionall(widenconst(compact_exprtype(compact, stmt.args[2])))
        if isa(struct_typ, Union) && struct_typ <: Tuple
            struct_typ = unswitchtupleunion(struct_typ)
        end
        isa(struct_typ, DataType) || continue

        struct_typ.name.atomicfields == C_NULL || continue # TODO: handle more
        if !(field_ordering === :unspecified || (field_ordering isa Const && field_ordering.val === :not_atomic))
            continue
        end

        def, typeconstraint = stmt.args[2], struct_typ

        if ismutabletype(struct_typ)
            isa(def, SSAValue) || continue
            let intermediaries = IdSet{Int}()
                callback = function (@nospecialize(pi), @nospecialize(ssa))
                    push!(intermediaries, ssa.id)
                    return false
                end
                def = simple_walk(compact, def, callback)
                # Mutable stuff here
                isa(def, SSAValue) || continue
                mid, defuse = get!(defuses, def.id, (IdSet{Int}(), SSADefUse()))
                if is_setfield
                    push!(defuse.defs, idx)
                else
                    push!(defuse.uses, idx)
                end
                union!(mid, intermediaries)
            end
            continue
        elseif is_setfield
            continue
        end

        # perform SROA on immutable structs here on

        if isa(def, Union{OldSSAValue, SSAValue})
            def, typeconstraint = simple_walk_constraint(compact, def, typeconstraint)
        end

        visited_phinodes = Any[]
        leaves = walk_to_defs(compact, def, typeconstraint, visited_phinodes)

        isempty(leaves) && continue

        field = try_compute_fieldidx(struct_typ, field)
        field === nothing && continue

        r = lift_leaves(compact, result_t, field, leaves)
        r === nothing && continue
        lifted_leaves, any_undef = r

        if any_undef
            result_t = make_MaybeUndef(result_t)
        end

        val = perform_lifting!(compact, visited_phinodes, field, lifting_cache, result_t, lifted_leaves, stmt.args[2])

        # Insert the undef check if necessary
        if any_undef
            if val === nothing
                insert_node!(compact, SSAValue(idx),
                    non_effect_free(NewInstruction(Expr(:throw_undef_if_not, Symbol("##getfield##"), false), Nothing)))
            else
                # val must be defined
            end
        else
            @assert val !== nothing
        end

        # global assertion_counter
        # assertion_counter::Int += 1
        # insert_node_here!(compact, Expr(:assert_egal, Symbol(string("assert_egal_", assertion_counter)), SSAValue(idx), val), nothing, 0, true)
        # continue
        compact[idx] = val === nothing ? nothing : val.x
    end

    non_dce_finish!(compact)
    # Copy the use count, `simple_dce!` may modify it and for our predicate
    # below we need it consistent with the state of the IR here (after tracking
    # phi node arguments, but before dce).
    used_ssas = copy(compact.used_ssas)
    simple_dce!(compact)
    ir = complete(compact)

    # Compute domtree, needed below, now that we have finished compacting the
    # IR. This needs to be after we iterate through the IR with
    # `IncrementalCompact` because removing dead blocks can invalidate the
    # domtree.
    @timeit "domtree 2" domtree = construct_domtree(ir.cfg.blocks)

    # Now go through any mutable structs and see which ones we can eliminate
    for (idx, (intermediaries, defuse)) in defuses
        intermediaries = collect(intermediaries)
        # Check if there are any uses we did not account for. If so, the variable
        # escapes and we cannot eliminate the allocation. This works, because we're guaranteed
        # not to include any intermediaries that have dead uses. As a result, missing uses will only ever
        # show up in the nuses_total count.
        nleaves = length(defuse.uses) + length(defuse.defs) + length(defuse.ccall_preserve_uses)
        nuses = 0
        for idx in intermediaries
            nuses += used_ssas[idx]
        end
        nuses_total = used_ssas[idx] + nuses - length(intermediaries)
        nleaves == nuses_total || continue
        # Find the type for this allocation
        defexpr = ir[SSAValue(idx)]
        isexpr(defexpr, :new) || continue
        typ = ir.stmts[idx][:type]
        if isa(typ, UnionAll)
            typ = unwrap_unionall(typ)
        end
        # Could still end up here if we tried to setfield! and immutable, which would
        # error at runtime, but is not illegal to have in the IR.
        ismutabletype(typ) || continue
        typ = typ::DataType
        # Partition defuses by field
        fielddefuse = SSADefUse[SSADefUse() for _ = 1:fieldcount(typ)]
        for use in defuse.uses
            stmt = ir[SSAValue(use)]
            # We may have discovered above that this use is dead
            # after the getfield elim of immutables. In that case,
            # it would have been deleted. That's fine, just ignore
            # the use in that case.
            stmt === nothing && continue
            field = try_compute_fieldidx_stmt(compact, stmt::Expr, typ)
            field === nothing && @goto skip
            push!(fielddefuse[field].uses, use)
        end
        for use in defuse.defs
            field = try_compute_fieldidx_stmt(compact, ir[SSAValue(use)]::Expr, typ)
            field === nothing && @goto skip
            push!(fielddefuse[field].defs, use)
        end
        # Check that the defexpr has defined values for all the fields
        # we're accessing. In the future, we may want to relax this,
        # but we should come up with semantics for well defined semantics
        # for uninitialized fields first.
        ndefuse = length(fielddefuse)
        blocks = Vector{Tuple{#=phiblocks=# Vector{Int}, #=allblocks=# Vector{Int}}}(undef, ndefuse)
        for fidx in 1:ndefuse
            du = fielddefuse[fidx]
            isempty(du.uses) && continue
            push!(du.defs, idx)
            ldu = compute_live_ins(ir.cfg, du)
            phiblocks = Int[]
            if !isempty(ldu.live_in_bbs)
                phiblocks = idf(ir.cfg, ldu, domtree)
            end
            allblocks = sort(vcat(phiblocks, ldu.def_bbs))
            blocks[fidx] = phiblocks, allblocks
            if fidx + 1 > length(defexpr.args)
                # even if the allocation contains an uninitialized field, we try an extra effort
                # to check if all uses have any "solid" `setfield!` calls that define the field
                # although we give up the cases below:
                # `def == idx`: this field can only defined at the allocation site (thus this case will throw)
                # `def == 0`: this field comes from `PhiNode`
                # we may be able to traverse control flows of PhiNode values, but it sounds
                # more complicated than beneficial under the current implementation
                for use in du.uses
                    def = find_def_for_use(ir, domtree, allblocks, du, use)[1]
                    (def == 0 || def == idx) && @goto skip
                end
            end
        end
        preserve_uses = IdDict{Int, Vector{Any}}((idx=>Any[] for idx in IdSet{Int}(defuse.ccall_preserve_uses)))
        # Everything accounted for. Go field by field and perform idf
        for fidx in 1:ndefuse
            du = fielddefuse[fidx]
            ftyp = fieldtype(typ, fidx)
            if !isempty(du.uses)
                phiblocks, allblocks = blocks[fidx]
                phinodes = IdDict{Int, SSAValue}()
                for b in phiblocks
                    n = PhiNode()
                    phinodes[b] = insert_node!(ir, first(ir.cfg.blocks[b].stmts),
                        NewInstruction(n, ftyp))
                end
                # Now go through all uses and rewrite them
                for stmt in du.uses
                    ir[SSAValue(stmt)] = compute_value_for_use(ir, domtree, allblocks, du, phinodes, fidx, stmt)
                end
                if !isbitstype(ftyp)
                    for (use, list) in preserve_uses
                        push!(list, compute_value_for_use(ir, domtree, allblocks, du, phinodes, fidx, use))
                    end
                end
                for b in phiblocks
                    for p in ir.cfg.blocks[b].preds
                        n = ir[phinodes[b]]::PhiNode
                        push!(n.edges, p)
                        push!(n.values, compute_value_for_block(ir, domtree,
                            allblocks, du, phinodes, fidx, p))
                    end
                end
            end
            for stmt in du.defs
                stmt == idx && continue
                ir[SSAValue(stmt)] = nothing
            end
        end
        isempty(defuse.ccall_preserve_uses) && continue
        push!(intermediaries, idx)
        # Insert the new preserves
        for (use, new_preserves) in preserve_uses
            useexpr = ir[SSAValue(use)]::Expr
            nccallargs = length(useexpr.args[3]::SimpleVector)
            old_preserves = let intermediaries = intermediaries
                filter(ssa->!isa(ssa, SSAValue) || !(ssa.id in intermediaries), useexpr.args[(6+nccallargs):end])
            end
            new_expr = Expr(:foreigncall, useexpr.args[1:(6+nccallargs-1)]...,
                old_preserves..., new_preserves...)
            ir[SSAValue(use)] = new_expr
        end

        @label skip
    end

    return ir
end
# assertion_counter = 0

function adce_erase!(phi_uses::Vector{Int}, extra_worklist::Vector{Int}, compact::IncrementalCompact, idx::Int)
    # return whether this made a change
    if isa(compact.result[idx][:inst], PhiNode)
        return maybe_erase_unused!(extra_worklist, compact, idx, val::SSAValue -> phi_uses[val.id] -= 1)
    else
        return maybe_erase_unused!(extra_worklist, compact, idx)
    end
end

function count_uses(@nospecialize(stmt), uses::Vector{Int})
    for ur in userefs(stmt)
        use = ur[]
        if isa(use, SSAValue)
            uses[use.id] += 1
        end
    end
end

function mark_phi_cycles!(compact::IncrementalCompact, safe_phis::BitSet, phi::Int)
    worklist = Int[]
    push!(worklist, phi)
    while !isempty(worklist)
        phi = pop!(worklist)
        push!(safe_phis, phi)
        for ur in userefs(compact.result[phi][:inst])
            val = ur[]
            isa(val, SSAValue) || continue
            isa(compact[val], PhiNode) || continue
            (val.id in safe_phis) && continue
            push!(worklist, val.id)
        end
    end
end

"""
    adce_pass!(ir::IRCode) -> newir::IRCode

Aggressive Dead Code Elimination pass.

In addition to a simple DCE for unused values and allocations,
this pass also nullifies `typeassert` calls that can be proved to be no-op,
in order to allow LLVM to emit simpler code down the road.

Note that this pass is more effective after SROA optimization (i.e. `sroa_pass!`),
since SROA often allows this pass to:
- eliminate allocation of object whose field references are all replaced with scalar values, and
- nullify `typeassert` call whose first operand has been replaced with a scalar value
  (, which may have introduced new type information that inference did not understand)

Also note that currently this pass _needs_ to run after `sroa_pass!`, because
the `typeassert` elimination depends on the transformation within `sroa_pass!`
which redirects references of `typeassert`ed value to the corresponding `PiNode`.
"""
function adce_pass!(ir::IRCode)
    phi_uses = fill(0, length(ir.stmts) + length(ir.new_nodes))
    all_phis = Int[]
    compact = IncrementalCompact(ir)
    for ((_, idx), stmt) in compact
        if isa(stmt, PhiNode)
            push!(all_phis, idx)
        elseif isexpr(stmt, :call)
            # nullify safe `typeassert` calls
            if is_known_call(stmt, typeassert, compact) && length(stmt.args) == 3
                ty, isexact = instanceof_tfunc(compact_exprtype(compact, stmt.args[3]))
                if isexact && compact_exprtype(compact, stmt.args[2]) ⊑ ty
                    compact[idx] = nothing
                end
            end
        end
    end
    non_dce_finish!(compact)
    for phi in all_phis
        count_uses(compact.result[phi][:inst]::PhiNode, phi_uses)
    end
    # Perform simple DCE for unused values
    extra_worklist = Int[]
    for (idx, nused) in Iterators.enumerate(compact.used_ssas)
        idx >= compact.result_idx && break
        nused == 0 || continue
        adce_erase!(phi_uses, extra_worklist, compact, idx)
    end
    while !isempty(extra_worklist)
        adce_erase!(phi_uses, extra_worklist, compact, pop!(extra_worklist))
    end
    # Go back and erase any phi cycles
    changed = true
    while changed
        changed = false
        safe_phis = BitSet()
        for phi in all_phis
            # Save any phi cycles that have non-phi uses
            if compact.used_ssas[phi] - phi_uses[phi] != 0
                mark_phi_cycles!(compact, safe_phis, phi)
            end
        end
        for phi in all_phis
            if !(phi in safe_phis)
                push!(extra_worklist, phi)
            end
        end
        while !isempty(extra_worklist)
            if adce_erase!(phi_uses, extra_worklist, compact, pop!(extra_worklist))
                changed = true
            end
        end
    end
    return complete(compact)
end

function type_lift_pass!(ir::IRCode)
    lifted_undef = IdDict{Int, Any}()
    insts = ir.stmts
    for idx in 1:length(insts)
        stmt = insts[idx][:inst]
        stmt isa Expr || continue
        if (stmt.head === :isdefined || stmt.head === :undefcheck)
            # after optimization, undef can only show up by being introduced in
            # a phi node (or an UpsilonNode() argument to a PhiC node), so lift
            # all these nodes that have maybe undef values
            val = stmt.args[(stmt.head === :isdefined) ? 1 : 2]
            if stmt.head === :isdefined && (val isa Slot || val isa GlobalRef ||
                    isexpr(val, :static_parameter) || val isa Argument || val isa Symbol)
                # this is a legal node, so assume it was not introduced by
                # slot2ssa (at worst, we might leave in a runtime check that
                # shouldn't have been there)
                continue
            end
            # otherwise, we definitely have a corrupt node from slot2ssa, and
            # must fix or delete that now
            processed = IdDict{Int, Union{SSAValue, Bool}}()
            def = val
            while true
                # peek through PiNodes
                isa(val, SSAValue) || break
                def = insts[val.id][:inst]
                isa(def, PiNode) || break
                val = def.val
            end
            if !isa(val, SSAValue) || (!isa(def, PhiNode) && !isa(def, PhiCNode))
                # in most cases, reaching this statement implies we had a value
                if stmt.head === :undefcheck
                    insts[idx][:inst] = nothing
                else
                    insts[idx][:inst] = true
                end
                continue
            end
            stmt_id = val.id
            worklist = Tuple{Int, Int, SSAValue, Int}[(stmt_id, 0, SSAValue(0), 0)]
            if !haskey(lifted_undef, stmt_id)
                first = true
                while !isempty(worklist)
                    item, w_up_id, which, use = pop!(worklist)
                    def = insts[item][:inst]
                    if isa(def, PhiNode)
                        edges = copy(def.edges)
                        values = Vector{Any}(undef, length(edges))
                        new_phi = if length(values) == 0
                            false
                        else
                            insert_node!(ir, item, NewInstruction(PhiNode(edges, values), Bool))
                        end
                    else
                        def = def::PhiCNode
                        values = Vector{Any}(undef, length(def.values))
                        new_phi = if length(values) == 0
                            false
                        else
                            insert_node!(ir, item, NewInstruction(PhiCNode(values), Bool))
                        end
                    end
                    processed[item] = new_phi
                    if first
                        lifted_undef[stmt_id] = new_phi
                        first = false
                    end
                    local id::Int = 0
                    for i = 1:length(values)
                        if !isassigned(def.values, i)
                            val = false
                        elseif !isa(def.values[i], SSAValue)
                            val = true
                        else
                            up_id = id = (def.values[i]::SSAValue).id
                            @label restart
                            if !isa(ir.stmts[id][:type], MaybeUndef)
                                val = true
                            else
                                node = insts[id][:inst]
                                if isa(node, UpsilonNode)
                                    if !isdefined(node, :val)
                                        val = false
                                    elseif !isa(node.val, SSAValue)
                                        val = true
                                    else
                                        id = (node.val::SSAValue).id
                                        @goto restart
                                    end
                                else
                                    while isa(node, PiNode)
                                        id = node.val.id
                                        node = insts[id][:inst]
                                    end
                                    if isa(node, Union{PhiNode, PhiCNode})
                                        if haskey(processed, id)
                                            val = processed[id]
                                        else
                                            push!(worklist, (id, up_id, new_phi::SSAValue, i))
                                            continue
                                        end
                                    else
                                        val = true
                                    end
                                end
                            end
                        end
                        if isa(def, PhiNode)
                            values[i] = val
                        else
                            values[i] = insert_node!(ir, up_id, NewInstruction(UpsilonNode(val), Bool))
                        end
                    end
                    if which !== SSAValue(0)
                        phi = ir[which]
                        if isa(phi, PhiNode)
                            phi.values[use] = new_phi
                        else
                            phi = phi::PhiCNode
                            phi.values[use] = insert_node!(ir, w_up_id, NewInstruction(UpsilonNode(new_phi), Bool))
                        end
                    end
                end
            end
            inst = lifted_undef[stmt_id]
            if stmt.head === :undefcheck
                inst = Expr(:throw_undef_if_not, stmt.args[1], inst)
            end
            insts[idx][:inst] = inst
        end
    end
    ir
end

function cfg_simplify!(ir::IRCode)
    bbs = ir.cfg.blocks
    merge_into = zeros(Int, length(bbs))
    merged_succ = zeros(Int, length(bbs))
    function follow_merge_into(idx::Int)
        while merge_into[idx] != 0
            idx = merge_into[idx]
        end
        return idx
    end
    function follow_merged_succ(idx::Int)
        while merged_succ[idx] != 0
            idx = merged_succ[idx]
        end
        return idx
    end

    # Walk the CFG from the entry block and aggressively combine blocks
    for (idx, bb) in enumerate(bbs)
        if length(bb.succs) == 1
            succ = bb.succs[1]
            if length(bbs[succ].preds) == 1
                # Prevent cycles by making sure we don't end up back at `idx`
                # by following what is to be merged into `succ`
                if follow_merged_succ(succ) != idx
                    merge_into[succ] = idx
                    merged_succ[idx] = succ
                end
            end
        end
    end

    # Assign new BB numbers
    max_bb_num = 1
    bb_rename_succ = zeros(Int, length(bbs))
    for i = 1:length(bbs)
        # Drop blocks that will be merged away
        if merge_into[i] != 0
            bb_rename_succ[i] = -1
        end
        # Drop blocks with no predecessors
        if i != 1 && length(ir.cfg.blocks[i].preds) == 0
            bb_rename_succ[i] = -1
        end

        bb_rename_succ[i] != 0 && continue

        curr = i
        while true
            bb_rename_succ[curr] = max_bb_num
            max_bb_num += 1
            # Now walk the chain of blocks we merged.
            # If we end in something that may fall through,
            # we have to schedule that block next
            curr = follow_merged_succ(curr)
            terminator = ir.stmts[ir.cfg.blocks[curr].stmts[end]][:inst]
            if isa(terminator, GotoNode) || isa(terminator, ReturnNode)
                break
            end
            curr += 1
        end
    end

    # Figure out how predecessors should be renamed
    bb_rename_pred = zeros(Int, length(bbs))
    for i = 1:length(bbs)
        if merged_succ[i] != 0
            # Block `i` should no longer be a predecessor (before renaming)
            # because it is being merged with its sole successor
            bb_rename_pred[i] = -1
            continue
        end
        bbnum = follow_merge_into(i)
        bb_rename_pred[i] = bb_rename_succ[bbnum]
    end

    # Compute map from new to old blocks
    result_bbs = Int[findfirst(j->i==j, bb_rename_succ) for i = 1:max_bb_num-1]

    # Compute new block lengths
    result_bbs_lengths = zeros(Int, max_bb_num-1)
    for (idx, orig_bb) in enumerate(result_bbs)
        ms = orig_bb
        while ms != 0
            result_bbs_lengths[idx] += length(bbs[ms].stmts)
            ms = merged_succ[ms]
        end
    end

    # Compute statement indices the new blocks start at
    bb_starts = Vector{Int}(undef, 1+length(result_bbs_lengths))
    bb_starts[1] = 1
    for i = 1:length(result_bbs_lengths)
        bb_starts[i+1] = bb_starts[i] + result_bbs_lengths[i]
    end

    cresult_bbs = let result_bbs = result_bbs,
                      merged_succ = merged_succ,
                      merge_into = merge_into,
                      bbs = bbs,
                      bb_rename_succ = bb_rename_succ

        # Compute (renamed) successors and predecessors given (renamed) block
        function compute_succs(i)
            orig_bb = follow_merged_succ(result_bbs[i])
            return Int[bb_rename_succ[i] for i in bbs[orig_bb].succs]
        end
        function compute_preds(i)
            orig_bb = result_bbs[i]
            preds = bbs[orig_bb].preds
            return Int[bb_rename_pred[pred] for pred in preds]
        end

        BasicBlock[
            BasicBlock(StmtRange(bb_starts[i],
                                 i+1 > length(bb_starts) ?
                                    length(compact.result) : bb_starts[i+1]-1),
                       compute_preds(i),
                       compute_succs(i))
            for i = 1:length(result_bbs)]
    end

    compact = IncrementalCompact(ir, true)
    # Run instruction compaction to produce the result,
    # but we're messing with the CFG
    # so we don't want compaction to do so independently
    compact.fold_constant_branches = false
    compact.bb_rename_succ = bb_rename_succ
    compact.bb_rename_pred = bb_rename_pred
    compact.result_bbs = cresult_bbs
    result_idx = 1
    for (idx, orig_bb) in enumerate(result_bbs)
        ms = orig_bb
        while ms != 0
            for i in bbs[ms].stmts
                node = ir.stmts[i]
                compact.result[compact.result_idx] = node
                if isa(node[:inst], GotoNode) && merged_succ[ms] != 0
                    # If we merged a basic block, we need remove the trailing GotoNode (if any)
                    compact.result[compact.result_idx][:inst] = nothing
                else
                    process_node!(compact, compact.result_idx, node, i, i, ms, true)
                end
                # We always increase the result index to ensure a predicatable
                # placement of the resulting nodes.
                compact.result_idx += 1
            end
            ms = merged_succ[ms]
        end
    end
    compact.active_result_bb = length(bb_starts)
    return finish(compact)
end