diff --git a/src/core/RandomVariables.jl b/src/core/RandomVariables.jl index 5c8d84677..1876e3f79 100644 --- a/src/core/RandomVariables.jl +++ b/src/core/RandomVariables.jl @@ -76,8 +76,8 @@ const VarInfo = AbstractVarInfo function Turing.runmodel!(model::Model, vi::AbstractVarInfo, spl::AbstractSampler = SampleFromPrior()) setlogp!(vi, zero(Float64)) - if spl isa Sampler && isdefined(spl.info, :eval_num) - spl.info.eval_num += 1 + if spl isa Sampler && haskey(spl.info, :eval_num) + spl.info[:eval_num] += 1 end model(vi, spl) return vi diff --git a/test/core/RandomVariables.jl b/test/core/RandomVariables.jl index c84e2ac17..4c0e0083c 100644 --- a/test/core/RandomVariables.jl +++ b/test/core/RandomVariables.jl @@ -1,5 +1,6 @@ using Turing, Random -using Turing: Selector, reconstruct, invlink, CACHERESET, SampleFromPrior +using Turing: Selector, reconstruct, invlink, CACHERESET, + SampleFromPrior, Sampler, runmodel! using Turing.RandomVariables using Turing.RandomVariables: uid, cuid, getvals, getidcs, set_retained_vns_del_by_spl!, is_flagged, @@ -13,6 +14,20 @@ i, j, k = 1, 2, 3 include("../test_utils/AllUtils.jl") @testset "RandomVariables.jl" begin + @turing_testset "runmodel!" begin + @model testmodel() = begin + x ~ Normal() + end + alg = HMC(1000, 0.1, 5) + spl = Sampler(alg) + vi = VarInfo() + m = testmodel() + m(vi) + runmodel!(m, vi, spl) + @test spl.info[:eval_num] == 1 + runmodel!(m, vi, spl) + @test spl.info[:eval_num] == 2 + end @turing_testset "flags" begin vi = VarInfo() vn_x = VarName(gensym(), :x, "", 1)