Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixes, debugging, wip on DERelative tests #1774

Merged
merged 4 commits into from
Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 48 additions & 7 deletions ext/IncrInfrDiffEqFactorExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ import DifferentialEquations: solve
using Dates

using IncrementalInference
import IncrementalInference: getSample, getManifold, DERelative
import IncrementalInference: sampleFactor
import IncrementalInference: DERelative, _solveFactorODE!
import IncrementalInference: getSample, sampleFactor, getManifold

using DocStringExtensions

Expand Down Expand Up @@ -174,12 +174,12 @@ function (cf::CalcFactor{<:DERelative})(measurement, X...)
# need to recalculate new ODE (forward) for change in parameters (solving for 3rd or higher variable)
solveforIdx = 2
# use forward solve for all solvefor not in [1;2]
u0pts = getBelief(cf.fullvariables[1]) |> getPoints
# u0pts = getBelief(cf.fullvariables[1]) |> getPoints
# update parameters for additional variables
_solveFactorODE!(
meas1,
oderel.forwardProblem,
u0pts[cf._sampleIdx],
X[1], # u0pts[cf._sampleIdx],
_maketuplebeyond2args(X...)...,
)
end
Expand All @@ -192,13 +192,52 @@ function (cf::CalcFactor{<:DERelative})(measurement, X...)
#FIXME
res = zeros(size(X[2], 1))
for i = 1:size(X[2], 1)
# diffop( test, reference ) <===> ΔX = test \ reference
# diffop( reference?, test? ) <===> ΔX = test \ reference
res[i] = diffOp[i](X[solveforIdx][i], meas1[i])
end
return res
end


# # FIXME see #1025, `multihypo=` will not work properly yet
# function getSample(cf::CalcFactor{<:DERelative})

# oder = cf.factor

# # how many trajectories to propagate?
# # @show getLabel(cf.fullvariables[2]), getDimension(cf.fullvariables[2])
# meas = zeros(getDimension(cf.fullvariables[2]))

# # pick forward or backward direction
# # set boundary condition
# u0pts = if cf.solvefor == 1
# # backward direction
# prob = oder.backwardProblem
# addOp, diffOp, _, _ = AMP.buildHybridManifoldCallbacks(
# convert(Tuple, getManifold(getVariableType(cf.fullvariables[1]))),
# )
# cf._legacyParams[2]
# else
# # forward backward
# prob = oder.forwardProblem
# # buffer manifold operations for use during factor evaluation
# addOp, diffOp, _, _ = AMP.buildHybridManifoldCallbacks(
# convert(Tuple, getManifold(getVariableType(cf.fullvariables[2]))),
# )
# cf._legacyParams[1]
# end

# i = cf._sampleIdx
# # solve likely elements
# # TODO, does this respect hyporecipe ???
# idxArr = (k -> cf._legacyParams[k][i]).(1:length(cf._legacyParams))
# _solveFactorODE!(meas, prob, u0pts[i], _maketuplebeyond2args(idxArr...)...)
# # _solveFactorODE!(meas, prob, u0pts, i, _maketuplebeyond2args(cf._legacyParams...)...)

# return meas, diffOp
# end




