-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
507 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
""" | ||
Stochastic approximation to transition density. | ||
Provide Wiener process. | ||
""" | ||
function slogpW(x0deepv, Lt0, Mt⁺0, μt0, xobst0, Q, Wᵒ) | ||
x0 = deepvec2state(x0deepv) | ||
Xᵒ = Bridge.solve(EulerMaruyama!(), x0, Wᵒ, Q)# this causes the problem | ||
lptilde(vec(x0), Lt0, Mt⁺0, μt0, xobst0) + llikelihood(LeftRule(), Xᵒ, Q; skip = 1) | ||
end | ||
slogpW(Lt0, Mt⁺0, μt0, xobst0, Q, Wᵒ) = (x) -> slogpW(x, Lt0, Mt⁺0, μt0, xobst0, Q, Wᵒ) | ||
∇slogpW(Lt0, Mt⁺0, μt0, xobst0, Q, Wᵒ) = (x) -> gradient(slogpW(Lt0, Mt⁺0, μt0, xobst0, Q, Wᵒ), x) | ||
|
||
function slogpWX(x0deepv, Lt0, Mt⁺0, μt0, xobst0, Q, Wᵒ,Xᵒ) # preferred way | ||
x0 = deepvec2state(x0deepv) | ||
# println(fieldnames(typeof(x0))) | ||
solve!(EulerMaruyama!(), Xᵒ, x0, Wᵒ, Q) | ||
lptilde(vec(x0), Lt0, Mt⁺0, μt0, xobst0) + llikelihood(LeftRule(), Xᵒ, Q; skip = 1) | ||
end | ||
|
||
|
||
""" | ||
update initial momenta and/or guided proposals using either | ||
sgd, sgld or mcmc | ||
""" | ||
function updatepath!(X,Xᵒ,W,Wᵒ,Wnew,ll,x,xᵒ,∇x, ∇xᵒ, | ||
sampler,(Lt0, Mt⁺0, μt0, xobst0, Q),mask, mask_id, δ, ρ, acc) | ||
if sampler in [:sgd, :sgld] | ||
sample!(W, Wiener{Vector{StateW}}()) | ||
cfg = GradientConfig(slogpW(Lt0, Mt⁺0, μt0, xobst0, Q, W), x, Chunk{d*P.n}()) # 2*d*P.n is maximal | ||
@time gradient!(∇x, slogpW(Lt0, Mt⁺0, μt0, xobst0, Q, W),x,cfg) | ||
|
||
if sampler==:sgd | ||
x .+= δ*mask.*∇x | ||
end | ||
if sampler==:sgld | ||
x .+= .5*δ*mask.*∇x + sqrt(δ)*mask.*randn(2d*Q.target.n) | ||
end | ||
xstate = deepvec2state(x) | ||
Bridge.solve!(EulerMaruyama!(), X, xstate, W, Q) | ||
obj = lptilde(vec(xstate), Lt0, Mt⁺0, μt0, xobst0) + | ||
llikelihood(LeftRule(), X, Q; skip = 1) | ||
println("ll ", obj) | ||
end | ||
if sampler==:mcmc | ||
# Update W | ||
sample!(Wnew, Wiener{Vector{PointF}}()) | ||
Wᵒ.yy .= ρ * W.yy + sqrt(1-ρ^2) * Wnew.yy | ||
solve!(EulerMaruyama!(), Xᵒ, deepvec2state(x), Wᵒ, Q) | ||
|
||
llᵒ = llikelihood(Bridge.LeftRule(), Xᵒ, Q,skip=sk) | ||
print("ll $ll $llᵒ, diff_ll: ",round(llᵒ-ll;digits=3)) | ||
|
||
if log(rand()) <= llᵒ - ll | ||
X.yy .= Xᵒ.yy | ||
W.yy .= Wᵒ.yy | ||
ll = llᵒ | ||
print("update innovation accepted") | ||
acc[1] +=1 | ||
else | ||
print("update innovation rejected") | ||
end | ||
println() | ||
|
||
# MALA step (update x) | ||
∇x .= ∇slogpW(Lt0, Mt⁺0, μt0, xobst0, Q, W)(x) | ||
xᵒ .= x .+ .5*δ * mask.* (∇x .+ sqrt(δ) * randn(length(x))) | ||
∇xᵒ .= ∇slogpW(Lt0, Mt⁺0, μt0, xobst0, Q, W)(xᵒ) | ||
|
||
xstate = deepvec2state(x) | ||
xᵒstate = deepvec2state(xᵒ) | ||
solve!(EulerMaruyama!(), Xᵒ, xᵒstate, W, Q) | ||
ainit = lptilde(vec(xᵒstate), Lt0, Mt⁺0, μt0, xobst0) + llikelihood(LeftRule(), Xᵒ, Q; skip = 1) - | ||
lptilde(vec(xstate), Lt0, Mt⁺0, μt0, xobst0) - llikelihood(LeftRule(), X, Q; skip = 1) - | ||
logpdf(MvNormal(d*P.n,sqrt(δ)),(xᵒ - x - .5*δ* mask.* ∇x)[mask_id]) + | ||
logpdf(MvNormal(d*P.n,sqrt(δ)),(x - xᵒ - .5*δ* mask.* ∇xᵒ)[mask_id]) | ||
# compute acc prob | ||
print("ainit: ", ainit) | ||
if log(rand()) <= ainit | ||
x .= xᵒ | ||
xstate = xᵒstate | ||
X.yy .= Xᵒ.yy | ||
println("mala step accepted") | ||
acc[2] +=1 | ||
else | ||
println("mala step rejected") | ||
end | ||
obj = lptilde(vec(xstate), Lt0, Mt⁺0, μt0, xobst0) + | ||
llikelihood(LeftRule(), X, Q; skip = 1) | ||
end | ||
X,Xᵒ,W,Wᵒ,ll,x,xᵒ,∇x,∇xᵒ, obj,acc | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
""" | ||
Generate landmarks data, or read landmarks data. | ||
- dataset: specifies type of data to be generated/loaded | ||
- P: | ||
- t: grid used for generating data | ||
- σobs: standard deviation of observation noise | ||
Returns | ||
- x0: initial state without noise | ||
- xobs0, xobsT: (observed initial and final states with noise) | ||
- Xf: forward simulated path | ||
- P: only adjusted in case of 'real' data, like the bearskull data | ||
Example | ||
x0, xobs0, xobsT, Xf, P = generatedata(dataset,P,t,σobs) | ||
""" | ||
function generatedata(dataset,P,t,σobs) | ||
|
||
n = P.n | ||
if P isa MarslandShardlow | ||
dwiener = P.n | ||
else | ||
dwiener = length(P.nfs) | ||
end | ||
if dataset=="forwardsimulated" | ||
q0 = [PointF(2.0cos(t), sin(t)) for t in (0:(2pi/n):2pi)[1:n]] #q0 = circshift(q0, (1,)) | ||
p0 = [Point(1.0, -3.0) for i in 1:n] # #p0 = [randn(Point) for i in 1:n] | ||
x0 = State(q0, p0) | ||
@time Wf, Xf = landmarksforward(t, dwiener, x0, P) | ||
xobs0 = x0.q + σobs * randn(PointF,n) | ||
xobsT = [Xf.yy[end].q[i] for i in 1:P.n ] + σobs * randn(PointF,n) | ||
end | ||
if dataset in ["shifted","shiftedextreme"] # first stretch, then rotate, then shift; finally add noise | ||
q0 = [PointF(2.0cos(t), sin(t)) for t in (0:(2pi/n):2pi)[1:n]] #q0 = circshift(q0, (1,)) | ||
p0 = [PointF(1.0, -3.0) for i in 1:n] # #p0 = [randn(Point) for i in 1:n] | ||
x0 = State(q0, p0) | ||
@time Wf, Xf = landmarksforward(t, dwiener, x0, P) | ||
xobs0 = x0.q + σobs * randn(PointF,n) | ||
if dataset == "shifted" θ, η = π/10, 0.2 end | ||
if dataset == "shiftedextreme" θ, η = π/5, 0.4 end | ||
rot = SMatrix{2,2}(cos(θ), sin(θ), -sin(θ), cos(θ)) | ||
stretch = SMatrix{2,2}(1.0 + η, 0.0, 0.0, 1.0 - η) | ||
shift = PointF(0.1,-0.1) | ||
xobsT = [rot * stretch * xobs0[i] + shift for i in 1:P.n ] + σobs * randn(PointF,n) | ||
end | ||
if dataset=="bear" | ||
cd("/Users/Frank/github/BridgeLandmarks/landmarks/beardata") | ||
bear0 = readdlm("bear1.csv",',') | ||
bearT = readdlm("bear2.csv",',') | ||
nb = size(bear0)[1] | ||
avePoint = Point(414.0, 290.0) # average of (x,y)-coords for bear0 to center figure at origin | ||
xobs0 = [Point(bear0[i,1], bear0[i,2]) - avePoint for i in 1:nb]/200. | ||
xobsT = [Point(bearT[i,1], bearT[i,2]) - avePoint for i in 1:nb]/200. | ||
# need to redefine P, because of n | ||
if model == :ms | ||
P = MarslandShardlow(a, γ, λ, nb) | ||
else | ||
P = Landmarks(a, 0.0, nb, nfs) | ||
end | ||
x0 = State(xobs0, rand(PointF,P.n)) | ||
Wf, Xf = landmarksforward(t, dwiener, x0, P) | ||
end | ||
if dataset=="heart" | ||
q0 = [PointF(2.0cos(t), 2.0sin(t)) for t in (0:(2pi/n):2pi)[1:n]] #q0 = circshift(q0, (1,)) | ||
p0 = [PointF(1.0, -3.0) for i in 1:n] # #p0 = [randn(Point) for i in 1:n] | ||
x0 = State(q0, p0) | ||
@time Wf, Xf = landmarksforward(t, dwiener, x0, P) | ||
xobs0 = x0.q + σobs * randn(PointF,n) | ||
heart_xcoord(s) = 0.2*(13cos(s)-5cos(2s)-2cos(3s)-cos(4s)) | ||
heart_ycoord(s) = 0.2*16(sin(s)^3) | ||
qT = [PointF(heart_xcoord(t), heart_ycoord(t)) for t in (0:(2pi/n):2pi)[1:n]] #q0 = circshift(q0, (1,)) | ||
xobsT = qT + σobs * randn(PointF,n) | ||
end | ||
x0, xobs0, xobsT, Xf, P | ||
end |
Oops, something went wrong.