Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Design/improve/unifty sample() interface #746

Closed
xukai92 opened this issue Apr 2, 2019 · 18 comments
Closed

Design/improve/unifty sample() interface #746

xukai92 opened this issue Apr 2, 2019 · 18 comments

Comments

@xukai92
Copy link
Member

xukai92 commented Apr 2, 2019

Questions and my current thoughts

  • How to pass initial state
    • Support sample(..., init::Array{Union{Missing, Array{Float64,1}},1}=[missing])
    • We flat init, and assign those variables which are not missed to vi[spl]
    • Constructing a correct init is the user's responsibility
  • How to pass adaptation configuration
    • Take use of AdvancedHMC.Adaptation
    • sample(..., adapt::Union{Nothing,AbstractAdaptor}=nothing)
  • How to pass continuing chain
    • sample(..., resume::Union{Nothing,AbstractChain}=nothing)
    • How do we resume an adapted sampler?
      • With the new interface of AdvancedHMC.Adaptation, the adapted sampler is just the types but not the adapter or any other thing. So if we can make the types in Turing to be 1:1 to those types, we can resume the adapted sampler.

Related issue:

@trappmartin
Copy link
Member

Does it make sense that we differentiate between initial value and resume sampling? Why not get rid of the resume interface and only provide a clean way to set an initial value?

@xukai92
Copy link
Member Author

xukai92 commented Apr 2, 2019

Good idea! Then the only question leaves here is how to resume adapted sampler.

@yebai
Copy link
Member

yebai commented Apr 4, 2019

I think we can use keyword arguments here to allow maximum flexibility and minimal sample design. For example:

function sample(rng::AbstractRNG, 
                       ℓ::ModelType, s::SamplerType, 
                       N::Integer; args...
        ) where {ModelType, SamplerType, TransitionType}
    t = Array{TransitionType, 1}(undef, N)
     for i=1:N
        t[i] = step(rng::AbstractRNG, ℓ:: ModelType, s::Sampler, N::Integer; args...)
    end
   # convert `t` into `MCMCChains.Chain`.
end

where args is a keyword argument, some example uses are

  • θ₀::AbstractVector = zeros(D)
  • adaptor::Adaptor = StanNUTSAdaptor()
  • n_chains::Int = 1
  • resume::AbstractChain = chain

@yebai
Copy link
Member

yebai commented Apr 4, 2019

Related: #723

@yebai
Copy link
Member

yebai commented Apr 4, 2019

@cpfiffer Does it make sense moving this sample function to MCMCChains. This design would allow packages such as AdvancedHMC fully functional by depending only on MCMCChains. That is, making both Turing and AdvancedHMC depending on MCMCChains, and removing inter-dependence between them.

@cpfiffer
Copy link
Member

cpfiffer commented Apr 5, 2019

I don't see why not. The function looks pretty small, and everything's parametric so there's not really any dependencies to speak of. The JuliaStan people may have some opinions though, so allow me to summon the JuliaStan kingpin.

@goedman, we're working on something of an overhaul of the Turing internals for our 0.7 release and wanted to know what you think about including a function in MCMCChains like the following:

function sample(rng::AbstractRNG, 
    ℓ:: ModelType, s::SamplerType, 
    N::Integer; args...
) where {ModelType, SamplerType, TransitionType}
    t = Array{TransitionType, 1}(undef, N)

    for i=1:N
        t[i] = step(rng::AbstractRNG, ℓ::ModelType, s::SamplerType, N::Integer; args...)
    end
    
    return Chains(t ...)
end

It's designed such that AdvancedHMC and Turing can work independently of one another. I'd view it as basically another MCMCChains constructor. Do you have any issues with including this in MCMCChains?

@goedman
Copy link

goedman commented Apr 5, 2019

That's perfectly ok.

@goedman
Copy link

goedman commented Apr 5, 2019

I've actually started to look at the AHMC sampler would love to spend some time with it next week and try it out! If the performance is reasonably close to DynamicHMC that would be a huge leap forward!

@yebai yebai mentioned this issue Apr 19, 2019
56 tasks
@cpfiffer
Copy link
Member

cpfiffer commented Apr 21, 2019

