Skip to content

Commit

Permalink
Merge pull request #1724 from JuliaRobotics/23Q2/enh/manopt
Browse files Browse the repository at this point in the history
Riemannian Levenberg-Marquardt for parametric solver
  • Loading branch information
dehann authored Jul 4, 2023
2 parents 3eb8e47 + a8e8eaf commit 1061dbb
Show file tree
Hide file tree
Showing 23 changed files with 553 additions and 106 deletions.
8 changes: 8 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@ Alternatively, either use the Github Blame, or the Github `/compare/v0.18.0...v0

The list below highlights breaking changes according to normal semver workflow -- i.e. breaking changes go through at least one deprecatation (via warnings) on the dominant number in the version number. E.g. v0.18 -> v0.19 (warnings) -> v0.20 (breaking). Note that ongoing efforts are made to properly deprecate old code/APIs

# Changes in v0.34

- Start transition to Manopt.jl via Riemannian Levenberg Marquart.
- Deprecate `AbstractRelativeRoots`.

# Changes in v0.33

- Upgrades for DFG using StructTypes.jl (serialization).
# Changes in v0.32

- Major internal refactoring of `CommonConvWrapper` to avoid abstract field types, and better standardization; towards cleanup of internal multihypo handling and naming conventions.
Expand Down
8 changes: 7 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ version = "0.34.0"
ApproxManifoldProducts = "9bbbb610-88a1-53cd-9763-118ce10c1f89"
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Expand All @@ -17,6 +18,7 @@ DistributedFactorGraphs = "b5cc3c7e-6572-11e9-2517-99fb8daf2f04"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
FunctionalStateMachine = "3e9e306e-7e3c-11e9-12d2-8f8f67a2f951"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
Expand All @@ -26,6 +28,7 @@ Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
ManifoldDiff = "af67fdf4-a580-4b9f-bbec-742ef357defd"
Manifolds = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e"
ManifoldsBase = "3362f125-f0bb-47a3-aa74-596ffd7ef2fb"
Manopt = "0fc0a36d-df90-57f3-8f93-d78a9fc72bb5"
MetaGraphs = "626554b9-1ddb-594c-aa3c-2596fe9399a5"
NLSolversBase = "d41bc354-129a-5804-8e4c-c37616107c6c"
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
Expand All @@ -38,6 +41,7 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Expand All @@ -64,6 +68,7 @@ KernelDensityEstimate = "0.5.6"
ManifoldDiff = "0.3"
Manifolds = "0.8.15"
ManifoldsBase = "0.13.12, 0.14"
Manopt = "0.4.27"
MetaGraphs = "0.7"
NLSolversBase = "7.6"
NLsolve = "3, 4"
Expand All @@ -87,10 +92,11 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
Manopt = "0fc0a36d-df90-57f3-8f93-d78a9fc72bb5"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Rotations = "6038ab10-8711-5258-84ad-4b1120ba62dc"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["DifferentialEquations", "Flux", "Graphs", "Manopt", "InteractiveUtils", "Interpolations", "Pkg", "Rotations", "Test"]
test = ["DifferentialEquations", "Flux", "Graphs", "Manopt", "InteractiveUtils", "Interpolations", "LineSearches", "Pkg", "Rotations", "Test"]
23 changes: 23 additions & 0 deletions src/Deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,29 @@ function solveGraphParametric2(
return d, result, flatvar.idx, Σ
end


##==============================================================================
## Deprecate code below before v0.35
##==============================================================================



function _solveLambdaNumeric(
fcttype::Union{F, <:Mixture{N_, F, S, T}},
objResX::Function,
residual::AbstractVector{<:Real},
u0::AbstractVector{<:Real},
islen1::Bool = false,
) where {N_, F <: AbstractRelativeRoots, S, T}
#

#
r = NLsolve.nlsolve((res, x) -> res .= objResX(x), u0; inplace = true) #, ftol=1e-14)

#
return r.zero
end

##==============================================================================
## Deprecate code below before v0.35
##==============================================================================
Expand Down
2 changes: 1 addition & 1 deletion src/Factors/EuclidDistance.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ $(TYPEDEF)
Default linear offset between two scalar variables.
"""
struct EuclidDistance{T <: SamplableBelief} <: AbstractRelativeMinimize
struct EuclidDistance{T <: SamplableBelief} <: AbstractManifoldMinimize # AbstractRelativeMinimize
Z::T
end

Expand Down
27 changes: 14 additions & 13 deletions src/Factors/GenericFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,10 @@ export ManifoldFactor
# For now, `Z` is on the tangent space in coordinates at the point used in the factor.
# For groups just the lie algebra
# As transition it will be easier this way, we can reevaluate
struct ManifoldFactor{M <: AbstractManifold, T <: SamplableBelief} <:
AbstractManifoldMinimize #AbstractFactor
struct ManifoldFactor{
M <: AbstractManifold,
T <: SamplableBelief
} <: AbstractManifoldMinimize
M::M
Z::T
end
Expand All @@ -80,22 +82,21 @@ DFG.getManifold(f::ManifoldFactor) = f.M
function getSample(cf::CalcFactor{<:ManifoldFactor{M, Z}}) where {M, Z}
#TODO @assert dim == cf.factor.Z's dimension
#TODO investigate use of SVector if small dims
if M isa ManifoldKernelDensity
ret = sample(cf.factor.Z.belief)[1]
else
ret = rand(cf.factor.Z)
end
# if M isa ManifoldKernelDensity
# ret = sample(cf.factor.Z.belief)[1]
# else
# ret = rand(cf.factor.Z)
# end

# ASSUME this function is only used for RelativeFactors which must use measurements as tangents
ret = sampleTangent(cf.factor.M, cf.factor.Z)
#return coordinates as we do not know the point here #TODO separate Lie group
return ret
end

# function (cf::CalcFactor{<:ManifoldFactor{<:AbstractDecoratorManifold}})(Xc, p, q)
function (cf::CalcFactor{<:ManifoldFactor})(Xc, p, q)
# function (cf::ManifoldFactor)(X, p, q)
M = cf.manifold # .factor.M
# M = cf.M
X = hat(M, p, Xc)
return distanceTangent2Point(M, X, p, q)
function (cf::CalcFactor{<:ManifoldFactor})(X, p, q)
return distanceTangent2Point(cf.manifold, X, p, q)
end

## ======================================================================================
Expand Down
2 changes: 1 addition & 1 deletion src/Factors/GenericMarginal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""
$(TYPEDEF)
"""
mutable struct GenericMarginal <: AbstractRelativeRoots
mutable struct GenericMarginal <: AbstractManifoldMinimize # AbstractRelativeRoots
Zij::Array{Float64, 1}
Cov::Array{Float64, 1}
W::Array{Float64, 1}
Expand Down
2 changes: 1 addition & 1 deletion src/Factors/LinearRelative.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Default linear offset between two scalar variables.
X_2 = X_1 + η_Z
```
"""
struct LinearRelative{N, T <: SamplableBelief} <: AbstractRelativeMinimize
struct LinearRelative{N, T <: SamplableBelief} <: AbstractManifoldMinimize # AbstractRelativeMinimize
Z::T
end

Expand Down
1 change: 1 addition & 0 deletions src/IncrementalInference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ include("CliqueStateMachine/services/CliqStateMachineUtils.jl")
#EXPERIMENTAL parametric
include("ParametricCSMFunctions.jl")
include("ParametricUtils.jl")
include("ParametricManoptDev.jl")
include("services/MaxMixture.jl")

#X-stroke
Expand Down
5 changes: 3 additions & 2 deletions src/ManifoldSampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@ function sampleTangent(x::ManifoldKernelDensity, p = mean(x))
end

# Sampling Distributions
function sampleTangent(M::AbstractManifold, z::Distribution, p, basis::AbstractBasis)
# assumes M is a group and will break for Riemannian, but leaving that enhancement as TODO
function sampleTangent(M::AbstractManifold, z::Distribution, p = getPointIdentity(M), basis::AbstractBasis = DefaultOrthogonalBasis())
return get_vector(M, p, rand(z), basis)
end

function sampleTangent(
M::AbstractDecoratorManifold,
z::Distribution,
p = getPointIdentity(M),
p = identity_element(M), #getPointIdentity(M),
)
return hat(M, p, rand(z, 1)[:]) #TODO find something better than (z,1)[:]
end
Expand Down
24 changes: 24 additions & 0 deletions src/ManifoldsExtentions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,30 @@ function Manifolds.exp!(M::NPowerManifold, q, p, X)
return q
end

function Manifolds.compose!(M::NPowerManifold, x, p, q)
rep_size = representation_size(M.manifold)
for i in Manifolds.get_iterator(M)
x[i] = compose(
M.manifold,
Manifolds._read(M, rep_size, p, i),
Manifolds._read(M, rep_size, q, i),
)
end
return x
end

function Manifolds.allocate_result(M::NPowerManifold, f, x...)
if length(x) == 0
return [Manifolds.allocate_result(M.manifold, f) for _ in Manifolds.get_iterator(M)]
else
return copy(x[1])
end
end

function Manifolds.allocate_result(::NPowerManifold, ::typeof(get_vector), p, X)
return copy(p)
end

## ================================================================================================
## ArrayPartition getPointIdentity (identity_element)
## ================================================================================================
Expand Down
Loading

0 comments on commit 1061dbb

Please sign in to comment.