Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP Linear parametric solution #468

Merged
merged 22 commits into from
Jan 20, 2020
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@ test/tmp/*.dot
test/tmp/*.tex

Manifest.toml

dev
46 changes: 46 additions & 0 deletions src/BeliefTypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,49 @@ mutable struct NBPMessage <: Singleton
end

### SOME CONVERGENCE REQUIRED ^^^

#DEV NOTE it looks like it can be consolidated into one type
# if we can pass messages similar to EasyMessage:
# pts::Array{Float64,2}
# bws::Array{Float64,1}
# option a tuple
# bellief::Dict{Symbol, NamedTuple{(:vec, :bw, :inferdim),Tuple{Array{Int64,1},Array{Int64,1},Float64}}}
# or an extra type
# or the MsgPrior/PackedMessagePrior, depending on the serialization requirement of the channel
# but I would think only one message type

# mutable struct NBPMessage <: Singleton
# status::Symbol # Ek kort die in die boodskap
# p::Dict{Symbol, EasyMessage}
# end

struct TreeBelief
val::Array{Float64,2}
bw::Array{Float64,2}
inferdim::Float64
# manifolds::T TODO ? JT not needed as you have the variable with all the info in it?
Affie marked this conversation as resolved.
Show resolved Hide resolved
end
TreeBelief(p::BallTreeDensity, inferdim::Real=0.0) = TreeBelief(getPoints(p), getBW(p), inferdim)
TreeBelief(val::Array{Float64,2}, bw::Array{Float64,2}, inferdim::Real=0.0) = TreeBelief(val, bw, inferdim)

"""
CliqStatus
Clique status message enumerated type with status:
initialized, upsolved, marginalized, downsolved, uprecycled
"""
@enum CliqStatus initialized upsolved marginalized downsolved uprecycled error_status

abstract type AbstractBeliefMessage end
#JT Ek maak nog whahahaha, maar kom ons inheret van AbstractBeliefMessage?
# of ons gebruik T parameter in belief::Dict{Symbol, T} where T <: Union{BallTreeDensity, Vector{Float64}}
# It should be possible to use this as a standard message for all. Then just BeliefMessage
struct ParametricBeliefMessage <: AbstractBeliefMessage
# TODO JT nog velde, maar dis wat ek sover aan kan dink
status::CliqStatus
belief::Dict{Symbol, TreeBelief}
# of dalk named tuple maak
# bellief::Dict{Symbol, NamedTuple{(:val, :bw, :inferdim),Tuple{Array{Int64,1},Array{Int64,1},Float64}}}
end

ParametricBeliefMessage(status::CliqStatus) =
ParametricBeliefMessage(status, Dict{Symbol, TreeBelief}())
78 changes: 78 additions & 0 deletions src/DefaultNodeTypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

SamplableBelief = Union{Distributions.Distribution, KernelDensityEstimate.BallTreeDensity, AliasingScalarSampler}

#Supported types for parametric
ParametricTypes = Union{Normal, MvNormal}
dehann marked this conversation as resolved.
Show resolved Hide resolved


"""
$(TYPEDEF)
Expand Down Expand Up @@ -48,6 +51,32 @@ struct Prior{T} <: IncrementalInference.FunctorSingleton where T <: SamplableBel
end
getSample(s::Prior, N::Int=1) = (reshape(rand(s.Z,N),:,N), )


# TODO maybe replace X with a type.
#TODO ParametricTypes or dispatch on each? I like it together for now. Maybe compiler will split it?
function (s::Prior{<:ParametricTypes})(X1::Vector{Float64};
userdata::Union{Nothing,FactorMetadata}=nothing)

if isa(s.Z, Normal)
Affie marked this conversation as resolved.
Show resolved Hide resolved
meas = s.Z.μ
σ = s.Z.σ
#TODO confirm signs
res = [meas - X1[1]]
return (res./σ) .^ 2

elseif isa(s.Z, MvNormal)
meas = mean(s.Z)
iΣ = invcov(s.Z)
#TODO confirm math : Σ^(1/2)*X
res = meas .- X1
return res' * iΣ * res #((res) .^ 2)

else
#this should not happen
@error("$s not suported, please use non-parametric")
end
end

"""
$(TYPEDEF)

Expand All @@ -67,6 +96,29 @@ function MsgPrior(z::T, infd::R) where {T <: SamplableBelief, R <: Real}
end
getSample(s::MsgPrior, N::Int=1) = (reshape(rand(s.Z,N),:,N), )

function (s::MsgPrior{<:ParametricTypes})(X1::Vector{Float64};
userdata::Union{Nothing,FactorMetadata}=nothing)

if isa(s.Z, Normal)
meas = s.Z.μ
σ = s.Z.σ
#TODO confirm signs
res = [meas - X1[1]]
return (res./σ) .^ 2

elseif isa(s.Z, MvNormal)
meas = mean(s.Z)
iΣ = invcov(s.Z)
#TODO confirm math : Σ^(1/2)*X
res = meas .- X1
return res' * iΣ * res

else
#this should not happen
@error("$s not suported, please use non-parametric")
end #
end

struct PackedMsgPrior <: PackedInferenceType where T
Z::String
inferdim::Float64
Expand Down Expand Up @@ -134,6 +186,32 @@ function (s::LinearConditional)(res::Array{Float64},
nothing
end

# parametric specific functor
function (s::LinearConditional{<:ParametricTypes})(X1::Vector{Float64},
Affie marked this conversation as resolved.
Show resolved Hide resolved
X2::Vector{Float64};
userdata::Union{Nothing,FactorMetadata}=nothing)
#can I change userdata to a keyword arg

if isa(s.Z, Normal)
meas = mean(s.Z)
σ = std(s.Z)
# res = similar(X2)
res = [meas - (X2[1] - X1[1])]
return (res/σ) .^ 2

elseif isa(s.Z, MvNormal)
meas = mean(s.Z)
iΣ = invcov(s.Z)
#TODO confirm math : Σ^(1/2)*X
res = meas .- (X2 .- X1)
return res' * iΣ * res

else
#this should not happen
@error("$s not suported, please use non-parametric")
end
end


"""
$(TYPEDEF)
Expand Down
52 changes: 49 additions & 3 deletions src/FactorGraph01.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ getData(v::Graphs.ExVertex) = v.attributes["data"]
Retrieve data structure stored in a variable.
"""
function getVariableData(dfg::AbstractDFG, lbl::Symbol; solveKey::Symbol=:default)::VariableNodeData
return solverData(getVariable(dfg, lbl, solveKey=solveKey))
return solverData(getVariable(dfg, lbl), solveKey)
end

"""
Expand Down Expand Up @@ -351,6 +351,43 @@ function getOutNeighbors(dfg::T, vertSym::Symbol; needdata::Bool=false, ready::I
return nodes
end



function DefaultNodeDataParametric(dodims::Int,
dehann marked this conversation as resolved.
Show resolved Hide resolved
dims::Int,
softtype::InferenceVariable;
initialized::Bool=true,
dontmargin::Bool=false)::VariableNodeData

# this should be the only function allocating memory for the node points
if false && initialized
error("not implemented yet")
# pN = AMP.manikde!(randn(dims, N), softtype.manifolds);
#
# sp = Int[0;] #round.(Int,range(dodims,stop=dodims+dims-1,length=dims))
# gbw = getBW(pN)[:,1]
# gbw2 = Array{Float64}(undef, length(gbw),1)
# gbw2[:,1] = gbw[:]
# pNpts = getPoints(pN)
# #initval, stdev
# return VariableNodeData(pNpts,
# gbw2, Symbol[], sp,
# dims, false, :_null, Symbol[], softtype, true, 0.0, false, dontmargin)
else
sp = round.(Int,range(dodims,stop=dodims+dims-1,length=dims))
return VariableNodeData(zeros(dims, 1),
zeros(dims,1), Symbol[], sp,
dims, false, :_null, Symbol[], softtype, false, 0.0, false, dontmargin)
end

end

function setDefaultNodeDataParametric!(v::DFGVariable, softtype::InferenceVariable; kwargs...)
vnd = DefaultNodeDataParametric(0, softtype.dims, softtype; kwargs...)
setSolverData(v, vnd, :parametric)
return nothing
end

function setDefaultNodeData!(v::DFGVariable,
dodims::Int,
N::Int,
Expand Down Expand Up @@ -417,14 +454,23 @@ function addVariable!(dfg::AbstractDFG,
dontmargin::Bool=false,
labels::Vector{Symbol}=Symbol[],
smalldata=Dict{String, String}(),
checkduplicates::Bool=true )::DFGVariable
checkduplicates::Bool=true,
initsolvekeys::Vector{Symbol}=getSolverParams(dfg).algorithms)::DFGVariable

#
v = DFGVariable(lbl, softtype)
v.solvable = solvable
# v.backendset = backendset
v.tags = union(labels, Symbol.(softtype.labels), [:VARIABLE])
v.smallData = smalldata
setDefaultNodeData!(v, 0, N, softtype.dims, initialized=!autoinit, softtype=softtype, dontmargin=dontmargin) # dodims

#JT, Ek weet nie of ek van die manier hou nie. Daar gaan nie so baie algoritmes wees nie so dit sal seker nie so groot raak nie
(:default in initsolvekeys) &&
setDefaultNodeData!(v, 0, N, softtype.dims, initialized=!autoinit, softtype=softtype, dontmargin=dontmargin) # dodims

(:parametric in initsolvekeys) &&
setDefaultNodeDataParametric!(v, softtype, initialized=!autoinit, dontmargin=dontmargin)

DFG.addVariable!(dfg, v)

return v
Expand Down
3 changes: 3 additions & 0 deletions src/FactorGraphTypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ mutable struct SolverParams <: DFG.AbstractParams
N::Int
multiproc::Bool
logpath::String
algorithms::Vector{Symbol} # list of algorithms to run [:default] is mmisam
Affie marked this conversation as resolved.
Show resolved Hide resolved
devParams::Dict{Symbol,String}
SolverParams(;dimID::Int=0,
registeredModuleFunctions=nothing,
Expand All @@ -77,6 +78,7 @@ mutable struct SolverParams <: DFG.AbstractParams
N::Int=100,
multiproc::Bool=true,
logpath::String="/tmp/caesar/$(now())",
algorithms::Vector{Symbol}=[:default],
devParams::Dict{Symbol,String}=Dict{Symbol,String}()) = new(dimID,
registeredModuleFunctions,
reference,
Expand All @@ -95,6 +97,7 @@ mutable struct SolverParams <: DFG.AbstractParams
N,
multiproc,
logpath,
algorithms,
devParams)
#
end
Expand Down
5 changes: 5 additions & 0 deletions src/IncrementalInference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,11 @@ include("TetherUtils.jl")
include("CliqStateMachine.jl")
include("CliqStateMachineUtils.jl")

#EXPERIMENTAL parametric
include("ParametricSolveTree.jl")
include("ParametricCliqStateMachine.jl")
include("ParametricUtils.jl")
Affie marked this conversation as resolved.
Show resolved Hide resolved

# special variables and factors, see RoME.jl for more examples
include("Variables/Sphere1D.jl")
include("Factors/Sphere1D.jl")
Expand Down
11 changes: 11 additions & 0 deletions src/JunctionTree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,17 @@ function wipeBuildNewTree!(dfg::G;
return prepBatchTree!(dfg, ordering=ordering, drawpdf=drawpdf, show=show, filepath=filepath, viewerapp=viewerapp, imgs=imgs, maxparallel=maxparallel);
end

"""
$(SIGNATURES)
Experimental create and initialize tree message channels
"""
function initTreeMessageChannels!(tree::BayesTree) # TODO as dit lekker werk gebruik T vir ParametricBeliefMessage
for e = 1:tree.bt.nedges
push!(tree.messages, e=>(upMsg=Channel{ParametricBeliefMessage}(0),downMsg=Channel{ParametricBeliefMessage}(0)))
end
return nothing
end

"""
$(SIGNATURES)

Expand Down
9 changes: 7 additions & 2 deletions src/JunctionTreeTypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@ mutable struct BayesTree
btid::Int
cliques::Dict{Int,Graphs.ExVertex}
frontals::Dict{Symbol,Int}
#TEMP JT for evaluation, to store message channels in tree, would be better stored on edges
Affie marked this conversation as resolved.
Show resolved Hide resolved
messages::Dict{Int, NamedTuple{(:upMsg, :downMsg),Tuple{Channel{<:AbstractBeliefMessage},Channel{<:AbstractBeliefMessage}}}}
end

function emptyBayesTree()
bt = BayesTree(Graphs.inclist(Graphs.ExVertex,is_directed=true),
0,
Dict{Int,Graphs.ExVertex}(),
#[],
Dict{AbstractString, Int}())
Dict{AbstractString, Int}(),
Dict{Int, NamedTuple{(:upMsg, :downMsg),Tuple{Channel{ParametricBeliefMessage},Channel{ParametricBeliefMessage}}}}())
return bt
end

Expand Down Expand Up @@ -51,6 +54,8 @@ mutable struct CliqStateMachineContainer{BTND, T <: AbstractDFG, InMemG <: InMem
refactoring::Dict{Symbol, String}
oldcliqdata::BTND
logger::SimpleLogger
parametricMsgsUp::Vector{ParametricBeliefMessage} #TODO TEMP net om te kyk hoe dit werk, dit moet eintelik <: abstact wees
parametricMsgsDown::Vector{ParametricBeliefMessage}
Affie marked this conversation as resolved.
Show resolved Hide resolved
CliqStateMachineContainer{BTND}() where {BTND} = new{BTND, DFG.GraphsDFG, DFG.GraphsDFG}() # NOTE JT - GraphsDFG as default?
CliqStateMachineContainer{BTND}(x1::G,
x2::InMemoryDFGTypes,
Expand All @@ -66,7 +71,7 @@ mutable struct CliqStateMachineContainer{BTND, T <: AbstractDFG, InMemG <: InMem
x10aaa::SolverParams,
x10b::Dict{Symbol,String}=Dict{Symbol,String}(),
x11::BTND=emptyBTNodeData(),
x13::SimpleLogger=SimpleLogger(Base.stdout) ) where {BTND, G <: AbstractDFG} = new{BTND, G, typeof(x2)}(x1,x2,x3,x4,x5,x6,x7,x8,x9,x10a,x10aa,x10aaa,x10b,x11, x13)
x13::SimpleLogger=SimpleLogger(Base.stdout) ) where {BTND, G <: AbstractDFG} = new{BTND, G, typeof(x2)}(x1,x2,x3,x4,x5,x6,x7,x8,x9,x10a,x10aa,x10aaa,x10b,x11, x13, ParametricBeliefMessage[], ParametricBeliefMessage[])
end

function CliqStateMachineContainer(x1::G,
Expand Down
Loading