Skip to content

Commit

Permalink
minor followups on recent CodeInstance refactors
Browse files Browse the repository at this point in the history
- simplifies the signature of `transform_result_for_cache`
- make `jl_uncompress_ir` take `MethodInstance` instead of `CodeInstance`
  and simplifies the inlining algorithm
  • Loading branch information
aviatesk committed Mar 5, 2024
1 parent 6745160 commit 5c8e5d3
Show file tree
Hide file tree
Showing 11 changed files with 64 additions and 73 deletions.
4 changes: 2 additions & 2 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -797,10 +797,10 @@ end

function IRInterpretationState(interp::AbstractInterpreter,
code::CodeInstance, mi::MethodInstance, argtypes::Vector{Any}, world::UInt)
@assert code.def === mi
@assert code.def === mi "method instance is not synced with code instance"
src = @atomic :monotonic code.inferred
if isa(src, String)
src = _uncompressed_ir(code, src)
src = _uncompressed_ir(mi, src)
else
isa(src, CodeInfo) || return nothing
end
Expand Down
12 changes: 5 additions & 7 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -907,8 +907,7 @@ function resolve_todo(mi::MethodInstance, result::Union{Nothing,InferenceResult,
compilesig_invokes=OptimizationParams(state.interp).compilesig_invokes)

add_inlining_backedge!(et, mi)
ir = inferred_result isa CodeInstance ? retrieve_ir_for_inlining(inferred_result, src) :
retrieve_ir_for_inlining(mi, src, preserve_local_sources)
ir = retrieve_ir_for_inlining(mi, src, preserve_local_sources)
return InliningTodo(mi, ir, effects)
end

Expand Down Expand Up @@ -938,8 +937,7 @@ function resolve_todo(mi::MethodInstance, @nospecialize(info::CallInfo), flag::U

preserve_local_sources = true
src_inlining_policy(state.interp, src, info, flag) || return nothing
ir = cached_result isa CodeInstance ? retrieve_ir_for_inlining(cached_result, src) :
retrieve_ir_for_inlining(mi, src, preserve_local_sources)
ir = retrieve_ir_for_inlining(mi, src, preserve_local_sources)
add_inlining_backedge!(et, mi)
return InliningTodo(mi, ir, effects)
end
Expand Down Expand Up @@ -994,9 +992,9 @@ function analyze_method!(match::MethodMatch, argtypes::Vector{Any},
return resolve_todo(mi, volatile_inf_result, info, flag, state; invokesig)
end

function retrieve_ir_for_inlining(cached_result::CodeInstance, src::MaybeCompressed)
src = _uncompressed_ir(cached_result, src)::CodeInfo
return inflate_ir!(src, cached_result.def)
function retrieve_ir_for_inlining(mi::MethodInstance, src::String, ::Bool)
src = _uncompressed_ir(mi, src)
return inflate_ir!(src, mi)
end
function retrieve_ir_for_inlining(mi::MethodInstance, src::CodeInfo, preserve_local_sources::Bool)
if preserve_local_sources
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1496,7 +1496,7 @@ function try_inline_finalizer!(ir::IRCode, argexprs::Vector{Any}, idx::Int,
end

src_inlining_policy(inlining.interp, src, info, IR_FLAG_NULL) || return false
src = retrieve_ir_for_inlining(code, src)
src = retrieve_ir_for_inlining(mi, src, #=preserve_local_sources=#true)

# For now: Require finalizer to only have one basic block
length(src.cfg.blocks) == 1 || return false
Expand Down
50 changes: 22 additions & 28 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,13 +230,6 @@ function finish!(interp::AbstractInterpreter, caller::InferenceState)
if opt isa OptimizationState
result.src = opt = ir_to_codeinf!(opt)
end
if opt isa CodeInfo
caller.src = opt
else
# In this case `caller.src` is invalid for clients (such as `typeinf_ext`) to use
# but that is what's permitted by `caller.cache_mode`.
# This is hopefully unreachable from such clients using `NativeInterpreter`.
end
return nothing
end

Expand Down Expand Up @@ -277,7 +270,7 @@ function is_result_constabi_eligible(result::InferenceResult)
return isa(result_type, Const) && is_foldable_nothrow(result.ipo_effects) && is_inlineable_constant(result_type.val)
end
function CodeInstance(interp::AbstractInterpreter, result::InferenceResult;
can_discard_trees::Bool=may_discard_trees(interp))
can_discard_trees::Bool=may_discard_trees(interp))
local const_flags::Int32
result_type = result.result
@assert !(result_type === nothing || result_type isa LimitedAccuracy)
Expand Down Expand Up @@ -315,7 +308,7 @@ function CodeInstance(interp::AbstractInterpreter, result::InferenceResult;
inferred_result = nothing
relocatability = 0x1
else
inferred_result = transform_result_for_cache(interp, result.linfo, result.valid_worlds, result, can_discard_trees)
inferred_result = transform_result_for_cache(interp, result)
if inferred_result isa CodeInfo
uncompressed = inferred_result
inferred_result = maybe_compress_codeinfo(interp, result.linfo, inferred_result, can_discard_trees)
Expand All @@ -341,19 +334,17 @@ function CodeInstance(interp::AbstractInterpreter, result::InferenceResult;
relocatability)
end

function transform_result_for_cache(interp::AbstractInterpreter,
linfo::MethodInstance, valid_worlds::WorldRange, result::InferenceResult,
can_discard_trees::Bool=may_discard_trees(interp))
function transform_result_for_cache(interp::AbstractInterpreter, result::InferenceResult)
return result.src
end

function maybe_compress_codeinfo(interp::AbstractInterpreter, linfo::MethodInstance, ci::CodeInfo,
can_discard_trees::Bool=may_discard_trees(interp))
def = linfo.def
function maybe_compress_codeinfo(interp::AbstractInterpreter, mi::MethodInstance, ci::CodeInfo,
can_discard_trees::Bool=may_discard_trees(interp))
def = mi.def
isa(def, Method) || return ci # don't compress toplevel code
cache_the_tree = true
if can_discard_trees
cache_the_tree = is_inlineable(ci) || isa_compileable_sig(linfo.specTypes, linfo.sparam_vals, def)
cache_the_tree = is_inlineable(ci) || isa_compileable_sig(mi.specTypes, mi.sparam_vals, def)
end
if cache_the_tree
if may_compress(interp)
Expand Down Expand Up @@ -572,13 +563,13 @@ function finish(me::InferenceState, interp::AbstractInterpreter)
# annotate fulltree with type information,
# either because we are the outermost code, or we might use this later
type_annotate!(interp, me)
doopt = (me.cache_mode != CACHE_MODE_NULL || me.parent !== nothing)
# Disable the optimizer if we've already determined that there's nothing for
# it to do.
if may_discard_trees(interp) && is_result_constabi_eligible(me.result)
doopt = false
end
if doopt && may_optimize(interp)
mayopt = may_optimize(interp)
doopt = mayopt &&
# disable optimization if we don't use this later
(me.cache_mode != CACHE_MODE_NULL || me.parent !== nothing) &&
# disable optimization if we've already obtained very accurate result
!result_is_constabi(interp, me.result, mayopt)
if doopt
me.result.src = OptimizationState(me, interp)
else
me.result.src = me.src # for reflection etc.
Expand Down Expand Up @@ -948,7 +939,8 @@ function codeinstance_for_const_with_code(interp::AbstractInterpreter, code::Cod
code.relocatability)
end

result_is_constabi(interp::AbstractInterpreter, run_optimizer::Bool, result::InferenceResult) =
result_is_constabi(interp::AbstractInterpreter, result::InferenceResult,
run_optimizer::Bool=may_optimize(interp)) =
run_optimizer && may_discard_trees(interp) && is_result_constabi_eligible(result)

# compute an inferred AST and return type
Expand All @@ -961,7 +953,7 @@ function typeinf_code(interp::AbstractInterpreter, mi::MethodInstance, run_optim
frame = typeinf_frame(interp, mi, run_optimizer)
frame === nothing && return nothing, Any
is_inferred(frame) || return nothing, Any
if result_is_constabi(interp, run_optimizer, frame.result)
if result_is_constabi(interp, frame.result, run_optimizer)
rt = frame.result.result::Const
return codeinfo_for_const(interp, frame.linfo, rt.val), widenconst(rt)
end
Expand Down Expand Up @@ -1091,7 +1083,8 @@ function ci_meets_requirement(code::CodeInstance, source_mode::UInt8, ci_is_cach
return false
end

_uncompressed_ir(ci::Core.CodeInstance, s::String) = ccall(:jl_uncompress_ir, Any, (Any, Any, Any), ci.def.def::Method, ci, s)::CodeInfo
_uncompressed_ir(mi::MethodInstance, s::String) =
ccall(:jl_uncompress_ir, Ref{CodeInfo}, (Any, Any, Any), mi.def::Method, mi, s)

# compute (and cache) an inferred AST and return type
function typeinf_ext(interp::AbstractInterpreter, mi::MethodInstance, source_mode::UInt8)
Expand Down Expand Up @@ -1137,8 +1130,9 @@ function typeinf_ext(interp::AbstractInterpreter, mi::MethodInstance, source_mod
# Inference result is not cacheable or is was cacheable, but we do not want to
# store the source in the cache, but the caller wanted it anyway (e.g. for reflection).
# We construct a new CodeInstance for it that is not part of the cache hierarchy.
code = CodeInstance(interp, result, can_discard_trees=(
source_mode != SOURCE_MODE_FORCE_SOURCE && source_mode != SOURCE_MODE_FORCE_SOURCE_UNCACHED))
can_discard_trees = !(source_mode == SOURCE_MODE_FORCE_SOURCE ||
source_mode == SOURCE_MODE_FORCE_SOURCE_UNCACHED)
code = CodeInstance(interp, result; can_discard_trees)

# If the caller cares about the code and this is constabi, still use our synthesis function
# anyway, because we will have not finished inferring the code inside the CodeInstance once
Expand Down
18 changes: 6 additions & 12 deletions base/compiler/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,29 +127,23 @@ function get_staged(mi::MethodInstance, world::UInt)
end
end

function retrieve_code_info(linfo::MethodInstance, world::UInt)
def = linfo.def
if !isa(def, Method)
return linfo.uninferred
end
c = nothing
if isdefined(def, :generator)
# user code might throw errors – ignore them
c = get_staged(linfo, world)
end
function retrieve_code_info(mi::MethodInstance, world::UInt)
def = mi.def
isa(def, Method) || return mi.uninferred
c = isdefined(def, :generator) ? get_staged(mi, world) : nothing
if c === nothing && isdefined(def, :source)
src = def.source
if src === nothing
# can happen in images built with --strip-ir
return nothing
elseif isa(src, String)
c = ccall(:jl_uncompress_ir, Any, (Any, Ptr{Cvoid}, Any), def, C_NULL, src)
c = _uncompressed_ir(mi, src)
else
c = copy(src::CodeInfo)
end
end
if c isa CodeInfo
c.parent = linfo
c.parent = mi
return c
end
return nothing
Expand Down
6 changes: 3 additions & 3 deletions src/aotcompile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ static void jl_ci_cache_lookup(const jl_cgparams_t &cgparams, jl_method_instance
if ((jl_value_t*)*src_out == jl_nothing)
*src_out = NULL;
if (*src_out && jl_is_method(def))
*src_out = jl_uncompress_ir(def, codeinst, (jl_value_t*)*src_out);
*src_out = jl_uncompress_ir(def, codeinst->def, (jl_value_t*)*src_out);
}
if (*src_out == NULL || !jl_is_code_info(*src_out)) {
if (cgparams.lookup != jl_rettype_inferred_addr) {
Expand Down Expand Up @@ -1950,7 +1950,7 @@ extern "C" JL_DLLEXPORT_CODEGEN jl_code_info_t *jl_gdbdumpcode(jl_method_instanc
src = (jl_code_info_t*)jl_atomic_load_relaxed(&codeinst->inferred);
if ((jl_value_t*)src != jl_nothing && !jl_is_code_info(src) && jl_is_method(mi->def.method)) {
JL_GC_PUSH2(&codeinst, &src);
src = jl_uncompress_ir(mi->def.method, codeinst, (jl_value_t*)src);
src = jl_uncompress_ir(mi->def.method, codeinst->def, (jl_value_t*)src);
JL_GC_POP();
}
}
Expand Down Expand Up @@ -1989,7 +1989,7 @@ void jl_get_llvmf_defn_impl(jl_llvmf_dump_t* dump, jl_method_instance_t *mi, siz
}
if (src) {
if ((jl_value_t*)src != jl_nothing && !jl_is_code_info(src) && jl_is_method(mi->def.method))
src = jl_uncompress_ir(mi->def.method, codeinst, (jl_value_t*)src);
src = jl_uncompress_ir(mi->def.method, codeinst->def, (jl_value_t*)src);
}

// emit this function into a new llvm module
Expand Down
4 changes: 2 additions & 2 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6012,7 +6012,7 @@ static std::pair<Function*, Function*> get_oc_function(jl_codectx_t &ctx, jl_met

if (it == ctx.emission_context.compiled_functions.end()) {
++EmittedOpaqueClosureFunctions;
jl_code_info_t *ir = jl_uncompress_ir(closure_method, ci, (jl_value_t*)inferred);
jl_code_info_t *ir = jl_uncompress_ir(closure_method, ci->def, (jl_value_t*)inferred);
JL_GC_PUSH1(&ir);
// TODO: Emit this inline and outline it late using LLVM's coroutine support.
orc::ThreadSafeModule closure_m = jl_create_ts_module(
Expand Down Expand Up @@ -9565,7 +9565,7 @@ jl_llvm_functions_t jl_emit_codeinst(
return jl_emit_oc_wrapper(m, params, codeinst->def, codeinst->rettype);
}
if (src && (jl_value_t*)src != jl_nothing && jl_is_method(def))
src = jl_uncompress_ir(def, codeinst, (jl_value_t*)src);
src = jl_uncompress_ir(def, codeinst->def, (jl_value_t*)src);
if (!src || !jl_is_code_info(src)) {
JL_GC_POP();
m = orc::ThreadSafeModule();
Expand Down
4 changes: 2 additions & 2 deletions src/ircode.c
Original file line number Diff line number Diff line change
Expand Up @@ -904,7 +904,7 @@ JL_DLLEXPORT jl_string_t *jl_compress_ir(jl_method_t *m, jl_code_info_t *code)
return v;
}

JL_DLLEXPORT jl_code_info_t *jl_uncompress_ir(jl_method_t *m, jl_code_instance_t *metadata, jl_string_t *data)
JL_DLLEXPORT jl_code_info_t *jl_uncompress_ir(jl_method_t *m, jl_method_instance_t *metadata, jl_string_t *data)
{
if (jl_is_code_info(data))
return (jl_code_info_t*)data;
Expand Down Expand Up @@ -980,7 +980,7 @@ JL_DLLEXPORT jl_code_info_t *jl_uncompress_ir(jl_method_t *m, jl_code_instance_t
JL_UNLOCK(&m->writelock); // Might GC
JL_GC_POP();
if (metadata) {
code->parent = metadata->def;
code->parent = metadata;
}

return code;
Expand Down
2 changes: 1 addition & 1 deletion src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -2136,7 +2136,7 @@ JL_DLLEXPORT jl_value_t *jl_copy_ast(jl_value_t *expr JL_MAYBE_UNROOTED);

// IR representation
JL_DLLEXPORT jl_value_t *jl_compress_ir(jl_method_t *m, jl_code_info_t *code);
JL_DLLEXPORT jl_code_info_t *jl_uncompress_ir(jl_method_t *m, jl_code_instance_t *metadata, jl_value_t *data);
JL_DLLEXPORT jl_code_info_t *jl_uncompress_ir(jl_method_t *m, jl_method_instance_t *metadata, jl_value_t *data);
JL_DLLEXPORT uint8_t jl_ir_flag_inlining(jl_value_t *data) JL_NOTSAFEPOINT;
JL_DLLEXPORT uint8_t jl_ir_flag_has_fcall(jl_value_t *data) JL_NOTSAFEPOINT;
JL_DLLEXPORT uint16_t jl_ir_inlining_cost(jl_value_t *data) JL_NOTSAFEPOINT;
Expand Down
26 changes: 17 additions & 9 deletions test/compiler/AbstractInterpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ CC.may_optimize(::AbsIntOnlyInterp1) = false
# it should work even if the interpreter discards inferred source entirely
@newinterp AbsIntOnlyInterp2
CC.may_optimize(::AbsIntOnlyInterp2) = false
CC.transform_result_for_cache(::AbsIntOnlyInterp2, ::Core.MethodInstance, ::CC.WorldRange, ::CC.InferenceResult) = nothing
CC.transform_result_for_cache(::AbsIntOnlyInterp2, ::CC.InferenceResult) = nothing
@test Base.infer_return_type(Base.init_stdio, (Ptr{Cvoid},); interp=AbsIntOnlyInterp2()) >: IO

# OverlayMethodTable
Expand Down Expand Up @@ -463,7 +463,7 @@ let # generate cache
target_mi = CC.specialize_method(only(methods(custom_lookup_target)), Tuple{typeof(custom_lookup_target),Bool,Int}, Core.svec())
target_ci = custom_lookup(target_mi, CONST_INVOKE_INTERP_WORLD, CONST_INVOKE_INTERP_WORLD)
@test target_ci.rettype == Tuple{Float64,Nothing} # constprop'ed source
# display(@ccall jl_uncompress_ir(target_ci.def.def::Any, C_NULL::Ptr{Cvoid}, target_ci.inferred::Any)::Any)
# display(Core.Compiler._uncompressed_ir(target_ci.def, target_ci.inferred::String))

raw = false
lookup = @cfunction(custom_lookup, Any, (Any,Csize_t,Csize_t))
Expand All @@ -486,31 +486,39 @@ end
@newinterp CustomDataInterp
struct CustomDataInterpToken end
CC.cache_owner(::CustomDataInterp) = CustomDataInterpToken()
global custom_data_interp_transformed_sin::Bool = global custom_data_interp_transformed_cos::Bool = false
struct CustomData
inferred
CustomData(@nospecialize inferred) = new(inferred)
end
function CC.transform_result_for_cache(interp::CustomDataInterp,
mi::Core.MethodInstance, valid_worlds::CC.WorldRange, result::CC.InferenceResult)
inferred_result = @invoke CC.transform_result_for_cache(interp::CC.AbstractInterpreter,
mi::Core.MethodInstance, valid_worlds::CC.WorldRange, result::CC.InferenceResult)
function CC.transform_result_for_cache(interp::CustomDataInterp, result::CC.InferenceResult)
inferred_result = @invoke CC.transform_result_for_cache(
interp::CC.AbstractInterpreter, result::CC.InferenceResult)
def = result.linfo.def
if def isa Method
meth_name = def.name
if meth_name === :sin
global custom_data_interp_transformed_sin = true
elseif meth_name === :cos
global custom_data_interp_transformed_cos = true
end
end
return CustomData(inferred_result)
end
function CC.src_inlining_policy(interp::CustomDataInterp, @nospecialize(src),
@nospecialize(info::CC.CallInfo), stmt_flag::UInt32)
@nospecialize(info::CC.CallInfo), stmt_flag::UInt32)
if src isa CustomData
src = src.inferred
end
return @invoke CC.src_inlining_policy(interp::CC.AbstractInterpreter, src::Any,
info::CC.CallInfo, stmt_flag::UInt32)
end
CC.retrieve_ir_for_inlining(cached_result::CodeInstance, src::CustomData) =
CC.retrieve_ir_for_inlining(cached_result, src.inferred)
CC.retrieve_ir_for_inlining(mi::MethodInstance, src::CustomData, preserve_local_sources::Bool) =
CC.retrieve_ir_for_inlining(mi, src.inferred, preserve_local_sources)
let src = code_typed((Int,); interp=CustomDataInterp()) do x
return sin(x) + cos(x)
end |> only |> first
@test custom_data_interp_transformed_sin && custom_data_interp_transformed_cos
@test count(isinvoke(:sin), src.code) == 1
@test count(isinvoke(:cos), src.code) == 1
@test count(isinvoke(:+), src.code) == 0
Expand Down
9 changes: 3 additions & 6 deletions test/precompile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1787,10 +1787,9 @@ let newinterp_path = abspath("compiler/newinterp.jl")
inferred
CustomData(@nospecialize inferred) = new(inferred)
end
function CC.transform_result_for_cache(interp::PrecompileInterpreter,
mi::Core.MethodInstance, valid_worlds::CC.WorldRange, result::CC.InferenceResult)
inferred_result = @invoke CC.transform_result_for_cache(interp::CC.AbstractInterpreter,
mi::Core.MethodInstance, valid_worlds::CC.WorldRange, result::CC.InferenceResult)
function CC.transform_result_for_cache(interp::PrecompileInterpreter, result::CC.InferenceResult)
inferred_result = @invoke CC.transform_result_for_cache(
interp::CC.AbstractInterpreter, result::CC.InferenceResult)
return CustomData(inferred_result)
end
function CC.src_inlining_policy(interp::PrecompileInterpreter, @nospecialize(src),
Expand All @@ -1801,8 +1800,6 @@ let newinterp_path = abspath("compiler/newinterp.jl")
return @invoke CC.src_inlining_policy(interp::CC.AbstractInterpreter, src::Any,
info::CC.CallInfo, stmt_flag::UInt32)
end
CC.retrieve_ir_for_inlining(cached_result::Core.CodeInstance, src::CustomData) =
CC.retrieve_ir_for_inlining(cached_result, src.inferred)
CC.retrieve_ir_for_inlining(mi::Core.MethodInstance, src::CustomData, preserve_local_sources::Bool) =
CC.retrieve_ir_for_inlining(mi, src.inferred, preserve_local_sources)
end
Expand Down

0 comments on commit 5c8e5d3

Please sign in to comment.