diff --git a/NEWS.md b/NEWS.md index 9205c406092b18..ac6c5a799299c3 100644 --- a/NEWS.md +++ b/NEWS.md @@ -29,7 +29,7 @@ Build system changes New library functions --------------------- - +* New function `Base.kron!` and corresponding overloads for various matrix types for performing Kronecker product in-place. ([#31069]). New library features -------------------- diff --git a/base/exports.jl b/base/exports.jl index 316025db9ce6ca..3d11cc04819316 100644 --- a/base/exports.jl +++ b/base/exports.jl @@ -463,6 +463,7 @@ export adjoint, transpose, kron, + kron!, # bitarrays falses, diff --git a/base/operators.jl b/base/operators.jl index c304530dfef800..c7998b51756c97 100644 --- a/base/operators.jl +++ b/base/operators.jl @@ -542,6 +542,8 @@ for op in (:+, :*, :&, :|, :xor, :min, :max, :kron) end end +function kron! end + const var"'" = adjoint """ diff --git a/stdlib/LinearAlgebra/docs/src/index.md b/stdlib/LinearAlgebra/docs/src/index.md index b78ed785080e00..c5f0448bfa629b 100644 --- a/stdlib/LinearAlgebra/docs/src/index.md +++ b/stdlib/LinearAlgebra/docs/src/index.md @@ -409,6 +409,7 @@ Base.inv(::AbstractMatrix) LinearAlgebra.pinv LinearAlgebra.nullspace Base.kron +Base.kron! LinearAlgebra.exp(::StridedMatrix{<:LinearAlgebra.BlasFloat}) Base.:^(::AbstractMatrix, ::Number) Base.:^(::Number, ::AbstractMatrix) diff --git a/stdlib/LinearAlgebra/src/LinearAlgebra.jl b/stdlib/LinearAlgebra/src/LinearAlgebra.jl index bb1dcb3c17ea70..e9476645d5e89d 100644 --- a/stdlib/LinearAlgebra/src/LinearAlgebra.jl +++ b/stdlib/LinearAlgebra/src/LinearAlgebra.jl @@ -11,7 +11,7 @@ import Base: \, /, *, ^, +, -, == import Base: USE_BLAS64, abs, acos, acosh, acot, acoth, acsc, acsch, adjoint, asec, asech, asin, asinh, atan, atanh, axes, big, broadcast, ceil, conj, convert, copy, copyto!, cos, cosh, cot, coth, csc, csch, eltype, exp, fill!, floor, getindex, hcat, - getproperty, imag, inv, isapprox, isone, iszero, IndexStyle, kron, length, log, map, ndims, + getproperty, imag, inv, isapprox, isone, iszero, IndexStyle, kron, kron!, length, log, map, ndims, oneunit, parent, power_by_squaring, print_matrix, promote_rule, real, round, sec, sech, setindex!, show, similar, sin, sincos, sinh, size, sqrt, strides, stride, tan, tanh, transpose, trunc, typed_hcat, vec diff --git a/stdlib/LinearAlgebra/src/bitarray.jl b/stdlib/LinearAlgebra/src/bitarray.jl index 3e38b073a992b0..d1857c3c386599 100644 --- a/stdlib/LinearAlgebra/src/bitarray.jl +++ b/stdlib/LinearAlgebra/src/bitarray.jl @@ -92,22 +92,29 @@ qr(A::BitMatrix) = qr(float(A)) ## kron -function kron(a::BitVector, b::BitVector) +@inline function kron!(R::BitVector, a::BitVector, b::BitVector) m = length(a) n = length(b) - R = falses(n * m) + @boundscheck length(R) == n*m || throw(DimensionMismatch()) Rc = R.chunks bc = b.chunks for j = 1:m a[j] && Base.copy_chunks!(Rc, (j-1)*n+1, bc, 1, n) end - R + return R end -function kron(a::BitMatrix, b::BitMatrix) +function kron(a::BitVector, b::BitVector) + m = length(a) + n = length(b) + R = falses(n * m) + return @inbounds kron!(R, a, b) +end + +function kron!(R::BitMatrix, a::BitMatrix, b::BitMatrix) mA,nA = size(a) mB,nB = size(b) - R = falses(mA*mB, nA*nB) + @boundscheck size(R) == (mA*mB, nA*nB) || throw(DimensionMismatch()) for i = 1:mA ri = (1:mB) .+ ((i-1)*mB) @@ -118,7 +125,14 @@ function kron(a::BitMatrix, b::BitMatrix) end end end - R + return R +end + +function kron(a::BitMatrix, b::BitMatrix) + mA,nA = size(a) + mB,nB = size(b) + R = falses(mA*mB, nA*nB) + return @inbounds kron!(R, a, b) end ## Structure query functions diff --git a/stdlib/LinearAlgebra/src/dense.jl b/stdlib/LinearAlgebra/src/dense.jl index 26d6915aad0cf7..10ce1ea44b6872 100644 --- a/stdlib/LinearAlgebra/src/dense.jl +++ b/stdlib/LinearAlgebra/src/dense.jl @@ -336,6 +336,29 @@ function tr(A::Matrix{T}) where T t end +""" + kron!(C, A, B) + +`kron!` is the in-place version of [`kron`](@ref). Computes `kron(A, B)` and stores the result in `C` +overwriting the existing value of `C`. + +!!! tip + Bounds checking can be disabled by [`@inbounds`](@ref), but you need to take care of the shape + of `C`, `A`, `B` yourself. +""" +@inline function kron!(C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix) + require_one_based_indexing(A, B) + @boundscheck (size(C) == (size(A,1)*size(B,1), size(A,2)*size(B,2))) || throw(DimensionMismatch()) + m = 0 + @inbounds for j = 1:size(A,2), l = 1:size(B,2), i = 1:size(A,1) + Aij = A[i,j] + for k = 1:size(B,1) + C[m += 1] = Aij*B[k,l] + end + end + return C +end + """ kron(A, B) @@ -383,18 +406,23 @@ julia> reshape(kron(v,w), (length(w), length(v))) ``` """ function kron(a::AbstractMatrix{T}, b::AbstractMatrix{S}) where {T,S} - require_one_based_indexing(a, b) R = Matrix{promote_op(*,T,S)}(undef, size(a,1)*size(b,1), size(a,2)*size(b,2)) - m = 0 - @inbounds for j = 1:size(a,2), l = 1:size(b,2), i = 1:size(a,1) - aij = a[i,j] - for k = 1:size(b,1) - R[m += 1] = aij*b[k,l] - end - end - R + return @inbounds kron!(R, a, b) end +kron!(c::AbstractVecOrMat, a::AbstractVecOrMat, b::Number) = mul!(c, a, b) + +Base.@propagate_inbounds function kron!(c::AbstractVector, a::AbstractVector, b::AbstractVector) + C = reshape(c, length(a)*length(b), 1) + A = reshape(a ,length(a), 1) + B = reshape(b, length(b), 1) + kron!(C, A, B) + return c +end + +Base.@propagate_inbounds kron!(C::AbstractMatrix, a::AbstractMatrix, b::AbstractVector) = kron!(C, a, reshape(b, length(b), 1)) +Base.@propagate_inbounds kron!(C::AbstractMatrix, a::AbstractVector, b::AbstractMatrix) = kron!(C, reshape(a, length(a), 1), b) + kron(a::Number, b::Union{Number, AbstractVecOrMat}) = a * b kron(a::AbstractVecOrMat, b::Number) = a * b kron(a::AbstractVector, b::AbstractVector) = vec(kron(reshape(a ,length(a), 1), reshape(b, length(b), 1))) diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index f3b4ac17eec782..27e8ba4c80f3d3 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -493,52 +493,80 @@ rdiv!(A::AbstractMatrix{T}, transD::Transpose{<:Any,<:Diagonal{T}}) where {T} = (\)(A::Union{QR,QRCompactWY,QRPivoted}, B::Diagonal) = invoke(\, Tuple{Union{QR,QRCompactWY,QRPivoted}, AbstractVecOrMat}, A, B) -function kron(A::Diagonal{T1}, B::Diagonal{T2}) where {T1<:Number, T2<:Number} + +@inline function kron!(C::AbstractMatrix{T}, A::Diagonal, B::Diagonal) where T + fill!(C, zero(T)) valA = A.diag; nA = length(valA) valB = B.diag; nB = length(valB) - valC = Vector{typeof(zero(T1)*zero(T2))}(undef,nA*nB) + nC = checksquare(C) + @boundscheck nC == nA*nB || + throw(DimensionMismatch("expect C to be a $(nA*nB)x$(nA*nB) matrix, got size $(nC)x$(nC)")) + @inbounds for i = 1:nA, j = 1:nB - valC[(i-1)*nB+j] = valA[i] * valB[j] + idx = (i-1)*nB+j + C[idx, idx] = valA[i] * valB[j] end - return Diagonal(valC) + return C end -function kron(A::Diagonal{T}, B::AbstractMatrix{S}) where {T<:Number, S<:Number} +function kron(A::Diagonal{T1}, B::Diagonal{T2}) where {T1<:Number, T2<:Number} + valA = A.diag; nA = length(valA) + valB = B.diag; nB = length(valB) + valC = Vector{typeof(zero(T1)*zero(T2))}(undef,nA*nB) + C = Diagonal(valC) + return @inbounds kron!(C, A, B) +end + +@inline function kron!(C::AbstractMatrix, A::Diagonal, B::AbstractMatrix) Base.require_one_based_indexing(B) - (mA, nA) = size(A); (mB, nB) = size(B) - R = zeros(Base.promote_op(*, T, S), mA * mB, nA * nB) + (mA, nA) = size(A); (mB, nB) = size(B); (mC, nC) = size(C); + @boundscheck (mC, nC) == (mA * mB, nA * nB) || + throw(DimensionMismatch("expect C to be a $(mA * mB)x$(nA * nB) matrix, got size $(mC)x$(nC)")) m = 1 - for j = 1:nA + @inbounds for j = 1:nA A_jj = A[j,j] for k = 1:nB for l = 1:mB - R[m] = A_jj * B[l,k] + C[m] = A_jj * B[l,k] m += 1 end m += (nA - 1) * mB end m += mB end - return R + return C end -function kron(A::AbstractMatrix{T}, B::Diagonal{S}) where {T<:Number, S<:Number} +@inline function kron!(C::AbstractMatrix, A::AbstractMatrix, B::Diagonal) require_one_based_indexing(A) - (mA, nA) = size(A); (mB, nB) = size(B) - R = zeros(promote_op(*, T, S), mA * mB, nA * nB) + (mA, nA) = size(A); (mB, nB) = size(B); (mC, nC) = size(C); + @boundscheck (mC, nC) == (mA * mB, nA * nB) || + throw(DimensionMismatch("expect C to be a $(mA * mB)x$(nA * nB) matrix, got size $(mC)x$(nC)")) m = 1 - for j = 1:nA + @inbounds for j = 1:nA for l = 1:mB Bll = B[l,l] for k = 1:mA - R[m] = A[k,j] * Bll + C[m] = A[k,j] * Bll m += nB end m += 1 end m -= nB end - return R + return C +end + +function kron(A::Diagonal{T}, B::AbstractMatrix{S}) where {T<:Number, S<:Number} + (mA, nA) = size(A); (mB, nB) = size(B) + R = zeros(Base.promote_op(*, T, S), mA * mB, nA * nB) + return @inbounds kron!(R, A, B) +end + +function kron(A::AbstractMatrix{T}, B::Diagonal{S}) where {T<:Number, S<:Number} + (mA, nA) = size(A); (mB, nB) = size(B) + R = zeros(promote_op(*, T, S), mA * mB, nA * nB) + return @inbounds kron!(R, A, B) end conj(D::Diagonal) = Diagonal(conj(D.diag))