Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Merged by Bors] - Support for submodels #233

Closed
wants to merge 18 commits into from
Closed

Conversation

torfjelde
Copy link
Member

@torfjelde torfjelde commented Apr 25, 2021

Part of the motivations for #221 and #222 was so we could add submodels/model-nesting.

Well, now we can.

Special thanks to @devmotion who reviewed those PRs (several times), improving them significantly, made additional PRs and suggested the current impl of the @submodel ❤️

EDIT: We fixed the performance:) This now has zero runtime overhead! See comment-section.
EDIT 2: Thanks to @devmotion, we can now alos deal with dynamically specified prefices!

Motivating example: AR1-prior

using Turing
using DynamicPPL
# Could have made model which samples `num_obs` AR1 samples simulatenously,
# but for the sake of showing off dynamic prefixes, we'll only use a vector-implementation.
# The matrix implementation will be quite a bit faster too, but oh well.
@model function AR1(num_steps, α, μ, σ, ::Type{TV} = Vector{Float64}) where {TV}
    η ~ MvNormal(num_steps, 1.0)
    δ = sqrt(1 - α^2)

    x = TV(undef, num_steps)
    x[1] = η[1]
    @inbounds for t = 2:num_steps
        x[t] = @. α * x[t - 1] + δ * η[t]
    end

    return @. μ + σ * x
end

# Generate an observation
σ_obs = 0.1
num_obs = 5
num_steps = 10

ar1 = AR1(num_steps, 0.5, 1.0, 1.0)
ys = mapreduce(hcat, 1:num_obs) do i
    ar1() + σ_obs * randn(num_steps)
end
10×5 Matrix{Float64}:
  2.30189    0.301618  1.73268   -0.65096    1.46835
  2.11187   -1.34878   2.3728     1.02125    3.28422
 -0.249064   0.769488  1.34044    3.22175    2.52196
 -0.25863   -0.216914  0.528954   3.04756    3.8234
  0.372122   0.473511  0.708068   0.76197    0.202003
  0.41487    0.759435  1.80162    0.790204   0.12331
  1.32585    0.567929  2.74316    1.0874     2.82701
  1.84307    1.16138   1.36382    0.735388   1.07423
  3.20139    0.75177   1.57236    0.865401  -0.315341
  1.22479    1.35688   2.8239     0.597959   0.587955
@model function demo(y)
    α ~ Uniform()
    μ ~ Normal()
    σ ~ truncated(Normal(), 0, Inf)

    num_steps = size(y, 1)
    num_obs = size(y, 2)
    @inbounds for i = 1:num_obs
        x = @submodel $(Symbol("ar1_$i")) AR1(num_steps, α, μ, σ)
        y[:, i] ~ MvNormal(x, 0.1)
    end
end;

m = demo(y);
vi = VarInfo(m);
keys(vi)
8-element Vector{VarName{sym, Tuple{}} where sym}:
 α
 μ
 σ
 ar1_1.η
 ar1_2.η
 ar1_3.η
 ar1_4.η
 ar1_5.η
vi[@varname α]
0.9383208224122919
chain = sample(m, NUTS(1_000, 0.8), 3_000);
┌ Info: Found initial step size
│   ϵ = 0.025
└ @ Turing.Inference /home/tor/.julia/packages/Turing/rHLGJ/src/inference/hmc.jl:188
Sampling: 100%|█████████████████████████████████████████| Time: 0:04:00
chain[1001:end, [, , ], :]
Chains MCMC chain (2000×3×1 Array{Float64, 3}):

Iterations        = 1001:3000
Thinning interval = 1
Chains            = 1
Samples per chain = 2000
parameters        = α, μ, σ
internals         = 

Summary Statistics
  parameters      mean       std   naive_se      mcse        ess      rhat 
      Symbol   Float64   Float64    Float64   Float64    Float64   Float64 

           α    0.5474    0.1334     0.0030    0.0073   159.6969    0.9995
           μ    1.0039    0.2733     0.0061    0.0168   169.9106    1.0134
           σ    1.1294    0.1807     0.0040    0.0106   166.8670    0.9998

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

           α    0.2684    0.4625    0.5534    0.6445    0.7861
           μ    0.4248    0.8227    1.0241    1.2011    1.4801
           σ    0.8781    1.0018    1.0989    1.2239    1.5472

Yay! We recovered the true parameters 🎉

@benchmark $m($vi)
BenchmarkTools.Trial: 
  memory estimate:  12.05 KiB
  allocs estimate:  123
  --------------
  minimum time:     15.091 μs (0.00% GC)
  median time:      17.861 μs (0.00% GC)
  mean time:        19.582 μs (5.23% GC)
  maximum time:     10.293 ms (99.46% GC)
  --------------
  samples:          10000
  evals/sample:     1

Demos

using DynamicPPL, Distributions
┌ Info: Precompiling DynamicPPL [366bfd00-2699-11ea-058f-f148b4cae6d8]
└ @ Base loading.jl:1317
@model function demo1(x)
    x ~ Normal()
