Skip to content

Commit

Permalink
Prototype for new compilation
Browse files Browse the repository at this point in the history
  • Loading branch information
sunxd3 committed Oct 8, 2023
1 parent 1b46691 commit a334e63
Show file tree
Hide file tree
Showing 3 changed files with 240 additions and 14 deletions.
5 changes: 2 additions & 3 deletions src/compiler_pass.jl
Original file line number Diff line number Diff line change
Expand Up @@ -348,9 +348,8 @@ end
function assignment!(pass::CollectVariables, expr::Expr, env)
lhs_expr, rhs_expr = expr.args[1:2]

v = find_variables_on_lhs(
Meta.isexpr(lhs_expr, :call) ? lhs_expr.args[2] : lhs_expr, env
)
@assert !Meta.isexpr(lhs_expr, :call) "Link functions should already be transformed, but not in $expr"
v = find_variables_on_lhs(lhs_expr, env)
if !isa(v, Scalar)
check_idxs(v.name, v.indices, env)
end
Expand Down
227 changes: 227 additions & 0 deletions src/new_passes.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
using JuliaBUGS, MacroTools
using JuliaBUGS: loop_fission_helper

expr, L, R = JuliaBUGS.sort_expressions(JuliaBUGS.BUGSExamples.rats.model_def)
data = JuliaBUGS.BUGSExamples.rats.data

# lossless representation of the model Expr

abstract type BUGSExpr end

struct ForExpr{T} <: BUGSExpr
loop_vars
loop_bounds
body_expr::T
end

struct LogicalAssignment <: BUGSExpr
lhs
rhs
end

struct StochasticAssignment <: BUGSExpr
lhs
rhs
end

function Assignment(expr)
if Meta.isexpr(expr, :(~))
return StochasticAssignment(expr.args...)
else
return LogicalAssignment(expr.args...)
end
end

function ForExpr(expr)
processed_expr = only(loop_fission_helper(expr))
T = Union{Int, Symbol, Expr}
loop_vars = Symbol[]
loop_bounds = Tuple{T,T}[]
while true
loop_var, l, h, remaining = processed_expr
push!(loop_vars, loop_var)
push!(loop_bounds, (l, h))
if remaining isa Expr
return ForExpr(loop_vars, loop_bounds, Assignment(remaining))
else
processed_expr = remaining
end
end
end

# 1. determine what variables are in the model, if they are scalar or array
# if array, then number of dimensions
# 2. determine the size of the array: this can be done with pure static analysis
# size table:
# mu (i, false)
# (j, false)
# x (5, true)
# 3. constant propagation: only care about the logical expressions
# 4. type inference: let's ban x[2.5-1.5]
# - Sources of type information
# - data
# - constants
# - indexing
# - distributions (datatype, but also support)

# return a mapping from a Symbol representing variable name to the number of dimensions
function free_variables(expr::Expr)
vars = Dict()
MacroTools.prewalk(expr) do sub_expr
if MacroTools.@capture(sub_expr, f_(args__))
for arg in args
if arg isa Symbol
if !haskey(vars, arg)
vars[arg] = 0
elseif vars[arg] != 0
error("$arg is used as both scalar and array")
end
end
end
elseif MacroTools.@capture(sub_expr, x_[idx__])
if !haskey(vars, x)
vars[x] = length(idx)
elseif vars[x] != length(idx)
error("variable $x is used in multiple dimensions")
end
end
sub_expr
end
return vars
end

fv = free_variables(expr)

function size_table(fv, data)
st = Dict()
for k in keys(data)
if data[k] isa Number
st[k] = [true, Set()]
else
s = []
for i in 1:ndims(data[k])
push!(s, Any[true, Set{Any}(size(data[k])[i])])
end
st[k] = s
end
end
for (k, v) in fv
if k in keys(data)
@assert v == ndims(data[k])
elseif v == 0
st[k] = [true, Set()]
else
s = []
for i in 1:v
push!(s, Any[false, Set{Any}(-1)])
end
st[k] = s
end
end
return st
end

st = size_table(fv, data)

function collect_size_info(le, st)
expr = le.body_expr.lhs # only care about the lhs
st = deepcopy(st)
if MacroTools.@capture(expr, x_[idx__])
for (i, id) in enumerate(idx)
if id isa Number # good
elseif id isa Symbol
id = le.loop_bounds[findfirst(x->x==id, le.loop_vars)][2]
else # Expr
id = MacroTools.postwalk(id) do se
if MacroTools.@capture(se, l_:h_)
h
elseif se in le.loop_vars
le.loop_bounds[findfirst(x->x==se, le.loop_vars)][2]
else
se
end
end
end
push!(st[x][i][2], id)
end
end
return st
end

