Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow constant-propagation to be disabled #42125

Merged
merged 3 commits into from
Sep 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions base/char.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ represents a valid Unicode character.
"""
Char

@aggressive_constprop (::Type{T})(x::Number) where {T<:AbstractChar} = T(UInt32(x))
@aggressive_constprop AbstractChar(x::Number) = Char(x)
@aggressive_constprop (::Type{T})(x::AbstractChar) where {T<:Union{Number,AbstractChar}} = T(codepoint(x))
@aggressive_constprop (::Type{T})(x::AbstractChar) where {T<:Union{Int32,Int64}} = codepoint(x) % T
@constprop :aggressive (::Type{T})(x::Number) where {T<:AbstractChar} = T(UInt32(x))
@constprop :aggressive AbstractChar(x::Number) = Char(x)
@constprop :aggressive (::Type{T})(x::AbstractChar) where {T<:Union{Number,AbstractChar}} = T(codepoint(x))
@constprop :aggressive (::Type{T})(x::AbstractChar) where {T<:Union{Int32,Int64}} = codepoint(x) % T
(::Type{T})(x::T) where {T<:AbstractChar} = x

"""
Expand All @@ -75,7 +75,7 @@ return a different-sized integer (e.g. `UInt8`).
"""
function codepoint end

@aggressive_constprop codepoint(c::Char) = UInt32(c)
@constprop :aggressive codepoint(c::Char) = UInt32(c)

struct InvalidCharError{T<:AbstractChar} <: Exception
char::T
Expand Down Expand Up @@ -124,7 +124,7 @@ See also [`decode_overlong`](@ref) and [`show_invalid`](@ref).
"""
isoverlong(c::AbstractChar) = false

@aggressive_constprop function UInt32(c::Char)
@constprop :aggressive function UInt32(c::Char)
# TODO: use optimized inline LLVM
u = bitcast(UInt32, c)
u < 0x80000000 && return u >> 24
Expand All @@ -148,7 +148,7 @@ that support overlong encodings should implement `Base.decode_overlong`.
"""
function decode_overlong end

@aggressive_constprop function decode_overlong(c::Char)
@constprop :aggressive function decode_overlong(c::Char)
u = bitcast(UInt32, c)
l1 = leading_ones(u)
t0 = trailing_zeros(u) & 56
Expand All @@ -158,7 +158,7 @@ function decode_overlong end
((u & 0x007f0000) >> 4) | ((u & 0x7f000000) >> 6)
end

@aggressive_constprop function Char(u::UInt32)
@constprop :aggressive function Char(u::UInt32)
u < 0x80 && return bitcast(Char, u << 24)
u < 0x00200000 || throw_code_point_err(u)
c = ((u << 0) & 0x0000003f) | ((u << 2) & 0x00003f00) |
Expand All @@ -169,14 +169,14 @@ end
bitcast(Char, c)
end

@aggressive_constprop @noinline UInt32_cold(c::Char) = UInt32(c)
@aggressive_constprop function (T::Union{Type{Int8},Type{UInt8}})(c::Char)
@constprop :aggressive @noinline UInt32_cold(c::Char) = UInt32(c)
@constprop :aggressive function (T::Union{Type{Int8},Type{UInt8}})(c::Char)
i = bitcast(Int32, c)
i ≥ 0 ? ((i >>> 24) % T) : T(UInt32_cold(c))
end

@aggressive_constprop @noinline Char_cold(b::UInt32) = Char(b)
@aggressive_constprop function Char(b::Union{Int8,UInt8})
@constprop :aggressive @noinline Char_cold(b::UInt32) = Char(b)
@constprop :aggressive function Char(b::Union{Int8,UInt8})
0 ≤ b ≤ 0x7f ? bitcast(Char, (b % UInt32) << 24) : Char_cold(UInt32(b))
end

Expand Down
6 changes: 5 additions & 1 deletion base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,10 @@ function maybe_get_const_prop_profitable(interp::AbstractInterpreter, result::Me
return nothing
end
method = match.method
if method.constprop == 0x02
add_remark!(interp, sv, "[constprop] Disabled by method parameter")
return nothing
end
force = force_const_prop(interp, f, method)
force || const_prop_entry_heuristic(interp, result, sv) || return nothing
nargs::Int = method.nargs
Expand Down Expand Up @@ -649,7 +653,7 @@ function is_allconst(argtypes::Vector{Any})
end

function force_const_prop(interp::AbstractInterpreter, @nospecialize(f), method::Method)
return method.aggressive_constprop ||
return method.constprop == 0x01 ||
InferenceParams(interp).aggressive_constant_propagation ||
istopfunction(f, :getproperty) ||
istopfunction(f, :setproperty!)
Expand Down
30 changes: 20 additions & 10 deletions base/expr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -349,16 +349,26 @@ macro pure(ex)
end

"""
@aggressive_constprop ex
@aggressive_constprop(ex)

