Skip to content

Commit

Permalink
Adjust matvec and matmatmul! to new internal LinAlg interface (#519)
Browse files Browse the repository at this point in the history
  • Loading branch information
dkarrasch authored May 8, 2024
1 parent 3b30333 commit a09f90b
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 45 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "SparseArrays"
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
version = "1.11.0"
version = "1.12.0"

[deps]
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
Expand Down
6 changes: 3 additions & 3 deletions src/SparseArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@ Support for sparse arrays. Provides `AbstractSparseArray` and subtypes.
module SparseArrays

using Base: ReshapedArray, promote_op, setindex_shape_check, to_shape, tail,
require_one_based_indexing, promote_eltype
require_one_based_indexing, promote_eltype, @propagate_inbounds, &, |
using Base.Order: Forward
using LinearAlgebra
using LinearAlgebra: AdjOrTrans, AdjointFactorization, TransposeFactorization, matprod,
AbstractQ, AdjointQ, HessenbergQ, QRCompactWYQ, QRPackedQ, LQPackedQ, MulAddMul,
UpperOrLowerTriangular
UpperOrLowerTriangular, @stable_muladdmul


import Base: +, -, *, \, /, &, |, xor, ==, zero, @propagate_inbounds
import Base: +, -, *, \, /, ==, zero
import LinearAlgebra: mul!, ldiv!, rdiv!, cholesky, adjoint!, diag, eigen, dot,
issymmetric, istril, istriu, lu, tr, transpose!, tril!, triu!, isbanded,
cond, diagm, factorize, ishermitian, norm, opnorm, lmul!, rmul!, tril, triu,
Expand Down
56 changes: 29 additions & 27 deletions src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,28 +47,29 @@ for op ∈ (:+, :-)
end
end

generic_matmatmul!(C::StridedMatrix, tA, tB, A::SparseMatrixCSCUnion2, B::DenseMatrixUnion, _add::MulAddMul) =
spdensemul!(C, tA, tB, A, B, _add)
generic_matmatmul!(C::StridedMatrix, tA, tB, A::SparseMatrixCSCUnion2, B::AbstractTriangular, _add::MulAddMul) =
spdensemul!(C, tA, tB, A, B, _add)
generic_matvecmul!(C::StridedVecOrMat, tA, A::SparseMatrixCSCUnion2, B::DenseInputVector, _add::MulAddMul) =
spdensemul!(C, tA, 'N', A, B, _add)

Base.@constprop :aggressive function spdensemul!(C, tA, tB, A, B, _add)
if tA == 'N'
_spmatmul!(C, A, wrap(B, tB), _add.alpha, _add.beta)
elseif tA == 'T'
_At_or_Ac_mul_B!(transpose, C, A, wrap(B, tB), _add.alpha, _add.beta)
elseif tA == 'C'
_At_or_Ac_mul_B!(adjoint, C, A, wrap(B, tB), _add.alpha, _add.beta)
elseif tA in ('S', 's', 'H', 'h') && tB == 'N'
generic_matmatmul!(C::StridedMatrix, tA, tB, A::SparseMatrixCSCUnion2, B::DenseMatrixUnion, alpha::Number, beta::Number) =
spdensemul!(C, tA, tB, A, B, alpha, beta)
generic_matmatmul!(C::StridedMatrix, tA, tB, A::SparseMatrixCSCUnion2, B::AbstractTriangular, alpha::Number, beta::Number) =
spdensemul!(C, tA, tB, A, B, alpha, beta)
generic_matvecmul!(C::StridedVecOrMat, tA, A::SparseMatrixCSCUnion2, B::DenseInputVector, alpha::Number, beta::Number) =
spdensemul!(C, tA, 'N', A, B, alpha, beta)

Base.@constprop :aggressive function spdensemul!(C, tA, tB, A, B, alpha, beta)
tA_uc, tB_uc = uppercase(tA), uppercase(tB)
if tA_uc == 'N'
_spmatmul!(C, A, wrap(B, tB), alpha, beta)
elseif tA_uc == 'T'
_At_or_Ac_mul_B!(transpose, C, A, wrap(B, tB), alpha, beta)
elseif tA_uc == 'C'
_At_or_Ac_mul_B!(adjoint, C, A, wrap(B, tB), alpha, beta)
elseif tA_uc in ('S', 'H') && tB_uc == 'N'
rangefun = isuppercase(tA) ? nzrangeup : nzrangelo
diagop = tA in ('S', 's') ? identity : real
odiagop = tA in ('S', 's') ? transpose : adjoint
diagop = tA_uc == 'S' ? identity : real
odiagop = tA_uc == 'S' ? transpose : adjoint
T = eltype(C)
_mul!(rangefun, diagop, odiagop, C, A, B, T(_add.alpha), T(_add.beta))
_mul!(rangefun, diagop, odiagop, C, A, B, T(alpha), T(beta))
else
_generic_matmatmul!(C, 'N', 'N', wrap(A, tA), wrap(B, tB), _add)
@stable_muladdmul LinearAlgebra._generic_matmatmul!(C, 'N', 'N', wrap(A, tA), wrap(B, tB), MulAddMul(alpha, beta))
end
return C
end
Expand Down Expand Up @@ -116,14 +117,14 @@ function _At_or_Ac_mul_B!(tfun::Function, C, A, B, α, β)
C
end

Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix, tA, tB, A::DenseMatrixUnion, B::SparseMatrixCSCUnion2, _add::MulAddMul)
Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix, tA, tB, A::DenseMatrixUnion, B::SparseMatrixCSCUnion2, alpha::Number, beta::Number)
transA = tA == 'N' ? identity : tA == 'T' ? transpose : adjoint
if tB == 'N'
_spmul!(C, transA(A), B, _add.alpha, _add.beta)
_spmul!(C, transA(A), B, alpha, beta)
elseif tB == 'T'
_A_mul_Bt_or_Bc!(transpose, C, transA(A), B, _add.alpha, _add.beta)
_A_mul_Bt_or_Bc!(transpose, C, transA(A), B, alpha, beta)
else # tB == 'C'
_A_mul_Bt_or_Bc!(adjoint, C, transA(A), B, _add.alpha, _add.beta)
_A_mul_Bt_or_Bc!(adjoint, C, transA(A), B, alpha, beta)
end
return C
end
Expand Down Expand Up @@ -319,10 +320,11 @@ function estimate_mulsize(m::Integer, nnzA::Integer, n::Integer, nnzB::Integer,
end

Base.@constprop :aggressive function generic_matmatmul!(C::SparseMatrixCSCUnion2, tA, tB, A::SparseMatrixCSCUnion2,
B::SparseMatrixCSCUnion2, _add::MulAddMul)
A, tA = tA in ('H', 'h', 'S', 's') ? (wrap(A, tA), 'N') : (A, tA)
B, tB = tB in ('H', 'h', 'S', 's') ? (wrap(B, tB), 'N') : (B, tB)
_generic_matmatmul!(C, tA, tB, A, B, _add)
B::SparseMatrixCSCUnion2, alpha::Number, beta::Number)
tA_uc, tB_uc = uppercase(tA), uppercase(tB)
Anew, ta = tA_uc in ('S', 'H') ? (wrap(A, tA), oftype(tA, 'N')) : (A, tA)
Bnew, tb = tB_uc in ('S', 'H') ? (wrap(B, tB), oftype(tB, 'N')) : (B, tB)
@stable_muladdmul _generic_matmatmul!(C, ta, tb, Anew, Bnew, MulAddMul(alpha, beta))
end
function _generic_matmatmul!(C::SparseMatrixCSCUnion2, tA, tB, A::AbstractVecOrMat,
B::AbstractVecOrMat, _add::MulAddMul)
Expand Down
39 changes: 25 additions & 14 deletions src/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1859,29 +1859,36 @@ function (*)(A::_StridedOrTriangularMatrix{Ta}, x::AbstractSparseVector{Tx}) whe
mul!(y, A, x)
end

