Skip to content

Commit

Permalink
Scoped execution for foreign AbsInt code
Browse files Browse the repository at this point in the history
1. Introduced a task-local inherited (nee ScopedValue) compiler field
2. Allow compilation and execution of code generated by foreign abstract
   interpreters

A new primitive `call_within` is introduced that switches the compiler
instance. The compiler instance is used for cache-lookups, compilation
request, and dispatch.
  • Loading branch information
vchuravy committed Mar 21, 2024
1 parent 8425b0e commit 4f16213
Show file tree
Hide file tree
Showing 32 changed files with 872 additions and 138 deletions.
1 change: 1 addition & 0 deletions base/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@
# exception::Any
# backtrace::Any
# scope::Any
# compiler::Any
# code::Any
#end

Expand Down
29 changes: 29 additions & 0 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2072,6 +2072,33 @@ function abstract_throw(interp::AbstractInterpreter, argtypes::Vector{Any}, ::Ab
return CallMeta(Union{}, exct, EFFECTS_THROWS, NoCallInfo())
end

function abstract_call_within(interp::AbstractInterpreter, (; fargs, argtypes)::ArgInfo, si::StmtInfo,
sv::AbsIntState, max_methods::Int=get_max_methods(interp, sv))
if length(argtypes) < 2

return CallMeta(Union{}, Effects(), NoCallInfo())
end
CT = argtypes[2]
other_compiler = singleton_type(CT)
if other_compiler === nothing
if CT isa Const
other_compiler = CT.val
else
# Compiler is not a singleton type result may depend on runtime configuration
add_remark!(interp, sv, "Skipped call_within since compiler plugin not constant")
return CallMeta(Any, Effects(), NoCallInfo())
end
end
# Change world to one where our methods exist.
cworld = invokelatest(compiler_world, other_compiler)::UInt
other_interp = Core._call_in_world(cworld, abstract_interpreter, other_compiler, get_inference_world(interp))
other_fargs = fargs === nothing ? nothing : fargs[3:end]
other_arginfo = ArgInfo(other_fargs, argtypes[3:end])
call = Core._call_in_world(cworld, abstract_call, other_interp, other_arginfo, si, sv, max_methods)
# TODO: Edges? Effects?
return CallMeta(call.rt, call.exct, call.effects, WithinCallInfo(other_compiler, call.info))
end

# call where the function is known exactly
function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f),
arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState,
Expand All @@ -2092,6 +2119,8 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f),
return abstract_applicable(interp, argtypes, sv, max_methods)
elseif f === throw
return abstract_throw(interp, argtypes, sv)
elseif f === Core._call_within
return abstract_call_within(interp, arginfo, si, sv, max_methods)
end
rt = abstract_call_builtin(interp, f, arginfo, sv)
ft = popfirst!(argtypes)
Expand Down
24 changes: 24 additions & 0 deletions base/compiler/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,30 @@ macro _boundscheck() Expr(:boundscheck) end
convert(::Type{Any}, Core.@nospecialize x) = x
convert(::Type{T}, x::T) where {T} = x

abstract type AbstractCompiler end
const CompilerInstance = Union{AbstractCompiler, Nothing}
const NativeCompiler = Nothing

current_compiler() = ccall(:jl_get_current_task, Ref{Task}, ()).compiler::CompilerInstance

"""
abstract_interpreter(::CompilerInstance, world::UInt)
Construct an appropriate abstract interpreter for the given compiler instance.
"""
function abstract_interpreter end

abstract_interpreter(::Nothing, world::UInt) = NativeInterpreter(world)

"""
compiler_world(::CompilerInstance)
The compiler world to execute this compiler instance in.
"""

compiler_world(::Nothing) = unsafe_load(cglobal(:jl_typeinf_world, UInt))
compiler_world(::AbstractCompiler) = get_world_counter() # equivalent to invokelatest

# These types are used by reflection.jl and expr.jl too, so declare them here.
# Note that `@assume_effects` is available only after loading namedtuple.jl.
abstract type MethodTableView end
Expand Down
5 changes: 5 additions & 0 deletions base/compiler/stmtinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -236,4 +236,9 @@ struct ModifyOpInfo <: CallInfo
info::CallInfo # the callinfo for the `op(getval(), x)` call
end

struct WithinCallInfo <: CallInfo
compiler::CompilerInstance
info::CallInfo
end

@specialize
13 changes: 12 additions & 1 deletion base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1174,7 +1174,18 @@ function typeinf_type(interp::AbstractInterpreter, mi::MethodInstance)
end

