diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 28cab42..5544ef6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -14,7 +14,7 @@ jobs: matrix: version: - '1' - - '1.5' + - '1.6' - 'nightly' os: - ubuntu-latest diff --git a/Project.toml b/Project.toml index a7501f6..ceeaba6 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Turkie" uuid = "8156cc02-0533-41cd-9345-13411ebe105f" authors = ["Theo Galy-Fajou and contributors"] -version = "0.1.9" +version = "0.1.10" [deps] ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4" @@ -13,17 +13,11 @@ OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [compat] -ColorSchemes = "3.10" +ColorSchemes = "3" Colors = "0.12" -KernelDensity = "0.5, 0.6" +KernelDensity = "0.6" MCMCChains = "4, 5" -Makie = "0.13, 0.14, 0.15" +Makie = "0.15" OnlineStats = "1.5" -Turing = "0.15, 0.16, 0.17, 0.18" -julia = "1.5" - -[extras] -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[targets] -test = ["Test"] +Turing = "0.15, 0.16, 0.17, 0.18, 0.19" +julia = "1.6" \ No newline at end of file diff --git a/docs/src/index.md b/docs/src/index.md index d1f368f..e7fef97 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -94,7 +94,7 @@ If you want a specific implementation of a certain stat please [open an issue](h If you want to make a cool animation you can use the built-in recording features of Makie. Here is the simple example, using the Turing example from above: ```julia -record(cb.scene, joinpath(@__DIR__, "video.webm")) do io +record(cb.figure, joinpath(@__DIR__, "video.webm")) do io addIO!(cb, io) sample(m, NUTS(0.65), 300; callback = cb) end diff --git a/src/Turkie.jl b/src/Turkie.jl index fbe6b02..b2e917a 100644 --- a/src/Turkie.jl +++ b/src/Turkie.jl @@ -1,6 +1,6 @@ module Turkie -using Makie: Makie, Figure, Scene, Point2f0 +using Makie: Makie, Figure, Axis, Point2f0 using Makie: barplot!, lines!, scatter! # Plotting tools using Makie: Observable, Node, lift, on # Observable tools using Makie: recordframe! # Recording tools @@ -8,7 +8,7 @@ using Makie.MakieLayout # Layouting tool using Colors, ColorSchemes # Colors tools using KernelDensity # To be able to give a KDE using OnlineStats # Estimators -using Turing: DynamicPPL.VarInfo, DynamicPPL.Model, Inference._params_to_array +using Turing: DynamicPPL.VarInfo, DynamicPPL.Model, Inference, Inference._params_to_array using MCMCChains diff --git a/src/live_sampling/online_stats_plots.jl b/src/live_sampling/online_stats_plots.jl index ed0dbb1..ea409a9 100644 --- a/src/live_sampling/online_stats_plots.jl +++ b/src/live_sampling/online_stats_plots.jl @@ -1,13 +1,24 @@ -function onlineplot!(fig, axis_dict, stats::AbstractVector, iter, data, variable, i) +function onlineplot!(fig::Figure, axis_dict::AbstractDict, stats::AbstractVector, stats_dict::AbstractDict, iter, data, variable, i) + # Iter over all the stats for a given variable, create an axis and add + # the appropriate plots on it. for (j, stat) in enumerate(stats) axis_dict[(variable, stat)] = fig[i, j] = Axis(fig, title="$(name(stat))") limits!(axis_dict[(variable, stat)], 0.0, 10.0, -1.0, 1.0) - onlineplot!(axis_dict[(variable, stat)], stat, iter, data[variable], data[:iter], i, j) + stats_dict[(variable, stat)] = [] + onlineplot!(axis_dict[(variable, stat)], stat, stats_dict[(variable, stat)], iter, data[variable], data[:iter], i, j) tight_ticklabel_spacing!(axis_dict[(variable, stat)]) end end -function onlineplot!(axis, stat::Symbol, args...) +# Reset the saved stats to be able to refresh the plots when wanted +reset!(::Any, ::Any) = nothing # Default behavior is to do nothing +reset!(stats, stat::Symbol) = reset!(stats, Val(stat)) +reset!(stats, ::Val{:mean}) = reset!(stats, Mean()) +reset!(stats, ::Val{:var}) = reset!(stats, Variance()) +reset!(stats, ::Val{:autocov}) = reset!(stats, AutoCov(20)) +reset!(stats, ::Val{:hist}) = reset!(stats, KHist(50, Float32)) + +function onlineplot!(axis::Axis, stat::Symbol, args...) onlineplot!(axis, Val(stat), args...) end @@ -20,44 +31,62 @@ onlineplot!(axis, ::Val{:autocov}, args...) = onlineplot!(axis, AutoCov(20), arg onlineplot!(axis, ::Val{:hist}, args...) = onlineplot!(axis, KHist(50, Float32), args...) # Generic fallback for OnlineStat objects -function onlineplot!(axis, stat::T, iter, data, iterations, i, j) where {T<:OnlineStat} - window = data.b +function onlineplot!(axis, stat::T, stats, iter, data, iterations, i, j) where {T<:OnlineStat} + window = data[].b @eval TStat = $(nameof(T)) + # Create an observable based on the given stat stat = Observable(TStat(Float32)) on(iter) do _ - stat[] = fit!(stat[], last(value(data))) + stat[] = fit!(stat[], last(value(data[]))) end + push!(stats, stat) + # Create a moving window on this value statvals = Observable(MovingWindow(window, Float32)) on(stat) do s statvals[] = fit!(statvals[], Float32(value(s))) end + push!(stats, statvals) + # Pass this observable to create points to pass to Makie statpoints = map!(Observable(Point2f0.([0], [0])), statvals) do v - Point2f0.(value(iterations), value(v)) + Point2f0.(value(iterations[]), value(v)) end lines!(axis, statpoints, color = std_colors[i], linewidth = 3.0) end -function onlineplot!(axis, ::Val{:trace}, iter, data, iterations, i, j) +function reset!(stats, stat::T) where {T<:OnlineStat} + @eval TStat = $(nameof(T)) + stats[1].val = TStat(Float32) # Represent the actual stat + stats[2].val = MovingWindow(stats[2][].b, Float32) # Represent the moving window on the stat +end + +function onlineplot!(axis, ::Val{:trace}, stats, iter, data, iterations, i, j) trace = map!(Observable([Point2f0(0, 0)]), iter) do _ - Point2f0.(value(iterations), value(data)) + Point2f0.(value(iterations[]), value(data[])) end lines!(axis, trace, color = std_colors[i]; linewidth = 3.0) end -function onlineplot!(axis, stat::KHist, iter, data, iterations, i, j) +function onlineplot!(axis, stat::KHist, stats, iter, data, iterations, i, j) nbins = stat.k stat = Observable(KHist(nbins, Float32)) on(iter) do _ - stat[] = fit!(stat[], last(value(data))) + stat[] = fit!(stat[], last(value(data[]))) end hist_vals = Node(Point2f0.(collect(range(0f0, 1f0, length=nbins)), zeros(Float32, nbins))) + push!(stats, stat) on(stat) do h edges, weights = OnlineStats.xy(h) weights = nobs(h) > 1 ? weights / OnlineStats.area(h) : weights hist_vals[] = Point2f0.(edges, weights) end + push!(stats, hist_vals) barplot!(axis, hist_vals; color=std_colors[i]) - # barplot!(axis, rand(4), rand(4)) +end + +function reset!(stats, stat::KHist) + nbins = stat.k + stats[1].val = KHist(nbins, Float32) + stats[2].val = Point2f0.(collect(range(0f0, 1f0, length=nbins)), zeros(Float32, nbins)) end function expand_extrema(xs) @@ -68,37 +97,57 @@ function expand_extrema(xs) return (xmin, xmax) end -function onlineplot!(axis, ::Val{:kde}, iter, data, iterations, i, j) +function onlineplot!(axis, ::Val{:kde}, stats, iter, data, iterations, i, j) interpkde = Observable(InterpKDE(kde([1f0]))) on(iter) do _ - interpkde[] = InterpKDE(kde(value(data))) + interpkde[] = InterpKDE(kde(value(data[]))) end + push!(stats, interpkde) xs = Observable(range(0, 2, length=10)) on(iter) do _ - xs[] = range(expand_extrema(extrema(value(data)))..., length = 200) + xs[] = range(expand_extrema(extrema(value(data[])))..., length = 200) end + push!(stats, xs) kde_pdf = lift(xs) do xs pdf.(Ref(interpkde[]), xs) end lines!(axis, xs, kde_pdf, color = std_colors[i], linewidth = 3.0) end +function reset!(stats, ::Val{:kde}) + stats[1].val = InterpKDE(kde([1f0])) + stats[2].val = range(0, 2, length=10) +end + name(s::Val{:histkde}) = "Hist. + KDE" -function onlineplot!(axis, ::Val{:histkde}, iter, data, iterations, i, j) - onlineplot!(axis, KHist(50), iter, data, iterations, i, j) - onlineplot!(axis, Val(:kde), iter, data, iterations, i, j) +function onlineplot!(axis, ::Val{:histkde}, stats, iter, data, iterations, i, j) + onlineplot!(axis, KHist(50), stats, iter, data, iterations, i, j) + onlineplot!(axis, Val(:kde), stats, iter, data, iterations, i, j) end -function onlineplot!(axis, stat::AutoCov, iter, data, iterations, i, j) - b = length(stat.cross) +function reset!(stats, ::Val{:histkde}) + reset!(stats[1:2], KHist(50)) + reset!(stats[3:end], Val(:kde)) +end + +function onlineplot!(axis, stat::AutoCov, stats, iter, data, iterations, i, j) + b = length(stat.cross) - 1 stat = Observable(AutoCov(b, Float32)) on(iter) do _ - stat[] = fit!(stat[], last(value(data))) + stat[] = fit!(stat[], last(value(data[]))) end + push!(stats, stat) statvals = map!(Observable(zeros(Float32, b + 1)), stat) do s value(s) end + push!(stats, statvals) scatter!(axis, Point2f0.([0.0, b], [-0.1, 1.0]), markersize = 0.0, color = RGBA(0.0, 0.0, 0.0, 0.0)) # Invisible points to keep limits fixed lines!(axis, 0:b, statvals, color = std_colors[i], linewidth = 3.0) end + +function reset!(stats, stat::AutoCov) + b = length(stat.cross) - 1 + stats[1].val = AutoCov(b, Float32) + stats[2].val = zeros(Float32, b + 1) +end \ No newline at end of file diff --git a/src/live_sampling/turkie_callback.jl b/src/live_sampling/turkie_callback.jl index 24e5502..af9d61f 100644 --- a/src/live_sampling/turkie_callback.jl +++ b/src/live_sampling/turkie_callback.jl @@ -20,11 +20,12 @@ See the docs for some examples. """ TurkieCallback -struct TurkieCallback{TN<:NamedTuple,TD<:AbstractDict} +struct TurkieCallback{TN<:NamedTuple,TS<:AbstractDict,TD<:AbstractDict} figure::Figure - data::Dict{Symbol, MovingWindow} + data::Dict{Symbol,Observable{MovingWindow}} axis_dict::Dict vars::TN + stats::TS params::TD iter::Observable{Int} end @@ -51,17 +52,18 @@ function TurkieCallback(vars::NamedTuple, params::Dict) resolution = get!(params, :resolution, (1200, 700)) fig = Figure(;resolution=resolution, figure_padding=outer_padding) window = get!(params, :window, 1000) - refresh = get!(params, :refresh, false) + get!(params, :refresh, false) params[:t0] = 0 iter = Observable(0) - data = Dict{Symbol, MovingWindow}(:iter => MovingWindow(window, Int)) + data = Dict{Symbol, Observable{MovingWindow}}(:iter => Node(MovingWindow(window, Int))) axis_dict = Dict() + stats_dict = Dict() for (i, variable) in enumerate(keys(vars)) plots = vars[variable] - data[variable] = MovingWindow(window, Float32) + data[variable] = Node(MovingWindow(window, Float32)) axis_dict[(variable, :varname)] = fig[i, 1, Left()] = Label(fig, string(variable), textsize = 30) axis_dict[(variable, :varname)].padding = (0, 60, 0, 0) - onlineplot!(fig, axis_dict, plots, iter, data, variable, i) + onlineplot!(fig, axis_dict, plots, stats_dict, iter, data, variable, i) end on(iter) do i if i > 1 # To deal with autolimits a certain number of samples are needed @@ -74,7 +76,22 @@ function TurkieCallback(vars::NamedTuple, params::Dict) end MakieLayout.trim!(fig.layout) display(fig) - return TurkieCallback(fig, data, axis_dict, vars, params, iter) + return TurkieCallback(fig, data, axis_dict, vars, stats_dict, params, iter) +end + +function Base.show(io::IO, cb::TurkieCallback) + show(io, cb.figure) +end + +function Base.show(io::IO, ::MIME"text/plain", cb::TurkieCallback) + print(io, "TurkieCallback tracking the following variables:\n") + for v in keys(cb.vars) + print(io, " ", v, "\t=> [") + for s in cb.vars[v][1:end-1] + print(io, name(s), ", ") + end + print(io, name(cb.vars[v][end]), "]\n") + end end function addIO!(cb::TurkieCallback, io) @@ -82,24 +99,31 @@ function addIO!(cb::TurkieCallback, io) end function (cb::TurkieCallback)(rng, model, sampler, transition, state, iteration; kwargs...) - if iteration == 1 + if iteration == 1 && cb.iter[] != 0 if cb.params[:refresh] refresh_plots!(cb) end cb.params[:t0] = cb.iter[] end - fit!(cb.data[:iter], iteration + cb.params[:t0]) # Update the iteration value - for (variable, val) in zip(_params_to_array([transition])...) + fit!(cb.data[:iter][], iteration + cb.params[:t0]) # Update the iteration value + for (variable, val) in zip(Inference._params_to_array([transition])...) if haskey(cb.data, variable) # Check if symbol should be plotted - fit!(cb.data[variable], Float32(val)) # Update its value + fit!(cb.data[variable][], Float32(val)) # Update its value end end - cb.iter[] = cb.iter[] + 1 + cb.iter[] = cb.iter[] + 1 # Triggers all the updates if haskey(cb.params, :io) recordframe!(cb.params[:io]) end end function refresh_plots!(cb) - #TODO + for v in keys(cb.vars) + cb.data[v].val = MovingWindow(cb.params[:window], Float32) # Reset the moving window + for stat in cb.vars[v] + reset!(cb.stats[(v, stat)], stat) # Reset the stats observables + end + end + cb.data[:iter].val = MovingWindow(cb.params[:window], Int) + cb.iter.val = 0 end \ No newline at end of file diff --git a/test/Project.toml b/test/Project.toml index ecfdded..8991299 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -4,3 +4,9 @@ ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4" OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" + +[compat] +CairoMakie = "0.6" +ColorSchemes = "3" +OnlineStats = "1" +Turing = "0.15, 0.16, 0.17, 0.18, 0.19" \ No newline at end of file diff --git a/test/dev_test.jl b/test/dev_test.jl index 7e57c57..e7b141a 100644 --- a/test/dev_test.jl +++ b/test/dev_test.jl @@ -1,10 +1,7 @@ -using Pkg; Pkg.activate("..") using Turing using Turkie using GLMakie # You could also use CairoMakie or another backend -using CairoMakie GLMakie.activate!() -CairoMakie.activate!() Turing.@model function demo(x) # Some random Turing model m0 ~ Normal(0, 2) s ~ InverseGamma(2, 3) @@ -16,8 +13,9 @@ end xs = randn(100) .+ 1; m = demo(xs); -cb = TurkieCallback(m) # Create a callback function to be given to the sample function -chain = sample(m, NUTS(0.65), 30; callback = cb) +cb = TurkieCallback(m; refresh=false, resolution=(400, 400)) # Create a callback function to be given to the sample function +chain = sample(m, NUTS(), 300; callback = cb); +chain = sample(m, MH(), 30; callback = cb) record(cb.figure, joinpath(@__DIR__, "video.gif")) do io diff --git a/test/runtests.jl b/test/runtests.jl index 84210a6..965c8c7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,7 +1,6 @@ -using Turkie using Test +using Turkie using CairoMakie -CairoMakie.activate!() using OnlineStats using Turing using ColorSchemes @@ -27,7 +26,7 @@ using ColorSchemes cb = TurkieCallback(model; blah=2.0) @test cb.figure isa Figure @test sort(collect(keys(cb.data))) == sort(vcat(vars, :iter)) - @test cb.data[:m] isa MovingWindow{Float32} + @test cb.data[:m] isa Observable{MovingWindow} @test sort(collect(keys(cb.vars))) == sort(vars) @test cb.vars[:m][1] == :histkde @test cb.vars[:m][2] == Mean(Float32) @@ -45,7 +44,7 @@ using ColorSchemes cb = TurkieCallback(Dict(:m => [stat])) sample(model, MH(), 50; progress=false, callback=cb) end - end + end @testset "Series" begin for stat in [Mean(Float32), Variance(Float32)] cb = TurkieCallback(model, OnlineStats.Series(stat)) @@ -53,5 +52,23 @@ using ColorSchemes end end end + @testset "Trying same thing with refresh" begin + @testset "Vector of symbols" begin + for stat in [:histkde, :kde, :hist, :mean, :var, :trace, :autocov] + @testset "$stat" begin + cb = TurkieCallback(Dict(:m => [stat]), refresh=true) + sample(model, MH(), 20; progress=false, callback=cb); + sample(model, MH(), 20; progress=false, callback=cb); + end + end + end + @testset "Series" begin + for stat in [Mean(Float32), Variance(Float32)] + cb = TurkieCallback(model, OnlineStats.Series(stat), refresh=true) + sample(model, MH(), 20; progress=false, callback=cb); + sample(model, MH(), 20; progress=false, callback=cb); + end + end + end end