Skip to content

Commit

Permalink
resolve matrix multiplication issue - greta-dev#464
Browse files Browse the repository at this point in the history
* Add check for if y isn't greta array, coerce
* Add tests to check for different variations on multiplying as_data(x) and/or as_data(y)
  • Loading branch information
njtierney committed May 21, 2024
1 parent 471524e commit ace1f81
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 0 deletions.
3 changes: 3 additions & 0 deletions R/operators.R
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ NULL
# if y is a greta array, coerce x before dispatch
if (inherits(y, "greta_array") & !inherits(x, "greta_array")) {
as_data(x) %*% y
# if y is not a greta array and x is, coerce y before dispatch
} else if (!inherits(y, "greta_array") & inherits(x, "greta_array")){
x %*% as_data(y)
} else {
UseMethod("%*%", x)
}
Expand Down
81 changes: 81 additions & 0 deletions tests/testthat/_snaps/operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,84 @@
only two-dimensional <greta_array>s can be matrix-multiplied
dimensions recorded were 3 and 4

# %*% works when one is a non-greta array

Code
x %*% y
Output
[,1]
[1,] 3
[2,] 3

---

Code
x %*% as_data(y)
Output
greta array (operation)
[,1]
[1,] ?
[2,] ?

---

Code
as_data(x) %*% y
Output
greta array (operation)
[,1]
[1,] ?
[2,] ?

---

Code
as_data(x) %*% as_data(y)
Output
greta array (operation)
[,1]
[1,] ?
[2,] ?

---

Code
calculate(res_1, nsim = 1)
Output
$res_1
, , 1
[,1] [,2]
[1,] 3 3

---

Code
calculate(res_2, nsim = 1)
Output
$res_2
, , 1
[,1] [,2]
[1,] 3 3

---

Code
calculate(res_3, nsim = 1)
Output
$res_3
, , 1
[,1] [,2]
[1,] 3 3

24 changes: 24 additions & 0 deletions tests/testthat/test_operators.R
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,27 @@ test_that("%*% errors informatively", {
a %*% c
)
})

test_that("%*% works when one is a non-greta array", {
x <- matrix(1, 2, 3)
y <- rep(1, 3)

expect_snapshot(x %*% y)
expect_snapshot(x %*% as_data(y))
expect_snapshot(as_data(x) %*% y)
expect_snapshot(as_data(x) %*% as_data(y))

dim1 <- dim(x %*% as_data(y))
dim2 <- dim(as_data(x) %*% y)
dim3 <- dim(as_data(x) %*% as_data(y))

expect_true(all(dim1 == dim2 & dim1 == dim3 & dim2 == dim3))

res_1 <- x %*% as_data(y)
res_2 <- as_data(x) %*% y
res_3 <- as_data(x) %*% as_data(y)

expect_snapshot(calculate(res_1, nsim = 1))
expect_snapshot(calculate(res_2, nsim = 1))
expect_snapshot(calculate(res_3, nsim = 1))
})

0 comments on commit ace1f81

Please sign in to comment.