From 97488a3d9cf9b03ec0e32491dfbf73b7a5bf7a3d Mon Sep 17 00:00:00 2001 From: Albert Date: Thu, 8 Aug 2024 13:25:44 +0200 Subject: [PATCH 01/32] Add tests for MvNormalMeanScalePrecision --- .../mv_normal_mean_scale_precision_tests.jl | 128 ++++++++++++++++++ 1 file changed, 128 insertions(+) create mode 100644 test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl diff --git a/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl b/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl new file mode 100644 index 00000000..d5426385 --- /dev/null +++ b/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl @@ -0,0 +1,128 @@ + +@testitem "MvNormalMeanScalePrecision: Constructor" begin + include("./normal_family_setuptests.jl") + + @test MvNormalMeanScalePrecision <: AbstractMvNormal + + @test MvNormalMeanScalePrecision([1.0, 1.0]) == MvNormalMeanScalePrecision([1.0, 1.0], 1.0) + @test MvNormalMeanScalePrecision([1.0, 2.0]) == MvNormalMeanScalePrecision([1.0, 2.0], 1.0) + @test MvNormalMeanScalePrecision([1, 2]) == MvNormalMeanScalePrecision([1.0, 2.0], 1.0) + @test MvNormalMeanScalePrecision([1.0f0, 2.0f0]) == MvNormalMeanScalePrecision([1.0f0, 2.0f0], 1.0f0) + + @test eltype(MvNormalMeanScalePrecision([1.0, 1.0])) === Float64 + @test eltype(MvNormalMeanScalePrecision([1.0, 1.0], 1.0)) === Float64 + @test eltype(MvNormalMeanScalePrecision([1, 1])) === Float64 + @test eltype(MvNormalMeanScalePrecision([1, 1], 1)) === Float64 + @test eltype(MvNormalMeanScalePrecision([1.0f0, 1.0f0])) === Float32 + @test eltype(MvNormalMeanScalePrecision([1.0f0, 1.0f0], 1.0f0)) === Float32 + + @test MvNormalMeanScalePrecision(ones(3), 5) == MvNormalMeanScalePrecision(ones(3), 5) + @test MvNormalMeanScalePrecision([1, 2, 3, 4], 7.0) == MvNormalMeanPrecision([1.0, 2.0, 3.0, 4.0], 7.0) +end + +@testitem "MvNormalMeanScalePrecision: distrname" begin + include("./normal_family_setuptests.jl") + + @test ExponentialFamily.distrname(MvNormalMeanScalePrecision(zeros(2))) === "MvNormalMeanScalePrecision" +end + +@testitem "MvNormalMeanScalePrecision: Stats methods" begin + include("./normal_family_setuptests.jl") + + μ = [0.2, 3.0, 4.0] + γ = 2.0 + dist = MvNormalMeanScalePrecision(μ, γ) + rdist = MvNormalMeanPrecision(μ, γ * ones(length(μ))) + + @test mean(dist) == μ + @test mode(dist) == μ + @test weightedmean(dist) == γ * μ + @test invcov(dist) == γ + @test precision(dist) == γ + @test cov(dist) ≈ inv(Λ) + @test std(dist) * std(dist)' ≈ inv(γ) + @test all(mean_cov(dist) .≈ (μ, inv(γ))) + @test all(mean_invcov(dist) .≈ (μ, γ)) + @test all(mean_precision(dist) .≈ (μ, γ)) + @test all(weightedmean_cov(dist) .≈ (Λ * μ, inv(γ))) + @test all(weightedmean_invcov(dist) .≈ (γ * μ, γ)) + @test all(weightedmean_precision(dist) .≈ (γ * μ, γ)) + + @test length(dist) == 3 + @test entropy(dist) ≈ entropy(rdist) + @test pdf(dist, [0.2, 3.0, 4.0]) ≈ pdf(rdist, [0.2, 3.0, 4.0]) + @test pdf(dist, [0.202, 3.002, 4.002]) ≈ pdf(rdist, [0.2, 3.0, 4.0]) + @test logpdf(dist, [0.2, 3.0, 4.0]) ≈ pdf(rdist, [0.2, 3.0, 4.0]) + @test logpdf(dist, [0.202, 3.002, 4.002]) ≈ pdf(rdist, [0.2, 3.0, 4.0]) +end + +@testitem "MvNormalMeanScalePrecision: Base methods" begin + include("./normal_family_setuptests.jl") + + @test convert(MvNormalMeanScalePrecision{Float32}, MvNormalMeanScalePrecision([0.0, 0.0])) == + MvNormalMeanScalePrecision([0.0f0, 0.0f0], 1.0f0) + @test convert(MvNormalMeanScalePrecision{Float64}, [0.0, 0.0], 2.0) == + MvNormalMeanScalePrecision([0.0, 0.0], 2.0) + + @test length(MvNormalMeanScalePrecision([0.0, 0.0])) === 2 + @test length(MvNormalMeanScalePrecision([0.0, 0.0, 0.0])) === 3 + @test ndims(MvNormalMeanScalePrecision([0.0, 0.0])) === 2 + @test ndims(MvNormalMeanScalePrecision([0.0, 0.0, 0.0])) === 3 + @test size(MvNormalMeanScalePrecision([0.0, 0.0])) === (2,) + @test size(MvNormalMeanScalePrecision([0.0, 0.0, 0.0])) === (3,) + + distribution = MvNormalMeanScalePrecision([0.0, 0.0], [2.0 0.0; 0.0 3.0]) + + @test distribution ≈ distribution + @test distribution ≈ convert(MvNormalMeanCovariance, distribution) + @test distribution ≈ convert(MvNormalMeanPrecision, distribution) + @test distribution ≈ convert(MvNormalWeightedMeanPrecision, distribution) +end + +@testitem "MvNormalMeanScalePrecision: vague" begin + include("./normal_family_setuptests.jl") + + @test_throws MethodError vague(MvNormalMeanScalePrecision) + + d1 = vague(MvNormalMeanScalePrecision, 2) + + @test typeof(d1) <: MvNormalMeanScalePrecision + @test mean(d1) == zeros(2) + @test invcov(d1) == Matrix(Diagonal(1e-12 * ones(2))) + @test ndims(d1) == 2 + + d2 = vague(MvNormalMeanScalePrecision, 3) + + @test typeof(d2) <: MvNormalMeanScalePrecision + @test mean(d2) == zeros(3) + @test invcov(d2) == Matrix(Diagonal(1e-12 * ones(3))) + @test ndims(d2) == 3 +end + +@testitem "MvNormalMeanScalePrecision: prod" begin + include("./normal_family_setuptests.jl") + + for strategy in (ClosedProd(), PreserveTypeProd(Distribution), GenericProd()) + @test prod(strategy, MvNormalMeanScalePrecision([-1, -1], 2), MvNormalMeanPrecision([1, 1], [2, 4])) ≈ + MvNormalWeightedMeanPrecision([0, 2], [4, 6]) + + μ = [1.0, 2.0, 3.0] + γ = 2.0 + dist = MvNormalMeanPrecision(μ, Λ) + + @test prod(strategy, dist, dist) ≈ + MvNormalMeanScalePrecision([4.0, 8.0, 12.0], 2γ) + end +end + +@testitem "MvNormalMeanScalePrecision: convert" begin + include("./normal_family_setuptests.jl") + + @test convert(MvNormalMeanScalePrecision, zeros(2), 1.0) == + MvNormalMeanScalePrecision(zeros(2), 1.0) + @test begin + m = rand(5) + c = rand() + convert(MvNormalMeanScalePrecision, m, c) == MvNormalMeanPrecision(m, c) + end +end From a590f7fcd51457bd87f92de583fdb269221d6a14 Mon Sep 17 00:00:00 2001 From: Albert Date: Thu, 8 Aug 2024 17:54:15 +0200 Subject: [PATCH 02/32] Add MvNormalMeanScalePrecision --- .../mv_normal_mean_scale_precision.jl | 113 ++++++++++++++++++ .../normal_family/normal_family.jl | 38 +++++- 2 files changed, 146 insertions(+), 5 deletions(-) create mode 100644 src/distributions/normal_family/mv_normal_mean_scale_precision.jl diff --git a/src/distributions/normal_family/mv_normal_mean_scale_precision.jl b/src/distributions/normal_family/mv_normal_mean_scale_precision.jl new file mode 100644 index 00000000..a63f0e44 --- /dev/null +++ b/src/distributions/normal_family/mv_normal_mean_scale_precision.jl @@ -0,0 +1,113 @@ +export MvNormalMeanScalePrecision + +import Distributions: logdetcov, distrname, sqmahal, sqmahal!, AbstractMvNormal +import LinearAlgebra: diag, Diagonal, dot +import Base: ndims, precision, length, size, prod + +""" + MvNormalMeanScalePrecision{T <: Real, M <: AbstractVector{T}} <: AbstractMvNormal + +A multivariate normal distribution with mean `μ` and scale parameter `γ` that scales the identity precision matrix. + +# Type Parameters +- `T`: The element type of the mean vector and scale parameter +- `M`: The type of the mean vector, which must be a subtype of `AbstractVector{T}` + +# Fields +- `μ::M`: The mean vector of the multivariate normal distribution +- `γ::T`: The scale parameter that scales the identity precision matrix + +# Notes +The precision matrix of this distribution is `γ * I`, where `I` is the identity matrix. +The covariance matrix is the inverse of the precision matrix, i.e., `(1/γ) * I`. +""" +struct MvNormalMeanScalePrecision{T <: Real, M <: AbstractVector{T}} <: AbstractMvNormal + μ::M + γ::T +end + +function MvNormalMeanScalePrecision(μ::AbstractVector{<:Real}, γ::Real) + T = promote_type(eltype(μ), eltype(γ)) + return MvNormalMeanScalePrecision(convert(AbstractArray{T}, μ), convert(T, γ)) +end + +function MvNormalMeanScalePrecision(μ::AbstractVector{<:Integer}, γ::Real) + return MvNormalMeanScalePrecision(float.(μ), float(γ)) +end + +function MvNormalMeanScalePrecision(μ::AbstractVector{T}) where {T} + return MvNormalMeanScalePrecision(μ, convert(T, 1)) +end + +function MvNormalMeanScalePrecision(μ::AbstractVector{T1}, γ::T2) where {T1, T2} + T = promote_type(T1, T2) + μ_new = convert(AbstractArray{T}, μ) + γ_new = convert(T, γ)(length(μ)) + return MvNormalMeanScalePrecision(μ_new, γ_new) +end + +Distributions.distrname(::MvNormalMeanScalePrecision) = "MvNormalMeanScalePrecision" + +BayesBase.weightedmean(dist::MvNormalMeanScalePrecision) = precision(dist) * mean(dist) + +BayesBase.mean(dist::MvNormalMeanScalePrecision) = dist.μ +BayesBase.mode(dist::MvNormalMeanScalePrecision) = mean(dist) +BayesBase.var(dist::MvNormalMeanScalePrecision) = diag(cov(dist)) +BayesBase.cov(dist::MvNormalMeanScalePrecision) = cholinv(invcov(dist)) +BayesBase.invcov(dist::MvNormalMeanScalePrecision) = dist.γ * I(length(mean(dist))) +BayesBase.std(dist::MvNormalMeanScalePrecision) = cholsqrt(cov(dist)) +BayesBase.logdetcov(dist::MvNormalMeanScalePrecision) = -chollogdet(invcov(dist)) +BayesBase.params(dist::MvNormalMeanScalePrecision) = (mean(dist), invcov(dist)) + +function Distributions.sqmahal(dist::MvNormalMeanScalePrecision, x::AbstractVector) + T = promote_type(eltype(x), paramfloattype(dist)) + return sqmahal!(similar(x, T), dist, x) +end + +function Distributions.sqmahal!(r, dist::MvNormalMeanScalePrecision, x::AbstractVector) + μ = mean(dist) + @inbounds @simd for i in 1:length(r) + r[i] = μ[i] - x[i] + end + return dot3arg(r, invcov(dist), r) # x' * A * x +end + +Base.eltype(::MvNormalMeanScalePrecision{T}) where {T} = T +Base.precision(dist::MvNormalMeanScalePrecision) = dist.γ +Base.length(dist::MvNormalMeanScalePrecision) = length(mean(dist)) +Base.ndims(dist::MvNormalMeanScalePrecision) = length(dist) +Base.size(dist::MvNormalMeanScalePrecision) = (length(dist),) + +Base.convert(::Type{<:MvNormalMeanScalePrecision}, μ::AbstractVector, γ::Real) = MvNormalMeanScalePrecision(μ, γ) + +function Base.convert(::Type{<:MvNormalMeanScalePrecision{T}}, μ::AbstractVector, γ::T) where {T <: Real} + @show "hi" + MvNormalMeanScalePrecision(convert(AbstractArray{T}, μ), convert(T, γ)) +end + +BayesBase.vague(::Type{<:MvNormalMeanScalePrecision}, dims::Int) = + MvNormalMeanScalePrecision(zeros(Float64, dims), convert(Float64, tiny)) + +BayesBase.default_prod_rule(::Type{<:MvNormalMeanScalePrecision}, ::Type{<:MvNormalMeanScalePrecision}) = PreserveTypeProd(Distribution) + +function BayesBase.prod(::PreserveTypeProd{Distribution}, left::MvNormalMeanScalePrecision, right::MvNormalMeanScalePrecision) + w = precision(left) + precision(right) + m = (precision(left) * mean(left) + precision(right) * mean(right)) / w + return MvNormalMeanScalePrecision(m, w) +end + +function BayesBase.prod( + ::PreserveTypeProd{Distribution}, + left::MvNormalMeanScalePrecision{T1}, + right::MvNormalMeanScalePrecision{T2} +) where {T1 <: LinearAlgebra.BlasFloat, T2 <: LinearAlgebra.BlasFloat} + w = precision(left) + precision(right) + + xi = precision(right) * mean(right) + T = promote_type(T1, T2) + xi = convert(AbstractVector{T}, xi) + w = convert(T, w) + xi = BLAS.gemv!('N', one(T), convert(AbstractMatrix{T}, precision(left)), convert(AbstractVector{T}, mean(left)), one(T), xi) + + return MvNormalMeanScalePrecision(xi / w, w) +end diff --git a/src/distributions/normal_family/normal_family.jl b/src/distributions/normal_family/normal_family.jl index 80e70658..49eba1e0 100644 --- a/src/distributions/normal_family/normal_family.jl +++ b/src/distributions/normal_family/normal_family.jl @@ -10,9 +10,10 @@ const GaussianWeighteMeanPrecision = NormalWeightedMeanPrecision const MvGaussianMeanCovariance = MvNormalMeanCovariance const MvGaussianMeanPrecision = MvNormalMeanPrecision const MvGaussianWeightedMeanPrecision = MvNormalWeightedMeanPrecision +const MvGaussianMeanScalePrecision = MvNormalMeanScalePrecision const UnivariateNormalDistributionsFamily{T} = Union{NormalMeanPrecision{T}, NormalMeanVariance{T}, NormalWeightedMeanPrecision{T}, Normal{T}} -const MultivariateNormalDistributionsFamily{T} = Union{MvNormalMeanPrecision{T}, MvNormalMeanCovariance{T}, MvNormalWeightedMeanPrecision{T}, MvNormal{T}} +const MultivariateNormalDistributionsFamily{T} = Union{MvNormalMeanPrecision{T}, MvNormalMeanScalePrecision{T}, MvNormalMeanCovariance{T}, MvNormalWeightedMeanPrecision{T}, MvNormal{T}} const NormalDistributionsFamily{T} = Union{UnivariateNormalDistributionsFamily{T}, MultivariateNormalDistributionsFamily{T}} const UnivariateGaussianDistributionsFamily = UnivariateNormalDistributionsFamily @@ -251,6 +252,12 @@ function Base.convert( return MvNormal(convert(M, mean), Distributions.PDMats.PDMat(convert(AbstractMatrix{T}, cov))) end +# Special case for `MvNormalMeanScalePrecision` to `MvNormal` +function Base.convert(::Type{MvNormal{T, C, M}}, dist::MvNormalMeanScalePrecision) where {T <: Real, C <: Distributions.PDMats.PDMat{T, Matrix{T}}, M <: AbstractVector{T}} + m, σ = mean(dist), std(dist) + return MvNormal(convert(M, m), convert(T, σ)) +end + function Base.convert(::Type{MvNormal{T}}, dist::MultivariateNormalDistributionsFamily) where {T <: Real} return convert(MvNormal{T, Distributions.PDMats.PDMat{T, Matrix{T}}, Vector{T}}, dist) end @@ -263,6 +270,8 @@ function Base.convert(::Type{MvNormal{T}}, mean::AbstractVector, cov::AbstractMa return MvNormal(convert(AbstractVector{T}, mean), Distributions.PDMats.PDMat(convert(AbstractMatrix{T}, cov))) end + + # Conversion to mean - variance parametrisation function Base.convert(::Type{NormalMeanVariance{T}}, dist::UnivariateNormalDistributionsFamily) where {T <: Real} @@ -297,6 +306,13 @@ function Base.convert(::Type{MvNormalMeanCovariance}, dist::MultivariateNormalDi return convert(MvNormalMeanCovariance{T}, dist) end + +# Special case for `MvNormalMeanScalePrecision` to `MvNormalMeanCovariance` +function Base.convert(::Type{MvNormalMeanCovariance}, dist::MvNormalMeanScalePrecision) + m, σ = mean(dist), cov(dist) + return MvNormalMeanCovariance(m, σ*diagm(ones(length(m)))) +end + # Conversion to mean - precision parametrisation function Base.convert(::Type{NormalMeanPrecision{T}}, dist::UnivariateNormalDistributionsFamily) where {T <: Real} @@ -331,6 +347,12 @@ function Base.convert(::Type{MvNormalMeanPrecision}, dist::MultivariateNormalDis return convert(MvNormalMeanPrecision{T}, dist) end +# Special case for `MvNormalMeanScalePrecision` to `MvNormalMeanPrecision` +function Base.convert(::Type{MvNormalMeanPrecision}, dist::MvNormalMeanScalePrecision) + m, γ = mean(dist), precision(dist) + return MvNormalMeanPrecision(m, γ*diagm(ones(length(m)))) +end + # Conversion to weighted mean - precision parametrisation function Base.convert( @@ -393,13 +415,19 @@ function Base.convert(::Type{MvNormalWeightedMeanPrecision}, dist::FullNormal) return MvNormalWeightedMeanPrecision(precision * mean, precision) end +# Special case for `MvNormalMeanScalePrecision` to `MvNormalWeightedMeanPrecision` +function Base.convert(::Type{MvNormalWeightedMeanPrecision}, dist::MvNormalMeanScalePrecision) + m, γ = mean(dist), precision(dist) + return MvNormalWeightedMeanPrecision(γ*m, γ*diagm(ones(length(m)))) +end + # isapprox function Base.isapprox(left::UnivariateNormalDistributionsFamily, right::UnivariateNormalDistributionsFamily; kwargs...) return all(p -> isapprox(p[1], p[2]; kwargs...), zip(mean_var(left), mean_var(right))) end -function Base.isapprox(left::D, right::D; kwargs...) where { D <: UnivariateNormalDistributionsFamily } +function Base.isapprox(left::D, right::D; kwargs...) where {D <: UnivariateNormalDistributionsFamily} return all(p -> isapprox(p[1], p[2]; kwargs...), zip(params(left), params(right))) end @@ -407,7 +435,7 @@ function Base.isapprox(left::MultivariateNormalDistributionsFamily, right::Multi return all(p -> isapprox(p[1], p[2]; kwargs...), zip(mean_cov(left), mean_cov(right))) end -function Base.isapprox(left::D, right::D; kwargs...) where { D <: MultivariateNormalDistributionsFamily } +function Base.isapprox(left::D, right::D; kwargs...) where {D <: MultivariateNormalDistributionsFamily} return all(p -> isapprox(p[1], p[2]; kwargs...), zip(params(left), params(right))) end @@ -586,7 +614,7 @@ end getgradlogpartition(::NaturalParametersSpace, ::Type{NormalMeanVariance}) = (η) -> begin (η₁, η₂) = unpack_parameters(NormalMeanVariance, η) - return SA[-η₁ * inv(η₂*2), abs2(η₁) / ( 4 * abs2(η₂)) - 1 / (2 * η₂)] + return SA[-η₁*inv(η₂ * 2), abs2(η₁)/(4*abs2(η₂))-1/(2*η₂)] end getfisherinformation(::NaturalParametersSpace, ::Type{NormalMeanVariance}) = @@ -608,7 +636,7 @@ end getgradlogpartition(::MeanParametersSpace, ::Type{NormalMeanVariance}) = (θ) -> begin (μ, σ²) = unpack_parameters(NormalMeanVariance, θ) - return SA[μ / σ², - abs2(μ) / (2σ²^2) + 1 / σ²] + return SA[μ/σ², -abs2(μ)/(2σ²^2)+1/σ²] end getfisherinformation(::MeanParametersSpace, ::Type{NormalMeanVariance}) = (θ) -> begin From d78da98c4066d3e74627666b3bed0d93df507bd8 Mon Sep 17 00:00:00 2001 From: Albert Date: Thu, 8 Aug 2024 18:19:33 +0200 Subject: [PATCH 03/32] Fix distribution --- src/ExponentialFamily.jl | 1 + .../mv_normal_mean_scale_precision.jl | 13 ++++--- .../mv_normal_mean_scale_precision_tests.jl | 34 +++++++++---------- 3 files changed, 24 insertions(+), 24 deletions(-) diff --git a/src/ExponentialFamily.jl b/src/ExponentialFamily.jl index 0366326e..5fe00a55 100644 --- a/src/ExponentialFamily.jl +++ b/src/ExponentialFamily.jl @@ -45,6 +45,7 @@ include("distributions/normal_family/normal_weighted_mean_precision.jl") include("distributions/normal_family/mv_normal_mean_covariance.jl") include("distributions/normal_family/mv_normal_mean_precision.jl") include("distributions/normal_family/mv_normal_weighted_mean_precision.jl") +include("distributions/normal_family/mv_normal_mean_scale_precision.jl") include("distributions/normal_family/normal_family.jl") include("distributions/gamma_inverse.jl") include("distributions/geometric.jl") diff --git a/src/distributions/normal_family/mv_normal_mean_scale_precision.jl b/src/distributions/normal_family/mv_normal_mean_scale_precision.jl index a63f0e44..1e5263f1 100644 --- a/src/distributions/normal_family/mv_normal_mean_scale_precision.jl +++ b/src/distributions/normal_family/mv_normal_mean_scale_precision.jl @@ -21,7 +21,7 @@ A multivariate normal distribution with mean `μ` and scale parameter `γ` that The precision matrix of this distribution is `γ * I`, where `I` is the identity matrix. The covariance matrix is the inverse of the precision matrix, i.e., `(1/γ) * I`. """ -struct MvNormalMeanScalePrecision{T <: Real, M <: AbstractVector{T}} <: AbstractMvNormal +struct MvNormalMeanScalePrecision{T <: Real, M <: AbstractVector{T}} μ::M γ::T end @@ -73,17 +73,16 @@ function Distributions.sqmahal!(r, dist::MvNormalMeanScalePrecision, x::Abstract end Base.eltype(::MvNormalMeanScalePrecision{T}) where {T} = T -Base.precision(dist::MvNormalMeanScalePrecision) = dist.γ +Base.precision(dist::MvNormalMeanScalePrecision) = invcov(dist) Base.length(dist::MvNormalMeanScalePrecision) = length(mean(dist)) Base.ndims(dist::MvNormalMeanScalePrecision) = length(dist) Base.size(dist::MvNormalMeanScalePrecision) = (length(dist),) -Base.convert(::Type{<:MvNormalMeanScalePrecision}, μ::AbstractVector, γ::Real) = MvNormalMeanScalePrecision(μ, γ) +# Base.convert(::Type{<:MvNormalMeanScalePrecision}, μ::AbstractVector, γ::Real) = MvNormalMeanScalePrecision(μ, γ) -function Base.convert(::Type{<:MvNormalMeanScalePrecision{T}}, μ::AbstractVector, γ::T) where {T <: Real} - @show "hi" - MvNormalMeanScalePrecision(convert(AbstractArray{T}, μ), convert(T, γ)) -end +# function Base.convert(::Type{<:MvNormalMeanScalePrecision{T}}, μ::AbstractVector, γ::T) where {T <: Real} +# MvNormalMeanScalePrecision(convert(AbstractArray{T}, μ), convert(T, γ)) +# end BayesBase.vague(::Type{<:MvNormalMeanScalePrecision}, dims::Int) = MvNormalMeanScalePrecision(zeros(Float64, dims), convert(Float64, tiny)) diff --git a/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl b/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl index d5426385..e9594fec 100644 --- a/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl +++ b/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl @@ -17,7 +17,7 @@ @test eltype(MvNormalMeanScalePrecision([1.0f0, 1.0f0], 1.0f0)) === Float32 @test MvNormalMeanScalePrecision(ones(3), 5) == MvNormalMeanScalePrecision(ones(3), 5) - @test MvNormalMeanScalePrecision([1, 2, 3, 4], 7.0) == MvNormalMeanPrecision([1.0, 2.0, 3.0, 4.0], 7.0) + @test MvNormalMeanScalePrecision([1, 2, 3, 4], 7.0) == MvNormalMeanScalePrecision([1.0, 2.0, 3.0, 4.0], 7.0) end @testitem "MvNormalMeanScalePrecision: distrname" begin @@ -36,24 +36,24 @@ end @test mean(dist) == μ @test mode(dist) == μ - @test weightedmean(dist) == γ * μ - @test invcov(dist) == γ - @test precision(dist) == γ - @test cov(dist) ≈ inv(Λ) - @test std(dist) * std(dist)' ≈ inv(γ) - @test all(mean_cov(dist) .≈ (μ, inv(γ))) - @test all(mean_invcov(dist) .≈ (μ, γ)) - @test all(mean_precision(dist) .≈ (μ, γ)) - @test all(weightedmean_cov(dist) .≈ (Λ * μ, inv(γ))) - @test all(weightedmean_invcov(dist) .≈ (γ * μ, γ)) - @test all(weightedmean_precision(dist) .≈ (γ * μ, γ)) + @test weightedmean(dist) == weightedmean(rdist) + @test invcov(dist) == invcov(rdist) + @test precision(dist) == precision(rdist) + @test cov(dist) ≈ cov(rdist) + @test std(dist) * std(dist)' ≈ std(rdist) * std(rdist)' + @test all(mean_cov(dist) .≈ mean_cov(rdist)) + @test all(mean_invcov(dist) .≈ mean_invcov(rdist)) + @test all(mean_precision(dist) .≈ mean_precision(rdist)) + @test all(weightedmean_cov(dist) .≈ weightedmean_cov(rdist)) + @test all(weightedmean_invcov(dist) .≈ weightedmean_invcov(rdist)) + @test all(weightedmean_precision(dist) .≈ weightedmean_precision(rdist)) @test length(dist) == 3 @test entropy(dist) ≈ entropy(rdist) @test pdf(dist, [0.2, 3.0, 4.0]) ≈ pdf(rdist, [0.2, 3.0, 4.0]) - @test pdf(dist, [0.202, 3.002, 4.002]) ≈ pdf(rdist, [0.2, 3.0, 4.0]) - @test logpdf(dist, [0.2, 3.0, 4.0]) ≈ pdf(rdist, [0.2, 3.0, 4.0]) - @test logpdf(dist, [0.202, 3.002, 4.002]) ≈ pdf(rdist, [0.2, 3.0, 4.0]) + @test pdf(dist, [0.202, 3.002, 4.002]) ≈ pdf(rdist, [0.202, 3.002, 4.002]) + @test logpdf(dist, [0.2, 3.0, 4.0]) ≈ logpdf(rdist, [0.2, 3.0, 4.0]) + @test logpdf(dist, [0.202, 3.002, 4.002]) ≈ logpdf(rdist, [0.202, 3.002, 4.002]) end @testitem "MvNormalMeanScalePrecision: Base methods" begin @@ -71,7 +71,7 @@ end @test size(MvNormalMeanScalePrecision([0.0, 0.0])) === (2,) @test size(MvNormalMeanScalePrecision([0.0, 0.0, 0.0])) === (3,) - distribution = MvNormalMeanScalePrecision([0.0, 0.0], [2.0 0.0; 0.0 3.0]) + distribution = MvNormalMeanScalePrecision([0.0, 0.0], 2.0) @test distribution ≈ distribution @test distribution ≈ convert(MvNormalMeanCovariance, distribution) @@ -108,7 +108,7 @@ end μ = [1.0, 2.0, 3.0] γ = 2.0 - dist = MvNormalMeanPrecision(μ, Λ) + dist = MvNormalMeanScalePrecision(μ, γ) @test prod(strategy, dist, dist) ≈ MvNormalMeanScalePrecision([4.0, 8.0, 12.0], 2γ) From 45d0f591411974853e78dc1d417cee2158a7d75b Mon Sep 17 00:00:00 2001 From: Albert Date: Fri, 9 Aug 2024 11:46:22 +0200 Subject: [PATCH 04/32] Fix tests --- .../mv_normal_mean_scale_precision.jl | 25 ++++++++++--------- .../normal_family/normal_family.jl | 15 +++++++++++ .../mv_normal_mean_scale_precision_tests.jl | 4 +-- 3 files changed, 30 insertions(+), 14 deletions(-) diff --git a/src/distributions/normal_family/mv_normal_mean_scale_precision.jl b/src/distributions/normal_family/mv_normal_mean_scale_precision.jl index 1e5263f1..ecf29181 100644 --- a/src/distributions/normal_family/mv_normal_mean_scale_precision.jl +++ b/src/distributions/normal_family/mv_normal_mean_scale_precision.jl @@ -21,7 +21,7 @@ A multivariate normal distribution with mean `μ` and scale parameter `γ` that The precision matrix of this distribution is `γ * I`, where `I` is the identity matrix. The covariance matrix is the inverse of the precision matrix, i.e., `(1/γ) * I`. """ -struct MvNormalMeanScalePrecision{T <: Real, M <: AbstractVector{T}} +struct MvNormalMeanScalePrecision{T <: Real, M <: AbstractVector{T}} <: AbstractMvNormal μ::M γ::T end @@ -78,11 +78,11 @@ Base.length(dist::MvNormalMeanScalePrecision) = length(mean(dist)) Base.ndims(dist::MvNormalMeanScalePrecision) = length(dist) Base.size(dist::MvNormalMeanScalePrecision) = (length(dist),) -# Base.convert(::Type{<:MvNormalMeanScalePrecision}, μ::AbstractVector, γ::Real) = MvNormalMeanScalePrecision(μ, γ) +Base.convert(::Type{<:MvNormalMeanScalePrecision}, μ::AbstractVector, γ::Real) = MvNormalMeanScalePrecision(μ, γ) -# function Base.convert(::Type{<:MvNormalMeanScalePrecision{T}}, μ::AbstractVector, γ::T) where {T <: Real} -# MvNormalMeanScalePrecision(convert(AbstractArray{T}, μ), convert(T, γ)) -# end +function Base.convert(::Type{<:MvNormalMeanScalePrecision{T}}, μ::AbstractVector, γ::T) where {T <: Real} + MvNormalMeanScalePrecision(convert(AbstractArray{T}, μ), convert(T, γ)) +end BayesBase.vague(::Type{<:MvNormalMeanScalePrecision}, dims::Int) = MvNormalMeanScalePrecision(zeros(Float64, dims), convert(Float64, tiny)) @@ -90,7 +90,7 @@ BayesBase.vague(::Type{<:MvNormalMeanScalePrecision}, dims::Int) = BayesBase.default_prod_rule(::Type{<:MvNormalMeanScalePrecision}, ::Type{<:MvNormalMeanScalePrecision}) = PreserveTypeProd(Distribution) function BayesBase.prod(::PreserveTypeProd{Distribution}, left::MvNormalMeanScalePrecision, right::MvNormalMeanScalePrecision) - w = precision(left) + precision(right) + w = left.γ + right.γ m = (precision(left) * mean(left) + precision(right) * mean(right)) / w return MvNormalMeanScalePrecision(m, w) end @@ -100,13 +100,14 @@ function BayesBase.prod( left::MvNormalMeanScalePrecision{T1}, right::MvNormalMeanScalePrecision{T2} ) where {T1 <: LinearAlgebra.BlasFloat, T2 <: LinearAlgebra.BlasFloat} - w = precision(left) + precision(right) + w = left.γ + right.γ - xi = precision(right) * mean(right) - T = promote_type(T1, T2) - xi = convert(AbstractVector{T}, xi) - w = convert(T, w) - xi = BLAS.gemv!('N', one(T), convert(AbstractMatrix{T}, precision(left)), convert(AbstractVector{T}, mean(left)), one(T), xi) + T = promote_type(T1, T2) + + xi = convert(AbstractVector{T}, right.γ * mean(right)) + w = convert(T, w) + + xi .+= convert(T, left.γ) .* convert(AbstractVector{T}, mean(left)) return MvNormalMeanScalePrecision(xi / w, w) end diff --git a/src/distributions/normal_family/normal_family.jl b/src/distributions/normal_family/normal_family.jl index 49eba1e0..304a868b 100644 --- a/src/distributions/normal_family/normal_family.jl +++ b/src/distributions/normal_family/normal_family.jl @@ -307,6 +307,21 @@ function Base.convert(::Type{MvNormalMeanCovariance}, dist::MultivariateNormalDi end +function Base.convert( + ::Type{MvNormalMeanScalePrecision{T, M}}, + dist::MvNormalMeanScalePrecision +) where {T <: Real, M <: AbstractArray{T}} + m, γ = mean(dist), dist.γ + return MvNormalMeanScalePrecision{T, M}(convert(M, m), convert(T, γ)) +end + +function Base.convert( + ::Type{MvNormalMeanScalePrecision{T}}, + dist::MvNormalMeanScalePrecision +) where {T <: Real} + return convert(MvNormalMeanScalePrecision{T, AbstractArray{T, 1}}, dist) +end + # Special case for `MvNormalMeanScalePrecision` to `MvNormalMeanCovariance` function Base.convert(::Type{MvNormalMeanCovariance}, dist::MvNormalMeanScalePrecision) m, σ = mean(dist), cov(dist) diff --git a/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl b/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl index e9594fec..e46f7d98 100644 --- a/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl +++ b/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl @@ -111,7 +111,7 @@ end dist = MvNormalMeanScalePrecision(μ, γ) @test prod(strategy, dist, dist) ≈ - MvNormalMeanScalePrecision([4.0, 8.0, 12.0], 2γ) + MvNormalMeanScalePrecision([1.0, 2.0, 3.0], 2γ) end end @@ -123,6 +123,6 @@ end @test begin m = rand(5) c = rand() - convert(MvNormalMeanScalePrecision, m, c) == MvNormalMeanPrecision(m, c) + convert(MvNormalMeanScalePrecision, m, c) == MvNormalMeanScalePrecision(m, c) end end From 9818bd317c93d9bd2ec7c30361d420968ff04ed7 Mon Sep 17 00:00:00 2001 From: Albert Date: Fri, 9 Aug 2024 15:14:57 +0200 Subject: [PATCH 05/32] Update structure and tests --- src/ExponentialFamily.jl | 2 +- .../mv_normal_mean_scale_precision.jl | 70 ++++++++++++++----- .../normal_family/normal_family.jl | 43 +----------- .../mv_normal_mean_scale_precision_tests.jl | 13 ++-- 4 files changed, 63 insertions(+), 65 deletions(-) diff --git a/src/ExponentialFamily.jl b/src/ExponentialFamily.jl index 5fe00a55..f2ec4c1c 100644 --- a/src/ExponentialFamily.jl +++ b/src/ExponentialFamily.jl @@ -45,8 +45,8 @@ include("distributions/normal_family/normal_weighted_mean_precision.jl") include("distributions/normal_family/mv_normal_mean_covariance.jl") include("distributions/normal_family/mv_normal_mean_precision.jl") include("distributions/normal_family/mv_normal_weighted_mean_precision.jl") -include("distributions/normal_family/mv_normal_mean_scale_precision.jl") include("distributions/normal_family/normal_family.jl") +include("distributions/normal_family/mv_normal_mean_scale_precision.jl") include("distributions/gamma_inverse.jl") include("distributions/geometric.jl") include("distributions/matrix_dirichlet.jl") diff --git a/src/distributions/normal_family/mv_normal_mean_scale_precision.jl b/src/distributions/normal_family/mv_normal_mean_scale_precision.jl index ecf29181..ce82eeee 100644 --- a/src/distributions/normal_family/mv_normal_mean_scale_precision.jl +++ b/src/distributions/normal_family/mv_normal_mean_scale_precision.jl @@ -1,4 +1,4 @@ -export MvNormalMeanScalePrecision +export MvNormalMeanScalePrecision, MvGaussianMeanScalePrecision import Distributions: logdetcov, distrname, sqmahal, sqmahal!, AbstractMvNormal import LinearAlgebra: diag, Diagonal, dot @@ -26,6 +26,8 @@ struct MvNormalMeanScalePrecision{T <: Real, M <: AbstractVector{T}} <: Abstract γ::T end +const MvGaussianMeanScalePrecision = MvNormalMeanScalePrecision + function MvNormalMeanScalePrecision(μ::AbstractVector{<:Real}, γ::Real) T = promote_type(eltype(μ), eltype(γ)) return MvNormalMeanScalePrecision(convert(AbstractArray{T}, μ), convert(T, γ)) @@ -46,6 +48,42 @@ function MvNormalMeanScalePrecision(μ::AbstractVector{T1}, γ::T2) where {T1, T return MvNormalMeanScalePrecision(μ_new, γ_new) end +# Conversions +function Base.convert(::Type{MvNormal{T, C, M}}, dist::MvNormalMeanScalePrecision) where {T <: Real, C <: Distributions.PDMats.PDMat{T, Matrix{T}}, M <: AbstractVector{T}} + m, σ = mean(dist), std(dist) + return MvNormal(convert(M, m), convert(T, σ)) +end + +function Base.convert( + ::Type{MvNormalMeanScalePrecision{T, M}}, + dist::MvNormalMeanScalePrecision +) where {T <: Real, M <: AbstractArray{T}} + m, γ = mean(dist), dist.γ + return MvNormalMeanScalePrecision{T, M}(convert(M, m), convert(T, γ)) +end + +function Base.convert( + ::Type{MvNormalMeanScalePrecision{T}}, + dist::MvNormalMeanScalePrecision +) where {T <: Real} + return convert(MvNormalMeanScalePrecision{T, AbstractArray{T, 1}}, dist) +end + +function Base.convert(::Type{MvNormalMeanCovariance}, dist::MvNormalMeanScalePrecision) + m, σ = mean(dist), cov(dist) + return MvNormalMeanCovariance(m, σ*diagm(ones(length(m)))) +end + +function Base.convert(::Type{MvNormalMeanPrecision}, dist::MvNormalMeanScalePrecision) + m, γ = mean(dist), precision(dist) + return MvNormalMeanPrecision(m, γ*diagm(ones(length(m)))) +end + +function Base.convert(::Type{MvNormalWeightedMeanPrecision}, dist::MvNormalMeanScalePrecision) + m, γ = mean(dist), precision(dist) + return MvNormalWeightedMeanPrecision(γ*m, γ*diagm(ones(length(m)))) +end + Distributions.distrname(::MvNormalMeanScalePrecision) = "MvNormalMeanScalePrecision" BayesBase.weightedmean(dist::MvNormalMeanScalePrecision) = precision(dist) * mean(dist) @@ -54,10 +92,11 @@ BayesBase.mean(dist::MvNormalMeanScalePrecision) = dist.μ BayesBase.mode(dist::MvNormalMeanScalePrecision) = mean(dist) BayesBase.var(dist::MvNormalMeanScalePrecision) = diag(cov(dist)) BayesBase.cov(dist::MvNormalMeanScalePrecision) = cholinv(invcov(dist)) -BayesBase.invcov(dist::MvNormalMeanScalePrecision) = dist.γ * I(length(mean(dist))) +BayesBase.invcov(dist::MvNormalMeanScalePrecision) = scale(dist) * I(length(mean(dist))) BayesBase.std(dist::MvNormalMeanScalePrecision) = cholsqrt(cov(dist)) BayesBase.logdetcov(dist::MvNormalMeanScalePrecision) = -chollogdet(invcov(dist)) BayesBase.params(dist::MvNormalMeanScalePrecision) = (mean(dist), invcov(dist)) +BayesBase.scale(dist::MvNormalMeanScalePrecision) = dist.γ function Distributions.sqmahal(dist::MvNormalMeanScalePrecision, x::AbstractVector) T = promote_type(eltype(x), paramfloattype(dist)) @@ -90,24 +129,19 @@ BayesBase.vague(::Type{<:MvNormalMeanScalePrecision}, dims::Int) = BayesBase.default_prod_rule(::Type{<:MvNormalMeanScalePrecision}, ::Type{<:MvNormalMeanScalePrecision}) = PreserveTypeProd(Distribution) function BayesBase.prod(::PreserveTypeProd{Distribution}, left::MvNormalMeanScalePrecision, right::MvNormalMeanScalePrecision) - w = left.γ + right.γ - m = (precision(left) * mean(left) + precision(right) * mean(right)) / w + w = scale(left) + scale(right) + m = (scale(left) * mean(left) + scale(right) * mean(right)) / w return MvNormalMeanScalePrecision(m, w) end +BayesBase.default_prod_rule(::Type{<:MultivariateNormalDistributionsFamily}, ::Type{<:MvNormalMeanScalePrecision}) = PreserveTypeProd(Distribution) + function BayesBase.prod( ::PreserveTypeProd{Distribution}, - left::MvNormalMeanScalePrecision{T1}, - right::MvNormalMeanScalePrecision{T2} -) where {T1 <: LinearAlgebra.BlasFloat, T2 <: LinearAlgebra.BlasFloat} - w = left.γ + right.γ - - T = promote_type(T1, T2) - - xi = convert(AbstractVector{T}, right.γ * mean(right)) - w = convert(T, w) - - xi .+= convert(T, left.γ) .* convert(AbstractVector{T}, mean(left)) - - return MvNormalMeanScalePrecision(xi / w, w) -end + left::L, + right::R +) where {L <: MultivariateNormalDistributionsFamily, R <: MvNormalMeanScalePrecision} + wleft = convert(MvNormalWeightedMeanPrecision, left) + wright = convert(MvNormalWeightedMeanPrecision, right) + return prod(BayesBase.default_prod_rule(wleft, wright), wleft, wright) +end \ No newline at end of file diff --git a/src/distributions/normal_family/normal_family.jl b/src/distributions/normal_family/normal_family.jl index 304a868b..e68fdfe0 100644 --- a/src/distributions/normal_family/normal_family.jl +++ b/src/distributions/normal_family/normal_family.jl @@ -10,10 +10,9 @@ const GaussianWeighteMeanPrecision = NormalWeightedMeanPrecision const MvGaussianMeanCovariance = MvNormalMeanCovariance const MvGaussianMeanPrecision = MvNormalMeanPrecision const MvGaussianWeightedMeanPrecision = MvNormalWeightedMeanPrecision -const MvGaussianMeanScalePrecision = MvNormalMeanScalePrecision const UnivariateNormalDistributionsFamily{T} = Union{NormalMeanPrecision{T}, NormalMeanVariance{T}, NormalWeightedMeanPrecision{T}, Normal{T}} -const MultivariateNormalDistributionsFamily{T} = Union{MvNormalMeanPrecision{T}, MvNormalMeanScalePrecision{T}, MvNormalMeanCovariance{T}, MvNormalWeightedMeanPrecision{T}, MvNormal{T}} +const MultivariateNormalDistributionsFamily{T} = Union{MvNormalMeanPrecision{T}, MvNormalMeanCovariance{T}, MvNormalWeightedMeanPrecision{T}, MvNormal{T}} const NormalDistributionsFamily{T} = Union{UnivariateNormalDistributionsFamily{T}, MultivariateNormalDistributionsFamily{T}} const UnivariateGaussianDistributionsFamily = UnivariateNormalDistributionsFamily @@ -252,12 +251,6 @@ function Base.convert( return MvNormal(convert(M, mean), Distributions.PDMats.PDMat(convert(AbstractMatrix{T}, cov))) end -# Special case for `MvNormalMeanScalePrecision` to `MvNormal` -function Base.convert(::Type{MvNormal{T, C, M}}, dist::MvNormalMeanScalePrecision) where {T <: Real, C <: Distributions.PDMats.PDMat{T, Matrix{T}}, M <: AbstractVector{T}} - m, σ = mean(dist), std(dist) - return MvNormal(convert(M, m), convert(T, σ)) -end - function Base.convert(::Type{MvNormal{T}}, dist::MultivariateNormalDistributionsFamily) where {T <: Real} return convert(MvNormal{T, Distributions.PDMats.PDMat{T, Matrix{T}}, Vector{T}}, dist) end @@ -306,28 +299,6 @@ function Base.convert(::Type{MvNormalMeanCovariance}, dist::MultivariateNormalDi return convert(MvNormalMeanCovariance{T}, dist) end - -function Base.convert( - ::Type{MvNormalMeanScalePrecision{T, M}}, - dist::MvNormalMeanScalePrecision -) where {T <: Real, M <: AbstractArray{T}} - m, γ = mean(dist), dist.γ - return MvNormalMeanScalePrecision{T, M}(convert(M, m), convert(T, γ)) -end - -function Base.convert( - ::Type{MvNormalMeanScalePrecision{T}}, - dist::MvNormalMeanScalePrecision -) where {T <: Real} - return convert(MvNormalMeanScalePrecision{T, AbstractArray{T, 1}}, dist) -end - -# Special case for `MvNormalMeanScalePrecision` to `MvNormalMeanCovariance` -function Base.convert(::Type{MvNormalMeanCovariance}, dist::MvNormalMeanScalePrecision) - m, σ = mean(dist), cov(dist) - return MvNormalMeanCovariance(m, σ*diagm(ones(length(m)))) -end - # Conversion to mean - precision parametrisation function Base.convert(::Type{NormalMeanPrecision{T}}, dist::UnivariateNormalDistributionsFamily) where {T <: Real} @@ -362,11 +333,7 @@ function Base.convert(::Type{MvNormalMeanPrecision}, dist::MultivariateNormalDis return convert(MvNormalMeanPrecision{T}, dist) end -# Special case for `MvNormalMeanScalePrecision` to `MvNormalMeanPrecision` -function Base.convert(::Type{MvNormalMeanPrecision}, dist::MvNormalMeanScalePrecision) - m, γ = mean(dist), precision(dist) - return MvNormalMeanPrecision(m, γ*diagm(ones(length(m)))) -end + # Conversion to weighted mean - precision parametrisation @@ -430,12 +397,6 @@ function Base.convert(::Type{MvNormalWeightedMeanPrecision}, dist::FullNormal) return MvNormalWeightedMeanPrecision(precision * mean, precision) end -# Special case for `MvNormalMeanScalePrecision` to `MvNormalWeightedMeanPrecision` -function Base.convert(::Type{MvNormalWeightedMeanPrecision}, dist::MvNormalMeanScalePrecision) - m, γ = mean(dist), precision(dist) - return MvNormalWeightedMeanPrecision(γ*m, γ*diagm(ones(length(m)))) -end - # isapprox function Base.isapprox(left::UnivariateNormalDistributionsFamily, right::UnivariateNormalDistributionsFamily; kwargs...) diff --git a/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl b/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl index e46f7d98..eb05a644 100644 --- a/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl +++ b/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl @@ -36,6 +36,7 @@ end @test mean(dist) == μ @test mode(dist) == μ + @test scale(dist) == γ @test weightedmean(dist) == weightedmean(rdist) @test invcov(dist) == invcov(rdist) @test precision(dist) == precision(rdist) @@ -70,13 +71,14 @@ end @test ndims(MvNormalMeanScalePrecision([0.0, 0.0, 0.0])) === 3 @test size(MvNormalMeanScalePrecision([0.0, 0.0])) === (2,) @test size(MvNormalMeanScalePrecision([0.0, 0.0, 0.0])) === (3,) - - distribution = MvNormalMeanScalePrecision([0.0, 0.0], 2.0) + + μ, γ = zeros(2), 2.0 + distribution = MvNormalMeanScalePrecision(μ, γ) @test distribution ≈ distribution - @test distribution ≈ convert(MvNormalMeanCovariance, distribution) - @test distribution ≈ convert(MvNormalMeanPrecision, distribution) - @test distribution ≈ convert(MvNormalWeightedMeanPrecision, distribution) + @test convert(MvNormalMeanCovariance, distribution) == MvNormalMeanCovariance(μ, inv(γ)*I(length(μ))) + @test convert(MvNormalMeanPrecision, distribution) == MvNormalMeanPrecision(μ, γ*I(length(μ))) + @test convert(MvNormalWeightedMeanPrecision, distribution) == MvNormalWeightedMeanPrecision(γ*μ, γ*I(length(μ))) end @testitem "MvNormalMeanScalePrecision: vague" begin @@ -102,6 +104,7 @@ end @testitem "MvNormalMeanScalePrecision: prod" begin include("./normal_family_setuptests.jl") + for strategy in (ClosedProd(), PreserveTypeProd(Distribution), GenericProd()) @test prod(strategy, MvNormalMeanScalePrecision([-1, -1], 2), MvNormalMeanPrecision([1, 1], [2, 4])) ≈ MvNormalWeightedMeanPrecision([0, 2], [4, 6]) From 22e611a4f8fc520d7feee7a9a5a7264ce7abceeb Mon Sep 17 00:00:00 2001 From: Albert Date: Mon, 12 Aug 2024 12:52:44 +0200 Subject: [PATCH 06/32] Add natural parameters related functions --- .../mv_normal_mean_scale_precision.jl | 48 +++++++++++++++---- 1 file changed, 38 insertions(+), 10 deletions(-) diff --git a/src/distributions/normal_family/mv_normal_mean_scale_precision.jl b/src/distributions/normal_family/mv_normal_mean_scale_precision.jl index ce82eeee..24f1bed0 100644 --- a/src/distributions/normal_family/mv_normal_mean_scale_precision.jl +++ b/src/distributions/normal_family/mv_normal_mean_scale_precision.jl @@ -26,7 +26,7 @@ struct MvNormalMeanScalePrecision{T <: Real, M <: AbstractVector{T}} <: Abstract γ::T end -const MvGaussianMeanScalePrecision = MvNormalMeanScalePrecision +const MvGaussianMeanScalePrecision = MvNormalMeanScalePrecision function MvNormalMeanScalePrecision(μ::AbstractVector{<:Real}, γ::Real) T = promote_type(eltype(μ), eltype(γ)) @@ -48,8 +48,36 @@ function MvNormalMeanScalePrecision(μ::AbstractVector{T1}, γ::T2) where {T1, T return MvNormalMeanScalePrecision(μ_new, γ_new) end +function unpack_parameters(::Type{MvNormalMeanScalePrecision}, packed) + len = length(packed) + n = div(-1 + isqrt(1 + 4 * len), 2) + + p₁ = view(packed, 1:n) + p₂ = packed[end] + + return (p₁, p₂) +end + +function isproper(::NaturalParametersSpace, ::Type{MvNormalMeanScalePrecision}, η, conditioner) + k = length(η) - 1 + if length(η) < 2 || (length(η) !== k + 1) + return false + end + (η₁, η₂) = unpack_parameters(MvNormalMeanScalePrecision, η) + return isnothing(conditioner) && length(η₁) === size(η₂, 1) && (size(η₂, 1) === size(η₂, 2)) && isposdef(-η₂) +end + +function (::MeanToNatural{MvNormalMeanScalePrecision})(tuple_of_θ::Tuple{Any, Any}) + (μ, γ) = tuple_of_θ + Σ⁻¹ = 1 / γ + return (Σ⁻¹ * μ, Σ⁻¹ / -2) +end + # Conversions -function Base.convert(::Type{MvNormal{T, C, M}}, dist::MvNormalMeanScalePrecision) where {T <: Real, C <: Distributions.PDMats.PDMat{T, Matrix{T}}, M <: AbstractVector{T}} +function Base.convert( + ::Type{MvNormal{T, C, M}}, + dist::MvNormalMeanScalePrecision +) where {T <: Real, C <: Distributions.PDMats.PDMat{T, Matrix{T}}, M <: AbstractVector{T}} m, σ = mean(dist), std(dist) return MvNormal(convert(M, m), convert(T, σ)) end @@ -70,18 +98,18 @@ function Base.convert( end function Base.convert(::Type{MvNormalMeanCovariance}, dist::MvNormalMeanScalePrecision) - m, σ = mean(dist), cov(dist) - return MvNormalMeanCovariance(m, σ*diagm(ones(length(m)))) + m, σ = mean(dist), cov(dist) + return MvNormalMeanCovariance(m, σ * diagm(ones(length(m)))) end function Base.convert(::Type{MvNormalMeanPrecision}, dist::MvNormalMeanScalePrecision) - m, γ = mean(dist), precision(dist) - return MvNormalMeanPrecision(m, γ*diagm(ones(length(m)))) + m, γ = mean(dist), precision(dist) + return MvNormalMeanPrecision(m, γ * diagm(ones(length(m)))) end function Base.convert(::Type{MvNormalWeightedMeanPrecision}, dist::MvNormalMeanScalePrecision) - m, γ = mean(dist), precision(dist) - return MvNormalWeightedMeanPrecision(γ*m, γ*diagm(ones(length(m)))) + m, γ = mean(dist), precision(dist) + return MvNormalWeightedMeanPrecision(γ * m, γ * diagm(ones(length(m)))) end Distributions.distrname(::MvNormalMeanScalePrecision) = "MvNormalMeanScalePrecision" @@ -95,8 +123,8 @@ BayesBase.cov(dist::MvNormalMeanScalePrecision) = cholinv(invcov(dist)) BayesBase.invcov(dist::MvNormalMeanScalePrecision) = scale(dist) * I(length(mean(dist))) BayesBase.std(dist::MvNormalMeanScalePrecision) = cholsqrt(cov(dist)) BayesBase.logdetcov(dist::MvNormalMeanScalePrecision) = -chollogdet(invcov(dist)) -BayesBase.params(dist::MvNormalMeanScalePrecision) = (mean(dist), invcov(dist)) BayesBase.scale(dist::MvNormalMeanScalePrecision) = dist.γ +BayesBase.params(dist::MvNormalMeanScalePrecision) = (mean(dist), scale(dist)) function Distributions.sqmahal(dist::MvNormalMeanScalePrecision, x::AbstractVector) T = promote_type(eltype(x), paramfloattype(dist)) @@ -144,4 +172,4 @@ function BayesBase.prod( wleft = convert(MvNormalWeightedMeanPrecision, left) wright = convert(MvNormalWeightedMeanPrecision, right) return prod(BayesBase.default_prod_rule(wleft, wright), wleft, wright) -end \ No newline at end of file +end From c9ad326056de0af6ddb4e679939a5eb0fa5f3831 Mon Sep 17 00:00:00 2001 From: Albert Date: Wed, 14 Aug 2024 17:59:50 +0200 Subject: [PATCH 07/32] WIP: Parameters transforamtion --- .../mv_normal_mean_scale_precision.jl | 35 +++++++++++++++---- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/src/distributions/normal_family/mv_normal_mean_scale_precision.jl b/src/distributions/normal_family/mv_normal_mean_scale_precision.jl index 24f1bed0..813aae48 100644 --- a/src/distributions/normal_family/mv_normal_mean_scale_precision.jl +++ b/src/distributions/normal_family/mv_normal_mean_scale_precision.jl @@ -49,10 +49,7 @@ function MvNormalMeanScalePrecision(μ::AbstractVector{T1}, γ::T2) where {T1, T end function unpack_parameters(::Type{MvNormalMeanScalePrecision}, packed) - len = length(packed) - n = div(-1 + isqrt(1 + 4 * len), 2) - - p₁ = view(packed, 1:n) + p₁ = view(packed, 1:length(packed) - 1) p₂ = packed[end] return (p₁, p₂) @@ -64,13 +61,18 @@ function isproper(::NaturalParametersSpace, ::Type{MvNormalMeanScalePrecision}, return false end (η₁, η₂) = unpack_parameters(MvNormalMeanScalePrecision, η) - return isnothing(conditioner) && length(η₁) === size(η₂, 1) && (size(η₂, 1) === size(η₂, 2)) && isposdef(-η₂) + return isnothing(conditioner) && isone(size(η₂, 1)) && isposdef(-η₂) end function (::MeanToNatural{MvNormalMeanScalePrecision})(tuple_of_θ::Tuple{Any, Any}) (μ, γ) = tuple_of_θ - Σ⁻¹ = 1 / γ - return (Σ⁻¹ * μ, Σ⁻¹ / -2) + return (Σ⁻¹ * μ, γ / -2) +end + +function (::NaturalToMean{MvNormalMeanScalePrecision})(tuple_of_η::Tuple{Any, Any}) + (η₁, η₂) = tuple_of_η + Σ = -1 / η₂ + return (Σ * η₁, Σ) end # Conversions @@ -173,3 +175,22 @@ function BayesBase.prod( wright = convert(MvNormalWeightedMeanPrecision, right) return prod(BayesBase.default_prod_rule(wleft, wright), wleft, wright) end + +getfisherinformation(::NaturalParametersSpace, ::Type{MvNormalMeanScalePrecision}) = + (η) -> begin + (η₁, η₂) = unpack_parameters(MvNormalMeanScalePrecision, η) + invη2 = -cholinv(-η₂) + n = size(η₁, 1) + ident = Eye(n) + Iₙ = PermutationMatrix(1, 1) + offdiag = + 1 / 4 * (invη2 * kron(ident, transpose(invη2 * η₁)) + invη2 * kron(η₁' * invη2, ident)) * + kron(ident, kron(Iₙ, ident)) + G = + -1 / 4 * + ( + kron(invη2, invη2) * kron(ident, η₁) * kron(ident, transpose(invη2 * η₁)) + + kron(invη2, invη2) * kron(η₁, ident) * kron(η₁' * invη2, ident) + ) * kron(ident, kron(Iₙ, ident)) + 1 / 2 * kron(invη2, invη2) + [-1/2*invη2 offdiag; offdiag' G] + end \ No newline at end of file From a0ca84877101037885c26d214ac481a817868a60 Mon Sep 17 00:00:00 2001 From: Albert Date: Thu, 15 Aug 2024 14:45:28 +0200 Subject: [PATCH 08/32] Add fisher information --- .../mv_normal_mean_scale_precision.jl | 20 ++++++++++--------- .../normal_family/normal_family.jl | 4 ---- .../mv_normal_mean_scale_precision_tests.jl | 19 +++++++++--------- 3 files changed, 20 insertions(+), 23 deletions(-) diff --git a/src/distributions/normal_family/mv_normal_mean_scale_precision.jl b/src/distributions/normal_family/mv_normal_mean_scale_precision.jl index 813aae48..bc3208f0 100644 --- a/src/distributions/normal_family/mv_normal_mean_scale_precision.jl +++ b/src/distributions/normal_family/mv_normal_mean_scale_precision.jl @@ -49,7 +49,7 @@ function MvNormalMeanScalePrecision(μ::AbstractVector{T1}, γ::T2) where {T1, T end function unpack_parameters(::Type{MvNormalMeanScalePrecision}, packed) - p₁ = view(packed, 1:length(packed) - 1) + p₁ = view(packed, 1:length(packed)-1) p₂ = packed[end] return (p₁, p₂) @@ -66,13 +66,13 @@ end function (::MeanToNatural{MvNormalMeanScalePrecision})(tuple_of_θ::Tuple{Any, Any}) (μ, γ) = tuple_of_θ - return (Σ⁻¹ * μ, γ / -2) + return (γ * μ, γ / -2) end function (::NaturalToMean{MvNormalMeanScalePrecision})(tuple_of_η::Tuple{Any, Any}) (η₁, η₂) = tuple_of_η - Σ = -1 / η₂ - return (Σ * η₁, Σ) + γ = -2 * η₂ + return (η₁ / γ, γ) end # Conversions @@ -182,6 +182,7 @@ getfisherinformation(::NaturalParametersSpace, ::Type{MvNormalMeanScalePrecision invη2 = -cholinv(-η₂) n = size(η₁, 1) ident = Eye(n) + kronprod = invη2^2 * Eye(n^2) Iₙ = PermutationMatrix(1, 1) offdiag = 1 / 4 * (invη2 * kron(ident, transpose(invη2 * η₁)) + invη2 * kron(η₁' * invη2, ident)) * @@ -189,8 +190,9 @@ getfisherinformation(::NaturalParametersSpace, ::Type{MvNormalMeanScalePrecision G = -1 / 4 * ( - kron(invη2, invη2) * kron(ident, η₁) * kron(ident, transpose(invη2 * η₁)) + - kron(invη2, invη2) * kron(η₁, ident) * kron(η₁' * invη2, ident) - ) * kron(ident, kron(Iₙ, ident)) + 1 / 2 * kron(invη2, invη2) - [-1/2*invη2 offdiag; offdiag' G] - end \ No newline at end of file + kronprod * kron(ident, η₁) * kron(ident, transpose(invη2 * η₁)) + + kronprod * kron(η₁, ident) * kron(η₁' * invη2 * ident, ident) + ) * kron(ident, kron(Iₙ, ident)) + 1 / 2 * kronprod + + [-1/2*invη2*ident offdiag; offdiag' G] + end diff --git a/src/distributions/normal_family/normal_family.jl b/src/distributions/normal_family/normal_family.jl index e68fdfe0..74999551 100644 --- a/src/distributions/normal_family/normal_family.jl +++ b/src/distributions/normal_family/normal_family.jl @@ -263,8 +263,6 @@ function Base.convert(::Type{MvNormal{T}}, mean::AbstractVector, cov::AbstractMa return MvNormal(convert(AbstractVector{T}, mean), Distributions.PDMats.PDMat(convert(AbstractMatrix{T}, cov))) end - - # Conversion to mean - variance parametrisation function Base.convert(::Type{NormalMeanVariance{T}}, dist::UnivariateNormalDistributionsFamily) where {T <: Real} @@ -333,8 +331,6 @@ function Base.convert(::Type{MvNormalMeanPrecision}, dist::MultivariateNormalDis return convert(MvNormalMeanPrecision{T}, dist) end - - # Conversion to weighted mean - precision parametrisation function Base.convert( diff --git a/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl b/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl index eb05a644..946f5693 100644 --- a/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl +++ b/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl @@ -29,8 +29,8 @@ end @testitem "MvNormalMeanScalePrecision: Stats methods" begin include("./normal_family_setuptests.jl") - μ = [0.2, 3.0, 4.0] - γ = 2.0 + μ = [0.2, 3.0, 4.0] + γ = 2.0 dist = MvNormalMeanScalePrecision(μ, γ) rdist = MvNormalMeanPrecision(μ, γ * ones(length(μ))) @@ -61,9 +61,9 @@ end include("./normal_family_setuptests.jl") @test convert(MvNormalMeanScalePrecision{Float32}, MvNormalMeanScalePrecision([0.0, 0.0])) == - MvNormalMeanScalePrecision([0.0f0, 0.0f0], 1.0f0) + MvNormalMeanScalePrecision([0.0f0, 0.0f0], 1.0f0) @test convert(MvNormalMeanScalePrecision{Float64}, [0.0, 0.0], 2.0) == - MvNormalMeanScalePrecision([0.0, 0.0], 2.0) + MvNormalMeanScalePrecision([0.0, 0.0], 2.0) @test length(MvNormalMeanScalePrecision([0.0, 0.0])) === 2 @test length(MvNormalMeanScalePrecision([0.0, 0.0, 0.0])) === 3 @@ -76,9 +76,9 @@ end distribution = MvNormalMeanScalePrecision(μ, γ) @test distribution ≈ distribution - @test convert(MvNormalMeanCovariance, distribution) == MvNormalMeanCovariance(μ, inv(γ)*I(length(μ))) - @test convert(MvNormalMeanPrecision, distribution) == MvNormalMeanPrecision(μ, γ*I(length(μ))) - @test convert(MvNormalWeightedMeanPrecision, distribution) == MvNormalWeightedMeanPrecision(γ*μ, γ*I(length(μ))) + @test convert(MvNormalMeanCovariance, distribution) == MvNormalMeanCovariance(μ, inv(γ) * I(length(μ))) + @test convert(MvNormalMeanPrecision, distribution) == MvNormalMeanPrecision(μ, γ * I(length(μ))) + @test convert(MvNormalWeightedMeanPrecision, distribution) == MvNormalWeightedMeanPrecision(γ * μ, γ * I(length(μ))) end @testitem "MvNormalMeanScalePrecision: vague" begin @@ -104,7 +104,6 @@ end @testitem "MvNormalMeanScalePrecision: prod" begin include("./normal_family_setuptests.jl") - for strategy in (ClosedProd(), PreserveTypeProd(Distribution), GenericProd()) @test prod(strategy, MvNormalMeanScalePrecision([-1, -1], 2), MvNormalMeanPrecision([1, 1], [2, 4])) ≈ MvNormalWeightedMeanPrecision([0, 2], [4, 6]) @@ -114,7 +113,7 @@ end dist = MvNormalMeanScalePrecision(μ, γ) @test prod(strategy, dist, dist) ≈ - MvNormalMeanScalePrecision([1.0, 2.0, 3.0], 2γ) + MvNormalMeanScalePrecision([1.0, 2.0, 3.0], 2γ) end end @@ -122,7 +121,7 @@ end include("./normal_family_setuptests.jl") @test convert(MvNormalMeanScalePrecision, zeros(2), 1.0) == - MvNormalMeanScalePrecision(zeros(2), 1.0) + MvNormalMeanScalePrecision(zeros(2), 1.0) @test begin m = rand(5) c = rand() From 8a37b2c8360c31e3260673d18e44d9270d8c4f9b Mon Sep 17 00:00:00 2001 From: Albert Date: Wed, 21 Aug 2024 11:59:56 +0200 Subject: [PATCH 09/32] Add fisher tests --- .../mv_normal_mean_scale_precision.jl | 8 +++++++ .../mv_normal_mean_scale_precision_tests.jl | 21 +++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/src/distributions/normal_family/mv_normal_mean_scale_precision.jl b/src/distributions/normal_family/mv_normal_mean_scale_precision.jl index bc3208f0..958282a7 100644 --- a/src/distributions/normal_family/mv_normal_mean_scale_precision.jl +++ b/src/distributions/normal_family/mv_normal_mean_scale_precision.jl @@ -196,3 +196,11 @@ getfisherinformation(::NaturalParametersSpace, ::Type{MvNormalMeanScalePrecision [-1/2*invη2*ident offdiag; offdiag' G] end + +getfisherinformation(::MeanParametersSpace, ::Type{MvNormalMeanScalePrecision}) = (θ) -> begin + μ, γ = unpack_parameters(MvNormalMeanScalePrecision, θ) + n = size(μ, 1) + offdiag = zeros(n, n^2) + G = (1 / 2) * γ^2 * Eye(n^2) + [γ*Eye(n) offdiag; offdiag' G] +end diff --git a/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl b/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl index 946f5693..e539f9ca 100644 --- a/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl +++ b/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl @@ -128,3 +128,24 @@ end convert(MvNormalMeanScalePrecision, m, c) == MvNormalMeanScalePrecision(m, c) end end + +@testitem "Fisher information tests" begin + include("./normal_family_setuptests.jl") + + for s in 2:5 + μ = randn(s) + γ = rand() + θ = pack_parameters(MvNormalMeanScalePrecision, (μ, γ)) + η = MeanToNatural(MvNormalMeanScalePrecision)(θ) + θ1 = pack_parameters(MvNormalMeanCovariance, (μ, inv(γ) * I(s))) + η1 = MeanToNatural(MvNormalMeanCovariance)(θ1) + @test begin + getfisherinformation(NaturalParametersSpace(), MvNormalMeanScalePrecision)(η) ≈ + getfisherinformation(NaturalParametersSpace(), MvNormalMeanCovariance)(η1) + end + + @test begin + getfisherinformation(MeanParametersSpace(), MvNormalMeanScalePrecision)(θ) ≈ getfisherinformation(MeanParametersSpace(), MvNormalMeanCovariance)(θ1) + end + end +end From 1260dd34ea3ae30419170e16321582d6548ebdba Mon Sep 17 00:00:00 2001 From: Albert Date: Wed, 21 Aug 2024 13:44:41 +0200 Subject: [PATCH 10/32] Add rand --- .../mv_normal_mean_scale_precision.jl | 29 ++++++++++++++++++- .../mv_normal_mean_scale_precision_tests.jl | 20 +++++++++++++ 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/src/distributions/normal_family/mv_normal_mean_scale_precision.jl b/src/distributions/normal_family/mv_normal_mean_scale_precision.jl index 958282a7..92b68395 100644 --- a/src/distributions/normal_family/mv_normal_mean_scale_precision.jl +++ b/src/distributions/normal_family/mv_normal_mean_scale_precision.jl @@ -149,7 +149,7 @@ Base.size(dist::MvNormalMeanScalePrecision) = (length(dist),) Base.convert(::Type{<:MvNormalMeanScalePrecision}, μ::AbstractVector, γ::Real) = MvNormalMeanScalePrecision(μ, γ) -function Base.convert(::Type{<:MvNormalMeanScalePrecision{T}}, μ::AbstractVector, γ::T) where {T <: Real} +function Base.convert(::Type{<:MvNormalMeanScalePrecision{T}}, μ::AbstractVector, γ::Real) where {T <: Real} MvNormalMeanScalePrecision(convert(AbstractArray{T}, μ), convert(T, γ)) end @@ -176,6 +176,33 @@ function BayesBase.prod( return prod(BayesBase.default_prod_rule(wleft, wright), wleft, wright) end +function BayesBase.rand(rng::AbstractRNG, dist::MvGaussianMeanScalePrecision{T}) where {T} + μ, γ = mean(dist), scale(dist) + return μ + 1 / γ * I(length(μ)) * randn(rng, T, length(μ)) +end + +function BayesBase.rand(rng::AbstractRNG, dist::MvGaussianMeanScalePrecision{T}, size::Int64) where {T} + container = Matrix{T}(undef, length(dist), size) + return rand!(rng, dist, container) +end + +# FIXME: This is not the most efficient way to generate random samples within container +# it needs to work with scale method, not with std +function BayesBase.rand!( + rng::AbstractRNG, + dist::MvGaussianMeanScalePrecision, + container::AbstractArray{T} +) where {T <: Real} + preallocated = similar(container) + randn!(rng, reshape(preallocated, length(preallocated))) + μ, L = mean_std(dist) + @views for i in axes(preallocated, 2) + copyto!(container[:, i], μ) + mul!(container[:, i], L, preallocated[:, i], 1, 1) + end + container +end + getfisherinformation(::NaturalParametersSpace, ::Type{MvNormalMeanScalePrecision}) = (η) -> begin (η₁, η₂) = unpack_parameters(MvNormalMeanScalePrecision, η) diff --git a/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl b/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl index e539f9ca..59dcd73f 100644 --- a/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl +++ b/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl @@ -149,3 +149,23 @@ end end end end + +@testitem "MvNormalMeanScalePrecision: rand" begin + include("./normal_family_setuptests.jl") + + rng = MersenneTwister(42) + + for T in (Float32, Float64) + @testset "Basic functionality" begin + μ = [1.0, 2.0, 3.0] + γ = 2.0 + dist = convert(MvNormalMeanScalePrecision{T}, μ, γ) + + @test typeof(rand(dist)) <: Vector{T} + + samples = rand(rng, dist, 5_000) + + @test isapprox(mean(samples), mean(μ), atol = 0.5) + end + end +end From d8b2370ddc0e3e4abbb3a73fd8b0a3608f9e8e25 Mon Sep 17 00:00:00 2001 From: Albert Date: Wed, 21 Aug 2024 14:18:57 +0200 Subject: [PATCH 11/32] Add MvNormalMeanScalePrecision to library.md --- docs/src/library.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/src/library.md b/docs/src/library.md index f1786834..a271e018 100644 --- a/docs/src/library.md +++ b/docs/src/library.md @@ -14,6 +14,7 @@ ExponentialFamily.NormalWeightedMeanPrecision ExponentialFamily.MvNormalMeanPrecision ExponentialFamily.MvNormalMeanCovariance ExponentialFamily.MvNormalWeightedMeanPrecision +ExponentialFamily.MvNormalMeanScalePrecision ExponentialFamily.JointNormal ExponentialFamily.JointGaussian ExponentialFamily.WishartFast From 44e2ce65fca27a6255ee3e82d9a89b943b1b3ffd Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Fri, 20 Sep 2024 13:09:29 +0200 Subject: [PATCH 12/32] test: add test exponentialfamily interface for MvNormalMeanScalePrecision --- .../mv_normal_mean_scale_precision_tests.jl | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl b/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl index 59dcd73f..31030906 100644 --- a/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl +++ b/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl @@ -26,6 +26,19 @@ end @test ExponentialFamily.distrname(MvNormalMeanScalePrecision(zeros(2))) === "MvNormalMeanScalePrecision" end +@testitem "MvNormalMeanScalePrecision: ExponentialFamilyDistribution" begin + include("../distributions_setuptests.jl") + + for s in 2:5 + μ = randn(s) + γ = rand() + + @testset let d = MvNormalMeanScalePrecision(μ, γ) + ef = test_exponentialfamily_interface(d; option_assume_no_allocations = true) + end + end +end + @testitem "MvNormalMeanScalePrecision: Stats methods" begin include("./normal_family_setuptests.jl") From 49670f8bc3671984446602bf7b5bf42472633370 Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Fri, 20 Sep 2024 14:07:04 +0200 Subject: [PATCH 13/32] feat: add basic functions for MvNormalMeanScalePrecision --- .../mv_normal_mean_scale_precision.jl | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/src/distributions/normal_family/mv_normal_mean_scale_precision.jl b/src/distributions/normal_family/mv_normal_mean_scale_precision.jl index 92b68395..ffcaf474 100644 --- a/src/distributions/normal_family/mv_normal_mean_scale_precision.jl +++ b/src/distributions/normal_family/mv_normal_mean_scale_precision.jl @@ -75,6 +75,12 @@ function (::NaturalToMean{MvNormalMeanScalePrecision})(tuple_of_η::Tuple{Any, A return (η₁ / γ, γ) end +function nabs2(x) + return sum(map(abs2, x)) +end + +getsufficientstatistics(::Type{MvNormalMeanScalePrecision}) = (identity, nabs2) + # Conversions function Base.convert( ::Type{MvNormal{T, C, M}}, @@ -203,6 +209,29 @@ function BayesBase.rand!( container end +function getsupport(ef::ExponentialFamilyDistribution{MvNormalMeanScalePrecision}) + dim = length(getnaturalparameters(ef)) - 1 + return Domain(IndicatorFunction{AbstractVector}(MvNormalDomainIndicator(dim))) +end + +getlogpartition(::NaturalParametersSpace, ::Type{MvNormalMeanScalePrecision}) = + (η) -> begin + η1 = @view η[1:end-1] + η2 = η[end] + k = length(η1) + Cinv = -inv(η2) + l = log(-inv(η2)) + return (dot(η1, Cinv, η1) / 2 - (k * log(2) + l)) / 2 + end + +getgradlogpartition(::NaturalParametersSpace, ::Type{MvNormalMeanScalePrecision}) = + (η) -> begin + η1 = @view η[1:end-1] + η2 = η[end] + Cinv = log(-inv(η2)) + return pack_parameters(MvNormalMeanCovariance, (0.5 * Cinv * η1, 0.25 * Cinv^2 * dot(η1,η1) + 0.5 * Cinv)) + end + getfisherinformation(::NaturalParametersSpace, ::Type{MvNormalMeanScalePrecision}) = (η) -> begin (η₁, η₂) = unpack_parameters(MvNormalMeanScalePrecision, η) From 1877a70fdbb489c069f40e734ebe5915db1f40bf Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Fri, 20 Sep 2024 15:37:23 +0200 Subject: [PATCH 14/32] feat: draft MvNormalMeanScalePrecision --- .../mv_normal_mean_scale_precision.jl | 2 + .../mv_normal_mean_scale_precision_tests.jl | 23 +- test/repopack-output.txt | 5221 +++++++++++++++++ 3 files changed, 5224 insertions(+), 22 deletions(-) create mode 100644 test/repopack-output.txt diff --git a/src/distributions/normal_family/mv_normal_mean_scale_precision.jl b/src/distributions/normal_family/mv_normal_mean_scale_precision.jl index ffcaf474..9939cc95 100644 --- a/src/distributions/normal_family/mv_normal_mean_scale_precision.jl +++ b/src/distributions/normal_family/mv_normal_mean_scale_precision.jl @@ -214,6 +214,8 @@ function getsupport(ef::ExponentialFamilyDistribution{MvNormalMeanScalePrecision return Domain(IndicatorFunction{AbstractVector}(MvNormalDomainIndicator(dim))) end +getbasemeasure(::Type{MvNormalMeanScalePrecision}) = (x) -> (2π)^(- length(x) / 2) + getlogpartition(::NaturalParametersSpace, ::Type{MvNormalMeanScalePrecision}) = (η) -> begin η1 = @view η[1:end-1] diff --git a/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl b/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl index 31030906..63286da9 100644 --- a/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl +++ b/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl @@ -34,7 +34,7 @@ end γ = rand() @testset let d = MvNormalMeanScalePrecision(μ, γ) - ef = test_exponentialfamily_interface(d; option_assume_no_allocations = true) + ef = test_exponentialfamily_interface(d;) end end end @@ -142,27 +142,6 @@ end end end -@testitem "Fisher information tests" begin - include("./normal_family_setuptests.jl") - - for s in 2:5 - μ = randn(s) - γ = rand() - θ = pack_parameters(MvNormalMeanScalePrecision, (μ, γ)) - η = MeanToNatural(MvNormalMeanScalePrecision)(θ) - θ1 = pack_parameters(MvNormalMeanCovariance, (μ, inv(γ) * I(s))) - η1 = MeanToNatural(MvNormalMeanCovariance)(θ1) - @test begin - getfisherinformation(NaturalParametersSpace(), MvNormalMeanScalePrecision)(η) ≈ - getfisherinformation(NaturalParametersSpace(), MvNormalMeanCovariance)(η1) - end - - @test begin - getfisherinformation(MeanParametersSpace(), MvNormalMeanScalePrecision)(θ) ≈ getfisherinformation(MeanParametersSpace(), MvNormalMeanCovariance)(θ1) - end - end -end - @testitem "MvNormalMeanScalePrecision: rand" begin include("./normal_family_setuptests.jl") diff --git a/test/repopack-output.txt b/test/repopack-output.txt new file mode 100644 index 00000000..156907a3 --- /dev/null +++ b/test/repopack-output.txt @@ -0,0 +1,5221 @@ +This file is a merged representation of the entire codebase, combining all repository files into a single document. +Generated by Repopack on: 2024-09-19T14:22:03.493Z + +================================================================ +File Summary +================================================================ + +Purpose: +-------- +This file contains a packed representation of the entire repository's contents. +It is designed to be easily consumable by AI systems for analysis, code review, +or other automated processes. + +File Format: +------------ +The content is organized as follows: +1. This summary section +2. Repository information +3. Repository structure +4. Multiple file entries, each consisting of: + a. A separator line (================) + b. The file path (File: path/to/file) + c. Another separator line + d. The full contents of the file + e. A blank line + +Usage Guidelines: +----------------- +- This file should be treated as read-only. Any changes should be made to the + original repository files, not this packed version. +- When processing this file, use the file path to distinguish + between different files in the repository. +- Be aware that this file may contain sensitive information. Handle it with + the same level of security as you would the original repository. + +Notes: +------ +- Some files may have been excluded based on .gitignore rules and Repopack's + configuration. +- Binary files are not included in this packed representation. Please refer to + the Repository Structure section for a complete list of file paths, including + binary files. + +Additional Info: +---------------- + +For more information about Repopack, visit: https://github.com/yamadashy/repopack + +================================================================ +Repository Structure +================================================================ +distributions/ + gamma_family/ + gamma_family_setuptests.jl + gamma_family_tests.jl + gamma_shape_rate_tests.jl + gamma_shape_scale_tests.jl + normal_family/ + mv_normal_mean_covariance_tests.jl + mv_normal_mean_precision_tests.jl + mv_normal_weighted_mean_precision_tests.jl + normal_family_setuptests.jl + normal_family_tests.jl + normal_mean_precision_tests.jl + normal_mean_variance_tests.jl + normal_weighted_mean_precision_tests.jl + wip/ + test_continuous_bernoulli.jl + test_multinomial.jl + bernoulli_tests.jl + beta_tests.jl + binomial_tests.jl + categorical_tests.jl + chi_squared_tests.jl + dirichlet_tests.jl + distributions_setuptests.jl + erlang_tests.jl + exponential_tests.jl + gamma_inverse_tests.jl + geometric_tests.jl + laplace_tests.jl + lognormal_tests.jl + matrix_dirichlet_tests.jl + mv_normal_wishart_tests.jl + negative_binomial_tests.jl + normal_gamma_tests.jl + pareto_tests.jl + poisson_tests.jl + rayleigh_tests.jl + von_mises_fisher_tests.jl + vonmises_tests.jl + weibull_tests.jl + wishart_inverse_tests.jl + wishart_tests.jl +common_tests.jl +exponential_family_setuptests.jl +exponential_family_tests.jl +runtests.jl + +================================================================ +Repository Files +================================================================ + +================ +File: distributions/gamma_family/gamma_family_setuptests.jl +================ +include("../distributions_setuptests.jl") + +import ExponentialFamily: xtlog + +function compare_basic_statistics(left, right) + @test mean(left) ≈ mean(right) + @test var(left) ≈ var(right) + @test cov(left) ≈ cov(right) + @test shape(left) ≈ shape(right) + @test scale(left) ≈ scale(right) + @test rate(left) ≈ rate(right) + @test entropy(left) ≈ entropy(right) + @test pdf(left, 1.0) ≈ pdf(right, 1.0) + @test pdf(left, 10.0) ≈ pdf(right, 10.0) + @test logpdf(left, 1.0) ≈ logpdf(right, 1.0) + @test logpdf(left, 10.0) ≈ logpdf(right, 10.0) + + @test mean(log, left) ≈ mean(log, right) + @test mean(loggamma, left) ≈ mean(loggamma, right) + @test mean(xtlog, left) ≈ mean(xtlog, right) + + return true +end + +================ +File: distributions/gamma_family/gamma_family_tests.jl +================ +@testitem "GammaFamily: Base statistical methods" begin + include("./gamma_family_setuptests.jl") + + types = union_types(GammaDistributionsFamily{Float64}) + rng = MersenneTwister(1234) + for _ in 1:10 + for type in types + left = convert(type, 100 * rand(rng, Float64), 100 * rand(rng, Float64)) + for type in types + right = convert(type, left) + @test compare_basic_statistics(left, right) + + @test all(params(MeanParametersSpace(), left) .== (shape(left), scale(left))) + @test all(params(MeanParametersSpace(), right) .== (shape(right), scale(right))) + end + end + end +end + +@testitem "GammaFamily: ExponentialFamilyDistribution" begin + include("./gamma_family_setuptests.jl") + + for k in (0.1, 2.0, 5.0), θ in (0.1, 2.0, 5.0), T in union_types(GammaDistributionsFamily{Float64}) + @testset let d = convert(T, GammaShapeScale(k, θ)) + ef = test_exponentialfamily_interface(d; option_assume_no_allocations = true) + + (η₁, η₂) = (shape(d) - 1, -inv(scale(d))) + + for x in 0.1:0.5:5.0 + @test @inferred(isbasemeasureconstant(ef)) === ConstantBaseMeasure() + @test @inferred(basemeasure(ef, x)) === oneunit(x) + @test all(@inferred(sufficientstatistics(ef, x)) .=== (log(x), x)) + @test @inferred(logpartition(ef)) ≈ loggamma(η₁ + 1) - (η₁ + 1) * log(-η₂) + end + + @test @inferred(insupport(ef, 0.5)) + @test !@inferred(insupport(ef, -0.5)) + + # # Not in the support + @test_throws Exception logpdf(ef, -0.5) + end + end + + # Test failing isproper cases + @test !isproper(MeanParametersSpace(), Gamma, [-1]) + @test !isproper(MeanParametersSpace(), Gamma, [1, -1]) + @test !isproper(MeanParametersSpace(), Gamma, [-1, -1]) + @test !isproper(NaturalParametersSpace(), Gamma, [-1]) + @test !isproper(NaturalParametersSpace(), Gamma, [1, 10]) + @test !isproper(NaturalParametersSpace(), Gamma, [-100, -1]) + + # shapes must add up to something more than 1, otherwise is not proper + let ef = convert(ExponentialFamilyDistribution, Gamma(0.1, 1.0)) + @test !isproper(prod(PreserveTypeProd(ExponentialFamilyDistribution), ef, ef)) + end +end + +@testitem "GammaFamily: prod with ExponentialFamilyDistribution" begin + include("./gamma_family_setuptests.jl") + + for kleft in 0.51:1.0:5.0, kright in 0.51:1.0:5.0, θleft in 0.1:1.0:5.0, θright in 0.1:1.0:5.0, Tleft in union_types(GammaDistributionsFamily{Float64}), + Tright in union_types(GammaDistributionsFamily{Float64}) + + @testset let (left, right) = (convert(Tleft, Gamma(kleft, θleft)), convert(Tright, Gamma(kright, θright))) + @test test_generic_simple_exponentialfamily_product( + left, + right, + strategies = ( + ClosedProd(), + GenericProd(), + PreserveTypeProd(ExponentialFamilyDistribution), + PreserveTypeProd(ExponentialFamilyDistribution{Gamma}) + ) + ) + end + end +end + +================ +File: distributions/gamma_family/gamma_shape_rate_tests.jl +================ +@testitem "GammaShapeRate: Constructor" begin + include("./gamma_family_setuptests.jl") + + @test GammaShapeScale <: GammaDistributionsFamily + + @test GammaShapeRate() == GammaShapeRate{Float64}(1.0, 1.0) + @test GammaShapeRate(1.0) == GammaShapeRate{Float64}(1.0, 1.0) + @test GammaShapeRate(1.0, 2.0) == GammaShapeRate{Float64}(1.0, 2.0) + @test GammaShapeRate(1) == GammaShapeRate{Float64}(1.0, 1.0) + @test GammaShapeRate(1, 2) == GammaShapeRate{Float64}(1.0, 2.0) + @test GammaShapeRate(1.0, 2) == GammaShapeRate{Float64}(1.0, 2.0) + @test GammaShapeRate(1, 2.0) == GammaShapeRate{Float64}(1.0, 2.0) + @test GammaShapeRate(1.0f0) == GammaShapeRate{Float32}(1.0f0, 1.0f0) + @test GammaShapeRate(1.0f0, 2.0f0) == GammaShapeRate{Float32}(1.0f0, 2.0f0) + @test GammaShapeRate(1.0f0, 2) == GammaShapeRate{Float32}(1.0f0, 2.0f0) + @test GammaShapeRate(1.0f0, 2.0) == GammaShapeRate{Float64}(1.0, 2.0) + + @test paramfloattype(GammaShapeRate(1.0, 2.0)) === Float64 + @test paramfloattype(GammaShapeRate(1.0f0, 2.0f0)) === Float32 + + @test convert(GammaShapeRate{Float32}, GammaShapeRate()) == GammaShapeRate{Float32}(1.0f0, 1.0f0) + @test convert(GammaShapeRate{Float64}, GammaShapeRate(1.0, 10.0)) == GammaShapeRate{Float64}(1.0, 10.0) + @test convert(GammaShapeRate{Float64}, GammaShapeRate(1.0, 0.1)) == GammaShapeRate{Float64}(1.0, 0.1) + @test convert(GammaShapeRate{Float64}, 1, 1) == GammaShapeRate{Float64}(1.0, 1.0) + @test convert(GammaShapeRate{Float64}, 1, 10) == GammaShapeRate{Float64}(1.0, 10.0) + @test convert(GammaShapeRate{Float64}, 1.0, 0.1) == GammaShapeRate{Float64}(1.0, 0.1) + + @test convert(GammaShapeRate, GammaShapeRate(2.0, 2.0)) == GammaShapeRate{Float64}(2.0, 2.0) + @test convert(GammaShapeScale, GammaShapeRate(2.0, 2.0)) == GammaShapeScale{Float64}(2.0, 1.0 / 2.0) + + @test convert(GammaShapeRate, GammaShapeScale(2.0, 2.0)) == GammaShapeRate{Float64}(2.0, 1.0 / 2.0) + @test convert(GammaShapeScale, GammaShapeScale(2.0, 2.0)) == GammaShapeScale{Float64}(2.0, 2.0) +end + +@testitem "GammaShapeRate: vague" begin + include("./gamma_family_setuptests.jl") + + @test vague(GammaShapeRate) == GammaShapeRate(1.0, 1e-12) +end + +@testitem "GammaShapeRate: stats methods" begin + include("./gamma_family_setuptests.jl") + + dist1 = GammaShapeRate(1.0, 1.0) + + @test mean(dist1) === 1.0 + @test var(dist1) === 1.0 + @test cov(dist1) === 1.0 + @test shape(dist1) === 1.0 + @test scale(dist1) === 1.0 + @test rate(dist1) === 1.0 + @test entropy(dist1) ≈ 1.0 + @test pdf(dist1, 1.0) ≈ 0.36787944117144233 + @test logpdf(dist1, 1.0) ≈ -1.0 + + dist2 = GammaShapeRate(1.0, 2.0) + + @test mean(dist2) === inv(2.0) + @test var(dist2) === inv(4.0) + @test cov(dist2) === inv(4.0) + @test shape(dist2) === 1.0 + @test scale(dist2) === inv(2.0) + @test rate(dist2) === 2.0 + @test entropy(dist2) ≈ 0.3068528194400547 + @test pdf(dist2, 1.0) ≈ 0.2706705664732254 + @test logpdf(dist2, 1.0) ≈ -1.3068528194400546 + + dist3 = GammaShapeRate(2.0, 2.0) + + @test mean(dist3) === 1.0 + @test var(dist3) === inv(2.0) + @test cov(dist3) === inv(2.0) + @test shape(dist3) === 2.0 + @test scale(dist3) === inv(2.0) + @test rate(dist3) === 2.0 + @test entropy(dist3) ≈ 0.8840684843415857 + @test pdf(dist3, 1.0) ≈ 0.5413411329464508 + @test logpdf(dist3, 1.0) ≈ -0.6137056388801094 + + # see https://github.com/ReactiveBayes/ReactiveMP.jl/issues/314 + dist = GammaShapeRate(257.37489915581654, 3.0) + @test pdf(dist, 86.2027941354432) == 0.07400338986721687 +end + +@testitem "GammaShapeRate: prod" begin + include("./gamma_family_setuptests.jl") + + for strategy in (ClosedProd(), PreserveTypeProd(Distribution), PreserveTypeLeftProd(), PreserveTypeRightProd(), GenericProd()) + @test prod(strategy, GammaShapeScale(1, 1), GammaShapeScale(1, 1)) == GammaShapeScale(1, 1 / 2) + @test prod(strategy, GammaShapeScale(1, 2), GammaShapeScale(1, 1)) == GammaShapeScale(1, 2 / 3) + @test prod(strategy, GammaShapeScale(1, 2), GammaShapeScale(1, 2)) == GammaShapeScale(1, 1) + @test prod(strategy, GammaShapeScale(2, 2), GammaShapeScale(1, 2)) == GammaShapeScale(2, 1) + @test prod(strategy, GammaShapeScale(2, 2), GammaShapeScale(2, 2)) == GammaShapeScale(3, 1) + @test prod(strategy, GammaShapeScale(1, 1), GammaShapeRate(1, 1)) == GammaShapeScale(1, 1 / 2) + @test prod(strategy, GammaShapeScale(1, 2), GammaShapeRate(1, 1)) == GammaShapeScale(1, 2 / 3) + @test prod(strategy, GammaShapeScale(1, 2), GammaShapeRate(1, 2)) == GammaShapeScale(1, 2 / 5) + @test prod(strategy, GammaShapeScale(2, 2), GammaShapeRate(1, 2)) == GammaShapeScale(2, 2 / 5) + @test prod(strategy, GammaShapeScale(2, 2), GammaShapeRate(2, 2)) == GammaShapeScale(3, 2 / 5) + end +end + +================ +File: distributions/gamma_family/gamma_shape_scale_tests.jl +================ +@testitem "GammaShapeScale: Constructor" begin + include("./gamma_family_setuptests.jl") + + @test GammaShapeScale <: GammaDistributionsFamily + + @test GammaShapeScale() == GammaShapeScale{Float64}(1.0, 1.0) + @test GammaShapeScale(1.0) == GammaShapeScale{Float64}(1.0, 1.0) + @test GammaShapeScale(1.0, 2.0) == GammaShapeScale{Float64}(1.0, 2.0) + @test GammaShapeScale(1) == GammaShapeScale{Float64}(1.0, 1.0) + @test GammaShapeScale(1, 2) == GammaShapeScale{Float64}(1.0, 2.0) + @test GammaShapeScale(1.0, 2) == GammaShapeScale{Float64}(1.0, 2.0) + @test GammaShapeScale(1, 2.0) == GammaShapeScale{Float64}(1.0, 2.0) + @test GammaShapeScale(1.0f0) == GammaShapeScale{Float32}(1.0f0, 1.0f0) + @test GammaShapeScale(1.0f0, 2.0f0) == GammaShapeScale{Float32}(1.0f0, 2.0f0) + @test GammaShapeScale(1.0f0, 2) == GammaShapeScale{Float32}(1.0f0, 2.0f0) + @test GammaShapeScale(1.0f0, 2.0) == GammaShapeScale{Float64}(1.0, 2.0) + + @test paramfloattype(GammaShapeScale(1.0, 2.0)) === Float64 + @test paramfloattype(GammaShapeScale(1.0f0, 2.0f0)) === Float32 + + @test convert(GammaShapeScale{Float32}, GammaShapeScale()) == GammaShapeScale{Float32}(1.0f0, 1.0f0) + @test convert(GammaShapeScale{Float64}, GammaShapeScale(1.0, 10.0)) == GammaShapeScale{Float64}(1.0, 10.0) + @test convert(GammaShapeScale{Float64}, GammaShapeScale(1.0, 0.1)) == GammaShapeScale{Float64}(1.0, 0.1) + @test convert(GammaShapeScale{Float64}, 1, 1) == GammaShapeScale{Float64}(1.0, 1.0) + @test convert(GammaShapeScale{Float64}, 1, 10) == GammaShapeScale{Float64}(1.0, 10.0) + @test convert(GammaShapeScale{Float64}, 1.0, 0.1) == GammaShapeScale{Float64}(1.0, 0.1) + + @test convert(GammaShapeRate, GammaShapeScale(2.0, 2.0)) == GammaShapeRate{Float64}(2.0, 1.0 / 2.0) + @test convert(GammaShapeScale, GammaShapeScale(2.0, 2.0)) == GammaShapeScale{Float64}(2.0, 2.0) +end + +@testitem "GammaShapeScale: vague" begin + include("./gamma_family_setuptests.jl") + + @test vague(GammaShapeScale) == GammaShapeScale(1.0, 1e12) +end + +@testitem "GammaShapeScale: stats methods" begin + include("./gamma_family_setuptests.jl") + + dist1 = GammaShapeScale(1.0, 1.0) + + @test mean(dist1) === 1.0 + @test var(dist1) === 1.0 + @test cov(dist1) === 1.0 + @test shape(dist1) === 1.0 + @test scale(dist1) === 1.0 + @test rate(dist1) === 1.0 + @test entropy(dist1) ≈ 1.0 + + dist2 = GammaShapeScale(1.0, 2.0) + + @test mean(dist2) === 2.0 + @test var(dist2) === 4.0 + @test cov(dist2) === 4.0 + @test shape(dist2) === 1.0 + @test scale(dist2) === 2.0 + @test rate(dist2) === inv(2.0) + @test entropy(dist2) ≈ 1.6931471805599454 + + dist3 = GammaShapeScale(2.0, 2.0) + + @test mean(dist3) === 4.0 + @test var(dist3) === 8.0 + @test cov(dist3) === 8.0 + @test shape(dist3) === 2.0 + @test scale(dist3) === 2.0 + @test rate(dist3) === inv(2.0) + @test entropy(dist3) ≈ 2.2703628454614764 +end + +@testitem "GammaShapeScale: prod" begin + include("./gamma_family_setuptests.jl") + + for strategy in (ClosedProd(), PreserveTypeProd(Distribution), PreserveTypeLeftProd(), PreserveTypeRightProd(), GenericProd()) + @test prod(strategy, GammaShapeRate(1, 1), GammaShapeRate(1, 1)) == GammaShapeRate(1, 2) + @test prod(strategy, GammaShapeRate(1, 2), GammaShapeRate(1, 1)) == GammaShapeRate(1, 3) + @test prod(strategy, GammaShapeRate(1, 2), GammaShapeRate(1, 2)) == GammaShapeRate(1, 4) + @test prod(strategy, GammaShapeRate(2, 2), GammaShapeRate(1, 2)) == GammaShapeRate(2, 4) + @test prod(strategy, GammaShapeRate(2, 2), GammaShapeRate(2, 2)) == GammaShapeRate(3, 4) + @test prod(strategy, GammaShapeRate(1, 1), GammaShapeScale(1, 1)) == GammaShapeRate(1, 2) + @test prod(strategy, GammaShapeRate(1, 2), GammaShapeScale(1, 1)) == GammaShapeRate(1, 3) + @test prod(strategy, GammaShapeRate(1, 2), GammaShapeScale(1, 2)) == GammaShapeRate(1, 5 / 2) + @test prod(strategy, GammaShapeRate(2, 2), GammaShapeScale(1, 2)) == GammaShapeRate(2, 5 / 2) + @test prod(strategy, GammaShapeRate(2, 2), GammaShapeScale(2, 2)) == GammaShapeRate(3, 5 / 2) + end +end + +================ +File: distributions/normal_family/mv_normal_mean_covariance_tests.jl +================ +@testitem "MvNormalMeanCovariance: Constructor" begin + include("./normal_family_setuptests.jl") + + @test MvNormalMeanCovariance <: AbstractMvNormal + + @test MvNormalMeanCovariance([1.0, 1.0]) == MvNormalMeanCovariance([1.0, 1.0], [1.0, 1.0]) + @test MvNormalMeanCovariance([1.0, 2.0]) == MvNormalMeanCovariance([1.0, 2.0], [1.0, 1.0]) + @test MvNormalMeanCovariance([1, 2]) == MvNormalMeanCovariance([1.0, 2.0], [1.0, 1.0]) + @test MvNormalMeanCovariance([1.0f0, 2.0f0]) == MvNormalMeanCovariance([1.0f0, 2.0f0], [1.0f0, 1.0f0]) + + @test eltype(MvNormalMeanCovariance([1.0, 1.0])) === Float64 + @test eltype(MvNormalMeanCovariance([1.0, 1.0], [1.0, 1.0])) === Float64 + @test eltype(MvNormalMeanCovariance([1, 1])) === Float64 + @test eltype(MvNormalMeanCovariance([1, 1], [1, 1])) === Float64 + @test eltype(MvNormalMeanCovariance([1.0f0, 1.0f0])) === Float32 + @test eltype(MvNormalMeanCovariance([1.0f0, 1.0f0], [1.0f0, 1.0f0])) === Float32 + + @test MvNormalMeanCovariance(ones(3), 5I) == MvNormalMeanCovariance(ones(3), Diagonal(5 * ones(3))) + @test MvNormalMeanCovariance([1, 2, 3, 4], 7.0I) == MvNormalMeanCovariance([1.0, 2.0, 3.0, 4.0], Diagonal(7.0 * ones(4))) +end + +@testitem "MvNormalMeanCovariance: distrname" begin + include("./normal_family_setuptests.jl") + + @test ExponentialFamily.distrname(MvNormalMeanCovariance(zeros(2))) === "MvNormalMeanCovariance" +end + +@testitem "MvNormalMeanCovariance: Stats methods" begin + include("./normal_family_setuptests.jl") + + μ = [0.2, 3.0, 4.0] + Σ = [1.5 -0.3 0.1; -0.3 1.8 0.0; 0.1 0.0 3.5] + dist = MvNormalMeanCovariance(μ, Σ) + + @test mean(dist) == μ + @test mode(dist) == μ + @test weightedmean(dist) ≈ inv(Σ) * μ + @test invcov(dist) ≈ inv(Σ) + @test precision(dist) ≈ inv(Σ) + @test cov(dist) == Σ + @test std(dist) * std(dist)' ≈ Σ + @test all(mean_cov(dist) .≈ (μ, Σ)) + @test all(mean_invcov(dist) .≈ (μ, inv(Σ))) + @test all(mean_precision(dist) .≈ (μ, inv(Σ))) + @test all(weightedmean_cov(dist) .≈ (inv(Σ) * μ, Σ)) + @test all(weightedmean_invcov(dist) .≈ (inv(Σ) * μ, inv(Σ))) + @test all(weightedmean_precision(dist) .≈ (inv(Σ) * μ, inv(Σ))) + + @test length(dist) == 3 + @test entropy(dist) ≈ 5.361886000915401 + @test pdf(dist, [0.2, 3.0, 4.0]) ≈ 0.021028302702542 + @test pdf(dist, [0.202, 3.002, 4.002]) ≈ 0.021028229679079503 + @test logpdf(dist, [0.2, 3.0, 4.0]) ≈ -3.8618860009154012 + @test logpdf(dist, [0.202, 3.002, 4.002]) ≈ -3.861889473548943 +end + +@testitem "MvNormalMeanCovariance: Base methods" begin + include("./normal_family_setuptests.jl") + + @test convert(MvNormalMeanCovariance{Float32}, MvNormalMeanCovariance([0.0, 0.0])) == + MvNormalMeanCovariance([0.0f0, 0.0f0], [1.0f0, 1.0f0]) + @test convert(MvNormalMeanCovariance{Float64}, [0.0, 0.0], [2 0; 0 3]) == + MvNormalMeanCovariance([0.0, 0.0], [2.0 0.0; 0.0 3.0]) + + @test length(MvNormalMeanCovariance([0.0, 0.0])) === 2 + @test length(MvNormalMeanCovariance([0.0, 0.0, 0.0])) === 3 + @test ndims(MvNormalMeanCovariance([0.0, 0.0])) === 2 + @test ndims(MvNormalMeanCovariance([0.0, 0.0, 0.0])) === 3 + @test size(MvNormalMeanCovariance([0.0, 0.0])) === (2,) + @test size(MvNormalMeanCovariance([0.0, 0.0, 0.0])) === (3,) + + distribution = MvNormalMeanCovariance([0.0, 0.0], [2.0 0.0; 0.0 3.0]) + + @test distribution ≈ distribution + @test distribution ≈ convert(MvNormalMeanPrecision, distribution) + @test distribution ≈ convert(MvNormalWeightedMeanPrecision, distribution) +end + +@testitem "MvNormalMeanCovariance: vague" begin + include("./normal_family_setuptests.jl") + + @test_throws MethodError vague(MvNormalMeanCovariance) + + d1 = vague(MvNormalMeanCovariance, 2) + + @test typeof(d1) <: MvNormalMeanCovariance + @test mean(d1) == zeros(2) + @test cov(d1) == Matrix(Diagonal(1e12 * ones(2))) + @test ndims(d1) == 2 + + d2 = vague(MvNormalMeanCovariance, 3) + + @test typeof(d2) <: MvNormalMeanCovariance + @test mean(d2) == zeros(3) + @test cov(d2) == Matrix(Diagonal(1e12 * ones(3))) + @test ndims(d2) == 3 +end + +@testitem "MvNormalMeanCovariance: prod" begin + include("./normal_family_setuptests.jl") + + for strategy in (ClosedProd(), PreserveTypeProd(Distribution), GenericProd()) + @test prod(strategy, MvNormalMeanCovariance([-1, -1], [2, 2]), MvNormalMeanCovariance([1, 1], [2, 4])) ≈ + MvNormalWeightedMeanPrecision([0, -1 / 4], [1, 3 / 4]) + + μ = [1.0, 2.0, 3.0] + Σ = diagm([1.0, 2.0, 3.0]) + dist = MvNormalMeanCovariance(μ, Σ) + + @test prod(strategy, dist, dist) ≈ + MvNormalWeightedMeanPrecision([2.0, 2.0, 2.0], diagm([2.0, 1.0, 2 / 3])) + end +end + +@testitem "MvNormalMeanCovariance: convert" begin + include("./normal_family_setuptests.jl") + + @test convert(MvNormalMeanCovariance, zeros(2), Matrix(Diagonal(ones(2)))) == + MvNormalMeanCovariance(zeros(2), Matrix(Diagonal(ones(2)))) + @test begin + m = rand(5) + c = Matrix(Symmetric(rand(5, 5))) + convert(MvNormalMeanCovariance, m, c) == MvNormalMeanCovariance(m, c) + end +end + +================ +File: distributions/normal_family/mv_normal_mean_precision_tests.jl +================ +@testitem "MvNormalMeanPrecision: Constructor" begin + include("./normal_family_setuptests.jl") + + @test MvNormalMeanPrecision <: AbstractMvNormal + + @test MvNormalMeanPrecision([1.0, 1.0]) == MvNormalMeanPrecision([1.0, 1.0], [1.0, 1.0]) + @test MvNormalMeanPrecision([1.0, 2.0]) == MvNormalMeanPrecision([1.0, 2.0], [1.0, 1.0]) + @test MvNormalMeanPrecision([1, 2]) == MvNormalMeanPrecision([1.0, 2.0], [1.0, 1.0]) + @test MvNormalMeanPrecision([1.0f0, 2.0f0]) == MvNormalMeanPrecision([1.0f0, 2.0f0], [1.0f0, 1.0f0]) + + @test eltype(MvNormalMeanPrecision([1.0, 1.0])) === Float64 + @test eltype(MvNormalMeanPrecision([1.0, 1.0], [1.0, 1.0])) === Float64 + @test eltype(MvNormalMeanPrecision([1, 1])) === Float64 + @test eltype(MvNormalMeanPrecision([1, 1], [1, 1])) === Float64 + @test eltype(MvNormalMeanPrecision([1.0f0, 1.0f0])) === Float32 + @test eltype(MvNormalMeanPrecision([1.0f0, 1.0f0], [1.0f0, 1.0f0])) === Float32 + + @test MvNormalMeanPrecision(ones(3), 5I) == MvNormalMeanPrecision(ones(3), Diagonal(5 * ones(3))) + @test MvNormalMeanPrecision([1, 2, 3, 4], 7.0I) == MvNormalMeanPrecision([1.0, 2.0, 3.0, 4.0], Diagonal(7.0 * ones(4))) +end + +@testitem "MvNormalMeanPrecision: distrname" begin + include("./normal_family_setuptests.jl") + + @test ExponentialFamily.distrname(MvNormalMeanPrecision(zeros(2))) === "MvNormalMeanPrecision" +end + +@testitem "MvNormalMeanPrecision: Stats methods" begin + include("./normal_family_setuptests.jl") + + μ = [0.2, 3.0, 4.0] + Λ = [1.5 -0.3 0.1; -0.3 1.8 0.0; 0.1 0.0 3.5] + dist = MvNormalMeanPrecision(μ, Λ) + + @test mean(dist) == μ + @test mode(dist) == μ + @test weightedmean(dist) == Λ * μ + @test invcov(dist) == Λ + @test precision(dist) == Λ + @test cov(dist) ≈ inv(Λ) + @test std(dist) * std(dist)' ≈ inv(Λ) + @test all(mean_cov(dist) .≈ (μ, inv(Λ))) + @test all(mean_invcov(dist) .≈ (μ, Λ)) + @test all(mean_precision(dist) .≈ (μ, Λ)) + @test all(weightedmean_cov(dist) .≈ (Λ * μ, inv(Λ))) + @test all(weightedmean_invcov(dist) .≈ (Λ * μ, Λ)) + @test all(weightedmean_precision(dist) .≈ (Λ * μ, Λ)) + + @test length(dist) == 3 + @test entropy(dist) ≈ 3.1517451983126357 + @test pdf(dist, [0.2, 3.0, 4.0]) ≈ 0.19171503573907536 + @test pdf(dist, [0.202, 3.002, 4.002]) ≈ 0.19171258180232315 + @test logpdf(dist, [0.2, 3.0, 4.0]) ≈ -1.6517451983126357 + @test logpdf(dist, [0.202, 3.002, 4.002]) ≈ -1.6517579983126356 +end + +@testitem "MvNormalMeanPrecision: Base methods" begin + include("./normal_family_setuptests.jl") + + @test convert(MvNormalMeanPrecision{Float32}, MvNormalMeanPrecision([0.0, 0.0])) == + MvNormalMeanPrecision([0.0f0, 0.0f0], [1.0f0, 1.0f0]) + @test convert(MvNormalMeanPrecision{Float64}, [0.0, 0.0], [2 0; 0 3]) == + MvNormalMeanPrecision([0.0, 0.0], [2.0 0.0; 0.0 3.0]) + + @test length(MvNormalMeanPrecision([0.0, 0.0])) === 2 + @test length(MvNormalMeanPrecision([0.0, 0.0, 0.0])) === 3 + @test ndims(MvNormalMeanPrecision([0.0, 0.0])) === 2 + @test ndims(MvNormalMeanPrecision([0.0, 0.0, 0.0])) === 3 + @test size(MvNormalMeanPrecision([0.0, 0.0])) === (2,) + @test size(MvNormalMeanPrecision([0.0, 0.0, 0.0])) === (3,) + + distribution = MvNormalMeanPrecision([0.0, 0.0], [2.0 0.0; 0.0 3.0]) + + @test distribution ≈ distribution + @test distribution ≈ convert(MvNormalMeanCovariance, distribution) + @test distribution ≈ convert(MvNormalWeightedMeanPrecision, distribution) +end + +@testitem "MvNormalMeanPrecision: vague" begin + include("./normal_family_setuptests.jl") + + @test_throws MethodError vague(MvNormalMeanPrecision) + + d1 = vague(MvNormalMeanPrecision, 2) + + @test typeof(d1) <: MvNormalMeanPrecision + @test mean(d1) == zeros(2) + @test invcov(d1) == Matrix(Diagonal(1e-12 * ones(2))) + @test ndims(d1) == 2 + + d2 = vague(MvNormalMeanPrecision, 3) + + @test typeof(d2) <: MvNormalMeanPrecision + @test mean(d2) == zeros(3) + @test invcov(d2) == Matrix(Diagonal(1e-12 * ones(3))) + @test ndims(d2) == 3 +end + +@testitem "MvNormalMeanPrecision: prod" begin + include("./normal_family_setuptests.jl") + + for strategy in (ClosedProd(), PreserveTypeProd(Distribution), GenericProd()) + @test prod(strategy, MvNormalMeanPrecision([-1, -1], [2, 2]), MvNormalMeanPrecision([1, 1], [2, 4])) ≈ + MvNormalWeightedMeanPrecision([0, 2], [4, 6]) + + μ = [1.0, 2.0, 3.0] + Λ = diagm(1 ./ [1.0, 2.0, 3.0]) + dist = MvNormalMeanPrecision(μ, Λ) + + @test prod(strategy, dist, dist) ≈ + MvNormalWeightedMeanPrecision([2.0, 2.0, 2.0], diagm([2.0, 1.0, 2 / 3])) + end +end + +@testitem "MvNormalMeanPrecision: convert" begin + include("./normal_family_setuptests.jl") + + @test convert(MvNormalMeanPrecision, zeros(2), Matrix(Diagonal(ones(2)))) == + MvNormalMeanPrecision(zeros(2), Matrix(Diagonal(ones(2)))) + @test begin + m = rand(5) + c = Matrix(Symmetric(rand(5, 5))) + convert(MvNormalMeanPrecision, m, c) == MvNormalMeanPrecision(m, c) + end +end + +================ +File: distributions/normal_family/mv_normal_weighted_mean_precision_tests.jl +================ +@testitem "MvNormalWeightedMeanPrecision: Constructor" begin + include("./normal_family_setuptests.jl") + + @test MvNormalWeightedMeanPrecision <: AbstractMvNormal + + @test MvNormalWeightedMeanPrecision([1.0, 1.0]) == MvNormalWeightedMeanPrecision([1.0, 1.0], [1.0, 1.0]) + @test MvNormalWeightedMeanPrecision([1.0, 2.0]) == MvNormalWeightedMeanPrecision([1.0, 2.0], [1.0, 1.0]) + @test MvNormalWeightedMeanPrecision([1, 2]) == MvNormalWeightedMeanPrecision([1.0, 2.0], [1.0, 1.0]) + @test MvNormalWeightedMeanPrecision([1.0f0, 2.0f0]) == + MvNormalWeightedMeanPrecision([1.0f0, 2.0f0], [1.0f0, 1.0f0]) + + @test eltype(MvNormalWeightedMeanPrecision([1.0, 1.0])) === Float64 + @test eltype(MvNormalWeightedMeanPrecision([1.0, 1.0], [1.0, 1.0])) === Float64 + @test eltype(MvNormalWeightedMeanPrecision([1, 1])) === Float64 + @test eltype(MvNormalWeightedMeanPrecision([1, 1], [1, 1])) === Float64 + @test eltype(MvNormalWeightedMeanPrecision([1.0f0, 1.0f0])) === Float32 + @test eltype(MvNormalWeightedMeanPrecision([1.0f0, 1.0f0], [1.0f0, 1.0f0])) === Float32 + + @test MvNormalWeightedMeanPrecision(ones(3), 5I) == MvNormalWeightedMeanPrecision(ones(3), Diagonal(5 * ones(3))) + @test MvNormalWeightedMeanPrecision([1, 2, 3, 4], 7.0I) == MvNormalWeightedMeanPrecision([1.0, 2.0, 3.0, 4.0], Diagonal(7.0 * ones(4))) +end + +@testitem "MvNormalWeightedMeanPrecision: distrname" begin + include("./normal_family_setuptests.jl") + + @test ExponentialFamily.distrname(MvNormalWeightedMeanPrecision(zeros(2))) === "MvNormalWeightedMeanPrecision" +end + +@testitem "MvNormalWeightedMeanPrecision: Stats methods" begin + include("./normal_family_setuptests.jl") + + xi = [-0.2, 5.34, 14.02] + Λ = [1.5 -0.3 0.1; -0.3 1.8 0.0; 0.1 0.0 3.5] + dist = MvNormalWeightedMeanPrecision(xi, Λ) + + @test mean(dist) ≈ inv(Λ) * xi + @test mode(dist) ≈ inv(Λ) * xi + @test weightedmean(dist) == xi + @test invcov(dist) == Λ + @test precision(dist) == Λ + @test cov(dist) ≈ inv(Λ) + @test std(dist) * std(dist)' ≈ inv(Λ) + @test all(mean_cov(dist) .≈ (inv(Λ) * xi, inv(Λ))) + @test all(mean_invcov(dist) .≈ (inv(Λ) * xi, Λ)) + @test all(mean_precision(dist) .≈ (inv(Λ) * xi, Λ)) + @test all(weightedmean_cov(dist) .≈ (xi, inv(Λ))) + @test all(weightedmean_invcov(dist) .≈ (xi, Λ)) + @test all(weightedmean_precision(dist) .≈ (xi, Λ)) + + @test length(dist) == 3 + @test entropy(dist) ≈ 3.1517451983126357 + @test pdf(dist, [0.2, 3.0, 4.0]) ≈ 0.19171503573907536 + @test pdf(dist, [0.202, 3.002, 4.002]) ≈ 0.19171258180232315 + @test logpdf(dist, [0.2, 3.0, 4.0]) ≈ -1.6517451983126357 + @test logpdf(dist, [0.202, 3.002, 4.002]) ≈ -1.6517579983126356 +end + +@testitem "MvNormalWeightedMeanPrecision: Base methods" begin + include("./normal_family_setuptests.jl") + + @test convert(MvNormalWeightedMeanPrecision{Float32}, MvNormalWeightedMeanPrecision([0.0, 0.0])) == + MvNormalWeightedMeanPrecision([0.0f0, 0.0f0], [1.0f0, 1.0f0]) + @test convert(MvNormalWeightedMeanPrecision{Float64}, [0.0, 0.0], [2 0; 0 3]) == + MvNormalWeightedMeanPrecision([0.0, 0.0], [2.0 0.0; 0.0 3.0]) + + @test length(MvNormalWeightedMeanPrecision([0.0, 0.0])) === 2 + @test length(MvNormalWeightedMeanPrecision([0.0, 0.0, 0.0])) === 3 + @test ndims(MvNormalWeightedMeanPrecision([0.0, 0.0])) === 2 + @test ndims(MvNormalWeightedMeanPrecision([0.0, 0.0, 0.0])) === 3 + @test size(MvNormalWeightedMeanPrecision([0.0, 0.0])) === (2,) + @test size(MvNormalWeightedMeanPrecision([0.0, 0.0, 0.0])) === (3,) + + distribution = MvNormalWeightedMeanPrecision([0.0, 0.0], [2.0 0.0; 0.0 3.0]) + + @test distribution ≈ distribution + @test distribution ≈ convert(MvNormalMeanCovariance, distribution) + @test distribution ≈ convert(MvNormalMeanPrecision, distribution) +end + +@testitem "MvNormalWeightedMeanPrecision: vague" begin + include("./normal_family_setuptests.jl") + + @test_throws MethodError vague(MvNormalWeightedMeanPrecision) + + d1 = vague(MvNormalWeightedMeanPrecision, 2) + + @test typeof(d1) <: MvNormalWeightedMeanPrecision + @test mean(d1) == zeros(2) + @test invcov(d1) == Matrix(Diagonal(1e-12 * ones(2))) + @test ndims(d1) == 2 + + d2 = vague(MvNormalWeightedMeanPrecision, 3) + + @test typeof(d2) <: MvNormalWeightedMeanPrecision + @test mean(d2) == zeros(3) + @test invcov(d2) == Matrix(Diagonal(1e-12 * ones(3))) + @test ndims(d2) == 3 +end + +@testitem "MvNormalWeightedMeanPrecision: prod" begin + include("./normal_family_setuptests.jl") + + for strategy in (ClosedProd(), PreserveTypeProd(Distribution), GenericProd()) + @test prod( + strategy, + MvNormalWeightedMeanPrecision([-1, -1], [2, 2]), + MvNormalWeightedMeanPrecision([1, 1], [2, 4]) + ) ≈ MvNormalWeightedMeanPrecision([0, 0], [4, 6]) + + xi = [0.2, 3.0, 4.0] + Λ = [1.5 -0.1 0.1; -0.1 1.8 0.0; 0.1 0.0 3.5] + dist = MvNormalWeightedMeanPrecision(xi, Λ) + + @test prod(strategy, dist, dist) ≈ + MvNormalWeightedMeanPrecision([0.40, 6.00, 8.00], [3.00 -0.20 0.20; -0.20 3.60 0.00; 0.20 0.00 7.00]) + end +end + +@testitem "MvNormalWeightedMeanPrecision: convert" begin + include("./normal_family_setuptests.jl") + + @test convert(MvNormalWeightedMeanPrecision, zeros(2), Matrix(Diagonal(ones(2)))) == + MvNormalWeightedMeanPrecision(zeros(2), Matrix(Diagonal(ones(2)))) + @test begin + m = rand(5) + c = Matrix(Symmetric(rand(5, 5))) + convert(MvNormalWeightedMeanPrecision, m, c) == MvNormalWeightedMeanPrecision(m, c) + end +end + +================ +File: distributions/normal_family/normal_family_setuptests.jl +================ +include("../distributions_setuptests.jl") + +import ExponentialFamily: dot3arg + +# We need this extra function to ensure better derivatives with AD, it is slower than our implementation +# but is more AD friendly +function getlogpartitionfortest(::NaturalParametersSpace, ::Type{MvNormalMeanCovariance}) + return (η) -> begin + weightedmean, minushalfprecision = unpack_parameters(MvNormalMeanCovariance, η) + return (dot3arg(weightedmean, inv(-minushalfprecision), weightedmean) / 2 - logdet(-2 * minushalfprecision)) / 2 + end +end + +function gaussianlpdffortest(params, x) + k = length(x) + μ, Σ = params[1:k], reshape(params[k+1:end], k, k) + coef = (2π)^(-k / 2) * det(Σ)^(-1 / 2) + exponent = -0.5 * (x - μ)' * inv(Σ) * (x - μ) + return log(coef * exp(exponent)) +end + +function check_basic_statistics(left::UnivariateNormalDistributionsFamily, right::UnivariateNormalDistributionsFamily) + @test mean(left) ≈ mean(right) + @test median(left) ≈ median(right) + @test mode(left) ≈ mode(right) + @test var(left) ≈ var(right) + @test std(left) ≈ std(right) + @test entropy(left) ≈ entropy(right) + + for value in (1.0, -1.0, 0.0, mean(left), mean(right), rand()) + @test pdf(left, value) ≈ pdf(right, value) + @test logpdf(left, value) ≈ logpdf(right, value) + @test all( + ForwardDiff.gradient((x) -> logpdf(left, x[1]), [value]) .≈ + ForwardDiff.gradient((x) -> logpdf(right, x[1]), [value]) + ) + @test all( + ForwardDiff.hessian((x) -> logpdf(left, x[1]), [value]) .≈ + ForwardDiff.hessian((x) -> logpdf(right, x[1]), [value]) + ) + end + + # `Normal` is not defining some of these methods and we don't want to define them either, because of the type piracy + if !(left isa Normal || right isa Normal) + @test cov(left) ≈ cov(right) + @test invcov(left) ≈ invcov(right) + @test weightedmean(left) ≈ weightedmean(right) + @test precision(left) ≈ precision(right) + @test all(mean_cov(left) .≈ mean_cov(right)) + @test all(mean_invcov(left) .≈ mean_invcov(right)) + @test all(mean_precision(left) .≈ mean_precision(right)) + @test all(weightedmean_cov(left) .≈ weightedmean_cov(right)) + @test all(weightedmean_invcov(left) .≈ weightedmean_invcov(right)) + @test all(weightedmean_precision(left) .≈ weightedmean_precision(right)) + end +end + +function check_basic_statistics(left::MultivariateNormalDistributionsFamily, right::MultivariateNormalDistributionsFamily) + @test mean(left) ≈ mean(right) + @test mode(left) ≈ mode(right) + @test var(left) ≈ var(right) + @test cov(left) ≈ cov(right) + @test logdetcov(left) ≈ logdetcov(right) + @test length(left) === length(right) + @test size(left) === size(right) + @test entropy(left) ≈ entropy(right) + + dims = length(mean(left)) + + for value in ( + fill(1.0, dims), + fill(-1.0, dims), + fill(0.1, dims), + mean(left), + mean(right), + rand(dims) + ) + @test pdf(left, value) ≈ pdf(right, value) + @test logpdf(left, value) ≈ logpdf(right, value) + @test all( + isapprox.( + ForwardDiff.gradient((x) -> logpdf(left, x), value), + ForwardDiff.gradient((x) -> logpdf(right, x), value), + atol = 1e-12 + ) + ) + if !all( + isapprox.( + ForwardDiff.hessian((x) -> logpdf(left, x), value), + ForwardDiff.hessian((x) -> logpdf(right, x), value), + atol = 1e-12 + ) + ) + error(left, right) + end + end + + # `MvNormal` is not defining some of these methods and we don't want to define them either, because of the type piracy + if !(left isa MvNormal || right isa MvNormal) + @test ndims(left) === ndims(right) + @test invcov(left) ≈ invcov(right) + @test weightedmean(left) ≈ weightedmean(right) + @test precision(left) ≈ precision(right) + @test all(mean_cov(left) .≈ mean_cov(right)) + @test all(mean_invcov(left) .≈ mean_invcov(right)) + @test all(mean_precision(left) .≈ mean_precision(right)) + @test all(weightedmean_cov(left) .≈ weightedmean_cov(right)) + @test all(weightedmean_invcov(left) .≈ weightedmean_invcov(right)) + @test all(weightedmean_precision(left) .≈ weightedmean_precision(right)) + end +end + +================ +File: distributions/normal_family/normal_family_tests.jl +================ +@testitem "NormalFamily: Univariate conversions" begin + include("./normal_family_setuptests.jl") + + types = union_types(UnivariateNormalDistributionsFamily{Float64}) + etypes = union_types(UnivariateNormalDistributionsFamily) + + rng = MersenneTwister(1234) + + for type in types + left = convert(type, rand(rng, Float64), rand(rng, Float64)) + + for type in [types..., etypes...] + right = convert(type, left) + check_basic_statistics(left, right) + + p1 = prod(PreserveTypeLeftProd(), left, right) + @test typeof(p1) <: typeof(left) + + p2 = prod(PreserveTypeRightProd(), left, right) + @test typeof(p2) <: typeof(right) + + for strategy in (ClosedProd(), PreserveTypeProd(Distribution), GenericProd()) + p3 = prod(strategy, left, right) + + check_basic_statistics(p1, p2) + check_basic_statistics(p2, p3) + check_basic_statistics(p1, p3) + end + end + end +end + +@testitem "NormalFamily: Multivariate conversions" begin + include("./normal_family_setuptests.jl") + + types = union_types(MultivariateNormalDistributionsFamily{Float64}) + etypes = union_types(MultivariateNormalDistributionsFamily) + + dims = (2, 3, 5) + rng = MersenneTwister(1234) + + for dim in dims + for type in types + left = convert(type, rand(rng, Float64, dim), Matrix(Diagonal(abs.(rand(rng, Float64, dim))))) + + for type in [types..., etypes...] + right = convert(type, left) + check_basic_statistics(left, right) + + p1 = prod(PreserveTypeLeftProd(), left, right) + @test typeof(p1) <: typeof(left) + + p2 = prod(PreserveTypeRightProd(), left, right) + @test typeof(p2) <: typeof(right) + + for strategy in (ClosedProd(), PreserveTypeProd(Distribution), GenericProd()) + p3 = prod(strategy, left, right) + + check_basic_statistics(p1, p2) + check_basic_statistics(p2, p3) + check_basic_statistics(p1, p3) + end + end + end + end +end + +@testitem "NormalFamily: Variate forms promotions" begin + include("./normal_family_setuptests.jl") + + @test promote_variate_type(Univariate, NormalMeanVariance) === NormalMeanVariance + @test promote_variate_type(Univariate, NormalMeanPrecision) === NormalMeanPrecision + @test promote_variate_type(Univariate, NormalWeightedMeanPrecision) === NormalWeightedMeanPrecision + + @test promote_variate_type(Multivariate, NormalMeanVariance) === MvNormalMeanCovariance + @test promote_variate_type(Multivariate, NormalMeanPrecision) === MvNormalMeanPrecision + @test promote_variate_type(Multivariate, NormalWeightedMeanPrecision) === MvNormalWeightedMeanPrecision + + @test promote_variate_type(Univariate, MvNormalMeanCovariance) === NormalMeanVariance + @test promote_variate_type(Univariate, MvNormalMeanPrecision) === NormalMeanPrecision + @test promote_variate_type(Univariate, MvNormalWeightedMeanPrecision) === NormalWeightedMeanPrecision + + @test promote_variate_type(Multivariate, MvNormalMeanCovariance) === MvNormalMeanCovariance + @test promote_variate_type(Multivariate, MvNormalMeanPrecision) === MvNormalMeanPrecision + @test promote_variate_type(Multivariate, MvNormalWeightedMeanPrecision) === MvNormalWeightedMeanPrecision +end + +@testitem "NormalFamily: Sampling univariate" begin + include("./normal_family_setuptests.jl") + + rng = MersenneTwister(1234) + + for T in (Float32, Float64) + let # NormalMeanVariance + μ, v = 10randn(rng), 10rand(rng) + d = convert(NormalMeanVariance{T}, μ, v) + + @test typeof(rand(d)) <: T + + samples = rand(rng, d, 5_000) + + @test isapprox(mean(samples), μ, atol = 0.5) + @test isapprox(var(samples), v, atol = 0.5) + end + + let # NormalMeanPrecision + μ, w = 10randn(rng), 10rand(rng) + d = convert(NormalMeanPrecision{T}, μ, w) + + @test typeof(rand(d)) <: T + + samples = rand(rng, d, 5_000) + + @test isapprox(mean(samples), μ, atol = 0.5) + @test isapprox(inv(var(samples)), w, atol = 0.5) + end + + let # WeightedMeanPrecision + wμ, w = 10randn(rng), 10rand(rng) + d = convert(NormalWeightedMeanPrecision{T}, wμ, w) + + @test typeof(rand(d)) <: T + + samples = rand(rng, d, 5_000) + + @test isapprox(inv(var(samples)) * mean(samples), wμ, atol = 0.5) + @test isapprox(inv(var(samples)), w, atol = 0.5) + end + end +end + +@testitem "NormalFamily: Sampling multivariate" begin + include("./normal_family_setuptests.jl") + + rng = MersenneTwister(1234) + for n in (2, 3), T in (Float64,), nsamples in (10_000,) + μ = randn(rng, n) + L = randn(rng, n, n) + Σ = L * L' + + d = convert(MvNormalMeanCovariance{T}, μ, Σ) + @test typeof(rand(d)) <: Vector{T} + + samples = eachcol(rand(rng, d, nsamples)) + weights = fill(1 / nsamples, nsamples) + + @test isapprox(sum(sample for sample in samples) / nsamples, mean(d), atol = n * 0.5) + @test isapprox( + sum((sample - mean(d)) * (sample - mean(d))' for sample in samples) / nsamples, + cov(d), + atol = n * 0.5 + ) + + μ = randn(rng, n) + L = randn(rng, n, n) + W = L * L' + d = convert(MvNormalMeanCovariance{T}, μ, W) + @test typeof(rand(d)) <: Vector{T} + + samples = eachcol(rand(rng, d, nsamples)) + weights = fill(1 / nsamples, nsamples) + + @test isapprox(sum(sample for sample in samples) / nsamples, mean(d), atol = n * 0.5) + @test isapprox( + sum((sample - mean(d)) * (sample - mean(d))' for sample in samples) / nsamples, + cov(d), + atol = n * 0.5 + ) + + ξ = randn(rng, n) + L = randn(rng, n, n) + W = L * L' + + d = convert(MvNormalWeightedMeanPrecision{T}, ξ, W) + + @test typeof(rand(d)) <: Vector{T} + + samples = eachcol(rand(rng, d, nsamples)) + weights = fill(1 / nsamples, nsamples) + + @test isapprox(sum(sample for sample in samples) / nsamples, mean(d), atol = n * 0.5) + @test isapprox( + sum((sample - mean(d)) * (sample - mean(d))' for sample in samples) / nsamples, + cov(d), + atol = n * 0.5 + ) + end +end + +@testitem "NormalFamily: ExponentialFamilyDistribution{NormalMeanVariance}" begin + include("./normal_family_setuptests.jl") + + for μ in -10.0:5.0:10.0, σ² in 0.1:1.0:5.0, T in union_types(UnivariateNormalDistributionsFamily) + @testset let d = convert(T, NormalMeanVariance(μ, σ²)) + ef = test_exponentialfamily_interface(d) + + (η₁, η₂) = (mean(d) / var(d), -1 / 2var(d)) + + for x in 10randn(4) + @test @inferred(isbasemeasureconstant(ef)) === ConstantBaseMeasure() + @test @inferred(basemeasure(ef, x)) ≈ 1 / sqrt(2π) + @test all(@inferred(sufficientstatistics(ef, x)) .≈ (x, abs2(x))) + @test @inferred(logpartition(ef)) ≈ (-η₁^2 / 4η₂ - 1 / 2 * log(-2η₂)) + @test @inferred(insupport(ef, x)) + end + end + end + + # Test failing isproper cases + @test !isproper(MeanParametersSpace(), NormalMeanVariance, [-1]) + @test !isproper(MeanParametersSpace(), NormalMeanVariance, [1, -0.1]) + @test !isproper(MeanParametersSpace(), NormalMeanVariance, [-0.1, -1]) + @test !isproper(NaturalParametersSpace(), NormalMeanVariance, [-1.1]) + @test !isproper(NaturalParametersSpace(), NormalMeanVariance, [1, 1]) + @test !isproper(NaturalParametersSpace(), NormalMeanVariance, [-1.1, 1]) +end + +@testitem "NormalFamily: prod with ExponentialFamilyDistribution{NormalMeanVariance}" begin + include("./normal_family_setuptests.jl") + + for μleft in 10randn(4), σ²left in 10rand(4), μright in 10randn(4), + σ²right in 10rand(4), Tleft in union_types(UnivariateNormalDistributionsFamily), + Tright in union_types(UnivariateNormalDistributionsFamily) + + @testset let (left, right) = (convert(Tleft, NormalMeanVariance(μleft, σ²left)), convert(Tright, NormalMeanVariance(μright, σ²right))) + @test test_generic_simple_exponentialfamily_product( + left, + right, + strategies = ( + ClosedProd(), + GenericProd(), + PreserveTypeProd(ExponentialFamilyDistribution), + PreserveTypeProd(ExponentialFamilyDistribution{NormalMeanVariance}) + ) + ) + end + end +end + +@testitem "NormalFamily: ExponentialFamilyDistribution{MvNormalMeanCovariance}" begin + include("./normal_family_setuptests.jl") + + for s in (2, 3), T in union_types(MultivariateNormalDistributionsFamily) + μ = 10randn(s) + L = LowerTriangular(randn(s, s) + s * I) + Σ = L * L' + @testset let d = convert(T, MvNormalMeanCovariance(μ, Σ)) + ef = test_exponentialfamily_interface( + d; + # These are handled differently below + test_fisherinformation_against_hessian = false, + test_fisherinformation_against_jacobian = false, + test_gradlogpartition_properties = false + ) + + (η₁, η₂) = (inv(Σ) * mean(d), -inv(Σ) / 2) + + for x in [10randn(s) for _ in 1:4] + @test @inferred(isbasemeasureconstant(ef)) === ConstantBaseMeasure() + @test @inferred(basemeasure(ef, x)) ≈ (2π)^(-s / 2) + @test all(@inferred(sufficientstatistics(ef, x)) .≈ (x, x * x')) + @test @inferred(logpartition(ef)) ≈ -1 / 4 * (η₁' * inv(η₂) * η₁) - 1 / 2 * logdet(-2η₂) + @test @inferred(insupport(ef, x)) + end + + run_test_gradlogpartition_properties(d, test_against_forwardiff = false) + + # Extra test with AD-friendly logpartition function + lp_ag = ForwardDiff.gradient(getlogpartitionfortest(NaturalParametersSpace(), MvNormalMeanCovariance), getnaturalparameters(ef)) + @test gradlogpartition(ef) ≈ lp_ag + end + end + + # Test failing isproper cases (naive) + @test !isproper(MeanParametersSpace(), MvNormalMeanCovariance, [-1]) + @test !isproper(MeanParametersSpace(), MvNormalMeanCovariance, [1, -0.1]) + @test !isproper(MeanParametersSpace(), MvNormalMeanCovariance, [-0.1, -1]) + @test !isproper(MeanParametersSpace(), MvNormalMeanCovariance, [-1, 2, 3, 4]) # shapes are incompatible + @test !isproper(MeanParametersSpace(), MvNormalMeanCovariance, [1, -0.1, -1, 0, 0, -1]) # covariance is not posdef + + @test !isproper(NaturalParametersSpace(), MvNormalMeanCovariance, [-1.1]) + @test !isproper(NaturalParametersSpace(), MvNormalMeanCovariance, [1, 1]) + @test !isproper(NaturalParametersSpace(), MvNormalMeanCovariance, [-1.1, 1]) + @test !isproper(NaturalParametersSpace(), MvNormalMeanCovariance, [-1, 2, 3, 4]) # shapes are incompatible + @test !isproper(NaturalParametersSpace(), MvNormalMeanCovariance, [1, -0.1, 1, 0, 0, 1]) # -η₂ is not posdef +end + +@testitem "NormalFamily: Fisher information matrix in natural parameters space" begin + include("./normal_family_setuptests.jl") + + for i in 1:5, d in 2:10 + rng = StableRNG(d * i) + μ = 10randn(rng, d) + L = LowerTriangular(randn(rng, d, d) + d * I) + Σ = L * L' + ef = convert(ExponentialFamilyDistribution, MvNormalMeanCovariance(μ, Σ)) + + fi_ef = fisherinformation(ef) + # @test_broken isposdef(fi_ef) + # The `isposdef` check is not really reliable in Julia, here, instead we compute eigen values + @test issymmetric(fi_ef) || (LowerTriangular(fi_ef) ≈ (UpperTriangular(fi_ef)')) + @test isposdef(fi_ef) || all(>(0), eigvals(fi_ef)) + + fi_ef_inv = inv(fi_ef) + @test (fi_ef_inv * fi_ef) ≈ Diagonal(ones(d + d^2)) + + # WARNING: ForwardDiff returns a non-positive definite Hessian for a convex function. + # The matrices are identical up to permutations, resulting in eigenvalues that are the same up to a sign. + fi_ag = ForwardDiff.hessian(getlogpartitionfortest(NaturalParametersSpace(), MvNormalMeanCovariance), getnaturalparameters(ef)) + @test norm(sort(eigvals(fi_ef)) - sort(abs.(eigvals(fi_ag)))) ≈ 0 atol = (1e-9 * d^2) + end +end + +# We normally perform test with jacobian transformation, but autograd fails to compute jacobians with duplicated elements. +@testitem "Fisher information matrix in mean parameters space" begin + include("./normal_family_setuptests.jl") + + for i in 1:5, d in 2:3 + rng = StableRNG(d * i) + μ = 10randn(rng, d) + L = LowerTriangular(randn(rng, d, d) + d * I) + Σ = L * L' + n_samples = 10000 + dist = MvNormalMeanCovariance(μ, Σ) + + samples = rand(rng, dist, n_samples) + + θ = pack_parameters(MvNormalMeanCovariance, (μ, Σ)) + + approxHessian = zeros(length(θ), length(θ)) + for sample in eachcol(samples) + approxHessian -= ForwardDiff.hessian(Base.Fix2(gaussianlpdffortest, sample), θ) + end + approxFisherInformation = approxHessian /= n_samples + + # The error will be higher for sampling tests, tolerance adjusted accordingly. + fi_dist = getfisherinformation(MeanParametersSpace(), MvNormalMeanCovariance)(θ) + @test isposdef(fi_dist) || all(>(0), eigvals(fi_dist)) + @test issymmetric(fi_dist) || (LowerTriangular(fi_dist) ≈ (UpperTriangular(fi_dist)')) + @test sort(eigvals(fi_dist)) ≈ sort(abs.(eigvals(approxFisherInformation))) rtol = 1e-1 + @test sort(svd(fi_dist).S) ≈ sort(svd(approxFisherInformation).S) rtol = 1e-1 + end +end + +@testitem "Diffrentiabilty of ExponentialFamily(ExponentialFamily.MvNormalMeanCovariance) logpdf" begin + include("./normal_family_setuptests.jl") + for i in 1:5, d in 2:3 + rng = StableRNG(d * i) + μ = 10randn(rng, d) + L = LowerTriangular(randn(rng, d, d) + d * I) + Σ = L * L' + n_samples = 1 + dist = MvNormalMeanCovariance(μ, Σ) + + samples = rand(rng, dist, n_samples) + + θ = pack_parameters(MvNormalMeanCovariance, (μ, Σ)) + ef = convert(ExponentialFamilyDistribution, MvNormalMeanCovariance(μ, Σ)) + + nat_space2mean_space = (η) -> begin + dist = convert(Distribution, ExponentialFamilyDistribution(MvNormalMeanCovariance, η)) + μ, Σ = mean(dist), cov(dist) + pack_parameters(MvNormalMeanCovariance, (μ, Σ)) + end + + for sample in eachcol(samples) + mean_gradient = ForwardDiff.gradient(Base.Fix2(gaussianlpdffortest, sample), θ) + nat_gradient = ForwardDiff.gradient((η) -> logpdf(ExponentialFamilyDistribution(MvNormalMeanCovariance, η), sample), getnaturalparameters(ef)) + jacobian = ForwardDiff.jacobian(nat_space2mean_space, getnaturalparameters(ef)) + #autograd failing to compute jacobian of matrix part correclty. Comparing only vector (mean) part. + @test nat_gradient[1:d] ≈ (jacobian'*mean_gradient)[1:d] + end + end +end + +================ +File: distributions/normal_family/normal_mean_precision_tests.jl +================ +@testitem "NormalMeanPrecision: Constructor" begin + include("./normal_family_setuptests.jl") + + @test NormalMeanPrecision <: NormalDistributionsFamily + @test NormalMeanPrecision <: UnivariateNormalDistributionsFamily + + @test NormalMeanPrecision() == NormalMeanPrecision{Float64}(0.0, 1.0) + @test NormalMeanPrecision(1.0) == NormalMeanPrecision{Float64}(1.0, 1.0) + @test NormalMeanPrecision(1.0, 2.0) == NormalMeanPrecision{Float64}(1.0, 2.0) + @test NormalMeanPrecision(1) == NormalMeanPrecision{Float64}(1.0, 1.0) + @test NormalMeanPrecision(1, 2) == NormalMeanPrecision{Float64}(1.0, 2.0) + @test NormalMeanPrecision(1.0, 2) == NormalMeanPrecision{Float64}(1.0, 2.0) + @test NormalMeanPrecision(1, 2.0) == NormalMeanPrecision{Float64}(1.0, 2.0) + @test NormalMeanPrecision(1.0f0) == NormalMeanPrecision{Float32}(1.0f0, 1.0f0) + @test NormalMeanPrecision(1.0f0, 2.0f0) == NormalMeanPrecision{Float32}(1.0f0, 2.0f0) + @test NormalMeanPrecision(1.0f0, 2) == NormalMeanPrecision{Float32}(1.0f0, 2.0f0) + @test NormalMeanPrecision(1.0f0, 2.0) == NormalMeanPrecision{Float64}(1.0, 2.0) + + @test eltype(NormalMeanPrecision()) === Float64 + @test eltype(NormalMeanPrecision(0.0)) === Float64 + @test eltype(NormalMeanPrecision(0.0, 1.0)) === Float64 + @test eltype(NormalMeanPrecision(0)) === Float64 + @test eltype(NormalMeanPrecision(0, 1)) === Float64 + @test eltype(NormalMeanPrecision(0.0, 1)) === Float64 + @test eltype(NormalMeanPrecision(0, 1.0)) === Float64 + @test eltype(NormalMeanPrecision(0.0f0)) === Float32 + @test eltype(NormalMeanPrecision(0.0f0, 1.0f0)) === Float32 + @test eltype(NormalMeanPrecision(0.0f0, 1.0)) === Float64 + + @test NormalMeanPrecision(3, 5I) == NormalMeanPrecision(3, 5) + @test NormalMeanPrecision(2, 7.0I) == NormalMeanPrecision(2.0, 7.0) +end + +@testitem "NormalMeanPrecision: Stats methods" begin + include("./normal_family_setuptests.jl") + + dist1 = NormalMeanPrecision(0.0, 1.0) + + @test mean(dist1) === 0.0 + @test median(dist1) === 0.0 + @test mode(dist1) === 0.0 + @test weightedmean(dist1) === 0.0 + @test var(dist1) === 1.0 + @test std(dist1) === 1.0 + @test cov(dist1) === 1.0 + @test invcov(dist1) === 1.0 + @test precision(dist1) === 1.0 + @test entropy(dist1) ≈ 1.41893853320467 + @test pdf(dist1, 1.0) ≈ 0.24197072451914337 + @test pdf(dist1, -1.0) ≈ 0.24197072451914337 + @test pdf(dist1, 0.0) ≈ 0.3989422804014327 + @test logpdf(dist1, 1.0) ≈ -1.4189385332046727 + @test logpdf(dist1, -1.0) ≈ -1.4189385332046727 + @test logpdf(dist1, 0.0) ≈ -0.9189385332046728 + + dist2 = NormalMeanPrecision(1.0, 1.0) + + @test mean(dist2) === 1.0 + @test median(dist2) === 1.0 + @test mode(dist2) === 1.0 + @test weightedmean(dist2) === 1.0 + @test var(dist2) === 1.0 + @test std(dist2) === 1.0 + @test cov(dist2) === 1.0 + @test invcov(dist2) === 1.0 + @test precision(dist2) === 1.0 + @test entropy(dist2) ≈ 1.41893853320467 + @test pdf(dist2, 1.0) ≈ 0.3989422804014327 + @test pdf(dist2, -1.0) ≈ 0.05399096651318806 + @test pdf(dist2, 0.0) ≈ 0.24197072451914337 + @test logpdf(dist2, 1.0) ≈ -0.9189385332046728 + @test logpdf(dist2, -1.0) ≈ -2.9189385332046727 + @test logpdf(dist2, 0.0) ≈ -1.4189385332046727 + + dist3 = NormalMeanPrecision(1.0, 0.5) + + @test mean(dist3) === 1.0 + @test median(dist3) === 1.0 + @test mode(dist3) === 1.0 + @test weightedmean(dist3) === inv(2.0) + @test var(dist3) === 2.0 + @test std(dist3) === sqrt(2.0) + @test cov(dist3) === 2.0 + @test invcov(dist3) === inv(2.0) + @test precision(dist3) === inv(2.0) + @test entropy(dist3) ≈ 1.7655121234846454 + @test pdf(dist3, 1.0) ≈ 0.28209479177387814 + @test pdf(dist3, -1.0) ≈ 0.1037768743551487 + @test pdf(dist3, 0.0) ≈ 0.21969564473386122 + @test logpdf(dist3, 1.0) ≈ -1.2655121234846454 + @test logpdf(dist3, -1.0) ≈ -2.2655121234846454 + @test logpdf(dist3, 0.0) ≈ -1.5155121234846454 +end + +@testitem "NormalMeanPrecision: Base methods" begin + include("./normal_family_setuptests.jl") + + @test convert(NormalMeanPrecision{Float32}, NormalMeanPrecision()) == NormalMeanPrecision{Float32}(0.0f0, 1.0f0) + @test convert(NormalMeanPrecision{Float64}, NormalMeanPrecision(0.0, 10.0)) == + NormalMeanPrecision{Float64}(0.0, 10.0) + @test convert(NormalMeanPrecision{Float64}, NormalMeanPrecision(0.0, 0.1)) == + NormalMeanPrecision{Float64}(0.0, 0.1) + @test convert(NormalMeanPrecision{Float64}, 0, 1) == NormalMeanPrecision{Float64}(0.0, 1.0) + @test convert(NormalMeanPrecision{Float64}, 0, 10) == NormalMeanPrecision{Float64}(0.0, 10.0) + @test convert(NormalMeanPrecision{Float64}, 0, 0.1) == NormalMeanPrecision{Float64}(0.0, 0.1) + @test convert(NormalMeanPrecision, 0, 1) == NormalMeanPrecision{Float64}(0.0, 1.0) + @test convert(NormalMeanPrecision, 0, 10) == NormalMeanPrecision{Float64}(0.0, 10.0) + @test convert(NormalMeanPrecision, 0, 0.1) == NormalMeanPrecision{Float64}(0.0, 0.1) + + distribution = NormalMeanPrecision(-2.0, 3.0) + + @test distribution ≈ distribution + @test distribution ≈ convert(NormalMeanVariance, distribution) + @test distribution ≈ convert(NormalWeightedMeanPrecision, distribution) +end + +@testitem "NormalMeanPrecision: vague" begin + include("./normal_family_setuptests.jl") + + d1 = vague(NormalMeanPrecision) + + @test typeof(d1) <: NormalMeanPrecision + @test mean(d1) == 0.0 + @test precision(d1) == 1e-12 +end + +@testitem "NormalMeanPrecision: prod" begin + include("./normal_family_setuptests.jl") + + for strategy in (ClosedProd(), PreserveTypeProd(Distribution), GenericProd()) + @test prod(strategy, NormalMeanPrecision(-1, 1 / 1), NormalMeanPrecision(1, 1 / 1)) ≈ + NormalWeightedMeanPrecision(0.0, 2.0) + @test prod(strategy, NormalMeanPrecision(-1, 1 / 2), NormalMeanPrecision(1, 1 / 4)) ≈ + NormalWeightedMeanPrecision(-1 / 4, 3 / 4) + @test prod(strategy, NormalMeanPrecision(2, 1 / 2), NormalMeanPrecision(0, 1 / 10)) ≈ + NormalWeightedMeanPrecision(1, 3 / 5) + end +end + +================ +File: distributions/normal_family/normal_mean_variance_tests.jl +================ +@testitem "NormalMeanVariance: Constructor" begin + include("./normal_family_setuptests.jl") + + @test NormalMeanVariance <: NormalDistributionsFamily + @test NormalMeanVariance <: UnivariateNormalDistributionsFamily + + @test NormalMeanVariance() == NormalMeanVariance{Float64}(0.0, 1.0) + @test NormalMeanVariance(1.0) == NormalMeanVariance{Float64}(1.0, 1.0) + @test NormalMeanVariance(1.0, 2.0) == NormalMeanVariance{Float64}(1.0, 2.0) + @test NormalMeanVariance(1) == NormalMeanVariance{Float64}(1.0, 1.0) + @test NormalMeanVariance(1, 2) == NormalMeanVariance{Float64}(1.0, 2.0) + @test NormalMeanVariance(1.0, 2) == NormalMeanVariance{Float64}(1.0, 2.0) + @test NormalMeanVariance(1, 2.0) == NormalMeanVariance{Float64}(1.0, 2.0) + @test NormalMeanVariance(1.0f0) == NormalMeanVariance{Float32}(1.0f0, 1.0f0) + @test NormalMeanVariance(1.0f0, 2.0f0) == NormalMeanVariance{Float32}(1.0f0, 2.0f0) + @test NormalMeanVariance(1.0f0, 2) == NormalMeanVariance{Float32}(1.0f0, 2.0f0) + @test NormalMeanVariance(1.0f0, 2.0) == NormalMeanVariance{Float64}(1.0, 2.0) + + @test eltype(NormalMeanVariance()) === Float64 + @test eltype(NormalMeanVariance(0.0)) === Float64 + @test eltype(NormalMeanVariance(0.0, 1.0)) === Float64 + @test eltype(NormalMeanVariance(0)) === Float64 + @test eltype(NormalMeanVariance(0, 1)) === Float64 + @test eltype(NormalMeanVariance(0.0, 1)) === Float64 + @test eltype(NormalMeanVariance(0, 1.0)) === Float64 + @test eltype(NormalMeanVariance(0.0f0)) === Float32 + @test eltype(NormalMeanVariance(0.0f0, 1.0f0)) === Float32 + @test eltype(NormalMeanVariance(0.0f0, 1.0)) === Float64 + + @test NormalMeanVariance(3, 5I) == NormalMeanVariance(3, 5) + @test NormalMeanVariance(2, 7.0I) == NormalMeanVariance(2.0, 7.0) +end + +@testitem "NormalMeanVariance: Stats methods" begin + include("./normal_family_setuptests.jl") + + dist1 = NormalMeanVariance(0.0, 1.0) + + @test mean(dist1) === 0.0 + @test median(dist1) === 0.0 + @test mode(dist1) === 0.0 + @test weightedmean(dist1) === 0.0 + @test var(dist1) === 1.0 + @test std(dist1) === 1.0 + @test cov(dist1) === 1.0 + @test invcov(dist1) === 1.0 + @test precision(dist1) === 1.0 + @test entropy(dist1) ≈ 1.41893853320467 + @test pdf(dist1, 1.0) ≈ 0.24197072451914337 + @test pdf(dist1, -1.0) ≈ 0.24197072451914337 + @test pdf(dist1, 0.0) ≈ 0.3989422804014327 + @test logpdf(dist1, 1.0) ≈ -1.4189385332046727 + @test logpdf(dist1, -1.0) ≈ -1.4189385332046727 + @test logpdf(dist1, 0.0) ≈ -0.9189385332046728 + + dist2 = NormalMeanVariance(1.0, 1.0) + + @test mean(dist2) === 1.0 + @test median(dist2) === 1.0 + @test mode(dist2) === 1.0 + @test weightedmean(dist2) === 1.0 + @test var(dist2) === 1.0 + @test std(dist2) === 1.0 + @test cov(dist2) === 1.0 + @test invcov(dist2) === 1.0 + @test precision(dist2) === 1.0 + @test entropy(dist2) ≈ 1.41893853320467 + @test pdf(dist2, 1.0) ≈ 0.3989422804014327 + @test pdf(dist2, -1.0) ≈ 0.05399096651318806 + @test pdf(dist2, 0.0) ≈ 0.24197072451914337 + @test logpdf(dist2, 1.0) ≈ -0.9189385332046728 + @test logpdf(dist2, -1.0) ≈ -2.9189385332046727 + @test logpdf(dist2, 0.0) ≈ -1.4189385332046727 + + dist3 = NormalMeanVariance(1.0, 2.0) + + @test mean(dist3) === 1.0 + @test median(dist3) === 1.0 + @test mode(dist3) === 1.0 + @test weightedmean(dist3) === inv(2.0) + @test var(dist3) === 2.0 + @test std(dist3) === sqrt(2.0) + @test cov(dist3) === 2.0 + @test invcov(dist3) === inv(2.0) + @test precision(dist3) === inv(2.0) + @test entropy(dist3) ≈ 1.7655121234846454 + @test pdf(dist3, 1.0) ≈ 0.28209479177387814 + @test pdf(dist3, -1.0) ≈ 0.1037768743551487 + @test pdf(dist3, 0.0) ≈ 0.21969564473386122 + @test logpdf(dist3, 1.0) ≈ -1.2655121234846454 + @test logpdf(dist3, -1.0) ≈ -2.2655121234846454 + @test logpdf(dist3, 0.0) ≈ -1.5155121234846454 +end + +@testitem "NormalMeanVariance: Base methods" begin + include("./normal_family_setuptests.jl") + + @test convert(NormalMeanVariance{Float32}, NormalMeanVariance()) === NormalMeanVariance{Float32}(0.0f0, 1.0f0) + @test convert(NormalMeanVariance{Float64}, NormalMeanVariance(0.0, 10.0)) == + NormalMeanVariance{Float64}(0.0, 10.0) + @test convert(NormalMeanVariance{Float64}, NormalMeanVariance(0.0, 0.1)) == + NormalMeanVariance{Float64}(0.0, 0.1) + @test convert(NormalMeanVariance{Float64}, 0, 1) == NormalMeanVariance{Float64}(0.0, 1.0) + @test convert(NormalMeanVariance{Float64}, 0, 10) == NormalMeanVariance{Float64}(0.0, 10.0) + @test convert(NormalMeanVariance{Float64}, 0, 0.1) == NormalMeanVariance{Float64}(0.0, 0.1) + @test convert(NormalMeanVariance, 0, 1) == NormalMeanVariance{Float64}(0.0, 1.0) + @test convert(NormalMeanVariance, 0, 10) == NormalMeanVariance{Float64}(0.0, 10.0) + @test convert(NormalMeanVariance, 0, 0.1) == NormalMeanVariance{Float64}(0.0, 0.1) + + distribution = NormalMeanVariance(-2.0, 3.0) + + @test distribution ≈ distribution + @test distribution ≈ convert(NormalMeanPrecision, distribution) + @test distribution ≈ convert(NormalWeightedMeanPrecision, distribution) +end + +@testitem "NormalMeanVariance: vague" begin + include("./normal_family_setuptests.jl") + + d1 = vague(NormalMeanVariance) + + @test typeof(d1) <: NormalMeanVariance + @test mean(d1) == 0.0 + @test var(d1) == 1e12 +end + +@testitem "NormalMeanVariance: prod" begin + include("./normal_family_setuptests.jl") + + for strategy in (ClosedProd(), PreserveTypeProd(Distribution), GenericProd()) + @test prod(strategy, NormalMeanVariance(-1, 1), NormalMeanVariance(1, 1)) ≈ + NormalWeightedMeanPrecision(0.0, 2.0) + @test prod(strategy, NormalMeanVariance(-1, 2), NormalMeanVariance(1, 4)) ≈ + NormalWeightedMeanPrecision(-1 / 4, 3 / 4) + @test prod(strategy, NormalMeanVariance(2, 2), NormalMeanVariance(0, 10)) ≈ + NormalWeightedMeanPrecision(1.0, 3 / 5) + end +end + +================ +File: distributions/normal_family/normal_weighted_mean_precision_tests.jl +================ +@testitem "NormalWeightedMeanPrecision: Constructor" begin + include("./normal_family_setuptests.jl") + + @test NormalWeightedMeanPrecision <: NormalDistributionsFamily + @test NormalWeightedMeanPrecision <: UnivariateNormalDistributionsFamily + + @test NormalWeightedMeanPrecision() == NormalWeightedMeanPrecision{Float64}(0.0, 1.0) + @test NormalWeightedMeanPrecision(1.0) == NormalWeightedMeanPrecision{Float64}(1.0, 1.0) + @test NormalWeightedMeanPrecision(1.0, 2.0) == NormalWeightedMeanPrecision{Float64}(1.0, 2.0) + @test NormalWeightedMeanPrecision(1) == NormalWeightedMeanPrecision{Float64}(1.0, 1.0) + @test NormalWeightedMeanPrecision(1, 2) == NormalWeightedMeanPrecision{Float64}(1.0, 2.0) + @test NormalWeightedMeanPrecision(1.0, 2) == NormalWeightedMeanPrecision{Float64}(1.0, 2.0) + @test NormalWeightedMeanPrecision(1, 2.0) == NormalWeightedMeanPrecision{Float64}(1.0, 2.0) + @test NormalWeightedMeanPrecision(1.0f0) == NormalWeightedMeanPrecision{Float32}(1.0f0, 1.0f0) + @test NormalWeightedMeanPrecision(1.0f0, 2.0f0) == NormalWeightedMeanPrecision{Float32}(1.0f0, 2.0f0) + @test NormalWeightedMeanPrecision(1.0f0, 2.0) == NormalWeightedMeanPrecision{Float64}(1.0, 2.0) + + @test eltype(NormalWeightedMeanPrecision()) === Float64 + @test eltype(NormalWeightedMeanPrecision(0.0)) === Float64 + @test eltype(NormalWeightedMeanPrecision(0.0, 1.0)) === Float64 + @test eltype(NormalWeightedMeanPrecision(0)) === Float64 + @test eltype(NormalWeightedMeanPrecision(0, 1)) === Float64 + @test eltype(NormalWeightedMeanPrecision(0.0, 1)) === Float64 + @test eltype(NormalWeightedMeanPrecision(0, 1.0)) === Float64 + @test eltype(NormalWeightedMeanPrecision(0.0f0)) === Float32 + @test eltype(NormalWeightedMeanPrecision(0.0f0, 1.0f0)) === Float32 + @test eltype(NormalWeightedMeanPrecision(0.0f0, 1.0)) === Float64 + + @test NormalWeightedMeanPrecision(3, 5I) == NormalWeightedMeanPrecision(3, 5) + @test NormalWeightedMeanPrecision(2, 7.0I) == NormalWeightedMeanPrecision(2.0, 7.0) +end + +@testitem "NormalWeightedMeanPrecision: Stats methods" begin + include("./normal_family_setuptests.jl") + + dist1 = NormalWeightedMeanPrecision(0.0, 1.0) + + @test mean(dist1) === 0.0 + @test median(dist1) === 0.0 + @test mode(dist1) === 0.0 + @test weightedmean(dist1) === 0.0 + @test var(dist1) === 1.0 + @test std(dist1) === 1.0 + @test cov(dist1) === 1.0 + @test invcov(dist1) === 1.0 + @test precision(dist1) === 1.0 + @test entropy(dist1) ≈ 1.41893853320467 + @test pdf(dist1, 1.0) ≈ 0.24197072451914337 + @test pdf(dist1, -1.0) ≈ 0.24197072451914337 + @test pdf(dist1, 0.0) ≈ 0.3989422804014327 + @test logpdf(dist1, 1.0) ≈ -1.4189385332046727 + @test logpdf(dist1, -1.0) ≈ -1.4189385332046727 + @test logpdf(dist1, 0.0) ≈ -0.9189385332046728 + + dist2 = NormalWeightedMeanPrecision(1.0, 1.0) + + @test mean(dist2) === 1.0 + @test median(dist2) === 1.0 + @test mode(dist2) === 1.0 + @test weightedmean(dist2) === 1.0 + @test var(dist2) === 1.0 + @test std(dist2) === 1.0 + @test cov(dist2) === 1.0 + @test invcov(dist2) === 1.0 + @test precision(dist2) === 1.0 + @test entropy(dist2) ≈ 1.41893853320467 + @test pdf(dist2, 1.0) ≈ 0.3989422804014327 + @test pdf(dist2, -1.0) ≈ 0.05399096651318806 + @test pdf(dist2, 0.0) ≈ 0.24197072451914337 + @test logpdf(dist2, 1.0) ≈ -0.9189385332046728 + @test logpdf(dist2, -1.0) ≈ -2.9189385332046727 + @test logpdf(dist2, 0.0) ≈ -1.4189385332046727 + + dist3 = NormalWeightedMeanPrecision(1.0, 0.5) + + @test mean(dist3) === inv(0.5) + @test median(dist3) === inv(0.5) + @test mode(dist3) === inv(0.5) + @test weightedmean(dist3) === 1.0 + @test var(dist3) === 2.0 + @test std(dist3) === sqrt(2.0) + @test cov(dist3) === 2.0 + @test invcov(dist3) === inv(2.0) + @test precision(dist3) === inv(2.0) + @test entropy(dist3) ≈ 1.7655121234846454 + @test pdf(dist3, 1.0) ≈ 0.21969564473386122 + @test pdf(dist3, -1.0) ≈ 0.02973257230590734 + @test pdf(dist3, 0.0) ≈ 0.1037768743551487 + @test logpdf(dist3, 1.0) ≈ -1.5155121234846454 + @test logpdf(dist3, -1.0) ≈ -3.5155121234846454 + @test logpdf(dist3, 0.0) ≈ -2.2655121234846454 +end + +@testitem "NormalWeightedMeanPrecision: Base methods" begin + include("./normal_family_setuptests.jl") + + @test convert(NormalWeightedMeanPrecision{Float32}, NormalWeightedMeanPrecision()) == + NormalWeightedMeanPrecision{Float32}(0.0f0, 1.0f0) + @test convert(NormalWeightedMeanPrecision{Float64}, NormalWeightedMeanPrecision(0.0, 10.0)) == + NormalWeightedMeanPrecision{Float64}(0.0, 10.0) + @test convert(NormalWeightedMeanPrecision{Float64}, NormalWeightedMeanPrecision(0.0, 0.1)) == + NormalWeightedMeanPrecision{Float64}(0.0, 0.1) + @test convert(NormalWeightedMeanPrecision{Float64}, 0, 1) == NormalWeightedMeanPrecision{Float64}(0.0, 1.0) + @test convert(NormalWeightedMeanPrecision{Float64}, 0, 10) == NormalWeightedMeanPrecision{Float64}(0.0, 10.0) + @test convert(NormalWeightedMeanPrecision{Float64}, 0, 0.1) == NormalWeightedMeanPrecision{Float64}(0.0, 0.1) + @test convert(NormalWeightedMeanPrecision, 0, 1) == NormalWeightedMeanPrecision{Float64}(0.0, 1.0) + @test convert(NormalWeightedMeanPrecision, 0, 10) == NormalWeightedMeanPrecision{Float64}(0.0, 10.0) + @test convert(NormalWeightedMeanPrecision, 0, 0.1) == NormalWeightedMeanPrecision{Float64}(0.0, 0.1) + + distribution = NormalWeightedMeanPrecision(-2.0, 3.0) + + @test distribution ≈ distribution + @test distribution ≈ convert(NormalMeanPrecision, distribution) + @test distribution ≈ convert(NormalMeanVariance, distribution) +end + +@testitem "NormalWeightedMeanPrecision: vague" begin + include("./normal_family_setuptests.jl") + + d1 = vague(NormalWeightedMeanPrecision) + + @test typeof(d1) <: NormalWeightedMeanPrecision + @test mean(d1) == 0.0 + @test precision(d1) == 1e-12 +end + +@testitem "NormalWeightedMeanPrecision: prod" begin + include("./normal_family_setuptests.jl") + + for strategy in (ClosedProd(), PreserveTypeProd(Distribution), GenericProd()) + @test prod(strategy, NormalWeightedMeanPrecision(-1, 1 / 1), NormalWeightedMeanPrecision(1, 1 / 1)) ≈ + NormalWeightedMeanPrecision(0, 2) + @test prod(strategy, NormalWeightedMeanPrecision(-1, 1 / 2), NormalWeightedMeanPrecision(1, 1 / 4)) ≈ + NormalWeightedMeanPrecision(0, 3 / 4) + @test prod(strategy, NormalWeightedMeanPrecision(2, 1 / 2), NormalWeightedMeanPrecision(0, 1 / 10)) ≈ + NormalWeightedMeanPrecision(2, 3 / 5) + end +end + +================ +File: distributions/wip/test_continuous_bernoulli.jl +================ +module ContinuousBernoulliTest + +using Test +using ExponentialFamily +using Distributions +using Random +using StatsFuns +using ForwardDiff +import ExponentialFamily: + ExponentialFamilyDistribution, getnaturalparameters, compute_logscale, logpartition, basemeasure, + fisherinformation, isvague + +@testset "ContinuousBernoulli" begin + @testset "vague" begin + d = vague(ContinuousBernoulli) + + @test typeof(d) <: ContinuousBernoulli + @test mean(d) === 0.5 + @test succprob(d) === 0.5 + @test failprob(d) === 0.5 + end + + @testset "probvec" begin + @test probvec(ContinuousBernoulli(0.5)) === (0.5, 0.5) + @test probvec(ContinuousBernoulli(0.3)) === (0.7, 0.3) + @test probvec(ContinuousBernoulli(0.6)) === (0.4, 0.6) + end + + @testset "natural parameters related" begin + @test logpartition(convert(ExponentialFamilyDistribution, ContinuousBernoulli(0.5))) ≈ log(2) + @test logpartition(convert(ExponentialFamilyDistribution, ContinuousBernoulli(0.2))) ≈ + log((-3 / 4) / log(1 / 4)) + b_99 = ContinuousBernoulli(0.99) + for i in 1:9 + b = ContinuousBernoulli(i / 10.0) + bnp = convert(ExponentialFamilyDistribution, b) + @test convert(Distribution, bnp) ≈ b + @test logpdf(bnp, 1) ≈ logpdf(b, 1) + @test logpdf(bnp, 0) ≈ logpdf(b, 0) + + @test convert(ExponentialFamilyDistribution, b) == + ExponentialFamilyDistribution(ContinuousBernoulli, [logit(i / 10.0)]) + end + @test isproper(ExponentialFamilyDistribution(ContinuousBernoulli, [10])) === true + @test basemeasure(ExponentialFamilyDistribution(ContinuousBernoulli, [10]), 0.2) == 1.0 + end + + @testset "prod" begin + @test prod(ClosedProd(), ContinuousBernoulli(0.5), ContinuousBernoulli(0.5)) ≈ ContinuousBernoulli(0.5) + @test prod(ClosedProd(), ContinuousBernoulli(0.1), ContinuousBernoulli(0.6)) ≈ + ContinuousBernoulli(0.14285714285714285) + @test prod(ClosedProd(), ContinuousBernoulli(0.78), ContinuousBernoulli(0.05)) ≈ + ContinuousBernoulli(0.1572580645161291) + + left = convert(ExponentialFamilyDistribution, ContinuousBernoulli(0.5)) + right = convert(ExponentialFamilyDistribution, ContinuousBernoulli(0.6)) + @test prod(left, right) == convert(ExponentialFamilyDistribution, ContinuousBernoulli(0.6)) + end + + @testset "rand" begin + dist = ContinuousBernoulli(0.3) + nsamples = 1000 + rng = collect(1:100) + for i in 1:10 + samples = rand(MersenneTwister(rng[i]), dist, nsamples) + mestimated = mean(samples) + weights = ones(nsamples) / nsamples + @test isapprox(mestimated, mean(dist), atol = 1e-1) + @test isapprox( + sum(weight * (sample - mestimated)^2 for (sample, weight) in (samples, weights)), + var(dist), + atol = 1e-1 + ) + end + end + + @testset "fisher information" begin + function transformation(params) + return logistic(params[1]) + end + + for κ in 0.000001:0.01:0.49 + dist = ContinuousBernoulli(κ) + ef = convert(ExponentialFamilyDistribution, dist) + η = getnaturalparameters(ef) + + f_logpartition = (η) -> logpartition(ExponentialFamilyDistribution(ContinuousBernoulli, η)) + autograd_information = (η) -> ForwardDiff.hessian(f_logpartition, η) + @test first(fisherinformation(ef)) ≈ first(autograd_information(η)) atol = 1e-9 + J = ForwardDiff.gradient(transformation, η) + @test J' * fisherinformation(dist) * J ≈ first(fisherinformation(ef)) atol = 1e-9 + end + + for κ in 0.51:0.01:0.99 + dist = ContinuousBernoulli(κ) + ef = convert(ExponentialFamilyDistribution, dist) + η = getnaturalparameters(ef) + + f_logpartition = (η) -> logpartition(ExponentialFamilyDistribution(ContinuousBernoulli, η)) + autograd_information = (η) -> ForwardDiff.hessian(f_logpartition, η) + @test first(fisherinformation(ef)) ≈ first(autograd_information(η)) atol = 1e-9 + J = ForwardDiff.gradient(transformation, η) + @test J' * fisherinformation(dist) * J ≈ first(fisherinformation(ef)) atol = 1e-9 + end + + for κ in 0.499:0.0001:0.50001 + dist = ContinuousBernoulli(κ) + ef = convert(ExponentialFamilyDistribution, dist) + η = getnaturalparameters(ef) + + J = ForwardDiff.gradient(transformation, η) + @test J' * fisherinformation(dist) * J ≈ first(fisherinformation(ef)) atol = 1e-9 + end + end + + @testset "ExponentialFamilyDistribution mean var" begin + for ν in 0.1:0.1:0.99 + dist = ContinuousBernoulli(ν) + ef = convert(ExponentialFamilyDistribution, dist) + @test mean(dist) ≈ mean(ef) atol = 1e-8 + @test var(dist) ≈ var(ef) atol = 1e-8 + end + end +end + +end + +================ +File: distributions/wip/test_multinomial.jl +================ +module MultinomialTest + +using Test +using ExponentialFamily +using Distributions +using Random +using StableRNGs +using ForwardDiff +import ExponentialFamily: ExponentialFamilyDistribution, getnaturalparameters, basemeasure, fisherinformation + +@testset "Multinomial" begin + @testset "probvec" begin + @test probvec(Multinomial(5, [1 / 3, 1 / 3, 1 / 3])) == [1 / 3, 1 / 3, 1 / 3] + @test probvec(Multinomial(3, [0.2, 0.2, 0.4, 0.1, 0.1])) == [0.2, 0.2, 0.4, 0.1, 0.1] + @test probvec(Multinomial(2, [0.5, 0.5])) == [0.5, 0.5] + end + + @testset "vague" begin + @test_throws MethodError vague(Multinomial) + @test_throws MethodError vague(Multinomial, 4) + + vague_dist1 = vague(Multinomial, 5, 4) + @test typeof(vague_dist1) <: Multinomial + @test probvec(vague_dist1) == [1 / 4, 1 / 4, 1 / 4, 1 / 4] + + vague_dist2 = vague(Multinomial, 3, 5) + @test typeof(vague_dist2) <: Multinomial + @test probvec(vague_dist2) == [1 / 5, 1 / 5, 1 / 5, 1 / 5, 1 / 5] + end + + @testset "prod" begin + for n in 2:3 + plength = Int64(ceil(rand(Uniform(1, n)))) + pleft = rand(plength) + pleft = pleft ./ sum(pleft) + pright = rand(plength) + pright = pright ./ sum(pright) + left = Multinomial(n, pleft) + right = Multinomial(n, pright) + efleft = convert(ExponentialFamilyDistribution, left) + efright = convert(ExponentialFamilyDistribution, right) + prod_dist = prod(ClosedProd(), left, right) + prod_ef = prod(efleft, efright) + d = Multinomial(n, ones(plength) ./ plength) + sample_space = unique(rand(StableRNG(1), d, 4000), dims = 2) + sample_space = [sample_space[:, i] for i in 1:size(sample_space, 2)] + + hist_sum(x) = + prod_dist.basemeasure(x) * exp( + prod_dist.naturalparameters' * prod_dist.sufficientstatistics(x) - + prod_dist.logpartition(prod_dist.naturalparameters) + ) + hist_sumef(x) = + prod_ef.basemeasure(x) * exp( + prod_ef.naturalparameters' * prod_ef.sufficientstatistics(x) - + prod_ef.logpartition(prod_ef.naturalparameters) + ) + @test sum(hist_sum(x_sample) for x_sample in sample_space) ≈ 1.0 atol = 1e-10 + @test sum(hist_sumef(x_sample) for x_sample in sample_space) ≈ 1.0 atol = 1e-10 + sample_x = rand(d, 5) + for xi in sample_x + @test prod_dist.basemeasure(xi) ≈ (factorial(n) / prod(@.factorial(xi)))^2 atol = 1e-10 + @test prod_dist.sufficientstatistics(xi) == xi + + @test prod_ef.basemeasure(xi) ≈ (factorial(n) / prod(@.factorial(xi)))^2 atol = 1e-10 + @test prod_ef.sufficientstatistics(xi) == xi + end + end + + @test_throws AssertionError prod( + ClosedProd(), + Multinomial(4, [0.2, 0.4, 0.4]), + Multinomial(5, [0.1, 0.3, 0.6]) + ) + @test_throws AssertionError prod( + ClosedProd(), + Multinomial(4, [0.2, 0.4, 0.4]), + Multinomial(3, [0.1, 0.3, 0.6]) + ) + end + + @testset "natural parameters related " begin + d1 = Multinomial(5, [0.1, 0.4, 0.5]) + d2 = Multinomial(5, [0.2, 0.4, 0.4]) + η1 = ExponentialFamilyDistribution(Multinomial, [log(0.1 / 0.5), log(0.4 / 0.5), 0.0], 5) + η2 = ExponentialFamilyDistribution(Multinomial, [log(0.2 / 0.4), 0.0, 0.0], 5) + + @test convert(ExponentialFamilyDistribution, d1) ≈ η1 + @test convert(ExponentialFamilyDistribution, d2) ≈ η2 + + @test convert(Distribution, η1) ≈ d1 + @test convert(Distribution, η2) ≈ d2 + + @test logpartition(η1) == 3.4657359027997265 + @test logpartition(η2) == 4.5814536593707755 + + @test basemeasure(η1, [1, 2, 2]) == 30.0 + @test basemeasure(η2, [1, 2, 2]) == 30.0 + + @test logpdf(η1, [1, 2, 2]) == logpdf(d1, [1, 2, 2]) + @test logpdf(η2, [1, 2, 2]) == logpdf(d2, [1, 2, 2]) + + @test pdf(η1, [1, 2, 2]) == pdf(d1, [1, 2, 2]) + @test pdf(η2, [1, 2, 2]) == pdf(d2, [1, 2, 2]) + end + + @testset "fisher information" begin + function transformation(η) + expη = exp.(η) + expη / sum(expη) + end + rng = StableRNG(42) + ## ForwardDiff hessian is slow so we only test one time with hessian + n = 3 + p = rand(rng, Dirichlet(ones(n))) + dist = Multinomial(n, p) + ef = convert(ExponentialFamilyDistribution, dist) + η = getnaturalparameters(ef) + + f_logpartition = (η) -> logpartition(ExponentialFamilyDistribution(Multinomial, η, n)) + autograd_information = (η) -> ForwardDiff.hessian(f_logpartition, η) + @test fisherinformation(ef) ≈ autograd_information(η) atol = 1e-8 + + for n in 2:12 + p = rand(rng, Dirichlet(ones(n))) + dist = Multinomial(n, p) + ef = convert(ExponentialFamilyDistribution, dist) + η = getnaturalparameters(ef) + + J = ForwardDiff.jacobian(transformation, η) + @test J' * fisherinformation(dist) * J ≈ fisherinformation(ef) atol = 1e-8 + end + end + + @testset "ExponentialFamilyDistribution mean,cov" begin + rng = StableRNG(42) + for n in 2:12 + p = rand(rng, Dirichlet(ones(n))) + dist = Multinomial(n, p) + ef = convert(ExponentialFamilyDistribution, dist) + @test mean(dist) ≈ mean(ef) atol = 1e-8 + @test cov(dist) ≈ cov(ef) atol = 1e-8 + end + end +end +end + +================ +File: distributions/bernoulli_tests.jl +================ +# Bernoulli comes from Distributions.jl and most of the things should be covered there +# Here we test some extra ExponentialFamily.jl specific functionality + +@testitem "Bernoulli: vague" begin + include("distributions_setuptests.jl") + + d = vague(Bernoulli) + + @test typeof(d) <: Bernoulli + @test mean(d) === 0.5 + @test succprob(d) === 0.5 + @test failprob(d) === 0.5 +end + +@testitem "Bernoulli: probvec" begin + include("distributions_setuptests.jl") + + @test probvec(Bernoulli(0.5)) === (0.5, 0.5) + @test probvec(Bernoulli(0.3)) === (0.7, 0.3) + @test probvec(Bernoulli(0.6)) === (0.4, 0.6) +end + +@testitem "Bernoulli: logscale Bernoulli-Bernoulli/Categorical" begin + include("distributions_setuptests.jl") + + @test BayesBase.compute_logscale(Bernoulli(0.5), Bernoulli(0.5), Bernoulli(0.5)) ≈ log(0.5) + @test BayesBase.compute_logscale(Bernoulli(1), Bernoulli(0.5), Bernoulli(1)) ≈ log(0.5) + @test BayesBase.compute_logscale(Categorical([0.5, 0.5]), Bernoulli(0.5), Categorical([0.5, 0.5])) ≈ log(0.5) + @test BayesBase.compute_logscale(Categorical([0.5, 0.5]), Categorical([0.5, 0.5]), Bernoulli(0.5)) ≈ log(0.5) + @test BayesBase.compute_logscale(Categorical([1.0, 0.0]), Bernoulli(0.5), Categorical([1])) ≈ log(0.5) + @test BayesBase.compute_logscale(Categorical([1.0, 0.0, 0.0]), Bernoulli(0.5), Categorical([1.0, 0, 0])) ≈ log(0.5) +end + +@testitem "Bernoulli: ExponentialFamilyDistribution" begin + include("distributions_setuptests.jl") + + for p in 0.1:0.1:0.9 + @testset let d = Bernoulli(p) + ef = test_exponentialfamily_interface(d; option_assume_no_allocations = true) + η₁ = logit(p) + + for x in (0, 1) + @test @inferred(isbasemeasureconstant(ef)) === ConstantBaseMeasure() + @test @inferred(basemeasure(ef, x)) === oneunit(x) + @test @inferred(sufficientstatistics(ef, x)) === (x,) + @test @inferred(logpartition(ef)) ≈ log(1 + exp(η₁)) + end + + @test !@inferred(insupport(ef, -0.5)) + @test !@inferred(insupport(ef, 0.5)) + + # Not in the support + @test_throws Exception logpdf(ef, 0.5) + @test_throws Exception logpdf(ef, -0.5) + end + end + + # Test failing isproper cases + @test !isproper(MeanParametersSpace(), Bernoulli, [-1]) + @test !isproper(MeanParametersSpace(), Bernoulli, [0.5, 0.5]) + @test !isproper(NaturalParametersSpace(), Bernoulli, [0.5, 0.5]) + @test !isproper(NaturalParametersSpace(), Bernoulli, [Inf]) + + @test_throws Exception convert(ExponentialFamilyDistribution, Bernoulli(1.0)) # We cannot convert from `1.0`, `logit` function returns `Inf` +end + +@testitem "Bernoulli: prod with Distribution" begin + include("distributions_setuptests.jl") + + for strategy in (ClosedProd(), PreserveTypeProd(Distribution), PreserveTypeLeftProd(), PreserveTypeRightProd(), GenericProd()) + @test @inferred(prod(strategy, Bernoulli(0.5), Bernoulli(0.5))) ≈ Bernoulli(0.5) + @test @inferred(prod(strategy, Bernoulli(0.1), Bernoulli(0.6))) ≈ Bernoulli(0.14285714285714285) + @test @inferred(prod(strategy, Bernoulli(0.78), Bernoulli(0.05))) ≈ Bernoulli(0.1572580645161291) + end + + for strategy in (ClosedProd(), PreserveTypeProd(Distribution), GenericProd()) + # Test symmetric case + @test @inferred(prod(strategy, Bernoulli(0.5), Categorical([0.5, 0.5]))) ≈ Categorical([0.5, 0.5]) + @test @inferred(prod(strategy, Categorical([0.5, 0.5]), Bernoulli(0.5))) ≈ Categorical([0.5, 0.5]) + end + + @test @allocated(prod(ClosedProd(), Bernoulli(0.5), Bernoulli(0.5))) === 0 + @test @allocated(prod(PreserveTypeProd(Distribution), Bernoulli(0.5), Bernoulli(0.5))) === 0 + @test @allocated(prod(GenericProd(), Bernoulli(0.5), Bernoulli(0.5))) === 0 +end + +@testitem "Bernoulli: prod with Categorical" begin + include("distributions_setuptests.jl") + + @test prod(ClosedProd(), Bernoulli(0.5), Categorical([0.5, 0.5])) ≈ + Categorical([0.5, 0.5]) + @test prod(ClosedProd(), Bernoulli(0.1), Categorical(0.4, 0.6)) ≈ + Categorical([1 - 0.14285714285714285, 0.14285714285714285]) + @test prod(ClosedProd(), Bernoulli(0.78), Categorical([0.95, 0.05])) ≈ + Categorical([1 - 0.1572580645161291, 0.1572580645161291]) + @test prod(ClosedProd(), Bernoulli(0.5), Categorical([0.3, 0.3, 0.4])) ≈ + Categorical([0.5, 0.5, 0]) + @test prod(ClosedProd(), Bernoulli(0.5), Categorical([1.0])) ≈ + Categorical([1.0, 0]) +end + +@testitem "Bernoulli: prod with ExponentialFamilyDistribution" begin + include("distributions_setuptests.jl") + + for pleft in 0.1:0.1:0.9, pright in 0.1:0.1:0.9 + @testset let (left, right) = (Bernoulli(pleft), Bernoulli(pright)) + @test test_generic_simple_exponentialfamily_product( + left, + right, + strategies = ( + ClosedProd(), + GenericProd(), + PreserveTypeProd(ExponentialFamilyDistribution), + PreserveTypeProd(ExponentialFamilyDistribution{Bernoulli}) + ) + ) + end + end +end + +================ +File: distributions/beta_tests.jl +================ +# Beta comes from Distributions.jl and most of the things should be covered there +# Here we test some extra ExponentialFamily.jl specific functionality + +@testitem "Beta: vague" begin + include("distributions_setuptests.jl") + + d = vague(Beta) + + @test typeof(d) <: Beta + @test mean(d) === 0.5 + @test params(d) === (1.0, 1.0) +end + +@testitem "Beta: mean(::typeof(log))" begin + include("distributions_setuptests.jl") + + @test mean(log, Beta(1.0, 3.0)) ≈ -1.8333333333333335 + @test mean(log, Beta(0.1, 0.3)) ≈ -7.862370395825961 + @test mean(log, Beta(4.5, 0.3)) ≈ -0.07197681436958758 +end + +@testitem "Beta: mean(::typeof(mirrorlog))" begin + include("distributions_setuptests.jl") + + @test mean(mirrorlog, Beta(1.0, 3.0)) ≈ -0.33333333333333337 + @test mean(mirrorlog, Beta(0.1, 0.3)) ≈ -0.9411396776150167 + @test mean(mirrorlog, Beta(4.5, 0.3)) ≈ -4.963371962929249 +end + +@testitem "Beta: ExponentialFamilyDistribution" begin + include("distributions_setuptests.jl") + + for a in 0.1:0.1:0.9, b in 1.1:0.2:2.0 + @testset let d = Beta(a, b) + ef = test_exponentialfamily_interface(d; option_assume_no_allocations = true) + + (η₁, η₂) = (a - 1, b - 1) + + for x in 0.1:0.1:0.9 + @test @inferred(isbasemeasureconstant(ef)) === ConstantBaseMeasure() + @test @inferred(basemeasure(ef, x)) === oneunit(x) + @test all(@inferred(sufficientstatistics(ef, x)) .≈ (log(x), log(1 - x))) + @test @inferred(logpartition(ef)) ≈ (logbeta(η₁ + 1, η₂ + 1)) + end + + @test !@inferred(insupport(ef, -0.5)) + @test @inferred(insupport(ef, 0.5)) + + # Not in the support + @test_throws Exception logpdf(ef, -0.5) + end + end + + # Test failing isproper cases + @test !isproper(MeanParametersSpace(), Beta, [-1]) + @test !isproper(MeanParametersSpace(), Beta, [1, -0.1]) + @test !isproper(MeanParametersSpace(), Beta, [-0.1, 1]) + @test !isproper(NaturalParametersSpace(), Beta, [-1.1]) + @test !isproper(NaturalParametersSpace(), Beta, [1, -1.1]) + @test !isproper(NaturalParametersSpace(), Beta, [-1.1, 1]) + + # `a`s must add up to something more than 1, otherwise is not proper + let ef = convert(ExponentialFamilyDistribution, Beta(0.1, 1.0)) + @test !isproper(prod(PreserveTypeProd(ExponentialFamilyDistribution), ef, ef)) + end + + # `b`s must add up to something more than 1, otherwise is not proper + let ef = convert(ExponentialFamilyDistribution, Beta(1.0, 0.1)) + @test !isproper(prod(PreserveTypeProd(ExponentialFamilyDistribution), ef, ef)) + end +end + +@testitem "Beta: prod with Distributions" begin + include("distributions_setuptests.jl") + + for strategy in (ClosedProd(), PreserveTypeProd(Distribution), PreserveTypeLeftProd(), PreserveTypeRightProd(), GenericProd()) + @test prod(strategy, Beta(3.0, 2.0), Beta(2.0, 1.0)) ≈ Beta(4.0, 2.0) + @test prod(strategy, Beta(7.0, 1.0), Beta(0.1, 4.5)) ≈ Beta(6.1, 4.5) + @test prod(strategy, Beta(1.0, 3.0), Beta(0.2, 0.4)) ≈ Beta(0.19999999999999996, 2.4) + end + + @test @allocated(prod(ClosedProd(), Beta(3.0, 2.0), Beta(2.0, 1.0))) === 0 + @test @allocated(prod(GenericProd(), Beta(3.0, 2.0), Beta(2.0, 1.0))) === 0 +end + +@testitem "Beta: prod with ExponentialFamilyDistribution" begin + include("distributions_setuptests.jl") + + for aleft in 0.51:1.0:5.0, aright in 0.51:1.0:5.0, bleft in 0.51:1.0:5.0, bright in 0.51:1.0:5.0 + @testset let (left, right) = (Beta(aleft, bleft), Beta(aright, bright)) + @test test_generic_simple_exponentialfamily_product( + left, + right, + strategies = ( + ClosedProd(), + GenericProd(), + PreserveTypeProd(ExponentialFamilyDistribution), + PreserveTypeProd(ExponentialFamilyDistribution{Beta}) + ) + ) + end + end +end + +================ +File: distributions/binomial_tests.jl +================ +# Binomial comes from Distributions.jl and most of the things should be covered there +# Here we test some extra ExponentialFamily.jl specific functionality + +@testitem "Binomial: probvec" begin + include("distributions_setuptests.jl") + + @test all(probvec(Binomial(2, 0.8)) .≈ (0.2, 0.8)) + @test probvec(Binomial(2, 0.2)) == (0.8, 0.2) + @test probvec(Binomial(2, 0.1)) == (0.9, 0.1) + @test probvec(Binomial(2)) == (0.5, 0.5) +end + +@testitem "Binomial: vague" begin + include("distributions_setuptests.jl") + + @test_throws MethodError vague(Binomial) + @test_throws MethodError vague(Binomial, 1 / 2) + + vague_dist = vague(Binomial, 5) + @test typeof(vague_dist) <: Binomial + @test probvec(vague_dist) == (0.5, 0.5) +end + +@testitem "Binomial: ExponentialFamilyDistribution" begin + include("distributions_setuptests.jl") + + for n in (2, 3, 4), p in 0.1:0.2:0.9 + @testset let d = Binomial(n, p) + ef = test_exponentialfamily_interface(d; option_assume_no_allocations = true) + + η₁ = log(p / (1 - p)) + + for x in 0:n + @test @inferred(isbasemeasureconstant(ef)) === NonConstantBaseMeasure() + @test @inferred(basemeasure(ef, x)) === binomial(n, x) + @test @inferred(logbasemeasure(ef, x)) === loggamma(n+1) - (loggamma(n - x + 1) + loggamma(x + 1)) + @test all(@inferred(sufficientstatistics(ef, x)) .≈ (x,)) + @test @inferred(logpartition(ef)) ≈ (n * log(1 + exp(η₁))) + end + + @test !@inferred(insupport(ef, -1)) + @test @inferred(insupport(ef, 0)) + + # Not in the support + @test_throws Exception logpdf(ef, -1) + end + end + + # Test failing isproper cases + @test !isproper(MeanParametersSpace(), Binomial, [-1], 1) + @test !isproper(MeanParametersSpace(), Binomial, [0.5], -1) + @test !isproper(MeanParametersSpace(), Binomial, [-0.1, 1], 10) + @test !isproper(NaturalParametersSpace(), Binomial, [-1.1], -1) + @test !isproper(NaturalParametersSpace(), Binomial, [1, -1.1], 10) +end + +@testitem "Binomial: prod ExponentialFamilyDistribution" begin + include("distributions_setuptests.jl") + + for nleft in 1:1, pleft in 0.1:0.1:0.1, nright in 1:1, pright in 0.1:0.1:0.1 + @testset let (left, right) = (Binomial(nleft, pleft), Binomial(nright, pright)) + for (efleft, efright) in ((left, right), (convert(ExponentialFamilyDistribution, left), convert(ExponentialFamilyDistribution, right))) + for strategy in (PreserveTypeProd(ExponentialFamilyDistribution),) + prod_dist = prod(strategy, efleft, efright) + + @test prod_dist isa ExponentialFamilyDistribution + + hist_sum(x) = + basemeasure(prod_dist, x) * exp( + dot(ExponentialFamily.flatten_parameters(sufficientstatistics(prod_dist, x)), getnaturalparameters(prod_dist)) - + logpartition(prod_dist) + ) + + support = 0:1:max(nleft, nright) + + @test sum(hist_sum, support) ≈ 1.0 atol = 1e-9 + + for x in support + @test basemeasure(prod_dist, x) ≈ (binomial(nleft, x) * binomial(nright, x)) + @test all(sufficientstatistics(prod_dist, x) .≈ (x,)) + end + end + end + end + end +end + +================ +File: distributions/categorical_tests.jl +================ +# Categorical comes from Distributions.jl and most of the things should be covered there +# Here we test some extra ExponentialFamily.jl specific functionality + +@testitem "Categorical: vague" begin + include("distributions_setuptests.jl") + + @test_throws MethodError vague(Categorical) + + d1 = vague(Categorical, 2) + + @test typeof(d1) <: Categorical + @test probvec(d1) ≈ [0.5, 0.5] + + d2 = vague(Categorical, 4) + + @test typeof(d2) <: Categorical + @test probvec(d2) ≈ [0.25, 0.25, 0.25, 0.25] +end + +@testitem "Categorical: probvec" begin + include("distributions_setuptests.jl") + + @test probvec(Categorical([0.1, 0.4, 0.5])) == [0.1, 0.4, 0.5] + @test probvec(Categorical([1 / 3, 1 / 3, 1 / 3])) == [1 / 3, 1 / 3, 1 / 3] + @test probvec(Categorical([0.8, 0.1, 0.1])) == [0.8, 0.1, 0.1] +end + +@testitem "Categorical: prod Distribution" begin + include("distributions_setuptests.jl") + + for strategy in (ClosedProd(), PreserveTypeProd(Distribution), PreserveTypeLeftProd(), PreserveTypeRightProd(), GenericProd()) + @test prod(strategy, Categorical([0.1, 0.4, 0.5]), Categorical([1 / 3, 1 / 3, 1 / 3])) == + Categorical([0.1, 0.4, 0.5]) + @test prod(strategy, Categorical([0.1, 0.4, 0.5]), Categorical([0.8, 0.1, 0.1])) == + Categorical([0.47058823529411764, 0.23529411764705882, 0.2941176470588235]) + @test prod(strategy, Categorical([0.2, 0.6, 0.2]), Categorical([0.8, 0.1, 0.1])) ≈ + Categorical([2 / 3, 1 / 4, 1 / 12]) + end +end + +@testitem "Categorical: ExponentialFamilyDistribution" begin + include("distributions_setuptests.jl") + + for s in (2, 3, 4, 5) + @testset let d = Categorical(normalize!(rand(s), 1)) + ef = test_exponentialfamily_interface( + d; + test_fisherinformation_properties = false, # The fisher information is not-posdef, to discuss + test_fisherinformation_against_jacobian = false + ) + + run_test_fisherinformation_against_jacobian(d; assume_no_allocations = false, mappings = ( + NaturalParametersSpace() => MeanParametersSpace(), + # MeanParametersSpace() => NaturalParametersSpace(), # here is the problem for discussion, the test is broken + )) + + θ = probvec(d) + η = map(p -> log(p / θ[end]), θ) + + for x in 1:s + v = zeros(s) + v[x] = 1 + + @test @inferred(isbasemeasureconstant(ef)) === ConstantBaseMeasure() + @test @inferred(basemeasure(ef, x)) === oneunit(x) + @test all(@inferred(sufficientstatistics(ef, x)) .≈ (v,)) + @test @inferred(logpartition(ef)) ≈ logsumexp(η) + end + + @test !@inferred(insupport(ef, s + 1)) + @test @inferred(insupport(ef, s)) + + # # Not in the support + @test_throws Exception logpdf(ef, ones(s)) + end + end + + # Test failing isproper cases + @test !isproper(MeanParametersSpace(), Categorical, [-1], 2) # conditioner does not match the length + @test !isproper(MeanParametersSpace(), Categorical, [-1], 1) + @test !isproper(MeanParametersSpace(), Categorical, [1, 0.5], 2) + @test !isproper(MeanParametersSpace(), Categorical, [-0.5, 1.5], 2) + @test !isproper(NaturalParametersSpace(), Categorical, [-1.1], 2) # conditioner does not match the length + @test !isproper(NaturalParametersSpace(), Categorical, [-1.1], 1) + @test !isproper(NaturalParametersSpace(), Categorical, [1], 1) # length should be >=2 +end + +@testitem "Categorical ExponentialFamilyDistribution supports RecursiveArrayTools" begin + using RecursiveArrayTools + include("distributions_setuptests.jl") + ef = ExponentialFamilyDistribution(Categorical, ArrayPartition([0, 1, 0]), 3, nothing) + part_ef = ExponentialFamilyDistribution(Categorical, ArrayPartition([0, 1], [0]), 3, nothing) + @test convert(Distribution, ef) ≈ convert(Distribution, part_ef) +end + +@testitem "Categorical: prod with ExponentialFamilyDistribution" begin + include("distributions_setuptests.jl") + + for s in (2, 3, 4, 5) + @testset let (left, right) = (Categorical(normalize!(rand(s), 1)), Categorical(normalize!(rand(s), 1))) + @test test_generic_simple_exponentialfamily_product( + left, + right, + strategies = ( + ClosedProd(), + GenericProd(), + PreserveTypeProd(ExponentialFamilyDistribution), + PreserveTypeProd(ExponentialFamilyDistribution{Categorical}) + ) + ) + end + end +end + +================ +File: distributions/chi_squared_tests.jl +================ +# Chisq comes from Distributions.jl and most of the things should be covered there +# Here we test some extra ExponentialFamily.jl specific functionality + +@testitem "Chisq: ExponentialFamilyDistribution" begin + include("distributions_setuptests.jl") + + for i in 3:0.5:7 + @testset let d = Chisq(2 * (i + 1)) + ef = test_exponentialfamily_interface(d; option_assume_no_allocations = true) + η₁ = first(getnaturalparameters(ef)) + + for x in (0.1, 0.5, 1.0) + @test @inferred(isbasemeasureconstant(ef)) === NonConstantBaseMeasure() + @test @inferred(basemeasure(ef, x)) === exp(-x / 2) + @test @inferred(sufficientstatistics(ef, x)) === (log(x),) + @test @inferred(logpartition(ef)) ≈ loggamma(η₁ + 1) + (η₁ + 1) * log(2.0) + end + end + end + + for space in (MeanParametersSpace(), NaturalParametersSpace()) + @test !isproper(space, Chisq, [Inf]) + @test !isproper(space, Chisq, [-1.0]) + @test !isproper(space, Chisq, [NaN]) + @test !isproper(space, Chisq, [1.0], NaN) + @test !isproper(space, Chisq, [0.5, 0.5], 1.0) + end + + ## mean parameter should be integer in the MeanParametersSpace + @test isproper(MeanParametersSpace(), Chisq, [0.1]) + @test isproper(NaturalParametersSpace(), Chisq, [-0.5]) + @test !isproper(NaturalParametersSpace(), Chisq, [-1.5]) + @test convert(ExponentialFamilyDistribution, Chisq(0.5)) ≈ ExponentialFamilyDistribution(Chisq, [-0.75]) + @test_throws Exception convert(ExponentialFamilyDistribution, Chisq(Inf)) +end + +@testitem "Chisq: prod with Distribution and ExponentialFamilyDistribution" begin + include("distributions_setuptests.jl") + + @testset for i in 3:6 + left = Chisq(i + 1) + right = Chisq(i) + prod_dist = prod(PreserveTypeProd(ExponentialFamilyDistribution), left, right) + efleft = convert(ExponentialFamilyDistribution, left) + efright = convert(ExponentialFamilyDistribution, right) + prod_ef = prod(PreserveTypeProd(ExponentialFamilyDistribution), efleft, efright) + η_left = getnaturalparameters(efleft) + η_right = getnaturalparameters(efright) + naturalparameters = η_left + η_right + + @test prod_dist.naturalparameters == naturalparameters + @test getbasemeasure(prod_dist)(i) ≈ exp(-i) + @test sufficientstatistics(prod_dist, i) === (log(i),) + @test getlogpartition(prod_dist)(η_left + η_right) ≈ loggamma(η_left[1] + η_right[1] + 1) + @test getsupport(prod_dist) === support(left) + + @test prod_ef.naturalparameters == naturalparameters + @test getbasemeasure(prod_ef)(i) ≈ exp(-i) + @test sufficientstatistics(prod_ef, i) === (log(i),) + @test getlogpartition(prod_ef)(η_left + η_right) ≈ loggamma(η_left[1] + η_right[1] + 1) + @test getsupport(prod_ef) === support(left) + end +end + +================ +File: distributions/dirichlet_tests.jl +================ +# Dirichlet comes from Distributions.jl and most of the things should be covered there +# Here we test some extra ExponentialFamily.jl specific functionality + +@testitem "Dirichlet: vague" begin + include("distributions_setuptests.jl") + + @test_throws MethodError vague(Dirichlet) + + d1 = vague(Dirichlet, 2) + + @test typeof(d1) <: Dirichlet + @test probvec(d1) == ones(2) + + d2 = vague(Dirichlet, 4) + + @test typeof(d2) <: Dirichlet + @test probvec(d2) == ones(4) +end + +@testitem "Dirichlet: mean(::typeof(log))" begin + include("distributions_setuptests.jl") + + import Base.Broadcast: BroadcastFunction + + @test mean(BroadcastFunction(log), Dirichlet([1.0, 1.0, 1.0])) ≈ [-1.5000000000000002, -1.5000000000000002, -1.5000000000000002] + @test mean(BroadcastFunction(log), Dirichlet([1.1, 2.0, 2.0])) ≈ [-1.9517644694670657, -1.1052251939575213, -1.1052251939575213] + @test mean(BroadcastFunction(log), Dirichlet([3.0, 1.2, 5.0])) ≈ [-1.2410879175727905, -2.4529121492634465, -0.657754584239457] +end + +@testitem "Dirichlet: ExponentialFamilyDistribution" begin + include("distributions_setuptests.jl") + + rng = StableRNG(42) + for len in 3:5 + α = rand(rng, len) + @testset let d = Dirichlet(α) + ef = test_exponentialfamily_interface(d; option_assume_no_allocations = false) + η1 = getnaturalparameters(ef) + + for x in [rand(rng, len) for _ in 1:3] + x = x ./ sum(x) + @test @inferred(isbasemeasureconstant(ef)) === ConstantBaseMeasure() + @test @inferred(basemeasure(ef, x)) === 1.0 + @test all(@inferred(sufficientstatistics(ef, x)) .≈ (map(log, x),)) + firstterm = mapreduce(x -> loggamma(x + 1), +, η1) + secondterm = loggamma(sum(η1) + length(η1)) + @test @inferred(logpartition(ef)) ≈ firstterm - secondterm + end + end + end + + for space in (MeanParametersSpace(), NaturalParametersSpace()) + @test !isproper(space, Dirichlet, [Inf, Inf], 1.0) + @test !isproper(space, Dirichlet, [1.0], Inf) + @test !isproper(space, Dirichlet, [NaN], 1.0) + @test !isproper(space, Dirichlet, [1.0], NaN) + @test !isproper(space, Dirichlet, [0.5, 0.5], 1.0) + @test isproper(space, Dirichlet, [2.0, 3.0]) + @test !isproper(space, Dirichlet, [-1.0, -1.2]) + end + + @test_throws Exception convert(ExponentialFamilyDistribution, Dirichlet([Inf, Inf])) +end + +@testitem "Dirichlet: prod with Distribution" begin + include("distributions_setuptests.jl") + + for strategy in (ClosedProd(), PreserveTypeProd(Distribution), PreserveTypeLeftProd(), PreserveTypeRightProd(), GenericProd()) + @test @inferred(prod(strategy, Dirichlet([1.0, 1.0, 1.0]), Dirichlet([1.0, 1.0, 1.0]))) ≈ Dirichlet([1.0, 1.0, 1.0]) + @test @inferred(prod(strategy, Dirichlet([1.1, 1.0, 2.0]), Dirichlet([1.0, 1.2, 1.0]))) ≈ Dirichlet([1.1, 1.2, 2.0]) + @test @inferred(prod(strategy, Dirichlet([1.1, 2.0, 2.0]), Dirichlet([3.0, 1.2, 5.0]))) ≈ Dirichlet([3.1, 2.2, 6.0]) + end +end + +@testitem "Dirichlet: prod with ExponentialFamilyDistribution" begin + include("distributions_setuptests.jl") + rng = StableRNG(123) + for len in 3:6 + αleft = rand(rng, len) .+ 1 + αright = rand(rng, len) .+ 1 + @testset let (left, right) = (Dirichlet(αleft), Dirichlet(αright)) + @test test_generic_simple_exponentialfamily_product( + left, + right, + strategies = ( + ClosedProd(), + GenericProd() + ) + ) + end + end +end + +================ +File: distributions/distributions_setuptests.jl +================ +using ExponentialFamily, BayesBase, FastCholesky, Distributions, LinearAlgebra, TinyHugeNumbers +using Test, ForwardDiff, Random, StatsFuns, StableRNGs, FillArrays, JET, SpecialFunctions + +import BayesBase: compute_logscale + +import ExponentialFamily: + ExponentialFamilyDistribution, + getnaturalparameters, + getconditioner, + logpartition, + basemeasure, + logbasemeasure, + insupport, + sufficientstatistics, + fisherinformation, + pack_parameters, + unpack_parameters, + isbasemeasureconstant, + ConstantBaseMeasure, + MeanToNatural, + NaturalToMean, + NaturalParametersSpace, + invscatter, + location, + locationdim + +import Distributions: + variate_form, + value_support + +import SpecialFunctions: + logbeta, + loggamma, + digamma, + logfactorial, + besseli + +import HCubature: + hquadrature + +import DomainSets: + NaturalNumbers + +union_types(x::Union) = (x.a, union_types(x.b)...) +union_types(x::Type) = (x,) + +function Base.isapprox(a::Tuple, b::Tuple; kwargs...) + return length(a) === length(b) && all((d) -> isapprox(d[1], d[2]; kwargs...), zip(a, b)) +end + +JET_function_filter(@nospecialize f) = ((f === FastCholesky.cholinv) || (f === FastCholesky.cholsqrt)) + +macro test_opt(expr) + return esc(quote + JET.@test_opt function_filter=JET_function_filter ignored_modules=(Base,LinearAlgebra) $expr + end) +end + +function test_exponentialfamily_interface(distribution; + test_parameters_conversion = true, + test_similar_creation = true, + test_distribution_conversion = true, + test_packing_unpacking = true, + test_isproper = true, + test_basic_functions = true, + test_gradlogpartition_properties = true, + test_fisherinformation_properties = true, + test_fisherinformation_against_hessian = true, + test_fisherinformation_against_jacobian = true, + test_plogpdf_interface = true, + option_assume_no_allocations = false +) + T = ExponentialFamily.exponential_family_typetag(distribution) + + ef = @inferred(convert(ExponentialFamilyDistribution, distribution)) + + @test_opt convert(ExponentialFamilyDistribution, distribution) + + @test ef isa ExponentialFamilyDistribution{T} + + test_parameters_conversion && run_test_parameters_conversion(distribution) + test_similar_creation && run_test_similar_creation(distribution) + test_distribution_conversion && run_test_distribution_conversion(distribution; assume_no_allocations = option_assume_no_allocations) + test_packing_unpacking && run_test_packing_unpacking(distribution) + test_isproper && run_test_isproper(distribution; assume_no_allocations = option_assume_no_allocations) + test_basic_functions && run_test_basic_functions(distribution; assume_no_allocations = option_assume_no_allocations) + test_gradlogpartition_properties && run_test_gradlogpartition_properties(distribution) + test_fisherinformation_properties && run_test_fisherinformation_properties(distribution) + test_fisherinformation_against_hessian && run_test_fisherinformation_against_hessian(distribution; assume_no_allocations = option_assume_no_allocations) + test_fisherinformation_against_jacobian && run_test_fisherinformation_against_jacobian(distribution; assume_no_allocations = option_assume_no_allocations) + test_plogpdf_interface && run_test_plogpdf_interface(distribution) + return ef +end + +function run_test_plogpdf_interface(distribution) + ef = convert(ExponentialFamily.ExponentialFamilyDistribution, distribution) + η = getnaturalparameters(ef) + samples = rand(StableRNG(42), distribution, 10) + _, _samples = ExponentialFamily.check_logpdf(variate_form(typeof(ef)), typeof(samples), eltype(samples), ef, samples) + ss_vectors = map(s -> ExponentialFamily.pack_parameters(ExponentialFamily.sufficientstatistics(ef, s)), _samples) + unnormalized_logpdfs = map(v -> dot(v, η), ss_vectors) + @test all(unnormalized_logpdfs ≈ map(x -> ExponentialFamily._plogpdf(ef, x, 0, 0), _samples)) +end + +function run_test_parameters_conversion(distribution) + T = ExponentialFamily.exponential_family_typetag(distribution) + + tuple_of_θ, conditioner = ExponentialFamily.separate_conditioner(T, params(MeanParametersSpace(), distribution)) + + @test all(ExponentialFamily.join_conditioner(T, tuple_of_θ, conditioner) .== params(MeanParametersSpace(), distribution)) + + @test_opt ExponentialFamily.separate_conditioner(T, params(MeanParametersSpace(), distribution)) + @test_opt ExponentialFamily.join_conditioner(T, tuple_of_θ, conditioner) + @test_opt params(MeanParametersSpace(), distribution) + + ef = @inferred(convert(ExponentialFamilyDistribution, distribution)) + + @test conditioner === getconditioner(ef) + + # Check the `conditioned` conversions, should work for un-conditioned members as well + tuple_of_η = MeanToNatural(T)(tuple_of_θ, conditioner) + + @test all(NaturalToMean(T)(tuple_of_η, conditioner) .≈ tuple_of_θ) + @test all(MeanToNatural(T)(tuple_of_θ, conditioner) .≈ tuple_of_η) + @test all(NaturalToMean(T)(pack_parameters(NaturalParametersSpace(), T, tuple_of_η), conditioner) .≈ pack_parameters(MeanParametersSpace(), T, tuple_of_θ)) + @test all(MeanToNatural(T)(pack_parameters(MeanParametersSpace(), T, tuple_of_θ), conditioner) .≈ pack_parameters(NaturalParametersSpace(), T, tuple_of_η)) + + @test_opt NaturalToMean(T)(tuple_of_η, conditioner) + @test_opt MeanToNatural(T)(tuple_of_θ, conditioner) + @test_opt NaturalToMean(T)(pack_parameters(NaturalParametersSpace(), T, tuple_of_η), conditioner) + @test_opt MeanToNatural(T)(pack_parameters(MeanParametersSpace(), T, tuple_of_θ), conditioner) + + @test all(map(NaturalParametersSpace() => MeanParametersSpace(), T, tuple_of_η, conditioner) .≈ tuple_of_θ) + @test all(map(MeanParametersSpace() => NaturalParametersSpace(), T, tuple_of_θ, conditioner) .≈ tuple_of_η) + @test all( + map(NaturalParametersSpace() => MeanParametersSpace(), T, pack_parameters(NaturalParametersSpace(), T, tuple_of_η), conditioner) .≈ + pack_parameters(MeanParametersSpace(), T, tuple_of_θ) + ) + @test all( + map(MeanParametersSpace() => NaturalParametersSpace(), T, pack_parameters(MeanParametersSpace(), T, tuple_of_θ), conditioner) .≈ + pack_parameters(NaturalParametersSpace(), T, tuple_of_η) + ) + + # Double check the `conditioner` free conversions + if isnothing(conditioner) + local _tuple_of_η = MeanToNatural(T)(tuple_of_θ) + + @test all(_tuple_of_η .== tuple_of_η) + @test all(NaturalToMean(T)(_tuple_of_η) .≈ tuple_of_θ) + @test all(NaturalToMean(T)(_tuple_of_η) .≈ tuple_of_θ) + @test all(MeanToNatural(T)(tuple_of_θ) .≈ _tuple_of_η) + @test all(NaturalToMean(T)(pack_parameters(NaturalParametersSpace(), T, _tuple_of_η)) .≈ pack_parameters(MeanParametersSpace(), T, tuple_of_θ)) + @test all(MeanToNatural(T)(pack_parameters(MeanParametersSpace(), T, tuple_of_θ)) .≈ pack_parameters(NaturalParametersSpace(), T, _tuple_of_η)) + + @test all(map(NaturalParametersSpace() => MeanParametersSpace(), T, _tuple_of_η) .≈ tuple_of_θ) + @test all(map(NaturalParametersSpace() => MeanParametersSpace(), T, _tuple_of_η) .≈ tuple_of_θ) + @test all(map(MeanParametersSpace() => NaturalParametersSpace(), T, tuple_of_θ) .≈ _tuple_of_η) + @test all( + map(NaturalParametersSpace() => MeanParametersSpace(), T, pack_parameters(NaturalParametersSpace(), T, _tuple_of_η)) .≈ + pack_parameters(MeanParametersSpace(), T, tuple_of_θ) + ) + @test all( + map(MeanParametersSpace() => NaturalParametersSpace(), T, pack_parameters(MeanParametersSpace(), T, tuple_of_θ)) .≈ + pack_parameters(NaturalParametersSpace(), T, _tuple_of_η) + ) + end + + @test all(unpack_parameters(NaturalParametersSpace(), T, pack_parameters(NaturalParametersSpace(), T, tuple_of_η)) .== tuple_of_η) + @test all(unpack_parameters(MeanParametersSpace(), T, pack_parameters(MeanParametersSpace(), T, tuple_of_θ)) .== tuple_of_θ) + + @test_opt unpack_parameters(NaturalParametersSpace(), T, pack_parameters(NaturalParametersSpace(), T, tuple_of_η)) + @test_opt unpack_parameters(MeanParametersSpace(), T, pack_parameters(MeanParametersSpace(), T, tuple_of_θ)) + + # Extra methods for conditioner free distributions + if isnothing(conditioner) + @test all( + params(MeanParametersSpace(), distribution) .≈ + map(NaturalParametersSpace() => MeanParametersSpace(), T, params(NaturalParametersSpace(), distribution)) + ) + @test all( + params(NaturalParametersSpace(), distribution) .≈ + map(MeanParametersSpace() => NaturalParametersSpace(), T, params(MeanParametersSpace(), distribution)) + ) + end +end + +function run_test_similar_creation(distribution) + T = ExponentialFamily.exponential_family_typetag(distribution) + + ef = @inferred(convert(ExponentialFamilyDistribution, distribution)) + + @test similar(ef) isa ExponentialFamilyDistribution{T} + @test_opt similar(ef) +end + +function run_test_distribution_conversion(distribution; assume_no_allocations = true) + T = ExponentialFamily.exponential_family_typetag(distribution) + + ef = @inferred(convert(ExponentialFamilyDistribution, distribution)) + + @test @inferred(convert(Distribution, ef)) ≈ distribution + @test_opt convert(Distribution, ef) + + if assume_no_allocations + @test @allocated(convert(Distribution, ef)) === 0 + end +end + +function run_test_packing_unpacking(distribution) + T = ExponentialFamily.exponential_family_typetag(distribution) + + tuple_of_θ, conditioner = ExponentialFamily.separate_conditioner(T, params(MeanParametersSpace(), distribution)) + ef = @inferred(convert(ExponentialFamilyDistribution, distribution)) + + tuple_of_η = MeanToNatural(T)(tuple_of_θ, conditioner) + + @test all(unpack_parameters(ef) .≈ tuple_of_η) + @test @allocated(unpack_parameters(ef)) === 0 + + @test_opt ExponentialFamily.separate_conditioner(T, params(MeanParametersSpace(), distribution)) + @test_opt unpack_parameters(ef) +end + +function run_test_isproper(distribution; assume_no_allocations = true) + T = ExponentialFamily.exponential_family_typetag(distribution) + + exponential_family_form = @inferred(convert(ExponentialFamilyDistribution, distribution)) + + @test @inferred(isproper(exponential_family_form)) + @test_opt isproper(exponential_family_form) + + if assume_no_allocations + @test @allocated(isproper(exponential_family_form)) === 0 + end +end + +function run_test_basic_functions(distribution; nsamples = 10, test_gradients = true, test_samples_logpdf = true, assume_no_allocations = true) + T = ExponentialFamily.exponential_family_typetag(distribution) + + ef = @inferred(convert(ExponentialFamilyDistribution, distribution)) + + (η, conditioner) = (getnaturalparameters(ef), getconditioner(ef)) + + # ! do not use `rand(distribution, nsamples)` + # ! do not use fixed RNG + samples = [rand(distribution) for _ in 1:nsamples] + + # Not all methods are defined for all objects in Distributions.jl + # For this methods we first test if the method is defined for the distribution + # And only then we test the method for the exponential family form + potentially_missing_methods = ( + cov, + skewness, + kurtosis + ) + + argument_type = Tuple{typeof(distribution)} + + @test_opt logpdf(ef, first(samples)) + @test_opt pdf(ef, first(samples)) + @test_opt mean(ef) + @test_opt var(ef) + @test_opt std(ef) + # Sampling is not type-stable for all distributions + # due to fallback to `Distributions.jl` + # @test_opt rand(ef) + # @test_opt rand(ef, 10) + # @test_opt rand!(ef, rand(ef, 10)) + + @test_opt isbasemeasureconstant(ef) + @test_opt basemeasure(ef, first(samples)) + @test_opt logbasemeasure(ef, first(samples)) + @test_opt sufficientstatistics(ef, first(samples)) + @test_opt logpartition(ef) + @test_opt gradlogpartition(ef) + @test_opt fisherinformation(ef) + + for x in samples + # We believe in the implementation in the `Distributions.jl` + @test @inferred(logpdf(ef, x)) ≈ logpdf(distribution, x) + @test @inferred(pdf(ef, x)) ≈ pdf(distribution, x) + @test @inferred(mean(ef)) ≈ mean(distribution) + @test @inferred(var(ef)) ≈ var(distribution) + @test @inferred(std(ef)) ≈ std(distribution) + @test last(size(rand(ef, 10))) === 10 # Test that `rand` without explicit `rng` works + @test rand(StableRNG(42), ef) ≈ rand(StableRNG(42), distribution) + @test all(rand(StableRNG(42), ef, 10) .≈ rand(StableRNG(42), distribution, 10)) + @test all(rand!(StableRNG(42), ef, [deepcopy(x) for _ in 1:10]) .≈ rand!(StableRNG(42), distribution, [deepcopy(x) for _ in 1:10])) + + for method in potentially_missing_methods + if hasmethod(method, argument_type) + @test @inferred(method(ef)) ≈ method(distribution) + end + end + + @test @inferred(isbasemeasureconstant(ef)) === isbasemeasureconstant(T) + @test @inferred(basemeasure(ef, x)) == getbasemeasure(T, conditioner)(x) + @test @inferred(logbasemeasure(ef, x)) == getlogbasemeasure(T, conditioner)(x) + @test logbasemeasure(ef, x) ≈ log(basemeasure(ef, x)) atol = 1e-8 + @test all(@inferred(sufficientstatistics(ef, x)) .== map(f -> f(x), getsufficientstatistics(T, conditioner))) + @test @inferred(logpartition(ef)) == getlogpartition(T, conditioner)(η) + @test @inferred(fisherinformation(ef)) == getfisherinformation(T, conditioner)(η) + + # Double check the `conditioner` free methods + if isnothing(conditioner) + @test @inferred(basemeasure(ef, x)) == getbasemeasure(T)(x) + @test @inferred(logbasemeasure(ef, x)) == getlogbasemeasure(T)(x) + @test all(@inferred(sufficientstatistics(ef, x)) .== map(f -> f(x), getsufficientstatistics(T))) + @test @inferred(logpartition(ef)) == getlogpartition(T)(η) + @test @inferred(fisherinformation(ef)) == getfisherinformation(T)(η) + end + + if test_gradients && value_support(T) === Continuous && x isa Number + let tlogpdf = ForwardDiff.derivative((x) -> logpdf(distribution, x), x) + if !isnan(tlogpdf) && !isinf(tlogpdf) + @test ForwardDiff.derivative((x) -> logpdf(ef, x), x) ≈ tlogpdf + @test ForwardDiff.gradient((x) -> logpdf(ef, x[1]), [x])[1] ≈ tlogpdf + end + end + let tpdf = ForwardDiff.derivative((x) -> pdf(distribution, x), x) + if !isnan(tpdf) && !isinf(tpdf) + @test ForwardDiff.derivative((x) -> pdf(ef, x), x) ≈ tpdf + @test ForwardDiff.gradient((x) -> pdf(ef, x[1]), [x])[1] ≈ tpdf + end + end + end + + if test_gradients && value_support(T) === Continuous && x isa AbstractVector + let tlogpdf = ForwardDiff.gradient((x) -> logpdf(distribution, x), x) + if !any(isnan, tlogpdf) && !any(isinf, tlogpdf) + @test ForwardDiff.gradient((x) -> logpdf(ef, x), x) ≈ tlogpdf + end + end + let tpdf = ForwardDiff.gradient((x) -> pdf(distribution, x), x) + if !any(isnan, tpdf) && !any(isinf, tpdf) + @test ForwardDiff.gradient((x) -> pdf(ef, x), x) ≈ tpdf + end + end + end + + # Test that the selected methods do not allocate + if assume_no_allocations + @test @allocated(logpdf(ef, x)) === 0 + @test @allocated(pdf(ef, x)) === 0 + @test @allocated(mean(ef)) === 0 + @test @allocated(var(ef)) === 0 + @test @allocated(basemeasure(ef, x)) === 0 + @test @allocated(logbasemeasure(ef, x)) === 0 + @test @allocated(sufficientstatistics(ef, x)) === 0 + end + end + + if test_samples_logpdf + @test @inferred(logpdf(ef, samples)) ≈ map((s) -> logpdf(distribution, s), samples) + @test @inferred(pdf(ef, samples)) ≈ map((s) -> pdf(distribution, s), samples) + end +end + +function run_test_fisherinformation_properties(distribution; test_properties_in_natural_space = true, test_properties_in_mean_space = true) + T = ExponentialFamily.exponential_family_typetag(distribution) + + ef = @inferred(convert(ExponentialFamilyDistribution, distribution)) + + (η, conditioner) = (getnaturalparameters(ef), getconditioner(ef)) + + if test_properties_in_natural_space + F = getfisherinformation(NaturalParametersSpace(), T, conditioner)(η) + + @test_opt getfisherinformation(NaturalParametersSpace(), T, conditioner)(η) + + @test issymmetric(F) || (LowerTriangular(F) ≈ (UpperTriangular(F)')) + @test isposdef(F) || all(>(0), eigvals(F)) + @test size(F, 1) === size(F, 2) + @test size(F, 1) === isqrt(length(F)) + @test (inv(fastcholesky(F)) * F ≈ Diagonal(ones(size(F, 1)))) rtol = 1e-2 + end + + if test_properties_in_mean_space + θ = map(NaturalParametersSpace() => MeanParametersSpace(), T, η, conditioner) + F = getfisherinformation(MeanParametersSpace(), T, conditioner)(θ) + + @test_opt getfisherinformation(MeanParametersSpace(), T, conditioner)(θ) + + @test issymmetric(F) || (LowerTriangular(F) ≈ (UpperTriangular(F)')) + @test isposdef(F) || all(>(0), eigvals(F)) + @test size(F, 1) === size(F, 2) + @test size(F, 1) === isqrt(length(F)) + @test (inv(fastcholesky(F)) * F ≈ Diagonal(ones(size(F, 1)))) rtol = 1e-2 + end +end + +function run_test_gradlogpartition_properties(distribution; nsamples = 6000, test_against_forwardiff = true) + ef = @inferred(convert(ExponentialFamilyDistribution, distribution)) + + (η, conditioner) = (getnaturalparameters(ef), getconditioner(ef)) + + rng = StableRNG(42) + # Some distributions do not use a vector to store a collection of samples (e.g. matrix for MvGaussian) + collection_of_samples = rand(rng, distribution, nsamples) + # The `check_logpdf` here converts the collection to a vector like iterable + _, samples = ExponentialFamily.check_logpdf(ef, collection_of_samples) + expectation_of_sufficient_statistics = mean((s) -> ExponentialFamily.pack_parameters(ExponentialFamily.sufficientstatistics(ef, s)), samples) + gradient = gradlogpartition(ef) + inverse_fisher = cholinv(fisherinformation(ef)) + @test length(gradient) === length(η) + @test dot(gradient - expectation_of_sufficient_statistics, inverse_fisher, gradient - expectation_of_sufficient_statistics) ≈ 0 atol = 0.01 + + if test_against_forwardiff + @test gradient ≈ ForwardDiff.gradient((η) -> getlogpartition(ef)(η), getnaturalparameters(ef)) + end +end + +function run_test_fisherinformation_against_hessian(distribution; assume_ours_faster = true, assume_no_allocations = true) + T = ExponentialFamily.exponential_family_typetag(distribution) + + ef = @inferred(convert(ExponentialFamilyDistribution, distribution)) + + (η, conditioner) = (getnaturalparameters(ef), getconditioner(ef)) + + @test fisherinformation(ef) ≈ ForwardDiff.hessian(η -> getlogpartition(NaturalParametersSpace(), T, conditioner)(η), η) + + # Double check the `conditioner` free methods + if isnothing(conditioner) + @test fisherinformation(ef) ≈ ForwardDiff.hessian(η -> getlogpartition(NaturalParametersSpace(), T)(η), η) + end + + if assume_ours_faster + @test @elapsed(fisherinformation(ef)) < (@elapsed(ForwardDiff.hessian(η -> getlogpartition(NaturalParametersSpace(), T, conditioner)(η), η))) + end + + if assume_no_allocations + @test @allocated(fisherinformation(ef)) === 0 + end +end + +function run_test_fisherinformation_against_jacobian( + distribution; + assume_no_allocations = true, + mappings = ( + NaturalParametersSpace() => MeanParametersSpace(), + MeanParametersSpace() => NaturalParametersSpace() + ) +) + T = ExponentialFamily.exponential_family_typetag(distribution) + + ef = @inferred(convert(ExponentialFamilyDistribution, distribution)) + + (η, conditioner) = (getnaturalparameters(ef), getconditioner(ef)) + θ = map(NaturalParametersSpace() => MeanParametersSpace(), T, η, conditioner) + + # Check natural to mean Jacobian based FI computation + # So here we check that the fisher information matrices are identical with respect to `J`, which is the jacobian of the + # transformation. For example if we have a mapping T : M -> N, the fisher information matrices computed in M and N + # respectively must follow this relation `Fₘ = J' * Fₙ * J` + for (M, N, parameters) in ((NaturalParametersSpace(), MeanParametersSpace(), η), (MeanParametersSpace(), NaturalParametersSpace(), θ)) + if (M => N) ∈ mappings + mapping = getmapping(M => N, T) + m = parameters + n = mapping(m, conditioner) + J = ForwardDiff.jacobian(Base.Fix2(mapping, conditioner), m) + Fₘ = getfisherinformation(M, T, conditioner)(m) + Fₙ = getfisherinformation(N, T, conditioner)(n) + + @test Fₘ ≈ (J' * Fₙ * J) + + # Check the default space + if M === NaturalParametersSpace() + # The `fisherinformation` uses the `NaturalParametersSpace` by default + @test fisherinformation(ef) ≈ (J' * Fₙ * J) + end + + # Double check the `conditioner` free methods + if isnothing(conditioner) + n = mapping(m) + J = ForwardDiff.jacobian(mapping, m) + Fₘ = getfisherinformation(M, T)(m) + Fₙ = getfisherinformation(N, T)(n) + + @test Fₘ ≈ (J' * Fₙ * J) + + if M === NaturalParametersSpace() + @test fisherinformation(ef) ≈ (J' * Fₙ * J) + end + end + + if assume_no_allocations + @test @allocated(getfisherinformation(M, T, conditioner)(m)) === 0 + @test @allocated(getfisherinformation(N, T, conditioner)(n)) === 0 + end + end + end +end + +# This generic testing works only for the same distributions `D` +function test_generic_simple_exponentialfamily_product( + left::Distribution, + right::Distribution; + strategies = (GenericProd(),), + test_inplace_version = true, + test_inplace_assume_no_allocations = true, + test_preserve_type_prod_of_distribution = true, + test_against_distributions_prod_if_possible = true +) + Tl = ExponentialFamily.exponential_family_typetag(left) + Tr = ExponentialFamily.exponential_family_typetag(right) + + @test Tl === Tr + + T = Tl + + efleft = @inferred(convert(ExponentialFamilyDistribution, left)) + efright = @inferred(convert(ExponentialFamilyDistribution, right)) + ηleft = @inferred(getnaturalparameters(efleft)) + ηright = @inferred(getnaturalparameters(efright)) + + if (!isnothing(getconditioner(efleft)) || !isnothing(getconditioner(efright))) + @test isapprox(getconditioner(efleft), getconditioner(efright)) + end + + prod_dist = prod(GenericProd(), left, right) + + @test_opt prod(GenericProd(), left, right) + + # We check against the `prod_dist` only if we have the proper solution, and skip if the result is of type `ProductOf` + if test_against_distributions_prod_if_possible && (prod_dist isa ProductOf || !(typeof(prod_dist) <: T)) + prod_dist = nothing + end + + for strategy in strategies + @test @inferred(prod(strategy, efleft, efright)) == ExponentialFamilyDistribution(T, ηleft + ηright, getconditioner(efleft)) + + # Double check the `conditioner` free methods + if isnothing(getconditioner(efleft)) && isnothing(getconditioner(efright)) + @test @inferred(prod(strategy, efleft, efright)) == ExponentialFamilyDistribution(T, ηleft + ηright) + end + + # Check that the result is consistent with the `prod_dist` + if !isnothing(prod_dist) + @test convert(T, prod(strategy, efleft, efright)) ≈ prod_dist + end + end + + if test_inplace_version + @test @inferred(prod!(similar(efleft), efleft, efright)) == + ExponentialFamilyDistribution(T, ηleft + ηright, getconditioner(efleft)) + + if test_inplace_assume_no_allocations + let _similar = similar(efleft) + @test @allocated(prod!(_similar, efleft, efright)) === 0 + end + end + end + + if test_preserve_type_prod_of_distribution + @test @inferred(prod(PreserveTypeProd(T), efleft, efright)) ≈ + prod(PreserveTypeProd(T), left, right) + end + + return true +end + +================ +File: distributions/erlang_tests.jl +================ +# Erlang comes from Distributions.jl and most of the things should be covered there +# Here we test some extra ExponentialFamily.jl specific functionality + +@testitem "Erlang: vague" begin + include("distributions_setuptests.jl") + + @test Erlang() == Erlang(1, 1.0) + @test vague(Erlang) == Erlang(1, 1e12) +end + +@testitem "Erlang: mean(::typeof(log))" begin + include("distributions_setuptests.jl") + + @test mean(log, Erlang(1, 3.0)) ≈ digamma(1) + log(3.0) + @test mean(log, Erlang(2, 0.3)) ≈ digamma(2) + log(0.3) + @test mean(log, Erlang(3, 0.3)) ≈ digamma(3) + log(0.3) +end + +@testitem "Erlang: ExponentialFamilyDistribution" begin + include("distributions_setuptests.jl") + + for a in 1:3, b in 1.0:1.0:3.0 + @testset let d = Erlang(a, b) + ef = test_exponentialfamily_interface(d; option_assume_no_allocations = true) + + (η1, η2) = (a - 1, -inv(b)) + for x in 10rand(4) + @test @inferred(isbasemeasureconstant(ef)) === ConstantBaseMeasure() + @test @inferred(basemeasure(ef, x)) === oneunit(x) + @test all(@inferred(sufficientstatistics(ef, x)) .≈ (log(x), x)) + @test @inferred(logpartition(ef)) ≈ (logfactorial(η1) - (η1 + one(η1)) * log(-η2)) + end + + @test !@inferred(insupport(ef, -0.5)) + @test @inferred(insupport(ef, 0.5)) + + # Not in the support + @test_throws Exception logpdf(ef, -0.5) + end + end + + # Test failing isproper cases + @test !isproper(MeanParametersSpace(), Erlang, [-1]) + @test !isproper(MeanParametersSpace(), Erlang, [1, -0.1]) + @test !isproper(MeanParametersSpace(), Erlang, [-0.1, 1]) + @test !isproper(NaturalParametersSpace(), Erlang, [-1.1]) + @test isproper(NaturalParametersSpace(), Erlang, [1, -1.1]) + @test !isproper(NaturalParametersSpace(), Erlang, [-1.1, 1]) +end + +@testitem "Erlang: prod with Distributions" begin + include("distributions_setuptests.jl") + + for strategy in (ClosedProd(), PreserveTypeProd(Distribution), PreserveTypeLeftProd(), PreserveTypeRightProd(), GenericProd()) + @test prod(strategy, Erlang(1, 1), Erlang(1, 1)) == Erlang(1, 1 / 2) + @test prod(strategy, Erlang(1, 2), Erlang(1, 1)) == Erlang(1, 2 / 3) + @test prod(strategy, Erlang(1, 2), Erlang(1, 2)) == Erlang(1, 1) + @test prod(strategy, Erlang(2, 2), Erlang(1, 2)) == Erlang(2, 1) + @test prod(strategy, Erlang(2, 2), Erlang(2, 2)) == Erlang(3, 1) + end + + @test @allocated(prod(ClosedProd(), Erlang(1, 1), Erlang(1, 1))) == 0 +end + +@testitem "Erlang: prod with ExponentialFamilyDistribution" begin + include("distributions_setuptests.jl") + + for aleft in 1:3, aright in 2:5, bleft in 0.51:1.0:5.0, bright in 0.51:1.0:5.0 + @testset let (left, right) = (Erlang(aleft, bleft), Erlang(aright, bright)) + @test test_generic_simple_exponentialfamily_product( + left, + right, + strategies = ( + ClosedProd(), + GenericProd(), + PreserveTypeProd(ExponentialFamilyDistribution), + PreserveTypeProd(ExponentialFamilyDistribution{Erlang}) + ) + ) + end + end +end + +================ +File: distributions/exponential_tests.jl +================ +# Exponential comes from Distributions.jl and most of the things should be covered there +# Here we test some extra ExponentialFamily.jl specific functionality + +@testitem "Exponential: vague" begin + include("distributions_setuptests.jl") + + d = vague(Exponential) + + @test typeof(d) <: Exponential + @test mean(d) === 1e12 + @test params(d) === (1e12,) +end + +@testitem "Exponential: mean(::typeof(log))" begin + include("distributions_setuptests.jl") + + @test mean(log, Exponential(1)) ≈ -MathConstants.eulergamma + @test mean(log, Exponential(10)) ≈ 1.7253694280925127 + @test mean(log, Exponential(0.1)) ≈ -2.8798007578955787 +end + +@testitem "Exponential: ExponentialFamilyDistribution" begin + include("distributions_setuptests.jl") + + for scale in (0.1, 1.0, 10.0, 10.0rand()) + @testset let d = Exponential(scale) + ef = test_exponentialfamily_interface(d; option_assume_no_allocations = true) + + (η₁,) = -inv(scale) + + for x in [100rand() for _ in 1:4] + @test @inferred(isbasemeasureconstant(ef)) === ConstantBaseMeasure() + @test @inferred(basemeasure(ef, x)) === oneunit(x) + @test all(@inferred(sufficientstatistics(ef, x)) .≈ (x,)) + @test @inferred(logpartition(ef)) ≈ (-log(-η₁)) + end + + @test !@inferred(insupport(ef, -0.5)) + @test @inferred(insupport(ef, 0.5)) + + # Not in the support + @test_throws Exception logpdf(ef, -0.5) + end + end + + # Test failing isproper cases + @test !isproper(MeanParametersSpace(), Exponential, [-1]) + @test !isproper(MeanParametersSpace(), Exponential, [1, -0.1]) + @test !isproper(MeanParametersSpace(), Exponential, [-0.1, 1]) + @test !isproper(NaturalParametersSpace(), Exponential, [1.1]) + @test !isproper(NaturalParametersSpace(), Exponential, [1, -1.1]) + @test !isproper(NaturalParametersSpace(), Exponential, [-1.1, 1]) +end + +@testitem "Exponential: prod with Distributiond" begin + include("distributions_setuptests.jl") + + for strategy in (GenericProd(), ClosedProd(), PreserveTypeProd(Distribution), PreserveTypeLeftProd(), PreserveTypeRightProd()) + @test prod(strategy, Exponential(5), Exponential(4)) ≈ Exponential(1 / 0.45) + @test prod(strategy, Exponential(1), Exponential(1)) ≈ Exponential(1 / 2) + @test prod(strategy, Exponential(0.1), Exponential(0.1)) ≈ Exponential(0.05) + end +end + +@testitem "Exponential: prod with ExponentialFamilyDistribution" begin + include("distributions_setuptests.jl") + + for sleft in (0.1, 1.0, 10.0, 10.0rand()), sright in (0.1, 1.0, 10.0, 10.0rand()) + @testset let (left, right) = (Exponential(sleft), Exponential(sright)) + @test test_generic_simple_exponentialfamily_product( + left, + right, + strategies = ( + ClosedProd(), + GenericProd(), + PreserveTypeProd(ExponentialFamilyDistribution), + PreserveTypeProd(ExponentialFamilyDistribution{Exponential}) + ) + ) + end + end +end + +================ +File: distributions/gamma_inverse_tests.jl +================ +# GammaInverse comes from Distributions.jl and most of the things should be covered there +# Here we test some extra ExponentialFamily.jl specific functionality + +@testitem "GammaInverse: vague" begin + include("distributions_setuptests.jl") + + d = vague(GammaInverse) + @test typeof(d) <: GammaInverse + @test mean(d) == huge + @test params(d) == (2.0, huge) +end + +# (α, θ) = (α_L + α_R + 1, θ_L + θ_R) +@testitem "GammaInverse: prod" begin + include("distributions_setuptests.jl") + + for strategy in (ClosedProd(), PreserveTypeProd(GammaInverse), PreserveTypeLeftProd(), PreserveTypeRightProd(), GenericProd()) + @test @inferred(prod(ClosedProd(), GammaInverse(3.0, 2.0), GammaInverse(2.0, 1.0))) ≈ GammaInverse(6.0, 3.0) + @test @inferred(prod(ClosedProd(), GammaInverse(7.0, 1.0), GammaInverse(0.1, 4.5))) ≈ GammaInverse(8.1, 5.5) + @test @inferred(prod(ClosedProd(), GammaInverse(1.0, 3.0), GammaInverse(0.2, 0.4))) ≈ GammaInverse(2.2, 3.4) + end +end + +# log(θ) - digamma(α) +@testitem "GammaInverse: mean(::typeof(log))" begin + include("distributions_setuptests.jl") + + @test mean(log, GammaInverse(1.0, 3.0)) ≈ 1.6758279535696414 + @test mean(log, GammaInverse(0.1, 0.3)) ≈ 9.21978213608514 + @test mean(log, GammaInverse(4.5, 0.3)) ≈ -2.5928437306854653 + @test mean(log, GammaInverse(42.0, 42.0)) ≈ 0.011952000346086233 +end + +# α / θ +@testitem "GammaInverse: mean(::typeof(inv))" begin + include("distributions_setuptests.jl") + + @test mean(inv, GammaInverse(1.0, 3.0)) ≈ 0.33333333333333333 + @test mean(inv, GammaInverse(0.1, 0.3)) ≈ 0.33333333333333337 + @test mean(inv, GammaInverse(4.5, 0.3)) ≈ 15.0000000000000000 + @test mean(inv, GammaInverse(42.0, 42.0)) ≈ 1.0000000000000000 +end + +@testitem "GammaInverse: ExponentialFamilyDistribution" begin + include("distributions_setuptests.jl") + + for α in (10rand(4) .+ 4.0), θ in 10rand(4) + @testset let d = InverseGamma(α, θ) + ef = test_exponentialfamily_interface(d; option_assume_no_allocations = true) + + (α, β) = params(MeanParametersSpace(), d) + + (η₁, η₂) = (-α - 1, -β) + + for x in 10rand(4) + @test @inferred(isbasemeasureconstant(ef)) === ConstantBaseMeasure() + @test @inferred(basemeasure(ef, x)) === oneunit(x) + @test all(@inferred(sufficientstatistics(ef, x)) .≈ (log(x), inv(x))) + @test @inferred(logpartition(ef)) ≈ (loggamma(-η₁ - 1) - (-η₁ - 1) * log(-η₂)) + end + + @test !@inferred(insupport(ef, -0.5)) + @test @inferred(insupport(ef, 0.5)) + + # Not in the support + @test_throws Exception logpdf(ef, -0.5) + end + end + + # Test failing isproper cases + @test !isproper(MeanParametersSpace(), InverseGamma, [-1]) + @test !isproper(MeanParametersSpace(), InverseGamma, [1, -0.1]) + @test !isproper(MeanParametersSpace(), InverseGamma, [-0.1, 1]) + @test !isproper(NaturalParametersSpace(), InverseGamma, [-0.5]) + @test !isproper(NaturalParametersSpace(), InverseGamma, [1, -1.1]) + @test !isproper(NaturalParametersSpace(), InverseGamma, [-0.5, 1]) +end + +@testitem "GammaInverse: prod with ExponentialFamilyDistribution" begin + include("distributions_setuptests.jl") + + for aleft in 10rand(4), aright in 10rand(4), bleft in 10rand(4), bright in 10rand(4) + @testset let (left, right) = (InverseGamma(aleft, bleft), InverseGamma(aright, bright)) + @test test_generic_simple_exponentialfamily_product( + left, + right, + strategies = ( + ClosedProd(), + GenericProd(), + PreserveTypeProd(ExponentialFamilyDistribution), + PreserveTypeProd(ExponentialFamilyDistribution{InverseGamma}) + ) + ) + end + end +end + +================ +File: distributions/geometric_tests.jl +================ +# Geometric comes from Distributions.jl and most of the things should be covered there +# Here we test some extra ExponentialFamily.jl specific functionality + +@testitem "Geometric: ExponentialFamilyDistribution" begin + include("distributions_setuptests.jl") + + @testset for p in 0.1:0.2:1.0 + @testset let d = Geometric(p) + ef = test_exponentialfamily_interface(d; option_assume_no_allocations = true) + η1 = first(getnaturalparameters(ef)) + + for x in (1, 3, 5) + @test @inferred(isbasemeasureconstant(ef)) === ConstantBaseMeasure() + @test @inferred(basemeasure(ef, x)) === one(x) + @test @inferred(sufficientstatistics(ef, x)) === (x,) + @test @inferred(logpartition(ef)) ≈ -log(one(η1) - exp(η1)) + end + end + end + + for space in (MeanParametersSpace(), NaturalParametersSpace()) + @test !isproper(space, Geometric, [2.0]) + @test !isproper(space, Geometric, [Inf]) + @test !isproper(space, Geometric, [NaN]) + @test !isproper(space, Geometric, [1.0], NaN) + @test !isproper(space, Geometric, [0.5, 0.5], 1.0) + end + ## mean parameter should be integer in the MeanParametersSpace + @test !isproper(MeanParametersSpace(), Geometric, [-0.1]) + @test_throws Exception convert(ExponentialFamilyDistribution, Geometric(Inf)) +end + +@testitem "Geometric: prod with Distributions" begin + include("distributions_setuptests.jl") + + for strategy in (GenericProd(), ClosedProd(), PreserveTypeProd(Distribution), PreserveTypeLeftProd(), PreserveTypeRightProd()) + @test prod(strategy, Geometric(0.5), Geometric(0.6)) == Geometric(0.8) + @test prod(strategy, Geometric(0.3), Geometric(0.8)) == Geometric(0.8600000000000001) + @test prod(strategy, Geometric(0.5), Geometric(0.5)) == Geometric(0.75) + end + + @test @allocated(prod(ClosedProd(), Geometric(0.5), Geometric(0.6))) == 0 +end + +@testitem "Geometric: prod with ExponentialFamilyDistribution" begin + include("distributions_setuptests.jl") + + for pleft in 0.1:0.2:0.5, pright in 0.5:0.2:1.0 + @testset let (left, right) = (Geometric(pleft), Geometric(pright)) + @test test_generic_simple_exponentialfamily_product( + left, + right, + strategies = ( + ClosedProd(), + GenericProd(), + PreserveTypeProd(ExponentialFamilyDistribution), + PreserveTypeProd(ExponentialFamilyDistribution{Geometric}) + ) + ) + end + end +end + +================ +File: distributions/laplace_tests.jl +================ +# Laplace comes from Distributions.jl and most of the things should be covered there +# Here we test some extra ExponentialFamily.jl specific functionality + +@testitem "Laplace: vague" begin + include("distributions_setuptests.jl") + + d = vague(Laplace) + + @test typeof(d) <: Laplace + @test mean(d) === 0.0 + @test params(d) === (0.0, 1e12) +end + +@testitem "Laplace: ExponentialFamilyDistribution" begin + include("distributions_setuptests.jl") + + for location in (-1.0, 0.0, 1.0), scale in (0.25, 0.5, 2.0) + @testset let d = Laplace(location, scale) + ef = test_exponentialfamily_interface(d; option_assume_no_allocations = true) + η₁ = -1 / scale + + for x in (-1.0, 0.0, 1.0) + @test @inferred(isbasemeasureconstant(ef)) === ConstantBaseMeasure() + @test @inferred(basemeasure(ef, x)) === oneunit(x) + @test @inferred(sufficientstatistics(ef, x)) === (abs(x - location),) + @test @inferred(logpartition(ef)) ≈ log(-2 / η₁) + end + end + end + + for space in (MeanParametersSpace(), NaturalParametersSpace()) + @test !isproper(space, Laplace, [Inf], 1.0) + @test !isproper(space, Laplace, [1.0], Inf) + @test !isproper(space, Laplace, [NaN], 1.0) + @test !isproper(space, Laplace, [1.0], NaN) + @test !isproper(space, Laplace, [0.5, 0.5], 1.0) + + # Conditioner is required + @test_throws Exception isproper(space, Laplace, [0.5], [0.5, 0.5]) + @test_throws Exception isproper(space, Laplace, [1.0], nothing) + @test_throws Exception isproper(space, Laplace, [1.0], nothing) + end + + @test_throws Exception convert(ExponentialFamilyDistribution, Laplace(Inf, Inf)) +end + +@testitem "Laplace: prod with Distribution" begin + include("distributions_setuptests.jl") + + @test @inferred(prod(PreserveTypeProd(Laplace), Laplace(0.0, 0.5), Laplace(0.0, 0.5))) ≈ Laplace(0.0, 0.25) + @test @inferred(prod(PreserveTypeProd(Laplace), Laplace(1.0, 1.0), Laplace(1.0, 1.0))) ≈ Laplace(1.0, 0.5) + @test @inferred(prod(PreserveTypeProd(Laplace), Laplace(2.0, 3.0), Laplace(2.0, 7.0))) ≈ Laplace(2.0, 2.1) + + # GenericProd should always check the default strategy and fallback if available + @test @inferred(prod(GenericProd(), Laplace(0.0, 0.5), Laplace(0.0, 0.5))) ≈ Laplace(0.0, 0.25) + @test @inferred(prod(GenericProd(), Laplace(1.0, 1.0), Laplace(1.0, 1.0))) ≈ Laplace(1.0, 0.5) + @test @inferred(prod(GenericProd(), Laplace(2.0, 3.0), Laplace(2.0, 7.0))) ≈ Laplace(2.0, 2.1) + + # Different location parameters cannot be compute a closed prod with the same type + @test_throws Exception prod(PreserveTypeProd(Laplace), Laplace(0.0, 0.5), Laplace(0.01, 0.5)) + @test_throws Exception prod(PreserveTypeProd(Laplace), Laplace(1.0, 0.5), Laplace(-1.0, 0.5)) +end + +@testitem "Laplace: prod with ExponentialFamilyDistribution: same location parameter" begin + include("distributions_setuptests.jl") + + for location in (0.0, 1.0), sleft in 0.1:0.1:0.9, sright in 0.1:0.1:0.9 + @testset let (left, right) = (Laplace(location, sleft), Laplace(location, sright)) + @test test_generic_simple_exponentialfamily_product( + left, + right, + strategies = (PreserveTypeProd(ExponentialFamilyDistribution{Laplace}), GenericProd()) + ) + end + end + + # Different location parameters cannot be compute a closed prod with the same type + @test_throws Exception prod( + PreserveTypeProd(ExponentialFamilyDistribution{Laplace}), + convert(ExponentialFamilyDistribution, Laplace(0.0, 0.5)), + convert(ExponentialFamilyDistribution, Laplace(0.01, 0.5)) + ) + @test_throws Exception prod( + PreserveTypeProd(ExponentialFamilyDistribution{Laplace}), + convert(ExponentialFamilyDistribution, Laplace(1.0, 0.5)), + convert(ExponentialFamilyDistribution, Laplace(-1.0, 0.5)) + ) +end + +@testitem "Laplace: prod with ExponentialFamilyDistribution: different location parameter" begin + include("distributions_setuptests.jl") + + for locationleft in (0.0, 1.0), sleft in 0.1:0.1:0.4, locationright in (2.0, 3.0), sright in 1.1:0.1:1.3 + @testset let (left, right) = (Laplace(locationleft, sleft), Laplace(locationright, sright)) + ef_left = convert(ExponentialFamilyDistribution, left) + ef_right = convert(ExponentialFamilyDistribution, right) + ef_prod = prod(PreserveTypeProd(ExponentialFamilyDistribution), ef_left, ef_right) + @test first(hquadrature(x -> pdf(ef_prod, tan(x * pi / 2)) * (pi / 2) * (1 / cos(x * pi / 2)^2), -1.0, 1.0)) ≈ 1.0 atol = 1e-6 + end + end +end + +================ +File: distributions/lognormal_tests.jl +================ +# LogNormal comes from Distributions.jl and most of the things should be covered there +# Here we test some extra ExponentialFamily.jl specific functionality + +@testitem "LogNormal: vague" begin + include("distributions_setuptests.jl") + + @test LogNormal() == LogNormal(0.0, 1.0) + @test typeof(vague(LogNormal)) <: LogNormal + @test vague(LogNormal) == LogNormal(1, 1e12) +end + +@testitem "LogNormal: prod with Distribution" begin + include("distributions_setuptests.jl") + + for strategy in (ClosedProd(), PreserveTypeProd(LogNormal), PreserveTypeLeftProd(), PreserveTypeRightProd(), GenericProd()) + @test prod(strategy, LogNormal(1.0, 1.0), LogNormal(1.0, 1.0)) == LogNormal(0.5, sqrt(1 / 2)) + @test prod(strategy, LogNormal(2.0, 1.0), LogNormal(2.0, 1.0)) == LogNormal(1.5, sqrt(1 / 2)) + @test prod(strategy, LogNormal(1.0, 1.0), LogNormal(2.0, 1.0)) == LogNormal(1.0, sqrt(1 / 2)) + @test prod(strategy, LogNormal(1.0, 2.0), LogNormal(1.0, 2.0)) == LogNormal(-1.0, sqrt(2)) + @test prod(strategy, LogNormal(2.0, 2.0), LogNormal(2.0, 2.0)) == LogNormal(0.0, sqrt(2)) + end +end + +@testitem "LogNormal: prod with ExponentialFamilyDistribution" begin + include("distributions_setuptests.jl") + + for μleft in 10randn(4), μright in 10randn(4), σleft in 10rand(4), σright in 10rand(4) + @testset let (left, right) = (LogNormal(μleft, σleft), LogNormal(μright, σright)) + @test test_generic_simple_exponentialfamily_product( + left, + right, + strategies = ( + ClosedProd(), + GenericProd(), + PreserveTypeProd(ExponentialFamilyDistribution), + PreserveTypeProd(ExponentialFamilyDistribution{LogNormal}) + ) + ) + end + end +end + +@testitem "LogNormal: ExponentialFamilyDistribution" begin + include("distributions_setuptests.jl") + + for μ in 10randn(4), σ in 10rand(4) + @testset let d = LogNormal(μ, σ) + ef = test_exponentialfamily_interface(d; option_assume_no_allocations = true) + + (η₁, η₂) = (μ / abs2(σ) - 1, -1 / (2abs2(σ))) + + for x in 10rand(4) + @test @inferred(isbasemeasureconstant(ef)) === ConstantBaseMeasure() + @test @inferred(basemeasure(ef, x)) ≈ invsqrt2π + @test @inferred(sufficientstatistics(ef, x)) === (log(x), abs2(log(x))) + @test @inferred(logpartition(ef)) ≈ (-(η₁ + 1)^2 / (4η₂) - 1 / 2 * log(-2η₂)) + end + end + end + + @test !isproper(MeanParametersSpace(), LogNormal, [1.0]) + @test !isproper(MeanParametersSpace(), LogNormal, [-1.0, 0.0]) + @test !isproper(MeanParametersSpace(), LogNormal, [1.0, -1.0]) + @test !isproper(NaturalParametersSpace(), LogNormal, [1.0]) + @test !isproper(NaturalParametersSpace(), LogNormal, [-1.0, 0.0]) + @test !isproper(NaturalParametersSpace(), LogNormal, [1.0, 1.0]) +end + +================ +File: distributions/matrix_dirichlet_tests.jl +================ +@testitem "MatrixDirichlet: common" begin + include("distributions_setuptests.jl") + + @test MatrixDirichlet <: Distribution + @test MatrixDirichlet <: ContinuousDistribution + @test MatrixDirichlet <: MatrixDistribution + + @test value_support(MatrixDirichlet) === Continuous + @test variate_form(MatrixDirichlet) === Matrixvariate +end + +@testitem "MatrixDirichlet: vague" begin + include("distributions_setuptests.jl") + + @test_throws MethodError vague(MatrixDirichlet) + + d1 = vague(MatrixDirichlet, 3) + + @test typeof(d1) <: MatrixDirichlet + @test mean(d1) == ones(3, 3) ./ sum(ones(3, 3); dims = 1) + + d2 = vague(MatrixDirichlet, 4) + + @test typeof(d2) <: MatrixDirichlet + @test mean(d2) == ones(4, 4) ./ sum(ones(4, 4); dims = 1) + + @test vague(MatrixDirichlet, 3, 3) == vague(MatrixDirichlet, (3, 3)) + @test vague(MatrixDirichlet, 4, 4) == vague(MatrixDirichlet, (4, 4)) + @test vague(MatrixDirichlet, 3, 4) == vague(MatrixDirichlet, (3, 4)) + @test vague(MatrixDirichlet, 4, 3) == vague(MatrixDirichlet, (4, 3)) + + d3 = vague(MatrixDirichlet, 3, 4) + + @test typeof(d3) <: MatrixDirichlet + @test mean(d3) == ones(3, 4) ./ sum(ones(3, 4); dims = 1) +end + +@testitem "MatrixDirichlet: entropy" begin + include("distributions_setuptests.jl") + + @test entropy(MatrixDirichlet([1.0 1.0; 1.0 1.0; 1.0 1.0])) ≈ -1.3862943611198906 + @test entropy(MatrixDirichlet([1.2 3.3; 4.0 5.0; 2.0 1.1])) ≈ -3.1139933152617787 + @test entropy(MatrixDirichlet([0.2 3.4; 5.0 11.0; 0.2 0.6])) ≈ -11.444984495104693 +end + +@testitem "MatrixDirichlet: mean(::typeof(log))" begin + include("distributions_setuptests.jl") + + import Base.Broadcast: BroadcastFunction + + @test mean(BroadcastFunction(log), MatrixDirichlet([1.0 1.0; 1.0 1.0; 1.0 1.0])) ≈ [ + -1.5000000000000002 -1.5000000000000002 + -1.5000000000000002 -1.5000000000000002 + -1.5000000000000002 -1.5000000000000002 + ] + @test mean(BroadcastFunction(log), MatrixDirichlet([1.2 3.3; 4.0 5.0; 2.0 1.1])) ≈ [ + -2.1920720408623637 -1.1517536610071326 + -0.646914475838374 -0.680458481634953 + -1.480247809171707 -2.6103310904778305 + ] + @test mean(BroadcastFunction(log), MatrixDirichlet([0.2 3.4; 5.0 11.0; 0.2 0.6])) ≈ [ + -6.879998107291004 -1.604778825293528 + -0.08484054226701443 -0.32259407259407213 + -6.879998107291004 -4.214965875553984 + ] +end + +@testitem "MatrixDirichlet: ExponentialFamilyDistribution" begin + include("distributions_setuptests.jl") + + for len in 3:5 + α = rand(1.0:2.0, len, len) + @testset let d = MatrixDirichlet(α) + ef = test_exponentialfamily_interface(d; test_basic_functions = true, option_assume_no_allocations = false) + η1 = getnaturalparameters(ef) + + for x in [rand(1.0:2.0, len, len) for _ in 1:3] + x = x ./ sum(x) + @test @inferred(isbasemeasureconstant(ef)) === ConstantBaseMeasure() + @test @inferred(basemeasure(ef, x)) === 1.0 + @test all(@inferred(sufficientstatistics(ef, x)) .≈ (map(log, x),)) + @test @inferred(logpartition(ef)) ≈ mapreduce( + d -> getlogpartition(NaturalParametersSpace(), Dirichlet)(convert(Vector, d)), + +, + eachcol(first(unpack_parameters(MatrixDirichlet, η1))) + ) + end + end + end + + for space in (MeanParametersSpace(), NaturalParametersSpace()) + @test !isproper(space, MatrixDirichlet, [Inf Inf; Inf 1.0], 1.0) + @test !isproper(space, MatrixDirichlet, [1.0], Inf) + @test !isproper(space, MatrixDirichlet, [NaN], 1.0) + @test !isproper(space, MatrixDirichlet, [1.0], NaN) + @test !isproper(space, MatrixDirichlet, [0.5, 0.5], 1.0) + @test isproper(space, MatrixDirichlet, [2.0, 3.0]) + @test !isproper(space, MatrixDirichlet, [-1.0, -1.2]) + end + + @test_throws Exception convert(ExponentialFamilyDistribution, MatrixDirichlet([Inf Inf; 2 3])) +end + +@testitem "MatrixDirichlet: prod with Distribution" begin + include("distributions_setuptests.jl") + + d1 = MatrixDirichlet([0.2 3.4; 5.0 11.0; 0.2 0.6]) + d2 = MatrixDirichlet([1.2 3.3; 4.0 5.0; 2.0 1.1]) + d3 = MatrixDirichlet([1.0 1.0; 1.0 1.0; 1.0 1.0]) + for strategy in (GenericProd(), ClosedProd(), PreserveTypeProd(Distribution), PreserveTypeLeftProd(), PreserveTypeRightProd()) + @test @inferred(prod(strategy, d1, d2)) ≈ + MatrixDirichlet([0.3999999999999999 5.699999999999999; 8.0 15.0; 1.2000000000000002 0.7000000000000002]) + @test @inferred(prod(strategy, d1, d3)) ≈ MatrixDirichlet( + [0.19999999999999996 3.4000000000000004; 5.0 11.0; 0.19999999999999996 0.6000000000000001] + ) + @test @inferred(prod(strategy, d2, d3)) ≈ MatrixDirichlet([1.2000000000000002 3.3; 4.0 5.0; 2.0 1.1]) + end +end + +@testitem "MatrixDirichlet: prod with ExponentialFamilyDistribution" begin + include("distributions_setuptests.jl") + + for len in 3:6 + αleft = rand(len, len) .+ 1 + αright = rand(len, len) .+ 1 + @testset let (left, right) = (MatrixDirichlet(αleft), MatrixDirichlet(αright)) + @test test_generic_simple_exponentialfamily_product( + left, + right, + strategies = ( + ClosedProd(), + GenericProd() + ) + ) + end + end +end + +@testitem "MatrixDirichlet: promote_variate_type" begin + include("distributions_setuptests.jl") + + @test_throws MethodError promote_variate_type(Univariate, MatrixDirichlet) + + @test promote_variate_type(Multivariate, Dirichlet) === Dirichlet + @test promote_variate_type(Matrixvariate, Dirichlet) === MatrixDirichlet + + @test promote_variate_type(Multivariate, MatrixDirichlet) === Dirichlet + @test promote_variate_type(Matrixvariate, MatrixDirichlet) === MatrixDirichlet +end + +================ +File: distributions/mv_normal_wishart_tests.jl +================ +@testitem "MvNormalWishart: common" begin + include("distributions_setuptests.jl") + + m = rand(2) + dist = MvNormalWishart(m, [1.0 0.0; 0.0 1.0], 0.1, 3.0) + @test params(dist) == (m, [1.0 0.0; 0.0 1.0], 0.1, 3.0) + @test dof(dist) == 3.0 + @test invscatter(dist) == [1.0 0.0; 0.0 1.0] + @test scale(dist) == 0.1 + @test locationdim(dist) == 2 +end + +@testitem "MvNormalWishart: ExponentialFamilyDistribution" begin + include("distributions_setuptests.jl") + + for dim in (3,), invS in rand(Wishart(10, Array(Eye(dim))), 4) + ν = dim + 2 + @testset let (d = MvNormalWishart(rand(dim), invS, rand(), ν)) + ef = test_exponentialfamily_interface( + d; + option_assume_no_allocations = false, + test_basic_functions = false, + test_fisherinformation_against_hessian = false, + test_fisherinformation_against_jacobian = false, + test_gradlogpartition_properties = false, + test_plogpdf_interface = false + ) + + run_test_basic_functions(d; assume_no_allocations = false, test_samples_logpdf = false) + end + end +end + +@testitem "MvNormalWishart: prod" begin + include("distributions_setuptests.jl") + + for j in 2:2, κ in 1:2 + m1 = rand(j) + m2 = rand(j) + Ψ1 = m1 * m1' + I + Ψ2 = m2 * m2' + I + dist1 = MvNormalWishart(m1, Ψ1, κ + rand(), rand() + 4) + dist2 = MvNormalWishart(m2, Ψ2, κ + rand(), rand() + 4) + ef1 = convert(ExponentialFamilyDistribution, dist1) + ef2 = convert(ExponentialFamilyDistribution, dist2) + @test prod(PreserveTypeProd(Distribution), dist1, dist2) ≈ convert(Distribution, prod(ClosedProd(), ef1, ef2)) + end +end + +@testitem "MvNormalWishart: prod with ExponentialFamilyDistribution{MvNormalWishart}" begin + include("distributions_setuptests.jl") + + for Sleft in rand(Wishart(10, Array(Eye(2))), 2), Sright in rand(Wishart(10, Array(Eye(2))), 2), νright in (6, 7), νleft in (4, 5) + @testset let (left, right) = (MvNormalWishart(rand(2), Sleft, rand(), νleft), MvNormalWishart(rand(2), Sright, rand(), νright)) + @test test_generic_simple_exponentialfamily_product( + left, + right, + strategies = (PreserveTypeProd(ExponentialFamilyDistribution{MvNormalWishart}), GenericProd()) + ) + end + end +end + +================ +File: distributions/negative_binomial_tests.jl +================ +# NegativeBinomial comes from Distributions.jl and most of the things should be covered there +# Here we test some extra ExponentialFamily.jl specific functionality + +@testitem "NegativeBinomial: probvec" begin + include("distributions_setuptests.jl") + + @test all(probvec(NegativeBinomial(2, 0.8)) .≈ (0.2, 0.8)) + @test probvec(NegativeBinomial(2, 0.2)) == (0.8, 0.2) + @test probvec(NegativeBinomial(2, 0.1)) == (0.9, 0.1) + @test probvec(NegativeBinomial(2)) == (0.5, 0.5) +end + +@testitem "NegativeBinomial: vague" begin + include("distributions_setuptests.jl") + + @test_throws MethodError vague(NegativeBinomial) + @test_throws MethodError vague(NegativeBinomial, 1 / 2) + + vague_dist = vague(NegativeBinomial, 5) + @test typeof(vague_dist) <: NegativeBinomial + @test probvec(vague_dist) == (0.5, 0.5) +end + +@testitem "NegativeBinomial: ExponentialFamilyDistribution" begin + include("distributions_setuptests.jl") + + for p in (0.1, 0.4), r in (2, 3, 4) + @testset let d = NegativeBinomial(r, p) + ef = test_exponentialfamily_interface(d; option_assume_no_allocations = false) + for x in 2:4 + @test @inferred(isbasemeasureconstant(ef)) === NonConstantBaseMeasure() + @test @inferred(basemeasure(ef, x)) === binomial(Int(x + r - 1), x) + @test @inferred(sufficientstatistics(ef, x)) === (x,) + @test @inferred(logpartition(ef)) ≈ -r * log(1 - exp(getnaturalparameters(ef)[1])) + end + end + end + + for space in (MeanParametersSpace(), NaturalParametersSpace()) + @test !isproper(space, NegativeBinomial, [Inf], 1.0) + @test !isproper(space, NegativeBinomial, [1.0], Inf) + @test !isproper(space, NegativeBinomial, [NaN], 1.0) + @test !isproper(space, NegativeBinomial, [1.0], NaN) + @test !isproper(space, NegativeBinomial, [0.5, 0.5], 1.0) + + # Conditioner is required + @test_throws Exception isproper(space, NegativeBinomial, [0.5], [0.5, 0.5]) + @test_throws Exception isproper(space, NegativeBinomial, [1.0], nothing) + @test_throws Exception isproper(space, NegativeBinomial, [1.0], nothing) + end + + @test_throws Exception convert(ExponentialFamilyDistribution, NegativeBinomial(Inf, Inf)) +end + +@testitem "NegativeBinomial: prod" begin + include("distributions_setuptests.jl") + + for nleft in 3:5, pleft in 0.01:0.3:0.99 + left = NegativeBinomial(nleft, pleft) + efleft = convert(ExponentialFamilyDistribution, left) + η_left = getnaturalparameters(efleft) + for nright in 6:7, pright in 0.01:0.3:0.99 + right = NegativeBinomial(nright, pright) + efright = convert(ExponentialFamilyDistribution, right) + η_right = first(getnaturalparameters(efright)) + prod_dist = prod(PreserveTypeProd(ExponentialFamilyDistribution), left, right) + + @test sum(pdf(prod_dist, x) for x in 0:max(nleft, nright)) ≈ 1.0 atol = 1e-5 + end + end +end + +================ +File: distributions/normal_gamma_tests.jl +================ +@testitem "NormalGamma: common" begin + include("distributions_setuptests.jl") + + m = rand() + s, a, b = 1.0, 0.1, 3.0 + dist = NormalGamma(m, s, a, b) + @test params(dist) == (m, s, a, b) + @test location(dist) == m + @test scale(dist) == s + @test shape(dist) == a + @test rate(dist) == b +end + +@testitem "NormalGamma: ExponentialFamilyDistribution" begin + include("distributions_setuptests.jl") + + for μ in 10randn(3), λ in 10rand(3), α in (1 .+ 10rand(3)), β in 10 * rand(3) + @testset let d = NormalGamma(μ, λ, α, β) + ef = test_exponentialfamily_interface(d; option_assume_no_allocations = false) + + (η1, η2, η3, η4) = unpack_parameters(NormalGamma, getnaturalparameters(ef)) + η3half = η3 + 1 / 2 + for x in rand(d, 3) + @test @inferred(isbasemeasureconstant(ef)) === ConstantBaseMeasure() + @test @inferred(basemeasure(ef, x)) ≈ invsqrt2π + @test @inferred(sufficientstatistics(ef, x)) === (x[1] * x[2], x[1]^2 * x[2], log(x[2]), x[2]) + @test @inferred(logpartition(ef)) ≈ loggamma(η3half) - log(-2η2) * (1 / 2) - (η3half) * log(-η4 + η1^2 / (4η2)) + end + end + end + + @test !isproper(MeanParametersSpace(), NormalGamma, [1.0, 0.0, -1.0, 2.0]) + @test !isproper(MeanParametersSpace(), NormalGamma, [1.0, 0.0, -1.0, 2.0]) + @test !isproper(NaturalParametersSpace(), NormalGamma, [1.0, -0.2, 1.0, -1/0.8 ]) + @test !isproper(NaturalParametersSpace(), NormalGamma, [1.0, -0.2, 1.0, 10/8]) + @test isproper(NaturalParametersSpace(), NormalGamma, [1.0, -0.2, 1.0, -11/8]) + @test !isproper(MeanParametersSpace(), NormalGamma, [-1.0, 0.0, NaN, 1.0], [Inf]) +end + +@testitem "NormalGamma: prod with Distribution" begin + include("distributions_setuptests.jl") + + for strategy in (ClosedProd(), PreserveTypeProd(NormalGamma), PreserveTypeLeftProd(), PreserveTypeRightProd(), GenericProd()) + @test prod(strategy, NormalGamma(1.0, 1.0, 2.0, 3.0), NormalGamma(1.0, 1.0, 5.0, 6.0)) == NormalGamma(1.0, 2.0, 6.5, 9.0) + @test prod(strategy, NormalGamma(2.0, 1.0, 3.0, 4.0), NormalGamma(2.0, 1.0, 0.4, 2.0)) == NormalGamma(2.0, 2.0, 2.9, 6.0) + end +end + +@testitem "NormalGamma: prod with ExponentialFamilyDistribution" begin + include("distributions_setuptests.jl") + + for μleft in 10randn(2), μright in 10randn(2), σleft in 10rand(2), σright in 10rand(2), + αleft in (1 .+ 10rand(2)), αright in (1 .+ 10rand(2)), βleft in 10rand(2), βright in 10rand(2) + + let left = NormalGamma(μleft, σleft, αleft, βleft), right = NormalGamma(μright, σright, αright + 1 / 2, βright) + @test test_generic_simple_exponentialfamily_product( + left, + right, + strategies = ( + ClosedProd(), + GenericProd(), + PreserveTypeProd(ExponentialFamilyDistribution), + PreserveTypeProd(ExponentialFamilyDistribution{NormalGamma}) + ) + ) + end + end +end + +================ +File: distributions/pareto_tests.jl +================ +# Pareto comes from Distributions.jl and most of the things should be covered there +# Here we test some extra ExponentialFamily.jl specific functionality + +@testitem "Pareto: ExponentialFamilyDistribution" begin + include("distributions_setuptests.jl") + + for shape in (5.0, 6.0, 7.0), scale in (0.25, 0.5, 2.0) + @testset let d = Pareto(shape, scale) + ef = test_exponentialfamily_interface(d; option_assume_no_allocations = false) + η1 = -shape - 1 + for x in scale:1.0:scale+3.0 + @test @inferred(isbasemeasureconstant(ef)) === ConstantBaseMeasure() + @test @inferred(basemeasure(ef, x)) === oneunit(x) + @test @inferred(sufficientstatistics(ef, x)) === (log(x),) + @test @inferred(logpartition(ef)) ≈ log(scale^(one(η1) + η1) / (-one(η1) - η1)) + end + end + end + + for space in (MeanParametersSpace(), NaturalParametersSpace()) + @test !isproper(space, Pareto, [Inf], 1.0) + @test !isproper(space, Pareto, [1.0], Inf) + @test !isproper(space, Pareto, [NaN], 1.0) + @test !isproper(space, Pareto, [1.0], NaN) + @test !isproper(space, Pareto, [0.5, 0.5], 1.0) + + # Conditioner is required + @test_throws Exception isproper(space, Pareto, [0.5], [0.5, 0.5]) + @test_throws Exception isproper(space, Pareto, [1.0], nothing) + @test_throws Exception isproper(space, Pareto, [1.0], nothing) + end + + @test_throws Exception convert(ExponentialFamilyDistribution, Pareto(Inf, Inf)) +end + +@testitem "Pareto: prod with Distributions" begin + include("distributions_setuptests.jl") + + @test prod(PreserveTypeProd(Pareto), Pareto(0.5), Pareto(0.6)) == Pareto(2.1) + @test prod(PreserveTypeProd(Pareto), Pareto(0.3), Pareto(0.8)) == Pareto(2.1) + @test prod(PreserveTypeProd(Pareto), Pareto(0.5), Pareto(0.5)) == Pareto(2.0) + @test prod(PreserveTypeProd(Pareto), Pareto(3), Pareto(2)) == Pareto(6.0) +end + +@testitem "Pareto: prod with ExponentialFamilyDistribution" begin + include("distributions_setuptests.jl") + + for conditioner in (0.01, 1.0), alphaleft in 0.1:0.1:0.9, alpharight in 0.1:0.1:0.9 + let left = Pareto(alphaleft, conditioner), right = Pareto(alpharight, conditioner) + @test test_generic_simple_exponentialfamily_product( + left, + right, + strategies = (PreserveTypeProd(ExponentialFamilyDistribution{Pareto}), GenericProd()) + ) + end + end + + # Different conditioner parameters cannot be compute a closed prod with the same type + @test_throws Exception prod( + PreserveTypeProd(ExponentialFamilyDistribution{Pareto}), + convert(ExponentialFamilyDistribution, Pareto(0.0, 0.54)), + convert(ExponentialFamilyDistribution, Pareto(0.01, 0.5)) + ) + @test_throws Exception prod( + PreserveTypeProd(ExponentialFamilyDistribution{Pareto}), + convert(ExponentialFamilyDistribution, Pareto(1.0, 0.56)), + convert(ExponentialFamilyDistribution, Pareto(2.0, 0.5)) + ) +end + +@testitem "Pareto: prod with different conditioner" begin + include("distributions_setuptests.jl") + + for conditioner_left in (2, 3), conditioner_right in (4, 5), alphaleft in 0.1:0.1:0.3, alpharight in 0.1:0.1:0.3 + let left = Pareto(alphaleft, conditioner_left), right = Pareto(alpharight, conditioner_right) + ef_left = convert(ExponentialFamilyDistribution, left) + ef_right = convert(ExponentialFamilyDistribution, right) + prod_dist = prod(PreserveTypeProd(ExponentialFamilyDistribution), ef_left, ef_right) + @test getnaturalparameters(prod_dist) ≈ getnaturalparameters(ef_left) + getnaturalparameters(ef_right) + @test getsupport(prod_dist).lb == max(conditioner_left, conditioner_right) + @test sufficientstatistics(prod_dist, (max(conditioner_left, conditioner_right) + 1)) === (log(max(conditioner_left, conditioner_right) + 1),) + @test first( + hquadrature(x -> pdf(prod_dist, tan(x * pi / 2)) * (pi / 2) * (1 / cos(x * pi / 2)^2), (2 / pi) * atan(getsupport(prod_dist).lb), 1.0) + ) ≈ 1.0 + end + end +end + +================ +File: distributions/poisson_tests.jl +================ +# Poisson comes from Distributions.jl and most of the things should be covered there +# Here we test some extra ExponentialFamily.jl specific functionality + +@testitem "Poisson: ExponentialFamilyDistribution" begin + include("distributions_setuptests.jl") + + @testset for i in 2:7 + @testset let d = Poisson(2 * (i + 1)) + ef = test_exponentialfamily_interface(d; option_assume_no_allocations = true) + η1 = first(getnaturalparameters(ef)) + + for x in 1:5 + @test @inferred(isbasemeasureconstant(ef)) === NonConstantBaseMeasure() + @test @inferred(basemeasure(ef, x)) === 1 / factorial(x) + @test @inferred(logbasemeasure(ef, x)) === -loggamma(x + one(x)) + @test @inferred(sufficientstatistics(ef, x)) === (x,) + @test @inferred(logpartition(ef)) ≈ exp(η1) + end + end + end + + for space in (MeanParametersSpace(), NaturalParametersSpace()) + @test !isproper(space, Poisson, [Inf]) + @test !isproper(space, Poisson, [NaN]) + @test !isproper(space, Poisson, [1.0], NaN) + @test !isproper(space, Poisson, [0.5, 0.5], 1.0) + end + ## mean parameter should be integer in the MeanParametersSpace + @test !isproper(MeanParametersSpace(), Poisson, [-0.1]) + @test_throws Exception convert(ExponentialFamilyDistribution, Poisson(Inf)) +end + +@testitem "Poisson: prod" begin + include("distributions_setuptests.jl") + + @testset for λleft in 2:3, λright in 3:4 + left = Poisson(λleft) + right = Poisson(λright) + prod_dist = prod(PreserveTypeProd(ExponentialFamilyDistribution), left, right) + sample_points = collect(1:5) + for x in sample_points + @test basemeasure(prod_dist, x) == (1 / gamma(x + one(x))^2) + @test sufficientstatistics(prod_dist, x) == (x,) + end + sample_points = [-5, -2, 0, 2, 5] + for η in sample_points + @test logpartition(prod_dist, η) == log(abs(besseli(0, 2 * exp(η / 2)))) + end + @test getnaturalparameters(prod_dist) == [log(λleft) + log(λright)] + @test getsupport(prod_dist) == NaturalNumbers() + + @test sum(pdf(prod_dist, x) for x in 0:15) ≈ 1.0 + end +end + +================ +File: distributions/rayleigh_tests.jl +================ +# Rayleigh comes from Distributions.jl and most of the things should be covered there +# Here we test some extra ExponentialFamily.jl specific functionality + +@testitem "Rayleigh: ExponentialFamilyDistribution" begin + include("distributions_setuptests.jl") + + for σ in 10rand(4) + @testset let d = Rayleigh(σ) + ef = test_exponentialfamily_interface(d; option_assume_no_allocations = true) + η1 = first(getnaturalparameters(ef)) + + for x in 10rand(4) + @test @inferred(isbasemeasureconstant(ef)) === NonConstantBaseMeasure() + @test @inferred(basemeasure(ef, x)) === x + @test @inferred(sufficientstatistics(ef, x)) === (x^2,) + @test @inferred(logpartition(ef)) ≈ -log(-2 * η1) + end + end + end + + for space in (MeanParametersSpace(), NaturalParametersSpace()) + @test !isproper(space, Rayleigh, [Inf]) + @test !isproper(space, Rayleigh, [NaN]) + @test !isproper(space, Rayleigh, [1.0], NaN) + @test !isproper(space, Rayleigh, [0.5, 0.5], 1.0) + end + @test !isproper(MeanParametersSpace(), Rayleigh, [-1.0]) + @test_throws Exception convert(ExponentialFamilyDistribution, Rayleigh(Inf)) +end + +@testitem "Rayleigh: prod with PreserveTypeProd{ExponentialFamilyDistribution}" begin + include("distributions_setuptests.jl") + + for σleft in 1:4, σright in 4:7 + @testset let (left, right) = (Rayleigh(σleft), Rayleigh(σright)) + ef_left = convert(ExponentialFamilyDistribution, left) + ef_right = convert(ExponentialFamilyDistribution, right) + prod_dist = prod(PreserveTypeProd(ExponentialFamilyDistribution), left, right) + @test first(hquadrature(x -> pdf(prod_dist, tan(x * pi / 2)) * (pi / 2) * (1 / cos(x * pi / 2)^2), 0.0, 1.0)) ≈ 1.0 + @test getnaturalparameters(prod_dist) == getnaturalparameters(ef_left) + getnaturalparameters(ef_right) + end + end +end + +================ +File: distributions/von_mises_fisher_tests.jl +================ +# VonMisesFisher comes from Distributions.jl and most of the things should be covered there +# Here we test some extra ExponentialFamily.jl specific functionality + +@testitem "VonMisesFisher: vague" begin + include("distributions_setuptests.jl") + + d = vague(VonMisesFisher, 3) + + @test typeof(d) <: VonMisesFisher + @test mean(d) == zeros(3) + @test params(d) == (zeros(3), 1.0e-12) +end + +@testitem "VonMisesFisher: ExponentialFamilyDistribution" begin + include("distributions_setuptests.jl") + + for len in 3:5, b in (0.5) + a_unnormalized = rand(len) + a = a_unnormalized ./ norm(a_unnormalized) + @testset let d = VonMisesFisher(a, b) + ef = test_exponentialfamily_interface(d; option_assume_no_allocations = false, test_fisherinformation_against_jacobian = false, + test_fisherinformation_properties = false + ) + + run_test_fisherinformation_against_jacobian(d; assume_no_allocations = false, mappings = ( + NaturalParametersSpace() => MeanParametersSpace(), + # MeanParametersSpace() => NaturalParametersSpace(), # here is the problem for discussion, the test is broken + )) + + for x in rand(d) + @test @inferred(isbasemeasureconstant(ef)) === ConstantBaseMeasure() + @test @inferred(basemeasure(ef, x)) === (1 / twoπ)^(length(x) * (1 / 2)) + @test all(@inferred(sufficientstatistics(ef, x)) .≈ (x,)) + @test @inferred(logpartition(ef)) ≈ log(besseli((len / 2) - 1, b)) - ((len / 2) - 1) * log(b) + end + end + end +end + +@testitem "VonMisesFisher: prod" begin + include("distributions_setuptests.jl") + + for strategy in (ClosedProd(), PreserveTypeLeftProd(), PreserveTypeRightProd(), PreserveTypeProd(Distribution)) + @test prod(strategy, VonMisesFisher([sin(30), cos(30)], 3.0), VonMisesFisher([sin(45), cos(45)], 4.0)) ≈ + Base.convert( + Distribution, + prod(strategy, convert(ExponentialFamilyDistribution, VonMisesFisher([sin(30), cos(30)], 3.0)), + convert(ExponentialFamilyDistribution, VonMisesFisher([sin(45), cos(45)], 4.0))) + ) + @test prod(strategy, VonMisesFisher([sin(15), cos(15)], 5.0), VonMisesFisher([cos(20), sin(20)], 2.0)) ≈ + Base.convert( + Distribution, + prod(strategy, convert(ExponentialFamilyDistribution, VonMisesFisher([sin(15), cos(15)], 5.0)), + convert(ExponentialFamilyDistribution, VonMisesFisher([cos(20), sin(20)], 2.0))) + ) + end +end + +@testitem "VonMisesFisher: prod with ExponentialFamilyDistribution" begin + include("distributions_setuptests.jl") + + for μleft in eachcol(10rand(4, 4)), μright in eachcol(10rand(4, 4)), σleft in (2, 3), σright in (2, 3) + @testset let (left, right) = (VonMisesFisher(μleft / norm(μleft), σleft), VonMisesFisher(μright / norm(μright), σright)) + @test test_generic_simple_exponentialfamily_product( + left, + right, + strategies = ( + ClosedProd(), + GenericProd(), + PreserveTypeProd(ExponentialFamilyDistribution), + PreserveTypeProd(ExponentialFamilyDistribution{VonMisesFisher}) + ) + ) + end + end +end + +================ +File: distributions/vonmises_tests.jl +================ +# VonMises comes from Distributions.jl and most of the things should be covered there +# Here we test some extra ExponentialFamily.jl specific functionality + +@testitem "VonMises: vague" begin + include("distributions_setuptests.jl") + + d = vague(VonMises) + + @test typeof(d) <: VonMises + @test mean(d) === 0.0 + @test params(d) === (0.0, 1.0e-12) +end + +@testitem "VonMises: ExponentialFamilyDistribution" begin + include("distributions_setuptests.jl") + + for a in -2:1.0:2, b in 0.1:4.0:10.0 + @testset let d = VonMises(a, b) + ef = test_exponentialfamily_interface(d; option_assume_no_allocations = false) + + for x in a-1:0.5:a+1 + @test @inferred(isbasemeasureconstant(ef)) === ConstantBaseMeasure() + @test @inferred(basemeasure(ef, x)) === inv(twoπ) + @test all(@inferred(sufficientstatistics(ef, x)) .≈ (cos(x), sin(x))) + @test @inferred(logpartition(ef)) ≈ (log(besseli(0, b))) + end + + @test !@inferred(insupport(ef, -6)) + @test @inferred(insupport(ef, 0.5)) + + # Not in the support + @test_throws Exception logpdf(ef, -6.0) + end + end + + # Test failing isproper cases + @test !isproper(MeanParametersSpace(), VonMises, [-1]) + @test !isproper(MeanParametersSpace(), VonMises, [1], 3.0) + @test !isproper(MeanParametersSpace(), VonMises, [1, -2]) +end + +@testitem "VonMises: prod with Distributions" begin + include("distributions_setuptests.jl") + for dist1 in [VonMises(10randn(),10rand()) for _=1:20], dist2 in [VonMises(10randn(),10rand()) for _=1:20] + ef1 = convert(ExponentialFamilyDistribution,dist1) + ef2 = convert(ExponentialFamilyDistribution,dist2) + prod_ef = prod(GenericProd(),ef1,ef2) + for strategy in (ClosedProd(), PreserveTypeProd(Distribution), PreserveTypeLeftProd(), PreserveTypeRightProd(), GenericProd()) + @test prod(strategy, dist1, dist2) ≈ convert(Distribution,prod_ef) + end + end +end + +@testitem "VonMises: prod with ExponentialFamilyDistribution" begin + include("distributions_setuptests.jl") + + for kleft in (0.01, 1.0), kright in (0.01, 1.0), alphaleft in 0.1:0.1:0.9, alpharight in 0.1:0.1:0.9 + let left = VonMises(alphaleft, kleft), right = VonMises(alpharight, kright) + @test test_generic_simple_exponentialfamily_product( + left, + right, + strategies = (PreserveTypeProd(ExponentialFamilyDistribution{VonMises}), GenericProd()) + ) + end + end +end + +================ +File: distributions/weibull_tests.jl +================ +# Weibull comes from Distributions.jl and most of the things should be covered there +# Here we test some extra ExponentialFamily.jl specific functionality + +@testitem "Weibull: ExponentialFamilyDistribution" begin + include("distributions_setuptests.jl") + + for shape in (1.0, 2.0, 3.0), scale in (0.25, 0.5, 2.0) + @testset let d = Weibull(shape, scale) + ef = test_exponentialfamily_interface(d; option_assume_no_allocations = false, test_fisherinformation_against_jacobian = false) + η1 = first(getnaturalparameters(ef)) + run_test_fisherinformation_against_jacobian( + d; + assume_no_allocations = true, + mappings = ( + MeanParametersSpace() => NaturalParametersSpace() + ) + ) + for x in scale:1.0:scale+3.0 + @test @inferred(isbasemeasureconstant(ef)) === NonConstantBaseMeasure() + @test @inferred(basemeasure(ef, x)) === x^(shape - 1) + @test @inferred(sufficientstatistics(ef, x)) === (x^shape,) + @test @inferred(logpartition(ef)) ≈ -log(-η1) - log(shape) + end + end + end + + @testset "fisher information by natural to mean jacobian" begin + @testset for k in (1, 3), λ in (0.1, 4.0) + η = -(1 / λ)^k + transformation(η) = [k, (-1 / η[1])^(1 / k)] + J = ForwardDiff.jacobian(transformation, [η]) + @test first(J' * getfisherinformation(MeanParametersSpace(), Weibull, k)(λ) * J) ≈ + first(getfisherinformation(NaturalParametersSpace(), Weibull, k)(η)) atol = 1e-8 + end + end + + for space in (MeanParametersSpace(), NaturalParametersSpace()) + @test !isproper(space, Weibull, [Inf], 1.0) + @test !isproper(space, Weibull, [1.0], Inf) + @test !isproper(space, Weibull, [NaN], 1.0) + @test !isproper(space, Weibull, [1.0], NaN) + @test !isproper(space, Weibull, [0.5, 0.5], 1.0) + + # Conditioner is required + @test_throws Exception isproper(space, Weibull, [0.5], [0.5, 0.5]) + @test_throws Exception isproper(space, Weibull, [1.0], nothing) + @test_throws Exception isproper(space, Weibull, [1.0], nothing) + end + + @test_throws Exception convert(ExponentialFamilyDistribution, Weibull(Inf, Inf)) +end + +@testitem "Weibull: prod with PreserveTypeProd{ExponentialFamilyDistribution} for the same conditioner" begin + include("distributions_setuptests.jl") + + for η in -2.0:0.5:-0.5, k in 1.0:0.5:2, x in 0.5:0.5:2.0 + ef_left = convert(Distribution, ExponentialFamilyDistribution(Weibull, [η], k)) + ef_right = convert(Distribution, ExponentialFamilyDistribution(Weibull, [-η^2], k)) + res = prod(PreserveTypeProd(ExponentialFamilyDistribution), ef_left, ef_right) + @test getbasemeasure(res)(x) == x^(2 * (k - 1)) + @test sufficientstatistics(res, x) == (x^k,) + @test getlogpartition(res)(η - η^2) == + log(abs(η - η^2)^(1 / k)) + loggamma(2 - 1 / k) - 2 * log(abs(η - η^2)) - log(k) + @test getnaturalparameters(res) ≈ [η - η^2] + @test first(hquadrature(x -> pdf(res, tan(x * pi / 2)) * (pi / 2) * (1 / cos(pi * x / 2))^2, 0.0, 1.0)) ≈ + 1.0 + end +end + +@testitem "Weibull: prod with PreserveTypeProd{ExponentialFamilyDistribution} for different k" begin + include("distributions_setuptests.jl") + + for η in -12:4:-0.5, k in 1.0:4:10, x in 0.5:4:10 + ef_left = convert(Distribution, ExponentialFamilyDistribution(Weibull, [η], k * 2)) + ef_right = convert(Distribution, ExponentialFamilyDistribution(Weibull, [-η^2], k)) + res = prod(PreserveTypeProd(ExponentialFamilyDistribution), ef_left, ef_right) + @test getbasemeasure(res)(x) == x^(k + k * 2 - 2) + @test sufficientstatistics(res, x) == (x^(2 * k), x^k) + @test getnaturalparameters(res) ≈ [η, -η^2] + @test first(hquadrature(x -> pdf(res, tan(x * pi / 2)) * (pi / 2) * (1 / cos(pi * x / 2))^2, 0.0, 1.0)) ≈ + 1.0 + end +end + +================ +File: distributions/wishart_inverse_tests.jl +================ +@testitem "InverseWishart: common" begin + include("distributions_setuptests.jl") + + import ExponentialFamily: InverseWishartFast + + @test InverseWishartFast <: Distribution + @test InverseWishartFast <: ContinuousDistribution + @test InverseWishartFast <: MatrixDistribution + + @test value_support(InverseWishartFast) === Continuous + @test variate_form(InverseWishartFast) === Matrixvariate +end + +@testitem "InverseWishart: ExponentialFamilyDistribution" begin + include("distributions_setuptests.jl") + + import ExponentialFamily: InverseWishartFast + + rng = StableRNG(42) + + @testset for dim in (3), S in rand(rng, InverseWishart(10, Array(Eye(dim))), 2) + ν = dim + 4 + @testset let (d = InverseWishartFast(ν, S)) + ef = test_exponentialfamily_interface(d; option_assume_no_allocations = false, test_fisherinformation_against_hessian = false) + (η1, η2) = unpack_parameters(InverseWishartFast, getnaturalparameters(ef)) + + for x in (Eye(dim), Diagonal(ones(dim)), Array(Eye(dim))) + @test @inferred(isbasemeasureconstant(ef)) === ConstantBaseMeasure() + @test @inferred(basemeasure(ef, x)) === 1.0 + @test all(@inferred(sufficientstatistics(ef, x)) .≈ (logdet(x), inv(x))) + @test @inferred(logpartition(ef)) ≈ (η1 + (dim + 1) / 2) * logdet(-η2) + logmvgamma(dim, -(η1 + (dim + 1) / 2)) + end + end + end +end + +@testitem "InverseWishart: statistics" begin + include("distributions_setuptests.jl") + + import ExponentialFamily: InverseWishartFast + + rng = StableRNG(42) + # ν > dim(d) + 1 + for ν in 4:10 + L = randn(rng, ν - 2, ν - 2) + S = L * L' + d = InverseWishartFast(ν, S) + + @test mean(d) == mean(InverseWishart(params(d)...)) + @test mode(d) == mode(InverseWishart(params(d)...)) + end + + # ν > dim(d) + 3 + for ν in 5:10 + L = randn(rng, ν - 4, ν - 4) + S = L * L' + d = InverseWishartFast(ν, S) + + @test cov(d) == cov(InverseWishart(params(d)...)) + @test var(d) == var(InverseWishart(params(d)...)) + end +end + +@testitem "InverseWishart: vague" begin + include("distributions_setuptests.jl") + + dims = 3 + d1 = vague(InverseWishart, dims) + + @test typeof(d1) <: InverseWishart + ν1, S1 = params(d1) + @test ν1 == dims + 2 + @test S1 == tiny .* Eye(dims) + + @test mean(d1) == S1 + + dims = 4 + d2 = vague(InverseWishart, dims) + + @test typeof(d2) <: InverseWishart + ν2, S2 = params(d2) + @test ν2 == dims + 2 + @test S2 == tiny .* Eye(dims) + + @test mean(d2) == S2 +end + +@testitem "InverseWishart: entropy" begin + include("distributions_setuptests.jl") + + import ExponentialFamily: InverseWishartFast + + @test entropy( + InverseWishartFast( + 2.0, + [2.2658069783329573 -0.47934965873423374; -0.47934965873423374 1.4313564100863712] + ) + ) ≈ 10.111427477184794 + @test entropy(InverseWishartFast(5.0, Eye(4))) ≈ 8.939145914882221 +end + +@testitem "InverseWishart: convert" begin + include("distributions_setuptests.jl") + + import ExponentialFamily: InverseWishartFast + + rng = StableRNG(42) + for ν in 2:10 + L = randn(rng, ν, ν) + S = L * L' + d = InverseWishartFast(ν, S) + @test convert(InverseWishart, d) == InverseWishart(ν, S) + end +end + +@testitem "InverseWishart: mean(::typeof(logdet))" begin + include("distributions_setuptests.jl") + + import ExponentialFamily: InverseWishartFast + + rng = StableRNG(123) + ν, S = 2.0, [2.2658069783329573 -0.47934965873423374; -0.47934965873423374 1.4313564100863712] + samples = rand(rng, InverseWishart(ν, S), Int(1e6)) + @test isapprox(mean(logdet, InverseWishartFast(ν, S)), mean(logdet.(samples)), atol = 1e-2) + @test isapprox(mean(logdet, InverseWishart(ν, S)), mean(logdet.(samples)), atol = 1e-2) + + ν, S = 4.0, Array(Eye(3)) + samples = rand(rng, InverseWishart(ν, S), Int(1e6)) + @test isapprox(mean(logdet, InverseWishartFast(ν, S)), mean(logdet.(samples)), atol = 1e-2) + @test isapprox(mean(logdet, InverseWishart(ν, S)), mean(logdet.(samples)), atol = 1e-2) +end + +@testitem "InverseWishart: mean(::typeof(inv))" begin + include("distributions_setuptests.jl") + + import ExponentialFamily: InverseWishartFast + + rng = StableRNG(321) + ν, S = 2.0, [2.2658069783329573 -0.47934965873423374; -0.47934965873423374 1.4313564100863712] + samples = rand(rng, InverseWishart(ν, S), Int(1e6)) + @test isapprox(mean(inv, InverseWishartFast(ν, S)), mean(inv.(samples)), atol = 1e-2) + @test isapprox(mean(inv, InverseWishart(ν, S)), mean(inv.(samples)), atol = 1e-2) + + ν, S = 4.0, Array(Eye(3)) + samples = rand(rng, InverseWishart(ν, S), Int(1e6)) + @test isapprox(mean(inv, InverseWishartFast(ν, S)), mean(inv.(samples)), atol = 1e-2) + @test isapprox(mean(inv, InverseWishart(ν, S)), mean(inv.(samples)), atol = 1e-2) +end + +@testitem "InverseWishart: prod" begin + include("distributions_setuptests.jl") + + import ExponentialFamily: InverseWishartFast + + d1 = InverseWishartFast(3.0, Eye(2)) + d2 = InverseWishartFast(-3.0, [0.6423504672769315 0.9203141654948761; 0.9203141654948761 1.528137747462735]) + + @test prod(PreserveTypeProd(Distribution), d1, d2) ≈ + InverseWishartFast(3.0, [1.6423504672769313 0.9203141654948761; 0.9203141654948761 2.528137747462735]) + + d1 = InverseWishartFast(4.0, Eye(3)) + d2 = InverseWishartFast(-2.0, Eye(3)) + + @test prod(PreserveTypeProd(Distribution), d1, d2) ≈ InverseWishartFast(6.0, 2 * Eye(3)) +end + +@testitem "InverseWishart: rand" begin + include("distributions_setuptests.jl") + + import ExponentialFamily: InverseWishartFast + + for d in (2, 3, 4, 5) + v = rand() + d + L = rand(d, d) + S = L' * L + d * Eye(d) + cS = copy(S) + container1 = [zeros(d, d) for _ in 1:100] + container2 = [zeros(d, d) for _ in 1:100] + + # Check in-place version + @test rand!(StableRNG(321), InverseWishart(v, S), container1) ≈ + rand!(StableRNG(321), InverseWishartFast(v, S), container2) + + # Check that the matrix has not been corrupted + @test all(S .=== cS) + + # Check non-inplace version + @test rand(StableRNG(321), InverseWishart(v, S), length(container1)) ≈ + rand(StableRNG(321), InverseWishartFast(v, S), length(container2)) + end +end + +@testitem "InverseWishart: pdf!" begin + include("distributions_setuptests.jl") + + import ExponentialFamily: InverseWishartFast + import Distributions: pdf! + + for d in (2, 3, 4, 5), n in (10, 20) + v = rand() + d + L = rand(d, d) + S = L' * L + d * Eye(d) + + samples = map(1:n) do _ + L_sample = rand(d, d) + return L_sample' * L_sample + d * Eye(d) + end + + result = zeros(n) + + @test all(pdf(InverseWishart(v, S), samples) .≈ pdf!(result, InverseWishartFast(v, S), samples)) + end +end + +@testitem "InverseWishart: prod with ExponentialFamilyDistribution" begin + include("distributions_setuptests.jl") + + import ExponentialFamily: InverseWishartFast + + for Sleft in rand(InverseWishart(10, Array(Eye(2))), 2), Sright in rand(InverseWishart(10, Array(Eye(2))), 2), νright in (6, 7), νleft in (4, 5) + let left = InverseWishartFast(νleft, Sleft), right = InverseWishartFast(νright, Sright) + @test test_generic_simple_exponentialfamily_product( + left, + right, + strategies = (PreserveTypeProd(ExponentialFamilyDistribution{InverseWishartFast}), GenericProd()) + ) + end + end +end + +================ +File: distributions/wishart_tests.jl +================ +# Wishart comes from Distributions.jl and most of the things should be covered there +# Here we test some extra ExponentialFamily.jl specific functionality + +@testitem "Wishart: mean(::logdet)" begin + include("distributions_setuptests.jl") + + @test mean(logdet, Wishart(3, [1.0 0.0; 0.0 1.0])) ≈ 0.845568670196936 + @test mean( + logdet, + Wishart( + 5, + [ + 1.4659658963311604 1.111775094889733 0.8741034114800605 + 1.111775094889733 0.8746971141492232 0.6545661366809246 + 0.8741034114800605 0.6545661366809246 0.5498917856395482 + ] + ) + ) ≈ -3.4633310802040693 +end + +@testitem "Wishart: mean(::cholinv)" begin + include("distributions_setuptests.jl") + + L = rand(2, 2) + S = L * L' + Eye(2) + invS = inv(S) + @test mean(inv, Wishart(5, S)) ≈ mean(InverseWishart(5, invS)) +end + +@testitem "Wishart: vague" begin + include("distributions_setuptests.jl") + + @test_throws MethodError vague(Wishart) + + d = vague(Wishart, 3) + + @test typeof(d) <: Wishart + @test mean(d) == Matrix(Diagonal(3 * 1e12 * ones(3))) +end + +@testitem "Wishart: rand" begin + include("distributions_setuptests.jl") + + import ExponentialFamily: WishartFast + + rng = StableRNG(42) + + for d in (2, 3, 4, 5) + v = rand(rng) + d + L = rand(rng, d, d) + S = L' * L + d * Eye(d) + invS = inv(S) + cS = copy(S) + cinvS = copy(invS) + container1 = [zeros(d, d) for _ in 1:100] + container2 = [zeros(d, d) for _ in 1:100] + + # Check inplace versions + @test rand!(StableRNG(321), Wishart(v, S), container1) ≈ + rand!(StableRNG(321), WishartFast(v, invS), container2) + + # Check that matrices are not corrupted + @test all(S .=== cS) + @test all(invS .=== cinvS) + + # Check non-inplace versions + @test rand(StableRNG(321), Wishart(v, S), length(container1)) ≈ + rand(StableRNG(321), WishartFast(v, invS), length(container2)) + end +end + + +@testitem "Wishart: ExponentialFamilyDistribution" begin + include("distributions_setuptests.jl") + + import ExponentialFamily: WishartFast + + rng = StableRNG(42) + + for dim in (3, 4), invS in rand(rng, Wishart(10, Array(Eye(dim))), 2) + ν = dim + 2 + @testset let (d = WishartFast(ν, invS)) + ef = test_exponentialfamily_interface(d; option_assume_no_allocations = false, test_fisherinformation_against_hessian = false) + (η1, η2) = unpack_parameters(WishartFast, getnaturalparameters(ef)) + + for x in (Eye(dim), Diagonal(ones(dim)), Array(Eye(dim))) + @test @inferred(isbasemeasureconstant(ef)) === ConstantBaseMeasure() + @test @inferred(basemeasure(ef, x)) === 1.0 + @test all(@inferred(sufficientstatistics(ef, x)) .≈ (logdet(x), x)) + @test @inferred(logpartition(ef)) ≈ -(η1 + (dim + 1) / 2) * logdet(-η2) + logmvgamma(dim, η1 + (dim + 1) / 2) + end + end + end +end + +@testitem "Wishart: prod" begin + include("distributions_setuptests.jl") + + import ExponentialFamily: WishartFast + + inv_v1 = inv([9.0 -3.4; -3.4 11.0]) + inv_v2 = inv([10.2 -3.3; -3.3 5.0]) + inv_v3 = inv([8.1 -2.7; -2.7 9.0]) + + @test prod(PreserveTypeProd(Distribution), WishartFast(3, inv_v1), WishartFast(3, inv_v2)) ≈ + WishartFast( + 3, + inv_v1 + inv_v2 + ) + @test prod(PreserveTypeProd(Distribution), WishartFast(4, inv_v1), WishartFast(4, inv_v3)) ≈ + WishartFast( + 5, + inv_v1 + inv_v3 + ) + @test prod(PreserveTypeProd(Distribution), WishartFast(5, inv_v2), WishartFast(4, inv_v3)) ≈ + WishartFast(6, inv([4.51459128065395 -1.4750681198910067; -1.4750681198910067 3.129155313351499])) +end + +@testitem "Wishart: prod with ExponentialFamilyDistribution{Wishart}" begin + include("distributions_setuptests.jl") + + import ExponentialFamily: WishartFast + + for Sleft in rand(Wishart(10, Array(Eye(2))), 2), Sright in rand(Wishart(10, Array(Eye(2))), 2), νright in (6, 7), νleft in (4, 5) + let left = WishartFast(νleft, Sleft), right = WishartFast(νright, Sright) + @test test_generic_simple_exponentialfamily_product( + left, + right, + strategies = (PreserveTypeProd(ExponentialFamilyDistribution{WishartFast}), GenericProd()) + ) + end + end +end + +================ +File: common_tests.jl +================ +@testitem "dot3arg" begin + using LinearAlgebra, ForwardDiff + using ExponentialFamily: dot3arg + + for n in 2:10 + x = rand(n) + y = rand(n) + A = rand(n, n) + @test dot3arg(x, A, y) ≈ dot(x, A, y) + @test all(ForwardDiff.hessian((x) -> dot3arg(x, A, x), x) .!== 0) + @test all(ForwardDiff.hessian((x) -> dot3arg(x, A, x), y) .!== 0) + end +end + +================ +File: exponential_family_setuptests.jl +================ +using ExponentialFamily, BayesBase, Distributions, Test, StatsFuns, BenchmarkTools, Random, FillArrays + +import Distributions: RealInterval, ContinuousUnivariateDistribution, Univariate +import ExponentialFamily: basemeasure, logbasemeasure, sufficientstatistics, logpartition, insupport, ConstantBaseMeasure +import ExponentialFamily: getnaturalparameters, getbasemeasure, getlogbasemeasure, getsufficientstatistics, getlogpartition, getsupport +import ExponentialFamily: ExponentialFamilyDistributionAttributes, NaturalParametersSpace + +# import ExponentialFamily: +# ExponentialFamilyDistribution, getnaturalparameters, getconditioner, reconstructargument!, as_vec, +# pack_naturalparameters, unpack_naturalparameters, insupport +# import Distributions: pdf, logpdf, cdf + +## =========================================================================== +## Tests fixtures +const ArbitraryExponentialFamilyAttributes = ExponentialFamilyDistributionAttributes( + (x) -> 1 / x, + ((x) -> x, (x) -> log(x)), + (η) -> 1 / sum(η), + RealInterval(0, Inf) +) + +# Arbitrary distribution (un-conditioned) +struct ArbitraryDistributionFromExponentialFamily <: ContinuousUnivariateDistribution + p1::Float64 + p2::Float64 +end + +ExponentialFamily.isproper(::NaturalParametersSpace, ::Type{ArbitraryDistributionFromExponentialFamily}, η, conditioner) = isnothing(conditioner) +ExponentialFamily.isbasemeasureconstant(::Type{ArbitraryDistributionFromExponentialFamily}) = ConstantBaseMeasure() +ExponentialFamily.getbasemeasure(::Type{ArbitraryDistributionFromExponentialFamily}) = (x) -> oneunit(x) +ExponentialFamily.getlogbasemeasure(::Type{ArbitraryDistributionFromExponentialFamily}) = (x) -> zero(x) +ExponentialFamily.getsufficientstatistics(::Type{ArbitraryDistributionFromExponentialFamily}) = + ((x) -> x, (x) -> log(x)) +ExponentialFamily.getlogpartition(::NaturalParametersSpace, ::Type{ArbitraryDistributionFromExponentialFamily}) = (η) -> 1 / sum(η) +ExponentialFamily.getsupport(::Type{ArbitraryDistributionFromExponentialFamily}) = RealInterval(0, Inf) + +BayesBase.vague(::Type{ArbitraryDistributionFromExponentialFamily}) = + ArbitraryDistributionFromExponentialFamily(1.0, 1.0) + +BayesBase.params(dist::ArbitraryDistributionFromExponentialFamily) = (dist.p1, dist.p2) + +(::MeanToNatural{ArbitraryDistributionFromExponentialFamily})(params::Tuple) = (params[1] + 1, params[2] + 1) +(::NaturalToMean{ArbitraryDistributionFromExponentialFamily})(params::Tuple) = (params[1] - 1, params[2] - 1) + +ExponentialFamily.unpack_parameters(::Type{ArbitraryDistributionFromExponentialFamily}, η) = (η[1], η[2]) + +# Arbitrary distribution (conditioned) +struct ArbitraryConditionedDistributionFromExponentialFamily <: ContinuousUnivariateDistribution + con::Int + p1::Float64 +end + +ExponentialFamily.isproper(::NaturalParametersSpace, ::Type{ArbitraryConditionedDistributionFromExponentialFamily}, η, conditioner) = isinteger(conditioner) +ExponentialFamily.isbasemeasureconstant(::Type{ArbitraryConditionedDistributionFromExponentialFamily}) = NonConstantBaseMeasure() +ExponentialFamily.getbasemeasure(::Type{ArbitraryConditionedDistributionFromExponentialFamily}, conditioner) = (x) -> x^conditioner +ExponentialFamily.getlogbasemeasure(::Type{ArbitraryConditionedDistributionFromExponentialFamily}, conditioner) = (x) -> conditioner*log(x) +ExponentialFamily.getsufficientstatistics(::Type{ArbitraryConditionedDistributionFromExponentialFamily}, conditioner) = + ((x) -> log(x - conditioner),) +ExponentialFamily.getlogpartition(::NaturalParametersSpace, ::Type{ArbitraryConditionedDistributionFromExponentialFamily}, conditioner) = + (η) -> conditioner / sum(η) +ExponentialFamily.getsupport(::Type{ArbitraryConditionedDistributionFromExponentialFamily}) = RealInterval(0, Inf) + +BayesBase.vague(::Type{ArbitraryConditionedDistributionFromExponentialFamily}) = + ArbitraryConditionedDistributionFromExponentialFamily(1.0, -2) + +BayesBase.params(dist::ArbitraryConditionedDistributionFromExponentialFamily) = (dist.con, dist.p1) + +ExponentialFamily.separate_conditioner(::Type{ArbitraryConditionedDistributionFromExponentialFamily}, params) = ((params[2],), params[1]) +ExponentialFamily.join_conditioner(::Type{ArbitraryConditionedDistributionFromExponentialFamily}, cparams, conditioner) = (conditioner, cparams...) + +(::MeanToNatural{ArbitraryConditionedDistributionFromExponentialFamily})(params::Tuple, conditioner::Number) = (params[1] + conditioner,) +(::NaturalToMean{ArbitraryConditionedDistributionFromExponentialFamily})(params::Tuple, conditioner::Number) = (params[1] - conditioner,) + +ExponentialFamily.unpack_parameters(::Type{ArbitraryConditionedDistributionFromExponentialFamily}, η) = (η[1],) + +================ +File: exponential_family_tests.jl +================ +@testitem "pack_parameters" begin + include("./exponential_family_setuptests.jl") + + import ExponentialFamily: pack_parameters + + to_test_fixed = [ + (1, 2), + (1.0, 2), + (1, 2.0), + (1.0, 2.0), + ([1, 2, 3], 3), + ([1, 2, 3], 3.0), + ([1.0, 2.0, 3.0], 3), + ([1.0, 2.0, 3.0], 3.0), + (4, [1, 2, 3]), + (4.0, [1, 2, 3]), + (4, [1.0, 2.0, 3.0]), + (4.0, [1.0, 2.0, 3.0]), + ([1, 2, 3], 3, [1 2 3; 1 2 3; 1 2 3], 4), + ([1, 2, 3], 3, [1.0 2.0 3.0; 1.0 2.0 3.0; 1.0 2.0 3.0], 4), + ([1, 2, 3], 3, [1.0 2.0 3.0; 1.0 2.0 3.0; 1.0 2.0 3.0], 4.0), + ([1, 2, 3], 3.0, [1 2 3; 1 2 3; 1 2 3], 4), + ([1, 2, 3], 3.0, [1.0 2.0 3.0; 1.0 2.0 3.0; 1.0 2.0 3.0], 4), + ([1, 2, 3], 3.0, [1.0 2.0 3.0; 1.0 2.0 3.0; 1.0 2.0 3.0], 4.0), + ([1.0, 2.0, 3.0], 3, [1 2 3; 1 2 3; 1 2 3], 4), + ([1.0, 2.0, 3.0], 3, [1.0 2.0 3.0; 1.0 2.0 3.0; 1.0 2.0 3.0], 4), + ([1.0, 2.0, 3.0], 3, [1.0 2.0 3.0; 1.0 2.0 3.0; 1.0 2.0 3.0], 4.0) + ] + + for test in to_test_fixed + @test all(@inferred(pack_parameters(test)) .== collect(Iterators.flatten(test))) + end + + for _ in 1:10 + to_test_random = [ + rand(Float64), + rand(1:10), + [rand(Float64) for _ in rand(1:10)], + [rand(1:10) for _ in rand(1:10)], + [rand(Float64) for _ in rand(1:10) for _ in rand(1:10)], + [rand(1:10) for _ in rand(1:10) for _ in rand(1:10)] + ] + params = Tuple(shuffle(to_test_random)) + @test all(@inferred(pack_parameters(params)) .== collect(Iterators.flatten(params))) + end +end + +@testitem "ExponentialFamilyDistributionAttributes" begin + include("./exponential_family_setuptests.jl") + + @testset "getmapping" begin + @test @inferred(getmapping(MeanParametersSpace() => NaturalParametersSpace(), ArbitraryDistributionFromExponentialFamily)) === + MeanToNatural{ArbitraryDistributionFromExponentialFamily}() + @test @inferred(getmapping(NaturalParametersSpace() => MeanParametersSpace(), ArbitraryDistributionFromExponentialFamily)) === + NaturalToMean{ArbitraryDistributionFromExponentialFamily}() + @test @allocated(getmapping(MeanParametersSpace() => NaturalParametersSpace(), ArbitraryDistributionFromExponentialFamily)) === 0 + @test @allocated(getmapping(NaturalParametersSpace() => MeanParametersSpace(), ArbitraryDistributionFromExponentialFamily)) === 0 + end + + @testset let attributes = ArbitraryExponentialFamilyAttributes + @test @inferred(getbasemeasure(attributes)(2.0)) ≈ 0.5 + @test @inferred(getlogbasemeasure(attributes)(2.0)) ≈ log(0.5) + @test @inferred(getsufficientstatistics(attributes)[1](2.0)) ≈ 2.0 + @test @inferred(getsufficientstatistics(attributes)[2](2.0)) ≈ log(2.0) + @test @inferred(getlogpartition(attributes)([2.0])) ≈ 0.5 + @test @inferred(getsupport(attributes)) == RealInterval(0, Inf) + @test @inferred(insupport(attributes, 1.0)) + @test !@inferred(insupport(attributes, -1.0)) + end + + @testset let member = + ExponentialFamilyDistribution(Univariate, [2.0, 2.0], nothing, ArbitraryExponentialFamilyAttributes) + η = @inferred(getnaturalparameters(member)) + + @test ExponentialFamily.exponential_family_typetag(member) === Univariate + + @test @inferred(basemeasure(member, 2.0)) ≈ 0.5 + @test @inferred(getbasemeasure(member)(2.0)) ≈ 0.5 + @test @inferred(getbasemeasure(member)(4.0)) ≈ 0.25 + + @test all(@inferred(sufficientstatistics(member, 2.0)) .≈ (2.0, log(2.0))) + @test all(@inferred(map(f -> f(2.0), getsufficientstatistics(member))) .≈ (2.0, log(2.0))) + @test all(@inferred(map(f -> f(4.0), getsufficientstatistics(member))) .≈ (4.0, log(4.0))) + + @test @inferred(logpartition(member)) ≈ 0.25 + @test @inferred(getlogpartition(member)([2.0, 2.0])) ≈ 0.25 + @test @inferred(getlogpartition(member)([4.0, 4.0])) ≈ 0.125 + + @test @inferred(getsupport(member)) == RealInterval(0, Inf) + @test @inferred(insupport(member, 1.0)) + @test !@inferred(insupport(member, -1.0)) + + _similar = @inferred(similar(member)) + + # The standard `@allocated` is not really reliable in this test + # We avoid using the `BenchmarkTools`, but here it is essential + @test @ballocated(logpdf($member, 1.0), samples = 1, evals = 1) === 0 + @test @ballocated(pdf($member, 1.0), samples = 1, evals = 1) === 0 + + @test _similar isa typeof(member) + + # `similar` most probably returns the un-initialized natural parameters with garbage in it + # But we do expect the functions to work anyway given proper values + @test @inferred(basemeasure(_similar, 2.0)) ≈ 0.5 + @test all(@inferred(sufficientstatistics(_similar, 2.0)) .≈ (2.0, log(2.0))) + @test @inferred(logpartition(_similar, η)) ≈ 0.25 + @test @inferred(getsupport(_similar)) == RealInterval(0, Inf) + end +end + +@testitem "ArbitraryDistributionFromExponentialFamily" begin + include("./exponential_family_setuptests.jl") + + @testset for member in ( + ExponentialFamilyDistribution(ArbitraryDistributionFromExponentialFamily, [2.0, 2.0]), + convert(ExponentialFamilyDistribution, ArbitraryDistributionFromExponentialFamily(1.0, 1.0)) + ) + η = @inferred(getnaturalparameters(member)) + + @test ExponentialFamily.exponential_family_typetag(member) === ArbitraryDistributionFromExponentialFamily + + @test convert(ExponentialFamilyDistribution, convert(Distribution, member)) == + ExponentialFamilyDistribution(ArbitraryDistributionFromExponentialFamily, [2.0, 2.0]) + @test convert(Distribution, convert(ExponentialFamilyDistribution, member)) == ArbitraryDistributionFromExponentialFamily(1.0, 1.0) + + @test @inferred(basemeasure(member, 2.0)) ≈ 1.0 + @test @inferred(getbasemeasure(member)(2.0)) ≈ 1.0 + @test @inferred(getbasemeasure(member)(4.0)) ≈ 1.0 + + @test @inferred(logbasemeasure(member, 2.0)) ≈ log(1.0) + @test @inferred(getlogbasemeasure(member)(2.0)) ≈ log(1.0) + @test @inferred(getlogbasemeasure(member)(4.0)) ≈ log(1.0) + + @test all(@inferred(sufficientstatistics(member, 2.0)) .≈ (2.0, log(2.0))) + @test all(@inferred(map(f -> f(2.0), getsufficientstatistics(member))) .≈ (2.0, log(2.0))) + @test all(@inferred(map(f -> f(4.0), getsufficientstatistics(member))) .≈ (4.0, log(4.0))) + + @test @inferred(logpartition(member)) ≈ 0.25 + @test @inferred(getlogpartition(member)([2.0, 2.0])) ≈ 0.25 + @test @inferred(getlogpartition(member)([4.0, 4.0])) ≈ 0.125 + + @test @inferred(getsupport(member)) == RealInterval(0, Inf) + @test insupport(member, 1.0) + @test !insupport(member, -1.0) + + # Computed by hand + @test @inferred(logpdf(member, 2.0)) ≈ (3.75 + 2log(2)) + @test @inferred(logpdf(member, 4.0)) ≈ (7.75 + 4log(2)) + @test @inferred(pdf(member, 2.0)) ≈ exp(3.75 + 2log(2)) + @test @inferred(pdf(member, 4.0)) ≈ exp(7.75 + 4log(2)) + + # The standard `@allocated` is not really reliable in this test + # We avoid using the `BenchmarkTools`, but here it is essential + @test @ballocated(logpdf($member, 2.0), samples = 1, evals = 1) === 0 + @test @ballocated(pdf($member, 2.0), samples = 1, evals = 1) === 0 + + @test @inferred(member == member) + @test @inferred(member ≈ member) + + _similar = @inferred(similar(member)) + _prod = ExponentialFamilyDistribution(ArbitraryDistributionFromExponentialFamily, [4.0, 4.0]) + + @test @inferred(prod(ClosedProd(), member, member)) == _prod + @test @inferred(prod(GenericProd(), member, member)) == _prod + @test @inferred(prod(PreserveTypeProd(ExponentialFamilyDistribution), member, member)) == _prod + @test @inferred(prod(PreserveTypeLeftProd(), member, member)) == _prod + @test @inferred(prod(PreserveTypeRightProd(), member, member)) == _prod + + # Test that the generic prod version does not allocate as much as simply creating a similar ef member + # This is important, because the generic prod version should simply call the in-place version + @test @allocated(prod(ClosedProd(), member, member)) <= @allocated(similar(member)) + @test @allocated(prod(GenericProd(), member, member)) <= @allocated(similar(member)) + + # This test is actually passing, but does not work if you re-run tests for some reason (which is hapenning often during development) + # @test @allocated(prod(PreserveTypeProd(ExponentialFamilyDistribution), member, member)) <= + # @allocated(similar(member)) + + @test @inferred(prod!(_similar, member, member)) == _prod + + # Test that the in-place prod preserves the container paramfloatype + for F in (Float16, Float32, Float64) + @test @inferred(paramfloattype(prod!(similar(member, F), member, member))) === F + @test @inferred(prod!(similar(member, F), member, member)) == convert_paramfloattype(F, _prod) + end + + # Test that the generic in-place prod! version does not allocate at all + @test @allocated(prod!(_similar, member, member)) === 0 + end +end + +@testitem "ArbitraryConditionedDistributionFromExponentialFamily" begin + include("./exponential_family_setuptests.jl") + + # See the `ArbitraryDistributionFromExponentialFamily` defined in the fixtures (above) + # p1 = 3.0, con = -2 + @testset for member in ( + ExponentialFamilyDistribution(ArbitraryConditionedDistributionFromExponentialFamily, [1.0], -2), + convert(ExponentialFamilyDistribution, ArbitraryConditionedDistributionFromExponentialFamily(-2, 3.0)) + ) + @test ExponentialFamily.exponential_family_typetag(member) === ArbitraryConditionedDistributionFromExponentialFamily + + η = @inferred(getnaturalparameters(member)) + + @test convert(ExponentialFamilyDistribution, convert(Distribution, member)) == + ExponentialFamilyDistribution(ArbitraryConditionedDistributionFromExponentialFamily, [1.0], -2) + @test convert(Distribution, convert(ExponentialFamilyDistribution, member)) == ArbitraryConditionedDistributionFromExponentialFamily(-2, 3.0) + + @test @inferred(basemeasure(member, 2.0)) ≈ 2.0^-2 + @test @inferred(getbasemeasure(member)(2.0)) ≈ 2.0^-2 + @test @inferred(getbasemeasure(member)(4.0)) ≈ 4.0^-2 + + @test @inferred(logbasemeasure(member, 2.0)) ≈ -2*log(2.0) + @test @inferred(getlogbasemeasure(member)(2.0)) ≈ -2*log(2.0) + @test @inferred(getlogbasemeasure(member)(4.0)) ≈ -2*log(4.0) + + @test all(@inferred(sufficientstatistics(member, 2.0)) .≈ (log(2.0 + 2),)) + @test all(@inferred(map(f -> f(2.0), getsufficientstatistics(member))) .≈ (log(2.0 + 2),)) + @test all(@inferred(map(f -> f(4.0), getsufficientstatistics(member))) .≈ (log(4.0 + 2),)) + + @test @inferred(logpartition(member)) ≈ -2.0 + @test @inferred(getlogpartition(member)([2.0])) ≈ -1.0 + @test @inferred(getlogpartition(member)([4.0])) ≈ -0.5 + + @test @inferred(getsupport(member)) == RealInterval(0, Inf) + @test insupport(member, 1.0) + @test !insupport(member, -1.0) + + # # Computed by hand + @test @inferred(logpdf(member, 2.0)) ≈ (log(2.0^-2) + log(2.0 + 2) + 2.0) + @test @inferred(logpdf(member, 4.0)) ≈ (log(4.0^-2) + log(4.0 + 2) + 2.0) + @test @inferred(pdf(member, 2.0)) ≈ exp((log(2.0^-2) + log(2.0 + 2) + 2.0)) + @test @inferred(pdf(member, 4.0)) ≈ exp((log(4.0^-2) + log(4.0 + 2) + 2.0)) + + # The standard `@allocated` is not really reliable in this test + # We avoid using the `BenchmarkTools`, but here it is essential + @test @ballocated(logpdf($member, 2.0), samples = 1, evals = 1) === 0 + @test @ballocated(pdf($member, 2.0), samples = 1, evals = 1) === 0 + + @test @inferred(member == member) + @test @inferred(member ≈ member) + + _similar = @inferred(similar(member)) + _prod = ExponentialFamilyDistribution(ArbitraryConditionedDistributionFromExponentialFamily, [1.0], -2) + + # We don't test the prod becasue the basemeasure is not a constant, so the generic prod is not applicable + + # # Test that the in-place prod preserves the container paramfloatype + for F in (Float16, Float32, Float64) + @test @inferred(paramfloattype(similar(member, F))) === F + end + end +end + +@testitem "vague" begin + include("./exponential_family_setuptests.jl") + + @test @inferred(vague(ExponentialFamilyDistribution{ArbitraryDistributionFromExponentialFamily})) isa + ExponentialFamilyDistribution{ArbitraryDistributionFromExponentialFamily} + + @test @inferred(vague(ExponentialFamilyDistribution{ArbitraryConditionedDistributionFromExponentialFamily})) isa + ExponentialFamilyDistribution{ArbitraryConditionedDistributionFromExponentialFamily} +end + +================ +File: runtests.jl +================ +using Aqua, CpuId, ReTestItems, ExponentialFamily + +# `ambiguities = false` - there are quite some ambiguities, but these should be normal and should not be encountered under normal circumstances +# `piracies = false` - we extend/add some of the methods to the objects defined in the Distributions.jl +Aqua.test_all(ExponentialFamily, ambiguities = false, deps_compat = (; check_extras = false, check_weakdeps = true), piracies = false) + +nthreads = max(cputhreads(), 1) +ncores = max(cpucores(), 1) + +runtests(ExponentialFamily, + nworkers = ncores, + nworker_threads = Int(nthreads / ncores), + memory_threshold = 1.0 +) From 118ccfdc49eb65beec3c514f3229a80fc2b5e7b7 Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Fri, 20 Sep 2024 16:10:36 +0200 Subject: [PATCH 15/32] fix: dimension match --- .../mv_normal_mean_scale_precision.jl | 33 ++++++------------- 1 file changed, 10 insertions(+), 23 deletions(-) diff --git a/src/distributions/normal_family/mv_normal_mean_scale_precision.jl b/src/distributions/normal_family/mv_normal_mean_scale_precision.jl index 9939cc95..e67a78c3 100644 --- a/src/distributions/normal_family/mv_normal_mean_scale_precision.jl +++ b/src/distributions/normal_family/mv_normal_mean_scale_precision.jl @@ -214,7 +214,9 @@ function getsupport(ef::ExponentialFamilyDistribution{MvNormalMeanScalePrecision return Domain(IndicatorFunction{AbstractVector}(MvNormalDomainIndicator(dim))) end -getbasemeasure(::Type{MvNormalMeanScalePrecision}) = (x) -> (2π)^(- length(x) / 2) +isbasemeasureconstant(::Type{MvNormalMeanScalePrecision}) = ConstantBaseMeasure() + +getbasemeasure(::Type{MvNormalMeanScalePrecision}) = (x) -> (2π)^(-length(x) / 2) getlogpartition(::NaturalParametersSpace, ::Type{MvNormalMeanScalePrecision}) = (η) -> begin @@ -238,27 +240,12 @@ getfisherinformation(::NaturalParametersSpace, ::Type{MvNormalMeanScalePrecision (η) -> begin (η₁, η₂) = unpack_parameters(MvNormalMeanScalePrecision, η) invη2 = -cholinv(-η₂) - n = size(η₁, 1) - ident = Eye(n) - kronprod = invη2^2 * Eye(n^2) - Iₙ = PermutationMatrix(1, 1) - offdiag = - 1 / 4 * (invη2 * kron(ident, transpose(invη2 * η₁)) + invη2 * kron(η₁' * invη2, ident)) * - kron(ident, kron(Iₙ, ident)) - G = - -1 / 4 * - ( - kronprod * kron(ident, η₁) * kron(ident, transpose(invη2 * η₁)) + - kronprod * kron(η₁, ident) * kron(η₁' * invη2 * ident, ident) - ) * kron(ident, kron(Iₙ, ident)) + 1 / 2 * kronprod - - [-1/2*invη2*ident offdiag; offdiag' G] + return Diagonal([η₁..., invη2]) end -getfisherinformation(::MeanParametersSpace, ::Type{MvNormalMeanScalePrecision}) = (θ) -> begin - μ, γ = unpack_parameters(MvNormalMeanScalePrecision, θ) - n = size(μ, 1) - offdiag = zeros(n, n^2) - G = (1 / 2) * γ^2 * Eye(n^2) - [γ*Eye(n) offdiag; offdiag' G] -end +getfisherinformation(::MeanParametersSpace, ::Type{MvNormalMeanScalePrecision}) = + (η) -> begin + (η₁, η₂) = unpack_parameters(MvNormalMeanScalePrecision, η) + invη2 = -cholinv(-η₂) + return Diagonal([η₁..., invη2]) + end \ No newline at end of file From 9d6159a9c7f0b0ce14ca4f84b31f118102129bb0 Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Mon, 23 Sep 2024 16:22:22 +0200 Subject: [PATCH 16/32] test: add check that samples are correct --- .../mv_normal_mean_scale_precision_tests.jl | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl b/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl index 63286da9..4aad9529 100644 --- a/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl +++ b/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl @@ -29,9 +29,11 @@ end @testitem "MvNormalMeanScalePrecision: ExponentialFamilyDistribution" begin include("../distributions_setuptests.jl") + rng = StableRNG(42) + for s in 2:5 - μ = randn(s) - γ = rand() + μ = randn(rng, s) + γ = rand(rng) @testset let d = MvNormalMeanScalePrecision(μ, γ) ef = test_exponentialfamily_interface(d;) @@ -65,9 +67,10 @@ end @test length(dist) == 3 @test entropy(dist) ≈ entropy(rdist) @test pdf(dist, [0.2, 3.0, 4.0]) ≈ pdf(rdist, [0.2, 3.0, 4.0]) - @test pdf(dist, [0.202, 3.002, 4.002]) ≈ pdf(rdist, [0.202, 3.002, 4.002]) + @test pdf(dist, [0.202, 3.002, 4.002]) ≈ pdf(rdist, [0.202, 3.002, 4.002]) atol = 1e-4 @test logpdf(dist, [0.2, 3.0, 4.0]) ≈ logpdf(rdist, [0.2, 3.0, 4.0]) - @test logpdf(dist, [0.202, 3.002, 4.002]) ≈ logpdf(rdist, [0.202, 3.002, 4.002]) + @test logpdf(dist, [0.202, 3.002, 4.002]) ≈ logpdf(rdist, [0.202, 3.002, 4.002]) atol = 1e-4 + @test rand(StableRNG(42), dist, 1000) ≈ rand(StableRNG(42), rdist, 1000) end @testitem "MvNormalMeanScalePrecision: Base methods" begin From 2fb5717bb9a117c32ee7e6b9f06e36ebc7238af7 Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Mon, 23 Sep 2024 16:25:39 +0200 Subject: [PATCH 17/32] feat: implement getfisherinformation(::NaturalParametersSpace, ::Type{MvNormalMeanScalePrecision}) --- Project.toml | 3 +- .../mv_normal_mean_scale_precision.jl | 54 ++++++++++++------- 2 files changed, 38 insertions(+), 19 deletions(-) diff --git a/Project.toml b/Project.toml index 0a063ae9..be538333 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "1.5.1" [deps] BayesBase = "b4ee3484-f114-42fe-b91c-797d54a0c67e" +BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf" FastCholesky = "2d5283b6-8564-42b6-bb00-83ed8e915756" @@ -57,8 +58,8 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" CpuId = "adafc99b-e345-5852-983c-f28acb93d879" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" -RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" +RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/src/distributions/normal_family/mv_normal_mean_scale_precision.jl b/src/distributions/normal_family/mv_normal_mean_scale_precision.jl index e67a78c3..4477c403 100644 --- a/src/distributions/normal_family/mv_normal_mean_scale_precision.jl +++ b/src/distributions/normal_family/mv_normal_mean_scale_precision.jl @@ -3,6 +3,7 @@ export MvNormalMeanScalePrecision, MvGaussianMeanScalePrecision import Distributions: logdetcov, distrname, sqmahal, sqmahal!, AbstractMvNormal import LinearAlgebra: diag, Diagonal, dot import Base: ndims, precision, length, size, prod +import BlockArrays: Block, BlockArray, undef_blocks """ MvNormalMeanScalePrecision{T <: Real, M <: AbstractVector{T}} <: AbstractMvNormal @@ -66,7 +67,7 @@ end function (::MeanToNatural{MvNormalMeanScalePrecision})(tuple_of_θ::Tuple{Any, Any}) (μ, γ) = tuple_of_θ - return (γ * μ, γ / -2) + return (γ * μ, - γ / 2) end function (::NaturalToMean{MvNormalMeanScalePrecision})(tuple_of_η::Tuple{Any, Any}) @@ -140,11 +141,11 @@ function Distributions.sqmahal(dist::MvNormalMeanScalePrecision, x::AbstractVect end function Distributions.sqmahal!(r, dist::MvNormalMeanScalePrecision, x::AbstractVector) - μ = mean(dist) + μ, γ = params(dist) @inbounds @simd for i in 1:length(r) r[i] = μ[i] - x[i] end - return dot3arg(r, invcov(dist), r) # x' * A * x + return dot3arg(r, γ, r) # x' * A * x end Base.eltype(::MvNormalMeanScalePrecision{T}) where {T} = T @@ -184,7 +185,7 @@ end function BayesBase.rand(rng::AbstractRNG, dist::MvGaussianMeanScalePrecision{T}) where {T} μ, γ = mean(dist), scale(dist) - return μ + 1 / γ * I(length(μ)) * randn(rng, T, length(μ)) + return μ + 1 / γ .* randn(rng, T, length(μ)) end function BayesBase.rand(rng::AbstractRNG, dist::MvGaussianMeanScalePrecision{T}, size::Int64) where {T} @@ -218,34 +219,51 @@ isbasemeasureconstant(::Type{MvNormalMeanScalePrecision}) = ConstantBaseMeasure( getbasemeasure(::Type{MvNormalMeanScalePrecision}) = (x) -> (2π)^(-length(x) / 2) +getlogbasemeasure(::Type{MvNormalMeanScalePrecision}) = (x) -> -length(x) / 2 * log2π + getlogpartition(::NaturalParametersSpace, ::Type{MvNormalMeanScalePrecision}) = (η) -> begin η1 = @view η[1:end-1] η2 = η[end] k = length(η1) - Cinv = -inv(η2) - l = log(-inv(η2)) - return (dot(η1, Cinv, η1) / 2 - (k * log(2) + l)) / 2 + Cinv = inv(η2) + return -dot(η1, 1/4*Cinv, η1) - (k / 2)*log(-2*η2) end getgradlogpartition(::NaturalParametersSpace, ::Type{MvNormalMeanScalePrecision}) = (η) -> begin η1 = @view η[1:end-1] η2 = η[end] - Cinv = log(-inv(η2)) - return pack_parameters(MvNormalMeanCovariance, (0.5 * Cinv * η1, 0.25 * Cinv^2 * dot(η1,η1) + 0.5 * Cinv)) + inv2 = inv(η2) + k = length(η1) + return pack_parameters(MvNormalMeanCovariance, (-1/(2*η2) * η1, dot(η1,η1) / 4*inv2^2 - k/2 * inv2)) end -getfisherinformation(::NaturalParametersSpace, ::Type{MvNormalMeanScalePrecision}) = +getfisherinformation(::NaturalParametersSpace, ::Type{MvNormalMeanScalePrecision}) = (η) -> begin - (η₁, η₂) = unpack_parameters(MvNormalMeanScalePrecision, η) - invη2 = -cholinv(-η₂) - return Diagonal([η₁..., invη2]) + η1 = @view η[1:end-1] + η2 = η[end] + k = length(η1) + + inv_η2 = inv(η2) + η1_part = -1/(2*inv_η2)* I(length(η1)) + η1η2 = zeros(k, 1) + η1η2 .= 2*η1/inv_η2^2 + + η2_part = zeros(1, 1) + η2_part .= -dot(η1,η1) / 2*inv_η2^3 + k/(2inv_η2) + + fisher = BlockArray{eltype(η)}(undef_blocks, [k, 1], [k, 1]) + + fisher[Block(1), Block(1)] = η1_part + fisher[Block(1), Block(2)] = η1η2 + fisher[Block(2), Block(1)] = η1η2' + fisher[Block(2), Block(2)] = η2_part + return fisher end + -getfisherinformation(::MeanParametersSpace, ::Type{MvNormalMeanScalePrecision}) = - (η) -> begin - (η₁, η₂) = unpack_parameters(MvNormalMeanScalePrecision, η) - invη2 = -cholinv(-η₂) - return Diagonal([η₁..., invη2]) + getfisherinformation(::MeanParametersSpace, ::Type{NormalMeanVariance}) = (θ) -> begin + (_, σ²) = unpack_parameters(NormalMeanVariance, θ) + return SA[inv(σ²) 0; 0 inv(2 * (σ²^2))] end \ No newline at end of file From 89a4932d5e451706d015d3828407c96851736c74 Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Mon, 23 Sep 2024 16:25:57 +0200 Subject: [PATCH 18/32] feat: implement getfisherinformation(::NaturalParametersSpace, ::Type{MvNormalMeanScalePrecision}) --- .../normal_family/mv_normal_mean_scale_precision.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/distributions/normal_family/mv_normal_mean_scale_precision.jl b/src/distributions/normal_family/mv_normal_mean_scale_precision.jl index 4477c403..18ff40e0 100644 --- a/src/distributions/normal_family/mv_normal_mean_scale_precision.jl +++ b/src/distributions/normal_family/mv_normal_mean_scale_precision.jl @@ -263,7 +263,5 @@ getfisherinformation(::NaturalParametersSpace, ::Type{MvNormalMeanScalePrecision end - getfisherinformation(::MeanParametersSpace, ::Type{NormalMeanVariance}) = (θ) -> begin - (_, σ²) = unpack_parameters(NormalMeanVariance, θ) - return SA[inv(σ²) 0; 0 inv(2 * (σ²^2))] - end \ No newline at end of file +getfisherinformation(::MeanParametersSpace, ::Type{MvNormalMeanScalePrecision}) = + error("MeanParametersSpaceFisher is not implemented for MvNormalMeanScalePrecision") \ No newline at end of file From e3199637e4d6ab92290eff685501b9d48431ddd7 Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Tue, 24 Sep 2024 13:39:35 +0200 Subject: [PATCH 19/32] fix: correct getfisherinformation(::MeanParametersSpace, ::Type{MvNormalMeanScalePrecision}) --- .../mv_normal_mean_scale_precision.jl | 25 +++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/src/distributions/normal_family/mv_normal_mean_scale_precision.jl b/src/distributions/normal_family/mv_normal_mean_scale_precision.jl index 18ff40e0..d97522a4 100644 --- a/src/distributions/normal_family/mv_normal_mean_scale_precision.jl +++ b/src/distributions/normal_family/mv_normal_mean_scale_precision.jl @@ -263,5 +263,26 @@ getfisherinformation(::NaturalParametersSpace, ::Type{MvNormalMeanScalePrecision end -getfisherinformation(::MeanParametersSpace, ::Type{MvNormalMeanScalePrecision}) = - error("MeanParametersSpaceFisher is not implemented for MvNormalMeanScalePrecision") \ No newline at end of file +getfisherinformation(::MeanParametersSpace, ::Type{MvNormalMeanScalePrecision}) = + (θ) -> begin + μ = @view θ[1:end-1] + γ = θ[end] + k = length(μ) + + μ_part = γ * I(k) + + μγ_part = zeros(k, 1) + μγ_part .= μ + + γ_part = zeros(1, 1) + γ_part .= k/(2*γ^2) + + fisher = BlockArray{eltype(θ)}(undef_blocks, [k, 1], [k, 1]) + + fisher[Block(1), Block(1)] = μ_part + fisher[Block(1), Block(2)] = μγ_part + fisher[Block(2), Block(1)] = μγ_part' + fisher[Block(2), Block(2)] = γ_part + + return fisher + end \ No newline at end of file From 77f4a0d83c4321cdbb8df969b500ee015c2f6c34 Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Tue, 24 Sep 2024 14:49:58 +0200 Subject: [PATCH 20/32] test: use test_exponentialfamily_interface and add MvNormalMeanScalePrecision efficency test fix: remove unneeded code fix: remove not needed stuff fix: remove unused code test: add efficency test fix: return distributions_setuptests to HEAD test(fix): typo test(fix): remove unneeded testset test(fix): update efficency test --- Project.toml | 1 + .../mv_normal_mean_scale_precision.jl | 15 +++-- .../distributions/distributions_setuptests.jl | 2 +- .../mv_normal_mean_scale_precision_tests.jl | 64 ++++++++++++++++++- 4 files changed, 73 insertions(+), 9 deletions(-) diff --git a/Project.toml b/Project.toml index be538333..ffe0e999 100644 --- a/Project.toml +++ b/Project.toml @@ -30,6 +30,7 @@ TinyHugeNumbers = "783c9a47-75a3-44ac-a16b-f1ab7b3acf04" [compat] Aqua = "0.8.7" BayesBase = "1.2" +BlockArrays = "1.1.1" Distributions = "0.25" DomainSets = "0.5.2, 0.6, 0.7" FastCholesky = "1.0" diff --git a/src/distributions/normal_family/mv_normal_mean_scale_precision.jl b/src/distributions/normal_family/mv_normal_mean_scale_precision.jl index d97522a4..fb7231d6 100644 --- a/src/distributions/normal_family/mv_normal_mean_scale_precision.jl +++ b/src/distributions/normal_family/mv_normal_mean_scale_precision.jl @@ -245,13 +245,14 @@ getfisherinformation(::NaturalParametersSpace, ::Type{MvNormalMeanScalePrecision η2 = η[end] k = length(η1) - inv_η2 = inv(η2) - η1_part = -1/(2*inv_η2)* I(length(η1)) + η1_part = -inv(2*η2)* I(length(η1)) η1η2 = zeros(k, 1) - η1η2 .= 2*η1/inv_η2^2 + η1η2 .= η1*inv(2*η2^2) + #η₁/(2abs2(η₂)) η2_part = zeros(1, 1) - η2_part .= -dot(η1,η1) / 2*inv_η2^3 + k/(2inv_η2) + η2_part .= k*inv(2abs2(η2)) - dot(η1,η1) / (2*η2^3) + # inv(2abs2(η₂))-abs2(η₁)/(2(η₂^3)) fisher = BlockArray{eltype(η)}(undef_blocks, [k, 1], [k, 1]) @@ -272,10 +273,10 @@ getfisherinformation(::MeanParametersSpace, ::Type{MvNormalMeanScalePrecision}) μ_part = γ * I(k) μγ_part = zeros(k, 1) - μγ_part .= μ + μγ_part .= 0 γ_part = zeros(1, 1) - γ_part .= k/(2*γ^2) + γ_part .= k*inv(2abs2(γ)) fisher = BlockArray{eltype(θ)}(undef_blocks, [k, 1], [k, 1]) @@ -285,4 +286,4 @@ getfisherinformation(::MeanParametersSpace, ::Type{MvNormalMeanScalePrecision}) fisher[Block(2), Block(2)] = γ_part return fisher - end \ No newline at end of file + end diff --git a/test/distributions/distributions_setuptests.jl b/test/distributions/distributions_setuptests.jl index a6186e20..7292f8f7 100644 --- a/test/distributions/distributions_setuptests.jl +++ b/test/distributions/distributions_setuptests.jl @@ -557,4 +557,4 @@ function test_generic_simple_exponentialfamily_product( end return true -end +end \ No newline at end of file diff --git a/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl b/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl index 4aad9529..f3d695c2 100644 --- a/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl +++ b/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl @@ -31,7 +31,7 @@ end rng = StableRNG(42) - for s in 2:5 + for s in 1:6 μ = randn(rng, s) γ = rand(rng) @@ -39,6 +39,19 @@ end ef = test_exponentialfamily_interface(d;) end end + + μ = randn(rng, 1) + γ = rand(rng) + + d = MvNormalMeanScalePrecision(μ, γ) + ef = convert(ExponentialFamilyDistribution, d) + + d1d = NormalMeanPrecision(μ[1], γ) + ef1d = convert(ExponentialFamilyDistribution, d1d) + + @test logpartition(ef) ≈ logpartition(ef1d) + @test gradlogpartition(ef) ≈ gradlogpartition(ef1d) + @test fisherinformation(ef) ≈ fisherinformation(ef1d) end @testitem "MvNormalMeanScalePrecision: Stats methods" begin @@ -164,3 +177,52 @@ end end end end + +@testitem "MvNormalMeanScalePrecision: Fisher is faster then for full parametrization" begin + include("./normal_family_setuptests.jl") + using BenchmarkTools + using LinearAlgebra + using JET + + rng = StableRNG(42) + for k in 20:40 + μ = randn(rng, k) + γ = rand(rng) + cov = γ * I(k) + + ef_small = convert(ExponentialFamilyDistribution, MvNormalMeanScalePrecision(μ, γ)) + ef_full = convert(ExponentialFamilyDistribution, MvNormalMeanCovariance(μ, cov)) + + fi_small = fisherinformation(ef_small) + fi_full = fisherinformation(ef_full) + + @test_opt fisherinformation(ef_small) + @test_opt fisherinformation(ef_full) + + fi_mvsp_time = @elapsed fisherinformation(ef_small) + fi_mvsp_alloc = @allocated fisherinformation(ef_small) + + fi_full_time = @elapsed fisherinformation(ef_full) + fi_full_alloc = @allocated fisherinformation(ef_full) + + @test_opt cholinv(fi_small) + @test_opt cholinv(fi_full) + + cholinv_time_small = @elapsed cholinv(fi_small) + cholinv_alloc_small = @allocated fisherinformation(ef_small) + + cholinv_time_full = @elapsed cholinv(fi_full) + cholinv_alloc_full = @allocated cholinv(fi_full) + + fi_small = fisherinformation(ef_small) + fi_full = fisherinformation(ef_full) + + # small time is supposed to be O(k) and full time is supposed to O(k^2) + # the constant C is selected to account to fluctuations in test runs + C = 0.9 + @test fi_mvsp_time < fi_full_time/(C*k) + @test fi_mvsp_alloc < fi_full_alloc/(C*k) + @test cholinv_time_small < cholinv_time_full/(C*k) + @test cholinv_alloc_small < cholinv_alloc_full/(C*k) + end +end \ No newline at end of file From 0b875690bb4b0aae46312eac448d0467b4219bba Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Wed, 25 Sep 2024 17:42:37 +0200 Subject: [PATCH 21/32] Delete test/repopack-output.txt --- test/repopack-output.txt | 5221 -------------------------------------- 1 file changed, 5221 deletions(-) delete mode 100644 test/repopack-output.txt diff --git a/test/repopack-output.txt b/test/repopack-output.txt deleted file mode 100644 index 156907a3..00000000 --- a/test/repopack-output.txt +++ /dev/null @@ -1,5221 +0,0 @@ -This file is a merged representation of the entire codebase, combining all repository files into a single document. -Generated by Repopack on: 2024-09-19T14:22:03.493Z - -================================================================ -File Summary -================================================================ - -Purpose: --------- -This file contains a packed representation of the entire repository's contents. -It is designed to be easily consumable by AI systems for analysis, code review, -or other automated processes. - -File Format: ------------- -The content is organized as follows: -1. This summary section -2. Repository information -3. Repository structure -4. Multiple file entries, each consisting of: - a. A separator line (================) - b. The file path (File: path/to/file) - c. Another separator line - d. The full contents of the file - e. A blank line - -Usage Guidelines: ------------------ -- This file should be treated as read-only. Any changes should be made to the - original repository files, not this packed version. -- When processing this file, use the file path to distinguish - between different files in the repository. -- Be aware that this file may contain sensitive information. Handle it with - the same level of security as you would the original repository. - -Notes: ------- -- Some files may have been excluded based on .gitignore rules and Repopack's - configuration. -- Binary files are not included in this packed representation. Please refer to - the Repository Structure section for a complete list of file paths, including - binary files. - -Additional Info: ----------------- - -For more information about Repopack, visit: https://github.com/yamadashy/repopack - -================================================================ -Repository Structure -================================================================ -distributions/ - gamma_family/ - gamma_family_setuptests.jl - gamma_family_tests.jl - gamma_shape_rate_tests.jl - gamma_shape_scale_tests.jl - normal_family/ - mv_normal_mean_covariance_tests.jl - mv_normal_mean_precision_tests.jl - mv_normal_weighted_mean_precision_tests.jl - normal_family_setuptests.jl - normal_family_tests.jl - normal_mean_precision_tests.jl - normal_mean_variance_tests.jl - normal_weighted_mean_precision_tests.jl - wip/ - test_continuous_bernoulli.jl - test_multinomial.jl - bernoulli_tests.jl - beta_tests.jl - binomial_tests.jl - categorical_tests.jl - chi_squared_tests.jl - dirichlet_tests.jl - distributions_setuptests.jl - erlang_tests.jl - exponential_tests.jl - gamma_inverse_tests.jl - geometric_tests.jl - laplace_tests.jl - lognormal_tests.jl - matrix_dirichlet_tests.jl - mv_normal_wishart_tests.jl - negative_binomial_tests.jl - normal_gamma_tests.jl - pareto_tests.jl - poisson_tests.jl - rayleigh_tests.jl - von_mises_fisher_tests.jl - vonmises_tests.jl - weibull_tests.jl - wishart_inverse_tests.jl - wishart_tests.jl -common_tests.jl -exponential_family_setuptests.jl -exponential_family_tests.jl -runtests.jl - -================================================================ -Repository Files -================================================================ - -================ -File: distributions/gamma_family/gamma_family_setuptests.jl -================ -include("../distributions_setuptests.jl") - -import ExponentialFamily: xtlog - -function compare_basic_statistics(left, right) - @test mean(left) ≈ mean(right) - @test var(left) ≈ var(right) - @test cov(left) ≈ cov(right) - @test shape(left) ≈ shape(right) - @test scale(left) ≈ scale(right) - @test rate(left) ≈ rate(right) - @test entropy(left) ≈ entropy(right) - @test pdf(left, 1.0) ≈ pdf(right, 1.0) - @test pdf(left, 10.0) ≈ pdf(right, 10.0) - @test logpdf(left, 1.0) ≈ logpdf(right, 1.0) - @test logpdf(left, 10.0) ≈ logpdf(right, 10.0) - - @test mean(log, left) ≈ mean(log, right) - @test mean(loggamma, left) ≈ mean(loggamma, right) - @test mean(xtlog, left) ≈ mean(xtlog, right) - - return true -end - -================ -File: distributions/gamma_family/gamma_family_tests.jl -================ -@testitem "GammaFamily: Base statistical methods" begin - include("./gamma_family_setuptests.jl") - - types = union_types(GammaDistributionsFamily{Float64}) - rng = MersenneTwister(1234) - for _ in 1:10 - for type in types - left = convert(type, 100 * rand(rng, Float64), 100 * rand(rng, Float64)) - for type in types - right = convert(type, left) - @test compare_basic_statistics(left, right) - - @test all(params(MeanParametersSpace(), left) .== (shape(left), scale(left))) - @test all(params(MeanParametersSpace(), right) .== (shape(right), scale(right))) - end - end - end -end - -@testitem "GammaFamily: ExponentialFamilyDistribution" begin - include("./gamma_family_setuptests.jl") - - for k in (0.1, 2.0, 5.0), θ in (0.1, 2.0, 5.0), T in union_types(GammaDistributionsFamily{Float64}) - @testset let d = convert(T, GammaShapeScale(k, θ)) - ef = test_exponentialfamily_interface(d; option_assume_no_allocations = true) - - (η₁, η₂) = (shape(d) - 1, -inv(scale(d))) - - for x in 0.1:0.5:5.0 - @test @inferred(isbasemeasureconstant(ef)) === ConstantBaseMeasure() - @test @inferred(basemeasure(ef, x)) === oneunit(x) - @test all(@inferred(sufficientstatistics(ef, x)) .=== (log(x), x)) - @test @inferred(logpartition(ef)) ≈ loggamma(η₁ + 1) - (η₁ + 1) * log(-η₂) - end - - @test @inferred(insupport(ef, 0.5)) - @test !@inferred(insupport(ef, -0.5)) - - # # Not in the support - @test_throws Exception logpdf(ef, -0.5) - end - end - - # Test failing isproper cases - @test !isproper(MeanParametersSpace(), Gamma, [-1]) - @test !isproper(MeanParametersSpace(), Gamma, [1, -1]) - @test !isproper(MeanParametersSpace(), Gamma, [-1, -1]) - @test !isproper(NaturalParametersSpace(), Gamma, [-1]) - @test !isproper(NaturalParametersSpace(), Gamma, [1, 10]) - @test !isproper(NaturalParametersSpace(), Gamma, [-100, -1]) - - # shapes must add up to something more than 1, otherwise is not proper - let ef = convert(ExponentialFamilyDistribution, Gamma(0.1, 1.0)) - @test !isproper(prod(PreserveTypeProd(ExponentialFamilyDistribution), ef, ef)) - end -end - -@testitem "GammaFamily: prod with ExponentialFamilyDistribution" begin - include("./gamma_family_setuptests.jl") - - for kleft in 0.51:1.0:5.0, kright in 0.51:1.0:5.0, θleft in 0.1:1.0:5.0, θright in 0.1:1.0:5.0, Tleft in union_types(GammaDistributionsFamily{Float64}), - Tright in union_types(GammaDistributionsFamily{Float64}) - - @testset let (left, right) = (convert(Tleft, Gamma(kleft, θleft)), convert(Tright, Gamma(kright, θright))) - @test test_generic_simple_exponentialfamily_product( - left, - right, - strategies = ( - ClosedProd(), - GenericProd(), - PreserveTypeProd(ExponentialFamilyDistribution), - PreserveTypeProd(ExponentialFamilyDistribution{Gamma}) - ) - ) - end - end -end - -================ -File: distributions/gamma_family/gamma_shape_rate_tests.jl -================ -@testitem "GammaShapeRate: Constructor" begin - include("./gamma_family_setuptests.jl") - - @test GammaShapeScale <: GammaDistributionsFamily - - @test GammaShapeRate() == GammaShapeRate{Float64}(1.0, 1.0) - @test GammaShapeRate(1.0) == GammaShapeRate{Float64}(1.0, 1.0) - @test GammaShapeRate(1.0, 2.0) == GammaShapeRate{Float64}(1.0, 2.0) - @test GammaShapeRate(1) == GammaShapeRate{Float64}(1.0, 1.0) - @test GammaShapeRate(1, 2) == GammaShapeRate{Float64}(1.0, 2.0) - @test GammaShapeRate(1.0, 2) == GammaShapeRate{Float64}(1.0, 2.0) - @test GammaShapeRate(1, 2.0) == GammaShapeRate{Float64}(1.0, 2.0) - @test GammaShapeRate(1.0f0) == GammaShapeRate{Float32}(1.0f0, 1.0f0) - @test GammaShapeRate(1.0f0, 2.0f0) == GammaShapeRate{Float32}(1.0f0, 2.0f0) - @test GammaShapeRate(1.0f0, 2) == GammaShapeRate{Float32}(1.0f0, 2.0f0) - @test GammaShapeRate(1.0f0, 2.0) == GammaShapeRate{Float64}(1.0, 2.0) - - @test paramfloattype(GammaShapeRate(1.0, 2.0)) === Float64 - @test paramfloattype(GammaShapeRate(1.0f0, 2.0f0)) === Float32 - - @test convert(GammaShapeRate{Float32}, GammaShapeRate()) == GammaShapeRate{Float32}(1.0f0, 1.0f0) - @test convert(GammaShapeRate{Float64}, GammaShapeRate(1.0, 10.0)) == GammaShapeRate{Float64}(1.0, 10.0) - @test convert(GammaShapeRate{Float64}, GammaShapeRate(1.0, 0.1)) == GammaShapeRate{Float64}(1.0, 0.1) - @test convert(GammaShapeRate{Float64}, 1, 1) == GammaShapeRate{Float64}(1.0, 1.0) - @test convert(GammaShapeRate{Float64}, 1, 10) == GammaShapeRate{Float64}(1.0, 10.0) - @test convert(GammaShapeRate{Float64}, 1.0, 0.1) == GammaShapeRate{Float64}(1.0, 0.1) - - @test convert(GammaShapeRate, GammaShapeRate(2.0, 2.0)) == GammaShapeRate{Float64}(2.0, 2.0) - @test convert(GammaShapeScale, GammaShapeRate(2.0, 2.0)) == GammaShapeScale{Float64}(2.0, 1.0 / 2.0) - - @test convert(GammaShapeRate, GammaShapeScale(2.0, 2.0)) == GammaShapeRate{Float64}(2.0, 1.0 / 2.0) - @test convert(GammaShapeScale, GammaShapeScale(2.0, 2.0)) == GammaShapeScale{Float64}(2.0, 2.0) -end - -@testitem "GammaShapeRate: vague" begin - include("./gamma_family_setuptests.jl") - - @test vague(GammaShapeRate) == GammaShapeRate(1.0, 1e-12) -end - -@testitem "GammaShapeRate: stats methods" begin - include("./gamma_family_setuptests.jl") - - dist1 = GammaShapeRate(1.0, 1.0) - - @test mean(dist1) === 1.0 - @test var(dist1) === 1.0 - @test cov(dist1) === 1.0 - @test shape(dist1) === 1.0 - @test scale(dist1) === 1.0 - @test rate(dist1) === 1.0 - @test entropy(dist1) ≈ 1.0 - @test pdf(dist1, 1.0) ≈ 0.36787944117144233 - @test logpdf(dist1, 1.0) ≈ -1.0 - - dist2 = GammaShapeRate(1.0, 2.0) - - @test mean(dist2) === inv(2.0) - @test var(dist2) === inv(4.0) - @test cov(dist2) === inv(4.0) - @test shape(dist2) === 1.0 - @test scale(dist2) === inv(2.0) - @test rate(dist2) === 2.0 - @test entropy(dist2) ≈ 0.3068528194400547 - @test pdf(dist2, 1.0) ≈ 0.2706705664732254 - @test logpdf(dist2, 1.0) ≈ -1.3068528194400546 - - dist3 = GammaShapeRate(2.0, 2.0) - - @test mean(dist3) === 1.0 - @test var(dist3) === inv(2.0) - @test cov(dist3) === inv(2.0) - @test shape(dist3) === 2.0 - @test scale(dist3) === inv(2.0) - @test rate(dist3) === 2.0 - @test entropy(dist3) ≈ 0.8840684843415857 - @test pdf(dist3, 1.0) ≈ 0.5413411329464508 - @test logpdf(dist3, 1.0) ≈ -0.6137056388801094 - - # see https://github.com/ReactiveBayes/ReactiveMP.jl/issues/314 - dist = GammaShapeRate(257.37489915581654, 3.0) - @test pdf(dist, 86.2027941354432) == 0.07400338986721687 -end - -@testitem "GammaShapeRate: prod" begin - include("./gamma_family_setuptests.jl") - - for strategy in (ClosedProd(), PreserveTypeProd(Distribution), PreserveTypeLeftProd(), PreserveTypeRightProd(), GenericProd()) - @test prod(strategy, GammaShapeScale(1, 1), GammaShapeScale(1, 1)) == GammaShapeScale(1, 1 / 2) - @test prod(strategy, GammaShapeScale(1, 2), GammaShapeScale(1, 1)) == GammaShapeScale(1, 2 / 3) - @test prod(strategy, GammaShapeScale(1, 2), GammaShapeScale(1, 2)) == GammaShapeScale(1, 1) - @test prod(strategy, GammaShapeScale(2, 2), GammaShapeScale(1, 2)) == GammaShapeScale(2, 1) - @test prod(strategy, GammaShapeScale(2, 2), GammaShapeScale(2, 2)) == GammaShapeScale(3, 1) - @test prod(strategy, GammaShapeScale(1, 1), GammaShapeRate(1, 1)) == GammaShapeScale(1, 1 / 2) - @test prod(strategy, GammaShapeScale(1, 2), GammaShapeRate(1, 1)) == GammaShapeScale(1, 2 / 3) - @test prod(strategy, GammaShapeScale(1, 2), GammaShapeRate(1, 2)) == GammaShapeScale(1, 2 / 5) - @test prod(strategy, GammaShapeScale(2, 2), GammaShapeRate(1, 2)) == GammaShapeScale(2, 2 / 5) - @test prod(strategy, GammaShapeScale(2, 2), GammaShapeRate(2, 2)) == GammaShapeScale(3, 2 / 5) - end -end - -================ -File: distributions/gamma_family/gamma_shape_scale_tests.jl -================ -@testitem "GammaShapeScale: Constructor" begin - include("./gamma_family_setuptests.jl") - - @test GammaShapeScale <: GammaDistributionsFamily - - @test GammaShapeScale() == GammaShapeScale{Float64}(1.0, 1.0) - @test GammaShapeScale(1.0) == GammaShapeScale{Float64}(1.0, 1.0) - @test GammaShapeScale(1.0, 2.0) == GammaShapeScale{Float64}(1.0, 2.0) - @test GammaShapeScale(1) == GammaShapeScale{Float64}(1.0, 1.0) - @test GammaShapeScale(1, 2) == GammaShapeScale{Float64}(1.0, 2.0) - @test GammaShapeScale(1.0, 2) == GammaShapeScale{Float64}(1.0, 2.0) - @test GammaShapeScale(1, 2.0) == GammaShapeScale{Float64}(1.0, 2.0) - @test GammaShapeScale(1.0f0) == GammaShapeScale{Float32}(1.0f0, 1.0f0) - @test GammaShapeScale(1.0f0, 2.0f0) == GammaShapeScale{Float32}(1.0f0, 2.0f0) - @test GammaShapeScale(1.0f0, 2) == GammaShapeScale{Float32}(1.0f0, 2.0f0) - @test GammaShapeScale(1.0f0, 2.0) == GammaShapeScale{Float64}(1.0, 2.0) - - @test paramfloattype(GammaShapeScale(1.0, 2.0)) === Float64 - @test paramfloattype(GammaShapeScale(1.0f0, 2.0f0)) === Float32 - - @test convert(GammaShapeScale{Float32}, GammaShapeScale()) == GammaShapeScale{Float32}(1.0f0, 1.0f0) - @test convert(GammaShapeScale{Float64}, GammaShapeScale(1.0, 10.0)) == GammaShapeScale{Float64}(1.0, 10.0) - @test convert(GammaShapeScale{Float64}, GammaShapeScale(1.0, 0.1)) == GammaShapeScale{Float64}(1.0, 0.1) - @test convert(GammaShapeScale{Float64}, 1, 1) == GammaShapeScale{Float64}(1.0, 1.0) - @test convert(GammaShapeScale{Float64}, 1, 10) == GammaShapeScale{Float64}(1.0, 10.0) - @test convert(GammaShapeScale{Float64}, 1.0, 0.1) == GammaShapeScale{Float64}(1.0, 0.1) - - @test convert(GammaShapeRate, GammaShapeScale(2.0, 2.0)) == GammaShapeRate{Float64}(2.0, 1.0 / 2.0) - @test convert(GammaShapeScale, GammaShapeScale(2.0, 2.0)) == GammaShapeScale{Float64}(2.0, 2.0) -end - -@testitem "GammaShapeScale: vague" begin - include("./gamma_family_setuptests.jl") - - @test vague(GammaShapeScale) == GammaShapeScale(1.0, 1e12) -end - -@testitem "GammaShapeScale: stats methods" begin - include("./gamma_family_setuptests.jl") - - dist1 = GammaShapeScale(1.0, 1.0) - - @test mean(dist1) === 1.0 - @test var(dist1) === 1.0 - @test cov(dist1) === 1.0 - @test shape(dist1) === 1.0 - @test scale(dist1) === 1.0 - @test rate(dist1) === 1.0 - @test entropy(dist1) ≈ 1.0 - - dist2 = GammaShapeScale(1.0, 2.0) - - @test mean(dist2) === 2.0 - @test var(dist2) === 4.0 - @test cov(dist2) === 4.0 - @test shape(dist2) === 1.0 - @test scale(dist2) === 2.0 - @test rate(dist2) === inv(2.0) - @test entropy(dist2) ≈ 1.6931471805599454 - - dist3 = GammaShapeScale(2.0, 2.0) - - @test mean(dist3) === 4.0 - @test var(dist3) === 8.0 - @test cov(dist3) === 8.0 - @test shape(dist3) === 2.0 - @test scale(dist3) === 2.0 - @test rate(dist3) === inv(2.0) - @test entropy(dist3) ≈ 2.2703628454614764 -end - -@testitem "GammaShapeScale: prod" begin - include("./gamma_family_setuptests.jl") - - for strategy in (ClosedProd(), PreserveTypeProd(Distribution), PreserveTypeLeftProd(), PreserveTypeRightProd(), GenericProd()) - @test prod(strategy, GammaShapeRate(1, 1), GammaShapeRate(1, 1)) == GammaShapeRate(1, 2) - @test prod(strategy, GammaShapeRate(1, 2), GammaShapeRate(1, 1)) == GammaShapeRate(1, 3) - @test prod(strategy, GammaShapeRate(1, 2), GammaShapeRate(1, 2)) == GammaShapeRate(1, 4) - @test prod(strategy, GammaShapeRate(2, 2), GammaShapeRate(1, 2)) == GammaShapeRate(2, 4) - @test prod(strategy, GammaShapeRate(2, 2), GammaShapeRate(2, 2)) == GammaShapeRate(3, 4) - @test prod(strategy, GammaShapeRate(1, 1), GammaShapeScale(1, 1)) == GammaShapeRate(1, 2) - @test prod(strategy, GammaShapeRate(1, 2), GammaShapeScale(1, 1)) == GammaShapeRate(1, 3) - @test prod(strategy, GammaShapeRate(1, 2), GammaShapeScale(1, 2)) == GammaShapeRate(1, 5 / 2) - @test prod(strategy, GammaShapeRate(2, 2), GammaShapeScale(1, 2)) == GammaShapeRate(2, 5 / 2) - @test prod(strategy, GammaShapeRate(2, 2), GammaShapeScale(2, 2)) == GammaShapeRate(3, 5 / 2) - end -end - -================ -File: distributions/normal_family/mv_normal_mean_covariance_tests.jl -================ -@testitem "MvNormalMeanCovariance: Constructor" begin - include("./normal_family_setuptests.jl") - - @test MvNormalMeanCovariance <: AbstractMvNormal - - @test MvNormalMeanCovariance([1.0, 1.0]) == MvNormalMeanCovariance([1.0, 1.0], [1.0, 1.0]) - @test MvNormalMeanCovariance([1.0, 2.0]) == MvNormalMeanCovariance([1.0, 2.0], [1.0, 1.0]) - @test MvNormalMeanCovariance([1, 2]) == MvNormalMeanCovariance([1.0, 2.0], [1.0, 1.0]) - @test MvNormalMeanCovariance([1.0f0, 2.0f0]) == MvNormalMeanCovariance([1.0f0, 2.0f0], [1.0f0, 1.0f0]) - - @test eltype(MvNormalMeanCovariance([1.0, 1.0])) === Float64 - @test eltype(MvNormalMeanCovariance([1.0, 1.0], [1.0, 1.0])) === Float64 - @test eltype(MvNormalMeanCovariance([1, 1])) === Float64 - @test eltype(MvNormalMeanCovariance([1, 1], [1, 1])) === Float64 - @test eltype(MvNormalMeanCovariance([1.0f0, 1.0f0])) === Float32 - @test eltype(MvNormalMeanCovariance([1.0f0, 1.0f0], [1.0f0, 1.0f0])) === Float32 - - @test MvNormalMeanCovariance(ones(3), 5I) == MvNormalMeanCovariance(ones(3), Diagonal(5 * ones(3))) - @test MvNormalMeanCovariance([1, 2, 3, 4], 7.0I) == MvNormalMeanCovariance([1.0, 2.0, 3.0, 4.0], Diagonal(7.0 * ones(4))) -end - -@testitem "MvNormalMeanCovariance: distrname" begin - include("./normal_family_setuptests.jl") - - @test ExponentialFamily.distrname(MvNormalMeanCovariance(zeros(2))) === "MvNormalMeanCovariance" -end - -@testitem "MvNormalMeanCovariance: Stats methods" begin - include("./normal_family_setuptests.jl") - - μ = [0.2, 3.0, 4.0] - Σ = [1.5 -0.3 0.1; -0.3 1.8 0.0; 0.1 0.0 3.5] - dist = MvNormalMeanCovariance(μ, Σ) - - @test mean(dist) == μ - @test mode(dist) == μ - @test weightedmean(dist) ≈ inv(Σ) * μ - @test invcov(dist) ≈ inv(Σ) - @test precision(dist) ≈ inv(Σ) - @test cov(dist) == Σ - @test std(dist) * std(dist)' ≈ Σ - @test all(mean_cov(dist) .≈ (μ, Σ)) - @test all(mean_invcov(dist) .≈ (μ, inv(Σ))) - @test all(mean_precision(dist) .≈ (μ, inv(Σ))) - @test all(weightedmean_cov(dist) .≈ (inv(Σ) * μ, Σ)) - @test all(weightedmean_invcov(dist) .≈ (inv(Σ) * μ, inv(Σ))) - @test all(weightedmean_precision(dist) .≈ (inv(Σ) * μ, inv(Σ))) - - @test length(dist) == 3 - @test entropy(dist) ≈ 5.361886000915401 - @test pdf(dist, [0.2, 3.0, 4.0]) ≈ 0.021028302702542 - @test pdf(dist, [0.202, 3.002, 4.002]) ≈ 0.021028229679079503 - @test logpdf(dist, [0.2, 3.0, 4.0]) ≈ -3.8618860009154012 - @test logpdf(dist, [0.202, 3.002, 4.002]) ≈ -3.861889473548943 -end - -@testitem "MvNormalMeanCovariance: Base methods" begin - include("./normal_family_setuptests.jl") - - @test convert(MvNormalMeanCovariance{Float32}, MvNormalMeanCovariance([0.0, 0.0])) == - MvNormalMeanCovariance([0.0f0, 0.0f0], [1.0f0, 1.0f0]) - @test convert(MvNormalMeanCovariance{Float64}, [0.0, 0.0], [2 0; 0 3]) == - MvNormalMeanCovariance([0.0, 0.0], [2.0 0.0; 0.0 3.0]) - - @test length(MvNormalMeanCovariance([0.0, 0.0])) === 2 - @test length(MvNormalMeanCovariance([0.0, 0.0, 0.0])) === 3 - @test ndims(MvNormalMeanCovariance([0.0, 0.0])) === 2 - @test ndims(MvNormalMeanCovariance([0.0, 0.0, 0.0])) === 3 - @test size(MvNormalMeanCovariance([0.0, 0.0])) === (2,) - @test size(MvNormalMeanCovariance([0.0, 0.0, 0.0])) === (3,) - - distribution = MvNormalMeanCovariance([0.0, 0.0], [2.0 0.0; 0.0 3.0]) - - @test distribution ≈ distribution - @test distribution ≈ convert(MvNormalMeanPrecision, distribution) - @test distribution ≈ convert(MvNormalWeightedMeanPrecision, distribution) -end - -@testitem "MvNormalMeanCovariance: vague" begin - include("./normal_family_setuptests.jl") - - @test_throws MethodError vague(MvNormalMeanCovariance) - - d1 = vague(MvNormalMeanCovariance, 2) - - @test typeof(d1) <: MvNormalMeanCovariance - @test mean(d1) == zeros(2) - @test cov(d1) == Matrix(Diagonal(1e12 * ones(2))) - @test ndims(d1) == 2 - - d2 = vague(MvNormalMeanCovariance, 3) - - @test typeof(d2) <: MvNormalMeanCovariance - @test mean(d2) == zeros(3) - @test cov(d2) == Matrix(Diagonal(1e12 * ones(3))) - @test ndims(d2) == 3 -end - -@testitem "MvNormalMeanCovariance: prod" begin - include("./normal_family_setuptests.jl") - - for strategy in (ClosedProd(), PreserveTypeProd(Distribution), GenericProd()) - @test prod(strategy, MvNormalMeanCovariance([-1, -1], [2, 2]), MvNormalMeanCovariance([1, 1], [2, 4])) ≈ - MvNormalWeightedMeanPrecision([0, -1 / 4], [1, 3 / 4]) - - μ = [1.0, 2.0, 3.0] - Σ = diagm([1.0, 2.0, 3.0]) - dist = MvNormalMeanCovariance(μ, Σ) - - @test prod(strategy, dist, dist) ≈ - MvNormalWeightedMeanPrecision([2.0, 2.0, 2.0], diagm([2.0, 1.0, 2 / 3])) - end -end - -@testitem "MvNormalMeanCovariance: convert" begin - include("./normal_family_setuptests.jl") - - @test convert(MvNormalMeanCovariance, zeros(2), Matrix(Diagonal(ones(2)))) == - MvNormalMeanCovariance(zeros(2), Matrix(Diagonal(ones(2)))) - @test begin - m = rand(5) - c = Matrix(Symmetric(rand(5, 5))) - convert(MvNormalMeanCovariance, m, c) == MvNormalMeanCovariance(m, c) - end -end - -================ -File: distributions/normal_family/mv_normal_mean_precision_tests.jl -================ -@testitem "MvNormalMeanPrecision: Constructor" begin - include("./normal_family_setuptests.jl") - - @test MvNormalMeanPrecision <: AbstractMvNormal - - @test MvNormalMeanPrecision([1.0, 1.0]) == MvNormalMeanPrecision([1.0, 1.0], [1.0, 1.0]) - @test MvNormalMeanPrecision([1.0, 2.0]) == MvNormalMeanPrecision([1.0, 2.0], [1.0, 1.0]) - @test MvNormalMeanPrecision([1, 2]) == MvNormalMeanPrecision([1.0, 2.0], [1.0, 1.0]) - @test MvNormalMeanPrecision([1.0f0, 2.0f0]) == MvNormalMeanPrecision([1.0f0, 2.0f0], [1.0f0, 1.0f0]) - - @test eltype(MvNormalMeanPrecision([1.0, 1.0])) === Float64 - @test eltype(MvNormalMeanPrecision([1.0, 1.0], [1.0, 1.0])) === Float64 - @test eltype(MvNormalMeanPrecision([1, 1])) === Float64 - @test eltype(MvNormalMeanPrecision([1, 1], [1, 1])) === Float64 - @test eltype(MvNormalMeanPrecision([1.0f0, 1.0f0])) === Float32 - @test eltype(MvNormalMeanPrecision([1.0f0, 1.0f0], [1.0f0, 1.0f0])) === Float32 - - @test MvNormalMeanPrecision(ones(3), 5I) == MvNormalMeanPrecision(ones(3), Diagonal(5 * ones(3))) - @test MvNormalMeanPrecision([1, 2, 3, 4], 7.0I) == MvNormalMeanPrecision([1.0, 2.0, 3.0, 4.0], Diagonal(7.0 * ones(4))) -end - -@testitem "MvNormalMeanPrecision: distrname" begin - include("./normal_family_setuptests.jl") - - @test ExponentialFamily.distrname(MvNormalMeanPrecision(zeros(2))) === "MvNormalMeanPrecision" -end - -@testitem "MvNormalMeanPrecision: Stats methods" begin - include("./normal_family_setuptests.jl") - - μ = [0.2, 3.0, 4.0] - Λ = [1.5 -0.3 0.1; -0.3 1.8 0.0; 0.1 0.0 3.5] - dist = MvNormalMeanPrecision(μ, Λ) - - @test mean(dist) == μ - @test mode(dist) == μ - @test weightedmean(dist) == Λ * μ - @test invcov(dist) == Λ - @test precision(dist) == Λ - @test cov(dist) ≈ inv(Λ) - @test std(dist) * std(dist)' ≈ inv(Λ) - @test all(mean_cov(dist) .≈ (μ, inv(Λ))) - @test all(mean_invcov(dist) .≈ (μ, Λ)) - @test all(mean_precision(dist) .≈ (μ, Λ)) - @test all(weightedmean_cov(dist) .≈ (Λ * μ, inv(Λ))) - @test all(weightedmean_invcov(dist) .≈ (Λ * μ, Λ)) - @test all(weightedmean_precision(dist) .≈ (Λ * μ, Λ)) - - @test length(dist) == 3 - @test entropy(dist) ≈ 3.1517451983126357 - @test pdf(dist, [0.2, 3.0, 4.0]) ≈ 0.19171503573907536 - @test pdf(dist, [0.202, 3.002, 4.002]) ≈ 0.19171258180232315 - @test logpdf(dist, [0.2, 3.0, 4.0]) ≈ -1.6517451983126357 - @test logpdf(dist, [0.202, 3.002, 4.002]) ≈ -1.6517579983126356 -end - -@testitem "MvNormalMeanPrecision: Base methods" begin - include("./normal_family_setuptests.jl") - - @test convert(MvNormalMeanPrecision{Float32}, MvNormalMeanPrecision([0.0, 0.0])) == - MvNormalMeanPrecision([0.0f0, 0.0f0], [1.0f0, 1.0f0]) - @test convert(MvNormalMeanPrecision{Float64}, [0.0, 0.0], [2 0; 0 3]) == - MvNormalMeanPrecision([0.0, 0.0], [2.0 0.0; 0.0 3.0]) - - @test length(MvNormalMeanPrecision([0.0, 0.0])) === 2 - @test length(MvNormalMeanPrecision([0.0, 0.0, 0.0])) === 3 - @test ndims(MvNormalMeanPrecision([0.0, 0.0])) === 2 - @test ndims(MvNormalMeanPrecision([0.0, 0.0, 0.0])) === 3 - @test size(MvNormalMeanPrecision([0.0, 0.0])) === (2,) - @test size(MvNormalMeanPrecision([0.0, 0.0, 0.0])) === (3,) - - distribution = MvNormalMeanPrecision([0.0, 0.0], [2.0 0.0; 0.0 3.0]) - - @test distribution ≈ distribution - @test distribution ≈ convert(MvNormalMeanCovariance, distribution) - @test distribution ≈ convert(MvNormalWeightedMeanPrecision, distribution) -end - -@testitem "MvNormalMeanPrecision: vague" begin - include("./normal_family_setuptests.jl") - - @test_throws MethodError vague(MvNormalMeanPrecision) - - d1 = vague(MvNormalMeanPrecision, 2) - - @test typeof(d1) <: MvNormalMeanPrecision - @test mean(d1) == zeros(2) - @test invcov(d1) == Matrix(Diagonal(1e-12 * ones(2))) - @test ndims(d1) == 2 - - d2 = vague(MvNormalMeanPrecision, 3) - - @test typeof(d2) <: MvNormalMeanPrecision - @test mean(d2) == zeros(3) - @test invcov(d2) == Matrix(Diagonal(1e-12 * ones(3))) - @test ndims(d2) == 3 -end - -@testitem "MvNormalMeanPrecision: prod" begin - include("./normal_family_setuptests.jl") - - for strategy in (ClosedProd(), PreserveTypeProd(Distribution), GenericProd()) - @test prod(strategy, MvNormalMeanPrecision([-1, -1], [2, 2]), MvNormalMeanPrecision([1, 1], [2, 4])) ≈ - MvNormalWeightedMeanPrecision([0, 2], [4, 6]) - - μ = [1.0, 2.0, 3.0] - Λ = diagm(1 ./ [1.0, 2.0, 3.0]) - dist = MvNormalMeanPrecision(μ, Λ) - - @test prod(strategy, dist, dist) ≈ - MvNormalWeightedMeanPrecision([2.0, 2.0, 2.0], diagm([2.0, 1.0, 2 / 3])) - end -end - -@testitem "MvNormalMeanPrecision: convert" begin - include("./normal_family_setuptests.jl") - - @test convert(MvNormalMeanPrecision, zeros(2), Matrix(Diagonal(ones(2)))) == - MvNormalMeanPrecision(zeros(2), Matrix(Diagonal(ones(2)))) - @test begin - m = rand(5) - c = Matrix(Symmetric(rand(5, 5))) - convert(MvNormalMeanPrecision, m, c) == MvNormalMeanPrecision(m, c) - end -end - -================ -File: distributions/normal_family/mv_normal_weighted_mean_precision_tests.jl -================ -@testitem "MvNormalWeightedMeanPrecision: Constructor" begin - include("./normal_family_setuptests.jl") - - @test MvNormalWeightedMeanPrecision <: AbstractMvNormal - - @test MvNormalWeightedMeanPrecision([1.0, 1.0]) == MvNormalWeightedMeanPrecision([1.0, 1.0], [1.0, 1.0]) - @test MvNormalWeightedMeanPrecision([1.0, 2.0]) == MvNormalWeightedMeanPrecision([1.0, 2.0], [1.0, 1.0]) - @test MvNormalWeightedMeanPrecision([1, 2]) == MvNormalWeightedMeanPrecision([1.0, 2.0], [1.0, 1.0]) - @test MvNormalWeightedMeanPrecision([1.0f0, 2.0f0]) == - MvNormalWeightedMeanPrecision([1.0f0, 2.0f0], [1.0f0, 1.0f0]) - - @test eltype(MvNormalWeightedMeanPrecision([1.0, 1.0])) === Float64 - @test eltype(MvNormalWeightedMeanPrecision([1.0, 1.0], [1.0, 1.0])) === Float64 - @test eltype(MvNormalWeightedMeanPrecision([1, 1])) === Float64 - @test eltype(MvNormalWeightedMeanPrecision([1, 1], [1, 1])) === Float64 - @test eltype(MvNormalWeightedMeanPrecision([1.0f0, 1.0f0])) === Float32 - @test eltype(MvNormalWeightedMeanPrecision([1.0f0, 1.0f0], [1.0f0, 1.0f0])) === Float32 - - @test MvNormalWeightedMeanPrecision(ones(3), 5I) == MvNormalWeightedMeanPrecision(ones(3), Diagonal(5 * ones(3))) - @test MvNormalWeightedMeanPrecision([1, 2, 3, 4], 7.0I) == MvNormalWeightedMeanPrecision([1.0, 2.0, 3.0, 4.0], Diagonal(7.0 * ones(4))) -end - -@testitem "MvNormalWeightedMeanPrecision: distrname" begin - include("./normal_family_setuptests.jl") - - @test ExponentialFamily.distrname(MvNormalWeightedMeanPrecision(zeros(2))) === "MvNormalWeightedMeanPrecision" -end - -@testitem "MvNormalWeightedMeanPrecision: Stats methods" begin - include("./normal_family_setuptests.jl") - - xi = [-0.2, 5.34, 14.02] - Λ = [1.5 -0.3 0.1; -0.3 1.8 0.0; 0.1 0.0 3.5] - dist = MvNormalWeightedMeanPrecision(xi, Λ) - - @test mean(dist) ≈ inv(Λ) * xi - @test mode(dist) ≈ inv(Λ) * xi - @test weightedmean(dist) == xi - @test invcov(dist) == Λ - @test precision(dist) == Λ - @test cov(dist) ≈ inv(Λ) - @test std(dist) * std(dist)' ≈ inv(Λ) - @test all(mean_cov(dist) .≈ (inv(Λ) * xi, inv(Λ))) - @test all(mean_invcov(dist) .≈ (inv(Λ) * xi, Λ)) - @test all(mean_precision(dist) .≈ (inv(Λ) * xi, Λ)) - @test all(weightedmean_cov(dist) .≈ (xi, inv(Λ))) - @test all(weightedmean_invcov(dist) .≈ (xi, Λ)) - @test all(weightedmean_precision(dist) .≈ (xi, Λ)) - - @test length(dist) == 3 - @test entropy(dist) ≈ 3.1517451983126357 - @test pdf(dist, [0.2, 3.0, 4.0]) ≈ 0.19171503573907536 - @test pdf(dist, [0.202, 3.002, 4.002]) ≈ 0.19171258180232315 - @test logpdf(dist, [0.2, 3.0, 4.0]) ≈ -1.6517451983126357 - @test logpdf(dist, [0.202, 3.002, 4.002]) ≈ -1.6517579983126356 -end - -@testitem "MvNormalWeightedMeanPrecision: Base methods" begin - include("./normal_family_setuptests.jl") - - @test convert(MvNormalWeightedMeanPrecision{Float32}, MvNormalWeightedMeanPrecision([0.0, 0.0])) == - MvNormalWeightedMeanPrecision([0.0f0, 0.0f0], [1.0f0, 1.0f0]) - @test convert(MvNormalWeightedMeanPrecision{Float64}, [0.0, 0.0], [2 0; 0 3]) == - MvNormalWeightedMeanPrecision([0.0, 0.0], [2.0 0.0; 0.0 3.0]) - - @test length(MvNormalWeightedMeanPrecision([0.0, 0.0])) === 2 - @test length(MvNormalWeightedMeanPrecision([0.0, 0.0, 0.0])) === 3 - @test ndims(MvNormalWeightedMeanPrecision([0.0, 0.0])) === 2 - @test ndims(MvNormalWeightedMeanPrecision([0.0, 0.0, 0.0])) === 3 - @test size(MvNormalWeightedMeanPrecision([0.0, 0.0])) === (2,) - @test size(MvNormalWeightedMeanPrecision([0.0, 0.0, 0.0])) === (3,) - - distribution = MvNormalWeightedMeanPrecision([0.0, 0.0], [2.0 0.0; 0.0 3.0]) - - @test distribution ≈ distribution - @test distribution ≈ convert(MvNormalMeanCovariance, distribution) - @test distribution ≈ convert(MvNormalMeanPrecision, distribution) -end - -@testitem "MvNormalWeightedMeanPrecision: vague" begin - include("./normal_family_setuptests.jl") - - @test_throws MethodError vague(MvNormalWeightedMeanPrecision) - - d1 = vague(MvNormalWeightedMeanPrecision, 2) - - @test typeof(d1) <: MvNormalWeightedMeanPrecision - @test mean(d1) == zeros(2) - @test invcov(d1) == Matrix(Diagonal(1e-12 * ones(2))) - @test ndims(d1) == 2 - - d2 = vague(MvNormalWeightedMeanPrecision, 3) - - @test typeof(d2) <: MvNormalWeightedMeanPrecision - @test mean(d2) == zeros(3) - @test invcov(d2) == Matrix(Diagonal(1e-12 * ones(3))) - @test ndims(d2) == 3 -end - -@testitem "MvNormalWeightedMeanPrecision: prod" begin - include("./normal_family_setuptests.jl") - - for strategy in (ClosedProd(), PreserveTypeProd(Distribution), GenericProd()) - @test prod( - strategy, - MvNormalWeightedMeanPrecision([-1, -1], [2, 2]), - MvNormalWeightedMeanPrecision([1, 1], [2, 4]) - ) ≈ MvNormalWeightedMeanPrecision([0, 0], [4, 6]) - - xi = [0.2, 3.0, 4.0] - Λ = [1.5 -0.1 0.1; -0.1 1.8 0.0; 0.1 0.0 3.5] - dist = MvNormalWeightedMeanPrecision(xi, Λ) - - @test prod(strategy, dist, dist) ≈ - MvNormalWeightedMeanPrecision([0.40, 6.00, 8.00], [3.00 -0.20 0.20; -0.20 3.60 0.00; 0.20 0.00 7.00]) - end -end - -@testitem "MvNormalWeightedMeanPrecision: convert" begin - include("./normal_family_setuptests.jl") - - @test convert(MvNormalWeightedMeanPrecision, zeros(2), Matrix(Diagonal(ones(2)))) == - MvNormalWeightedMeanPrecision(zeros(2), Matrix(Diagonal(ones(2)))) - @test begin - m = rand(5) - c = Matrix(Symmetric(rand(5, 5))) - convert(MvNormalWeightedMeanPrecision, m, c) == MvNormalWeightedMeanPrecision(m, c) - end -end - -================ -File: distributions/normal_family/normal_family_setuptests.jl -================ -include("../distributions_setuptests.jl") - -import ExponentialFamily: dot3arg - -# We need this extra function to ensure better derivatives with AD, it is slower than our implementation -# but is more AD friendly -function getlogpartitionfortest(::NaturalParametersSpace, ::Type{MvNormalMeanCovariance}) - return (η) -> begin - weightedmean, minushalfprecision = unpack_parameters(MvNormalMeanCovariance, η) - return (dot3arg(weightedmean, inv(-minushalfprecision), weightedmean) / 2 - logdet(-2 * minushalfprecision)) / 2 - end -end - -function gaussianlpdffortest(params, x) - k = length(x) - μ, Σ = params[1:k], reshape(params[k+1:end], k, k) - coef = (2π)^(-k / 2) * det(Σ)^(-1 / 2) - exponent = -0.5 * (x - μ)' * inv(Σ) * (x - μ) - return log(coef * exp(exponent)) -end - -function check_basic_statistics(left::UnivariateNormalDistributionsFamily, right::UnivariateNormalDistributionsFamily) - @test mean(left) ≈ mean(right) - @test median(left) ≈ median(right) - @test mode(left) ≈ mode(right) - @test var(left) ≈ var(right) - @test std(left) ≈ std(right) - @test entropy(left) ≈ entropy(right) - - for value in (1.0, -1.0, 0.0, mean(left), mean(right), rand()) - @test pdf(left, value) ≈ pdf(right, value) - @test logpdf(left, value) ≈ logpdf(right, value) - @test all( - ForwardDiff.gradient((x) -> logpdf(left, x[1]), [value]) .≈ - ForwardDiff.gradient((x) -> logpdf(right, x[1]), [value]) - ) - @test all( - ForwardDiff.hessian((x) -> logpdf(left, x[1]), [value]) .≈ - ForwardDiff.hessian((x) -> logpdf(right, x[1]), [value]) - ) - end - - # `Normal` is not defining some of these methods and we don't want to define them either, because of the type piracy - if !(left isa Normal || right isa Normal) - @test cov(left) ≈ cov(right) - @test invcov(left) ≈ invcov(right) - @test weightedmean(left) ≈ weightedmean(right) - @test precision(left) ≈ precision(right) - @test all(mean_cov(left) .≈ mean_cov(right)) - @test all(mean_invcov(left) .≈ mean_invcov(right)) - @test all(mean_precision(left) .≈ mean_precision(right)) - @test all(weightedmean_cov(left) .≈ weightedmean_cov(right)) - @test all(weightedmean_invcov(left) .≈ weightedmean_invcov(right)) - @test all(weightedmean_precision(left) .≈ weightedmean_precision(right)) - end -end - -function check_basic_statistics(left::MultivariateNormalDistributionsFamily, right::MultivariateNormalDistributionsFamily) - @test mean(left) ≈ mean(right) - @test mode(left) ≈ mode(right) - @test var(left) ≈ var(right) - @test cov(left) ≈ cov(right) - @test logdetcov(left) ≈ logdetcov(right) - @test length(left) === length(right) - @test size(left) === size(right) - @test entropy(left) ≈ entropy(right) - - dims = length(mean(left)) - - for value in ( - fill(1.0, dims), - fill(-1.0, dims), - fill(0.1, dims), - mean(left), - mean(right), - rand(dims) - ) - @test pdf(left, value) ≈ pdf(right, value) - @test logpdf(left, value) ≈ logpdf(right, value) - @test all( - isapprox.( - ForwardDiff.gradient((x) -> logpdf(left, x), value), - ForwardDiff.gradient((x) -> logpdf(right, x), value), - atol = 1e-12 - ) - ) - if !all( - isapprox.( - ForwardDiff.hessian((x) -> logpdf(left, x), value), - ForwardDiff.hessian((x) -> logpdf(right, x), value), - atol = 1e-12 - ) - ) - error(left, right) - end - end - - # `MvNormal` is not defining some of these methods and we don't want to define them either, because of the type piracy - if !(left isa MvNormal || right isa MvNormal) - @test ndims(left) === ndims(right) - @test invcov(left) ≈ invcov(right) - @test weightedmean(left) ≈ weightedmean(right) - @test precision(left) ≈ precision(right) - @test all(mean_cov(left) .≈ mean_cov(right)) - @test all(mean_invcov(left) .≈ mean_invcov(right)) - @test all(mean_precision(left) .≈ mean_precision(right)) - @test all(weightedmean_cov(left) .≈ weightedmean_cov(right)) - @test all(weightedmean_invcov(left) .≈ weightedmean_invcov(right)) - @test all(weightedmean_precision(left) .≈ weightedmean_precision(right)) - end -end - -================ -File: distributions/normal_family/normal_family_tests.jl -================ -@testitem "NormalFamily: Univariate conversions" begin - include("./normal_family_setuptests.jl") - - types = union_types(UnivariateNormalDistributionsFamily{Float64}) - etypes = union_types(UnivariateNormalDistributionsFamily) - - rng = MersenneTwister(1234) - - for type in types - left = convert(type, rand(rng, Float64), rand(rng, Float64)) - - for type in [types..., etypes...] - right = convert(type, left) - check_basic_statistics(left, right) - - p1 = prod(PreserveTypeLeftProd(), left, right) - @test typeof(p1) <: typeof(left) - - p2 = prod(PreserveTypeRightProd(), left, right) - @test typeof(p2) <: typeof(right) - - for strategy in (ClosedProd(), PreserveTypeProd(Distribution), GenericProd()) - p3 = prod(strategy, left, right) - - check_basic_statistics(p1, p2) - check_basic_statistics(p2, p3) - check_basic_statistics(p1, p3) - end - end - end -end - -@testitem "NormalFamily: Multivariate conversions" begin - include("./normal_family_setuptests.jl") - - types = union_types(MultivariateNormalDistributionsFamily{Float64}) - etypes = union_types(MultivariateNormalDistributionsFamily) - - dims = (2, 3, 5) - rng = MersenneTwister(1234) - - for dim in dims - for type in types - left = convert(type, rand(rng, Float64, dim), Matrix(Diagonal(abs.(rand(rng, Float64, dim))))) - - for type in [types..., etypes...] - right = convert(type, left) - check_basic_statistics(left, right) - - p1 = prod(PreserveTypeLeftProd(), left, right) - @test typeof(p1) <: typeof(left) - - p2 = prod(PreserveTypeRightProd(), left, right) - @test typeof(p2) <: typeof(right) - - for strategy in (ClosedProd(), PreserveTypeProd(Distribution), GenericProd()) - p3 = prod(strategy, left, right) - - check_basic_statistics(p1, p2) - check_basic_statistics(p2, p3) - check_basic_statistics(p1, p3) - end - end - end - end -end - -@testitem "NormalFamily: Variate forms promotions" begin - include("./normal_family_setuptests.jl") - - @test promote_variate_type(Univariate, NormalMeanVariance) === NormalMeanVariance - @test promote_variate_type(Univariate, NormalMeanPrecision) === NormalMeanPrecision - @test promote_variate_type(Univariate, NormalWeightedMeanPrecision) === NormalWeightedMeanPrecision - - @test promote_variate_type(Multivariate, NormalMeanVariance) === MvNormalMeanCovariance - @test promote_variate_type(Multivariate, NormalMeanPrecision) === MvNormalMeanPrecision - @test promote_variate_type(Multivariate, NormalWeightedMeanPrecision) === MvNormalWeightedMeanPrecision - - @test promote_variate_type(Univariate, MvNormalMeanCovariance) === NormalMeanVariance - @test promote_variate_type(Univariate, MvNormalMeanPrecision) === NormalMeanPrecision - @test promote_variate_type(Univariate, MvNormalWeightedMeanPrecision) === NormalWeightedMeanPrecision - - @test promote_variate_type(Multivariate, MvNormalMeanCovariance) === MvNormalMeanCovariance - @test promote_variate_type(Multivariate, MvNormalMeanPrecision) === MvNormalMeanPrecision - @test promote_variate_type(Multivariate, MvNormalWeightedMeanPrecision) === MvNormalWeightedMeanPrecision -end - -@testitem "NormalFamily: Sampling univariate" begin - include("./normal_family_setuptests.jl") - - rng = MersenneTwister(1234) - - for T in (Float32, Float64) - let # NormalMeanVariance - μ, v = 10randn(rng), 10rand(rng) - d = convert(NormalMeanVariance{T}, μ, v) - - @test typeof(rand(d)) <: T - - samples = rand(rng, d, 5_000) - - @test isapprox(mean(samples), μ, atol = 0.5) - @test isapprox(var(samples), v, atol = 0.5) - end - - let # NormalMeanPrecision - μ, w = 10randn(rng), 10rand(rng) - d = convert(NormalMeanPrecision{T}, μ, w) - - @test typeof(rand(d)) <: T - - samples = rand(rng, d, 5_000) - - @test isapprox(mean(samples), μ, atol = 0.5) - @test isapprox(inv(var(samples)), w, atol = 0.5) - end - - let # WeightedMeanPrecision - wμ, w = 10randn(rng), 10rand(rng) - d = convert(NormalWeightedMeanPrecision{T}, wμ, w) - - @test typeof(rand(d)) <: T - - samples = rand(rng, d, 5_000) - - @test isapprox(inv(var(samples)) * mean(samples), wμ, atol = 0.5) - @test isapprox(inv(var(samples)), w, atol = 0.5) - end - end -end - -@testitem "NormalFamily: Sampling multivariate" begin - include("./normal_family_setuptests.jl") - - rng = MersenneTwister(1234) - for n in (2, 3), T in (Float64,), nsamples in (10_000,) - μ = randn(rng, n) - L = randn(rng, n, n) - Σ = L * L' - - d = convert(MvNormalMeanCovariance{T}, μ, Σ) - @test typeof(rand(d)) <: Vector{T} - - samples = eachcol(rand(rng, d, nsamples)) - weights = fill(1 / nsamples, nsamples) - - @test isapprox(sum(sample for sample in samples) / nsamples, mean(d), atol = n * 0.5) - @test isapprox( - sum((sample - mean(d)) * (sample - mean(d))' for sample in samples) / nsamples, - cov(d), - atol = n * 0.5 - ) - - μ = randn(rng, n) - L = randn(rng, n, n) - W = L * L' - d = convert(MvNormalMeanCovariance{T}, μ, W) - @test typeof(rand(d)) <: Vector{T} - - samples = eachcol(rand(rng, d, nsamples)) - weights = fill(1 / nsamples, nsamples) - - @test isapprox(sum(sample for sample in samples) / nsamples, mean(d), atol = n * 0.5) - @test isapprox( - sum((sample - mean(d)) * (sample - mean(d))' for sample in samples) / nsamples, - cov(d), - atol = n * 0.5 - ) - - ξ = randn(rng, n) - L = randn(rng, n, n) - W = L * L' - - d = convert(MvNormalWeightedMeanPrecision{T}, ξ, W) - - @test typeof(rand(d)) <: Vector{T} - - samples = eachcol(rand(rng, d, nsamples)) - weights = fill(1 / nsamples, nsamples) - - @test isapprox(sum(sample for sample in samples) / nsamples, mean(d), atol = n * 0.5) - @test isapprox( - sum((sample - mean(d)) * (sample - mean(d))' for sample in samples) / nsamples, - cov(d), - atol = n * 0.5 - ) - end -end - -@testitem "NormalFamily: ExponentialFamilyDistribution{NormalMeanVariance}" begin - include("./normal_family_setuptests.jl") - - for μ in -10.0:5.0:10.0, σ² in 0.1:1.0:5.0, T in union_types(UnivariateNormalDistributionsFamily) - @testset let d = convert(T, NormalMeanVariance(μ, σ²)) - ef = test_exponentialfamily_interface(d) - - (η₁, η₂) = (mean(d) / var(d), -1 / 2var(d)) - - for x in 10randn(4) - @test @inferred(isbasemeasureconstant(ef)) === ConstantBaseMeasure() - @test @inferred(basemeasure(ef, x)) ≈ 1 / sqrt(2π) - @test all(@inferred(sufficientstatistics(ef, x)) .≈ (x, abs2(x))) - @test @inferred(logpartition(ef)) ≈ (-η₁^2 / 4η₂ - 1 / 2 * log(-2η₂)) - @test @inferred(insupport(ef, x)) - end - end - end - - # Test failing isproper cases - @test !isproper(MeanParametersSpace(), NormalMeanVariance, [-1]) - @test !isproper(MeanParametersSpace(), NormalMeanVariance, [1, -0.1]) - @test !isproper(MeanParametersSpace(), NormalMeanVariance, [-0.1, -1]) - @test !isproper(NaturalParametersSpace(), NormalMeanVariance, [-1.1]) - @test !isproper(NaturalParametersSpace(), NormalMeanVariance, [1, 1]) - @test !isproper(NaturalParametersSpace(), NormalMeanVariance, [-1.1, 1]) -end - -@testitem "NormalFamily: prod with ExponentialFamilyDistribution{NormalMeanVariance}" begin - include("./normal_family_setuptests.jl") - - for μleft in 10randn(4), σ²left in 10rand(4), μright in 10randn(4), - σ²right in 10rand(4), Tleft in union_types(UnivariateNormalDistributionsFamily), - Tright in union_types(UnivariateNormalDistributionsFamily) - - @testset let (left, right) = (convert(Tleft, NormalMeanVariance(μleft, σ²left)), convert(Tright, NormalMeanVariance(μright, σ²right))) - @test test_generic_simple_exponentialfamily_product( - left, - right, - strategies = ( - ClosedProd(), - GenericProd(), - PreserveTypeProd(ExponentialFamilyDistribution), - PreserveTypeProd(ExponentialFamilyDistribution{NormalMeanVariance}) - ) - ) - end - end -end - -@testitem "NormalFamily: ExponentialFamilyDistribution{MvNormalMeanCovariance}" begin - include("./normal_family_setuptests.jl") - - for s in (2, 3), T in union_types(MultivariateNormalDistributionsFamily) - μ = 10randn(s) - L = LowerTriangular(randn(s, s) + s * I) - Σ = L * L' - @testset let d = convert(T, MvNormalMeanCovariance(μ, Σ)) - ef = test_exponentialfamily_interface( - d; - # These are handled differently below - test_fisherinformation_against_hessian = false, - test_fisherinformation_against_jacobian = false, - test_gradlogpartition_properties = false - ) - - (η₁, η₂) = (inv(Σ) * mean(d), -inv(Σ) / 2) - - for x in [10randn(s) for _ in 1:4] - @test @inferred(isbasemeasureconstant(ef)) === ConstantBaseMeasure() - @test @inferred(basemeasure(ef, x)) ≈ (2π)^(-s / 2) - @test all(@inferred(sufficientstatistics(ef, x)) .≈ (x, x * x')) - @test @inferred(logpartition(ef)) ≈ -1 / 4 * (η₁' * inv(η₂) * η₁) - 1 / 2 * logdet(-2η₂) - @test @inferred(insupport(ef, x)) - end - - run_test_gradlogpartition_properties(d, test_against_forwardiff = false) - - # Extra test with AD-friendly logpartition function - lp_ag = ForwardDiff.gradient(getlogpartitionfortest(NaturalParametersSpace(), MvNormalMeanCovariance), getnaturalparameters(ef)) - @test gradlogpartition(ef) ≈ lp_ag - end - end - - # Test failing isproper cases (naive) - @test !isproper(MeanParametersSpace(), MvNormalMeanCovariance, [-1]) - @test !isproper(MeanParametersSpace(), MvNormalMeanCovariance, [1, -0.1]) - @test !isproper(MeanParametersSpace(), MvNormalMeanCovariance, [-0.1, -1]) - @test !isproper(MeanParametersSpace(), MvNormalMeanCovariance, [-1, 2, 3, 4]) # shapes are incompatible - @test !isproper(MeanParametersSpace(), MvNormalMeanCovariance, [1, -0.1, -1, 0, 0, -1]) # covariance is not posdef - - @test !isproper(NaturalParametersSpace(), MvNormalMeanCovariance, [-1.1]) - @test !isproper(NaturalParametersSpace(), MvNormalMeanCovariance, [1, 1]) - @test !isproper(NaturalParametersSpace(), MvNormalMeanCovariance, [-1.1, 1]) - @test !isproper(NaturalParametersSpace(), MvNormalMeanCovariance, [-1, 2, 3, 4]) # shapes are incompatible - @test !isproper(NaturalParametersSpace(), MvNormalMeanCovariance, [1, -0.1, 1, 0, 0, 1]) # -η₂ is not posdef -end - -@testitem "NormalFamily: Fisher information matrix in natural parameters space" begin - include("./normal_family_setuptests.jl") - - for i in 1:5, d in 2:10 - rng = StableRNG(d * i) - μ = 10randn(rng, d) - L = LowerTriangular(randn(rng, d, d) + d * I) - Σ = L * L' - ef = convert(ExponentialFamilyDistribution, MvNormalMeanCovariance(μ, Σ)) - - fi_ef = fisherinformation(ef) - # @test_broken isposdef(fi_ef) - # The `isposdef` check is not really reliable in Julia, here, instead we compute eigen values - @test issymmetric(fi_ef) || (LowerTriangular(fi_ef) ≈ (UpperTriangular(fi_ef)')) - @test isposdef(fi_ef) || all(>(0), eigvals(fi_ef)) - - fi_ef_inv = inv(fi_ef) - @test (fi_ef_inv * fi_ef) ≈ Diagonal(ones(d + d^2)) - - # WARNING: ForwardDiff returns a non-positive definite Hessian for a convex function. - # The matrices are identical up to permutations, resulting in eigenvalues that are the same up to a sign. - fi_ag = ForwardDiff.hessian(getlogpartitionfortest(NaturalParametersSpace(), MvNormalMeanCovariance), getnaturalparameters(ef)) - @test norm(sort(eigvals(fi_ef)) - sort(abs.(eigvals(fi_ag)))) ≈ 0 atol = (1e-9 * d^2) - end -end - -# We normally perform test with jacobian transformation, but autograd fails to compute jacobians with duplicated elements. -@testitem "Fisher information matrix in mean parameters space" begin - include("./normal_family_setuptests.jl") - - for i in 1:5, d in 2:3 - rng = StableRNG(d * i) - μ = 10randn(rng, d) - L = LowerTriangular(randn(rng, d, d) + d * I) - Σ = L * L' - n_samples = 10000 - dist = MvNormalMeanCovariance(μ, Σ) - - samples = rand(rng, dist, n_samples) - - θ = pack_parameters(MvNormalMeanCovariance, (μ, Σ)) - - approxHessian = zeros(length(θ), length(θ)) - for sample in eachcol(samples) - approxHessian -= ForwardDiff.hessian(Base.Fix2(gaussianlpdffortest, sample), θ) - end - approxFisherInformation = approxHessian /= n_samples - - # The error will be higher for sampling tests, tolerance adjusted accordingly. - fi_dist = getfisherinformation(MeanParametersSpace(), MvNormalMeanCovariance)(θ) - @test isposdef(fi_dist) || all(>(0), eigvals(fi_dist)) - @test issymmetric(fi_dist) || (LowerTriangular(fi_dist) ≈ (UpperTriangular(fi_dist)')) - @test sort(eigvals(fi_dist)) ≈ sort(abs.(eigvals(approxFisherInformation))) rtol = 1e-1 - @test sort(svd(fi_dist).S) ≈ sort(svd(approxFisherInformation).S) rtol = 1e-1 - end -end - -@testitem "Diffrentiabilty of ExponentialFamily(ExponentialFamily.MvNormalMeanCovariance) logpdf" begin - include("./normal_family_setuptests.jl") - for i in 1:5, d in 2:3 - rng = StableRNG(d * i) - μ = 10randn(rng, d) - L = LowerTriangular(randn(rng, d, d) + d * I) - Σ = L * L' - n_samples = 1 - dist = MvNormalMeanCovariance(μ, Σ) - - samples = rand(rng, dist, n_samples) - - θ = pack_parameters(MvNormalMeanCovariance, (μ, Σ)) - ef = convert(ExponentialFamilyDistribution, MvNormalMeanCovariance(μ, Σ)) - - nat_space2mean_space = (η) -> begin - dist = convert(Distribution, ExponentialFamilyDistribution(MvNormalMeanCovariance, η)) - μ, Σ = mean(dist), cov(dist) - pack_parameters(MvNormalMeanCovariance, (μ, Σ)) - end - - for sample in eachcol(samples) - mean_gradient = ForwardDiff.gradient(Base.Fix2(gaussianlpdffortest, sample), θ) - nat_gradient = ForwardDiff.gradient((η) -> logpdf(ExponentialFamilyDistribution(MvNormalMeanCovariance, η), sample), getnaturalparameters(ef)) - jacobian = ForwardDiff.jacobian(nat_space2mean_space, getnaturalparameters(ef)) - #autograd failing to compute jacobian of matrix part correclty. Comparing only vector (mean) part. - @test nat_gradient[1:d] ≈ (jacobian'*mean_gradient)[1:d] - end - end -end - -================ -File: distributions/normal_family/normal_mean_precision_tests.jl -================ -@testitem "NormalMeanPrecision: Constructor" begin - include("./normal_family_setuptests.jl") - - @test NormalMeanPrecision <: NormalDistributionsFamily - @test NormalMeanPrecision <: UnivariateNormalDistributionsFamily - - @test NormalMeanPrecision() == NormalMeanPrecision{Float64}(0.0, 1.0) - @test NormalMeanPrecision(1.0) == NormalMeanPrecision{Float64}(1.0, 1.0) - @test NormalMeanPrecision(1.0, 2.0) == NormalMeanPrecision{Float64}(1.0, 2.0) - @test NormalMeanPrecision(1) == NormalMeanPrecision{Float64}(1.0, 1.0) - @test NormalMeanPrecision(1, 2) == NormalMeanPrecision{Float64}(1.0, 2.0) - @test NormalMeanPrecision(1.0, 2) == NormalMeanPrecision{Float64}(1.0, 2.0) - @test NormalMeanPrecision(1, 2.0) == NormalMeanPrecision{Float64}(1.0, 2.0) - @test NormalMeanPrecision(1.0f0) == NormalMeanPrecision{Float32}(1.0f0, 1.0f0) - @test NormalMeanPrecision(1.0f0, 2.0f0) == NormalMeanPrecision{Float32}(1.0f0, 2.0f0) - @test NormalMeanPrecision(1.0f0, 2) == NormalMeanPrecision{Float32}(1.0f0, 2.0f0) - @test NormalMeanPrecision(1.0f0, 2.0) == NormalMeanPrecision{Float64}(1.0, 2.0) - - @test eltype(NormalMeanPrecision()) === Float64 - @test eltype(NormalMeanPrecision(0.0)) === Float64 - @test eltype(NormalMeanPrecision(0.0, 1.0)) === Float64 - @test eltype(NormalMeanPrecision(0)) === Float64 - @test eltype(NormalMeanPrecision(0, 1)) === Float64 - @test eltype(NormalMeanPrecision(0.0, 1)) === Float64 - @test eltype(NormalMeanPrecision(0, 1.0)) === Float64 - @test eltype(NormalMeanPrecision(0.0f0)) === Float32 - @test eltype(NormalMeanPrecision(0.0f0, 1.0f0)) === Float32 - @test eltype(NormalMeanPrecision(0.0f0, 1.0)) === Float64 - - @test NormalMeanPrecision(3, 5I) == NormalMeanPrecision(3, 5) - @test NormalMeanPrecision(2, 7.0I) == NormalMeanPrecision(2.0, 7.0) -end - -@testitem "NormalMeanPrecision: Stats methods" begin - include("./normal_family_setuptests.jl") - - dist1 = NormalMeanPrecision(0.0, 1.0) - - @test mean(dist1) === 0.0 - @test median(dist1) === 0.0 - @test mode(dist1) === 0.0 - @test weightedmean(dist1) === 0.0 - @test var(dist1) === 1.0 - @test std(dist1) === 1.0 - @test cov(dist1) === 1.0 - @test invcov(dist1) === 1.0 - @test precision(dist1) === 1.0 - @test entropy(dist1) ≈ 1.41893853320467 - @test pdf(dist1, 1.0) ≈ 0.24197072451914337 - @test pdf(dist1, -1.0) ≈ 0.24197072451914337 - @test pdf(dist1, 0.0) ≈ 0.3989422804014327 - @test logpdf(dist1, 1.0) ≈ -1.4189385332046727 - @test logpdf(dist1, -1.0) ≈ -1.4189385332046727 - @test logpdf(dist1, 0.0) ≈ -0.9189385332046728 - - dist2 = NormalMeanPrecision(1.0, 1.0) - - @test mean(dist2) === 1.0 - @test median(dist2) === 1.0 - @test mode(dist2) === 1.0 - @test weightedmean(dist2) === 1.0 - @test var(dist2) === 1.0 - @test std(dist2) === 1.0 - @test cov(dist2) === 1.0 - @test invcov(dist2) === 1.0 - @test precision(dist2) === 1.0 - @test entropy(dist2) ≈ 1.41893853320467 - @test pdf(dist2, 1.0) ≈ 0.3989422804014327 - @test pdf(dist2, -1.0) ≈ 0.05399096651318806 - @test pdf(dist2, 0.0) ≈ 0.24197072451914337 - @test logpdf(dist2, 1.0) ≈ -0.9189385332046728 - @test logpdf(dist2, -1.0) ≈ -2.9189385332046727 - @test logpdf(dist2, 0.0) ≈ -1.4189385332046727 - - dist3 = NormalMeanPrecision(1.0, 0.5) - - @test mean(dist3) === 1.0 - @test median(dist3) === 1.0 - @test mode(dist3) === 1.0 - @test weightedmean(dist3) === inv(2.0) - @test var(dist3) === 2.0 - @test std(dist3) === sqrt(2.0) - @test cov(dist3) === 2.0 - @test invcov(dist3) === inv(2.0) - @test precision(dist3) === inv(2.0) - @test entropy(dist3) ≈ 1.7655121234846454 - @test pdf(dist3, 1.0) ≈ 0.28209479177387814 - @test pdf(dist3, -1.0) ≈ 0.1037768743551487 - @test pdf(dist3, 0.0) ≈ 0.21969564473386122 - @test logpdf(dist3, 1.0) ≈ -1.2655121234846454 - @test logpdf(dist3, -1.0) ≈ -2.2655121234846454 - @test logpdf(dist3, 0.0) ≈ -1.5155121234846454 -end - -@testitem "NormalMeanPrecision: Base methods" begin - include("./normal_family_setuptests.jl") - - @test convert(NormalMeanPrecision{Float32}, NormalMeanPrecision()) == NormalMeanPrecision{Float32}(0.0f0, 1.0f0) - @test convert(NormalMeanPrecision{Float64}, NormalMeanPrecision(0.0, 10.0)) == - NormalMeanPrecision{Float64}(0.0, 10.0) - @test convert(NormalMeanPrecision{Float64}, NormalMeanPrecision(0.0, 0.1)) == - NormalMeanPrecision{Float64}(0.0, 0.1) - @test convert(NormalMeanPrecision{Float64}, 0, 1) == NormalMeanPrecision{Float64}(0.0, 1.0) - @test convert(NormalMeanPrecision{Float64}, 0, 10) == NormalMeanPrecision{Float64}(0.0, 10.0) - @test convert(NormalMeanPrecision{Float64}, 0, 0.1) == NormalMeanPrecision{Float64}(0.0, 0.1) - @test convert(NormalMeanPrecision, 0, 1) == NormalMeanPrecision{Float64}(0.0, 1.0) - @test convert(NormalMeanPrecision, 0, 10) == NormalMeanPrecision{Float64}(0.0, 10.0) - @test convert(NormalMeanPrecision, 0, 0.1) == NormalMeanPrecision{Float64}(0.0, 0.1) - - distribution = NormalMeanPrecision(-2.0, 3.0) - - @test distribution ≈ distribution - @test distribution ≈ convert(NormalMeanVariance, distribution) - @test distribution ≈ convert(NormalWeightedMeanPrecision, distribution) -end - -@testitem "NormalMeanPrecision: vague" begin - include("./normal_family_setuptests.jl") - - d1 = vague(NormalMeanPrecision) - - @test typeof(d1) <: NormalMeanPrecision - @test mean(d1) == 0.0 - @test precision(d1) == 1e-12 -end - -@testitem "NormalMeanPrecision: prod" begin - include("./normal_family_setuptests.jl") - - for strategy in (ClosedProd(), PreserveTypeProd(Distribution), GenericProd()) - @test prod(strategy, NormalMeanPrecision(-1, 1 / 1), NormalMeanPrecision(1, 1 / 1)) ≈ - NormalWeightedMeanPrecision(0.0, 2.0) - @test prod(strategy, NormalMeanPrecision(-1, 1 / 2), NormalMeanPrecision(1, 1 / 4)) ≈ - NormalWeightedMeanPrecision(-1 / 4, 3 / 4) - @test prod(strategy, NormalMeanPrecision(2, 1 / 2), NormalMeanPrecision(0, 1 / 10)) ≈ - NormalWeightedMeanPrecision(1, 3 / 5) - end -end - -================ -File: distributions/normal_family/normal_mean_variance_tests.jl -================ -@testitem "NormalMeanVariance: Constructor" begin - include("./normal_family_setuptests.jl") - - @test NormalMeanVariance <: NormalDistributionsFamily - @test NormalMeanVariance <: UnivariateNormalDistributionsFamily - - @test NormalMeanVariance() == NormalMeanVariance{Float64}(0.0, 1.0) - @test NormalMeanVariance(1.0) == NormalMeanVariance{Float64}(1.0, 1.0) - @test NormalMeanVariance(1.0, 2.0) == NormalMeanVariance{Float64}(1.0, 2.0) - @test NormalMeanVariance(1) == NormalMeanVariance{Float64}(1.0, 1.0) - @test NormalMeanVariance(1, 2) == NormalMeanVariance{Float64}(1.0, 2.0) - @test NormalMeanVariance(1.0, 2) == NormalMeanVariance{Float64}(1.0, 2.0) - @test NormalMeanVariance(1, 2.0) == NormalMeanVariance{Float64}(1.0, 2.0) - @test NormalMeanVariance(1.0f0) == NormalMeanVariance{Float32}(1.0f0, 1.0f0) - @test NormalMeanVariance(1.0f0, 2.0f0) == NormalMeanVariance{Float32}(1.0f0, 2.0f0) - @test NormalMeanVariance(1.0f0, 2) == NormalMeanVariance{Float32}(1.0f0, 2.0f0) - @test NormalMeanVariance(1.0f0, 2.0) == NormalMeanVariance{Float64}(1.0, 2.0) - - @test eltype(NormalMeanVariance()) === Float64 - @test eltype(NormalMeanVariance(0.0)) === Float64 - @test eltype(NormalMeanVariance(0.0, 1.0)) === Float64 - @test eltype(NormalMeanVariance(0)) === Float64 - @test eltype(NormalMeanVariance(0, 1)) === Float64 - @test eltype(NormalMeanVariance(0.0, 1)) === Float64 - @test eltype(NormalMeanVariance(0, 1.0)) === Float64 - @test eltype(NormalMeanVariance(0.0f0)) === Float32 - @test eltype(NormalMeanVariance(0.0f0, 1.0f0)) === Float32 - @test eltype(NormalMeanVariance(0.0f0, 1.0)) === Float64 - - @test NormalMeanVariance(3, 5I) == NormalMeanVariance(3, 5) - @test NormalMeanVariance(2, 7.0I) == NormalMeanVariance(2.0, 7.0) -end - -@testitem "NormalMeanVariance: Stats methods" begin - include("./normal_family_setuptests.jl") - - dist1 = NormalMeanVariance(0.0, 1.0) - - @test mean(dist1) === 0.0 - @test median(dist1) === 0.0 - @test mode(dist1) === 0.0 - @test weightedmean(dist1) === 0.0 - @test var(dist1) === 1.0 - @test std(dist1) === 1.0 - @test cov(dist1) === 1.0 - @test invcov(dist1) === 1.0 - @test precision(dist1) === 1.0 - @test entropy(dist1) ≈ 1.41893853320467 - @test pdf(dist1, 1.0) ≈ 0.24197072451914337 - @test pdf(dist1, -1.0) ≈ 0.24197072451914337 - @test pdf(dist1, 0.0) ≈ 0.3989422804014327 - @test logpdf(dist1, 1.0) ≈ -1.4189385332046727 - @test logpdf(dist1, -1.0) ≈ -1.4189385332046727 - @test logpdf(dist1, 0.0) ≈ -0.9189385332046728 - - dist2 = NormalMeanVariance(1.0, 1.0) - - @test mean(dist2) === 1.0 - @test median(dist2) === 1.0 - @test mode(dist2) === 1.0 - @test weightedmean(dist2) === 1.0 - @test var(dist2) === 1.0 - @test std(dist2) === 1.0 - @test cov(dist2) === 1.0 - @test invcov(dist2) === 1.0 - @test precision(dist2) === 1.0 - @test entropy(dist2) ≈ 1.41893853320467 - @test pdf(dist2, 1.0) ≈ 0.3989422804014327 - @test pdf(dist2, -1.0) ≈ 0.05399096651318806 - @test pdf(dist2, 0.0) ≈ 0.24197072451914337 - @test logpdf(dist2, 1.0) ≈ -0.9189385332046728 - @test logpdf(dist2, -1.0) ≈ -2.9189385332046727 - @test logpdf(dist2, 0.0) ≈ -1.4189385332046727 - - dist3 = NormalMeanVariance(1.0, 2.0) - - @test mean(dist3) === 1.0 - @test median(dist3) === 1.0 - @test mode(dist3) === 1.0 - @test weightedmean(dist3) === inv(2.0) - @test var(dist3) === 2.0 - @test std(dist3) === sqrt(2.0) - @test cov(dist3) === 2.0 - @test invcov(dist3) === inv(2.0) - @test precision(dist3) === inv(2.0) - @test entropy(dist3) ≈ 1.7655121234846454 - @test pdf(dist3, 1.0) ≈ 0.28209479177387814 - @test pdf(dist3, -1.0) ≈ 0.1037768743551487 - @test pdf(dist3, 0.0) ≈ 0.21969564473386122 - @test logpdf(dist3, 1.0) ≈ -1.2655121234846454 - @test logpdf(dist3, -1.0) ≈ -2.2655121234846454 - @test logpdf(dist3, 0.0) ≈ -1.5155121234846454 -end - -@testitem "NormalMeanVariance: Base methods" begin - include("./normal_family_setuptests.jl") - - @test convert(NormalMeanVariance{Float32}, NormalMeanVariance()) === NormalMeanVariance{Float32}(0.0f0, 1.0f0) - @test convert(NormalMeanVariance{Float64}, NormalMeanVariance(0.0, 10.0)) == - NormalMeanVariance{Float64}(0.0, 10.0) - @test convert(NormalMeanVariance{Float64}, NormalMeanVariance(0.0, 0.1)) == - NormalMeanVariance{Float64}(0.0, 0.1) - @test convert(NormalMeanVariance{Float64}, 0, 1) == NormalMeanVariance{Float64}(0.0, 1.0) - @test convert(NormalMeanVariance{Float64}, 0, 10) == NormalMeanVariance{Float64}(0.0, 10.0) - @test convert(NormalMeanVariance{Float64}, 0, 0.1) == NormalMeanVariance{Float64}(0.0, 0.1) - @test convert(NormalMeanVariance, 0, 1) == NormalMeanVariance{Float64}(0.0, 1.0) - @test convert(NormalMeanVariance, 0, 10) == NormalMeanVariance{Float64}(0.0, 10.0) - @test convert(NormalMeanVariance, 0, 0.1) == NormalMeanVariance{Float64}(0.0, 0.1) - - distribution = NormalMeanVariance(-2.0, 3.0) - - @test distribution ≈ distribution - @test distribution ≈ convert(NormalMeanPrecision, distribution) - @test distribution ≈ convert(NormalWeightedMeanPrecision, distribution) -end - -@testitem "NormalMeanVariance: vague" begin - include("./normal_family_setuptests.jl") - - d1 = vague(NormalMeanVariance) - - @test typeof(d1) <: NormalMeanVariance - @test mean(d1) == 0.0 - @test var(d1) == 1e12 -end - -@testitem "NormalMeanVariance: prod" begin - include("./normal_family_setuptests.jl") - - for strategy in (ClosedProd(), PreserveTypeProd(Distribution), GenericProd()) - @test prod(strategy, NormalMeanVariance(-1, 1), NormalMeanVariance(1, 1)) ≈ - NormalWeightedMeanPrecision(0.0, 2.0) - @test prod(strategy, NormalMeanVariance(-1, 2), NormalMeanVariance(1, 4)) ≈ - NormalWeightedMeanPrecision(-1 / 4, 3 / 4) - @test prod(strategy, NormalMeanVariance(2, 2), NormalMeanVariance(0, 10)) ≈ - NormalWeightedMeanPrecision(1.0, 3 / 5) - end -end - -================ -File: distributions/normal_family/normal_weighted_mean_precision_tests.jl -================ -@testitem "NormalWeightedMeanPrecision: Constructor" begin - include("./normal_family_setuptests.jl") - - @test NormalWeightedMeanPrecision <: NormalDistributionsFamily - @test NormalWeightedMeanPrecision <: UnivariateNormalDistributionsFamily - - @test NormalWeightedMeanPrecision() == NormalWeightedMeanPrecision{Float64}(0.0, 1.0) - @test NormalWeightedMeanPrecision(1.0) == NormalWeightedMeanPrecision{Float64}(1.0, 1.0) - @test NormalWeightedMeanPrecision(1.0, 2.0) == NormalWeightedMeanPrecision{Float64}(1.0, 2.0) - @test NormalWeightedMeanPrecision(1) == NormalWeightedMeanPrecision{Float64}(1.0, 1.0) - @test NormalWeightedMeanPrecision(1, 2) == NormalWeightedMeanPrecision{Float64}(1.0, 2.0) - @test NormalWeightedMeanPrecision(1.0, 2) == NormalWeightedMeanPrecision{Float64}(1.0, 2.0) - @test NormalWeightedMeanPrecision(1, 2.0) == NormalWeightedMeanPrecision{Float64}(1.0, 2.0) - @test NormalWeightedMeanPrecision(1.0f0) == NormalWeightedMeanPrecision{Float32}(1.0f0, 1.0f0) - @test NormalWeightedMeanPrecision(1.0f0, 2.0f0) == NormalWeightedMeanPrecision{Float32}(1.0f0, 2.0f0) - @test NormalWeightedMeanPrecision(1.0f0, 2.0) == NormalWeightedMeanPrecision{Float64}(1.0, 2.0) - - @test eltype(NormalWeightedMeanPrecision()) === Float64 - @test eltype(NormalWeightedMeanPrecision(0.0)) === Float64 - @test eltype(NormalWeightedMeanPrecision(0.0, 1.0)) === Float64 - @test eltype(NormalWeightedMeanPrecision(0)) === Float64 - @test eltype(NormalWeightedMeanPrecision(0, 1)) === Float64 - @test eltype(NormalWeightedMeanPrecision(0.0, 1)) === Float64 - @test eltype(NormalWeightedMeanPrecision(0, 1.0)) === Float64 - @test eltype(NormalWeightedMeanPrecision(0.0f0)) === Float32 - @test eltype(NormalWeightedMeanPrecision(0.0f0, 1.0f0)) === Float32 - @test eltype(NormalWeightedMeanPrecision(0.0f0, 1.0)) === Float64 - - @test NormalWeightedMeanPrecision(3, 5I) == NormalWeightedMeanPrecision(3, 5) - @test NormalWeightedMeanPrecision(2, 7.0I) == NormalWeightedMeanPrecision(2.0, 7.0) -end - -@testitem "NormalWeightedMeanPrecision: Stats methods" begin - include("./normal_family_setuptests.jl") - - dist1 = NormalWeightedMeanPrecision(0.0, 1.0) - - @test mean(dist1) === 0.0 - @test median(dist1) === 0.0 - @test mode(dist1) === 0.0 - @test weightedmean(dist1) === 0.0 - @test var(dist1) === 1.0 - @test std(dist1) === 1.0 - @test cov(dist1) === 1.0 - @test invcov(dist1) === 1.0 - @test precision(dist1) === 1.0 - @test entropy(dist1) ≈ 1.41893853320467 - @test pdf(dist1, 1.0) ≈ 0.24197072451914337 - @test pdf(dist1, -1.0) ≈ 0.24197072451914337 - @test pdf(dist1, 0.0) ≈ 0.3989422804014327 - @test logpdf(dist1, 1.0) ≈ -1.4189385332046727 - @test logpdf(dist1, -1.0) ≈ -1.4189385332046727 - @test logpdf(dist1, 0.0) ≈ -0.9189385332046728 - - dist2 = NormalWeightedMeanPrecision(1.0, 1.0) - - @test mean(dist2) === 1.0 - @test median(dist2) === 1.0 - @test mode(dist2) === 1.0 - @test weightedmean(dist2) === 1.0 - @test var(dist2) === 1.0 - @test std(dist2) === 1.0 - @test cov(dist2) === 1.0 - @test invcov(dist2) === 1.0 - @test precision(dist2) === 1.0 - @test entropy(dist2) ≈ 1.41893853320467 - @test pdf(dist2, 1.0) ≈ 0.3989422804014327 - @test pdf(dist2, -1.0) ≈ 0.05399096651318806 - @test pdf(dist2, 0.0) ≈ 0.24197072451914337 - @test logpdf(dist2, 1.0) ≈ -0.9189385332046728 - @test logpdf(dist2, -1.0) ≈ -2.9189385332046727 - @test logpdf(dist2, 0.0) ≈ -1.4189385332046727 - - dist3 = NormalWeightedMeanPrecision(1.0, 0.5) - - @test mean(dist3) === inv(0.5) - @test median(dist3) === inv(0.5) - @test mode(dist3) === inv(0.5) - @test weightedmean(dist3) === 1.0 - @test var(dist3) === 2.0 - @test std(dist3) === sqrt(2.0) - @test cov(dist3) === 2.0 - @test invcov(dist3) === inv(2.0) - @test precision(dist3) === inv(2.0) - @test entropy(dist3) ≈ 1.7655121234846454 - @test pdf(dist3, 1.0) ≈ 0.21969564473386122 - @test pdf(dist3, -1.0) ≈ 0.02973257230590734 - @test pdf(dist3, 0.0) ≈ 0.1037768743551487 - @test logpdf(dist3, 1.0) ≈ -1.5155121234846454 - @test logpdf(dist3, -1.0) ≈ -3.5155121234846454 - @test logpdf(dist3, 0.0) ≈ -2.2655121234846454 -end - -@testitem "NormalWeightedMeanPrecision: Base methods" begin - include("./normal_family_setuptests.jl") - - @test convert(NormalWeightedMeanPrecision{Float32}, NormalWeightedMeanPrecision()) == - NormalWeightedMeanPrecision{Float32}(0.0f0, 1.0f0) - @test convert(NormalWeightedMeanPrecision{Float64}, NormalWeightedMeanPrecision(0.0, 10.0)) == - NormalWeightedMeanPrecision{Float64}(0.0, 10.0) - @test convert(NormalWeightedMeanPrecision{Float64}, NormalWeightedMeanPrecision(0.0, 0.1)) == - NormalWeightedMeanPrecision{Float64}(0.0, 0.1) - @test convert(NormalWeightedMeanPrecision{Float64}, 0, 1) == NormalWeightedMeanPrecision{Float64}(0.0, 1.0) - @test convert(NormalWeightedMeanPrecision{Float64}, 0, 10) == NormalWeightedMeanPrecision{Float64}(0.0, 10.0) - @test convert(NormalWeightedMeanPrecision{Float64}, 0, 0.1) == NormalWeightedMeanPrecision{Float64}(0.0, 0.1) - @test convert(NormalWeightedMeanPrecision, 0, 1) == NormalWeightedMeanPrecision{Float64}(0.0, 1.0) - @test convert(NormalWeightedMeanPrecision, 0, 10) == NormalWeightedMeanPrecision{Float64}(0.0, 10.0) - @test convert(NormalWeightedMeanPrecision, 0, 0.1) == NormalWeightedMeanPrecision{Float64}(0.0, 0.1) - - distribution = NormalWeightedMeanPrecision(-2.0, 3.0) - - @test distribution ≈ distribution - @test distribution ≈ convert(NormalMeanPrecision, distribution) - @test distribution ≈ convert(NormalMeanVariance, distribution) -end - -@testitem "NormalWeightedMeanPrecision: vague" begin - include("./normal_family_setuptests.jl") - - d1 = vague(NormalWeightedMeanPrecision) - - @test typeof(d1) <: NormalWeightedMeanPrecision - @test mean(d1) == 0.0 - @test precision(d1) == 1e-12 -end - -@testitem "NormalWeightedMeanPrecision: prod" begin - include("./normal_family_setuptests.jl") - - for strategy in (ClosedProd(), PreserveTypeProd(Distribution), GenericProd()) - @test prod(strategy, NormalWeightedMeanPrecision(-1, 1 / 1), NormalWeightedMeanPrecision(1, 1 / 1)) ≈ - NormalWeightedMeanPrecision(0, 2) - @test prod(strategy, NormalWeightedMeanPrecision(-1, 1 / 2), NormalWeightedMeanPrecision(1, 1 / 4)) ≈ - NormalWeightedMeanPrecision(0, 3 / 4) - @test prod(strategy, NormalWeightedMeanPrecision(2, 1 / 2), NormalWeightedMeanPrecision(0, 1 / 10)) ≈ - NormalWeightedMeanPrecision(2, 3 / 5) - end -end - -================ -File: distributions/wip/test_continuous_bernoulli.jl -================ -module ContinuousBernoulliTest - -using Test -using ExponentialFamily -using Distributions -using Random -using StatsFuns -using ForwardDiff -import ExponentialFamily: - ExponentialFamilyDistribution, getnaturalparameters, compute_logscale, logpartition, basemeasure, - fisherinformation, isvague - -@testset "ContinuousBernoulli" begin - @testset "vague" begin - d = vague(ContinuousBernoulli) - - @test typeof(d) <: ContinuousBernoulli - @test mean(d) === 0.5 - @test succprob(d) === 0.5 - @test failprob(d) === 0.5 - end - - @testset "probvec" begin - @test probvec(ContinuousBernoulli(0.5)) === (0.5, 0.5) - @test probvec(ContinuousBernoulli(0.3)) === (0.7, 0.3) - @test probvec(ContinuousBernoulli(0.6)) === (0.4, 0.6) - end - - @testset "natural parameters related" begin - @test logpartition(convert(ExponentialFamilyDistribution, ContinuousBernoulli(0.5))) ≈ log(2) - @test logpartition(convert(ExponentialFamilyDistribution, ContinuousBernoulli(0.2))) ≈ - log((-3 / 4) / log(1 / 4)) - b_99 = ContinuousBernoulli(0.99) - for i in 1:9 - b = ContinuousBernoulli(i / 10.0) - bnp = convert(ExponentialFamilyDistribution, b) - @test convert(Distribution, bnp) ≈ b - @test logpdf(bnp, 1) ≈ logpdf(b, 1) - @test logpdf(bnp, 0) ≈ logpdf(b, 0) - - @test convert(ExponentialFamilyDistribution, b) == - ExponentialFamilyDistribution(ContinuousBernoulli, [logit(i / 10.0)]) - end - @test isproper(ExponentialFamilyDistribution(ContinuousBernoulli, [10])) === true - @test basemeasure(ExponentialFamilyDistribution(ContinuousBernoulli, [10]), 0.2) == 1.0 - end - - @testset "prod" begin - @test prod(ClosedProd(), ContinuousBernoulli(0.5), ContinuousBernoulli(0.5)) ≈ ContinuousBernoulli(0.5) - @test prod(ClosedProd(), ContinuousBernoulli(0.1), ContinuousBernoulli(0.6)) ≈ - ContinuousBernoulli(0.14285714285714285) - @test prod(ClosedProd(), ContinuousBernoulli(0.78), ContinuousBernoulli(0.05)) ≈ - ContinuousBernoulli(0.1572580645161291) - - left = convert(ExponentialFamilyDistribution, ContinuousBernoulli(0.5)) - right = convert(ExponentialFamilyDistribution, ContinuousBernoulli(0.6)) - @test prod(left, right) == convert(ExponentialFamilyDistribution, ContinuousBernoulli(0.6)) - end - - @testset "rand" begin - dist = ContinuousBernoulli(0.3) - nsamples = 1000 - rng = collect(1:100) - for i in 1:10 - samples = rand(MersenneTwister(rng[i]), dist, nsamples) - mestimated = mean(samples) - weights = ones(nsamples) / nsamples - @test isapprox(mestimated, mean(dist), atol = 1e-1) - @test isapprox( - sum(weight * (sample - mestimated)^2 for (sample, weight) in (samples, weights)), - var(dist), - atol = 1e-1 - ) - end - end - - @testset "fisher information" begin - function transformation(params) - return logistic(params[1]) - end - - for κ in 0.000001:0.01:0.49 - dist = ContinuousBernoulli(κ) - ef = convert(ExponentialFamilyDistribution, dist) - η = getnaturalparameters(ef) - - f_logpartition = (η) -> logpartition(ExponentialFamilyDistribution(ContinuousBernoulli, η)) - autograd_information = (η) -> ForwardDiff.hessian(f_logpartition, η) - @test first(fisherinformation(ef)) ≈ first(autograd_information(η)) atol = 1e-9 - J = ForwardDiff.gradient(transformation, η) - @test J' * fisherinformation(dist) * J ≈ first(fisherinformation(ef)) atol = 1e-9 - end - - for κ in 0.51:0.01:0.99 - dist = ContinuousBernoulli(κ) - ef = convert(ExponentialFamilyDistribution, dist) - η = getnaturalparameters(ef) - - f_logpartition = (η) -> logpartition(ExponentialFamilyDistribution(ContinuousBernoulli, η)) - autograd_information = (η) -> ForwardDiff.hessian(f_logpartition, η) - @test first(fisherinformation(ef)) ≈ first(autograd_information(η)) atol = 1e-9 - J = ForwardDiff.gradient(transformation, η) - @test J' * fisherinformation(dist) * J ≈ first(fisherinformation(ef)) atol = 1e-9 - end - - for κ in 0.499:0.0001:0.50001 - dist = ContinuousBernoulli(κ) - ef = convert(ExponentialFamilyDistribution, dist) - η = getnaturalparameters(ef) - - J = ForwardDiff.gradient(transformation, η) - @test J' * fisherinformation(dist) * J ≈ first(fisherinformation(ef)) atol = 1e-9 - end - end - - @testset "ExponentialFamilyDistribution mean var" begin - for ν in 0.1:0.1:0.99 - dist = ContinuousBernoulli(ν) - ef = convert(ExponentialFamilyDistribution, dist) - @test mean(dist) ≈ mean(ef) atol = 1e-8 - @test var(dist) ≈ var(ef) atol = 1e-8 - end - end -end - -end - -================ -File: distributions/wip/test_multinomial.jl -================ -module MultinomialTest - -using Test -using ExponentialFamily -using Distributions -using Random -using StableRNGs -using ForwardDiff -import ExponentialFamily: ExponentialFamilyDistribution, getnaturalparameters, basemeasure, fisherinformation - -@testset "Multinomial" begin - @testset "probvec" begin - @test probvec(Multinomial(5, [1 / 3, 1 / 3, 1 / 3])) == [1 / 3, 1 / 3, 1 / 3] - @test probvec(Multinomial(3, [0.2, 0.2, 0.4, 0.1, 0.1])) == [0.2, 0.2, 0.4, 0.1, 0.1] - @test probvec(Multinomial(2, [0.5, 0.5])) == [0.5, 0.5] - end - - @testset "vague" begin - @test_throws MethodError vague(Multinomial) - @test_throws MethodError vague(Multinomial, 4) - - vague_dist1 = vague(Multinomial, 5, 4) - @test typeof(vague_dist1) <: Multinomial - @test probvec(vague_dist1) == [1 / 4, 1 / 4, 1 / 4, 1 / 4] - - vague_dist2 = vague(Multinomial, 3, 5) - @test typeof(vague_dist2) <: Multinomial - @test probvec(vague_dist2) == [1 / 5, 1 / 5, 1 / 5, 1 / 5, 1 / 5] - end - - @testset "prod" begin - for n in 2:3 - plength = Int64(ceil(rand(Uniform(1, n)))) - pleft = rand(plength) - pleft = pleft ./ sum(pleft) - pright = rand(plength) - pright = pright ./ sum(pright) - left = Multinomial(n, pleft) - right = Multinomial(n, pright) - efleft = convert(ExponentialFamilyDistribution, left) - efright = convert(ExponentialFamilyDistribution, right) - prod_dist = prod(ClosedProd(), left, right) - prod_ef = prod(efleft, efright) - d = Multinomial(n, ones(plength) ./ plength) - sample_space = unique(rand(StableRNG(1), d, 4000), dims = 2) - sample_space = [sample_space[:, i] for i in 1:size(sample_space, 2)] - - hist_sum(x) = - prod_dist.basemeasure(x) * exp( - prod_dist.naturalparameters' * prod_dist.sufficientstatistics(x) - - prod_dist.logpartition(prod_dist.naturalparameters) - ) - hist_sumef(x) = - prod_ef.basemeasure(x) * exp( - prod_ef.naturalparameters' * prod_ef.sufficientstatistics(x) - - prod_ef.logpartition(prod_ef.naturalparameters) - ) - @test sum(hist_sum(x_sample) for x_sample in sample_space) ≈ 1.0 atol = 1e-10 - @test sum(hist_sumef(x_sample) for x_sample in sample_space) ≈ 1.0 atol = 1e-10 - sample_x = rand(d, 5) - for xi in sample_x - @test prod_dist.basemeasure(xi) ≈ (factorial(n) / prod(@.factorial(xi)))^2 atol = 1e-10 - @test prod_dist.sufficientstatistics(xi) == xi - - @test prod_ef.basemeasure(xi) ≈ (factorial(n) / prod(@.factorial(xi)))^2 atol = 1e-10 - @test prod_ef.sufficientstatistics(xi) == xi - end - end - - @test_throws AssertionError prod( - ClosedProd(), - Multinomial(4, [0.2, 0.4, 0.4]), - Multinomial(5, [0.1, 0.3, 0.6]) - ) - @test_throws AssertionError prod( - ClosedProd(), - Multinomial(4, [0.2, 0.4, 0.4]), - Multinomial(3, [0.1, 0.3, 0.6]) - ) - end - - @testset "natural parameters related " begin - d1 = Multinomial(5, [0.1, 0.4, 0.5]) - d2 = Multinomial(5, [0.2, 0.4, 0.4]) - η1 = ExponentialFamilyDistribution(Multinomial, [log(0.1 / 0.5), log(0.4 / 0.5), 0.0], 5) - η2 = ExponentialFamilyDistribution(Multinomial, [log(0.2 / 0.4), 0.0, 0.0], 5) - - @test convert(ExponentialFamilyDistribution, d1) ≈ η1 - @test convert(ExponentialFamilyDistribution, d2) ≈ η2 - - @test convert(Distribution, η1) ≈ d1 - @test convert(Distribution, η2) ≈ d2 - - @test logpartition(η1) == 3.4657359027997265 - @test logpartition(η2) == 4.5814536593707755 - - @test basemeasure(η1, [1, 2, 2]) == 30.0 - @test basemeasure(η2, [1, 2, 2]) == 30.0 - - @test logpdf(η1, [1, 2, 2]) == logpdf(d1, [1, 2, 2]) - @test logpdf(η2, [1, 2, 2]) == logpdf(d2, [1, 2, 2]) - - @test pdf(η1, [1, 2, 2]) == pdf(d1, [1, 2, 2]) - @test pdf(η2, [1, 2, 2]) == pdf(d2, [1, 2, 2]) - end - - @testset "fisher information" begin - function transformation(η) - expη = exp.(η) - expη / sum(expη) - end - rng = StableRNG(42) - ## ForwardDiff hessian is slow so we only test one time with hessian - n = 3 - p = rand(rng, Dirichlet(ones(n))) - dist = Multinomial(n, p) - ef = convert(ExponentialFamilyDistribution, dist) - η = getnaturalparameters(ef) - - f_logpartition = (η) -> logpartition(ExponentialFamilyDistribution(Multinomial, η, n)) - autograd_information = (η) -> ForwardDiff.hessian(f_logpartition, η) - @test fisherinformation(ef) ≈ autograd_information(η) atol = 1e-8 - - for n in 2:12 - p = rand(rng, Dirichlet(ones(n))) - dist = Multinomial(n, p) - ef = convert(ExponentialFamilyDistribution, dist) - η = getnaturalparameters(ef) - - J = ForwardDiff.jacobian(transformation, η) - @test J' * fisherinformation(dist) * J ≈ fisherinformation(ef) atol = 1e-8 - end - end - - @testset "ExponentialFamilyDistribution mean,cov" begin - rng = StableRNG(42) - for n in 2:12 - p = rand(rng, Dirichlet(ones(n))) - dist = Multinomial(n, p) - ef = convert(ExponentialFamilyDistribution, dist) - @test mean(dist) ≈ mean(ef) atol = 1e-8 - @test cov(dist) ≈ cov(ef) atol = 1e-8 - end - end -end -end - -================ -File: distributions/bernoulli_tests.jl -================ -# Bernoulli comes from Distributions.jl and most of the things should be covered there -# Here we test some extra ExponentialFamily.jl specific functionality - -@testitem "Bernoulli: vague" begin - include("distributions_setuptests.jl") - - d = vague(Bernoulli) - - @test typeof(d) <: Bernoulli - @test mean(d) === 0.5 - @test succprob(d) === 0.5 - @test failprob(d) === 0.5 -end - -@testitem "Bernoulli: probvec" begin - include("distributions_setuptests.jl") - - @test probvec(Bernoulli(0.5)) === (0.5, 0.5) - @test probvec(Bernoulli(0.3)) === (0.7, 0.3) - @test probvec(Bernoulli(0.6)) === (0.4, 0.6) -end - -@testitem "Bernoulli: logscale Bernoulli-Bernoulli/Categorical" begin - include("distributions_setuptests.jl") - - @test BayesBase.compute_logscale(Bernoulli(0.5), Bernoulli(0.5), Bernoulli(0.5)) ≈ log(0.5) - @test BayesBase.compute_logscale(Bernoulli(1), Bernoulli(0.5), Bernoulli(1)) ≈ log(0.5) - @test BayesBase.compute_logscale(Categorical([0.5, 0.5]), Bernoulli(0.5), Categorical([0.5, 0.5])) ≈ log(0.5) - @test BayesBase.compute_logscale(Categorical([0.5, 0.5]), Categorical([0.5, 0.5]), Bernoulli(0.5)) ≈ log(0.5) - @test BayesBase.compute_logscale(Categorical([1.0, 0.0]), Bernoulli(0.5), Categorical([1])) ≈ log(0.5) - @test BayesBase.compute_logscale(Categorical([1.0, 0.0, 0.0]), Bernoulli(0.5), Categorical([1.0, 0, 0])) ≈ log(0.5) -end - -@testitem "Bernoulli: ExponentialFamilyDistribution" begin - include("distributions_setuptests.jl") - - for p in 0.1:0.1:0.9 - @testset let d = Bernoulli(p) - ef = test_exponentialfamily_interface(d; option_assume_no_allocations = true) - η₁ = logit(p) - - for x in (0, 1) - @test @inferred(isbasemeasureconstant(ef)) === ConstantBaseMeasure() - @test @inferred(basemeasure(ef, x)) === oneunit(x) - @test @inferred(sufficientstatistics(ef, x)) === (x,) - @test @inferred(logpartition(ef)) ≈ log(1 + exp(η₁)) - end - - @test !@inferred(insupport(ef, -0.5)) - @test !@inferred(insupport(ef, 0.5)) - - # Not in the support - @test_throws Exception logpdf(ef, 0.5) - @test_throws Exception logpdf(ef, -0.5) - end - end - - # Test failing isproper cases - @test !isproper(MeanParametersSpace(), Bernoulli, [-1]) - @test !isproper(MeanParametersSpace(), Bernoulli, [0.5, 0.5]) - @test !isproper(NaturalParametersSpace(), Bernoulli, [0.5, 0.5]) - @test !isproper(NaturalParametersSpace(), Bernoulli, [Inf]) - - @test_throws Exception convert(ExponentialFamilyDistribution, Bernoulli(1.0)) # We cannot convert from `1.0`, `logit` function returns `Inf` -end - -@testitem "Bernoulli: prod with Distribution" begin - include("distributions_setuptests.jl") - - for strategy in (ClosedProd(), PreserveTypeProd(Distribution), PreserveTypeLeftProd(), PreserveTypeRightProd(), GenericProd()) - @test @inferred(prod(strategy, Bernoulli(0.5), Bernoulli(0.5))) ≈ Bernoulli(0.5) - @test @inferred(prod(strategy, Bernoulli(0.1), Bernoulli(0.6))) ≈ Bernoulli(0.14285714285714285) - @test @inferred(prod(strategy, Bernoulli(0.78), Bernoulli(0.05))) ≈ Bernoulli(0.1572580645161291) - end - - for strategy in (ClosedProd(), PreserveTypeProd(Distribution), GenericProd()) - # Test symmetric case - @test @inferred(prod(strategy, Bernoulli(0.5), Categorical([0.5, 0.5]))) ≈ Categorical([0.5, 0.5]) - @test @inferred(prod(strategy, Categorical([0.5, 0.5]), Bernoulli(0.5))) ≈ Categorical([0.5, 0.5]) - end - - @test @allocated(prod(ClosedProd(), Bernoulli(0.5), Bernoulli(0.5))) === 0 - @test @allocated(prod(PreserveTypeProd(Distribution), Bernoulli(0.5), Bernoulli(0.5))) === 0 - @test @allocated(prod(GenericProd(), Bernoulli(0.5), Bernoulli(0.5))) === 0 -end - -@testitem "Bernoulli: prod with Categorical" begin - include("distributions_setuptests.jl") - - @test prod(ClosedProd(), Bernoulli(0.5), Categorical([0.5, 0.5])) ≈ - Categorical([0.5, 0.5]) - @test prod(ClosedProd(), Bernoulli(0.1), Categorical(0.4, 0.6)) ≈ - Categorical([1 - 0.14285714285714285, 0.14285714285714285]) - @test prod(ClosedProd(), Bernoulli(0.78), Categorical([0.95, 0.05])) ≈ - Categorical([1 - 0.1572580645161291, 0.1572580645161291]) - @test prod(ClosedProd(), Bernoulli(0.5), Categorical([0.3, 0.3, 0.4])) ≈ - Categorical([0.5, 0.5, 0]) - @test prod(ClosedProd(), Bernoulli(0.5), Categorical([1.0])) ≈ - Categorical([1.0, 0]) -end - -@testitem "Bernoulli: prod with ExponentialFamilyDistribution" begin - include("distributions_setuptests.jl") - - for pleft in 0.1:0.1:0.9, pright in 0.1:0.1:0.9 - @testset let (left, right) = (Bernoulli(pleft), Bernoulli(pright)) - @test test_generic_simple_exponentialfamily_product( - left, - right, - strategies = ( - ClosedProd(), - GenericProd(), - PreserveTypeProd(ExponentialFamilyDistribution), - PreserveTypeProd(ExponentialFamilyDistribution{Bernoulli}) - ) - ) - end - end -end - -================ -File: distributions/beta_tests.jl -================ -# Beta comes from Distributions.jl and most of the things should be covered there -# Here we test some extra ExponentialFamily.jl specific functionality - -@testitem "Beta: vague" begin - include("distributions_setuptests.jl") - - d = vague(Beta) - - @test typeof(d) <: Beta - @test mean(d) === 0.5 - @test params(d) === (1.0, 1.0) -end - -@testitem "Beta: mean(::typeof(log))" begin - include("distributions_setuptests.jl") - - @test mean(log, Beta(1.0, 3.0)) ≈ -1.8333333333333335 - @test mean(log, Beta(0.1, 0.3)) ≈ -7.862370395825961 - @test mean(log, Beta(4.5, 0.3)) ≈ -0.07197681436958758 -end - -@testitem "Beta: mean(::typeof(mirrorlog))" begin - include("distributions_setuptests.jl") - - @test mean(mirrorlog, Beta(1.0, 3.0)) ≈ -0.33333333333333337 - @test mean(mirrorlog, Beta(0.1, 0.3)) ≈ -0.9411396776150167 - @test mean(mirrorlog, Beta(4.5, 0.3)) ≈ -4.963371962929249 -end - -@testitem "Beta: ExponentialFamilyDistribution" begin - include("distributions_setuptests.jl") - - for a in 0.1:0.1:0.9, b in 1.1:0.2:2.0 - @testset let d = Beta(a, b) - ef = test_exponentialfamily_interface(d; option_assume_no_allocations = true) - - (η₁, η₂) = (a - 1, b - 1) - - for x in 0.1:0.1:0.9 - @test @inferred(isbasemeasureconstant(ef)) === ConstantBaseMeasure() - @test @inferred(basemeasure(ef, x)) === oneunit(x) - @test all(@inferred(sufficientstatistics(ef, x)) .≈ (log(x), log(1 - x))) - @test @inferred(logpartition(ef)) ≈ (logbeta(η₁ + 1, η₂ + 1)) - end - - @test !@inferred(insupport(ef, -0.5)) - @test @inferred(insupport(ef, 0.5)) - - # Not in the support - @test_throws Exception logpdf(ef, -0.5) - end - end - - # Test failing isproper cases - @test !isproper(MeanParametersSpace(), Beta, [-1]) - @test !isproper(MeanParametersSpace(), Beta, [1, -0.1]) - @test !isproper(MeanParametersSpace(), Beta, [-0.1, 1]) - @test !isproper(NaturalParametersSpace(), Beta, [-1.1]) - @test !isproper(NaturalParametersSpace(), Beta, [1, -1.1]) - @test !isproper(NaturalParametersSpace(), Beta, [-1.1, 1]) - - # `a`s must add up to something more than 1, otherwise is not proper - let ef = convert(ExponentialFamilyDistribution, Beta(0.1, 1.0)) - @test !isproper(prod(PreserveTypeProd(ExponentialFamilyDistribution), ef, ef)) - end - - # `b`s must add up to something more than 1, otherwise is not proper - let ef = convert(ExponentialFamilyDistribution, Beta(1.0, 0.1)) - @test !isproper(prod(PreserveTypeProd(ExponentialFamilyDistribution), ef, ef)) - end -end - -@testitem "Beta: prod with Distributions" begin - include("distributions_setuptests.jl") - - for strategy in (ClosedProd(), PreserveTypeProd(Distribution), PreserveTypeLeftProd(), PreserveTypeRightProd(), GenericProd()) - @test prod(strategy, Beta(3.0, 2.0), Beta(2.0, 1.0)) ≈ Beta(4.0, 2.0) - @test prod(strategy, Beta(7.0, 1.0), Beta(0.1, 4.5)) ≈ Beta(6.1, 4.5) - @test prod(strategy, Beta(1.0, 3.0), Beta(0.2, 0.4)) ≈ Beta(0.19999999999999996, 2.4) - end - - @test @allocated(prod(ClosedProd(), Beta(3.0, 2.0), Beta(2.0, 1.0))) === 0 - @test @allocated(prod(GenericProd(), Beta(3.0, 2.0), Beta(2.0, 1.0))) === 0 -end - -@testitem "Beta: prod with ExponentialFamilyDistribution" begin - include("distributions_setuptests.jl") - - for aleft in 0.51:1.0:5.0, aright in 0.51:1.0:5.0, bleft in 0.51:1.0:5.0, bright in 0.51:1.0:5.0 - @testset let (left, right) = (Beta(aleft, bleft), Beta(aright, bright)) - @test test_generic_simple_exponentialfamily_product( - left, - right, - strategies = ( - ClosedProd(), - GenericProd(), - PreserveTypeProd(ExponentialFamilyDistribution), - PreserveTypeProd(ExponentialFamilyDistribution{Beta}) - ) - ) - end - end -end - -================ -File: distributions/binomial_tests.jl -================ -# Binomial comes from Distributions.jl and most of the things should be covered there -# Here we test some extra ExponentialFamily.jl specific functionality - -@testitem "Binomial: probvec" begin - include("distributions_setuptests.jl") - - @test all(probvec(Binomial(2, 0.8)) .≈ (0.2, 0.8)) - @test probvec(Binomial(2, 0.2)) == (0.8, 0.2) - @test probvec(Binomial(2, 0.1)) == (0.9, 0.1) - @test probvec(Binomial(2)) == (0.5, 0.5) -end - -@testitem "Binomial: vague" begin - include("distributions_setuptests.jl") - - @test_throws MethodError vague(Binomial) - @test_throws MethodError vague(Binomial, 1 / 2) - - vague_dist = vague(Binomial, 5) - @test typeof(vague_dist) <: Binomial - @test probvec(vague_dist) == (0.5, 0.5) -end - -@testitem "Binomial: ExponentialFamilyDistribution" begin - include("distributions_setuptests.jl") - - for n in (2, 3, 4), p in 0.1:0.2:0.9 - @testset let d = Binomial(n, p) - ef = test_exponentialfamily_interface(d; option_assume_no_allocations = true) - - η₁ = log(p / (1 - p)) - - for x in 0:n - @test @inferred(isbasemeasureconstant(ef)) === NonConstantBaseMeasure() - @test @inferred(basemeasure(ef, x)) === binomial(n, x) - @test @inferred(logbasemeasure(ef, x)) === loggamma(n+1) - (loggamma(n - x + 1) + loggamma(x + 1)) - @test all(@inferred(sufficientstatistics(ef, x)) .≈ (x,)) - @test @inferred(logpartition(ef)) ≈ (n * log(1 + exp(η₁))) - end - - @test !@inferred(insupport(ef, -1)) - @test @inferred(insupport(ef, 0)) - - # Not in the support - @test_throws Exception logpdf(ef, -1) - end - end - - # Test failing isproper cases - @test !isproper(MeanParametersSpace(), Binomial, [-1], 1) - @test !isproper(MeanParametersSpace(), Binomial, [0.5], -1) - @test !isproper(MeanParametersSpace(), Binomial, [-0.1, 1], 10) - @test !isproper(NaturalParametersSpace(), Binomial, [-1.1], -1) - @test !isproper(NaturalParametersSpace(), Binomial, [1, -1.1], 10) -end - -@testitem "Binomial: prod ExponentialFamilyDistribution" begin - include("distributions_setuptests.jl") - - for nleft in 1:1, pleft in 0.1:0.1:0.1, nright in 1:1, pright in 0.1:0.1:0.1 - @testset let (left, right) = (Binomial(nleft, pleft), Binomial(nright, pright)) - for (efleft, efright) in ((left, right), (convert(ExponentialFamilyDistribution, left), convert(ExponentialFamilyDistribution, right))) - for strategy in (PreserveTypeProd(ExponentialFamilyDistribution),) - prod_dist = prod(strategy, efleft, efright) - - @test prod_dist isa ExponentialFamilyDistribution - - hist_sum(x) = - basemeasure(prod_dist, x) * exp( - dot(ExponentialFamily.flatten_parameters(sufficientstatistics(prod_dist, x)), getnaturalparameters(prod_dist)) - - logpartition(prod_dist) - ) - - support = 0:1:max(nleft, nright) - - @test sum(hist_sum, support) ≈ 1.0 atol = 1e-9 - - for x in support - @test basemeasure(prod_dist, x) ≈ (binomial(nleft, x) * binomial(nright, x)) - @test all(sufficientstatistics(prod_dist, x) .≈ (x,)) - end - end - end - end - end -end - -================ -File: distributions/categorical_tests.jl -================ -# Categorical comes from Distributions.jl and most of the things should be covered there -# Here we test some extra ExponentialFamily.jl specific functionality - -@testitem "Categorical: vague" begin - include("distributions_setuptests.jl") - - @test_throws MethodError vague(Categorical) - - d1 = vague(Categorical, 2) - - @test typeof(d1) <: Categorical - @test probvec(d1) ≈ [0.5, 0.5] - - d2 = vague(Categorical, 4) - - @test typeof(d2) <: Categorical - @test probvec(d2) ≈ [0.25, 0.25, 0.25, 0.25] -end - -@testitem "Categorical: probvec" begin - include("distributions_setuptests.jl") - - @test probvec(Categorical([0.1, 0.4, 0.5])) == [0.1, 0.4, 0.5] - @test probvec(Categorical([1 / 3, 1 / 3, 1 / 3])) == [1 / 3, 1 / 3, 1 / 3] - @test probvec(Categorical([0.8, 0.1, 0.1])) == [0.8, 0.1, 0.1] -end - -@testitem "Categorical: prod Distribution" begin - include("distributions_setuptests.jl") - - for strategy in (ClosedProd(), PreserveTypeProd(Distribution), PreserveTypeLeftProd(), PreserveTypeRightProd(), GenericProd()) - @test prod(strategy, Categorical([0.1, 0.4, 0.5]), Categorical([1 / 3, 1 / 3, 1 / 3])) == - Categorical([0.1, 0.4, 0.5]) - @test prod(strategy, Categorical([0.1, 0.4, 0.5]), Categorical([0.8, 0.1, 0.1])) == - Categorical([0.47058823529411764, 0.23529411764705882, 0.2941176470588235]) - @test prod(strategy, Categorical([0.2, 0.6, 0.2]), Categorical([0.8, 0.1, 0.1])) ≈ - Categorical([2 / 3, 1 / 4, 1 / 12]) - end -end - -@testitem "Categorical: ExponentialFamilyDistribution" begin - include("distributions_setuptests.jl") - - for s in (2, 3, 4, 5) - @testset let d = Categorical(normalize!(rand(s), 1)) - ef = test_exponentialfamily_interface( - d; - test_fisherinformation_properties = false, # The fisher information is not-posdef, to discuss - test_fisherinformation_against_jacobian = false - ) - - run_test_fisherinformation_against_jacobian(d; assume_no_allocations = false, mappings = ( - NaturalParametersSpace() => MeanParametersSpace(), - # MeanParametersSpace() => NaturalParametersSpace(), # here is the problem for discussion, the test is broken - )) - - θ = probvec(d) - η = map(p -> log(p / θ[end]), θ) - - for x in 1:s - v = zeros(s) - v[x] = 1 - - @test @inferred(isbasemeasureconstant(ef)) === ConstantBaseMeasure() - @test @inferred(basemeasure(ef, x)) === oneunit(x) - @test all(@inferred(sufficientstatistics(ef, x)) .≈ (v,)) - @test @inferred(logpartition(ef)) ≈ logsumexp(η) - end - - @test !@inferred(insupport(ef, s + 1)) - @test @inferred(insupport(ef, s)) - - # # Not in the support - @test_throws Exception logpdf(ef, ones(s)) - end - end - - # Test failing isproper cases - @test !isproper(MeanParametersSpace(), Categorical, [-1], 2) # conditioner does not match the length - @test !isproper(MeanParametersSpace(), Categorical, [-1], 1) - @test !isproper(MeanParametersSpace(), Categorical, [1, 0.5], 2) - @test !isproper(MeanParametersSpace(), Categorical, [-0.5, 1.5], 2) - @test !isproper(NaturalParametersSpace(), Categorical, [-1.1], 2) # conditioner does not match the length - @test !isproper(NaturalParametersSpace(), Categorical, [-1.1], 1) - @test !isproper(NaturalParametersSpace(), Categorical, [1], 1) # length should be >=2 -end - -@testitem "Categorical ExponentialFamilyDistribution supports RecursiveArrayTools" begin - using RecursiveArrayTools - include("distributions_setuptests.jl") - ef = ExponentialFamilyDistribution(Categorical, ArrayPartition([0, 1, 0]), 3, nothing) - part_ef = ExponentialFamilyDistribution(Categorical, ArrayPartition([0, 1], [0]), 3, nothing) - @test convert(Distribution, ef) ≈ convert(Distribution, part_ef) -end - -@testitem "Categorical: prod with ExponentialFamilyDistribution" begin - include("distributions_setuptests.jl") - - for s in (2, 3, 4, 5) - @testset let (left, right) = (Categorical(normalize!(rand(s), 1)), Categorical(normalize!(rand(s), 1))) - @test test_generic_simple_exponentialfamily_product( - left, - right, - strategies = ( - ClosedProd(), - GenericProd(), - PreserveTypeProd(ExponentialFamilyDistribution), - PreserveTypeProd(ExponentialFamilyDistribution{Categorical}) - ) - ) - end - end -end - -================ -File: distributions/chi_squared_tests.jl -================ -# Chisq comes from Distributions.jl and most of the things should be covered there -# Here we test some extra ExponentialFamily.jl specific functionality - -@testitem "Chisq: ExponentialFamilyDistribution" begin - include("distributions_setuptests.jl") - - for i in 3:0.5:7 - @testset let d = Chisq(2 * (i + 1)) - ef = test_exponentialfamily_interface(d; option_assume_no_allocations = true) - η₁ = first(getnaturalparameters(ef)) - - for x in (0.1, 0.5, 1.0) - @test @inferred(isbasemeasureconstant(ef)) === NonConstantBaseMeasure() - @test @inferred(basemeasure(ef, x)) === exp(-x / 2) - @test @inferred(sufficientstatistics(ef, x)) === (log(x),) - @test @inferred(logpartition(ef)) ≈ loggamma(η₁ + 1) + (η₁ + 1) * log(2.0) - end - end - end - - for space in (MeanParametersSpace(), NaturalParametersSpace()) - @test !isproper(space, Chisq, [Inf]) - @test !isproper(space, Chisq, [-1.0]) - @test !isproper(space, Chisq, [NaN]) - @test !isproper(space, Chisq, [1.0], NaN) - @test !isproper(space, Chisq, [0.5, 0.5], 1.0) - end - - ## mean parameter should be integer in the MeanParametersSpace - @test isproper(MeanParametersSpace(), Chisq, [0.1]) - @test isproper(NaturalParametersSpace(), Chisq, [-0.5]) - @test !isproper(NaturalParametersSpace(), Chisq, [-1.5]) - @test convert(ExponentialFamilyDistribution, Chisq(0.5)) ≈ ExponentialFamilyDistribution(Chisq, [-0.75]) - @test_throws Exception convert(ExponentialFamilyDistribution, Chisq(Inf)) -end - -@testitem "Chisq: prod with Distribution and ExponentialFamilyDistribution" begin - include("distributions_setuptests.jl") - - @testset for i in 3:6 - left = Chisq(i + 1) - right = Chisq(i) - prod_dist = prod(PreserveTypeProd(ExponentialFamilyDistribution), left, right) - efleft = convert(ExponentialFamilyDistribution, left) - efright = convert(ExponentialFamilyDistribution, right) - prod_ef = prod(PreserveTypeProd(ExponentialFamilyDistribution), efleft, efright) - η_left = getnaturalparameters(efleft) - η_right = getnaturalparameters(efright) - naturalparameters = η_left + η_right - - @test prod_dist.naturalparameters == naturalparameters - @test getbasemeasure(prod_dist)(i) ≈ exp(-i) - @test sufficientstatistics(prod_dist, i) === (log(i),) - @test getlogpartition(prod_dist)(η_left + η_right) ≈ loggamma(η_left[1] + η_right[1] + 1) - @test getsupport(prod_dist) === support(left) - - @test prod_ef.naturalparameters == naturalparameters - @test getbasemeasure(prod_ef)(i) ≈ exp(-i) - @test sufficientstatistics(prod_ef, i) === (log(i),) - @test getlogpartition(prod_ef)(η_left + η_right) ≈ loggamma(η_left[1] + η_right[1] + 1) - @test getsupport(prod_ef) === support(left) - end -end - -================ -File: distributions/dirichlet_tests.jl -================ -# Dirichlet comes from Distributions.jl and most of the things should be covered there -# Here we test some extra ExponentialFamily.jl specific functionality - -@testitem "Dirichlet: vague" begin - include("distributions_setuptests.jl") - - @test_throws MethodError vague(Dirichlet) - - d1 = vague(Dirichlet, 2) - - @test typeof(d1) <: Dirichlet - @test probvec(d1) == ones(2) - - d2 = vague(Dirichlet, 4) - - @test typeof(d2) <: Dirichlet - @test probvec(d2) == ones(4) -end - -@testitem "Dirichlet: mean(::typeof(log))" begin - include("distributions_setuptests.jl") - - import Base.Broadcast: BroadcastFunction - - @test mean(BroadcastFunction(log), Dirichlet([1.0, 1.0, 1.0])) ≈ [-1.5000000000000002, -1.5000000000000002, -1.5000000000000002] - @test mean(BroadcastFunction(log), Dirichlet([1.1, 2.0, 2.0])) ≈ [-1.9517644694670657, -1.1052251939575213, -1.1052251939575213] - @test mean(BroadcastFunction(log), Dirichlet([3.0, 1.2, 5.0])) ≈ [-1.2410879175727905, -2.4529121492634465, -0.657754584239457] -end - -@testitem "Dirichlet: ExponentialFamilyDistribution" begin - include("distributions_setuptests.jl") - - rng = StableRNG(42) - for len in 3:5 - α = rand(rng, len) - @testset let d = Dirichlet(α) - ef = test_exponentialfamily_interface(d; option_assume_no_allocations = false) - η1 = getnaturalparameters(ef) - - for x in [rand(rng, len) for _ in 1:3] - x = x ./ sum(x) - @test @inferred(isbasemeasureconstant(ef)) === ConstantBaseMeasure() - @test @inferred(basemeasure(ef, x)) === 1.0 - @test all(@inferred(sufficientstatistics(ef, x)) .≈ (map(log, x),)) - firstterm = mapreduce(x -> loggamma(x + 1), +, η1) - secondterm = loggamma(sum(η1) + length(η1)) - @test @inferred(logpartition(ef)) ≈ firstterm - secondterm - end - end - end - - for space in (MeanParametersSpace(), NaturalParametersSpace()) - @test !isproper(space, Dirichlet, [Inf, Inf], 1.0) - @test !isproper(space, Dirichlet, [1.0], Inf) - @test !isproper(space, Dirichlet, [NaN], 1.0) - @test !isproper(space, Dirichlet, [1.0], NaN) - @test !isproper(space, Dirichlet, [0.5, 0.5], 1.0) - @test isproper(space, Dirichlet, [2.0, 3.0]) - @test !isproper(space, Dirichlet, [-1.0, -1.2]) - end - - @test_throws Exception convert(ExponentialFamilyDistribution, Dirichlet([Inf, Inf])) -end - -@testitem "Dirichlet: prod with Distribution" begin - include("distributions_setuptests.jl") - - for strategy in (ClosedProd(), PreserveTypeProd(Distribution), PreserveTypeLeftProd(), PreserveTypeRightProd(), GenericProd()) - @test @inferred(prod(strategy, Dirichlet([1.0, 1.0, 1.0]), Dirichlet([1.0, 1.0, 1.0]))) ≈ Dirichlet([1.0, 1.0, 1.0]) - @test @inferred(prod(strategy, Dirichlet([1.1, 1.0, 2.0]), Dirichlet([1.0, 1.2, 1.0]))) ≈ Dirichlet([1.1, 1.2, 2.0]) - @test @inferred(prod(strategy, Dirichlet([1.1, 2.0, 2.0]), Dirichlet([3.0, 1.2, 5.0]))) ≈ Dirichlet([3.1, 2.2, 6.0]) - end -end - -@testitem "Dirichlet: prod with ExponentialFamilyDistribution" begin - include("distributions_setuptests.jl") - rng = StableRNG(123) - for len in 3:6 - αleft = rand(rng, len) .+ 1 - αright = rand(rng, len) .+ 1 - @testset let (left, right) = (Dirichlet(αleft), Dirichlet(αright)) - @test test_generic_simple_exponentialfamily_product( - left, - right, - strategies = ( - ClosedProd(), - GenericProd() - ) - ) - end - end -end - -================ -File: distributions/distributions_setuptests.jl -================ -using ExponentialFamily, BayesBase, FastCholesky, Distributions, LinearAlgebra, TinyHugeNumbers -using Test, ForwardDiff, Random, StatsFuns, StableRNGs, FillArrays, JET, SpecialFunctions - -import BayesBase: compute_logscale - -import ExponentialFamily: - ExponentialFamilyDistribution, - getnaturalparameters, - getconditioner, - logpartition, - basemeasure, - logbasemeasure, - insupport, - sufficientstatistics, - fisherinformation, - pack_parameters, - unpack_parameters, - isbasemeasureconstant, - ConstantBaseMeasure, - MeanToNatural, - NaturalToMean, - NaturalParametersSpace, - invscatter, - location, - locationdim - -import Distributions: - variate_form, - value_support - -import SpecialFunctions: - logbeta, - loggamma, - digamma, - logfactorial, - besseli - -import HCubature: - hquadrature - -import DomainSets: - NaturalNumbers - -union_types(x::Union) = (x.a, union_types(x.b)...) -union_types(x::Type) = (x,) - -function Base.isapprox(a::Tuple, b::Tuple; kwargs...) - return length(a) === length(b) && all((d) -> isapprox(d[1], d[2]; kwargs...), zip(a, b)) -end - -JET_function_filter(@nospecialize f) = ((f === FastCholesky.cholinv) || (f === FastCholesky.cholsqrt)) - -macro test_opt(expr) - return esc(quote - JET.@test_opt function_filter=JET_function_filter ignored_modules=(Base,LinearAlgebra) $expr - end) -end - -function test_exponentialfamily_interface(distribution; - test_parameters_conversion = true, - test_similar_creation = true, - test_distribution_conversion = true, - test_packing_unpacking = true, - test_isproper = true, - test_basic_functions = true, - test_gradlogpartition_properties = true, - test_fisherinformation_properties = true, - test_fisherinformation_against_hessian = true, - test_fisherinformation_against_jacobian = true, - test_plogpdf_interface = true, - option_assume_no_allocations = false -) - T = ExponentialFamily.exponential_family_typetag(distribution) - - ef = @inferred(convert(ExponentialFamilyDistribution, distribution)) - - @test_opt convert(ExponentialFamilyDistribution, distribution) - - @test ef isa ExponentialFamilyDistribution{T} - - test_parameters_conversion && run_test_parameters_conversion(distribution) - test_similar_creation && run_test_similar_creation(distribution) - test_distribution_conversion && run_test_distribution_conversion(distribution; assume_no_allocations = option_assume_no_allocations) - test_packing_unpacking && run_test_packing_unpacking(distribution) - test_isproper && run_test_isproper(distribution; assume_no_allocations = option_assume_no_allocations) - test_basic_functions && run_test_basic_functions(distribution; assume_no_allocations = option_assume_no_allocations) - test_gradlogpartition_properties && run_test_gradlogpartition_properties(distribution) - test_fisherinformation_properties && run_test_fisherinformation_properties(distribution) - test_fisherinformation_against_hessian && run_test_fisherinformation_against_hessian(distribution; assume_no_allocations = option_assume_no_allocations) - test_fisherinformation_against_jacobian && run_test_fisherinformation_against_jacobian(distribution; assume_no_allocations = option_assume_no_allocations) - test_plogpdf_interface && run_test_plogpdf_interface(distribution) - return ef -end - -function run_test_plogpdf_interface(distribution) - ef = convert(ExponentialFamily.ExponentialFamilyDistribution, distribution) - η = getnaturalparameters(ef) - samples = rand(StableRNG(42), distribution, 10) - _, _samples = ExponentialFamily.check_logpdf(variate_form(typeof(ef)), typeof(samples), eltype(samples), ef, samples) - ss_vectors = map(s -> ExponentialFamily.pack_parameters(ExponentialFamily.sufficientstatistics(ef, s)), _samples) - unnormalized_logpdfs = map(v -> dot(v, η), ss_vectors) - @test all(unnormalized_logpdfs ≈ map(x -> ExponentialFamily._plogpdf(ef, x, 0, 0), _samples)) -end - -function run_test_parameters_conversion(distribution) - T = ExponentialFamily.exponential_family_typetag(distribution) - - tuple_of_θ, conditioner = ExponentialFamily.separate_conditioner(T, params(MeanParametersSpace(), distribution)) - - @test all(ExponentialFamily.join_conditioner(T, tuple_of_θ, conditioner) .== params(MeanParametersSpace(), distribution)) - - @test_opt ExponentialFamily.separate_conditioner(T, params(MeanParametersSpace(), distribution)) - @test_opt ExponentialFamily.join_conditioner(T, tuple_of_θ, conditioner) - @test_opt params(MeanParametersSpace(), distribution) - - ef = @inferred(convert(ExponentialFamilyDistribution, distribution)) - - @test conditioner === getconditioner(ef) - - # Check the `conditioned` conversions, should work for un-conditioned members as well - tuple_of_η = MeanToNatural(T)(tuple_of_θ, conditioner) - - @test all(NaturalToMean(T)(tuple_of_η, conditioner) .≈ tuple_of_θ) - @test all(MeanToNatural(T)(tuple_of_θ, conditioner) .≈ tuple_of_η) - @test all(NaturalToMean(T)(pack_parameters(NaturalParametersSpace(), T, tuple_of_η), conditioner) .≈ pack_parameters(MeanParametersSpace(), T, tuple_of_θ)) - @test all(MeanToNatural(T)(pack_parameters(MeanParametersSpace(), T, tuple_of_θ), conditioner) .≈ pack_parameters(NaturalParametersSpace(), T, tuple_of_η)) - - @test_opt NaturalToMean(T)(tuple_of_η, conditioner) - @test_opt MeanToNatural(T)(tuple_of_θ, conditioner) - @test_opt NaturalToMean(T)(pack_parameters(NaturalParametersSpace(), T, tuple_of_η), conditioner) - @test_opt MeanToNatural(T)(pack_parameters(MeanParametersSpace(), T, tuple_of_θ), conditioner) - - @test all(map(NaturalParametersSpace() => MeanParametersSpace(), T, tuple_of_η, conditioner) .≈ tuple_of_θ) - @test all(map(MeanParametersSpace() => NaturalParametersSpace(), T, tuple_of_θ, conditioner) .≈ tuple_of_η) - @test all( - map(NaturalParametersSpace() => MeanParametersSpace(), T, pack_parameters(NaturalParametersSpace(), T, tuple_of_η), conditioner) .≈ - pack_parameters(MeanParametersSpace(), T, tuple_of_θ) - ) - @test all( - map(MeanParametersSpace() => NaturalParametersSpace(), T, pack_parameters(MeanParametersSpace(), T, tuple_of_θ), conditioner) .≈ - pack_parameters(NaturalParametersSpace(), T, tuple_of_η) - ) - - # Double check the `conditioner` free conversions - if isnothing(conditioner) - local _tuple_of_η = MeanToNatural(T)(tuple_of_θ) - - @test all(_tuple_of_η .== tuple_of_η) - @test all(NaturalToMean(T)(_tuple_of_η) .≈ tuple_of_θ) - @test all(NaturalToMean(T)(_tuple_of_η) .≈ tuple_of_θ) - @test all(MeanToNatural(T)(tuple_of_θ) .≈ _tuple_of_η) - @test all(NaturalToMean(T)(pack_parameters(NaturalParametersSpace(), T, _tuple_of_η)) .≈ pack_parameters(MeanParametersSpace(), T, tuple_of_θ)) - @test all(MeanToNatural(T)(pack_parameters(MeanParametersSpace(), T, tuple_of_θ)) .≈ pack_parameters(NaturalParametersSpace(), T, _tuple_of_η)) - - @test all(map(NaturalParametersSpace() => MeanParametersSpace(), T, _tuple_of_η) .≈ tuple_of_θ) - @test all(map(NaturalParametersSpace() => MeanParametersSpace(), T, _tuple_of_η) .≈ tuple_of_θ) - @test all(map(MeanParametersSpace() => NaturalParametersSpace(), T, tuple_of_θ) .≈ _tuple_of_η) - @test all( - map(NaturalParametersSpace() => MeanParametersSpace(), T, pack_parameters(NaturalParametersSpace(), T, _tuple_of_η)) .≈ - pack_parameters(MeanParametersSpace(), T, tuple_of_θ) - ) - @test all( - map(MeanParametersSpace() => NaturalParametersSpace(), T, pack_parameters(MeanParametersSpace(), T, tuple_of_θ)) .≈ - pack_parameters(NaturalParametersSpace(), T, _tuple_of_η) - ) - end - - @test all(unpack_parameters(NaturalParametersSpace(), T, pack_parameters(NaturalParametersSpace(), T, tuple_of_η)) .== tuple_of_η) - @test all(unpack_parameters(MeanParametersSpace(), T, pack_parameters(MeanParametersSpace(), T, tuple_of_θ)) .== tuple_of_θ) - - @test_opt unpack_parameters(NaturalParametersSpace(), T, pack_parameters(NaturalParametersSpace(), T, tuple_of_η)) - @test_opt unpack_parameters(MeanParametersSpace(), T, pack_parameters(MeanParametersSpace(), T, tuple_of_θ)) - - # Extra methods for conditioner free distributions - if isnothing(conditioner) - @test all( - params(MeanParametersSpace(), distribution) .≈ - map(NaturalParametersSpace() => MeanParametersSpace(), T, params(NaturalParametersSpace(), distribution)) - ) - @test all( - params(NaturalParametersSpace(), distribution) .≈ - map(MeanParametersSpace() => NaturalParametersSpace(), T, params(MeanParametersSpace(), distribution)) - ) - end -end - -function run_test_similar_creation(distribution) - T = ExponentialFamily.exponential_family_typetag(distribution) - - ef = @inferred(convert(ExponentialFamilyDistribution, distribution)) - - @test similar(ef) isa ExponentialFamilyDistribution{T} - @test_opt similar(ef) -end - -function run_test_distribution_conversion(distribution; assume_no_allocations = true) - T = ExponentialFamily.exponential_family_typetag(distribution) - - ef = @inferred(convert(ExponentialFamilyDistribution, distribution)) - - @test @inferred(convert(Distribution, ef)) ≈ distribution - @test_opt convert(Distribution, ef) - - if assume_no_allocations - @test @allocated(convert(Distribution, ef)) === 0 - end -end - -function run_test_packing_unpacking(distribution) - T = ExponentialFamily.exponential_family_typetag(distribution) - - tuple_of_θ, conditioner = ExponentialFamily.separate_conditioner(T, params(MeanParametersSpace(), distribution)) - ef = @inferred(convert(ExponentialFamilyDistribution, distribution)) - - tuple_of_η = MeanToNatural(T)(tuple_of_θ, conditioner) - - @test all(unpack_parameters(ef) .≈ tuple_of_η) - @test @allocated(unpack_parameters(ef)) === 0 - - @test_opt ExponentialFamily.separate_conditioner(T, params(MeanParametersSpace(), distribution)) - @test_opt unpack_parameters(ef) -end - -function run_test_isproper(distribution; assume_no_allocations = true) - T = ExponentialFamily.exponential_family_typetag(distribution) - - exponential_family_form = @inferred(convert(ExponentialFamilyDistribution, distribution)) - - @test @inferred(isproper(exponential_family_form)) - @test_opt isproper(exponential_family_form) - - if assume_no_allocations - @test @allocated(isproper(exponential_family_form)) === 0 - end -end - -function run_test_basic_functions(distribution; nsamples = 10, test_gradients = true, test_samples_logpdf = true, assume_no_allocations = true) - T = ExponentialFamily.exponential_family_typetag(distribution) - - ef = @inferred(convert(ExponentialFamilyDistribution, distribution)) - - (η, conditioner) = (getnaturalparameters(ef), getconditioner(ef)) - - # ! do not use `rand(distribution, nsamples)` - # ! do not use fixed RNG - samples = [rand(distribution) for _ in 1:nsamples] - - # Not all methods are defined for all objects in Distributions.jl - # For this methods we first test if the method is defined for the distribution - # And only then we test the method for the exponential family form - potentially_missing_methods = ( - cov, - skewness, - kurtosis - ) - - argument_type = Tuple{typeof(distribution)} - - @test_opt logpdf(ef, first(samples)) - @test_opt pdf(ef, first(samples)) - @test_opt mean(ef) - @test_opt var(ef) - @test_opt std(ef) - # Sampling is not type-stable for all distributions - # due to fallback to `Distributions.jl` - # @test_opt rand(ef) - # @test_opt rand(ef, 10) - # @test_opt rand!(ef, rand(ef, 10)) - - @test_opt isbasemeasureconstant(ef) - @test_opt basemeasure(ef, first(samples)) - @test_opt logbasemeasure(ef, first(samples)) - @test_opt sufficientstatistics(ef, first(samples)) - @test_opt logpartition(ef) - @test_opt gradlogpartition(ef) - @test_opt fisherinformation(ef) - - for x in samples - # We believe in the implementation in the `Distributions.jl` - @test @inferred(logpdf(ef, x)) ≈ logpdf(distribution, x) - @test @inferred(pdf(ef, x)) ≈ pdf(distribution, x) - @test @inferred(mean(ef)) ≈ mean(distribution) - @test @inferred(var(ef)) ≈ var(distribution) - @test @inferred(std(ef)) ≈ std(distribution) - @test last(size(rand(ef, 10))) === 10 # Test that `rand` without explicit `rng` works - @test rand(StableRNG(42), ef) ≈ rand(StableRNG(42), distribution) - @test all(rand(StableRNG(42), ef, 10) .≈ rand(StableRNG(42), distribution, 10)) - @test all(rand!(StableRNG(42), ef, [deepcopy(x) for _ in 1:10]) .≈ rand!(StableRNG(42), distribution, [deepcopy(x) for _ in 1:10])) - - for method in potentially_missing_methods - if hasmethod(method, argument_type) - @test @inferred(method(ef)) ≈ method(distribution) - end - end - - @test @inferred(isbasemeasureconstant(ef)) === isbasemeasureconstant(T) - @test @inferred(basemeasure(ef, x)) == getbasemeasure(T, conditioner)(x) - @test @inferred(logbasemeasure(ef, x)) == getlogbasemeasure(T, conditioner)(x) - @test logbasemeasure(ef, x) ≈ log(basemeasure(ef, x)) atol = 1e-8 - @test all(@inferred(sufficientstatistics(ef, x)) .== map(f -> f(x), getsufficientstatistics(T, conditioner))) - @test @inferred(logpartition(ef)) == getlogpartition(T, conditioner)(η) - @test @inferred(fisherinformation(ef)) == getfisherinformation(T, conditioner)(η) - - # Double check the `conditioner` free methods - if isnothing(conditioner) - @test @inferred(basemeasure(ef, x)) == getbasemeasure(T)(x) - @test @inferred(logbasemeasure(ef, x)) == getlogbasemeasure(T)(x) - @test all(@inferred(sufficientstatistics(ef, x)) .== map(f -> f(x), getsufficientstatistics(T))) - @test @inferred(logpartition(ef)) == getlogpartition(T)(η) - @test @inferred(fisherinformation(ef)) == getfisherinformation(T)(η) - end - - if test_gradients && value_support(T) === Continuous && x isa Number - let tlogpdf = ForwardDiff.derivative((x) -> logpdf(distribution, x), x) - if !isnan(tlogpdf) && !isinf(tlogpdf) - @test ForwardDiff.derivative((x) -> logpdf(ef, x), x) ≈ tlogpdf - @test ForwardDiff.gradient((x) -> logpdf(ef, x[1]), [x])[1] ≈ tlogpdf - end - end - let tpdf = ForwardDiff.derivative((x) -> pdf(distribution, x), x) - if !isnan(tpdf) && !isinf(tpdf) - @test ForwardDiff.derivative((x) -> pdf(ef, x), x) ≈ tpdf - @test ForwardDiff.gradient((x) -> pdf(ef, x[1]), [x])[1] ≈ tpdf - end - end - end - - if test_gradients && value_support(T) === Continuous && x isa AbstractVector - let tlogpdf = ForwardDiff.gradient((x) -> logpdf(distribution, x), x) - if !any(isnan, tlogpdf) && !any(isinf, tlogpdf) - @test ForwardDiff.gradient((x) -> logpdf(ef, x), x) ≈ tlogpdf - end - end - let tpdf = ForwardDiff.gradient((x) -> pdf(distribution, x), x) - if !any(isnan, tpdf) && !any(isinf, tpdf) - @test ForwardDiff.gradient((x) -> pdf(ef, x), x) ≈ tpdf - end - end - end - - # Test that the selected methods do not allocate - if assume_no_allocations - @test @allocated(logpdf(ef, x)) === 0 - @test @allocated(pdf(ef, x)) === 0 - @test @allocated(mean(ef)) === 0 - @test @allocated(var(ef)) === 0 - @test @allocated(basemeasure(ef, x)) === 0 - @test @allocated(logbasemeasure(ef, x)) === 0 - @test @allocated(sufficientstatistics(ef, x)) === 0 - end - end - - if test_samples_logpdf - @test @inferred(logpdf(ef, samples)) ≈ map((s) -> logpdf(distribution, s), samples) - @test @inferred(pdf(ef, samples)) ≈ map((s) -> pdf(distribution, s), samples) - end -end - -function run_test_fisherinformation_properties(distribution; test_properties_in_natural_space = true, test_properties_in_mean_space = true) - T = ExponentialFamily.exponential_family_typetag(distribution) - - ef = @inferred(convert(ExponentialFamilyDistribution, distribution)) - - (η, conditioner) = (getnaturalparameters(ef), getconditioner(ef)) - - if test_properties_in_natural_space - F = getfisherinformation(NaturalParametersSpace(), T, conditioner)(η) - - @test_opt getfisherinformation(NaturalParametersSpace(), T, conditioner)(η) - - @test issymmetric(F) || (LowerTriangular(F) ≈ (UpperTriangular(F)')) - @test isposdef(F) || all(>(0), eigvals(F)) - @test size(F, 1) === size(F, 2) - @test size(F, 1) === isqrt(length(F)) - @test (inv(fastcholesky(F)) * F ≈ Diagonal(ones(size(F, 1)))) rtol = 1e-2 - end - - if test_properties_in_mean_space - θ = map(NaturalParametersSpace() => MeanParametersSpace(), T, η, conditioner) - F = getfisherinformation(MeanParametersSpace(), T, conditioner)(θ) - - @test_opt getfisherinformation(MeanParametersSpace(), T, conditioner)(θ) - - @test issymmetric(F) || (LowerTriangular(F) ≈ (UpperTriangular(F)')) - @test isposdef(F) || all(>(0), eigvals(F)) - @test size(F, 1) === size(F, 2) - @test size(F, 1) === isqrt(length(F)) - @test (inv(fastcholesky(F)) * F ≈ Diagonal(ones(size(F, 1)))) rtol = 1e-2 - end -end - -function run_test_gradlogpartition_properties(distribution; nsamples = 6000, test_against_forwardiff = true) - ef = @inferred(convert(ExponentialFamilyDistribution, distribution)) - - (η, conditioner) = (getnaturalparameters(ef), getconditioner(ef)) - - rng = StableRNG(42) - # Some distributions do not use a vector to store a collection of samples (e.g. matrix for MvGaussian) - collection_of_samples = rand(rng, distribution, nsamples) - # The `check_logpdf` here converts the collection to a vector like iterable - _, samples = ExponentialFamily.check_logpdf(ef, collection_of_samples) - expectation_of_sufficient_statistics = mean((s) -> ExponentialFamily.pack_parameters(ExponentialFamily.sufficientstatistics(ef, s)), samples) - gradient = gradlogpartition(ef) - inverse_fisher = cholinv(fisherinformation(ef)) - @test length(gradient) === length(η) - @test dot(gradient - expectation_of_sufficient_statistics, inverse_fisher, gradient - expectation_of_sufficient_statistics) ≈ 0 atol = 0.01 - - if test_against_forwardiff - @test gradient ≈ ForwardDiff.gradient((η) -> getlogpartition(ef)(η), getnaturalparameters(ef)) - end -end - -function run_test_fisherinformation_against_hessian(distribution; assume_ours_faster = true, assume_no_allocations = true) - T = ExponentialFamily.exponential_family_typetag(distribution) - - ef = @inferred(convert(ExponentialFamilyDistribution, distribution)) - - (η, conditioner) = (getnaturalparameters(ef), getconditioner(ef)) - - @test fisherinformation(ef) ≈ ForwardDiff.hessian(η -> getlogpartition(NaturalParametersSpace(), T, conditioner)(η), η) - - # Double check the `conditioner` free methods - if isnothing(conditioner) - @test fisherinformation(ef) ≈ ForwardDiff.hessian(η -> getlogpartition(NaturalParametersSpace(), T)(η), η) - end - - if assume_ours_faster - @test @elapsed(fisherinformation(ef)) < (@elapsed(ForwardDiff.hessian(η -> getlogpartition(NaturalParametersSpace(), T, conditioner)(η), η))) - end - - if assume_no_allocations - @test @allocated(fisherinformation(ef)) === 0 - end -end - -function run_test_fisherinformation_against_jacobian( - distribution; - assume_no_allocations = true, - mappings = ( - NaturalParametersSpace() => MeanParametersSpace(), - MeanParametersSpace() => NaturalParametersSpace() - ) -) - T = ExponentialFamily.exponential_family_typetag(distribution) - - ef = @inferred(convert(ExponentialFamilyDistribution, distribution)) - - (η, conditioner) = (getnaturalparameters(ef), getconditioner(ef)) - θ = map(NaturalParametersSpace() => MeanParametersSpace(), T, η, conditioner) - - # Check natural to mean Jacobian based FI computation - # So here we check that the fisher information matrices are identical with respect to `J`, which is the jacobian of the - # transformation. For example if we have a mapping T : M -> N, the fisher information matrices computed in M and N - # respectively must follow this relation `Fₘ = J' * Fₙ * J` - for (M, N, parameters) in ((NaturalParametersSpace(), MeanParametersSpace(), η), (MeanParametersSpace(), NaturalParametersSpace(), θ)) - if (M => N) ∈ mappings - mapping = getmapping(M => N, T) - m = parameters - n = mapping(m, conditioner) - J = ForwardDiff.jacobian(Base.Fix2(mapping, conditioner), m) - Fₘ = getfisherinformation(M, T, conditioner)(m) - Fₙ = getfisherinformation(N, T, conditioner)(n) - - @test Fₘ ≈ (J' * Fₙ * J) - - # Check the default space - if M === NaturalParametersSpace() - # The `fisherinformation` uses the `NaturalParametersSpace` by default - @test fisherinformation(ef) ≈ (J' * Fₙ * J) - end - - # Double check the `conditioner` free methods - if isnothing(conditioner) - n = mapping(m) - J = ForwardDiff.jacobian(mapping, m) - Fₘ = getfisherinformation(M, T)(m) - Fₙ = getfisherinformation(N, T)(n) - - @test Fₘ ≈ (J' * Fₙ * J) - - if M === NaturalParametersSpace() - @test fisherinformation(ef) ≈ (J' * Fₙ * J) - end - end - - if assume_no_allocations - @test @allocated(getfisherinformation(M, T, conditioner)(m)) === 0 - @test @allocated(getfisherinformation(N, T, conditioner)(n)) === 0 - end - end - end -end - -# This generic testing works only for the same distributions `D` -function test_generic_simple_exponentialfamily_product( - left::Distribution, - right::Distribution; - strategies = (GenericProd(),), - test_inplace_version = true, - test_inplace_assume_no_allocations = true, - test_preserve_type_prod_of_distribution = true, - test_against_distributions_prod_if_possible = true -) - Tl = ExponentialFamily.exponential_family_typetag(left) - Tr = ExponentialFamily.exponential_family_typetag(right) - - @test Tl === Tr - - T = Tl - - efleft = @inferred(convert(ExponentialFamilyDistribution, left)) - efright = @inferred(convert(ExponentialFamilyDistribution, right)) - ηleft = @inferred(getnaturalparameters(efleft)) - ηright = @inferred(getnaturalparameters(efright)) - - if (!isnothing(getconditioner(efleft)) || !isnothing(getconditioner(efright))) - @test isapprox(getconditioner(efleft), getconditioner(efright)) - end - - prod_dist = prod(GenericProd(), left, right) - - @test_opt prod(GenericProd(), left, right) - - # We check against the `prod_dist` only if we have the proper solution, and skip if the result is of type `ProductOf` - if test_against_distributions_prod_if_possible && (prod_dist isa ProductOf || !(typeof(prod_dist) <: T)) - prod_dist = nothing - end - - for strategy in strategies - @test @inferred(prod(strategy, efleft, efright)) == ExponentialFamilyDistribution(T, ηleft + ηright, getconditioner(efleft)) - - # Double check the `conditioner` free methods - if isnothing(getconditioner(efleft)) && isnothing(getconditioner(efright)) - @test @inferred(prod(strategy, efleft, efright)) == ExponentialFamilyDistribution(T, ηleft + ηright) - end - - # Check that the result is consistent with the `prod_dist` - if !isnothing(prod_dist) - @test convert(T, prod(strategy, efleft, efright)) ≈ prod_dist - end - end - - if test_inplace_version - @test @inferred(prod!(similar(efleft), efleft, efright)) == - ExponentialFamilyDistribution(T, ηleft + ηright, getconditioner(efleft)) - - if test_inplace_assume_no_allocations - let _similar = similar(efleft) - @test @allocated(prod!(_similar, efleft, efright)) === 0 - end - end - end - - if test_preserve_type_prod_of_distribution - @test @inferred(prod(PreserveTypeProd(T), efleft, efright)) ≈ - prod(PreserveTypeProd(T), left, right) - end - - return true -end - -================ -File: distributions/erlang_tests.jl -================ -# Erlang comes from Distributions.jl and most of the things should be covered there -# Here we test some extra ExponentialFamily.jl specific functionality - -@testitem "Erlang: vague" begin - include("distributions_setuptests.jl") - - @test Erlang() == Erlang(1, 1.0) - @test vague(Erlang) == Erlang(1, 1e12) -end - -@testitem "Erlang: mean(::typeof(log))" begin - include("distributions_setuptests.jl") - - @test mean(log, Erlang(1, 3.0)) ≈ digamma(1) + log(3.0) - @test mean(log, Erlang(2, 0.3)) ≈ digamma(2) + log(0.3) - @test mean(log, Erlang(3, 0.3)) ≈ digamma(3) + log(0.3) -end - -@testitem "Erlang: ExponentialFamilyDistribution" begin - include("distributions_setuptests.jl") - - for a in 1:3, b in 1.0:1.0:3.0 - @testset let d = Erlang(a, b) - ef = test_exponentialfamily_interface(d; option_assume_no_allocations = true) - - (η1, η2) = (a - 1, -inv(b)) - for x in 10rand(4) - @test @inferred(isbasemeasureconstant(ef)) === ConstantBaseMeasure() - @test @inferred(basemeasure(ef, x)) === oneunit(x) - @test all(@inferred(sufficientstatistics(ef, x)) .≈ (log(x), x)) - @test @inferred(logpartition(ef)) ≈ (logfactorial(η1) - (η1 + one(η1)) * log(-η2)) - end - - @test !@inferred(insupport(ef, -0.5)) - @test @inferred(insupport(ef, 0.5)) - - # Not in the support - @test_throws Exception logpdf(ef, -0.5) - end - end - - # Test failing isproper cases - @test !isproper(MeanParametersSpace(), Erlang, [-1]) - @test !isproper(MeanParametersSpace(), Erlang, [1, -0.1]) - @test !isproper(MeanParametersSpace(), Erlang, [-0.1, 1]) - @test !isproper(NaturalParametersSpace(), Erlang, [-1.1]) - @test isproper(NaturalParametersSpace(), Erlang, [1, -1.1]) - @test !isproper(NaturalParametersSpace(), Erlang, [-1.1, 1]) -end - -@testitem "Erlang: prod with Distributions" begin - include("distributions_setuptests.jl") - - for strategy in (ClosedProd(), PreserveTypeProd(Distribution), PreserveTypeLeftProd(), PreserveTypeRightProd(), GenericProd()) - @test prod(strategy, Erlang(1, 1), Erlang(1, 1)) == Erlang(1, 1 / 2) - @test prod(strategy, Erlang(1, 2), Erlang(1, 1)) == Erlang(1, 2 / 3) - @test prod(strategy, Erlang(1, 2), Erlang(1, 2)) == Erlang(1, 1) - @test prod(strategy, Erlang(2, 2), Erlang(1, 2)) == Erlang(2, 1) - @test prod(strategy, Erlang(2, 2), Erlang(2, 2)) == Erlang(3, 1) - end - - @test @allocated(prod(ClosedProd(), Erlang(1, 1), Erlang(1, 1))) == 0 -end - -@testitem "Erlang: prod with ExponentialFamilyDistribution" begin - include("distributions_setuptests.jl") - - for aleft in 1:3, aright in 2:5, bleft in 0.51:1.0:5.0, bright in 0.51:1.0:5.0 - @testset let (left, right) = (Erlang(aleft, bleft), Erlang(aright, bright)) - @test test_generic_simple_exponentialfamily_product( - left, - right, - strategies = ( - ClosedProd(), - GenericProd(), - PreserveTypeProd(ExponentialFamilyDistribution), - PreserveTypeProd(ExponentialFamilyDistribution{Erlang}) - ) - ) - end - end -end - -================ -File: distributions/exponential_tests.jl -================ -# Exponential comes from Distributions.jl and most of the things should be covered there -# Here we test some extra ExponentialFamily.jl specific functionality - -@testitem "Exponential: vague" begin - include("distributions_setuptests.jl") - - d = vague(Exponential) - - @test typeof(d) <: Exponential - @test mean(d) === 1e12 - @test params(d) === (1e12,) -end - -@testitem "Exponential: mean(::typeof(log))" begin - include("distributions_setuptests.jl") - - @test mean(log, Exponential(1)) ≈ -MathConstants.eulergamma - @test mean(log, Exponential(10)) ≈ 1.7253694280925127 - @test mean(log, Exponential(0.1)) ≈ -2.8798007578955787 -end - -@testitem "Exponential: ExponentialFamilyDistribution" begin - include("distributions_setuptests.jl") - - for scale in (0.1, 1.0, 10.0, 10.0rand()) - @testset let d = Exponential(scale) - ef = test_exponentialfamily_interface(d; option_assume_no_allocations = true) - - (η₁,) = -inv(scale) - - for x in [100rand() for _ in 1:4] - @test @inferred(isbasemeasureconstant(ef)) === ConstantBaseMeasure() - @test @inferred(basemeasure(ef, x)) === oneunit(x) - @test all(@inferred(sufficientstatistics(ef, x)) .≈ (x,)) - @test @inferred(logpartition(ef)) ≈ (-log(-η₁)) - end - - @test !@inferred(insupport(ef, -0.5)) - @test @inferred(insupport(ef, 0.5)) - - # Not in the support - @test_throws Exception logpdf(ef, -0.5) - end - end - - # Test failing isproper cases - @test !isproper(MeanParametersSpace(), Exponential, [-1]) - @test !isproper(MeanParametersSpace(), Exponential, [1, -0.1]) - @test !isproper(MeanParametersSpace(), Exponential, [-0.1, 1]) - @test !isproper(NaturalParametersSpace(), Exponential, [1.1]) - @test !isproper(NaturalParametersSpace(), Exponential, [1, -1.1]) - @test !isproper(NaturalParametersSpace(), Exponential, [-1.1, 1]) -end - -@testitem "Exponential: prod with Distributiond" begin - include("distributions_setuptests.jl") - - for strategy in (GenericProd(), ClosedProd(), PreserveTypeProd(Distribution), PreserveTypeLeftProd(), PreserveTypeRightProd()) - @test prod(strategy, Exponential(5), Exponential(4)) ≈ Exponential(1 / 0.45) - @test prod(strategy, Exponential(1), Exponential(1)) ≈ Exponential(1 / 2) - @test prod(strategy, Exponential(0.1), Exponential(0.1)) ≈ Exponential(0.05) - end -end - -@testitem "Exponential: prod with ExponentialFamilyDistribution" begin - include("distributions_setuptests.jl") - - for sleft in (0.1, 1.0, 10.0, 10.0rand()), sright in (0.1, 1.0, 10.0, 10.0rand()) - @testset let (left, right) = (Exponential(sleft), Exponential(sright)) - @test test_generic_simple_exponentialfamily_product( - left, - right, - strategies = ( - ClosedProd(), - GenericProd(), - PreserveTypeProd(ExponentialFamilyDistribution), - PreserveTypeProd(ExponentialFamilyDistribution{Exponential}) - ) - ) - end - end -end - -================ -File: distributions/gamma_inverse_tests.jl -================ -# GammaInverse comes from Distributions.jl and most of the things should be covered there -# Here we test some extra ExponentialFamily.jl specific functionality - -@testitem "GammaInverse: vague" begin - include("distributions_setuptests.jl") - - d = vague(GammaInverse) - @test typeof(d) <: GammaInverse - @test mean(d) == huge - @test params(d) == (2.0, huge) -end - -# (α, θ) = (α_L + α_R + 1, θ_L + θ_R) -@testitem "GammaInverse: prod" begin - include("distributions_setuptests.jl") - - for strategy in (ClosedProd(), PreserveTypeProd(GammaInverse), PreserveTypeLeftProd(), PreserveTypeRightProd(), GenericProd()) - @test @inferred(prod(ClosedProd(), GammaInverse(3.0, 2.0), GammaInverse(2.0, 1.0))) ≈ GammaInverse(6.0, 3.0) - @test @inferred(prod(ClosedProd(), GammaInverse(7.0, 1.0), GammaInverse(0.1, 4.5))) ≈ GammaInverse(8.1, 5.5) - @test @inferred(prod(ClosedProd(), GammaInverse(1.0, 3.0), GammaInverse(0.2, 0.4))) ≈ GammaInverse(2.2, 3.4) - end -end - -# log(θ) - digamma(α) -@testitem "GammaInverse: mean(::typeof(log))" begin - include("distributions_setuptests.jl") - - @test mean(log, GammaInverse(1.0, 3.0)) ≈ 1.6758279535696414 - @test mean(log, GammaInverse(0.1, 0.3)) ≈ 9.21978213608514 - @test mean(log, GammaInverse(4.5, 0.3)) ≈ -2.5928437306854653 - @test mean(log, GammaInverse(42.0, 42.0)) ≈ 0.011952000346086233 -end - -# α / θ -@testitem "GammaInverse: mean(::typeof(inv))" begin - include("distributions_setuptests.jl") - - @test mean(inv, GammaInverse(1.0, 3.0)) ≈ 0.33333333333333333 - @test mean(inv, GammaInverse(0.1, 0.3)) ≈ 0.33333333333333337 - @test mean(inv, GammaInverse(4.5, 0.3)) ≈ 15.0000000000000000 - @test mean(inv, GammaInverse(42.0, 42.0)) ≈ 1.0000000000000000 -end - -@testitem "GammaInverse: ExponentialFamilyDistribution" begin - include("distributions_setuptests.jl") - - for α in (10rand(4) .+ 4.0), θ in 10rand(4) - @testset let d = InverseGamma(α, θ) - ef = test_exponentialfamily_interface(d; option_assume_no_allocations = true) - - (α, β) = params(MeanParametersSpace(), d) - - (η₁, η₂) = (-α - 1, -β) - - for x in 10rand(4) - @test @inferred(isbasemeasureconstant(ef)) === ConstantBaseMeasure() - @test @inferred(basemeasure(ef, x)) === oneunit(x) - @test all(@inferred(sufficientstatistics(ef, x)) .≈ (log(x), inv(x))) - @test @inferred(logpartition(ef)) ≈ (loggamma(-η₁ - 1) - (-η₁ - 1) * log(-η₂)) - end - - @test !@inferred(insupport(ef, -0.5)) - @test @inferred(insupport(ef, 0.5)) - - # Not in the support - @test_throws Exception logpdf(ef, -0.5) - end - end - - # Test failing isproper cases - @test !isproper(MeanParametersSpace(), InverseGamma, [-1]) - @test !isproper(MeanParametersSpace(), InverseGamma, [1, -0.1]) - @test !isproper(MeanParametersSpace(), InverseGamma, [-0.1, 1]) - @test !isproper(NaturalParametersSpace(), InverseGamma, [-0.5]) - @test !isproper(NaturalParametersSpace(), InverseGamma, [1, -1.1]) - @test !isproper(NaturalParametersSpace(), InverseGamma, [-0.5, 1]) -end - -@testitem "GammaInverse: prod with ExponentialFamilyDistribution" begin - include("distributions_setuptests.jl") - - for aleft in 10rand(4), aright in 10rand(4), bleft in 10rand(4), bright in 10rand(4) - @testset let (left, right) = (InverseGamma(aleft, bleft), InverseGamma(aright, bright)) - @test test_generic_simple_exponentialfamily_product( - left, - right, - strategies = ( - ClosedProd(), - GenericProd(), - PreserveTypeProd(ExponentialFamilyDistribution), - PreserveTypeProd(ExponentialFamilyDistribution{InverseGamma}) - ) - ) - end - end -end - -================ -File: distributions/geometric_tests.jl -================ -# Geometric comes from Distributions.jl and most of the things should be covered there -# Here we test some extra ExponentialFamily.jl specific functionality - -@testitem "Geometric: ExponentialFamilyDistribution" begin - include("distributions_setuptests.jl") - - @testset for p in 0.1:0.2:1.0 - @testset let d = Geometric(p) - ef = test_exponentialfamily_interface(d; option_assume_no_allocations = true) - η1 = first(getnaturalparameters(ef)) - - for x in (1, 3, 5) - @test @inferred(isbasemeasureconstant(ef)) === ConstantBaseMeasure() - @test @inferred(basemeasure(ef, x)) === one(x) - @test @inferred(sufficientstatistics(ef, x)) === (x,) - @test @inferred(logpartition(ef)) ≈ -log(one(η1) - exp(η1)) - end - end - end - - for space in (MeanParametersSpace(), NaturalParametersSpace()) - @test !isproper(space, Geometric, [2.0]) - @test !isproper(space, Geometric, [Inf]) - @test !isproper(space, Geometric, [NaN]) - @test !isproper(space, Geometric, [1.0], NaN) - @test !isproper(space, Geometric, [0.5, 0.5], 1.0) - end - ## mean parameter should be integer in the MeanParametersSpace - @test !isproper(MeanParametersSpace(), Geometric, [-0.1]) - @test_throws Exception convert(ExponentialFamilyDistribution, Geometric(Inf)) -end - -@testitem "Geometric: prod with Distributions" begin - include("distributions_setuptests.jl") - - for strategy in (GenericProd(), ClosedProd(), PreserveTypeProd(Distribution), PreserveTypeLeftProd(), PreserveTypeRightProd()) - @test prod(strategy, Geometric(0.5), Geometric(0.6)) == Geometric(0.8) - @test prod(strategy, Geometric(0.3), Geometric(0.8)) == Geometric(0.8600000000000001) - @test prod(strategy, Geometric(0.5), Geometric(0.5)) == Geometric(0.75) - end - - @test @allocated(prod(ClosedProd(), Geometric(0.5), Geometric(0.6))) == 0 -end - -@testitem "Geometric: prod with ExponentialFamilyDistribution" begin - include("distributions_setuptests.jl") - - for pleft in 0.1:0.2:0.5, pright in 0.5:0.2:1.0 - @testset let (left, right) = (Geometric(pleft), Geometric(pright)) - @test test_generic_simple_exponentialfamily_product( - left, - right, - strategies = ( - ClosedProd(), - GenericProd(), - PreserveTypeProd(ExponentialFamilyDistribution), - PreserveTypeProd(ExponentialFamilyDistribution{Geometric}) - ) - ) - end - end -end - -================ -File: distributions/laplace_tests.jl -================ -# Laplace comes from Distributions.jl and most of the things should be covered there -# Here we test some extra ExponentialFamily.jl specific functionality - -@testitem "Laplace: vague" begin - include("distributions_setuptests.jl") - - d = vague(Laplace) - - @test typeof(d) <: Laplace - @test mean(d) === 0.0 - @test params(d) === (0.0, 1e12) -end - -@testitem "Laplace: ExponentialFamilyDistribution" begin - include("distributions_setuptests.jl") - - for location in (-1.0, 0.0, 1.0), scale in (0.25, 0.5, 2.0) - @testset let d = Laplace(location, scale) - ef = test_exponentialfamily_interface(d; option_assume_no_allocations = true) - η₁ = -1 / scale - - for x in (-1.0, 0.0, 1.0) - @test @inferred(isbasemeasureconstant(ef)) === ConstantBaseMeasure() - @test @inferred(basemeasure(ef, x)) === oneunit(x) - @test @inferred(sufficientstatistics(ef, x)) === (abs(x - location),) - @test @inferred(logpartition(ef)) ≈ log(-2 / η₁) - end - end - end - - for space in (MeanParametersSpace(), NaturalParametersSpace()) - @test !isproper(space, Laplace, [Inf], 1.0) - @test !isproper(space, Laplace, [1.0], Inf) - @test !isproper(space, Laplace, [NaN], 1.0) - @test !isproper(space, Laplace, [1.0], NaN) - @test !isproper(space, Laplace, [0.5, 0.5], 1.0) - - # Conditioner is required - @test_throws Exception isproper(space, Laplace, [0.5], [0.5, 0.5]) - @test_throws Exception isproper(space, Laplace, [1.0], nothing) - @test_throws Exception isproper(space, Laplace, [1.0], nothing) - end - - @test_throws Exception convert(ExponentialFamilyDistribution, Laplace(Inf, Inf)) -end - -@testitem "Laplace: prod with Distribution" begin - include("distributions_setuptests.jl") - - @test @inferred(prod(PreserveTypeProd(Laplace), Laplace(0.0, 0.5), Laplace(0.0, 0.5))) ≈ Laplace(0.0, 0.25) - @test @inferred(prod(PreserveTypeProd(Laplace), Laplace(1.0, 1.0), Laplace(1.0, 1.0))) ≈ Laplace(1.0, 0.5) - @test @inferred(prod(PreserveTypeProd(Laplace), Laplace(2.0, 3.0), Laplace(2.0, 7.0))) ≈ Laplace(2.0, 2.1) - - # GenericProd should always check the default strategy and fallback if available - @test @inferred(prod(GenericProd(), Laplace(0.0, 0.5), Laplace(0.0, 0.5))) ≈ Laplace(0.0, 0.25) - @test @inferred(prod(GenericProd(), Laplace(1.0, 1.0), Laplace(1.0, 1.0))) ≈ Laplace(1.0, 0.5) - @test @inferred(prod(GenericProd(), Laplace(2.0, 3.0), Laplace(2.0, 7.0))) ≈ Laplace(2.0, 2.1) - - # Different location parameters cannot be compute a closed prod with the same type - @test_throws Exception prod(PreserveTypeProd(Laplace), Laplace(0.0, 0.5), Laplace(0.01, 0.5)) - @test_throws Exception prod(PreserveTypeProd(Laplace), Laplace(1.0, 0.5), Laplace(-1.0, 0.5)) -end - -@testitem "Laplace: prod with ExponentialFamilyDistribution: same location parameter" begin - include("distributions_setuptests.jl") - - for location in (0.0, 1.0), sleft in 0.1:0.1:0.9, sright in 0.1:0.1:0.9 - @testset let (left, right) = (Laplace(location, sleft), Laplace(location, sright)) - @test test_generic_simple_exponentialfamily_product( - left, - right, - strategies = (PreserveTypeProd(ExponentialFamilyDistribution{Laplace}), GenericProd()) - ) - end - end - - # Different location parameters cannot be compute a closed prod with the same type - @test_throws Exception prod( - PreserveTypeProd(ExponentialFamilyDistribution{Laplace}), - convert(ExponentialFamilyDistribution, Laplace(0.0, 0.5)), - convert(ExponentialFamilyDistribution, Laplace(0.01, 0.5)) - ) - @test_throws Exception prod( - PreserveTypeProd(ExponentialFamilyDistribution{Laplace}), - convert(ExponentialFamilyDistribution, Laplace(1.0, 0.5)), - convert(ExponentialFamilyDistribution, Laplace(-1.0, 0.5)) - ) -end - -@testitem "Laplace: prod with ExponentialFamilyDistribution: different location parameter" begin - include("distributions_setuptests.jl") - - for locationleft in (0.0, 1.0), sleft in 0.1:0.1:0.4, locationright in (2.0, 3.0), sright in 1.1:0.1:1.3 - @testset let (left, right) = (Laplace(locationleft, sleft), Laplace(locationright, sright)) - ef_left = convert(ExponentialFamilyDistribution, left) - ef_right = convert(ExponentialFamilyDistribution, right) - ef_prod = prod(PreserveTypeProd(ExponentialFamilyDistribution), ef_left, ef_right) - @test first(hquadrature(x -> pdf(ef_prod, tan(x * pi / 2)) * (pi / 2) * (1 / cos(x * pi / 2)^2), -1.0, 1.0)) ≈ 1.0 atol = 1e-6 - end - end -end - -================ -File: distributions/lognormal_tests.jl -================ -# LogNormal comes from Distributions.jl and most of the things should be covered there -# Here we test some extra ExponentialFamily.jl specific functionality - -@testitem "LogNormal: vague" begin - include("distributions_setuptests.jl") - - @test LogNormal() == LogNormal(0.0, 1.0) - @test typeof(vague(LogNormal)) <: LogNormal - @test vague(LogNormal) == LogNormal(1, 1e12) -end - -@testitem "LogNormal: prod with Distribution" begin - include("distributions_setuptests.jl") - - for strategy in (ClosedProd(), PreserveTypeProd(LogNormal), PreserveTypeLeftProd(), PreserveTypeRightProd(), GenericProd()) - @test prod(strategy, LogNormal(1.0, 1.0), LogNormal(1.0, 1.0)) == LogNormal(0.5, sqrt(1 / 2)) - @test prod(strategy, LogNormal(2.0, 1.0), LogNormal(2.0, 1.0)) == LogNormal(1.5, sqrt(1 / 2)) - @test prod(strategy, LogNormal(1.0, 1.0), LogNormal(2.0, 1.0)) == LogNormal(1.0, sqrt(1 / 2)) - @test prod(strategy, LogNormal(1.0, 2.0), LogNormal(1.0, 2.0)) == LogNormal(-1.0, sqrt(2)) - @test prod(strategy, LogNormal(2.0, 2.0), LogNormal(2.0, 2.0)) == LogNormal(0.0, sqrt(2)) - end -end - -@testitem "LogNormal: prod with ExponentialFamilyDistribution" begin - include("distributions_setuptests.jl") - - for μleft in 10randn(4), μright in 10randn(4), σleft in 10rand(4), σright in 10rand(4) - @testset let (left, right) = (LogNormal(μleft, σleft), LogNormal(μright, σright)) - @test test_generic_simple_exponentialfamily_product( - left, - right, - strategies = ( - ClosedProd(), - GenericProd(), - PreserveTypeProd(ExponentialFamilyDistribution), - PreserveTypeProd(ExponentialFamilyDistribution{LogNormal}) - ) - ) - end - end -end - -@testitem "LogNormal: ExponentialFamilyDistribution" begin - include("distributions_setuptests.jl") - - for μ in 10randn(4), σ in 10rand(4) - @testset let d = LogNormal(μ, σ) - ef = test_exponentialfamily_interface(d; option_assume_no_allocations = true) - - (η₁, η₂) = (μ / abs2(σ) - 1, -1 / (2abs2(σ))) - - for x in 10rand(4) - @test @inferred(isbasemeasureconstant(ef)) === ConstantBaseMeasure() - @test @inferred(basemeasure(ef, x)) ≈ invsqrt2π - @test @inferred(sufficientstatistics(ef, x)) === (log(x), abs2(log(x))) - @test @inferred(logpartition(ef)) ≈ (-(η₁ + 1)^2 / (4η₂) - 1 / 2 * log(-2η₂)) - end - end - end - - @test !isproper(MeanParametersSpace(), LogNormal, [1.0]) - @test !isproper(MeanParametersSpace(), LogNormal, [-1.0, 0.0]) - @test !isproper(MeanParametersSpace(), LogNormal, [1.0, -1.0]) - @test !isproper(NaturalParametersSpace(), LogNormal, [1.0]) - @test !isproper(NaturalParametersSpace(), LogNormal, [-1.0, 0.0]) - @test !isproper(NaturalParametersSpace(), LogNormal, [1.0, 1.0]) -end - -================ -File: distributions/matrix_dirichlet_tests.jl -================ -@testitem "MatrixDirichlet: common" begin - include("distributions_setuptests.jl") - - @test MatrixDirichlet <: Distribution - @test MatrixDirichlet <: ContinuousDistribution - @test MatrixDirichlet <: MatrixDistribution - - @test value_support(MatrixDirichlet) === Continuous - @test variate_form(MatrixDirichlet) === Matrixvariate -end - -@testitem "MatrixDirichlet: vague" begin - include("distributions_setuptests.jl") - - @test_throws MethodError vague(MatrixDirichlet) - - d1 = vague(MatrixDirichlet, 3) - - @test typeof(d1) <: MatrixDirichlet - @test mean(d1) == ones(3, 3) ./ sum(ones(3, 3); dims = 1) - - d2 = vague(MatrixDirichlet, 4) - - @test typeof(d2) <: MatrixDirichlet - @test mean(d2) == ones(4, 4) ./ sum(ones(4, 4); dims = 1) - - @test vague(MatrixDirichlet, 3, 3) == vague(MatrixDirichlet, (3, 3)) - @test vague(MatrixDirichlet, 4, 4) == vague(MatrixDirichlet, (4, 4)) - @test vague(MatrixDirichlet, 3, 4) == vague(MatrixDirichlet, (3, 4)) - @test vague(MatrixDirichlet, 4, 3) == vague(MatrixDirichlet, (4, 3)) - - d3 = vague(MatrixDirichlet, 3, 4) - - @test typeof(d3) <: MatrixDirichlet - @test mean(d3) == ones(3, 4) ./ sum(ones(3, 4); dims = 1) -end - -@testitem "MatrixDirichlet: entropy" begin - include("distributions_setuptests.jl") - - @test entropy(MatrixDirichlet([1.0 1.0; 1.0 1.0; 1.0 1.0])) ≈ -1.3862943611198906 - @test entropy(MatrixDirichlet([1.2 3.3; 4.0 5.0; 2.0 1.1])) ≈ -3.1139933152617787 - @test entropy(MatrixDirichlet([0.2 3.4; 5.0 11.0; 0.2 0.6])) ≈ -11.444984495104693 -end - -@testitem "MatrixDirichlet: mean(::typeof(log))" begin - include("distributions_setuptests.jl") - - import Base.Broadcast: BroadcastFunction - - @test mean(BroadcastFunction(log), MatrixDirichlet([1.0 1.0; 1.0 1.0; 1.0 1.0])) ≈ [ - -1.5000000000000002 -1.5000000000000002 - -1.5000000000000002 -1.5000000000000002 - -1.5000000000000002 -1.5000000000000002 - ] - @test mean(BroadcastFunction(log), MatrixDirichlet([1.2 3.3; 4.0 5.0; 2.0 1.1])) ≈ [ - -2.1920720408623637 -1.1517536610071326 - -0.646914475838374 -0.680458481634953 - -1.480247809171707 -2.6103310904778305 - ] - @test mean(BroadcastFunction(log), MatrixDirichlet([0.2 3.4; 5.0 11.0; 0.2 0.6])) ≈ [ - -6.879998107291004 -1.604778825293528 - -0.08484054226701443 -0.32259407259407213 - -6.879998107291004 -4.214965875553984 - ] -end - -@testitem "MatrixDirichlet: ExponentialFamilyDistribution" begin - include("distributions_setuptests.jl") - - for len in 3:5 - α = rand(1.0:2.0, len, len) - @testset let d = MatrixDirichlet(α) - ef = test_exponentialfamily_interface(d; test_basic_functions = true, option_assume_no_allocations = false) - η1 = getnaturalparameters(ef) - - for x in [rand(1.0:2.0, len, len) for _ in 1:3] - x = x ./ sum(x) - @test @inferred(isbasemeasureconstant(ef)) === ConstantBaseMeasure() - @test @inferred(basemeasure(ef, x)) === 1.0 - @test all(@inferred(sufficientstatistics(ef, x)) .≈ (map(log, x),)) - @test @inferred(logpartition(ef)) ≈ mapreduce( - d -> getlogpartition(NaturalParametersSpace(), Dirichlet)(convert(Vector, d)), - +, - eachcol(first(unpack_parameters(MatrixDirichlet, η1))) - ) - end - end - end - - for space in (MeanParametersSpace(), NaturalParametersSpace()) - @test !isproper(space, MatrixDirichlet, [Inf Inf; Inf 1.0], 1.0) - @test !isproper(space, MatrixDirichlet, [1.0], Inf) - @test !isproper(space, MatrixDirichlet, [NaN], 1.0) - @test !isproper(space, MatrixDirichlet, [1.0], NaN) - @test !isproper(space, MatrixDirichlet, [0.5, 0.5], 1.0) - @test isproper(space, MatrixDirichlet, [2.0, 3.0]) - @test !isproper(space, MatrixDirichlet, [-1.0, -1.2]) - end - - @test_throws Exception convert(ExponentialFamilyDistribution, MatrixDirichlet([Inf Inf; 2 3])) -end - -@testitem "MatrixDirichlet: prod with Distribution" begin - include("distributions_setuptests.jl") - - d1 = MatrixDirichlet([0.2 3.4; 5.0 11.0; 0.2 0.6]) - d2 = MatrixDirichlet([1.2 3.3; 4.0 5.0; 2.0 1.1]) - d3 = MatrixDirichlet([1.0 1.0; 1.0 1.0; 1.0 1.0]) - for strategy in (GenericProd(), ClosedProd(), PreserveTypeProd(Distribution), PreserveTypeLeftProd(), PreserveTypeRightProd()) - @test @inferred(prod(strategy, d1, d2)) ≈ - MatrixDirichlet([0.3999999999999999 5.699999999999999; 8.0 15.0; 1.2000000000000002 0.7000000000000002]) - @test @inferred(prod(strategy, d1, d3)) ≈ MatrixDirichlet( - [0.19999999999999996 3.4000000000000004; 5.0 11.0; 0.19999999999999996 0.6000000000000001] - ) - @test @inferred(prod(strategy, d2, d3)) ≈ MatrixDirichlet([1.2000000000000002 3.3; 4.0 5.0; 2.0 1.1]) - end -end - -@testitem "MatrixDirichlet: prod with ExponentialFamilyDistribution" begin - include("distributions_setuptests.jl") - - for len in 3:6 - αleft = rand(len, len) .+ 1 - αright = rand(len, len) .+ 1 - @testset let (left, right) = (MatrixDirichlet(αleft), MatrixDirichlet(αright)) - @test test_generic_simple_exponentialfamily_product( - left, - right, - strategies = ( - ClosedProd(), - GenericProd() - ) - ) - end - end -end - -@testitem "MatrixDirichlet: promote_variate_type" begin - include("distributions_setuptests.jl") - - @test_throws MethodError promote_variate_type(Univariate, MatrixDirichlet) - - @test promote_variate_type(Multivariate, Dirichlet) === Dirichlet - @test promote_variate_type(Matrixvariate, Dirichlet) === MatrixDirichlet - - @test promote_variate_type(Multivariate, MatrixDirichlet) === Dirichlet - @test promote_variate_type(Matrixvariate, MatrixDirichlet) === MatrixDirichlet -end - -================ -File: distributions/mv_normal_wishart_tests.jl -================ -@testitem "MvNormalWishart: common" begin - include("distributions_setuptests.jl") - - m = rand(2) - dist = MvNormalWishart(m, [1.0 0.0; 0.0 1.0], 0.1, 3.0) - @test params(dist) == (m, [1.0 0.0; 0.0 1.0], 0.1, 3.0) - @test dof(dist) == 3.0 - @test invscatter(dist) == [1.0 0.0; 0.0 1.0] - @test scale(dist) == 0.1 - @test locationdim(dist) == 2 -end - -@testitem "MvNormalWishart: ExponentialFamilyDistribution" begin - include("distributions_setuptests.jl") - - for dim in (3,), invS in rand(Wishart(10, Array(Eye(dim))), 4) - ν = dim + 2 - @testset let (d = MvNormalWishart(rand(dim), invS, rand(), ν)) - ef = test_exponentialfamily_interface( - d; - option_assume_no_allocations = false, - test_basic_functions = false, - test_fisherinformation_against_hessian = false, - test_fisherinformation_against_jacobian = false, - test_gradlogpartition_properties = false, - test_plogpdf_interface = false - ) - - run_test_basic_functions(d; assume_no_allocations = false, test_samples_logpdf = false) - end - end -end - -@testitem "MvNormalWishart: prod" begin - include("distributions_setuptests.jl") - - for j in 2:2, κ in 1:2 - m1 = rand(j) - m2 = rand(j) - Ψ1 = m1 * m1' + I - Ψ2 = m2 * m2' + I - dist1 = MvNormalWishart(m1, Ψ1, κ + rand(), rand() + 4) - dist2 = MvNormalWishart(m2, Ψ2, κ + rand(), rand() + 4) - ef1 = convert(ExponentialFamilyDistribution, dist1) - ef2 = convert(ExponentialFamilyDistribution, dist2) - @test prod(PreserveTypeProd(Distribution), dist1, dist2) ≈ convert(Distribution, prod(ClosedProd(), ef1, ef2)) - end -end - -@testitem "MvNormalWishart: prod with ExponentialFamilyDistribution{MvNormalWishart}" begin - include("distributions_setuptests.jl") - - for Sleft in rand(Wishart(10, Array(Eye(2))), 2), Sright in rand(Wishart(10, Array(Eye(2))), 2), νright in (6, 7), νleft in (4, 5) - @testset let (left, right) = (MvNormalWishart(rand(2), Sleft, rand(), νleft), MvNormalWishart(rand(2), Sright, rand(), νright)) - @test test_generic_simple_exponentialfamily_product( - left, - right, - strategies = (PreserveTypeProd(ExponentialFamilyDistribution{MvNormalWishart}), GenericProd()) - ) - end - end -end - -================ -File: distributions/negative_binomial_tests.jl -================ -# NegativeBinomial comes from Distributions.jl and most of the things should be covered there -# Here we test some extra ExponentialFamily.jl specific functionality - -@testitem "NegativeBinomial: probvec" begin - include("distributions_setuptests.jl") - - @test all(probvec(NegativeBinomial(2, 0.8)) .≈ (0.2, 0.8)) - @test probvec(NegativeBinomial(2, 0.2)) == (0.8, 0.2) - @test probvec(NegativeBinomial(2, 0.1)) == (0.9, 0.1) - @test probvec(NegativeBinomial(2)) == (0.5, 0.5) -end - -@testitem "NegativeBinomial: vague" begin - include("distributions_setuptests.jl") - - @test_throws MethodError vague(NegativeBinomial) - @test_throws MethodError vague(NegativeBinomial, 1 / 2) - - vague_dist = vague(NegativeBinomial, 5) - @test typeof(vague_dist) <: NegativeBinomial - @test probvec(vague_dist) == (0.5, 0.5) -end - -@testitem "NegativeBinomial: ExponentialFamilyDistribution" begin - include("distributions_setuptests.jl") - - for p in (0.1, 0.4), r in (2, 3, 4) - @testset let d = NegativeBinomial(r, p) - ef = test_exponentialfamily_interface(d; option_assume_no_allocations = false) - for x in 2:4 - @test @inferred(isbasemeasureconstant(ef)) === NonConstantBaseMeasure() - @test @inferred(basemeasure(ef, x)) === binomial(Int(x + r - 1), x) - @test @inferred(sufficientstatistics(ef, x)) === (x,) - @test @inferred(logpartition(ef)) ≈ -r * log(1 - exp(getnaturalparameters(ef)[1])) - end - end - end - - for space in (MeanParametersSpace(), NaturalParametersSpace()) - @test !isproper(space, NegativeBinomial, [Inf], 1.0) - @test !isproper(space, NegativeBinomial, [1.0], Inf) - @test !isproper(space, NegativeBinomial, [NaN], 1.0) - @test !isproper(space, NegativeBinomial, [1.0], NaN) - @test !isproper(space, NegativeBinomial, [0.5, 0.5], 1.0) - - # Conditioner is required - @test_throws Exception isproper(space, NegativeBinomial, [0.5], [0.5, 0.5]) - @test_throws Exception isproper(space, NegativeBinomial, [1.0], nothing) - @test_throws Exception isproper(space, NegativeBinomial, [1.0], nothing) - end - - @test_throws Exception convert(ExponentialFamilyDistribution, NegativeBinomial(Inf, Inf)) -end - -@testitem "NegativeBinomial: prod" begin - include("distributions_setuptests.jl") - - for nleft in 3:5, pleft in 0.01:0.3:0.99 - left = NegativeBinomial(nleft, pleft) - efleft = convert(ExponentialFamilyDistribution, left) - η_left = getnaturalparameters(efleft) - for nright in 6:7, pright in 0.01:0.3:0.99 - right = NegativeBinomial(nright, pright) - efright = convert(ExponentialFamilyDistribution, right) - η_right = first(getnaturalparameters(efright)) - prod_dist = prod(PreserveTypeProd(ExponentialFamilyDistribution), left, right) - - @test sum(pdf(prod_dist, x) for x in 0:max(nleft, nright)) ≈ 1.0 atol = 1e-5 - end - end -end - -================ -File: distributions/normal_gamma_tests.jl -================ -@testitem "NormalGamma: common" begin - include("distributions_setuptests.jl") - - m = rand() - s, a, b = 1.0, 0.1, 3.0 - dist = NormalGamma(m, s, a, b) - @test params(dist) == (m, s, a, b) - @test location(dist) == m - @test scale(dist) == s - @test shape(dist) == a - @test rate(dist) == b -end - -@testitem "NormalGamma: ExponentialFamilyDistribution" begin - include("distributions_setuptests.jl") - - for μ in 10randn(3), λ in 10rand(3), α in (1 .+ 10rand(3)), β in 10 * rand(3) - @testset let d = NormalGamma(μ, λ, α, β) - ef = test_exponentialfamily_interface(d; option_assume_no_allocations = false) - - (η1, η2, η3, η4) = unpack_parameters(NormalGamma, getnaturalparameters(ef)) - η3half = η3 + 1 / 2 - for x in rand(d, 3) - @test @inferred(isbasemeasureconstant(ef)) === ConstantBaseMeasure() - @test @inferred(basemeasure(ef, x)) ≈ invsqrt2π - @test @inferred(sufficientstatistics(ef, x)) === (x[1] * x[2], x[1]^2 * x[2], log(x[2]), x[2]) - @test @inferred(logpartition(ef)) ≈ loggamma(η3half) - log(-2η2) * (1 / 2) - (η3half) * log(-η4 + η1^2 / (4η2)) - end - end - end - - @test !isproper(MeanParametersSpace(), NormalGamma, [1.0, 0.0, -1.0, 2.0]) - @test !isproper(MeanParametersSpace(), NormalGamma, [1.0, 0.0, -1.0, 2.0]) - @test !isproper(NaturalParametersSpace(), NormalGamma, [1.0, -0.2, 1.0, -1/0.8 ]) - @test !isproper(NaturalParametersSpace(), NormalGamma, [1.0, -0.2, 1.0, 10/8]) - @test isproper(NaturalParametersSpace(), NormalGamma, [1.0, -0.2, 1.0, -11/8]) - @test !isproper(MeanParametersSpace(), NormalGamma, [-1.0, 0.0, NaN, 1.0], [Inf]) -end - -@testitem "NormalGamma: prod with Distribution" begin - include("distributions_setuptests.jl") - - for strategy in (ClosedProd(), PreserveTypeProd(NormalGamma), PreserveTypeLeftProd(), PreserveTypeRightProd(), GenericProd()) - @test prod(strategy, NormalGamma(1.0, 1.0, 2.0, 3.0), NormalGamma(1.0, 1.0, 5.0, 6.0)) == NormalGamma(1.0, 2.0, 6.5, 9.0) - @test prod(strategy, NormalGamma(2.0, 1.0, 3.0, 4.0), NormalGamma(2.0, 1.0, 0.4, 2.0)) == NormalGamma(2.0, 2.0, 2.9, 6.0) - end -end - -@testitem "NormalGamma: prod with ExponentialFamilyDistribution" begin - include("distributions_setuptests.jl") - - for μleft in 10randn(2), μright in 10randn(2), σleft in 10rand(2), σright in 10rand(2), - αleft in (1 .+ 10rand(2)), αright in (1 .+ 10rand(2)), βleft in 10rand(2), βright in 10rand(2) - - let left = NormalGamma(μleft, σleft, αleft, βleft), right = NormalGamma(μright, σright, αright + 1 / 2, βright) - @test test_generic_simple_exponentialfamily_product( - left, - right, - strategies = ( - ClosedProd(), - GenericProd(), - PreserveTypeProd(ExponentialFamilyDistribution), - PreserveTypeProd(ExponentialFamilyDistribution{NormalGamma}) - ) - ) - end - end -end - -================ -File: distributions/pareto_tests.jl -================ -# Pareto comes from Distributions.jl and most of the things should be covered there -# Here we test some extra ExponentialFamily.jl specific functionality - -@testitem "Pareto: ExponentialFamilyDistribution" begin - include("distributions_setuptests.jl") - - for shape in (5.0, 6.0, 7.0), scale in (0.25, 0.5, 2.0) - @testset let d = Pareto(shape, scale) - ef = test_exponentialfamily_interface(d; option_assume_no_allocations = false) - η1 = -shape - 1 - for x in scale:1.0:scale+3.0 - @test @inferred(isbasemeasureconstant(ef)) === ConstantBaseMeasure() - @test @inferred(basemeasure(ef, x)) === oneunit(x) - @test @inferred(sufficientstatistics(ef, x)) === (log(x),) - @test @inferred(logpartition(ef)) ≈ log(scale^(one(η1) + η1) / (-one(η1) - η1)) - end - end - end - - for space in (MeanParametersSpace(), NaturalParametersSpace()) - @test !isproper(space, Pareto, [Inf], 1.0) - @test !isproper(space, Pareto, [1.0], Inf) - @test !isproper(space, Pareto, [NaN], 1.0) - @test !isproper(space, Pareto, [1.0], NaN) - @test !isproper(space, Pareto, [0.5, 0.5], 1.0) - - # Conditioner is required - @test_throws Exception isproper(space, Pareto, [0.5], [0.5, 0.5]) - @test_throws Exception isproper(space, Pareto, [1.0], nothing) - @test_throws Exception isproper(space, Pareto, [1.0], nothing) - end - - @test_throws Exception convert(ExponentialFamilyDistribution, Pareto(Inf, Inf)) -end - -@testitem "Pareto: prod with Distributions" begin - include("distributions_setuptests.jl") - - @test prod(PreserveTypeProd(Pareto), Pareto(0.5), Pareto(0.6)) == Pareto(2.1) - @test prod(PreserveTypeProd(Pareto), Pareto(0.3), Pareto(0.8)) == Pareto(2.1) - @test prod(PreserveTypeProd(Pareto), Pareto(0.5), Pareto(0.5)) == Pareto(2.0) - @test prod(PreserveTypeProd(Pareto), Pareto(3), Pareto(2)) == Pareto(6.0) -end - -@testitem "Pareto: prod with ExponentialFamilyDistribution" begin - include("distributions_setuptests.jl") - - for conditioner in (0.01, 1.0), alphaleft in 0.1:0.1:0.9, alpharight in 0.1:0.1:0.9 - let left = Pareto(alphaleft, conditioner), right = Pareto(alpharight, conditioner) - @test test_generic_simple_exponentialfamily_product( - left, - right, - strategies = (PreserveTypeProd(ExponentialFamilyDistribution{Pareto}), GenericProd()) - ) - end - end - - # Different conditioner parameters cannot be compute a closed prod with the same type - @test_throws Exception prod( - PreserveTypeProd(ExponentialFamilyDistribution{Pareto}), - convert(ExponentialFamilyDistribution, Pareto(0.0, 0.54)), - convert(ExponentialFamilyDistribution, Pareto(0.01, 0.5)) - ) - @test_throws Exception prod( - PreserveTypeProd(ExponentialFamilyDistribution{Pareto}), - convert(ExponentialFamilyDistribution, Pareto(1.0, 0.56)), - convert(ExponentialFamilyDistribution, Pareto(2.0, 0.5)) - ) -end - -@testitem "Pareto: prod with different conditioner" begin - include("distributions_setuptests.jl") - - for conditioner_left in (2, 3), conditioner_right in (4, 5), alphaleft in 0.1:0.1:0.3, alpharight in 0.1:0.1:0.3 - let left = Pareto(alphaleft, conditioner_left), right = Pareto(alpharight, conditioner_right) - ef_left = convert(ExponentialFamilyDistribution, left) - ef_right = convert(ExponentialFamilyDistribution, right) - prod_dist = prod(PreserveTypeProd(ExponentialFamilyDistribution), ef_left, ef_right) - @test getnaturalparameters(prod_dist) ≈ getnaturalparameters(ef_left) + getnaturalparameters(ef_right) - @test getsupport(prod_dist).lb == max(conditioner_left, conditioner_right) - @test sufficientstatistics(prod_dist, (max(conditioner_left, conditioner_right) + 1)) === (log(max(conditioner_left, conditioner_right) + 1),) - @test first( - hquadrature(x -> pdf(prod_dist, tan(x * pi / 2)) * (pi / 2) * (1 / cos(x * pi / 2)^2), (2 / pi) * atan(getsupport(prod_dist).lb), 1.0) - ) ≈ 1.0 - end - end -end - -================ -File: distributions/poisson_tests.jl -================ -# Poisson comes from Distributions.jl and most of the things should be covered there -# Here we test some extra ExponentialFamily.jl specific functionality - -@testitem "Poisson: ExponentialFamilyDistribution" begin - include("distributions_setuptests.jl") - - @testset for i in 2:7 - @testset let d = Poisson(2 * (i + 1)) - ef = test_exponentialfamily_interface(d; option_assume_no_allocations = true) - η1 = first(getnaturalparameters(ef)) - - for x in 1:5 - @test @inferred(isbasemeasureconstant(ef)) === NonConstantBaseMeasure() - @test @inferred(basemeasure(ef, x)) === 1 / factorial(x) - @test @inferred(logbasemeasure(ef, x)) === -loggamma(x + one(x)) - @test @inferred(sufficientstatistics(ef, x)) === (x,) - @test @inferred(logpartition(ef)) ≈ exp(η1) - end - end - end - - for space in (MeanParametersSpace(), NaturalParametersSpace()) - @test !isproper(space, Poisson, [Inf]) - @test !isproper(space, Poisson, [NaN]) - @test !isproper(space, Poisson, [1.0], NaN) - @test !isproper(space, Poisson, [0.5, 0.5], 1.0) - end - ## mean parameter should be integer in the MeanParametersSpace - @test !isproper(MeanParametersSpace(), Poisson, [-0.1]) - @test_throws Exception convert(ExponentialFamilyDistribution, Poisson(Inf)) -end - -@testitem "Poisson: prod" begin - include("distributions_setuptests.jl") - - @testset for λleft in 2:3, λright in 3:4 - left = Poisson(λleft) - right = Poisson(λright) - prod_dist = prod(PreserveTypeProd(ExponentialFamilyDistribution), left, right) - sample_points = collect(1:5) - for x in sample_points - @test basemeasure(prod_dist, x) == (1 / gamma(x + one(x))^2) - @test sufficientstatistics(prod_dist, x) == (x,) - end - sample_points = [-5, -2, 0, 2, 5] - for η in sample_points - @test logpartition(prod_dist, η) == log(abs(besseli(0, 2 * exp(η / 2)))) - end - @test getnaturalparameters(prod_dist) == [log(λleft) + log(λright)] - @test getsupport(prod_dist) == NaturalNumbers() - - @test sum(pdf(prod_dist, x) for x in 0:15) ≈ 1.0 - end -end - -================ -File: distributions/rayleigh_tests.jl -================ -# Rayleigh comes from Distributions.jl and most of the things should be covered there -# Here we test some extra ExponentialFamily.jl specific functionality - -@testitem "Rayleigh: ExponentialFamilyDistribution" begin - include("distributions_setuptests.jl") - - for σ in 10rand(4) - @testset let d = Rayleigh(σ) - ef = test_exponentialfamily_interface(d; option_assume_no_allocations = true) - η1 = first(getnaturalparameters(ef)) - - for x in 10rand(4) - @test @inferred(isbasemeasureconstant(ef)) === NonConstantBaseMeasure() - @test @inferred(basemeasure(ef, x)) === x - @test @inferred(sufficientstatistics(ef, x)) === (x^2,) - @test @inferred(logpartition(ef)) ≈ -log(-2 * η1) - end - end - end - - for space in (MeanParametersSpace(), NaturalParametersSpace()) - @test !isproper(space, Rayleigh, [Inf]) - @test !isproper(space, Rayleigh, [NaN]) - @test !isproper(space, Rayleigh, [1.0], NaN) - @test !isproper(space, Rayleigh, [0.5, 0.5], 1.0) - end - @test !isproper(MeanParametersSpace(), Rayleigh, [-1.0]) - @test_throws Exception convert(ExponentialFamilyDistribution, Rayleigh(Inf)) -end - -@testitem "Rayleigh: prod with PreserveTypeProd{ExponentialFamilyDistribution}" begin - include("distributions_setuptests.jl") - - for σleft in 1:4, σright in 4:7 - @testset let (left, right) = (Rayleigh(σleft), Rayleigh(σright)) - ef_left = convert(ExponentialFamilyDistribution, left) - ef_right = convert(ExponentialFamilyDistribution, right) - prod_dist = prod(PreserveTypeProd(ExponentialFamilyDistribution), left, right) - @test first(hquadrature(x -> pdf(prod_dist, tan(x * pi / 2)) * (pi / 2) * (1 / cos(x * pi / 2)^2), 0.0, 1.0)) ≈ 1.0 - @test getnaturalparameters(prod_dist) == getnaturalparameters(ef_left) + getnaturalparameters(ef_right) - end - end -end - -================ -File: distributions/von_mises_fisher_tests.jl -================ -# VonMisesFisher comes from Distributions.jl and most of the things should be covered there -# Here we test some extra ExponentialFamily.jl specific functionality - -@testitem "VonMisesFisher: vague" begin - include("distributions_setuptests.jl") - - d = vague(VonMisesFisher, 3) - - @test typeof(d) <: VonMisesFisher - @test mean(d) == zeros(3) - @test params(d) == (zeros(3), 1.0e-12) -end - -@testitem "VonMisesFisher: ExponentialFamilyDistribution" begin - include("distributions_setuptests.jl") - - for len in 3:5, b in (0.5) - a_unnormalized = rand(len) - a = a_unnormalized ./ norm(a_unnormalized) - @testset let d = VonMisesFisher(a, b) - ef = test_exponentialfamily_interface(d; option_assume_no_allocations = false, test_fisherinformation_against_jacobian = false, - test_fisherinformation_properties = false - ) - - run_test_fisherinformation_against_jacobian(d; assume_no_allocations = false, mappings = ( - NaturalParametersSpace() => MeanParametersSpace(), - # MeanParametersSpace() => NaturalParametersSpace(), # here is the problem for discussion, the test is broken - )) - - for x in rand(d) - @test @inferred(isbasemeasureconstant(ef)) === ConstantBaseMeasure() - @test @inferred(basemeasure(ef, x)) === (1 / twoπ)^(length(x) * (1 / 2)) - @test all(@inferred(sufficientstatistics(ef, x)) .≈ (x,)) - @test @inferred(logpartition(ef)) ≈ log(besseli((len / 2) - 1, b)) - ((len / 2) - 1) * log(b) - end - end - end -end - -@testitem "VonMisesFisher: prod" begin - include("distributions_setuptests.jl") - - for strategy in (ClosedProd(), PreserveTypeLeftProd(), PreserveTypeRightProd(), PreserveTypeProd(Distribution)) - @test prod(strategy, VonMisesFisher([sin(30), cos(30)], 3.0), VonMisesFisher([sin(45), cos(45)], 4.0)) ≈ - Base.convert( - Distribution, - prod(strategy, convert(ExponentialFamilyDistribution, VonMisesFisher([sin(30), cos(30)], 3.0)), - convert(ExponentialFamilyDistribution, VonMisesFisher([sin(45), cos(45)], 4.0))) - ) - @test prod(strategy, VonMisesFisher([sin(15), cos(15)], 5.0), VonMisesFisher([cos(20), sin(20)], 2.0)) ≈ - Base.convert( - Distribution, - prod(strategy, convert(ExponentialFamilyDistribution, VonMisesFisher([sin(15), cos(15)], 5.0)), - convert(ExponentialFamilyDistribution, VonMisesFisher([cos(20), sin(20)], 2.0))) - ) - end -end - -@testitem "VonMisesFisher: prod with ExponentialFamilyDistribution" begin - include("distributions_setuptests.jl") - - for μleft in eachcol(10rand(4, 4)), μright in eachcol(10rand(4, 4)), σleft in (2, 3), σright in (2, 3) - @testset let (left, right) = (VonMisesFisher(μleft / norm(μleft), σleft), VonMisesFisher(μright / norm(μright), σright)) - @test test_generic_simple_exponentialfamily_product( - left, - right, - strategies = ( - ClosedProd(), - GenericProd(), - PreserveTypeProd(ExponentialFamilyDistribution), - PreserveTypeProd(ExponentialFamilyDistribution{VonMisesFisher}) - ) - ) - end - end -end - -================ -File: distributions/vonmises_tests.jl -================ -# VonMises comes from Distributions.jl and most of the things should be covered there -# Here we test some extra ExponentialFamily.jl specific functionality - -@testitem "VonMises: vague" begin - include("distributions_setuptests.jl") - - d = vague(VonMises) - - @test typeof(d) <: VonMises - @test mean(d) === 0.0 - @test params(d) === (0.0, 1.0e-12) -end - -@testitem "VonMises: ExponentialFamilyDistribution" begin - include("distributions_setuptests.jl") - - for a in -2:1.0:2, b in 0.1:4.0:10.0 - @testset let d = VonMises(a, b) - ef = test_exponentialfamily_interface(d; option_assume_no_allocations = false) - - for x in a-1:0.5:a+1 - @test @inferred(isbasemeasureconstant(ef)) === ConstantBaseMeasure() - @test @inferred(basemeasure(ef, x)) === inv(twoπ) - @test all(@inferred(sufficientstatistics(ef, x)) .≈ (cos(x), sin(x))) - @test @inferred(logpartition(ef)) ≈ (log(besseli(0, b))) - end - - @test !@inferred(insupport(ef, -6)) - @test @inferred(insupport(ef, 0.5)) - - # Not in the support - @test_throws Exception logpdf(ef, -6.0) - end - end - - # Test failing isproper cases - @test !isproper(MeanParametersSpace(), VonMises, [-1]) - @test !isproper(MeanParametersSpace(), VonMises, [1], 3.0) - @test !isproper(MeanParametersSpace(), VonMises, [1, -2]) -end - -@testitem "VonMises: prod with Distributions" begin - include("distributions_setuptests.jl") - for dist1 in [VonMises(10randn(),10rand()) for _=1:20], dist2 in [VonMises(10randn(),10rand()) for _=1:20] - ef1 = convert(ExponentialFamilyDistribution,dist1) - ef2 = convert(ExponentialFamilyDistribution,dist2) - prod_ef = prod(GenericProd(),ef1,ef2) - for strategy in (ClosedProd(), PreserveTypeProd(Distribution), PreserveTypeLeftProd(), PreserveTypeRightProd(), GenericProd()) - @test prod(strategy, dist1, dist2) ≈ convert(Distribution,prod_ef) - end - end -end - -@testitem "VonMises: prod with ExponentialFamilyDistribution" begin - include("distributions_setuptests.jl") - - for kleft in (0.01, 1.0), kright in (0.01, 1.0), alphaleft in 0.1:0.1:0.9, alpharight in 0.1:0.1:0.9 - let left = VonMises(alphaleft, kleft), right = VonMises(alpharight, kright) - @test test_generic_simple_exponentialfamily_product( - left, - right, - strategies = (PreserveTypeProd(ExponentialFamilyDistribution{VonMises}), GenericProd()) - ) - end - end -end - -================ -File: distributions/weibull_tests.jl -================ -# Weibull comes from Distributions.jl and most of the things should be covered there -# Here we test some extra ExponentialFamily.jl specific functionality - -@testitem "Weibull: ExponentialFamilyDistribution" begin - include("distributions_setuptests.jl") - - for shape in (1.0, 2.0, 3.0), scale in (0.25, 0.5, 2.0) - @testset let d = Weibull(shape, scale) - ef = test_exponentialfamily_interface(d; option_assume_no_allocations = false, test_fisherinformation_against_jacobian = false) - η1 = first(getnaturalparameters(ef)) - run_test_fisherinformation_against_jacobian( - d; - assume_no_allocations = true, - mappings = ( - MeanParametersSpace() => NaturalParametersSpace() - ) - ) - for x in scale:1.0:scale+3.0 - @test @inferred(isbasemeasureconstant(ef)) === NonConstantBaseMeasure() - @test @inferred(basemeasure(ef, x)) === x^(shape - 1) - @test @inferred(sufficientstatistics(ef, x)) === (x^shape,) - @test @inferred(logpartition(ef)) ≈ -log(-η1) - log(shape) - end - end - end - - @testset "fisher information by natural to mean jacobian" begin - @testset for k in (1, 3), λ in (0.1, 4.0) - η = -(1 / λ)^k - transformation(η) = [k, (-1 / η[1])^(1 / k)] - J = ForwardDiff.jacobian(transformation, [η]) - @test first(J' * getfisherinformation(MeanParametersSpace(), Weibull, k)(λ) * J) ≈ - first(getfisherinformation(NaturalParametersSpace(), Weibull, k)(η)) atol = 1e-8 - end - end - - for space in (MeanParametersSpace(), NaturalParametersSpace()) - @test !isproper(space, Weibull, [Inf], 1.0) - @test !isproper(space, Weibull, [1.0], Inf) - @test !isproper(space, Weibull, [NaN], 1.0) - @test !isproper(space, Weibull, [1.0], NaN) - @test !isproper(space, Weibull, [0.5, 0.5], 1.0) - - # Conditioner is required - @test_throws Exception isproper(space, Weibull, [0.5], [0.5, 0.5]) - @test_throws Exception isproper(space, Weibull, [1.0], nothing) - @test_throws Exception isproper(space, Weibull, [1.0], nothing) - end - - @test_throws Exception convert(ExponentialFamilyDistribution, Weibull(Inf, Inf)) -end - -@testitem "Weibull: prod with PreserveTypeProd{ExponentialFamilyDistribution} for the same conditioner" begin - include("distributions_setuptests.jl") - - for η in -2.0:0.5:-0.5, k in 1.0:0.5:2, x in 0.5:0.5:2.0 - ef_left = convert(Distribution, ExponentialFamilyDistribution(Weibull, [η], k)) - ef_right = convert(Distribution, ExponentialFamilyDistribution(Weibull, [-η^2], k)) - res = prod(PreserveTypeProd(ExponentialFamilyDistribution), ef_left, ef_right) - @test getbasemeasure(res)(x) == x^(2 * (k - 1)) - @test sufficientstatistics(res, x) == (x^k,) - @test getlogpartition(res)(η - η^2) == - log(abs(η - η^2)^(1 / k)) + loggamma(2 - 1 / k) - 2 * log(abs(η - η^2)) - log(k) - @test getnaturalparameters(res) ≈ [η - η^2] - @test first(hquadrature(x -> pdf(res, tan(x * pi / 2)) * (pi / 2) * (1 / cos(pi * x / 2))^2, 0.0, 1.0)) ≈ - 1.0 - end -end - -@testitem "Weibull: prod with PreserveTypeProd{ExponentialFamilyDistribution} for different k" begin - include("distributions_setuptests.jl") - - for η in -12:4:-0.5, k in 1.0:4:10, x in 0.5:4:10 - ef_left = convert(Distribution, ExponentialFamilyDistribution(Weibull, [η], k * 2)) - ef_right = convert(Distribution, ExponentialFamilyDistribution(Weibull, [-η^2], k)) - res = prod(PreserveTypeProd(ExponentialFamilyDistribution), ef_left, ef_right) - @test getbasemeasure(res)(x) == x^(k + k * 2 - 2) - @test sufficientstatistics(res, x) == (x^(2 * k), x^k) - @test getnaturalparameters(res) ≈ [η, -η^2] - @test first(hquadrature(x -> pdf(res, tan(x * pi / 2)) * (pi / 2) * (1 / cos(pi * x / 2))^2, 0.0, 1.0)) ≈ - 1.0 - end -end - -================ -File: distributions/wishart_inverse_tests.jl -================ -@testitem "InverseWishart: common" begin - include("distributions_setuptests.jl") - - import ExponentialFamily: InverseWishartFast - - @test InverseWishartFast <: Distribution - @test InverseWishartFast <: ContinuousDistribution - @test InverseWishartFast <: MatrixDistribution - - @test value_support(InverseWishartFast) === Continuous - @test variate_form(InverseWishartFast) === Matrixvariate -end - -@testitem "InverseWishart: ExponentialFamilyDistribution" begin - include("distributions_setuptests.jl") - - import ExponentialFamily: InverseWishartFast - - rng = StableRNG(42) - - @testset for dim in (3), S in rand(rng, InverseWishart(10, Array(Eye(dim))), 2) - ν = dim + 4 - @testset let (d = InverseWishartFast(ν, S)) - ef = test_exponentialfamily_interface(d; option_assume_no_allocations = false, test_fisherinformation_against_hessian = false) - (η1, η2) = unpack_parameters(InverseWishartFast, getnaturalparameters(ef)) - - for x in (Eye(dim), Diagonal(ones(dim)), Array(Eye(dim))) - @test @inferred(isbasemeasureconstant(ef)) === ConstantBaseMeasure() - @test @inferred(basemeasure(ef, x)) === 1.0 - @test all(@inferred(sufficientstatistics(ef, x)) .≈ (logdet(x), inv(x))) - @test @inferred(logpartition(ef)) ≈ (η1 + (dim + 1) / 2) * logdet(-η2) + logmvgamma(dim, -(η1 + (dim + 1) / 2)) - end - end - end -end - -@testitem "InverseWishart: statistics" begin - include("distributions_setuptests.jl") - - import ExponentialFamily: InverseWishartFast - - rng = StableRNG(42) - # ν > dim(d) + 1 - for ν in 4:10 - L = randn(rng, ν - 2, ν - 2) - S = L * L' - d = InverseWishartFast(ν, S) - - @test mean(d) == mean(InverseWishart(params(d)...)) - @test mode(d) == mode(InverseWishart(params(d)...)) - end - - # ν > dim(d) + 3 - for ν in 5:10 - L = randn(rng, ν - 4, ν - 4) - S = L * L' - d = InverseWishartFast(ν, S) - - @test cov(d) == cov(InverseWishart(params(d)...)) - @test var(d) == var(InverseWishart(params(d)...)) - end -end - -@testitem "InverseWishart: vague" begin - include("distributions_setuptests.jl") - - dims = 3 - d1 = vague(InverseWishart, dims) - - @test typeof(d1) <: InverseWishart - ν1, S1 = params(d1) - @test ν1 == dims + 2 - @test S1 == tiny .* Eye(dims) - - @test mean(d1) == S1 - - dims = 4 - d2 = vague(InverseWishart, dims) - - @test typeof(d2) <: InverseWishart - ν2, S2 = params(d2) - @test ν2 == dims + 2 - @test S2 == tiny .* Eye(dims) - - @test mean(d2) == S2 -end - -@testitem "InverseWishart: entropy" begin - include("distributions_setuptests.jl") - - import ExponentialFamily: InverseWishartFast - - @test entropy( - InverseWishartFast( - 2.0, - [2.2658069783329573 -0.47934965873423374; -0.47934965873423374 1.4313564100863712] - ) - ) ≈ 10.111427477184794 - @test entropy(InverseWishartFast(5.0, Eye(4))) ≈ 8.939145914882221 -end - -@testitem "InverseWishart: convert" begin - include("distributions_setuptests.jl") - - import ExponentialFamily: InverseWishartFast - - rng = StableRNG(42) - for ν in 2:10 - L = randn(rng, ν, ν) - S = L * L' - d = InverseWishartFast(ν, S) - @test convert(InverseWishart, d) == InverseWishart(ν, S) - end -end - -@testitem "InverseWishart: mean(::typeof(logdet))" begin - include("distributions_setuptests.jl") - - import ExponentialFamily: InverseWishartFast - - rng = StableRNG(123) - ν, S = 2.0, [2.2658069783329573 -0.47934965873423374; -0.47934965873423374 1.4313564100863712] - samples = rand(rng, InverseWishart(ν, S), Int(1e6)) - @test isapprox(mean(logdet, InverseWishartFast(ν, S)), mean(logdet.(samples)), atol = 1e-2) - @test isapprox(mean(logdet, InverseWishart(ν, S)), mean(logdet.(samples)), atol = 1e-2) - - ν, S = 4.0, Array(Eye(3)) - samples = rand(rng, InverseWishart(ν, S), Int(1e6)) - @test isapprox(mean(logdet, InverseWishartFast(ν, S)), mean(logdet.(samples)), atol = 1e-2) - @test isapprox(mean(logdet, InverseWishart(ν, S)), mean(logdet.(samples)), atol = 1e-2) -end - -@testitem "InverseWishart: mean(::typeof(inv))" begin - include("distributions_setuptests.jl") - - import ExponentialFamily: InverseWishartFast - - rng = StableRNG(321) - ν, S = 2.0, [2.2658069783329573 -0.47934965873423374; -0.47934965873423374 1.4313564100863712] - samples = rand(rng, InverseWishart(ν, S), Int(1e6)) - @test isapprox(mean(inv, InverseWishartFast(ν, S)), mean(inv.(samples)), atol = 1e-2) - @test isapprox(mean(inv, InverseWishart(ν, S)), mean(inv.(samples)), atol = 1e-2) - - ν, S = 4.0, Array(Eye(3)) - samples = rand(rng, InverseWishart(ν, S), Int(1e6)) - @test isapprox(mean(inv, InverseWishartFast(ν, S)), mean(inv.(samples)), atol = 1e-2) - @test isapprox(mean(inv, InverseWishart(ν, S)), mean(inv.(samples)), atol = 1e-2) -end - -@testitem "InverseWishart: prod" begin - include("distributions_setuptests.jl") - - import ExponentialFamily: InverseWishartFast - - d1 = InverseWishartFast(3.0, Eye(2)) - d2 = InverseWishartFast(-3.0, [0.6423504672769315 0.9203141654948761; 0.9203141654948761 1.528137747462735]) - - @test prod(PreserveTypeProd(Distribution), d1, d2) ≈ - InverseWishartFast(3.0, [1.6423504672769313 0.9203141654948761; 0.9203141654948761 2.528137747462735]) - - d1 = InverseWishartFast(4.0, Eye(3)) - d2 = InverseWishartFast(-2.0, Eye(3)) - - @test prod(PreserveTypeProd(Distribution), d1, d2) ≈ InverseWishartFast(6.0, 2 * Eye(3)) -end - -@testitem "InverseWishart: rand" begin - include("distributions_setuptests.jl") - - import ExponentialFamily: InverseWishartFast - - for d in (2, 3, 4, 5) - v = rand() + d - L = rand(d, d) - S = L' * L + d * Eye(d) - cS = copy(S) - container1 = [zeros(d, d) for _ in 1:100] - container2 = [zeros(d, d) for _ in 1:100] - - # Check in-place version - @test rand!(StableRNG(321), InverseWishart(v, S), container1) ≈ - rand!(StableRNG(321), InverseWishartFast(v, S), container2) - - # Check that the matrix has not been corrupted - @test all(S .=== cS) - - # Check non-inplace version - @test rand(StableRNG(321), InverseWishart(v, S), length(container1)) ≈ - rand(StableRNG(321), InverseWishartFast(v, S), length(container2)) - end -end - -@testitem "InverseWishart: pdf!" begin - include("distributions_setuptests.jl") - - import ExponentialFamily: InverseWishartFast - import Distributions: pdf! - - for d in (2, 3, 4, 5), n in (10, 20) - v = rand() + d - L = rand(d, d) - S = L' * L + d * Eye(d) - - samples = map(1:n) do _ - L_sample = rand(d, d) - return L_sample' * L_sample + d * Eye(d) - end - - result = zeros(n) - - @test all(pdf(InverseWishart(v, S), samples) .≈ pdf!(result, InverseWishartFast(v, S), samples)) - end -end - -@testitem "InverseWishart: prod with ExponentialFamilyDistribution" begin - include("distributions_setuptests.jl") - - import ExponentialFamily: InverseWishartFast - - for Sleft in rand(InverseWishart(10, Array(Eye(2))), 2), Sright in rand(InverseWishart(10, Array(Eye(2))), 2), νright in (6, 7), νleft in (4, 5) - let left = InverseWishartFast(νleft, Sleft), right = InverseWishartFast(νright, Sright) - @test test_generic_simple_exponentialfamily_product( - left, - right, - strategies = (PreserveTypeProd(ExponentialFamilyDistribution{InverseWishartFast}), GenericProd()) - ) - end - end -end - -================ -File: distributions/wishart_tests.jl -================ -# Wishart comes from Distributions.jl and most of the things should be covered there -# Here we test some extra ExponentialFamily.jl specific functionality - -@testitem "Wishart: mean(::logdet)" begin - include("distributions_setuptests.jl") - - @test mean(logdet, Wishart(3, [1.0 0.0; 0.0 1.0])) ≈ 0.845568670196936 - @test mean( - logdet, - Wishart( - 5, - [ - 1.4659658963311604 1.111775094889733 0.8741034114800605 - 1.111775094889733 0.8746971141492232 0.6545661366809246 - 0.8741034114800605 0.6545661366809246 0.5498917856395482 - ] - ) - ) ≈ -3.4633310802040693 -end - -@testitem "Wishart: mean(::cholinv)" begin - include("distributions_setuptests.jl") - - L = rand(2, 2) - S = L * L' + Eye(2) - invS = inv(S) - @test mean(inv, Wishart(5, S)) ≈ mean(InverseWishart(5, invS)) -end - -@testitem "Wishart: vague" begin - include("distributions_setuptests.jl") - - @test_throws MethodError vague(Wishart) - - d = vague(Wishart, 3) - - @test typeof(d) <: Wishart - @test mean(d) == Matrix(Diagonal(3 * 1e12 * ones(3))) -end - -@testitem "Wishart: rand" begin - include("distributions_setuptests.jl") - - import ExponentialFamily: WishartFast - - rng = StableRNG(42) - - for d in (2, 3, 4, 5) - v = rand(rng) + d - L = rand(rng, d, d) - S = L' * L + d * Eye(d) - invS = inv(S) - cS = copy(S) - cinvS = copy(invS) - container1 = [zeros(d, d) for _ in 1:100] - container2 = [zeros(d, d) for _ in 1:100] - - # Check inplace versions - @test rand!(StableRNG(321), Wishart(v, S), container1) ≈ - rand!(StableRNG(321), WishartFast(v, invS), container2) - - # Check that matrices are not corrupted - @test all(S .=== cS) - @test all(invS .=== cinvS) - - # Check non-inplace versions - @test rand(StableRNG(321), Wishart(v, S), length(container1)) ≈ - rand(StableRNG(321), WishartFast(v, invS), length(container2)) - end -end - - -@testitem "Wishart: ExponentialFamilyDistribution" begin - include("distributions_setuptests.jl") - - import ExponentialFamily: WishartFast - - rng = StableRNG(42) - - for dim in (3, 4), invS in rand(rng, Wishart(10, Array(Eye(dim))), 2) - ν = dim + 2 - @testset let (d = WishartFast(ν, invS)) - ef = test_exponentialfamily_interface(d; option_assume_no_allocations = false, test_fisherinformation_against_hessian = false) - (η1, η2) = unpack_parameters(WishartFast, getnaturalparameters(ef)) - - for x in (Eye(dim), Diagonal(ones(dim)), Array(Eye(dim))) - @test @inferred(isbasemeasureconstant(ef)) === ConstantBaseMeasure() - @test @inferred(basemeasure(ef, x)) === 1.0 - @test all(@inferred(sufficientstatistics(ef, x)) .≈ (logdet(x), x)) - @test @inferred(logpartition(ef)) ≈ -(η1 + (dim + 1) / 2) * logdet(-η2) + logmvgamma(dim, η1 + (dim + 1) / 2) - end - end - end -end - -@testitem "Wishart: prod" begin - include("distributions_setuptests.jl") - - import ExponentialFamily: WishartFast - - inv_v1 = inv([9.0 -3.4; -3.4 11.0]) - inv_v2 = inv([10.2 -3.3; -3.3 5.0]) - inv_v3 = inv([8.1 -2.7; -2.7 9.0]) - - @test prod(PreserveTypeProd(Distribution), WishartFast(3, inv_v1), WishartFast(3, inv_v2)) ≈ - WishartFast( - 3, - inv_v1 + inv_v2 - ) - @test prod(PreserveTypeProd(Distribution), WishartFast(4, inv_v1), WishartFast(4, inv_v3)) ≈ - WishartFast( - 5, - inv_v1 + inv_v3 - ) - @test prod(PreserveTypeProd(Distribution), WishartFast(5, inv_v2), WishartFast(4, inv_v3)) ≈ - WishartFast(6, inv([4.51459128065395 -1.4750681198910067; -1.4750681198910067 3.129155313351499])) -end - -@testitem "Wishart: prod with ExponentialFamilyDistribution{Wishart}" begin - include("distributions_setuptests.jl") - - import ExponentialFamily: WishartFast - - for Sleft in rand(Wishart(10, Array(Eye(2))), 2), Sright in rand(Wishart(10, Array(Eye(2))), 2), νright in (6, 7), νleft in (4, 5) - let left = WishartFast(νleft, Sleft), right = WishartFast(νright, Sright) - @test test_generic_simple_exponentialfamily_product( - left, - right, - strategies = (PreserveTypeProd(ExponentialFamilyDistribution{WishartFast}), GenericProd()) - ) - end - end -end - -================ -File: common_tests.jl -================ -@testitem "dot3arg" begin - using LinearAlgebra, ForwardDiff - using ExponentialFamily: dot3arg - - for n in 2:10 - x = rand(n) - y = rand(n) - A = rand(n, n) - @test dot3arg(x, A, y) ≈ dot(x, A, y) - @test all(ForwardDiff.hessian((x) -> dot3arg(x, A, x), x) .!== 0) - @test all(ForwardDiff.hessian((x) -> dot3arg(x, A, x), y) .!== 0) - end -end - -================ -File: exponential_family_setuptests.jl -================ -using ExponentialFamily, BayesBase, Distributions, Test, StatsFuns, BenchmarkTools, Random, FillArrays - -import Distributions: RealInterval, ContinuousUnivariateDistribution, Univariate -import ExponentialFamily: basemeasure, logbasemeasure, sufficientstatistics, logpartition, insupport, ConstantBaseMeasure -import ExponentialFamily: getnaturalparameters, getbasemeasure, getlogbasemeasure, getsufficientstatistics, getlogpartition, getsupport -import ExponentialFamily: ExponentialFamilyDistributionAttributes, NaturalParametersSpace - -# import ExponentialFamily: -# ExponentialFamilyDistribution, getnaturalparameters, getconditioner, reconstructargument!, as_vec, -# pack_naturalparameters, unpack_naturalparameters, insupport -# import Distributions: pdf, logpdf, cdf - -## =========================================================================== -## Tests fixtures -const ArbitraryExponentialFamilyAttributes = ExponentialFamilyDistributionAttributes( - (x) -> 1 / x, - ((x) -> x, (x) -> log(x)), - (η) -> 1 / sum(η), - RealInterval(0, Inf) -) - -# Arbitrary distribution (un-conditioned) -struct ArbitraryDistributionFromExponentialFamily <: ContinuousUnivariateDistribution - p1::Float64 - p2::Float64 -end - -ExponentialFamily.isproper(::NaturalParametersSpace, ::Type{ArbitraryDistributionFromExponentialFamily}, η, conditioner) = isnothing(conditioner) -ExponentialFamily.isbasemeasureconstant(::Type{ArbitraryDistributionFromExponentialFamily}) = ConstantBaseMeasure() -ExponentialFamily.getbasemeasure(::Type{ArbitraryDistributionFromExponentialFamily}) = (x) -> oneunit(x) -ExponentialFamily.getlogbasemeasure(::Type{ArbitraryDistributionFromExponentialFamily}) = (x) -> zero(x) -ExponentialFamily.getsufficientstatistics(::Type{ArbitraryDistributionFromExponentialFamily}) = - ((x) -> x, (x) -> log(x)) -ExponentialFamily.getlogpartition(::NaturalParametersSpace, ::Type{ArbitraryDistributionFromExponentialFamily}) = (η) -> 1 / sum(η) -ExponentialFamily.getsupport(::Type{ArbitraryDistributionFromExponentialFamily}) = RealInterval(0, Inf) - -BayesBase.vague(::Type{ArbitraryDistributionFromExponentialFamily}) = - ArbitraryDistributionFromExponentialFamily(1.0, 1.0) - -BayesBase.params(dist::ArbitraryDistributionFromExponentialFamily) = (dist.p1, dist.p2) - -(::MeanToNatural{ArbitraryDistributionFromExponentialFamily})(params::Tuple) = (params[1] + 1, params[2] + 1) -(::NaturalToMean{ArbitraryDistributionFromExponentialFamily})(params::Tuple) = (params[1] - 1, params[2] - 1) - -ExponentialFamily.unpack_parameters(::Type{ArbitraryDistributionFromExponentialFamily}, η) = (η[1], η[2]) - -# Arbitrary distribution (conditioned) -struct ArbitraryConditionedDistributionFromExponentialFamily <: ContinuousUnivariateDistribution - con::Int - p1::Float64 -end - -ExponentialFamily.isproper(::NaturalParametersSpace, ::Type{ArbitraryConditionedDistributionFromExponentialFamily}, η, conditioner) = isinteger(conditioner) -ExponentialFamily.isbasemeasureconstant(::Type{ArbitraryConditionedDistributionFromExponentialFamily}) = NonConstantBaseMeasure() -ExponentialFamily.getbasemeasure(::Type{ArbitraryConditionedDistributionFromExponentialFamily}, conditioner) = (x) -> x^conditioner -ExponentialFamily.getlogbasemeasure(::Type{ArbitraryConditionedDistributionFromExponentialFamily}, conditioner) = (x) -> conditioner*log(x) -ExponentialFamily.getsufficientstatistics(::Type{ArbitraryConditionedDistributionFromExponentialFamily}, conditioner) = - ((x) -> log(x - conditioner),) -ExponentialFamily.getlogpartition(::NaturalParametersSpace, ::Type{ArbitraryConditionedDistributionFromExponentialFamily}, conditioner) = - (η) -> conditioner / sum(η) -ExponentialFamily.getsupport(::Type{ArbitraryConditionedDistributionFromExponentialFamily}) = RealInterval(0, Inf) - -BayesBase.vague(::Type{ArbitraryConditionedDistributionFromExponentialFamily}) = - ArbitraryConditionedDistributionFromExponentialFamily(1.0, -2) - -BayesBase.params(dist::ArbitraryConditionedDistributionFromExponentialFamily) = (dist.con, dist.p1) - -ExponentialFamily.separate_conditioner(::Type{ArbitraryConditionedDistributionFromExponentialFamily}, params) = ((params[2],), params[1]) -ExponentialFamily.join_conditioner(::Type{ArbitraryConditionedDistributionFromExponentialFamily}, cparams, conditioner) = (conditioner, cparams...) - -(::MeanToNatural{ArbitraryConditionedDistributionFromExponentialFamily})(params::Tuple, conditioner::Number) = (params[1] + conditioner,) -(::NaturalToMean{ArbitraryConditionedDistributionFromExponentialFamily})(params::Tuple, conditioner::Number) = (params[1] - conditioner,) - -ExponentialFamily.unpack_parameters(::Type{ArbitraryConditionedDistributionFromExponentialFamily}, η) = (η[1],) - -================ -File: exponential_family_tests.jl -================ -@testitem "pack_parameters" begin - include("./exponential_family_setuptests.jl") - - import ExponentialFamily: pack_parameters - - to_test_fixed = [ - (1, 2), - (1.0, 2), - (1, 2.0), - (1.0, 2.0), - ([1, 2, 3], 3), - ([1, 2, 3], 3.0), - ([1.0, 2.0, 3.0], 3), - ([1.0, 2.0, 3.0], 3.0), - (4, [1, 2, 3]), - (4.0, [1, 2, 3]), - (4, [1.0, 2.0, 3.0]), - (4.0, [1.0, 2.0, 3.0]), - ([1, 2, 3], 3, [1 2 3; 1 2 3; 1 2 3], 4), - ([1, 2, 3], 3, [1.0 2.0 3.0; 1.0 2.0 3.0; 1.0 2.0 3.0], 4), - ([1, 2, 3], 3, [1.0 2.0 3.0; 1.0 2.0 3.0; 1.0 2.0 3.0], 4.0), - ([1, 2, 3], 3.0, [1 2 3; 1 2 3; 1 2 3], 4), - ([1, 2, 3], 3.0, [1.0 2.0 3.0; 1.0 2.0 3.0; 1.0 2.0 3.0], 4), - ([1, 2, 3], 3.0, [1.0 2.0 3.0; 1.0 2.0 3.0; 1.0 2.0 3.0], 4.0), - ([1.0, 2.0, 3.0], 3, [1 2 3; 1 2 3; 1 2 3], 4), - ([1.0, 2.0, 3.0], 3, [1.0 2.0 3.0; 1.0 2.0 3.0; 1.0 2.0 3.0], 4), - ([1.0, 2.0, 3.0], 3, [1.0 2.0 3.0; 1.0 2.0 3.0; 1.0 2.0 3.0], 4.0) - ] - - for test in to_test_fixed - @test all(@inferred(pack_parameters(test)) .== collect(Iterators.flatten(test))) - end - - for _ in 1:10 - to_test_random = [ - rand(Float64), - rand(1:10), - [rand(Float64) for _ in rand(1:10)], - [rand(1:10) for _ in rand(1:10)], - [rand(Float64) for _ in rand(1:10) for _ in rand(1:10)], - [rand(1:10) for _ in rand(1:10) for _ in rand(1:10)] - ] - params = Tuple(shuffle(to_test_random)) - @test all(@inferred(pack_parameters(params)) .== collect(Iterators.flatten(params))) - end -end - -@testitem "ExponentialFamilyDistributionAttributes" begin - include("./exponential_family_setuptests.jl") - - @testset "getmapping" begin - @test @inferred(getmapping(MeanParametersSpace() => NaturalParametersSpace(), ArbitraryDistributionFromExponentialFamily)) === - MeanToNatural{ArbitraryDistributionFromExponentialFamily}() - @test @inferred(getmapping(NaturalParametersSpace() => MeanParametersSpace(), ArbitraryDistributionFromExponentialFamily)) === - NaturalToMean{ArbitraryDistributionFromExponentialFamily}() - @test @allocated(getmapping(MeanParametersSpace() => NaturalParametersSpace(), ArbitraryDistributionFromExponentialFamily)) === 0 - @test @allocated(getmapping(NaturalParametersSpace() => MeanParametersSpace(), ArbitraryDistributionFromExponentialFamily)) === 0 - end - - @testset let attributes = ArbitraryExponentialFamilyAttributes - @test @inferred(getbasemeasure(attributes)(2.0)) ≈ 0.5 - @test @inferred(getlogbasemeasure(attributes)(2.0)) ≈ log(0.5) - @test @inferred(getsufficientstatistics(attributes)[1](2.0)) ≈ 2.0 - @test @inferred(getsufficientstatistics(attributes)[2](2.0)) ≈ log(2.0) - @test @inferred(getlogpartition(attributes)([2.0])) ≈ 0.5 - @test @inferred(getsupport(attributes)) == RealInterval(0, Inf) - @test @inferred(insupport(attributes, 1.0)) - @test !@inferred(insupport(attributes, -1.0)) - end - - @testset let member = - ExponentialFamilyDistribution(Univariate, [2.0, 2.0], nothing, ArbitraryExponentialFamilyAttributes) - η = @inferred(getnaturalparameters(member)) - - @test ExponentialFamily.exponential_family_typetag(member) === Univariate - - @test @inferred(basemeasure(member, 2.0)) ≈ 0.5 - @test @inferred(getbasemeasure(member)(2.0)) ≈ 0.5 - @test @inferred(getbasemeasure(member)(4.0)) ≈ 0.25 - - @test all(@inferred(sufficientstatistics(member, 2.0)) .≈ (2.0, log(2.0))) - @test all(@inferred(map(f -> f(2.0), getsufficientstatistics(member))) .≈ (2.0, log(2.0))) - @test all(@inferred(map(f -> f(4.0), getsufficientstatistics(member))) .≈ (4.0, log(4.0))) - - @test @inferred(logpartition(member)) ≈ 0.25 - @test @inferred(getlogpartition(member)([2.0, 2.0])) ≈ 0.25 - @test @inferred(getlogpartition(member)([4.0, 4.0])) ≈ 0.125 - - @test @inferred(getsupport(member)) == RealInterval(0, Inf) - @test @inferred(insupport(member, 1.0)) - @test !@inferred(insupport(member, -1.0)) - - _similar = @inferred(similar(member)) - - # The standard `@allocated` is not really reliable in this test - # We avoid using the `BenchmarkTools`, but here it is essential - @test @ballocated(logpdf($member, 1.0), samples = 1, evals = 1) === 0 - @test @ballocated(pdf($member, 1.0), samples = 1, evals = 1) === 0 - - @test _similar isa typeof(member) - - # `similar` most probably returns the un-initialized natural parameters with garbage in it - # But we do expect the functions to work anyway given proper values - @test @inferred(basemeasure(_similar, 2.0)) ≈ 0.5 - @test all(@inferred(sufficientstatistics(_similar, 2.0)) .≈ (2.0, log(2.0))) - @test @inferred(logpartition(_similar, η)) ≈ 0.25 - @test @inferred(getsupport(_similar)) == RealInterval(0, Inf) - end -end - -@testitem "ArbitraryDistributionFromExponentialFamily" begin - include("./exponential_family_setuptests.jl") - - @testset for member in ( - ExponentialFamilyDistribution(ArbitraryDistributionFromExponentialFamily, [2.0, 2.0]), - convert(ExponentialFamilyDistribution, ArbitraryDistributionFromExponentialFamily(1.0, 1.0)) - ) - η = @inferred(getnaturalparameters(member)) - - @test ExponentialFamily.exponential_family_typetag(member) === ArbitraryDistributionFromExponentialFamily - - @test convert(ExponentialFamilyDistribution, convert(Distribution, member)) == - ExponentialFamilyDistribution(ArbitraryDistributionFromExponentialFamily, [2.0, 2.0]) - @test convert(Distribution, convert(ExponentialFamilyDistribution, member)) == ArbitraryDistributionFromExponentialFamily(1.0, 1.0) - - @test @inferred(basemeasure(member, 2.0)) ≈ 1.0 - @test @inferred(getbasemeasure(member)(2.0)) ≈ 1.0 - @test @inferred(getbasemeasure(member)(4.0)) ≈ 1.0 - - @test @inferred(logbasemeasure(member, 2.0)) ≈ log(1.0) - @test @inferred(getlogbasemeasure(member)(2.0)) ≈ log(1.0) - @test @inferred(getlogbasemeasure(member)(4.0)) ≈ log(1.0) - - @test all(@inferred(sufficientstatistics(member, 2.0)) .≈ (2.0, log(2.0))) - @test all(@inferred(map(f -> f(2.0), getsufficientstatistics(member))) .≈ (2.0, log(2.0))) - @test all(@inferred(map(f -> f(4.0), getsufficientstatistics(member))) .≈ (4.0, log(4.0))) - - @test @inferred(logpartition(member)) ≈ 0.25 - @test @inferred(getlogpartition(member)([2.0, 2.0])) ≈ 0.25 - @test @inferred(getlogpartition(member)([4.0, 4.0])) ≈ 0.125 - - @test @inferred(getsupport(member)) == RealInterval(0, Inf) - @test insupport(member, 1.0) - @test !insupport(member, -1.0) - - # Computed by hand - @test @inferred(logpdf(member, 2.0)) ≈ (3.75 + 2log(2)) - @test @inferred(logpdf(member, 4.0)) ≈ (7.75 + 4log(2)) - @test @inferred(pdf(member, 2.0)) ≈ exp(3.75 + 2log(2)) - @test @inferred(pdf(member, 4.0)) ≈ exp(7.75 + 4log(2)) - - # The standard `@allocated` is not really reliable in this test - # We avoid using the `BenchmarkTools`, but here it is essential - @test @ballocated(logpdf($member, 2.0), samples = 1, evals = 1) === 0 - @test @ballocated(pdf($member, 2.0), samples = 1, evals = 1) === 0 - - @test @inferred(member == member) - @test @inferred(member ≈ member) - - _similar = @inferred(similar(member)) - _prod = ExponentialFamilyDistribution(ArbitraryDistributionFromExponentialFamily, [4.0, 4.0]) - - @test @inferred(prod(ClosedProd(), member, member)) == _prod - @test @inferred(prod(GenericProd(), member, member)) == _prod - @test @inferred(prod(PreserveTypeProd(ExponentialFamilyDistribution), member, member)) == _prod - @test @inferred(prod(PreserveTypeLeftProd(), member, member)) == _prod - @test @inferred(prod(PreserveTypeRightProd(), member, member)) == _prod - - # Test that the generic prod version does not allocate as much as simply creating a similar ef member - # This is important, because the generic prod version should simply call the in-place version - @test @allocated(prod(ClosedProd(), member, member)) <= @allocated(similar(member)) - @test @allocated(prod(GenericProd(), member, member)) <= @allocated(similar(member)) - - # This test is actually passing, but does not work if you re-run tests for some reason (which is hapenning often during development) - # @test @allocated(prod(PreserveTypeProd(ExponentialFamilyDistribution), member, member)) <= - # @allocated(similar(member)) - - @test @inferred(prod!(_similar, member, member)) == _prod - - # Test that the in-place prod preserves the container paramfloatype - for F in (Float16, Float32, Float64) - @test @inferred(paramfloattype(prod!(similar(member, F), member, member))) === F - @test @inferred(prod!(similar(member, F), member, member)) == convert_paramfloattype(F, _prod) - end - - # Test that the generic in-place prod! version does not allocate at all - @test @allocated(prod!(_similar, member, member)) === 0 - end -end - -@testitem "ArbitraryConditionedDistributionFromExponentialFamily" begin - include("./exponential_family_setuptests.jl") - - # See the `ArbitraryDistributionFromExponentialFamily` defined in the fixtures (above) - # p1 = 3.0, con = -2 - @testset for member in ( - ExponentialFamilyDistribution(ArbitraryConditionedDistributionFromExponentialFamily, [1.0], -2), - convert(ExponentialFamilyDistribution, ArbitraryConditionedDistributionFromExponentialFamily(-2, 3.0)) - ) - @test ExponentialFamily.exponential_family_typetag(member) === ArbitraryConditionedDistributionFromExponentialFamily - - η = @inferred(getnaturalparameters(member)) - - @test convert(ExponentialFamilyDistribution, convert(Distribution, member)) == - ExponentialFamilyDistribution(ArbitraryConditionedDistributionFromExponentialFamily, [1.0], -2) - @test convert(Distribution, convert(ExponentialFamilyDistribution, member)) == ArbitraryConditionedDistributionFromExponentialFamily(-2, 3.0) - - @test @inferred(basemeasure(member, 2.0)) ≈ 2.0^-2 - @test @inferred(getbasemeasure(member)(2.0)) ≈ 2.0^-2 - @test @inferred(getbasemeasure(member)(4.0)) ≈ 4.0^-2 - - @test @inferred(logbasemeasure(member, 2.0)) ≈ -2*log(2.0) - @test @inferred(getlogbasemeasure(member)(2.0)) ≈ -2*log(2.0) - @test @inferred(getlogbasemeasure(member)(4.0)) ≈ -2*log(4.0) - - @test all(@inferred(sufficientstatistics(member, 2.0)) .≈ (log(2.0 + 2),)) - @test all(@inferred(map(f -> f(2.0), getsufficientstatistics(member))) .≈ (log(2.0 + 2),)) - @test all(@inferred(map(f -> f(4.0), getsufficientstatistics(member))) .≈ (log(4.0 + 2),)) - - @test @inferred(logpartition(member)) ≈ -2.0 - @test @inferred(getlogpartition(member)([2.0])) ≈ -1.0 - @test @inferred(getlogpartition(member)([4.0])) ≈ -0.5 - - @test @inferred(getsupport(member)) == RealInterval(0, Inf) - @test insupport(member, 1.0) - @test !insupport(member, -1.0) - - # # Computed by hand - @test @inferred(logpdf(member, 2.0)) ≈ (log(2.0^-2) + log(2.0 + 2) + 2.0) - @test @inferred(logpdf(member, 4.0)) ≈ (log(4.0^-2) + log(4.0 + 2) + 2.0) - @test @inferred(pdf(member, 2.0)) ≈ exp((log(2.0^-2) + log(2.0 + 2) + 2.0)) - @test @inferred(pdf(member, 4.0)) ≈ exp((log(4.0^-2) + log(4.0 + 2) + 2.0)) - - # The standard `@allocated` is not really reliable in this test - # We avoid using the `BenchmarkTools`, but here it is essential - @test @ballocated(logpdf($member, 2.0), samples = 1, evals = 1) === 0 - @test @ballocated(pdf($member, 2.0), samples = 1, evals = 1) === 0 - - @test @inferred(member == member) - @test @inferred(member ≈ member) - - _similar = @inferred(similar(member)) - _prod = ExponentialFamilyDistribution(ArbitraryConditionedDistributionFromExponentialFamily, [1.0], -2) - - # We don't test the prod becasue the basemeasure is not a constant, so the generic prod is not applicable - - # # Test that the in-place prod preserves the container paramfloatype - for F in (Float16, Float32, Float64) - @test @inferred(paramfloattype(similar(member, F))) === F - end - end -end - -@testitem "vague" begin - include("./exponential_family_setuptests.jl") - - @test @inferred(vague(ExponentialFamilyDistribution{ArbitraryDistributionFromExponentialFamily})) isa - ExponentialFamilyDistribution{ArbitraryDistributionFromExponentialFamily} - - @test @inferred(vague(ExponentialFamilyDistribution{ArbitraryConditionedDistributionFromExponentialFamily})) isa - ExponentialFamilyDistribution{ArbitraryConditionedDistributionFromExponentialFamily} -end - -================ -File: runtests.jl -================ -using Aqua, CpuId, ReTestItems, ExponentialFamily - -# `ambiguities = false` - there are quite some ambiguities, but these should be normal and should not be encountered under normal circumstances -# `piracies = false` - we extend/add some of the methods to the objects defined in the Distributions.jl -Aqua.test_all(ExponentialFamily, ambiguities = false, deps_compat = (; check_extras = false, check_weakdeps = true), piracies = false) - -nthreads = max(cputhreads(), 1) -ncores = max(cpucores(), 1) - -runtests(ExponentialFamily, - nworkers = ncores, - nworker_threads = Int(nthreads / ncores), - memory_threshold = 1.0 -) From 575c4afc6038da949b8d36922dd9806483b136f5 Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Thu, 26 Sep 2024 09:41:00 +0200 Subject: [PATCH 22/32] Update test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl Co-authored-by: Bagaev Dmitry --- .../normal_family/mv_normal_mean_scale_precision_tests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl b/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl index f3d695c2..18b4681b 100644 --- a/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl +++ b/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl @@ -209,7 +209,7 @@ end @test_opt cholinv(fi_full) cholinv_time_small = @elapsed cholinv(fi_small) - cholinv_alloc_small = @allocated fisherinformation(ef_small) + cholinv_alloc_small = @allocated cholinv(ef_small) cholinv_time_full = @elapsed cholinv(fi_full) cholinv_alloc_full = @allocated cholinv(fi_full) From 2178b10d8cb7fc9114141e007dfab5fb85b75928 Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Thu, 26 Sep 2024 09:57:19 +0200 Subject: [PATCH 23/32] test(fix): typo in @allocated cholinv(fi_small) --- .../normal_family/mv_normal_mean_scale_precision.jl | 1 - .../normal_family/mv_normal_mean_scale_precision_tests.jl | 5 +---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/src/distributions/normal_family/mv_normal_mean_scale_precision.jl b/src/distributions/normal_family/mv_normal_mean_scale_precision.jl index fb7231d6..44f74558 100644 --- a/src/distributions/normal_family/mv_normal_mean_scale_precision.jl +++ b/src/distributions/normal_family/mv_normal_mean_scale_precision.jl @@ -248,7 +248,6 @@ getfisherinformation(::NaturalParametersSpace, ::Type{MvNormalMeanScalePrecision η1_part = -inv(2*η2)* I(length(η1)) η1η2 = zeros(k, 1) η1η2 .= η1*inv(2*η2^2) - #η₁/(2abs2(η₂)) η2_part = zeros(1, 1) η2_part .= k*inv(2abs2(η2)) - dot(η1,η1) / (2*η2^3) diff --git a/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl b/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl index 18b4681b..d8f43f9c 100644 --- a/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl +++ b/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl @@ -209,14 +209,11 @@ end @test_opt cholinv(fi_full) cholinv_time_small = @elapsed cholinv(fi_small) - cholinv_alloc_small = @allocated cholinv(ef_small) + cholinv_alloc_small = @allocated cholinv(fi_small) cholinv_time_full = @elapsed cholinv(fi_full) cholinv_alloc_full = @allocated cholinv(fi_full) - fi_small = fisherinformation(ef_small) - fi_full = fisherinformation(ef_full) - # small time is supposed to be O(k) and full time is supposed to O(k^2) # the constant C is selected to account to fluctuations in test runs C = 0.9 From 2a5db1346e4dd45743960927abbe8df1e895521b Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Fri, 11 Oct 2024 14:26:21 +0200 Subject: [PATCH 24/32] fix: MvNormalMeanScalePrecision should be faster from 10 dimensions --- .../normal_family/mv_normal_mean_scale_precision_tests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl b/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl index d8f43f9c..222a004d 100644 --- a/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl +++ b/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl @@ -185,7 +185,7 @@ end using JET rng = StableRNG(42) - for k in 20:40 + for k in 10:40 μ = randn(rng, k) γ = rand(rng) cov = γ * I(k) From 7a00b4069ceccbcdafe161642b399a4f8c71dd84 Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Fri, 11 Oct 2024 14:26:50 +0200 Subject: [PATCH 25/32] fix: bump BayesBase 1.4.0 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index ffe0e999..597e5025 100644 --- a/Project.toml +++ b/Project.toml @@ -29,7 +29,7 @@ TinyHugeNumbers = "783c9a47-75a3-44ac-a16b-f1ab7b3acf04" [compat] Aqua = "0.8.7" -BayesBase = "1.2" +BayesBase = "1.4.0" BlockArrays = "1.1.1" Distributions = "0.25" DomainSets = "0.5.2, 0.6, 0.7" From 148d14033d9eaaab553e12f5f9bdb2fdf10ae49f Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Fri, 11 Oct 2024 14:40:34 +0200 Subject: [PATCH 26/32] refactor: mean param fisher for MvNormalMeanScalePrecision --- .../mv_normal_mean_scale_precision.jl | 20 ++++--------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/src/distributions/normal_family/mv_normal_mean_scale_precision.jl b/src/distributions/normal_family/mv_normal_mean_scale_precision.jl index 44f74558..ca98da55 100644 --- a/src/distributions/normal_family/mv_normal_mean_scale_precision.jl +++ b/src/distributions/normal_family/mv_normal_mean_scale_precision.jl @@ -269,20 +269,8 @@ getfisherinformation(::MeanParametersSpace, ::Type{MvNormalMeanScalePrecision}) γ = θ[end] k = length(μ) - μ_part = γ * I(k) - - μγ_part = zeros(k, 1) - μγ_part .= 0 - - γ_part = zeros(1, 1) - γ_part .= k*inv(2abs2(γ)) - - fisher = BlockArray{eltype(θ)}(undef_blocks, [k, 1], [k, 1]) - - fisher[Block(1), Block(1)] = μ_part - fisher[Block(1), Block(2)] = μγ_part - fisher[Block(2), Block(1)] = μγ_part' - fisher[Block(2), Block(2)] = γ_part - - return fisher + matrix = zeros(eltype(μ), (k+1)) + matrix[1:k] .= γ + matrix[k+1] = k*inv(2abs2(γ)) + return Diagonal(matrix) end From 99902803524b7e10c11658a16281a561f7cfdc7f Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Fri, 11 Oct 2024 14:56:44 +0200 Subject: [PATCH 27/32] fix: remove BlockArrays --- Project.toml | 2 -- .../normal_family/mv_normal_mean_scale_precision.jl | 13 ++----------- 2 files changed, 2 insertions(+), 13 deletions(-) diff --git a/Project.toml b/Project.toml index 597e5025..cdfaaf51 100644 --- a/Project.toml +++ b/Project.toml @@ -5,7 +5,6 @@ version = "1.5.1" [deps] BayesBase = "b4ee3484-f114-42fe-b91c-797d54a0c67e" -BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf" FastCholesky = "2d5283b6-8564-42b6-bb00-83ed8e915756" @@ -30,7 +29,6 @@ TinyHugeNumbers = "783c9a47-75a3-44ac-a16b-f1ab7b3acf04" [compat] Aqua = "0.8.7" BayesBase = "1.4.0" -BlockArrays = "1.1.1" Distributions = "0.25" DomainSets = "0.5.2, 0.6, 0.7" FastCholesky = "1.0" diff --git a/src/distributions/normal_family/mv_normal_mean_scale_precision.jl b/src/distributions/normal_family/mv_normal_mean_scale_precision.jl index ca98da55..f4c65037 100644 --- a/src/distributions/normal_family/mv_normal_mean_scale_precision.jl +++ b/src/distributions/normal_family/mv_normal_mean_scale_precision.jl @@ -3,7 +3,6 @@ export MvNormalMeanScalePrecision, MvGaussianMeanScalePrecision import Distributions: logdetcov, distrname, sqmahal, sqmahal!, AbstractMvNormal import LinearAlgebra: diag, Diagonal, dot import Base: ndims, precision, length, size, prod -import BlockArrays: Block, BlockArray, undef_blocks """ MvNormalMeanScalePrecision{T <: Real, M <: AbstractVector{T}} <: AbstractMvNormal @@ -249,17 +248,9 @@ getfisherinformation(::NaturalParametersSpace, ::Type{MvNormalMeanScalePrecision η1η2 = zeros(k, 1) η1η2 .= η1*inv(2*η2^2) - η2_part = zeros(1, 1) - η2_part .= k*inv(2abs2(η2)) - dot(η1,η1) / (2*η2^3) - # inv(2abs2(η₂))-abs2(η₁)/(2(η₂^3)) - - fisher = BlockArray{eltype(η)}(undef_blocks, [k, 1], [k, 1]) + η2_part = k*inv(2abs2(η2)) - dot(η1,η1) / (2*η2^3) - fisher[Block(1), Block(1)] = η1_part - fisher[Block(1), Block(2)] = η1η2 - fisher[Block(2), Block(1)] = η1η2' - fisher[Block(2), Block(2)] = η2_part - return fisher + return ArrowheadMatrix(η2_part, η1η2, diag(η1_part)) end From 98f2343641aff2d7b23ef73cfecf8232b0a1dd27 Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Tue, 22 Oct 2024 11:43:43 +0200 Subject: [PATCH 28/32] fix: update BayesBase 1.5.0 --- Project.toml | 2 +- .../normal_family/mv_normal_mean_scale_precision_tests.jl | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index cdfaaf51..ec6942bf 100644 --- a/Project.toml +++ b/Project.toml @@ -28,7 +28,7 @@ TinyHugeNumbers = "783c9a47-75a3-44ac-a16b-f1ab7b3acf04" [compat] Aqua = "0.8.7" -BayesBase = "1.4.0" +BayesBase = "1.5.0" Distributions = "0.25" DomainSets = "0.5.2, 0.6, 0.7" FastCholesky = "1.0" diff --git a/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl b/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl index 222a004d..1b64f29e 100644 --- a/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl +++ b/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl @@ -181,6 +181,7 @@ end @testitem "MvNormalMeanScalePrecision: Fisher is faster then for full parametrization" begin include("./normal_family_setuptests.jl") using BenchmarkTools + using FastCholesky using LinearAlgebra using JET From fc1103d9b597e2e7c5a7ba9ce45e597e76399f70 Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Tue, 22 Oct 2024 12:56:00 +0200 Subject: [PATCH 29/32] fix: use rand! in rand for MvGaussianMeanScalePrecision --- .../normal_family/mv_normal_mean_scale_precision.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/distributions/normal_family/mv_normal_mean_scale_precision.jl b/src/distributions/normal_family/mv_normal_mean_scale_precision.jl index f4c65037..07aa8a69 100644 --- a/src/distributions/normal_family/mv_normal_mean_scale_precision.jl +++ b/src/distributions/normal_family/mv_normal_mean_scale_precision.jl @@ -183,8 +183,9 @@ function BayesBase.prod( end function BayesBase.rand(rng::AbstractRNG, dist::MvGaussianMeanScalePrecision{T}) where {T} - μ, γ = mean(dist), scale(dist) - return μ + 1 / γ .* randn(rng, T, length(μ)) + μ, γ = params(dist) + d = length(μ) + return rand!(rng, dist, Vector{T}(undef, d)) end function BayesBase.rand(rng::AbstractRNG, dist::MvGaussianMeanScalePrecision{T}, size::Int64) where {T} From 9bd15a97eec8a41f5f3000210e9b79927ee6a7e8 Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Tue, 22 Oct 2024 13:10:06 +0200 Subject: [PATCH 30/32] fix: make C=0.7 for Fisher is faster test --- .../normal_family/mv_normal_mean_scale_precision_tests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl b/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl index 1b64f29e..7b3b789b 100644 --- a/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl +++ b/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl @@ -217,7 +217,7 @@ end # small time is supposed to be O(k) and full time is supposed to O(k^2) # the constant C is selected to account to fluctuations in test runs - C = 0.9 + C = 0.7 @test fi_mvsp_time < fi_full_time/(C*k) @test fi_mvsp_alloc < fi_full_alloc/(C*k) @test cholinv_time_small < cholinv_time_full/(C*k) From 809dc6760fa261de1ea2a97755a03e73ec08d937 Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Tue, 22 Oct 2024 13:56:01 +0200 Subject: [PATCH 31/32] test(fix): use benchmark --- .../normal_family/mv_normal_mean_scale_precision_tests.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl b/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl index 7b3b789b..1a2fa956 100644 --- a/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl +++ b/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl @@ -200,10 +200,10 @@ end @test_opt fisherinformation(ef_small) @test_opt fisherinformation(ef_full) - fi_mvsp_time = @elapsed fisherinformation(ef_small) + fi_mvsp_time = min((@benchmark fisherinformation($ef_small)).times...) fi_mvsp_alloc = @allocated fisherinformation(ef_small) - fi_full_time = @elapsed fisherinformation(ef_full) + fi_full_time = min((@benchmark fisherinformation($ef_full)).times...) fi_full_alloc = @allocated fisherinformation(ef_full) @test_opt cholinv(fi_small) From 24108bbd7eeeaabcfd82aa560ffe9b18bd856fea Mon Sep 17 00:00:00 2001 From: Wouter Nuijten Date: Tue, 22 Oct 2024 14:53:30 +0200 Subject: [PATCH 32/32] Change nr of samples belapsed and # dimensions --- .../mv_normal_mean_scale_precision_tests.jl | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl b/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl index 1a2fa956..2c9dd1c3 100644 --- a/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl +++ b/test/distributions/normal_family/mv_normal_mean_scale_precision_tests.jl @@ -34,7 +34,7 @@ end for s in 1:6 μ = randn(rng, s) γ = rand(rng) - + @testset let d = MvNormalMeanScalePrecision(μ, γ) ef = test_exponentialfamily_interface(d;) end @@ -45,10 +45,10 @@ end d = MvNormalMeanScalePrecision(μ, γ) ef = convert(ExponentialFamilyDistribution, d) - + d1d = NormalMeanPrecision(μ[1], γ) ef1d = convert(ExponentialFamilyDistribution, d1d) - + @test logpartition(ef) ≈ logpartition(ef1d) @test gradlogpartition(ef) ≈ gradlogpartition(ef1d) @test fisherinformation(ef) ≈ fisherinformation(ef1d) @@ -186,7 +186,7 @@ end using JET rng = StableRNG(42) - for k in 10:40 + for k in 10:5:40 μ = randn(rng, k) γ = rand(rng) cov = γ * I(k) @@ -200,27 +200,27 @@ end @test_opt fisherinformation(ef_small) @test_opt fisherinformation(ef_full) - fi_mvsp_time = min((@benchmark fisherinformation($ef_small)).times...) + fi_mvsp_time = @elapsed fisherinformation(ef_small) fi_mvsp_alloc = @allocated fisherinformation(ef_small) - fi_full_time = min((@benchmark fisherinformation($ef_full)).times...) + fi_full_time = @elapsed fisherinformation(ef_full) fi_full_alloc = @allocated fisherinformation(ef_full) @test_opt cholinv(fi_small) @test_opt cholinv(fi_full) - cholinv_time_small = @elapsed cholinv(fi_small) + cholinv_time_small = @belapsed cholinv($fi_small) samples = 3 cholinv_alloc_small = @allocated cholinv(fi_small) - cholinv_time_full = @elapsed cholinv(fi_full) + cholinv_time_full = @belapsed cholinv($fi_full) samples = 3 cholinv_alloc_full = @allocated cholinv(fi_full) # small time is supposed to be O(k) and full time is supposed to O(k^2) # the constant C is selected to account to fluctuations in test runs C = 0.7 - @test fi_mvsp_time < fi_full_time/(C*k) - @test fi_mvsp_alloc < fi_full_alloc/(C*k) - @test cholinv_time_small < cholinv_time_full/(C*k) - @test cholinv_alloc_small < cholinv_alloc_full/(C*k) + @test fi_mvsp_time < fi_full_time / (C * k) + @test fi_mvsp_alloc < fi_full_alloc / (C * k) + @test cholinv_time_small < cholinv_time_full / (C * k) + @test cholinv_alloc_small < cholinv_alloc_full / (C * k) end end \ No newline at end of file