Skip to content

Commit

Permalink
improve NTuple show and cleanup code for Vararg and NTuple type param…
Browse files Browse the repository at this point in the history
…eter construction (#51244)

Various simplifications and improvements from investigating #51228.
Improves the logic for showing of NTuple to handle constant lengths.
Improves the logic for showing NTuple of bound length (e.g. NTuple
itself). Also makes a choice to avoid showing non-types as NTuple, but
instead try to write them out, to make it more visually obvious when the
parameters have been swapped.

---------

Co-authored-by: Jameson Nash <[email protected]>
  • Loading branch information
maleadt and vtjnash authored Sep 25, 2023
1 parent b44a95b commit 0287a00
Show file tree
Hide file tree
Showing 10 changed files with 120 additions and 59 deletions.
65 changes: 52 additions & 13 deletions base/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1084,29 +1084,68 @@ function show_datatype(io::IO, x::DataType, wheres::Vector{TypeVar}=TypeVar[])

# Print tuple types with homogeneous tails longer than max_n compactly using `NTuple` or `Vararg`
if istuple
if n == 0
print(io, "Tuple{}")
return
end

# find the length of the homogeneous tail
max_n = 3
taillen = 1
for i in (n-1):-1:1
if parameters[i] === parameters[n]
taillen += 1
pn = parameters[n]
fulln = n
vakind = :none
vaN = 0
if pn isa Core.TypeofVararg
if isdefined(pn, :N)
vaN = pn.N
if vaN isa Int
taillen = vaN
fulln += taillen - 1
vakind = :fixed
else
vakind = :bound
end
else
break
vakind = :unbound
end
pn = unwrapva(pn)
end
if !(pn isa TypeVar || pn isa Type)
# prefer Tuple over NTuple if it contains something other than types
# (e.g. if the user has switched the N and T accidentally)
taillen = 0
elseif vakind === :none || vakind === :fixed
for i in (n-1):-1:1
if parameters[i] === pn
taillen += 1
else
break
end
end
end
if n == taillen > max_n
print(io, "NTuple{", n, ", ")
show(io, parameters[1])

# prefer NTuple over Tuple if it is a Vararg without a fixed length
# and prefer Tuple for short lists of elements
if (vakind == :bound && n == 1 == taillen) || (vakind === :fixed && taillen == fulln > max_n) ||
(vakind === :none && taillen == fulln > max_n)
print(io, "NTuple{")
vakind === :bound ? show(io, vaN) : print(io, fulln)
print(io, ", ")
show(io, pn)
print(io, "}")
else
print(io, "Tuple{")
for i = 1:(taillen > max_n ? n-taillen : n)
headlen = (taillen > max_n ? fulln - taillen : fulln)
for i = 1:headlen
i > 1 && print(io, ", ")
show(io, parameters[i])
show(io, vakind === :fixed && i >= n ? pn : parameters[i])
end
if taillen > max_n
print(io, ", Vararg{")
show(io, parameters[n])
print(io, ", ", taillen, "}")
if headlen < fulln
headlen > 0 && print(io, ", ")
print(io, "Vararg{")
show(io, pn)
print(io, ", ", fulln - headlen, "}")
end
print(io, "}")
end
Expand Down
4 changes: 2 additions & 2 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6738,7 +6738,7 @@ static jl_cgval_t emit_cfunction(jl_codectx_t &ctx, jl_value_t *output_type, con
sigt = NULL;
}
else {
sigt = jl_apply_tuple_type((jl_svec_t*)sigt);
sigt = jl_apply_tuple_type((jl_svec_t*)sigt, 1);
}
if (sigt && !(unionall_env && jl_has_typevar_from_unionall(rt, unionall_env))) {
unionall_env = NULL;
Expand Down Expand Up @@ -7242,7 +7242,7 @@ static jl_datatype_t *compute_va_type(jl_method_instance_t *lam, size_t nreq)
}
jl_svecset(tupargs, i-nreq, argType);
}
jl_value_t *typ = jl_apply_tuple_type(tupargs);
jl_value_t *typ = jl_apply_tuple_type(tupargs, 1);
JL_GC_POP();
return (jl_datatype_t*)typ;
}
Expand Down
8 changes: 4 additions & 4 deletions src/gf.c
Original file line number Diff line number Diff line change
Expand Up @@ -1266,7 +1266,7 @@ static jl_method_instance_t *cache_method(
intptr_t max_varargs = get_max_varargs(definition, kwmt, mt, NULL);
jl_compilation_sig(tt, sparams, definition, max_varargs, &newparams);
if (newparams) {
temp2 = jl_apply_tuple_type(newparams);
temp2 = jl_apply_tuple_type(newparams, 1);
// Now there may be a problem: the widened signature is more general
// than just the given arguments, so it might conflict with another
// definition that does not have cache instances yet. To fix this, we
Expand Down Expand Up @@ -1389,7 +1389,7 @@ static jl_method_instance_t *cache_method(
}
}
if (newparams) {
simplett = (jl_datatype_t*)jl_apply_tuple_type(newparams);
simplett = (jl_datatype_t*)jl_apply_tuple_type(newparams, 1);
temp2 = (jl_value_t*)simplett;
}

Expand Down Expand Up @@ -2579,7 +2579,7 @@ JL_DLLEXPORT jl_value_t *jl_normalize_to_compilable_sig(jl_methtable_t *mt, jl_t
jl_compilation_sig(ti, env, m, max_varargs, &newparams);
int is_compileable = ((jl_datatype_t*)ti)->isdispatchtuple;
if (newparams) {
tt = (jl_datatype_t*)jl_apply_tuple_type(newparams);
tt = (jl_datatype_t*)jl_apply_tuple_type(newparams, 1);
if (!is_compileable) {
// compute new env, if used below
jl_value_t *ti = jl_type_intersection_env((jl_value_t*)tt, (jl_value_t*)m->sig, &newparams);
Expand Down Expand Up @@ -2834,7 +2834,7 @@ jl_value_t *jl_argtype_with_function_type(jl_value_t *ft JL_MAYBE_UNROOTED, jl_v
jl_svecset(tt, 0, ft);
for (size_t i = 0; i < l; i++)
jl_svecset(tt, i+1, jl_tparam(types,i));
tt = (jl_value_t*)jl_apply_tuple_type((jl_svec_t*)tt);
tt = (jl_value_t*)jl_apply_tuple_type((jl_svec_t*)tt, 1);
tt = jl_rewrap_unionall_(tt, types0);
JL_GC_POP();
return tt;
Expand Down
78 changes: 47 additions & 31 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ JL_DLLEXPORT int jl_get_size(jl_value_t *val, size_t *pnt)
if (jl_is_long(val)) {
ssize_t slen = jl_unbox_long(val);
if (slen < 0)
jl_errorf("size or dimension is negative: %d", slen);
jl_errorf("size or dimension is negative: %zd", slen);
*pnt = slen;
return 1;
}
Expand Down Expand Up @@ -1435,17 +1435,6 @@ jl_datatype_t *jl_apply_cmpswap_type(jl_value_t *ty)
return rettyp;
}

// used to expand an NTuple to a flat representation
static jl_value_t *jl_tupletype_fill(size_t n, jl_value_t *v)
{
jl_value_t *p = NULL;
JL_GC_PUSH1(&p);
p = (jl_value_t*)jl_svec_fill(n, v);
p = jl_apply_tuple_type((jl_svec_t*)p);
JL_GC_POP();
return p;
}

JL_EXTENSION struct _jl_typestack_t {
jl_datatype_t *tt;
struct _jl_typestack_t *prev;
Expand Down Expand Up @@ -1724,7 +1713,7 @@ static void check_datatype_parameters(jl_typename_t *tn, jl_value_t **params, si
JL_GC_POP();
}

jl_value_t *extract_wrapper(jl_value_t *t JL_PROPAGATES_ROOT) JL_GLOBALLY_ROOTED
jl_value_t *extract_wrapper(jl_value_t *t JL_PROPAGATES_ROOT) JL_NOTSAFEPOINT JL_GLOBALLY_ROOTED
{
t = jl_unwrap_unionall(t);
if (jl_is_datatype(t))
Expand Down Expand Up @@ -1796,13 +1785,13 @@ int _may_substitute_ub(jl_value_t *v, jl_tvar_t *var, int inside_inv, int *cov_c
// * `var` does not appear in invariant position
// * `var` appears at most once (in covariant position) and not in a `Vararg`
// unless the upper bound is concrete (diagonal rule)
int may_substitute_ub(jl_value_t *v, jl_tvar_t *var) JL_NOTSAFEPOINT
static int may_substitute_ub(jl_value_t *v, jl_tvar_t *var) JL_NOTSAFEPOINT
{
int cov_count = 0;
return _may_substitute_ub(v, var, 0, &cov_count);
}

jl_value_t *normalize_unionalls(jl_value_t *t)
static jl_value_t *normalize_unionalls(jl_value_t *t)
{
if (jl_is_uniontype(t)) {
jl_uniontype_t *u = (jl_uniontype_t*)t;
Expand Down Expand Up @@ -1840,6 +1829,31 @@ jl_value_t *normalize_unionalls(jl_value_t *t)
return t;
}

// used to expand an NTuple to a flat representation
static jl_value_t *jl_tupletype_fill(size_t n, jl_value_t *t, int check)
{
jl_value_t *p = NULL;
JL_GC_PUSH1(&p);
if (check) {
// Since we are skipping making the Vararg and skipping checks later,
// we inline the checks from jl_wrap_vararg here now
if (!jl_valid_type_param(t))
jl_type_error_rt("Vararg", "type", (jl_value_t*)jl_type_type, t);
// jl_wrap_vararg sometimes simplifies the type, so we only do this 1 time, instead of for each n later
t = normalize_unionalls(t);
p = t;
jl_value_t *tw = extract_wrapper(t);
if (tw && t != tw && jl_types_equal(t, tw))
t = tw;
p = t;
check = 0; // remember that checks are already done now
}
p = (jl_value_t*)jl_svec_fill(n, t);
p = jl_apply_tuple_type((jl_svec_t*)p, check);
JL_GC_POP();
return p;
}

static jl_value_t *_jl_instantiate_type_in_env(jl_value_t *ty, jl_unionall_t *env, jl_value_t **vals, jl_typeenv_t *prev, jl_typestack_t *stack);

static jl_value_t *inst_datatype_inner(jl_datatype_t *dt, jl_svec_t *p, jl_value_t **iparams, size_t ntp,
Expand Down Expand Up @@ -1962,7 +1976,7 @@ static jl_value_t *inst_datatype_inner(jl_datatype_t *dt, jl_svec_t *p, jl_value
if (nt == 0 || !jl_has_free_typevars(va0)) {
if (ntp == 1) {
JL_GC_POP();
return jl_tupletype_fill(nt, va0);
return jl_tupletype_fill(nt, va0, 0);
}
size_t i, l;
p = jl_alloc_svec(ntp - 1 + nt);
Expand All @@ -1971,7 +1985,7 @@ static jl_value_t *inst_datatype_inner(jl_datatype_t *dt, jl_svec_t *p, jl_value
l = ntp - 1 + nt;
for (; i < l; i++)
jl_svecset(p, i, va0);
jl_value_t *ndt = jl_apply_tuple_type(p);
jl_value_t *ndt = jl_apply_tuple_type(p, check);
JL_GC_POP();
return ndt;
}
Expand Down Expand Up @@ -2136,19 +2150,19 @@ static jl_value_t *inst_datatype_inner(jl_datatype_t *dt, jl_svec_t *p, jl_value
return (jl_value_t*)ndt;
}

static jl_value_t *jl_apply_tuple_type_v_(jl_value_t **p, size_t np, jl_svec_t *params)
static jl_value_t *jl_apply_tuple_type_v_(jl_value_t **p, size_t np, jl_svec_t *params, int check)
{
return inst_datatype_inner(jl_anytuple_type, params, p, np, NULL, NULL, 1);
return inst_datatype_inner(jl_anytuple_type, params, p, np, NULL, NULL, check);
}

JL_DLLEXPORT jl_value_t *jl_apply_tuple_type(jl_svec_t *params)
JL_DLLEXPORT jl_value_t *jl_apply_tuple_type(jl_svec_t *params, int check)
{
return jl_apply_tuple_type_v_(jl_svec_data(params), jl_svec_len(params), params);
return jl_apply_tuple_type_v_(jl_svec_data(params), jl_svec_len(params), params, check);
}

JL_DLLEXPORT jl_value_t *jl_apply_tuple_type_v(jl_value_t **p, size_t np)
{
return jl_apply_tuple_type_v_(p, np, NULL);
return jl_apply_tuple_type_v_(p, np, NULL, 1);
}

jl_tupletype_t *jl_lookup_arg_tuple_type(jl_value_t *arg1, jl_value_t **args, size_t nargs, int leaf)
Expand Down Expand Up @@ -2211,13 +2225,15 @@ static jl_value_t *inst_tuple_w_(jl_value_t *t, jl_typeenv_t *env, jl_typestack_
jl_datatype_t *tt = (jl_datatype_t*)t;
jl_svec_t *tp = tt->parameters;
size_t ntp = jl_svec_len(tp);
// Instantiate NTuple{3,Int}
// Instantiate Tuple{Vararg{T,N}} where T is fixed and N is known, such as Dims{3}
// And avoiding allocating the intermediate steps
// Note this does not instantiate Tuple{Vararg{Int,3}}; that's done in inst_datatype_inner
// Note this does not instantiate NTuple{N,T}, since it is unnecessary and inefficient to expand that now
if (jl_is_va_tuple(tt) && ntp == 1) {
// If this is a Tuple{Vararg{T,N}} with known N, expand it to
// If this is a Tuple{Vararg{T,N}} with known N and T, expand it to
// a fixed-length tuple
jl_value_t *T=NULL, *N=NULL;
jl_value_t *va = jl_unwrap_unionall(jl_tparam0(tt));
jl_value_t *va = jl_tparam0(tt);
jl_value_t *ttT = jl_unwrap_vararg(va);
jl_value_t *ttN = jl_unwrap_vararg_num(va);
jl_typeenv_t *e = env;
Expand All @@ -2228,11 +2244,12 @@ static jl_value_t *inst_tuple_w_(jl_value_t *t, jl_typeenv_t *env, jl_typestack_
N = e->val;
e = e->prev;
}
if (T != NULL && N != NULL && jl_is_long(N)) {
if (T != NULL && N != NULL && jl_is_long(N)) { // TODO: && !jl_has_free_typevars(T) to match inst_datatype_inner, or even && jl_is_concrete_type(T)
// Since this is skipping jl_wrap_vararg, we inline the checks from it here
ssize_t nt = jl_unbox_long(N);
if (nt < 0)
jl_errorf("size or dimension is negative: %zd", nt);
return jl_tupletype_fill(nt, T);
jl_errorf("Vararg length is negative: %zd", nt);
return jl_tupletype_fill(nt, T, check);
}
}
jl_value_t **iparams;
Expand Down Expand Up @@ -2428,9 +2445,8 @@ jl_vararg_t *jl_wrap_vararg(jl_value_t *t, jl_value_t *n, int check)
}
}
if (t) {
if (!jl_valid_type_param(t)) {
if (!jl_valid_type_param(t))
jl_type_error_rt("Vararg", "type", (jl_value_t*)jl_type_type, t);
}
t = normalize_unionalls(t);
jl_value_t *tw = extract_wrapper(t);
if (tw && t != tw && jl_types_equal(t, tw))
Expand Down Expand Up @@ -2735,7 +2751,7 @@ void jl_init_types(void) JL_GC_DISABLED
jl_anytuple_type->layout = NULL;

jl_typeofbottom_type->super = jl_wrap_Type(jl_bottom_type);
jl_emptytuple_type = (jl_datatype_t*)jl_apply_tuple_type(jl_emptysvec);
jl_emptytuple_type = (jl_datatype_t*)jl_apply_tuple_type(jl_emptysvec, 0);
jl_emptytuple = jl_gc_permobj(0, jl_emptytuple_type);
jl_emptytuple_type->instance = jl_emptytuple;

Expand Down
2 changes: 1 addition & 1 deletion src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -1569,7 +1569,7 @@ JL_DLLEXPORT jl_value_t *jl_apply_type1(jl_value_t *tc, jl_value_t *p1);
JL_DLLEXPORT jl_value_t *jl_apply_type2(jl_value_t *tc, jl_value_t *p1, jl_value_t *p2);
JL_DLLEXPORT jl_datatype_t *jl_apply_modify_type(jl_value_t *dt);
JL_DLLEXPORT jl_datatype_t *jl_apply_cmpswap_type(jl_value_t *dt);
JL_DLLEXPORT jl_value_t *jl_apply_tuple_type(jl_svec_t *params);
JL_DLLEXPORT jl_value_t *jl_apply_tuple_type(jl_svec_t *params, int check); // if uncertain, set check=1
JL_DLLEXPORT jl_value_t *jl_apply_tuple_type_v(jl_value_t **p, size_t np);
JL_DLLEXPORT jl_datatype_t *jl_new_datatype(jl_sym_t *name,
jl_module_t *module,
Expand Down
2 changes: 1 addition & 1 deletion src/method.c
Original file line number Diff line number Diff line change
Expand Up @@ -998,7 +998,7 @@ JL_DLLEXPORT jl_method_t* jl_method_def(jl_svec_t *argdata,
JL_GC_PUSH3(&f, &m, &argtype);
size_t i, na = jl_svec_len(atypes);

argtype = jl_apply_tuple_type(atypes);
argtype = jl_apply_tuple_type(atypes, 1);
if (!jl_is_datatype(argtype))
jl_error("invalid type in method definition (Union{})");

Expand Down
2 changes: 1 addition & 1 deletion src/precompile_utils.c
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ static void _compile_all_union(jl_value_t *sig)
jl_svecset(p, i, ty);
}
}
methsig = jl_apply_tuple_type(p);
methsig = jl_apply_tuple_type(p, 1);
methsig = jl_rewrap_unionall(methsig, sig);
_compile_all_tvar_union(methsig);
}
Expand Down
4 changes: 2 additions & 2 deletions src/subtype.c
Original file line number Diff line number Diff line change
Expand Up @@ -3393,7 +3393,7 @@ static jl_value_t *intersect_tuple(jl_datatype_t *xd, jl_datatype_t *yd, jl_sten
else if (isy)
res = (jl_value_t*)yd;
else if (p)
res = jl_apply_tuple_type(p);
res = jl_apply_tuple_type(p, 1);
else
res = jl_apply_tuple_type_v(params, np);
}
Expand Down Expand Up @@ -4130,7 +4130,7 @@ static jl_value_t *switch_union_tuple(jl_value_t *a, jl_value_t *b)
ts[1] = jl_tparam(b, i);
jl_svecset(vec, i, jl_type_union(ts, 2));
}
jl_value_t *ans = jl_apply_tuple_type(vec);
jl_value_t *ans = jl_apply_tuple_type(vec, 1);
JL_GC_POP();
return ans;
}
Expand Down
4 changes: 2 additions & 2 deletions test/docs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1028,7 +1028,7 @@ struct $(curmod_prefix)Undocumented.st3{T<:Integer, N}

# Fields
```
a :: Tuple{Vararg{T<:Integer, N}}
a :: NTuple{N, T<:Integer}
b :: Array{Int64, N}
c :: Int64
```
Expand All @@ -1052,7 +1052,7 @@ struct $(curmod_prefix)Undocumented.st4{T, N}
# Fields
```
a :: T
b :: Tuple{Vararg{T, N}}
b :: NTuple{N, T}
```

# Supertype Hierarchy
Expand Down
Loading

0 comments on commit 0287a00

Please sign in to comment.