Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: introduce the framework of sqlsmith #3305

Merged
merged 21 commits into from
Jun 20, 2022
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ members = [
"src/stream",
"src/test_runner",
"src/tests/regress",
"src/tests/sqlsmith",
"src/utils/logging",
"src/utils/memcomparable",
"src/utils/pgwire",
Expand Down
1 change: 1 addition & 0 deletions src/frontend/src/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ mod values;

pub use bind_context::BindContext;
pub use delete::BoundDelete;
pub use expr::bind_data_type;
pub use insert::BoundInsert;
pub use query::BoundQuery;
pub use relation::{
Expand Down
5 changes: 4 additions & 1 deletion src/frontend/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ pub type ExprType = risingwave_pb::expr::expr_node::Type;

pub use expr_rewriter::ExprRewriter;
pub use expr_visitor::ExprVisitor;
pub use type_inference::{align_types, cast_ok, infer_type, least_restrictive, CastContext};
pub use type_inference::{
align_types, cast_ok, func_sig_map, infer_type, least_restrictive, CastContext, DataTypeName,
FuncSign,
};
pub use utils::*;

/// the trait of bound exprssions
Expand Down
4 changes: 2 additions & 2 deletions src/frontend/src/expr/type_inference/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use itertools::Itertools as _;
use risingwave_common::error::{ErrorCode, Result};
use risingwave_common::types::DataType;

use super::{name_of, DataTypeName};
use super::DataTypeName;
use crate::expr::{Expr as _, ExprImpl};

/// Find the least restrictive type. Used by `VALUES`, `CASE`, `UNION`, etc.
Expand Down Expand Up @@ -77,7 +77,7 @@ pub enum CastContext {

/// Checks whether casting from `source` to `target` is ok in `allows` context.
pub fn cast_ok(source: &DataType, target: &DataType, allows: &CastContext) -> bool {
let k = (name_of(source), name_of(target));
let k = (DataTypeName::from(source), DataTypeName::from(target));
matches!(CAST_MAP.get(&k), Some(context) if context <= allows)
}

Expand Down
17 changes: 11 additions & 6 deletions src/frontend/src/expr/type_inference/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ use itertools::iproduct;
use risingwave_common::error::{ErrorCode, Result};
use risingwave_common::types::DataType;

use super::{name_of, DataTypeName};
use super::DataTypeName;
use crate::expr::ExprType;

/// Infers the return type of a function. Returns `Err` if the function with specified data types
/// is not supported on backend.
pub fn infer_type(func_type: ExprType, inputs_type: Vec<DataType>) -> Result<DataType> {
// With our current simplified type system, where all types are nullable and not parameterized
// by things like length or precision, the inference can be done with a map lookup.
let input_type_names = inputs_type.iter().map(name_of).collect();
let input_type_names = inputs_type.iter().map(DataTypeName::from).collect();
infer_type_name(func_type, input_type_names).map(|type_name| match type_name {
DataTypeName::Boolean => DataType::Boolean,
DataTypeName::Int16 => DataType::Int16,
Expand Down Expand Up @@ -58,10 +58,10 @@ fn infer_type_name(func_type: ExprType, inputs_type: Vec<DataTypeName>) -> Resul
})
}

#[derive(PartialEq, Hash)]
struct FuncSign {
func: ExprType,
inputs_type: Vec<DataTypeName>,
#[derive(PartialEq, Hash, Clone)]
pub struct FuncSign {
pub func: ExprType,
pub inputs_type: Vec<DataTypeName>,
}

impl Eq for FuncSign {}
Expand Down Expand Up @@ -327,6 +327,11 @@ lazy_static::lazy_static! {
};
}

/// The table of function signatures.
pub fn func_sig_map() -> &'static HashMap<FuncSign, DataTypeName> {
&*FUNC_SIG_MAP
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
46 changes: 27 additions & 19 deletions src/frontend/src/expr/type_inference/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ use risingwave_common::types::DataType;
mod cast;
mod func;
pub use cast::{align_types, cast_ok, least_restrictive, CastContext};
pub use func::infer_type;
pub use func::{func_sig_map, infer_type, FuncSign};

/// `DataTypeName` is designed for type derivation here. In other scenarios,
/// use `DataType` instead.
#[derive(Debug, Clone, PartialEq, Eq, Hash, Copy)]
enum DataTypeName {
pub enum DataTypeName {
Boolean,
Int16,
Int32,
Expand All @@ -42,22 +42,30 @@ enum DataTypeName {
List,
}

fn name_of(ty: &DataType) -> DataTypeName {
match ty {
DataType::Boolean => DataTypeName::Boolean,
DataType::Int16 => DataTypeName::Int16,
DataType::Int32 => DataTypeName::Int32,
DataType::Int64 => DataTypeName::Int64,
DataType::Decimal => DataTypeName::Decimal,
DataType::Float32 => DataTypeName::Float32,
DataType::Float64 => DataTypeName::Float64,
DataType::Varchar => DataTypeName::Varchar,
DataType::Date => DataTypeName::Date,
DataType::Timestamp => DataTypeName::Timestamp,
DataType::Timestampz => DataTypeName::Timestampz,
DataType::Time => DataTypeName::Time,
DataType::Interval => DataTypeName::Interval,
DataType::Struct { .. } => DataTypeName::Struct,
DataType::List { .. } => DataTypeName::List,
impl From<&DataType> for DataTypeName {
fn from(ty: &DataType) -> Self {
match ty {
DataType::Boolean => DataTypeName::Boolean,
DataType::Int16 => DataTypeName::Int16,
DataType::Int32 => DataTypeName::Int32,
DataType::Int64 => DataTypeName::Int64,
DataType::Decimal => DataTypeName::Decimal,
DataType::Float32 => DataTypeName::Float32,
DataType::Float64 => DataTypeName::Float64,
DataType::Varchar => DataTypeName::Varchar,
DataType::Date => DataTypeName::Date,
DataType::Timestamp => DataTypeName::Timestamp,
DataType::Timestampz => DataTypeName::Timestampz,
DataType::Time => DataTypeName::Time,
DataType::Interval => DataTypeName::Interval,
DataType::Struct { .. } => DataTypeName::Struct,
DataType::List { .. } => DataTypeName::List,
}
}
}

impl From<DataType> for DataTypeName {
fn from(ty: DataType) -> Self {
(&ty).into()
}
}
21 changes: 13 additions & 8 deletions src/frontend/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -538,14 +538,15 @@ impl Session for SessionImpl {
tracing::error!("failed to parse sql:\n{}:\n{}", sql, e);
e
})?;
// With pgwire, there would be at most 1 statement in the vec.
assert!(stmts.len() <= 1);
if stmts.is_empty() {
return Ok(PgResponse::new(
return Ok(PgResponse::empty_result(
pgwire::pg_response::StatementType::EMPTY,
0,
vec![],
vec![],
));
}
if stmts.len() > 1 {
return Ok(PgResponse::empty_result_with_notice(
pgwire::pg_response::StatementType::EMPTY,
"cannot insert multiple commands into statement".to_string(),
));
}
let stmt = stmts.swap_remove(0);
Expand All @@ -565,11 +566,15 @@ impl Session for SessionImpl {
tracing::error!("failed to parse sql:\n{}:\n{}", sql, e);
e
})?;
// With pgwire, there would be at most 1 statement in the vec.
assert!(stmts.len() <= 1);
if stmts.is_empty() {
return Ok(vec![]);
}
if stmts.len() > 1 {
return Err(Box::new(Error::new(
ErrorKind::InvalidInput,
"cannot insert multiple commands into statement",
)));
}
let stmt = stmts.swap_remove(0);
let rsp = infer(self, stmt).map_err(|e| {
tracing::error!("failed to handle sql:\n{}:\n{}", sql, e);
Expand Down
37 changes: 36 additions & 1 deletion src/sqlparser/src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,13 @@ impl fmt::Display for Expr {
low,
high
),
Expr::BinaryOp { left, op, right } => write!(f, "{} {} {}", left, op, right),
Expr::BinaryOp { left, op, right } => write!(
f,
"{} {} {}",
fmt_expr_with_paren(left),
op,
fmt_expr_with_paren(right)
),
Expr::UnaryOp { op, expr } => {
if op == &UnaryOperator::PGPostfixFactorial {
write!(f, "{}{}", expr, op)
Expand Down Expand Up @@ -481,6 +487,35 @@ impl fmt::Display for Expr {
}
}

/// Wrap complex expressions with parentheses.
fn fmt_expr_with_paren(e: &Expr) -> String {
use BinaryOperator as B;
if let Expr::BinaryOp { op, .. } = e {
match op {
B::Plus
| B::Multiply
| B::Modulo
| B::Minus
| B::LtEq
| B::GtEq
| B::Eq
| B::Gt
| B::Lt
| B::Xor
| B::NotEq
| B::Divide
| B::BitwiseAnd
| B::BitwiseOr
| B::BitwiseXor
| B::PGBitwiseXor
| B::PGBitwiseShiftLeft
| B::PGBitwiseShiftRight => return format!("({})", e),
_ => {}
}
}
format!("{}", e)
}

/// A window specification (i.e. `OVER (PARTITION BY .. ORDER BY .. etc.)`)
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
Expand Down
7 changes: 1 addition & 6 deletions src/sqlparser/test_runner/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,7 @@ use anyhow::{anyhow, Result};
use risingwave_sqlparser::parser::Parser;
use serde::Deserialize;

// 1. The input sql.
// 2. ---
// 3. If sql parsing succeeds, the line is the formatted sql.
// Otherwise, it is the error message.
// 4. => No exist if the parsing is expected to fail.
// 5. The formatted ast.
/// `TestCase` will be deserialized from yaml.
#[derive(PartialEq, Eq, Debug, Deserialize)]
struct TestCase {
input: String,
Expand Down
24 changes: 10 additions & 14 deletions src/sqlparser/tests/sqlparser_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -437,14 +437,6 @@ fn parse_collate() {
);
}

#[test]
fn parse_select_string_predicate() {
let sql = "SELECT id, fname, lname FROM customer \
WHERE salary <> 'Not Provided' AND salary <> ''";
let _ast = verified_only_select(sql);
// TODO: add assertions
}

#[test]
fn parse_projection_nested_type() {
let sql = "SELECT customer.address.state FROM foo";
Expand Down Expand Up @@ -505,6 +497,8 @@ fn parse_compound_expr_1() {
use self::BinaryOperator::*;
use self::Expr::*;
let sql = "a + b * c";
let ast = run_parser_method(sql, |parser| parser.parse_expr()).unwrap();
assert_eq!("a + (b * c)", &ast.to_string());
assert_eq!(
BinaryOp {
left: Box::new(Identifier(Ident::new("a"))),
Expand All @@ -515,7 +509,7 @@ fn parse_compound_expr_1() {
right: Box::new(Identifier(Ident::new("c")))
})
},
verified_expr(sql)
ast
);
}

Expand All @@ -524,6 +518,8 @@ fn parse_compound_expr_2() {
use self::BinaryOperator::*;
use self::Expr::*;
let sql = "a * b + c";
let ast = run_parser_method(sql, |parser| parser.parse_expr()).unwrap();
assert_eq!("(a * b) + c", &ast.to_string());
assert_eq!(
BinaryOp {
left: Box::new(BinaryOp {
Expand All @@ -534,7 +530,7 @@ fn parse_compound_expr_2() {
op: Plus,
right: Box::new(Identifier(Ident::new("c")))
},
verified_expr(sql)
ast
);
}

Expand Down Expand Up @@ -917,15 +913,15 @@ fn parse_between_with_expr() {
select.selection.unwrap()
);

let sql = "SELECT * FROM t WHERE 1 = 1 AND 1 + x BETWEEN 1 AND 2";
let sql = "SELECT * FROM t WHERE (1 = 1) AND 1 + x BETWEEN 1 AND 2";
let select = verified_only_select(sql);
assert_eq!(
Expr::BinaryOp {
left: Box::new(Expr::BinaryOp {
left: Box::new(Expr::Nested(Box::new(Expr::BinaryOp {
left: Box::new(Expr::Value(number("1"))),
op: BinaryOperator::Eq,
right: Box::new(Expr::Value(number("1"))),
}),
}))),
op: BinaryOperator::And,
right: Box::new(Expr::Between {
expr: Box::new(Expr::BinaryOp {
Expand Down Expand Up @@ -1558,7 +1554,7 @@ fn parse_alter_table_constraints() {
check_one("PRIMARY KEY (foo, bar)");
check_one("UNIQUE (id)");
check_one("FOREIGN KEY (foo, bar) REFERENCES AnotherTable(foo, bar)");
check_one("CHECK (end_date > start_date OR end_date IS NULL)");
check_one("CHECK ((end_date > start_date) OR end_date IS NULL)");

fn check_one(constraint_text: &str) {
match verified_stmt(&format!("ALTER TABLE tab ADD {}", constraint_text)) {
Expand Down
3 changes: 3 additions & 0 deletions src/sqlparser/tests/testdata/select.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,6 @@
formatted_sql: SELECT * FROM unnest(ARRAY[1, 2, 3])
formatted_ast: |
Query(Query { with: None, body: Select(Select { distinct: false, projection: [Wildcard], from: [TableWithJoins { relation: Table { name: ObjectName([Ident { value: "unnest", quote_style: None }]), alias: None, args: [Unnamed(Expr(Array([Value(Number("1", false)), Value(Number("2", false)), Value(Number("3", false))])))] }, joins: [] }], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })

- input: SELECT id, fname, lname FROM customer WHERE salary <> 'Not Provided' AND salary <> ''
formatted_sql: SELECT id, fname, lname FROM customer WHERE (salary <> 'Not Provided') AND (salary <> '')
Loading