Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of VecCorrBijector #246

Merged
merged 179 commits into from
Jun 12, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
179 commits
Select commit Hold shift + click to select a range
79b92c9
initial work on VecCorrBijector
torfjelde Feb 6, 2023
aa2fe61
added some tests for CorrBijector, and fixed implementation for VecCo…
torfjelde Feb 6, 2023
8d23094
improved tests and are now using integer sqrt and division
torfjelde Feb 6, 2023
a35e36f
moved things around a bit
torfjelde Feb 12, 2023
8cadf69
added chainrule for ReverseDiff
torfjelde Feb 13, 2023
eaf5324
some fixes for AD
torfjelde Feb 13, 2023
36ffbdb
added some TODOs
torfjelde Feb 13, 2023
62ae1ac
Update src/bijectors/corr.jl
torfjelde Mar 24, 2023
3f25a8b
define bijectors for `LKJ` and `LKJCholesky`
harisorgn Apr 4, 2023
e1567c3
add `TransformedDistribution` constructor
harisorgn Apr 6, 2023
8d07e34
define `logpdf` for `LKJ` & `LKJCholesky`
harisorgn Apr 6, 2023
9a59a9f
define `rand` for `LKJ` & `LKJCholesky`
harisorgn Apr 6, 2023
f15ad85
add util to extract Cholesky factor
harisorgn Apr 6, 2023
53e78f3
TYPO: capitalize matrix
harisorgn Apr 6, 2023
ec7d20e
add util to convert `Vector` index
harisorgn Apr 6, 2023
2ed00f4
add `VecTriBijector`s for `LKJCholesky`
harisorgn Apr 6, 2023
07555fc
TYPO: capitilize matrix
harisorgn Apr 6, 2023
a75cabc
add `LKJCholesky` link for `UpperTriangular`
harisorgn Apr 6, 2023
844b07e
add `LKJCholesky` link for `LowerTriangular`
harisorgn Apr 6, 2023
792cfe9
TYPO: capitalize matrix
harisorgn Apr 6, 2023
8f0886b
add `LKJCholesky` inverse link to `UpperTriangular`
harisorgn Apr 6, 2023
35f1c03
rename `_logabsdetjac_chol_lkj`
harisorgn Apr 6, 2023
9d55829
dispatch `_logabsdetjac_inv_corr` for `::Vector`
harisorgn Apr 6, 2023
adf10ad
add logabsdetjac for inverse link of `LKJCholesky`
harisorgn Apr 6, 2023
03a55b2
add tests for `VecTriBijector`s
harisorgn Apr 6, 2023
1059569
add `rrule` for LKJ(Cholesky) link function
harisorgn Apr 6, 2023
222eb6e
Merge branch 'torfjelde/vec-corr' into ho/vec-lkj-cholesky
harisorgn Apr 6, 2023
7f5d0fc
Merge pull request #1 from harisorgn/ho/vec-lkj-cholesky
harisorgn Apr 6, 2023
ad080ea
use `transpose` in link for `::LowerTriangular'
harisorgn Apr 11, 2023
6e1a5b1
add `Tracker` support for inverse link
harisorgn Apr 12, 2023
5fd0a65
better utility function call
harisorgn Apr 12, 2023
b38acda
use function barrier properly for type stability
harisorgn Apr 12, 2023
424f8ca
account for difference in support dimensions
harisorgn Apr 13, 2023
b749d37
fix indexing in Jacobian of `VecCorrBijector`
harisorgn Apr 13, 2023
7b1f74d
add `_logabsdetjac_dist` for `::LKJCholesky`
harisorgn Apr 13, 2023
75c605b
replace function composition for proper barrier
harisorgn Apr 13, 2023
a7a6c05
add util convert `Transpose -> Matrix` for type stability
harisorgn Apr 13, 2023
09c35b6
add `LKJCholesky` Jacobian+type tests
harisorgn Apr 13, 2023
2ad5038
fix `logabsdetjac` for inverse link
harisorgn Apr 14, 2023
f5be4e2
use `Cholesky` constructor compatible with `v1.6`
harisorgn Apr 14, 2023
10d9345
add empty line
harisorgn Apr 17, 2023
bcf32a3
fix `rrule` for link function
harisorgn Apr 17, 2023
7f4551f
add link `rrule` test
harisorgn Apr 17, 2023
dc2c856
add `rrule` for inverse link
harisorgn Apr 17, 2023
87bc3ca
remove TODO
harisorgn Apr 17, 2023
bfb7c15
add inverse link `rrule` test
harisorgn Apr 17, 2023
20ab3b4
Update src/bijectors/corr.jl
harisorgn Apr 17, 2023
7bb37e0
add link `rrule` for `LowerTriangular`
harisorgn Apr 18, 2023
3e2c7a8
add `LowerTriangular` chainrule test
harisorgn Apr 18, 2023
adba9e8
Update src/bijectors/corr.jl
harisorgn Apr 18, 2023
ec18964
remove unused util
harisorgn Apr 18, 2023
37c38ab
use `similar` instead of `zeros`
harisorgn Apr 18, 2023
8fd13b0
update comments
harisorgn Apr 18, 2023
56cc43f
remove old comment
harisorgn Apr 18, 2023
8ee086a
minimize zero-setting operations in inverse link
harisorgn Apr 18, 2023
837b49c
minimize zero-setting operations in `rrule`
harisorgn Apr 18, 2023
0c3aa39
add parametric `Val` type to `VecCorrBijector`
harisorgn Apr 18, 2023
c1be272
update `VecCorrBijector` tests
harisorgn Apr 18, 2023
29fced6
use field value instead of `Val`-parametric type
harisorgn Apr 18, 2023
74d6edb
update tests with new `VecCorrBijector`
harisorgn Apr 18, 2023
4c27987
`using VecCorrBijector` in test utils
harisorgn Apr 18, 2023
9108c40
add `VecCorrBijector.mode` check
harisorgn Apr 18, 2023
24847cc
update `VecCorrBijector` docstring
harisorgn Apr 18, 2023
bd4de96
specialise `Zygote@adjoint` for `AbstractMatrix`
harisorgn Apr 18, 2023
65bfc42
`ReverseDiff` opt-in to `ChainRules`
harisorgn Apr 18, 2023
eca3411
empty lines format
harisorgn Apr 18, 2023
f02fd9b
add AD test for inverse link
harisorgn Apr 18, 2023
c90f7ac
include `VecCorrBijector` tests
harisorgn Apr 18, 2023
974efb5
remove broken flag for `Tracker`
harisorgn Apr 18, 2023
71fdae6
add roundtrip AD tests for `VecCorrBijector`
harisorgn Apr 18, 2023
6524fe4
remove wrong `ReverseDiff.@grad` for `pd_from_upper`
harisorgn Apr 19, 2023
5e4abae
add corrected `rrule` for `pd_from_upper`
harisorgn Apr 19, 2023
c547542
update AD tests
harisorgn Apr 19, 2023
0d599e8
remove `Tracker` from broken
harisorgn Apr 19, 2023
a1f16b6
update zero-filling in `Tracker` pullback
harisorgn Apr 25, 2023
8b4b0c7
fix `Zygote`
harisorgn Apr 25, 2023
890127f
merge lines - applying feedback suggestions
harisorgn May 4, 2023
fa13e27
`unthunk` in `pd_from_upper` rrule
harisorgn May 24, 2023
a36f2b6
split structs into `VecCorrBijector` and `VecCholeskyBijector`
harisorgn May 24, 2023
9690dd2
remove old `Zygote` adjoints
harisorgn May 24, 2023
8a67713
update tests
harisorgn May 24, 2023
37cfd90
fix `Union` in `@inferred` after splitting structs
harisorgn May 24, 2023
a3c7f57
remove `Tracker` tests as support is dropped
harisorgn May 24, 2023
df4d960
use `permutedims` instead of casting
harisorgn Jun 6, 2023
17f784f
remove `Union` in `@inferred`
harisorgn Jun 6, 2023
852573d
initial work on VecCorrBijector
torfjelde Feb 6, 2023
cea5f19
added some tests for CorrBijector, and fixed implementation for VecCo…
torfjelde Feb 6, 2023
89612cc
improved tests and are now using integer sqrt and division
torfjelde Feb 6, 2023
bc8f755
moved things around a bit
torfjelde Feb 12, 2023
9b3d7e9
added chainrule for ReverseDiff
torfjelde Feb 13, 2023
b1176d0
some fixes for AD
torfjelde Feb 13, 2023
f3a623f
added some TODOs
torfjelde Feb 13, 2023
d46e966
define bijectors for `LKJ` and `LKJCholesky`
harisorgn Apr 4, 2023
f210356
add `TransformedDistribution` constructor
harisorgn Apr 6, 2023
71e1017
define `logpdf` for `LKJ` & `LKJCholesky`
harisorgn Apr 6, 2023
37e649c
define `rand` for `LKJ` & `LKJCholesky`
harisorgn Apr 6, 2023
c09c5c8
add util to extract Cholesky factor
harisorgn Apr 6, 2023
2a514c8
TYPO: capitalize matrix
harisorgn Apr 6, 2023
6596c9e
add util to convert `Vector` index
harisorgn Apr 6, 2023
6123d6d
add `VecTriBijector`s for `LKJCholesky`
harisorgn Apr 6, 2023
791f764
TYPO: capitilize matrix
harisorgn Apr 6, 2023
f47cdac
add `LKJCholesky` link for `UpperTriangular`
harisorgn Apr 6, 2023
959b836
add `LKJCholesky` link for `LowerTriangular`
harisorgn Apr 6, 2023
a8ccaa1
TYPO: capitalize matrix
harisorgn Apr 6, 2023
82bf085
add `LKJCholesky` inverse link to `UpperTriangular`
harisorgn Apr 6, 2023
597b6a1
rename `_logabsdetjac_chol_lkj`
harisorgn Apr 6, 2023
54dd86d
dispatch `_logabsdetjac_inv_corr` for `::Vector`
harisorgn Apr 6, 2023
eaf60f7
add logabsdetjac for inverse link of `LKJCholesky`
harisorgn Apr 6, 2023
861eef6
add tests for `VecTriBijector`s
harisorgn Apr 6, 2023
78b9999
add `rrule` for LKJ(Cholesky) link function
harisorgn Apr 6, 2023
5b4119a
use `transpose` in link for `::LowerTriangular'
harisorgn Apr 11, 2023
011534c
add `Tracker` support for inverse link
harisorgn Apr 12, 2023
ff61ef0
better utility function call
harisorgn Apr 12, 2023
a2ec603
use function barrier properly for type stability
harisorgn Apr 12, 2023
4c3a68b
account for difference in support dimensions
harisorgn Apr 13, 2023
6349546
fix indexing in Jacobian of `VecCorrBijector`
harisorgn Apr 13, 2023
e65a78b
add `_logabsdetjac_dist` for `::LKJCholesky`
harisorgn Apr 13, 2023
b6b7fa6
replace function composition for proper barrier
harisorgn Apr 13, 2023
fd24602
add util convert `Transpose -> Matrix` for type stability
harisorgn Apr 13, 2023
1cd62d1
add `LKJCholesky` Jacobian+type tests
harisorgn Apr 13, 2023
f437e68
fix `logabsdetjac` for inverse link
harisorgn Apr 14, 2023
85397e8
use `Cholesky` constructor compatible with `v1.6`
harisorgn Apr 14, 2023
aa5685a
add empty line
harisorgn Apr 17, 2023
df264d6
fix `rrule` for link function
harisorgn Apr 17, 2023
599cb66
add link `rrule` test
harisorgn Apr 17, 2023
9cd42c0
add `rrule` for inverse link
harisorgn Apr 17, 2023
9de4734
remove TODO
harisorgn Apr 17, 2023
befa1cc
add inverse link `rrule` test
harisorgn Apr 17, 2023
6ba1c1f
Update src/bijectors/corr.jl
harisorgn Apr 17, 2023
79ad5f8
add link `rrule` for `LowerTriangular`
harisorgn Apr 18, 2023
19e8843
add `LowerTriangular` chainrule test
harisorgn Apr 18, 2023
4216dbd
Update src/bijectors/corr.jl
harisorgn Apr 18, 2023
e70430f
remove unused util
harisorgn Apr 18, 2023
2caba1c
use `similar` instead of `zeros`
harisorgn Apr 18, 2023
561f6b1
update comments
harisorgn Apr 18, 2023
69f5daa
remove old comment
harisorgn Apr 18, 2023
ca9807e
minimize zero-setting operations in inverse link
harisorgn Apr 18, 2023
1883b36
minimize zero-setting operations in `rrule`
harisorgn Apr 18, 2023
f84b329
add parametric `Val` type to `VecCorrBijector`
harisorgn Apr 18, 2023
2918463
update `VecCorrBijector` tests
harisorgn Apr 18, 2023
2c4920d
use field value instead of `Val`-parametric type
harisorgn Apr 18, 2023
1872bb6
update tests with new `VecCorrBijector`
harisorgn Apr 18, 2023
1250592
`using VecCorrBijector` in test utils
harisorgn Apr 18, 2023
66b4caa
add `VecCorrBijector.mode` check
harisorgn Apr 18, 2023
c5cb535
update `VecCorrBijector` docstring
harisorgn Apr 18, 2023
8a06239
specialise `Zygote@adjoint` for `AbstractMatrix`
harisorgn Apr 18, 2023
44b3b9f
`ReverseDiff` opt-in to `ChainRules`
harisorgn Apr 18, 2023
a5d601d
empty lines format
harisorgn Apr 18, 2023
8783271
add AD test for inverse link
harisorgn Apr 18, 2023
a197076
include `VecCorrBijector` tests
harisorgn Apr 18, 2023
7b9d1b2
remove broken flag for `Tracker`
harisorgn Apr 18, 2023
5d1a7b8
add roundtrip AD tests for `VecCorrBijector`
harisorgn Apr 18, 2023
a0d5e52
remove wrong `ReverseDiff.@grad` for `pd_from_upper`
harisorgn Apr 19, 2023
bd0efff
add corrected `rrule` for `pd_from_upper`
harisorgn Apr 19, 2023
e3314a4
update AD tests
harisorgn Apr 19, 2023
c34ad47
remove `Tracker` from broken
harisorgn Apr 19, 2023
e154061
update zero-filling in `Tracker` pullback
harisorgn Apr 25, 2023
cffb616
fix `Zygote`
harisorgn Apr 25, 2023
c13fce6
merge lines - applying feedback suggestions
harisorgn May 4, 2023
dfeb71e
`unthunk` in `pd_from_upper` rrule
harisorgn May 24, 2023
5210437
split structs into `VecCorrBijector` and `VecCholeskyBijector`
harisorgn May 24, 2023
25a70b4
remove old `Zygote` adjoints
harisorgn May 24, 2023
b056fdd
update tests
harisorgn May 24, 2023
33a8a29
fix `Union` in `@inferred` after splitting structs
harisorgn May 24, 2023
bfa448b
remove `Tracker` tests as support is dropped
harisorgn May 24, 2023
96b90e6
use `permutedims` instead of casting
harisorgn Jun 6, 2023
48edf87
remove `Union` in `@inferred`
harisorgn Jun 6, 2023
a25b36f
Merge branch 'torfjelde/vec-corr' of https://github.com/TuringLang/Bi…
harisorgn Jun 6, 2023
159ddb6
wrap matrix in `Hermitian` before `cholesky`
harisorgn Jun 6, 2023
1bfb2ee
Merge branch 'master' into torfjelde/vec-corr
torfjelde Jun 6, 2023
9c3dec8
add hacky dispatch for `cholesky_factor` and `ReverseDiff`
harisorgn Jun 8, 2023
980660a
Merge branch 'torfjelde/vec-corr' of https://github.com/TuringLang/Bi…
harisorgn Jun 8, 2023
87a6fac
import `cholesky_factor` in ReverseDiff module for hacky dispatch
harisorgn Jun 8, 2023
1d8999f
only use hacky `cholesky_factor` in versions before fix
harisorgn Jun 8, 2023
424607d
change `LKJCholesky` shape to avoid stochastic test failures
harisorgn Jun 8, 2023
be5c1c5
Merge branch 'master' into torfjelde/vec-corr
yebai Jun 10, 2023
6aeebbf
remove old TODOs
harisorgn Jun 12, 2023
62ca234
add explicit zero-filling in link for `CorrBijector`
harisorgn Jun 12, 2023
f439682
Merge branch 'torfjelde/vec-corr' of https://github.com/TuringLang/Bi…
harisorgn Jun 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 183 additions & 35 deletions src/bijectors/corr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ struct CorrBijector <: Bijector end
with_logabsdet_jacobian(b::CorrBijector, x) = transform(b, x), logabsdetjac(b, x)