end;
@model function demo2(x, y)
    @submodel demo1(x)
    y ~ Uniform()
end false;
m2 = demo2(missing, missing);
vi2 = VarInfo(m2);
keys(vi2)
2-element Vector{VarName{sym, Tuple{}} where sym}:
 x
 y
println(vi2[VarName(Symbol("x"))])
println(vi2[VarName(Symbol("y"))])
0.3069117531180063
0.7325324947386318

We can also observe without issues:

@model function demo2(x, y)
    @submodel demo1(x)
    y ~ Normal(x)
end false;
m2 = demo2(1000.0, missing);
vi2 = VarInfo(m2);
keys(vi2)
1-element Vector{VarName{:y, Tuple{}}}:
 y
vi2[@varname y]
1000.3905079427211
DynamicPPL.getlogp(vi2)
-500001.9141252931

But what if the models have the same variable-names?!

"Sure, this is cool and all, but can we even use the values from the nested values in the parent model?"

@model function demo_return(x)
    x ~ Normal()
    return x
end;

@model function demo_useval(x, y)
    x1 = @submodel sub1 demo_return(x)
    x2 = @submodel sub2 demo_return(y)

    z ~ Normal(x1 + x2 + 100, 1.0)
end false;
vi = VarInfo(demo_useval(missing, missing));
keys(vi)
3-element Vector{VarName{sym, Tuple{}} where sym}:
 sub1.x
 sub2.x
 z
vi[@varname z]
101.09066854862154

And just to prove a point:

@model function nested(x, y)
    @submodel 1 nested1(x, y)
    y ~ Uniform()
end false;
@model function nested1(x, y)
    @submodel 2 nested2(x, y)
    y ~ Uniform()
end false;
@model function nested2(x, y)
    z = @submodel 3 nested3(x, y)
    y ~ Normal(z, 1.0)
end false;
@model function nested3(x, y)
    x ~ Uniform()
    y ~ Normal(-100.0, 1.0)

    return x + 1000
end false;

m = nested(missing, missing);
vi = VarInfo(m);
keys(vi)
5-element Vector{VarName{sym, Tuple{}} where sym}:
 1.2.3.x
 1.2.3.y
 1.2.y
 1.y
 y
vi[VarName(Symbol("1.2.y"))]
1000.5609156083766
DynamicPPL.getlogp(vi)
-4.620040828101227

Can it ever fail?

Yeah, if the user doesn't provide the prefix, it can:

@model function nested(x, y)
    @submodel nested1(x, y)
    y ~ Uniform()
end false;
@model function nested1(x, y)
    @submodel nested2(x, y)
    y ~ Uniform()
end false;
@model function nested2(x, y)
    z = @submodel nested3(x, y)
    y ~ Normal(z, 1.0)
end false;
@model function nested3(x, y)
    x ~ Uniform()
    y ~ Normal(-100.0, 1.0)

    return x + 1000
end false;

m = nested(missing, missing);
vi = VarInfo(m);
keys(vi)
2-element Vector{VarName{sym, Tuple{}} where sym}:
 x
 y
# Inner-most value is recorded (i.e. the first one reached)
vi[@varname y]
-100.16836599596732

And it messes up the logp computation:

DynamicPPL.getlogp(vi)
-Inf

But I could imagine there's a way for us to fix this, or at least warn the user when this happens.

Benchmarks

At this point you're probably wondering, "but does it have any overhead (at runtime)?". For a "shallow" nestings, nah, but if you go deep enough there seems to be a tiny bit (likely because we're calling the "constructor" for the model):

using BenchmarkTools

@model function base(x, y)
    x ~ Uniform()
    y ~ Uniform()
    y1 ~ Uniform()
    z = x + 1000
    y12 ~ Normal()
    y123 ~ Normal(-100.0, 1.0)
end

m1 = base(missing, missing);
vi1 = VarInfo(m1);
@model function nested_shallow(x, y)
    @submodel 1 nested1_shallow(x, y)
    y ~ Uniform()
end false;
@model function nested1_shallow(x, y)
    x ~ Uniform()
    y ~ Uniform()
    z = x + 1000
    y12 ~ Normal()
    y123 ~ Normal(-100.0, 1.0)
end false;

m2 = nested_shallow(missing, missing);
vi2 = VarInfo(m2);
@model function nested(x, y)
    @submodel 1 nested1(x, y)
    y ~ Uniform()
end false;
@model function nested1(x, y)
    @submodel 2 nested2(x, y)
    y ~ Uniform()
end false;
@model function nested2(x, y)
    z = @submodel 3 nested3(x, y)
    y ~ Normal(z, 1.0)
end false;
@model function nested3(x, y)
    x ~ Uniform()
    y ~ Normal(-100.0, 1.0)

    return x + 1000
end

m3 = nested(missing, missing);
vi3 = VarInfo(m3);
@model function nested_noprefix(x, y)
    @submodel nested_noprefix1(x, y)
    y ~ Uniform()
end false;
@model function nested_noprefix1(x, y)
    @submodel nested_noprefix2(x, y)
    y1 ~ Uniform()
