-
Notifications
You must be signed in to change notification settings - Fork 223
/
Copy pathio.jl
152 lines (132 loc) · 4.57 KB
/
io.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
#########################
# Sampler I/O Interface #
#########################
##########
# Sample #
##########
mutable struct Sample
weight :: Float64 # particle weight
value :: Dict{Symbol,Any}
end
Base.getindex(s::Sample, v::Symbol) = getjuliatype(s, v)
function parse_inds(inds)
p_inds = [parse(Int, m.captures[1]) for m in eachmatch(r"(\d+)", inds)]
if length(p_inds) == 1
return p_inds[1]
else
return Tuple(p_inds)
end
end
function getjuliatype(s::Sample, v::Symbol, cached_syms=nothing)
# NOTE: cached_syms is used to cache the filter entiries in svalue. This is helpful when the dimension of model is huge.
if cached_syms == nothing
# Get all keys associated with the given symbol
syms = collect(Iterators.filter(k -> occursin(string(v)*"[", string(k)), keys(s.value)))
else
syms = collect((Iterators.filter(k -> occursin(string(v), string(k)), cached_syms)))
end
# Map to the corresponding indices part
idx_str = map(sym -> replace(string(sym), string(v) => ""), syms)
# Get the indexing component
idx_comp = map(idx -> collect(Iterators.filter(str -> str != "", split(string(idx), [']','[']))), idx_str)
# Deal with v is really a symbol, e.g. :x
if isempty(idx_comp)
@assert haskey(s.value, v)
return Base.getindex(s.value, v)
end
# Construct container for the frist nesting layer
dim = length(split(idx_comp[1][1], ','))
if dim == 1
sample = Vector(undef, length(unique(map(c -> c[1], idx_comp))))
else
d = max(map(c -> parse_inds(c[1]), idx_comp)...)
sample = Array{Any, length(d)}(undef, d)
end
# Fill sample
for i = 1:length(syms)
# Get indexing
idx = parse_inds(idx_comp[i][1])
# Determine if nesting
nested_dim = length(idx_comp[1]) # how many nested layers?
if nested_dim == 1
setindex!(sample, getindex(s.value, syms[i]), idx...)
else # nested case, iteratively evaluation
v_indexed = Symbol("$v[$(idx_comp[i][1])]")
setindex!(sample, getjuliatype(s, v_indexed, syms), idx...)
end
end
return sample
end
#########
# Chain #
#########
# Variables to put in the Chains :internal section.
const _internal_vars = ["elapsed",
"epsilon",
"eval_num",
"lf_eps",
"lf_num",
"lp"]
function Chain(w::Real, s::AbstractArray{Sample})
samples = flatten.(s)
names_ = collect(mapreduce(s -> keys(s), union, samples))
values_ = mapreduce(v -> map(k -> haskey(v, k) ? v[k] : missing, names_), hcat, samples)
values_ = convert(Array{Union{Missing, Float64}, 2}, values_')
chn = Chains(
reshape(values_, size(values_, 1), size(values_, 2), 1),
names_,
Dict(:internals => _internal_vars),
evidence = w
)
return chn
end
# ind2sub is deprecated in Julia 1.0
ind2sub(v, i) = Tuple(CartesianIndices(v)[i])
function flatten(s::Sample)
vals = Vector{Float64}()
names = Vector{AbstractString}()
for (k, v) in s.value
flatten(names, vals, string(k), v)
end
return Dict(names[i] => vals[i] for i in 1:length(vals))
end
function flatten(names, value :: Array{Float64}, k :: String, v)
if isa(v, Number)
name = k
push!(value, v)
push!(names, name)
elseif isa(v, Array)
for i = eachindex(v)
if isa(v[i], Number)
name = string(ind2sub(size(v), i))
name = replace(name, "(" => "[");
name = replace(name, ",)" => "]");
name = replace(name, ")" => "]");
name = k * name
isa(v[i], Nothing) && println(v, i, v[i])
push!(value, Float64(v[i]))
push!(names, name)
elseif isa(v[i], Array)
name = k * string(ind2sub(size(v), i))
flatten(names, value, name, v[i])
else
error("Unknown var type: typeof($v[i])=$(typeof(v[i]))")
end
end
else
error("Unknown var type: typeof($v)=$(typeof(v))")
end
return
end
function save(c::Chains, spl::Sampler, model, vi, samples)
nt = NamedTuple{(:spl, :model, :vi, :samples)}((spl, model, deepcopy(vi), samples))
return setinfo(c, merge(nt, c.info))
end
function resume(c::Chains, n_iter::Int)
@assert !isempty(c.info) "[Turing] cannot resume from a chain without state info"
return sample( c.info[:model],
c.info[:spl].alg; # this is actually not used
resume_from=c,
reuse_spl_n=n_iter
)
end