Skip to content

Commit

Permalink
conj function sets
Browse files Browse the repository at this point in the history
  • Loading branch information
scsmithr committed Feb 26, 2025
1 parent dc542c3 commit 6057369
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 101 deletions.
6 changes: 5 additions & 1 deletion crates/rayexec_execution/src/expr/conjunction_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::fmt;
use super::{AsScalarFunctionSet, Expression};
use crate::explain::context_display::{ContextDisplay, ContextDisplayMode, ContextDisplayWrapper};
use crate::functions::function_set::ScalarFunctionSet;
use crate::functions::scalar::builtin::boolean::{FUNCTION_SET_AND, FUNCTION_SET_OR};

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ConjunctionOperator {
Expand All @@ -12,7 +13,10 @@ pub enum ConjunctionOperator {

impl AsScalarFunctionSet for ConjunctionOperator {
fn as_scalar_function_set(&self) -> &ScalarFunctionSet {
unimplemented!()
match self {
ConjunctionOperator::And => &FUNCTION_SET_AND,
ConjunctionOperator::Or => &FUNCTION_SET_OR,
}
}
}

Expand Down
6 changes: 3 additions & 3 deletions crates/rayexec_execution/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ use crate::explain::context_display::{ContextDisplay, ContextDisplayMode};
use crate::functions::function_set::{FunctionInfo, FunctionSet, ScalarFunctionSet};
use crate::functions::scalar::{FunctionVolatility, PlannedScalarFunction};
use crate::functions::CastType;
use crate::logical::binder::table_list::{TableList, TableRef};
use crate::logical::binder::table_list::TableRef;

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum Expression {
Expand Down Expand Up @@ -753,8 +753,8 @@ mod tests {
column((0, 0), DataType::Utf8).into(),
column((0, 1), DataType::Int32).into(),
or([
column((1, 4), DataType::Int8),
column((1, 2), DataType::Int8),
column((1, 4), DataType::Boolean),
column((1, 2), DataType::Boolean),
])
.unwrap()
.into(),
Expand Down
158 changes: 62 additions & 96 deletions crates/rayexec_execution/src/functions/scalar/builtin/boolean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,65 +11,71 @@ use crate::arrays::executor::scalar::{BinaryExecutor, UnaryExecutor, UniformExec
use crate::arrays::executor::OutBuffer;
use crate::expr::Expression;
use crate::functions::documentation::{Category, Documentation, Example};
use crate::functions::scalar::{PlannedScalarFunction2, ScalarFunction2, ScalarFunctionImpl};
use crate::functions::{invalid_input_types_error, FunctionInfo, Signature};
use crate::logical::binder::table_list::TableList;

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct And;

impl FunctionInfo for And {
fn name(&self) -> &'static str {
"and"
}

fn signatures(&self) -> &[Signature] {
&[Signature {
use crate::functions::function_set::ScalarFunctionSet;
use crate::functions::scalar::{BindState, RawScalarFunction, ScalarFunction};
use crate::functions::Signature;

pub const FUNCTION_SET_AND: ScalarFunctionSet = ScalarFunctionSet {
name: "and",
aliases: &[],
doc: Some(&Documentation {
category: Category::General,
description: "Boolean and all inputs.",
arguments: &["var_args"],
example: Some(Example {
example: "and(true, false, true)",
output: "false",
}),
}),
functions: &[RawScalarFunction::new(
Signature {
positional_args: &[],
variadic_arg: Some(DataTypeId::Boolean),
return_type: DataTypeId::Boolean,
doc: Some(&Documentation {
category: Category::General,
description: "Boolean and all inputs.",
arguments: &["var_args"],
example: Some(Example {
example: "and(true, false, true)",
output: "false",
}),
}),
}]
}
}
doc: None,
},
&And,
)],
};

pub const FUNCTION_SET_OR: ScalarFunctionSet = ScalarFunctionSet {
name: "or",
aliases: &[],
doc: Some(&Documentation {
category: Category::General,
description: "Boolean or all inputs.",
arguments: &["var_args"],
example: Some(Example {
example: "or(true, false, true)",
output: "true",
}),
}),
functions: &[RawScalarFunction::new(
Signature {
positional_args: &[],
variadic_arg: Some(DataTypeId::Boolean),
return_type: DataTypeId::Boolean,
doc: None,
},
&Or,
)],
};

impl ScalarFunction2 for And {
fn plan(
&self,
table_list: &TableList,
inputs: Vec<Expression>,
) -> Result<PlannedScalarFunction2> {
let datatypes = inputs
.iter()
.map(|input| input.datatype())
.collect::<Result<Vec<_>>>()?;

if !datatypes.iter().all(|dt| dt == &DataType::Boolean) {
return Err(invalid_input_types_error(self, &datatypes));
}
#[derive(Debug, Clone)]
pub struct And;

impl ScalarFunction for And {
type State = ();

Ok(PlannedScalarFunction2 {
function: Box::new(*self),
fn bind(&self, inputs: Vec<Expression>) -> Result<BindState<Self::State>> {
Ok(BindState {
state: (),
return_type: DataType::Boolean,
inputs,
function_impl: Box::new(AndImpl),
})
}
}

#[derive(Debug, Clone)]
pub struct AndImpl;

impl ScalarFunctionImpl for AndImpl {
fn execute(&self, input: &Batch, output: &mut Array) -> Result<()> {
fn execute(&self, _state: &Self::State, input: &Batch, output: &mut Array) -> Result<()> {
let sel = input.selection();

match input.arrays().len() {
Expand Down Expand Up @@ -114,61 +120,21 @@ impl ScalarFunctionImpl for AndImpl {
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct Or;

impl FunctionInfo for Or {
fn name(&self) -> &'static str {
"or"
}
impl ScalarFunction for Or {
type State = ();

fn signatures(&self) -> &[Signature] {
&[Signature {
positional_args: &[],
variadic_arg: Some(DataTypeId::Boolean),
return_type: DataTypeId::Boolean,
doc: Some(&Documentation {
category: Category::General,
description: "Boolean or all inputs.",
arguments: &["var_args"],
example: Some(Example {
example: "or(true, false, true)",
output: "true",
}),
}),
}]
}
}

impl ScalarFunction2 for Or {
fn plan(
&self,
table_list: &TableList,
inputs: Vec<Expression>,
) -> Result<PlannedScalarFunction2> {
let datatypes = inputs
.iter()
.map(|input| input.datatype())
.collect::<Result<Vec<_>>>()?;

if !datatypes.iter().all(|dt| dt == &DataType::Boolean) {
return Err(invalid_input_types_error(self, &datatypes));
}

Ok(PlannedScalarFunction2 {
function: Box::new(*self),
fn bind(&self, inputs: Vec<Expression>) -> Result<BindState<Self::State>> {
Ok(BindState {
state: (),
return_type: DataType::Boolean,
inputs,
function_impl: Box::new(OrImpl),
})
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct OrImpl;

impl ScalarFunctionImpl for OrImpl {
fn execute(&self, input: &Batch, output: &mut Array) -> Result<()> {
fn execute(&self, _state: &Self::State, input: &Batch, output: &mut Array) -> Result<()> {
let sel = input.selection();

match input.arrays().len() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use rayexec_error::Result;
use super::ExpressionRewriteRule;
use crate::expr::conjunction_expr::{ConjunctionExpr, ConjunctionOperator};
use crate::expr::Expression;
use crate::logical::binder::table_list::TableList;

/// Unnest nested AND or OR expressions.
///
Expand Down

0 comments on commit 6057369

Please sign in to comment.