Skip to content

Commit

Permalink
Change in-place exp to out-of-place in matrix trig functions (#56242)
Browse files Browse the repository at this point in the history
This makes the functions work for arbitrary matrix types that support
`exp`, but not necessarily the in-place `exp!`. For example, the
following works after this:
```julia
julia> m = SMatrix{2,2}(1:4);

julia> cos(m)
2×2 SMatrix{2, 2, Float64, 4} with indices SOneTo(2)×SOneTo(2):
  0.855423  -0.166315
 -0.110876   0.689109
```
There's a slight performance improvement as well because we don't
compute `im*A` and `-im*A` separately, but we negate the first to obtain
the second.
```julia
julia> A = rand(ComplexF64,100,100);

julia> @Btime sin($A);
  2.796 ms (48 allocations: 1.84 MiB) # nightly v"1.12.0-DEV.1571"
  2.304 ms (48 allocations: 1.84 MiB) # this PR
```
  • Loading branch information
jishnub authored Nov 11, 2024
1 parent afdba95 commit 3318941
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 28 deletions.
80 changes: 52 additions & 28 deletions stdlib/LinearAlgebra/src/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1039,6 +1039,11 @@ function inv(A::StridedMatrix{T}) where T
return Ai
end

# helper function to perform a broadcast in-place if the destination is strided
# otherwise, this performs an out-of-place broadcast
@inline _broadcast!!(f, dest::StridedArray, args...) = broadcast!(f, dest, args...)
@inline _broadcast!!(f, dest, args...) = broadcast(f, args...)

"""
cos(A::AbstractMatrix)
Expand All @@ -1061,19 +1066,22 @@ function cos(A::AbstractMatrix{<:Real})
elseif issymmetric(A)
return copytri!(parent(cos(Symmetric(A))), 'U')
end
T = complex(float(eltype(A)))
return real(exp!(T.(im .* A)))
M = im .* float.(A)
return real(exp_maybe_inplace(M))
end
function cos(A::AbstractMatrix{<:Complex})
if isdiag(A)
return applydiagonal(cos, A)
elseif ishermitian(A)
return copytri!(parent(cos(Hermitian(A))), 'U', true)
end
T = complex(float(eltype(A)))
X = exp!(T.(im .* A))
@. X = (X + $exp!(T(-im*A))) / 2
return X
M = im .* float.(A)
N = -M
X = exp_maybe_inplace(M)
Y = exp_maybe_inplace(N)
# Compute (X + Y)/2 and return the result.
# Compute the result in-place if X is strided
_broadcast!!((x,y) -> (x + y)/2, X, X, Y)
end

"""
Expand All @@ -1098,23 +1106,22 @@ function sin(A::AbstractMatrix{<:Real})
elseif issymmetric(A)
return copytri!(parent(sin(Symmetric(A))), 'U')
end
T = complex(float(eltype(A)))
return imag(exp!(T.(im .* A)))
M = im .* float.(A)
return imag(exp_maybe_inplace(M))
end
function sin(A::AbstractMatrix{<:Complex})
if isdiag(A)
return applydiagonal(sin, A)
elseif ishermitian(A)
return copytri!(parent(sin(Hermitian(A))), 'U', true)
end
T = complex(float(eltype(A)))
X = exp!(T.(im .* A))
Y = exp!(T.(.-im .* A))
@inbounds for i in eachindex(X, Y)
x, y = X[i]/2, Y[i]/2
X[i] = Complex(imag(x)-imag(y), real(y)-real(x))
end
return X
M = im .* float.(A)
Mneg = -M
X = exp_maybe_inplace(M)
Y = exp_maybe_inplace(Mneg)
# Compute (X - Y)/2im and return the result.
# Compute the result in-place if X is strided
_broadcast!!((x,y) -> (x - y)/2im, X, X, Y)
end

"""
Expand Down Expand Up @@ -1144,8 +1151,8 @@ function sincos(A::AbstractMatrix{<:Real})
cosA = copytri!(parent(symcosA), 'U')
return sinA, cosA
end
T = complex(float(eltype(A)))
c, s = reim(exp!(T.(im .* A)))
M = im .* float.(A)
c, s = reim(exp_maybe_inplace(M))
return s, c
end
function sincos(A::AbstractMatrix{<:Complex})
Expand All @@ -1155,16 +1162,26 @@ function sincos(A::AbstractMatrix{<:Complex})
cosA = copytri!(parent(hermcosA), 'U', true)
return sinA, cosA
end
T = complex(float(eltype(A)))
X = exp!(T.(im .* A))
Y = exp!(T.(.-im .* A))
M = im .* float.(A)
Mneg = -M
X = exp_maybe_inplace(M)
Y = exp_maybe_inplace(Mneg)
_sincos(X, Y)
end
function _sincos(X::StridedMatrix, Y::StridedMatrix)
@inbounds for i in eachindex(X, Y)
x, y = X[i]/2, Y[i]/2
X[i] = Complex(imag(x)-imag(y), real(y)-real(x))
Y[i] = x+y
end
return X, Y
end
function _sincos(X, Y)
T = eltype(X)
S = T(0.5)*im .* (Y .- X)
C = T(0.5) .* (X .+ Y)
S, C
end

"""
tan(A::AbstractMatrix)
Expand Down Expand Up @@ -1205,8 +1222,9 @@ function cosh(A::AbstractMatrix)
return copytri!(parent(cosh(Hermitian(A))), 'U', true)
end
X = exp(A)
@. X = (X + $exp!(float(-A))) / 2
return X
negA = @. float(-A)
Y = exp_maybe_inplace(negA)
_broadcast!!((x,y) -> (x + y)/2, X, X, Y)
end

"""
Expand All @@ -1221,8 +1239,9 @@ function sinh(A::AbstractMatrix)
return copytri!(parent(sinh(Hermitian(A))), 'U', true)
end
X = exp(A)
@. X = (X - $exp!(float(-A))) / 2
return X
negA = @. float(-A)
Y = exp_maybe_inplace(negA)
_broadcast!!((x,y) -> (x - y)/2, X, X, Y)
end

"""
Expand All @@ -1237,15 +1256,20 @@ function tanh(A::AbstractMatrix)
return copytri!(parent(tanh(Hermitian(A))), 'U', true)
end
X = exp(A)
Y = exp!(float.(.-A))
negA = @. float(-A)
Y = exp_maybe_inplace(negA)
X′, Y′ = _subadd!!(X, Y)
return X′ / Y′
end
function _subadd!!(X::StridedMatrix, Y::StridedMatrix)
@inbounds for i in eachindex(X, Y)
x, y = X[i], Y[i]
X[i] = x - y
Y[i] = x + y
end
X /= Y
return X
return X, Y
end
_subadd!!(X, Y) = X - Y, X + Y

"""
acos(A::AbstractMatrix)
Expand Down
17 changes: 17 additions & 0 deletions stdlib/LinearAlgebra/test/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ module TestDense
using Test, LinearAlgebra, Random
using LinearAlgebra: BlasComplex, BlasFloat, BlasReal

const BASE_TEST_PATH = joinpath(Sys.BINDIR, "..", "share", "julia", "test")
isdefined(Main, :FillArrays) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "FillArrays.jl"))
import Main.FillArrays

@testset "Check that non-floats are correctly promoted" begin
@test [1 0 0; 0 1 0]\[1,1] [1;1;0]
end
Expand Down Expand Up @@ -1302,4 +1306,17 @@ end
end
end

@testset "trig functions for non-strided" begin
@testset for T in (Float32,ComplexF32)
A = FillArrays.Fill(T(0.1), 4, 4) # all.(<(1), eigvals(A)) for atanh
M = Matrix(A)
@testset for f in (sin,cos,tan,sincos,sinh,cosh,tanh)
@test f(A) == f(M)
end
@testset for f in (asin,acos,atan,asinh,acosh,atanh)
@test f(A) == f(M)
end
end
end

end # module TestDense

0 comments on commit 3318941

Please sign in to comment.