diff --git a/src/Turing.jl b/src/Turing.jl index 3e15b4978..ffa7b23b7 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -60,7 +60,7 @@ end return sym in data.types ? :(true) : :(false) end (model::Model)(args...; kwargs...) = model.f(args..., model; kwargs...) -function runmodel! end +function logp! end abstract type AbstractSampler end diff --git a/src/core/Core.jl b/src/core/Core.jl index e1682708c..a13513a3e 100644 --- a/src/core/Core.jl +++ b/src/core/Core.jl @@ -3,7 +3,7 @@ module Core using MacroTools, Libtask, ForwardDiff using ..Utilities, Reexport using Flux.Tracker: Tracker -using ..Turing: Turing, Model, runmodel!, +using ..Turing: Turing, Model, logp!, AbstractSampler, Sampler, SampleFromPrior include("VarReplay.jl") diff --git a/src/core/VarReplay.jl b/src/core/VarReplay.jl index c2865ca02..8cd32d9f6 100644 --- a/src/core/VarReplay.jl +++ b/src/core/VarReplay.jl @@ -9,27 +9,27 @@ using Distributions import Base: string, isequal, ==, hash, getindex, setindex!, push!, show, isempty import Turing: link, invlink -export VarName, - VarInfo, - uid, - sym, - getlogp, - set_retained_vns_del_by_spl!, - resetlogp!, - is_flagged, - unset_flag!, - setgid!, - copybyindex, - setorder!, - updategid!, - acclogp!, - istrans, - link!, - invlink!, - setlogp!, - getranges, - getrange, - getvns, +export VarName, + VarInfo, + uid, + sym, + getlogp, + set_retained_vns_del_by_spl!, + resetlogp!, + is_flagged, + unset_flag!, + setgid!, + copybyindex, + setorder!, + updategid!, + acclogp!, + istrans, + link!, + invlink!, + setlogp!, + getranges, + getrange, + getvns, getval ########### @@ -103,17 +103,29 @@ mutable struct VarInfo end end -@generated function Turing.runmodel!(model::Model, vi::VarInfo, spl::AbstractSampler) +@generated function Turing.logp!(model::Model, vi::VarInfo, spl::AbstractSampler) expr_eval_num = spl <: Sampler ? :(if :eval_num ∈ keys(spl.info) spl.info[:eval_num] += 1 end) : :() return quote setlogp!(vi, zero(Real)) $(expr_eval_num) - model(vi, spl) + vi.flags["divergent"] = [false] + try + model(vi, spl) + #log(-1) + catch e + if e isa DomainError + @warn "Numerical error has been detected: $(typeof(e))" + vi.flags["divergent"] = [true] + setlogp!(vi, log(0)*vi.logp) + else + throw(e) + end + end return vi end end -Turing.runmodel!(model::Model, vi::VarInfo) = Turing.runmodel!(model, vi, SampleFromPrior()) +Turing.logp!(model::Model, vi::VarInfo) = Turing.logp!(model, vi, SampleFromPrior()) const VarView = Union{Int,UnitRange,Vector{Int},Vector{UnitRange}} diff --git a/src/core/ad.jl b/src/core/ad.jl index 1a6f107ff..e4bc7bc86 100644 --- a/src/core/ad.jl +++ b/src/core/ad.jl @@ -106,7 +106,7 @@ function gradient_logp_forward( # Define function to compute log joint. function f(θ) vi[sampler] = θ - return runmodel!(model, vi, sampler).logp + return logp!(model, vi, sampler).logp end chunk_size = getchunksize(sampler) @@ -147,7 +147,7 @@ function gradient_logp_reverse( # Specify objective function. function f(θ) vi[sampler] = θ - return runmodel!(model, vi, sampler).logp + return logp!(model, vi, sampler).logp end # Compute forward and reverse passes. @@ -167,7 +167,7 @@ end function verifygrad(grad::AbstractVector{<:Real}) if any(isnan, grad) || any(isinf, grad) @warn("Numerical error in gradients. Rejecting current proposal...") - @warn("grad = $(grad)") + Turing.DEBUG && @debug("grad = $(grad)") return false else return true diff --git a/src/inference/Inference.jl b/src/inference/Inference.jl index 141b3bddd..f2943b7de 100644 --- a/src/inference/Inference.jl +++ b/src/inference/Inference.jl @@ -4,7 +4,7 @@ using ..Core, ..Core.VarReplay, ..Utilities using Distributions, Libtask, Bijectors using ProgressMeter, LinearAlgebra using ..Turing: PROGRESS, CACHERESET, AbstractSampler -using ..Turing: Model, runmodel!, get_pvars, get_dvars, +using ..Turing: Model, logp!, get_pvars, get_dvars, Sampler, SampleFromPrior, HamiltonianRobustInit using ..Turing: in_pvars, in_dvars, Turing using StatsFuns: logsumexp diff --git a/src/inference/dynamichmc.jl b/src/inference/dynamichmc.jl index 2bb7f6a99..d7d41801c 100644 --- a/src/inference/dynamichmc.jl +++ b/src/inference/dynamichmc.jl @@ -55,7 +55,7 @@ function sample(model::Model, if spl.alg.gid == 0 link!(vi, spl) - runmodel!(model, vi, spl) + logp!(model, vi, spl) end function _lp(x) diff --git a/src/inference/hmc.jl b/src/inference/hmc.jl index 890018162..7a404e53e 100644 --- a/src/inference/hmc.jl +++ b/src/inference/hmc.jl @@ -139,7 +139,7 @@ function sample(model::Model, alg::Hamiltonian; if spl.alg.gid == 0 link!(vi, spl) - runmodel!(model, vi, spl) + logp!(model, vi, spl) end # HMC steps @@ -221,7 +221,7 @@ function step(model, spl::Sampler{<:Hamiltonian}, vi::VarInfo, is_first::Val{fal Turing.DEBUG && @debug "X-> R..." if spl.alg.gid != 0 link!(vi, spl) - runmodel!(model, vi, spl) + logp!(model, vi, spl) end grad_func = gen_grad_func(vi, spl, model) diff --git a/src/inference/mh.jl b/src/inference/mh.jl index bfc91c79c..4b03114f0 100644 --- a/src/inference/mh.jl +++ b/src/inference/mh.jl @@ -72,7 +72,7 @@ function propose(model, spl::Sampler{<:MH}, vi::VarInfo) spl.info[:proposal_ratio] = 0.0 spl.info[:prior_prob] = 0.0 spl.info[:violating_support] = false - return runmodel!(model, vi ,spl) + return logp!(model, vi ,spl) end function step(model, spl::Sampler{<:MH}, vi::VarInfo, is_first::Val{true}) @@ -81,7 +81,7 @@ end function step(model, spl::Sampler{<:MH}, vi::VarInfo, is_first::Val{false}) if spl.alg.gid != 0 # Recompute joint in logp - runmodel!(model, vi) + logp!(model, vi) end old_θ = copy(vi[spl]) old_logp = getlogp(vi) @@ -135,7 +135,7 @@ function sample(model::Model, alg::MH; end if spl.alg.gid == 0 - runmodel!(model, vi, spl) + logp!(model, vi, spl) end # MH steps diff --git a/src/inference/sghmc.jl b/src/inference/sghmc.jl index 26187c0fc..ec35102df 100644 --- a/src/inference/sghmc.jl +++ b/src/inference/sghmc.jl @@ -69,7 +69,7 @@ function step(model, spl::Sampler{<:SGHMC}, vi::VarInfo, is_first::Val{false}) Turing.DEBUG && @debug "X-> R..." if spl.alg.gid != 0 link!(vi, spl) - runmodel!(model, vi, spl) + logp!(model, vi, spl) end Turing.DEBUG && @debug "recording old variables..." diff --git a/src/inference/sgld.jl b/src/inference/sgld.jl index 49796d8c4..60c480562 100644 --- a/src/inference/sgld.jl +++ b/src/inference/sgld.jl @@ -74,7 +74,7 @@ function step(model, spl::Sampler{<:SGLD}, vi::VarInfo, is_first::Val{false}) Turing.DEBUG && @debug "X-> R..." if spl.alg.gid != 0 link!(vi, spl) - runmodel!(model, vi, spl) + logp!(model, vi, spl) end Turing.DEBUG && @debug "recording old variables..." diff --git a/src/inference/support/hmc_core.jl b/src/inference/support/hmc_core.jl index 1e643b16d..dce122c3d 100644 --- a/src/inference/support/hmc_core.jl +++ b/src/inference/support/hmc_core.jl @@ -21,7 +21,7 @@ Generate a function that takes `θ` and returns logpdf at `θ` for the model spe function gen_lj_func(vi::VarInfo, sampler::Sampler, model) return function(θ::AbstractVector{<:Real}) vi[sampler] = θ - return runmodel!(model, vi, sampler).logp + return logp!(model, vi, sampler).logp end end