Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
trappmartin committed Apr 9, 2019
1 parent 5560d9f commit 2fd3bff
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 13 deletions.
12 changes: 5 additions & 7 deletions src/inference/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using ProgressMeter, LinearAlgebra
using ..Turing: PROGRESS, CACHERESET, AbstractRunner
using ..Turing: Model, runmodel!, get_pvars, get_dvars,
Sampler, SampleFromPrior, SampleFromUniform,
Selector
Selector, ComputeLogJointDensity
using ..Turing: in_pvars, in_dvars, Turing
using StatsFuns: logsumexp

Expand Down Expand Up @@ -80,6 +80,8 @@ getADtype(::Type{<:Hamiltonian{AD}}) where {AD} = AD
include("adapt/adapt.jl")
include("support/hmc_core.jl")

include("runners.jl")

# Concrete algorithm implementations.
include("hmcda.jl")
include("nuts.jl")
Expand Down Expand Up @@ -174,10 +176,7 @@ function assume(spl::A,
end


observe(::Nothing,
dist::T,
value::Any,
vi::VarInfo) where T = observe(SampleFromPrior(), dist, value, vi)
@inline observe(::Nothing, dist, value, vi) = observe(SampleFromPrior(), dist, value, vi)

function observe(spl::A,
dist::Distribution,
Expand All @@ -189,8 +188,7 @@ function observe(spl::A,
Turing.DEBUG && @debug "value = $value"

# acclogp!(vi, logpdf(dist, value))
logpdf(dist, value)

return logpdf(dist, value)
end

function observe(spl::A,
Expand Down
116 changes: 110 additions & 6 deletions src/inference/runners.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,114 @@
function _getdist(dists::Vector{<:Distribution})
@assert length(dists) == 1 "[observe] Turing only support vectorizing iid distribution."
return first(dists)
end

##
# Default definition when runner = nothing
##

function observe(::Nothing, dist::Distribution, value, vi::VarInfo)
vi.num_produce += one(vi.num_produce)
return logpdf(dist, value)
end

function observe(::Nothing, dists::Vector{<:UnivariateDistribution}, values, vi::VarInfo)
dist = _getdist(dists)
return sum(logpdf.(dist, values))
end

# NOTE: this is necessary as we cannot use broadcasting for MV dists.
function observe(::Nothing, dists::Vector{<:MultivariateDistribution}, values, vi::VarInfo)
dist = _getdist(dists)
return sum(logpdf(dist, values))
end

##
# Sample from prior
##

function assume(spl::SampleFromPrior, dist::Distribution, vn::VarName, vi::VarInfo)

if haskey(vi, vn)
r = vi[vn]
else
r = rand(dist)
push!(vi, vn, r, dist)
end

return r, logpdf_with_trans(dist, r, istrans(vi, vn))
end

function assume(spl::SampleFromUniform, dist::Distribution, vn::VarName, vi::VarInfo)

if haskey(vi, vn)
r = vi[vn]
else
r = init(dist)
push!(vi, vn, r, dist)
end
# NOTE: The importance weight is not correctly computed here because
# r is genereated from some uniform distribution which is different from the prior

return r, logpdf_with_trans(dist, r, istrans(vi, vn))
end

function assume(spl::A,
dists::Vector{T},
vn::VarName,
var::Any,
vi::VarInfo) where {T<:Distribution, A<:Union{SampleFromPrior, SampleFromUniform}}

@assert length(dists) == 1 "Turing.assume only support vectorizing i.i.d distribution"
dist = dists[1]
n = size(var)[end]

vns = map(i -> copybyindex(vn, "[$i]"), 1:n)

if haskey(vi, vns[1])
rs = vi[vns]
else
rs = isa(spl, SampleFromUniform) ? init(dist, n) : rand(dist, n)

if isa(dist, UnivariateDistribution) || isa(dist, MatrixDistribution)
for i = 1:n
push!(vi, vns[i], rs[i], dist)
end
@assert size(var) == size(rs) "Turing.assume: variable and random number dimension unmatched"
var = rs
elseif isa(dist, MultivariateDistribution)
for i = 1:n
push!(vi, vns[i], rs[:,i], dist)
end
if isa(var, Vector)
@assert length(var) == size(rs)[2] "Turing.assume: variable and random number dimension unmatched"
for i = 1:n
var[i] = rs[:,i]
end
elseif isa(var, Matrix)
@assert size(var) == size(rs) "Turing.assume: variable and random number dimension unmatched"
var = rs
else
@error("Turing.assume: unsupported variable container"); error()
end
end
end

# acclogp!(vi, sum(logpdf_with_trans(dist, rs, istrans(vi, vns[1]))))

var, sum(logpdf_with_trans(dist, rs, istrans(vi, vns[1])))

end

@inline observe(spl::SampleFromPrior, dist, value, vi) = observe(nothing, dist, value, vi)
@inline observe(spl::SampleFromUniform, dist, value, vi) = observe(nothing, dist, value, vi)

###
# Functions for runner to compute the log joint.
###

function assume(spl::ComputeLogJointDensity,
dist::Distribution,
vn::VarName,
vi::VarInfo)

function assume(spl::ComputeLogJointDensity, dist::Distribution, vn::VarName, vi::VarInfo)
@assert haskey(vi, vn)
r = vi[vn]
return r, logpdf_with_trans(dist, r, istrans(vi, vn))
end
Expand All @@ -17,7 +119,7 @@ function assume(spl::ComputeLogJointDensity,
var,
vi::VarInfo)

@assert length(dists) == 1 "[Turing.assume] Turing only supports vectorizing iid distributions"
@assert length(dists) == 1 "[assume] Turing only supports vectorizing iid distributions"

dist = first(dist)
N = size(var)[end]
Expand Down Expand Up @@ -46,3 +148,5 @@ function assume(spl::ComputeLogJointDensity,

return var, sum(logpdf_with_trans(dist, rs, istrans(vi, first(vns))))
end

@inline observe(spl::ComputeLogJointDensity, dist, value, vi) = observe(nothing, dist, value, vi)
36 changes: 36 additions & 0 deletions test/inference/common/runners.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
using Turing, Random, Bijectors
using Distributions
using Turing.RandomVariables
using Test

@testset "Common/Runners" begin
@turing_testset "ComputeLogJointDensity" begin

# Test assume with single random variable (unconstrained)
vi = VarInfo()
vn = VarName(gensym(), :x, "", 1)
dist = Normal(0,1)
r = rand(dist)
gid = Turing.Selector()
runner = Turing.ComputeLogJointDensity()

push!(vi, vn, r, dist, gid)
@test Turing.assume(runner, dist, vn, vi) == (r, logpdf(dist, r))

# Test assume with single random variable (constrained)
vi = VarInfo()
vn = VarName(gensym(), :x, "", 1)
dist = Truncated(Normal(0,1), 0, Inf)
r = link(dist, rand(dist))
gid = Turing.Selector()
runner = Turing.ComputeLogJointDensity()

push!(vi, vn, r, dist, gid)
Turing.RandomVariables.settrans!(vi, true, vn)
r_ = invlink(dist, r)
@test Turing.assume(runner, dist, vn, vi) == (r_, logpdf_with_trans(dist, r_, true))

# Test observe with single random variable

end
end

0 comments on commit 2fd3bff

Please sign in to comment.