-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implementation of VecCorrBijector
#246
Merged
Merged
Changes from 7 commits
Commits
Show all changes
179 commits
Select commit
Hold shift + click to select a range
79b92c9
initial work on VecCorrBijector
torfjelde aa2fe61
added some tests for CorrBijector, and fixed implementation for VecCo…
torfjelde 8d23094
improved tests and are now using integer sqrt and division
torfjelde a35e36f
moved things around a bit
torfjelde 8cadf69
added chainrule for ReverseDiff
torfjelde eaf5324
some fixes for AD
torfjelde 36ffbdb
added some TODOs
torfjelde 62ae1ac
Update src/bijectors/corr.jl
torfjelde 3f25a8b
define bijectors for `LKJ` and `LKJCholesky`
harisorgn e1567c3
add `TransformedDistribution` constructor
harisorgn 8d07e34
define `logpdf` for `LKJ` & `LKJCholesky`
harisorgn 9a59a9f
define `rand` for `LKJ` & `LKJCholesky`
harisorgn f15ad85
add util to extract Cholesky factor
harisorgn 53e78f3
TYPO: capitalize matrix
harisorgn ec7d20e
add util to convert `Vector` index
harisorgn 2ed00f4
add `VecTriBijector`s for `LKJCholesky`
harisorgn 07555fc
TYPO: capitilize matrix
harisorgn a75cabc
add `LKJCholesky` link for `UpperTriangular`
harisorgn 844b07e
add `LKJCholesky` link for `LowerTriangular`
harisorgn 792cfe9
TYPO: capitalize matrix
harisorgn 8f0886b
add `LKJCholesky` inverse link to `UpperTriangular`
harisorgn 35f1c03
rename `_logabsdetjac_chol_lkj`
harisorgn 9d55829
dispatch `_logabsdetjac_inv_corr` for `::Vector`
harisorgn adf10ad
add logabsdetjac for inverse link of `LKJCholesky`
harisorgn 03a55b2
add tests for `VecTriBijector`s
harisorgn 1059569
add `rrule` for LKJ(Cholesky) link function
harisorgn 222eb6e
Merge branch 'torfjelde/vec-corr' into ho/vec-lkj-cholesky
harisorgn 7f5d0fc
Merge pull request #1 from harisorgn/ho/vec-lkj-cholesky
harisorgn ad080ea
use `transpose` in link for `::LowerTriangular'
harisorgn 6e1a5b1
add `Tracker` support for inverse link
harisorgn 5fd0a65
better utility function call
harisorgn b38acda
use function barrier properly for type stability
harisorgn 424f8ca
account for difference in support dimensions
harisorgn b749d37
fix indexing in Jacobian of `VecCorrBijector`
harisorgn 7b1f74d
add `_logabsdetjac_dist` for `::LKJCholesky`
harisorgn 75c605b
replace function composition for proper barrier
harisorgn a7a6c05
add util convert `Transpose -> Matrix` for type stability
harisorgn 09c35b6
add `LKJCholesky` Jacobian+type tests
harisorgn 2ad5038
fix `logabsdetjac` for inverse link
harisorgn f5be4e2
use `Cholesky` constructor compatible with `v1.6`
harisorgn 10d9345
add empty line
harisorgn bcf32a3
fix `rrule` for link function
harisorgn 7f4551f
add link `rrule` test
harisorgn dc2c856
add `rrule` for inverse link
harisorgn 87bc3ca
remove TODO
harisorgn bfb7c15
add inverse link `rrule` test
harisorgn 20ab3b4
Update src/bijectors/corr.jl
harisorgn 7bb37e0
add link `rrule` for `LowerTriangular`
harisorgn 3e2c7a8
add `LowerTriangular` chainrule test
harisorgn adba9e8
Update src/bijectors/corr.jl
harisorgn ec18964
remove unused util
harisorgn 37c38ab
use `similar` instead of `zeros`
harisorgn 8fd13b0
update comments
harisorgn 56cc43f
remove old comment
harisorgn 8ee086a
minimize zero-setting operations in inverse link
harisorgn 837b49c
minimize zero-setting operations in `rrule`
harisorgn 0c3aa39
add parametric `Val` type to `VecCorrBijector`
harisorgn c1be272
update `VecCorrBijector` tests
harisorgn 29fced6
use field value instead of `Val`-parametric type
harisorgn 74d6edb
update tests with new `VecCorrBijector`
harisorgn 4c27987
`using VecCorrBijector` in test utils
harisorgn 9108c40
add `VecCorrBijector.mode` check
harisorgn 24847cc
update `VecCorrBijector` docstring
harisorgn bd4de96
specialise `Zygote@adjoint` for `AbstractMatrix`
harisorgn 65bfc42
`ReverseDiff` opt-in to `ChainRules`
harisorgn eca3411
empty lines format
harisorgn f02fd9b
add AD test for inverse link
harisorgn c90f7ac
include `VecCorrBijector` tests
harisorgn 974efb5
remove broken flag for `Tracker`
harisorgn 71fdae6
add roundtrip AD tests for `VecCorrBijector`
harisorgn 6524fe4
remove wrong `ReverseDiff.@grad` for `pd_from_upper`
harisorgn 5e4abae
add corrected `rrule` for `pd_from_upper`
harisorgn c547542
update AD tests
harisorgn 0d599e8
remove `Tracker` from broken
harisorgn a1f16b6
update zero-filling in `Tracker` pullback
harisorgn 8b4b0c7
fix `Zygote`
harisorgn 890127f
merge lines - applying feedback suggestions
harisorgn fa13e27
`unthunk` in `pd_from_upper` rrule
harisorgn a36f2b6
split structs into `VecCorrBijector` and `VecCholeskyBijector`
harisorgn 9690dd2
remove old `Zygote` adjoints
harisorgn 8a67713
update tests
harisorgn 37cfd90
fix `Union` in `@inferred` after splitting structs
harisorgn a3c7f57
remove `Tracker` tests as support is dropped
harisorgn df4d960
use `permutedims` instead of casting
harisorgn 17f784f
remove `Union` in `@inferred`
harisorgn 852573d
initial work on VecCorrBijector
torfjelde cea5f19
added some tests for CorrBijector, and fixed implementation for VecCo…
torfjelde 89612cc
improved tests and are now using integer sqrt and division
torfjelde bc8f755
moved things around a bit
torfjelde 9b3d7e9
added chainrule for ReverseDiff
torfjelde b1176d0
some fixes for AD
torfjelde f3a623f
added some TODOs
torfjelde d46e966
define bijectors for `LKJ` and `LKJCholesky`
harisorgn f210356
add `TransformedDistribution` constructor
harisorgn 71e1017
define `logpdf` for `LKJ` & `LKJCholesky`
harisorgn 37e649c
define `rand` for `LKJ` & `LKJCholesky`
harisorgn c09c5c8
add util to extract Cholesky factor
harisorgn 2a514c8
TYPO: capitalize matrix
harisorgn 6596c9e
add util to convert `Vector` index
harisorgn 6123d6d
add `VecTriBijector`s for `LKJCholesky`
harisorgn 791f764
TYPO: capitilize matrix
harisorgn f47cdac
add `LKJCholesky` link for `UpperTriangular`
harisorgn 959b836
add `LKJCholesky` link for `LowerTriangular`
harisorgn a8ccaa1
TYPO: capitalize matrix
harisorgn 82bf085
add `LKJCholesky` inverse link to `UpperTriangular`
harisorgn 597b6a1
rename `_logabsdetjac_chol_lkj`
harisorgn 54dd86d
dispatch `_logabsdetjac_inv_corr` for `::Vector`
harisorgn eaf60f7
add logabsdetjac for inverse link of `LKJCholesky`
harisorgn 861eef6
add tests for `VecTriBijector`s
harisorgn 78b9999
add `rrule` for LKJ(Cholesky) link function
harisorgn 5b4119a
use `transpose` in link for `::LowerTriangular'
harisorgn 011534c
add `Tracker` support for inverse link
harisorgn ff61ef0
better utility function call
harisorgn a2ec603
use function barrier properly for type stability
harisorgn 4c3a68b
account for difference in support dimensions
harisorgn 6349546
fix indexing in Jacobian of `VecCorrBijector`
harisorgn e65a78b
add `_logabsdetjac_dist` for `::LKJCholesky`
harisorgn b6b7fa6
replace function composition for proper barrier
harisorgn fd24602
add util convert `Transpose -> Matrix` for type stability
harisorgn 1cd62d1
add `LKJCholesky` Jacobian+type tests
harisorgn f437e68
fix `logabsdetjac` for inverse link
harisorgn 85397e8
use `Cholesky` constructor compatible with `v1.6`
harisorgn aa5685a
add empty line
harisorgn df264d6
fix `rrule` for link function
harisorgn 599cb66
add link `rrule` test
harisorgn 9cd42c0
add `rrule` for inverse link
harisorgn 9de4734
remove TODO
harisorgn befa1cc
add inverse link `rrule` test
harisorgn 6ba1c1f
Update src/bijectors/corr.jl
harisorgn 79ad5f8
add link `rrule` for `LowerTriangular`
harisorgn 19e8843
add `LowerTriangular` chainrule test
harisorgn 4216dbd
Update src/bijectors/corr.jl
harisorgn e70430f
remove unused util
harisorgn 2caba1c
use `similar` instead of `zeros`
harisorgn 561f6b1
update comments
harisorgn 69f5daa
remove old comment
harisorgn ca9807e
minimize zero-setting operations in inverse link
harisorgn 1883b36
minimize zero-setting operations in `rrule`
harisorgn f84b329
add parametric `Val` type to `VecCorrBijector`
harisorgn 2918463
update `VecCorrBijector` tests
harisorgn 2c4920d
use field value instead of `Val`-parametric type
harisorgn 1872bb6
update tests with new `VecCorrBijector`
harisorgn 1250592
`using VecCorrBijector` in test utils
harisorgn 66b4caa
add `VecCorrBijector.mode` check
harisorgn c5cb535
update `VecCorrBijector` docstring
harisorgn 8a06239
specialise `Zygote@adjoint` for `AbstractMatrix`
harisorgn 44b3b9f
`ReverseDiff` opt-in to `ChainRules`
harisorgn a5d601d
empty lines format
harisorgn 8783271
add AD test for inverse link
harisorgn a197076
include `VecCorrBijector` tests
harisorgn 7b9d1b2
remove broken flag for `Tracker`
harisorgn 5d1a7b8
add roundtrip AD tests for `VecCorrBijector`
harisorgn a0d5e52
remove wrong `ReverseDiff.@grad` for `pd_from_upper`
harisorgn bd0efff
add corrected `rrule` for `pd_from_upper`
harisorgn e3314a4
update AD tests
harisorgn c34ad47
remove `Tracker` from broken
harisorgn e154061
update zero-filling in `Tracker` pullback
harisorgn cffb616
fix `Zygote`
harisorgn c13fce6
merge lines - applying feedback suggestions
harisorgn dfeb71e
`unthunk` in `pd_from_upper` rrule
harisorgn 5210437
split structs into `VecCorrBijector` and `VecCholeskyBijector`
harisorgn 25a70b4
remove old `Zygote` adjoints
harisorgn b056fdd
update tests
harisorgn 33a8a29
fix `Union` in `@inferred` after splitting structs
harisorgn bfa448b
remove `Tracker` tests as support is dropped
harisorgn 96b90e6
use `permutedims` instead of casting
harisorgn 48edf87
remove `Union` in `@inferred`
harisorgn a25b36f
Merge branch 'torfjelde/vec-corr' of https://github.com/TuringLang/Bi…
harisorgn 159ddb6
wrap matrix in `Hermitian` before `cholesky`
harisorgn 1bfb2ee
Merge branch 'master' into torfjelde/vec-corr
torfjelde 9c3dec8
add hacky dispatch for `cholesky_factor` and `ReverseDiff`
harisorgn 980660a
Merge branch 'torfjelde/vec-corr' of https://github.com/TuringLang/Bi…
harisorgn 87a6fac
import `cholesky_factor` in ReverseDiff module for hacky dispatch
harisorgn 1d8999f
only use hacky `cholesky_factor` in versions before fix
harisorgn 424607d
change `LKJCholesky` shape to avoid stochastic test failures
harisorgn be5c1c5
Merge branch 'master' into torfjelde/vec-corr
yebai 6aeebbf
remove old TODOs
harisorgn 62ca234
add explicit zero-filling in link for `CorrBijector`
harisorgn f439682
Merge branch 'torfjelde/vec-corr' of https://github.com/TuringLang/Bi…
harisorgn File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I hate this:(
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we not have
w = cholesky(X).U
and work with aw::UpperTriangular
instead of a dense matrix? Tried it locally, no real gain in performance though.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, but this won't work with some of the AD backends (I know, it's super-annoying...). If you have a look at our compat-code for ReverseDiff (I believe), I think you'll see that we have to do some custom stuff to compute the pullback.
I don't think we'd expect it to because internally we're iterating over the relevant elements of the matrix anyways, i.e. we're not gaining anything by telling the rest of the computational path that we're actually working on a lower-triangular matrix because it already assumes the given matrix is lower-triangular.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Completely unrelated, but if it is not 100%-guaranteed that you always end up with an upper triangular matrix when calling
cholesky
(which I think you can't ifAbstractMatrix
is supported), it would be better to work with.UL
instead of.U
(as we already do in other places of Bijectors and other libraries).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's 100% guaranteed that it's available though, right? So it's a question of efficiency, not correctness.
The problem here is that
_link_chol_lkj
to work with vector).w
, hence we don't actually know if it was a uppper- or lower-triangular (in the case where we docholesky(...).UL
).All in all, it seems we need an additional diversion, e.g.
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would guess it's mainly for efficiency reasons. But it's difficult to say if there are other implications as well, e.g., regarding AD. An alternative to
.UL
would be something like what's used in PDMats: https://github.com/JuliaStats/PDMats.jl/blob/fff131e11e23403931a42f5bfb3384f0d2b114c9/src/chol.jl#L6-L11 That should also be quite efficient and you could continue working with upper-triangular matrices.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://github.com/JuliaLang/julia/blob/db7971f49912d1abba703345ca6eb43249607f32/stdlib/LinearAlgebra/src/cholesky.jl#L515-L527
Hmm fair, though IMO something like this seems like it would be a bug with the AD package, no?
Me and @harisorgn were just having a chat and we're thinking of replacing
upper_triangular(parent(cholesky(X).U))
withto make it less likely that we forget or mess up somewhere.
But we can make it
But are you sure there's not a good reason for why the default is
copy
? Of course it's more mem-intensive, but will stuff likeLowerTriangular(U')
lead to slower computation paths (since you're now working withadjoint(U)
rather than something that is actually lower-triangular)? E.g. indexingadjoin(U)
surely involves more computations than indexingcopy(adjoint(U))
.