Skip to content

Commit 79bea13

Browse files
committed
do union split and concrete compilation search
1 parent 8593792 commit 79bea13

File tree

3 files changed

+77
-57
lines changed

3 files changed

+77
-57
lines changed

src/gf.c

+7-1
Original file line numberDiff line numberDiff line change
@@ -3188,6 +3188,12 @@ JL_DLLEXPORT void jl_compile_method_instance(jl_method_instance_t *mi, jl_tuplet
31883188
}
31893189
}
31903190

3191+
JL_DLLEXPORT void jl_compile_method_sig(jl_method_t *m, jl_value_t *types, jl_svec_t *env, size_t world)
3192+
{
3193+
jl_method_instance_t *mi = jl_specializations_get_linfo(m, types, env);
3194+
jl_compile_method_instance(mi, NULL, world);
3195+
}
3196+
31913197
JL_DLLEXPORT int jl_compile_hint(jl_tupletype_t *types)
31923198
{
31933199
size_t world = jl_atomic_load_acquire(&jl_world_counter);
@@ -3197,7 +3203,7 @@ JL_DLLEXPORT int jl_compile_hint(jl_tupletype_t *types)
31973203
if (mi == NULL)
31983204
return 0;
31993205
JL_GC_PROMISE_ROOTED(mi);
3200-
jl_compile_method_instance(mi, types, world);
3206+
jl_compile_method_instance(mi, NULL, world);
32013207
return 1;
32023208
}
32033209

src/julia_internal.h

+1
Original file line numberDiff line numberDiff line change
@@ -695,6 +695,7 @@ JL_DLLEXPORT jl_module_t *jl_debuginfo_module1(jl_value_t *debuginfo_def) JL_NOT
695695
JL_DLLEXPORT const char *jl_debuginfo_name(jl_value_t *func) JL_NOTSAFEPOINT;
696696

697697
JL_DLLEXPORT void jl_compile_method_instance(jl_method_instance_t *mi, jl_tupletype_t *types, size_t world);
698+
JL_DLLEXPORT void jl_compile_method_sig(jl_method_t *m, jl_value_t *types, jl_svec_t *sparams, size_t world);
698699
JL_DLLEXPORT int jl_compile_hint(jl_tupletype_t *types);
699700
JL_DLLEXPORT int jl_add_entrypoint(jl_tupletype_t *types);
700701
jl_code_info_t *jl_code_for_interpreter(jl_method_instance_t *lam JL_PROPAGATES_ROOT, size_t world);

src/precompile_utils.c

+69-56
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
// f{<:Union{...}}(...) is a common pattern
2-
// and expanding the Union may give a leaf function
3-
static void _compile_all_tvar_union(jl_value_t *methsig)
1+
// This file is a part of Julia. License is MIT: https://julialang.org/license
2+
3+
// f(...) where {T<:Union{...}} is a common pattern
4+
// and expanding the Union may give some leaf functions
5+
static int _compile_all_tvar_union(jl_value_t *methsig)
46
{
57
int tvarslen = jl_subtype_env_size(methsig);
68
jl_value_t *sigbody = methsig;
@@ -13,86 +15,93 @@ static void _compile_all_tvar_union(jl_value_t *methsig)
1315
assert(jl_is_unionall(sigbody));
1416
idx[i] = 0;
1517
env[2 * i] = (jl_value_t*)((jl_unionall_t*)sigbody)->var;
16-
env[2 * i + 1] = jl_bottom_type; // initialize the list with Union{}, since T<:Union{} is always a valid option
18+
jl_value_t *tv = env[2 * i];
19+
while (jl_is_typevar(tv))
20+
tv = ((jl_tvar_t*)tv)->ub;
21+
if (jl_is_abstracttype(tv) && !jl_is_type_type(tv)) {
22+
JL_GC_POP();
23+
return 0; // Any as TypeVar is common and not useful here to try to analyze further
24+
}
25+
env[2 * i + 1] = tv;
1726
sigbody = ((jl_unionall_t*)sigbody)->body;
1827
}
1928