function transform(b::CorrBijector, x::AbstractMatrix{<:Real})
w = cholesky(x).U # keep LowerTriangular until here can avoid some computation
w = upper_triangular(parent(cholesky(x).U)) # keep LowerTriangular until here can avoid some computation
r = _link_chol_lkj(w)
return r + zero(x)
# This dense format itself is required by a test, though I can't get the point.
Expand All @@ -75,51 +75,158 @@ end

function transform(ib::Inverse{CorrBijector}, y::AbstractMatrix{<:Real})
w = _inv_link_chol_lkj(y)
return w' * w
return pd_from_upper(w)
end

function logabsdetjac(::Inverse{CorrBijector}, y::AbstractMatrix{<:Real})
K = LinearAlgebra.checksquare(y)

result = float(zero(eltype(y)))
for j in 2:K, i in 1:(j - 1)
@inbounds abs_y_i_j = abs(y[i, j])
result += (K - i + 1) * (
IrrationalConstants.logtwo - (abs_y_i_j + LogExpFunctions.log1pexp(-2 * abs_y_i_j))
)
end

return result
end
logabsdetjac(::Inverse{CorrBijector}, Y::AbstractMatrix{<:Real}) = _logabsdetjac_chol_lkj(Y)
function logabsdetjac(b::CorrBijector, X::AbstractMatrix{<:Real})
#=
It may be more efficient if we can use un-contraint value to prevent call of b
It's recommended to directly call
`logabsdetjac(::Inverse{CorrBijector}, y::AbstractMatrix{<:Real})`
if possible.
=#
return -logabsdetjac(inverse(b), (b(X)))
return -logabsdetjac(inverse(b), (b(X)))
end