`@aggressive_constprop` requests more aggressive interprocedural constant
propagation for the annotated function. For a method where the return type
depends on the value of the arguments, this can yield improved inference results
at the cost of additional compile time.
"""
macro aggressive_constprop(ex)
esc(isa(ex, Expr) ? pushmeta!(ex, :aggressive_constprop) : ex)
@constprop setting ex
@constprop(setting, ex)

`@constprop` controls the mode of interprocedural constant propagation for the
annotated function. Two `setting`s are supported:

- `@constprop :aggressive ex`: apply constant propagation aggressively.
For a method where the return type depends on the value of the arguments,
this can yield improved inference results at the cost of additional compile time.
- `@constprop :none ex`: disable constant propagation. This can reduce compile
times for functions that Julia might otherwise deem worthy of constant-propagation.
Common cases are for functions with `Bool`- or `Symbol`-valued arguments or keyword arguments.
"""
macro constprop(setting, ex)
if isa(setting, QuoteNode)
setting = setting.value
end
setting === :aggressive && return esc(isa(ex, Expr) ? pushmeta!(ex, :aggressive_constprop) : ex)
setting === :none && return esc(isa(ex, Expr) ? pushmeta!(ex, :no_constprop) : ex)
throw(ArgumentError("@constprop $setting not supported"))
end

