Skip to content

Commit

Permalink
add simple 3-arg and 4-arg * methods
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Abbott committed Oct 8, 2020
1 parent 301e082 commit a1aa074
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
10 changes: 7 additions & 3 deletions src/matrix_multiply.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,15 @@ import LinearAlgebra: BlasFloat, matprod, mul!
@inline *(A::StaticArray{Tuple{N,1},<:Any,2}, B::Adjoint{<:Any,<:StaticVector}) where {N} = vec(A) * B
@inline *(A::StaticArray{Tuple{N,1},<:Any,2}, B::Transpose{<:Any,<:StaticVector}) where {N} = vec(A) * B

# Avoid LinearAlgebra._quad_matmul's order calculation on equal sizes
@inline *(A::StaticMatrix{N,N}, B::StaticMatrix{N,N}, C::StaticMatrix{N,N}) where {N} = (A*B)*C
@inline *(A::StaticMatrix{N,N}, B::StaticMatrix{N,N}, C::StaticMatrix{N,N}, D::StaticMatrix{N,N}) where {N} = ((A*B)*C)*D

"""
mul_result_structure(a::Type, b::Type)
Get a structure wrapper that should be applied to the result of multiplication of matrices
of given types (a*b).
of given types (a*b).
"""
function mul_result_structure(a, b)
return identity
Expand Down Expand Up @@ -119,7 +123,7 @@ end
else
exprs = [:(a[$i] * transpose(b[$j])) for i = 1:sa[1], j = 1:sb[2]]
end

return quote
@_inline_meta
T = promote_op(*, Ta, Tb)
Expand Down Expand Up @@ -214,7 +218,7 @@ end
while m < M
mu = min(M, m + M_r)
mrange = m+1:mu

atemps_init = [:($(atemps[k1]) = a[$k1]) for k1 = mrange]
exprs_init = [:($(tmps[k1,k2]) = $(atemps[k1]) * b[$(1 + (k2-1) * sb[1])]) for k1 = mrange, k2 = nrange]
atemps_loop_init = [:($(atemps[k1]) = a[$(k1-sa[1]) + $(sa[1])*j]) for k1 = mrange]
Expand Down
3 changes: 3 additions & 0 deletions test/matrix_multiply.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ mul_wrappers = [
@test m*transpose(n) === @SMatrix [8 14; 18 32]
@test transpose(m)*transpose(n) === @SMatrix [11 19; 16 28]

@test @inferred(m*n*m) === @SMatrix [49 72; 109 160]
@test @inferred(m*n*m*n) === @SMatrix [386 507; 858 1127]

# check different sizes because there are multiple implementations for matrices of different sizes
for (mm, nn) in [
(m, n),
Expand Down

0 comments on commit a1aa074

Please sign in to comment.