-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #15 from ReactiveBayes/categorical_manifold
Add manifold for Categorical distribution
- Loading branch information
Showing
7 changed files
with
248 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
|
||
""" | ||
get_natural_manifold_base(::Type{Categorical}, dims::Tuple{Int}, conditioner=nothing) | ||
Get the natural manifold base for the `Categorical` distribution. | ||
""" | ||
function get_natural_manifold_base(::Type{Categorical}, ::Tuple{}, conditioner=nothing) | ||
return ProductManifold( | ||
Euclidean(conditioner-1), SinglePointManifold([0]) | ||
) | ||
end | ||
|
||
""" | ||
partition_point(::Type{Categorical}, dims::Tuple{Int}, p, conditioner=nothing) | ||
Converts the `point` to a compatible representation for the natural manifold of type `Categorical`. | ||
""" | ||
function partition_point(::Type{Categorical}, ::Tuple{}, p, conditioner=nothing) | ||
return ArrayPartition(view(p, 1:conditioner-1), view(p, conditioner:conditioner)) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
using ManifoldsBase | ||
using Random | ||
|
||
""" | ||
SinglePointManifold(point) | ||
This manifold represents a set from one point. | ||
""" | ||
struct SinglePointManifold{T, R} <: AbstractManifold{ℝ} | ||
point::T | ||
representation_size::R | ||
end | ||
|
||
function SinglePointManifold(point::T) where {T} | ||
return SinglePointManifold(point, size(point)) | ||
end | ||
|
||
function Base.show(io::IO, M::SinglePointManifold) | ||
print(io, "SinglePointManifold(", M.point, ")") | ||
end | ||
|
||
ManifoldsBase.manifold_dimension(::SinglePointManifold) = 0 | ||
ManifoldsBase.representation_size(M::SinglePointManifold) = M.representation_size | ||
ManifoldsBase.injectivity_radius(M::SinglePointManifold) = zero(eltype(M.point)) | ||
|
||
ManifoldsBase.default_retraction_method(::SinglePointManifold) = ExponentialRetraction() | ||
|
||
function ManifoldsBase.check_point(M::SinglePointManifold, p; kwargs...) | ||
if !(p ≈ M.point) | ||
return DomainError(p, "The point $(p) does not lie on $(M), which contains only $(M.point).") | ||
end | ||
return nothing | ||
end | ||
|
||
function ManifoldsBase.check_vector(M::SinglePointManifold, p, X; kwargs...) | ||
if !iszero(X) && size(M.point) == size(X) | ||
return DomainError(X, "The tangent space of $(M) contains only the zero vector.") | ||
end | ||
return nothing | ||
end | ||
|
||
ManifoldsBase.is_flat(::SinglePointManifold) = true | ||
|
||
ManifoldsBase.embed(::SinglePointManifold, p) = p | ||
ManifoldsBase.embed(::SinglePointManifold, p, X) = X | ||
|
||
function ManifoldsBase.inner(::SinglePointManifold, p, X, Y) | ||
return zero(eltype(X)) | ||
end | ||
|
||
function ManifoldsBase.exp!(M::SinglePointManifold, q, p, X, t::Number=1) | ||
q .= M.point | ||
return q | ||
end | ||
|
||
function ManifoldsBase.log!(::SinglePointManifold, X, p, q) | ||
X .= zero(eltype(X)) | ||
return X | ||
end | ||
|
||
function ManifoldsBase.project!(::SinglePointManifold, Y, p, X) | ||
fill!(Y, zero(eltype(Y))) | ||
return Y | ||
end | ||
|
||
function ManifoldsBase.zero_vector!(::SinglePointManifold, X, p) | ||
return fill!(X, zero(eltype(X))) | ||
end | ||
|
||
function Random.rand(M::SinglePointManifold; kwargs...) | ||
return rand(Random.default_rng(), M; kwargs...) | ||
end | ||
|
||
function Random.rand(rng::AbstractRNG, M::SinglePointManifold; kwargs...) | ||
return M.point | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
@testitem "Check `Categorical` natural manifold" begin | ||
include("natural_manifolds_setuptests.jl") | ||
|
||
test_natural_manifold() do rng | ||
p = rand(rng, 10) | ||
normalize!(p, 1) | ||
return Categorical(p) | ||
end | ||
end | ||
|
||
@testitem "Check that optimization work on Categorical" begin | ||
include("natural_manifolds_setuptests.jl") | ||
|
||
using Manopt, ForwardDiff | ||
using BayesBase | ||
|
||
rng = StableRNG(42) | ||
p = rand(StableRNG(42), 10) | ||
normalize!(p, 1) | ||
distribution = Categorical(p) | ||
sample = rand(rng, distribution) | ||
dims = size(sample) | ||
ef = convert(ExponentialFamilyDistribution, distribution) | ||
T = ExponentialFamily.exponential_family_typetag(ef) | ||
M = get_natural_manifold(T, dims, getconditioner(ef)) | ||
|
||
function f(M, p) | ||
ef = convert(ExponentialFamilyDistribution, M, p) | ||
η = getnaturalparameters(ef) | ||
return (mean(η) - 0.5)^2 | ||
end | ||
|
||
function g(M, p) | ||
return project(M, p, 2 * p ./ 10) | ||
end | ||
|
||
q = gradient_descent(M, f, g, rand(rng, M)) | ||
@test q ∈ M | ||
@test mean(q) ≈ 0.5 atol = 1e-1 | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
@testitem "Generic properties of SinglePointManifold" begin | ||
import ManifoldsBase: check_point, check_vector, embed, representation_size, injectivity_radius, get_embedding, is_flat, inner, manifold_dimension | ||
import ExponentialFamilyManifolds: SinglePointManifold | ||
using ManifoldsBase, Static, StaticArrays, JET, Manifolds | ||
using StableRNGs | ||
using Random | ||
|
||
rng = StableRNG(42) | ||
|
||
|
||
points = [ | ||
0, | ||
0.0, | ||
0.0f0, | ||
1, | ||
1.0, | ||
1.0f0, | ||
-1, | ||
2, | ||
π, | ||
rand(), | ||
randn() | ||
] | ||
|
||
for p in points | ||
M = SinglePointManifold(p) | ||
|
||
@test repr(M) == "SinglePointManifold($p)" | ||
|
||
@test @inferred(representation_size(M)) === () | ||
@test @inferred(manifold_dimension(M)) === 0 | ||
@test @inferred(is_flat(M)) === true | ||
@test injectivity_radius(M) ≈ 0 | ||
@test default_retraction_method(M) == ExponentialRetraction() | ||
|
||
@test_throws MethodError get_embedding(M) | ||
|
||
@test check_point(M, p) === nothing | ||
@test check_point(M, p + 1) isa DomainError | ||
@test check_point(M, p - 1) isa DomainError | ||
|
||
@test check_vector(M, p, 0) === nothing | ||
@test check_vector(M, p, 1) isa DomainError | ||
@test check_vector(M, p, -1) isa DomainError | ||
|
||
@test @eval(@allocated(representation_size($M))) === 0 | ||
@test @eval(@allocated(manifold_dimension($M))) === 0 | ||
@test @eval(@allocated(is_flat($M))) === 0 | ||
|
||
X = [1] | ||
Y = [1] | ||
|
||
@test_opt inner(M, p, X, Y) | ||
@test_opt inner(M, p, 0, 0) | ||
|
||
@test embed(M, p) == p | ||
@test embed(M, p, 0) == 0 | ||
@test inner(M, p, 0, 0) == 0 | ||
end | ||
|
||
vector_points = [[1], [1, 2], [1, 2, 3]] | ||
|
||
for p in vector_points | ||
M = SinglePointManifold(p) | ||
q = similar(p) | ||
X = zero_vector(M, p) | ||
@test ManifoldsBase.exp!(M, q, p, X) == p | ||
@test ManifoldsBase.log!(M, X, p, p) == zero_vector(M, p) | ||
@test ManifoldsBase.log(M, p, p) == zero_vector(M, p) | ||
@test ManifoldsBase.project!(M, similar(X), p, similar(X)) == zero_vector(M, p) | ||
@test rand(rng, M) ∈ M | ||
@test rand(M) ∈ M | ||
end | ||
end | ||
|
||
@testitem "Simple manifold optimization problem #1" begin | ||
using Manopt, ForwardDiff, Static, StableRNGs, LinearAlgebra | ||
|
||
import ExponentialFamilyManifolds: SinglePointManifold | ||
|
||
for a in (2.0, 3.0), | ||
b in (10.0, 5.0), | ||
c in (1.0, 10.0, -1.0), | ||
eps in (1e-4, 1e-5, 1e-8, 1e-10), | ||
stepsize in (ConstantStepsize(0.1), ConstantStepsize(0.01), ConstantStepsize(0.001)) | ||
|
||
f(M, x) = (a .* x .^ 2 .+ b .* x .+ c)[1] | ||
grad_f(M, x) = 2 .* a .* x .+ b | ||
|
||
rng = StableRNG(42) | ||
|
||
for s in [0, 0.0, 10] | ||
M = SinglePointManifold(s) | ||
p0 = rand(rng, M) | ||
|
||
q1 = gradient_descent( | ||
M, | ||
f, | ||
grad_f, | ||
p0; | ||
stepsize=stepsize, | ||
stopping_criterion=StopAfterIteration(1) | ||
) | ||
|
||
@test q1 ≈ s | ||
end | ||
end | ||
end |