Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Show and refresh #40

Merged
merged 9 commits into from
Dec 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
matrix:
version:
- '1'
- '1.5'
- '1.6'
- 'nightly'
os:
- ubuntu-latest
Expand Down
18 changes: 6 additions & 12 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Turkie"
uuid = "8156cc02-0533-41cd-9345-13411ebe105f"
authors = ["Theo Galy-Fajou <[email protected]> and contributors"]
version = "0.1.9"
version = "0.1.10"

[deps]
ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4"
Expand All @@ -13,17 +13,11 @@ OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

[compat]
ColorSchemes = "3.10"
ColorSchemes = "3"
Colors = "0.12"
KernelDensity = "0.5, 0.6"
KernelDensity = "0.6"
MCMCChains = "4, 5"
Makie = "0.13, 0.14, 0.15"
Makie = "0.15"
OnlineStats = "1.5"
Turing = "0.15, 0.16, 0.17, 0.18"
julia = "1.5"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
Turing = "0.15, 0.16, 0.17, 0.18, 0.19"
julia = "1.6"
2 changes: 1 addition & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ If you want a specific implementation of a certain stat please [open an issue](h
If you want to make a cool animation you can use the built-in recording features of Makie.
Here is the simple example, using the Turing example from above:
```julia
record(cb.scene, joinpath(@__DIR__, "video.webm")) do io
record(cb.figure, joinpath(@__DIR__, "video.webm")) do io
addIO!(cb, io)
sample(m, NUTS(0.65), 300; callback = cb)
end
Expand Down
4 changes: 2 additions & 2 deletions src/Turkie.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
module Turkie

using Makie: Makie, Figure, Scene, Point2f0
using Makie: Makie, Figure, Axis, Point2f0
using Makie: barplot!, lines!, scatter! # Plotting tools
using Makie: Observable, Node, lift, on # Observable tools
using Makie: recordframe! # Recording tools
using Makie.MakieLayout # Layouting tool
using Colors, ColorSchemes # Colors tools
using KernelDensity # To be able to give a KDE
using OnlineStats # Estimators
using Turing: DynamicPPL.VarInfo, DynamicPPL.Model, Inference._params_to_array
using Turing: DynamicPPL.VarInfo, DynamicPPL.Model, Inference, Inference._params_to_array

using MCMCChains

Expand Down
91 changes: 70 additions & 21 deletions src/live_sampling/online_stats_plots.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,24 @@
function onlineplot!(fig, axis_dict, stats::AbstractVector, iter, data, variable, i)
function onlineplot!(fig::Figure, axis_dict::AbstractDict, stats::AbstractVector, stats_dict::AbstractDict, iter, data, variable, i)
# Iter over all the stats for a given variable, create an axis and add
# the appropriate plots on it.
for (j, stat) in enumerate(stats)
axis_dict[(variable, stat)] = fig[i, j] = Axis(fig, title="$(name(stat))")
limits!(axis_dict[(variable, stat)], 0.0, 10.0, -1.0, 1.0)
onlineplot!(axis_dict[(variable, stat)], stat, iter, data[variable], data[:iter], i, j)
stats_dict[(variable, stat)] = []
onlineplot!(axis_dict[(variable, stat)], stat, stats_dict[(variable, stat)], iter, data[variable], data[:iter], i, j)
tight_ticklabel_spacing!(axis_dict[(variable, stat)])
end
end

function onlineplot!(axis, stat::Symbol, args...)
# Reset the saved stats to be able to refresh the plots when wanted
reset!(::Any, ::Any) = nothing # Default behavior is to do nothing
reset!(stats, stat::Symbol) = reset!(stats, Val(stat))
reset!(stats, ::Val{:mean}) = reset!(stats, Mean())
reset!(stats, ::Val{:var}) = reset!(stats, Variance())
reset!(stats, ::Val{:autocov}) = reset!(stats, AutoCov(20))
reset!(stats, ::Val{:hist}) = reset!(stats, KHist(50, Float32))

function onlineplot!(axis::Axis, stat::Symbol, args...)
onlineplot!(axis, Val(stat), args...)
end

Expand All @@ -20,44 +31,62 @@ onlineplot!(axis, ::Val{:autocov}, args...) = onlineplot!(axis, AutoCov(20), arg
onlineplot!(axis, ::Val{:hist}, args...) = onlineplot!(axis, KHist(50, Float32), args...)

# Generic fallback for OnlineStat objects
function onlineplot!(axis, stat::T, iter, data, iterations, i, j) where {T<:OnlineStat}
window = data.b
function onlineplot!(axis, stat::T, stats, iter, data, iterations, i, j) where {T<:OnlineStat}
window = data[].b
@eval TStat = $(nameof(T))
# Create an observable based on the given stat
stat = Observable(TStat(Float32))
on(iter) do _
stat[] = fit!(stat[], last(value(data)))
stat[] = fit!(stat[], last(value(data[])))
end
push!(stats, stat)
# Create a moving window on this value
statvals = Observable(MovingWindow(window, Float32))
on(stat) do s
statvals[] = fit!(statvals[], Float32(value(s)))
end
push!(stats, statvals)
# Pass this observable to create points to pass to Makie
statpoints = map!(Observable(Point2f0.([0], [0])), statvals) do v
Point2f0.(value(iterations), value(v))
Point2f0.(value(iterations[]), value(v))
end
lines!(axis, statpoints, color = std_colors[i], linewidth = 3.0)
end

function onlineplot!(axis, ::Val{:trace}, iter, data, iterations, i, j)
function reset!(stats, stat::T) where {T<:OnlineStat}
@eval TStat = $(nameof(T))
stats[1].val = TStat(Float32) # Represent the actual stat
stats[2].val = MovingWindow(stats[2][].b, Float32) # Represent the moving window on the stat
end

function onlineplot!(axis, ::Val{:trace}, stats, iter, data, iterations, i, j)
trace = map!(Observable([Point2f0(0, 0)]), iter) do _
Point2f0.(value(iterations), value(data))
Point2f0.(value(iterations[]), value(data[]))
end
lines!(axis, trace, color = std_colors[i]; linewidth = 3.0)
end

function onlineplot!(axis, stat::KHist, iter, data, iterations, i, j)
function onlineplot!(axis, stat::KHist, stats, iter, data, iterations, i, j)
nbins = stat.k
stat = Observable(KHist(nbins, Float32))
on(iter) do _
stat[] = fit!(stat[], last(value(data)))
stat[] = fit!(stat[], last(value(data[])))
end
hist_vals = Node(Point2f0.(collect(range(0f0, 1f0, length=nbins)), zeros(Float32, nbins)))
push!(stats, stat)
on(stat) do h
edges, weights = OnlineStats.xy(h)
weights = nobs(h) > 1 ? weights / OnlineStats.area(h) : weights
hist_vals[] = Point2f0.(edges, weights)
end
push!(stats, hist_vals)
barplot!(axis, hist_vals; color=std_colors[i])
# barplot!(axis, rand(4), rand(4))
end

function reset!(stats, stat::KHist)
nbins = stat.k
stats[1].val = KHist(nbins, Float32)
stats[2].val = Point2f0.(collect(range(0f0, 1f0, length=nbins)), zeros(Float32, nbins))
end

function expand_extrema(xs)
Expand All @@ -68,37 +97,57 @@ function expand_extrema(xs)
return (xmin, xmax)
end

function onlineplot!(axis, ::Val{:kde}, iter, data, iterations, i, j)
function onlineplot!(axis, ::Val{:kde}, stats, iter, data, iterations, i, j)
interpkde = Observable(InterpKDE(kde([1f0])))
on(iter) do _
interpkde[] = InterpKDE(kde(value(data)))
interpkde[] = InterpKDE(kde(value(data[])))
end
push!(stats, interpkde)
xs = Observable(range(0, 2, length=10))
on(iter) do _
xs[] = range(expand_extrema(extrema(value(data)))..., length = 200)
xs[] = range(expand_extrema(extrema(value(data[])))..., length = 200)
end
push!(stats, xs)
kde_pdf = lift(xs) do xs
pdf.(Ref(interpkde[]), xs)
end
lines!(axis, xs, kde_pdf, color = std_colors[i], linewidth = 3.0)
end

function reset!(stats, ::Val{:kde})
stats[1].val = InterpKDE(kde([1f0]))
stats[2].val = range(0, 2, length=10)
end

name(s::Val{:histkde}) = "Hist. + KDE"

function onlineplot!(axis, ::Val{:histkde}, iter, data, iterations, i, j)
onlineplot!(axis, KHist(50), iter, data, iterations, i, j)
onlineplot!(axis, Val(:kde), iter, data, iterations, i, j)
function onlineplot!(axis, ::Val{:histkde}, stats, iter, data, iterations, i, j)
onlineplot!(axis, KHist(50), stats, iter, data, iterations, i, j)
onlineplot!(axis, Val(:kde), stats, iter, data, iterations, i, j)
end

function onlineplot!(axis, stat::AutoCov, iter, data, iterations, i, j)
b = length(stat.cross)
function reset!(stats, ::Val{:histkde})
reset!(stats[1:2], KHist(50))
reset!(stats[3:end], Val(:kde))
end

function onlineplot!(axis, stat::AutoCov, stats, iter, data, iterations, i, j)
b = length(stat.cross) - 1
stat = Observable(AutoCov(b, Float32))
on(iter) do _
stat[] = fit!(stat[], last(value(data)))
stat[] = fit!(stat[], last(value(data[])))
end
push!(stats, stat)
statvals = map!(Observable(zeros(Float32, b + 1)), stat) do s
value(s)
end
push!(stats, statvals)
scatter!(axis, Point2f0.([0.0, b], [-0.1, 1.0]), markersize = 0.0, color = RGBA(0.0, 0.0, 0.0, 0.0)) # Invisible points to keep limits fixed
lines!(axis, 0:b, statvals, color = std_colors[i], linewidth = 3.0)
end

function reset!(stats, stat::AutoCov)
b = length(stat.cross) - 1
stats[1].val = AutoCov(b, Float32)
stats[2].val = zeros(Float32, b + 1)
end
50 changes: 37 additions & 13 deletions src/live_sampling/turkie_callback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@ See the docs for some examples.
"""
TurkieCallback

struct TurkieCallback{TN<:NamedTuple,TD<:AbstractDict}
struct TurkieCallback{TN<:NamedTuple,TS<:AbstractDict,TD<:AbstractDict}
figure::Figure
data::Dict{Symbol, MovingWindow}
data::Dict{Symbol,Observable{MovingWindow}}
axis_dict::Dict
vars::TN
stats::TS
params::TD
iter::Observable{Int}
end
Expand All @@ -51,17 +52,18 @@ function TurkieCallback(vars::NamedTuple, params::Dict)
resolution = get!(params, :resolution, (1200, 700))
fig = Figure(;resolution=resolution, figure_padding=outer_padding)
window = get!(params, :window, 1000)
refresh = get!(params, :refresh, false)
get!(params, :refresh, false)
params[:t0] = 0
iter = Observable(0)
data = Dict{Symbol, MovingWindow}(:iter => MovingWindow(window, Int))
data = Dict{Symbol, Observable{MovingWindow}}(:iter => Node(MovingWindow(window, Int)))
axis_dict = Dict()
stats_dict = Dict()
for (i, variable) in enumerate(keys(vars))
plots = vars[variable]
data[variable] = MovingWindow(window, Float32)
data[variable] = Node(MovingWindow(window, Float32))
axis_dict[(variable, :varname)] = fig[i, 1, Left()] = Label(fig, string(variable), textsize = 30)
axis_dict[(variable, :varname)].padding = (0, 60, 0, 0)
onlineplot!(fig, axis_dict, plots, iter, data, variable, i)
onlineplot!(fig, axis_dict, plots, stats_dict, iter, data, variable, i)
end
on(iter) do i
if i > 1 # To deal with autolimits a certain number of samples are needed
Expand All @@ -74,32 +76,54 @@ function TurkieCallback(vars::NamedTuple, params::Dict)
end
MakieLayout.trim!(fig.layout)
display(fig)
return TurkieCallback(fig, data, axis_dict, vars, params, iter)
return TurkieCallback(fig, data, axis_dict, vars, stats_dict, params, iter)
end

function Base.show(io::IO, cb::TurkieCallback)
show(io, cb.figure)
end

function Base.show(io::IO, ::MIME"text/plain", cb::TurkieCallback)
print(io, "TurkieCallback tracking the following variables:\n")
for v in keys(cb.vars)
print(io, " ", v, "\t=> [")
for s in cb.vars[v][1:end-1]
print(io, name(s), ", ")
end
print(io, name(cb.vars[v][end]), "]\n")
end
end

function addIO!(cb::TurkieCallback, io)
cb.params[:io] = io
end

function (cb::TurkieCallback)(rng, model, sampler, transition, state, iteration; kwargs...)
if iteration == 1
if iteration == 1 && cb.iter[] != 0
if cb.params[:refresh]
refresh_plots!(cb)
end
cb.params[:t0] = cb.iter[]
end
fit!(cb.data[:iter], iteration + cb.params[:t0]) # Update the iteration value
for (variable, val) in zip(_params_to_array([transition])...)
fit!(cb.data[:iter][], iteration + cb.params[:t0]) # Update the iteration value
for (variable, val) in zip(Inference._params_to_array([transition])...)
if haskey(cb.data, variable) # Check if symbol should be plotted
fit!(cb.data[variable], Float32(val)) # Update its value
fit!(cb.data[variable][], Float32(val)) # Update its value
end
end
cb.iter[] = cb.iter[] + 1
cb.iter[] = cb.iter[] + 1 # Triggers all the updates
if haskey(cb.params, :io)
recordframe!(cb.params[:io])
end
end

function refresh_plots!(cb)
#TODO
for v in keys(cb.vars)
cb.data[v].val = MovingWindow(cb.params[:window], Float32) # Reset the moving window
for stat in cb.vars[v]
reset!(cb.stats[(v, stat)], stat) # Reset the stats observables
end
end
cb.data[:iter].val = MovingWindow(cb.params[:window], Int)
cb.iter.val = 0
end
6 changes: 6 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,9 @@ ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4"
OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

[compat]
CairoMakie = "0.6"
ColorSchemes = "3"
OnlineStats = "1"
Turing = "0.15, 0.16, 0.17, 0.18, 0.19"
8 changes: 3 additions & 5 deletions test/dev_test.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
using Pkg; Pkg.activate("..")
using Turing
using Turkie
using GLMakie # You could also use CairoMakie or another backend
using CairoMakie
GLMakie.activate!()
CairoMakie.activate!()
Turing.@model function demo(x) # Some random Turing model
m0 ~ Normal(0, 2)
s ~ InverseGamma(2, 3)
Expand All @@ -16,8 +13,9 @@ end

xs = randn(100) .+ 1;
m = demo(xs);
cb = TurkieCallback(m) # Create a callback function to be given to the sample function
chain = sample(m, NUTS(0.65), 30; callback = cb)
cb = TurkieCallback(m; refresh=false, resolution=(400, 400)) # Create a callback function to be given to the sample function
chain = sample(m, NUTS(), 300; callback = cb);
chain = sample(m, MH(), 30; callback = cb)


record(cb.figure, joinpath(@__DIR__, "video.gif")) do io
Expand Down
Loading