function _inv_link_chol_lkj(y)
K = LinearAlgebra.checksquare(y)
"""
triu_mask(X::AbstractMatrix, k::Int)

w = similar(y)
Return a mask for elements of `X` above the `k`th diagonal.
"""
function triu_mask(X::AbstractMatrix, k::Int)
# Ensure that we're working with a square matrix.
LinearAlgebra.checksquare(X)

@inbounds for j in 1:K
w[1, j] = 1
for i in 2:j
z = tanh(y[i-1, j])
tmp = w[i-1, j]
w[i-1, j] = z * tmp
w[i, j] = tmp * sqrt(1 - z^2)
end
for i in (j+1):K
w[i, j] = 0
# Using `similar` allows us to respect device of array, etc., e.g. `CuArray`.
m = similar(X, Bool)
return triu(.~m .| m, k)
end

triu_to_vec(X::AbstractMatrix{<:Real}, k::Int) = X[triu_mask(X, k)]

function update_triu_from_vec!(
vals::AbstractVector{<:Real},
k::Int,
X::AbstractMatrix{<:Real}
)
# Ensure that we're working with one-based indexing.
# `triu` requires this too.
LinearAlgebra.require_one_based_indexing(X)

# Set the values.
idx = 1
m, n = size(X)
for j = 1:n
for i = 1:min(j - k, m)
X[i, j] = vals[idx]
idx += 1
end
end

