Skip to content

Commit

Permalink
Fix Base.rand Signature (#882)
Browse files Browse the repository at this point in the history
* fix rand signatures

* remove comment

* use consistent function names
  • Loading branch information
Brandon Gomes authored and yebai committed Aug 7, 2019
1 parent 27a4132 commit 9f6f127
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 29 deletions.
16 changes: 7 additions & 9 deletions docs/_docs/using-turing/advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ end
### 2. Implement Sampling and Evaluation of the log-pdf


Second, define `rand()` and `logpdf()`, which will be used to run the model.
Second, define `rand` and `logpdf`, which will be used to run the model.


```julia
Distributions.rand(d::Flat) = rand()
Distributions.logpdf{T<:Real}(d::Flat, x::T) = zero(x)
Distributions.rand(rng::AbstractRNG, d::Flat) = rand(rng)
Distributions.logpdf(d::Flat, x::Real) = zero(x)
```


Expand All @@ -47,7 +47,7 @@ In most cases, it may be required to define helper functions, such as the `minim
#### 3.1 Domain Transformation


Some helper functions are necessary for domain transformation. For univariate distributions, the necessary ones to implement are `minimum()` and `maximum()`.
Some helper functions are necessary for domain transformation. For univariate distributions, the necessary ones to implement are `minimum` and `maximum`.


```julia
Expand All @@ -56,18 +56,17 @@ Distributions.maximum(d::Flat) = +Inf
```


Functions for domain transformation which may be required by multivariate or matrix-variate distributions are `size(d)`, `link(d, x)` and `invlink(d, x)`. Please see Turing's [`transform.jl`](https://github.com/TuringLang/Turing.jl/blob/master/src/utilities/transform.jl) for examples.
Functions for domain transformation which may be required by multivariate or matrix-variate distributions are `size`, `link` and `invlink`. Please see Turing's [`transform.jl`](https://github.com/TuringLang/Turing.jl/blob/master/src/utilities/transform.jl) for examples.


#### 3.2 Vectorization Support


The vectorization syntax follows `rv ~ [distribution]`, which requires `rand()` and `logpdf()` to be called on multiple data points at once. An appropriate implementation for `Flat` are shown below.
The vectorization syntax follows `rv ~ [distribution]`, which requires `rand` and `logpdf` to be called on multiple data points at once. An appropriate implementation for `Flat` is shown below.


```julia
Distributions.rand(d::Flat, n::Int) = Vector([rand() for _ = 1:n])
Distributions.logpdf{T<:Real}(d::Flat, x::Vector{T}) = zero(x)
Distributions.logpdf(d::Flat, x::AbstractVector{<:Real}) = zero(x)
```


Expand Down Expand Up @@ -247,4 +246,3 @@ model = gdemo([1.2, 3.5]
# Run all samples.
chns = reduce(chainscat, pmap(x->sample(model,sampler),1:num_chains))
```
31 changes: 16 additions & 15 deletions src/stdlib/RandomMeasures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using StatsFuns: logsumexp

import Distributions: sample, logpdf
import Base: maximum, minimum, rand
import Random: AbstractRNG

## ############### ##
## Representations ##
Expand All @@ -25,7 +26,7 @@ struct SizeBiasedSamplingProcess{T<:AbstractRandomProbabilityMeasure,V<:Abstract
end

logpdf(d::SizeBiasedSamplingProcess, x) = _logpdf(d, x)
rand(d::SizeBiasedSamplingProcess) = _rand(d)
rand(rng::AbstractRNG, d::SizeBiasedSamplingProcess) = _rand(rng, d)
minimum(d::SizeBiasedSamplingProcess) = zero(d.surplus)
maximum(d::SizeBiasedSamplingProcess) = d.surplus

Expand All @@ -39,7 +40,7 @@ struct StickBreakingProcess{T<:AbstractRandomProbabilityMeasure} <: ContinuousUn
end

logpdf(d::StickBreakingProcess, x) = _logpdf(d, x)
rand(d::StickBreakingProcess) = _rand(d)
rand(rng::AbstractRNG, d::StickBreakingProcess) = _rand(rng, d)
minimum(d::StickBreakingProcess) = 0.0
maximum(d::StickBreakingProcess) = 1.0

Expand Down Expand Up @@ -76,10 +77,10 @@ function logpdf(d::ChineseRestaurantProcess, x::Int)
end
end

function rand(d::ChineseRestaurantProcess)
function rand(rng::AbstractRNG, d::ChineseRestaurantProcess)
lp = _logpdf_table(d.rpm, d.m)
p = exp.(lp)
return rand(Categorical(p ./ sum(p)))
return rand(rng, Categorical(p ./ sum(p)))
end

minimum(d::ChineseRestaurantProcess) = 1
Expand Down Expand Up @@ -107,8 +108,8 @@ v_k \\sim Beta(1, \\alpha)
*Chinese Restaurant Process*
```math
p(z_n = k | z_{1:n-1}) \\propto \\begin{cases}
\\frac{m_k}{n-1+\\alpha}, \\text{if} m_k > 0\\\\
p(z_n = k | z_{1:n-1}) \\propto \\begin{cases}
\\frac{m_k}{n-1+\\alpha}, \\text{if} m_k > 0\\\\
\\frac{\\alpha}{n-1+\\alpha}
\\end{cases}
```
Expand All @@ -119,10 +120,10 @@ struct DirichletProcess{T<:Real} <: AbstractRandomProbabilityMeasure
α::T
end

_rand(d::StickBreakingProcess{DirichletProcess{T}}) where {T<:Real} = rand(Beta(one(T), d.rpm.α))
_rand(rng::AbstractRNG, d::StickBreakingProcess{DirichletProcess{T}}) where {T<:Real} = rand(rng, Beta(one(T), d.rpm.α))

function _rand(d::SizeBiasedSamplingProcess{DirichletProcess{T}}) where {T<:Real}
return d.surplus*rand(Beta(one(T), d.rpm.α))
function _rand(rng::AbstractRNG, d::SizeBiasedSamplingProcess{DirichletProcess{T}}) where {T<:Real}
return d.surplus*rand(rng, Beta(one(T), d.rpm.α))
end

function _logpdf(d::StickBreakingProcess{DirichletProcess{T}}, x::T) where {T<:Real}
Expand Down Expand Up @@ -168,8 +169,8 @@ v_k \\sim Beta(1-d, \\theta + t*d)
*Chinese Restaurant Process*
```math
p(z_n = k | z_{1:n-1}) \\propto \\begin{cases}
\\frac{m_k - d}{n+\\theta}, \\text{if} m_k > 0\\\\
p(z_n = k | z_{1:n-1}) \\propto \\begin{cases}
\\frac{m_k - d}{n+\\theta}, \\text{if} m_k > 0\\\\
\\frac{\\theta + d*t}{n+\\theta}
\\end{cases}
```
Expand All @@ -182,12 +183,12 @@ struct PitmanYorProcess{T<:Real} <: AbstractRandomProbabilityMeasure
t::Int
end

function _rand(d::StickBreakingProcess{PitmanYorProcess{T}}) where {T<:Real}
return rand(Beta(one(T)-d.rpm.d, d.rpm.θ + d.rpm.t*d.rpm.d))
function _rand(rng::AbstractRNG, d::StickBreakingProcess{PitmanYorProcess{T}}) where {T<:Real}
return rand(rng, Beta(one(T)-d.rpm.d, d.rpm.θ + d.rpm.t*d.rpm.d))
end

function _rand(d::SizeBiasedSamplingProcess{PitmanYorProcess{T}}) where {T<:Real}
return d.surplus*rand(Beta(one(T)-d.rpm.d, d.rpm.θ + d.rpm.t*d.rpm.d))
function _rand(rng::AbstractRNG, d::SizeBiasedSamplingProcess{PitmanYorProcess{T}}) where {T<:Real}
return d.surplus*rand(rng, Beta(one(T)-d.rpm.d, d.rpm.θ + d.rpm.t*d.rpm.d))
end

function _logpdf(d::StickBreakingProcess{PitmanYorProcess{T}}, x::T) where {T<:Real}
Expand Down
10 changes: 5 additions & 5 deletions src/stdlib/distributions.jl
Original file line number Diff line number Diff line change
@@ -1,27 +1,27 @@
import Random: AbstractRNG

# No info
struct Flat <: ContinuousUnivariateDistribution end

Distributions.rand(d::Flat) = rand()
Distributions.logpdf(d::Flat, x::T) where T<:Real= zero(x)
Distributions.rand(rng::AbstractRNG, d::Flat) = rand(rng)
Distributions.logpdf(d::Flat, x::Real) = zero(x)
Distributions.minimum(d::Flat) = -Inf
Distributions.maximum(d::Flat) = +Inf

# For vec support
Distributions.rand(d::Flat, n::Int) = Vector([rand() for _ = 1:n])
Distributions.logpdf(d::Flat, x::AbstractVector{<:Real}) = zero(x)

# Pos
struct FlatPos{T<:Real} <: ContinuousUnivariateDistribution
l::T
end

Distributions.rand(d::FlatPos) = rand() + d.l
Distributions.rand(rng::AbstractRNG, d::FlatPos) = rand(rng) + d.l
Distributions.logpdf(d::FlatPos, x::Real) = x <= d.l ? -Inf : zero(x)
Distributions.minimum(d::FlatPos) = d.l
Distributions.maximum(d::FlatPos) = Inf

# For vec support
Distributions.rand(d::FlatPos, n::Int) = Vector([rand() for _ = 1:n] .+ d.l)
function Distributions.logpdf(d::FlatPos, x::AbstractVector{<:Real})
return any(x .<= d.l) ? -Inf : zero(x)
end
Expand Down

0 comments on commit 9f6f127

Please sign in to comment.