diff --git a/docs/_docs/using-turing/advanced.md b/docs/_docs/using-turing/advanced.md index 29158552e..8dd176f8f 100644 --- a/docs/_docs/using-turing/advanced.md +++ b/docs/_docs/using-turing/advanced.md @@ -29,12 +29,12 @@ end ### 2. Implement Sampling and Evaluation of the log-pdf -Second, define `rand()` and `logpdf()`, which will be used to run the model. +Second, define `rand` and `logpdf`, which will be used to run the model. ```julia -Distributions.rand(d::Flat) = rand() -Distributions.logpdf{T<:Real}(d::Flat, x::T) = zero(x) +Distributions.rand(rng::AbstractRNG, d::Flat) = rand(rng) +Distributions.logpdf(d::Flat, x::Real) = zero(x) ``` @@ -47,7 +47,7 @@ In most cases, it may be required to define helper functions, such as the `minim #### 3.1 Domain Transformation -Some helper functions are necessary for domain transformation. For univariate distributions, the necessary ones to implement are `minimum()` and `maximum()`. +Some helper functions are necessary for domain transformation. For univariate distributions, the necessary ones to implement are `minimum` and `maximum`. ```julia @@ -56,18 +56,17 @@ Distributions.maximum(d::Flat) = +Inf ``` -Functions for domain transformation which may be required by multivariate or matrix-variate distributions are `size(d)`, `link(d, x)` and `invlink(d, x)`. Please see Turing's [`transform.jl`](https://github.com/TuringLang/Turing.jl/blob/master/src/utilities/transform.jl) for examples. +Functions for domain transformation which may be required by multivariate or matrix-variate distributions are `size`, `link` and `invlink`. Please see Turing's [`transform.jl`](https://github.com/TuringLang/Turing.jl/blob/master/src/utilities/transform.jl) for examples. #### 3.2 Vectorization Support -The vectorization syntax follows `rv ~ [distribution]`, which requires `rand()` and `logpdf()` to be called on multiple data points at once. An appropriate implementation for `Flat` are shown below. +The vectorization syntax follows `rv ~ [distribution]`, which requires `rand` and `logpdf` to be called on multiple data points at once. An appropriate implementation for `Flat` is shown below. ```julia -Distributions.rand(d::Flat, n::Int) = Vector([rand() for _ = 1:n]) -Distributions.logpdf{T<:Real}(d::Flat, x::Vector{T}) = zero(x) +Distributions.logpdf(d::Flat, x::AbstractVector{<:Real}) = zero(x) ``` @@ -247,4 +246,3 @@ model = gdemo([1.2, 3.5] # Run all samples. chns = reduce(chainscat, pmap(x->sample(model,sampler),1:num_chains)) ``` - diff --git a/src/stdlib/RandomMeasures.jl b/src/stdlib/RandomMeasures.jl index a41bc2e25..bd2f81c61 100644 --- a/src/stdlib/RandomMeasures.jl +++ b/src/stdlib/RandomMeasures.jl @@ -7,6 +7,7 @@ using StatsFuns: logsumexp import Distributions: sample, logpdf import Base: maximum, minimum, rand +import Random: AbstractRNG ## ############### ## ## Representations ## @@ -25,7 +26,7 @@ struct SizeBiasedSamplingProcess{T<:AbstractRandomProbabilityMeasure,V<:Abstract end logpdf(d::SizeBiasedSamplingProcess, x) = _logpdf(d, x) -rand(d::SizeBiasedSamplingProcess) = _rand(d) +rand(rng::AbstractRNG, d::SizeBiasedSamplingProcess) = _rand(rng, d) minimum(d::SizeBiasedSamplingProcess) = zero(d.surplus) maximum(d::SizeBiasedSamplingProcess) = d.surplus @@ -39,7 +40,7 @@ struct StickBreakingProcess{T<:AbstractRandomProbabilityMeasure} <: ContinuousUn end logpdf(d::StickBreakingProcess, x) = _logpdf(d, x) -rand(d::StickBreakingProcess) = _rand(d) +rand(rng::AbstractRNG, d::StickBreakingProcess) = _rand(rng, d) minimum(d::StickBreakingProcess) = 0.0 maximum(d::StickBreakingProcess) = 1.0 @@ -76,10 +77,10 @@ function logpdf(d::ChineseRestaurantProcess, x::Int) end end -function rand(d::ChineseRestaurantProcess) +function rand(rng::AbstractRNG, d::ChineseRestaurantProcess) lp = _logpdf_table(d.rpm, d.m) p = exp.(lp) - return rand(Categorical(p ./ sum(p))) + return rand(rng, Categorical(p ./ sum(p))) end minimum(d::ChineseRestaurantProcess) = 1 @@ -107,8 +108,8 @@ v_k \\sim Beta(1, \\alpha) *Chinese Restaurant Process* ```math -p(z_n = k | z_{1:n-1}) \\propto \\begin{cases} - \\frac{m_k}{n-1+\\alpha}, \\text{if} m_k > 0\\\\ +p(z_n = k | z_{1:n-1}) \\propto \\begin{cases} + \\frac{m_k}{n-1+\\alpha}, \\text{if} m_k > 0\\\\ \\frac{\\alpha}{n-1+\\alpha} \\end{cases} ``` @@ -119,10 +120,10 @@ struct DirichletProcess{T<:Real} <: AbstractRandomProbabilityMeasure α::T end -_rand(d::StickBreakingProcess{DirichletProcess{T}}) where {T<:Real} = rand(Beta(one(T), d.rpm.α)) +_rand(rng::AbstractRNG, d::StickBreakingProcess{DirichletProcess{T}}) where {T<:Real} = rand(rng, Beta(one(T), d.rpm.α)) -function _rand(d::SizeBiasedSamplingProcess{DirichletProcess{T}}) where {T<:Real} - return d.surplus*rand(Beta(one(T), d.rpm.α)) +function _rand(rng::AbstractRNG, d::SizeBiasedSamplingProcess{DirichletProcess{T}}) where {T<:Real} + return d.surplus*rand(rng, Beta(one(T), d.rpm.α)) end function _logpdf(d::StickBreakingProcess{DirichletProcess{T}}, x::T) where {T<:Real} @@ -168,8 +169,8 @@ v_k \\sim Beta(1-d, \\theta + t*d) *Chinese Restaurant Process* ```math -p(z_n = k | z_{1:n-1}) \\propto \\begin{cases} - \\frac{m_k - d}{n+\\theta}, \\text{if} m_k > 0\\\\ +p(z_n = k | z_{1:n-1}) \\propto \\begin{cases} + \\frac{m_k - d}{n+\\theta}, \\text{if} m_k > 0\\\\ \\frac{\\theta + d*t}{n+\\theta} \\end{cases} ``` @@ -182,12 +183,12 @@ struct PitmanYorProcess{T<:Real} <: AbstractRandomProbabilityMeasure t::Int end -function _rand(d::StickBreakingProcess{PitmanYorProcess{T}}) where {T<:Real} - return rand(Beta(one(T)-d.rpm.d, d.rpm.θ + d.rpm.t*d.rpm.d)) +function _rand(rng::AbstractRNG, d::StickBreakingProcess{PitmanYorProcess{T}}) where {T<:Real} + return rand(rng, Beta(one(T)-d.rpm.d, d.rpm.θ + d.rpm.t*d.rpm.d)) end -function _rand(d::SizeBiasedSamplingProcess{PitmanYorProcess{T}}) where {T<:Real} - return d.surplus*rand(Beta(one(T)-d.rpm.d, d.rpm.θ + d.rpm.t*d.rpm.d)) +function _rand(rng::AbstractRNG, d::SizeBiasedSamplingProcess{PitmanYorProcess{T}}) where {T<:Real} + return d.surplus*rand(rng, Beta(one(T)-d.rpm.d, d.rpm.θ + d.rpm.t*d.rpm.d)) end function _logpdf(d::StickBreakingProcess{PitmanYorProcess{T}}, x::T) where {T<:Real} diff --git a/src/stdlib/distributions.jl b/src/stdlib/distributions.jl index 3d36080a7..72e366099 100644 --- a/src/stdlib/distributions.jl +++ b/src/stdlib/distributions.jl @@ -1,13 +1,14 @@ +import Random: AbstractRNG + # No info struct Flat <: ContinuousUnivariateDistribution end -Distributions.rand(d::Flat) = rand() -Distributions.logpdf(d::Flat, x::T) where T<:Real= zero(x) +Distributions.rand(rng::AbstractRNG, d::Flat) = rand(rng) +Distributions.logpdf(d::Flat, x::Real) = zero(x) Distributions.minimum(d::Flat) = -Inf Distributions.maximum(d::Flat) = +Inf # For vec support -Distributions.rand(d::Flat, n::Int) = Vector([rand() for _ = 1:n]) Distributions.logpdf(d::Flat, x::AbstractVector{<:Real}) = zero(x) # Pos @@ -15,13 +16,12 @@ struct FlatPos{T<:Real} <: ContinuousUnivariateDistribution l::T end -Distributions.rand(d::FlatPos) = rand() + d.l +Distributions.rand(rng::AbstractRNG, d::FlatPos) = rand(rng) + d.l Distributions.logpdf(d::FlatPos, x::Real) = x <= d.l ? -Inf : zero(x) Distributions.minimum(d::FlatPos) = d.l Distributions.maximum(d::FlatPos) = Inf # For vec support -Distributions.rand(d::FlatPos, n::Int) = Vector([rand() for _ = 1:n] .+ d.l) function Distributions.logpdf(d::FlatPos, x::AbstractVector{<:Real}) return any(x .<= d.l) ? -Inf : zero(x) end