end false;
@model function nested_noprefix2(x, y)
    z = @submodel nested_noprefix3(x, y)
    y2 ~ Normal(z, 1.0)
end false;
@model function nested_noprefix3(x, y)
    x ~ Uniform()
    y3 ~ Normal(-100.0, 1.0)

    return x + 1000
end

m4 = nested_noprefix(missing, missing);
vi4 = VarInfo(m4);
keys(vi1)
5-element Vector{VarName{sym, Tuple{}} where sym}:
 x
 y
 y1
 y12
 y123
keys(vi2)
5-element Vector{VarName{sym, Tuple{}} where sym}:
 1.x
 1.y
 1.y12
 1.y123
 y
keys(vi3)
5-element Vector{VarName{sym, Tuple{}} where sym}:
 1.2.3.x
 1.2.3.y
 1.2.y
 1.y
 y
keys(vi4)
5-element Vector{VarName{sym, Tuple{}} where sym}:
 x
 y3
 y2
 y1
 y
@benchmark $m1($vi1)
BenchmarkTools.Trial: 
  memory estimate:  160 bytes
  allocs estimate:  5
  --------------
  minimum time:     1.714 μs (0.00% GC)
  median time:      1.747 μs (0.00% GC)
  mean time:        1.835 μs (0.00% GC)
  maximum time:     6.894 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     10
@benchmark $m2($vi2)
BenchmarkTools.Trial: 
  memory estimate:  160 bytes
  allocs estimate:  5
  --------------
  minimum time:     1.759 μs (0.00% GC)
  median time:      1.778 μs (0.00% GC)
  mean time:        1.819 μs (0.00% GC)
  maximum time:     5.563 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     10
@benchmark $m3($vi3)
BenchmarkTools.Trial: 
  memory estimate:  160 bytes
  allocs estimate:  5
  --------------
  minimum time:     1.718 μs (0.00% GC)
  median time:      1.746 μs (0.00% GC)
  mean time:        1.787 μs (0.00% GC)
  maximum time:     5.758 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     10
@benchmark $m4($vi4)
BenchmarkTools.Trial: 
  memory estimate:  160 bytes
  allocs estimate:  5
  --------------
  minimum time:     1.672 μs (0.00% GC)
  median time:      1.696 μs (0.00% GC)
  mean time:        1.756 μs (0.00% GC)
  maximum time:     4.882 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     10

Notice that the number of allocations have increased for the deeply nested model. Seems like the Julia compiler isn't too good at inferring the return-types of Turing-models? This seems to be the case too by looking at the lowered code. I haven't given this too much thought yet btw; likely is a way for us to help the compiler here.

@torfjelde
Copy link
Member Author

torfjelde commented Apr 25, 2021

Btw, it seems like the nested Prefix is the issue; not model-within-model. So adding the following constructor, fixes the issue:

@generated function PrefixContext{PrefixInner}(
    ctx::PrefixContext{<:Any, PrefixOuter}
) where {PrefixInner, PrefixOuter}
    :(PrefixContext{$(QuoteNode(Symbol(PrefixOuter, ":", PrefixInner)))}(ctx.ctx))
end

@generated is unfortunately needed, as it won't be compiled away otherwise (failed on this "simple" example). Also, note that this won't completely unwrap everything if we have other contexts wrapping PrefixContext again, i.e. it doesn't ensure that PrefixContext has zero overhead in general, but it does for the case where only the outer-most context is non-prefix-context, i.e. in the case of submodel:)

Now we get:

@model function nested_noprefix(x, y)
    @submodel nested_noprefix1(x, y)
    y ~ Uniform()
end false;
@model function nested_noprefix1(x, y)
    @submodel nested_noprefix2(x, y)
    y1 ~ Uniform()
end false;
@model function nested_noprefix2(x, y)
    z = @submodel nested_noprefix3(x, y)
    y2 ~ Normal(z, 1.0)
end false;
@model function nested_noprefix3(x, y)
    x ~ Uniform()
    y3 ~ Normal(-100.0, 1.0)

    return x + 1000
end

m4 = nested_noprefix(missing, missing);
vi4 = VarInfo(m4);
keys(vi1)
5-element Vector{VarName{sym, Tuple{}} where sym}:
 x
 y
 y1
 y12
 y123
keys(vi2)
5-element Vector{VarName{sym, Tuple{}} where sym}:
 1:x
 1:y
 1:y12
 1:y123
 y
keys(vi3)
5-element Vector{VarName{sym, Tuple{}} where sym}:
 1:2:3:x
 1:2:3:y
 1:2:y
 1:y
 y
keys(vi4)
5-element Vector{VarName{sym, Tuple{}} where sym}:
 x
 y3
 y2
 y1
 y
@benchmark $m1($vi1)
BenchmarkTools.Trial: 
  memory estimate:  160 bytes
  allocs estimate:  5
  --------------
  minimum time:     2.586 μs (0.00% GC)
  median time:      2.609 μs (0.00% GC)
  mean time:        2.720 μs (0.00% GC)
  maximum time:     8.541 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     9
