Skip to content

Commit

Permalink
feat: Add SQL support for bit_count and bitwise &, |, and xor
Browse files Browse the repository at this point in the history
… operators (#19114)
  • Loading branch information
alexander-beedie authored Oct 20, 2024
1 parent 94b7e89 commit 08732c4
Show file tree
Hide file tree
Showing 10 changed files with 360 additions and 51 deletions.
2 changes: 1 addition & 1 deletion crates/polars-sql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ rand = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
sqlparser = { workspace = true }
# sqlparser = { git = "https://github.com/sqlparser-rs/sqlparser-rs.git", rev = "ae3b5844c839072c235965fe0d1bddc473dced87" }

[dev-dependencies]
# to display dataframes in case of test failures
Expand All @@ -34,6 +33,7 @@ polars-core = { workspace = true, features = ["fmt"] }
default = []
nightly = []
binary_encoding = ["polars-lazy/binary_encoding"]
bitwise = ["polars-lazy/bitwise"]
csv = ["polars-lazy/csv"]
diagonal_concat = ["polars-lazy/diagonal_concat"]
dtype-decimal = ["polars-lazy/dtype-decimal"]
Expand Down
81 changes: 66 additions & 15 deletions crates/polars-sql/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,40 @@ pub(crate) struct SQLFunctionVisitor<'a> {

/// SQL functions that are supported by Polars
pub(crate) enum PolarsSQLFunctions {
// ----
// Bitwise functions
// ----
/// SQL 'bit_and' function.
/// Returns the bitwise AND of the input expressions.
/// ```sql
/// SELECT BIT_AND(column_1, column_2) FROM df;
/// ```
BitAnd,
/// SQL 'bit_count' function.
/// Returns the number of set bits in the input expression.
/// ```sql
/// SELECT BIT_COUNT(column_1) FROM df;
/// ```
#[cfg(feature = "bitwise")]
BitCount,
/// SQL 'bit_or' function.
/// Returns the bitwise OR of the input expressions.
/// ```sql
/// SELECT BIT_OR(column_1, column_2) FROM df;
/// ```
BitOr,
/// SQL 'bit_xor' function.
/// Returns the bitwise XOR of the input expressions.
/// ```sql
/// SELECT BIT_XOR(column_1, column_2) FROM df;
/// ```
BitXor,

// ----
// Math functions
// ----
/// SQL 'abs' function
/// Returns the absolute value of the input column.
/// Returns the absolute value of the input expression.
/// ```sql
/// SELECT ABS(column_1) FROM df;
/// ```
Expand Down Expand Up @@ -142,67 +171,67 @@ pub(crate) enum PolarsSQLFunctions {
// Trig functions
// ----
/// SQL 'cos' function
/// Compute the cosine sine of the input column (in radians).
/// Compute the cosine sine of the input expression (in radians).
/// ```sql
/// SELECT COS(column_1) FROM df;
/// ```
Cos,
/// SQL 'cot' function
/// Compute the cotangent of the input column (in radians).
/// Compute the cotangent of the input expression (in radians).
/// ```sql
/// SELECT COT(column_1) FROM df;
/// ```
Cot,
/// SQL 'sin' function
/// Compute the sine of the input column (in radians).
/// Compute the sine of the input expression (in radians).
/// ```sql
/// SELECT SIN(column_1) FROM df;
/// ```
Sin,
/// SQL 'tan' function
/// Compute the tangent of the input column (in radians).
/// Compute the tangent of the input expression (in radians).
/// ```sql
/// SELECT TAN(column_1) FROM df;
/// ```
Tan,
/// SQL 'cosd' function
/// Compute the cosine sine of the input column (in degrees).
/// Compute the cosine sine of the input expression (in degrees).
/// ```sql
/// SELECT COSD(column_1) FROM df;
/// ```
CosD,
/// SQL 'cotd' function
/// Compute cotangent of the input column (in degrees).
/// Compute cotangent of the input expression (in degrees).
/// ```sql
/// SELECT COTD(column_1) FROM df;
/// ```
CotD,
/// SQL 'sind' function
/// Compute the sine of the input column (in degrees).
/// Compute the sine of the input expression (in degrees).
/// ```sql
/// SELECT SIND(column_1) FROM df;
/// ```
SinD,
/// SQL 'tand' function
/// Compute the tangent of the input column (in degrees).
/// Compute the tangent of the input expression (in degrees).
/// ```sql
/// SELECT TAND(column_1) FROM df;
/// ```
TanD,
/// SQL 'acos' function
/// Compute inverse cosinus of the input column (in radians).
/// Compute inverse cosinus of the input expression (in radians).
/// ```sql
/// SELECT ACOS(column_1) FROM df;
/// ```
Acos,
/// SQL 'asin' function
/// Compute inverse sine of the input column (in radians).
/// Compute inverse sine of the input expression (in radians).
/// ```sql
/// SELECT ASIN(column_1) FROM df;
/// ```
Asin,
/// SQL 'atan' function
/// Compute inverse tangent of the input column (in radians).
/// Compute inverse tangent of the input expression (in radians).
/// ```sql
/// SELECT ATAN(column_1) FROM df;
/// ```
Expand All @@ -214,19 +243,19 @@ pub(crate) enum PolarsSQLFunctions {
/// ```
Atan2,
/// SQL 'acosd' function
/// Compute inverse cosinus of the input column (in degrees).
/// Compute inverse cosinus of the input expression (in degrees).
/// ```sql
/// SELECT ACOSD(column_1) FROM df;
/// ```
AcosD,
/// SQL 'asind' function
/// Compute inverse sine of the input column (in degrees).
/// Compute inverse sine of the input expression (in degrees).
/// ```sql
/// SELECT ASIND(column_1) FROM df;
/// ```
AsinD,
/// SQL 'atand' function
/// Compute inverse tangent of the input column (in degrees).
/// Compute inverse tangent of the input expression (in degrees).
/// ```sql
/// SELECT ATAND(column_1) FROM df;
/// ```
Expand Down Expand Up @@ -656,7 +685,11 @@ impl PolarsSQLFunctions {
"atan2d",
"atand",
"avg",
"bit_and",
"bit_count",
"bit_length",
"bit_or",
"bit_xor",
"cbrt",
"ceil",
"ceiling",
Expand Down Expand Up @@ -741,6 +774,15 @@ impl PolarsSQLFunctions {
fn try_from_sql(function: &'_ SQLFunction, ctx: &'_ SQLContext) -> PolarsResult<Self> {
let function_name = function.name.0[0].value.to_lowercase();
Ok(match function_name.as_str() {
// ----
// Bitwise functions
// ----
"bit_and" | "bitand" => Self::BitAnd,
#[cfg(feature = "bitwise")]
"bit_count" | "bitcount" => Self::BitCount,
"bit_or" | "bitor" => Self::BitOr,
"bit_xor" | "bitxor" | "xor" => Self::BitXor,

// ----
// Math functions
// ----
Expand Down Expand Up @@ -894,6 +936,15 @@ impl SQLFunctionVisitor<'_> {
}

match function_name {
// ----
// Bitwise functions
// ----
BitAnd => self.visit_binary::<Expr>(Expr::and),
#[cfg(feature = "bitwise")]
BitCount => self.visit_unary(Expr::bitwise_count_ones),
BitOr => self.visit_binary::<Expr>(Expr::or),
BitXor => self.visit_binary::<Expr>(Expr::xor),

// ----
// Math functions
// ----
Expand Down
69 changes: 37 additions & 32 deletions crates/polars-sql/src/sql_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -469,48 +469,53 @@ impl SQLExprVisitor<'_> {
rhs = self.convert_temporal_strings(&lhs, &rhs);

Ok(match op {
SQLBinaryOperator::And => lhs.and(rhs),
SQLBinaryOperator::Divide => lhs / rhs,
SQLBinaryOperator::DuckIntegerDivide => lhs.floor_div(rhs).cast(DataType::Int64),
SQLBinaryOperator::Eq => lhs.eq(rhs),
SQLBinaryOperator::Gt => lhs.gt(rhs),
SQLBinaryOperator::GtEq => lhs.gt_eq(rhs),
SQLBinaryOperator::Lt => lhs.lt(rhs),
SQLBinaryOperator::LtEq => lhs.lt_eq(rhs),
SQLBinaryOperator::Minus => lhs - rhs,
SQLBinaryOperator::Modulo => lhs % rhs,
SQLBinaryOperator::Multiply => lhs * rhs,
SQLBinaryOperator::NotEq => lhs.eq(rhs).not(),
SQLBinaryOperator::Or => lhs.or(rhs),
SQLBinaryOperator::Plus => lhs + rhs,
SQLBinaryOperator::Spaceship => lhs.eq_missing(rhs),
SQLBinaryOperator::StringConcat => {
// ----
// Bitwise operators
// ----
SQLBinaryOperator::BitwiseAnd => lhs.and(rhs), // "x & y"
SQLBinaryOperator::BitwiseOr => lhs.or(rhs), // "x | y"
SQLBinaryOperator::Xor => lhs.xor(rhs), // "x XOR y"

// ----
// General operators
// ----
SQLBinaryOperator::And => lhs.and(rhs), // "x AND y"
SQLBinaryOperator::Divide => lhs / rhs, // "x / y"
SQLBinaryOperator::DuckIntegerDivide => lhs.floor_div(rhs).cast(DataType::Int64), // "x // y"
SQLBinaryOperator::Eq => lhs.eq(rhs), // "x = y"
SQLBinaryOperator::Gt => lhs.gt(rhs), // "x > y"
SQLBinaryOperator::GtEq => lhs.gt_eq(rhs), // "x >= y"
SQLBinaryOperator::Lt => lhs.lt(rhs), // "x < y"
SQLBinaryOperator::LtEq => lhs.lt_eq(rhs), // "x <= y"
SQLBinaryOperator::Minus => lhs - rhs, // "x - y"
SQLBinaryOperator::Modulo => lhs % rhs, // "x % y"
SQLBinaryOperator::Multiply => lhs * rhs, // "x * y"
SQLBinaryOperator::NotEq => lhs.eq(rhs).not(), // "x != y"
SQLBinaryOperator::Or => lhs.or(rhs), // "x OR y"
SQLBinaryOperator::Plus => lhs + rhs, // "x + y"
SQLBinaryOperator::Spaceship => lhs.eq_missing(rhs), // "x <=> y"
SQLBinaryOperator::StringConcat => { // "x || y"
lhs.cast(DataType::String) + rhs.cast(DataType::String)
},
SQLBinaryOperator::Xor => lhs.xor(rhs),
SQLBinaryOperator::PGStartsWith => lhs.str().starts_with(rhs),
SQLBinaryOperator::PGStartsWith => lhs.str().starts_with(rhs), // "x ^@ y"
// ----
// Regular expression operators
// ----
// "a ~ b"
SQLBinaryOperator::PGRegexMatch => match rhs {
SQLBinaryOperator::PGRegexMatch => match rhs { // "x ~ y"
Expr::Literal(LiteralValue::String(_)) => lhs.str().contains(rhs, true),
_ => polars_bail!(SQLSyntax: "invalid pattern for '~' operator: {:?}", rhs),
},
// "a !~ b"
SQLBinaryOperator::PGRegexNotMatch => match rhs {
SQLBinaryOperator::PGRegexNotMatch => match rhs { // "x !~ y"
Expr::Literal(LiteralValue::String(_)) => lhs.str().contains(rhs, true).not(),
_ => polars_bail!(SQLSyntax: "invalid pattern for '!~' operator: {:?}", rhs),
},
// "a ~* b"
SQLBinaryOperator::PGRegexIMatch => match rhs {
SQLBinaryOperator::PGRegexIMatch => match rhs { // "x ~* y"
Expr::Literal(LiteralValue::String(pat)) => {
lhs.str().contains(lit(format!("(?i){}", pat)), true)
},
_ => polars_bail!(SQLSyntax: "invalid pattern for '~*' operator: {:?}", rhs),
},
// "a !~* b"
SQLBinaryOperator::PGRegexNotIMatch => match rhs {
SQLBinaryOperator::PGRegexNotIMatch => match rhs { // "x !~* y"
Expr::Literal(LiteralValue::String(pat)) => {
lhs.str().contains(lit(format!("(?i){}", pat)), true).not()
},
Expand All @@ -521,10 +526,10 @@ impl SQLExprVisitor<'_> {
// ----
// LIKE/ILIKE operators
// ----
SQLBinaryOperator::PGLikeMatch
| SQLBinaryOperator::PGNotLikeMatch
| SQLBinaryOperator::PGILikeMatch
| SQLBinaryOperator::PGNotILikeMatch => {
SQLBinaryOperator::PGLikeMatch // "x ~~ y"
| SQLBinaryOperator::PGNotLikeMatch // "x !~~ y"
| SQLBinaryOperator::PGILikeMatch // "x ~~* y"
| SQLBinaryOperator::PGNotILikeMatch => { // "x !~~* y"
let expr = if matches!(
op,
SQLBinaryOperator::PGLikeMatch | SQLBinaryOperator::PGNotLikeMatch
Expand All @@ -548,7 +553,7 @@ impl SQLExprVisitor<'_> {
// ----
// JSON/Struct field access operators
// ----
SQLBinaryOperator::Arrow | SQLBinaryOperator::LongArrow => match rhs {
SQLBinaryOperator::Arrow | SQLBinaryOperator::LongArrow => match rhs { // "x -> y", "x ->> y"
Expr::Literal(LiteralValue::String(path)) => {
let mut expr = self.struct_field_access_expr(&lhs, &path, false)?;
if let SQLBinaryOperator::LongArrow = op {
Expand All @@ -567,7 +572,7 @@ impl SQLExprVisitor<'_> {
polars_bail!(SQLSyntax: "invalid json/struct path-extract definition: {:?}", right)
},
},
SQLBinaryOperator::HashArrow | SQLBinaryOperator::HashLongArrow => {
SQLBinaryOperator::HashArrow | SQLBinaryOperator::HashLongArrow => { // "x #> y", "x #>> y"
if let Expr::Literal(LiteralValue::String(path)) = rhs {
let mut expr = self.struct_field_access_expr(&lhs, &path, true)?;
if let SQLBinaryOperator::HashLongArrow = op {
Expand Down
1 change: 1 addition & 0 deletions crates/polars-stream/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,4 @@ version_check = { workspace = true }
[features]
nightly = []
bitwise = ["polars-core/bitwise", "polars-plan/bitwise"]
merge_sorted = ["polars-plan/merge_sorted"]
1 change: 1 addition & 0 deletions crates/polars-stream/src/physical_plan/lower_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ pub fn lower_ir(
IR::MapFunction { input, function } => {
// MergeSorted uses a rechunk hack incompatible with the
// streaming engine.
#[cfg(feature = "merge_sorted")]
if let FunctionIR::MergeSorted { .. } = function {
todo!()
}
Expand Down
8 changes: 7 additions & 1 deletion crates/polars/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,13 @@ array_any_all = ["polars-lazy?/array_any_all", "dtype-array"]
asof_join = ["polars-lazy?/asof_join", "polars-ops/asof_join"]
iejoin = ["polars-lazy?/iejoin"]
binary_encoding = ["polars-ops/binary_encoding", "polars-lazy?/binary_encoding", "polars-sql?/binary_encoding"]
bitwise = ["polars-core/bitwise", "polars-plan?/bitwise", "polars-ops/bitwise", "polars-lazy?/bitwise"]
bitwise = [
"polars-core/bitwise",
"polars-plan?/bitwise",
"polars-ops/bitwise",
"polars-lazy?/bitwise",
"polars-sql?/bitwise",
]
business = ["polars-lazy?/business", "polars-ops/business"]
checked_arithmetic = ["polars-core/checked_arithmetic"]
chunked_ids = ["polars-ops?/chunked_ids"]
Expand Down
Loading

0 comments on commit 08732c4

Please sign in to comment.