diff --git a/R/args.R b/R/args.R index 123022a5..50c029b9 100644 --- a/R/args.R +++ b/R/args.R @@ -566,7 +566,9 @@ PathfinderArgs <- R6::R6Class( num_paths = NULL, max_lbfgs_iters = NULL, num_elbo_draws = NULL, - save_single_paths = NULL) { + save_single_paths = NULL, + psis_resample = NULL, + calculate_lp = NULL) { self$init_alpha <- init_alpha self$tol_obj <- tol_obj self$tol_rel_obj <- tol_rel_obj @@ -580,6 +582,8 @@ PathfinderArgs <- R6::R6Class( self$max_lbfgs_iters <- max_lbfgs_iters self$num_elbo_draws <- num_elbo_draws self$save_single_paths <- save_single_paths + self$psis_resample <- psis_resample + self$calculate_lp <- calculate_lp invisible(self) }, @@ -608,7 +612,9 @@ PathfinderArgs <- R6::R6Class( .make_arg("num_paths"), .make_arg("max_lbfgs_iters"), .make_arg("num_elbo_draws"), - .make_arg("save_single_paths") + .make_arg("save_single_paths"), + .make_arg("psis_resample"), + .make_arg("calculate_lp") ) new_args <- do.call(c, new_args) c(args, new_args) @@ -966,6 +972,10 @@ validate_pathfinder_args <- function(self) { if (!is.null(self$save_single_paths)) { self$save_single_paths <- 0 } + checkmate::assert_integerish(self$psis_resample, null.ok = TRUE, + lower = 0, upper = 1, len = 1) + checkmate::assert_integerish(self$calculate_lp, null.ok = TRUE, + lower = 0, upper = 1, len = 1) # check args only available for lbfgs and bfgs diff --git a/tests/testthat/test-model-pathfinder.R b/tests/testthat/test-model-pathfinder.R index 2f23cc9b..0860ba27 100644 --- a/tests/testthat/test-model-pathfinder.R +++ b/tests/testthat/test-model-pathfinder.R @@ -30,7 +30,9 @@ ok_arg_values <- list( draws = 100, num_paths = 4, max_lbfgs_iters = 100, - save_single_paths = FALSE) + save_single_paths = FALSE, + calculate_lp = TRUE, + psis_resample=TRUE) # using any one of these should cause sample() to error bad_arg_values <- list(