Skip to content

Commit

Permalink
use golden_cholesky flag. An attempt at greta-dev#604
Browse files Browse the repository at this point in the history
  • Loading branch information
njtierney committed May 10, 2024
1 parent c951562 commit 917f936
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 18 deletions.
4 changes: 4 additions & 0 deletions R/functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,9 @@ chol.greta_array <- function(x, ...) {
)
}

# set golden_cholesky flag
x <- set_golden_cholesky(x)

if (has_representation(x, "cholesky")) {
result <- copy_representation(x, "cholesky")
} else {
Expand All @@ -390,6 +393,7 @@ chol.greta_array <- function(x, ...) {
dim = dim,
tf_operation = "tf_chol"
)

}

result
Expand Down
6 changes: 6 additions & 0 deletions R/greta_array_class.R
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,12 @@ get_node <- function(x) {
attr(x, "node")
}

set_golden_cholesky <- function(x) {
x_node <- get_node(x)
x_node$golden_cholesky <- TRUE
x_node
}

# check for and get representations
representation <- function(x, name, error = TRUE) {
if (inherits(x, "greta_array")) {
Expand Down
38 changes: 20 additions & 18 deletions R/node_types.R
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,15 @@ operation_node <- R6Class(
# computational speedups or numerical stability. E.g. a logarithm or a
# cholesky factor
representations = list(),
golden_cholesky = FALSE,
initialize = function(operation,
...,
dim = NULL,
operation_args = list(),
tf_operation = NULL,
value = NULL,
representations = list(),
golden_cholesky = FALSE,
tf_function_env = parent.frame(3),
expand_scalars = FALSE) {

Expand Down Expand Up @@ -127,6 +129,7 @@ operation_node <- R6Class(
self$operation <- tf_operation
self$operation_args <- operation_args
self$representations <- representations
self$golden_cholesky <- golden_cholesky
self$tf_function_env <- tf_function_env

# assign empty value of the right dimension, or the values passed via the
Expand Down Expand Up @@ -161,35 +164,34 @@ operation_node <- R6Class(
# if sampling get the distribution constructor and sample this
if (mode == "sampling") {
tensor <- dag$draw_sample(self$distribution)
if (has_representation(self, "cholesky")) {
# error here since when sampling from a cholesky represented variable
# we don't really get consistent results
is_cholesky <- isTRUE(self$golden_cholesky)
if (has_representation(self, "cholesky") && is_cholesky){
## TF1/2
## This approach currently fails because of how we use representations
## within greta.
# We will now error here since when sampling from a cholesky
# represented variable, we don't really get consistent results
cli::cli_warn(
## Could note that there are false positives
## Could note that there are false positives?
message = c(
"When using {.fun calculate} to sample a greta array with a \\
cholesky factor, the output can sometimes be unreliable.",
"Cannot use {.fun calculate} to sample a cholesky factor of a \\
greta array",
"E.g., {.code x_chol <- chol(wishart(df = 4, Sigma = diag(3)))}",
"{.code {.code calculate(x_chol)}}",
"This is due to an internal issue with how greta handles \\
cholesky representations.",
"See issue here on github for more details:",
"{.url }"
"{.url https://github.com/greta-dev/greta/issues/593}"
)
)
}
cholesky_tensor <- tf_chol(tensor)
cholesky_tf_name <- dag$tf_name(self$representation$cholesky)
assign(cholesky_tf_name, cholesky_tensor, envir = dag$tf_environment)
## TF1/2
## This gives strange/bad behaviour for two reasons
## 1. self$representation should be self$representation**s**
## But if you do this then you end up with a tensor being passed
## to dag$tf_name(self$representation$cholesky), which is an error
## So instead I think it should be
## dag$tf_name(self)
## 2. This assignment I think is supposed to be passed down to later
## on in the script, as `cholesky_tf_name` gets overwritten
# cholesky_tf_name <- dag$tf_name(self)

# tf_name <- cholesky_tf_name
# tensor <- cholesky_tensor
}
}

if (mode == "forward") {

Expand Down

0 comments on commit 917f936

Please sign in to comment.