return w

return X
end

function update_triu_from_vec(vals::AbstractVector{<:Real}, k::Int, dim::Int)
X = similar(vals, dim, dim)
# TODO: Do we need this?
X .= 0
return update_triu_from_vec!(vals, k, X)
end

function ChainRulesCore.rrule(::typeof(update_triu_from_vec), x::AbstractVector{<:Real}, k::Int, dim::Int)
function update_triu_from_vec_pullback(ΔX)
return (
ChainRulesCore.NoTangent(),
triu_to_vec(ChainRulesCore.unthunk(ΔX), k),
ChainRulesCore.NoTangent(),
ChainRulesCore.NoTangent()
)
end
return update_triu_from_vec(x, k, dim), update_triu_from_vec_pullback
end

# n * (n - 1) / 2 = d
# ⟺ n^2 - n - 2d = 0
# ⟹ n = (1 + sqrt(1 + 8d)) / 2
_triu1_dim_from_length(d) = (1 + isqrt(1 + 8d)) ÷ 2

"""
triu1_to_vec(X::AbstractMatrix{<:Real})

Extracts elements from upper triangle of `X` with offset `1` and returns them as a vector.
"""
triu1_to_vec(X::AbstractMatrix) = triu_to_vec(X, 1)

inverse(::typeof(triu1_to_vec)) = vec_to_triu1