@yebai Does this interface need some additional components? I've added all the bits to a draft of MCMCChains that includes the addition of AbstractSampler, AbstractModel, and AbstractTransition types which I could see Turing inheriting from. I can take this out, I'm not very convinced as to their necessity, but it does unify the "interface" component of all this.

I'm not sure on what AbstractTransition should look like, though, and I couldn't find an example of what this might be. The function above also doesn't have a mechanism for the type to actually make it into the function. Should AbstractModel have some parametric type AbstractModel{TransitionType}? This comment seems to indicate that the TransitionType is at the sampler level.

Additionally, many of the samplers have some initial setup of some kind that isn't caught well by the "iterate through each step" methodology above. I added some "empty functions" that specific samplers can overload if they need to. It would be really cool if this could catch the general cases for all the samplers.

For example:

function sample(
    rng::AbstractRNG,
    ℓ::ModelType, 
    s::SamplerType,
    N::Integer; 
    kwargs...
) where {ModelType<:AbstractModel,
    SamplerType<:AbstractSampler,
    TransitionType<:AbstractTransition} #Not sure how to get TransitionType into the function
    t = Array{TransitionType, 1}(undef, N)

    # Perform any necessary setup.
    sample_init!(rng, ℓ, s, N; kwargs...)

    # Step through the sampler.
    for i=1:N
        t[i] = step(rng::AbstractRNG, ℓ::ModelType, s::SamplerType, N::Integer; kwargs...)
    end

    # Wrap up the sampler.
    sample_end!(rng, ℓ, s, N; kwargs...)

   # Placeholder while I figure out what t looks like.
    return Chains(t)
end

In summary, this is what I see the sampling interface looking like:

  • Samplers inherit from AbstractSampler or some subtype of AbstractSampler, like Hamiltonian<:AbstractSampler
  • Models inherit from the AbstractModel type. This won't be too hard on the Turing side, as
    struct Model{pvars, dvars, F, TData, TDefaults}
    would just become
    struct Model{pvars, dvars, F, TData, TDefaults} <: AbstractModel
  • Samplers that need setup or wrap-up code overload their own versions of sample_init!(rng, ℓ, s, N; kwargs...) and sample_end!(rng, ℓ, s, N; kwargs...).
  • I have no idea what to do with TransitionType or AbstractTransition.

I also have wrapper functions for the above that allow for the sampling of multiple chains, and psample which will run multiple chains on parallel processes if any workers are available, but I can't really test them well without the whole interface working.

Comments, pointers, and corrections welcomed.

@yebai
Copy link
Member

yebai commented May 3, 2019

