Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transfer some test utility function into DynamicPPL #2049

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 3 additions & 38 deletions test/modes/OptimInterface.jl
Original file line number Diff line number Diff line change
@@ -1,38 +1,3 @@
# TODO: Remove these once the equivalent is present in `DynamicPPL.TestUtils.
function likelihood_optima(::DynamicPPL.TestUtils.UnivariateAssumeDemoModels)
return (s=1/16, m=7/4)
end
function posterior_optima(::DynamicPPL.TestUtils.UnivariateAssumeDemoModels)
# TODO: Figure out exact for `s`.
return (s=0.907407, m=7/6)
end

function likelihood_optima(model::DynamicPPL.TestUtils.MultivariateAssumeDemoModels)
# Get some containers to fill.
vals = Random.rand(model)

# NOTE: These are "as close to zero as we can get".
vals.s[1] = 1e-32
vals.s[2] = 1e-32

vals.m[1] = 1.5
vals.m[2] = 2.0

return vals
end
function posterior_optima(model::DynamicPPL.TestUtils.MultivariateAssumeDemoModels)
# Get some containers to fill.
vals = Random.rand(model)

# TODO: Figure out exact for `s[1]`.
vals.s[1] = 0.890625
vals.s[2] = 1
vals.m[1] = 3/4
vals.m[2] = 1

return vals
end

# Used for testing how well it works with nested contexts.
struct OverrideContext{C,T1,T2} <: DynamicPPL.AbstractContext
context::C
Expand All @@ -57,7 +22,7 @@ function DynamicPPL.tilde_observe(context::OverrideContext, right, left, vi)
return context.loglikelihood_weight, vi
end

@testset "OptimInterface.jl" begin
@numerical_testset "OptimInterface.jl" begin
@testset "MLE" begin
Random.seed!(222)
true_value = [0.0625, 1.75]
Expand Down Expand Up @@ -157,7 +122,7 @@ end
# FIXME: Some models doesn't work for Tracker and ReverseDiff.
if Turing.Essential.ADBACKEND[] === :forwarddiff
@testset "MAP for $(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
result_true = posterior_optima(model)
result_true = DynamicPPL.TestUtils.posterior_optima(model)

@testset "$(nameof(typeof(optimizer)))" for optimizer in [LBFGS(), NelderMead()]
result = optimize(model, MAP(), optimizer)
Expand Down Expand Up @@ -188,7 +153,7 @@ end
DynamicPPL.TestUtils.demo_dot_assume_matrix_dot_observe_matrix,
]
@testset "MLE for $(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
result_true = likelihood_optima(model)
result_true = DynamicPPL.TestUtils.likelihood_optima(model)

# `NelderMead` seems to struggle with convergence here, so we exclude it.
@testset "$(nameof(typeof(optimizer)))" for optimizer in [LBFGS(),]
Expand Down