diff --git a/Project.toml b/Project.toml index fa368706c..865448158 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.7.53" +version = "0.7.54" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/ChainRules.jl b/src/ChainRules.jl index a8684c8a6..095c6b90a 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -42,6 +42,7 @@ include("rulesets/Statistics/statistics.jl") include("rulesets/LinearAlgebra/utils.jl") include("rulesets/LinearAlgebra/blas.jl") +include("rulesets/LinearAlgebra/lapack.jl") include("rulesets/LinearAlgebra/dense.jl") include("rulesets/LinearAlgebra/norm.jl") include("rulesets/LinearAlgebra/matfun.jl") diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index 482a05ed4..5069f4c61 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -239,3 +239,86 @@ function rrule(::typeof(pinv), A::AbstractMatrix{T}; kwargs...) where {T} end return Y, pinv_pullback end + +##### +##### `sylvester` +##### + +# included because the primal uses `schur`, for which we don't have a rule + +function frule( + (_, ΔA, ΔB, ΔC), + ::typeof(sylvester), + A::StridedMatrix{T}, + B::StridedMatrix{T}, + C::StridedMatrix{T}, +) where {T<:BlasFloat} + RA, QA = schur(A) + RB, QB = schur(B) + D = QA' * (C * QB) + Y, scale = LAPACK.trsyl!('N', 'N', RA, RB, D) + Ω = rmul!(QA * (Y * QB'), -inv(scale)) + ∂D = QA' * (mul!(muladd(ΔA, Ω, ΔC), Ω, ΔB, true, true) * QB) + ∂Y, scale2 = LAPACK.trsyl!('N', 'N', RA, RB, ∂D) + ∂Ω = rmul!(QA * (∂Y * QB'), -inv(scale2)) + return Ω, ∂Ω +end + +# included because the primal mutates and uses `schur` and LAPACK + +function rrule( + ::typeof(sylvester), A::StridedMatrix{T}, B::StridedMatrix{T}, C::StridedMatrix{T} +) where {T<:BlasFloat} + RA, QA = schur(A) + RB, QB = schur(B) + D = QA' * (C * QB) + Y, scale = LAPACK.trsyl!('N', 'N', RA, RB, D) + Ω = rmul!(QA * (Y * QB'), -inv(scale)) + function sylvester_pullback(ΔΩ) + ∂Ω = T <: Real ? real(ΔΩ) : ΔΩ + ∂Y = QA' * (∂Ω * QB) + trans = T <: Complex ? 'C' : 'T' + ∂D, scale2 = LAPACK.trsyl!(trans, trans, RA, RB, ∂Y) + ∂Z = rmul!(QA * (∂D * QB'), -inv(scale2)) + return NO_FIELDS, @thunk(∂Z * Ω'), @thunk(Ω' * ∂Z), @thunk(∂Z * inv(scale)) + end + return Ω, sylvester_pullback +end + +##### +##### `lyap` +##### + +# included because the primal uses `schur`, for which we don't have a rule + +function frule( + (_, ΔA, ΔC), ::typeof(lyap), A::StridedMatrix{T}, C::StridedMatrix{T} +) where {T<:BlasFloat} + R, Q = schur(A) + D = Q' * (C * Q) + Y, scale = LAPACK.trsyl!('N', T <: Complex ? 'C' : 'T', R, R, D) + Ω = rmul!(Q * (Y * Q'), -inv(scale)) + ∂D = Q' * (mul!(muladd(ΔA, Ω, ΔC), Ω, ΔA', true, true) * Q) + ∂Y, scale2 = LAPACK.trsyl!('N', T <: Complex ? 'C' : 'T', R, R, ∂D) + ∂Ω = rmul!(Q * (∂Y * Q'), -inv(scale2)) + return Ω, ∂Ω +end + +# included because the primal mutates and uses `schur` and LAPACK + +function rrule( + ::typeof(lyap), A::StridedMatrix{T}, C::StridedMatrix{T} +) where {T<:BlasFloat} + R, Q = schur(A) + D = Q' * (C * Q) + Y, scale = LAPACK.trsyl!('N', T <: Complex ? 'C' : 'T', R, R, D) + Ω = rmul!(Q * (Y * Q'), -inv(scale)) + function lyap_pullback(ΔΩ) + ∂Ω = T <: Real ? real(ΔΩ) : ΔΩ + ∂Y = Q' * (∂Ω * Q) + ∂D, scale2 = LAPACK.trsyl!(T <: Complex ? 'C' : 'T', 'N', R, R, ∂Y) + ∂Z = rmul!(Q * (∂D * Q'), -inv(scale2)) + return NO_FIELDS, @thunk(mul!(∂Z * Ω', ∂Z', Ω, true, true)), @thunk(∂Z * inv(scale)) + end + return Ω, lyap_pullback +end diff --git a/src/rulesets/LinearAlgebra/lapack.jl b/src/rulesets/LinearAlgebra/lapack.jl new file mode 100644 index 000000000..e98c0e43f --- /dev/null +++ b/src/rulesets/LinearAlgebra/lapack.jl @@ -0,0 +1,25 @@ +##### +##### `LAPACK.trsyl!` +##### + +function ChainRules.frule( + (_, _, _, ΔA, ΔB, ΔC), + ::typeof(LAPACK.trsyl!), + transa::AbstractChar, + transb::AbstractChar, + A::AbstractMatrix{T}, + B::AbstractMatrix{T}, + C::AbstractMatrix{T}, + isgn::Int, +) where {T<:BlasFloat} + C, scale = LAPACK.trsyl!(transa, transb, A, B, C, isgn) + Y = (C, scale) + ΔAtrans = transa === 'T' ? transpose(ΔA) : (transa === 'C' ? ΔA' : ΔA) + ΔBtrans = transb === 'T' ? transpose(ΔB) : (transb === 'C' ? ΔB' : ΔB) + mul!(ΔC, ΔAtrans, C, -1, scale) + mul!(ΔC, C, ΔBtrans, -isgn, true) + ΔC, scale2 = LAPACK.trsyl!(transa, transb, A, B, ΔC, isgn) + rmul!(ΔC, inv(scale2)) + ∂Y = Composite{typeof(Y)}(ΔC, Zero()) + return Y, ∂Y +end diff --git a/test/rulesets/LinearAlgebra/dense.jl b/test/rulesets/LinearAlgebra/dense.jl index 796ccb1f1..a9f1d75a3 100644 --- a/test/rulesets/LinearAlgebra/dense.jl +++ b/test/rulesets/LinearAlgebra/dense.jl @@ -102,4 +102,22 @@ test_frule(tr, randn(4, 4)) test_rrule(tr, randn(4, 4)) end + @testset "sylvester" begin + @testset "T=$T, m=$m, n=$n" for T in (Float64, ComplexF64), m in (2, 3), n in (1, 3) + A = randn(T, m, m) + B = randn(T, n, n) + C = randn(T, m, n) + test_frule(sylvester, A, B, C) + test_rrule(sylvester, A, B, C) + end + end + @testset "lyap" begin + n = 3 + @testset "Float64" for T in (Float64, ComplexF64) + A = randn(T, n, n) + C = randn(T, n, n) + test_frule(lyap, A, C) + test_rrule(lyap, A, C) + end + end end diff --git a/test/rulesets/LinearAlgebra/lapack.jl b/test/rulesets/LinearAlgebra/lapack.jl new file mode 100644 index 000000000..c6c755343 --- /dev/null +++ b/test/rulesets/LinearAlgebra/lapack.jl @@ -0,0 +1,27 @@ +@testset "LAPACK" begin + @testset "trsyl!" begin + @testset "T=$T, m=$m, n=$n, transa='$transa', transb='$transb', isgn=$isgn" for + T in (Float64, ComplexF64), + transa in (T <: Real ? ('N', 'C', 'T') : ('N', 'C')), + transb in (T <: Real ? ('N', 'C', 'T') : ('N', 'C')), + m in (2, 3), + n in (1, 3), + isgn in (1, -1) + + # make A and B quasi upper-triangular (or upper-triangular for complex) + # and their tangents have the same sparsity pattern + A = schur(randn(T, m, m)).T + B = schur(randn(T, n, n)).T + C = randn(T, m, n) + test_frule( + LAPACK.trsyl!, + transa ⊢ nothing, + transb ⊢ nothing, + A ⊢ rand_tangent(A) .* (!iszero).(A), # Match sparsity pattern + B ⊢ rand_tangent(B) .* (!iszero).(B), + C, + isgn ⊢ nothing, + ) + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index fba65cf7a..e17b419d8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -48,6 +48,7 @@ println("Testing ChainRules.jl") include_test("rulesets/LinearAlgebra/symmetric.jl") include_test("rulesets/LinearAlgebra/factorization.jl") include_test("rulesets/LinearAlgebra/blas.jl") + include_test("rulesets/LinearAlgebra/lapack.jl") end println()