Skip to content

Commit

Permalink
Add generic matmatmul for inplace sparse x sparse (#486)
Browse files Browse the repository at this point in the history
  • Loading branch information
dkarrasch authored Feb 16, 2024
1 parent 95575c0 commit cb602d7
Show file tree
Hide file tree
Showing 4 changed files with 217 additions and 15 deletions.
2 changes: 1 addition & 1 deletion src/SparseArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import Base: +, -, *, \, /, &, |, xor, ==, zero, @propagate_inbounds
import LinearAlgebra: mul!, ldiv!, rdiv!, cholesky, adjoint!, diag, eigen, dot,
issymmetric, istril, istriu, lu, tr, transpose!, tril!, triu!, isbanded,
cond, diagm, factorize, ishermitian, norm, opnorm, lmul!, rmul!, tril, triu,
matprod_dest
matprod_dest, generic_matvecmul!, generic_matmatmul!

import Base: adjoint, argmin, argmax, Array, broadcast, circshift!, complex, Complex,
conj, conj!, convert, copy, copy!, copyto!, count, diff, findall, findmax, findmin,
Expand Down
203 changes: 194 additions & 9 deletions src/linalg.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

using LinearAlgebra: AbstractTriangular, StridedMaybeAdjOrTransMat, UpperOrLowerTriangular,
RealHermSymComplexHerm, checksquare, sym_uplo
RealHermSymComplexHerm, checksquare, sym_uplo, wrap
using Random: rand!

const tilebufsize = 10800 # Approximately 32k/3

# In matrix-vector multiplication, the correct orientation of the vector is assumed.
const DenseMatrixUnion = Union{StridedMatrix, BitMatrix}
const DenseTriangular = UpperOrLowerTriangular{<:Any,<:DenseMatrixUnion}
Expand Down Expand Up @@ -45,28 +47,28 @@ for op ∈ (:+, :-)
end
end

LinearAlgebra.generic_matmatmul!(C::StridedMatrix, tA, tB, A::SparseMatrixCSCUnion2, B::DenseMatrixUnion, _add::MulAddMul) =
generic_matmatmul!(C::StridedMatrix, tA, tB, A::SparseMatrixCSCUnion2, B::DenseMatrixUnion, _add::MulAddMul) =
spdensemul!(C, tA, tB, A, B, _add)
LinearAlgebra.generic_matmatmul!(C::StridedMatrix, tA, tB, A::SparseMatrixCSCUnion2, B::AbstractTriangular, _add::MulAddMul) =
generic_matmatmul!(C::StridedMatrix, tA, tB, A::SparseMatrixCSCUnion2, B::AbstractTriangular, _add::MulAddMul) =
spdensemul!(C, tA, tB, A, B, _add)
LinearAlgebra.generic_matvecmul!(C::StridedVecOrMat, tA, A::SparseMatrixCSCUnion2, B::DenseInputVector, _add::MulAddMul) =
generic_matvecmul!(C::StridedVecOrMat, tA, A::SparseMatrixCSCUnion2, B::DenseInputVector, _add::MulAddMul) =
spdensemul!(C, tA, 'N', A, B, _add)

Base.@constprop :aggressive function spdensemul!(C, tA, tB, A, B, _add)
if tA == 'N'
_spmatmul!(C, A, LinearAlgebra.wrap(B, tB), _add.alpha, _add.beta)
_spmatmul!(C, A, wrap(B, tB), _add.alpha, _add.beta)
elseif tA == 'T'
_At_or_Ac_mul_B!(transpose, C, A, LinearAlgebra.wrap(B, tB), _add.alpha, _add.beta)
_At_or_Ac_mul_B!(transpose, C, A, wrap(B, tB), _add.alpha, _add.beta)
elseif tA == 'C'
_At_or_Ac_mul_B!(adjoint, C, A, LinearAlgebra.wrap(B, tB), _add.alpha, _add.beta)
_At_or_Ac_mul_B!(adjoint, C, A, wrap(B, tB), _add.alpha, _add.beta)
elseif tA in ('S', 's', 'H', 'h') && tB == 'N'
rangefun = isuppercase(tA) ? nzrangeup : nzrangelo
diagop = tA in ('S', 's') ? identity : real
odiagop = tA in ('S', 's') ? transpose : adjoint
T = eltype(C)
_mul!(rangefun, diagop, odiagop, C, A, B, T(_add.alpha), T(_add.beta))
else
LinearAlgebra._generic_matmatmul!(C, 'N', 'N', LinearAlgebra.wrap(A, tA), LinearAlgebra.wrap(B, tB), _add)
_generic_matmatmul!(C, 'N', 'N', wrap(A, tA), wrap(B, tB), _add)
end
return C
end
Expand Down Expand Up @@ -114,7 +116,7 @@ function _At_or_Ac_mul_B!(tfun::Function, C, A, B, α, β)
C
end

Base.@constprop :aggressive function LinearAlgebra.generic_matmatmul!(C::StridedMatrix, tA, tB, A::DenseMatrixUnion, B::SparseMatrixCSCUnion2, _add::MulAddMul)
Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix, tA, tB, A::DenseMatrixUnion, B::SparseMatrixCSCUnion2, _add::MulAddMul)
transA = tA == 'N' ? identity : tA == 'T' ? transpose : adjoint
if tB == 'N'
_spmul!(C, transA(A), B, _add.alpha, _add.beta)
Expand Down Expand Up @@ -316,6 +318,189 @@ function estimate_mulsize(m::Integer, nnzA::Integer, n::Integer, nnzB::Integer,
p >= 1 ? m*k : p > 0 ? Int(ceil(-expm1(log1p(-p) * n)*m*k)) : 0 # (1-(1-p)^n)*m*k
end

Base.@constprop :aggressive function generic_matmatmul!(C::SparseMatrixCSCUnion2, tA, tB, A::SparseMatrixCSCUnion2,
B::SparseMatrixCSCUnion2, _add::MulAddMul)
A, tA = tA in ('H', 'h', 'S', 's') ? (wrap(A, tA), 'N') : (A, tA)
B, tB = tB in ('H', 'h', 'S', 's') ? (wrap(B, tB), 'N') : (B, tB)
_generic_matmatmul!(C, tA, tB, A, B, _add)
end
function _generic_matmatmul!(C::SparseMatrixCSCUnion2, tA, tB, A::AbstractVecOrMat,
B::AbstractVecOrMat, _add::MulAddMul)
@assert tA in ('N', 'T', 'C') && tB in ('N', 'T', 'C')
require_one_based_indexing(C, A, B)
R = eltype(C)
T = eltype(A)
S = eltype(B)

mA, nA = LinearAlgebra.lapack_size(tA, A)
mB, nB = LinearAlgebra.lapack_size(tB, B)
if mB != nA
throw(DimensionMismatch(lazy"matrix A has dimensions ($mA,$nA), matrix B has dimensions ($mB,$nB)"))
end
if size(C,1) != mA || size(C,2) != nB
throw(DimensionMismatch(lazy"result C has dimensions $(size(C)), needs ($mA,$nB)"))
end

if iszero(_add.alpha) || isempty(A) || isempty(B)
return LinearAlgebra._rmul_or_fill!(C, _add.beta)
end

tile_size = 0
if isbitstype(R) && isbitstype(T) && isbitstype(S) && (tA == 'N' || tB != 'N')
tile_size = floor(Int, sqrt(tilebufsize / max(sizeof(R), sizeof(S), sizeof(T), 1)))
end
@inbounds begin
if tile_size > 0
sz = (tile_size, tile_size)
Atile = Array{T}(undef, sz)
Btile = Array{S}(undef, sz)

z1 = zero(A[1, 1]*B[1, 1] + A[1, 1]*B[1, 1])
z = convert(promote_type(typeof(z1), R), z1)

if mA < tile_size && nA < tile_size && nB < tile_size
copy_transpose!(Atile, 1:nA, 1:mA, tA, A, 1:mA, 1:nA)
copyto!(Btile, 1:mB, 1:nB, tB, B, 1:mB, 1:nB)
for j = 1:nB
boff = (j-1)*tile_size
for i = 1:mA
aoff = (i-1)*tile_size
s = z
for k = 1:nA
s += Atile[aoff+k] * Btile[boff+k]
end
LinearAlgebra._modify!(_add, s, C, (i,j))
end
end
else
Ctile = Array{R}(undef, sz)
for jb = 1:tile_size:nB
jlim = min(jb+tile_size-1,nB)
jlen = jlim-jb+1
for ib = 1:tile_size:mA
ilim = min(ib+tile_size-1,mA)
ilen = ilim-ib+1
fill!(Ctile, z)
for kb = 1:tile_size:nA
klim = min(kb+tile_size-1,mB)
klen = klim-kb+1
copy_transpose!(Atile, 1:klen, 1:ilen, tA, A, ib:ilim, kb:klim)
copyto!(Btile, 1:klen, 1:jlen, tB, B, kb:klim, jb:jlim)
for j=1:jlen
bcoff = (j-1)*tile_size
for i = 1:ilen
aoff = (i-1)*tile_size
s = z
for k = 1:klen
s += Atile[aoff+k] * Btile[bcoff+k]
end
Ctile[bcoff+i] += s
end
end
end
if isone(_add.alpha) && iszero(_add.beta)
copyto!(C, ib:ilim, jb:jlim, Ctile, 1:ilen, 1:jlen)
else
C[ib:ilim, jb:jlim] .= @views _add.(Ctile[1:ilen, 1:jlen], C[ib:ilim, jb:jlim])
end
end
end
end
else
# Multiplication for non-plain-data uses the naive algorithm
if tA == 'N'
if tB == 'N'
for i = 1:mA, j = 1:nB
z2 = zero(A[i, 1]*B[1, j] + A[i, 1]*B[1, j])
Ctmp = convert(promote_type(R, typeof(z2)), z2)
for k = 1:nA
Ctmp += A[i, k]*B[k, j]
end
LinearAlgebra._modify!(_add, Ctmp, C, (i,j))
end
elseif tB == 'T'
for i = 1:mA, j = 1:nB
z2 = zero(A[i, 1]*transpose(B[j, 1]) + A[i, 1]*transpose(B[j, 1]))
Ctmp = convert(promote_type(R, typeof(z2)), z2)
for k = 1:nA
Ctmp += A[i, k] * transpose(B[j, k])
end
LinearAlgebra._modify!(_add, Ctmp, C, (i,j))
end
else
for i = 1:mA, j = 1:nB
z2 = zero(A[i, 1]*B[j, 1]' + A[i, 1]*B[j, 1]')
Ctmp = convert(promote_type(R, typeof(z2)), z2)
for k = 1:nA
Ctmp += A[i, k]*B[j, k]'
end
LinearAlgebra._modify!(_add, Ctmp, C, (i,j))
end
end
elseif tA == 'T'
if tB == 'N'
for i = 1:mA, j = 1:nB
z2 = zero(transpose(A[1, i])*B[1, j] + transpose(A[1, i])*B[1, j])
Ctmp = convert(promote_type(R, typeof(z2)), z2)
for k = 1:nA
Ctmp += transpose(A[k, i]) * B[k, j]
end
LinearAlgebra._modify!(_add, Ctmp, C, (i,j))
end
elseif tB == 'T'
for i = 1:mA, j = 1:nB
z2 = zero(transpose(A[1, i])*transpose(B[j, 1]) + transpose(A[1, i])*transpose(B[j, 1]))
Ctmp = convert(promote_type(R, typeof(z2)), z2)
for k = 1:nA
Ctmp += transpose(A[k, i]) * transpose(B[j, k])
end
LinearAlgebra._modify!(_add, Ctmp, C, (i,j))
end
else
for i = 1:mA, j = 1:nB
z2 = zero(transpose(A[1, i])*B[j, 1]' + transpose(A[1, i])*B[j, 1]')
Ctmp = convert(promote_type(R, typeof(z2)), z2)
for k = 1:nA
Ctmp += transpose(A[k, i]) * adjoint(B[j, k])
end
LinearAlgebra._modify!(_add, Ctmp, C, (i,j))
end
end
else
if tB == 'N'
for i = 1:mA, j = 1:nB
z2 = zero(A[1, i]'*B[1, j] + A[1, i]'*B[1, j])
Ctmp = convert(promote_type(R, typeof(z2)), z2)
for k = 1:nA
Ctmp += A[k, i]'B[k, j]
end
LinearAlgebra._modify!(_add, Ctmp, C, (i,j))
end
elseif tB == 'T'
for i = 1:mA, j = 1:nB
z2 = zero(A[1, i]'*transpose(B[j, 1]) + A[1, i]'*transpose(B[j, 1]))
Ctmp = convert(promote_type(R, typeof(z2)), z2)
for k = 1:nA
Ctmp += adjoint(A[k, i]) * transpose(B[j, k])
end
LinearAlgebra._modify!(_add, Ctmp, C, (i,j))
end
else
for i = 1:mA, j = 1:nB
z2 = zero(A[1, i]'*B[j, 1]' + A[1, i]'*B[j, 1]')
Ctmp = convert(promote_type(R, typeof(z2)), z2)
for k = 1:nA
Ctmp += A[k, i]'B[j, k]'
end
LinearAlgebra._modify!(_add, Ctmp, C, (i,j))
end
end
end
end
end # @inbounds
C
end

if VERSION < v"1.10.0-DEV.299"
top_set_bit(x::Base.BitInteger) = 8 * sizeof(x) - leading_zeros(x)
else
Expand Down
10 changes: 5 additions & 5 deletions src/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1858,7 +1858,7 @@ function (*)(A::_StridedOrTriangularMatrix{Ta}, x::AbstractSparseVector{Tx}) whe
mul!(y, A, x)
end

Base.@constprop :aggressive function LinearAlgebra.generic_matvecmul!(y::AbstractVector, tA, A::StridedMatrix, x::AbstractSparseVector,
Base.@constprop :aggressive function generic_matvecmul!(y::AbstractVector, tA, A::StridedMatrix, x::AbstractSparseVector,
_add::MulAddMul = MulAddMul())
if tA == 'N'
_spmul!(y, A, x, _add.alpha, _add.beta)
Expand All @@ -1867,11 +1867,11 @@ Base.@constprop :aggressive function LinearAlgebra.generic_matvecmul!(y::Abstrac
elseif tA == 'C'
_At_or_Ac_mul_B!(adjoint, y, A, x, _add.alpha, _add.beta)
else
_spmul!(y, LinearAlgebra.wrap(A, tA), x, _add.alpha, _add.beta)
_spmul!(y, wrap(A, tA), x, _add.alpha, _add.beta)
end
return y
end
function LinearAlgebra.generic_matvecmul!(y::AbstractVector, tA, A::UpperOrLowerTriangular, x::AbstractSparseVector,
function generic_matvecmul!(y::AbstractVector, tA, A::UpperOrLowerTriangular, x::AbstractSparseVector,
_add::MulAddMul = MulAddMul())
@assert tA == 'N'
Adata = parent(A)
Expand Down Expand Up @@ -1989,7 +1989,7 @@ function densemv(A::AbstractSparseMatrixCSC, x::AbstractSparseVector; trans::Abs
end

# * and mul!
Base.@constprop :aggressive function LinearAlgebra.generic_matvecmul!(y::AbstractVector, tA, A::AbstractSparseMatrixCSC, x::AbstractSparseVector,
Base.@constprop :aggressive function generic_matvecmul!(y::AbstractVector, tA, A::AbstractSparseMatrixCSC, x::AbstractSparseVector,
_add::MulAddMul = MulAddMul())
if tA == 'N'
_spmul!(y, A, x, _add.alpha, _add.beta)
Expand All @@ -1998,7 +1998,7 @@ Base.@constprop :aggressive function LinearAlgebra.generic_matvecmul!(y::Abstrac
elseif tA == 'C'
_At_or_Ac_mul_B!((a,b) -> adjoint(a) * b, y, A, x, _add.alpha, _add.beta)
else
LinearAlgebra._generic_matvecmul!(y, 'N', LinearAlgebra.wrap(A, tA), x, _add)
LinearAlgebra._generic_matvecmul!(y, 'N', wrap(A, tA), x, _add)
end
return y
end
Expand Down
17 changes: 17 additions & 0 deletions test/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,23 @@ end
end
end

@testset "in-place sparse-sparse mul!" begin
for n in (20, 30)
sA = sprandn(ComplexF64, n, n, 0.1); A = Array(sA)
sB = sprandn(ComplexF64, n, n, 0.1); B = Array(sB)
sC = sprandn(ComplexF64, n, n, 0.1); C = Array(sC)
a = randn(ComplexF64); b = randn(ComplexF64)
for (sA, A) in ((sA, A), (view(sA, :, 1:1:n), A[:,1:1:n]))
for trA in (identity, adjoint, transpose), trB in (identity, adjoint, transpose)
@test mul!(copy(sC), trA(sA), trB(sB)) trA(A) * trB(B)
for α in (true, false, a), β in (true, false, b)
@test mul!(copy(sC), trA(sA), trB(sB), α, β) C*β + trA(A) * trB(B) * α
end
end
end
end
end

@testset "UniformScaling" begin
local A = sprandn(10, 10, 0.5)
MA = Array(A)
Expand Down

0 comments on commit cb602d7

Please sign in to comment.