-
Notifications
You must be signed in to change notification settings - Fork 43
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
More robust and modular way of detecting divergence #16
Comments
Update: The current NUTS implementation rejects only leaf nodes that contain numerical errors (by setting leaf nodes energy to Inf, see below). This leads to biased samples in general. We need to reject the trajectory increment that contains the error (e.g. numerical error or divergent Hamiltonian). This increment is the part of a trajectory that is added to the current trajectory during a doubling tree step (i.e. the subtree being created). AdvancedHMC.jl/src/trajectory.jl Line 85 in 734c0fa
|
Update: The following 2 lines in AdvancedHMC.jl/src/stepsize.jl Line 16 in e7bb08b
AdvancedHMC.jl/src/stepsize.jl Line 18 in e7bb08b
|
Update: The following function should be removed: taking the last point from a AdvancedHMC.jl/src/trajectory.jl Line 118 in 734c0fa
|
Regarding #16 (comment), does it mean we should reject the whole sub-tree? Regarding #16 (comment), I guess I name the function wrong. This actually doing a slice sampling from the trajectory. |
Yes, we should reject each trajectory increment that causes numerical error / divergent Hamiltonian. |
@yebai I think we are tackling this correctly in Turing; see https://github.com/TuringLang/Turing.jl/blob/master/src/inference/nuts.jl#L91-L93. An invalid numerical error would cause I guess DynamicHMC follows a different abstraction from the NUTS paper so we didn't realise that before. What I can do is to add this divergence abstraction into the NUTS code to make it more readable if that makes sense. |
I see - that makes sense. Introducing some abstractions for
sounds like a good idea. |
Good plan! |
update: The following line seems to be problematic, since the AdvancedHMC.jl/src/proposal.jl Line 101 in 64a72da
|
Nice catch! I'm wondering what's the best way to fix it, since we lost our interface of passing step size.
|
update: the following line seems problematic too: AdvancedHMC.jl/src/integrator.jl Lines 45 to 46 in a0af7cc
shouldn't θ_new = lf_position(lf.ϵ, h, θ, r)
r_new, _is_valid = lf_momentum(i == n_steps ? lf.ϵ / 2 : lf.ϵ, h, θ_new, r) |
I'll add a |
So what do we really want here? @yebai |
Approximation error of leapfrog integration (i.e. accumulated Hamiltonian energy error) can sometimes explode, for example when the curvature of the current region is very high. This type of approximation error is sometimes called
divergence
[1] since it shifts a leapfrog simulation away from the correct solution.In Turing, this type of errors is currently caught a relatively ad-hoc function called
is_valid
,AdvancedHMC.jl/src/integrator.jl
Line 7 in 734c0fa
is_valid
can catch cases where one or more elements of the parameter vector is eithernan
orinf
. This has several drawbacksAdvancedHMC.jl/src/integrator.jl
Line 26 in 734c0fa
AdvancedHMC.jl/src/integrator.jl
Line 45 in 734c0fa
Therefore, we might want to refactor the current code a bit for a more robust mechanism for handling leapfrog approximation errors. Perhaps we can learn from the
DynamicHMC
implementation:https://github.com/tpapp/DynamicHMC.jl/blob/master/src/buildingblocks.jl#L168
[1]: Betancourt, M. (2017). A Conceptual Introduction to Hamiltonian Monte Carlo. arXiv preprint arXiv:1701.02434.
The text was updated successfully, but these errors were encountered: