Skip to content

Commit

Permalink
Merge pull request #644 from njtierney/greta-array-print-480
Browse files Browse the repository at this point in the history
Greta array print 480
  • Loading branch information
njtierney authored Jul 30, 2024
2 parents 2287519 + 22665f6 commit a859e6b
Show file tree
Hide file tree
Showing 10 changed files with 609 additions and 87 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ The following optimisers are removed, as they are no longer supported by Tensorf
* slice sampler no longer needs precision = "single" to work.
* greta now depends on R 4.1.0, which was released May 2021, over 3 years ago.
* export `is.greta_array()` and `is.greta_mcmc_list()`
* greta arrays now have a print method that stops them from printing too many rows into the console. Similar to MCMC print method, you can control the print output with the `n` argument: `print(object, n = <elements to print>)`. (#644)
* New print method for `greta_mcmc_list`. This means MCMC output will be shorter and more informative. (#644)

## Internals
Expand Down
43 changes: 37 additions & 6 deletions R/greta_array_class.R
Original file line number Diff line number Diff line change
Expand Up @@ -153,16 +153,47 @@ as.greta_array.default <- function(x, optional = FALSE, original_x = x, ...) {

# print method
#' @export
print.greta_array <- function(x, ...) {
print.greta_array <- function(x, ..., n = 10) {
node <- get_node(x)
text <- glue::glue(
"greta array ({node$description()})\n\n\n"
)

cat(text)
print(node$value(), ...)
node_desc <- node$cli_description()

cli::cli_text("{.pkg greta} array {.cls {node_desc}}")
cli::cli_text("\n")

if (is.unknowns(node$value())){
return(print(node$value(), ..., n = n))
}

x_val <- node$value()
n_print <- getOption("greta.print_max") %||% n

n_unknowns <- length(x_val)
x_head <- head(x_val, n = n_print)
remaining_vals <- n_unknowns - n_print

# print with question marks
print.default(x_head, quote = FALSE, max = n)

cli::cli_text("\n")

if (remaining_vals <= 0) {
return(invisible(x_val))
}

if (remaining_vals > 0 ) {
cli::cli_alert_info(
text = c(
"i" = "{remaining_vals} more values\n",
"i" = "Use {.code print(n = ...)} to see more values"
)
)
}


}


# summary method
#' @export
summary.greta_array <- function(object, ...) {
Expand Down
17 changes: 17 additions & 0 deletions R/node_class.R
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,23 @@ node <- R6Class(

text
},
cli_description = function() {
text <- node_type(self)
text <- node_type_colour(text)

dist_txt <- glue::glue("{self$distribution$distribution_name} distribution")
if (has_distribution(self)) {
text <- cli::cli_fmt(
cli::cli_text(
# "{text} following a {.strong {dist_txt}}"
"{text} following a {cli::col_yellow({dist_txt})}"
)
)
}

text
},

get_unique_name = function() {
self$unique_name <- glue::glue("node_{rhex()}")
},
Expand Down
28 changes: 26 additions & 2 deletions R/unknowns_class.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,39 @@ as.unknowns.matrix <- function(x) { # nolint
}

#' @export
print.unknowns <- function(x, ...) {
print.unknowns <- function(x, ..., n = 10) {
# remove 'unknown' class attribute
x <- unclass(x)

# set NA values to ? for printing
x[is.na(x)] <- " ?"

# browser()

n_print <- getOption("greta.print_max") %||% n

n_unknowns <- length(x)
x_head <- head(x, n = n_print)
remaining_vals <- n_unknowns - n_print

# print with question marks
print.default(x, quote = FALSE, ...)
print.default(x_head, quote = FALSE)

cli::cli_text("\n")

if (remaining_vals <= 0) {
return(invisible(x))
}

if (remaining_vals > 0 ) {
cli::cli_alert_info(
text = c(
"i" = "{remaining_vals} more values\n",
"i" = "Use {.code print(n = ...)} to see more values"
)
)
}

}

# create an unknowns array from some dimensions
Expand Down
13 changes: 13 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -1139,3 +1139,16 @@ is.unknowns <- function(x, ...){
is.initials <- function(x, ...){
inherits(x, "initials")
}

node_type_colour <- function(type){

switch_cols <- switch(
type,
variable = cli::col_red(type),
data = cli::col_green(type),
operation = cli::col_cyan(type),
distribution = cli::col_yellow(type)
)

switch_cols
}
110 changes: 55 additions & 55 deletions tests/testthat/_snaps/distributions_cholesky.md
Original file line number Diff line number Diff line change
@@ -1,70 +1,70 @@
# Cholesky factor of Wishart should be a lower triangular matrix

Code
calculate(chol_x, nsim = 1)
calculate(chol_x, nsim = 1, seed = 2024 - 7 - 30 - 1431)
Output
$chol_x
, , 1
[,1] [,2] [,3]
[1,] 1.191098 0 0
[,1] [,2] [,3]
[1,] 1.76182 0 0
, , 2
[,1] [,2] [,3]
[1,] -0.5446781 1.148268 0
[,1] [,2] [,3]
[1,] -0.005111915 1.734849 0
, , 3
[,1] [,2] [,3]
[1,] -0.07584197 -0.543763 0.852187
[,1] [,2] [,3]
[1,] -0.9636536 -0.8675692 0.7560725

---

Code
(calc_chol <- calculate(x, chol_x, nsim = 1))
(calc_chol <- calculate(x, chol_x, nsim = 1, seed = 2024 - 7 - 30 - 1431))
Output
$x
, , 1
[,1] [,2] [,3]
[1,] 3.280857 -0.8723121 -0.7930352
[,1] [,2] [,3]
[1,] 3.104011 -0.009006277 -1.697785
, , 2
[,1] [,2] [,3]
[1,] -0.8723121 4.256879 1.20619
[,1] [,2] [,3]
[1,] -0.009006277 3.009727 -1.500175
, , 3
[,1] [,2] [,3]
[1,] -0.7930352 1.20619 1.070123
[,1] [,2] [,3]
[1,] -1.697785 -1.500175 2.25295
$chol_x
, , 1
[,1] [,2] [,3]
[1,] 1.811314 0 0
[,1] [,2] [,3]
[1,] 1.76182 0 0
, , 2
[,1] [,2] [,3]
[1,] -0.4815909 2.006228 0
[,1] [,2] [,3]
[1,] -0.005111915 1.734849 0
, , 3
[,1] [,2] [,3]
[1,] -0.4378232 0.4961243 0.7951696
[,1] [,2] [,3]
[1,] -0.9636536 -0.8675692 0.7560725

# Cholesky factor of LJK_correlation should be a lower triangular matrix

Code
calculate(chol_x, nsim = 1)
calculate(chol_x, nsim = 1, seed = 2024 - 7 - 30 - 1431)
Output
$chol_x
, , 1
Expand All @@ -75,35 +75,35 @@
, , 2
[,1] [,2] [,3]
[1,] -0.2949061 0.9555263 0
[1,] -0.1775724 0.9841077 0
, , 3
[,1] [,2] [,3]
[1,] -0.02679775 0.2445977 0.9692543
[,1] [,2] [,3]
[1,] 0.2806787 0.7509681 0.5977177

---

Code
(calc_chol <- calculate(x, chol_x, nsim = 1))
(calc_chol <- calculate(x, chol_x, nsim = 1, seed = 2024 - 7 - 30 - 1431))
Output
$x
, , 1
[,1] [,2] [,3]
[1,] 1 -0.2012312 -0.1072049
[,1] [,2] [,3]
[1,] 1 -0.1775724 0.2806787
, , 2
[,1] [,2] [,3]
[1,] -0.2012312 1 0.5519843
[1,] -0.1775724 1 0.6891927
, , 3
[,1] [,2] [,3]
[1,] -0.1072049 0.5519843 1
[,1] [,2] [,3]
[1,] 0.2806787 0.6891927 1
$chol_x
Expand All @@ -115,12 +115,12 @@
, , 2
[,1] [,2] [,3]
[1,] -0.2012312 0.9795438 0
[1,] -0.1775724 0.9841077 0
, , 3
[,1] [,2] [,3]
[1,] -0.1072049 0.5414881 0.8338451
[,1] [,2] [,3]
[1,] 0.2806787 0.7509681 0.5977177

Expand All @@ -132,35 +132,35 @@
$x
, , 1
[,1] [,2] [,3]
[1,] 5.555332 -0.2148704 2.431943
[,1] [,2] [,3]
[1,] 3.104011 -0.009006277 -1.697785
, , 2
[,1] [,2] [,3]
[1,] -0.2148704 1.03555 -1.263782
[,1] [,2] [,3]
[1,] -0.009006277 3.009727 -1.500175
, , 3
[,1] [,2] [,3]
[1,] 2.431943 -1.263782 8.162073
[,1] [,2] [,3]
[1,] -1.697785 -1.500175 2.25295
$`chol(x)`
, , 1
[,1] [,2] [,3]
[1,] 2.356975 0 0
[,1] [,2] [,3]
[1,] 1.76182 0 0
, , 2
[,1] [,2] [,3]
[1,] -0.09116363 1.013528 0
[,1] [,2] [,3]
[1,] -0.005111915 1.734849 0
, , 3
[,1] [,2] [,3]
[1,] 1.031807 -1.154106 2.401143
[,1] [,2] [,3]
[1,] -0.9636536 -0.8675692 0.7560725

Expand All @@ -172,18 +172,18 @@
$x
, , 1
[,1] [,2] [,3]
[1,] 1 0.4092182 0.1046861
[,1] [,2] [,3]
[1,] 1 -0.1775724 0.2806787
, , 2
[,1] [,2] [,3]
[1,] 0.4092182 1 -0.03410899
[,1] [,2] [,3]
[1,] -0.1775724 1 0.6891927
, , 3
[,1] [,2] [,3]
[1,] 0.1046861 -0.03410899 1
[,1] [,2] [,3]
[1,] 0.2806787 0.6891927 1
$`chol(x)`
Expand All @@ -194,13 +194,13 @@
, , 2
[,1] [,2] [,3]
[1,] 0.4092182 0.9124365 0
[,1] [,2] [,3]
[1,] -0.1775724 0.9841077 0
, , 3
[,1] [,2] [,3]
[1,] 0.1046861 -0.08433292 0.9909232
[,1] [,2] [,3]
[1,] 0.2806787 0.7509681 0.5977177

Loading

0 comments on commit a859e6b

Please sign in to comment.