Skip to content

Commit

Permalink
fix some bug in eval_num from #740 (#770)
Browse files Browse the repository at this point in the history
  • Loading branch information
mohamed82008 authored and yebai committed Apr 28, 2019
1 parent e67bb64 commit 502027b
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
4 changes: 2 additions & 2 deletions src/core/RandomVariables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ const VarInfo = AbstractVarInfo

function Turing.runmodel!(model::Model, vi::AbstractVarInfo, spl::AbstractSampler = SampleFromPrior())
setlogp!(vi, zero(Float64))
if spl isa Sampler && isdefined(spl.info, :eval_num)
spl.info.eval_num += 1
if spl isa Sampler && haskey(spl.info, :eval_num)
spl.info[:eval_num] += 1
end
model(vi, spl)
return vi
Expand Down
17 changes: 16 additions & 1 deletion test/core/RandomVariables.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Turing, Random
using Turing: Selector, reconstruct, invlink, CACHERESET, SampleFromPrior
using Turing: Selector, reconstruct, invlink, CACHERESET,
SampleFromPrior, Sampler, runmodel!
using Turing.RandomVariables
using Turing.RandomVariables: uid, cuid, getvals, getidcs,
set_retained_vns_del_by_spl!, is_flagged,
Expand All @@ -13,6 +14,20 @@ i, j, k = 1, 2, 3
include("../test_utils/AllUtils.jl")

@testset "RandomVariables.jl" begin
@turing_testset "runmodel!" begin
@model testmodel() = begin
x ~ Normal()
end
alg = HMC(1000, 0.1, 5)
spl = Sampler(alg)
vi = VarInfo()
m = testmodel()
m(vi)
runmodel!(m, vi, spl)
@test spl.info[:eval_num] == 1
runmodel!(m, vi, spl)
@test spl.info[:eval_num] == 2
end
@turing_testset "flags" begin
vi = VarInfo()
vn_x = VarName(gensym(), :x, "", 1)
Expand Down

0 comments on commit 502027b

Please sign in to comment.