"""
vec_to_triu1(x::AbstractVector{<:Real})

Constructs a matrix from a vector `x` by filling the upper triangle with offset `1`.
"""
function vec_to_triu1(x::AbstractVector)
n = _triu1_dim_from_length(length(x))
X = update_triu_from_vec(x, 1, n)
return upper_triangular(X)
end

inverse(::typeof(vec_to_triu1)) = triu1_to_vec

"""
VecCorrBijector <: Bijector

Similar to `CorrBijector`, but correlation matrix to a vector,
and its inverse transforms vector to a correlation matrix.

See also: [`CorrBijector`](@ref)

# Example

```jldoctest
julia> using LinearAlgebra

julia> using StableRNGs; rng = StableRNG(42);

julia> b = Bijectors.VecCorrBijector();

julia> X = rand(rng, LKJ(3, 1)) # Sample a correlation matrix.
3×3 Matrix{Float64}:
1.0 -0.705273 -0.348638
-0.705273 1.0 0.0534538
-0.348638 0.0534538 1.0

julia> y = b(X) # Transform to unconstrained vector representation.
3-element Vector{Float64}:
-0.8777149781928181
-0.3638927608636788
-0.29813769428942216

julia> inverse(b)(y) ≈ X # (✓) Round-trip through `b` and its inverse.
true
"""
struct VecCorrBijector <: Bijector end
with_logabsdet_jacobian(b::VecCorrBijector, x) = transform(b, x), logabsdetjac(b, x)

