Skip to content

Commit

Permalink
Add remaining LM files
Browse files Browse the repository at this point in the history
  • Loading branch information
mschauer committed Jun 24, 2019
1 parent ec273a9 commit 620e7fa
Show file tree
Hide file tree
Showing 3 changed files with 507 additions and 0 deletions.
91 changes: 91 additions & 0 deletions landmarks/automaticdiff_lm.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""
Stochastic approximation to transition density.
Provide Wiener process.
"""
function slogpW(x0deepv, Lt0, Mt⁺0, μt0, xobst0, Q, Wᵒ)
x0 = deepvec2state(x0deepv)
Xᵒ = Bridge.solve(EulerMaruyama!(), x0, Wᵒ, Q)# this causes the problem
lptilde(vec(x0), Lt0, Mt⁺0, μt0, xobst0) + llikelihood(LeftRule(), Xᵒ, Q; skip = 1)
end
slogpW(Lt0, Mt⁺0, μt0, xobst0, Q, Wᵒ) = (x) -> slogpW(x, Lt0, Mt⁺0, μt0, xobst0, Q, Wᵒ)
∇slogpW(Lt0, Mt⁺0, μt0, xobst0, Q, Wᵒ) = (x) -> gradient(slogpW(Lt0, Mt⁺0, μt0, xobst0, Q, Wᵒ), x)

function slogpWX(x0deepv, Lt0, Mt⁺0, μt0, xobst0, Q, Wᵒ,Xᵒ) # preferred way
x0 = deepvec2state(x0deepv)
# println(fieldnames(typeof(x0)))
solve!(EulerMaruyama!(), Xᵒ, x0, Wᵒ, Q)
lptilde(vec(x0), Lt0, Mt⁺0, μt0, xobst0) + llikelihood(LeftRule(), Xᵒ, Q; skip = 1)
end


"""
update initial momenta and/or guided proposals using either
sgd, sgld or mcmc
"""
function updatepath!(X,Xᵒ,W,Wᵒ,Wnew,ll,x,xᵒ,∇x, ∇xᵒ,
sampler,(Lt0, Mt⁺0, μt0, xobst0, Q),mask, mask_id, δ, ρ, acc)
if sampler in [:sgd, :sgld]
sample!(W, Wiener{Vector{StateW}}())
cfg = GradientConfig(slogpW(Lt0, Mt⁺0, μt0, xobst0, Q, W), x, Chunk{d*P.n}()) # 2*d*P.n is maximal
@time gradient!(∇x, slogpW(Lt0, Mt⁺0, μt0, xobst0, Q, W),x,cfg)

if sampler==:sgd
x .+= δ*mask.*∇x
end
if sampler==:sgld
x .+= .5*δ*mask.*∇x + sqrt(δ)*mask.*randn(2d*Q.target.n)
end
xstate = deepvec2state(x)
Bridge.solve!(EulerMaruyama!(), X, xstate, W, Q)
obj = lptilde(vec(xstate), Lt0, Mt⁺0, μt0, xobst0) +
llikelihood(LeftRule(), X, Q; skip = 1)
println("ll ", obj)
end
if sampler==:mcmc
# Update W
sample!(Wnew, Wiener{Vector{PointF}}())
Wᵒ.yy .= ρ * W.yy + sqrt(1-ρ^2) * Wnew.yy
solve!(EulerMaruyama!(), Xᵒ, deepvec2state(x), Wᵒ, Q)

llᵒ = llikelihood(Bridge.LeftRule(), Xᵒ, Q,skip=sk)
print("ll $ll $llᵒ, diff_ll: ",round(llᵒ-ll;digits=3))

if log(rand()) <= llᵒ - ll
X.yy .= Xᵒ.yy
W.yy .= Wᵒ.yy
ll = llᵒ
print("update innovation acceptedœ“")
acc[1] +=1
else
print("update innovation rejected")
end
println()

# MALA step (update x)
∇x .= ∇slogpW(Lt0, Mt⁺0, μt0, xobst0, Q, W)(x)
xᵒ .= x .+ .5*δ * mask.* (∇x .+ sqrt(δ) * randn(length(x)))
∇xᵒ .= ∇slogpW(Lt0, Mt⁺0, μt0, xobst0, Q, W)(xᵒ)

xstate = deepvec2state(x)
xᵒstate = deepvec2state(xᵒ)
solve!(EulerMaruyama!(), Xᵒ, xᵒstate, W, Q)
ainit = lptilde(vec(xᵒstate), Lt0, Mt⁺0, μt0, xobst0) + llikelihood(LeftRule(), Xᵒ, Q; skip = 1) -
lptilde(vec(xstate), Lt0, Mt⁺0, μt0, xobst0) - llikelihood(LeftRule(), X, Q; skip = 1) -
logpdf(MvNormal(d*P.n,sqrt(δ)),(xᵒ - x - .5*δ* mask.* ∇x)[mask_id]) +
logpdf(MvNormal(d*P.n,sqrt(δ)),(x - xᵒ - .5*δ* mask.* ∇xᵒ)[mask_id])
# compute acc prob
print("ainit: ", ainit)
if log(rand()) <= ainit
x .= xᵒ
xstate = xᵒstate
X.yy .= Xᵒ.yy
println("mala step acceptedœ“")
acc[2] +=1
else
println("mala step rejectedœ“")
end
obj = lptilde(vec(xstate), Lt0, Mt⁺0, μt0, xobst0) +
llikelihood(LeftRule(), X, Q; skip = 1)
end
X,Xᵒ,W,Wᵒ,ll,x,xᵒ,∇x,∇xᵒ, obj,acc
end
76 changes: 76 additions & 0 deletions landmarks/generatedata.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""
Generate landmarks data, or read landmarks data.
- dataset: specifies type of data to be generated/loaded
- P:
- t: grid used for generating data
- σobs: standard deviation of observation noise
Returns
- x0: initial state without noise
- xobs0, xobsT: (observed initial and final states with noise)
- Xf: forward simulated path
- P: only adjusted in case of 'real' data, like the bearskull data
Example
x0, xobs0, xobsT, Xf, P = generatedata(dataset,P,t,σobs)
"""
function generatedata(dataset,P,t,σobs)

