Skip to content

Commit

Permalink
Work in progress.
Browse files Browse the repository at this point in the history
  • Loading branch information
trappmartin committed Apr 13, 2019
1 parent 2fd3bff commit c4aa64c
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 214 deletions.
11 changes: 9 additions & 2 deletions src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,16 @@ hash(s::Selector) = hash(s.gid)

abstract type AbstractRunner end
#abstract type AbstractSampler end
abstract type SampleFromDistribution end

struct SampleFromPrior <: SampleFromDistribution end
_rand(dist::Distribution) = rand(dist)
_rand(dist::distribution, n::Int) = rand(dist, n)

struct SampleFromUniform <: SampleFromDistribution end
_rand(dist::Distribution) = init(dist)
_rand(dist::Distributionm, n::Int) = init(dist, n)

struct SampleFromUniform <: AbstractRunner end
struct SampleFromPrior <: AbstractRunner end
struct ComputeLogJointDensity <: AbstractRunner end
struct ComputeLogDensity <: AbstractRunner end
struct ParticleFiltering <: AbstractRunner end
Expand Down
110 changes: 12 additions & 98 deletions src/inference/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,114 +102,28 @@ include("gibbs.jl")
require_gradient(spl::Sampler) = false
require_particles(spl::Sampler) = false

assume(spl::Sampler, dist::Distribution) =
error("Turing.assume: unmanaged inference algorithm: $(typeof(spl))")

observe(spl::Sampler, weight::Float64) =
error("Turing.observe: unmanaged inference algorithm: $(typeof(spl))")

## Default definitions for assume, observe, when sampler = nothing.
function assume(spl::A,
dist::Distribution,
vn::VarName,
vi::VarInfo) where {A<:Union{SampleFromPrior, SampleFromUniform}}

if haskey(vi, vn)
r = vi[vn]
else
r = isa(spl, SampleFromUniform) ? init(dist) : rand(dist)
push!(vi, vn, r, dist)
end
# NOTE: The importance weight is not correctly computed here because
# r is genereated from some uniform distribution which is different from the prior
# acclogp!(vi, logpdf_with_trans(dist, r, istrans(vi, vn)))
# ############################# #
# Assume and observe functions. #
# ############################# #

r, logpdf_with_trans(dist, r, istrans(vi, vn))
# Those functions have to be implemented for each new sampler.

function assume(spl::Sampler, dist::Distribution, vn::VarName, vi::VarInfo)
@error "[assume]: Unmanaged inference algorithm: $(typeof(spl))"
end

function assume(spl::A,
dists::Vector{T},
vn::VarName,
var::Any,
vi::VarInfo) where {T<:Distribution, A<:Union{SampleFromPrior, SampleFromUniform}}

@assert length(dists) == 1 "Turing.assume only support vectorizing i.i.d distribution"
dist = dists[1]
n = size(var)[end]

vns = map(i -> copybyindex(vn, "[$i]"), 1:n)

if haskey(vi, vns[1])
rs = vi[vns]
else
rs = isa(spl, SampleFromUniform) ? init(dist, n) : rand(dist, n)

if isa(dist, UnivariateDistribution) || isa(dist, MatrixDistribution)
for i = 1:n
push!(vi, vns[i], rs[i], dist)
end
@assert size(var) == size(rs) "Turing.assume: variable and random number dimension unmatched"
var = rs
elseif isa(dist, MultivariateDistribution)
for i = 1:n
push!(vi, vns[i], rs[:,i], dist)
end
if isa(var, Vector)
@assert length(var) == size(rs)[2] "Turing.assume: variable and random number dimension unmatched"
for i = 1:n
var[i] = rs[:,i]
end
elseif isa(var, Matrix)
@assert size(var) == size(rs) "Turing.assume: variable and random number dimension unmatched"
var = rs
else
@error("Turing.assume: unsupported variable container"); error()
end
end
end

# acclogp!(vi, sum(logpdf_with_trans(dist, rs, istrans(vi, vns[1]))))

var, sum(logpdf_with_trans(dist, rs, istrans(vi, vns[1])))

function assume(spl::Sampler, dists::Vector{<:Distribution}, vn::VarName, var, vi::VarInfo)
@error "[assume]: Unmanaged inference algorithm: $(typeof(spl))"
end


@inline observe(::Nothing, dist, value, vi) = observe(SampleFromPrior(), dist, value, vi)