@benchmark $m2($vi2)
BenchmarkTools.Trial: 
  memory estimate:  160 bytes
  allocs estimate:  5
  --------------
  minimum time:     2.555 μs (0.00% GC)
  median time:      2.638 μs (0.00% GC)
  mean time:        2.712 μs (0.00% GC)
  maximum time:     7.377 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     9
@benchmark $m3($vi3)
BenchmarkTools.Trial: 
  memory estimate:  160 bytes
  allocs estimate:  5
  --------------
  minimum time:     2.624 μs (0.00% GC)
  median time:      2.655 μs (0.00% GC)
  mean time:        2.800 μs (0.00% GC)
  maximum time:     8.954 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     9
@benchmark $m4($vi4)
BenchmarkTools.Trial: 
  memory estimate:  160 bytes
  allocs estimate:  5
  --------------
  minimum time:     2.585 μs (0.00% GC)
  median time:      2.665 μs (0.00% GC)
  mean time:        2.755 μs (0.00% GC)
  maximum time:     8.014 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     9

@torfjelde
Copy link
Member Author

torfjelde commented Apr 25, 2021

One thing that is a bit annoying is that we can't do prefix the submodel's variables dynamically, i.e. it's not possible to do

for i = 1:n
    @submodel "sub_$i" submodel(...)
end

😕

EDIT: Nvm, now we can:)

@@ -62,7 +62,7 @@ end

To generate a `Model`, call `model(xvalue)` or `model(xvalue, yvalue)`.
"""
macro model(expr, warn=true)
macro model(expr, warn=false)
Copy link
Member

@devmotion devmotion Apr 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess, if the default value is false we can also just remove it since I doubt that anyone will enable the warnings explicitly. I am not completely sure if the warnings are useful anymore, in particular with the new variable names __varinfo__ etc. it seems unlikely thats someone would use the same name in their model definition. On the other hand, if we could ensure that official macros such as @addlogprob! and @submodel do not cause these warnings, I don't think there is any harm in keeping them.

So if possible, I think it would be better to check in the macro expansion step of the compiler if it is one of the official macros and disable warnings for only the expression generated by them.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@torfjelde What's your opinion?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I weakly lean towards keeping this feature for developers.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh sorry! But yes, I left it there because of the same reason as Hong said. I'm pro leaving it as is, and then if no one uses it for a long time, we might as well just drop it then. No need to rush completely removing it IMO.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought one should not only keep it but also show warnings if not explicitly requested otherwise - i.e., I suggested reverting it back to

Suggested change
macro model(expr, warn=false)
macro model(expr, warn=true)

However, to avoid printing warnings if users use @submodel or @addlogprob! I think one should disable warnings for the expanded code of these macros. It seems a simple if statement in the macro expansion in

return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), warn)
should be sufficient to achieve this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I have to admit I don't like it either 😄 So I think I changed my mind and I would be fine with changing it to warn=false. Even though this changes the behaviour of @model this won't break anyone's code. And in the next breaking release we might even consider removing the warn argument completely.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha, lovely 👍 True, plus I don't think I've ever come across anyone actually using these warnings...

I'll make default false and push 👍

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh nvm, it's already this way, haha. I think this is good to go then!:)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know that it's used in DiffEqBayes since I added it there to avoid the warnings: https://github.com/SciML/DiffEqBayes.jl/blob/1749bc7ade1511d62a858eec4359705901126c92/src/turing_inference.jl#L53 😄 So as long as we do not suddenly remove it completely in a supposedly non-breaking release, it's fine.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha, nice:) Good! I just merged with master and checking that tests run locally. Once that's done I'll bump version and it should be ready for bors!

@@ -62,7 +62,7 @@ end

To generate a `Model`, call `model(xvalue)` or `model(xvalue, yvalue)`.
"""
macro model(expr, warn=true)
macro model(expr, warn=false)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I weakly lean towards keeping this feature for developers.

@yebai
Copy link
Member

yebai commented Apr 26, 2021

Thanks, @torfjelde, it looks great overall. Just a minor suggestion for the submodel syntax. Maybe we can add a sugar-syntax, for example:

@model function demo(y)
    α ~ Uniform()
    μ ~ Normal()
    σ ~ truncated(Normal(), 0, Inf)

    num_steps = size(y, 1)
    num_obs = size(y, 2)
    x = zeros(num_obs) # pre-allocate `x` array
    @inbounds for i = 1:num_obs
        # x = @submodel $(Symbol("ar1_$i")) AR1(num_steps, α, μ, σ)
       x[i] ~ AR1(num_steps, α, μ, σ) # we can use `x[i]` as the prefix.  
        y[:, i] ~ MvNormal(x[i], 0.1)
    end
end;

m = demo(y);
vi = VarInfo(m);

In the above example, we can re-use the clean tilde notation for submodels. Under the hood, tilde will dispatch based on the type of right-hand-side distribution objects (i.e. either a Distributions.Distribution or a Turing.Model). The implementation can also use the left-hand-side variable (which is assumed to be unique in the outer/parent model) as a prefix. A special case is that, when a submodel returns nothing, then the left-hand-side will be assigned to nothing, which is fine.