n = P.n
if P isa MarslandShardlow
dwiener = P.n
else
dwiener = length(P.nfs)
end
if dataset=="forwardsimulated"
q0 = [PointF(2.0cos(t), sin(t)) for t in (0:(2pi/n):2pi)[1:n]] #q0 = circshift(q0, (1,))
p0 = [Point(1.0, -3.0) for i in 1:n] # #p0 = [randn(Point) for i in 1:n]
x0 = State(q0, p0)
@time Wf, Xf = landmarksforward(t, dwiener, x0, P)
xobs0 = x0.q + σobs * randn(PointF,n)
xobsT = [Xf.yy[end].q[i] for i in 1:P.n ] + σobs * randn(PointF,n)
end
if dataset in ["shifted","shiftedextreme"] # first stretch, then rotate, then shift; finally add noise
q0 = [PointF(2.0cos(t), sin(t)) for t in (0:(2pi/n):2pi)[1:n]] #q0 = circshift(q0, (1,))
p0 = [PointF(1.0, -3.0) for i in 1:n] # #p0 = [randn(Point) for i in 1:n]
x0 = State(q0, p0)
@time Wf, Xf = landmarksforward(t, dwiener, x0, P)
xobs0 = x0.q + σobs * randn(PointF,n)
if dataset == "shifted" θ, η = π/10, 0.2 end
if dataset == "shiftedextreme" θ, η = π/5, 0.4 end
rot = SMatrix{2,2}(cos(θ), sin(θ), -sin(θ), cos(θ))
stretch = SMatrix{2,2}(1.0 + η, 0.0, 0.0, 1.0 - η)
shift = PointF(0.1,-0.1)
xobsT = [rot * stretch * xobs0[i] + shift for i in 1:P.n ] + σobs * randn(PointF,n)
end
if dataset=="bear"
cd("/Users/Frank/github/BridgeLandmarks/landmarks/beardata")
bear0 = readdlm("bear1.csv",',')
bearT = readdlm("bear2.csv",',')
nb = size(bear0)[1]
avePoint = Point(414.0, 290.0) # average of (x,y)-coords for bear0 to center figure at origin
xobs0 = [Point(bear0[i,1], bear0[i,2]) - avePoint for i in 1:nb]/200.
xobsT = [Point(bearT[i,1], bearT[i,2]) - avePoint for i in 1:nb]/200.
# need to redefine P, because of n
if model == :ms
P = MarslandShardlow(a, γ, λ, nb)
else
P = Landmarks(a, 0.0, nb, nfs)
end
x0 = State(xobs0, rand(PointF,P.n))
Wf, Xf = landmarksforward(t, dwiener, x0, P)
end
if dataset=="heart"
q0 = [PointF(2.0cos(t), 2.0sin(t)) for t in (0:(2pi/n):2pi)[1:n]] #q0 = circshift(q0, (1,))
p0 = [PointF(1.0, -3.0) for i in 1:n] # #p0 = [randn(Point) for i in 1:n]
x0 = State(q0, p0)
@time Wf, Xf = landmarksforward(t, dwiener, x0, P)
xobs0 = x0.q + σobs * randn(PointF,n)
heart_xcoord(s) = 0.2*(13cos(s)-5cos(2s)-2cos(3s)-cos(4s))
heart_ycoord(s) = 0.2*16(sin(s)^3)
qT = [PointF(heart_xcoord(t), heart_ycoord(t)) for t in (0:(2pi/n):2pi)[1:n]] #q0 = circshift(q0, (1,))
xobsT = qT + σobs * randn(PointF,n)
end
x0, xobs0, xobsT, Xf, P
end
Loading

0 comments on commit 620e7fa

Please sign in to comment.