Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modifies implementation of expected_loglik and add tests for it #93

Merged
merged 14 commits into from
Jan 19, 2022
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ApproximateGPs"
uuid = "298c2ebc-0411-48ad-af38-99e88101b606"
authors = ["JuliaGaussianProcesses Team"]
version = "0.2.5"
version = "0.2.6"

[deps]
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
Expand All @@ -11,6 +11,7 @@ FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
GPLikelihoods = "6031954c-0455-49d7-b3b9-3e1c99afaf40"
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
KLDivergences = "3c9cd921-3d3f-41e2-830c-e020174918cc"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Expand All @@ -27,6 +28,7 @@ FastGaussQuadrature = "0.4"
FillArrays = "0.12"
ForwardDiff = "0.10"
GPLikelihoods = "0.1, 0.2"
IrrationalConstants = "0.1"
KLDivergences = "0.2.1"
PDMats = "0.11"
Reexport = "1"
Expand Down
1 change: 1 addition & 0 deletions src/ApproximateGPs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ using ChainRulesCore
using FillArrays
using KLDivergences
using PDMats: chol_lower
using IrrationalConstants
theogf marked this conversation as resolved.
Show resolved Hide resolved

using AbstractGPs: AbstractGP, FiniteGP, LatentFiniteGP, ApproxPosteriorGP, At_A, diag_At_A

Expand Down
15 changes: 12 additions & 3 deletions src/expected_loglik.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,18 @@ function expected_loglik(
# (see e.g. en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)
xs, ws = gausshermite(gh.n_points)
# size(fs): (length(y), n_points)
theogf marked this conversation as resolved.
Show resolved Hide resolved
fs = √2 * std.(q_f) .* xs' .+ mean.(q_f)
theogf marked this conversation as resolved.
Show resolved Hide resolved
lls = loglikelihood.(lik.(fs), y)
return sum((1 / √π) * lls * ws)
return sum(Broadcast.instantiate(
theogf marked this conversation as resolved.
Show resolved Hide resolved
Broadcast.broadcasted(q_f, y) do q, y
st-- marked this conversation as resolved.
Show resolved Hide resolved
μ = mean(q)
σ = std(q)
sum(Broadcast.instantiate(
Broadcast.broadcasted(xs, ws) do x, w
f = sqrt2 * σ * x + μ
loglikelihood(lik(f), y) * w
end
theogf marked this conversation as resolved.
Show resolved Hide resolved
))
end
)) * invsqrtπ
theogf marked this conversation as resolved.
Show resolved Hide resolved
end

ChainRulesCore.@non_differentiable gausshermite(n)
Expand Down
14 changes: 14 additions & 0 deletions test/expected_loglik.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,18 @@
GaussHermite(), zeros(10), q_f, GaussianLikelihood()
) isa Real
@test ApproximateGPs._default_quadrature(θ -> Normal(0, θ)) isa GaussHermite

@testset "testing Zygote compatibility with GaussHermite" begin # see issue #82
N = 10
gh = GaussHermite(12)
μ = randn(rng, N)
σ = rand(rng, N)
for lik in likelihoods_to_test
y = rand.(rng, lik.(rand.(Normal.(μ, σ))))
g = only(Zygote.gradient(μ) do x
ApproximateGPs.expected_loglik(gh, y, Normal.(x, σ), lik)
end)
theogf marked this conversation as resolved.
Show resolved Hide resolved
@test all(isfinite, g)
end
end
end