-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Change Cholesky to use Composite
#164
Conversation
nickrobinson251
commented
Jan 23, 2020
- part of Update Cholesky and SVD to use Composite #151
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a comment on the tests explaining why they do test this correctly right now?
Given this change should have broken them
Here's the relevant test code in the tests, for both SVD and Cholesky, with julia> function svd_test_code()
rng = MersenneTwister(1)
X = generate_well_conditioned_matrix(rng, 2)
F, dX_pullback = rrule(svd, X);
p = :U
@show Y, dF_pullback = rrule(getproperty, F, p)
Ȳ = randn(rng, size(Y))
@show dself1, dF, dp = dF_pullback(Ȳ)
@show ΔF = unthunk(dF)
@show dself2, dX = dX_pullback(ΔF)
@show X̄_ad = unthunk(dX)
end
svd_test_code (generic function with 1 method)
julia> function cholesky_test_code()
rng = MersenneTwister(1)
X = generate_well_conditioned_matrix(rng, 2)
F, dX_pullback = rrule(cholesky, X);
p = :U
@show Y, dF_pullback = rrule(getproperty, F, p)
Ȳ = (p === :U ? UpperTriangular : LowerTriangular)(randn(rng, size(Y)))
@show dself1, dF, dp = dF_pullback(Ȳ)
@show ΔF = unthunk(dF)
@show dself2, dX = dX_pullback(ΔF)
@show X̄_ad = unthunk(dX)
end
cholesky_test_code (generic function with 1 method)
julia> svd_test_code()
(Y, dF_pullback) = rrule(getproperty, F, p) = ([-0.9434666894623768 -0.33146735265317234; -0.3314673526531724 0.943466689462377], ChainRules.var"#getproperty_svd_pullback#338"{SVD{Float64,Float64,Array{Float64,2}},Symbol}(:U))
(dself1, dF, dp) = dF_pullback(Ȳ) = (Zero(), Composite{SVD{Float64,Float64,Array{Float64,2}}}(U = [-0.839026854388764 2.2950878238373105; 0.31111133849833383 -2.2670863488005306],), DoesNotExist())
ΔF = unthunk(dF) = Composite{SVD{Float64,Float64,Array{Float64,2}}}(U = [-0.839026854388764 2.2950878238373105; 0.31111133849833383 -2.2670863488005306],)
(dself2, dX) = dX_pullback(ΔF) = (Zero(), Thunk(ChainRules.var"#335#337"{Composite{SVD{Float64,Float64,Array{Float64,2}},NamedTuple{(:U,),Tuple{Array{Float64,2}}}},SVD{Float64,Float64,Array{Float64,2}}}(Composite{SVD{Float64,Float64,Array{Float64,2}}}(U = [-0.839026854388764 2.2950878238373105; 0.31111133849833383 -2.2670863488005306],), SVD{Float64,Float64,Array{Float64,2}}([-0.9434666894623768 -0.33146735265317234; -0.3314673526531724 0.943466689462377], [1.4876799687617108, 1.1042029240797229], [-0.9434666894623771 -0.33146735265317245; -0.33146735265317245 0.9434666894623771])))
)
X̄_ad = unthunk(dX) = [1.6191958455730942 -1.6369264262492607; -2.4029745379167835 -1.619195845573095]
2×2 Array{Float64,2}:
1.6192 -1.63693
-2.40297 -1.6192
julia> cholesky_test_code()
(Y, dF_pullback) = rrule(getproperty, F, p) = ([1.202309075705321 0.09974487200952159; 0.0 1.0660144182073028], ChainRules.var"#getproperty_cholesky_pullback#348"{Cholesky{Float64,Array{Float64,2}},Cholesky{Float64,Array{Float64,2}},Symbol}(Cholesky{Float64,Array{Float64,2}}([1.202309075705321 0.09974487200952159; 0.11992416487211345 1.0660144182073028], 'U', 0), :U))
(dself1, dF, dp) = dF_pullback(Ȳ) = (Zero(), Composite{Cholesky{Float64,Array{Float64,2}}}(U = Thunk(ChainRules.var"#344#349"{UpperTriangular{Float64,Array{Float64,2}}}([-0.839026854388764 2.2950878238373105; 0.0 -2.2670863488005306]))
,), DoesNotExist())
ΔF = unthunk(dF) = Composite{Cholesky{Float64,Array{Float64,2}}}(U = Thunk(ChainRules.var"#344#349"{UpperTriangular{Float64,Array{Float64,2}}}([-0.839026854388764 2.2950878238373105; 0.0 -2.2670863488005306]))
,)
(dself2, dX) = dX_pullback(ΔF) = (Zero(), Thunk(ChainRules.var"#339#342"{Composite{Cholesky{Float64,Array{Float64,2}},NamedTuple{(:U,),Tuple{Thunk{ChainRules.var"#344#349"{UpperTriangular{Float64,Array{Float64,2}}}}}}},Cholesky{Float64,Array{Float64,2}}}(Composite{Cholesky{Float64,Array{Float64,2}}}(U = Thunk(ChainRules.var"#344#349"{UpperTriangular{Float64,Array{Float64,2}}}([-0.839026854388764 2.2950878238373105; 0.0 -2.2670863488005306]))
,), Cholesky{Float64,Array{Float64,2}}([1.202309075705321 0.09974487200952159; 0.11992416487211345 1.0660144182073028], 'U', 0)))
)
X̄_ad = unthunk(dX) = [-0.4354238588089415 2.085332868446868; 0.0 -1.063346944506177]
2×2 Array{Float64,2}:
-0.435424 2.08533
0.0 -1.06335 |
Now you prompted me to re-review (thanks!) There is one change to the test, and it as the same change as in the equivalent PR for SVD (#157): (actually SVD also has some tests for |
Composite