-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Merged by Bors] - Small simplification of compiler #221
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's always great if the implementation can be simplified 🙂 I am a bit worried though about the impact it might have on performance and compile times. So it would be good to perform some benchmarks and check compilation before releasing it. E.g., I assume that in most models most variables are not arguments (i.e. observed) so currently already at macro expansion the observe branch and isassumption
checks are elided completely.
src/compiler.jl
Outdated
$left = if $isassumption | ||
$(DynamicPPL.tilde_assume)( | ||
_rng, _context, _sampler, $tmpright, $vn, $inds, _varinfo) | ||
else | ||
$(DynamicPPL.tilde_observe)( | ||
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo) | ||
$left | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is different from the current implementation and might lead to different return values. Currently, we assign $left
if it is an assumption but only return tilde_observe(...)
if it is not. IIRC this was changed on purpose in some PR a while back but I don't remember the exact details...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But in this case I assign $left
to $left
, i.e. no change, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Like I see what you mean: it does generate different code, but the resulting model will behave exactly the same.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, the return value is possibly different. Before we had
$left = if $isassumption | |
$(DynamicPPL.tilde_assume)( | |
_rng, _context, _sampler, $tmpright, $vn, $inds, _varinfo) | |
else | |
$(DynamicPPL.tilde_observe)( | |
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo) | |
$left | |
end | |
if $isassumption | |
$left = $(DynamicPPL.tilde_assume)( | |
_rng, _context, _sampler, $tmpright, $vn, $inds, _varinfo) | |
else | |
$(DynamicPPL.tilde_observe)( | |
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo) | |
end |
and we should not change it (at least not in this PR).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But there will be never anything on the receiving end of this value though, right? We're replacing a line such as x ~ Normal()
with the above.
Also, I introduced this because if you put the assignment of left
inside the if
statement, you're making the variable (if doesn't exist in the scope, i.e. not part of args
) local to that scope, and thus not usable in the rest of the model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW if
do not introduce a separate scope. And in the second branch $left
is already defined. And an additional why we don't want to do this: imagine you have X[1] ~ Normal(0, 1)
; if X
is an argument to the model, then you would evaluate the log pdf in tilde_observe
and then reassign X[1]
when you assign something to left - this is not needed (and probably undesired) and might be quite slow in higher-dimensional examples or with ranges of samples.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aaaaah weeeeell do we really care about that? 😅 I didn't think of this though, so good you brought it up 👍 Still, it seems like this isn't something we should actually take into consideration? (you're the PPL-people though, so you know better here for sure)
Btw, this will go on the most recent release of DynamicPPL, which includes breaking changes right? Just in case that plays into the consideration.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw, this will go on the most recent release of DynamicPPL, which includes breaking changes right? Just in case that plays into the consideration.
If we introduce breaking changes, we would have to bump the minor version (which is possible, of course). Turing uses (or can use) the latest release of DynamicPPL which is (basically always) the master branch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I thought we just did for AbstractPPL, but I misread before 👍
But what do you say to my previous question: is this something we actually care about? The return-type in the case where this is the last statement?
EDIT: Woah, sorry I missed your previous comment! I see what you mean. Indexing is big issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW if do not introduce a separate scope.
Also, of course! Crap. I got confused because when I tried just using the if-statement, I got complaints that the variable wasn't defined.. Hmm, maybe I just did something wrong. I'll try your suggestion again 👍
return quote | ||
$(top...) | ||
$left = $(DynamicPPL.tilde_assume)(_rng, _context, _sampler, $tmpright, $vn, | ||
$inds, _varinfo) | ||
$isassumption = $(DynamicPPL.isassumption(left)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW (I know not part of this PR 🙂): Why do we actually assign the output of isassumption
to a variable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep:)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh haha, I read do we actually assign the output of isassumption to a variable?
, didn't see the why
.
It's because isassumption
generates a larger if-statement which returns a Bool
, so we can't e.g. do if $(DynamicPPL.isassumption(left))
.
100% with you! I did just check quickly a simple model to make sure, and didn't affect runtime. But yeah, should do some additional benchmarking first. Do we have a go-to suite for this or something? |
No, unfortunately not (at least I'm not aware of it). At some point we experimented with automatic benchmarking in Turing IIRC, maybe there's still something in the repo? |
Hmm, aight, I'll see if I can come up with something. Btw the tests pass (just tried locally) 👍 |
Co-authored-by: David Widmann <[email protected]>
Thoughts about updated version @devmotion ? |
👍 Have you checked if it affects performance, compilation, and/or the compiled code? |
Nope, I'll do that now; just wanted to know you thought the currently version was good before starting down that path 👍 |
Looks good to me now 👍 |
NOTE: See comments follow this one for "analysis" of what's actually going on here. Okay, so this is going to be a bit of a moutful :upsidedownface: I'd recommend using the links in the TOC to jump around, and potentially have two browser windows next to each other: one looking at the new implementation, and one looking at the old implementation. If there's anything else you'd like to see comparison for, feel free to mention. I was quite certain what to do here, so I just took some models and compared:) New implementationusing Pkg
Pkg.status()
using DynamicPPL, Distributions, BenchmarkTools, MacroTools Model 1@time begin
@model function demo1(x)
m ~ Normal()
x ~ Normal(m, 1)
return (m = m, x = x)
end
end;
model_def = demo1
data = 1.0 @time model_def(data)();
@btime $(model_def(data))();
m = model_def(data);
var_info = VarInfo(m);
@btime $m($var_info);
untyped_var_info = VarInfo();
m(untyped_var_info)
@btime $m($untyped_var_info);
m = model_def(data);
var_info = VarInfo(m);
@code_warntype m(var_info)
rng = DynamicPPL.Random.MersenneTwister(42);
spl = DynamicPPL.SampleFromPrior()
ctx = DynamicPPL.DefaultContext()
@code_typed m.f(rng, m, var_info, spl, ctx, m.args...)
expr = @macroexpand begin
@model function demo1(x)
m ~ Normal()
x ~ Normal(m, 1)
return (m = m, x = x)
end
end
expr |> Base.remove_linenums!
Model 2@time begin
@model function demo2(y)
# Our prior belief about the probability of heads in a coin.
p ~ Beta(1, 1)
# The number of observations.
N = length(y)
for n in 1:N
# Heads or tails of a coin are drawn from a Bernoulli distribution.
y[n] ~ Bernoulli(p)
end
end;
end;
model_def = demo2
data = rand(0:1, 10) @time model_def(data)();
@btime $(model_def(data))();
m = model_def(data);
var_info = VarInfo(m);
@btime $m($var_info);
untyped_var_info = VarInfo();
m(untyped_var_info)
@btime $m($untyped_var_info);
m = model_def(data);
var_info = VarInfo(m);
@code_warntype m(var_info)
rng = DynamicPPL.Random.MersenneTwister(42);
spl = DynamicPPL.SampleFromPrior()
ctx = DynamicPPL.DefaultContext()
@code_typed m.f(rng, m, var_info, spl, ctx, m.args...)
expr = @macroexpand begin
@model function demo2(y)
# Our prior belief about the probability of heads in a coin.
p ~ Beta(1, 1)
# The number of observations.
N = length(y)
for n in 1:N
# Heads or tails of a coin are drawn from a Bernoulli distribution.
y[n] ~ Bernoulli(p)
end
end;
end
expr |> Base.remove_linenums!
Model 3@time begin
@model function demo3(x)
D, N = size(x)
# Draw the parameters for cluster 1.
μ1 ~ Normal()
# Draw the parameters for cluster 2.
μ2 ~ Normal()
μ = [μ1, μ2]
# Comment out this line if you instead want to draw the weights.
w = [0.5, 0.5]
# Draw assignments for each datum and generate it from a multivariate normal.
k = Vector{Int}(undef, N)
for i in 1:N
k[i] ~ Categorical(w)
x[:,i] ~ MvNormal([μ[k[i]], μ[k[i]]], 1.)
end
return k
end
end;
model_def = demo3
# Construct 30 data points for each cluster.
N = 30
# Parameters for each cluster, we assume that each cluster is Gaussian distributed in the example.
μs = [-3.5, 0.0]
# Construct the data points.
data = mapreduce(c -> rand(MvNormal([μs[c], μs[c]], 1.), N), hcat, 1:2) @time model_def(data)();
@btime $(model_def(data))();
m = model_def(data);
var_info = VarInfo(m);
@btime $m($var_info);
untyped_var_info = VarInfo();
m(untyped_var_info)
@btime $m($untyped_var_info);
m = model_def(data);
var_info = VarInfo(m);
@code_warntype m(var_info)
rng = DynamicPPL.Random.MersenneTwister(42);
spl = DynamicPPL.SampleFromPrior()
ctx = DynamicPPL.DefaultContext()
@code_typed m.f(rng, m, var_info, spl, ctx, m.args...)
expr = @macroexpand begin
@model function demo3(x)
D, N = size(x)
# Draw the parameters for cluster 1.
μ1 ~ Normal()
# Draw the parameters for cluster 2.
μ2 ~ Normal()
μ = [μ1, μ2]
# Comment out this line if you instead want to draw the weights.
w = [0.5, 0.5]
# Draw assignments for each datum and generate it from a multivariate normal.
k = Vector{Int}(undef, N)
for i in 1:N
k[i] ~ Categorical(w)
x[:,i] ~ MvNormal([μ[k[i]], μ[k[i]]], 1.)
end
return k
end
end
expr |> Base.remove_linenums!
Old implementationusing Pkg
projpath = "/tmp/DynamicPPLOld/"
if !ispath(projpath)
Pkg.generate(projpath)
end
Pkg.activate(projpath)
pkg"add DynamicPPL#master"
pkg"add Distributions MacroTools" Pkg.status()
using DynamicPPL, Distributions, BenchmarkTools, MacroTools Model 1@time begin
@model function demo1(x)
m ~ Normal()
x ~ Normal(m, 1)
return (m = m, x = x)
end
end;
model_def = demo1
data = 1.0
@time model_def(data)();
@btime $(model_def(data))();
m = model_def(data);
var_info = VarInfo(m);
@btime $m($var_info);
untyped_var_info = VarInfo();
m(untyped_var_info)
@btime $m($untyped_var_info);
m = model_def(data);
var_info = VarInfo(m);
@code_warntype m(var_info)
rng = DynamicPPL.Random.MersenneTwister(42);
spl = DynamicPPL.SampleFromPrior()
ctx = DynamicPPL.DefaultContext()
@code_typed m.f(rng, m, var_info, spl, ctx, m.args...)
expr = @macroexpand begin
@model function demo1(x)
m ~ Normal()
x ~ Normal(m, 1)
return (m = m, x = x)
end
end
expr |> Base.remove_linenums!
Model 2@time begin
@model function demo2(y)
# Our prior belief about the probability of heads in a coin.
p ~ Beta(1, 1)
# The number of observations.
N = length(y)
for n in 1:N
# Heads or tails of a coin are drawn from a Bernoulli distribution.
y[n] ~ Bernoulli(p)
end
end;
end;
model_def = demo2
data = rand(0:1, 10)
@time model_def(data)();
@btime $(model_def(data))();
m = model_def(data);
var_info = VarInfo(m);
@btime $m($var_info);
untyped_var_info = VarInfo();
m(untyped_var_info)
@btime $m($untyped_var_info);
m = model_def(data);
var_info = VarInfo(m);
@code_warntype m(var_info)
rng = DynamicPPL.Random.MersenneTwister(42);
spl = DynamicPPL.SampleFromPrior()
ctx = DynamicPPL.DefaultContext()
@code_typed m.f(rng, m, var_info, spl, ctx, m.args...)
expr = @macroexpand begin
@model function demo2(y)
# Our prior belief about the probability of heads in a coin.
p ~ Beta(1, 1)
# The number of observations.
N = length(y)
for n in 1:N
# Heads or tails of a coin are drawn from a Bernoulli distribution.
y[n] ~ Bernoulli(p)
end
end;
end
expr |> Base.remove_linenums!
Model 3@time begin
@model function demo3(x)
D, N = size(x)
# Draw the parameters for cluster 1.
μ1 ~ Normal()
# Draw the parameters for cluster 2.
μ2 ~ Normal()
μ = [μ1, μ2]
# Comment out this line if you instead want to draw the weights.
w = [0.5, 0.5]
# Draw assignments for each datum and generate it from a multivariate normal.
k = Vector{Int}(undef, N)
for i in 1:N
k[i] ~ Categorical(w)
x[:,i] ~ MvNormal([μ[k[i]], μ[k[i]]], 1.)
end
return k
end
end;
model_def = demo3
# Construct 30 data points for each cluster.
N = 30
# Parameters for each cluster, we assume that each cluster is Gaussian distributed in the example.
μs = [-3.5, 0.0]
# Construct the data points.
data = mapreduce(c -> rand(MvNormal([μs[c], μs[c]], 1.), N), hcat, 1:2)
@time model_def(data)();
@btime $(model_def(data))();
m = model_def(data);
var_info = VarInfo(m);
@btime $m($var_info);
untyped_var_info = VarInfo();
m(untyped_var_info)
@btime $m($untyped_var_info);
m = model_def(data);
var_info = VarInfo(m);
@code_warntype m(var_info)
rng = DynamicPPL.Random.MersenneTwister(42);
spl = DynamicPPL.SampleFromPrior()
ctx = DynamicPPL.DefaultContext()
@code_typed m.f(rng, m, var_info, spl, ctx, m.args...)
expr = @macroexpand begin
@model function demo3(x)
D, N = size(x)
# Draw the parameters for cluster 1.
μ1 ~ Normal()
# Draw the parameters for cluster 2.
μ2 ~ Normal()
μ = [μ1, μ2]
# Comment out this line if you instead want to draw the weights.
w = [0.5, 0.5]
# Draw assignments for each datum and generate it from a multivariate normal.
k = Vector{Int}(undef, N)
for i in 1:N
k[i] ~ Categorical(w)
x[:,i] ~ MvNormal([μ[k[i]], μ[k[i]]], 1.)
end
return k
end
end
expr |> Base.remove_linenums!
ComparisonModel 1diff -y model1_lowered_old.jl model1_lowered_new.jl
Model 2diff -y model2_lowered_old.jl model2_lowered_new.jl
Model 3diff -y model3_lowered_old.jl model3_lowered_new.jl
|
It seems the new implementation leads to more allocations, and so possible also worse performance, when compiling the model. Also running model 3 is significantly slower and leads to more allocations, both with typed and "untyped" VarInfo. I guess both results are not completely surprising - we introduce more branches which puts more pressure on the compiler. The diffs are difficult to read, is there any significant difference, in particular for the third model? |
Yeaaah it seems like it fails to infer what's going to happen with So |
Ah, but this makes sense since this is within the for-loop (btw, I realized I included the var"##vn#432" = (VarName)(:k, ((i,),)) Hmm, so I get why the optimization was introduced in the first place then. It avoids unnecessarily leaving the option of something being observed if not present in |
I guess what we could do (as I believe I mentioned at some point before) is to include both implementations, e.g. make available a It does "unnecessarily" make the code more cluttered though... Not sure. |
I wonder if it would help the compiler if |
You mean, e.g. function isassumption(vn::VarName, model::Model, left)
!inargnames(vn, model) || inmissings(vn, model) || false
end
function isassumption(vn::VarName, model::Model, left::Missing)
!inargnames(vn, model) || inmissings(vn, model) || true
end EDIT: I guess the issue is that we're not going to be able to evaluate if something is missing using this approach, since EDIT 2: Though we can "fix" this by doing: $isassumption = $(DynamicPPL.isassumption)($vn, _model) || $left === missing instead, and dropping the EDIT 3: Doesn't seem to help. Also, thinking about it I don't expect it to? I guess the issue lies in the fact that |
I thought maybe something like: isassumption(vn, model) = !inargnames(vn, model) || inmissings(vn, model)
isassumption(vn, model, left) = _isassumption(Val{isassumption(vn, model)}(), left)
_isassumption(::Val{true}, left) = true
_isassumption(::Val{false}, ::Missing) = true
_isassumption(::Val{false}, left) = false The main intention would be to enforce that the static checks are prioritized and that the runtime checks for missing are only needed if it seems to be an observation according to the static checks. I am not sure if it makes a difference but my hope would be that the dispatches are easier to handle for the compiler than the if-else statements. |
Good point regarding |
Re-stating another edit from above: I guess the issue lies in the fact that inargnames and inmissings will always fail to be inferred when the symbol passed to VarName isn't static. I.e. not sure we can actually circumvent this for non-static |
You mean the indices, e.g., the |
And it seems it should be possible to support dynamic analysis for indices as well but again prioritize static checks? |
Truuue! But for some reason it seems like the diversion through Is it because we effectively do |
EDIT: It actually has fewer allocations for Model 2 than the EDIT 2: EDIT 3: Eh, compilation time is a bit incoclosive. Model 1 is slower, Model 2 is faster, and Model 3 is about the same. EDIT 4: It's consistently faster than old implementation though! E.g. check out the numbers for Model 2 in particular. When interpolating using Pkg
Pkg.status()
using DynamicPPL, Distributions, BenchmarkTools, MacroTools
Model 1@time begin
@model function demo1(x)
m ~ Normal()
x ~ Normal(m, 1)
return (m = m, x = x)
end
end;
model_def = demo1
data = 1.0 @time model_def(data)();
@btime $(model_def(data))();
m = model_def(data);
var_info = VarInfo(m);
@btime $m($var_info);
untyped_var_info = VarInfo();
m(untyped_var_info)
@btime $m($untyped_var_info);
m = model_def(data);
var_info = VarInfo(m);
@code_warntype m(var_info)
rng = DynamicPPL.Random.MersenneTwister(42);
spl = DynamicPPL.SampleFromPrior()
ctx = DynamicPPL.DefaultContext()
@code_typed m.f(rng, m, var_info, spl, ctx, m.args...)
expr = @macroexpand begin
@model function demo1(x)
m ~ Normal()
x ~ Normal(m, 1)
return (m = m, x = x)
end
end
expr |> Base.remove_linenums!
Model 2@time begin
@model function demo2(y)
# Our prior belief about the probability of heads in a coin.
p ~ Beta(1, 1)
# The number of observations.
N = length(y)
for n in 1:N
# Heads or tails of a coin are drawn from a Bernoulli distribution.
y[n] ~ Bernoulli(p)
end
end;
end;
model_def = demo2
data = rand(0:1, 10) @time model_def(data)();
@btime $(model_def(data))();
m = model_def(data);
var_info = VarInfo(m);
@btime $m($var_info);
untyped_var_info = VarInfo();
m(untyped_var_info)
@btime $m($untyped_var_info);
m = model_def(data);
var_info = VarInfo(m);
@code_warntype m(var_info)
rng = DynamicPPL.Random.MersenneTwister(42);
spl = DynamicPPL.SampleFromPrior()
ctx = DynamicPPL.DefaultContext()
@code_typed m.f(rng, m, var_info, spl, ctx, m.args...)
expr = @macroexpand begin
@model function demo2(y)
# Our prior belief about the probability of heads in a coin.
p ~ Beta(1, 1)
# The number of observations.
N = length(y)
for n in 1:N
# Heads or tails of a coin are drawn from a Bernoulli distribution.
y[n] ~ Bernoulli(p)
end
end;
end
expr |> Base.remove_linenums!
Model 3@time begin
@model function demo3(x)
D, N = size(x)
# Draw the parameters for cluster 1.
μ1 ~ Normal()
# Draw the parameters for cluster 2.
μ2 ~ Normal()
μ = [μ1, μ2]
# Comment out this line if you instead want to draw the weights.
w = [0.5, 0.5]
# Draw assignments for each datum and generate it from a multivariate normal.
k = Vector{Int}(undef, N)
for i in 1:N
k[i] ~ Categorical(w)
x[:,i] ~ MvNormal([μ[k[i]], μ[k[i]]], 1.)
end
return k
end
end;
model_def = demo3
# Construct 30 data points for each cluster.
N = 30
# Parameters for each cluster, we assume that each cluster is Gaussian distributed in the example.
μs = [-3.5, 0.0]
# Construct the data points.
data = mapreduce(c -> rand(MvNormal([μs[c], μs[c]], 1.), N), hcat, 1:2) @time model_def(data)();
@btime $(model_def(data))();
m = model_def(data);
var_info = VarInfo(m);
@btime $m($var_info);
untyped_var_info = VarInfo();
m(untyped_var_info)
@btime $m($untyped_var_info);
m = model_def(data);
var_info = VarInfo(m);
@code_warntype m(var_info)
rng = DynamicPPL.Random.MersenneTwister(42);
spl = DynamicPPL.SampleFromPrior()
ctx = DynamicPPL.DefaultContext()
@code_typed m.f(rng, m, var_info, spl, ctx, m.args...)
expr = @macroexpand begin
@model function demo3(x)
D, N = size(x)
# Draw the parameters for cluster 1.
μ1 ~ Normal()
# Draw the parameters for cluster 2.
μ2 ~ Normal()
μ = [μ1, μ2]
# Comment out this line if you instead want to draw the weights.
w = [0.5, 0.5]
# Draw assignments for each datum and generate it from a multivariate normal.
k = Vector{Int}(undef, N)
for i in 1:N
k[i] ~ Categorical(w)
x[:,i] ~ MvNormal([μ[k[i]], μ[k[i]]], 1.)
end
return k
end
end
expr |> Base.remove_linenums!
|
Lovely:) Thanks!
using DynamicPPL, MacroTools
using Distributions
import Random
# Since we don't need access to `args` at expansion, we can re-use
# the implementation of `generate_mainbody`.
"""
@tilde x ~ Distribution
Generates the equivalent of a tilde-statement in DynamicPPL.
"""
macro tilde(expr)
esc(DynamicPPL.generate_mainbody(__module__, expr, false))
end
"""
@submodel_def expr
Insert variables internal to `DynamicPPL.@model` at beginning of `args` in method definition.
"""
macro submodel_def(expr)
modeldef = MacroTools.splitdef(expr)
modeldef[:args] = vcat(
[
:(_rng::$(Random.AbstractRNG)),
:(_model::$(DynamicPPL.Model)),
:(_varinfo::$(DynamicPPL.AbstractVarInfo)),
:(_sampler::$(DynamicPPL.AbstractSampler)),
:(_context::$(DynamicPPL.AbstractContext)),
],
modeldef[:args]
)
return esc(MacroTools.combinedef(modeldef))
end
"""
@submodel_call f(args...; kwargs...)
Insert variables internal to `@model` at beginning of `args` in a function call.
"""
macro submodel_call(expr)
# If not a `:call`, don't do anything.
if !Meta.isexpr(expr, :call)
return expr
end
# Insert arguments internal to `@model`.
has_kwargs = Meta.isexpr(expr.args[2], :parameters)
kwargs = has_kwargs ? expr.args[2] : []
args_start_idx = 2 + has_kwargs
args = if length(expr.args) ≥ args_start_idx
expr.args[args_start_idx:end]
else
[]
end
expr.args = vcat(
expr.args[1],
kwargs,
[:(_rng), :(_model), :(_varinfo), :(_sampler), :(_context)],
args
)
return esc(expr)
end julia> @submodel_def function example_submodel(x)
@tilde x ~ Normal()
return x
end
example_submodel (generic function with 1 method)
julia> @model function demo(x)
# Equivalent of
# `example_submodel(_rng, _model, _varinfo, _sampler, _context)`
@submodel_call example_submodel(x)
end false
demo (generic function with 1 method)
julia> m = demo(missing)
Model{var"#1#2", (:x,), (), (:x,), Tuple{Missing}, Tuple{}}(:demo,
var"#1#2"(), (x = missing,), NamedTuple())
julia> vi = VarInfo(m);
julia> vi[@varname(x)]
-0.6285896547007351
julia> m = demo(1.0)
Model{var"#1#2", (:x,), (), (), Tuple{Float64}, Tuple{}}(:demo, var"#1#2"(), (x = 1.0,), NamedTuple())
julia> m()
1.0 Some example-output of the macros defined here (because I can't embed julia-repl examples in the docstrings on GitHub): julia> @macroexpand @tilde x ~ Normal()
quote
var"##tmpright#257" = Normal()
var"##tmpright#257" isa Union{Distribution, AbstractVector{<:Distribution}} || throw(ArgumentError("Right-hand side of a ~ must be subtype of Distribution or a vector of Distributions."))
var"##vn#259" = (VarName){:x}()
var"##inds#260" = ()
var"##isassumption#261" = begin
let var"##vn#262" = (VarName){:x}()
if !((DynamicPPL.inargnames)(var"##vn#262", _model)) || (DynamicPPL.inmissings)(var"##vn#262", _model)
true
else
x === missing
end
end
end
if var"##isassumption#261"
x = (DynamicPPL.tilde_assume)(_rng, _context, _sampler, var"##tmpright#257", var"##vn#259", var"##inds#260", _varinfo)
else
(DynamicPPL.tilde_observe)(_context, _sampler, var"##tmpright#257", x, var"##vn#259", var"##inds#260", _varinfo)
end
end julia> @macroexpand(@submodel_def function f(x, y; z = 1)
return x + y + z
end) |> Base.remove_linenums!
:(function f(_rng::AbstractRNG, _model::Model, _varinfo::AbstractVarInfo, _sampler::AbstractMCMC.AbstractSampler, _context::DynamicPPL.AbstractContext, x, y; z = 1)
return x + y + z
end) julia> @macroexpand @submodel_call f(x, y; z = 1)
:(f(_rng, _model, _varinfo, _sampler, _context, x, y; z = 1)) It's worth pointing out that I'm not sure this is something we want to "offer" the users 😅 : But at least such a thing is possible. Of course one could just re-implement the |
Ah I see, yes, the struct CatchAllArgs end
Base.in(::Symbol, ::CatchAllArgs) = true
"""
@tilde x ~ Distribution
Generates the equivalent of a tilde-statement in DynamicPPL.
"""
macro tilde(expr)
esc(DynamicPPL.generate_mainbody(__module__, expr, CatchAllArgs(), false))
end I also thought about something like |
Ah yes this is better! But it still complicates things a bit 😕
Agreed 👍 You could of course return the values to make them accessible in the main scope, which one might argue is the best way to handle submodels anyways, e.g. often you have some process with latent variables that you don't care about and so you're fine with those not being availabe in the global scope. We could also start getting fancy, doing something like @submodel (a, b, c) ~ submodel1(args...) # `submodel1` actually contains variables `(a, b, c, d, e, f)` where which generates something like begin
@submodel_call submodel1.f(args...)
_varinfo[@varname(a)], _varinfo[@varname(b)], _varinfo[@varname(c)]
end And for the variables which are NOT then returned back to the main scope, we could even add in an identifier to the variable names somehow, e.g. prefixing the variable names with the name of the model/function name, so that when the latent variables show up in the resulting chain, it's all good.
The only annoying thing here (and the reason why I implemented a "regular" function instead), is that you need to actually instantiate the model which can be a bit annoying, e.g. in the above example you need to also instantiate the @model function submodel(x)
x ~ Normal()
end
subm = submodel(missing)
subm()
@model function demo(x)
@submodel_call subm.f(x)
end false
m = demo(missing)
m()
m = demo(1.0)
m()
|
It should be added anyway since otherwise variables of the same name will mess with each other in the VarInfo object.
You can hide this by instantiating it implicitly. E.g., one could define macro submodel(expr)
return :(_evaluate($(esc(:_rng)), $(esc(expr)), $(esc(:_varinfo)), $(esc(:_sampler)), $(esc(:_context))))
end and then just write @model function demo1(x)
x ~ Normal()
end
@model function demo2(x, y)
@submodel demo1(x)
y ~ Uniform()
end In principle, one could even use a different name for @model function demo2(a, y)
@submodel demo1(a)
y ~ Uniform()
end |
Oh wow, nice! I literally just converted But what you just did is waaaay nicer! I guess there would be a super-tiny overhead to |
So really the only thing missing is just passing around a EDIT: On second thought, I guess this wouldn't be sufficient since we'd have to add the prefix at DPPL-compile time, at which point we don't know whether or not it's a submodel. EDIT 2: No matter, being able to use a submodel but having to "manually" make sure that you're not using the same variable names, is still very, very useful. EDIT 3: On third thought, even if you have to specify the prefix at definition, it would still be immensely useful. Most of the times you have a particular part of the model that you want to use as a sub-model in different models, and being able to just do: @model function m(args...)
# ...
end true "prefix_yo" would be very helpful. |
I assume it would be most intuitive/convenient if you can specify the prefix when including the submodel, not when defining it. Not sure if there are any dispatch issues but my first try would be to call BTW coming back to this PR, there's a new version of AbstractPPL now with the changes in the AbstractPPL PR. I think it would be good to merge #223 first (because otherwise the integration tests don't work) and then we can probably merge this PR I think if there are no major concerns 🙂 |
Contexts! YES! I was thinking "should we include a But yeah, we can sort that ought in a separet issue/PR 👍 I'll merge and get tests running 👍 |
bors try |
Co-authored-by: David Widmann <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Should I bump version and bors it? |
bors r+ |
## Overview At the moment, we perform a check at model-expansion as to whether or not `vsym(left) in args`, where `args` is the arguments of the model. 1. If `true`, we return a block of code which uses `DynamicPPL.isassumption` to check whether or not to call `assume` or `observe` for the the variable present in `args`. 2. Otherwise, we generate a block which is identical to the `assume` block in the if-statement mentioned in (1). The thing is, `DynamicPPL.isassumption` performs exactly the same check as above but using `DynamicPPL.inargnames`, i.e. at runtime. So if we're using `TypedVarInfo`, the check at macro-expansion vs. at runtime is completely redundant since all the information necessary to determine `DynamicPPL.inargnames` is available at compile-time. Therefore I suggest we remove this check at model-expansion, and simply handle it using `DynamicPPL.isassumption`. ## Pros & cons Pros: - No need to pass `args` around everywhere - `generate_tilde` and `generate_dot_tilde` are much simpler: two possible blocks we can generate, either a) assume/observe, or b) observe literal. Cons: - We need to perform _one_ more check at runtime when using `UntypedVarInfo`. **IMO, this is really worth it.** ## Motivation (sort of) The main motivation behind this PR is simplification, but there's a different reason why I came across this. I came to this because I was thinking about trying to "customize" the behavior of `~`, and I was thinking of using a macro to do it, e.g. `@mymacro x ~ Normal()`. Atm we're actually performing model-expansion on the code passed to the macro and thus trying to alter the way DynamicPPL treats `~` using a macro is veeeery difficult since you actually have to work with the *expanded* code, but let's ignore that issue for now (and take that discussion somewhere else, because IMO we shouldn't do this). Suppose we didn't perform model-expansions of the code fed to the macros, then you can just copy-paste `generate_tilde`, customize it do what you want, and BAM, you got yourself a working `@mymacro x ~ Normal()` which can do neat stuff! This is *not* possible atm because we don't have access to `args`, and so you have to take the approach in this PR to get there. That means that it's of course possible to do atm, but it's a bit icky since it ends up looking fundamentally different from `generate_tilde` rather than just slightly different. Then we can implement things like a `@tilde` which will expand to `generate_tilde` which can be used *internally* in functions (if the "internal" variables are present in the functions of course, but we can also simplify this in different ways), actually allowing people to modularize their models a bit, and `@reparam` from #220 using very similar pieces of code, a `@track` macro can be introduced to deal with the explicit tracking of variables rather than putting this directly in the compiler, etc. Endless opportunities! (Of course, I'm not suggesting we add these, but this makes it a bit easier to explore.) Co-authored-by: David Widmann <[email protected]>
Part of the motivations for #221 and #222 was so we could add submodels/model-nesting. Well, now we can. Special thanks to @devmotion who reviewed those PRs (several times), improving them significantly, made additional PRs and suggested the current impl of the `@submodel` ❤️ EDIT: We fixed the performance:) This now has zero runtime overhead! See comment-section. EDIT 2: Thanks to @devmotion, we can now alos deal with dynamically specified prefices! - [Motivating example: AR1-prior](#org46a90a5) - [Demos](#org7e05701) - [Can it ever fail?](#org75acb71) - [Benchmarks](#orga99bcf4) <a id="org46a90a5"></a> # Motivating example: AR1-prior ```julia using Turing using DynamicPPL ``` ```julia # Could have made model which samples `num_obs` AR1 samples simulatenously, # but for the sake of showing off dynamic prefixes, we'll only use a vector-implementation. # The matrix implementation will be quite a bit faster too, but oh well. @model function AR1(num_steps, α, μ, σ, ::Type{TV} = Vector{Float64}) where {TV} η ~ MvNormal(num_steps, 1.0) δ = sqrt(1 - α^2) x = TV(undef, num_steps) x[1] = η[1] @inbounds for t = 2:num_steps x[t] = @. α * x[t - 1] + δ * η[t] end return @. μ + σ * x end # Generate an observation σ_obs = 0.1 num_obs = 5 num_steps = 10 ar1 = AR1(num_steps, 0.5, 1.0, 1.0) ys = mapreduce(hcat, 1:num_obs) do i ar1() + σ_obs * randn(num_steps) end ``` 10×5 Matrix{Float64}: 2.30189 0.301618 1.73268 -0.65096 1.46835 2.11187 -1.34878 2.3728 1.02125 3.28422 -0.249064 0.769488 1.34044 3.22175 2.52196 -0.25863 -0.216914 0.528954 3.04756 3.8234 0.372122 0.473511 0.708068 0.76197 0.202003 0.41487 0.759435 1.80162 0.790204 0.12331 1.32585 0.567929 2.74316 1.0874 2.82701 1.84307 1.16138 1.36382 0.735388 1.07423 3.20139 0.75177 1.57236 0.865401 -0.315341 1.22479 1.35688 2.8239 0.597959 0.587955 ```julia @model function demo(y) α ~ Uniform() μ ~ Normal() σ ~ truncated(Normal(), 0, Inf) num_steps = size(y, 1) num_obs = size(y, 2) @inbounds for i = 1:num_obs x = @SubModel $(Symbol("ar1_$i")) AR1(num_steps, α, μ, σ) y[:, i] ~ MvNormal(x, 0.1) end end; m = demo(y); vi = VarInfo(m); ``` ```julia keys(vi) ``` 8-element Vector{VarName{sym, Tuple{}} where sym}: α μ σ ar1_1.η ar1_2.η ar1_3.η ar1_4.η ar1_5.η ```julia vi[@varname α] ``` 0.9383208224122919 ```julia chain = sample(m, NUTS(1_000, 0.8), 3_000); ``` ┌ Info: Found initial step size │ ϵ = 0.025 └ @ Turing.Inference /home/tor/.julia/packages/Turing/rHLGJ/src/inference/hmc.jl:188 Sampling: 100%|█████████████████████████████████████████| Time: 0:04:00 ```julia chain[1001:end, [:α, :μ, :σ], :] ``` Chains MCMC chain (2000×3×1 Array{Float64, 3}): Iterations = 1001:3000 Thinning interval = 1 Chains = 1 Samples per chain = 2000 parameters = α, μ, σ internals = Summary Statistics parameters mean std naive_se mcse ess rhat Symbol Float64 Float64 Float64 Float64 Float64 Float64 α 0.5474 0.1334 0.0030 0.0073 159.6969 0.9995 μ 1.0039 0.2733 0.0061 0.0168 169.9106 1.0134 σ 1.1294 0.1807 0.0040 0.0106 166.8670 0.9998 Quantiles parameters 2.5% 25.0% 50.0% 75.0% 97.5% Symbol Float64 Float64 Float64 Float64 Float64 α 0.2684 0.4625 0.5534 0.6445 0.7861 μ 0.4248 0.8227 1.0241 1.2011 1.4801 σ 0.8781 1.0018 1.0989 1.2239 1.5472 Yay! We recovered the true parameters :tada: ```julia @benchmark $m($vi) ``` BenchmarkTools.Trial: memory estimate: 12.05 KiB allocs estimate: 123 -------------- minimum time: 15.091 μs (0.00% GC) median time: 17.861 μs (0.00% GC) mean time: 19.582 μs (5.23% GC) maximum time: 10.293 ms (99.46% GC) -------------- samples: 10000 evals/sample: 1 <a id="org7e05701"></a> # Demos ```julia using DynamicPPL, Distributions ``` ┌ Info: Precompiling DynamicPPL [366bfd00-2699-11ea-058f-f148b4cae6d8] └ @ Base loading.jl:1317 ```julia @model function demo1(x) x ~ Normal() end; @model function demo2(x, y) @SubModel demo1(x) y ~ Uniform() end false; m2 = demo2(missing, missing); vi2 = VarInfo(m2); keys(vi2) ``` 2-element Vector{VarName{sym, Tuple{}} where sym}: x y ```julia println(vi2[VarName(Symbol("x"))]) println(vi2[VarName(Symbol("y"))]) ``` 0.3069117531180063 0.7325324947386318 We can also `observe` without issues: ```julia @model function demo2(x, y) @SubModel demo1(x) y ~ Normal(x) end false; m2 = demo2(1000.0, missing); vi2 = VarInfo(m2); keys(vi2) ``` 1-element Vector{VarName{:y, Tuple{}}}: y ```julia vi2[@varname y] ``` 1000.3905079427211 ```julia DynamicPPL.getlogp(vi2) ``` -500001.9141252931 But what if the models have the same variable-names?! "Sure, this is cool and all, but can we even use the values from the nested values in the parent model?" ```julia @model function demo_return(x) x ~ Normal() return x end; @model function demo_useval(x, y) x1 = @SubModel sub1 demo_return(x) x2 = @SubModel sub2 demo_return(y) z ~ Normal(x1 + x2 + 100, 1.0) end false; vi = VarInfo(demo_useval(missing, missing)); keys(vi) ``` 3-element Vector{VarName{sym, Tuple{}} where sym}: sub1.x sub2.x z ```julia vi[@varname z] ``` 101.09066854862154 And just to prove a point: ```julia @model function nested(x, y) @SubModel 1 nested1(x, y) y ~ Uniform() end false; @model function nested1(x, y) @SubModel 2 nested2(x, y) y ~ Uniform() end false; @model function nested2(x, y) z = @SubModel 3 nested3(x, y) y ~ Normal(z, 1.0) end false; @model function nested3(x, y) x ~ Uniform() y ~ Normal(-100.0, 1.0) return x + 1000 end false; m = nested(missing, missing); vi = VarInfo(m); keys(vi) ``` 5-element Vector{VarName{sym, Tuple{}} where sym}: 1.2.3.x 1.2.3.y 1.2.y 1.y y ```julia vi[VarName(Symbol("1.2.y"))] ``` 1000.5609156083766 ```julia DynamicPPL.getlogp(vi) ``` -4.620040828101227 <a id="org75acb71"></a> # Can it ever fail? Yeah, if the user doesn't provide the prefix, it can: ```julia @model function nested(x, y) @SubModel nested1(x, y) y ~ Uniform() end false; @model function nested1(x, y) @SubModel nested2(x, y) y ~ Uniform() end false; @model function nested2(x, y) z = @SubModel nested3(x, y) y ~ Normal(z, 1.0) end false; @model function nested3(x, y) x ~ Uniform() y ~ Normal(-100.0, 1.0) return x + 1000 end false; m = nested(missing, missing); vi = VarInfo(m); keys(vi) ``` 2-element Vector{VarName{sym, Tuple{}} where sym}: x y ```julia # Inner-most value is recorded (i.e. the first one reached) vi[@varname y] ``` -100.16836599596732 And it messes up the logp computation: ```julia DynamicPPL.getlogp(vi) ``` -Inf But I could imagine there's a way for us to fix this, or at least warn the user when this happens. <a id="orga99bcf4"></a> # Benchmarks At this point you're probably wondering, "but does it have any overhead (at runtime)?". For a "shallow" nestings, nah, but if you go deep enough there seems to be a tiny bit (likely because we're calling the "constructor" for the model): ```julia using BenchmarkTools @model function base(x, y) x ~ Uniform() y ~ Uniform() y1 ~ Uniform() z = x + 1000 y12 ~ Normal() y123 ~ Normal(-100.0, 1.0) end m1 = base(missing, missing); vi1 = VarInfo(m1); ``` ```julia @model function nested_shallow(x, y) @SubModel 1 nested1_shallow(x, y) y ~ Uniform() end false; @model function nested1_shallow(x, y) x ~ Uniform() y ~ Uniform() z = x + 1000 y12 ~ Normal() y123 ~ Normal(-100.0, 1.0) end false; m2 = nested_shallow(missing, missing); vi2 = VarInfo(m2); ``` ```julia @model function nested(x, y) @SubModel 1 nested1(x, y) y ~ Uniform() end false; @model function nested1(x, y) @SubModel 2 nested2(x, y) y ~ Uniform() end false; @model function nested2(x, y) z = @SubModel 3 nested3(x, y) y ~ Normal(z, 1.0) end false; @model function nested3(x, y) x ~ Uniform() y ~ Normal(-100.0, 1.0) return x + 1000 end m3 = nested(missing, missing); vi3 = VarInfo(m3); ``` ```julia @model function nested_noprefix(x, y) @SubModel nested_noprefix1(x, y) y ~ Uniform() end false; @model function nested_noprefix1(x, y) @SubModel nested_noprefix2(x, y) y1 ~ Uniform() end false; @model function nested_noprefix2(x, y) z = @SubModel nested_noprefix3(x, y) y2 ~ Normal(z, 1.0) end false; @model function nested_noprefix3(x, y) x ~ Uniform() y3 ~ Normal(-100.0, 1.0) return x + 1000 end m4 = nested_noprefix(missing, missing); vi4 = VarInfo(m4); ``` ```julia keys(vi1) ``` 5-element Vector{VarName{sym, Tuple{}} where sym}: x y y1 y12 y123 ```julia keys(vi2) ``` 5-element Vector{VarName{sym, Tuple{}} where sym}: 1.x 1.y 1.y12 1.y123 y ```julia keys(vi3) ``` 5-element Vector{VarName{sym, Tuple{}} where sym}: 1.2.3.x 1.2.3.y 1.2.y 1.y y ```julia keys(vi4) ``` 5-element Vector{VarName{sym, Tuple{}} where sym}: x y3 y2 y1 y ```julia @benchmark $m1($vi1) ``` BenchmarkTools.Trial: memory estimate: 160 bytes allocs estimate: 5 -------------- minimum time: 1.714 μs (0.00% GC) median time: 1.747 μs (0.00% GC) mean time: 1.835 μs (0.00% GC) maximum time: 6.894 μs (0.00% GC) -------------- samples: 10000 evals/sample: 10 ```julia @benchmark $m2($vi2) ``` BenchmarkTools.Trial: memory estimate: 160 bytes allocs estimate: 5 -------------- minimum time: 1.759 μs (0.00% GC) median time: 1.778 μs (0.00% GC) mean time: 1.819 μs (0.00% GC) maximum time: 5.563 μs (0.00% GC) -------------- samples: 10000 evals/sample: 10 ```julia @benchmark $m3($vi3) ``` BenchmarkTools.Trial: memory estimate: 160 bytes allocs estimate: 5 -------------- minimum time: 1.718 μs (0.00% GC) median time: 1.746 μs (0.00% GC) mean time: 1.787 μs (0.00% GC) maximum time: 5.758 μs (0.00% GC) -------------- samples: 10000 evals/sample: 10 ```julia @benchmark $m4($vi4) ``` BenchmarkTools.Trial: memory estimate: 160 bytes allocs estimate: 5 -------------- minimum time: 1.672 μs (0.00% GC) median time: 1.696 μs (0.00% GC) mean time: 1.756 μs (0.00% GC) maximum time: 4.882 μs (0.00% GC) -------------- samples: 10000 evals/sample: 10 Notice that the number of allocations have increased for the deeply nested model. Seems like the Julia compiler isn't too good at inferring the return-types of Turing-models? This seems to be the case too by looking at the lowered code. I haven't given this too much thought yet btw; likely is a way for us to help the compiler here.
Changes to DPPL can often have quite significant effects for compilation time and performance of both itself and downstream packages. It's also sometimes difficult to discover these performance regressions. E.g. in #221 we made a small simplification to the compiler and it ended up taking quite a while to figure out what was going wrong and had to test several models to identify the issue. So, this is a WIP PR for including a small set of models which we can `weave` into a document where we can look at the changes. It's unclear to me whether this should go in DPPL itself or in a separate package. I found it useful myself and figured I'd put it here so we can start maybe get some "standard" benchmarks to run for testing purposes. IMO we don't need many of them, as we will add more as we go along. For each model the following will be included in the document: - Benchmarked evaluation of the model on untyped and typed `VarInfo`. - Timing of the compilation of the model in the typed `VarInfo`. - Lowered code for the model. - If `:prefix` is provided to `weave`, the string-representation of `code_typed` for the evaluation of the model will be saved to a file `$(prefix)_(model.name)`. Furthermore, if `:prefix_old` is provided, pointing to `:prefix` used for a previous run (likely using a different version of DPPL), we will `diff` the `code_typed` for the two models by loading the saved files.
Overview
At the moment, we perform a check at model-expansion as to whether or not
vsym(left) in args
, whereargs
is the arguments of the model.true
, we return a block of code which usesDynamicPPL.isassumption
to check whether or not to callassume
orobserve
for the the variable present inargs
.assume
block in the if-statement mentioned in (1).The thing is,
DynamicPPL.isassumption
performs exactly the same check as above but usingDynamicPPL.inargnames
, i.e. at runtime. So if we're usingTypedVarInfo
, the check at macro-expansion vs. at runtime is completely redundant since all the information necessary to determineDynamicPPL.inargnames
is available at compile-time.Therefore I suggest we remove this check at model-expansion, and simply handle it using
DynamicPPL.isassumption
.Pros & cons
Pros:
args
around everywheregenerate_tilde
andgenerate_dot_tilde
are much simpler: two possible blocks we can generate, either a) assume/observe, or b) observe literal.Cons:
UntypedVarInfo
.IMO, this is really worth it.
Motivation (sort of)
The main motivation behind this PR is simplification, but there's a different reason why I came across this.
I came to this because I was thinking about trying to "customize" the behavior of
~
, and I was thinking of using a macro to do it, e.g.@mymacro x ~ Normal()
. Atm we're actually performing model-expansion on the code passed to the macro and thus trying to alter the way DynamicPPL treats~
using a macro is veeeery difficult since you actually have to work with the expanded code, but let's ignore that issue for now (and take that discussion somewhere else, because IMO we shouldn't do this).Suppose we didn't perform model-expansions of the code fed to the macros, then you can just copy-paste
generate_tilde
, customize it do what you want, and BAM, you got yourself a working@mymacro x ~ Normal()
which can do neat stuff! This is not possible atm because we don't have access toargs
, and so you have to take the approach in this PR to get there. That means that it's of course possible to do atm, but it's a bit icky since it ends up looking fundamentally different fromgenerate_tilde
rather than just slightly different.Then we can implement things like a
@tilde
which will expand togenerate_tilde
which can be used internally in functions (if the "internal" variables are present in the functions of course, but we can also simplify this in different ways), actually allowing people to modularize their models a bit, and@reparam
from #220 using very similar pieces of code, a@track
macro can be introduced to deal with the explicit tracking of variables rather than putting this directly in the compiler, etc. Endless opportunities! (Of course, I'm not suggesting we add these, but this makes it a bit easier to explore.)