Skip to content

Commit

Permalink
fix: update ExponetialFamily.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
Nimrais committed Jul 25, 2024
1 parent 68e2f0d commit 99b234b
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"

[compat]
BayesBase = "1.3"
ExponentialFamily = "1.4.3"
ExponentialFamily = "1.5.1"
LinearAlgebra = "1.10"
Manifolds = "0.9"
ManifoldsBase = "0.15"
Expand Down
2 changes: 1 addition & 1 deletion src/single_point_manifold.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ ManifoldsBase.injectivity_radius(M::SinglePointManifold) = zero(eltype(M.point))
ManifoldsBase.default_retraction_method(::SinglePointManifold) = ExponentialRetraction()

function ManifoldsBase.check_point(M::SinglePointManifold, p; kwargs...)
if p[1] != M.point
if !(p M.point)
return DomainError(p, "The point $(p) does not lie on $(M), which contains only $(M.point).")
end
return nothing
Expand Down
6 changes: 3 additions & 3 deletions test/natural_manifolds/categorical_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ end

function f(M, p)
ef = convert(ExponentialFamilyDistribution, M, p)
return (mean(ef) - 0.5)^2
η = getnaturalparameters(ef)
return (mean(η) - 0.5)^2
end

function g(M, p)
ef = convert(ExponentialFamilyDistribution, M, p)
return project(M, p, 2 * (mean(ef) - 0.5) * p ./ 10)
return project(M, p, 2 * p ./ 10)
end

q = gradient_descent(M, f, g, rand(rng, M))
Expand Down

0 comments on commit 99b234b

Please sign in to comment.