Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve inferability of shape::Dims for cat #39294

Merged
merged 2 commits into from
Jan 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1580,6 +1580,7 @@ cat_indices(A::AbstractArray, d) = axes(A, d)
cat_similar(A, T, shape) = Array{T}(undef, shape)
cat_similar(A::AbstractArray, T, shape) = similar(A, T, shape)

# These are for backwards compatibility (even though internal)
cat_shape(dims, shape::Tuple{Vararg{Int}}) = shape
function cat_shape(dims, shapes::Tuple)
out_shape = ()
Expand All @@ -1588,6 +1589,11 @@ function cat_shape(dims, shapes::Tuple)
end
return out_shape
end
# The new way to compute the shape (more inferrable than combining cat_size & cat_shape, due to Varargs + issue#36454)
cat_size_shape(dims) = ntuple(zero, Val(length(dims)))
@inline cat_size_shape(dims, X, tail...) = _cat_size_shape(dims, _cshp(1, dims, (), cat_size(X)), tail...)
_cat_size_shape(dims, shape) = shape
@inline _cat_size_shape(dims, shape, X, tail...) = _cat_size_shape(dims, _cshp(1, dims, shape, cat_size(X)), tail...)

_cshp(ndim::Int, ::Tuple{}, ::Tuple{}, ::Tuple{}) = ()
_cshp(ndim::Int, ::Tuple{}, ::Tuple{}, nshape) = nshape
Expand Down Expand Up @@ -1631,7 +1637,7 @@ _cat(dims, X...) = cat_t(promote_eltypeof(X...), X...; dims=dims)
@inline cat_t(::Type{T}, X...; dims) where {T} = _cat_t(dims, T, X...)
@inline function _cat_t(dims, ::Type{T}, X...) where {T}
catdims = dims2cat(dims)
shape = cat_shape(catdims, map(cat_size, X))
shape = cat_size_shape(catdims, X...)
A = cat_similar(X[1], T, shape)
if count(!iszero, catdims)::Int > 1
fill!(A, zero(T))
Expand Down
6 changes: 6 additions & 0 deletions test/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,12 @@ function test_cat(::Type{TestAbstractArray})
# 36041
@test_throws MethodError cat(["a"], ["b"], dims=[1, 2])
@test cat([1], [1], dims=[1, 2]) == I(2)

# inferrability
As = [zeros(2, 2) for _ = 1:2]
@test @inferred(cat(As...; dims=Val(3))) == zeros(2, 2, 2)
cat3v(As) = cat(As...; dims=Val(3))
@test @inferred(cat3v(As)) == zeros(2, 2, 2)
end

function test_ind2sub(::Type{TestAbstractArray})
Expand Down