@torfjelde
Copy link
Member Author

Just a minor suggestion for the submodel syntax. Maybe we can add a sugar-syntax, for example:

The reason why we didn't do this from the start was because I was worried about changing internals; I wanted it to just be an addition rather than a change.

But when I first saw your suggestion I was like "Yeah, why don't we just do that? o.O". Buuut I've been playing around with it a bit now, and there are some issues (took a bit to find them though; was hoping this actually would work)

  1. What do you do in the case of x[:, i] ~ Model()? There are two options:
    • Use the ::Type{TV} approach and pre-allocate x. This fails though, because x is not present in VarInfo and so we can't use get_matching_type to get the corresponding TV. This in turn means we have to deal with Real (which, other than being slow, leads to downstream issues, i.e. StackOverflow because something dispatches in a cycle).
    • Don't actually allocate, just create some other variable to use there. This will be horrible for the user...
  2. And another thing, though not breaking like the above, is that the compiler seems to have a really hard time. At first there was quite a significant overhead (even m2 in the above benchmark was orders of magnitude slower than m1, with way more allocations). After removing the if @generated and forcing it to, aaand adding some more generated functions, I got it down to one additional allocation per nested level. I looked briefly at the generated code, and it seems like the compiler can't inline the model in the same way as it can with @submodel. It's a bit weird though, and might in fact be due to (1).

So, unless we want to add the possibility of tracking arbitrary quantities in VarInfo, a la deterministic, I can't see how we'll solve (1) 😕

@yebai
Copy link
Member

yebai commented Apr 26, 2021

So, unless we want to add the possibility of tracking arbitrary quantities in VarInfo, a la deterministic, I can't see how we'll solve (1) 😕

Distributions already have a Dirac distribution - maybe we can create a fake (multivariate) Dirac distribution for the return value of submodels?

@torfjelde
Copy link
Member Author

That will screw up HMC, etc. with the current setup 😕
IMO we need a separate field in VarInfo which handles determinstic variables. I think anything else will lead to lots of downstream issues.

@yebai
Copy link
Member

yebai commented Apr 26, 2021

IMO we need a separate field in VarInfo which handles determinstic variables. I think anything else will lead to lots of downstream issues.

Summon VarInfo guru @mohamed82008. Any thoughts on how we can quickly add tracking for deterministic variables without breaking lots of functionalities and efficiency for VarInfo?

@yebai
Copy link
Member

yebai commented Apr 28, 2021

I'm happy to merge this PR as-is and implement the suggested sugar syntax in separate efforts. Great PR overall!

@devmotion
Copy link
Member

@torfjelde, could you address #233 (comment) before merging the PR such that warnings are only disabled for the @addlogprob! and @submodel parts? Otherwise it looks good 👍

@torfjelde
Copy link
Member Author

@torfjelde, could you address #233 (comment) before merging the PR such that warnings are only disabled for the @addlogprob! and @SubModel parts? Otherwise it looks good +1

Yep! Also need to add tests before we merge (I'll convert examples above into tests). I just delayed this until we could conclude that this was in fact the approach we wanted to take.

@torfjelde
Copy link
Member Author

bors try

bors bot added a commit that referenced this pull request May 16, 2021
@phipsgabler
Copy link
Member

Absolutely awesome stuff.

Just one thing: can we not again take the path of stringifying structural information into variable names? That was a hassle with indexing, and I fear it will turn out the same with submodel prefixes.

A good case for VarName enhancement, perhaps?

@torfjelde
Copy link
Member Author

Just one thing: can we not again take the path of stringifying structural information into variable names? That was a hassle with indexing, and I fear it will turn out the same with submodel prefixes.

I do agree with you, but don't see how it's feasible to avoid this with the current implementation. And the motivation behind prefixing isn't to encode any sort of structural information; it's only so we can ensure that the variable names are unique but at the same type readable for a human. So for now this is good enough (and is non-breaking), and then should come up with a more "first-class support" for submodels once we start making changes to VarInfo and the like.

@phipsgabler
Copy link
Member

Let me say first say that I totally agree with the sentiment of "it's good for now, and let's add breaking improvements later". So I'm not saying that anything should be changed right now. But...

And the motivation behind prefixing isn't to encode any sort of structural information; it's only so we can ensure that the variable names are unique but at the same type readable for a human.

If for something like

x = @submodel bla()
y ~ blub()

the user wants to query all the variables from within x from a chain, or use a sampler for everything in x, how can they do that, except with string matching on x.?

That's what I meant with "structural". It looks like a hierarchical namespace kind of thing; if that is not the intent, it feels like an unnatural scheme. If the prefix were only a gensym-like "common string part for debugging purposes", I'd expect the internal names to be completely mangled and not for publicly access at all; but in the current form (which I do think is the better option!) my mental model is looking for structural access.

@yebai
Copy link
Member

yebai commented May 17, 2021

Bors try

bors bot added a commit that referenced this pull request May 17, 2021
@torfjelde
Copy link
Member Author

That's what I meant with "structural". It looks like a hierarchical namespace kind of thing; if that is not the intent, it feels like an unnatural scheme.

I agree with this. But when we have lenses in VarName everything will be better 🤓

@torfjelde
Copy link
Member Author

bors r+

bors bot pushed a commit that referenced this pull request May 18, 2021
Part of the motivations for #221 and #222 was so we could add submodels/model-nesting.

Well, now we can. 

Special thanks to @devmotion who reviewed those PRs (several times), improving them significantly, made additional PRs and suggested the current impl of the `@submodel` ❤️ 

EDIT: We fixed the performance:) This now has zero runtime overhead! See comment-section.
EDIT 2: Thanks to @devmotion, we can now alos deal with dynamically specified prefices!

