Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Merged by Bors] - Small simplification of compiler #221

Closed
wants to merge 13 commits into from
Closed
2 changes: 1 addition & 1 deletion .github/workflows/IntegrationTest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
with:
version: 1
version: 1.5
arch: x64
- uses: julia-actions/julia-buildpkg@latest
- name: Clone Downstream
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.10.11"
version = "0.10.12"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
94 changes: 38 additions & 56 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ function model(mod, linenumbernode, expr, warn)

# Generate main body
modelinfo[:body] = generate_mainbody(
mod, modelinfo[:modeldef][:body], modelinfo[:allargs_syms], warn
mod, modelinfo[:modeldef][:body], warn
)

return build_output(modelinfo, linenumbernode)
Expand Down Expand Up @@ -155,92 +155,84 @@ function build_model_info(input_expr)
end

"""
generate_mainbody(mod, expr, args, warn)
generate_mainbody(mod, expr, warn)

Generate the body of the main evaluation function from expression `expr` and arguments
`args`.

If `warn` is true, a warning is displayed if internal variables are used in the model
definition.
"""
generate_mainbody(mod, expr, args, warn) = generate_mainbody!(mod, Symbol[], expr, args, warn)
generate_mainbody(mod, expr, warn) = generate_mainbody!(mod, Symbol[], expr, warn)

generate_mainbody!(mod, found, x, args, warn) = x
function generate_mainbody!(mod, found, sym::Symbol, args, warn)
generate_mainbody!(mod, found, x, warn) = x
function generate_mainbody!(mod, found, sym::Symbol, warn)
if warn && sym in INTERNALNAMES && sym ∉ found
@warn "you are using the internal variable `$(sym)`"
push!(found, sym)
end
return sym
end
function generate_mainbody!(mod, found, expr::Expr, args, warn)
function generate_mainbody!(mod, found, expr::Expr, warn)
# Do not touch interpolated expressions
expr.head === :$ && return expr.args[1]

# If it's a macro, we expand it
if Meta.isexpr(expr, :macrocall)
return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), args, warn)
return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), warn)
end

# Modify dotted tilde operators.
args_dottilde = getargs_dottilde(expr)
if args_dottilde !== nothing
L, R = args_dottilde
return generate_dot_tilde(generate_mainbody!(mod, found, L, args, warn),
generate_mainbody!(mod, found, R, args, warn),
args) |> Base.remove_linenums!
return generate_dot_tilde(
generate_mainbody!(mod, found, L, warn),
generate_mainbody!(mod, found, R, warn),
) |> Base.remove_linenums!
end

# Modify tilde operators.
args_tilde = getargs_tilde(expr)
if args_tilde !== nothing
L, R = args_tilde
return generate_tilde(generate_mainbody!(mod, found, L, args, warn),
generate_mainbody!(mod, found, R, args, warn),
args) |> Base.remove_linenums!
return generate_tilde(
generate_mainbody!(mod, found, L, warn),
generate_mainbody!(mod, found, R, warn),
) |> Base.remove_linenums!
end

return Expr(expr.head, map(x -> generate_mainbody!(mod, found, x, args, warn), expr.args)...)
return Expr(expr.head, map(x -> generate_mainbody!(mod, found, x, warn), expr.args)...)
end



"""
generate_tilde(left, right, args)
generate_tilde(left, right)

Generate an `observe` expression for data variables and `assume` expression for parameter
variables.
"""
function generate_tilde(left, right, args)
function generate_tilde(left, right)
@gensym tmpright
top = [:($tmpright = $right),
:($tmpright isa Union{$Distribution,AbstractVector{<:$Distribution}}
|| throw(ArgumentError($DISTMSG)))]

if left isa Symbol || left isa Expr
@gensym out vn inds
@gensym out vn inds isassumption
push!(top, :($vn = $(varname(left))), :($inds = $(vinds(left))))

# It can only be an observation if the LHS is an argument of the model
if vsym(left) in args
@gensym isassumption
return quote
$(top...)
$isassumption = $(DynamicPPL.isassumption(left))
if $isassumption
$left = $(DynamicPPL.tilde_assume)(
_rng, _context, _sampler, $tmpright, $vn, $inds, _varinfo)
else
$(DynamicPPL.tilde_observe)(
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
end
end
end

return quote
$(top...)
$left = $(DynamicPPL.tilde_assume)(_rng, _context, _sampler, $tmpright, $vn,
$inds, _varinfo)
$isassumption = $(DynamicPPL.isassumption(left))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW (I know not part of this PR 🙂): Why do we actually assign the output of isassumption to a variable?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep:)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh haha, I read do we actually assign the output of isassumption to a variable?, didn't see the why.

It's because isassumption generates a larger if-statement which returns a Bool, so we can't e.g. do if $(DynamicPPL.isassumption(left)).

if $isassumption
$left = $(DynamicPPL.tilde_assume)(
_rng, _context, _sampler, $tmpright, $vn, $inds, _varinfo)
else
$(DynamicPPL.tilde_observe)(
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
end
end
end

Expand All @@ -252,40 +244,30 @@ function generate_tilde(left, right, args)
end

"""
generate_dot_tilde(left, right, args)
generate_dot_tilde(left, right)

Generate the expression that replaces `left .~ right` in the model body.
"""
function generate_dot_tilde(left, right, args)
function generate_dot_tilde(left, right)
@gensym tmpright
top = [:($tmpright = $right),
:($tmpright isa Union{$Distribution,AbstractVector{<:$Distribution}}
|| throw(ArgumentError($DISTMSG)))]

if left isa Symbol || left isa Expr
@gensym out vn inds
@gensym out vn inds isassumption
push!(top, :($vn = $(varname(left))), :($inds = $(vinds(left))))

# It can only be an observation if the LHS is an argument of the model
if vsym(left) in args
@gensym isassumption
return quote
$(top...)
$isassumption = $(DynamicPPL.isassumption(left))
if $isassumption
$left .= $(DynamicPPL.dot_tilde_assume)(
_rng, _context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
else
$(DynamicPPL.dot_tilde_observe)(
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
end
end
end

return quote
$(top...)
$left .= $(DynamicPPL.dot_tilde_assume)(
_rng, _context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
$isassumption = $(DynamicPPL.isassumption(left)) || $left === missing
if $isassumption
$left .= $(DynamicPPL.dot_tilde_assume)(
_rng, _context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
else
$(DynamicPPL.dot_tilde_observe)(
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
end
end
end

Expand Down