From ace1f81c31166bb5b85de56710fcc412f94a747d Mon Sep 17 00:00:00 2001 From: njtierney Date: Tue, 21 May 2024 15:31:06 +1000 Subject: [PATCH] resolve matrix multiplication issue - #464 * Add check for if y isn't greta array, coerce * Add tests to check for different variations on multiplying as_data(x) and/or as_data(y) --- R/operators.R | 3 ++ tests/testthat/_snaps/operators.md | 81 ++++++++++++++++++++++++++++++ tests/testthat/test_operators.R | 24 +++++++++ 3 files changed, 108 insertions(+) diff --git a/R/operators.R b/R/operators.R index dc4958a3..2ffd9b5e 100644 --- a/R/operators.R +++ b/R/operators.R @@ -146,6 +146,9 @@ NULL # if y is a greta array, coerce x before dispatch if (inherits(y, "greta_array") & !inherits(x, "greta_array")) { as_data(x) %*% y + # if y is not a greta array and x is, coerce y before dispatch + } else if (!inherits(y, "greta_array") & inherits(x, "greta_array")){ + x %*% as_data(y) } else { UseMethod("%*%", x) } diff --git a/tests/testthat/_snaps/operators.md b/tests/testthat/_snaps/operators.md index 754850a3..c7a5fadb 100644 --- a/tests/testthat/_snaps/operators.md +++ b/tests/testthat/_snaps/operators.md @@ -7,3 +7,84 @@ only two-dimensional s can be matrix-multiplied dimensions recorded were 3 and 4 +# %*% works when one is a non-greta array + + Code + x %*% y + Output + [,1] + [1,] 3 + [2,] 3 + +--- + + Code + x %*% as_data(y) + Output + greta array (operation) + + [,1] + [1,] ? + [2,] ? + +--- + + Code + as_data(x) %*% y + Output + greta array (operation) + + [,1] + [1,] ? + [2,] ? + +--- + + Code + as_data(x) %*% as_data(y) + Output + greta array (operation) + + [,1] + [1,] ? + [2,] ? + +--- + + Code + calculate(res_1, nsim = 1) + Output + $res_1 + , , 1 + + [,1] [,2] + [1,] 3 3 + + + +--- + + Code + calculate(res_2, nsim = 1) + Output + $res_2 + , , 1 + + [,1] [,2] + [1,] 3 3 + + + +--- + + Code + calculate(res_3, nsim = 1) + Output + $res_3 + , , 1 + + [,1] [,2] + [1,] 3 3 + + + diff --git a/tests/testthat/test_operators.R b/tests/testthat/test_operators.R index 53769b08..df51ab70 100644 --- a/tests/testthat/test_operators.R +++ b/tests/testthat/test_operators.R @@ -107,3 +107,27 @@ test_that("%*% errors informatively", { a %*% c ) }) + +test_that("%*% works when one is a non-greta array", { + x <- matrix(1, 2, 3) + y <- rep(1, 3) + + expect_snapshot(x %*% y) + expect_snapshot(x %*% as_data(y)) + expect_snapshot(as_data(x) %*% y) + expect_snapshot(as_data(x) %*% as_data(y)) + + dim1 <- dim(x %*% as_data(y)) + dim2 <- dim(as_data(x) %*% y) + dim3 <- dim(as_data(x) %*% as_data(y)) + + expect_true(all(dim1 == dim2 & dim1 == dim3 & dim2 == dim3)) + + res_1 <- x %*% as_data(y) + res_2 <- as_data(x) %*% y + res_3 <- as_data(x) %*% as_data(y) + + expect_snapshot(calculate(res_1, nsim = 1)) + expect_snapshot(calculate(res_2, nsim = 1)) + expect_snapshot(calculate(res_3, nsim = 1)) +})