Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change Cholesky to use Composite #164

Merged
merged 6 commits into from
Jan 30, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.3.2"
version = "0.3.3"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
29 changes: 18 additions & 11 deletions src/rulesets/LinearAlgebra/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,26 +71,32 @@ end

function rrule(::typeof(cholesky), X::AbstractMatrix{<:Real})
F = cholesky(X)
function cholesky_pullback(Ȳ)
∂X = @thunk(chol_blocked_rev(Matrix(Ȳ), Matrix(F.U), 25, true))
function cholesky_pullback(Ȳ::Composite{<:Cholesky})
# TODO: no `unthunk`ing - https://github.com/JuliaDiff/ChainRulesCore.jl/issues/100
∂X = if F.uplo === 'U'
@thunk(chol_blocked_rev(unthunk(Ȳ.U), F.U, 25, true))
else
@thunk(chol_blocked_rev(unthunk(Ȳ.L), F.L, 25, false))
end
return (NO_FIELDS, ∂X)
end
return F, cholesky_pullback
end

function rrule(::typeof(getproperty), F::Cholesky, x::Symbol)
function rrule(::typeof(getproperty), F::T, x::Symbol) where T <: Cholesky
function getproperty_cholesky_pullback(Ȳ)
if x === :U
C = Composite{T}
∂F = if x === :U
if F.uplo === 'U'
∂F = @thunk UpperTriangular(Ȳ)
C(U=(@thunk UpperTriangular(Ȳ)),)
else
∂F = @thunk LowerTriangular(Ȳ')
C(L=(@thunk LowerTriangular(Ȳ')),)
end
elseif x === :L
if F.uplo === 'L'
∂F = @thunk LowerTriangular(Ȳ)
C(L=(@thunk LowerTriangular(Ȳ)),)
else
∂F = @thunk UpperTriangular(Ȳ')
C(U=(@thunk UpperTriangular(Ȳ')),)
end
oxinabox marked this conversation as resolved.
Show resolved Hide resolved
end
return NO_FIELDS, ∂F, DoesNotExist()
Expand Down Expand Up @@ -194,15 +200,15 @@ function chol_unblocked_rev(Σ̄::AbstractMatrix, L::AbstractMatrix, upper::Bool
end

"""
chol_blocked_rev!(Σ̄::AbstractMatrix, L::AbstractMatrix, nb::Integer, upper::Bool)
chol_blocked_rev!(Σ̄::StridedMatrix, L::StridedMatrix, nb::Integer, upper::Bool)

Compute the sensitivities of the Cholesky factorization using a blocked, cache-friendly
procedure. `Σ̄` are the sensitivities of `L`, and will be transformed into the sensitivities
of `Σ`, where `Σ = LLᵀ`. `nb` is the block size to use. If the upper triangle has been used
to represent the factorization, that is `Σ = UᵀU` where `U := Lᵀ`, then this should be
indicated by passing `upper = true`.
"""
function chol_blocked_rev!(Σ̄::AbstractMatrix{T}, L::AbstractMatrix{T}, nb::Integer, upper::Bool) where T<:Real
function chol_blocked_rev!(Σ̄::StridedMatrix{T}, L::StridedMatrix{T}, nb::Integer, upper::Bool) where T<:Real
n = checksquare(Σ̄)
tmp = Matrix{T}(undef, nb, nb)
k = n
Expand Down Expand Up @@ -252,5 +258,6 @@ function chol_blocked_rev!(Σ̄::AbstractMatrix{T}, L::AbstractMatrix{T}, nb::In
end

function chol_blocked_rev(Σ̄::AbstractMatrix, L::AbstractMatrix, nb::Integer, upper::Bool)
return chol_blocked_rev!(copy(Σ̄), L, nb, upper)
# Convert to `Matrix`s because blas functions require StridedMatrix input.
return chol_blocked_rev!(Matrix(Σ̄), Matrix(L), nb, upper)
end
4 changes: 2 additions & 2 deletions test/rulesets/LinearAlgebra/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ using ChainRules: level2partition, level3partition, chol_blocked_rev, chol_unblo
# machinery from FiniteDifferences because that isn't set up to respect
# necessary special properties of the input. In the case of the Cholesky
# factorization, we need the input to be Hermitian.
ΔF = extern(dF)
ΔF = unthunk(dF)
_, dX = dX_pullback(ΔF)
X̄_ad = dot(extern(dX), V)
X̄_ad = dot(unthunk(dX), V)
X̄_fd = _fdm() do ε
dot(Ȳ, getproperty(cholesky(X .+ ε .* V), p))
end
Expand Down