Skip to content

Commit

Permalink
add thinning to Geweke tests
Browse files Browse the repository at this point in the history
  • Loading branch information
goldingn committed Nov 1, 2024
1 parent 99fdaf3 commit 3620d48
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 12 deletions.
22 changes: 12 additions & 10 deletions tests/testthat/helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,8 @@ check_geweke <- function(
p_theta,
p_x_bar_theta,
niter = 2000,
warmup = 1000
warmup = 1000,
thin = 1
) {
# sample independently
target_theta <- p_theta(niter)
Expand All @@ -670,9 +671,10 @@ check_geweke <- function(
warmup = warmup
)

geweke_checks <- list(target_theta = target_theta,
greta_theta = greta_theta)

geweke_checks <- list(
target_theta = do_thinning(target_theta, thin),
greta_theta = do_thinning(greta_theta, thin)
)

geweke_checks

Expand Down Expand Up @@ -900,6 +902,12 @@ check_mvn_samples <- function(sampler, n_effective = 3000) {
errors
}

do_thinning <- function(x, thinning = 1) {
idx <- seq(1, length(x), by = thinning)
x[idx]
}


# sample values of greta array 'x' (which must follow a distribution), and
# compare the samples with iid samples returned by iid_function (which takes the
# number of arguments as its sole argument), producing a labelled qqplot, and
Expand Down Expand Up @@ -928,15 +936,9 @@ check_samples <- function(
iid_samples <- iid_function(neff)
mcmc_samples <- as.matrix(draws)

do_thinning <- function(x, thinning = 1) {
idx <- seq(1, length(x), by = thinning)
x[idx]
}

mcmc_samples <- do_thinning(mcmc_samples, thin)
iid_samples <- do_thinning(iid_samples, thin)


# # plot
# if (is.null(title)) {
# distrib <- get_node(x)$distribution$distribution_name
Expand Down
6 changes: 4 additions & 2 deletions tests/testthat/test_posteriors_geweke.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ test_that("samplers pass geweke tests", {
model = model,
data = x,
p_theta = p_theta,
p_x_bar_theta = p_x_bar_theta
p_x_bar_theta = p_x_bar_theta,
thin = 5
)

geweke_qq(geweke_hmc, title = "HMC Geweke test")
Expand All @@ -56,7 +57,8 @@ test_that("samplers pass geweke tests", {
data = x,
p_theta = p_theta,
p_x_bar_theta = p_x_bar_theta,
warmup = 2000
warmup = 2000,
thin = 5
)

geweke_qq(geweke_hmc_rwmh, title = "RWMH Geweke test")
Expand Down

0 comments on commit 3620d48

Please sign in to comment.