Some desired refactoring goals for Sampler APIs

  1. Implement a shared sample(model, alg, ...) for all samplers (related to https://github.com/TuringLang/Turing.jl/issues/634#issuecomment-471339521, Pass number of parallel chains as an argument #582, Add example illustrating how to run multiple chains. #698, Better adaptation configuration interface #491, Add additional keyword arguments to the sample function #482). This would substantially simply sampling algorithm code since their current implementations contain a customised sample with loads of code copied and pasted from the HMC sampler.

  2. Remove Sampler.info and introduce a Transitiontype (generalising current Turing.Sample type) to store all MCMC running time information (related Logging in samplers / diagnosis statistics #599). This would remove the need to create a Sampler wrapper for each Gibbs component algorithms.

    mutable struct Sample
    weight :: Float64 # particle weight
    value :: Dict{Symbol,Any}
    end

  3. Remove Sampler wrapper for Gibbs component algorithms: Gibbs.algs::Sampler ==> Gibbs.algs:InferenceAlgorithm. (related: Introduce MCMCSampler type #602, [Gibbs] Increase performance when subsamplers have several iterations #378)

    for i in 1:n_samplers
    sub_alg = alg.algs[i]
    if isa(sub_alg, GibbsComponent)
    samplers[i] = Sampler(typeof(sub_alg)(sub_alg, i), model)
    else
    @error("[Gibbs] unsupport base sampling algorithm $alg")
    end
    space = union(space, sub_alg.space)
    end

Related https://github.com/TuringLang/Turing.jl/issues/634 #602 #582 #599 #689

Here are some relevant sample implementations in the current code base:

HMC, SGLD and SGHMC:

  • function sample(
    model::Model,
    alg::Hamiltonian;
    save_state=false, # flag for state saving
    resume_from=nothing, # chain to continue
    reuse_spl_n=0, # flag for spl re-using
    adaptor = AHMCAdaptor(alg),
    init_theta::Union{Nothing,Array{<:Any,1}}=nothing,
    rng::AbstractRNG=GLOBAL_RNG,
    kwargs...
    )

Gibbs

  • function sample(
    model::Model,
    alg::Gibbs;
    save_state=false, # flag for state saving
    resume_from=nothing, # chain to continue
    reuse_spl_n=0 # flag for spl re-using
    )

PGibbs

  • function sample( model::Model,
    alg::PG;
    save_state=false, # flag for state saving
    resume_from=nothing, # chain to continue
    reuse_spl_n=0 # flag for spl re-using
    )

SMC

Stan

@yebai
Copy link
Member

yebai commented May 3, 2019

I have no idea what to do with TransitionType or AbstractTransition.

@cpfiffer This is meant to replace the Sample type and Sampler.Info, see #723 and comments above (copied from #723)

Models inherit from the AbstractModel type.

I think we can inherit Distributions.Sampleable or Distributions.Distribution, e.g.

struct Model{pvars, dvars, F, TData, TDefaults} <: Distributions.Sampleable

@yebai yebai pinned this issue May 3, 2019
@cpfiffer
Copy link
Member

cpfiffer commented May 5, 2019

I've poked around a bit more and come up with an example of how I might think about modifying one of the samplers to fit this interface style. The changes below would fit fairly well into the common sampler interface function noted above, and give other potential Turing users a better example of how to construct modular sampler/model/general inference tools.

Using SMC as an initial example, I would imagine the following structure:

The Transition Type

ParticleContainer seems like a natural thing to use as a Transition type. I'd think of this in terms of changing the current definition of

mutable struct ParticleContainer{T<:Particle, F}
  . . .
end

to simply inherit AbstractTransition, like so:

mutable struct ParticleContainer{T<:Particle, F} <: AbstractTransition
  . . .
end

The SMC package would then declare a function with the signature of

MCMCChains.Chains(t::ParticleContainer)

which would handle the retrieval of all the samples. This would remove the need to generate a Sample type --- we'd just overload a Chains function for each Transition type to retrieve the samples using multiple dispatch. Not too much in the way of moving parts here.

An alternative specification for Transition would be to use the existing step function (which I'm not certain is currently used?) and save vi.vals each step. Then, the vector could be packaged up and sent off to MCMCChains.

function step(model, spl::Sampler{<:SMC}, vi::VarInfo)

It looks only slightly more complex for the HMC samplers. In this case we'd need to keep all the values from Sample(vi, spl).vals separately and then pull out varnames at the end. This would be doable during the sample_end! phase, or by making a function with the signature MCMCChains.Chains(t::TransitionType, spl::HamiltonianSampler). The spl here is necessary to get the vi info to the Chains object.

The Model type

As noted above, this should have to change too much:

struct Model{pvars, dvars, F, TData, TDefaults} <: Distributions.Sampleable

The Sampler type

In order to remove Sampler.info, we could just declare special Sampler types for each sampler, and store each sampler's necessary fields in the struct itself. The current definition of the SMC sampler dumps some random stuff into the generic Sampler struct:

function Sampler(alg::SMC, s::Selector)
    info = Dict{Symbol, Any}()
    info[:logevidence] = []
    return Sampler(alg, info, s)
end

For reference, the Sampler struct is

mutable struct Sampler{T} <: AbstractSampler
    alg      ::  T
    info     ::  Dict{Symbol, Any} # sampler infomation
    selector ::  Selector
end

The way I might think about changing this set up is to remove the generic Sampler type (or maybe replace it with the most bare-bones version possible for interface purposes) and force each sampler type to create it's own type:

mutable struct SMCSampler{T} <: AbstractSampler
    alg :: T
    logevidence :: Float64
    selector ::  Selector
    particles ::  ParticleContainer{Trace} # 
end

The gist here would be to just move everything currently stored in any sampler's spl.info entry and put it directly into the struct. An overload of the sample_init! and sample_end! functions would be necessary, as well:

function sample_init!(::AbstractRNG,
        model::ModelType,
        s::SMCSampler,
        ::Integer;
        kwargs...
    ) where ModelType<:Distributions.Sampleable
    # Instantiate the ParticleContainer.
    s.particles = ParticleContainer{Trace}(model)
    push!(s.particles, spl.alg.n_particles, spl, VarInfo())
end

yebai pushed a commit that referenced this issue May 6, 2019
* integrate AHMC


* make functions safe and add AHMC in REQUIRE

* implement experimental interface for ANUTS

* fix init_theta masking

* update to AHMC's new interface

* fix type constraint

* replace find_good_eps

* replace low-level HMC functions

* remove old comments

* remove rev and log funcs

* remove returned value from gen_grad_func

* replace adapation with AHMC code

* remove adapation code

* remove implict sampler constructors

* use integrated sample() if AHMC if possible

* delete ahmc.jl

* add version to AHMC

* make AHMC version master

* revert version of AHMC

* rename binding AdvancedHMC -> AHMC

* add AHMC#master to deps

* improve typing of mh_accept

* merge hmcda.jl and nuts.jl into hmc.jl

* replace conditions with multiple dispatch

* not storing adapt_conf in spl.info

* relax Sampler to AbstractSampler

* use AHMC.transition instead of copied version

* tweak comments

* mvoe hmc_core.jl into hmc.jl

* clean up

* clean up hmc.jl

* bugfix

* Remove some redundant docs.

* Some formating fixes - no functionality change.

* epsilon ==> ϵ.

* tau ==> n_leapfrog.

* add type for metricT

* remove redudant .logp

* fix typo in comment

* unroll first and rest iterations for steps

* Rename all: epsilon ==> ϵ

* Remove global STAN_DEFAULT_ADAPT_CONF and related types.

* Disable stan interface temporarily before fixing #746

* Ensure h.metric has the same dim as θ.

* Bugfix for using `step` in Gibbs.

* Merge SGLD into SGHMC - no functionality change.

* Add default metric type to SGLD and SGHMC.

* Merge test/SGLD into test/SGHMC - no functionality change.
yebai added a commit that referenced this issue May 6, 2019
commit beda95d
Author: Kai Xu <[email protected]>
Date:   Mon May 6 14:09:50 2019 +0100

    Replace HMC related codes using AHMC (#739)

    * integrate AHMC

    * make functions safe and add AHMC in REQUIRE

    * implement experimental interface for ANUTS

    * fix init_theta masking

    * update to AHMC's new interface

    * fix type constraint

    * replace find_good_eps

    * replace low-level HMC functions

    * remove old comments

    * remove rev and log funcs

    * remove returned value from gen_grad_func

    * replace adapation with AHMC code

    * remove adapation code

    * remove implict sampler constructors

    * use integrated sample() if AHMC if possible

    * delete ahmc.jl

    * add version to AHMC

    * make AHMC version master

    * revert version of AHMC

    * rename binding AdvancedHMC -> AHMC

    * add AHMC#master to deps

    * improve typing of mh_accept

    * merge hmcda.jl and nuts.jl into hmc.jl

    * replace conditions with multiple dispatch

    * not storing adapt_conf in spl.info

    * relax Sampler to AbstractSampler

    * use AHMC.transition instead of copied version

    * tweak comments

    * mvoe hmc_core.jl into hmc.jl

    * clean up

    * clean up hmc.jl

    * bugfix

    * Remove some redundant docs.

    * Some formating fixes - no functionality change.

    * epsilon ==> ϵ.

    * tau ==> n_leapfrog.

    * add type for metricT

    * remove redudant .logp

    * fix typo in comment

    * unroll first and rest iterations for steps

    * Rename all: epsilon ==> ϵ

    * Remove global STAN_DEFAULT_ADAPT_CONF and related types.

    * Disable stan interface temporarily before fixing #746

    * Ensure h.metric has the same dim as θ.

    * Bugfix for using `step` in Gibbs.

    * Merge SGLD into SGHMC - no functionality change.

    * Add default metric type to SGLD and SGHMC.

    * Merge test/SGLD into test/SGHMC - no functionality change.

commit 502027b
Author: Mohamed Tarek <[email protected]>
Date:   Sun Apr 28 19:39:57 2019 +1000

    fix some bug in `eval_num` from #740 (#770)
@yebai
Copy link
Member

yebai commented May 8, 2019

Thanks, @cpfiffer. The plan looks sensible. To summarise, the new design minimises the generic Sampler type and requires each algorithm to (optionally) implement two types

  • SomeSampler <: AbstractSampler
  • SomeTransition <: AbstractTransition

It is also worth noting that only some algorithms make use of Sampler.info, for example

  • SMC, PG, etc: info[:logevidence]
  • HMC, SGHMC, etc: info[:eval_num] and info[:i]
  • Gibbs, IPMCMC: info[:samplers]
  • MH: info[:proposal_ratio], info[:prior_prob], info[:violating_support]

So perhaps we can provide a default Sampler implementation for algorithms that do not utilities Sampler.info. Then, we can explicitly require algorithms utilising Sampler.info to define a customised Sampler type, e.g.

struct Sampler <: AbstractSampler # Default for all algorithms

get_selector(spl::AbstractSampler) = Selector() # callback interface 1
get_algstr(spl::AbstractSampler) = "AlgorithmName" # callback interface 2

mutable struct SMCSampler{SMC} <: AbstractSampler
   log_evidence::Float64
end

get_selector(spl::SMCSampler) = ...
get_algstr(spl::SMCSampler) = ...

In general, the InferenceAlgorithm type is meant to store immutable algorithm configuration parameters (e.g. step size for HMC, number of particles for PG). While the Sampler type is meant to store mutable algorithm runtime information (e.g. precondition matrix for HMC).

Below are all Sampler definition in Turing.

function Sampler(alg::SMC, s::Selector)
info = Dict{Symbol, Any}()
info[:logevidence] = []
return Sampler(alg, info, s)
end

function Sampler(alg::PG, s::Selector)
info = Dict{Symbol, Any}()
info[:logevidence] = []
Sampler(alg, info, s)
end

function Sampler(alg::Gibbs, model::Model, s::Selector)
info = Dict{Symbol, Any}()
spl = Sampler(alg, info, s)
n_samplers = length(alg.algs)
samplers = Array{Sampler}(undef, n_samplers)
space = Set{Symbol}()
for i in 1:n_samplers
sub_alg = alg.algs[i]
if isa(sub_alg, GibbsComponent)
samplers[i] = Sampler(sub_alg, model, Selector(Symbol(typeof(sub_alg))))
else
@error("[Gibbs] unsupport base sampling algorithm $alg")
end
space = union(space, sub_alg.space)
end
# Sanity check for space
@assert issubset(Set(get_pvars(model)), space) "[Gibbs] symbols specified to samplers ($space) doesn't cover the model parameters ($(Set(get_pvars(model))))"
if Set(get_pvars(model)) != space
@warn("[Gibbs] extra parameters specified by samplers don't exist in model: $(setdiff(space, Set(get_pvars(model))))")
end
info[:samplers] = samplers
return spl
end

function Sampler(alg::Hamiltonian, s::Selector=Selector())
info = Dict{Symbol, Any}()
info[:eval_num] = 0
info[:i] = 0
Sampler(alg, info, s)
end

function Sampler(alg::IS, s::Selector)
info = Dict{Symbol, Any}()
Sampler(alg, info, s)
end

function Sampler(alg::MH, model::Model, s::Selector)
alg_str = "MH"
# Sanity check for space
if (s.tag == :default) && !isempty(alg.space)
@assert issubset(Set(get_pvars(model)), alg.space) "[$alg_str] symbols specified to samplers ($alg.space) doesn't cover the model parameters ($(Set(get_pvars(model))))"
if Set(get_pvars(model)) != alg.space
warn("[$alg_str] extra parameters specified by samplers don't exist in model: $(setdiff(alg.space, Set(get_pvars(model))))")
end
end
info = Dict{Symbol, Any}()
info[:proposal_ratio] = 0.0
info[:prior_prob] = 0.0
info[:violating_support] = false
return Sampler(alg, info, s)
end

function Sampler(alg::PMMH, model::Model, s::Selector)
info = Dict{Symbol, Any}()
spl = Sampler(alg, info, s)
alg_str = "PMMH"
n_samplers = length(alg.algs)
samplers = Array{Sampler}(undef, n_samplers)
space = Set{Symbol}()
for i in 1:n_samplers
sub_alg = alg.algs[i]
if isa(sub_alg, Union{SMC, MH})
samplers[i] = Sampler(sub_alg, model, Selector(Symbol(typeof(sub_alg))))
else
error("[$alg_str] unsupport base sampling algorithm $alg")
end
if typeof(sub_alg) == MH && sub_alg.n_iters != 1
warn("[$alg_str] number of iterations greater than 1 is useless for MH since it is only used for its proposal")
end
space = union(space, sub_alg.space)
end
# Sanity check for space
if !isempty(space)
@assert issubset(Set(get_pvars(model)), space) "[$alg_str] symbols specified to samplers ($space) doesn't cover the model parameters ($(Set(get_pvars(model))))"
if Set(get_pvars(model)) != space
warn("[$alg_str] extra parameters specified by samplers don't exist in model: $(setdiff(space, Set(get_pvars(model))))")
end
end

function Sampler(alg::IPMCMC, s::Selector)
info = Dict{Symbol, Any}()
spl = Sampler(alg, info, s)
# Create SMC and CSMC nodes
samplers = Array{Sampler}(undef, alg.n_nodes)
# Use resampler_threshold=1.0 for SMC since adaptive resampling is invalid in this setting
default_CSMC = CSMC(alg.n_particles, 1, alg.resampler, alg.space)
default_SMC = SMC(alg.n_particles, alg.resampler, 1.0, false, alg.space)
for i in 1:alg.n_csmc_nodes
samplers[i] = Sampler(default_CSMC, Selector(Symbol(typeof(default_CSMC))))
end
for i in (alg.n_csmc_nodes+1):alg.n_nodes
samplers[i] = Sampler(default_SMC, Selector(Symbol(typeof(default_CSMC))))
end
info[:samplers] = samplers
return spl
end

@cpfiffer
Copy link
Member

One thing I'm a little fuzzy on is how the proposed sampler API works with compositional Gibbs, particularly as it relates to how you declare sample sizes in the components. Currently, you can do this:

chn = sample(model, Gibbs(1000, HMC(1, 0.2, 3, :x), PG(20, 1, :y)))

The sampler API forces the number of steps to be outside the InferenceAlgorithm contructor, and so we'd maybe have something that looks like

chn = sample(model, Gibbs(HMC(0.2, 3, :x), PG(20, :y)), 1000)

Is this an issue? I presume this would force HMC and PG to run one iteration apiece, but I don't know if the flexibility lost there is a problem or not.

@yebai
Copy link
Member

yebai commented May 13, 2019

It would be nice to have some support for running HMC or PG multiple steps in a Gibbs step. If the current API is troublesome (although I don't quite catch the issue @cpfiffer mentioned), we can re-design it. Following @cpfiffer's design, this would be adapted as:

# default 1
chn = sample(model, Gibbs(HMC(0.2, 3, :x), PG(20, :y)), 1000) 

# default 2
chn = sample(model, Gibbs(HMC(0.2, 3, :x), PG(20, :y)), (1000,)) 

# HMC - 2 steps, PG - 2 steps
chn = sample(model, Gibbs(HMC(0.2, 3, :x), PG(20, :y)), (1000,2,2)) 

This was referenced May 25, 2019
@xukai92
Copy link
Member Author

xukai92 commented Jul 12, 2019

Or

# Default
chn = sample(model, Gibbs(HMC(0.2, 3, :x), PG(20, :y), (1, 1)), 1000) 

# HMC - 2 steps, PG - 2 steps
chn = sample(model, Gibbs(HMC(0.2, 3, :x), PG(20, :y), (2, 2)), 1000) 

?

@cpfiffer
Copy link
Member

I like @xukai92's idea better, then there's no need to handle this weird edge case on the interface side, since N is strongly typed as an integer.

@yebai
Copy link
Member

yebai commented Sep 17, 2019

Fixed by #793

@yebai yebai closed this as completed Sep 17, 2019
@yebai yebai unpinned this issue Sep 17, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants