Skip to content

Commit

Permalink
Renamed runmodel! to logp!; initial support for catching numerica…
Browse files Browse the repository at this point in the history
…l errors. (#634)
  • Loading branch information
yebai committed Mar 11, 2019
1 parent c512d70 commit 7dd253c
Show file tree
Hide file tree
Showing 11 changed files with 51 additions and 39 deletions.
2 changes: 1 addition & 1 deletion src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ end
return sym in data.types ? :(true) : :(false)
end
(model::Model)(args...; kwargs...) = model.f(args..., model; kwargs...)
function runmodel! end
function logp! end

abstract type AbstractSampler end

Expand Down
2 changes: 1 addition & 1 deletion src/core/Core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module Core
using MacroTools, Libtask, ForwardDiff
using ..Utilities, Reexport
using Flux.Tracker: Tracker
using ..Turing: Turing, Model, runmodel!,
using ..Turing: Turing, Model, logp!,
AbstractSampler, Sampler, SampleFromPrior

include("VarReplay.jl")
Expand Down
60 changes: 36 additions & 24 deletions src/core/VarReplay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,27 @@ using Distributions
import Base: string, isequal, ==, hash, getindex, setindex!, push!, show, isempty
import Turing: link, invlink

export VarName,
VarInfo,
uid,
sym,
getlogp,
set_retained_vns_del_by_spl!,
resetlogp!,
is_flagged,
unset_flag!,
setgid!,
copybyindex,
setorder!,
updategid!,
acclogp!,
istrans,
link!,
invlink!,
setlogp!,
getranges,
getrange,
getvns,
export VarName,
VarInfo,
uid,
sym,
getlogp,
set_retained_vns_del_by_spl!,
resetlogp!,
is_flagged,
unset_flag!,
setgid!,
copybyindex,
setorder!,
updategid!,
acclogp!,
istrans,
link!,
invlink!,
setlogp!,
getranges,
getrange,
getvns,
getval

###########
Expand Down Expand Up @@ -103,17 +103,29 @@ mutable struct VarInfo
end
end

@generated function Turing.runmodel!(model::Model, vi::VarInfo, spl::AbstractSampler)
@generated function Turing.logp!(model::Model, vi::VarInfo, spl::AbstractSampler)
expr_eval_num = spl <: Sampler ?
:(if :eval_num keys(spl.info) spl.info[:eval_num] += 1 end) : :()
return quote
setlogp!(vi, zero(Real))
$(expr_eval_num)
model(vi, spl)
vi.flags["divergent"] = [false]
try
model(vi, spl)
#log(-1)
catch e
if e isa DomainError
@warn "Numerical error has been detected: $(typeof(e))"
vi.flags["divergent"] = [true]
setlogp!(vi, log(0)*vi.logp)
else
throw(e)
end
end
return vi
end
end
Turing.runmodel!(model::Model, vi::VarInfo) = Turing.runmodel!(model, vi, SampleFromPrior())
Turing.logp!(model::Model, vi::VarInfo) = Turing.logp!(model, vi, SampleFromPrior())

const VarView = Union{Int,UnitRange,Vector{Int},Vector{UnitRange}}

Expand Down
6 changes: 3 additions & 3 deletions src/core/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ function gradient_logp_forward(
# Define function to compute log joint.
function f(θ)
vi[sampler] = θ
return runmodel!(model, vi, sampler).logp
return logp!(model, vi, sampler).logp
end

chunk_size = getchunksize(sampler)
Expand Down Expand Up @@ -147,7 +147,7 @@ function gradient_logp_reverse(
# Specify objective function.
function f(θ)
vi[sampler] = θ
return runmodel!(model, vi, sampler).logp
return logp!(model, vi, sampler).logp
end

# Compute forward and reverse passes.
Expand All @@ -167,7 +167,7 @@ end
function verifygrad(grad::AbstractVector{<:Real})
if any(isnan, grad) || any(isinf, grad)
@warn("Numerical error in gradients. Rejecting current proposal...")
@warn("grad = $(grad)")
Turing.DEBUG && @debug("grad = $(grad)")
return false
else
return true
Expand Down
2 changes: 1 addition & 1 deletion src/inference/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using ..Core, ..Core.VarReplay, ..Utilities
using Distributions, Libtask, Bijectors
using ProgressMeter, LinearAlgebra
using ..Turing: PROGRESS, CACHERESET, AbstractSampler
using ..Turing: Model, runmodel!, get_pvars, get_dvars,
using ..Turing: Model, logp!, get_pvars, get_dvars,
Sampler, SampleFromPrior, HamiltonianRobustInit
using ..Turing: in_pvars, in_dvars, Turing
using StatsFuns: logsumexp
Expand Down
2 changes: 1 addition & 1 deletion src/inference/dynamichmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ function sample(model::Model,

if spl.alg.gid == 0
link!(vi, spl)
runmodel!(model, vi, spl)
logp!(model, vi, spl)
end

function _lp(x)
Expand Down
4 changes: 2 additions & 2 deletions src/inference/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ function sample(model::Model, alg::Hamiltonian;

if spl.alg.gid == 0
link!(vi, spl)
runmodel!(model, vi, spl)
logp!(model, vi, spl)
end

# HMC steps
Expand Down Expand Up @@ -221,7 +221,7 @@ function step(model, spl::Sampler{<:Hamiltonian}, vi::VarInfo, is_first::Val{fal
Turing.DEBUG && @debug "X-> R..."
if spl.alg.gid != 0
link!(vi, spl)
runmodel!(model, vi, spl)
logp!(model, vi, spl)
end

grad_func = gen_grad_func(vi, spl, model)
Expand Down
6 changes: 3 additions & 3 deletions src/inference/mh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ function propose(model, spl::Sampler{<:MH}, vi::VarInfo)
spl.info[:proposal_ratio] = 0.0
spl.info[:prior_prob] = 0.0
spl.info[:violating_support] = false
return runmodel!(model, vi ,spl)
return logp!(model, vi ,spl)
end

function step(model, spl::Sampler{<:MH}, vi::VarInfo, is_first::Val{true})
Expand All @@ -81,7 +81,7 @@ end

function step(model, spl::Sampler{<:MH}, vi::VarInfo, is_first::Val{false})
if spl.alg.gid != 0 # Recompute joint in logp
runmodel!(model, vi)
logp!(model, vi)
end
old_θ = copy(vi[spl])
old_logp = getlogp(vi)
Expand Down Expand Up @@ -135,7 +135,7 @@ function sample(model::Model, alg::MH;
end

if spl.alg.gid == 0
runmodel!(model, vi, spl)
logp!(model, vi, spl)
end

# MH steps
Expand Down
2 changes: 1 addition & 1 deletion src/inference/sghmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ function step(model, spl::Sampler{<:SGHMC}, vi::VarInfo, is_first::Val{false})
Turing.DEBUG && @debug "X-> R..."
if spl.alg.gid != 0
link!(vi, spl)
runmodel!(model, vi, spl)
logp!(model, vi, spl)
end

Turing.DEBUG && @debug "recording old variables..."
Expand Down
2 changes: 1 addition & 1 deletion src/inference/sgld.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ function step(model, spl::Sampler{<:SGLD}, vi::VarInfo, is_first::Val{false})
Turing.DEBUG && @debug "X-> R..."
if spl.alg.gid != 0
link!(vi, spl)
runmodel!(model, vi, spl)
logp!(model, vi, spl)
end

Turing.DEBUG && @debug "recording old variables..."
Expand Down
2 changes: 1 addition & 1 deletion src/inference/support/hmc_core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ Generate a function that takes `θ` and returns logpdf at `θ` for the model spe
function gen_lj_func(vi::VarInfo, sampler::Sampler, model)
return function::AbstractVector{<:Real})
vi[sampler] = θ
return runmodel!(model, vi, sampler).logp
return logp!(model, vi, sampler).logp
end
end

Expand Down

0 comments on commit 7dd253c

Please sign in to comment.