diff --git a/src/ExponentialFamilyManifolds.jl b/src/ExponentialFamilyManifolds.jl index 0079dce..1d4f181 100644 --- a/src/ExponentialFamilyManifolds.jl +++ b/src/ExponentialFamilyManifolds.jl @@ -5,6 +5,7 @@ using BayesBase, ExponentialFamily, ManifoldsBase, Manifolds, Random, LinearAlge include("symmetric_negative_definite.jl") include("shifted_negative_numbers.jl") include("shifted_positive_numbers.jl") +include("SinglePointManifold.jl") include("natural_manifolds.jl") include("natural_manifolds/bernoulli.jl") diff --git a/src/SinglePointManifold.jl b/src/SinglePointManifold.jl new file mode 100644 index 0000000..9386b8e --- /dev/null +++ b/src/SinglePointManifold.jl @@ -0,0 +1,68 @@ +using ManifoldsBase +using Random + +""" + SymmetricNegativeDefinite(point) + +This manifold represents a set from one point. +""" +struct SinglePointManifold{T, R} <: AbstractManifold{ℝ} + point::T + representation_size::R +end + +function SinglePointManifold(point::T) where {T} + return SinglePointManifold(point, size(point)) +end + +function Base.show(io::IO, M::SinglePointManifold) + print(io, "SinglePointManifold(", M.point, ")") +end + +ManifoldsBase.manifold_dimension(::SinglePointManifold) = 0 +ManifoldsBase.representation_size(M::SinglePointManifold) = M.representation_size +ManifoldsBase.injectivity_radius(M::SinglePointManifold) = zero(eltype(M.point)) + +ManifoldsBase.default_retraction_method(::SinglePointManifold) = ExponentialRetraction() + +function ManifoldsBase.check_point(M::SinglePointManifold, p; kwargs...) + if p != M.point + return DomainError(p, "The point $(p) does not lie on $(M), which contains only $(M.point).") + end + return nothing +end + +function ManifoldsBase.check_vector(M::SinglePointManifold, p, X; kwargs...) + if !iszero(X) && size(M.point) == size(X) + return DomainError(X, "The tangent space of $(M) contains only the zero vector.") + end + return nothing +end + +ManifoldsBase.is_flat(::SinglePointManifold) = true + +ManifoldsBase.embed(::SinglePointManifold, p) = p +ManifoldsBase.embed(::SinglePointManifold, p, X) = X + +function ManifoldsBase.inner(::SinglePointManifold, p, X, Y) + return zero(eltype(X)) +end + +function ManifoldsBase.exp!(M::SinglePointManifold, q, p, X, t::Number=1) + q .= M.point + return q +end + +function ManifoldsBase.log!(::SinglePointManifold, X, p, q) + X .= zero(eltype(X)) + return X +end + +function ManifoldsBase.project!(::SinglePointManifold, Y, p, X) + fill!(Y, zero(eltype(Y))) + return Y +end + +function ManifoldsBase.zero_vector!(::SinglePointManifold, X, p) + return fill!(X, zero(eltype(X))) +end \ No newline at end of file diff --git a/src/natural_manifolds/categorical.jl b/src/natural_manifolds/categorical.jl index 583f323..282af7c 100644 --- a/src/natural_manifolds/categorical.jl +++ b/src/natural_manifolds/categorical.jl @@ -5,7 +5,9 @@ Get the natural manifold base for the `Categorical` distribution. """ function get_natural_manifold_base(::Type{Categorical}, ::Tuple{}, conditioner=nothing) - return Euclidean(conditioner) + return ProductManifold( + Euclidean(conditioner-1), SinglePointManifold(0) + ) end """ @@ -15,5 +17,5 @@ Converts the `point` to a compatible representation for the natural manifold of """ function partition_point(::Type{Categorical}, ::Tuple{}, p, conditioner=nothing) # See comment in `get_natural_manifold_base` for `Categorical` - return ArrayPartition(p) + return ArrayPartition(p[1:end-1], p[end]) end \ No newline at end of file diff --git a/test/single_point_manifold_tests.jl b/test/single_point_manifold_tests.jl new file mode 100644 index 0000000..f6647c3 --- /dev/null +++ b/test/single_point_manifold_tests.jl @@ -0,0 +1,62 @@ +@testitem "Generic properties of SinglePointManifold" begin + import ManifoldsBase: check_point, check_vector, representation_size, injectivity_radius, get_embedding, is_flat, inner, manifold_dimension + import ExponentialFamilyManifolds: SinglePointManifold + using ManifoldsBase, Static, StaticArrays, JET, Manifolds + + points = [ + 0, + 0.0, + 0.0f0, + 1, + 1.0, + 1.0f0, + -1, + 2, + π, + rand(), + randn() + ] + + for p in points + M = SinglePointManifold(p) + + @test repr(M) == "SinglePointManifold($p)" + + @test @inferred(representation_size(M)) === () + @test @inferred(manifold_dimension(M)) === 0 + @test @inferred(is_flat(M)) === true + @test injectivity_radius(M) ≈ 0 + + @test_throws MethodError get_embedding(M) + + @test check_point(M, p) === nothing + @test check_point(M, p + 1) isa DomainError + @test check_point(M, p - 1) isa DomainError + + @test check_vector(M, p, 0) === nothing + @test check_vector(M, p, 1) isa DomainError + @test check_vector(M, p, -1) isa DomainError + + @test @eval(@allocated(representation_size($M))) === 0 + @test @eval(@allocated(manifold_dimension($M))) === 0 + @test @eval(@allocated(is_flat($M))) === 0 + + X = [1] + Y = [1] + + @test_opt inner(M, p, X, Y) + @test_opt inner(M, p, 0, 0) + end + + vector_points = [[1], [1, 2], [1, 2, 3]] + + for p in vector_points + M = SinglePointManifold(p) + q = similar(p) + X = zero_vector(M, p) + @test ManifoldsBase.exp!(M, q, p, X) == p + @test ManifoldsBase.log!(M, X, p, p) == zero_vector(M, p) + @test ManifoldsBase.log(M, p, p) == zero_vector(M, p) + @test ManifoldsBase.project!(M, similar(X), p, similar(X)) == zero_vector(M, p) + end +end \ No newline at end of file