From cb602d7b7cf46057ddc87d23cda2bdd168a548ac Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Fri, 16 Feb 2024 13:12:20 +0100 Subject: [PATCH] Add generic matmatmul for inplace sparse x sparse (#486) --- src/SparseArrays.jl | 2 +- src/linalg.jl | 203 ++++++++++++++++++++++++++++++++++++++++++-- src/sparsevector.jl | 10 +-- test/linalg.jl | 17 ++++ 4 files changed, 217 insertions(+), 15 deletions(-) diff --git a/src/SparseArrays.jl b/src/SparseArrays.jl index aaf2f95b..ff554c8e 100644 --- a/src/SparseArrays.jl +++ b/src/SparseArrays.jl @@ -18,7 +18,7 @@ import Base: +, -, *, \, /, &, |, xor, ==, zero, @propagate_inbounds import LinearAlgebra: mul!, ldiv!, rdiv!, cholesky, adjoint!, diag, eigen, dot, issymmetric, istril, istriu, lu, tr, transpose!, tril!, triu!, isbanded, cond, diagm, factorize, ishermitian, norm, opnorm, lmul!, rmul!, tril, triu, - matprod_dest + matprod_dest, generic_matvecmul!, generic_matmatmul! import Base: adjoint, argmin, argmax, Array, broadcast, circshift!, complex, Complex, conj, conj!, convert, copy, copy!, copyto!, count, diff, findall, findmax, findmin, diff --git a/src/linalg.jl b/src/linalg.jl index 5dd9e9a3..ca70ac4c 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -1,9 +1,11 @@ # This file is a part of Julia. License is MIT: https://julialang.org/license using LinearAlgebra: AbstractTriangular, StridedMaybeAdjOrTransMat, UpperOrLowerTriangular, - RealHermSymComplexHerm, checksquare, sym_uplo + RealHermSymComplexHerm, checksquare, sym_uplo, wrap using Random: rand! +const tilebufsize = 10800 # Approximately 32k/3 + # In matrix-vector multiplication, the correct orientation of the vector is assumed. const DenseMatrixUnion = Union{StridedMatrix, BitMatrix} const DenseTriangular = UpperOrLowerTriangular{<:Any,<:DenseMatrixUnion} @@ -45,20 +47,20 @@ for op ∈ (:+, :-) end end -LinearAlgebra.generic_matmatmul!(C::StridedMatrix, tA, tB, A::SparseMatrixCSCUnion2, B::DenseMatrixUnion, _add::MulAddMul) = +generic_matmatmul!(C::StridedMatrix, tA, tB, A::SparseMatrixCSCUnion2, B::DenseMatrixUnion, _add::MulAddMul) = spdensemul!(C, tA, tB, A, B, _add) -LinearAlgebra.generic_matmatmul!(C::StridedMatrix, tA, tB, A::SparseMatrixCSCUnion2, B::AbstractTriangular, _add::MulAddMul) = +generic_matmatmul!(C::StridedMatrix, tA, tB, A::SparseMatrixCSCUnion2, B::AbstractTriangular, _add::MulAddMul) = spdensemul!(C, tA, tB, A, B, _add) -LinearAlgebra.generic_matvecmul!(C::StridedVecOrMat, tA, A::SparseMatrixCSCUnion2, B::DenseInputVector, _add::MulAddMul) = +generic_matvecmul!(C::StridedVecOrMat, tA, A::SparseMatrixCSCUnion2, B::DenseInputVector, _add::MulAddMul) = spdensemul!(C, tA, 'N', A, B, _add) Base.@constprop :aggressive function spdensemul!(C, tA, tB, A, B, _add) if tA == 'N' - _spmatmul!(C, A, LinearAlgebra.wrap(B, tB), _add.alpha, _add.beta) + _spmatmul!(C, A, wrap(B, tB), _add.alpha, _add.beta) elseif tA == 'T' - _At_or_Ac_mul_B!(transpose, C, A, LinearAlgebra.wrap(B, tB), _add.alpha, _add.beta) + _At_or_Ac_mul_B!(transpose, C, A, wrap(B, tB), _add.alpha, _add.beta) elseif tA == 'C' - _At_or_Ac_mul_B!(adjoint, C, A, LinearAlgebra.wrap(B, tB), _add.alpha, _add.beta) + _At_or_Ac_mul_B!(adjoint, C, A, wrap(B, tB), _add.alpha, _add.beta) elseif tA in ('S', 's', 'H', 'h') && tB == 'N' rangefun = isuppercase(tA) ? nzrangeup : nzrangelo diagop = tA in ('S', 's') ? identity : real @@ -66,7 +68,7 @@ Base.@constprop :aggressive function spdensemul!(C, tA, tB, A, B, _add) T = eltype(C) _mul!(rangefun, diagop, odiagop, C, A, B, T(_add.alpha), T(_add.beta)) else - LinearAlgebra._generic_matmatmul!(C, 'N', 'N', LinearAlgebra.wrap(A, tA), LinearAlgebra.wrap(B, tB), _add) + _generic_matmatmul!(C, 'N', 'N', wrap(A, tA), wrap(B, tB), _add) end return C end @@ -114,7 +116,7 @@ function _At_or_Ac_mul_B!(tfun::Function, C, A, B, α, β) C end -Base.@constprop :aggressive function LinearAlgebra.generic_matmatmul!(C::StridedMatrix, tA, tB, A::DenseMatrixUnion, B::SparseMatrixCSCUnion2, _add::MulAddMul) +Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix, tA, tB, A::DenseMatrixUnion, B::SparseMatrixCSCUnion2, _add::MulAddMul) transA = tA == 'N' ? identity : tA == 'T' ? transpose : adjoint if tB == 'N' _spmul!(C, transA(A), B, _add.alpha, _add.beta) @@ -316,6 +318,189 @@ function estimate_mulsize(m::Integer, nnzA::Integer, n::Integer, nnzB::Integer, p >= 1 ? m*k : p > 0 ? Int(ceil(-expm1(log1p(-p) * n)*m*k)) : 0 # (1-(1-p)^n)*m*k end +Base.@constprop :aggressive function generic_matmatmul!(C::SparseMatrixCSCUnion2, tA, tB, A::SparseMatrixCSCUnion2, + B::SparseMatrixCSCUnion2, _add::MulAddMul) + A, tA = tA in ('H', 'h', 'S', 's') ? (wrap(A, tA), 'N') : (A, tA) + B, tB = tB in ('H', 'h', 'S', 's') ? (wrap(B, tB), 'N') : (B, tB) + _generic_matmatmul!(C, tA, tB, A, B, _add) +end +function _generic_matmatmul!(C::SparseMatrixCSCUnion2, tA, tB, A::AbstractVecOrMat, + B::AbstractVecOrMat, _add::MulAddMul) + @assert tA in ('N', 'T', 'C') && tB in ('N', 'T', 'C') + require_one_based_indexing(C, A, B) + R = eltype(C) + T = eltype(A) + S = eltype(B) + + mA, nA = LinearAlgebra.lapack_size(tA, A) + mB, nB = LinearAlgebra.lapack_size(tB, B) + if mB != nA + throw(DimensionMismatch(lazy"matrix A has dimensions ($mA,$nA), matrix B has dimensions ($mB,$nB)")) + end + if size(C,1) != mA || size(C,2) != nB + throw(DimensionMismatch(lazy"result C has dimensions $(size(C)), needs ($mA,$nB)")) + end + + if iszero(_add.alpha) || isempty(A) || isempty(B) + return LinearAlgebra._rmul_or_fill!(C, _add.beta) + end + + tile_size = 0 + if isbitstype(R) && isbitstype(T) && isbitstype(S) && (tA == 'N' || tB != 'N') + tile_size = floor(Int, sqrt(tilebufsize / max(sizeof(R), sizeof(S), sizeof(T), 1))) + end + @inbounds begin + if tile_size > 0 + sz = (tile_size, tile_size) + Atile = Array{T}(undef, sz) + Btile = Array{S}(undef, sz) + + z1 = zero(A[1, 1]*B[1, 1] + A[1, 1]*B[1, 1]) + z = convert(promote_type(typeof(z1), R), z1) + + if mA < tile_size && nA < tile_size && nB < tile_size + copy_transpose!(Atile, 1:nA, 1:mA, tA, A, 1:mA, 1:nA) + copyto!(Btile, 1:mB, 1:nB, tB, B, 1:mB, 1:nB) + for j = 1:nB + boff = (j-1)*tile_size + for i = 1:mA + aoff = (i-1)*tile_size + s = z + for k = 1:nA + s += Atile[aoff+k] * Btile[boff+k] + end + LinearAlgebra._modify!(_add, s, C, (i,j)) + end + end + else + Ctile = Array{R}(undef, sz) + for jb = 1:tile_size:nB + jlim = min(jb+tile_size-1,nB) + jlen = jlim-jb+1 + for ib = 1:tile_size:mA + ilim = min(ib+tile_size-1,mA) + ilen = ilim-ib+1 + fill!(Ctile, z) + for kb = 1:tile_size:nA + klim = min(kb+tile_size-1,mB) + klen = klim-kb+1 + copy_transpose!(Atile, 1:klen, 1:ilen, tA, A, ib:ilim, kb:klim) + copyto!(Btile, 1:klen, 1:jlen, tB, B, kb:klim, jb:jlim) + for j=1:jlen + bcoff = (j-1)*tile_size + for i = 1:ilen + aoff = (i-1)*tile_size + s = z + for k = 1:klen + s += Atile[aoff+k] * Btile[bcoff+k] + end + Ctile[bcoff+i] += s + end + end + end + if isone(_add.alpha) && iszero(_add.beta) + copyto!(C, ib:ilim, jb:jlim, Ctile, 1:ilen, 1:jlen) + else + C[ib:ilim, jb:jlim] .= @views _add.(Ctile[1:ilen, 1:jlen], C[ib:ilim, jb:jlim]) + end + end + end + end + else + # Multiplication for non-plain-data uses the naive algorithm + if tA == 'N' + if tB == 'N' + for i = 1:mA, j = 1:nB + z2 = zero(A[i, 1]*B[1, j] + A[i, 1]*B[1, j]) + Ctmp = convert(promote_type(R, typeof(z2)), z2) + for k = 1:nA + Ctmp += A[i, k]*B[k, j] + end + LinearAlgebra._modify!(_add, Ctmp, C, (i,j)) + end + elseif tB == 'T' + for i = 1:mA, j = 1:nB + z2 = zero(A[i, 1]*transpose(B[j, 1]) + A[i, 1]*transpose(B[j, 1])) + Ctmp = convert(promote_type(R, typeof(z2)), z2) + for k = 1:nA + Ctmp += A[i, k] * transpose(B[j, k]) + end + LinearAlgebra._modify!(_add, Ctmp, C, (i,j)) + end + else + for i = 1:mA, j = 1:nB + z2 = zero(A[i, 1]*B[j, 1]' + A[i, 1]*B[j, 1]') + Ctmp = convert(promote_type(R, typeof(z2)), z2) + for k = 1:nA + Ctmp += A[i, k]*B[j, k]' + end + LinearAlgebra._modify!(_add, Ctmp, C, (i,j)) + end + end + elseif tA == 'T' + if tB == 'N' + for i = 1:mA, j = 1:nB + z2 = zero(transpose(A[1, i])*B[1, j] + transpose(A[1, i])*B[1, j]) + Ctmp = convert(promote_type(R, typeof(z2)), z2) + for k = 1:nA + Ctmp += transpose(A[k, i]) * B[k, j] + end + LinearAlgebra._modify!(_add, Ctmp, C, (i,j)) + end + elseif tB == 'T' + for i = 1:mA, j = 1:nB + z2 = zero(transpose(A[1, i])*transpose(B[j, 1]) + transpose(A[1, i])*transpose(B[j, 1])) + Ctmp = convert(promote_type(R, typeof(z2)), z2) + for k = 1:nA + Ctmp += transpose(A[k, i]) * transpose(B[j, k]) + end + LinearAlgebra._modify!(_add, Ctmp, C, (i,j)) + end + else + for i = 1:mA, j = 1:nB + z2 = zero(transpose(A[1, i])*B[j, 1]' + transpose(A[1, i])*B[j, 1]') + Ctmp = convert(promote_type(R, typeof(z2)), z2) + for k = 1:nA + Ctmp += transpose(A[k, i]) * adjoint(B[j, k]) + end + LinearAlgebra._modify!(_add, Ctmp, C, (i,j)) + end + end + else + if tB == 'N' + for i = 1:mA, j = 1:nB + z2 = zero(A[1, i]'*B[1, j] + A[1, i]'*B[1, j]) + Ctmp = convert(promote_type(R, typeof(z2)), z2) + for k = 1:nA + Ctmp += A[k, i]'B[k, j] + end + LinearAlgebra._modify!(_add, Ctmp, C, (i,j)) + end + elseif tB == 'T' + for i = 1:mA, j = 1:nB + z2 = zero(A[1, i]'*transpose(B[j, 1]) + A[1, i]'*transpose(B[j, 1])) + Ctmp = convert(promote_type(R, typeof(z2)), z2) + for k = 1:nA + Ctmp += adjoint(A[k, i]) * transpose(B[j, k]) + end + LinearAlgebra._modify!(_add, Ctmp, C, (i,j)) + end + else + for i = 1:mA, j = 1:nB + z2 = zero(A[1, i]'*B[j, 1]' + A[1, i]'*B[j, 1]') + Ctmp = convert(promote_type(R, typeof(z2)), z2) + for k = 1:nA + Ctmp += A[k, i]'B[j, k]' + end + LinearAlgebra._modify!(_add, Ctmp, C, (i,j)) + end + end + end + end + end # @inbounds + C +end + if VERSION < v"1.10.0-DEV.299" top_set_bit(x::Base.BitInteger) = 8 * sizeof(x) - leading_zeros(x) else diff --git a/src/sparsevector.jl b/src/sparsevector.jl index 234b5a7b..206a1e06 100644 --- a/src/sparsevector.jl +++ b/src/sparsevector.jl @@ -1858,7 +1858,7 @@ function (*)(A::_StridedOrTriangularMatrix{Ta}, x::AbstractSparseVector{Tx}) whe mul!(y, A, x) end -Base.@constprop :aggressive function LinearAlgebra.generic_matvecmul!(y::AbstractVector, tA, A::StridedMatrix, x::AbstractSparseVector, +Base.@constprop :aggressive function generic_matvecmul!(y::AbstractVector, tA, A::StridedMatrix, x::AbstractSparseVector, _add::MulAddMul = MulAddMul()) if tA == 'N' _spmul!(y, A, x, _add.alpha, _add.beta) @@ -1867,11 +1867,11 @@ Base.@constprop :aggressive function LinearAlgebra.generic_matvecmul!(y::Abstrac elseif tA == 'C' _At_or_Ac_mul_B!(adjoint, y, A, x, _add.alpha, _add.beta) else - _spmul!(y, LinearAlgebra.wrap(A, tA), x, _add.alpha, _add.beta) + _spmul!(y, wrap(A, tA), x, _add.alpha, _add.beta) end return y end -function LinearAlgebra.generic_matvecmul!(y::AbstractVector, tA, A::UpperOrLowerTriangular, x::AbstractSparseVector, +function generic_matvecmul!(y::AbstractVector, tA, A::UpperOrLowerTriangular, x::AbstractSparseVector, _add::MulAddMul = MulAddMul()) @assert tA == 'N' Adata = parent(A) @@ -1989,7 +1989,7 @@ function densemv(A::AbstractSparseMatrixCSC, x::AbstractSparseVector; trans::Abs end # * and mul! -Base.@constprop :aggressive function LinearAlgebra.generic_matvecmul!(y::AbstractVector, tA, A::AbstractSparseMatrixCSC, x::AbstractSparseVector, +Base.@constprop :aggressive function generic_matvecmul!(y::AbstractVector, tA, A::AbstractSparseMatrixCSC, x::AbstractSparseVector, _add::MulAddMul = MulAddMul()) if tA == 'N' _spmul!(y, A, x, _add.alpha, _add.beta) @@ -1998,7 +1998,7 @@ Base.@constprop :aggressive function LinearAlgebra.generic_matvecmul!(y::Abstrac elseif tA == 'C' _At_or_Ac_mul_B!((a,b) -> adjoint(a) * b, y, A, x, _add.alpha, _add.beta) else - LinearAlgebra._generic_matvecmul!(y, 'N', LinearAlgebra.wrap(A, tA), x, _add) + LinearAlgebra._generic_matvecmul!(y, 'N', wrap(A, tA), x, _add) end return y end diff --git a/test/linalg.jl b/test/linalg.jl index 48029084..45d42d9f 100644 --- a/test/linalg.jl +++ b/test/linalg.jl @@ -228,6 +228,23 @@ end end end +@testset "in-place sparse-sparse mul!" begin + for n in (20, 30) + sA = sprandn(ComplexF64, n, n, 0.1); A = Array(sA) + sB = sprandn(ComplexF64, n, n, 0.1); B = Array(sB) + sC = sprandn(ComplexF64, n, n, 0.1); C = Array(sC) + a = randn(ComplexF64); b = randn(ComplexF64) + for (sA, A) in ((sA, A), (view(sA, :, 1:1:n), A[:,1:1:n])) + for trA in (identity, adjoint, transpose), trB in (identity, adjoint, transpose) + @test mul!(copy(sC), trA(sA), trB(sB)) ≈ trA(A) * trB(B) + for α in (true, false, a), β in (true, false, b) + @test mul!(copy(sC), trA(sA), trB(sB), α, β) ≈ C*β + trA(A) * trB(B) * α + end + end + end + end +end + @testset "UniformScaling" begin local A = sprandn(10, 10, 0.5) MA = Array(A)