Skip to content

Commit

Permalink
More efficient hvcat of scalars and arrays of numbers (#39729)
Browse files Browse the repository at this point in the history
First attempt to address #39713

Original:
```
const a, b, c, d = zeros(Int, 2, 2), [3 4], [2 ; 4], 5
using BenchmarkTools
@Btime [a c ; b d]   # 31 allocations and 1.25 kb
```

New:
```
@Btime [a c ; b d]   # 15 allocations and 656 bytes
```

Others unchanged, as expected. ~~Though if different types of numbers
are mixed, it still takes the longer path. I tried expanding the
definition but there's some weird stuff going on that increases
allocations in the other situations I posted in that issue.~~ Works for
any Number element type.

Fixes #39713
  • Loading branch information
BioTurboNick authored Jan 13, 2025
1 parent 3b629f1 commit f135f3a
Showing 1 changed file with 5 additions and 54 deletions.
59 changes: 5 additions & 54 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2189,48 +2189,8 @@ true
hvcat(rows::Tuple{Vararg{Int}}, xs::AbstractArray...) = typed_hvcat(promote_eltype(xs...), rows, xs...)
hvcat(rows::Tuple{Vararg{Int}}, xs::AbstractArray{T}...) where {T} = typed_hvcat(T, rows, xs...)

function typed_hvcat(::Type{T}, rows::Tuple{Vararg{Int}}, as::AbstractVecOrMat...) where T
nbr = length(rows) # number of block rows

nc = 0
for i=1:rows[1]
nc += size(as[i],2)
end

nr = 0
a = 1
for i = 1:nbr
nr += size(as[a],1)
a += rows[i]
end

out = similar(as[1], T, nr, nc)

a = 1
r = 1
for i = 1:nbr
c = 1
szi = size(as[a],1)
for j = 1:rows[i]
Aj = as[a+j-1]
szj = size(Aj,2)
if size(Aj,1) != szi
throw(DimensionMismatch("mismatched height in block row $(i) (expected $szi, got $(size(Aj,1)))"))
end
if c-1+szj > nc
throw(DimensionMismatch("block row $(i) has mismatched number of columns (expected $nc, got $(c-1+szj))"))
end
out[r:r-1+szi, c:c-1+szj] = Aj
c += szj
end
if c != nc+1
throw(DimensionMismatch("block row $(i) has mismatched number of columns (expected $nc, got $(c-1))"))
end
r += szi
a += rows[i]
end
out
end
rows_to_dimshape(rows::Tuple{Vararg{Int}}) = all(==(rows[1]), rows) ? (length(rows), rows[1]) : (rows, (sum(rows),))
typed_hvcat(::Type{T}, rows::Tuple{Vararg{Int}}, as::AbstractVecOrMat...) where T = typed_hvncat(T, rows_to_dimshape(rows), true, as...)

hvcat(rows::Tuple{Vararg{Int}}) = []
typed_hvcat(::Type{T}, rows::Tuple{Vararg{Int}}) where {T} = Vector{T}()
Expand Down Expand Up @@ -2288,16 +2248,7 @@ function typed_hvcat(::Type{T}, rows::Tuple{Vararg{Int}}, xs::Number...) where T
hvcat_fill!(Matrix{T}(undef, nr, nc), xs)
end

function typed_hvcat(::Type{T}, rows::Tuple{Vararg{Int}}, as...) where T
nbr = length(rows) # number of block rows
rs = Vector{Any}(undef, nbr)
a = 1
for i = 1:nbr
rs[i] = typed_hcat(T, as[a:a-1+rows[i]]...)
a += rows[i]
end
T[rs...;]
end
typed_hvcat(::Type{T}, rows::Tuple{Vararg{Int}}, as...) where T = typed_hvncat(T, rows_to_dimshape(rows), true, as...)

## N-dimensional concatenation ##

Expand Down Expand Up @@ -2645,7 +2596,7 @@ function _typed_hvncat_dims(::Type{T}, dims::NTuple{N, Int}, row_first::Bool, as
throw(DimensionMismatch("mismatched number of elements; expected $(outlen), got $(elementcount)"))

# copy into final array
A = cat_similar(as[1], T, outdims)
A = cat_similar(as[1], T, ntuple(i -> outdims[i], N))
# @assert all(==(0), currentdims)
outdims .= 0
hvncat_fill!(A, currentdims, outdims, d1, d2, as)
Expand Down Expand Up @@ -2739,7 +2690,7 @@ function _typed_hvncat_shape(::Type{T}, shape::NTuple{N, Tuple}, row_first, as::
# @assert all(==(0), blockcounts)

# copy into final array
A = cat_similar(as[1], T, outdims)
A = cat_similar(as[1], T, ntuple(i -> outdims[i], nd))
hvncat_fill!(A, currentdims, blockcounts, d1, d2, as)
return A
end
Expand Down

0 comments on commit f135f3a

Please sign in to comment.