diff --git a/.github/workflows/IntegrationTest.yml b/.github/workflows/IntegrationTest.yml index a15cf6e8e..28ed3570a 100644 --- a/.github/workflows/IntegrationTest.yml +++ b/.github/workflows/IntegrationTest.yml @@ -24,7 +24,7 @@ jobs: - uses: actions/checkout@v2 - uses: julia-actions/setup-julia@v1 with: - version: 1 + version: 1.5 arch: x64 - uses: julia-actions/julia-buildpkg@latest - name: Clone Downstream diff --git a/Project.toml b/Project.toml index 4b5926753..ec9ef4e96 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.10.11" +version = "0.10.12" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/compiler.jl b/src/compiler.jl index 1de89ef3d..a70401bfa 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -72,7 +72,7 @@ function model(mod, linenumbernode, expr, warn) # Generate main body modelinfo[:body] = generate_mainbody( - mod, modelinfo[:modeldef][:body], modelinfo[:allargs_syms], warn + mod, modelinfo[:modeldef][:body], warn ) return build_output(modelinfo, linenumbernode) @@ -155,7 +155,7 @@ function build_model_info(input_expr) end """ - generate_mainbody(mod, expr, args, warn) + generate_mainbody(mod, expr, warn) Generate the body of the main evaluation function from expression `expr` and arguments `args`. @@ -163,84 +163,76 @@ Generate the body of the main evaluation function from expression `expr` and arg If `warn` is true, a warning is displayed if internal variables are used in the model definition. """ -generate_mainbody(mod, expr, args, warn) = generate_mainbody!(mod, Symbol[], expr, args, warn) +generate_mainbody(mod, expr, warn) = generate_mainbody!(mod, Symbol[], expr, warn) -generate_mainbody!(mod, found, x, args, warn) = x -function generate_mainbody!(mod, found, sym::Symbol, args, warn) +generate_mainbody!(mod, found, x, warn) = x +function generate_mainbody!(mod, found, sym::Symbol, warn) if warn && sym in INTERNALNAMES && sym ∉ found @warn "you are using the internal variable `$(sym)`" push!(found, sym) end return sym end -function generate_mainbody!(mod, found, expr::Expr, args, warn) +function generate_mainbody!(mod, found, expr::Expr, warn) # Do not touch interpolated expressions expr.head === :$ && return expr.args[1] # If it's a macro, we expand it if Meta.isexpr(expr, :macrocall) - return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), args, warn) + return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), warn) end # Modify dotted tilde operators. args_dottilde = getargs_dottilde(expr) if args_dottilde !== nothing L, R = args_dottilde - return generate_dot_tilde(generate_mainbody!(mod, found, L, args, warn), - generate_mainbody!(mod, found, R, args, warn), - args) |> Base.remove_linenums! + return generate_dot_tilde( + generate_mainbody!(mod, found, L, warn), + generate_mainbody!(mod, found, R, warn), + ) |> Base.remove_linenums! end # Modify tilde operators. args_tilde = getargs_tilde(expr) if args_tilde !== nothing L, R = args_tilde - return generate_tilde(generate_mainbody!(mod, found, L, args, warn), - generate_mainbody!(mod, found, R, args, warn), - args) |> Base.remove_linenums! + return generate_tilde( + generate_mainbody!(mod, found, L, warn), + generate_mainbody!(mod, found, R, warn), + ) |> Base.remove_linenums! end - return Expr(expr.head, map(x -> generate_mainbody!(mod, found, x, args, warn), expr.args)...) + return Expr(expr.head, map(x -> generate_mainbody!(mod, found, x, warn), expr.args)...) end """ - generate_tilde(left, right, args) + generate_tilde(left, right) Generate an `observe` expression for data variables and `assume` expression for parameter variables. """ -function generate_tilde(left, right, args) +function generate_tilde(left, right) @gensym tmpright top = [:($tmpright = $right), :($tmpright isa Union{$Distribution,AbstractVector{<:$Distribution}} || throw(ArgumentError($DISTMSG)))] if left isa Symbol || left isa Expr - @gensym out vn inds + @gensym out vn inds isassumption push!(top, :($vn = $(varname(left))), :($inds = $(vinds(left)))) - # It can only be an observation if the LHS is an argument of the model - if vsym(left) in args - @gensym isassumption - return quote - $(top...) - $isassumption = $(DynamicPPL.isassumption(left)) - if $isassumption - $left = $(DynamicPPL.tilde_assume)( - _rng, _context, _sampler, $tmpright, $vn, $inds, _varinfo) - else - $(DynamicPPL.tilde_observe)( - _context, _sampler, $tmpright, $left, $vn, $inds, _varinfo) - end - end - end - return quote $(top...) - $left = $(DynamicPPL.tilde_assume)(_rng, _context, _sampler, $tmpright, $vn, - $inds, _varinfo) + $isassumption = $(DynamicPPL.isassumption(left)) + if $isassumption + $left = $(DynamicPPL.tilde_assume)( + _rng, _context, _sampler, $tmpright, $vn, $inds, _varinfo) + else + $(DynamicPPL.tilde_observe)( + _context, _sampler, $tmpright, $left, $vn, $inds, _varinfo) + end end end @@ -252,40 +244,30 @@ function generate_tilde(left, right, args) end """ - generate_dot_tilde(left, right, args) + generate_dot_tilde(left, right) Generate the expression that replaces `left .~ right` in the model body. """ -function generate_dot_tilde(left, right, args) +function generate_dot_tilde(left, right) @gensym tmpright top = [:($tmpright = $right), :($tmpright isa Union{$Distribution,AbstractVector{<:$Distribution}} || throw(ArgumentError($DISTMSG)))] if left isa Symbol || left isa Expr - @gensym out vn inds + @gensym out vn inds isassumption push!(top, :($vn = $(varname(left))), :($inds = $(vinds(left)))) - # It can only be an observation if the LHS is an argument of the model - if vsym(left) in args - @gensym isassumption - return quote - $(top...) - $isassumption = $(DynamicPPL.isassumption(left)) - if $isassumption - $left .= $(DynamicPPL.dot_tilde_assume)( - _rng, _context, _sampler, $tmpright, $left, $vn, $inds, _varinfo) - else - $(DynamicPPL.dot_tilde_observe)( - _context, _sampler, $tmpright, $left, $vn, $inds, _varinfo) - end - end - end - return quote $(top...) - $left .= $(DynamicPPL.dot_tilde_assume)( - _rng, _context, _sampler, $tmpright, $left, $vn, $inds, _varinfo) + $isassumption = $(DynamicPPL.isassumption(left)) || $left === missing + if $isassumption + $left .= $(DynamicPPL.dot_tilde_assume)( + _rng, _context, _sampler, $tmpright, $left, $vn, $inds, _varinfo) + else + $(DynamicPPL.dot_tilde_observe)( + _context, _sampler, $tmpright, $left, $vn, $inds, _varinfo) + end end end