function transform(::VecCorrBijector, X::AbstractMatrix{<:Real})
w = upper_triangular(parent(cholesky(X).U))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hate this:(

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we not have w = cholesky(X).U and work with a w::UpperTriangular instead of a dense matrix? Tried it locally, no real gain in performance though.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we not have w = cholesky(X).U and work with a w::UpperTriangular instead of a dense matrix?

Yeah, but this won't work with some of the AD backends (I know, it's super-annoying...). If you have a look at our compat-code for ReverseDiff (I believe), I think you'll see that we have to do some custom stuff to compute the pullback.

Tried it locally, no real gain in performance though.

I don't think we'd expect it to because internally we're iterating over the relevant elements of the matrix anyways, i.e. we're not gaining anything by telling the rest of the computational path that we're actually working on a lower-triangular matrix because it already assumes the given matrix is lower-triangular.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Completely unrelated, but if it is not 100%-guaranteed that you always end up with an upper triangular matrix when calling cholesky (which I think you can't if AbstractMatrix is supported), it would be better to work with .UL instead of .U (as we already do in other places of Bijectors and other libraries).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's 100% guaranteed that it's available though, right? So it's a question of efficiency, not correctness.

The problem here is that

  1. We need to work with a lower-triangular matrix (unless we re-implement _link_chol_lkj to work with vector).
  2. We drop the type-information in the construction of w, hence we don't actually know if it was a uppper- or lower-triangular (in the case where we do cholesky(...).UL).

All in all, it seems we need an additional diversion, e.g.

link_chol_lkj(x::LowerTriangular) = link_chol_lkj(lower_triangular(parent(x)))
link_chol_lkj(x::UpperTriangular) = link_chol_lkj(transpose(upper_triangular(parent(x))))
link_chol_lkj(x::AbstractMatrix) = _link_chol_lkj(x)  # assume it's lower-triangular

?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would guess it's mainly for efficiency reasons. But it's difficult to say if there are other implications as well, e.g., regarding AD. An alternative to .UL would be something like what's used in PDMats: https://github.com/JuliaStats/PDMats.jl/blob/fff131e11e23403931a42f5bfb3384f0d2b114c9/src/chol.jl#L6-L11 That should also be quite efficient and you could continue working with upper-triangular matrices.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would guess it's mainly for efficiency reasons.

https://github.com/JuliaLang/julia/blob/db7971f49912d1abba703345ca6eb43249607f32/stdlib/LinearAlgebra/src/cholesky.jl#L515-L527

But it's difficult to say if there are other implications as well, e.g., regarding AD.

Hmm fair, though IMO something like this seems like it would be a bug with the AD package, no?

An alternative to .UL would be something like what's used in PDMats: https://github.com/JuliaStats/PDMats.jl/blob/fff131e11e23403931a42f5bfb3384f0d2b114c9/src/chol.jl#L6-L11 That should also be quite efficient and you could continue working with upper-triangular matrices.

Me and @harisorgn were just having a chat and we're thinking of replacing upper_triangular(parent(cholesky(X).U)) with

cholesky_upper(x) = upper_triangular(parent(cholesky(X).U))
cholesky_lower(x) = lower_triangular(parent(cholesky(X).L))

to make it less likely that we forget or mess up somewhere.

But we can make it

cholesky_upper(x) = upper_triangular(parent(PDMats.chol_upper(cholesky(X))))
cholesky_lower(x) = lower_triangular(parent(PDMats.chol_lower(cholesky(X))))

But are you sure there's not a good reason for why the default is copy? Of course it's more mem-intensive, but will stuff like LowerTriangular(U') lead to slower computation paths (since you're now working with adjoint(U) rather than something that is actually lower-triangular)? E.g. indexing adjoin(U) surely involves more computations than indexing copy(adjoint(U)).

r = _link_chol_lkj(w)

# Extract only the upper triangle of `r`.
return triu1_to_vec(r)
end

function transform(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real})
Y = vec_to_triu1(y)
w = _inv_link_chol_lkj(Y)
return pd_from_upper(w)
end

