Skip to content

Commit

Permalink
Add bindings for round function
Browse files Browse the repository at this point in the history
  • Loading branch information
thisisnic committed Feb 23, 2023
1 parent 6451524 commit 3515c3f
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 0 deletions.
11 changes: 11 additions & 0 deletions R/pkg-arrow.R
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,17 @@ arrow_funs[["grepl"]] <- function(pattern, x, ...) {
substrait_call("string.contains", x, pattern)
}

# rounding functions
arrow_funs[["round"]] <- function(x, digits = 0) {
substrait_call(
"rounding.round",
x,
digits,
.options = list(
substrait$FunctionOption$create(name = "rounding",preference = "TIE_TO_EVEN"))
)
}

check_na_rm <- function(na.rm) {
if (!na.rm) {
warning("Missing value removal from aggregate functions not yet supported, switching to na.rm = TRUE")
Expand Down
10 changes: 10 additions & 0 deletions R/pkg-duckdb.R
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,16 @@ duckdb_funs[["n_distinct"]] <- function(x, na.rm = FALSE) {
substrait_call_agg("approx_count_distinct", x, .output_type = substrait_i64())
}

duckdb_funs[["round"]] <- function(x, digits = 0) {
substrait_call(
"round",
x,
as.integer(digits),
.options = list(
substrait$FunctionOption$create(name = "rounding",preference = "TIE_TO_EVEN"))
)
}

check_na_rm_duckdb <- function(na.rm) {
if (!na.rm) {
warning("Missing value removal from aggregate functions not supported in DuckDB, switching to na.rm = TRUE")
Expand Down
19 changes: 19 additions & 0 deletions tests/testthat/test-pkg-arrow.R
Original file line number Diff line number Diff line change
Expand Up @@ -662,3 +662,22 @@ test_that("arrow translation for if_else() works", {
)
)
})

test_that("arrow translation for round works", {
skip_if_not(has_arrow_with_substrait())

# Error: NotImplemented: conversion to arrow::DataType from Substrait type
# /home/nic2/arrow/cpp/src/arrow/engine/substrait/expression_internal.cc:119 FromProto(scalar_fn.output_type(), ext_set, conversion_options)
# /home/nic2/arrow/cpp/src/arrow/engine/substrait/expression_internal.cc:317 DecodeScalarFunction(function_id, scalar_fn, ext_set, conversion_options)
# /home/nic2/arrow/cpp/src/arrow/engine/substrait/relation_internal.cc:548 FromProto(expr, ext_set, conversion_options)
# /home/nic2/arrow/cpp/src/arrow/engine/substrait/serde.cc:157 FromProto(plan_rel.has_root() ? plan_rel.root().input() : plan_rel.rel(), ext_set, conversion_options)
expect_equal(
tibble::tibble(x = c(1, 2.34, 3.456, 4.5)) %>%
arrow_substrait_compiler() %>%
substrait_project(y = round(x)) %>%
dplyr::collect(),
tibble::tibble(
y = c(1, 2, 3, 4)
)
)
})
16 changes: 16 additions & 0 deletions tests/testthat/test-pkg-duckdb.R
Original file line number Diff line number Diff line change
Expand Up @@ -355,3 +355,19 @@ test_that("duckdb translation for n_distinct works", {
classes = "substrait.n_distinct.approximate"
)
})

test_that("duckdb translation for round works", {
skip_if_not(has_duckdb_with_substrait())

# DuckDB is ignoring the rounding strategy - need to check if it even can implement
# alternatives? Should we just do this one but warn??
expect_equal(
tibble::tibble(x = c(1, 2.34, 3.456, 4.5)) %>%
duckdb_substrait_compiler() %>%
substrait_project(x, y = round(x)) %>%
dplyr::collect(),
tibble::tibble(
y = c(1, 2, 3, 4)
)
)
})

0 comments on commit 3515c3f

Please sign in to comment.