## =========================================================================
Expand All @@ -221,15 +260,17 @@ function IncrementalInference.sampleFactor(cf::CalcFactor{<:DERelative}, N::Int
addOp, diffOp, _, _ = AMP.buildHybridManifoldCallbacks(
convert(Tuple, getManifold(getVariableType(cf.fullvariables[1]))),
)
getBelief(cf.fullvariables[2]) |> getPoints
# getBelief(cf.fullvariables[2]) |> getPoints
cf._legacyParams[2]
else
# forward backward
prob = oder.forwardProblem
# buffer manifold operations for use during factor evaluation
addOp, diffOp, _, _ = AMP.buildHybridManifoldCallbacks(
convert(Tuple, getManifold(getVariableType(cf.fullvariables[2]))),
)
getBelief(cf.fullvariables[1]) |> getPoints
# getBelief(cf.fullvariables[1]) |> getPoints
cf._legacyParams[1]
end

# solve likely elements
Expand Down
3 changes: 3 additions & 0 deletions ext/WeakDepsPrototypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
function _ccolamd! end
function _ccolamd end

# DiffEq
function _solveFactorODE! end

# Flux.jl
function MixtureFluxModels end

Expand Down
2 changes: 1 addition & 1 deletion src/services/ApproxConv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ function approxConvBelief(
)
#
v_trg = getVariable(dfg, target)
N = N == 0 ? getNumPts(v_trg; solveKey = solveKey) : N
N = N == 0 ? getNumPts(v_trg; solveKey) : N
# approxConv should push its result into duplicate memory destination, NOT the variable.VND.val itself. ccw.varValsAll always points directly to variable.VND.val
# points and infoPerCoord

Expand Down
4 changes: 2 additions & 2 deletions src/services/CalcFactor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ end
function CalcFactor(
ccwl::CommonConvWrapper;
factor = ccwl.usrfnc!,
_sampleIdx = 0,
_sampleIdx = ccwl.particleidx[],
_legacyParams = ccwl.varValsAll[],
_allowThreads = true,
cache = ccwl.dummyCache,
Expand Down Expand Up @@ -399,7 +399,7 @@ function _createCCW(
# MeasType = Vector{Float64} # FIXME use `usrfnc` to get this information instead
_cf = CalcFactor(
usrfnc,
0,
1,
_varValsAll,
false,
userCache,
Expand Down
4 changes: 2 additions & 2 deletions src/services/EvalFactor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -572,10 +572,10 @@ function evalFactor(
dfg::AbstractDFG,
fct::DFGFactor,
solvefor::Symbol,
measurement::AbstractVector = Tuple[];
measurement::AbstractVector = Tuple[]; # FIXME ensure type stable in all cases
needFreshMeasurements::Bool = true,
solveKey::Symbol = :default,
variables = getVariable.(dfg, getVariableOrder(fct)), # because we trying to use StaticArrays, go figure
variables = getVariable.(dfg, getVariableOrder(fct)), # FIXME use tuple instead for type stability
N::Int = length(measurement),
inflateCycles::Int = getSolverParams(dfg).inflateCycles,
nullSurplus::Real = 0,
Expand Down
20 changes: 15 additions & 5 deletions src/services/FactorGraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,12 @@ function setValKDE!(
setinit::Bool = true,
ipc::AbstractVector{<:Real} = [0.0;];
solveKey::Symbol = :default,
) where {P}
ppeType::Type{T} = MeanMaxPPE,
) where {P, T}
vnd = getSolverData(v, solveKey)
# recover variableType information
setValKDE!(getSolverData(v, solveKey), val, setinit, ipc)
setValKDE!(vnd, val, setinit, ipc)
setPPE!(v; solveKey, ppeType)
return nothing
end
function setValKDE!(
Expand Down Expand Up @@ -246,7 +249,7 @@ end

function setValKDE!(
vnd::VariableNodeData,
mkd::ManifoldKernelDensity{M, B, Nothing},
mkd::ManifoldKernelDensity{M, B, Nothing}, # TBD dispatch without partial?
setinit::Bool = true,
ipc::AbstractVector{<:Real} = [0.0;],
) where {M, B}
Expand Down Expand Up @@ -282,8 +285,15 @@ function setValKDE!(
return nothing
end

function setBelief!(vari::DFGVariable, bel::ManifoldKernelDensity, setinit::Bool=true,ipc::AbstractVector{<:Real}=[0.0;])
setValKDE!(vari,getPoints(bel, false),setinit, ipc)
function setBelief!(
vari::DFGVariable,
bel::ManifoldKernelDensity,
setinit::Bool=true,
ipc::AbstractVector{<:Real}=[0.0;];
solveKey::Symbol = :default
)
setValKDE!(vari, bel, setinit, ipc; solveKey)
# setValKDE!(vari,getPoints(bel, false), setinit, ipc)
end

"""
Expand Down
16 changes: 9 additions & 7 deletions src/services/GraphProductOperations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,18 @@ function propagateBelief(

# get proposal beliefs
destlbl = getLabel(destvar)
ipc = proposalbeliefs!(dfg, destlbl, factors, dens; solveKey = solveKey, N = N, dbg = dbg)
ipc = proposalbeliefs!(dfg, destlbl, factors, dens; solveKey, N, dbg)

# @show dens[1].manifold

# make sure oldpts has right number of points
# make sure oldPoints vector has right length
oldBel = getBelief(dfg, destlbl, solveKey)
oldpts = if Npts(oldBel) == N
getPoints(oldBel)
_pts = getPoints(oldBel, false)
oldPoints = if Npts(oldBel) <= N
_pts[1:N]
else
sample(oldBel, N)[1]
nn = N - length(_pts) # should be larger than 0
vcat(_pts, sample(oldBel, nn))
end

# few more data requirements
Expand All @@ -51,8 +53,8 @@ function propagateBelief(
dens,
M;
Niter = 1,
oldPoints = oldpts,
N = N,
oldPoints,
N,
u0 = getPointDefault(varType),
)

Expand Down
2 changes: 1 addition & 1 deletion src/services/SolveTree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ function doFMCIteration(
logger,
)

if 0 < length(getPoints(dens))
if 0 < Npts(dens)
setBelief!(vert, dens, true, ipc)
# setValKDE!(vert, densPts, true, ipc)
# TODO perhaps more debugging inside `propagateBelief`?
Expand Down
36 changes: 30 additions & 6 deletions src/services/SolverUtilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,18 @@ function mmd(
nodeType::Union{InstanceType{<:InferenceVariable}, InstanceType{<:AbstractFactor}},
threads::Bool = true;
bw::AbstractVector{<:Real} = SA[0.001;],
asPartial::Bool = true
)
#
return mmd(getPoints(p1), getPoints(p2), nodeType, threads; bw)
return mmd(getPoints(p1, asPartial), getPoints(p2, asPartial), nodeType, threads; bw)
end

# part of consolidation, see #927
function sampleFactor!(ccwl::CommonConvWrapper, N::Int; _allowThreads::Bool=true)
function sampleFactor!(
ccwl::CommonConvWrapper,
N::Int;
_allowThreads::Bool=true
)
#

# FIXME get allocations here down to 0
Expand All @@ -60,23 +65,42 @@ function sampleFactor!(ccwl::CommonConvWrapper, N::Int; _allowThreads::Bool=true
return ccwl.measurement
end

function sampleFactor(ccwl::CommonConvWrapper, N::Int; _allowThreads::Bool=true)
function sampleFactor(
ccwl::CommonConvWrapper,
N::Int;
_allowThreads::Bool=true
)
#
cf = CalcFactor(ccwl; _allowThreads)
return sampleFactor(cf, N)
end

sampleFactor(fct::DFGFactor, N::Int = 1; _allowThreads::Bool=true) = sampleFactor(_getCCW(fct), N; _allowThreads)
sampleFactor(
fct::DFGFactor,
N::Int = 1;
_allowThreads::Bool=true
) = sampleFactor(
_getCCW(fct),
N;
_allowThreads
)

function sampleFactor(dfg::AbstractDFG, sym::Symbol, N::Int = 1; _allowThreads::Bool=true)
function sampleFactor(
dfg::AbstractDFG,
sym::Symbol,
N::Int = 1;
_allowThreads::Bool=true
)
#
return sampleFactor(getFactor(dfg, sym), N; _allowThreads)
end

"""
$(SIGNATURES)

Update cliq `cliqID` in Bayes (Juction) tree `bt` according to contents of `urt` -- intended use is to update main clique after a upward belief propagation computation has been completed per clique.
Update cliq `cliqID` in Bayes (Juction) tree `bt` according to contents of `urt`.
Intended use is to update main clique after a upward belief propagation computation
has been completed per clique.
"""
function updateFGBT!(
fg::AbstractDFG,
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ TEST_GROUP = get(ENV, "IIF_TEST_GROUP", "all")
# temporarily moved to start (for debugging)
#...
if TEST_GROUP in ["all", "tmp_debug_group"]
include("testSpecialOrthogonalMani.jl")
include("testDERelative.jl")
include("testSpecialOrthogonalMani.jl")
include("testMultiHypo3Door.jl")
include("priorusetest.jl")
end
Expand Down
44 changes: 43 additions & 1 deletion test/testBasicGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -306,15 +306,57 @@ pts_ = getPoints(getBelief(fg, :x4))
TensorCast.@cast pts[i,j] := pts_[j][i]
@test 0.2 < Statistics.cov( pts[1,:] ) < 3.2



@testset "Test localProduct on solveKey" begin

localProduct(fg,:x2)

localProduct(fg,:x2, solveKey=:graphinit)

end

end


##
@testset "consistency check on more factors (origin is a DERelative fail case)" begin
##

fg = initfg()

addVariable!(fg, :x0, Position{1})
addFactor!(fg, [:x0], Prior(Normal(1.0, 0.01)))

# force a basic setup
initAll!(fg)
@test isapprox( 1, getPPE(fg, :x0).suggested[1]; atol=0.1)

##

addVariable!(fg, :x1, Position{1})
addFactor!(fg, [:x0;:x1], LinearRelative(Normal(1.0, 0.01)))

addVariable!(fg, :x2, Position{1})
addFactor!(fg, [:x1;:x2], LinearRelative(Normal(1.0, 0.01)))

addVariable!(fg, :x3, Position{1})
addFactor!(fg, [:x2;:x3], LinearRelative(Normal(1.0, 0.01)))

##

tree = solveGraph!(fg)

##

@test isapprox( 1, getPPE(fg, :x0).suggested[1]; atol=0.1)
@test isapprox( 4, getPPE(fg, :x3).suggested[1]; atol=0.3)

## check contents of tree messages

tree[1]
msg1 = IIF.getMessageBuffer(tree[1])

##
end


Expand Down
Loading