Skip to content

Commit

Permalink
Vectorize construct for uni
Browse files Browse the repository at this point in the history
  • Loading branch information
xukai92 committed May 13, 2017
1 parent a1afe04 commit af2b6c6
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 11 deletions.
12 changes: 6 additions & 6 deletions src/core/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,16 @@ typealias VarView Union{Int,UnitRange,Vector{Int},Vector{UnitRange}}
getidx(vi::VarInfo, vn::VarName) = vi.idcs[vn]

getrange(vi::VarInfo, vn::VarName) = vi.ranges[getidx(vi, vn)]
getranges(vi::VarInfo, vns::Vector{VarName}) = union(map(vn -> getrange(vi, vn), vns)...)

getval(vi::VarInfo, vn::VarName) = vi.vals[end][getrange(vi, vn)]
setval!(vi::VarInfo, val, vn::VarName) = vi.vals[end][getrange(vi, vn)] = val

getval(vi::VarInfo, vns::Vector{VarName}) = vi.vals[end][getranges(vi, vns)]

getval(vi::VarInfo, view::VarView) = vi.vals[end][view]
setval!(vi::VarInfo, val::Any, view::VarView) = vi.vals[end][view] = val
setval!(vi::VarInfo, val::Any, view::Vector{UnitRange}) = map(v->vi.vals[end][v] = val, view)
setval!(vi::VarInfo, val::Any, view::Vector{UnitRange}) = map(v -> vi.vals[end][v] = val, view)

getall(vi::VarInfo) = vi.vals[end]
setall!(vi::VarInfo, val::Any) = vi.vals[end] = val
Expand Down Expand Up @@ -142,11 +145,8 @@ end
Base.getindex(vi::VarInfo, vns::Vector{VarName}) = begin
@assert haskey(vi, vns[1]) "[Turing] attempted to replay unexisting variables in VarInfo"
dist = getdist(vi, vns[1])
if istrans(vi, vns[1])
[reconstruct(dist, getval(vi, vn)) for vn in vns]
else
[invlink(dist, reconstruct(dist, getval(vi, vn))) for vn in vns]
end
rs = reconstruct(dist, getval(vi, vns))
rs = istrans(vi, vns[1]) ? invlink(dist, rs) : rs
end

# NOTE: vi[view] will just return what insdie vi (no transformations applied)
Expand Down
8 changes: 7 additions & 1 deletion src/samplers/support/helper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,13 @@ vectorize{T<:Real}(d::MatrixDistribution, r::Matrix{T}) = Vector{Real}(vec
# Note this is not the case for MultivariateDistribution so I guess this might be lack of
# support for some types related to matrices (like PDMat).
reconstruct(d::Distribution, val::Vector) = reconstruct(d, val, typeof(val[1]))
reconstruct(d::UnivariateDistribution, val::Vector, T::Type) = T(val[1])
reconstruct(d::UnivariateDistribution, val::Vector, T::Type) = begin
if length(val) == 1
T(val[1])
else
Vector{T}(val)
end
end
reconstruct(d::MultivariateDistribution, val::Vector, T::Type) = Vector{T}(val)
reconstruct(d::MatrixDistribution, val::Vector, T::Type) = Array{T, 2}(reshape(val, size(d)...))

Expand Down
8 changes: 4 additions & 4 deletions test/sampler.jl/vec_assume.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ setchunksize(100)
# Test for vectorize UnivariateDistribution
@model vdemo() = begin
x = Vector{Real}(100)
# x ~ [Normal(0, sqrt(4))]
for i = 1:100
x[i] ~ Normal(0, sqrt(4))
end
x ~ [Normal(0, sqrt(4))]
# for i = 1:100
# x[i] ~ Normal(0, sqrt(4))
# end
end

alg = HMC(1000, 0.2, 4)
Expand Down

0 comments on commit af2b6c6

Please sign in to comment.