Skip to content

Commit

Permalink
Use new cholesky pivot syntax in v1.8 (#633)
Browse files Browse the repository at this point in the history
* Use new cholesky pivot syntax in v1.8

* Increment patch number

* Increment minor version number
  • Loading branch information
sethaxen authored Aug 4, 2022
1 parent 2c1bce7 commit 1f4a8a9
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 21 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "1.41.0"
version = "1.42.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
14 changes: 8 additions & 6 deletions src/rulesets/LinearAlgebra/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ using LinearAlgebra.BLAS: gemv, gemv!, gemm!, trsm!, axpy!, ger!
const LU_RowMaximum = VERSION >= v"1.7.0-DEV.1188" ? RowMaximum : Val{true}
const LU_NoPivot = VERSION >= v"1.7.0-DEV.1188" ? NoPivot : Val{false}

const CHOLESKY_NoPivot = VERSION >= v"1.8.0-rc1" ? Union{NoPivot, Val{false}} : Val{false}

function frule(
(_, Ȧ), ::typeof(lu!), A::StridedMatrix, pivot::Union{LU_RowMaximum,LU_NoPivot}; kwargs...
)
Expand Down Expand Up @@ -462,8 +464,8 @@ function _cholesky_Diagonal_pullback(ΔC, C)
end
return NoTangent(), Diagonal(Ādiag), NoTangent()
end
function rrule(::typeof(cholesky), A::Diagonal{<:Number}, ::Val{false}; check::Bool=true)
C = cholesky(A, Val(false); check=check)
function rrule(::typeof(cholesky), A::Diagonal{<:Number}, pivot::CHOLESKY_NoPivot; check::Bool=true)
C = cholesky(A, pivot; check=check)
cholesky_pullback(ȳ) = _cholesky_Diagonal_pullback(unthunk(ȳ), C)
return C, cholesky_pullback
end
Expand All @@ -474,10 +476,10 @@ end
function rrule(
::typeof(cholesky),
A::LinearAlgebra.RealHermSymComplexHerm{<:Real, <:StridedMatrix},
::Val{false};
pivot::CHOLESKY_NoPivot;
check::Bool=true,
)
C = cholesky(A, Val(false); check=check)
C = cholesky(A, pivot; check=check)
function cholesky_HermOrSym_pullback(ΔC)
= _cholesky_pullback_shared_code(C, unthunk(ΔC))
rmul!(Ā, one(eltype(Ā)) / 2)
Expand All @@ -489,10 +491,10 @@ end
function rrule(
::typeof(cholesky),
A::StridedMatrix{<:Union{Real,Complex}},
::Val{false};
pivot::CHOLESKY_NoPivot;
check::Bool=true,
)
C = cholesky(A, Val(false); check=check)
C = cholesky(A, pivot; check=check)
function cholesky_Strided_pullback(ΔC)
= _cholesky_pullback_shared_code(C, unthunk(ΔC))
idx = diagind(Ā)
Expand Down
30 changes: 16 additions & 14 deletions test/rulesets/LinearAlgebra/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ end
const LU_ROW_MAXIMUM = VERSION >= v"1.7.0-DEV.1188" ? RowMaximum() : Val(true)
const LU_NO_PIVOT = VERSION >= v"1.7.0-DEV.1188" ? NoPivot() : Val(false)

const CHOLESKY_NO_PIVOT = VERSION >= v"1.8.0-rc1" ? NoPivot() : Val(false)

# well-conditioned random n×n matrix with elements of type `T` for testing `eigen`
function rand_eigen(T::Type, n::Int)
# uniform distribution over `(-1, 1)` / `(-1, 1)^2`
Expand Down Expand Up @@ -394,7 +396,7 @@ end

@testset "Diagonal" begin
@testset "Diagonal{<:Real}" begin
test_rrule(cholesky, Diagonal([0.3, 0.2, 0.5, 0.6, 0.9]), Val(false))
test_rrule(cholesky, Diagonal([0.3, 0.2, 0.5, 0.6, 0.9]), CHOLESKY_NO_PIVOT)
end
@testset "Diagonal{<:Complex}" begin
# finite differences in general will produce matrices with non-real
Expand All @@ -403,26 +405,26 @@ end
D = Diagonal([0.3 + 0im, 0.2, 0.5, 0.6, 0.9])
C = cholesky(D)
test_rrule(
cholesky, D, Val(false);
cholesky, D, CHOLESKY_NO_PIVOT;
output_tangent=Tangent{typeof(C)}(factors=complex(randn(5, 5))),
fkwargs=(; check=false),
)
end
@testset "check has correct default and passed to primal" begin
@test_throws Exception rrule(cholesky, Diagonal(-rand(5)), Val(false))
rrule(cholesky, Diagonal(-rand(5)), Val(false); check=false)
@test_throws Exception rrule(cholesky, Diagonal(-rand(5)), CHOLESKY_NO_PIVOT)
rrule(cholesky, Diagonal(-rand(5)), CHOLESKY_NO_PIVOT; check=false)
end
@testset "failed factorization" begin
A = Diagonal(vcat(rand(4), -rand(4), rand(4)))
test_rrule(cholesky, A, Val(false); fkwargs=(; check=false))
test_rrule(cholesky, A, CHOLESKY_NO_PIVOT; fkwargs=(; check=false))
end
end

@testset "StridedMatrix" begin
@testset "Matrix{$T}" for T in (Float64, ComplexF64)
X = generate_well_conditioned_matrix(T, 10)
V = generate_well_conditioned_matrix(T, 10)
F, dX_pullback = rrule(cholesky, X, Val(false))
F, dX_pullback = rrule(cholesky, X, CHOLESKY_NO_PIVOT)
@testset "uplo=$p, cotangent eltype=$T" for p in [:U, :L], S in unique([T, complex(T)])
Y, dF_pullback = rrule(getproperty, F, p)
= randn(S, size(Y))
Expand All @@ -447,22 +449,22 @@ end
@testset "check has correct default and passed to primal" begin
# this will almost certainly be a non-PD matrix
X = Matrix(Symmetric(randn(10, 10)))
@test_throws Exception rrule(cholesky, X, Val(false))
rrule(cholesky, X, Val(false); check=false) # just check it doesn't throw
@test_throws Exception rrule(cholesky, X, CHOLESKY_NO_PIVOT)
rrule(cholesky, X, CHOLESKY_NO_PIVOT; check=false) # just check it doesn't throw
end
end

# Ensure that cotangents of cholesky(::StridedMatrix) and
# (cholesky ∘ Symmetric)(::StridedMatrix) are equal.
@testset "Symmetric" begin
X = generate_well_conditioned_matrix(10)
F, dX_pullback = rrule(cholesky, X, Val(false))
F, dX_pullback = rrule(cholesky, X, CHOLESKY_NO_PIVOT)
ΔU = randn(size(X))
ΔF = Tangent{typeof(F)}(; factors=ΔU)

@testset for uplo in (:L, :U)
X_symmetric, sym_back = rrule(Symmetric, X, uplo)
C, chol_back_sym = rrule(cholesky, X_symmetric, Val(false))
C, chol_back_sym = rrule(cholesky, X_symmetric, CHOLESKY_NO_PIVOT)

ΔC = Tangent{typeof(C)}(; factors=(uplo === :U ? ΔU : ΔU'))
ΔX_symmetric = chol_back_sym(ΔC)[2]
Expand All @@ -479,13 +481,13 @@ end
@testset "Hermitian" begin
@testset "Hermitian{$T}" for T in (Float64, ComplexF64)
X = generate_well_conditioned_matrix(T, 10)
F, dX_pullback = rrule(cholesky, X, Val(false))
F, dX_pullback = rrule(cholesky, X, CHOLESKY_NO_PIVOT)
ΔU = randn(T, size(X))
ΔF = Tangent{typeof(F)}(; factors=ΔU)

@testset for uplo in (:L, :U)
X_hermitian, herm_back = rrule(Hermitian, X, uplo)
C, chol_back_herm = rrule(cholesky, X_hermitian, Val(false))
C, chol_back_herm = rrule(cholesky, X_hermitian, CHOLESKY_NO_PIVOT)

ΔC = Tangent{typeof(C)}(; factors=(uplo === :U ? ΔU : ΔU'))
ΔX_hermitian = chol_back_herm(ΔC)[2]
Expand All @@ -499,8 +501,8 @@ end
@testset "check has correct default and passed to primal" begin
# this will almost certainly be a non-PD matrix
X = Hermitian(randn(10, 10))
@test_throws Exception rrule(cholesky, X, Val(false))
rrule(cholesky, X, Val(false); check=false)
@test_throws Exception rrule(cholesky, X, CHOLESKY_NO_PIVOT)
rrule(cholesky, X, CHOLESKY_NO_PIVOT; check=false)
end
end

Expand Down

6 comments on commit 1f4a8a9

@sethaxen
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/65607

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.42.0 -m "<description of version>" 1f4a8a9d86c79f024a911f61aa180bdc094bb8a3
git push origin v1.42.0

Also, note the warning: Version 1.42.0 skips over 1.40.0
This can be safely ignored. However, if you want to fix this you can do so. Call register() again after making the fix. This will update the Pull request.

@sethaxen
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request updated: JuliaRegistries/General/65607

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.42.0 -m "<description of version>" 1f4a8a9d86c79f024a911f61aa180bdc094bb8a3
git push origin v1.42.0

Also, note the warning: Version 1.42.0 skips over 1.41.0
This can be safely ignored. However, if you want to fix this you can do so. Call register() again after making the fix. This will update the Pull request.

@sethaxen
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request updated: JuliaRegistries/General/65607

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.42.0 -m "<description of version>" 1f4a8a9d86c79f024a911f61aa180bdc094bb8a3
git push origin v1.42.0

Please sign in to comment.