Skip to content

Commit

Permalink
Merge pull request #775 from JuliaRobotics/maint/21Q2/fix_defVariable
Browse files Browse the repository at this point in the history
Cleanup defVariable and add getPointType
  • Loading branch information
dehann authored Jun 18, 2021
2 parents 3200541 + 9fb530a commit be4d03b
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 89 deletions.
2 changes: 1 addition & 1 deletion src/DistributedFactorGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ export getSolverData
export getVariableType

# VariableType functions
export getDimension, getManifolds, getManifold
export getDimension, getManifolds, getManifold, getPointType

# Small Data CRUD
export SmallDataTypes, getSmallData, addSmallData!, updateSmallData!, deleteSmallData!, listSmallData, emptySmallData!
Expand Down
55 changes: 33 additions & 22 deletions src/services/DFGVariable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,6 @@ getVariableType(dfg::AbstractDFG, lbl::Symbol) = getVariableType(getVariable(dfg
##------------------------------------------------------------------------------
## InferenceVariable
##------------------------------------------------------------------------------
"""
$SIGNATURES
Interface function to return the `variableType` dimension of an InferenceVariable, extend this function for all Types<:InferenceVariable.
"""
function getDimension end
"""
$SIGNATURES
Interface function to return the `<:ManifoldsBase.AbstractManifold` object of `variableType<:InferenceVariable`, extend this function for all `Types<:InferenceVariable`.
"""
getManifold(vari::DFGVariable) = getVariableType(vari) |> getManifold

# """
# $SIGNATURES
Expand All @@ -83,7 +73,6 @@ getManifold(vari::DFGVariable) = getVariableType(vari) |> getManifold
# getManifolds(::Type{<:T}) where {T <: ManifoldsBase.AbstractManifold} = convert(Tuple, T)
# getManifolds(::T) where {T <: ManifoldsBase.AbstractManifold} = getManifolds(T)

getDimension(t_::T) where {T <: ManifoldsBase.AbstractManifold} = manifold_dimension(t_)

"""
@defVariable StructName manifolds<:ManifoldsBase.AbstractManifold
Expand All @@ -97,27 +86,49 @@ Example:
DFG.@defVariable Pose2 SpecialEuclidean(2)
```
"""
macro defVariable(structname, manifold)
macro defVariable(structname, manifold, point_type)
return esc(quote
Base.@__doc__ struct $structname <: InferenceVariable end

# user manifold must be a <:Manifold
@assert ($manifold isa AbstractManifold) "@defVariable of "*string($structname)*" requires that the "*string($manifold)*" be a subtype of `ManifoldsBase.AbstractManifold`"

# user manifold must be a <:Manifold
Base.convert(::Type{<:AbstractManifold}, ::Union{<:T, Type{<:T}}) where {T <: $structname} = $manifold
DFG.getManifold(::Type{$structname}) = $manifold

getManifold(::Type{M}) where {M <: $structname} = $manifold
getManifold(::M) where {M <: $structname} = getManifold(M)

DFG.getDimension(::Type{M}) where {M <: $structname} = manifold_dimension(getManifold(M))
DFG.getDimension(::M) where {M <: $structname} = manifold_dimension(getManifold(M))
DFG.getPointType(::Type{$structname}) = $point_type

# # # FIXME legacy API to be deprecated
# DFG.getManifolds(::Type{M}) where {M <: $structname} = convert(Tuple, $manifold)
# DFG.getManifolds(::M) where {M <: $structname} = convert(Tuple, $manifold)
end)
end

Base.convert(::Type{<:AbstractManifold}, ::Union{<:T, Type{<:T}}) where {T <: InferenceVariable} = getManifold(T)

"""
$SIGNATURES
Interface function to return the `<:ManifoldsBase.AbstractManifold` object of `variableType<:InferenceVariable`.
"""
getManifold(vari::DFGVariable) = getVariableType(vari) |> getManifold
getManifold(::T) where {T <: InferenceVariable} = getManifold(T)


"""
$SIGNATURES
Interface function to return the `variableType` dimension of an InferenceVariable, extend this function for all Types<:InferenceVariable.
"""
function getDimension end

getDimension(::Type{T}) where {T <: InferenceVariable} = manifold_dimension(getManifold(T))
getDimension(::T) where {T <: InferenceVariable} = manifold_dimension(getManifold(T))
getDimension(M::ManifoldsBase.AbstractManifold) = manifold_dimension(M)
getDimension(var::DFGVariable) = getDimension(getVariableType(var))


"""
$SIGNATURES
Interface function to return the manifold point type of an InferenceVariable, extend this function for all Types<:InferenceVariable.
"""
function getPointType end
getPointType(::T) where {T <: InferenceVariable} = getPointType(T)

##------------------------------------------------------------------------------
## solvedCount
##------------------------------------------------------------------------------
Expand Down
4 changes: 1 addition & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ using UUIDs
# logger = SimpleLogger(stdout, Logging.Debug)
# global_logger(logger)

# @testset "Check @defVariable design" begin
# include("test_defVariable.jl")
# end
include("test_defVariable.jl")

include("testBlocks.jl")

Expand Down
4 changes: 2 additions & 2 deletions test/testBlocks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ import Base: convert
Base.convert(::Type{<:Tuple}, ::typeof(Euclidean(1))) = (:Euclid,)
Base.convert(::Type{<:Tuple}, ::typeof(Euclidean(2))) = (:Euclid, :Euclid)

@defVariable TestVariableType1 Euclidean(1)
@defVariable TestVariableType2 Euclidean(2)
@defVariable TestVariableType1 Euclidean(1) Vector{Float64}
@defVariable TestVariableType2 Euclidean(2) Vector{Float64}

# struct TestVariableType2 <: InferenceVariable
# dims::Int
Expand Down
79 changes: 18 additions & 61 deletions test/test_defVariable.jl
Original file line number Diff line number Diff line change
@@ -1,82 +1,39 @@

using ManifoldsBase, Manifolds
using Manifolds
using Test

# abstract type InferenceVariable end

##

Base.convert(::Type{<:Tuple}, ::typeof(Euclidean(1))) = (:Euclid,)
Base.convert(::Type{<:Tuple}, ::typeof(Euclidean(2))) = (:Euclid, :Euclid)

##

## WARNING, THIS IS A DUPLICATE OF THE MACRO IN DFGVariable.jl TO MAKE SURE OVERLOADS ARE WORKING RIGHT
macro defVariable(structname, manifold)
return esc(quote
Base.@__doc__ struct $structname <: InferenceVariable end

@assert ($manifold isa AbstractManifold) "@defVariable of "*string($structname)*" requires that the "*string($manifold)*" be a subtype of `ManifoldsBase.AbstractManifold`"

# manifold must be is a <:Manifold
Base.convert(::Type{<:AbstractManifold}, ::Union{<:T, Type{<:T}}) where {T <: $structname} = $manifold

getManifold(::Type{M}) where {M <: $structname} = $manifold
getManifold(::M) where {M <: $structname} = getManifold(M)

getDimension(::Type{M}) where {M <: $structname} = manifold_dimension(getManifold(M))
getDimension(::M) where {M <: $structname} = manifold_dimension(getManifold(M))
# FIXME legacy API to be deprecated
# getManifolds(::Type{M}) where {M <: $structname} = convert(Tuple, $manifold)
# getManifolds(::M) where {M <: $structname} = convert(Tuple, $manifold)
end)
end


##


ex = macroexpand(Main, :(@defVariable(TestVariableType1, Euclidean(1))) )


@testset "Testing @defVariable" begin
##

struct NotAManifold end

try
@defVariable(MyVar, NotAManifold())
catch AssertionError
@test true
end
@test_throws AssertionError @defVariable(MyVar, NotAManifold(), Matrix{3})

##

@defVariable(TestVariableType1, Euclidean(1))
@defVariable(TestVariableType2, Euclidean(2))
@defVariable(TestVarType1, Euclidean(3), Vector{Float64})
@defVariable(TestVarType2, SpecialEuclidean(3), ProductRepr{Tuple{Vector{Float64}, Matrix{Float64}}})


##

@test getManifold( TestVariableType1) == Euclidean(1)
@test getManifold( TestVariableType2) == Euclidean(2)
@test getManifold( TestVarType1) == Euclidean(3)
@test getManifold( TestVarType2) == SpecialEuclidean(3)

# legacy
# @test getManifolds(TestVariableType1) == (:Euclid,)
@test getDimension(TestVariableType1) === 1
# @test getManifolds(TestVariableType2) == (:Euclid,:Euclid)
@test getDimension(TestVariableType2) === 2
@test getDimension(TestVarType1) === 3
@test getDimension(TestVarType2) === 6

@test getPointType(TestVarType1) == Vector{Float64}
@test getPointType(TestVarType2) == ProductRepr{Tuple{Vector{Float64}, Matrix{Float64}}}
##


@test getManifold( TestVariableType1()) == Euclidean(1)
@test getManifold( TestVariableType2()) == Euclidean(2)
@test getManifold( TestVarType1()) == Euclidean(3)
@test getManifold( TestVarType2()) == SpecialEuclidean(3)

# legacy
# @test getManifolds(TestVariableType1()) == (:Euclid,)
@test getDimension(TestVariableType1()) === 1
# @test getManifolds(TestVariableType2()) == (:Euclid,:Euclid)
@test getDimension(TestVariableType2()) === 2
@test getDimension(TestVarType1()) === 3
@test getDimension(TestVarType2()) === 6

@test getPointType(TestVarType1()) == Vector{Float64}
@test getPointType(TestVarType2()) == ProductRepr{Tuple{Vector{Float64}, Matrix{Float64}}}

##
end

0 comments on commit be4d03b

Please sign in to comment.