Skip to content

Commit

Permalink
Fix rand and randn type piracy (#387)
Browse files Browse the repository at this point in the history
* Fix `rand` and `randn` type piracy

* Fix missing `randn` method

* Fix format

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Extend `_randn` to vectors

* Qualify `randn!`

* Fix `_randn`

* Add `_rand`

* Rename internal functions that sample momentum to `rand_momentum`

* Increase number of samples

* Fix `randcat`

* Fix typo

* Update `jitter`

* Fix return value of `jitter`

* Qualify `rand!`

* Increase number of samples

* Update CI

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
devmotion and github-actions[bot] authored Feb 22, 2025
1 parent 85e6aa6 commit b61cbb5
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 45 deletions.
6 changes: 3 additions & 3 deletions research/src/relativistic_hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ using AdaptiveRejectionSampling: RejectionSampler, run_sampler!
import AdvancedHMC: _rand

# TODO Support AbstractVector{<:AbstractRNG}
function _rand(
function rand_momentum(
rng::AbstractRNG,
metric::UnitEuclideanMetric{T},
kinetic::RelativisticKinetic{T},
Expand All @@ -71,12 +71,12 @@ function _rand(
end

# TODO Support AbstractVector{<:AbstractRNG}
function _rand(
function rand_momentum(
rng::AbstractRNG,
metric::DiagEuclideanMetric{T},
kinetic::RelativisticKinetic{T},
) where {T}
r = _rand(rng, UnitEuclideanMetric(size(metric)), kinetic)
r = rand_momentum(rng, UnitEuclideanMetric(size(metric)), kinetic)
# p' = A p where A = sqrtM
r ./= metric.sqrtM⁻¹
return r
Expand Down
12 changes: 6 additions & 6 deletions research/src/riemannian_hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,9 @@ refresh(
# To change L146 of metric.jl
# Ignore θ by default (i.e. not position-dependent)
Base.rand(rng::AbstractRNG, metric::AbstractMetric, kinetic, θ) =
_rand(rng, metric, kinetic) # this disambiguity is required by Random.rand
rand_momentum(rng, metric, kinetic) # this disambiguity is required by Random.rand
Base.rand(rng::AbstractVector{<:AbstractRNG}, metric::AbstractMetric, kinetic, θ) =
_rand(rng, metric, kinetic)
rand_momentum(rng, metric, kinetic)
Base.rand(metric::AbstractMetric, kinetic, θ) = rand(Random.default_rng(), metric, kinetic)

### metric.jl
Expand Down Expand Up @@ -208,27 +208,27 @@ Base.size(e::DenseRiemannianMetric) = e.size
Base.size(e::DenseRiemannianMetric, dim::Int) = e.size[dim]
Base.show(io::IO, dem::DenseRiemannianMetric) = print(io, "DenseRiemannianMetric(...)")

function _rand(
function rand_momentum(
rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}},
metric::DenseRiemannianMetric{T},
kinetic,
θ,
) where {T}
r = randn(rng, T, size(metric)...)
r = _randn(rng, T, size(metric)...)
G⁻¹ = inv(metric.map(metric.G(θ)))
chol = cholesky(Symmetric(G⁻¹))
ldiv!(chol.U, r)
return r
end

Base.rand(rng::AbstractRNG, metric::AbstractRiemannianMetric, kinetic, θ) =
_rand(rng, metric, kinetic, θ)
rand_momentum(rng, metric, kinetic, θ)
Base.rand(
rng::AbstractVector{<:AbstractRNG},
metric::AbstractRiemannianMetric,
kinetic,
θ,
) = _rand(rng, metric, kinetic, θ)
) = rand_momentum(rng, metric, kinetic, θ)

### hamiltonian.jl

Expand Down
26 changes: 15 additions & 11 deletions src/integrator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,21 +127,25 @@ nom_step_size(lf::JitteredLeapfrog) = lf.ϵ0
update_nom_step_size(lf::JitteredLeapfrog, ϵ0) = @set lf.ϵ0 = ϵ0

# Jitter step size; ref: https://github.com/stan-dev/stan/blob/1bb054027b01326e66ec610e95ef9b2a60aa6bec/src/stan/mcmc/hmc/base_hmc.hpp#L177-L178
function _jitter(
function jitter(rng::AbstractRNG, lf::JitteredLeapfrog{FT,FT}) where {FT<:AbstractFloat}
ϵ = lf.ϵ0 * (1 + lf.jitter * (2 * rand(rng, FT) - 1))
return @set lf.ϵ = ϵ
end
function jitter(
rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}},
lf::JitteredLeapfrog{FT,T},
) where {FT<:AbstractFloat,T<:AbstractScalarOrVec{FT}}
ϵ = lf.ϵ0 .* (1 .+ lf.jitter .* (2 .* rand(rng) .- 1))
return @set lf.ϵ = FT.(ϵ)
) where {FT<:AbstractFloat,T<:AbstractVector{FT}}
ϵ = similar(lf.ϵ0)
if rng isa AbstractRNG
Random.rand!(rng, ϵ)
else
@argcheck length(rng) == length(ϵ)
map!(Base.Fix2(rand, FT), ϵ, rng)
end
@. ϵ = lf.ϵ0 * (1 + lf.jitter * (2 * ϵ - 1))
return @set lf.ϵ = ϵ
end

jitter(rng::AbstractRNG, lf::JitteredLeapfrog) = _jitter(rng, lf)

jitter(
rng::AbstractVector{<:AbstractRNG},
lf::JitteredLeapfrog{FT,T},
) where {FT<:AbstractFloat,T<:AbstractScalarOrVec{FT}} = _jitter(rng, lf)

### Tempering
# TODO: add ref or at least explain what exactly we're doing
"""
Expand Down
16 changes: 8 additions & 8 deletions src/metric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,43 +94,43 @@ Base.show(io::IO, dem::DenseEuclideanMetric) =

# `rand` functions for `metric` types.

function _rand(
function rand_momentum(
rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}},
metric::UnitEuclideanMetric{T},
kinetic::GaussianKinetic,
) where {T}
r = randn(rng, T, size(metric)...)
r = _randn(rng, T, size(metric)...)
return r
end

function _rand(
function rand_momentum(
rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}},
metric::DiagEuclideanMetric{T},
kinetic::GaussianKinetic,
) where {T}
r = randn(rng, T, size(metric)...)
r = _randn(rng, T, size(metric)...)
r ./= metric.sqrtM⁻¹
return r
end

function _rand(
function rand_momentum(
rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}},
metric::DenseEuclideanMetric{T},
kinetic::GaussianKinetic,
) where {T}
r = randn(rng, T, size(metric)...)
r = _randn(rng, T, size(metric)...)
ldiv!(metric.cholM⁻¹, r)
return r
end

# TODO (kai) The rand interface should be updated as "rand from momentum distribution + optional affine transformation by metric"
Base.rand(rng::AbstractRNG, metric::AbstractMetric, kinetic::AbstractKinetic) =
_rand(rng, metric, kinetic) # this disambiguity is required by Random.rand
rand_momentum(rng, metric, kinetic) # this disambiguity is required by Random.rand
Base.rand(
rng::AbstractVector{<:AbstractRNG},
metric::AbstractMetric,
kinetic::AbstractKinetic,
) = _rand(rng, metric, kinetic)
) = rand_momentum(rng, metric, kinetic)
Base.rand(metric::AbstractMetric, kinetic::AbstractKinetic) =
rand(Random.default_rng(), metric, kinetic)

Expand Down
7 changes: 5 additions & 2 deletions src/trajectory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -858,8 +858,11 @@ function mh_accept_ratio(
# the chains. We need to revisit this more rigirously
# in the future. See discussions at
# https://github.com/TuringLang/AdvancedHMC.jl/pull/166#pullrequestreview-367216534
_rng = rng isa AbstractRNG ? (rng,) : rng
accept = Hproposal .< Horiginal .+ Random.randexp.(_rng, (T,))
accept = if rng isa AbstractRNG
Hproposal .< Horiginal .+ Random.randexp(rng, T, length(Hproposal))
else
Hproposal .< Horiginal .+ Random.randexp.(rng, (T,))
end
α = min.(one(T), exp.(Horiginal .- Hproposal))
return accept, α
end
39 changes: 26 additions & 13 deletions src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,27 @@ const AbstractScalarOrVec{T} = Union{T,AbstractVector{T}} where {T<:AbstractFloa

# Support of passing a vector of RNGs

Base.rand(rng::AbstractVector{<:AbstractRNG}) = rand.(rng)

Base.randn(rng::AbstractVector{<:AbstractRNG}) = randn.(rng)

function Base.rand(rng::AbstractVector{<:AbstractRNG}, T, n_chains::Int)
@argcheck length(rng) == n_chains
return rand.(rng, T)
function _randn(rng::AbstractRNG, ::Type{T}, n_chains::Int) where {T}
return randn(rng, T, n_chains)
end
function _randn(rng::AbstractRNG, ::Type{T}, dim::Int, n_chains::Int) where {T}
return randn(rng, T, dim, n_chains)
end

function Base.randn(rng::AbstractVector{<:AbstractRNG}, T, dim::Int, n_chains::Int)
@argcheck length(rng) == n_chains
return cat(randn.(rng, T, dim)...; dims = 2)
function _randn(rngs::AbstractVector{<:AbstractRNG}, ::Type{T}, n_chains::Int) where {T}
@argcheck length(rngs) == n_chains
return map(Base.Fix2(randn, T), rngs)
end
function _randn(
rngs::AbstractVector{<:AbstractRNG},
::Type{T},
dim::Int,
n_chains::Int,
) where {T}
@argcheck length(rngs) == n_chains
out = similar(rngs, T, dim, n_chains)
foreach(Random.randn!, rngs, eachcol(out))
return out
end

"""
Expand Down Expand Up @@ -80,8 +89,12 @@ function randcat(
rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}},
P::AbstractMatrix{T},
) where {T}
u = rand(rng, T, size(P, 2))
u = if rng isa AbstractRNG
rand(rng, T, size(P, 2))
else
@argcheck length(rng) == size(P, 2)
map(Base.Fix2(rand, T), rng)
end
C = cumsum(P; dims = 1)
indices = convert.(Int, vec(sum(C .< u'; dims = 1))) .+ 1
return max.(indices, 1) # prevent numerical issue for Float32
return max.(vec(count(C .< u'; dims = 1)) .+ 1, 1) # prevent numerical issue for Float32
end
4 changes: 2 additions & 2 deletions test/adaptation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ using ReTest, LinearAlgebra, Distributions, AdvancedHMC, Random, ForwardDiff
using AdvancedHMC.Adaptation:
WelfordVar, NaiveVar, WelfordCov, NaiveCov, get_estimation, get_estimation, reset!

function runnuts(ℓπ, metric; n_samples = 3_000)
function runnuts(ℓπ, metric; n_samples = 10_000)
D = size(metric, 1)
n_adapts = 1_500
n_adapts = 5_000
θ_init = rand(D)
rng = MersenneTwister(0)

Expand Down

0 comments on commit b61cbb5

Please sign in to comment.