Skip to content

Commit

Permalink
Merge pull request #15 from ReactiveBayes/categorical_manifold
Browse files Browse the repository at this point in the history
Add manifold for Categorical distribution
  • Loading branch information
bvdmitri authored Jul 25, 2024
2 parents 8dedefe + 99b234b commit 61ce1a7
Show file tree
Hide file tree
Showing 7 changed files with 248 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"

[compat]
BayesBase = "1.3"
ExponentialFamily = "1.4.3"
ExponentialFamily = "1.5.1"
LinearAlgebra = "1.10"
Manifolds = "0.9"
ManifoldsBase = "0.15"
Expand Down
1 change: 1 addition & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ ExponentialFamilyManifolds.partition_point
ExponentialFamilyManifolds.ShiftedPositiveNumbers
ExponentialFamilyManifolds.ShiftedNegativeNumbers
ExponentialFamilyManifolds.SymmetricNegativeDefinite
ExponentialFamilyManifolds.SinglePointManifold
```

## Optimization example
Expand Down
2 changes: 2 additions & 0 deletions src/ExponentialFamilyManifolds.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ using BayesBase, ExponentialFamily, ManifoldsBase, Manifolds, Random, LinearAlge
include("symmetric_negative_definite.jl")
include("shifted_negative_numbers.jl")
include("shifted_positive_numbers.jl")
include("single_point_manifold.jl")
include("natural_manifolds.jl")

include("natural_manifolds/bernoulli.jl")
include("natural_manifolds/beta.jl")
include("natural_manifolds/binomial.jl")
include("natural_manifolds/chisq.jl")
include("natural_manifolds/categorical.jl")
include("natural_manifolds/dirichlet.jl")
include("natural_manifolds/exponential.jl")
include("natural_manifolds/gamma.jl")
Expand Down
20 changes: 20 additions & 0 deletions src/natural_manifolds/categorical.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@

"""
get_natural_manifold_base(::Type{Categorical}, dims::Tuple{Int}, conditioner=nothing)
Get the natural manifold base for the `Categorical` distribution.
"""
function get_natural_manifold_base(::Type{Categorical}, ::Tuple{}, conditioner=nothing)
return ProductManifold(
Euclidean(conditioner-1), SinglePointManifold([0])
)
end

"""
partition_point(::Type{Categorical}, dims::Tuple{Int}, p, conditioner=nothing)
Converts the `point` to a compatible representation for the natural manifold of type `Categorical`.
"""
function partition_point(::Type{Categorical}, ::Tuple{}, p, conditioner=nothing)
return ArrayPartition(view(p, 1:conditioner-1), view(p, conditioner:conditioner))
end
76 changes: 76 additions & 0 deletions src/single_point_manifold.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
using ManifoldsBase
using Random

"""
SinglePointManifold(point)
This manifold represents a set from one point.
"""
struct SinglePointManifold{T, R} <: AbstractManifold{ℝ}
point::T
representation_size::R
end

function SinglePointManifold(point::T) where {T}
return SinglePointManifold(point, size(point))
end

function Base.show(io::IO, M::SinglePointManifold)
print(io, "SinglePointManifold(", M.point, ")")
end

ManifoldsBase.manifold_dimension(::SinglePointManifold) = 0
ManifoldsBase.representation_size(M::SinglePointManifold) = M.representation_size
ManifoldsBase.injectivity_radius(M::SinglePointManifold) = zero(eltype(M.point))

ManifoldsBase.default_retraction_method(::SinglePointManifold) = ExponentialRetraction()

function ManifoldsBase.check_point(M::SinglePointManifold, p; kwargs...)
if !(p M.point)
return DomainError(p, "The point $(p) does not lie on $(M), which contains only $(M.point).")
end
return nothing
end

function ManifoldsBase.check_vector(M::SinglePointManifold, p, X; kwargs...)
if !iszero(X) && size(M.point) == size(X)
return DomainError(X, "The tangent space of $(M) contains only the zero vector.")
end
return nothing
end

ManifoldsBase.is_flat(::SinglePointManifold) = true

ManifoldsBase.embed(::SinglePointManifold, p) = p
ManifoldsBase.embed(::SinglePointManifold, p, X) = X

function ManifoldsBase.inner(::SinglePointManifold, p, X, Y)
return zero(eltype(X))
end

function ManifoldsBase.exp!(M::SinglePointManifold, q, p, X, t::Number=1)
q .= M.point
return q
end

function ManifoldsBase.log!(::SinglePointManifold, X, p, q)
X .= zero(eltype(X))
return X
end

function ManifoldsBase.project!(::SinglePointManifold, Y, p, X)
fill!(Y, zero(eltype(Y)))
return Y
end

function ManifoldsBase.zero_vector!(::SinglePointManifold, X, p)
return fill!(X, zero(eltype(X)))
end

function Random.rand(M::SinglePointManifold; kwargs...)
return rand(Random.default_rng(), M; kwargs...)
end

function Random.rand(rng::AbstractRNG, M::SinglePointManifold; kwargs...)
return M.point
end
40 changes: 40 additions & 0 deletions test/natural_manifolds/categorical_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
@testitem "Check `Categorical` natural manifold" begin
include("natural_manifolds_setuptests.jl")

test_natural_manifold() do rng
p = rand(rng, 10)
normalize!(p, 1)
return Categorical(p)
end
end

@testitem "Check that optimization work on Categorical" begin
include("natural_manifolds_setuptests.jl")

using Manopt, ForwardDiff
using BayesBase

rng = StableRNG(42)
p = rand(StableRNG(42), 10)
normalize!(p, 1)
distribution = Categorical(p)
sample = rand(rng, distribution)
dims = size(sample)
ef = convert(ExponentialFamilyDistribution, distribution)
T = ExponentialFamily.exponential_family_typetag(ef)
M = get_natural_manifold(T, dims, getconditioner(ef))

function f(M, p)
ef = convert(ExponentialFamilyDistribution, M, p)
η = getnaturalparameters(ef)
return (mean(η) - 0.5)^2
end

function g(M, p)
return project(M, p, 2 * p ./ 10)
end

q = gradient_descent(M, f, g, rand(rng, M))
@test q M
@test mean(q) 0.5 atol = 1e-1
end
108 changes: 108 additions & 0 deletions test/single_point_manifold_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
@testitem "Generic properties of SinglePointManifold" begin
import ManifoldsBase: check_point, check_vector, embed, representation_size, injectivity_radius, get_embedding, is_flat, inner, manifold_dimension
import ExponentialFamilyManifolds: SinglePointManifold
using ManifoldsBase, Static, StaticArrays, JET, Manifolds
using StableRNGs
using Random

rng = StableRNG(42)


points = [
0,
0.0,
0.0f0,
1,
1.0,
1.0f0,
-1,
2,
π,
rand(),
randn()
]

for p in points
M = SinglePointManifold(p)

@test repr(M) == "SinglePointManifold($p)"

@test @inferred(representation_size(M)) === ()
@test @inferred(manifold_dimension(M)) === 0
@test @inferred(is_flat(M)) === true
@test injectivity_radius(M) 0
@test default_retraction_method(M) == ExponentialRetraction()

@test_throws MethodError get_embedding(M)

@test check_point(M, p) === nothing
@test check_point(M, p + 1) isa DomainError
@test check_point(M, p - 1) isa DomainError

@test check_vector(M, p, 0) === nothing
@test check_vector(M, p, 1) isa DomainError
@test check_vector(M, p, -1) isa DomainError

@test @eval(@allocated(representation_size($M))) === 0
@test @eval(@allocated(manifold_dimension($M))) === 0
@test @eval(@allocated(is_flat($M))) === 0

X = [1]
Y = [1]

@test_opt inner(M, p, X, Y)
@test_opt inner(M, p, 0, 0)

@test embed(M, p) == p
@test embed(M, p, 0) == 0
@test inner(M, p, 0, 0) == 0
end

vector_points = [[1], [1, 2], [1, 2, 3]]

for p in vector_points
M = SinglePointManifold(p)
q = similar(p)
X = zero_vector(M, p)
@test ManifoldsBase.exp!(M, q, p, X) == p
@test ManifoldsBase.log!(M, X, p, p) == zero_vector(M, p)
@test ManifoldsBase.log(M, p, p) == zero_vector(M, p)
@test ManifoldsBase.project!(M, similar(X), p, similar(X)) == zero_vector(M, p)
@test rand(rng, M) M
@test rand(M) M
end
end

@testitem "Simple manifold optimization problem #1" begin
using Manopt, ForwardDiff, Static, StableRNGs, LinearAlgebra

import ExponentialFamilyManifolds: SinglePointManifold

for a in (2.0, 3.0),
b in (10.0, 5.0),
c in (1.0, 10.0, -1.0),
eps in (1e-4, 1e-5, 1e-8, 1e-10),
stepsize in (ConstantStepsize(0.1), ConstantStepsize(0.01), ConstantStepsize(0.001))

f(M, x) = (a .* x .^ 2 .+ b .* x .+ c)[1]
grad_f(M, x) = 2 .* a .* x .+ b

rng = StableRNG(42)

for s in [0, 0.0, 10]
M = SinglePointManifold(s)
p0 = rand(rng, M)

q1 = gradient_descent(
M,
f,
grad_f,
p0;
stepsize=stepsize,
stopping_criterion=StopAfterIteration(1)
)

@test q1 s
end
end
end

0 comments on commit 61ce1a7

Please sign in to comment.