Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP Linear parametric solution #468

Merged
merged 22 commits into from
Jan 20, 2020
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions src/BeliefTypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,48 @@ mutable struct NBPMessage <: Singleton
end

### SOME CONVERGENCE REQUIRED ^^^

#DEV NOTE it looks like it can be consolidated into one type
# if we can pass messages similar to EasyMessage:
# pts::Array{Float64,2}
# bws::Array{Float64,1}
# option a tuple
# bellief::Dict{Symbol, NamedTuple{(:vec, :bw, :inferdim),Tuple{Array{Int64,1},Array{Int64,1},Float64}}}
# or an extra type
# or the MsgPrior/PackedMessagePrior, depending on the serialization requirement of the channel
# but I would think only one message type

# mutable struct NBPMessage <: Singleton
# status::Symbol # Ek kort die in die boodskap
# p::Dict{Symbol, EasyMessage}
# end

struct TreeBelief
val::Array{Float64,2}
bw::Array{Float64,2}
inferdim::Float64
# manifolds::T TODO ? JT not needed as you have the variable with all the info in it?
Affie marked this conversation as resolved.
Show resolved Hide resolved
end
TreeBelief(p::BallTreeDensity, inferdim::Real=0.0) = TreeBelief(getPoints(p), getBW(p), inferdim)
TreeBelief(val::Array{Float64,2}, bw::Array{Float64,2}, inferdim::Real=0.0) = TreeBelief(val, bw, inferdim)

"""
CliqStatus
Clique status message enumerated type with status:
initialized, upsolved, marginalized, downsolved, uprecycled
"""
@enum CliqStatus initialized upsolved marginalized downsolved uprecycled error_status


"""
$(TYPEDEF)
Belief message for message passing on the tree.
$(TYPEDFIELDS)
"""
struct BeliefMessage
status::CliqStatus
belief::Dict{Symbol, TreeBelief}
end

BeliefMessage(status::CliqStatus) =
BeliefMessage(status, Dict{Symbol, TreeBelief}())
78 changes: 78 additions & 0 deletions src/DefaultNodeTypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

SamplableBelief = Union{Distributions.Distribution, KernelDensityEstimate.BallTreeDensity, AliasingScalarSampler}

#Supported types for parametric
ParametricTypes = Union{Normal, MvNormal}
dehann marked this conversation as resolved.
Show resolved Hide resolved


"""
$(TYPEDEF)
Expand Down Expand Up @@ -48,6 +51,32 @@ struct Prior{T} <: IncrementalInference.FunctorSingleton where T <: SamplableBel
end
getSample(s::Prior, N::Int=1) = (reshape(rand(s.Z,N),:,N), )


# TODO maybe replace X with a type.
#TODO ParametricTypes or dispatch on each? I like it together for now. Maybe compiler will split it?
function (s::Prior{<:ParametricTypes})(X1::Vector{Float64};
userdata::Union{Nothing,FactorMetadata}=nothing)

if isa(s.Z, Normal)
Affie marked this conversation as resolved.
Show resolved Hide resolved
meas = s.Z.μ
σ = s.Z.σ
#TODO confirm signs
res = [meas - X1[1]]
return (res./σ) .^ 2

elseif isa(s.Z, MvNormal)
meas = mean(s.Z)
iΣ = invcov(s.Z)
#TODO confirm math : Σ^(1/2)*X
res = meas .- X1
return res' * iΣ * res #((res) .^ 2)

else
#this should not happen
@error("$s not suported, please use non-parametric")
end
end

"""
$(TYPEDEF)

Expand All @@ -67,6 +96,29 @@ function MsgPrior(z::T, infd::R) where {T <: SamplableBelief, R <: Real}
end
getSample(s::MsgPrior, N::Int=1) = (reshape(rand(s.Z,N),:,N), )

function (s::MsgPrior{<:ParametricTypes})(X1::Vector{Float64};
userdata::Union{Nothing,FactorMetadata}=nothing)

if isa(s.Z, Normal)
meas = s.Z.μ
σ = s.Z.σ
#TODO confirm signs
res = [meas - X1[1]]
return (res./σ) .^ 2

elseif isa(s.Z, MvNormal)
meas = mean(s.Z)
iΣ = invcov(s.Z)
#TODO confirm math : Σ^(1/2)*X
res = meas .- X1
return res' * iΣ * res

else
#this should not happen
@error("$s not suported, please use non-parametric")
end #
end