20-
for (i = 0; i < tvarslen; /* incremented by inner loop */) {
29+
int all = 1;
30+
int incr = 0;
31+
while (!incr) {
32+
for (i = 0, incr = 1; i < tvarslen; i++) {
33+
jl_value_t *tv = env[2 * i];
34+
while (jl_is_typevar(tv))
35+
tv = ((jl_tvar_t*)tv)->ub;
36+
if (jl_is_uniontype(tv)) {
37+
size_t l = jl_count_union_components(tv);
38+
size_t j = idx[i];
39+
env[2 * i + 1] = jl_nth_union_component(tv, j);
40+
++j;
41+
if (incr) {
42+
if (j == l) {
43+
idx[i] = 0;
44+
}
45+
else {
46+
idx[i] = j;
47+
incr = 0;
48+
}
49+
}
50+
}
51+
}
2152
jl_value_t **sig = &roots[0];
2253
JL_TRY {
2354
// TODO: wrap in UnionAll for each tvar in env[2*i + 1] ?
2455
// currently doesn't matter much, since jl_compile_hint doesn't work on abstract types
2556
*sig = (jl_value_t*)jl_instantiate_type_with(sigbody, env, tvarslen);
2657
}
2758
JL_CATCH {
28-
goto getnext; // sigh, we found an invalid type signature. should we warn the user?
29-
}
30-
if (!jl_has_concrete_subtype(*sig))
31-
goto getnext; // signature wouldn't be callable / is invalid -- skip it
32-
if (jl_is_concrete_type(*sig)) {
33-
if (jl_compile_hint((jl_tupletype_t *)*sig))
34-
goto getnext; // success
59+
*sig = NULL;
3560
}
36-
37-
getnext:
38-
for (i = 0; i < tvarslen; i++) {
39-
jl_tvar_t *tv = (jl_tvar_t*)env[2 * i];
40-
if (jl_is_uniontype(tv->ub)) {
41-
size_t l = jl_count_union_components(tv->ub);
42-
size_t j = idx[i];
43-
if (j == l) {
44-
env[2 * i + 1] = jl_bottom_type;
45-
idx[i] = 0;
46-
}
47-
else {
48-
jl_value_t *ty = jl_nth_union_component(tv->ub, j);
49-
if (!jl_is_concrete_type(ty))
50-
ty = (jl_value_t*)jl_new_typevar(tv->name, tv->lb, ty);
51-
env[2 * i + 1] = ty;
52-
idx[i] = j + 1;
53-
break;
54-
}
55-
}
56-
else {
57-
env[2 * i + 1] = (jl_value_t*)tv;
58-
}
61+
if (*sig) {
62+
if (jl_is_datatype(*sig) && jl_has_concrete_subtype(*sig))
63+
all = all && jl_compile_hint((jl_tupletype_t*)*sig);
64+
else
65+
all = 0;
5966
}
6067
}
6168
JL_GC_POP();
69+
return all;
6270
}
6371

6472
// f(::Union{...}, ...) is a common pattern
6573
// and expanding the Union may give a leaf function
66-
static void _compile_all_union(jl_value_t *sig)
74+
static int _compile_all_union(jl_value_t *sig)
6775
{
6876
jl_tupletype_t *sigbody = (jl_tupletype_t*)jl_unwrap_unionall(sig);
6977
size_t count_unions = 0;
78+
size_t union_size = 1;
7079
size_t i, l = jl_svec_len(sigbody->parameters);
7180
jl_svec_t *p = NULL;
7281
jl_value_t *methsig = NULL;
7382

7483
for (i = 0; i < l; i++) {
7584
jl_value_t *ty = jl_svecref(sigbody->parameters, i);
76-
if (jl_is_uniontype(ty))
77-
++count_unions;
78-
else if (ty == jl_bottom_type)
79-
return; // why does this method exist?
80-
else if (jl_is_datatype(ty) && !jl_has_free_typevars(ty) &&
81-
((!jl_is_kind(ty) && ((jl_datatype_t*)ty)->isconcretetype) ||
82-
((jl_datatype_t*)ty)->name == jl_type_typename))
83-
return; // no amount of union splitting will make this a leaftype signature
85+
if (jl_is_uniontype(ty)) {
86+
count_unions += 1;
87+
union_size *= jl_count_union_components(ty);
88+
}
89+
else if (jl_is_datatype(ty) &&
90+
((!((jl_datatype_t*)ty)->isconcretetype || jl_is_kind(ty)) &&
91+
((jl_datatype_t*)ty)->name != jl_type_typename))
92+
return 0; // no amount of union splitting will make this a dispatch signature
8493
}
8594

86-
if (count_unions == 0 || count_unions >= 6) {
87-
_compile_all_tvar_union(sig);
88-
return;
95+
if (union_size <= 1 || union_size > 8) {
96+
return _compile_all_tvar_union(sig);
8997
}
9098

9199
int *idx = (int*)alloca(sizeof(int) * count_unions);
92100
for (i = 0; i < count_unions; i++) {
93101
idx[i] = 0;
94102
}
95103

104+
int all = 1;
96105
JL_GC_PUSH2(&p, &methsig);
97106
int idx_ctr = 0, incr = 0;
98107
while (!incr) {
@@ -122,10 +131,12 @@ static void _compile_all_union(jl_value_t *sig)
122131
}
123132
methsig = jl_apply_tuple_type(p, 1);
124133
methsig = jl_rewrap_unionall(methsig, sig);
125-
_compile_all_tvar_union(methsig);
134+
if (!_compile_all_tvar_union(methsig))
135+
all = 0;
126136
}
127137

128138
JL_GC_POP();
139+
return all;
129140
}
130141

131142
static int compile_all_collect__(jl_typemap_entry_t *ml, void *env)
@@ -147,29 +158,32 @@ static int compile_all_collect_(jl_methtable_t *mt, void *env)
147158
return 1;
148159
}
149160

150-
static void jl_compile_all_defs(jl_array_t *mis)
161+
static void jl_compile_all_defs(jl_array_t *mis, int all)
151162
{
152163
jl_array_t *allmeths = jl_alloc_vec_any(0);
153164
JL_GC_PUSH1(&allmeths);
154165

155166
jl_foreach_reachable_mtable(compile_all_collect_, allmeths);
156167

168+
size_t world = jl_atomic_load_acquire(&jl_world_counter);
157169
size_t i, l = jl_array_nrows(allmeths);
158170
for (i = 0; i < l; i++) {
159171
jl_method_t *m = (jl_method_t*)jl_array_ptr_ref(allmeths, i);
160172
if (jl_is_datatype(m->sig) && jl_isa_compileable_sig((jl_tupletype_t*)m->sig, jl_emptysvec, m)) {
161173
// method has a single compilable specialization, e.g. its definition
162174
// signature is concrete. in this case we can just hint it.
163-
jl_compile_hint((jl_tupletype_t*)m->sig);
175+
jl_compile_method_sig(m, m->sig, jl_emptysvec, world);
164176
}
165177
else {
166178
// first try to create leaf signatures from the signature declaration and compile those
167179
_compile_all_union(m->sig);
168180

169-
// finally, compile a fully generic fallback that can work for all arguments
170-
jl_method_instance_t *unspec = jl_get_unspecialized(m);
171-
if (unspec)
172-
jl_array_ptr_1d_push(mis, (jl_value_t*)unspec);
181+
if (all) {
182+
// finally, compile a fully generic fallback that can work for all arguments (even invoke)
183+
jl_method_instance_t *unspec = jl_get_unspecialized(m);
184+
if (unspec)
185+
jl_array_ptr_1d_push(mis, (jl_value_t*)unspec);
186+
}
173187
}
174188
}
175189

@@ -273,8 +287,7 @@ static void *jl_precompile(int all)
273287
// array of MethodInstances and ccallable aliases to include in the output
274288
jl_array_t *m = jl_alloc_vec_any(0);
275289
JL_GC_PUSH1(&m);
276-
if (all)
277-
jl_compile_all_defs(m);
290+
jl_compile_all_defs(m, all);
278291
jl_foreach_reachable_mtable(precompile_enq_all_specializations_, m);
279292
void *native_code = jl_precompile_(m, 0);
280293
JL_GC_POP();

0 commit comments

Comments
 (0)