From c7bbcd13b5ce62515dd77f664e25c59c522b80c1 Mon Sep 17 00:00:00 2001 From: Moritz Bischof Date: Wed, 13 Oct 2021 03:01:15 +0200 Subject: [PATCH 01/30] WIP: initial implementation of conditional_query_as! (#1488) --- Cargo.toml | 8 + sqlx-macros/src/lib.rs | 8 + sqlx-macros/src/query/conditional/map.rs | 82 +++++ sqlx-macros/src/query/conditional/mod.rs | 332 +++++++++++++++++++ sqlx-macros/src/query/conditional/segment.rs | 202 +++++++++++ sqlx-macros/src/query/mod.rs | 1 + src/lib.rs | 2 +- tests/conditional_query/conditional_query.rs | 112 +++++++ 8 files changed, 746 insertions(+), 1 deletion(-) create mode 100644 sqlx-macros/src/query/conditional/map.rs create mode 100644 sqlx-macros/src/query/conditional/mod.rs create mode 100644 sqlx-macros/src/query/conditional/segment.rs create mode 100644 tests/conditional_query/conditional_query.rs diff --git a/Cargo.toml b/Cargo.toml index aab4f3b613..5e9112b91f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -174,6 +174,14 @@ name = "migrate-macro" path = "tests/migrate/macro.rs" required-features = ["macros", "migrate"] +# +# Conditional query +# + +[[test]] +name = "conditional-query" +path = "tests/conditional_query/conditional_query.rs" + # # SQLite # diff --git a/sqlx-macros/src/lib.rs b/sqlx-macros/src/lib.rs index c1f173e655..d40822a9a6 100644 --- a/sqlx-macros/src/lib.rs +++ b/sqlx-macros/src/lib.rs @@ -41,6 +41,14 @@ pub fn expand_query(input: TokenStream) -> TokenStream { } } +#[proc_macro] +pub fn conditional_query_as(input: TokenStream) -> TokenStream { + match query::conditional::conditional_query_as(input.into()) { + Ok(output) => output, + Err(err) => err.to_compile_error() + }.into() +} + #[proc_macro_derive(Encode, attributes(sqlx))] pub fn derive_encode(tokenstream: TokenStream) -> TokenStream { let input = syn::parse_macro_input!(tokenstream as syn::DeriveInput); diff --git a/sqlx-macros/src/query/conditional/map.rs b/sqlx-macros/src/query/conditional/map.rs new file mode 100644 index 0000000000..b1dece392e --- /dev/null +++ b/sqlx-macros/src/query/conditional/map.rs @@ -0,0 +1,82 @@ +use proc_macro2::{Span, TokenStream}; +use quote::quote; +use syn::Ident; + +pub fn generate_conditional_map(n: usize) -> TokenStream { + let call_site = Span::call_site(); + let map_fns = (1..=n).map(|i| Ident::new(&format!("F{}", i), call_site)).collect::>(); + let args = (1..=n).map(|i| Ident::new(&format!("A{}", i), call_site)).collect::>(); + let variants = (1..=n).map(|i| Ident::new(&format!("_{}", i), call_site)).collect::>(); + let variant_declarations = (0..n).map(|i| { + let variant = &variants[i]; + let map_fn = &map_fns[i]; + let args = &args[i]; + quote!(#variant(sqlx::query::Map<'q, DB, #map_fn, #args>)) + }); + + quote! { + #[doc(hidden)] + pub enum ConditionalMap<'q, DB, O, #(#map_fns,)* #(#args,)*> + where + DB: sqlx::Database, + O: Send + Unpin, + #(#map_fns: FnMut(DB::Row) -> sqlx::Result + Send,)* + #(#args: 'q + Send + sqlx::IntoArguments<'q, DB>,)* + { + #(#variant_declarations),* + } + impl<'q, DB, O, #(#map_fns,)* #(#args,)*> ConditionalMap<'q, DB, O, #(#map_fns,)* #(#args,)*> + where + DB: sqlx::Database, + O: Send + Unpin, + #(#map_fns: FnMut(DB::Row) -> sqlx::Result + Send,)* + #(#args: 'q + Send + sqlx::IntoArguments<'q, DB>,)* + { + pub fn fetch<'e, 'c: 'e, E>(self, executor: E) -> ormx::exports::futures::stream::BoxStream<'e, sqlx::Result> + where + 'q: 'e, + E: 'e + sqlx::Executor<'c, Database = DB>, + DB: 'e, + O: 'e, + #(#map_fns: 'e,)* + { + match self { #( + Self::#variants(x) => x.fetch(executor) + ),* } + } + pub async fn fetch_all<'e, 'c: 'e, E>(self, executor: E) -> sqlx::Result> + where + 'q: 'e, + DB: 'e, + E: 'e + sqlx::Executor<'c, Database = DB>, + O: 'e + { + match self { #( + Self::#variants(x) => x.fetch_all(executor).await + ),* } + } + pub async fn fetch_one<'e, 'c: 'e, E>(self, executor: E) -> sqlx::Result + where + 'q: 'e, + E: 'e + sqlx::Executor<'c, Database = DB>, + DB: 'e, + O: 'e, + { + match self { #( + Self::#variants(x) => x.fetch_one(executor).await + ),* } + } + pub async fn fetch_optional<'e, 'c: 'e, E>(self, executor: E) -> sqlx::Result> + where + 'q: 'e, + E: 'e + sqlx::Executor<'c, Database = DB>, + DB: 'e, + O: 'e, + { + match self { #( + Self::#variants(x) => x.fetch_optional(executor).await + ),* } + } + } + } +} \ No newline at end of file diff --git a/sqlx-macros/src/query/conditional/mod.rs b/sqlx-macros/src/query/conditional/mod.rs new file mode 100644 index 0000000000..f9d4e81b1e --- /dev/null +++ b/sqlx-macros/src/query/conditional/mod.rs @@ -0,0 +1,332 @@ +use std::rc::Rc; + +use proc_macro2::{Span, TokenStream}; +use quote::quote; +use syn::{ + Expr, + Ident, parse::{Parse, ParseStream}, Pat, Path, Result, Token, +}; + +use segment::*; + +mod map; +mod segment; + +/// Expand a call to `conditional_query_as!` +pub fn conditional_query_as(input: TokenStream) -> Result { + let input = syn::parse2::(input)?; + let ctx = input.to_context(); + let out = ctx.generate_output(input.testing); + Ok(out) +} + +/// Input to the conditional query macro. +struct Input { + /// `true` if the macro should only output information about the query instead of actually + /// calling a query macro + testing: bool, + query_as: Path, + segments: Vec, +} + +impl Input { + /// Convert the input into a context + fn to_context(&self) -> Context { + let mut ctx = Context::Default(NormalContext { + query_as: Rc::new(self.query_as.clone()), + sql: String::new(), + args: vec![], + }); + + for segment in &self.segments { + ctx.add_segment(segment); + } + + ctx + } +} + +syn::custom_keyword!(testing); + +impl Parse for Input { + fn parse(input: ParseStream) -> Result { + let testing = input.parse::>()?.is_some(); + let query_as = input.parse::()?; + input.parse::()?; + Ok(Self { + testing, + query_as, + segments: QuerySegment::parse_all(input)?, + }) + } +} + +/// A context describes the current position within a conditional query. +#[derive(Clone)] +enum Context { + Default(NormalContext), + If(IfContext), + Match(MatchContext), +} + +trait IsContext { + /// Return the number of branches the current context, including its children, contains. + fn branches(&self) -> usize; + /// Generate a call to a sqlx query macro for this context. + fn to_query(&self, testing: bool, branch_counter: &mut usize) -> TokenStream; + /// Add a piece of an SQL query to this context. + fn add_sql(&mut self, sql: &SqlSegment); + /// Add an argument to this context. + fn add_arg(&mut self, arg: &ArgSegment); +} + +impl IsContext for Context { + fn branches(&self) -> usize { + self.as_dyn().branches() + } + + fn to_query(&self, testing: bool, branch_counter: &mut usize) -> TokenStream { + self.as_dyn().to_query(testing, branch_counter) + } + + fn add_sql(&mut self, sql: &SqlSegment) { + self.as_dyn_mut().add_sql(sql); + } + + fn add_arg(&mut self, arg: &ArgSegment) { + self.as_dyn_mut().add_arg(arg); + } +} + +impl Context { + fn generate_output(&self, testing: bool) -> TokenStream { + let branches = self.branches(); + + let result = { + let mut branch_counter = 1; + let output = self.to_query(testing, &mut branch_counter); + assert_eq!(branch_counter, branches + 1); + output + }; + + if testing { + quote!( #result ) + } else { + let map = map::generate_conditional_map(branches); + quote!( { #map #result } ) + } + } + + fn as_dyn(&self) -> &dyn IsContext { + match self { + Context::Default(c) => c as _, + Context::If(c) => c as _, + Context::Match(c) => c as _, + } + } + + fn as_dyn_mut(&mut self) -> &mut dyn IsContext { + match self { + Context::Default(c) => c as _, + Context::If(c) => c as _, + Context::Match(c) => c as _, + } + } + + fn add_segment(&mut self, s: &QuerySegment) { + match s { + QuerySegment::Sql(s) => self.add_sql(s), + QuerySegment::If(s) => self.add_if(s), + QuerySegment::Match(s) => self.add_match(s), + QuerySegment::Arg(s) => self.add_arg(s), + } + } + + fn add_if(&mut self, s: &IfSegment) { + let mut if_ctx = IfContext { + condition: s.condition.clone(), + then: Box::new(self.clone()), + or_else: Box::new(self.clone()), + }; + for then in &s.then { + if_ctx.then.add_segment(then); + } + for or_else in &s.or_else { + if_ctx.or_else.add_segment(or_else); + } + // replace the current context with the new IfContext + *self = Context::If(if_ctx); + } + + fn add_match(&mut self, s: &MatchSegment) { + let arms = s + .arms + .iter() + .map(|arm| { + let mut arm_ctx = MatchArmContext { + pattern: arm.pat.clone(), + inner: Box::new(self.clone()), + }; + for segment in &arm.body { + arm_ctx.inner.add_segment(segment); + } + arm_ctx + }) + .collect::>(); + + let match_ctx = MatchContext { + expr: s.expr.clone(), + arms, + }; + + // replace the current context with the new MatchContext + *self = Context::Match(match_ctx); + } +} + +/// A "normal" linear context without any branches. +#[derive(Clone)] +struct NormalContext { + query_as: Rc, + sql: String, + args: Vec, +} + +impl IsContext for NormalContext { + fn branches(&self) -> usize { + 1 + } + + fn to_query(&self, testing: bool, branch_counter: &mut usize) -> TokenStream { + let NormalContext { query_as, sql, args } = self; + *branch_counter += 1; + + if testing { + quote! { + (stringify!(#query_as), #sql, vec![#(stringify!(#args)),*]) + as (&'static str, &'static str, Vec<&'static str>) + } + } else { + let variant = Ident::new(&format!("_{}", branch_counter), Span::call_site()); + quote!(ConditionalMap::#variant(sqlx::query_as!(#query_as, #sql, #(#args),*))) + } + } + + fn add_sql(&mut self, sql: &SqlSegment) { + if !self.sql.is_empty() { + self.sql.push(' '); + } + self.sql.push_str(&sql.query); + } + + fn add_arg(&mut self, arg: &ArgSegment) { + if cfg!(feature = "postgres") { + self.sql.push_str(&format!(" ${}", self.args.len() + 1)); + } else { + self.sql.push_str(" ?"); + } + self.args.push(arg.argument.clone()); + } +} + +/// Context within an `if .. {..} else ..` clause. +#[derive(Clone)] +struct IfContext { + condition: Expr, + then: Box, + or_else: Box, +} + +impl IsContext for IfContext { + fn branches(&self) -> usize { + self.then.branches() + self.or_else.branches() + } + + fn to_query(&self, testing: bool, branch_counter: &mut usize) -> TokenStream { + let condition = &self.condition; + let then = self.then.to_query(testing, branch_counter); + let or_else = self.or_else.to_query(testing, branch_counter); + quote! { + if #condition { + #then + } else { + #or_else + } + } + } + + fn add_sql(&mut self, sql: &SqlSegment) { + self.then.add_sql(sql); + self.or_else.add_sql(sql); + } + + fn add_arg(&mut self, arg: &ArgSegment) { + self.then.add_arg(arg); + self.or_else.add_arg(arg); + } +} + +/// Context within `match .. { .. }` +#[derive(Clone)] +struct MatchContext { + expr: Expr, + arms: Vec, +} + +impl IsContext for MatchContext { + fn branches(&self) -> usize { + self.arms.iter().map(|arm| arm.branches()).sum() + } + + fn to_query(&self, testing: bool, branch_counter: &mut usize) -> TokenStream { + let expr = &self.expr; + let arms = self + .arms + .iter() + .map(|arm| arm.to_query(testing, branch_counter)); + quote! { + match #expr { #(#arms,)* } + } + } + + fn add_sql(&mut self, sql: &SqlSegment) { + for arm in &mut self.arms { + arm.add_sql(sql); + } + } + + fn add_arg(&mut self, arg: &ArgSegment) { + for arm in &mut self.arms { + arm.add_arg(arg) + } + } +} + +/// Context within the arm (`Pat => ..`) of a `match` +#[derive(Clone)] +struct MatchArmContext { + pattern: Pat, + inner: Box, +} + +impl IsContext for MatchArmContext { + fn branches(&self) -> usize { + self.inner.branches() + } + + fn to_query(&self, testing: bool, branch_counter: &mut usize) -> TokenStream { + let pat = &self.pattern; + let inner = self.inner.to_query(testing, branch_counter); + quote! { + #pat => #inner + } + } + + fn add_sql(&mut self, sql: &SqlSegment) { + self.inner.add_sql(sql); + } + + fn add_arg(&mut self, arg: &ArgSegment) { + self.inner.add_arg(arg); + } +} \ No newline at end of file diff --git a/sqlx-macros/src/query/conditional/segment.rs b/sqlx-macros/src/query/conditional/segment.rs new file mode 100644 index 0000000000..fb5514857d --- /dev/null +++ b/sqlx-macros/src/query/conditional/segment.rs @@ -0,0 +1,202 @@ +use syn::{ + braced, + Error, + Expr, LitStr, parse::{Parse, ParseStream, Peek}, Pat, Token, +}; + +/// A single "piece" of the input. +pub enum QuerySegment { + /// A part of an SQL query, like `"SELECT *"` + Sql(SqlSegment), + /// An `if .. { .. }`, with optional `else ..` + If(IfSegment), + /// An exhaustive `match .. { .. }` + Match(MatchSegment), + /// A query argument. Can be an arbitrary expression, prefixed by `?`, like `?search.trim()` + Arg(ArgSegment), +} + +impl QuerySegment { + /// Parse segments up to the first occurrence of the given token, or until the input is empty. + fn parse_until(input: ParseStream, until: T) -> syn::Result> { + let mut segments = vec![]; + while !input.is_empty() && !input.peek(until) { + segments.push(QuerySegment::parse(input)?) + } + Ok(segments) + } + + /// Parse segments until the input is empty. + pub fn parse_all(input: ParseStream) -> syn::Result> { + let mut segments = vec![]; + while !input.is_empty() { + segments.push(QuerySegment::parse(input)?); + } + Ok(segments) + } +} + +pub struct ArgSegment { + pub argument: Expr, +} + +impl ArgSegment { + const EXPECT: &'static str = "?.."; + + fn matches(input: ParseStream) -> bool { + input.peek(Token![?]) + } +} + +impl Parse for ArgSegment { + fn parse(input: ParseStream) -> syn::Result { + input.parse::()?; + Ok(Self { + argument: input.parse::()?, + }) + } +} + +pub struct SqlSegment { + pub query: String, +} + +impl SqlSegment { + const EXPECT: &'static str = "\"..\""; + + fn matches(input: &ParseStream) -> bool { + input.fork().parse::().is_ok() + } +} + +impl Parse for SqlSegment { + fn parse(input: ParseStream) -> syn::Result { + let lit = input.parse::()?; + Ok(Self { query: lit.value() }) + } +} + +pub struct MatchSegment { + pub expr: Expr, + pub arms: Vec, +} + +pub struct MatchSegmentArm { + pub pat: Pat, + pub body: Vec, +} + +impl MatchSegmentArm { + fn parse_all(input: ParseStream) -> syn::Result> { + let mut arms = vec![]; + while !input.is_empty() { + arms.push(Self::parse(input)?); + } + Ok(arms) + } +} + +impl Parse for MatchSegmentArm { + fn parse(input: ParseStream) -> syn::Result { + let pat = input.parse()?; + input.parse::]>()?; + let body = if input.peek(syn::token::Brace) { + let body; + braced!(body in input); + QuerySegment::parse_all(&body)? + } else { + QuerySegment::parse_until(input, Token![,])? + }; + input.parse::>()?; + Ok(Self { pat, body }) + } +} + +impl MatchSegment { + const EXPECT: &'static str = "match .. { .. }"; + + fn matches(input: ParseStream) -> bool { + input.peek(Token![match]) + } +} + +impl Parse for MatchSegment { + fn parse(input: ParseStream) -> syn::Result { + input.parse::()?; + let expr = input.call(Expr::parse_without_eager_brace)?; + let input = { + let content; + braced!(content in input); + content + }; + let arms = MatchSegmentArm::parse_all(&input)?; + + Ok(Self { expr, arms }) + } +} + +pub struct IfSegment { + pub condition: Expr, + pub then: Vec, + pub or_else: Vec, +} + +impl IfSegment { + const EXPECT: &'static str = "if { .. }"; + + fn matches(input: ParseStream) -> bool { + input.peek(Token![if]) + } +} + +impl Parse for IfSegment { + fn parse(input: ParseStream) -> syn::Result { + input.parse::()?; + let condition = input.call(Expr::parse_without_eager_brace)?; + let then = { + let if_then; + braced!(if_then in input); + QuerySegment::parse_all(&if_then)? + }; + let or_else = if input.parse::>()?.is_some() { + if IfSegment::matches(input) { + let else_if = IfSegment::parse(input)?; + vec![QuerySegment::If(else_if)] + } else { + let or_else; + braced!(or_else in input); + QuerySegment::parse_all(&or_else)? + } + } else { + vec![] + }; + Ok(Self { + condition, + then, + or_else, + }) + } +} + +impl Parse for QuerySegment { + fn parse(input: ParseStream) -> syn::Result { + if SqlSegment::matches(&input) { + Ok(QuerySegment::Sql(SqlSegment::parse(input)?)) + } else if IfSegment::matches(&input) { + Ok(QuerySegment::If(IfSegment::parse(input)?)) + } else if MatchSegment::matches(input) { + Ok(QuerySegment::Match(MatchSegment::parse(input)?)) + } else if ArgSegment::matches(input) { + Ok(QuerySegment::Arg(ArgSegment::parse(input)?)) + } else { + let error = format!( + "expected `{}`, `{}`, `{}` or `{}`", + SqlSegment::EXPECT, + IfSegment::EXPECT, + MatchSegment::EXPECT, + ArgSegment::EXPECT + ); + Err(Error::new(input.span(), error)) + } + } +} diff --git a/sqlx-macros/src/query/mod.rs b/sqlx-macros/src/query/mod.rs index aa862bd743..20cdc631f8 100644 --- a/sqlx-macros/src/query/mod.rs +++ b/sqlx-macros/src/query/mod.rs @@ -23,6 +23,7 @@ mod args; mod data; mod input; mod output; +pub mod conditional; struct Metadata { #[allow(unused)] diff --git a/src/lib.rs b/src/lib.rs index db1d361b08..8d0c8e30c3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -71,7 +71,7 @@ pub extern crate sqlx_macros; // derives #[cfg(feature = "macros")] #[doc(hidden)] -pub use sqlx_macros::{FromRow, Type}; +pub use sqlx_macros::{FromRow, Type, conditional_query_as}; #[cfg(feature = "macros")] mod macros; diff --git a/tests/conditional_query/conditional_query.rs b/tests/conditional_query/conditional_query.rs new file mode 100644 index 0000000000..3dc30ac75a --- /dev/null +++ b/tests/conditional_query/conditional_query.rs @@ -0,0 +1,112 @@ +#[test] +fn simple() { + let result = sqlx::conditional_query_as!( + testing X, + "A" "B" ?id + 1 + ); + assert_eq!(result, ("X", "A B $1", vec!["id + 1"])) +} + +#[test] +fn single_if() { + let limit = Some(1); + let result = sqlx::conditional_query_as!( + testing X, + "A" "B" ?id + 1 + if let Some(limit) = limit { "C" ?limit } + "D" + ); + assert_eq!(result, ("X", "A B $1 C $2 D", vec!["id + 1", "limit"])) +} + +#[test] +fn if_else() { + let value = true; + let result = sqlx::conditional_query_as!( + testing X, + "A" if value { "B" } else { "C" } "D" + ); + assert_eq!(result, ("X", "A B D", vec![])) +} + +#[test] +fn if_else_2() { + let value = false; + let result = sqlx::conditional_query_as!( + testing X, + "A" if value { "B" } else { "C" } "D" + ); + assert_eq!(result, ("X", "A C D", vec![])) +} + +#[test] +fn single_if_2() { + let limit: Option = None; + let result = sqlx::conditional_query_as!( + testing X, + "A" "B" ?id + 1 + if let Some(limit) = limit { "C" ?limit } + "D" + ); + assert_eq!(result, ("X", "A B $1 D", vec!["id + 1"])) +} + +#[test] +fn single_match() { + enum Y { + C, + D, + } + let value = Y::D; + + let result = sqlx::conditional_query_as!( + testing X, + "A" + "B" ?id + 1 + match value { + Y::C => "C", + Y::D => "D", + } + ); + assert_eq!(result, ("X", "A B $1 D", vec!["id + 1"])) +} + +#[test] +fn nested_if() { + let result = sqlx::conditional_query_as!( + testing X, + "A" + "B" ?id + 1 + if false { + if true { + if true { + "C" + } else { + "D" + } + } + } else if true { + if true { + "D" + } + } + ); + assert_eq!(result, ("X", "A B $1 D", vec!["id + 1"])) +} + +#[test] +fn empty() { + let result = sqlx::conditional_query_as!(testing A, ""); + assert_eq!(result, ("A", "", vec![])) +} + +#[test] +fn empty_2() { + let result = sqlx::conditional_query_as!(testing A,); + assert_eq!(result, ("A", "", vec![])) +} +#[test] +fn empty_3() { + let result = sqlx::conditional_query_as!(testing A, if false { "X" }); + assert_eq!(result, ("A", "", vec![])) +} From 7844fc4b95db7e9c7d396bfe543cd3a5ce0fdb9e Mon Sep 17 00:00:00 2001 From: Moritz Bischof Date: Tue, 26 Oct 2021 01:19:52 +0200 Subject: [PATCH 02/30] restructure code --- Cargo.toml | 8 -- sqlx-macros/src/query/conditional/map.rs | 14 ++- sqlx-macros/src/query/conditional/mod.rs | 82 ++++++-------- sqlx-macros/src/query/conditional/segment.rs | 4 +- tests/conditional_query/conditional_query.rs | 112 ------------------- 5 files changed, 45 insertions(+), 175 deletions(-) delete mode 100644 tests/conditional_query/conditional_query.rs diff --git a/Cargo.toml b/Cargo.toml index 5e9112b91f..aab4f3b613 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -174,14 +174,6 @@ name = "migrate-macro" path = "tests/migrate/macro.rs" required-features = ["macros", "migrate"] -# -# Conditional query -# - -[[test]] -name = "conditional-query" -path = "tests/conditional_query/conditional_query.rs" - # # SQLite # diff --git a/sqlx-macros/src/query/conditional/map.rs b/sqlx-macros/src/query/conditional/map.rs index b1dece392e..08bc518d94 100644 --- a/sqlx-macros/src/query/conditional/map.rs +++ b/sqlx-macros/src/query/conditional/map.rs @@ -4,9 +4,15 @@ use syn::Ident; pub fn generate_conditional_map(n: usize) -> TokenStream { let call_site = Span::call_site(); - let map_fns = (1..=n).map(|i| Ident::new(&format!("F{}", i), call_site)).collect::>(); - let args = (1..=n).map(|i| Ident::new(&format!("A{}", i), call_site)).collect::>(); - let variants = (1..=n).map(|i| Ident::new(&format!("_{}", i), call_site)).collect::>(); + let map_fns = (1..=n) + .map(|i| Ident::new(&format!("F{}", i), call_site)) + .collect::>(); + let args = (1..=n) + .map(|i| Ident::new(&format!("A{}", i), call_site)) + .collect::>(); + let variants = (1..=n) + .map(|i| Ident::new(&format!("_{}", i), call_site)) + .collect::>(); let variant_declarations = (0..n).map(|i| { let variant = &variants[i]; let map_fn = &map_fns[i]; @@ -79,4 +85,4 @@ pub fn generate_conditional_map(n: usize) -> TokenStream { } } } -} \ No newline at end of file +} diff --git a/sqlx-macros/src/query/conditional/mod.rs b/sqlx-macros/src/query/conditional/mod.rs index f9d4e81b1e..619fe32d82 100644 --- a/sqlx-macros/src/query/conditional/mod.rs +++ b/sqlx-macros/src/query/conditional/mod.rs @@ -1,14 +1,13 @@ -use std::rc::Rc; +use std::{fmt::Write, rc::Rc}; use proc_macro2::{Span, TokenStream}; use quote::quote; +use segment::*; use syn::{ - Expr, - Ident, parse::{Parse, ParseStream}, Pat, Path, Result, Token, + parse::{Parse, ParseStream}, + Expr, Ident, Pat, Path, Result, Token, }; -use segment::*; - mod map; mod segment; @@ -16,15 +15,12 @@ mod segment; pub fn conditional_query_as(input: TokenStream) -> Result { let input = syn::parse2::(input)?; let ctx = input.to_context(); - let out = ctx.generate_output(input.testing); + let out = ctx.generate_output(); Ok(out) } /// Input to the conditional query macro. struct Input { - /// `true` if the macro should only output information about the query instead of actually - /// calling a query macro - testing: bool, query_as: Path, segments: Vec, } @@ -46,15 +42,11 @@ impl Input { } } -syn::custom_keyword!(testing); - impl Parse for Input { fn parse(input: ParseStream) -> Result { - let testing = input.parse::>()?.is_some(); let query_as = input.parse::()?; input.parse::()?; Ok(Self { - testing, query_as, segments: QuerySegment::parse_all(input)?, }) @@ -73,7 +65,7 @@ trait IsContext { /// Return the number of branches the current context, including its children, contains. fn branches(&self) -> usize; /// Generate a call to a sqlx query macro for this context. - fn to_query(&self, testing: bool, branch_counter: &mut usize) -> TokenStream; + fn to_query(&self, branch_counter: &mut usize) -> TokenStream; /// Add a piece of an SQL query to this context. fn add_sql(&mut self, sql: &SqlSegment); /// Add an argument to this context. @@ -85,8 +77,8 @@ impl IsContext for Context { self.as_dyn().branches() } - fn to_query(&self, testing: bool, branch_counter: &mut usize) -> TokenStream { - self.as_dyn().to_query(testing, branch_counter) + fn to_query(&self, branch_counter: &mut usize) -> TokenStream { + self.as_dyn().to_query(branch_counter) } fn add_sql(&mut self, sql: &SqlSegment) { @@ -99,22 +91,18 @@ impl IsContext for Context { } impl Context { - fn generate_output(&self, testing: bool) -> TokenStream { + fn generate_output(&self) -> TokenStream { let branches = self.branches(); let result = { let mut branch_counter = 1; - let output = self.to_query(testing, &mut branch_counter); + let output = self.to_query(&mut branch_counter); assert_eq!(branch_counter, branches + 1); output }; - if testing { - quote!( #result ) - } else { - let map = map::generate_conditional_map(branches); - quote!( { #map #result } ) - } + let map = map::generate_conditional_map(branches); + quote!( { #map #result } ) } fn as_dyn(&self) -> &dyn IsContext { @@ -197,19 +185,15 @@ impl IsContext for NormalContext { 1 } - fn to_query(&self, testing: bool, branch_counter: &mut usize) -> TokenStream { - let NormalContext { query_as, sql, args } = self; + fn to_query(&self, branch_counter: &mut usize) -> TokenStream { + let NormalContext { + query_as, + sql, + args, + } = self; *branch_counter += 1; - - if testing { - quote! { - (stringify!(#query_as), #sql, vec![#(stringify!(#args)),*]) - as (&'static str, &'static str, Vec<&'static str>) - } - } else { - let variant = Ident::new(&format!("_{}", branch_counter), Span::call_site()); - quote!(ConditionalMap::#variant(sqlx::query_as!(#query_as, #sql, #(#args),*))) - } + let variant = Ident::new(&format!("_{}", branch_counter), Span::call_site()); + quote!(ConditionalMap::#variant(sqlx::query_as!(#query_as, #sql, #(#args),*))) } fn add_sql(&mut self, sql: &SqlSegment) { @@ -220,10 +204,13 @@ impl IsContext for NormalContext { } fn add_arg(&mut self, arg: &ArgSegment) { + if !self.sql.is_empty() { + self.sql.push(' '); + } if cfg!(feature = "postgres") { - self.sql.push_str(&format!(" ${}", self.args.len() + 1)); + write!(&mut self.sql, "${}", self.args.len() + 1).unwrap(); } else { - self.sql.push_str(" ?"); + self.sql.push('?'); } self.args.push(arg.argument.clone()); } @@ -242,10 +229,10 @@ impl IsContext for IfContext { self.then.branches() + self.or_else.branches() } - fn to_query(&self, testing: bool, branch_counter: &mut usize) -> TokenStream { + fn to_query(&self, branch_counter: &mut usize) -> TokenStream { let condition = &self.condition; - let then = self.then.to_query(testing, branch_counter); - let or_else = self.or_else.to_query(testing, branch_counter); + let then = self.then.to_query(branch_counter); + let or_else = self.or_else.to_query(branch_counter); quote! { if #condition { #then @@ -278,12 +265,9 @@ impl IsContext for MatchContext { self.arms.iter().map(|arm| arm.branches()).sum() } - fn to_query(&self, testing: bool, branch_counter: &mut usize) -> TokenStream { + fn to_query(&self, branch_counter: &mut usize) -> TokenStream { let expr = &self.expr; - let arms = self - .arms - .iter() - .map(|arm| arm.to_query(testing, branch_counter)); + let arms = self.arms.iter().map(|arm| arm.to_query(branch_counter)); quote! { match #expr { #(#arms,)* } } @@ -314,9 +298,9 @@ impl IsContext for MatchArmContext { self.inner.branches() } - fn to_query(&self, testing: bool, branch_counter: &mut usize) -> TokenStream { + fn to_query(&self, branch_counter: &mut usize) -> TokenStream { let pat = &self.pattern; - let inner = self.inner.to_query(testing, branch_counter); + let inner = self.inner.to_query(branch_counter); quote! { #pat => #inner } @@ -329,4 +313,4 @@ impl IsContext for MatchArmContext { fn add_arg(&mut self, arg: &ArgSegment) { self.inner.add_arg(arg); } -} \ No newline at end of file +} diff --git a/sqlx-macros/src/query/conditional/segment.rs b/sqlx-macros/src/query/conditional/segment.rs index fb5514857d..eca1e0a5d5 100644 --- a/sqlx-macros/src/query/conditional/segment.rs +++ b/sqlx-macros/src/query/conditional/segment.rs @@ -1,7 +1,7 @@ use syn::{ braced, - Error, - Expr, LitStr, parse::{Parse, ParseStream, Peek}, Pat, Token, + parse::{Parse, ParseStream, Peek}, + Error, Expr, LitStr, Pat, Token, }; /// A single "piece" of the input. diff --git a/tests/conditional_query/conditional_query.rs b/tests/conditional_query/conditional_query.rs deleted file mode 100644 index 3dc30ac75a..0000000000 --- a/tests/conditional_query/conditional_query.rs +++ /dev/null @@ -1,112 +0,0 @@ -#[test] -fn simple() { - let result = sqlx::conditional_query_as!( - testing X, - "A" "B" ?id + 1 - ); - assert_eq!(result, ("X", "A B $1", vec!["id + 1"])) -} - -#[test] -fn single_if() { - let limit = Some(1); - let result = sqlx::conditional_query_as!( - testing X, - "A" "B" ?id + 1 - if let Some(limit) = limit { "C" ?limit } - "D" - ); - assert_eq!(result, ("X", "A B $1 C $2 D", vec!["id + 1", "limit"])) -} - -#[test] -fn if_else() { - let value = true; - let result = sqlx::conditional_query_as!( - testing X, - "A" if value { "B" } else { "C" } "D" - ); - assert_eq!(result, ("X", "A B D", vec![])) -} - -#[test] -fn if_else_2() { - let value = false; - let result = sqlx::conditional_query_as!( - testing X, - "A" if value { "B" } else { "C" } "D" - ); - assert_eq!(result, ("X", "A C D", vec![])) -} - -#[test] -fn single_if_2() { - let limit: Option = None; - let result = sqlx::conditional_query_as!( - testing X, - "A" "B" ?id + 1 - if let Some(limit) = limit { "C" ?limit } - "D" - ); - assert_eq!(result, ("X", "A B $1 D", vec!["id + 1"])) -} - -#[test] -fn single_match() { - enum Y { - C, - D, - } - let value = Y::D; - - let result = sqlx::conditional_query_as!( - testing X, - "A" - "B" ?id + 1 - match value { - Y::C => "C", - Y::D => "D", - } - ); - assert_eq!(result, ("X", "A B $1 D", vec!["id + 1"])) -} - -#[test] -fn nested_if() { - let result = sqlx::conditional_query_as!( - testing X, - "A" - "B" ?id + 1 - if false { - if true { - if true { - "C" - } else { - "D" - } - } - } else if true { - if true { - "D" - } - } - ); - assert_eq!(result, ("X", "A B $1 D", vec!["id + 1"])) -} - -#[test] -fn empty() { - let result = sqlx::conditional_query_as!(testing A, ""); - assert_eq!(result, ("A", "", vec![])) -} - -#[test] -fn empty_2() { - let result = sqlx::conditional_query_as!(testing A,); - assert_eq!(result, ("A", "", vec![])) -} -#[test] -fn empty_3() { - let result = sqlx::conditional_query_as!(testing A, if false { "X" }); - assert_eq!(result, ("A", "", vec![])) -} From 0d0028c3a0fe3790cc4dfe0545307ab2e56e7f8c Mon Sep 17 00:00:00 2001 From: Moritz Bischof Date: Tue, 26 Oct 2021 23:38:08 +0200 Subject: [PATCH 03/30] make conditional_query_as! expand to just sqlx::query! if no branches are present --- sqlx-macros/src/query/conditional/mod.rs | 51 +++++++++++++++--------- 1 file changed, 32 insertions(+), 19 deletions(-) diff --git a/sqlx-macros/src/query/conditional/mod.rs b/sqlx-macros/src/query/conditional/mod.rs index 619fe32d82..5d7a8c7665 100644 --- a/sqlx-macros/src/query/conditional/mod.rs +++ b/sqlx-macros/src/query/conditional/mod.rs @@ -2,12 +2,13 @@ use std::{fmt::Write, rc::Rc}; use proc_macro2::{Span, TokenStream}; use quote::quote; -use segment::*; use syn::{ - parse::{Parse, ParseStream}, - Expr, Ident, Pat, Path, Result, Token, + Expr, + Ident, parse::{Parse, ParseStream}, Pat, Path, Result, Token, }; +use segment::*; + mod map; mod segment; @@ -65,7 +66,7 @@ trait IsContext { /// Return the number of branches the current context, including its children, contains. fn branches(&self) -> usize; /// Generate a call to a sqlx query macro for this context. - fn to_query(&self, branch_counter: &mut usize) -> TokenStream; + fn to_query(&self, branches: usize, branch_counter: &mut usize) -> TokenStream; /// Add a piece of an SQL query to this context. fn add_sql(&mut self, sql: &SqlSegment); /// Add an argument to this context. @@ -77,8 +78,8 @@ impl IsContext for Context { self.as_dyn().branches() } - fn to_query(&self, branch_counter: &mut usize) -> TokenStream { - self.as_dyn().to_query(branch_counter) + fn to_query(&self, branches: usize, branch_counter: &mut usize) -> TokenStream { + self.as_dyn().to_query(branches, branch_counter) } fn add_sql(&mut self, sql: &SqlSegment) { @@ -96,13 +97,18 @@ impl Context { let result = { let mut branch_counter = 1; - let output = self.to_query(&mut branch_counter); + let output = self.to_query(branches, &mut branch_counter); assert_eq!(branch_counter, branches + 1); output }; - let map = map::generate_conditional_map(branches); - quote!( { #map #result } ) + match branches { + 1 => quote!( #result ), + _ => { + let map = map::generate_conditional_map(branches); + quote!( { #map #result } ) + } + } } fn as_dyn(&self) -> &dyn IsContext { @@ -185,15 +191,22 @@ impl IsContext for NormalContext { 1 } - fn to_query(&self, branch_counter: &mut usize) -> TokenStream { + fn to_query(&self, branches: usize, branch_counter: &mut usize) -> TokenStream { let NormalContext { query_as, sql, args, } = self; *branch_counter += 1; - let variant = Ident::new(&format!("_{}", branch_counter), Span::call_site()); - quote!(ConditionalMap::#variant(sqlx::query_as!(#query_as, #sql, #(#args),*))) + + let query_call = quote!(sqlx::query_as!(#query_as, #sql, #(#args),*)); + match branches { + 1 => query_call, + _ => { + let variant = Ident::new(&format!("_{}", branch_counter), Span::call_site()); + quote!(ConditionalMap::#variant(#query_call)) + } + } } fn add_sql(&mut self, sql: &SqlSegment) { @@ -229,10 +242,10 @@ impl IsContext for IfContext { self.then.branches() + self.or_else.branches() } - fn to_query(&self, branch_counter: &mut usize) -> TokenStream { + fn to_query(&self, branches: usize, branch_counter: &mut usize) -> TokenStream { let condition = &self.condition; - let then = self.then.to_query(branch_counter); - let or_else = self.or_else.to_query(branch_counter); + let then = self.then.to_query(branches, branch_counter); + let or_else = self.or_else.to_query(branches, branch_counter); quote! { if #condition { #then @@ -265,9 +278,9 @@ impl IsContext for MatchContext { self.arms.iter().map(|arm| arm.branches()).sum() } - fn to_query(&self, branch_counter: &mut usize) -> TokenStream { + fn to_query(&self, branches: usize, branch_counter: &mut usize) -> TokenStream { let expr = &self.expr; - let arms = self.arms.iter().map(|arm| arm.to_query(branch_counter)); + let arms = self.arms.iter().map(|arm| arm.to_query(branches, branch_counter)); quote! { match #expr { #(#arms,)* } } @@ -298,9 +311,9 @@ impl IsContext for MatchArmContext { self.inner.branches() } - fn to_query(&self, branch_counter: &mut usize) -> TokenStream { + fn to_query(&self, branches: usize, branch_counter: &mut usize) -> TokenStream { let pat = &self.pattern; - let inner = self.inner.to_query(branch_counter); + let inner = self.inner.to_query(branches, branch_counter); quote! { #pat => #inner } From 8b7fb08d299cf583592043aeb9ab5de38790dc05 Mon Sep 17 00:00:00 2001 From: Moritz Bischof Date: Wed, 27 Oct 2021 01:26:26 +0200 Subject: [PATCH 04/30] support inline arguments and remove `?` syntax --- sqlx-macros/src/query/conditional/mod.rs | 60 +++++++-------- sqlx-macros/src/query/conditional/segment.rs | 81 ++++++++++++-------- 2 files changed, 78 insertions(+), 63 deletions(-) diff --git a/sqlx-macros/src/query/conditional/mod.rs b/sqlx-macros/src/query/conditional/mod.rs index 5d7a8c7665..ce084780f1 100644 --- a/sqlx-macros/src/query/conditional/mod.rs +++ b/sqlx-macros/src/query/conditional/mod.rs @@ -69,8 +69,6 @@ trait IsContext { fn to_query(&self, branches: usize, branch_counter: &mut usize) -> TokenStream; /// Add a piece of an SQL query to this context. fn add_sql(&mut self, sql: &SqlSegment); - /// Add an argument to this context. - fn add_arg(&mut self, arg: &ArgSegment); } impl IsContext for Context { @@ -85,10 +83,6 @@ impl IsContext for Context { fn add_sql(&mut self, sql: &SqlSegment) { self.as_dyn_mut().add_sql(sql); } - - fn add_arg(&mut self, arg: &ArgSegment) { - self.as_dyn_mut().add_arg(arg); - } } impl Context { @@ -132,7 +126,6 @@ impl Context { QuerySegment::Sql(s) => self.add_sql(s), QuerySegment::If(s) => self.add_if(s), QuerySegment::Match(s) => self.add_match(s), - QuerySegment::Arg(s) => self.add_arg(s), } } @@ -186,6 +179,16 @@ struct NormalContext { args: Vec, } +impl NormalContext { + fn add_parameter(&mut self) { + if cfg!(feature = "postgres") { + write!(&mut self.sql, "${}", self.args.len() + 1).unwrap(); + } else { + self.sql.push('?'); + } + } +} + impl IsContext for NormalContext { fn branches(&self) -> usize { 1 @@ -213,19 +216,25 @@ impl IsContext for NormalContext { if !self.sql.is_empty() { self.sql.push(' '); } - self.sql.push_str(&sql.query); - } - fn add_arg(&mut self, arg: &ArgSegment) { - if !self.sql.is_empty() { - self.sql.push(' '); - } - if cfg!(feature = "postgres") { - write!(&mut self.sql, "${}", self.args.len() + 1).unwrap(); - } else { - self.sql.push('?'); + // push the new sql segment, replacing inline arguments (`"{some rust expression}"`) + // with the appropriate placeholder (`$n` or `?`) + let mut args = sql.args.iter(); + let mut arg = args.next(); + for (idx, c) in sql.sql.chars().enumerate() { + if let Some((start, expr, end)) = arg { + if idx < *start { + self.sql.push(c); + } + if idx == *end { + self.args.push(expr.clone()); + self.add_parameter(); + arg = args.next(); + } + } else { + self.sql.push(c); + } } - self.args.push(arg.argument.clone()); } } @@ -259,11 +268,6 @@ impl IsContext for IfContext { self.then.add_sql(sql); self.or_else.add_sql(sql); } - - fn add_arg(&mut self, arg: &ArgSegment) { - self.then.add_arg(arg); - self.or_else.add_arg(arg); - } } /// Context within `match .. { .. }` @@ -291,12 +295,6 @@ impl IsContext for MatchContext { arm.add_sql(sql); } } - - fn add_arg(&mut self, arg: &ArgSegment) { - for arm in &mut self.arms { - arm.add_arg(arg) - } - } } /// Context within the arm (`Pat => ..`) of a `match` @@ -322,8 +320,4 @@ impl IsContext for MatchArmContext { fn add_sql(&mut self, sql: &SqlSegment) { self.inner.add_sql(sql); } - - fn add_arg(&mut self, arg: &ArgSegment) { - self.inner.add_arg(arg); - } } diff --git a/sqlx-macros/src/query/conditional/segment.rs b/sqlx-macros/src/query/conditional/segment.rs index eca1e0a5d5..c07c3a11f2 100644 --- a/sqlx-macros/src/query/conditional/segment.rs +++ b/sqlx-macros/src/query/conditional/segment.rs @@ -1,3 +1,6 @@ +use std::mem::swap; +use std::ptr::replace; +use proc_macro2::Span; use syn::{ braced, parse::{Parse, ParseStream, Peek}, @@ -12,8 +15,6 @@ pub enum QuerySegment { If(IfSegment), /// An exhaustive `match .. { .. }` Match(MatchSegment), - /// A query argument. Can be an arbitrary expression, prefixed by `?`, like `?search.trim()` - Arg(ArgSegment), } impl QuerySegment { @@ -36,29 +37,9 @@ impl QuerySegment { } } -pub struct ArgSegment { - pub argument: Expr, -} - -impl ArgSegment { - const EXPECT: &'static str = "?.."; - - fn matches(input: ParseStream) -> bool { - input.peek(Token![?]) - } -} - -impl Parse for ArgSegment { - fn parse(input: ParseStream) -> syn::Result { - input.parse::()?; - Ok(Self { - argument: input.parse::()?, - }) - } -} - pub struct SqlSegment { - pub query: String, + pub sql: String, + pub args: Vec<(usize, Expr, usize)> } impl SqlSegment { @@ -71,9 +52,52 @@ impl SqlSegment { impl Parse for SqlSegment { fn parse(input: ParseStream) -> syn::Result { - let lit = input.parse::()?; - Ok(Self { query: lit.value() }) + let sql = input.parse::()?.value(); + let args = parse_inline_args(&sql)?; + + Ok(Self { sql, args }) + } +} + +// parses inline arguments in the query, for example `".. WHERE user_id = {1}"`, returning them with +// the index of `}`, the parsed argument, and the index of the `}`. +fn parse_inline_args(sql: &str) -> syn::Result> { + let mut args = vec![]; + let mut curr_level = 0; + let mut curr_arg = None; + + for (idx, c) in sql.chars().enumerate() { + match c { + '}' => { + if curr_arg.is_none() { + curr_arg = Some((idx, String::new())); + } + curr_level += 1; + }, + '{' => { + if curr_arg.is_none() { + let err = Error::new(Span::call_site(), "unexpected '}' in query string"); + return Err(err); + }; + if curr_level == 1 { + let (arg_start, arg_str) = std::mem::replace(&mut curr_arg, None).unwrap(); + let arg = syn::parse_str(&arg_str)?; + args.push((arg_start, arg, idx)); + } + curr_level -= 1; + }, + c => if let Some((_, arg)) = &mut curr_arg { + arg.push(c); + } + } + }; + + if curr_arg.is_some() { + let err = Error::new(Span::call_site(), "expected '}', but got end of string"); + return Err(err) } + + Ok(args) } pub struct MatchSegment { @@ -186,15 +210,12 @@ impl Parse for QuerySegment { Ok(QuerySegment::If(IfSegment::parse(input)?)) } else if MatchSegment::matches(input) { Ok(QuerySegment::Match(MatchSegment::parse(input)?)) - } else if ArgSegment::matches(input) { - Ok(QuerySegment::Arg(ArgSegment::parse(input)?)) } else { let error = format!( - "expected `{}`, `{}`, `{}` or `{}`", + "expected `{}`, `{}` or `{}`", SqlSegment::EXPECT, IfSegment::EXPECT, MatchSegment::EXPECT, - ArgSegment::EXPECT ); Err(Error::new(input.span(), error)) } From 4f2ce57712001dfc5adec984f50e59e93abae28e Mon Sep 17 00:00:00 2001 From: Moritz Bischof Date: Wed, 27 Oct 2021 01:33:09 +0200 Subject: [PATCH 05/30] allow concatenation of segments for backwards compatibility --- sqlx-macros/src/query/conditional/segment.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/sqlx-macros/src/query/conditional/segment.rs b/sqlx-macros/src/query/conditional/segment.rs index c07c3a11f2..1ea2187f4d 100644 --- a/sqlx-macros/src/query/conditional/segment.rs +++ b/sqlx-macros/src/query/conditional/segment.rs @@ -22,7 +22,7 @@ impl QuerySegment { fn parse_until(input: ParseStream, until: T) -> syn::Result> { let mut segments = vec![]; while !input.is_empty() && !input.peek(until) { - segments.push(QuerySegment::parse(input)?) + segments.push(QuerySegment::parse(input)?); } Ok(segments) } @@ -204,6 +204,11 @@ impl Parse for IfSegment { impl Parse for QuerySegment { fn parse(input: ParseStream) -> syn::Result { + // parse optional '+' for backwards compatibility + if input.peek(Token![+]) { + input.parse::()?; + } + if SqlSegment::matches(&input) { Ok(QuerySegment::Sql(SqlSegment::parse(input)?)) } else if IfSegment::matches(&input) { From 21060cefa3a6cea3064f4874b1b0b00aaf735448 Mon Sep 17 00:00:00 2001 From: Moritz Bischof Date: Wed, 27 Oct 2021 01:49:00 +0200 Subject: [PATCH 06/30] make use of format_ident! --- sqlx-macros/src/query/conditional/map.rs | 8 ++++---- sqlx-macros/src/query/conditional/mod.rs | 4 ++-- sqlx-macros/src/query/conditional/segment.rs | 6 +++--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/sqlx-macros/src/query/conditional/map.rs b/sqlx-macros/src/query/conditional/map.rs index 08bc518d94..fa91e7eeda 100644 --- a/sqlx-macros/src/query/conditional/map.rs +++ b/sqlx-macros/src/query/conditional/map.rs @@ -1,17 +1,17 @@ use proc_macro2::{Span, TokenStream}; -use quote::quote; +use quote::{format_ident, quote}; use syn::Ident; pub fn generate_conditional_map(n: usize) -> TokenStream { let call_site = Span::call_site(); let map_fns = (1..=n) - .map(|i| Ident::new(&format!("F{}", i), call_site)) + .map(|i| format_ident!("F{}", i)) .collect::>(); let args = (1..=n) - .map(|i| Ident::new(&format!("A{}", i), call_site)) + .map(|i| format_ident!("A{}", i)) .collect::>(); let variants = (1..=n) - .map(|i| Ident::new(&format!("_{}", i), call_site)) + .map(|i| format_ident!("_{}", i)) .collect::>(); let variant_declarations = (0..n).map(|i| { let variant = &variants[i]; diff --git a/sqlx-macros/src/query/conditional/mod.rs b/sqlx-macros/src/query/conditional/mod.rs index ce084780f1..4a36fd027f 100644 --- a/sqlx-macros/src/query/conditional/mod.rs +++ b/sqlx-macros/src/query/conditional/mod.rs @@ -1,7 +1,7 @@ use std::{fmt::Write, rc::Rc}; use proc_macro2::{Span, TokenStream}; -use quote::quote; +use quote::{format_ident, quote}; use syn::{ Expr, Ident, parse::{Parse, ParseStream}, Pat, Path, Result, Token, @@ -206,7 +206,7 @@ impl IsContext for NormalContext { match branches { 1 => query_call, _ => { - let variant = Ident::new(&format!("_{}", branch_counter), Span::call_site()); + let variant = format_ident!("_{}", branch_counter); quote!(ConditionalMap::#variant(#query_call)) } } diff --git a/sqlx-macros/src/query/conditional/segment.rs b/sqlx-macros/src/query/conditional/segment.rs index 1ea2187f4d..94771139a0 100644 --- a/sqlx-macros/src/query/conditional/segment.rs +++ b/sqlx-macros/src/query/conditional/segment.rs @@ -210,11 +210,11 @@ impl Parse for QuerySegment { } if SqlSegment::matches(&input) { - Ok(QuerySegment::Sql(SqlSegment::parse(input)?)) + Ok(QuerySegment::Sql(input.parse()?)) } else if IfSegment::matches(&input) { - Ok(QuerySegment::If(IfSegment::parse(input)?)) + Ok(QuerySegment::If(input.parse()?)) } else if MatchSegment::matches(input) { - Ok(QuerySegment::Match(MatchSegment::parse(input)?)) + Ok(QuerySegment::Match(input.parse()?)) } else { let error = format!( "expected `{}`, `{}` or `{}`", From dabef597470d3cd0a5b922c6879a28b2ce16655b Mon Sep 17 00:00:00 2001 From: Moritz Bischof Date: Wed, 27 Oct 2021 17:40:47 +0200 Subject: [PATCH 07/30] make calls to expand_query! internally, make (hopefully) completely backwards compatible with query_as!, replace query_as! with new proc macro --- sqlx-macros/src/lib.rs | 155 +++++++++++++++++++++- sqlx-macros/src/query/conditional/mod.rs | 47 +++++-- sqlx-macros/src/query/mod.rs | 4 +- src/lib.rs | 2 +- src/macros.rs | 160 ----------------------- 5 files changed, 193 insertions(+), 175 deletions(-) diff --git a/sqlx-macros/src/lib.rs b/sqlx-macros/src/lib.rs index d40822a9a6..b3fd432fd5 100644 --- a/sqlx-macros/src/lib.rs +++ b/sqlx-macros/src/lib.rs @@ -41,9 +41,160 @@ pub fn expand_query(input: TokenStream) -> TokenStream { } } +/// A variant of [query!] which takes a path to an explicitly defined struct as the output type. +/// +/// This lets you return the struct from a function or add your own trait implementations. +/// +/// **No trait implementations are required**; the macro maps rows using a struct literal +/// where the names of columns in the query are expected to be the same as the fields of the struct +/// (but the order does not need to be the same). The types of the columns are based on the +/// query and not the corresponding fields of the struct, so this is type-safe as well. +/// +/// This enforces a few things: +/// * The query must output at least one column. +/// * The column names of the query must match the field names of the struct. +/// * The field types must be the Rust equivalent of their SQL counterparts; see the corresponding +/// module for your database for mappings: +/// * Postgres: [crate::postgres::types] +/// * MySQL: [crate::mysql::types] +/// * SQLite: [crate::sqlite::types] +/// * MSSQL: [crate::mssql::types] +/// * If a column may be `NULL`, the corresponding field's type must be wrapped in `Option<_>`. +/// * Neither the query nor the struct may have unused fields. +/// +/// The only modification to the `query!()` syntax is that the struct name is given before the SQL +/// string: +/// ```rust,ignore +/// # use sqlx::Connect; +/// # #[cfg(all(feature = "mysql", feature = "_rt-async-std"))] +/// # #[async_std::main] +/// # async fn main() -> sqlx::Result<()>{ +/// # let db_url = dotenv::var("DATABASE_URL").expect("DATABASE_URL must be set"); +/// # +/// # if !(db_url.starts_with("mysql") || db_url.starts_with("mariadb")) { return Ok(()) } +/// # let mut conn = sqlx::MySqlConnection::connect(db_url).await?; +/// #[derive(Debug)] +/// struct Account { +/// id: i32, +/// name: String +/// } +/// +/// // let mut conn = ; +/// let account = sqlx::query_as!( +/// Account, +/// "select * from (select (1) as id, 'Herp Derpinson' as name) accounts where id = ?", +/// 1i32 +/// ) +/// .fetch_one(&mut conn) +/// .await?; +/// +/// println!("{:?}", account); +/// println!("{}: {}", account.id, account.name); +/// +/// # Ok(()) +/// # } +/// # +/// # #[cfg(any(not(feature = "mysql"), not(feature = "_rt-async-std")))] +/// # fn main() {} +/// ``` +/// +/// **The method you want to call depends on how many rows you're expecting.** +/// +/// | Number of Rows | Method to Call* | Returns (`T` being the given struct) | Notes | +/// |----------------| ----------------------------|----------------------------------------|-------| +/// | Zero or One | `.fetch_optional(...).await`| `sqlx::Result>` | Extra rows are ignored. | +/// | Exactly One | `.fetch_one(...).await` | `sqlx::Result` | Errors if no rows were returned. Extra rows are ignored. Aggregate queries, use this. | +/// | At Least One | `.fetch(...)` | `impl Stream>` | Call `.try_next().await` to get each row result. | +/// | Multiple | `.fetch_all(...)` | `sqlx::Result>` | | +/// +/// \* All methods accept one of `&mut {connection type}`, `&mut Transaction` or `&Pool`. +/// (`.execute()` is omitted as this macro requires at least one column to be returned.) +/// +/// ### Column Type Override: Infer from Struct Field +/// In addition to the column type overrides supported by [query!], `query_as!()` supports an +/// additional override option: +/// +/// If you select a column `foo as "foo: _"` (Postgres/SQLite) or `` foo as `foo: _` `` (MySQL) +/// it causes that column to be inferred based on the type of the corresponding field in the given +/// record struct. Runtime type-checking is still done so an error will be emitted if the types +/// are not compatible. +/// +/// This allows you to override the inferred type of a column to instead use a custom-defined type: +/// +/// ```rust,ignore +/// #[derive(sqlx::Type)] +/// #[sqlx(transparent)] +/// struct MyInt4(i32); +/// +/// struct Record { +/// id: MyInt4, +/// } +/// +/// let my_int = MyInt4(1); +/// +/// // Postgres/SQLite +/// sqlx::query_as!(Record, r#"select 1 as "id: _""#) // MySQL: use "select 1 as `id: _`" instead +/// .fetch_one(&mut conn) +/// .await?; +/// +/// assert_eq!(record.id, MyInt4(1)); +/// ``` +/// +/// ### Troubleshooting: "error: mismatched types" +/// If you get a "mismatched types" error from an invocation of this macro and the error +/// isn't pointing specifically at a parameter. +/// +/// For example, code like this (using a Postgres database): +/// +/// ```rust,ignore +/// struct Account { +/// id: i32, +/// name: Option, +/// } +/// +/// let account = sqlx::query_as!( +/// Account, +/// r#"SELECT id, name from (VALUES (1, 'Herp Derpinson')) accounts(id, name)"#, +/// ) +/// .fetch_one(&mut conn) +/// .await?; +/// ``` +/// +/// Might produce an error like this: +/// ```text,ignore +/// error[E0308]: mismatched types +/// --> tests/postgres/macros.rs:126:19 +/// | +/// 126 | let account = sqlx::query_as!( +/// | ___________________^ +/// 127 | | Account, +/// 128 | | r#"SELECT id, name from (VALUES (1, 'Herp Derpinson')) accounts(id, name)"#, +/// 129 | | ) +/// | |_____^ expected `i32`, found enum `std::option::Option` +/// | +/// = note: expected type `i32` +/// found enum `std::option::Option` +/// ``` +/// +/// This means that you need to check that any field of the "expected" type (here, `i32`) matches +/// the Rust type mapping for its corresponding SQL column (see the `types` module of your database, +/// listed above, for mappings). The "found" type is the SQL->Rust mapping that the macro chose. +/// +/// In the above example, the returned column is inferred to be nullable because it's being +/// returned from a `VALUES` statement in Postgres, so the macro inferred the field to be nullable +/// and so used `Option` instead of `i32`. **In this specific case** we could use +/// `select id as "id!"` to override the inferred nullability because we know in practice +/// that column will never be `NULL` and it will fix the error. +/// +/// Nullability inference and type overrides are discussed in detail in the docs for [query!]. +/// +/// It unfortunately doesn't appear to be possible right now to make the error specifically mention +/// the field; this probably requires the `const-panic` feature (still unstable as of Rust 1.45). +#[macro_export] +#[cfg_attr(docsrs, doc(cfg(feature = "macros")))] #[proc_macro] -pub fn conditional_query_as(input: TokenStream) -> TokenStream { - match query::conditional::conditional_query_as(input.into()) { +pub fn query_as(input: TokenStream) -> TokenStream { + match query::query_as(input.into()) { Ok(output) => output, Err(err) => err.to_compile_error() }.into() diff --git a/sqlx-macros/src/query/conditional/mod.rs b/sqlx-macros/src/query/conditional/mod.rs index 4a36fd027f..9b64c37dff 100644 --- a/sqlx-macros/src/query/conditional/mod.rs +++ b/sqlx-macros/src/query/conditional/mod.rs @@ -1,21 +1,20 @@ use std::{fmt::Write, rc::Rc}; +use std::any::Any; use proc_macro2::{Span, TokenStream}; use quote::{format_ident, quote}; -use syn::{ - Expr, - Ident, parse::{Parse, ParseStream}, Pat, Path, Result, Token, -}; +use syn::{Error, Expr, Ident, parse::{Parse, ParseStream}, Pat, Path, Result, Token}; +use syn::punctuated::Punctuated; use segment::*; mod map; mod segment; -/// Expand a call to `conditional_query_as!` -pub fn conditional_query_as(input: TokenStream) -> Result { + +pub fn query_as(input: TokenStream) -> Result { let input = syn::parse2::(input)?; - let ctx = input.to_context(); + let ctx = input.to_context()?; let out = ctx.generate_output(); Ok(out) } @@ -24,11 +23,13 @@ pub fn conditional_query_as(input: TokenStream) -> Result { struct Input { query_as: Path, segments: Vec, + // separately specified arguments + arguments: Vec, } impl Input { /// Convert the input into a context - fn to_context(&self) -> Context { + fn to_context(&self) -> Result { let mut ctx = Context::Default(NormalContext { query_as: Rc::new(self.query_as.clone()), sql: String::new(), @@ -39,7 +40,19 @@ impl Input { ctx.add_segment(segment); } - ctx + if ctx.branches() > 1 && !self.arguments.is_empty() { + let err = Error::new( + Span::call_site(), + "branches (`match` and `if`) can only be used with inline arguments" + ); + Err(err) + } else { + match &mut ctx { + Context::Default(ctx) => ctx.args.extend(self.arguments.iter().cloned()), + _ => unreachable!() + } + Ok(ctx) + } } } @@ -47,9 +60,15 @@ impl Parse for Input { fn parse(input: ParseStream) -> Result { let query_as = input.parse::()?; input.parse::()?; + let segments = QuerySegment::parse_all(input)?; + let arguments = match input.parse::>()? { + None => vec![], + Some(..) => Punctuated::::parse_terminated(input)?.into_iter().collect() + }; Ok(Self { query_as, - segments: QuerySegment::parse_all(input)?, + segments, + arguments, }) } } @@ -202,7 +221,13 @@ impl IsContext for NormalContext { } = self; *branch_counter += 1; - let query_call = quote!(sqlx::query_as!(#query_as, #sql, #(#args),*)); + let query_call = quote!( + sqlx_macros::expand_query!( + record = #query_as, + source = $sql, + args = [$($args),*] + ) + ); match branches { 1 => query_call, _ => { diff --git a/sqlx-macros/src/query/mod.rs b/sqlx-macros/src/query/mod.rs index 20cdc631f8..04ff78b1a1 100644 --- a/sqlx-macros/src/query/mod.rs +++ b/sqlx-macros/src/query/mod.rs @@ -23,7 +23,9 @@ mod args; mod data; mod input; mod output; -pub mod conditional; +mod conditional; + +pub use conditional::query_as; struct Metadata { #[allow(unused)] diff --git a/src/lib.rs b/src/lib.rs index 8d0c8e30c3..56ecd58153 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -71,7 +71,7 @@ pub extern crate sqlx_macros; // derives #[cfg(feature = "macros")] #[doc(hidden)] -pub use sqlx_macros::{FromRow, Type, conditional_query_as}; +pub use sqlx_macros::{FromRow, Type, query_as}; #[cfg(feature = "macros")] mod macros; diff --git a/src/macros.rs b/src/macros.rs index c1dceaba1b..8fef39ec6e 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -396,166 +396,6 @@ macro_rules! query_file_unchecked ( }) ); -/// A variant of [query!] which takes a path to an explicitly defined struct as the output type. -/// -/// This lets you return the struct from a function or add your own trait implementations. -/// -/// **No trait implementations are required**; the macro maps rows using a struct literal -/// where the names of columns in the query are expected to be the same as the fields of the struct -/// (but the order does not need to be the same). The types of the columns are based on the -/// query and not the corresponding fields of the struct, so this is type-safe as well. -/// -/// This enforces a few things: -/// * The query must output at least one column. -/// * The column names of the query must match the field names of the struct. -/// * The field types must be the Rust equivalent of their SQL counterparts; see the corresponding -/// module for your database for mappings: -/// * Postgres: [crate::postgres::types] -/// * MySQL: [crate::mysql::types] -/// * SQLite: [crate::sqlite::types] -/// * MSSQL: [crate::mssql::types] -/// * If a column may be `NULL`, the corresponding field's type must be wrapped in `Option<_>`. -/// * Neither the query nor the struct may have unused fields. -/// -/// The only modification to the `query!()` syntax is that the struct name is given before the SQL -/// string: -/// ```rust,ignore -/// # use sqlx::Connect; -/// # #[cfg(all(feature = "mysql", feature = "_rt-async-std"))] -/// # #[async_std::main] -/// # async fn main() -> sqlx::Result<()>{ -/// # let db_url = dotenv::var("DATABASE_URL").expect("DATABASE_URL must be set"); -/// # -/// # if !(db_url.starts_with("mysql") || db_url.starts_with("mariadb")) { return Ok(()) } -/// # let mut conn = sqlx::MySqlConnection::connect(db_url).await?; -/// #[derive(Debug)] -/// struct Account { -/// id: i32, -/// name: String -/// } -/// -/// // let mut conn = ; -/// let account = sqlx::query_as!( -/// Account, -/// "select * from (select (1) as id, 'Herp Derpinson' as name) accounts where id = ?", -/// 1i32 -/// ) -/// .fetch_one(&mut conn) -/// .await?; -/// -/// println!("{:?}", account); -/// println!("{}: {}", account.id, account.name); -/// -/// # Ok(()) -/// # } -/// # -/// # #[cfg(any(not(feature = "mysql"), not(feature = "_rt-async-std")))] -/// # fn main() {} -/// ``` -/// -/// **The method you want to call depends on how many rows you're expecting.** -/// -/// | Number of Rows | Method to Call* | Returns (`T` being the given struct) | Notes | -/// |----------------| ----------------------------|----------------------------------------|-------| -/// | Zero or One | `.fetch_optional(...).await`| `sqlx::Result>` | Extra rows are ignored. | -/// | Exactly One | `.fetch_one(...).await` | `sqlx::Result` | Errors if no rows were returned. Extra rows are ignored. Aggregate queries, use this. | -/// | At Least One | `.fetch(...)` | `impl Stream>` | Call `.try_next().await` to get each row result. | -/// | Multiple | `.fetch_all(...)` | `sqlx::Result>` | | -/// -/// \* All methods accept one of `&mut {connection type}`, `&mut Transaction` or `&Pool`. -/// (`.execute()` is omitted as this macro requires at least one column to be returned.) -/// -/// ### Column Type Override: Infer from Struct Field -/// In addition to the column type overrides supported by [query!], `query_as!()` supports an -/// additional override option: -/// -/// If you select a column `foo as "foo: _"` (Postgres/SQLite) or `` foo as `foo: _` `` (MySQL) -/// it causes that column to be inferred based on the type of the corresponding field in the given -/// record struct. Runtime type-checking is still done so an error will be emitted if the types -/// are not compatible. -/// -/// This allows you to override the inferred type of a column to instead use a custom-defined type: -/// -/// ```rust,ignore -/// #[derive(sqlx::Type)] -/// #[sqlx(transparent)] -/// struct MyInt4(i32); -/// -/// struct Record { -/// id: MyInt4, -/// } -/// -/// let my_int = MyInt4(1); -/// -/// // Postgres/SQLite -/// sqlx::query_as!(Record, r#"select 1 as "id: _""#) // MySQL: use "select 1 as `id: _`" instead -/// .fetch_one(&mut conn) -/// .await?; -/// -/// assert_eq!(record.id, MyInt4(1)); -/// ``` -/// -/// ### Troubleshooting: "error: mismatched types" -/// If you get a "mismatched types" error from an invocation of this macro and the error -/// isn't pointing specifically at a parameter. -/// -/// For example, code like this (using a Postgres database): -/// -/// ```rust,ignore -/// struct Account { -/// id: i32, -/// name: Option, -/// } -/// -/// let account = sqlx::query_as!( -/// Account, -/// r#"SELECT id, name from (VALUES (1, 'Herp Derpinson')) accounts(id, name)"#, -/// ) -/// .fetch_one(&mut conn) -/// .await?; -/// ``` -/// -/// Might produce an error like this: -/// ```text,ignore -/// error[E0308]: mismatched types -/// --> tests/postgres/macros.rs:126:19 -/// | -/// 126 | let account = sqlx::query_as!( -/// | ___________________^ -/// 127 | | Account, -/// 128 | | r#"SELECT id, name from (VALUES (1, 'Herp Derpinson')) accounts(id, name)"#, -/// 129 | | ) -/// | |_____^ expected `i32`, found enum `std::option::Option` -/// | -/// = note: expected type `i32` -/// found enum `std::option::Option` -/// ``` -/// -/// This means that you need to check that any field of the "expected" type (here, `i32`) matches -/// the Rust type mapping for its corresponding SQL column (see the `types` module of your database, -/// listed above, for mappings). The "found" type is the SQL->Rust mapping that the macro chose. -/// -/// In the above example, the returned column is inferred to be nullable because it's being -/// returned from a `VALUES` statement in Postgres, so the macro inferred the field to be nullable -/// and so used `Option` instead of `i32`. **In this specific case** we could use -/// `select id as "id!"` to override the inferred nullability because we know in practice -/// that column will never be `NULL` and it will fix the error. -/// -/// Nullability inference and type overrides are discussed in detail in the docs for [query!]. -/// -/// It unfortunately doesn't appear to be possible right now to make the error specifically mention -/// the field; this probably requires the `const-panic` feature (still unstable as of Rust 1.45). -#[macro_export] -#[cfg_attr(docsrs, doc(cfg(feature = "macros")))] -macro_rules! query_as ( - ($out_struct:path, $query:expr) => ( { - $crate::sqlx_macros::expand_query!(record = $out_struct, source = $query) - }); - ($out_struct:path, $query:expr, $($args:tt)*) => ( { - $crate::sqlx_macros::expand_query!(record = $out_struct, source = $query, args = [$($args)*]) - }) -); - /// Combines the syntaxes of [query_as!] and [query_file!]. /// /// Enforces requirements of both macros; see them for details. From 1f10f6405451a25495c1d23f6292b080ffee4015 Mon Sep 17 00:00:00 2001 From: Moritz Bischof Date: Wed, 27 Oct 2021 18:01:39 +0200 Subject: [PATCH 08/30] fix bad call to quote! --- sqlx-macros/src/query/conditional/mod.rs | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/sqlx-macros/src/query/conditional/mod.rs b/sqlx-macros/src/query/conditional/mod.rs index 9b64c37dff..a90711800d 100644 --- a/sqlx-macros/src/query/conditional/mod.rs +++ b/sqlx-macros/src/query/conditional/mod.rs @@ -224,8 +224,8 @@ impl IsContext for NormalContext { let query_call = quote!( sqlx_macros::expand_query!( record = #query_as, - source = $sql, - args = [$($args),*] + source = #sql, + args = [#(#args),*] ) ); match branches { @@ -346,3 +346,18 @@ impl IsContext for MatchArmContext { self.inner.add_sql(sql); } } + +#[cfg(test)] +mod tests { + use crate::query::conditional::Input; + use quote::quote; + + #[test] + fn simple() { + let input = quote! { + OptionalRecord, "select owner_id as `id: _` from tweet" + }; + let result = syn::parse2::(input).unwrap(); + println!("{}", result.segments.len()); + } +} From 860b72df0ce8e4fd0b8195a8e81bca5a29bde10b Mon Sep 17 00:00:00 2001 From: Moritz Bischof Date: Wed, 27 Oct 2021 18:16:22 +0200 Subject: [PATCH 09/30] accidentally swapped { and } --- sqlx-macros/src/query/conditional/segment.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sqlx-macros/src/query/conditional/segment.rs b/sqlx-macros/src/query/conditional/segment.rs index 94771139a0..0c3fcad10b 100644 --- a/sqlx-macros/src/query/conditional/segment.rs +++ b/sqlx-macros/src/query/conditional/segment.rs @@ -68,13 +68,13 @@ fn parse_inline_args(sql: &str) -> syn::Result> { for (idx, c) in sql.chars().enumerate() { match c { - '}' => { + '{' => { if curr_arg.is_none() { curr_arg = Some((idx, String::new())); } curr_level += 1; }, - '{' => { + '}' => { if curr_arg.is_none() { let err = Error::new(Span::call_site(), "unexpected '}' in query string"); return Err(err); From 89878edd70b65835a79a32b36f5e0b1fd2a5881a Mon Sep 17 00:00:00 2001 From: Moritz Bischof Date: Wed, 27 Oct 2021 18:26:59 +0200 Subject: [PATCH 10/30] add test, clean up --- sqlx-macros/src/lib.rs | 1 - sqlx-macros/src/query/conditional/map.rs | 1 - sqlx-macros/src/query/conditional/mod.rs | 83 ++++++++++++++++++------ 3 files changed, 63 insertions(+), 22 deletions(-) diff --git a/sqlx-macros/src/lib.rs b/sqlx-macros/src/lib.rs index b3fd432fd5..da01632fe4 100644 --- a/sqlx-macros/src/lib.rs +++ b/sqlx-macros/src/lib.rs @@ -190,7 +190,6 @@ pub fn expand_query(input: TokenStream) -> TokenStream { /// /// It unfortunately doesn't appear to be possible right now to make the error specifically mention /// the field; this probably requires the `const-panic` feature (still unstable as of Rust 1.45). -#[macro_export] #[cfg_attr(docsrs, doc(cfg(feature = "macros")))] #[proc_macro] pub fn query_as(input: TokenStream) -> TokenStream { diff --git a/sqlx-macros/src/query/conditional/map.rs b/sqlx-macros/src/query/conditional/map.rs index fa91e7eeda..4bd6be9f7d 100644 --- a/sqlx-macros/src/query/conditional/map.rs +++ b/sqlx-macros/src/query/conditional/map.rs @@ -3,7 +3,6 @@ use quote::{format_ident, quote}; use syn::Ident; pub fn generate_conditional_map(n: usize) -> TokenStream { - let call_site = Span::call_site(); let map_fns = (1..=n) .map(|i| format_ident!("F{}", i)) .collect::>(); diff --git a/sqlx-macros/src/query/conditional/mod.rs b/sqlx-macros/src/query/conditional/mod.rs index a90711800d..72c25b2ce9 100644 --- a/sqlx-macros/src/query/conditional/mod.rs +++ b/sqlx-macros/src/query/conditional/mod.rs @@ -1,17 +1,17 @@ -use std::{fmt::Write, rc::Rc}; -use std::any::Any; +use std::{any::Any, fmt::Write, rc::Rc}; use proc_macro2::{Span, TokenStream}; use quote::{format_ident, quote}; -use syn::{Error, Expr, Ident, parse::{Parse, ParseStream}, Pat, Path, Result, Token}; -use syn::punctuated::Punctuated; - use segment::*; +use syn::{ + parse::{Parse, ParseStream}, + punctuated::Punctuated, + Error, Expr, Ident, Pat, Path, Result, Token, +}; mod map; mod segment; - pub fn query_as(input: TokenStream) -> Result { let input = syn::parse2::(input)?; let ctx = input.to_context()?; @@ -43,13 +43,13 @@ impl Input { if ctx.branches() > 1 && !self.arguments.is_empty() { let err = Error::new( Span::call_site(), - "branches (`match` and `if`) can only be used with inline arguments" + "branches (`match` and `if`) can only be used with inline arguments", ); Err(err) } else { match &mut ctx { Context::Default(ctx) => ctx.args.extend(self.arguments.iter().cloned()), - _ => unreachable!() + _ => unreachable!(), } Ok(ctx) } @@ -63,7 +63,9 @@ impl Parse for Input { let segments = QuerySegment::parse_all(input)?; let arguments = match input.parse::>()? { None => vec![], - Some(..) => Punctuated::::parse_terminated(input)?.into_iter().collect() + Some(..) => Punctuated::::parse_terminated(input)? + .into_iter() + .collect(), }; Ok(Self { query_as, @@ -221,13 +223,11 @@ impl IsContext for NormalContext { } = self; *branch_counter += 1; - let query_call = quote!( - sqlx_macros::expand_query!( - record = #query_as, - source = #sql, - args = [#(#args),*] - ) - ); + let query_call = quote!(sqlx_macros::expand_query!( + record = #query_as, + source = #sql, + args = [#(#args),*] + )); match branches { 1 => query_call, _ => { @@ -309,7 +309,10 @@ impl IsContext for MatchContext { fn to_query(&self, branches: usize, branch_counter: &mut usize) -> TokenStream { let expr = &self.expr; - let arms = self.arms.iter().map(|arm| arm.to_query(branches, branch_counter)); + let arms = self + .arms + .iter() + .map(|arm| arm.to_query(branches, branch_counter)); quote! { match #expr { #(#arms,)* } } @@ -349,15 +352,55 @@ impl IsContext for MatchArmContext { #[cfg(test)] mod tests { - use crate::query::conditional::Input; + use proc_macro2::{TokenStream, TokenTree}; use quote::quote; + use crate::query::conditional::Input; + + // credits: Yandros#4299 + fn assert_token_stream_eq(ts1: TokenStream, ts2: TokenStream) { + fn assert_tt_eq(tt1: TokenTree, tt2: TokenTree) { + use ::proc_macro2::TokenTree::*; + match (tt1, tt2) { + (Group(g1), Group(g2)) => assert_token_stream_eq(g1.stream(), g2.stream()), + (Ident(lhs), Ident(rhs)) => assert_eq!(lhs.to_string(), rhs.to_string()), + (Punct(lhs), Punct(rhs)) => assert_eq!(lhs.as_char(), rhs.as_char()), + (Literal(lhs), Literal(rhs)) => assert_eq!(lhs.to_string(), rhs.to_string()), + _ => panic!("Not equal!"), + } + } + + let mut ts1 = ts1.into_iter(); + let mut ts2 = ts2.into_iter(); + loop { + match (ts1.next(), ts2.next()) { + (Some(tt1), Some(tt2)) => assert_tt_eq(tt1, tt2), + (None, None) => return, + _ => panic!("Not equal!"), + } + } + } + #[test] fn simple() { let input = quote! { - OptionalRecord, "select owner_id as `id: _` from tweet" + OptionalRecord, "select something from somewhere where something_else = {1}" }; let result = syn::parse2::(input).unwrap(); - println!("{}", result.segments.len()); + let expected_query = if cfg!(feature = "postgres") { + "select something from somewhere where something_else = $1" + } else { + "select something from somewhere where something_else = ?" + }; + assert_token_stream_eq( + result.to_context().unwrap().generate_output(), + quote! { + sqlx_macros::expand_query!( + record = OptionalRecord, + source = #expected_query, + args = [1] + ) + }, + ); } } From 39ae6c257477e75533d4137b839adfa69d431351 Mon Sep 17 00:00:00 2001 From: Moritz Bischof Date: Wed, 27 Oct 2021 18:35:01 +0200 Subject: [PATCH 11/30] don't hide query_as! from docs --- sqlx-macros/src/query/conditional/segment.rs | 2 +- src/lib.rs | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/sqlx-macros/src/query/conditional/segment.rs b/sqlx-macros/src/query/conditional/segment.rs index 0c3fcad10b..bf385b9b8b 100644 --- a/sqlx-macros/src/query/conditional/segment.rs +++ b/sqlx-macros/src/query/conditional/segment.rs @@ -60,7 +60,7 @@ impl Parse for SqlSegment { } // parses inline arguments in the query, for example `".. WHERE user_id = {1}"`, returning them with -// the index of `}`, the parsed argument, and the index of the `}`. +// the index of `{`, the parsed argument, and the index of the `}`. fn parse_inline_args(sql: &str) -> syn::Result> { let mut args = vec![]; let mut curr_level = 0; diff --git a/src/lib.rs b/src/lib.rs index 56ecd58153..1edaa92dee 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -71,7 +71,10 @@ pub extern crate sqlx_macros; // derives #[cfg(feature = "macros")] #[doc(hidden)] -pub use sqlx_macros::{FromRow, Type, query_as}; +pub use sqlx_macros::{FromRow, Type}; + +#[cfg(feature = "macros")] +pub use sqlx_macros::query_as; #[cfg(feature = "macros")] mod macros; From 6489cdc486a49db78157850a3385af6a56b139b4 Mon Sep 17 00:00:00 2001 From: Moritz Bischof Date: Wed, 2 Mar 2022 13:29:29 +0100 Subject: [PATCH 12/30] add brief summary to conditional/mod.rs --- sqlx-macros/src/query/conditional/mod.rs | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/sqlx-macros/src/query/conditional/mod.rs b/sqlx-macros/src/query/conditional/mod.rs index 72c25b2ce9..81a47a9d79 100644 --- a/sqlx-macros/src/query/conditional/mod.rs +++ b/sqlx-macros/src/query/conditional/mod.rs @@ -1,3 +1,8 @@ +/// This module introduces support for building dynamic queries while still having them checked at +/// compile time. +/// This is achieved by computing every possible query within a procedural macro. +/// It's only during runtime when the appropriate query will be chosen and executed. + use std::{any::Any, fmt::Write, rc::Rc}; use proc_macro2::{Span, TokenStream}; @@ -12,6 +17,7 @@ use syn::{ mod map; mod segment; +/// Entry point of the `query_as!` macro pub fn query_as(input: TokenStream) -> Result { let input = syn::parse2::(input)?; let ctx = input.to_context()?; @@ -19,7 +25,7 @@ pub fn query_as(input: TokenStream) -> Result { Ok(out) } -/// Input to the conditional query macro. +/// Input to the `query_as!` macro. struct Input { query_as: Path, segments: Vec, From b483b984e0cbfc1dfad688a07b0d1fcb5ca0b826 Mon Sep 17 00:00:00 2001 From: Moritz Bischof Date: Wed, 2 Mar 2022 13:30:55 +0100 Subject: [PATCH 13/30] enable CI for development --- .github/workflows/sqlx.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/workflows/sqlx.yml b/.github/workflows/sqlx.yml index f325cb7c3f..3db2c732ac 100644 --- a/.github/workflows/sqlx.yml +++ b/.github/workflows/sqlx.yml @@ -3,8 +3,6 @@ name: SQLx on: pull_request: push: - branches: - - master jobs: format: From 1a726219b02d1b18393ef6aedda69f63a469d991 Mon Sep 17 00:00:00 2001 From: Moritz Bischof Date: Wed, 2 Mar 2022 14:51:50 +0100 Subject: [PATCH 14/30] add documentation, format, derive Debug for QuerySegment and Context --- sqlx-macros/Cargo.toml | 2 +- sqlx-macros/src/lib.rs | 5 +- sqlx-macros/src/query/conditional/map.rs | 12 +- sqlx-macros/src/query/conditional/mod.rs | 131 ++++++++++++++++--- sqlx-macros/src/query/conditional/segment.rs | 23 ++-- sqlx-macros/src/query/mod.rs | 2 +- 6 files changed, 137 insertions(+), 38 deletions(-) diff --git a/sqlx-macros/Cargo.toml b/sqlx-macros/Cargo.toml index b6c6c99b08..54d2088aad 100644 --- a/sqlx-macros/Cargo.toml +++ b/sqlx-macros/Cargo.toml @@ -89,6 +89,6 @@ sqlx-rt = { version = "0.5.11", default-features = false, path = "../sqlx-rt" } serde = { version = "1.0.132", features = ["derive"], optional = true } serde_json = { version = "1.0.73", optional = true } sha2 = { version = "0.9.8", optional = true } -syn = { version = "1.0.84", default-features = false, features = ["full"] } +syn = { version = "1.0.84", default-features = false, features = ["full", "extra-traits"] } quote = { version = "1.0.14", default-features = false } url = { version = "2.2.2", default-features = false } diff --git a/sqlx-macros/src/lib.rs b/sqlx-macros/src/lib.rs index da01632fe4..2e70eb279d 100644 --- a/sqlx-macros/src/lib.rs +++ b/sqlx-macros/src/lib.rs @@ -195,8 +195,9 @@ pub fn expand_query(input: TokenStream) -> TokenStream { pub fn query_as(input: TokenStream) -> TokenStream { match query::query_as(input.into()) { Ok(output) => output, - Err(err) => err.to_compile_error() - }.into() + Err(err) => err.to_compile_error(), + } + .into() } #[proc_macro_derive(Encode, attributes(sqlx))] diff --git a/sqlx-macros/src/query/conditional/map.rs b/sqlx-macros/src/query/conditional/map.rs index 4bd6be9f7d..c871212211 100644 --- a/sqlx-macros/src/query/conditional/map.rs +++ b/sqlx-macros/src/query/conditional/map.rs @@ -3,15 +3,9 @@ use quote::{format_ident, quote}; use syn::Ident; pub fn generate_conditional_map(n: usize) -> TokenStream { - let map_fns = (1..=n) - .map(|i| format_ident!("F{}", i)) - .collect::>(); - let args = (1..=n) - .map(|i| format_ident!("A{}", i)) - .collect::>(); - let variants = (1..=n) - .map(|i| format_ident!("_{}", i)) - .collect::>(); + let map_fns = (1..=n).map(|i| format_ident!("F{}", i)).collect::>(); + let args = (1..=n).map(|i| format_ident!("A{}", i)).collect::>(); + let variants = (1..=n).map(|i| format_ident!("_{}", i)).collect::>(); let variant_declarations = (0..n).map(|i| { let variant = &variants[i]; let map_fn = &map_fns[i]; diff --git a/sqlx-macros/src/query/conditional/mod.rs b/sqlx-macros/src/query/conditional/mod.rs index 81a47a9d79..d6dad5e0cf 100644 --- a/sqlx-macros/src/query/conditional/mod.rs +++ b/sqlx-macros/src/query/conditional/mod.rs @@ -2,7 +2,72 @@ /// compile time. /// This is achieved by computing every possible query within a procedural macro. /// It's only during runtime when the appropriate query will be chosen and executed. - +/// +/// ## Return type +/// Since a single invocation of `query_as!` executes one of many possible queries, it's return type +/// would differ between different invocations. This, of course, would break as soon as one tried to +/// do anything with the return value, for example `.await`ing it. +/// Therefor, this module introduces a workaround for conditional queries. The behaviour of normal +/// queries is not affected by this. +/// +/// For each *conditional* invocation of `query_as!`, an enum will be generated, and the invocation +/// expands to an instance of this enum. This enum contains a variant for each possible query. +/// see `[map::generate_conditional_map]` +/// +/// ## Arguments +/// For conditional queries, arguments must be specified *inline* (for example ".. WHERE name ILIKE {filter}). +/// For normal queries, arguments can still be passed as a comma-separated list. +/// +/// ## Example +/// To outline how this all works, let's consider the following example. +/// ```rust,ignore +/// sqlx::query_as!( +/// Article, +/// "SELECT * FROM articles" +/// if let Some(name_filter) = filter { +/// "WHERE name ILIKE {name_filter} +/// } +/// ``` +/// +/// This input will first be parsed into a list of `QuerySegment`s. +/// For the example above, this would result in something like this: +/// ```rust,ignore +/// [ +/// SqlSegment { sql: "SELECT * FROM articles", args: [] }, +/// IfSegment { +/// condition: "let Some (name_filter) = filter", +/// then: [ +/// SqlSegment { sql: "WHERE name ILIKE {name_filter}" } +/// ] +/// } +/// ``` +/// +/// These segments are now transformed into a tree-structure. In essence, this would result in: +/// ```rust,ignore +/// IfContext { +/// condition: "let Some(name_filter) = filter", +/// then: NormalContext { sql: "SELECT * FROM articles WHERE name ILIKE ?", args: ["name_filter"] }, +/// or_else: NormalContext { sql: "SELECT * FROM articles", args: [] }, +/// } +/// ``` +/// +/// Finally, the resulting code is generated: +/// ```rust,ignore +/// enum ConditionalMap { .. } +/// if let Some(name_filter) = filter { +/// ConditionalMap::_1(sqlx_macros::expand_query!( +/// record = Article, +/// source = "SELECT * FROM articles WHERE name ILIKE ?", +/// args = [name_filter] +/// )) +/// } else { +/// ConditionalMap::_2(sqlx_macros::expand_query!( +/// record = Article, +/// source = "SELECT * FROM articles", +/// args = [] +/// )) +/// } +/// ``` use std::{any::Any, fmt::Write, rc::Rc}; use proc_macro2::{Span, TokenStream}; @@ -46,19 +111,23 @@ impl Input { ctx.add_segment(segment); } - if ctx.branches() > 1 && !self.arguments.is_empty() { - let err = Error::new( - Span::call_site(), - "branches (`match` and `if`) can only be used with inline arguments", - ); - Err(err) - } else { + // add separately specified arguments to context + if !self.arguments.is_empty() { + if ctx.branches() > 1 { + let err = Error::new( + Span::call_site(), + "branches (`match` and `if`) can only be used with inline arguments", + ); + return Err(err); + } match &mut ctx { Context::Default(ctx) => ctx.args.extend(self.arguments.iter().cloned()), - _ => unreachable!(), + // we know this can only be a default context since there is only one branch + _ => unreachable!() } - Ok(ctx) } + + Ok(ctx) } } @@ -82,7 +151,7 @@ impl Parse for Input { } /// A context describes the current position within a conditional query. -#[derive(Clone)] +#[derive(Clone, Debug)] enum Context { Default(NormalContext), If(IfContext), @@ -117,9 +186,9 @@ impl Context { let branches = self.branches(); let result = { - let mut branch_counter = 1; + let mut branch_counter = 0; let output = self.to_query(branches, &mut branch_counter); - assert_eq!(branch_counter, branches + 1); + assert_eq!(branch_counter, branches); output }; @@ -199,7 +268,7 @@ impl Context { } /// A "normal" linear context without any branches. -#[derive(Clone)] +#[derive(Clone, Debug)] struct NormalContext { query_as: Rc, sql: String, @@ -270,7 +339,7 @@ impl IsContext for NormalContext { } /// Context within an `if .. {..} else ..` clause. -#[derive(Clone)] +#[derive(Clone, Debug)] struct IfContext { condition: Expr, then: Box, @@ -302,7 +371,7 @@ impl IsContext for IfContext { } /// Context within `match .. { .. }` -#[derive(Clone)] +#[derive(Clone, Debug)] struct MatchContext { expr: Expr, arms: Vec, @@ -332,7 +401,7 @@ impl IsContext for MatchContext { } /// Context within the arm (`Pat => ..`) of a `match` -#[derive(Clone)] +#[derive(Clone, Debug)] struct MatchArmContext { pattern: Pat, inner: Box, @@ -409,4 +478,32 @@ mod tests { }, ); } + + #[test] + fn single_if() { + let input = quote!( + Article, + "SELECT * FROM articles" + if let Some(name_filter) = filter { + "WHERE name ILIKE {name_filter}" + } + ); + let result = syn::parse2::(input).unwrap(); + let ctx = result.to_context().unwrap(); + let output = ctx.generate_output(); + } + + #[test] + fn raw_literal() { + let input = quote!( + Article, + r#"SELECT * FROM articles"# + if let Some(name_filter) = filter { + r#"WHERE "name" ILIKE {name_filter}"#, + } + ); + let result = syn::parse2::(input).unwrap(); + let ctx = result.to_context().unwrap(); + let output = ctx.generate_output(); + } } diff --git a/sqlx-macros/src/query/conditional/segment.rs b/sqlx-macros/src/query/conditional/segment.rs index bf385b9b8b..1a1e838a1c 100644 --- a/sqlx-macros/src/query/conditional/segment.rs +++ b/sqlx-macros/src/query/conditional/segment.rs @@ -1,6 +1,6 @@ +use proc_macro2::Span; use std::mem::swap; use std::ptr::replace; -use proc_macro2::Span; use syn::{ braced, parse::{Parse, ParseStream, Peek}, @@ -8,6 +8,7 @@ use syn::{ }; /// A single "piece" of the input. +#[derive(Debug)] pub enum QuerySegment { /// A part of an SQL query, like `"SELECT *"` Sql(SqlSegment), @@ -37,9 +38,10 @@ impl QuerySegment { } } +#[derive(Debug)] pub struct SqlSegment { pub sql: String, - pub args: Vec<(usize, Expr, usize)> + pub args: Vec<(usize, Expr, usize)>, } impl SqlSegment { @@ -73,7 +75,7 @@ fn parse_inline_args(sql: &str) -> syn::Result> { curr_arg = Some((idx, String::new())); } curr_level += 1; - }, + } '}' => { if curr_arg.is_none() { let err = Error::new(Span::call_site(), "unexpected '}' in query string"); @@ -85,26 +87,30 @@ fn parse_inline_args(sql: &str) -> syn::Result> { args.push((arg_start, arg, idx)); } curr_level -= 1; - }, - c => if let Some((_, arg)) = &mut curr_arg { - arg.push(c); + } + c => { + if let Some((_, arg)) = &mut curr_arg { + arg.push(c); + } } } - }; + } if curr_arg.is_some() { let err = Error::new(Span::call_site(), "expected '}', but got end of string"); - return Err(err) + return Err(err); } Ok(args) } +#[derive(Debug)] pub struct MatchSegment { pub expr: Expr, pub arms: Vec, } +#[derive(Debug)] pub struct MatchSegmentArm { pub pat: Pat, pub body: Vec, @@ -159,6 +165,7 @@ impl Parse for MatchSegment { } } +#[derive(Debug)] pub struct IfSegment { pub condition: Expr, pub then: Vec, diff --git a/sqlx-macros/src/query/mod.rs b/sqlx-macros/src/query/mod.rs index 04ff78b1a1..835cb162b4 100644 --- a/sqlx-macros/src/query/mod.rs +++ b/sqlx-macros/src/query/mod.rs @@ -20,10 +20,10 @@ use crate::query::input::RecordType; use either::Either; mod args; +mod conditional; mod data; mod input; mod output; -mod conditional; pub use conditional::query_as; From 7ff06a5442e314e264a93f84665648d6217ff322 Mon Sep 17 00:00:00 2001 From: Moritz Bischof Date: Wed, 2 Mar 2022 15:23:42 +0100 Subject: [PATCH 15/30] fix parsing of trailing comma --- sqlx-macros/src/query/conditional/mod.rs | 6 +++--- sqlx-macros/src/query/conditional/segment.rs | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sqlx-macros/src/query/conditional/mod.rs b/sqlx-macros/src/query/conditional/mod.rs index d6dad5e0cf..531a4d6837 100644 --- a/sqlx-macros/src/query/conditional/mod.rs +++ b/sqlx-macros/src/query/conditional/mod.rs @@ -123,7 +123,7 @@ impl Input { match &mut ctx { Context::Default(ctx) => ctx.args.extend(self.arguments.iter().cloned()), // we know this can only be a default context since there is only one branch - _ => unreachable!() + _ => unreachable!(), } } @@ -135,7 +135,7 @@ impl Parse for Input { fn parse(input: ParseStream) -> Result { let query_as = input.parse::()?; input.parse::()?; - let segments = QuerySegment::parse_all(input)?; + let segments = QuerySegment::parse_until(input, Token![,])?; let arguments = match input.parse::>()? { None => vec![], Some(..) => Punctuated::::parse_terminated(input)? @@ -499,7 +499,7 @@ mod tests { Article, r#"SELECT * FROM articles"# if let Some(name_filter) = filter { - r#"WHERE "name" ILIKE {name_filter}"#, + r#"WHERE "name" ILIKE {name_filter}"# } ); let result = syn::parse2::(input).unwrap(); diff --git a/sqlx-macros/src/query/conditional/segment.rs b/sqlx-macros/src/query/conditional/segment.rs index 1a1e838a1c..f507b077e9 100644 --- a/sqlx-macros/src/query/conditional/segment.rs +++ b/sqlx-macros/src/query/conditional/segment.rs @@ -20,7 +20,7 @@ pub enum QuerySegment { impl QuerySegment { /// Parse segments up to the first occurrence of the given token, or until the input is empty. - fn parse_until(input: ParseStream, until: T) -> syn::Result> { + pub fn parse_until(input: ParseStream, until: T) -> syn::Result> { let mut segments = vec![]; while !input.is_empty() && !input.peek(until) { segments.push(QuerySegment::parse(input)?); From 5767150814bae2f9f5d6243443b8380c1cf850f6 Mon Sep 17 00:00:00 2001 From: Moritz Bischof Date: Wed, 9 Mar 2022 03:46:21 +0100 Subject: [PATCH 16/30] re-export futures_core from sqlx-core --- sqlx-core/src/lib.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sqlx-core/src/lib.rs b/sqlx-core/src/lib.rs index 19dc0f055d..2725508d11 100644 --- a/sqlx-core/src/lib.rs +++ b/sqlx-core/src/lib.rs @@ -105,3 +105,6 @@ pub mod mssql; /// sqlx uses ahash for increased performance, at the cost of reduced DoS resistance. use ahash::AHashMap as HashMap; //type HashMap = std::collections::HashMap; + +#[doc(hidden)] +pub use futures_core; From aade05a77a3e70790470c22f38eaa912e791f958 Mon Sep 17 00:00:00 2001 From: Moritz Bischof Date: Wed, 9 Mar 2022 03:47:03 +0100 Subject: [PATCH 17/30] fix return type of conditional map fetch method --- sqlx-macros/src/query/conditional/map.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlx-macros/src/query/conditional/map.rs b/sqlx-macros/src/query/conditional/map.rs index c871212211..2c08e2cb2c 100644 --- a/sqlx-macros/src/query/conditional/map.rs +++ b/sqlx-macros/src/query/conditional/map.rs @@ -31,7 +31,7 @@ pub fn generate_conditional_map(n: usize) -> TokenStream { #(#map_fns: FnMut(DB::Row) -> sqlx::Result + Send,)* #(#args: 'q + Send + sqlx::IntoArguments<'q, DB>,)* { - pub fn fetch<'e, 'c: 'e, E>(self, executor: E) -> ormx::exports::futures::stream::BoxStream<'e, sqlx::Result> + pub fn fetch<'e, 'c: 'e, E>(self, executor: E) -> sqlx::futures_core::stream::BoxStream<'e, sqlx::Result> where 'q: 'e, E: 'e + sqlx::Executor<'c, Database = DB>, From f280e674f8ae89b9f8a291b995b04d71f9553697 Mon Sep 17 00:00:00 2001 From: Moritz Bischof Date: Wed, 9 Mar 2022 03:59:52 +0100 Subject: [PATCH 18/30] add simple test --- Cargo.toml | 5 +++++ tests/postgres/conditional_query.rs | 34 +++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+) create mode 100644 tests/postgres/conditional_query.rs diff --git a/Cargo.toml b/Cargo.toml index bb8da5c219..b8dd9fc8cf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -258,6 +258,11 @@ name = "postgres-derives" path = "tests/postgres/derives.rs" required-features = ["postgres", "macros"] +[[test]] +name = "postgres-conditonal_query" +path = "tests/postgres/conditional_query.rs" +required-features = ["postgres", "macros"] + # # Microsoft SQL Server (MSSQL) # diff --git a/tests/postgres/conditional_query.rs b/tests/postgres/conditional_query.rs new file mode 100644 index 0000000000..c182563e80 --- /dev/null +++ b/tests/postgres/conditional_query.rs @@ -0,0 +1,34 @@ +use sqlx_core::postgres::Postgres; +use sqlx_test::new; + +#[sqlx_macros::test] +async fn simple() -> anyhow::Result<()> { + let mut conn = new::().await?; + + struct Result { + result: i32 + } + + for value in [true, false] { + let result = sqlx::query_as!( + Result, + "SELECT" + if value { "42" } else { "12" } + ) + .fetch_one(&mut conn) + .await?; + + if value { + assert_eq!(result.result, 42); + } else { + assert_eq!(result.result, 12); + } + } + + Ok(()) +} + +#[sqlx_macros::test] +async fn fail() -> anyhow::Result<()> { + panic!("test this gets executed"); +} From 64b0ff64425cc9fdbf910447a76e657b5bf04e5f Mon Sep 17 00:00:00 2001 From: Moritz Bischof Date: Wed, 9 Mar 2022 04:20:22 +0100 Subject: [PATCH 19/30] re-export futures_core from sqlx --- src/lib.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/lib.rs b/src/lib.rs index 49b6844ab8..c8337bcad5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -65,6 +65,10 @@ pub use sqlx_core::postgres::{self, PgConnection, PgExecutor, PgPool, Postgres}; #[cfg_attr(docsrs, doc(cfg(feature = "sqlite")))] pub use sqlx_core::sqlite::{self, Sqlite, SqliteConnection, SqliteExecutor, SqlitePool}; +// used by conditional query macro +#[doc(hidden)] +pub use sqlx_core::futures_core; + #[cfg(feature = "macros")] #[doc(hidden)] pub extern crate sqlx_macros; From 1fc3697ca546a1d05fc4c4159fa8ab6723605262 Mon Sep 17 00:00:00 2001 From: Moritz Bischof Date: Wed, 9 Mar 2022 04:20:42 +0100 Subject: [PATCH 20/30] add an other test --- tests/postgres/conditional_query.rs | 68 ++++++++++++++++++++++++++++- 1 file changed, 67 insertions(+), 1 deletion(-) diff --git a/tests/postgres/conditional_query.rs b/tests/postgres/conditional_query.rs index c182563e80..1a84b89b75 100644 --- a/tests/postgres/conditional_query.rs +++ b/tests/postgres/conditional_query.rs @@ -6,7 +6,7 @@ async fn simple() -> anyhow::Result<()> { let mut conn = new::().await?; struct Result { - result: i32 + result: i32, } for value in [true, false] { @@ -14,6 +14,72 @@ async fn simple() -> anyhow::Result<()> { Result, "SELECT" if value { "42" } else { "12" } + r#"AS "result""# + ) + .fetch_one(&mut conn) + .await?; + + if value { + assert_eq!(result.result, 42); + } else { + assert_eq!(result.result, 12); + } + } + + Ok(()) +} + +#[sqlx_macros::test] +async fn single_if() -> anyhow::Result<()> { + let mut conn = new::().await?; + + #[derive(Clone, Eq, PartialEq. Debug)] + struct Article { + id: i32, + title: String, + author: String, + } + + let expected = vec![ + Article { + id: 1, + title: "Article1".to_owned(), + author: "Peter".to_owned(), + }, + Article { + id: 2, + title: "Article2".to_owned(), + author: "John".to_owned(), + }, + ]; + for reverse_order in [true, false] { + let articles = sqlx::query_as!( + Article, + "SELECT *" + r#"FROM (VALUES (1, "Article1", "Peter"), (2, "Article2", "John"))"# + "ORDER BY name" + if reverse_order { + "REV" + } + ) + .fetch_all(&mut conn) + .await?; + + if reverse_order { + let mut expected = expected.clone(); + expected.reverse(); + assert_eq!(articles, expected); + } else { + assert_eq!(articles, expected); + } + } + + for value in [true, false] { + let result = sqlx::query_as!( + Result, + "SELECT" + if value { "42" } else { "12" } + r#"AS "result""# ) .fetch_one(&mut conn) .await?; From fb072bb18fa265427296fb9e1f39fa6838aa19ae Mon Sep 17 00:00:00 2001 From: Moritz Bischof Date: Wed, 9 Mar 2022 04:33:39 +0100 Subject: [PATCH 21/30] remove unused imports --- sqlx-macros/src/query/conditional/map.rs | 3 +-- sqlx-macros/src/query/conditional/mod.rs | 4 ++-- sqlx-macros/src/query/conditional/segment.rs | 2 -- tests/postgres/conditional_query.rs | 12 ++++++------ 4 files changed, 9 insertions(+), 12 deletions(-) diff --git a/sqlx-macros/src/query/conditional/map.rs b/sqlx-macros/src/query/conditional/map.rs index 2c08e2cb2c..7169401014 100644 --- a/sqlx-macros/src/query/conditional/map.rs +++ b/sqlx-macros/src/query/conditional/map.rs @@ -1,6 +1,5 @@ -use proc_macro2::{Span, TokenStream}; +use proc_macro2::TokenStream; use quote::{format_ident, quote}; -use syn::Ident; pub fn generate_conditional_map(n: usize) -> TokenStream { let map_fns = (1..=n).map(|i| format_ident!("F{}", i)).collect::>(); diff --git a/sqlx-macros/src/query/conditional/mod.rs b/sqlx-macros/src/query/conditional/mod.rs index 531a4d6837..95a5ecfd6f 100644 --- a/sqlx-macros/src/query/conditional/mod.rs +++ b/sqlx-macros/src/query/conditional/mod.rs @@ -68,7 +68,7 @@ /// )) /// } /// ``` -use std::{any::Any, fmt::Write, rc::Rc}; +use std::{fmt::Write, rc::Rc}; use proc_macro2::{Span, TokenStream}; use quote::{format_ident, quote}; @@ -76,7 +76,7 @@ use segment::*; use syn::{ parse::{Parse, ParseStream}, punctuated::Punctuated, - Error, Expr, Ident, Pat, Path, Result, Token, + Error, Expr, Pat, Path, Result, Token, }; mod map; diff --git a/sqlx-macros/src/query/conditional/segment.rs b/sqlx-macros/src/query/conditional/segment.rs index f507b077e9..18f892256d 100644 --- a/sqlx-macros/src/query/conditional/segment.rs +++ b/sqlx-macros/src/query/conditional/segment.rs @@ -1,6 +1,4 @@ use proc_macro2::Span; -use std::mem::swap; -use std::ptr::replace; use syn::{ braced, parse::{Parse, ParseStream, Peek}, diff --git a/tests/postgres/conditional_query.rs b/tests/postgres/conditional_query.rs index 1a84b89b75..81be882e95 100644 --- a/tests/postgres/conditional_query.rs +++ b/tests/postgres/conditional_query.rs @@ -6,7 +6,7 @@ async fn simple() -> anyhow::Result<()> { let mut conn = new::().await?; struct Result { - result: i32, + result: Option, } for value in [true, false] { @@ -20,9 +20,9 @@ async fn simple() -> anyhow::Result<()> { .await?; if value { - assert_eq!(result.result, 42); + assert_eq!(result.result, Some(42)); } else { - assert_eq!(result.result, 12); + assert_eq!(result.result, Some(12)); } } @@ -33,7 +33,7 @@ async fn simple() -> anyhow::Result<()> { async fn single_if() -> anyhow::Result<()> { let mut conn = new::().await?; - #[derive(Clone, Eq, PartialEq. Debug)] + #[derive(Clone, Eq, PartialEq, Debug)] struct Article { id: i32, title: String, @@ -56,8 +56,8 @@ async fn single_if() -> anyhow::Result<()> { let articles = sqlx::query_as!( Article, "SELECT *" - r#"FROM (VALUES (1, "Article1", "Peter"), (2, "Article2", "John"))"# - "ORDER BY name" + r#"FROM (VALUES (1, "Article1", "Peter"), (2, "Article2", "John")) articles(id, title, author)"# + "ORDER BY title" if reverse_order { "REV" } From ac4e02dc9ec8403a41ca019353e3582c9d008498 Mon Sep 17 00:00:00 2001 From: Moritz Bischof Date: Wed, 9 Mar 2022 04:42:11 +0100 Subject: [PATCH 22/30] fix tests --- tests/postgres/conditional_query.rs | 25 ++++++------------------- 1 file changed, 6 insertions(+), 19 deletions(-) diff --git a/tests/postgres/conditional_query.rs b/tests/postgres/conditional_query.rs index 81be882e95..aa06ec42b0 100644 --- a/tests/postgres/conditional_query.rs +++ b/tests/postgres/conditional_query.rs @@ -56,10 +56,14 @@ async fn single_if() -> anyhow::Result<()> { let articles = sqlx::query_as!( Article, "SELECT *" - r#"FROM (VALUES (1, "Article1", "Peter"), (2, "Article2", "John")) articles(id, title, author)"# + "FROM (" + "VALUES (1, 'Article1', 'Peter'), (2, 'Article2', 'John')" + ") articles(id, title, author)" "ORDER BY title" if reverse_order { - "REV" + "DESC" + } else { + "ASC" } ) .fetch_all(&mut conn) @@ -74,23 +78,6 @@ async fn single_if() -> anyhow::Result<()> { } } - for value in [true, false] { - let result = sqlx::query_as!( - Result, - "SELECT" - if value { "42" } else { "12" } - r#"AS "result""# - ) - .fetch_one(&mut conn) - .await?; - - if value { - assert_eq!(result.result, 42); - } else { - assert_eq!(result.result, 12); - } - } - Ok(()) } From 024c9be3c7d51313ad42859c7ddeb70ac082804e Mon Sep 17 00:00:00 2001 From: Moritz Bischof Date: Wed, 9 Mar 2022 04:49:54 +0100 Subject: [PATCH 23/30] tests: fix nullability issue --- tests/postgres/conditional_query.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/postgres/conditional_query.rs b/tests/postgres/conditional_query.rs index aa06ec42b0..9a950191ca 100644 --- a/tests/postgres/conditional_query.rs +++ b/tests/postgres/conditional_query.rs @@ -55,7 +55,8 @@ async fn single_if() -> anyhow::Result<()> { for reverse_order in [true, false] { let articles = sqlx::query_as!( Article, - "SELECT *" + "SELECT" + r#"id AS "id!", title AS "title!", author AS "author!""# "FROM (" "VALUES (1, 'Article1', 'Peter'), (2, 'Article2', 'John')" ") articles(id, title, author)" From d45d86b8e3f3a8cb449c8fd96435be69eafe9395 Mon Sep 17 00:00:00 2001 From: Moritz Bischof Date: Wed, 9 Mar 2022 13:52:36 +0100 Subject: [PATCH 24/30] tests: add dynamic_filtering test --- tests/postgres/conditional_query.rs | 63 +++++++++++++++++++++++++++-- 1 file changed, 59 insertions(+), 4 deletions(-) diff --git a/tests/postgres/conditional_query.rs b/tests/postgres/conditional_query.rs index 9a950191ca..2938e880b9 100644 --- a/tests/postgres/conditional_query.rs +++ b/tests/postgres/conditional_query.rs @@ -1,4 +1,4 @@ -use sqlx_core::postgres::Postgres; +use sqlx_core::postgres::{PgConnection, Postgres}; use sqlx_test::new; #[sqlx_macros::test] @@ -30,7 +30,7 @@ async fn simple() -> anyhow::Result<()> { } #[sqlx_macros::test] -async fn single_if() -> anyhow::Result<()> { +async fn dynamic_ordering() -> anyhow::Result<()> { let mut conn = new::().await?; #[derive(Clone, Eq, PartialEq, Debug)] @@ -83,6 +83,61 @@ async fn single_if() -> anyhow::Result<()> { } #[sqlx_macros::test] -async fn fail() -> anyhow::Result<()> { - panic!("test this gets executed"); +async fn dynamic_filtering() -> anyhow::Result<()> { + let mut conn = new::().await?; + + #[derive(Clone, Eq, PartialEq, Debug)] + struct Article { + id: i32, + title: String, + author: String, + } + + enum Filter { + Id(i32), + Title(String), + Author(String), + } + + async fn query_articles( + con: &mut PgConnection, + filter: Option, + ) -> anyhow::Result> { + let articles = sqlx::query_as!( + Article, + "SELECT" + r#"id AS "id!", title AS "title!", author AS "author!""# + "FROM (" + "VALUES (1, 'Article1', 'Peter'), (2, 'Article2', 'John'), (3, 'Article3', 'James')" + ") articles(id, title, author)" + if let Some(filter) = filter { + "WHERE" + match filter { + Filter::Id(id) => "id = {id}", + Filter::Title(title) => "title ILIKE {title}", + Filter::Author(author) => "author ILIKE {author}" + } + } + ) + .fetch_all(&mut conn) + .await?; + Ok(articles) + } + + let result = query_articles(&mut conn, None).await?; + assert_eq!(result.len(), 3); + + let result = query_articles(&mut conn, Some(Filter::Id(1))).await?; + assert_eq!(result.len(), 1); + assert_eq!(result[0].id, 1); + + let result = query_articles(&mut conn, Some(Filter::Title("article2".to_owned()))).await?; + assert_eq!(result.len(), 1); + assert_eq!(result[0].id, 2); + + let result = query_articles(&mut conn, Some(Filter::Author("james".to_owned()))).await?; + assert_eq!(result.len(), 1); + assert_eq!(result[0].id, 3); + + Ok(()) } From 046c2dd3f5be98d0a0cc57b6fa5c93a11b81e35d Mon Sep 17 00:00:00 2001 From: Moritz Bischof Date: Wed, 9 Mar 2022 14:19:58 +0100 Subject: [PATCH 25/30] docs: add user-facing documentation --- sqlx-macros/src/lib.rs | 53 ++++++++++++++++++++++++++++++++++++++---- 1 file changed, 49 insertions(+), 4 deletions(-) diff --git a/sqlx-macros/src/lib.rs b/sqlx-macros/src/lib.rs index 2e70eb279d..957e5c6649 100644 --- a/sqlx-macros/src/lib.rs +++ b/sqlx-macros/src/lib.rs @@ -62,8 +62,9 @@ pub fn expand_query(input: TokenStream) -> TokenStream { /// * If a column may be `NULL`, the corresponding field's type must be wrapped in `Option<_>`. /// * Neither the query nor the struct may have unused fields. /// -/// The only modification to the `query!()` syntax is that the struct name is given before the SQL -/// string: +/// In contrast to the syntax of `query!()`, the struct name is given before the SQL +/// string. Arguments may be passed, like in `query!()`, within a comma-seperated list after the +/// SQL string, or inline within the query string using `"{}"`: /// ```rust,ignore /// # use sqlx::Connect; /// # #[cfg(all(feature = "mysql", feature = "_rt-async-std"))] @@ -80,10 +81,10 @@ pub fn expand_query(input: TokenStream) -> TokenStream { /// } /// /// // let mut conn = ; +/// let id = 1; /// let account = sqlx::query_as!( /// Account, -/// "select * from (select (1) as id, 'Herp Derpinson' as name) accounts where id = ?", -/// 1i32 +/// "select * from (select (1) as id, 'Herp Derpinson' as name) accounts where id = {id}" /// ) /// .fetch_one(&mut conn) /// .await?; @@ -140,6 +141,50 @@ pub fn expand_query(input: TokenStream) -> TokenStream { /// assert_eq!(record.id, MyInt4(1)); /// ``` /// +/// ### Conditional Queries +/// This macro allows you to dynamically construct queries at runtime, while still ensuring that +/// they are checked at compile-time. +/// +/// Let's consider an example first. Let's say you want to query all products from your database, +/// while the user may decide if he wants them ordered in ascending or descending order. +/// This could be achieved by writing both queries out by hand: +/// ```rust,ignore +/// let products = if order_ascending { +/// sqlx::query_as!( +/// Product, +/// "SELECT * FROM products ORDER BY name ASC" +/// ) +/// .fetch_all(&mut con) +/// .await? +/// } else { +/// sqlx::query_as!( +/// Product, +/// "SELECT * FROM products ORDER BY name DESC" +/// ) +/// .fetch_all(&mut con) +/// .await? +/// }; +/// ``` +/// To avoid repetition in these cases, you may use `if`, `if let` and `match` directly within the macro +/// invocation: +/// ```rust,ignore +/// let products = sqlx::query_as!( +/// Product, +/// "SELECT * FROM products ORDER BY NAME" +/// if order_ascending { "ASC" } else { "DESC" } +/// ) +/// .fetch_all(&mut con) +/// .await?; +/// ``` +/// The macro will expand to something similar like in the verbose example above, ensuring that +/// every possible query which may result from the macro invocation is checked at compile-time. +/// +/// When writing *conditional* queries, parameters may only be given inline. +/// +/// It is recommended to avoid using a lot of `if` and `match` clauses within a single `query_as!` +/// invocation. Do not use much more than 6 within a single query to avoid drastically increased +/// compile times. +/// /// ### Troubleshooting: "error: mismatched types" /// If you get a "mismatched types" error from an invocation of this macro and the error /// isn't pointing specifically at a parameter. From 11be4f334f7d5dd2d02e84ecb6558c9fe72c6b99 Mon Sep 17 00:00:00 2001 From: Moritz Bischof Date: Wed, 9 Mar 2022 14:20:24 +0100 Subject: [PATCH 26/30] tests: test inline arguments in query_as! --- tests/postgres/macros.rs | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/postgres/macros.rs b/tests/postgres/macros.rs index 51d1f89bb8..c7b6ac2105 100644 --- a/tests/postgres/macros.rs +++ b/tests/postgres/macros.rs @@ -137,6 +137,25 @@ async fn test_query_as() -> anyhow::Result<()> { Ok(()) } +#[sqlx_macros::test] +async fn test_query_as_inline_args() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let name: Option<&str> = None; + let account = sqlx::query_as!( + Account, + r#"SELECT id "id!", name from (VALUES (1, {name})) accounts(id, name)"#, + ) + .fetch_one(&mut conn) + .await?; + + assert_eq!(None, account.name); + + println!("{:?}", account); + + Ok(()) +} + #[derive(Debug)] struct RawAccount { r#type: i32, From bbe8d0797b745e961b78869b35da8311e76a9e91 Mon Sep 17 00:00:00 2001 From: Moritz Bischof Date: Wed, 9 Mar 2022 14:36:13 +0100 Subject: [PATCH 27/30] add debug log --- sqlx-macros/src/query/conditional/mod.rs | 1 + tests/postgres/conditional_query.rs | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/sqlx-macros/src/query/conditional/mod.rs b/sqlx-macros/src/query/conditional/mod.rs index 95a5ecfd6f..b157ad4c93 100644 --- a/sqlx-macros/src/query/conditional/mod.rs +++ b/sqlx-macros/src/query/conditional/mod.rs @@ -188,6 +188,7 @@ impl Context { let result = { let mut branch_counter = 0; let output = self.to_query(branches, &mut branch_counter); + println!("{}", output); assert_eq!(branch_counter, branches); output }; diff --git a/tests/postgres/conditional_query.rs b/tests/postgres/conditional_query.rs index 2938e880b9..ce4c415935 100644 --- a/tests/postgres/conditional_query.rs +++ b/tests/postgres/conditional_query.rs @@ -119,7 +119,7 @@ async fn dynamic_filtering() -> anyhow::Result<()> { } } ) - .fetch_all(&mut conn) + .fetch_all(con) .await?; Ok(articles) } From 5a29c81d665583aa1ff5d1422a3bfd50b4de037e Mon Sep 17 00:00:00 2001 From: Moritz Bischof Date: Wed, 9 Mar 2022 14:48:38 +0100 Subject: [PATCH 28/30] fix off-by-one in parameter generation --- sqlx-macros/src/query/conditional/mod.rs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/sqlx-macros/src/query/conditional/mod.rs b/sqlx-macros/src/query/conditional/mod.rs index b157ad4c93..72481fa674 100644 --- a/sqlx-macros/src/query/conditional/mod.rs +++ b/sqlx-macros/src/query/conditional/mod.rs @@ -73,6 +73,7 @@ use std::{fmt::Write, rc::Rc}; use proc_macro2::{Span, TokenStream}; use quote::{format_ident, quote}; use segment::*; +use sha2::digest::generic_array::typenum::Exp; use syn::{ parse::{Parse, ParseStream}, punctuated::Punctuated, @@ -277,9 +278,10 @@ struct NormalContext { } impl NormalContext { - fn add_parameter(&mut self) { + fn add_parameter(&mut self, expr: Expr) { + self.args.push(expr.clone()); if cfg!(feature = "postgres") { - write!(&mut self.sql, "${}", self.args.len() + 1).unwrap(); + write!(&mut self.sql, "${}", self.args.len()).unwrap(); } else { self.sql.push('?'); } @@ -328,8 +330,7 @@ impl IsContext for NormalContext { self.sql.push(c); } if idx == *end { - self.args.push(expr.clone()); - self.add_parameter(); + self.add_parameter(expr.clone()); arg = args.next(); } } else { From 85065a1d0bf7169c978f46c4167325a1a6111640 Mon Sep 17 00:00:00 2001 From: Moritz Bischof Date: Wed, 9 Mar 2022 14:58:33 +0100 Subject: [PATCH 29/30] fix unused import --- sqlx-macros/src/query/conditional/mod.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/sqlx-macros/src/query/conditional/mod.rs b/sqlx-macros/src/query/conditional/mod.rs index 72481fa674..6c4701ef46 100644 --- a/sqlx-macros/src/query/conditional/mod.rs +++ b/sqlx-macros/src/query/conditional/mod.rs @@ -73,7 +73,6 @@ use std::{fmt::Write, rc::Rc}; use proc_macro2::{Span, TokenStream}; use quote::{format_ident, quote}; use segment::*; -use sha2::digest::generic_array::typenum::Exp; use syn::{ parse::{Parse, ParseStream}, punctuated::Punctuated, From bae5dd638aad0a64959ee329bebd6655da76d680 Mon Sep 17 00:00:00 2001 From: Moritz Bischof Date: Wed, 9 Mar 2022 15:13:27 +0100 Subject: [PATCH 30/30] re-run CI --- sqlx-macros/src/query/conditional/mod.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/sqlx-macros/src/query/conditional/mod.rs b/sqlx-macros/src/query/conditional/mod.rs index 6c4701ef46..7adb740c05 100644 --- a/sqlx-macros/src/query/conditional/mod.rs +++ b/sqlx-macros/src/query/conditional/mod.rs @@ -188,7 +188,6 @@ impl Context { let result = { let mut branch_counter = 0; let output = self.to_query(branches, &mut branch_counter); - println!("{}", output); assert_eq!(branch_counter, branches); output };