Skip to content

Commit

Permalink
Update Turkie to new versions of Turing and Makie (#26)
Browse files Browse the repository at this point in the history
* General updates, update AbstractPlotting to Makie

* Working version

* Patch bump
  • Loading branch information
theogf authored May 26, 2021
1 parent 238a31c commit a05cc0a
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 40 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ test/Manifest.toml
test/video.gif
test/video.webm
docs/build/
docs/Manifest.toml
docs/Manifest.toml
.vscode/settings.json
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
name = "Turkie"
uuid = "8156cc02-0533-41cd-9345-13411ebe105f"
authors = ["Theo Galy-Fajou <[email protected]> and contributors"]
version = "0.1.3"
version = "0.1.4"

[deps]
AbstractPlotting = "537997a7-5e4e-5d89-9595-2241ea00577e"
ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4"
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
KernelDensity = "5ab0869b-81aa-558d-bb23-cbf5423bbe9b"
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

[compat]
AbstractPlotting = "0.15, 0.16"
ColorSchemes = "3.10"
Colors = "0.12"
KernelDensity = "0.5, 0.6"
Makie = "0.13"
OnlineStats = "1.5"
Turing = "0.15"
julia = "1.4"
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ Small example:
```julia
using Turing
using Turkie
using Makie # You could also use CairoMakie or another backend
using GLMakie # You could also use CairoMakie or another backend
@model function demo(x) # Some random Turing model
m0 ~ Normal(0, 2)
s ~ InverseGamma(2, 3)
Expand Down Expand Up @@ -76,7 +76,7 @@ If you want to record the video do

```julia
using Makie
record(cb.scene, joinpath(@__DIR__, "video.webm")) do io
record(cb, joinpath(@__DIR__, "video.webm")) do io
addIO!(cb, io)
sample(m, NUTS(0.65), 300; callback = cb)
end
Expand Down
29 changes: 15 additions & 14 deletions src/Turkie.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
module Turkie

using AbstractPlotting: Scene, Point2f0
using AbstractPlotting: barplot!, lines!, scatter! # Plotting tools
using AbstractPlotting: Observable, Node, lift, on # Observable tools
using AbstractPlotting: recordframe! # Recording tools
using AbstractPlotting.MakieLayout # Layouting tool
using Makie: Figure, Scene, Point2f0
using Makie: barplot!, lines!, scatter! # Plotting tools
using Makie: Observable, Node, lift, on # Observable tools
using Makie: recordframe! # Recording tools
using Makie.MakieLayout # Layouting tool
using Colors, ColorSchemes # Colors tools
using KernelDensity # To be able to give a KDE
using OnlineStats # Estimators
Expand All @@ -16,6 +16,7 @@ export addIO!, record

include("online_stats_plots.jl")

# Uses the colorblind scheme of seaborn by default
const std_colors = ColorSchemes.seaborn_colorblind

name(s::Symbol) = name(Val(s))
Expand Down Expand Up @@ -43,7 +44,7 @@ See the docs for some examples.
TurkieCallback

struct TurkieCallback{TN<:NamedTuple,TD<:AbstractDict}
scene::Scene
figure::Figure
data::Dict{Symbol, MovingWindow}
axis_dict::Dict
vars::TN
Expand All @@ -70,7 +71,7 @@ end
function TurkieCallback(vars::NamedTuple, params::Dict)
# Create a scene and a layout
outer_padding = 5
scene, layout = layoutscene(outer_padding, resolution = (1200, 700))
fig = Figure(;resolution = (1200, 700), figure_padding=outer_padding)
window = get!(params, :window, 1000)
refresh = get!(params, :refresh, false)
params[:t0] = 0
Expand All @@ -81,9 +82,9 @@ function TurkieCallback(vars::NamedTuple, params::Dict)
for (i, variable) in enumerate(keys(vars))
plots = vars[variable]
data[variable] = MovingWindow(window, Float32)
axis_dict[(variable, :varname)] = layout[i, 1, Left()] = Label(scene, string(variable), textsize = 30)
axis_dict[(variable, :varname)] = fig[i, 1, Left()] = Label(fig, string(variable), textsize = 30)
axis_dict[(variable, :varname)].padding = (0, 60, 0, 0)
onlineplot!(scene, layout, axis_dict, plots, iter, data, variable, i)
onlineplot!(fig, axis_dict, plots, iter, data, variable, i)
end
on(iter) do i
if i > 1 # To deal with autolimits a certain number of samples are needed
Expand All @@ -94,16 +95,16 @@ function TurkieCallback(vars::NamedTuple, params::Dict)
end
end
end
MakieLayout.trim!(layout)
display(scene)
TurkieCallback(scene, data, axis_dict, vars, params, iter)
MakieLayout.trim!(fig.layout)
display(fig)
return TurkieCallback(fig, data, axis_dict, vars, params, iter)
end

function addIO!(cb::TurkieCallback, io)
cb.params[:io] = io
end

function (cb::TurkieCallback)(rng, model, sampler, transition, iteration)
function (cb::TurkieCallback)(rng, model, sampler, transition, state, iteration; kwargs...)
if iteration == 1
if cb.params[:refresh]
refresh_plots!(cb)
Expand All @@ -116,7 +117,7 @@ function (cb::TurkieCallback)(rng, model, sampler, transition, iteration)
fit!(cb.data[variable], Float32(val)) # Update its value
end
end
cb.iter[] += 1
cb.iter[] = cb.iter[] + 1
if haskey(cb.params, :io)
recordframe!(cb.params[:io])
end
Expand Down
34 changes: 18 additions & 16 deletions src/online_stats_plots.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
function onlineplot!(scene, layout, axis_dict, stats::AbstractVector, iter, data, variable, i)
function onlineplot!(fig, axis_dict, stats::AbstractVector, iter, data, variable, i)
for (j, stat) in enumerate(stats)
axis_dict[(variable, stat)] = layout[i, j] = Axis(scene, title = "$(name(stat))")
axis_dict[(variable, stat)] = fig[i, j] = Axis(fig, title="$(name(stat))")
limits!(axis_dict[(variable, stat)], 0.0, 10.0, -1.0, 1.0)
onlineplot!(axis_dict[(variable, stat)], stat, iter, data[variable], data[:iter], i, j)
tight_ticklabel_spacing!(axis_dict[(variable, stat)])
Expand All @@ -19,27 +19,26 @@ onlineplot!(axis, ::Val{:autocov}, args...) = onlineplot!(axis, AutoCov(20), arg

onlineplot!(axis, ::Val{:hist}, args...) = onlineplot!(axis, KHist(50, Float32), args...)


# Generic fallback for OnlineStat objects
function onlineplot!(axis, stat::T, iter, data, iterations, i, j) where {T<:OnlineStat}
window = data.b
@eval TStat = $(nameof(T))
stat = Observable(TStat(Float32))
on(iter) do i
on(iter) do _
stat[] = fit!(stat[], last(value(data)))
end
statvals = Observable(MovingWindow(window, Float32))
on(stat) do s
statvals[] = fit!(statvals[], Float32(value(s)))
end
statpoints = lift(statvals; init = Point2f0.([0], [0])) do v
statpoints = map!(Observable(Point2f0.([0], [0])), statvals) do v
Point2f0.(value(iterations), value(v))
end
lines!(axis, statpoints, color = std_colors[i], linewidth = 3.0)
end

function onlineplot!(axis, ::Val{:trace}, iter, data, iterations, i, j)
trace = lift(iter; init = [Point2f0(0f0, 0f0)]) do i
trace = map!(Observable([Point2f0(0, 0)]), iter) do _
Point2f0.(value(iterations), value(data))
end
lines!(axis, trace, color = std_colors[i]; linewidth = 3.0)
Expand All @@ -48,15 +47,17 @@ end
function onlineplot!(axis, stat::KHist, iter, data, iterations, i, j)
nbins = stat.k
stat = Observable(KHist(nbins, Float32))
on(iter) do i
on(iter) do _
stat[] = fit!(stat[], last(value(data)))
end
hist_vals = lift(stat; init = Point2f0.(range(0, 1, length = nbins), zeros(Float32, nbins))) do h
hist_vals = Node(Point2f0.(collect(range(0f0, 1f0, length=nbins)), zeros(Float32, nbins)))
on(stat) do h
edges, weights = OnlineStats.xy(h)
weights = nobs(h) > 1 ? weights / OnlineStats.area(h) : weights
return Point2f0.(edges, weights)
hist_vals[] = Point2f0.(edges, weights)
end
barplot!(axis, hist_vals, color = std_colors[i])
barplot!(axis, hist_vals; color=std_colors[i])
# barplot!(axis, rand(4), rand(4))
end

function expand_extrema(xs)
Expand All @@ -69,19 +70,20 @@ end

function onlineplot!(axis, ::Val{:kde}, iter, data, iterations, i, j)
interpkde = Observable(InterpKDE(kde([1f0])))
on(iter) do i
on(iter) do _
interpkde[] = InterpKDE(kde(value(data)))
end
xs = lift(iter; init = range(0.0, 2.0, length = 200)) do i
range(expand_extrema(extrema(value(data)))..., length = 200)
xs = Observable(range(0, 2, length=10))
on(iter) do _
xs[] = range(expand_extrema(extrema(value(data)))..., length = 200)
end
kde_pdf = lift(xs) do xs
pdf.(Ref(interpkde[]), xs)
end
lines!(axis, xs, kde_pdf, color = std_colors[i], linewidth = 3.0)
end

name(s::Val{:histkde}) = "Hist + KDE"
name(s::Val{:histkde}) = "Hist. + KDE"

function onlineplot!(axis, ::Val{:histkde}, iter, data, iterations, i, j)
onlineplot!(axis, KHist(50), iter, data, iterations, i, j)
Expand All @@ -91,10 +93,10 @@ end
function onlineplot!(axis, stat::AutoCov, iter, data, iterations, i, j)
b = length(stat.cross)
stat = Observable(AutoCov(b, Float32))
on(iter) do i
on(iter) do _
stat[] = fit!(stat[], last(value(data)))
end
statvals = lift(stat; init = zeros(Float32, b + 1)) do s
statvals = map!(Observable(zeros(Float32, b + 1)), stat) do s
value(s)
end
scatter!(axis, Point2f0.([0.0, b], [-0.1, 1.0]), markersize = 0.0, color = RGBA(0.0, 0.0, 0.0, 0.0)) # Invisible points to keep limits fixed
Expand Down
3 changes: 2 additions & 1 deletion test/dev_test.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using Pkg; Pkg.activate("..")
using Turing
using Turkie
using GLMakie # You could also use CairoMakie or another backend
Expand All @@ -19,7 +20,7 @@ cb = TurkieCallback(m) # Create a callback function to be given to the sample fu
chain = sample(m, NUTS(0.65), 30; callback = cb)


record(cb.scene, joinpath(@__DIR__, "video.gif")) do io
record(cb.figure, joinpath(@__DIR__, "video.gif")) do io
addIO!(cb, io)
sample(m, NUTS(0.65), 50; callback = cb)
end
Expand Down
7 changes: 4 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Turkie
using Test
using CairoMakie
CairoMakie.activate!()
using OnlineStats
using Turing

Expand All @@ -23,7 +24,7 @@ using Turing
@test Turkie.name(OnlineStats.Mean(Float32)) == "Mean"

cb = TurkieCallback(model; blah=2.0)
@test cb.scene isa Scene
@test cb.figure isa Figure
@test sort(collect(keys(cb.data))) == sort(vcat(vars, :iter))
@test cb.data[:m] isa MovingWindow{Float32}
@test sort(collect(keys(cb.vars))) == sort(vars)
Expand All @@ -41,13 +42,13 @@ using Turing
@testset "Vector of symbols" begin
for stat in [:histkde, :kde, :hist, :mean, :var, :trace, :autocov]
cb = TurkieCallback(Dict(:m => [stat]))
sample(model, MH(), 50; callback = cb)
sample(model, MH(), 50; callback=cb)
end
end
@testset "Series" begin
for stat in [Mean(Float32), Variance(Float32)]
cb = TurkieCallback(model, OnlineStats.Series(stat))
sample(model, MH(), 50; callback = cb)
sample(model, MH(), 50; callback=cb)
end
end
end
Expand Down

2 comments on commit a05cc0a

@theogf
Copy link
Owner Author

@theogf theogf commented on a05cc0a May 26, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/37548

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.4 -m "<description of version>" a05cc0a11ad15820db41805150658b7c612cb508
git push origin v0.1.4

Please sign in to comment.