Skip to content

Commit

Permalink
update pathfinder args for psis_resample and lp_calculate
Browse files Browse the repository at this point in the history
  • Loading branch information
SteveBronder committed Jan 18, 2024
1 parent 2355447 commit c69ba62
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
14 changes: 12 additions & 2 deletions R/args.R
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,9 @@ PathfinderArgs <- R6::R6Class(
num_paths = NULL,
max_lbfgs_iters = NULL,
num_elbo_draws = NULL,
save_single_paths = NULL) {
save_single_paths = NULL,
psis_resample = NULL,
calculate_lp = NULL) {
self$init_alpha <- init_alpha
self$tol_obj <- tol_obj
self$tol_rel_obj <- tol_rel_obj
Expand All @@ -580,6 +582,8 @@ PathfinderArgs <- R6::R6Class(
self$max_lbfgs_iters <- max_lbfgs_iters
self$num_elbo_draws <- num_elbo_draws
self$save_single_paths <- save_single_paths
self$psis_resample <- psis_resample
self$calculate_lp <- calculate_lp
invisible(self)
},

Expand Down Expand Up @@ -608,7 +612,9 @@ PathfinderArgs <- R6::R6Class(
.make_arg("num_paths"),
.make_arg("max_lbfgs_iters"),
.make_arg("num_elbo_draws"),
.make_arg("save_single_paths")
.make_arg("save_single_paths"),
.make_arg("psis_resample"),
.make_arg("calculate_lp")
)
new_args <- do.call(c, new_args)
c(args, new_args)
Expand Down Expand Up @@ -966,6 +972,10 @@ validate_pathfinder_args <- function(self) {
if (!is.null(self$save_single_paths)) {
self$save_single_paths <- 0
}
checkmate::assert_integerish(self$psis_resample, null.ok = TRUE,
lower = 0, upper = 1, len = 1)
checkmate::assert_integerish(self$calculate_lp, null.ok = TRUE,
lower = 0, upper = 1, len = 1)


# check args only available for lbfgs and bfgs
Expand Down
4 changes: 3 additions & 1 deletion tests/testthat/test-model-pathfinder.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ ok_arg_values <- list(
draws = 100,
num_paths = 4,
max_lbfgs_iters = 100,
save_single_paths = FALSE)
save_single_paths = FALSE,
calculate_lp = TRUE,
psis_resample=TRUE)

# using any one of these should cause sample() to error
bad_arg_values <- list(
Expand Down

0 comments on commit c69ba62

Please sign in to comment.