Skip to content

Commit

Permalink
add inplace kron
Browse files Browse the repository at this point in the history
  • Loading branch information
Roger-luo committed May 14, 2020
1 parent 8ef29e6 commit 23bfba1
Show file tree
Hide file tree
Showing 8 changed files with 107 additions and 33 deletions.
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Build system changes

New library functions
---------------------

* New function `Base.kron!` and corresponding overloads for various matrix types for performing Kronecker product in-place. ([#31069]).

New library features
--------------------
Expand Down
1 change: 1 addition & 0 deletions base/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,7 @@ export
adjoint,
transpose,
kron,
kron!,

# bitarrays
falses,
Expand Down
2 changes: 2 additions & 0 deletions base/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,8 @@ for op in (:+, :*, :&, :|, :xor, :min, :max, :kron)
end
end

function kron! end

const var"'" = adjoint

"""
Expand Down
1 change: 1 addition & 0 deletions stdlib/LinearAlgebra/docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ Base.inv(::AbstractMatrix)
LinearAlgebra.pinv
LinearAlgebra.nullspace
Base.kron
Base.kron!
LinearAlgebra.exp(::StridedMatrix{<:LinearAlgebra.BlasFloat})
Base.:^(::AbstractMatrix, ::Number)
Base.:^(::Number, ::AbstractMatrix)
Expand Down
2 changes: 1 addition & 1 deletion stdlib/LinearAlgebra/src/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import Base: \, /, *, ^, +, -, ==
import Base: USE_BLAS64, abs, acos, acosh, acot, acoth, acsc, acsch, adjoint, asec, asech,
asin, asinh, atan, atanh, axes, big, broadcast, ceil, conj, convert, copy, copyto!, cos,
cosh, cot, coth, csc, csch, eltype, exp, fill!, floor, getindex, hcat,
getproperty, imag, inv, isapprox, isone, iszero, IndexStyle, kron, length, log, map, ndims,
getproperty, imag, inv, isapprox, isone, iszero, IndexStyle, kron, kron!, length, log, map, ndims,
oneunit, parent, power_by_squaring, print_matrix, promote_rule, real, round, sec, sech,
setindex!, show, similar, sin, sincos, sinh, size, sqrt,
strides, stride, tan, tanh, transpose, trunc, typed_hcat, vec
Expand Down
26 changes: 20 additions & 6 deletions stdlib/LinearAlgebra/src/bitarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,22 +92,29 @@ qr(A::BitMatrix) = qr(float(A))

## kron

function kron(a::BitVector, b::BitVector)
@inline function kron!(R::BitVector, a::BitVector, b::BitVector)
m = length(a)
n = length(b)
R = falses(n * m)
@boundscheck length(R) == n*m || throw(DimensionMismatch())
Rc = R.chunks
bc = b.chunks
for j = 1:m
a[j] && Base.copy_chunks!(Rc, (j-1)*n+1, bc, 1, n)
end
R
return R
end

function kron(a::BitMatrix, b::BitMatrix)
function kron(a::BitVector, b::BitVector)
m = length(a)
n = length(b)
R = falses(n * m)
return @inbounds kron!(R, a, b)
end

function kron!(R::BitMatrix, a::BitMatrix, b::BitMatrix)
mA,nA = size(a)
mB,nB = size(b)
R = falses(mA*mB, nA*nB)
@boundscheck size(R) == (mA*mB, nA*nB) || throw(DimensionMismatch())

for i = 1:mA
ri = (1:mB) .+ ((i-1)*mB)
Expand All @@ -118,7 +125,14 @@ function kron(a::BitMatrix, b::BitMatrix)
end
end
end
R
return R
end

function kron(a::BitMatrix, b::BitMatrix)
mA,nA = size(a)
mB,nB = size(b)
R = falses(mA*mB, nA*nB)
return @inbounds kron!(R, a, b)
end

## Structure query functions
Expand Down
46 changes: 37 additions & 9 deletions stdlib/LinearAlgebra/src/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,29 @@ function tr(A::Matrix{T}) where T
t
end

"""
kron!(C, A, B)
`kron!` is the in-place version of [`kron`](@ref). Computes `kron(A, B)` and stores the result in `C`
overwriting the existing value of `C`.
!!! tip
Bounds checking can be disabled by [`@inbounds`](@ref), but you need to take care of the shape
of `C`, `A`, `B` yourself.
"""
@inline function kron!(C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix)
require_one_based_indexing(A, B)
@boundscheck (size(C) == (size(A,1)*size(B,1), size(A,2)*size(B,2))) || throw(DimensionMismatch())
m = 0
@inbounds for j = 1:size(A,2), l = 1:size(B,2), i = 1:size(A,1)
Aij = A[i,j]
for k = 1:size(B,1)
C[m += 1] = Aij*B[k,l]
end
end
return C
end

"""
kron(A, B)
Expand Down Expand Up @@ -383,18 +406,23 @@ julia> reshape(kron(v,w), (length(w), length(v)))
```
"""
function kron(a::AbstractMatrix{T}, b::AbstractMatrix{S}) where {T,S}
require_one_based_indexing(a, b)
R = Matrix{promote_op(*,T,S)}(undef, size(a,1)*size(b,1), size(a,2)*size(b,2))
m = 0
@inbounds for j = 1:size(a,2), l = 1:size(b,2), i = 1:size(a,1)
aij = a[i,j]
for k = 1:size(b,1)
R[m += 1] = aij*b[k,l]
end
end
R
return @inbounds kron!(R, a, b)
end

kron!(c::AbstractVecOrMat, a::AbstractVecOrMat, b::Number) = mul!(c, a, b)

Base.@propagate_inbounds function kron!(c::AbstractVector, a::AbstractVector, b::AbstractVector)
C = reshape(c, length(a)*length(b), 1)
A = reshape(a ,length(a), 1)
B = reshape(b, length(b), 1)
kron!(C, A, B)
return c
end

Base.@propagate_inbounds kron!(C::AbstractMatrix, a::AbstractMatrix, b::AbstractVector) = kron!(C, a, reshape(b, length(b), 1))
Base.@propagate_inbounds kron!(C::AbstractMatrix, a::AbstractVector, b::AbstractMatrix) = kron!(C, reshape(a, length(a), 1), b)

kron(a::Number, b::Union{Number, AbstractVecOrMat}) = a * b
kron(a::AbstractVecOrMat, b::Number) = a * b
kron(a::AbstractVector, b::AbstractVector) = vec(kron(reshape(a ,length(a), 1), reshape(b, length(b), 1)))
Expand Down
60 changes: 44 additions & 16 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -493,52 +493,80 @@ rdiv!(A::AbstractMatrix{T}, transD::Transpose{<:Any,<:Diagonal{T}}) where {T} =
(\)(A::Union{QR,QRCompactWY,QRPivoted}, B::Diagonal) =
invoke(\, Tuple{Union{QR,QRCompactWY,QRPivoted}, AbstractVecOrMat}, A, B)

function kron(A::Diagonal{T1}, B::Diagonal{T2}) where {T1<:Number, T2<:Number}

@inline function kron!(C::AbstractMatrix{T}, A::Diagonal, B::Diagonal) where T
fill!(C, zero(T))
valA = A.diag; nA = length(valA)
valB = B.diag; nB = length(valB)
valC = Vector{typeof(zero(T1)*zero(T2))}(undef,nA*nB)
nC = checksquare(C)
@boundscheck nC == nA*nB ||
throw(DimensionMismatch("expect C to be a $(nA*nB)x$(nA*nB) matrix, got size $(nC)x$(nC)"))

@inbounds for i = 1:nA, j = 1:nB
valC[(i-1)*nB+j] = valA[i] * valB[j]
idx = (i-1)*nB+j
C[idx, idx] = valA[i] * valB[j]
end
return Diagonal(valC)
return C
end

function kron(A::Diagonal{T}, B::AbstractMatrix{S}) where {T<:Number, S<:Number}
function kron(A::Diagonal{T1}, B::Diagonal{T2}) where {T1<:Number, T2<:Number}
valA = A.diag; nA = length(valA)
valB = B.diag; nB = length(valB)
valC = Vector{typeof(zero(T1)*zero(T2))}(undef,nA*nB)
C = Diagonal(valC)
return @inbounds kron!(C, A, B)
end

@inline function kron!(C::AbstractMatrix, A::Diagonal, B::AbstractMatrix)
Base.require_one_based_indexing(B)
(mA, nA) = size(A); (mB, nB) = size(B)
R = zeros(Base.promote_op(*, T, S), mA * mB, nA * nB)
(mA, nA) = size(A); (mB, nB) = size(B); (mC, nC) = size(C);
@boundscheck (mC, nC) == (mA * mB, nA * nB) ||
throw(DimensionMismatch("expect C to be a $(mA * mB)x$(nA * nB) matrix, got size $(mC)x$(nC)"))
m = 1
for j = 1:nA
@inbounds for j = 1:nA
A_jj = A[j,j]
for k = 1:nB
for l = 1:mB
R[m] = A_jj * B[l,k]
C[m] = A_jj * B[l,k]
m += 1
end
m += (nA - 1) * mB
end
m += mB
end
return R
return C
end

function kron(A::AbstractMatrix{T}, B::Diagonal{S}) where {T<:Number, S<:Number}
@inline function kron!(C::AbstractMatrix, A::AbstractMatrix, B::Diagonal)
require_one_based_indexing(A)
(mA, nA) = size(A); (mB, nB) = size(B)
R = zeros(promote_op(*, T, S), mA * mB, nA * nB)
(mA, nA) = size(A); (mB, nB) = size(B); (mC, nC) = size(C);
@boundscheck (mC, nC) == (mA * mB, nA * nB) ||
throw(DimensionMismatch("expect C to be a $(mA * mB)x$(nA * nB) matrix, got size $(mC)x$(nC)"))
m = 1
for j = 1:nA
@inbounds for j = 1:nA
for l = 1:mB
Bll = B[l,l]
for k = 1:mA
R[m] = A[k,j] * Bll
C[m] = A[k,j] * Bll
m += nB
end
m += 1
end
m -= nB
end
return R
return C
end

function kron(A::Diagonal{T}, B::AbstractMatrix{S}) where {T<:Number, S<:Number}
(mA, nA) = size(A); (mB, nB) = size(B)
R = zeros(Base.promote_op(*, T, S), mA * mB, nA * nB)
return @inbounds kron!(R, A, B)
end

function kron(A::AbstractMatrix{T}, B::Diagonal{S}) where {T<:Number, S<:Number}
(mA, nA) = size(A); (mB, nB) = size(B)
R = zeros(promote_op(*, T, S), mA * mB, nA * nB)
return @inbounds kron!(R, A, B)
end

conj(D::Diagonal) = Diagonal(conj(D.diag))
Expand Down

0 comments on commit 23bfba1

Please sign in to comment.