function logabsdetjac(b::VecCorrBijector, x)
return -logabsdetjac(inverse(b), b(x))
end
function logabsdetjac(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real})
return _logabsdetjac_chol_lkj(vec_to_triu1(y))
end

"""
Expand All @@ -138,7 +245,7 @@ end
But this implementation will not work when w[i-1, j] = 0.
Though it is a zero measure set, unit matrix initialization will not work.

For equivelence, following explanations is given by @torfjelde:
For equivalence, following explanations is given by @torfjelde:

For `(i, j)` in the loop below, we define

Expand All @@ -155,6 +262,7 @@ and so
which is the above implementation.
"""
function _link_chol_lkj(w)
# TODO: Implement adjoint to support reverse-mode AD backends properly.
K = LinearAlgebra.checksquare(w)

z = similar(w) # z is also UpperTriangular.
Expand All @@ -163,16 +271,56 @@ function _link_chol_lkj(w)
# This block can't be integrated with loop below, because w[1,1] != 0.
@inbounds z[1, 1] = 0

@inbounds for j=2:K
@inbounds for j = 2:K
z[1, j] = atanh(w[1, j])
tmp = sqrt(1 - w[1, j]^2)
for i in 2:(j - 1)
for i in 2:(j-1)
p = w[i, j] / tmp
tmp *= sqrt(1 - p^2)
z[i, j] = atanh(p)
end
z[j, j] = 0
end

return z
end

"""
_inv_link_chol_lkj(y)

Inverse link function for cholesky factor.
"""
function _inv_link_chol_lkj(y)
# TODO: Implement adjoint to support reverse-mode AD backends properly.
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
K = LinearAlgebra.checksquare(y)

w = similar(y)

@inbounds for j in 1:K
w[1, j] = 1
for i in 2:j
z = tanh(y[i-1, j])
tmp = w[i-1, j]
w[i-1, j] = z * tmp
w[i, j] = tmp * sqrt(1 - z^2)
end
for i in (j+1):K
w[i, j] = 0
end
end

return w
end

function _logabsdetjac_chol_lkj(Y::AbstractMatrix)
K = LinearAlgebra.checksquare(Y)

result = float(zero(eltype(Y)))
for j in 2:K, i in 1:(j-1)
@inbounds abs_y_i_j = abs(Y[i, j])
result += (K - i + 1) * (
IrrationalConstants.logtwo - (abs_y_i_j + LogExpFunctions.log1pexp(-2 * abs_y_i_j))
)
end
return result
end
6 changes: 2 additions & 4 deletions src/bijectors/pd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,14 @@ function replace_diag(f, X)
end
transform(b::PDBijector, X::AbstractMatrix{<:Real}) = pd_link(X)
function pd_link(X)
Y = lower(parent(cholesky(X; check = true).L))
Y = lower_triangular(parent(cholesky(X; check = true).L))
return replace_diag(log, Y)
end
lower(A::AbstractMatrix) = convert(typeof(A), LowerTriangular(A))

function transform(ib::Inverse{PDBijector}, Y::AbstractMatrix{<:Real})
X = replace_diag(exp, Y)
return getpd(X)
return pd_from_lower(X)
end
getpd(X) = LowerTriangular(X) * LowerTriangular(X)'

function logabsdetjac(b::PDBijector, X::AbstractMatrix{<:Real})
T = eltype(X)
Expand Down
33 changes: 26 additions & 7 deletions src/compat/reversediff.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
module ReverseDiffCompat

using ..ReverseDiff: ReverseDiff, @grad, value, track, TrackedReal, TrackedVector,
TrackedMatrix
TrackedMatrix, @grad_from_chainrules
using Requires, LinearAlgebra

using ..Bijectors: Elementwise, SimplexBijector, maphcat, simplex_link_jacobian,
simplex_invlink_jacobian, simplex_logabsdetjac_gradient, Inverse
import ..Bijectors: _eps, logabsdetjac, _logabsdetjac_scale, _simplex_bijector,
_simplex_inv_bijector, replace_diag, jacobian, getpd, lower,
_simplex_inv_bijector, replace_diag, jacobian, pd_from_lower, pd_from_upper,
lower_triangular, upper_triangular,
_inv_link_chol_lkj, _link_chol_lkj, _transform_ordered, _transform_inverse_ordered,
find_alpha

Expand Down Expand Up @@ -136,18 +137,34 @@ logabsdetjac(b::SimplexBijector, x::Union{TrackedVector, TrackedMatrix}) = track
end
end

