From 6fc8845e68a92a0984b0ef89c2699e32b718b7d5 Mon Sep 17 00:00:00 2001 From: moralapablo Date: Wed, 20 Nov 2024 12:22:49 +0100 Subject: [PATCH] Changed output to matrix when single polynomial is being used --- R/eval_monomials.R | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/R/eval_monomials.R b/R/eval_monomials.R index ed21d4e..f48143c 100644 --- a/R/eval_monomials.R +++ b/R/eval_monomials.R @@ -10,6 +10,15 @@ #' #' @inheritParams eval_poly #' +#' @returns A 3D array where the dimensions are: +#' (n_samples, +#' n_monomial_terms, +#' n_polynomials). +#' +#' If n_polynomials = 1, a single matrix is returned with dimensions +#' (n_samples, +#' n_monomial_terms) +#' #' @seealso \code{eval_monomials()} is also used in [predict.nn2poly()]. #' eval_monomials <- function(poly, newdata) { @@ -25,8 +34,8 @@ eval_monomials <- function(poly, newdata) { # polynomials and columns equal to the number of observations evaluated. n_sample <- nrow(newdata) n_polynomials <- ncol(poly$values) - n_terms <- length(poly$labels) - response <- array(0,c(n_sample, n_terms, n_polynomials)) + n_monomial_terms <- length(poly$labels) + response <- array(0,c(n_sample, n_monomial_terms, n_polynomials)) for (k in 1:n_polynomials){ @@ -67,5 +76,11 @@ eval_monomials <- function(poly, newdata) { n_sample) } + # If single polynomial, the third dimension should be 1 and then we + # can return a matrix instead of a 3D array to simplify its use. + if (n_polynomials==1){ + response <- response[,,1] + } + return(response) }