Skip to content

Commit 50a1e1d

Browse files
committed
Remove dependency DistributionsAD
1 parent ff97551 commit 50a1e1d

File tree

8 files changed

+21
-12
lines changed

8 files changed

+21
-12
lines changed

Project.toml

+1-2
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@ version = "0.2.3"
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
77
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
88
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
9-
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
109
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1110
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
1211
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1312
LinearMaps = "7a12625a-238d-50fd-b39a-03d52299707e"
1413
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
1514
PositiveFactorizations = "85a6dd25-e78a-55b7-8502-1745935b8125"
1615
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
16+
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1717
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1818
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1919
ValueShapes = "136a8f8c-c49b-4edb-8b98-f3d64d48be8f"
@@ -22,7 +22,6 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2222
[compat]
2323
ChainRulesCore = "0.9.44"
2424
Distributions = "0.23, 0.24, 0.25"
25-
DistributionsAD = "0.6"
2625
ForwardDiff = "0.10"
2726
IterativeSolvers = "0.8, 0.9"
2827
LinearMaps = "3"

README.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@
88

99
*This is an implementation of the [Metric Gaussian Variational Inference](https://arxiv.org/abs/1901.11033) (MGVI) algorithm in julia*
1010

11-
1211
MGVI is an iterative method that performs a series of Gaussian approximations to the posterior. It alternates between approximating the covariance with the inverse Fisher information metric evaluated at an intermediate mean estimate and optimizing the KL-divergence for the given covariance with respect to the mean. This procedure is iterated until the uncertainty estimate is self-consistent with the mean parameter. We achieve linear scaling by avoiding to store the covariance explicitly at any time. Instead we draw samples from the approximating distribution relying on an implicit representation and numerical schemes to approximately solve linear equations. Those samples are used to approximate the KL-divergence and its gradient. The usage of natural gradient descent allows for rapid convergence. Formulating the Bayesian model in standardized coordinates makes MGVI applicable to any inference problem with continuous parameters.
1312

13+
Depending on the distributions used in your application, you may need to use the package [DistributionsAD](https://github.com/TuringLang/DistributionsAD.jl).
14+
15+
1416
## Documentation
1517
* [Documentation for stable version](https://bat.github.io/MGVI.jl/stable)
1618
* [Documentation for development version](https://bat.github.io/MGVI.jl/dev)

docs/src/index.md

+2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
MGVI is an iterative method that performs a series of Gaussian approximations to the posterior. We alternate between approximating the covariance with the inverse Fisher information metric evaluated at an intermediate mean estimate and optimizing the KL-divergence for the given covariance with respect to the mean. This procedure is iterated until the uncertainty estimate is self-consistent with the mean parameter. We achieve linear scaling by avoiding to store the covariance explicitly at any time. Instead we draw samples from the approximating distribution relying on an implicit representation and numerical schemes to approximately solve linear equations. Those samples are used to approximate the KL-divergence and its gradient. The usage of natural gradient descent allows for rapid convergence. Formulating the Bayesian model in standardized coordinates makes MGVI applicable to any inference problem with continuous parameters.
44

5+
Depending on the distributions used in your application, you may need to use the package [DistributionsAD](https://github.com/TuringLang/DistributionsAD.jl).
6+
57

68
## Citing MGVI.jl
79

src/MGVI.jl

+7-1
Original file line numberDiff line numberDiff line change
@@ -16,23 +16,29 @@ using Random
1616
using SparseArrays
1717
using Base.Iterators
1818
using Distributions
19-
using DistributionsAD
2019
import ForwardDiff
2120
using LinearMaps
2221
using IterativeSolvers
2322
using Optim
2423
using PositiveFactorizations
24+
using Requires
2525
import SparseArrays: blockdiag
2626
using SparseArrays
2727
using StaticArrays
2828
using ValueShapes
2929
import Zygote
3030

31+
import Requires
32+
3133
include("custom_linear_maps.jl")
3234
include("shapes.jl")
3335
include("jacobian_maps.jl")
3436
include("information.jl")
3537
include("residual_samplers.jl")
3638
include("mgvi_impl.jl")
3739

40+
function __init__()
41+
@require DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" include("distributionsad_support.jl")
42+
end
43+
3844
end # module

src/distributionsad_support.jl

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# This file is a part of MGVI.jl, licensed under the MIT License (MIT).
2+
3+
4+
function unshaped_params(d::DistributionsAD.TuringDenseMvNormal)
5+
μ = d.m
6+
σ = convert(AbstractMatrix, d.C)
7+
vcat(μ, _uppertriang_to_vec(σ))
8+
end

src/shapes.jl

-6
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,3 @@ function unshaped_params(d::MvNormal)
5252
μ, σ = params(d)
5353
vcat(μ, _uppertriang_to_vec(σ))
5454
end
55-
56-
function unshaped_params(d::TuringDenseMvNormal)
57-
μ = d.m
58-
σ = convert(AbstractMatrix, d.C)
59-
vcat(μ, _uppertriang_to_vec(σ))
60-
end

test/Project.toml

-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
[deps]
22
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
33
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
4-
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
54
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
65
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
76
HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5"

test/information/information_utils.jl

-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# This file is a part of MGVI.jl, licensed under the MIT License (MIT).
22

33
using Distributions
4-
using DistributionsAD
54
using LinearAlgebra
65
import Zygote
76

0 commit comments

Comments
 (0)