# This is a bridge for the C code calling `jl_typeinf_func()`
typeinf_ext_toplevel(mi::MethodInstance, world::UInt, source_mode::UInt8) = typeinf_ext_toplevel(NativeInterpreter(world), mi, source_mode)
# typeinf_ext_toplevel is going to be executed within `jl_typeinf_world`
function typeinf_ext_toplevel(compiler::CompilerInstance, mi::MethodInstance, world::UInt, source_mode::UInt8)
if compiler === nothing
return typeinf_ext_toplevel(abstract_interpreter(compiler, world), mi, source_mode)
else
# Change world to one where our methods exist.
cworld = invokelatest(compiler_world, compiler)::UInt
absint = Core._call_in_world(cworld, abstract_interpreter, compiler, world)
return Core._call_in_world(cworld, typeinf_ext_toplevel, absint, mi, source_mode)
end
end

function typeinf_ext_toplevel(interp::AbstractInterpreter, mi::MethodInstance, source_mode::UInt8)
return typeinf_ext(interp, mi, source_mode)
end
Expand Down
13 changes: 13 additions & 0 deletions base/essentials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1066,6 +1066,19 @@ function invoke_in_world(world::UInt, @nospecialize(f), @nospecialize args...; k
return Core._call_in_world(world, Core.kwcall, kwargs, f, args...)
end

"""
invoke_within(compiler, f, args...; kwargs...)
Call `f(args...; kwargs...)` within the compiler context provided by `compiler`.
"""
function invoke_within(compiler::Core.Compiler.CompilerInstance, @nospecialize(f), @nospecialize args...; kwargs...)
kwargs = Base.merge(NamedTuple(), kwargs)
if isempty(kwargs)
return Core._call_within(compiler, f, args...)
end
return Core._call_within(compiler, Core.kwcall, kwargs, f, args...)
end

inferencebarrier(@nospecialize(x)) = compilerbarrier(:type, x)

"""
Expand Down
15 changes: 6 additions & 9 deletions base/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1424,29 +1424,26 @@ struct CodegenParams
use_jlplt::Cint

"""
A pointer of type
An instance of type Core.Compiler.CompilerInstance
typedef jl_value_t *(*jl_codeinstance_lookup_t)(jl_method_instance_t *mi JL_PROPAGATES_ROOT,
size_t min_world, size_t max_world);
that may be used by external compilers as a callback to look up the code instance corresponding
to a particular method instance.
Used to look up the code instance corresponding to a particular method instance,
from the particular compiler instance.
"""
lookup::Ptr{Cvoid}
compiler::Core.Compiler.CompilerInstance

function CodegenParams(; track_allocations::Bool=true, code_coverage::Bool=true,
prefer_specsig::Bool=false,
gnu_pubnames::Bool=true, debug_info_kind::Cint = default_debug_info_kind(),
debug_info_level::Cint = Cint(JLOptions().debug_level), safepoint_on_entry::Bool=true,
gcstack_arg::Bool=true, use_jlplt::Bool=true,
lookup::Ptr{Cvoid}=unsafe_load(cglobal(:jl_rettype_inferred_addr, Ptr{Cvoid})))
compiler::Core.Compiler.CompilerInstance=nothing)
return new(
Cint(track_allocations), Cint(code_coverage),
Cint(prefer_specsig),
Cint(gnu_pubnames), debug_info_kind,
debug_info_level, Cint(safepoint_on_entry),
Cint(gcstack_arg), Cint(use_jlplt),
lookup)
compiler)
end
end

Expand Down
87 changes: 87 additions & 0 deletions doc/src/devdocs/compilerplugins.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Compiler plugins

!!! warning
Compiler plugins are an unstable feature and depend on compiler internals.
Users of compiler plugins are encouraged to actively keep track of Julia upstream,
and to contribute necessary enhancements.

Compiler plugins use the `AbstractInterpreter` interface to customize Julia's high-level compiler,
and the `AbstractCompiler` interface to make the results from a foreign abstract interpreter executable.

Each task has a private field `compiler` that is like a scoped value to set and propagate the compiler plugin
across a task-graph. It can be accessed through `Core.Compiler.current_compiler()` and set through `Base.invoke_within`.

Instances of `AbstractCompiler` are used a cache lookup tokens and are compared with `jl_egal`.

!!! note
The native compiler of Julia uses `nothing` as a compiler instance and cache token. Switching back to the native compiler
can thus be done with `Base.invoke_within(nothing, f, args...)`.

## Interface

```@docs
Core.Compiler.AbstractCompiler
Core.Compiler.current_compiler
Core.Compiler.abstract_interpreter
Core.Compiler.compiler_world
Base.invoke_within
```

!!! important
While the abstract interpreter interface can be used independently of compiler plugins, if it is used for a compiler plugin.
`CC.cache_owner(interp::AbstractInterpreter)` must return as a cache token the `AbstractCompiler` instance that was used
to construct the `AbstractInterpreter` with `Core.Compiler.abstract_interpreter`.

## Intermediate representations and compiler pipeline

### Pre-inference IR
### Post-inference IR


## Freezing compiler world-ages

It is often undesirable for compiler code to be invalidated by user code. Julia's `Core.Compiler`
as executed by `jl_type_infer` is run in a frozen world-age (`jl_typeinf_world`). Compiler plugins
are not reachable from that world-age and need to be executed in a newer one.

The API [`Core.Compiler.compiler_world`](@ref) can be used to return a frozen world-age.

```julia
const CC = Core.Compiler

struct MyCompiler <: CC.AbstractCompiler end
CC.abstract_interpreter(compiler::MyCompiler, world::UInt) = # ...

const COMPILER_WORLD = Ref{UInt}(0)

function __init__()
COMPILER_WORLD[] = Base.get_world_counter()
end

CC.compiler_world(::MyCompiler) = COMPILER_WORLD[]
```

By default, `compiler_world` returns `get_world_counter()`, thus `invoke_in_world(compiler_world(...), ...)`
being equivalent to `invokelatest`. This default definition can be useful during development.

## Runtime executed implicit functions

The Julia runtime executes implicit functions like finalizers and `__init__` functions.
Currently, these are all executed within the native compiler.

!!! warning
This is currently not fully implemented and might cause spooky action at a distance.

## Accidental recursive instrumentation

## Future work

- Compiler standard library
- Cross-compiler inference
- Inlining?

## Examples

### Tracing

### Broadcast fusion
20 changes: 9 additions & 11 deletions src/aotcompile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,8 @@ static void makeSafeName(GlobalObject &G)
static void jl_ci_cache_lookup(const jl_cgparams_t &cgparams, jl_method_instance_t *mi, size_t world, jl_code_instance_t **ci_out, jl_code_info_t **src_out)
{
++CICacheLookups;
jl_value_t *ci = cgparams.lookup(mi, world, world);
jl_value_t *compiler = cgparams.compiler;
jl_value_t *ci = jl_rettype_inferred(compiler, mi, world, world);
JL_GC_PROMISE_ROOTED(ci);
jl_code_instance_t *codeinst = NULL;
if (ci != jl_nothing) {
Expand All @@ -304,12 +305,7 @@ static void jl_ci_cache_lookup(const jl_cgparams_t &cgparams, jl_method_instance
*src_out = jl_uncompress_ir(def, codeinst, (jl_value_t*)*src_out);
}
if (*src_out == NULL || !jl_is_code_info(*src_out)) {
if (cgparams.lookup != jl_rettype_inferred_addr) {
jl_error("Refusing to automatically run type inference with custom cache lookup.");
}
else {
*ci_out = jl_type_infer(mi, world, 0, SOURCE_MODE_ABI);
}
*ci_out = jl_type_infer(compiler, mi, world, 0, SOURCE_MODE_ABI);
}
*ci_out = codeinst;
}
Expand Down Expand Up @@ -364,6 +360,7 @@ void *jl_create_native_impl(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvm
params.imaging_mode = imaging;
params.debug_level = cgparams->debug_info_level;
params.external_linkage = _external_linkage;
params.compiler = cgparams->compiler;
size_t compile_for[] = { jl_typeinf_world, _world };
for (int worlds = 0; worlds < 2; worlds++) {
JL_TIMING(NATIVE_AOT, NATIVE_Codegen);
Expand Down Expand Up @@ -1939,9 +1936,9 @@ extern "C" JL_DLLEXPORT_CODEGEN jl_code_info_t *jl_gdbdumpcode(jl_method_instanc
jl_printf(stream, "----\n");

jl_code_info_t *src = NULL;
jl_value_t *ci = jl_default_cgparams.lookup(mi, world, world);
jl_value_t *ci = jl_rettype_inferred(jl_default_cgparams.compiler, mi, world, world);
if (ci == jl_nothing) {
ci = (jl_value_t*)jl_type_infer(mi, world, 0, SOURCE_MODE_FORCE_SOURCE_UNCACHED);
ci = (jl_value_t*)jl_type_infer(jl_default_cgparams.compiler, mi, world, 0, SOURCE_MODE_FORCE_SOURCE_UNCACHED);
} else {
ci = NULL;
}
Expand Down Expand Up @@ -1976,13 +1973,13 @@ void jl_get_llvmf_defn_impl(jl_llvmf_dump_t* dump, jl_method_instance_t *mi, siz
jl_code_info_t *src = NULL;
jl_code_instance_t *codeinst = NULL;
JL_GC_PUSH2(&src, &codeinst);
jl_value_t *ci = params.lookup(mi, world, world);
jl_value_t *ci = jl_rettype_inferred(params.compiler, mi, world, world);
if (ci && ci != jl_nothing) {
codeinst = (jl_code_instance_t*)ci;
src = (jl_code_info_t*)jl_atomic_load_relaxed(&codeinst->inferred);
}
if (!src || (jl_value_t*)src == jl_nothing) {
codeinst = jl_type_infer(mi, world, 0, SOURCE_MODE_FORCE_SOURCE_UNCACHED);
codeinst = jl_type_infer(params.compiler, mi, world, 0, SOURCE_MODE_FORCE_SOURCE_UNCACHED);
if (codeinst) {
src = (jl_code_info_t*)jl_atomic_load_relaxed(&codeinst->inferred);
}
Expand All @@ -2007,6 +2004,7 @@ void jl_get_llvmf_defn_impl(jl_llvmf_dump_t* dump, jl_method_instance_t *mi, siz
jl_codegen_params_t output(*ctx, std::move(target_info.first), std::move(target_info.second));
output.params = &params;
output.imaging_mode = imaging_default();
output.compiler = params.compiler;
// This would be nice, but currently it causes some assembly regressions that make printed output
// differ very significantly from the actual non-imaging mode code.
// // Force imaging mode for names of pointers
Expand Down
1 change: 1 addition & 0 deletions src/builtin_proto.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ DECLARE_BUILTIN(_apply_iterate);
DECLARE_BUILTIN(_apply_pure);
DECLARE_BUILTIN(_call_in_world);
DECLARE_BUILTIN(_call_in_world_total);
DECLARE_BUILTIN(_call_within);
DECLARE_BUILTIN(_call_latest);
DECLARE_BUILTIN(_compute_sparams);
DECLARE_BUILTIN(_expr);
Expand Down
20 changes: 20 additions & 0 deletions src/builtins.c
Original file line number Diff line number Diff line change
Expand Up @@ -921,6 +921,25 @@ JL_CALLABLE(jl_f__call_in_world_total)
return ret;
}

JL_CALLABLE(jl_f__call_within)
{
JL_NARGSV(_apply_within, 2);
jl_task_t *ct = jl_current_task;
jl_value_t *last_compiler = ct->compiler;
jl_value_t *compiler = args[0];
jl_value_t *ret = NULL;
JL_TRY {
ct->compiler = compiler;
ret = jl_apply(&args[1], nargs - 1);
ct->compiler = last_compiler;
}
JL_CATCH {
ct->compiler = last_compiler;
jl_rethrow();
}
return ret;
}

// tuples ---------------------------------------------------------------------

JL_CALLABLE(jl_f_tuple)
Expand Down Expand Up @@ -2444,6 +2463,7 @@ void jl_init_primitives(void) JL_GC_DISABLED
add_builtin_func("_call_latest", jl_f__call_latest);
add_builtin_func("_call_in_world", jl_f__call_in_world);
add_builtin_func("_call_in_world_total", jl_f__call_in_world_total);
add_builtin_func("_call_within", jl_f__call_within);
add_builtin_func("_typevar", jl_f__typevar);
add_builtin_func("_structtype", jl_f__structtype);
add_builtin_func("_abstracttype", jl_f__abstracttype);
Expand Down
2 changes: 1 addition & 1 deletion src/cgutils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4041,7 +4041,7 @@ static int compare_cgparams(const jl_cgparams_t *a, const jl_cgparams_t *b)
(a->safepoint_on_entry == b->safepoint_on_entry) &&
(a->gcstack_arg == b->gcstack_arg) &&
(a->use_jlplt == b->use_jlplt) &&
(a->lookup == b->lookup);
(a->compiler == b->compiler);
}
#endif

Expand Down
Loading

0 comments on commit 4f16213

Please sign in to comment.