-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
240 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,227 @@ | ||
using JuliaBUGS, MacroTools | ||
using JuliaBUGS: loop_fission_helper | ||
|
||
expr, L, R = JuliaBUGS.sort_expressions(JuliaBUGS.BUGSExamples.rats.model_def) | ||
data = JuliaBUGS.BUGSExamples.rats.data | ||
|
||
# lossless representation of the model Expr | ||
|
||
abstract type BUGSExpr end | ||
|
||
struct ForExpr{T} <: BUGSExpr | ||
loop_vars | ||
loop_bounds | ||
body_expr::T | ||
end | ||
|
||
struct LogicalAssignment <: BUGSExpr | ||
lhs | ||
rhs | ||
end | ||
|
||
struct StochasticAssignment <: BUGSExpr | ||
lhs | ||
rhs | ||
end | ||
|
||
function Assignment(expr) | ||
if Meta.isexpr(expr, :(~)) | ||
return StochasticAssignment(expr.args...) | ||
else | ||
return LogicalAssignment(expr.args...) | ||
end | ||
end | ||
|
||
function ForExpr(expr) | ||
processed_expr = only(loop_fission_helper(expr)) | ||
T = Union{Int, Symbol, Expr} | ||
loop_vars = Symbol[] | ||
loop_bounds = Tuple{T,T}[] | ||
while true | ||
loop_var, l, h, remaining = processed_expr | ||
push!(loop_vars, loop_var) | ||
push!(loop_bounds, (l, h)) | ||
if remaining isa Expr | ||
return ForExpr(loop_vars, loop_bounds, Assignment(remaining)) | ||
else | ||
processed_expr = remaining | ||
end | ||
end | ||
end | ||
|
||
# 1. determine what variables are in the model, if they are scalar or array | ||
# if array, then number of dimensions | ||
# 2. determine the size of the array: this can be done with pure static analysis | ||
# size table: | ||
# mu (i, false) | ||
# (j, false) | ||
# x (5, true) | ||
# 3. constant propagation: only care about the logical expressions | ||
# 4. type inference: let's ban x[2.5-1.5] | ||
# - Sources of type information | ||
# - data | ||
# - constants | ||
# - indexing | ||
# - distributions (datatype, but also support) | ||
|
||
# return a mapping from a Symbol representing variable name to the number of dimensions | ||
function free_variables(expr::Expr) | ||
vars = Dict() | ||
MacroTools.prewalk(expr) do sub_expr | ||
if MacroTools.@capture(sub_expr, f_(args__)) | ||
for arg in args | ||
if arg isa Symbol | ||
if !haskey(vars, arg) | ||
vars[arg] = 0 | ||
elseif vars[arg] != 0 | ||
error("$arg is used as both scalar and array") | ||
end | ||
end | ||
end | ||
elseif MacroTools.@capture(sub_expr, x_[idx__]) | ||
if !haskey(vars, x) | ||
vars[x] = length(idx) | ||
elseif vars[x] != length(idx) | ||
error("variable $x is used in multiple dimensions") | ||
end | ||
end | ||
sub_expr | ||
end | ||
return vars | ||
end | ||
|
||
fv = free_variables(expr) | ||
|
||
function size_table(fv, data) | ||
st = Dict() | ||
for k in keys(data) | ||
if data[k] isa Number | ||
st[k] = [true, Set()] | ||
else | ||
s = [] | ||
for i in 1:ndims(data[k]) | ||
push!(s, Any[true, Set{Any}(size(data[k])[i])]) | ||
end | ||
st[k] = s | ||
end | ||
end | ||
for (k, v) in fv | ||
if k in keys(data) | ||
@assert v == ndims(data[k]) | ||
elseif v == 0 | ||
st[k] = [true, Set()] | ||
else | ||
s = [] | ||
for i in 1:v | ||
push!(s, Any[false, Set{Any}(-1)]) | ||
end | ||
st[k] = s | ||
end | ||
end | ||
return st | ||
end | ||
|
||
st = size_table(fv, data) | ||
|
||
function collect_size_info(le, st) | ||
expr = le.body_expr.lhs # only care about the lhs | ||
st = deepcopy(st) | ||
if MacroTools.@capture(expr, x_[idx__]) | ||
for (i, id) in enumerate(idx) | ||
if id isa Number # good | ||
elseif id isa Symbol | ||
id = le.loop_bounds[findfirst(x->x==id, le.loop_vars)][2] | ||
else # Expr | ||
id = MacroTools.postwalk(id) do se | ||
if MacroTools.@capture(se, l_:h_) | ||
h | ||
elseif se in le.loop_vars | ||
le.loop_bounds[findfirst(x->x==se, le.loop_vars)][2] | ||
else | ||
se | ||
end | ||
end | ||
end | ||
push!(st[x][i][2], id) | ||
end | ||
end | ||
return st | ||
end | ||
|
||
collect_size_info(ForExpr(L[3]), st) | ||
|
||
test_e = @bugs begin | ||
for i in 1:N | ||
for j in 1:M | ||
Y[i:i+1, j:j+1] ~ Normal(mu[i, j], sigma) | ||
end | ||
end | ||
end | ||
|
||
# loop over for expressions | ||
st_c = deepcopy(st) | ||
for ll in [L..., R...] | ||
if Meta.isexpr(ll, :for) | ||
le = ForExpr(ll) | ||
st_c = collect_size_info(le, st_c) | ||
end | ||
end | ||
st_c | ||
|
||
# determine the size of the array | ||
function narrowing(st, data) | ||
st = deepcopy(st) | ||
concrete_st = Dict() | ||
for (k, v) in st | ||
if v[1] == true | ||
concrete_st[k] = 0 | ||
else | ||
ss = [] | ||
for i in 1:length(v) | ||
possible_sizes = map(Base.Fix2(JuliaBUGS.evaluate, data), collect(v[i][2])) | ||
@assert all(map(x->x isa Number, possible_sizes)) "some size can't be evaluated with values from data" | ||
if v[i][1] == true | ||
@assert all(possible_sizes .== possible_sizes[1]) | ||
push!(ss, possible_sizes[1]) | ||
else | ||
push!(ss, maximum(possible_sizes)) | ||
end | ||
end | ||
concrete_st[k] = ss | ||
end | ||
end | ||
return concrete_st | ||
end | ||
|
||
narrowing(st_c, data) | ||
|
||
function create_internal_values(concrete_st, data) | ||
values = Dict() | ||
for k in keys(data) | ||
v = data[k] | ||
if v isa Number | ||
values[k] = v | ||
else | ||
missing_indices = findall(ismissing, v) | ||
if isempty(missing_indices) | ||
values[k] = v | ||
else | ||
sparse_v = zeros(size(v)) | ||
sparse_v[missing_indices] = 1 | ||
values[k] = (v, sparse(sparse_v)) | ||
end | ||
end | ||
end | ||
for (k, v) in concrete_st | ||
if haskey(values, k) | ||
continue | ||
elseif v == 0 | ||
values[k] = missing | ||
else | ||
values[k] = fill(missing, v...) | ||
end | ||
end | ||
return values | ||
end | ||
|
||
create_internal_values(narrowing(st_c, data), data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters