From e8d527e15c8a6e60227d6f2c67e656d3ea8f8dfe Mon Sep 17 00:00:00 2001 From: Tim Holy Date: Tue, 19 Jan 2021 13:19:56 -0600 Subject: [PATCH] Improve inferability of shape::Dims for cat (#39294) `cat` is often called with Varargs or heterogenous inputs, and inference almost always fails. Even when all the arrays are of the same type, if the number of varargs isn't known inference typically fails. The culprit is probably #36454. This reduces the number of failures considerably, by avoiding creation of vararg length tuples in the shape-inference pipeline. (cherry picked from commit 815076b392821609815d5d77077e042ef08d2e14) --- base/abstractarray.jl | 8 +++++++- test/abstractarray.jl | 6 ++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/base/abstractarray.jl b/base/abstractarray.jl index 30072363a34c3..1f1120740e99a 100644 --- a/base/abstractarray.jl +++ b/base/abstractarray.jl @@ -1580,6 +1580,7 @@ cat_indices(A::AbstractArray, d) = axes(A, d) cat_similar(A, ::Type{T}, shape) where T = Array{T}(undef, shape) cat_similar(A::AbstractArray, ::Type{T}, shape) where T = similar(A, T, shape) +# These are for backwards compatibility (even though internal) cat_shape(dims, shape::Tuple{Vararg{Int}}) = shape function cat_shape(dims, shapes::Tuple) out_shape = () @@ -1588,6 +1589,11 @@ function cat_shape(dims, shapes::Tuple) end return out_shape end +# The new way to compute the shape (more inferrable than combining cat_size & cat_shape, due to Varargs + issue#36454) +cat_size_shape(dims) = ntuple(zero, Val(length(dims))) +@inline cat_size_shape(dims, X, tail...) = _cat_size_shape(dims, _cshp(1, dims, (), cat_size(X)), tail...) +_cat_size_shape(dims, shape) = shape +@inline _cat_size_shape(dims, shape, X, tail...) = _cat_size_shape(dims, _cshp(1, dims, shape, cat_size(X)), tail...) _cshp(ndim::Int, ::Tuple{}, ::Tuple{}, ::Tuple{}) = () _cshp(ndim::Int, ::Tuple{}, ::Tuple{}, nshape) = nshape @@ -1631,7 +1637,7 @@ _cat(dims, X...) = cat_t(promote_eltypeof(X...), X...; dims=dims) @inline cat_t(::Type{T}, X...; dims) where {T} = _cat_t(dims, T, X...) @inline function _cat_t(dims, ::Type{T}, X...) where {T} catdims = dims2cat(dims) - shape = cat_shape(catdims, map(cat_size, X)) + shape = cat_size_shape(catdims, X...) A = cat_similar(X[1], T, shape) if count(!iszero, catdims)::Int > 1 fill!(A, zero(T)) diff --git a/test/abstractarray.jl b/test/abstractarray.jl index 804f8ffed468e..bc7c4d646fd8e 100644 --- a/test/abstractarray.jl +++ b/test/abstractarray.jl @@ -678,6 +678,12 @@ function test_cat(::Type{TestAbstractArray}) # 36041 @test_throws MethodError cat(["a"], ["b"], dims=[1, 2]) @test cat([1], [1], dims=[1, 2]) == I(2) + + # inferrability + As = [zeros(2, 2) for _ = 1:2] + @test @inferred(cat(As...; dims=Val(3))) == zeros(2, 2, 2) + cat3v(As) = cat(As...; dims=Val(3)) + @test @inferred(cat3v(As)) == zeros(2, 2, 2) end function test_ind2sub(::Type{TestAbstractArray})