-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add 3-arg * methods #37898
Add 3-arg * methods #37898
Conversation
Will this not run into the same problems as #24343 (comment)? That PR ended up doing it for |
I guess the functions which fuse the scalar like julia> sprand(13, 11, 0.3) * sprand(11, 0.3) * 17
13-element SparseVector{Float64,Int64} with 3 stored entries:
[1 ] = 1.40409
[2 ] = 5.72894
[7 ] = 5.95465
julia> SA[1 2; 3 4] * SA[5,6] * 7
2-element MArray{Tuple{2},Int64,1,2} with indices SOneTo(2):
119
273
julia> 5 * Diagonal(1:2) * Diagonal(3:4)
2×2 Adjoint{Int64,SparseMatrixCSC{Int64,Int64}}:
15 0
0 40 The rest dispatches to either 2-arg I haven't thought through all the weird special matrices in LinearAlgebra. Edit -- the above examples (and all special matrices) now go the fallback path, and no longer give unexpected types: julia> SA[1 2; 3 4] * SA[5,6] * 7
2-element SArray{Tuple{2},Int64,1,2} with indices SOneTo(2):
119
273
julia> 5 * Diagonal(1:2) * Diagonal(3:4)
2×2 Diagonal{Int64,Array{Int64,1}}:
15 ⋅
⋅ 40 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me. Pretty verbose, but that's perhaps due to the matter. One "beautification" suggestion, and then I'd shout out for another review.
While it doesn't fit the title, should we also add some 4-arg cases? There are quite a few which are trivial to route to 2-arg and 3-arg methods, without needing any new functions: # Four-argument *
*(α::Number, β::Number, γ::Number, D::AbstractArray) = (α*β*γ) * D
*(α::Number, β::Number, C::AbstractMatrix, D::AbstractVecOrMat) = (α*β) * C * D
*(α::Number, B::AbstractMatrix, C::AbstractMatrix, x::AbstractVector) = α * B * (C*x)
*(α::Number, vt::AdjOrTransAbsVec, C::AbstractMatrix, D::AbstractMatrix) = (α*vt*C) * D
*(α::Number, vt::AdjOrTransAbsVec, C::AbstractMatrix, x::AbstractVector) = α * (vt*C*x)
*(A::AbstractMatrix, B::AbstractMatrix, C::AbstractMatrix, x::AbstractVector) = A * B * (C*x)
*(vt::AdjOrTransAbsVec, B::AbstractMatrix, C::AbstractMatrix, D::AbstractMatrix) = (vt*B) * C * D
*(vt::AdjOrTransAbsVec, B::AbstractMatrix, C::AbstractMatrix, x::AbstractVector) = vt * B * (C*x)
function _quad_mul(A,B,C,D)
c1 = _cost((A,B),(C,D))
c2 = _cost(((A,B),C),D)
c3 = _cost(A,(B,(C,D)))
c4 = _cost((A,(B,C)),D)
c5 = _cost(A,((B,C),D))
cmin = min(c1,c2,c3,c4,c5)
if c1 == cmin
(A*B) * (C*D)
elseif c2 == cmin
((A*B) * C) * D
elseif c3 == cmin
A * (B * (C*D))
elseif c4 == cmin
(A * (B*C)) * D
else
A * ((B*C) * D)
end
end
@inline _cost(A::AbstractMatrix) = 0
@inline _cost((A,B)::Tuple) = _cost(A,B)
@inline _cost(A,B) = _cost(A) + _cost(B) + *(_sizes(A)..., _sizes(B)[end])
@inline _sizes(A::AbstractMatrix) = size(A)
@inline _sizes((A,B)::Tuple) = _sizes(A)[begin], _sizes(B)[end]
using Random, Test, BenchmarkTools
s1,s2,s3,s4,s5 = shuffle([5,10,20,100,200])
a=rand(s1,s2); b=rand(s2,s3); c=rand(s3,s4); d=rand(s4,s5);
@test *(a,b,c,d) ≈ _quad_mul(a,b,c,d)
@btime *($a,$b,$c,$d);
@btime _quad_mul($a,$b,$c,$d);
s1,s2,s3,s4,s5 = fill(30,5) # 0.2% overhead at size 30
s1,s2,s3,s4,s5 = fill(3,5) # 4% overhead at size 3 (7ns)
using StaticArrays
s1,s2,s3,s4,s5 = shuffle([2,3,5,7,11])
a=@SMatrix rand(s1,s2); b=@SMatrix rand(s2,s3); c=@SMatrix rand(s3,s4); d=@SMatrix rand(s4,s5);
@test *(a,b,c,d) ≈ _quad_mul(a,b,c,d)
@btime *($(Ref(a))[],$(Ref(b))[],$(Ref(c))[],$(Ref(d))[]);
@btime _quad_mul($(Ref(a))[],$(Ref(b))[],$(Ref(c))[],$(Ref(d))[]);
s1,s2,s3,s4,s5 = fill(3,5) # 28% overhead for 3x3 SMatrix (6ns) It should be easy to make StaticArrays avoid this overhead: mcabbott/StaticArrays.jl@a1aa074 |
Using 3-arg dot isn't always faster. Would it be too strange to check the size before deciding what method to call? julia> N = 10;
julia> @btime x'*A*y setup=(x=rand($N); A=rand($N,$N); y=rand($N));
138.139 ns (1 allocation: 160 bytes)
julia> @btime dot(x,A,y) setup=(x=rand($N); A=rand($N,$N); y=rand($N));
83.566 ns (0 allocations: 0 bytes)
julia> N = 100;
julia> @btime x'*A*y setup=(x=rand($N); A=rand($N,$N); y=rand($N));
1.151 μs (1 allocation: 896 bytes)
julia> @btime dot(x,A,y) setup=(x=rand($N); A=rand($N,$N); y=rand($N));
1.526 μs (0 allocations: 0 bytes)
julia> N = 10_000;
julia> @btime x'*A*y setup=(x=rand($N); A=rand($N,$N); y=rand($N));
33.634 ms (2 allocations: 78.20 KiB)
julia> @btime dot(x,A,y) setup=(x=rand($N); A=rand($N,$N); y=rand($N));
62.127 ms (0 allocations: 0 bytes)
julia> BLAS.vendor()
:mkl Edit -- maybe it's best just never to call |
Bump? Am happy to remove 4-arg cases, if they are thought too exotic / too complicated for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This LGTM, but I think we should have somebody else take a look.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a few minor comments, and then we should really push this over the finish line. This is great stuff.
Why is this marked with backport label @dkarrasch? |
I thought we have a "soft feature freeze", and that we include/backport PRs that have been developed for a while in the v1.6 cycle. This one, specifically, just got a bit forgotten, and my last comments are rather cosmetic. Please feel free to remove the backport label if I misunderstood the intention of the soft feature freeze. |
True, but now it is getting a bit late I think. I'll remove the label for now. |
This just missed 1.6, but perhaps it should be in 1.7? Looking over it again briefly, I'm slightly dismayed that it needs 100 lines of code, for what seems a pretty simple optimisation. Half of them are for the 4-arg case. However, there are very few existing methods to bump into ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have two minor suggestions. Could you also add a comment as to what is optimized? I think it's the number of operations, and not memory allocation for intermediate results, right?
Good point, I have re-written the docstring to be more explicit, see if you like it. |
stdlib/LinearAlgebra/src/matmul.jl
Outdated
|
||
If the last factor is a vector, or the first a transposed vector, then it is efficient | ||
to deal with these first. In particular `x' * B * y` means `(x' * B) * y` | ||
for an ordinary colum-major `B::Matrix`. This is often equivalent to `dot(x, B, y)`, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
colum -> column
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, why not be explicit: "For scalar eltypes, this is equivalent to dot(x, B, y)
" or something in that direction.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe "Unlike dot(), ..."? I don't want to dwell on pinning down exactly what types they agree or disagree on.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As an aside, trying to puzzle out what recursive dot
means in the 3-arg case, is this the intended behaviour?
julia> x = [rand(2,2) for _ in 1:3]; A = [rand(2,2) for _ in 1:3, _ in 1:3]; y = [rand(2,2) for _ in 1:3];
julia> dot(x,A,y)
7.411848453027886
julia> @tullio _ := x[i][b,a] * A[i,j][b,c] * y[j][c,a]
7.4118484530278845
julia> @tullio _ := tr(x[i]' * A[i,j] * y[j])
7.411848453027887
The action on the component matrices looks like a trace of a matrix product.
This is ready to go, I think. |
This addresses the simplest part of #12065 (optimizing * for optimal matrix order), by adding some methods for * with 3 arguments, where this can be done more efficiently than working left-to-right. Co-authored-by: Daniel Karrasch <[email protected]>
This addresses the simplest part of #12065 (optimizing * for optimal matrix order), by adding some methods for * with 3 arguments, where this can be done more efficiently than working left-to-right. Co-authored-by: Daniel Karrasch <[email protected]>
The fallback introduced here: mat_mat_scalar(A, B, γ) = (A*B) .* γ # fallback breaks some of my custom array code (and myabe others) and I'm wondering if it could be changed to just For Base Arrays these should be identical in terms of performance since it will just call broadcasting one method deeper, but for custom arrays which had only defined |
For such a generic fallback, I think it is very reasonable to make it |
Sorry about breaking your code! I agree there's no good reason for |
Awesome thanks for the quick reply! And no worries, it does also improve some other things as intended! |
PR #37898 added methods to `*` for chained matrix multiplication. They have a descriptive docstring but I don't think this is mentioned in the manual.
This addresses the simplest part of JuliaLang/LinearAlgebra.jl#227, by adding some methods for
*
with 3 arguments, where this can be done more efficiently than working left-to-right:I think it's careful about Adjoint & Transpose vectors, but might need more thought about Diagonal and other special matrix types.
(Edit -- testing says(zeros(0))' * Diagonal(zeros(0)) * zeros(0)
is ambiguous.)See also https://github.com/AustinPrivett/MatrixChainMultiply.jl, and discussion https://discourse.julialang.org/t/why-is-multiplication-a-b-c-left-associative-foldl-not-right-associative-foldr/17552.