"""
Expand Down
3 changes: 2 additions & 1 deletion src/ast.c
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ jl_sym_t *static_parameter_sym; jl_sym_t *inline_sym;
jl_sym_t *noinline_sym; jl_sym_t *generated_sym;
jl_sym_t *generated_only_sym; jl_sym_t *isdefined_sym;
jl_sym_t *propagate_inbounds_sym; jl_sym_t *specialize_sym;
jl_sym_t *aggressive_constprop_sym;
jl_sym_t *aggressive_constprop_sym; jl_sym_t *no_constprop_sym;
jl_sym_t *nospecialize_sym; jl_sym_t *macrocall_sym;
jl_sym_t *colon_sym; jl_sym_t *hygienicscope_sym;
jl_sym_t *throw_undef_if_not_sym; jl_sym_t *getfield_undefref_sym;
Expand Down Expand Up @@ -399,6 +399,7 @@ void jl_init_common_symbols(void)
polly_sym = jl_symbol("polly");
propagate_inbounds_sym = jl_symbol("propagate_inbounds");
aggressive_constprop_sym = jl_symbol("aggressive_constprop");
no_constprop_sym = jl_symbol("no_constprop");
isdefined_sym = jl_symbol("isdefined");
nospecialize_sym = jl_symbol("nospecialize");
specialize_sym = jl_symbol("specialize");
Expand Down
4 changes: 2 additions & 2 deletions src/dump.c
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ static void jl_serialize_value_(jl_serializer_state *s, jl_value_t *v, int as_li
write_int8(s->s, m->isva);
write_int8(s->s, m->pure);
write_int8(s->s, m->is_for_opaque_closure);
write_int8(s->s, m->aggressive_constprop);
write_int8(s->s, m->constprop);
jl_serialize_value(s, (jl_value_t*)m->slot_syms);
jl_serialize_value(s, (jl_value_t*)m->roots);
jl_serialize_value(s, (jl_value_t*)m->ccallable);
Expand Down Expand Up @@ -1525,7 +1525,7 @@ static jl_value_t *jl_deserialize_value_method(jl_serializer_state *s, jl_value_
m->isva = read_int8(s->s);
m->pure = read_int8(s->s);
m->is_for_opaque_closure = read_int8(s->s);
m->aggressive_constprop = read_int8(s->s);
m->constprop = read_int8(s->s);
m->slot_syms = jl_deserialize_value(s, (jl_value_t**)&m->slot_syms);
jl_gc_wb(m, m->slot_syms);
m->roots = (jl_array_t*)jl_deserialize_value(s, (jl_value_t**)&m->roots);
Expand Down
47 changes: 29 additions & 18 deletions src/ircode.c
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,17 @@ static void jl_encode_value_(jl_ircode_state *s, jl_value_t *v, int as_literal)
}
}

static jl_code_info_flags_t code_info_flags(uint8_t pure, uint8_t propagate_inbounds, uint8_t inlineable, uint8_t inferred, uint8_t constprop)
{
jl_code_info_flags_t flags;
flags.bits.pure = pure;
flags.bits.propagate_inbounds = propagate_inbounds;
flags.bits.inlineable = inlineable;
flags.bits.inferred = inferred;
flags.bits.constprop = constprop;
return flags;
}

// --- decoding ---

static jl_value_t *jl_decode_value(jl_ircode_state *s) JL_GC_DISABLED;
Expand Down Expand Up @@ -702,12 +713,8 @@ JL_DLLEXPORT jl_array_t *jl_compress_ir(jl_method_t *m, jl_code_info_t *code)
jl_current_task->ptls
};

uint8_t flags = (code->aggressive_constprop << 4)
| (code->inferred << 3)
| (code->inlineable << 2)
| (code->propagate_inbounds << 1)
| (code->pure << 0);
write_uint8(s.s, flags);
jl_code_info_flags_t flags = code_info_flags(code->pure, code->propagate_inbounds, code->inlineable, code->inferred, code->constprop);
write_uint8(s.s, flags.packed);

size_t nslots = jl_array_len(code->slotflags);
assert(nslots >= m->nargs && nslots < INT32_MAX); // required by generated functions
Expand Down Expand Up @@ -787,12 +794,13 @@ JL_DLLEXPORT jl_code_info_t *jl_uncompress_ir(jl_method_t *m, jl_code_instance_t
};

jl_code_info_t *code = jl_new_code_info_uninit();
uint8_t flags = read_uint8(s.s);
code->aggressive_constprop = !!(flags & (1 << 4));
code->inferred = !!(flags & (1 << 3));
code->inlineable = !!(flags & (1 << 2));
code->propagate_inbounds = !!(flags & (1 << 1));
code->pure = !!(flags & (1 << 0));
jl_code_info_flags_t flags;
flags.packed = read_uint8(s.s);
code->constprop = flags.bits.constprop;
code->inferred = flags.bits.inferred;
code->inlineable = flags.bits.inlineable;
code->propagate_inbounds = flags.bits.propagate_inbounds;
code->pure = flags.bits.pure;

size_t nslots = read_int32(&src);
code->slotflags = jl_alloc_array_1d(jl_array_uint8_type, nslots);
Expand Down Expand Up @@ -847,26 +855,29 @@ JL_DLLEXPORT uint8_t jl_ir_flag_inferred(jl_array_t *data)
if (jl_is_code_info(data))
return ((jl_code_info_t*)data)->inferred;
assert(jl_typeis(data, jl_array_uint8_type));
uint8_t flags = ((uint8_t*)data->data)[0];
return !!(flags & (1 << 3));
jl_code_info_flags_t flags;
flags.packed = ((uint8_t*)data->data)[0];
return flags.bits.inferred;
}

JL_DLLEXPORT uint8_t jl_ir_flag_inlineable(jl_array_t *data)
{
if (jl_is_code_info(data))
return ((jl_code_info_t*)data)->inlineable;
assert(jl_typeis(data, jl_array_uint8_type));
uint8_t flags = ((uint8_t*)data->data)[0];
return !!(flags & (1 << 2));
jl_code_info_flags_t flags;
flags.packed = ((uint8_t*)data->data)[0];
return flags.bits.inlineable;
}

JL_DLLEXPORT uint8_t jl_ir_flag_pure(jl_array_t *data)
{
if (jl_is_code_info(data))
return ((jl_code_info_t*)data)->pure;
assert(jl_typeis(data, jl_array_uint8_type));
uint8_t flags = ((uint8_t*)data->data)[0];
return !!(flags & (1 << 0));
jl_code_info_flags_t flags;
flags.packed = ((uint8_t*)data->data)[0];
return flags.bits.pure;
}

JL_DLLEXPORT jl_value_t *jl_compress_argnames(jl_array_t *syms)
Expand Down
8 changes: 4 additions & 4 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -2348,7 +2348,7 @@ void jl_init_types(void) JL_GC_DISABLED
"inlineable",
"propagate_inbounds",
"pure",
"aggressive_constprop"),
"constprop"),
jl_svec(19,
jl_array_any_type,
jl_array_int32_type,
Expand All @@ -2368,7 +2368,7 @@ void jl_init_types(void) JL_GC_DISABLED
jl_bool_type,
jl_bool_type,
jl_bool_type,
jl_bool_type),
jl_uint8_type),
jl_emptysvec,
0, 1, 19);

Expand Down Expand Up @@ -2401,7 +2401,7 @@ void jl_init_types(void) JL_GC_DISABLED
"isva",
"pure",
"is_for_opaque_closure",
"aggressive_constprop"),
"constprop"),
jl_svec(26,
jl_symbol_type,
jl_module_type,
Expand All @@ -2428,7 +2428,7 @@ void jl_init_types(void) JL_GC_DISABLED
jl_bool_type,
jl_bool_type,
jl_bool_type,
jl_bool_type),
jl_uint8_type),
jl_emptysvec,
0, 1, 10);

Expand Down
6 changes: 4 additions & 2 deletions src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,8 @@ typedef struct _jl_code_info_t {
uint8_t inlineable;
uint8_t propagate_inbounds;
uint8_t pure;
uint8_t aggressive_constprop;
// uint8 settings
uint8_t constprop; // 0 = use heuristic; 1 = aggressive; 2 = none
} jl_code_info_t;

// This type describes a single method definition, and stores data
Expand Down Expand Up @@ -326,7 +327,8 @@ typedef struct _jl_method_t {
uint8_t isva;
uint8_t pure;
uint8_t is_for_opaque_closure;
uint8_t aggressive_constprop;
// uint8 settings
uint8_t constprop; // 0x00 = use heuristic; 0x01 = aggressive; 0x02 = none

// hidden fields:
// lock for modifications to the method
Expand Down
17 changes: 16 additions & 1 deletion src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -475,9 +475,24 @@ STATIC_INLINE jl_value_t *undefref_check(jl_datatype_t *dt, jl_value_t *v) JL_NO
return v;
}

// -- helper types -- //

typedef struct {
uint8_t pure:1;
uint8_t propagate_inbounds:1;
uint8_t inlineable:1;
uint8_t inferred:1;
uint8_t constprop:2; // 0 = use heuristic; 1 = aggressive; 2 = none
} jl_code_info_flags_bitfield_t;

typedef union {
jl_code_info_flags_bitfield_t bits;
uint8_t packed;
} jl_code_info_flags_t;

// -- functions -- //

// jl_code_info_flag_t code_info_flags(uint8_t pure, uint8_t propagate_inbounds, uint8_t inlineable, uint8_t inferred, uint8_t constprop);
jl_code_info_t *jl_type_infer(jl_method_instance_t *li, size_t world, int force);
jl_code_instance_t *jl_compile_method_internal(jl_method_instance_t *meth JL_PROPAGATES_ROOT, size_t world);
jl_code_instance_t *jl_generate_fptr(jl_method_instance_t *mi JL_PROPAGATES_ROOT, size_t world);
Expand Down Expand Up @@ -1376,7 +1391,7 @@ extern jl_sym_t *static_parameter_sym; extern jl_sym_t *inline_sym;
extern jl_sym_t *noinline_sym; extern jl_sym_t *generated_sym;
extern jl_sym_t *generated_only_sym; extern jl_sym_t *isdefined_sym;
extern jl_sym_t *propagate_inbounds_sym; extern jl_sym_t *specialize_sym;
extern jl_sym_t *aggressive_constprop_sym;
extern jl_sym_t *aggressive_constprop_sym; extern jl_sym_t *no_constprop_sym;
extern jl_sym_t *nospecialize_sym; extern jl_sym_t *macrocall_sym;
extern jl_sym_t *colon_sym; extern jl_sym_t *hygienicscope_sym;
extern jl_sym_t *throw_undef_if_not_sym; extern jl_sym_t *getfield_undefref_sym;
Expand Down
10 changes: 6 additions & 4 deletions src/method.c
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,9 @@ static void jl_code_info_set_ir(jl_code_info_t *li, jl_expr_t *ir)
else if (ma == (jl_value_t*)propagate_inbounds_sym)
li->propagate_inbounds = 1;
else if (ma == (jl_value_t*)aggressive_constprop_sym)
li->aggressive_constprop = 1;
li->constprop = 1;
else if (ma == (jl_value_t*)no_constprop_sym)
li->constprop = 2;
else
jl_array_ptr_set(meta, ins++, ma);
}
Expand Down Expand Up @@ -443,7 +445,7 @@ JL_DLLEXPORT jl_code_info_t *jl_new_code_info_uninit(void)
src->propagate_inbounds = 0;
src->pure = 0;
src->edges = jl_nothing;
src->aggressive_constprop = 0;
src->constprop = 0;
return src;
}

Expand Down Expand Up @@ -630,7 +632,7 @@ static void jl_method_set_source(jl_method_t *m, jl_code_info_t *src)
}
m->called = called;
m->pure = src->pure;
m->aggressive_constprop = src->aggressive_constprop;
m->constprop = src->constprop;
jl_add_function_name_to_lineinfo(src, (jl_value_t*)m->name);

jl_array_t *copy = NULL;
Expand Down Expand Up @@ -746,7 +748,7 @@ JL_DLLEXPORT jl_method_t *jl_new_method_uninit(jl_module_t *module)
m->primary_world = 1;
m->deleted_world = ~(size_t)0;
m->is_for_opaque_closure = 0;
m->aggressive_constprop = 0;
m->constprop = 0;
JL_MUTEX_INIT(&m->writelock);
return m;
}
Expand Down
Loading