function observe(spl::A,
dist::Distribution,
value::Any,
vi::VarInfo) where {A<:Union{SampleFromPrior, SampleFromUniform}}

vi.num_produce += one(vi.num_produce)
Turing.DEBUG && @debug "dist = $dist"
Turing.DEBUG && @debug "value = $value"

# acclogp!(vi, logpdf(dist, value))
return logpdf(dist, value)
function observe(spl::Sampler, dist::Distribution, value, vi::VarInfo)
@error "[observe]: Unmanaged inference algorithm: $(typeof(spl))"
end

function observe(spl::A,
dists::Vector{T},
value::Any,
vi::VarInfo) where {T<:Distribution, A<:Union{SampleFromPrior, SampleFromUniform}}

@assert length(dists) == 1 "Turing.observe only support vectorizing i.i.d distribution"
dist = dists[1]
@assert isa(dist, UnivariateDistribution) || isa(dist, MultivariateDistribution) "Turing.observe: vectorizing matrix distribution is not supported"
if isa(dist, UnivariateDistribution) # only univariate distributions support broadcast operation (logpdf.) by Distributions.jl
# acclogp!(vi, sum(logpdf.(Ref(dist), value)))
sum(logpdf.(Ref(dist), value))
else
# acclogp!(vi, sum(logpdf(dist, value)))
sum(logpdf(dist, value))
end

function observe(spl::Sampler, dists::Vector{<:Distribution}, values, vi::VarInfo)
@error "[observe]: Unmanaged inference algorithm: $(typeof(spl))"
end


##############
# Utilities #
##############
Expand Down
46 changes: 5 additions & 41 deletions src/inference/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -263,50 +263,14 @@ function step(model, spl::Sampler{<:Hamiltonian}, vi::VarInfo, is_first::Val{fal
end

function assume(spl::Sampler{<:Hamiltonian}, dist::Distribution, vn::VarName, vi::VarInfo)
Turing.DEBUG && @debug "assuming..."
# Why do we need this here and not if dist is a vector of dists?
updategid!(vi, vn, spl)
r = vi[vn]
# acclogp!(vi, logpdf_with_trans(dist, r, istrans(vi, vn)))
# r
Turing.DEBUG && @debug "dist = $dist"
Turing.DEBUG && @debug "vn = $vn"
Turing.DEBUG && @debug "r = $r" "typeof(r)=$(typeof(r))"
r, logpdf_with_trans(dist, r, istrans(vi, vn))
return assume(ComputeLogJointDensity(), dist, vn, vi)
end

function assume(spl::Sampler{<:Hamiltonian}, dists::Vector{<:Distribution}, vn::VarName, var::Any, vi::VarInfo)
@assert length(dists) == 1 "[observe] Turing only support vectorizing i.i.d distribution"
dist = dists[1]
n = size(var)[end]

vns = map(i -> copybyindex(vn, "[$i]"), 1:n)

rs = vi[vns] # NOTE: inside Turing the Julia conversion should be sticked to

# acclogp!(vi, sum(logpdf_with_trans(dist, rs, istrans(vi, vns[1]))))

if isa(dist, UnivariateDistribution) || isa(dist, MatrixDistribution)
@assert size(var) == size(rs) "Turing.assume variable and random number dimension unmatched"
var = rs
elseif isa(dist, MultivariateDistribution)
if isa(var, Vector)
@assert length(var) == size(rs)[2] "Turing.assume variable and random number dimension unmatched"
for i = 1:n
var[i] = rs[:,i]
end
elseif isa(var, Matrix)
@assert size(var) == size(rs) "Turing.assume variable and random number dimension unmatched"
var = rs
else
error("[Turing] unsupported variable container")
end
end

var, sum(logpdf_with_trans(dist, rs, istrans(vi, vns[1])))
return assume(ComputeLogJointDensity(), dists, vn, var, vi)
end

observe(spl::Sampler{<:Hamiltonian}, d::Distribution, value::Any, vi::VarInfo) =
observe(nothing, d, value, vi)

observe(spl::Sampler{<:Hamiltonian}, ds::Vector{<:Distribution}, value::Any, vi::VarInfo) =
observe(nothing, ds, value, vi)
observe(spl::Sampler{<:Hamiltonian}, d, value, vi::VarInfo) =
observe(ComputeLogJointDensity(), d, value, vi)
Loading

0 comments on commit c4aa64c

Please sign in to comment.