diff --git a/jl/abstractarray.jl b/jl/abstractarray.jl index 4fc5939403297..c317478e63cbb 100644 --- a/jl/abstractarray.jl +++ b/jl/abstractarray.jl @@ -328,7 +328,7 @@ function vcat{T}(V::AbstractVector{T}...) for Vk in V n += length(Vk) end - a = similar(V[1], n) + a = similar(full(V[1]), n) pos = 1 for k=1:length(V) Vk = V[k] @@ -438,7 +438,7 @@ function cat(catdim::Integer, X...) ndimsC = max(catdim, d_max) dimsC = ntuple(ndimsC, compute_dims)::(Int...) typeC = promote_type(map(x->isa(x,AbstractArray) ? eltype(x) : typeof(x), X)...) - C = similar(isa(X[1],AbstractArray) ? X[1] : [X[1]], typeC, dimsC) + C = similar(isa(X[1],AbstractArray) ? full(X[1]) : [X[1]], typeC, dimsC) range = 1 for k=1:nargs @@ -501,7 +501,7 @@ function cat(catdim::Integer, A::AbstractArray...) ndimsC = max(catdim, d_max) dimsC = ntuple(ndimsC, compute_dims)::(Int...) typeC = promote_type(map(eltype, A)...) - C = similar(A[1], typeC, dimsC) + C = similar(full(A[1]), typeC, dimsC) range = 1 for k=1:nargs diff --git a/jl/sparse.jl b/jl/sparse.jl index e6785fa90ed56..2c95db52388eb 100644 --- a/jl/sparse.jl +++ b/jl/sparse.jl @@ -730,7 +730,7 @@ end # Sparse concatenation -function vcat{Tv, Ti}(X::SparseMatrixCSC{Tv, Ti}...) +function vcat(X::SparseMatrixCSC...) num = length(X) mX = [ size(x, 1) | x = X ] nX = [ size(x, 2) | x = X ] @@ -740,6 +740,9 @@ function vcat{Tv, Ti}(X::SparseMatrixCSC{Tv, Ti}...) end m = sum(mX) + Tv = promote_type(map(x->eltype(x.nzval), X)...) + Ti = promote_type(map(x->eltype(x.rowval), X)...) + colptr = Array(Ti, n + 1) nnzX = [ nnz(x) | x = X ] nnz_res = sum(nnzX) @@ -765,7 +768,7 @@ function vcat{Tv, Ti}(X::SparseMatrixCSC{Tv, Ti}...) SparseMatrixCSC(m, n, colptr, rowval, nzval) end -function hcat{Tv, Ti}(X::SparseMatrixCSC{Tv, Ti}...) +function hcat(X::SparseMatrixCSC...) num = length(X) mX = [ size(x, 1) | x = X ] nX = [ size(x, 2) | x = X ] @@ -775,6 +778,9 @@ function hcat{Tv, Ti}(X::SparseMatrixCSC{Tv, Ti}...) end n = sum(nX) + Tv = promote_type(map(x->eltype(x.nzval), X)...) + Ti = promote_type(map(x->eltype(x.rowval), X)...) + colptr = Array(Ti, n + 1) nnzX = [ nnz(x) | x = X ] nnz_res = sum(nnzX) @@ -794,10 +800,10 @@ function hcat{Tv, Ti}(X::SparseMatrixCSC{Tv, Ti}...) SparseMatrixCSC(m, n, colptr, rowval, nzval) end -function hvcat{Tv, Ti}(rows::(Int...), X::SparseMatrixCSC{Tv, Ti}...) +function hvcat(rows::(Int...), X::SparseMatrixCSC...) nbr = length(rows) # number of block rows - tmp_rows = Array(SparseMatrixCSC{Tv,Ti}, nbr) + tmp_rows = Array(SparseMatrixCSC, nbr) k = 0 for i = 1 : nbr tmp_rows[i] = hcat(X[(1 : rows[i]) + k]...) @@ -851,7 +857,7 @@ function _jl_spa_store_reset{T}(S::SparseAccumulator{T}, col, colptr, rowval, nz else offs += 1 end - flags[i] = false + flags[pos] = false end colptr[col+1] = start + nvals diff --git a/test/sparse.jl b/test/sparse.jl index ac6c58af15057..07d2bd7ec310e 100644 --- a/test/sparse.jl +++ b/test/sparse.jl @@ -1,6 +1,40 @@ -s = speye(3) -o = ones(3) -@assert s * s == s -@assert s \ o == o -@assert [s s] == sparse([1, 2, 3, 1, 2, 3], [1, 2, 3, 4, 5, 6], ones(6)) -@assert [s; s] == sparse([1, 4, 2, 5, 3, 6], [1, 1, 2, 2, 3, 3], ones(6)) +# check matrix operations +se33 = speye(3) +@assert se33 * se33 == se33 + +# check mixed sparse-dense matrix operations +do33 = ones(3) +@assert se33 \ do33 == do33 + +# check horiz concatenation +@assert [se33 se33] == sparse([1, 2, 3, 1, 2, 3], [1, 2, 3, 4, 5, 6], ones(6)) + +# check vert concatenation +@assert [se33; se33] == sparse([1, 4, 2, 5, 3, 6], [1, 1, 2, 2, 3, 3], ones(6)) + +# check h+v concatenation +se44 = speye(4) +sz42 = spzeros(4, 2) +sz41 = spzeros(4, 1) +sz34 = spzeros(3, 4) +se77 = speye(7) +@assert [se44 sz42 sz41; sz34 se33] == se77 + +# check concatenation promotion +sz41_f32 = spzeros(Float32, 4, 1) +se33_i32 = speye(Int32, 3, 3) +@assert [se44 sz42 sz41_f32; sz34 se33_i32] == se77 + +# check mixed sparse-dense concatenation +sz33 = spzeros(3) +de33 = eye(3) +@assert [se33 de33; sz33 se33] == full([se33 se33; sz33 se33 ]) + +# check splicing + concatenation on +# random instances, with nested vcat +# (also side-checks sparse ref, which uses +# sparse multiplication) +for i = 1 : 10 + a = sprand(5, 4, 0.5) + @assert [a[1:2,1:2] a[1:2,3:4]; a[3:5,1] [a[3:4,2:4]; a[5,2:4]]] == a +end