struct PackedMsgPrior <: PackedInferenceType where T
Z::String
inferdim::Float64
Expand Down Expand Up @@ -134,6 +186,32 @@ function (s::LinearConditional)(res::Array{Float64},
nothing
end

# parametric specific functor
function (s::LinearConditional{<:ParametricTypes})(X1::Vector{Float64},
Affie marked this conversation as resolved.
Show resolved Hide resolved
X2::Vector{Float64};
userdata::Union{Nothing,FactorMetadata}=nothing)
#can I change userdata to a keyword arg

if isa(s.Z, Normal)
meas = mean(s.Z)
σ = std(s.Z)
# res = similar(X2)
res = [meas - (X2[1] - X1[1])]
return (res/σ) .^ 2

elseif isa(s.Z, MvNormal)
meas = mean(s.Z)
iΣ = invcov(s.Z)
#TODO confirm math : Σ^(1/2)*X
res = meas .- (X2 .- X1)
return res' * iΣ * res

else
#this should not happen
@error("$s not suported, please use non-parametric")
end
end


"""
$(TYPEDEF)
Expand Down
52 changes: 49 additions & 3 deletions src/FactorGraph01.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ end
Retrieve data structure stored in a variable.
"""
function getVariableData(dfg::AbstractDFG, lbl::Symbol; solveKey::Symbol=:default)::VariableNodeData
return solverData(getVariable(dfg, lbl, solveKey=solveKey))
return solverData(getVariable(dfg, lbl), solveKey)
end

"""
Expand Down Expand Up @@ -355,6 +355,43 @@ function getOutNeighbors(dfg::T, vertSym::Symbol; needdata::Bool=false, ready::I
return nodes
end



function DefaultNodeDataParametric(dodims::Int,
dehann marked this conversation as resolved.
Show resolved Hide resolved
dims::Int,
softtype::InferenceVariable;
initialized::Bool=true,
dontmargin::Bool=false)::VariableNodeData

# this should be the only function allocating memory for the node points
if false && initialized
error("not implemented yet")
# pN = AMP.manikde!(randn(dims, N), softtype.manifolds);
#
# sp = Int[0;] #round.(Int,range(dodims,stop=dodims+dims-1,length=dims))
# gbw = getBW(pN)[:,1]
# gbw2 = Array{Float64}(undef, length(gbw),1)
# gbw2[:,1] = gbw[:]
# pNpts = getPoints(pN)
# #initval, stdev
# return VariableNodeData(pNpts,
# gbw2, Symbol[], sp,
# dims, false, :_null, Symbol[], softtype, true, 0.0, false, dontmargin)
else
sp = round.(Int,range(dodims,stop=dodims+dims-1,length=dims))
return VariableNodeData(zeros(dims, 1),
zeros(dims,1), Symbol[], sp,
dims, false, :_null, Symbol[], softtype, false, 0.0, false, dontmargin)
end

end

function setDefaultNodeDataParametric!(v::DFGVariable, softtype::InferenceVariable; kwargs...)
vnd = DefaultNodeDataParametric(0, softtype.dims, softtype; kwargs...)
setSolverData(v, vnd, :parametric)
return nothing
end

function setDefaultNodeData!(v::DFGVariable,
dodims::Int,
N::Int,
Expand Down Expand Up @@ -458,14 +495,23 @@ function addVariable!(dfg::AbstractDFG,
dontmargin::Bool=false,
labels::Vector{Symbol}=Symbol[],
smalldata=Dict{String, String}(),
checkduplicates::Bool=true )::DFGVariable
checkduplicates::Bool=true,
initsolvekeys::Vector{Symbol}=getSolverParams(dfg).algorithms)::DFGVariable

#
v = DFGVariable(lbl, softtype)
v.solvable = solvable
# v.backendset = backendset
v.tags = union(labels, Symbol.(softtype.labels), [:VARIABLE])
v.smallData = smalldata
setDefaultNodeData!(v, 0, N, softtype.dims, initialized=!autoinit, softtype=softtype, dontmargin=dontmargin) # dodims

#JT, Ek weet nie of ek van die manier hou nie. Daar gaan nie so baie algoritmes wees nie so dit sal seker nie so groot raak nie
(:default in initsolvekeys) &&
setDefaultNodeData!(v, 0, N, softtype.dims, initialized=!autoinit, softtype=softtype, dontmargin=dontmargin) # dodims

(:parametric in initsolvekeys) &&
setDefaultNodeDataParametric!(v, softtype, initialized=!autoinit, dontmargin=dontmargin)

DFG.addVariable!(dfg, v)

return v
Expand Down
3 changes: 3 additions & 0 deletions src/FactorGraphTypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ mutable struct SolverParams <: DFG.AbstractParams
N::Int
multiproc::Bool
logpath::String
algorithms::Vector{Symbol} # list of algorithms to run [:default] is mmisam
Affie marked this conversation as resolved.
Show resolved Hide resolved
devParams::Dict{Symbol,String}
SolverParams(;dimID::Int=0,
registeredModuleFunctions=nothing,
Expand All @@ -77,6 +78,7 @@ mutable struct SolverParams <: DFG.AbstractParams
N::Int=100,
multiproc::Bool=true,
logpath::String="/tmp/caesar/$(now())",
algorithms::Vector{Symbol}=[:default],
devParams::Dict{Symbol,String}=Dict{Symbol,String}()) = new(dimID,
registeredModuleFunctions,
reference,
Expand All @@ -95,6 +97,7 @@ mutable struct SolverParams <: DFG.AbstractParams
N,
multiproc,
logpath,
algorithms,
devParams)
#
end
Expand Down
6 changes: 6 additions & 0 deletions src/IncrementalInference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,11 @@ include("TetherUtils.jl")
include("CliqStateMachine.jl")
include("CliqStateMachineUtils.jl")

#EXPERIMENTAL parametric
include("ParametricSolveTree.jl")
include("ParametricCliqStateMachine.jl")
include("ParametricUtils.jl")
Affie marked this conversation as resolved.
Show resolved Hide resolved

# special variables and factors, see RoME.jl for more examples
include("Variables/Sphere1D.jl")
include("Factors/Sphere1D.jl")
Expand All @@ -553,6 +558,7 @@ include("SolverAPI.jl")
include("CanonicalGraphExamples.jl")
include("Deprecated.jl")


exportimg(pl) = error("Please do `using Gadfly` before IncrementalInference is used to allow image export.")
function __init__()
@require Gadfly="c91e804a-d5a3-530f-b6f0-dfbca275c004" begin
Expand Down
37 changes: 37 additions & 0 deletions src/JunctionTree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ function addClique!(bt::AbstractBayesTree, dfg::G, varID::Symbol, condIDs::Array
if isa(bt.bt,GenericIncidenceList)
Graphs.add_vertex!(bt.bt, clq)
bt.cliques[bt.btid] = clq
# clId = bt.btid
elseif isa(bt.bt, MetaDiGraph)
MetaGraphs.add_vertex!(bt.bt, :clique, clq)
# clId = MetaGraphs.nv(bt.bt)
else
error("Oops, something went wrong")
end
Expand All @@ -41,6 +43,7 @@ function addClique!(bt::AbstractBayesTree, dfg::G, varID::Symbol, condIDs::Array
# already emptyBTNodeData() in constructor
# setData!(clq, emptyBTNodeData())

# appendClique!(bt, clId, dfg, varID, condIDs)
appendClique!(bt, bt.btid, dfg, varID, condIDs)
return clq
end
Expand All @@ -52,14 +55,29 @@ export getClique, getCliques, getCliqueIds, getCliqueData

getClique(tree::AbstractBayesTree, cId::Int)::TreeClique = tree.cliques[cId]

getClique(tree::MetaBayesTree, cId::Int)::TreeClique = MetaGraphs.get_prop(tree.bt, cId, :clique)

#TODO
addClique!(tree::AbstractBayesTree, parentCliqId::Int, cliq::TreeClique)::Bool = error("addClique!(tree::AbstractBayesTree, parentCliqId::Int, cliq::TreeClique) not implemented")
updateClique!(tree::AbstractBayesTree, cliq::TreeClique)::Bool = error("updateClique!(tree::AbstractBayesTree, cliq::TreeClique)::Bool not implemented")
deleteClique!(tree::AbstractBayesTree, cId::Int)::TreeClique = error("deleteClique!(tree::AbstractBayesTree, cId::Int)::TreeClique not implemented")

getCliques(tree::AbstractBayesTree) = tree.cliques

function getCliques(tree::MetaBayesTree)
d = Dict{Int,Any}()
for (k,v) in tree.bt.vprops
d[k] = v[:clique]
end
return d
end

getCliqueIds(tree::AbstractBayesTree) = keys(getCliques(tree))

function getCliqueIds(tree::MetaBayesTree)
MetaGraphs.vertices(tree.bt)
end

getCliqueData(cliq::TreeClique)::BayesTreeNodeData = cliq.data
getCliqueData(tree::AbstractBayesTree, cId::Int)::BayesTreeNodeData = getClique(tree, cId) |> getCliqueData

Expand Down Expand Up @@ -590,6 +608,25 @@ function wipeBuildNewTree!(dfg::G;
return prepBatchTree!(dfg, variableOrder=variableOrder, ordering=ordering, drawpdf=drawpdf, show=show, filepath=filepath, viewerapp=viewerapp, imgs=imgs, maxparallel=maxparallel);
end

"""
$(SIGNATURES)
Experimental create and initialize tree message channels
"""
function initTreeMessageChannels!(tree::BayesTree)
for e = 1:tree.bt.nedges
push!(tree.messages, e=>(upMsg=Channel{BeliefMessage}(0),downMsg=Channel{BeliefMessage}(0)))
end
return nothing
end

function initTreeMessageChannels!(tree::MetaBayesTree)
for e = MetaGraphs.edges(tree.bt)
set_props!(tree.bt, e, Dict{Symbol,Any}(:upMsg=>Channel{BeliefMessage}(0),:downMsg=>Channel{BeliefMessage}(0)))
# push!(tree.messages, e=>(upMsg=Channel{BeliefMessage}(0),downMsg=Channel{BeliefMessage}(0)))
end
return nothing
end

"""
$(SIGNATURES)

Expand Down
Loading