# TODO: remove
Base.@constprop :aggressive generic_matvecmul!(y::AbstractVector, tA, A::StridedMatrix, x::AbstractSparseVector,
_add::MulAddMul = MulAddMul()) =
generic_matvecmul!(y, tA, A, x, _add.alpha, _add.beta)
Base.@constprop :aggressive function generic_matvecmul!(y::AbstractVector, tA, A::StridedMatrix, x::AbstractSparseVector,
_add::MulAddMul = MulAddMul())
alpha::Number, beta::Number)
if tA == 'N'
_spmul!(y, A, x, _add.alpha, _add.beta)
_spmul!(y, A, x, alpha, beta)
elseif tA == 'T'
_At_or_Ac_mul_B!(transpose, y, A, x, _add.alpha, _add.beta)
_At_or_Ac_mul_B!(transpose, y, A, x, alpha, beta)
elseif tA == 'C'
_At_or_Ac_mul_B!(adjoint, y, A, x, _add.alpha, _add.beta)
_At_or_Ac_mul_B!(adjoint, y, A, x, alpha, beta)
else
_spmul!(y, wrap(A, tA), x, _add.alpha, _add.beta)
_spmul!(y, wrap(A, tA), x, alpha, beta)
end
return y
end
# TODO: remove
generic_matvecmul!(y::AbstractVector, tA, A::UpperOrLowerTriangular, x::AbstractSparseVector, _add::MulAddMul = MulAddMul()) =
generic_matvecmul!(y, tA, A, x, _add.alpha, _add.beta)
function generic_matvecmul!(y::AbstractVector, tA, A::UpperOrLowerTriangular, x::AbstractSparseVector,
_add::MulAddMul = MulAddMul())
alpha::Number, beta::Number)
@assert tA == 'N'
Adata = parent(A)
if Adata isa Transpose
_At_or_Ac_mul_B!(transpose, y, _fliptri(A), x, _add.alpha, _add.beta)
_At_or_Ac_mul_B!(transpose, y, _fliptri(A), x, alpha, beta)
elseif Adata isa Adjoint
_At_or_Ac_mul_B!(adjoint, y, _fliptri(A), x, _add.alpha, _add.beta)
_At_or_Ac_mul_B!(adjoint, y, _fliptri(A), x, alpha, beta)
else # Adata is plain
_spmul!(y, A, x, _add.alpha, _add.beta)
_spmul!(y, A, x, alpha, beta)
end
return y
end
Expand Down Expand Up @@ -1990,16 +1997,20 @@ function densemv(A::AbstractSparseMatrixCSC, x::AbstractSparseVector; trans::Abs
end

# * and mul!
# TODO: remove
Base.@constprop :aggressive generic_matvecmul!(y::AbstractVector, tA, A::AbstractSparseMatrixCSC, x::AbstractSparseVector,
_add::MulAddMul = MulAddMul()) =
generic_matvecmul!(y, tA, A, x, _add.alpha, _add.beta)
Base.@constprop :aggressive function generic_matvecmul!(y::AbstractVector, tA, A::AbstractSparseMatrixCSC, x::AbstractSparseVector,
_add::MulAddMul = MulAddMul())
alpha::Number, beta::Number)
if tA == 'N'
_spmul!(y, A, x, _add.alpha, _add.beta)
_spmul!(y, A, x, alpha, beta)
elseif tA == 'T'
_At_or_Ac_mul_B!((a,b) -> transpose(a) * b, y, A, x, _add.alpha, _add.beta)
_At_or_Ac_mul_B!((a,b) -> transpose(a) * b, y, A, x, alpha, beta)
elseif tA == 'C'
_At_or_Ac_mul_B!((a,b) -> adjoint(a) * b, y, A, x, _add.alpha, _add.beta)
_At_or_Ac_mul_B!((a,b) -> adjoint(a) * b, y, A, x, alpha, beta)
else
LinearAlgebra._generic_matvecmul!(y, 'N', wrap(A, tA), x, _add)
@stable_muladdmul LinearAlgebra._generic_matvecmul!(y, 'N', wrap(A, tA), x, MulAddMul(alpha, beta))
end
return y
end
Expand Down

0 comments on commit a09f90b

Please sign in to comment.