Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add SinglePointManifold #16

Merged
merged 4 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/ExponentialFamilyManifolds.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using BayesBase, ExponentialFamily, ManifoldsBase, Manifolds, Random, LinearAlge
include("symmetric_negative_definite.jl")
include("shifted_negative_numbers.jl")
include("shifted_positive_numbers.jl")
include("SinglePointManifold.jl")
include("natural_manifolds.jl")

include("natural_manifolds/bernoulli.jl")
Expand Down
68 changes: 68 additions & 0 deletions src/SinglePointManifold.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
using ManifoldsBase
using Random

"""
SymmetricNegativeDefinite(point)

This manifold represents a set from one point.
"""
struct SinglePointManifold{T, R} <: AbstractManifold{ℝ}
point::T
representation_size::R
end

function SinglePointManifold(point::T) where {T}
return SinglePointManifold(point, size(point))
end

function Base.show(io::IO, M::SinglePointManifold)
print(io, "SinglePointManifold(", M.point, ")")
end

ManifoldsBase.manifold_dimension(::SinglePointManifold) = 0
ManifoldsBase.representation_size(M::SinglePointManifold) = M.representation_size
ManifoldsBase.injectivity_radius(M::SinglePointManifold) = zero(eltype(M.point))

ManifoldsBase.default_retraction_method(::SinglePointManifold) = ExponentialRetraction()

function ManifoldsBase.check_point(M::SinglePointManifold, p; kwargs...)
if p != M.point
return DomainError(p, "The point $(p) does not lie on $(M), which contains only $(M.point).")
end
return nothing
end

function ManifoldsBase.check_vector(M::SinglePointManifold, p, X; kwargs...)
if !iszero(X) && size(M.point) == size(X)
return DomainError(X, "The tangent space of $(M) contains only the zero vector.")
end
return nothing
end

ManifoldsBase.is_flat(::SinglePointManifold) = true

ManifoldsBase.embed(::SinglePointManifold, p) = p
ManifoldsBase.embed(::SinglePointManifold, p, X) = X

function ManifoldsBase.inner(::SinglePointManifold, p, X, Y)
return zero(eltype(X))
end

function ManifoldsBase.exp!(M::SinglePointManifold, q, p, X, t::Number=1)
q .= M.point
return q
end

function ManifoldsBase.log!(::SinglePointManifold, X, p, q)
X .= zero(eltype(X))
return X
end

function ManifoldsBase.project!(::SinglePointManifold, Y, p, X)
fill!(Y, zero(eltype(Y)))
return Y
end

function ManifoldsBase.zero_vector!(::SinglePointManifold, X, p)
return fill!(X, zero(eltype(X)))
end
6 changes: 4 additions & 2 deletions src/natural_manifolds/categorical.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
Get the natural manifold base for the `Categorical` distribution.
"""
function get_natural_manifold_base(::Type{Categorical}, ::Tuple{}, conditioner=nothing)
return Euclidean(conditioner)
return ProductManifold(
Euclidean(conditioner-1), SinglePointManifold(0)
)
end

"""
Expand All @@ -15,5 +17,5 @@ Converts the `point` to a compatible representation for the natural manifold of
"""
function partition_point(::Type{Categorical}, ::Tuple{}, p, conditioner=nothing)
# See comment in `get_natural_manifold_base` for `Categorical`
return ArrayPartition(p)
return ArrayPartition(p[1:end-1], p[end])
end
62 changes: 62 additions & 0 deletions test/single_point_manifold_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
@testitem "Generic properties of SinglePointManifold" begin
import ManifoldsBase: check_point, check_vector, representation_size, injectivity_radius, get_embedding, is_flat, inner, manifold_dimension
import ExponentialFamilyManifolds: SinglePointManifold
using ManifoldsBase, Static, StaticArrays, JET, Manifolds

points = [
0,
0.0,
0.0f0,
1,
1.0,
1.0f0,
-1,
2,
π,
rand(),
randn()
]

for p in points
M = SinglePointManifold(p)

@test repr(M) == "SinglePointManifold($p)"

@test @inferred(representation_size(M)) === ()
@test @inferred(manifold_dimension(M)) === 0
@test @inferred(is_flat(M)) === true
@test injectivity_radius(M) ≈ 0

@test_throws MethodError get_embedding(M)

@test check_point(M, p) === nothing
@test check_point(M, p + 1) isa DomainError
@test check_point(M, p - 1) isa DomainError

@test check_vector(M, p, 0) === nothing
@test check_vector(M, p, 1) isa DomainError
@test check_vector(M, p, -1) isa DomainError

@test @eval(@allocated(representation_size($M))) === 0
@test @eval(@allocated(manifold_dimension($M))) === 0
@test @eval(@allocated(is_flat($M))) === 0

X = [1]
Y = [1]

@test_opt inner(M, p, X, Y)
@test_opt inner(M, p, 0, 0)
end

vector_points = [[1], [1, 2], [1, 2, 3]]

for p in vector_points
M = SinglePointManifold(p)
q = similar(p)
X = zero_vector(M, p)
@test ManifoldsBase.exp!(M, q, p, X) == p
@test ManifoldsBase.log!(M, X, p, p) == zero_vector(M, p)
@test ManifoldsBase.log(M, p, p) == zero_vector(M, p)
@test ManifoldsBase.project!(M, similar(X), p, similar(X)) == zero_vector(M, p)
end
end
Loading