collect_size_info(ForExpr(L[3]), st)

test_e = @bugs begin
for i in 1:N
for j in 1:M
Y[i:i+1, j:j+1] ~ Normal(mu[i, j], sigma)
end
end
end

# loop over for expressions
st_c = deepcopy(st)
for ll in [L..., R...]
if Meta.isexpr(ll, :for)
le = ForExpr(ll)
st_c = collect_size_info(le, st_c)
end
end
st_c

# determine the size of the array
function narrowing(st, data)
st = deepcopy(st)
concrete_st = Dict()
for (k, v) in st
if v[1] == true
concrete_st[k] = 0
else
ss = []
for i in 1:length(v)
possible_sizes = map(Base.Fix2(JuliaBUGS.evaluate, data), collect(v[i][2]))
@assert all(map(x->x isa Number, possible_sizes)) "some size can't be evaluated with values from data"
if v[i][1] == true
@assert all(possible_sizes .== possible_sizes[1])
push!(ss, possible_sizes[1])
else
push!(ss, maximum(possible_sizes))
end
end
concrete_st[k] = ss
end
end
return concrete_st
end

narrowing(st_c, data)

function create_internal_values(concrete_st, data)
values = Dict()
for k in keys(data)
v = data[k]
if v isa Number
values[k] = v
else
missing_indices = findall(ismissing, v)
if isempty(missing_indices)
values[k] = v
else
sparse_v = zeros(size(v))
sparse_v[missing_indices] = 1
values[k] = (v, sparse(sparse_v))
end
end
end
for (k, v) in concrete_st
if haskey(values, k)
continue
elseif v == 0
values[k] = missing
else
values[k] = fill(missing, v...)
end
end
return values
end

create_internal_values(narrowing(st_c, data), data)
22 changes: 11 additions & 11 deletions src/parser.jl
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,7 @@ function loop_fission(expr::Expr)
return new_expr
end

# turn a for loop into tuple of form (loop_var, l, h, body) where body can be another loop
function loop_fission_helper(expr::Expr)
loops = []
MacroTools.prewalk(expr) do sub_expr
Expand Down Expand Up @@ -917,11 +918,7 @@ function generate_loop_expr(loop)
if !isa(remaining, Expr)
remaining = generate_loop_expr(remaining)
end
return MacroTools.prewalk(rmlines, :(
for $loop_var in ($l):($h)
$remaining
end
))
return Expr(:for, Expr(:(=), loop_var, Expr(:(:), l, h)), Expr(:block, remaining))
end

# TODO: @bugs will change Expr(:call, :(:), 1, :N) to Expr(:(:), 1, :(:), :N), beware when write tests
Expand All @@ -932,8 +929,7 @@ Returns the `loop_fission`ed expression with logical assignments in front of all
For all logical and stochastic assignments, simple assignments appear before for loops.
"""
function sort_expressions(expr)
expr = loop_fission(expr)
exprs = expr.args
# only works for fissioned expressions
function is_logical(e::Expr)
if Meta.isexpr(e, :~)
return false
Expand All @@ -946,6 +942,9 @@ function sort_expressions(expr)
return is_logical(e.args[2])
end
end

expr = loop_fission(expr)
exprs = expr.args
l_args = Expr[]
s_args = Expr[]
for e in exprs
Expand All @@ -955,8 +954,8 @@ function sort_expressions(expr)
push!(s_args, e)
end
end
return Expr(:block, l_args..., s_args...)
end
return Expr(:block, l_args..., s_args...), l_args, s_args
end

function check_idxs(expr::Expr)
return MacroTools.prewalk(expr) do sub_expr
Expand All @@ -975,9 +974,10 @@ function check_idxs(expr::Expr)
end
end

# This follow code are from early days of the parser, which uses a Julia String macro to
# The following code is from early days of the parser, which uses a Julia String macro to
# transform BUGS program into Julia program
# We have since implemented a new parser, see `parser.jl`
# We have since implemented a new parser that's a bit more robust, see `ProcessState` and
# `to_julia_program` above

macro _bugsmodel_str(s::String)
# Convert and wrap the whole thing in a block for parsing
Expand Down

0 comments on commit a334e63

Please sign in to comment.