From 8652f261b9d487e1b5ed0fe1282d1209ab910a7b Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Sun, 1 Jan 2023 20:21:09 +0800 Subject: [PATCH] Only merge vars occur in the local union decision. If we always merge the whole env, then the output bounds would be widen than input if different Union decision touch different vars. Also add missing `occurs_inv/cov`'s merge (by max). --- src/subtype.c | 132 ++++++++++++++++++++++++++++++++++++++---------- test/subtype.jl | 7 +++ 2 files changed, 113 insertions(+), 26 deletions(-) diff --git a/src/subtype.c b/src/subtype.c index 0413b8f9402a05..366de04fef0d2f 100644 --- a/src/subtype.c +++ b/src/subtype.c @@ -65,6 +65,7 @@ typedef struct jl_varbinding_t { jl_value_t *lb; jl_value_t *ub; int8_t right; // whether this variable came from the right side of `A <: B` + int8_t occurs; // occurs in any position int8_t occurs_inv; // occurs in invariant position int8_t occurs_cov; // # of occurrences in covariant position int8_t concrete; // 1 if another variable has a constraint forcing this one to be concrete @@ -161,7 +162,7 @@ static void statestack_set(jl_unionstate_t *st, int i, int val) JL_NOTSAFEPOINT typedef struct { int8_t *buf; int rdepth; - int8_t _space[16]; + int8_t _space[24]; } jl_savedenv_t; static void save_env(jl_stenv_t *e, jl_value_t **root, jl_savedenv_t *se) @@ -174,9 +175,9 @@ static void save_env(jl_stenv_t *e, jl_value_t **root, jl_savedenv_t *se) } if (root) *root = (jl_value_t*)jl_alloc_svec(len * 3); - se->buf = (int8_t*)(len > 8 ? malloc_s(len * 2) : &se->_space); + se->buf = (int8_t*)(len > 8 ? malloc_s(len * 3) : &se->_space); #ifdef __clang_gcanalyzer__ - memset(se->buf, 0, len * 2); + memset(se->buf, 0, len * 3); #endif int i=0, j=0; v = e->vars; while (v != NULL) { @@ -185,6 +186,7 @@ static void save_env(jl_stenv_t *e, jl_value_t **root, jl_savedenv_t *se) jl_svecset(*root, i++, v->ub); jl_svecset(*root, i++, (jl_value_t*)v->innervars); } + se->buf[j++] = v->occurs; se->buf[j++] = v->occurs_inv; se->buf[j++] = v->occurs_cov; v = v->prev; @@ -207,6 +209,7 @@ static void restore_env(jl_stenv_t *e, jl_value_t *root, jl_savedenv_t *se) JL_N if (root) v->lb = jl_svecref(root, i++); if (root) v->ub = jl_svecref(root, i++); if (root) v->innervars = (jl_array_t*)jl_svecref(root, i++); + v->occurs = se->buf[j++]; v->occurs_inv = se->buf[j++]; v->occurs_cov = se->buf[j++]; v = v->prev; @@ -227,6 +230,15 @@ static int current_env_length(jl_stenv_t *e) return len; } +static void clean_occurs(jl_stenv_t *e) +{ + jl_varbinding_t *v = e->vars; + while (v) { + v->occurs = 0; + v = v->prev; + } +} + // type utilities // quickly test that two types are identical @@ -601,6 +613,8 @@ static int subtype_left_var(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int par // of determining whether the variable is concrete. static void record_var_occurrence(jl_varbinding_t *vb, jl_stenv_t *e, int param) JL_NOTSAFEPOINT { + if (vb != NULL) + vb->occurs = 1; if (vb != NULL && param) { // saturate counters at 2; we don't need values bigger than that if (param == 2 && (vb->right ? e->Rinvdepth : e->invdepth) > vb->depth0) { @@ -793,7 +807,7 @@ static jl_unionall_t *unalias_unionall(jl_unionall_t *u, jl_stenv_t *e) static int subtype_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_t *e, int8_t R, int param) { u = unalias_unionall(u, e); - jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0, 0, + jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0, 0, 0, R ? e->Rinvdepth : e->invdepth, 0, NULL, e->vars }; JL_GC_PUSH4(&u, &vb.lb, &vb.ub, &vb.innervars); e->vars = &vb; @@ -2752,7 +2766,7 @@ static jl_value_t *intersect_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_ { jl_value_t *res=NULL, *save=NULL; jl_savedenv_t se; - jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0, 0, + jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0, 0, 0, R ? e->Rinvdepth : e->invdepth, 0, NULL, e->vars }; JL_GC_PUSH5(&res, &vb.lb, &vb.ub, &save, &vb.innervars); save_env(e, &save, &se); @@ -2765,13 +2779,13 @@ static jl_value_t *intersect_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_ else if (res != jl_bottom_type) { if (vb.concrete || vb.occurs_inv>1 || vb.intvalued > 1 || u->var->lb != jl_bottom_type || (vb.occurs_inv && vb.occurs_cov)) { restore_env(e, NULL, &se); - vb.occurs_cov = vb.occurs_inv = 0; + vb.occurs = vb.occurs_cov = vb.occurs_inv = 0; vb.constraintkind = vb.concrete ? 1 : 2; res = intersect_unionall_(t, u, e, R, param, &vb); } else if (vb.occurs_cov && !var_occurs_invariant(u->body, u->var, 0)) { restore_env(e, save, &se); - vb.occurs_cov = vb.occurs_inv = 0; + vb.occurs = vb.occurs_cov = vb.occurs_inv = 0; vb.constraintkind = 1; res = intersect_unionall_(t, u, e, R, param, &vb); } @@ -3282,29 +3296,42 @@ static jl_value_t *intersect(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int pa static int merge_env(jl_stenv_t *e, jl_value_t **root, jl_savedenv_t *se, int count) { - if (!count) { - save_env(e, root, se); - return 1; + if (count == 0) { + int len = current_env_length(e); + *root = (jl_value_t*)jl_alloc_svec(len * 3); + se->buf = (int8_t*)(len > 8 ? malloc_s(len * 3) : &se->_space); + memset(se->buf, 0, len * 3); } int n = 0; jl_varbinding_t *v = e->vars; jl_value_t *b1 = NULL, *b2 = NULL; JL_GC_PUSH2(&b1, &b2); // clang-sagc does not understand that *root is rooted already + v = e->vars; while (v != NULL) { - b1 = jl_svecref(*root, n); - b2 = v->lb; - jl_svecset(*root, n, simple_meet(b1, b2)); - b1 = jl_svecref(*root, n+1); - b2 = v->ub; - jl_svecset(*root, n+1, simple_join(b1, b2)); - b1 = jl_svecref(*root, n+2); - b2 = (jl_value_t*)v->innervars; - if (b2 && b1 != b2) { - if (b1) - jl_array_ptr_1d_append((jl_array_t*)b2, (jl_array_t*)b1); - else - jl_svecset(*root, n+2, b2); + if (v->occurs) { + // only merge lb/ub/innervars if this var occurs. + b1 = jl_svecref(*root, n); + b2 = v->lb; + jl_svecset(*root, n, b1 ? simple_meet(b1, b2) : b2); + b1 = jl_svecref(*root, n+1); + b2 = v->ub; + jl_svecset(*root, n+1, b1 ? simple_join(b1, b2) : b2); + b1 = jl_svecref(*root, n+2); + b2 = (jl_value_t*)v->innervars; + if (b2 && b1 != b2) { + if (b1) + jl_array_ptr_1d_append((jl_array_t*)b1, (jl_array_t*)b2); + else + jl_svecset(*root, n+2, b2); + } + // record the meeted vars. + se->buf[n] = 1; } + // always merge occurs_inv/cov by max (never decrease) + if (v->occurs_inv > se->buf[n+1]) + se->buf[n+1] = v->occurs_inv; + if (v->occurs_cov > se->buf[n+2]) + se->buf[n+2] = v->occurs_cov; n = n + 3; v = v->prev; } @@ -3312,6 +3339,52 @@ static int merge_env(jl_stenv_t *e, jl_value_t **root, jl_savedenv_t *se, int co return count + 1; } +// merge untouched vars' info. +static void final_merge_env(jl_value_t **merged, jl_savedenv_t *me, jl_value_t **saved, jl_savedenv_t *se) +{ + int l = jl_svec_len(*merged); + assert(l == jl_svec_len(*saved) && l%3 == 0); + jl_value_t *b1 = NULL, *b2 = NULL; + JL_GC_PUSH2(&b1, &b2); + for (int n = 0; n < l; n = n + 3) { + if (jl_svecref(*merged, n) == NULL) + jl_svecset(*merged, n, jl_svecref(*saved, n)); + if (jl_svecref(*merged, n+1) == NULL) + jl_svecset(*merged, n+1, jl_svecref(*saved, n+1)); + b1 = jl_svecref(*merged, n+2); + b2 = jl_svecref(*saved , n+2); + if (b1 == NULL) + jl_svecset(*merged, n+2, b2); + else if (b2 && b1 != b2) + jl_array_ptr_1d_append((jl_array_t*)b1, (jl_array_t*)b2); + me->buf[n] |= se->buf[n]; + } + JL_GC_POP(); +} + +static void expand_local_env(jl_stenv_t *e, jl_value_t *res) +{ + jl_varbinding_t *v = e->vars; + // Here we pull in some typevar missed in fastpath. + while (v != NULL) { + v->occurs = v->occurs || jl_has_typevar(res, v->var); + assert(v->occurs == 0 || v->occurs == 1); + v = v->prev; + } + v = e->vars; + while (v != NULL) { + if (v->occurs == 1) { + jl_varbinding_t *v2 = e->vars; + while (v2 != NULL) { + if (v2 != v && v2->occurs == 0) + v2->occurs = -(jl_has_typevar(v->lb, v2->var) || jl_has_typevar(v->ub, v2->var)); + v2 = v2->prev; + } + } + v = v->prev; + } +} + static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e) { e->Runions.depth = 0; @@ -3324,10 +3397,13 @@ static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e) jl_savedenv_t se, me; save_env(e, saved, &se); int lastset = 0, niter = 0, total_iter = 0; + clean_occurs(e); jl_value_t *ii = intersect(x, y, e, 0); is[0] = ii; // root - if (is[0] != jl_bottom_type) + if (is[0] != jl_bottom_type) { + expand_local_env(e, is[0]); niter = merge_env(e, merged, &me, niter); + } restore_env(e, *saved, &se); while (e->Runions.more) { if (e->emptiness_only && ii != jl_bottom_type) @@ -3341,9 +3417,12 @@ static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e) lastset = set; is[0] = ii; + clean_occurs(e); is[1] = intersect(x, y, e, 0); - if (is[1] != jl_bottom_type) + if (is[1] != jl_bottom_type) { + expand_local_env(e, is[1]); niter = merge_env(e, merged, &me, niter); + } restore_env(e, *saved, &se); if (is[0] == jl_bottom_type) ii = is[1]; @@ -3359,7 +3438,8 @@ static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e) break; } } - if (niter){ + if (niter) { + final_merge_env(merged, &me, saved, &se); restore_env(e, *merged, &me); free_env(&me); } diff --git a/test/subtype.jl b/test/subtype.jl index b6e30ce6771466..f1d0b09c5f19a8 100644 --- a/test/subtype.jl +++ b/test/subtype.jl @@ -2319,6 +2319,13 @@ let S = Tuple{T2, V2} where {T2, N2, V2<:(Array{S2, N2} where {S2 <: T2})}, @testintersect(S, T, !Union{}) end +# A simple case which has a small local union. +# make sure the env is not widened too much when we intersect(Int8, Int8). +struct T48006{A1,A2,A3} end +@testintersect(Tuple{T48006{Float64, Int, S1}, Int} where {F1<:Real, S1<:Union{Int8, Val{F1}}}, + Tuple{T48006{F2, I, S2}, I} where {F2<:Real, I<:Int, S2<:Union{Int8, Val{F2}}}, + Tuple{T48006{Float64, Int, S1}, Int} where S1<:Union{Val{Float64}, Int8}) + @testset "known subtype/intersect issue" begin #issue 45874 # Causes a hang due to jl_critical_error calling back into malloc...