From fee130d924966e21ea3f6a5619092abf6be442ab Mon Sep 17 00:00:00 2001 From: Nick Robinson Date: Thu, 30 Jan 2020 21:48:49 +0000 Subject: [PATCH] Change Cholesky to use `Composite` (#164) * Update Cholesky to use Composite * Bump version to v0.3.3 * Add comment on why we call `Matrix` * Replace `extern` with `unthunk` * Move `Matrix` conversion * Thunk whole `Composite` instead of the field --- Project.toml | 2 +- src/rulesets/LinearAlgebra/factorization.jl | 28 ++++++++++++-------- test/rulesets/LinearAlgebra/factorization.jl | 4 +-- 3 files changed, 20 insertions(+), 14 deletions(-) diff --git a/Project.toml b/Project.toml index c3b31aaf8..206a3143e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.3.2" +version = "0.3.3" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index 38a266a96..5daa046f4 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -71,26 +71,31 @@ end function rrule(::typeof(cholesky), X::AbstractMatrix{<:Real}) F = cholesky(X) - function cholesky_pullback(Ȳ) - ∂X = @thunk(chol_blocked_rev(Matrix(Ȳ), Matrix(F.U), 25, true)) + function cholesky_pullback(Ȳ::Composite{<:Cholesky}) + ∂X = if F.uplo === 'U' + @thunk(chol_blocked_rev(Ȳ.U, F.U, 25, true)) + else + @thunk(chol_blocked_rev(Ȳ.L, F.L, 25, false)) + end return (NO_FIELDS, ∂X) end return F, cholesky_pullback end -function rrule(::typeof(getproperty), F::Cholesky, x::Symbol) +function rrule(::typeof(getproperty), F::T, x::Symbol) where T <: Cholesky function getproperty_cholesky_pullback(Ȳ) - if x === :U + C = Composite{T} + ∂F = @thunk if x === :U if F.uplo === 'U' - ∂F = @thunk UpperTriangular(Ȳ) + C(U=UpperTriangular(Ȳ),) else - ∂F = @thunk LowerTriangular(Ȳ') + C(L=LowerTriangular(Ȳ'),) end elseif x === :L if F.uplo === 'L' - ∂F = @thunk LowerTriangular(Ȳ) + C(L=LowerTriangular(Ȳ),) else - ∂F = @thunk UpperTriangular(Ȳ') + C(U=UpperTriangular(Ȳ'),) end end return NO_FIELDS, ∂F, DoesNotExist() @@ -194,7 +199,7 @@ function chol_unblocked_rev(Σ̄::AbstractMatrix, L::AbstractMatrix, upper::Bool end """ - chol_blocked_rev!(Σ̄::AbstractMatrix, L::AbstractMatrix, nb::Integer, upper::Bool) + chol_blocked_rev!(Σ̄::StridedMatrix, L::StridedMatrix, nb::Integer, upper::Bool) Compute the sensitivities of the Cholesky factorization using a blocked, cache-friendly procedure. `Σ̄` are the sensitivities of `L`, and will be transformed into the sensitivities @@ -202,7 +207,7 @@ of `Σ`, where `Σ = LLᵀ`. `nb` is the block size to use. If the upper triangl to represent the factorization, that is `Σ = UᵀU` where `U := Lᵀ`, then this should be indicated by passing `upper = true`. """ -function chol_blocked_rev!(Σ̄::AbstractMatrix{T}, L::AbstractMatrix{T}, nb::Integer, upper::Bool) where T<:Real +function chol_blocked_rev!(Σ̄::StridedMatrix{T}, L::StridedMatrix{T}, nb::Integer, upper::Bool) where T<:Real n = checksquare(Σ̄) tmp = Matrix{T}(undef, nb, nb) k = n @@ -252,5 +257,6 @@ function chol_blocked_rev!(Σ̄::AbstractMatrix{T}, L::AbstractMatrix{T}, nb::In end function chol_blocked_rev(Σ̄::AbstractMatrix, L::AbstractMatrix, nb::Integer, upper::Bool) - return chol_blocked_rev!(copy(Σ̄), L, nb, upper) + # Convert to `Matrix`s because blas functions require StridedMatrix input. + return chol_blocked_rev!(Matrix(Σ̄), Matrix(L), nb, upper) end diff --git a/test/rulesets/LinearAlgebra/factorization.jl b/test/rulesets/LinearAlgebra/factorization.jl index 949008707..3fd899ae4 100644 --- a/test/rulesets/LinearAlgebra/factorization.jl +++ b/test/rulesets/LinearAlgebra/factorization.jl @@ -70,9 +70,9 @@ using ChainRules: level2partition, level3partition, chol_blocked_rev, chol_unblo # machinery from FiniteDifferences because that isn't set up to respect # necessary special properties of the input. In the case of the Cholesky # factorization, we need the input to be Hermitian. - ΔF = extern(dF) + ΔF = unthunk(dF) _, dX = dX_pullback(ΔF) - X̄_ad = dot(extern(dX), V) + X̄_ad = dot(unthunk(dX), V) X̄_fd = _fdm() do ε dot(Ȳ, getproperty(cholesky(X .+ ε .* V), p)) end