getpd(X::TrackedMatrix) = track(getpd, X)
@grad function getpd(X::AbstractMatrix)
pd_from_lower(X::TrackedMatrix) = track(pd_from_lower, X)
@grad function pd_from_lower(X::AbstractMatrix)
Xd = value(X)
return LowerTriangular(Xd) * LowerTriangular(Xd)', Δ -> begin
Xl = LowerTriangular(Xd)
return (LowerTriangular(Δ' * Xl + Δ * Xl),)
end
end
lower(A::TrackedMatrix) = track(lower, A)
@grad function lower(A::AbstractMatrix)

pd_from_upper(X::TrackedMatrix) = track(pd_from_upper, X)
@grad function pd_from_upper(X::AbstractMatrix)
Xd = value(X)
return UpperTriangular(Xd)' * UpperTriangular(Xd), Δ -> begin
Xu = UpperTriangular(Xd)
return (UpperTriangular(Δ * Xu + Δ' * Xu),)
end
end

lower_triangular(A::TrackedMatrix) = track(lower_triangular, A)
@grad function lower_triangular(A::AbstractMatrix)
Ad = value(A)
return lower(Ad), Δ -> (lower(Δ),)
return lower_triangular(Ad), Δ -> (lower_triangular(Δ),)
end

upper_triangular(A::TrackedMatrix) = track(upper_triangular, A)
@grad function upper_triangular(A::AbstractMatrix)
Ad = value(A)
return upper_triangular(Ad), Δ -> (upper_triangular(Δ),)
end

function find_alpha(wt_y::T, wt_u_hat::T, b::T) where {T<:TrackedReal}
Expand Down Expand Up @@ -181,6 +198,8 @@ end
return y, (wrap_chainrules_output ∘ Base.tail ∘ dy)
end

@grad_from_chainrules update_triu_from_vec(vals::TrackedVector{<:Real}, k::Int, dim::Int)

# NOTE: Probably doesn't work in complete generality.
wrap_chainrules_output(x) = x
wrap_chainrules_output(x::ChainRulesCore.AbstractZero) = nothing
Expand Down
10 changes: 5 additions & 5 deletions src/compat/tracker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -281,19 +281,19 @@ end
(b::Elementwise{typeof(log)})(x::TrackedVector) = log.(x)::vectorof(float(eltype(x)))
(b::Elementwise{typeof(log)})(x::TrackedMatrix) = log.(x)::matrixof(float(eltype(x)))

Bijectors.getpd(X::TrackedMatrix) = track(Bijectors.getpd, X)
@grad function Bijectors.getpd(X::AbstractMatrix)
Bijectors.pd_from_lower(X::TrackedMatrix) = track(Bijectors.pd_from_lower, X)
@grad function Bijectors.pd_from_lower(X::AbstractMatrix)
Xd = data(X)
return Bijectors.LowerTriangular(Xd) * Bijectors.LowerTriangular(Xd)', Δ -> begin
Xl = Bijectors.LowerTriangular(Xd)
return (Bijectors.LowerTriangular(Δ' * Xl + Δ * Xl),)
end
end

Bijectors.lower(A::TrackedMatrix) = track(Bijectors.lower, A)
@grad function Bijectors.lower(A::AbstractMatrix)
Bijectors.lower_triangular(A::TrackedMatrix) = track(Bijectors.lower_triangular, A)
@grad function Bijectors.lower_triangular(A::AbstractMatrix)
Ad = data(A)
return Bijectors.lower(Ad), Δ -> (Bijectors.lower(Δ),)
return Bijectors.lower_triangular(Ad), Δ -> (Bijectors.lower_triangular(Δ),)
end

Bijectors._inv_link_chol_lkj(y::TrackedMatrix) = track(Bijectors._inv_link_chol_lkj, y)
Expand Down
6 changes: 3 additions & 3 deletions src/compat/zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,10 @@ end
end
return pullback(_maximum, d)
end
@adjoint function lower(A::AbstractMatrix)
return lower(A), Δ -> (lower(Δ),)
@adjoint function lower_triangular(A::AbstractMatrix)
return lower_triangular(A), Δ -> (lower_triangular(Δ),)
end
@adjoint function getpd(X::AbstractMatrix)
@adjoint function pd_from_lower(X::AbstractMatrix)
return LowerTriangular(X) * LowerTriangular(X)', Δ -> begin
Xl = LowerTriangular(X)
return (LowerTriangular(Δ' * Xl + Δ * Xl),)
Expand Down
Loading