Skip to content

Commit

Permalink
add tests for manifolds
Browse files Browse the repository at this point in the history
  • Loading branch information
bvdmitri committed Jun 10, 2024
1 parent 2337aad commit 084f36e
Show file tree
Hide file tree
Showing 15 changed files with 130 additions and 4 deletions.
4 changes: 4 additions & 0 deletions src/natural_manifolds.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ function get_natural_manifold(::Type{T}, dims, conditioner = nothing) where {T}
)
end

function get_natural_manifold(ef::ExponentialFamilyDistribution)
return get_natural_manifold(exponential_family_typetag(ef), getconditioner(ef))
end

"""
get_natural_manifold_base(::Type{T}, conditioner = nothing)
Expand Down
6 changes: 4 additions & 2 deletions src/natural_manifolds/dirichlet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ function get_natural_manifold_base(
dims::Tuple{Int},
conditioner = nothing,
)
# `ProductManifold` here is important to treat the `PowerManifold` as a vector, and not matrix
# `PowerManifold` does treat the vector as a matrix with one row
# In the `parition_point` we transpose the vector and use `ArrayPartition` for `ProductManifold`
return ProductManifold(PowerManifold(ShiftedPositiveNumbers(static(-1)), first(dims)))
end

function partition_point(::Type{Dirichlet}, dims::Tuple{Int}, p, conditioner = nothing)
return p
# See comment in `get_natural_manifold_base` for `Dirichlet`
return ArrayPartition(p')
end
2 changes: 1 addition & 1 deletion src/natural_manifolds/lognormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ function get_natural_manifold_base(::Type{LogNormal}, ::Tuple{}, conditioner = n
return ProductManifold(Euclidean(1), ShiftedNegativeNumbers(static(0)))
end

function partition_point(::Type{LogNormal}, p, conditioner = nothing)
function partition_point(::Type{LogNormal}, ::Tuple{}, p, conditioner = nothing)
return ArrayPartition(view(p, 1:1), view(p, 2:2))
end
2 changes: 1 addition & 1 deletion src/natural_manifolds/normal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ function get_natural_manifold_base(
return ProductManifold(Euclidean(1), ShiftedNegativeNumbers(static(0)))
end

function partition_point(::Type{NormalMeanVariance}, p, conditioner = nothing)
function partition_point(::Type{NormalMeanVariance}, ::Tuple{}, p, conditioner = nothing)
return ArrayPartition(view(p, 1:1), view(p, 2:2))
end

Expand Down
9 changes: 9 additions & 0 deletions test/natural_manifolds/bernoulli_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
@testitem "Check `Bernoulli` natural manifold" begin

include("natural_manifolds_setuptests.jl")

test_natural_manifold() do rng
return Bernoulli(rand(rng))
end

end
9 changes: 9 additions & 0 deletions test/natural_manifolds/beta_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
@testitem "Check `Beta` natural manifold" begin

include("natural_manifolds_setuptests.jl")

test_natural_manifold() do rng
return Beta(10rand(rng), 10rand(rng))
end

end
9 changes: 9 additions & 0 deletions test/natural_manifolds/chisq_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
@testitem "Check `Chisq` natural manifold" begin

include("natural_manifolds_setuptests.jl")

test_natural_manifold() do rng
return Chisq(1 + 10rand(rng))
end

end
10 changes: 10 additions & 0 deletions test/natural_manifolds/dirichlet_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
@testitem "Check `Dirichlet` natural manifold" begin

include("natural_manifolds_setuptests.jl")

test_natural_manifold() do rng
k = rand(rng, 2:10)
return Dirichlet(10rand(rng, k))
end

end
9 changes: 9 additions & 0 deletions test/natural_manifolds/exponential_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
@testitem "Check `Exponential` natural manifold" begin

include("natural_manifolds_setuptests.jl")

test_natural_manifold() do rng
return Exponential(100rand(rng))
end

end
7 changes: 7 additions & 0 deletions test/natural_manifolds/gamma_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
@testitem "Check `Gamma` natural manifold" begin
include("natural_manifolds_setuptests.jl")

test_natural_manifold() do rng
return Gamma(10rand(rng), 10rand(rng))
end
end
9 changes: 9 additions & 0 deletions test/natural_manifolds/geometric_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
@testitem "Check `Geometric` natural manifold" begin

include("natural_manifolds_setuptests.jl")

test_natural_manifold() do rng
return Geometric(rand(rng))
end

end
9 changes: 9 additions & 0 deletions test/natural_manifolds/laplace_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
@testitem "Check `Laplace` natural manifold" begin

include("natural_manifolds_setuptests.jl")

test_natural_manifold() do rng
return Laplace(10rand(rng), 10rand(rng))
end

end
9 changes: 9 additions & 0 deletions test/natural_manifolds/lognormal_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
@testitem "Check `LogNormal` natural manifold" begin

include("natural_manifolds_setuptests.jl")

test_natural_manifold() do rng
return LogNormal(10randn(rng), 10rand(rng))
end

end
21 changes: 21 additions & 0 deletions test/natural_manifolds/natural_manifolds_setuptests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@

using StableRNGs, ExponentialFamily, ManifoldsBase, LinearAlgebra

import ExponentialFamilyManifolds: get_natural_manifold, partition_point

function test_natural_manifold(f; seed=42, ndistributions=100)
rng = StableRNG(seed)

foreach(1:ndistributions) do _
distribution = f(rng)
sample = rand(rng, distribution)
dims = size(sample)

ef = convert(ExponentialFamilyDistribution, distribution)
T = ExponentialFamily.exponential_family_typetag(ef)
M = get_natural_manifold(T, dims, getconditioner(ef))
η = partition_point(T, dims, getnaturalparameters(ef), getconditioner(ef))

@test is_point(M, η, error=:error)
end
end
19 changes: 19 additions & 0 deletions test/natural_manifolds/normal_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
@testitem "Check `Normal` natural manifold" begin
include("natural_manifolds_setuptests.jl")

test_natural_manifold() do rng
return NormalMeanVariance(10randn(rng), 10rand(rng))
end
end

@testitem "Check `MvNormal` natural manifold" begin
include("natural_manifolds_setuptests.jl")

test_natural_manifold() do rng
k = rand(rng, 1:10)
m = randn(k)
L = LowerTriangular(randn(k, k))
C = L * L' + k * I
return MvNormalMeanCovariance(m, C)
end
end

0 comments on commit 084f36e

Please sign in to comment.