Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vapply refactor 377 #658

Merged
merged 3 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ export(adagrad_da)
export(adam)
export(adamax)
export(apply)
export(are_null)
export(as.greta_model)
export(as.unknowns)
export(as_data)
Expand Down
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ The following optimisers are removed, as they are no longer supported by Tensorf
* Update photo of Grete Hermann (#598)
* Use `%||%` internally to replace the pattern: `if (is.null(x)) x <- thing` with `x <- x %||% thing`. (#630)
* Add more explaining variables - replace `if (thing & thing & what == this)` with `if (explanation_of_thing)`.
*
* Refactored repeated uses of `vapply` into functions (#377, #658)

## Bug fixes

Expand Down
59 changes: 10 additions & 49 deletions R/checkers.R
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,7 @@ check_dims <- function(..., target_dim = NULL) {
# if more than one is non-scalar, need to check them
more_than_one_is_non_scalar <- sum(!scalars) > 1
if (more_than_one_is_non_scalar) {
match_first <- vapply(dim_list[!scalars],
identical,
FUN.VALUE = FALSE,
dim_list[!scalars][[1]]
)
match_first <- are_identical(dim_list[!scalars], dim_list[!scalars][[1]])

# if they're non-scalar, but have the same dimensions, that's fine too
if (!all(match_first)) {
Expand Down Expand Up @@ -139,11 +135,7 @@ check_dims <- function(..., target_dim = NULL) {
if (!all(scalars)) {

# check all arguments against this
matches_target <- vapply(dim_list[!scalars],
identical,
FUN.VALUE = FALSE,
target_dim
)
matches_target <- are_identical(dim_list[!scalars], target_dim)

# error if not
if (!all(matches_target)) {
Expand Down Expand Up @@ -276,11 +268,7 @@ check_n_realisations <- function(vectors = list(),

# if more than one has multiple rows, need to check them
if (sum(!single_rows) > 1) {
match_first <- vapply(nrows[!single_rows],
identical,
FUN.VALUE = FALSE,
nrows[!single_rows][[1]]
)
match_first <- are_identical(nrows[!single_rows], nrows[!single_rows][[1]])

# if they're non-scalar, but have the same dimensions, that's fine too
if (!all(match_first)) {
Expand Down Expand Up @@ -321,11 +309,7 @@ check_n_realisations <- function(vectors = list(),
if (!all(single_rows)) {

# check all arguments against this
matches_target <- vapply(nrows[!single_rows],
identical,
FUN.VALUE = FALSE,
target
)
matches_target <- are_identical(nrows[!single_rows], target)

# error if not
if (!all(matches_target)) {
Expand Down Expand Up @@ -395,9 +379,7 @@ check_dimension <- function(vectors = list(),
}

# make sure all the parameters match this dimension
match_dimension <- vapply(ncols, identical, dimension,
FUN.VALUE = FALSE
)
match_dimension <- are_identical(ncols, dimension)

# otherwise it's not fine
if (!all(match_dimension)) {
Expand Down Expand Up @@ -575,11 +557,7 @@ check_future_plan <- function() {
check_greta_arrays <- function(greta_array_list, fun_name, hint = NULL) {

# check they are greta arrays
are_greta_arrays <- vapply(greta_array_list,
is.greta_array,
FUN.VALUE = FALSE
)

are_greta_arrays <- are_greta_array(greta_array_list)

msg <- NULL

Expand Down Expand Up @@ -628,11 +606,7 @@ check_values_list <- function(values, env) {
fixed_greta_arrays <- lapply(names, get, envir = env)

# make sure that's what they are
are_greta_arrays <- vapply(fixed_greta_arrays,
inherits,
"greta_array",
FUN.VALUE = FALSE
)
are_greta_arrays <- are_greta_array(fixed_greta_arrays)

if (!all(are_greta_arrays)) {
cli::cli_abort(
Expand Down Expand Up @@ -873,17 +847,9 @@ check_if_unsampleable_and_unfixed <- function(fixed_greta_arrays, dag) {
# check there are no variables without distributions (or whose children have
# distributions - for lkj & wishart) that aren't given fixed values
variables <- dag$node_list[dag$node_types == "variable"]
have_distributions <- vapply(
variables,
has_distribution,
FUN.VALUE = logical(1)
)
have_distributions <- have_distribution(variables)
any_child_has_distribution <- function(variable) {
have_distributions <- vapply(
variable$children,
has_distribution,
FUN.VALUE = logical(1)
)
have_distributions <- have_distribution(variable$children)
any(have_distributions)
}
children_have_distributions <- vapply(
Expand All @@ -895,12 +861,7 @@ check_if_unsampleable_and_unfixed <- function(fixed_greta_arrays, dag) {
unsampleable <- !have_distributions & !children_have_distributions

fixed_nodes <- lapply(fixed_greta_arrays, get_node)
fixed_node_names <- vapply(
fixed_nodes,
member,
"unique_name",
FUN.VALUE = character(1)
)
fixed_node_names <- extract_unique_names(fixed_nodes)

unfixed <- !(names(variables) %in% fixed_node_names)

Expand Down
2 changes: 1 addition & 1 deletion R/dag_class.R
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ dag_class <- R6Class(
# get all distribution nodes that have a target
distribution_nodes <- self$node_list[self$node_types == "distribution"]
target_nodes <- lapply(distribution_nodes, member, "get_tf_target_node()")
has_target <- !vapply(target_nodes, is.null, FUN.VALUE = TRUE)
has_target <- !are_null(target_nodes)
distribution_nodes <- distribution_nodes[has_target]
target_nodes <- target_nodes[has_target]

Expand Down
13 changes: 5 additions & 8 deletions R/extract_replace_combine.R
Original file line number Diff line number Diff line change
Expand Up @@ -384,14 +384,14 @@ abind.greta_array <- function(...,
arg_list <- list(...)

# drop any NULLs
to_discard <- vapply(arg_list, is.null, FUN.VALUE = FALSE)
to_discard <- are_null(arg_list)
if (any(to_discard)) {
arg_list <- arg_list[!to_discard]
}

# get N first, in case they used the default value for along
dims <- lapply(arg_list, dim)
n <- max(vapply(dims, length, FUN.VALUE = 1L))
n <- max(lengths(dims))

# needed to keep the same formals as abind
N <- n # nolint
Expand Down Expand Up @@ -469,17 +469,14 @@ c.greta_array <- function(...) {
args <- list(...)

# drop NULLs from the list
is_null <- vapply(args, is.null, FUN.VALUE = FALSE)
is_null <- are_null(args)
args <- args[!is_null]

# try to coerce to greta arrays
args <- lapply(args, as.greta_array, optional = TRUE)

# return a list if they aren't all greta arrays
is_greta_array <- vapply(args,
inherits, "greta_array",
FUN.VALUE = FALSE
)
is_greta_array <- are_greta_array(args)

if (!all(is_greta_array)) {
return(args)
Expand All @@ -489,7 +486,7 @@ c.greta_array <- function(...) {
arrays <- lapply(args, flatten)

# get output dimensions
length_vec <- vapply(arrays, length, FUN.VALUE = 1)
length_vec <- lengths(arrays)
dim_out <- c(sum(length_vec), 1L)

# create the op, expanding 'arrays' out to match op()'s dots input
Expand Down
17 changes: 4 additions & 13 deletions R/greta_model_class.R
Original file line number Diff line number Diff line change
Expand Up @@ -285,11 +285,8 @@ plot.greta_model <- function(x,

# add greta array names where available
visible_nodes <- lapply(x$visible_greta_arrays, get_node)
known_nodes <- vapply(visible_nodes,
member,
"unique_name",
FUN.VALUE = ""
)
known_nodes <- extract_unique_names(visible_nodes)

known_nodes <- known_nodes[known_nodes %in% names]
known_idx <- match(known_nodes, names)
node_labels[known_idx] <- paste(names(known_nodes),
Expand Down Expand Up @@ -323,13 +320,7 @@ plot.greta_model <- function(x,

node_names <- lapply(
parameter_list,
function(parameters) {
vapply(parameters,
member,
"unique_name",
FUN.VALUE = ""
)
}
extract_unique_names
)

# for each distribution
Expand Down Expand Up @@ -360,7 +351,7 @@ plot.greta_model <- function(x,
"target"
)

keep <- !vapply(targets, is.null, TRUE)
keep <- !are_null(targets)
distrib_idx <- distrib_idx[keep]


Expand Down
2 changes: 1 addition & 1 deletion R/inference.R
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,7 @@ parse_initial_values <- function(initials, dag) {
forward_names <- glue::glue("all_forward_{dag$node_tf_names}")
nodes <- dag$node_list[match(tf_names, forward_names)]
types <- lapply(nodes, node_type)
are_variables <- vapply(types, identical, "variable", FUN.VALUE = FALSE)
are_variables <- are_identical(types, "variable")

if (!all(are_variables)) {
cli::cli_abort(
Expand Down
2 changes: 1 addition & 1 deletion R/mixture.R
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ mixture_distribution <- R6Class(
truncations <- lapply(distribs, member, "truncation")
bounds <- lapply(distribs, member, "bounds")

truncated <- !vapply(truncations, is.null, logical(1))
truncated <- !are_null(truncations)
supports <- bounds
supports[truncated] <- truncations[truncated]

Expand Down
18 changes: 5 additions & 13 deletions R/node_class.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ node <- R6Class(
value <- value %||% unknowns(dim = dim)

self$value(value)
self$get_unique_name()
self$create_unique_name()
},
register = function(dag) {
## TODO add explaining variable
Expand All @@ -43,7 +43,7 @@ node <- R6Class(
family <- c(self$list_children(dag), self$list_parents(dag))

# get and assign their names
family_names <- vapply(family, member, "unique_name", FUN.VALUE = "")
family_names <- extract_unique_names(family)
names(family) <- family_names

# find the unregistered ones
Expand Down Expand Up @@ -102,11 +102,7 @@ node <- R6Class(
# that a child node
mode <- dag$how_to_define(self)
if (mode == "sampling" & has_distribution(self)) {
child_names <- vapply(children,
member,
"unique_name",
FUN.VALUE = character(1)
)
child_names <- extract_unique_names(children)
keep <- child_names != self$distribution$unique_name
children <- children[keep]
}
Expand All @@ -122,11 +118,7 @@ node <- R6Class(
parents <- self$parents

if (length(parents) > 0) {
names <- vapply(parents,
member,
"unique_name",
FUN.VALUE = character(1)
)
names <- extract_unique_names(parents)

if (recursive) {
their_parents <- function(x) {
Expand Down Expand Up @@ -266,7 +258,7 @@ node <- R6Class(
text
},

get_unique_name = function() {
create_unique_name = function() {
self$unique_name <- glue::glue("node_{rhex()}")
},
plotting_label = function() {
Expand Down
6 changes: 1 addition & 5 deletions R/node_types.R
Original file line number Diff line number Diff line change
Expand Up @@ -467,11 +467,7 @@ distribution_node <- R6Class(
# consider that a parent node
mode <- dag$how_to_define(self)
if (mode == "sampling" & !is.null(self$target)) {
parent_names <- vapply(parents,
member,
"unique_name",
FUN.VALUE = character(1)
)
parent_names <- extract_unique_names(parents)
keep <- parent_names != self$target$unique_name
parents <- parents[keep]
}
Expand Down
12 changes: 2 additions & 10 deletions R/simulate.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,8 @@ simulate.greta_model <- function(object,

# subset these to only those that are associated with the model
target_nodes <- lapply(target_greta_arrays, get_node)
target_node_names <- vapply(target_nodes,
member,
"unique_name",
FUN.VALUE = character(1)
)
object_node_names <- vapply(object$dag$node_list,
member,
"unique_name",
FUN.VALUE = character(1)
)
target_node_names <- extract_unique_names(target_nodes)
object_node_names <- extract_unique_names(object$dag$node_list)
keep <- target_node_names %in% object_node_names
target_greta_arrays <- target_greta_arrays[keep]

Expand Down
Loading