diff --git a/tests/testthat/helpers.R b/tests/testthat/helpers.R index cd55c2af..ff90097c 100644 --- a/tests/testthat/helpers.R +++ b/tests/testthat/helpers.R @@ -654,7 +654,8 @@ check_geweke <- function( p_theta, p_x_bar_theta, niter = 2000, - warmup = 1000 + warmup = 1000, + thin = 1 ) { # sample independently target_theta <- p_theta(niter) @@ -670,9 +671,10 @@ check_geweke <- function( warmup = warmup ) - geweke_checks <- list(target_theta = target_theta, - greta_theta = greta_theta) - + geweke_checks <- list( + target_theta = do_thinning(target_theta, thin), + greta_theta = do_thinning(greta_theta, thin) + ) geweke_checks @@ -900,6 +902,12 @@ check_mvn_samples <- function(sampler, n_effective = 3000) { errors } +do_thinning <- function(x, thinning = 1) { + idx <- seq(1, length(x), by = thinning) + x[idx] +} + + # sample values of greta array 'x' (which must follow a distribution), and # compare the samples with iid samples returned by iid_function (which takes the # number of arguments as its sole argument), producing a labelled qqplot, and @@ -928,15 +936,9 @@ check_samples <- function( iid_samples <- iid_function(neff) mcmc_samples <- as.matrix(draws) - do_thinning <- function(x, thinning = 1) { - idx <- seq(1, length(x), by = thinning) - x[idx] - } - mcmc_samples <- do_thinning(mcmc_samples, thin) iid_samples <- do_thinning(iid_samples, thin) - # # plot # if (is.null(title)) { # distrib <- get_node(x)$distribution$distribution_name diff --git a/tests/testthat/test_posteriors_geweke.R b/tests/testthat/test_posteriors_geweke.R index 9a730f40..8b9868b8 100644 --- a/tests/testthat/test_posteriors_geweke.R +++ b/tests/testthat/test_posteriors_geweke.R @@ -41,7 +41,8 @@ test_that("samplers pass geweke tests", { model = model, data = x, p_theta = p_theta, - p_x_bar_theta = p_x_bar_theta + p_x_bar_theta = p_x_bar_theta, + thin = 5 ) geweke_qq(geweke_hmc, title = "HMC Geweke test") @@ -56,7 +57,8 @@ test_that("samplers pass geweke tests", { data = x, p_theta = p_theta, p_x_bar_theta = p_x_bar_theta, - warmup = 2000 + warmup = 2000, + thin = 5 ) geweke_qq(geweke_hmc_rwmh, title = "RWMH Geweke test")