Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sparse: bugfix + improvements in [hv]cat + more tests #592

Merged
merged 3 commits into from
Mar 15, 2012
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions jl/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ function vcat{T}(V::AbstractVector{T}...)
for Vk in V
n += length(Vk)
end
a = similar(V[1], n)
a = similar(full(V[1]), n)
pos = 1
for k=1:length(V)
Vk = V[k]
Expand Down Expand Up @@ -438,7 +438,7 @@ function cat(catdim::Integer, X...)
ndimsC = max(catdim, d_max)
dimsC = ntuple(ndimsC, compute_dims)::(Int...)
typeC = promote_type(map(x->isa(x,AbstractArray) ? eltype(x) : typeof(x), X)...)
C = similar(isa(X[1],AbstractArray) ? X[1] : [X[1]], typeC, dimsC)
C = similar(isa(X[1],AbstractArray) ? full(X[1]) : [X[1]], typeC, dimsC)

range = 1
for k=1:nargs
Expand Down Expand Up @@ -501,7 +501,7 @@ function cat(catdim::Integer, A::AbstractArray...)
ndimsC = max(catdim, d_max)
dimsC = ntuple(ndimsC, compute_dims)::(Int...)
typeC = promote_type(map(eltype, A)...)
C = similar(A[1], typeC, dimsC)
C = similar(full(A[1]), typeC, dimsC)

range = 1
for k=1:nargs
Expand Down
16 changes: 11 additions & 5 deletions jl/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,7 @@ end

# Sparse concatenation

function vcat{Tv, Ti}(X::SparseMatrixCSC{Tv, Ti}...)
function vcat(X::SparseMatrixCSC...)
num = length(X)
mX = [ size(x, 1) | x = X ]
nX = [ size(x, 2) | x = X ]
Expand All @@ -740,6 +740,9 @@ function vcat{Tv, Ti}(X::SparseMatrixCSC{Tv, Ti}...)
end
m = sum(mX)

Tv = promote_type(map(x->eltype(x.nzval), X)...)
Ti = promote_type(map(x->eltype(x.rowval), X)...)

colptr = Array(Ti, n + 1)
nnzX = [ nnz(x) | x = X ]
nnz_res = sum(nnzX)
Expand All @@ -765,7 +768,7 @@ function vcat{Tv, Ti}(X::SparseMatrixCSC{Tv, Ti}...)
SparseMatrixCSC(m, n, colptr, rowval, nzval)
end

function hcat{Tv, Ti}(X::SparseMatrixCSC{Tv, Ti}...)
function hcat(X::SparseMatrixCSC...)
num = length(X)
mX = [ size(x, 1) | x = X ]
nX = [ size(x, 2) | x = X ]
Expand All @@ -775,6 +778,9 @@ function hcat{Tv, Ti}(X::SparseMatrixCSC{Tv, Ti}...)
end
n = sum(nX)

Tv = promote_type(map(x->eltype(x.nzval), X)...)
Ti = promote_type(map(x->eltype(x.rowval), X)...)

colptr = Array(Ti, n + 1)
nnzX = [ nnz(x) | x = X ]
nnz_res = sum(nnzX)
Expand All @@ -794,10 +800,10 @@ function hcat{Tv, Ti}(X::SparseMatrixCSC{Tv, Ti}...)
SparseMatrixCSC(m, n, colptr, rowval, nzval)
end

function hvcat{Tv, Ti}(rows::(Int...), X::SparseMatrixCSC{Tv, Ti}...)
function hvcat(rows::(Int...), X::SparseMatrixCSC...)
nbr = length(rows) # number of block rows

tmp_rows = Array(SparseMatrixCSC{Tv,Ti}, nbr)
tmp_rows = Array(SparseMatrixCSC, nbr)
k = 0
for i = 1 : nbr
tmp_rows[i] = hcat(X[(1 : rows[i]) + k]...)
Expand Down Expand Up @@ -851,7 +857,7 @@ function _jl_spa_store_reset{T}(S::SparseAccumulator{T}, col, colptr, rowval, nz
else
offs += 1
end
flags[i] = false
flags[pos] = false
end

colptr[col+1] = start + nvals
Expand Down
46 changes: 40 additions & 6 deletions test/sparse.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,40 @@
s = speye(3)
o = ones(3)
@assert s * s == s
@assert s \ o == o
@assert [s s] == sparse([1, 2, 3, 1, 2, 3], [1, 2, 3, 4, 5, 6], ones(6))
@assert [s; s] == sparse([1, 4, 2, 5, 3, 6], [1, 1, 2, 2, 3, 3], ones(6))
# check matrix operations
se33 = speye(3)
@assert se33 * se33 == se33

# check mixed sparse-dense matrix operations
do33 = ones(3)
@assert se33 \ do33 == do33

# check horiz concatenation
@assert [se33 se33] == sparse([1, 2, 3, 1, 2, 3], [1, 2, 3, 4, 5, 6], ones(6))

# check vert concatenation
@assert [se33; se33] == sparse([1, 4, 2, 5, 3, 6], [1, 1, 2, 2, 3, 3], ones(6))

# check h+v concatenation
se44 = speye(4)
sz42 = spzeros(4, 2)
sz41 = spzeros(4, 1)
sz34 = spzeros(3, 4)
se77 = speye(7)
@assert [se44 sz42 sz41; sz34 se33] == se77

# check concatenation promotion
sz41_f32 = spzeros(Float32, 4, 1)
se33_i32 = speye(Int32, 3, 3)
@assert [se44 sz42 sz41_f32; sz34 se33_i32] == se77

# check mixed sparse-dense concatenation
sz33 = spzeros(3)
de33 = eye(3)
@assert [se33 de33; sz33 se33] == full([se33 se33; sz33 se33 ])

# check splicing + concatenation on
# random instances, with nested vcat
# (also side-checks sparse ref, which uses
# sparse multiplication)
for i = 1 : 10
a = sprand(5, 4, 0.5)
@assert [a[1:2,1:2] a[1:2,3:4]; a[3:5,1] [a[3:4,2:4]; a[5,2:4]]] == a
end