- [Motivating example: AR1-prior](#org46a90a5)
- [Demos](#org7e05701)
- [Can it ever fail?](#org75acb71)
- [Benchmarks](#orga99bcf4)


<a id="org46a90a5"></a>

# Motivating example: AR1-prior

```julia
using Turing
using DynamicPPL
```

```julia
# Could have made model which samples `num_obs` AR1 samples simulatenously,
# but for the sake of showing off dynamic prefixes, we'll only use a vector-implementation.
# The matrix implementation will be quite a bit faster too, but oh well.
@model function AR1(num_steps, α, μ, σ, ::Type{TV} = Vector{Float64}) where {TV}
    η ~ MvNormal(num_steps, 1.0)
    δ = sqrt(1 - α^2)

    x = TV(undef, num_steps)
    x[1] = η[1]
    @inbounds for t = 2:num_steps
        x[t] = @. α * x[t - 1] + δ * η[t]
    end

    return @. μ + σ * x
end

# Generate an observation
σ_obs = 0.1
num_obs = 5
num_steps = 10

ar1 = AR1(num_steps, 0.5, 1.0, 1.0)
ys = mapreduce(hcat, 1:num_obs) do i
    ar1() + σ_obs * randn(num_steps)
end
```

    10×5 Matrix{Float64}:
      2.30189    0.301618  1.73268   -0.65096    1.46835
      2.11187   -1.34878   2.3728     1.02125    3.28422
     -0.249064   0.769488  1.34044    3.22175    2.52196
     -0.25863   -0.216914  0.528954   3.04756    3.8234
      0.372122   0.473511  0.708068   0.76197    0.202003
      0.41487    0.759435  1.80162    0.790204   0.12331
      1.32585    0.567929  2.74316    1.0874     2.82701
      1.84307    1.16138   1.36382    0.735388   1.07423
      3.20139    0.75177   1.57236    0.865401  -0.315341
      1.22479    1.35688   2.8239     0.597959   0.587955

```julia
@model function demo(y)
    α ~ Uniform()
    μ ~ Normal()
    σ ~ truncated(Normal(), 0, Inf)

    num_steps = size(y, 1)
    num_obs = size(y, 2)
    @inbounds for i = 1:num_obs
        x = @SubModel $(Symbol("ar1_$i")) AR1(num_steps, α, μ, σ)
        y[:, i] ~ MvNormal(x, 0.1)
    end
end;

m = demo(y);
vi = VarInfo(m);
```

```julia
keys(vi)
```

    8-element Vector{VarName{sym, Tuple{}} where sym}:
     α
     μ
     σ
     ar1_1.η
     ar1_2.η
     ar1_3.η
     ar1_4.η
     ar1_5.η

```julia
vi[@varname α]
```

    0.9383208224122919

```julia
chain = sample(m, NUTS(1_000, 0.8), 3_000);
```

    ┌ Info: Found initial step size
    │   ϵ = 0.025
    └ @ Turing.Inference /home/tor/.julia/packages/Turing/rHLGJ/src/inference/hmc.jl:188
    Sampling: 100%|█████████████████████████████████████████| Time: 0:04:00

```julia
chain[1001:end, [:α, :μ, :σ], :]
```

    Chains MCMC chain (2000×3×1 Array{Float64, 3}):
    
    Iterations        = 1001:3000
    Thinning interval = 1
    Chains            = 1
    Samples per chain = 2000
    parameters        = α, μ, σ
    internals         = 
    
    Summary Statistics
      parameters      mean       std   naive_se      mcse        ess      rhat 
          Symbol   Float64   Float64    Float64   Float64    Float64   Float64 
    
               α    0.5474    0.1334     0.0030    0.0073   159.6969    0.9995
               μ    1.0039    0.2733     0.0061    0.0168   169.9106    1.0134
               σ    1.1294    0.1807     0.0040    0.0106   166.8670    0.9998
    
    Quantiles
      parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
          Symbol   Float64   Float64   Float64   Float64   Float64 
    
               α    0.2684    0.4625    0.5534    0.6445    0.7861
               μ    0.4248    0.8227    1.0241    1.2011    1.4801
               σ    0.8781    1.0018    1.0989    1.2239    1.5472

Yay! We recovered the true parameters :tada:

```julia
@benchmark $m($vi)
```

    BenchmarkTools.Trial: 
      memory estimate:  12.05 KiB
      allocs estimate:  123
      --------------
      minimum time:     15.091 μs (0.00% GC)
      median time:      17.861 μs (0.00% GC)
      mean time:        19.582 μs (5.23% GC)
      maximum time:     10.293 ms (99.46% GC)
      --------------
      samples:          10000
      evals/sample:     1


<a id="org7e05701"></a>

# Demos

```julia
using DynamicPPL, Distributions
```

    ┌ Info: Precompiling DynamicPPL [366bfd00-2699-11ea-058f-f148b4cae6d8]
    └ @ Base loading.jl:1317

```julia
@model function demo1(x)
    x ~ Normal()
end;
@model function demo2(x, y)
    @SubModel demo1(x)
    y ~ Uniform()
end false;
m2 = demo2(missing, missing);
vi2 = VarInfo(m2);
keys(vi2)
```

    2-element Vector{VarName{sym, Tuple{}} where sym}:
     x
     y

```julia
println(vi2[VarName(Symbol("x"))])
println(vi2[VarName(Symbol("y"))])
```

    0.3069117531180063
    0.7325324947386318

We can also `observe` without issues:

```julia
@model function demo2(x, y)
    @SubModel demo1(x)
    y ~ Normal(x)
end false;
m2 = demo2(1000.0, missing);
vi2 = VarInfo(m2);
keys(vi2)
```

    1-element Vector{VarName{:y, Tuple{}}}:
     y

```julia
vi2[@varname y]
```

    1000.3905079427211

```julia
DynamicPPL.getlogp(vi2)
```

    -500001.9141252931

But what if the models have the same variable-names?!

"Sure, this is cool and all, but can we even use the values from the nested values in the parent model?"

```julia
@model function demo_return(x)
    x ~ Normal()
    return x
end;

@model function demo_useval(x, y)
    x1 = @SubModel sub1 demo_return(x)
    x2 = @SubModel sub2 demo_return(y)

    z ~ Normal(x1 + x2 + 100, 1.0)
end false;
vi = VarInfo(demo_useval(missing, missing));
keys(vi)
```

    3-element Vector{VarName{sym, Tuple{}} where sym}:
     sub1.x
     sub2.x
     z

```julia
vi[@varname z]
```

    101.09066854862154

And just to prove a point:

```julia
@model function nested(x, y)
    @SubModel 1 nested1(x, y)
    y ~ Uniform()
end false;
@model function nested1(x, y)
    @SubModel 2 nested2(x, y)
    y ~ Uniform()
end false;
@model function nested2(x, y)
    z = @SubModel 3 nested3(x, y)
    y ~ Normal(z, 1.0)
end false;
@model function nested3(x, y)
    x ~ Uniform()
    y ~ Normal(-100.0, 1.0)

    return x + 1000
end false;

m = nested(missing, missing);
vi = VarInfo(m);
keys(vi)
```

    5-element Vector{VarName{sym, Tuple{}} where sym}:
     1.2.3.x
     1.2.3.y
     1.2.y
     1.y
     y

```julia
vi[VarName(Symbol("1.2.y"))]
```

    1000.5609156083766

```julia
DynamicPPL.getlogp(vi)
```

    -4.620040828101227


<a id="org75acb71"></a>

# Can it ever fail?

Yeah, if the user doesn't provide the prefix, it can:

```julia
@model function nested(x, y)
    @SubModel nested1(x, y)
    y ~ Uniform()
end false;
@model function nested1(x, y)
    @SubModel nested2(x, y)
    y ~ Uniform()
end false;
@model function nested2(x, y)
    z = @SubModel nested3(x, y)
    y ~ Normal(z, 1.0)
end false;
@model function nested3(x, y)
    x ~ Uniform()
    y ~ Normal(-100.0, 1.0)

    return x + 1000
end false;

m = nested(missing, missing);
vi = VarInfo(m);
keys(vi)
```

    2-element Vector{VarName{sym, Tuple{}} where sym}:
     x
     y

```julia
# Inner-most value is recorded (i.e. the first one reached)
vi[@varname y]
```

    -100.16836599596732

And it messes up the logp computation:

```julia
DynamicPPL.getlogp(vi)
```

    -Inf

But I could imagine there's a way for us to fix this, or at least warn the user when this happens.


<a id="orga99bcf4"></a>

# Benchmarks

At this point you're probably wondering, "but does it have any overhead (at runtime)?". For a "shallow" nestings, nah, but if you go deep enough there seems to be a tiny bit (likely because we're calling the "constructor" for the model):

```julia
using BenchmarkTools

@model function base(x, y)
    x ~ Uniform()
    y ~ Uniform()
    y1 ~ Uniform()
    z = x + 1000
    y12 ~ Normal()
    y123 ~ Normal(-100.0, 1.0)
end

m1 = base(missing, missing);
vi1 = VarInfo(m1);
```

```julia
@model function nested_shallow(x, y)
    @SubModel 1 nested1_shallow(x, y)
    y ~ Uniform()
end false;
@model function nested1_shallow(x, y)
    x ~ Uniform()
    y ~ Uniform()
    z = x + 1000
    y12 ~ Normal()
    y123 ~ Normal(-100.0, 1.0)
end false;

m2 = nested_shallow(missing, missing);
vi2 = VarInfo(m2);
```

```julia
@model function nested(x, y)
    @SubModel 1 nested1(x, y)
    y ~ Uniform()
end false;
@model function nested1(x, y)
    @SubModel 2 nested2(x, y)
    y ~ Uniform()
end false;
@model function nested2(x, y)
    z = @SubModel 3 nested3(x, y)
    y ~ Normal(z, 1.0)
end false;
@model function nested3(x, y)
    x ~ Uniform()
    y ~ Normal(-100.0, 1.0)

    return x + 1000
end

m3 = nested(missing, missing);
vi3 = VarInfo(m3);
```

```julia
@model function nested_noprefix(x, y)
    @SubModel nested_noprefix1(x, y)
    y ~ Uniform()
end false;
@model function nested_noprefix1(x, y)
    @SubModel nested_noprefix2(x, y)
    y1 ~ Uniform()
end false;
@model function nested_noprefix2(x, y)
    z = @SubModel nested_noprefix3(x, y)
    y2 ~ Normal(z, 1.0)
end false;
@model function nested_noprefix3(x, y)
    x ~ Uniform()
    y3 ~ Normal(-100.0, 1.0)

    return x + 1000
end

m4 = nested_noprefix(missing, missing);
vi4 = VarInfo(m4);
```

```julia
keys(vi1)
```

    5-element Vector{VarName{sym, Tuple{}} where sym}:
     x
     y
     y1
     y12
     y123

```julia
keys(vi2)
```

    5-element Vector{VarName{sym, Tuple{}} where sym}:
     1.x
     1.y
     1.y12
     1.y123
     y

```julia
keys(vi3)
```

    5-element Vector{VarName{sym, Tuple{}} where sym}:
     1.2.3.x
     1.2.3.y
     1.2.y
     1.y
     y

```julia
keys(vi4)
```

    5-element Vector{VarName{sym, Tuple{}} where sym}:
     x
     y3
     y2
     y1
     y

```julia
@benchmark $m1($vi1)
```

    BenchmarkTools.Trial: 
      memory estimate:  160 bytes
      allocs estimate:  5
      --------------
      minimum time:     1.714 μs (0.00% GC)
      median time:      1.747 μs (0.00% GC)
      mean time:        1.835 μs (0.00% GC)
      maximum time:     6.894 μs (0.00% GC)
      --------------
      samples:          10000
      evals/sample:     10

```julia
@benchmark $m2($vi2)
```

    BenchmarkTools.Trial: 
      memory estimate:  160 bytes
      allocs estimate:  5
      --------------
      minimum time:     1.759 μs (0.00% GC)
      median time:      1.778 μs (0.00% GC)
      mean time:        1.819 μs (0.00% GC)
      maximum time:     5.563 μs (0.00% GC)
      --------------
      samples:          10000
      evals/sample:     10

```julia
@benchmark $m3($vi3)
```

    BenchmarkTools.Trial: 
      memory estimate:  160 bytes
      allocs estimate:  5
      --------------
      minimum time:     1.718 μs (0.00% GC)
      median time:      1.746 μs (0.00% GC)
      mean time:        1.787 μs (0.00% GC)
      maximum time:     5.758 μs (0.00% GC)
      --------------
      samples:          10000
      evals/sample:     10

```julia
@benchmark $m4($vi4)
```

    BenchmarkTools.Trial: 
      memory estimate:  160 bytes
      allocs estimate:  5
      --------------
      minimum time:     1.672 μs (0.00% GC)
      median time:      1.696 μs (0.00% GC)
      mean time:        1.756 μs (0.00% GC)
      maximum time:     4.882 μs (0.00% GC)
      --------------
      samples:          10000
      evals/sample:     10

Notice that the number of allocations have increased for the deeply nested model. Seems like the Julia compiler isn't too good at inferring the return-types of Turing-models? This seems to be the case too by looking at the lowered code. I haven't given this too much thought yet btw; likely is a way for us to help the compiler here.
@bors bors bot changed the title Support for submodels [Merged by Bors] - Support for submodels May 18, 2021
@bors bors bot closed this May 18, 2021
@bors bors bot deleted the tor/submodels branch May 18, 2021 13:26
@yebai yebai mentioned this pull request Jun 30, 2021
@dlakelan
Copy link

Hi all. Is there any discussion of submodels in the documentation at all? I couldn't find it. If not, where should I file an issue to track the fact that this needs to get documented for users?

@torfjelde
Copy link
Member Author

@dlakelan
Copy link

Hi. Yes, useful if you know where to look. There should maybe be some references to that in the main Turing.jl docs? https://turing.ml/stable/

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants