-
Notifications
You must be signed in to change notification settings - Fork 43
/
Copy pathintegrator.jl
52 lines (46 loc) · 1.79 KB
/
integrator.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
abstract type AbstractIntegrator end
struct Leapfrog{T<:Real} <: AbstractIntegrator
ϵ :: T
end
function is_valid(v::AbstractVector{<:Real})
if any(isnan, v) || any(isinf, v)
return false
end
return true
end
function lf_momentum(ϵ::T, h::Hamiltonian, θ::AbstractVector{T}, r::AbstractVector{T}) where {T<:Real}
_∂H∂θ = ∂H∂θ(h, θ)
!is_valid(_∂H∂θ) && return r, false
return r - ϵ * _∂H∂θ, true
end
function lf_position(ϵ::T, h::Hamiltonian, θ::AbstractVector{T}, r::AbstractVector{T}) where {T<:Real}
return θ + ϵ * ∂H∂r(h, r)
end
function step(lf::Leapfrog{T}, h::Hamiltonian, θ::AbstractVector{T}, r::AbstractVector{T}) where {T<:Real}
r_new, _is_valid = lf_momentum(lf.ϵ / 2, h, θ, r)
!_is_valid && return θ, r, false
θ_new = lf_position(lf.ϵ, h, θ, r_new)
r_new, _is_valid = lf_momentum(lf.ϵ / 2, h, θ_new, r_new)
!_is_valid && return θ, r, false
return θ_new, r_new, true
end
# TODO: double check the function below to see if it is type stable or not
function steps(lf::Leapfrog{T}, h::Hamiltonian, θ::AbstractVector{T}, r::AbstractVector{T}, n_steps::Int) where {T<:Real}
n_valid = 0
r_new, _is_valid = lf_momentum(lf.ϵ / 2, h, θ, r)
!_is_valid && return θ, r, n_valid
r = r_new
for i = 1:n_steps
θ_new = lf_position(lf.ϵ, h, θ, r)
r_new, _is_valid = lf_momentum(i == n_steps ? lf.ϵ / 2 : lf.ϵ, h, θ, r)
if !_is_valid
# The reverse function below is guarantee to be numerical safe.
# This is because we know the previous step was valid.
r, _ = lf_momentum(-lf.ϵ / 2, h, θ, r)
return θ, r, n_valid
end
θ, r = θ_new, r_new
n_valid = n_valid + 1
end
return θ, r, n_valid
end