diff --git a/.github/workflows/sqlx.yml b/.github/workflows/sqlx.yml index f325cb7c3f..3db2c732ac 100644 --- a/.github/workflows/sqlx.yml +++ b/.github/workflows/sqlx.yml @@ -3,8 +3,6 @@ name: SQLx on: pull_request: push: - branches: - - master jobs: format: diff --git a/Cargo.toml b/Cargo.toml index bb8da5c219..b8dd9fc8cf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -258,6 +258,11 @@ name = "postgres-derives" path = "tests/postgres/derives.rs" required-features = ["postgres", "macros"] +[[test]] +name = "postgres-conditonal_query" +path = "tests/postgres/conditional_query.rs" +required-features = ["postgres", "macros"] + # # Microsoft SQL Server (MSSQL) # diff --git a/sqlx-core/src/lib.rs b/sqlx-core/src/lib.rs index 19dc0f055d..2725508d11 100644 --- a/sqlx-core/src/lib.rs +++ b/sqlx-core/src/lib.rs @@ -105,3 +105,6 @@ pub mod mssql; /// sqlx uses ahash for increased performance, at the cost of reduced DoS resistance. use ahash::AHashMap as HashMap; //type HashMap = std::collections::HashMap; + +#[doc(hidden)] +pub use futures_core; diff --git a/sqlx-macros/Cargo.toml b/sqlx-macros/Cargo.toml index b6c6c99b08..54d2088aad 100644 --- a/sqlx-macros/Cargo.toml +++ b/sqlx-macros/Cargo.toml @@ -89,6 +89,6 @@ sqlx-rt = { version = "0.5.11", default-features = false, path = "../sqlx-rt" } serde = { version = "1.0.132", features = ["derive"], optional = true } serde_json = { version = "1.0.73", optional = true } sha2 = { version = "0.9.8", optional = true } -syn = { version = "1.0.84", default-features = false, features = ["full"] } +syn = { version = "1.0.84", default-features = false, features = ["full", "extra-traits"] } quote = { version = "1.0.14", default-features = false } url = { version = "2.2.2", default-features = false } diff --git a/sqlx-macros/src/lib.rs b/sqlx-macros/src/lib.rs index c1f173e655..957e5c6649 100644 --- a/sqlx-macros/src/lib.rs +++ b/sqlx-macros/src/lib.rs @@ -41,6 +41,210 @@ pub fn expand_query(input: TokenStream) -> TokenStream { } } +/// A variant of [query!] which takes a path to an explicitly defined struct as the output type. +/// +/// This lets you return the struct from a function or add your own trait implementations. +/// +/// **No trait implementations are required**; the macro maps rows using a struct literal +/// where the names of columns in the query are expected to be the same as the fields of the struct +/// (but the order does not need to be the same). The types of the columns are based on the +/// query and not the corresponding fields of the struct, so this is type-safe as well. +/// +/// This enforces a few things: +/// * The query must output at least one column. +/// * The column names of the query must match the field names of the struct. +/// * The field types must be the Rust equivalent of their SQL counterparts; see the corresponding +/// module for your database for mappings: +/// * Postgres: [crate::postgres::types] +/// * MySQL: [crate::mysql::types] +/// * SQLite: [crate::sqlite::types] +/// * MSSQL: [crate::mssql::types] +/// * If a column may be `NULL`, the corresponding field's type must be wrapped in `Option<_>`. +/// * Neither the query nor the struct may have unused fields. +/// +/// In contrast to the syntax of `query!()`, the struct name is given before the SQL +/// string. Arguments may be passed, like in `query!()`, within a comma-seperated list after the +/// SQL string, or inline within the query string using `"{}"`: +/// ```rust,ignore +/// # use sqlx::Connect; +/// # #[cfg(all(feature = "mysql", feature = "_rt-async-std"))] +/// # #[async_std::main] +/// # async fn main() -> sqlx::Result<()>{ +/// # let db_url = dotenv::var("DATABASE_URL").expect("DATABASE_URL must be set"); +/// # +/// # if !(db_url.starts_with("mysql") || db_url.starts_with("mariadb")) { return Ok(()) } +/// # let mut conn = sqlx::MySqlConnection::connect(db_url).await?; +/// #[derive(Debug)] +/// struct Account { +/// id: i32, +/// name: String +/// } +/// +/// // let mut conn = ; +/// let id = 1; +/// let account = sqlx::query_as!( +/// Account, +/// "select * from (select (1) as id, 'Herp Derpinson' as name) accounts where id = {id}" +/// ) +/// .fetch_one(&mut conn) +/// .await?; +/// +/// println!("{:?}", account); +/// println!("{}: {}", account.id, account.name); +/// +/// # Ok(()) +/// # } +/// # +/// # #[cfg(any(not(feature = "mysql"), not(feature = "_rt-async-std")))] +/// # fn main() {} +/// ``` +/// +/// **The method you want to call depends on how many rows you're expecting.** +/// +/// | Number of Rows | Method to Call* | Returns (`T` being the given struct) | Notes | +/// |----------------| ----------------------------|----------------------------------------|-------| +/// | Zero or One | `.fetch_optional(...).await`| `sqlx::Result>` | Extra rows are ignored. | +/// | Exactly One | `.fetch_one(...).await` | `sqlx::Result` | Errors if no rows were returned. Extra rows are ignored. Aggregate queries, use this. | +/// | At Least One | `.fetch(...)` | `impl Stream>` | Call `.try_next().await` to get each row result. | +/// | Multiple | `.fetch_all(...)` | `sqlx::Result>` | | +/// +/// \* All methods accept one of `&mut {connection type}`, `&mut Transaction` or `&Pool`. +/// (`.execute()` is omitted as this macro requires at least one column to be returned.) +/// +/// ### Column Type Override: Infer from Struct Field +/// In addition to the column type overrides supported by [query!], `query_as!()` supports an +/// additional override option: +/// +/// If you select a column `foo as "foo: _"` (Postgres/SQLite) or `` foo as `foo: _` `` (MySQL) +/// it causes that column to be inferred based on the type of the corresponding field in the given +/// record struct. Runtime type-checking is still done so an error will be emitted if the types +/// are not compatible. +/// +/// This allows you to override the inferred type of a column to instead use a custom-defined type: +/// +/// ```rust,ignore +/// #[derive(sqlx::Type)] +/// #[sqlx(transparent)] +/// struct MyInt4(i32); +/// +/// struct Record { +/// id: MyInt4, +/// } +/// +/// let my_int = MyInt4(1); +/// +/// // Postgres/SQLite +/// sqlx::query_as!(Record, r#"select 1 as "id: _""#) // MySQL: use "select 1 as `id: _`" instead +/// .fetch_one(&mut conn) +/// .await?; +/// +/// assert_eq!(record.id, MyInt4(1)); +/// ``` +/// +/// ### Conditional Queries +/// This macro allows you to dynamically construct queries at runtime, while still ensuring that +/// they are checked at compile-time. +/// +/// Let's consider an example first. Let's say you want to query all products from your database, +/// while the user may decide if he wants them ordered in ascending or descending order. +/// This could be achieved by writing both queries out by hand: +/// ```rust,ignore +/// let products = if order_ascending { +/// sqlx::query_as!( +/// Product, +/// "SELECT * FROM products ORDER BY name ASC" +/// ) +/// .fetch_all(&mut con) +/// .await? +/// } else { +/// sqlx::query_as!( +/// Product, +/// "SELECT * FROM products ORDER BY name DESC" +/// ) +/// .fetch_all(&mut con) +/// .await? +/// }; +/// ``` +/// To avoid repetition in these cases, you may use `if`, `if let` and `match` directly within the macro +/// invocation: +/// ```rust,ignore +/// let products = sqlx::query_as!( +/// Product, +/// "SELECT * FROM products ORDER BY NAME" +/// if order_ascending { "ASC" } else { "DESC" } +/// ) +/// .fetch_all(&mut con) +/// .await?; +/// ``` +/// The macro will expand to something similar like in the verbose example above, ensuring that +/// every possible query which may result from the macro invocation is checked at compile-time. +/// +/// When writing *conditional* queries, parameters may only be given inline. +/// +/// It is recommended to avoid using a lot of `if` and `match` clauses within a single `query_as!` +/// invocation. Do not use much more than 6 within a single query to avoid drastically increased +/// compile times. +/// +/// ### Troubleshooting: "error: mismatched types" +/// If you get a "mismatched types" error from an invocation of this macro and the error +/// isn't pointing specifically at a parameter. +/// +/// For example, code like this (using a Postgres database): +/// +/// ```rust,ignore +/// struct Account { +/// id: i32, +/// name: Option, +/// } +/// +/// let account = sqlx::query_as!( +/// Account, +/// r#"SELECT id, name from (VALUES (1, 'Herp Derpinson')) accounts(id, name)"#, +/// ) +/// .fetch_one(&mut conn) +/// .await?; +/// ``` +/// +/// Might produce an error like this: +/// ```text,ignore +/// error[E0308]: mismatched types +/// --> tests/postgres/macros.rs:126:19 +/// | +/// 126 | let account = sqlx::query_as!( +/// | ___________________^ +/// 127 | | Account, +/// 128 | | r#"SELECT id, name from (VALUES (1, 'Herp Derpinson')) accounts(id, name)"#, +/// 129 | | ) +/// | |_____^ expected `i32`, found enum `std::option::Option` +/// | +/// = note: expected type `i32` +/// found enum `std::option::Option` +/// ``` +/// +/// This means that you need to check that any field of the "expected" type (here, `i32`) matches +/// the Rust type mapping for its corresponding SQL column (see the `types` module of your database, +/// listed above, for mappings). The "found" type is the SQL->Rust mapping that the macro chose. +/// +/// In the above example, the returned column is inferred to be nullable because it's being +/// returned from a `VALUES` statement in Postgres, so the macro inferred the field to be nullable +/// and so used `Option` instead of `i32`. **In this specific case** we could use +/// `select id as "id!"` to override the inferred nullability because we know in practice +/// that column will never be `NULL` and it will fix the error. +/// +/// Nullability inference and type overrides are discussed in detail in the docs for [query!]. +/// +/// It unfortunately doesn't appear to be possible right now to make the error specifically mention +/// the field; this probably requires the `const-panic` feature (still unstable as of Rust 1.45). +#[cfg_attr(docsrs, doc(cfg(feature = "macros")))] +#[proc_macro] +pub fn query_as(input: TokenStream) -> TokenStream { + match query::query_as(input.into()) { + Ok(output) => output, + Err(err) => err.to_compile_error(), + } + .into() +} + #[proc_macro_derive(Encode, attributes(sqlx))] pub fn derive_encode(tokenstream: TokenStream) -> TokenStream { let input = syn::parse_macro_input!(tokenstream as syn::DeriveInput); diff --git a/sqlx-macros/src/query/conditional/map.rs b/sqlx-macros/src/query/conditional/map.rs new file mode 100644 index 0000000000..7169401014 --- /dev/null +++ b/sqlx-macros/src/query/conditional/map.rs @@ -0,0 +1,80 @@ +use proc_macro2::TokenStream; +use quote::{format_ident, quote}; + +pub fn generate_conditional_map(n: usize) -> TokenStream { + let map_fns = (1..=n).map(|i| format_ident!("F{}", i)).collect::>(); + let args = (1..=n).map(|i| format_ident!("A{}", i)).collect::>(); + let variants = (1..=n).map(|i| format_ident!("_{}", i)).collect::>(); + let variant_declarations = (0..n).map(|i| { + let variant = &variants[i]; + let map_fn = &map_fns[i]; + let args = &args[i]; + quote!(#variant(sqlx::query::Map<'q, DB, #map_fn, #args>)) + }); + + quote! { + #[doc(hidden)] + pub enum ConditionalMap<'q, DB, O, #(#map_fns,)* #(#args,)*> + where + DB: sqlx::Database, + O: Send + Unpin, + #(#map_fns: FnMut(DB::Row) -> sqlx::Result + Send,)* + #(#args: 'q + Send + sqlx::IntoArguments<'q, DB>,)* + { + #(#variant_declarations),* + } + impl<'q, DB, O, #(#map_fns,)* #(#args,)*> ConditionalMap<'q, DB, O, #(#map_fns,)* #(#args,)*> + where + DB: sqlx::Database, + O: Send + Unpin, + #(#map_fns: FnMut(DB::Row) -> sqlx::Result + Send,)* + #(#args: 'q + Send + sqlx::IntoArguments<'q, DB>,)* + { + pub fn fetch<'e, 'c: 'e, E>(self, executor: E) -> sqlx::futures_core::stream::BoxStream<'e, sqlx::Result> + where + 'q: 'e, + E: 'e + sqlx::Executor<'c, Database = DB>, + DB: 'e, + O: 'e, + #(#map_fns: 'e,)* + { + match self { #( + Self::#variants(x) => x.fetch(executor) + ),* } + } + pub async fn fetch_all<'e, 'c: 'e, E>(self, executor: E) -> sqlx::Result> + where + 'q: 'e, + DB: 'e, + E: 'e + sqlx::Executor<'c, Database = DB>, + O: 'e + { + match self { #( + Self::#variants(x) => x.fetch_all(executor).await + ),* } + } + pub async fn fetch_one<'e, 'c: 'e, E>(self, executor: E) -> sqlx::Result + where + 'q: 'e, + E: 'e + sqlx::Executor<'c, Database = DB>, + DB: 'e, + O: 'e, + { + match self { #( + Self::#variants(x) => x.fetch_one(executor).await + ),* } + } + pub async fn fetch_optional<'e, 'c: 'e, E>(self, executor: E) -> sqlx::Result> + where + 'q: 'e, + E: 'e + sqlx::Executor<'c, Database = DB>, + DB: 'e, + O: 'e, + { + match self { #( + Self::#variants(x) => x.fetch_optional(executor).await + ),* } + } + } + } +} diff --git a/sqlx-macros/src/query/conditional/mod.rs b/sqlx-macros/src/query/conditional/mod.rs new file mode 100644 index 0000000000..7adb740c05 --- /dev/null +++ b/sqlx-macros/src/query/conditional/mod.rs @@ -0,0 +1,509 @@ +/// This module introduces support for building dynamic queries while still having them checked at +/// compile time. +/// This is achieved by computing every possible query within a procedural macro. +/// It's only during runtime when the appropriate query will be chosen and executed. +/// +/// ## Return type +/// Since a single invocation of `query_as!` executes one of many possible queries, it's return type +/// would differ between different invocations. This, of course, would break as soon as one tried to +/// do anything with the return value, for example `.await`ing it. +/// Therefor, this module introduces a workaround for conditional queries. The behaviour of normal +/// queries is not affected by this. +/// +/// For each *conditional* invocation of `query_as!`, an enum will be generated, and the invocation +/// expands to an instance of this enum. This enum contains a variant for each possible query. +/// see `[map::generate_conditional_map]` +/// +/// ## Arguments +/// For conditional queries, arguments must be specified *inline* (for example ".. WHERE name ILIKE {filter}). +/// For normal queries, arguments can still be passed as a comma-separated list. +/// +/// ## Example +/// To outline how this all works, let's consider the following example. +/// ```rust,ignore +/// sqlx::query_as!( +/// Article, +/// "SELECT * FROM articles" +/// if let Some(name_filter) = filter { +/// "WHERE name ILIKE {name_filter} +/// } +/// ``` +/// +/// This input will first be parsed into a list of `QuerySegment`s. +/// For the example above, this would result in something like this: +/// ```rust,ignore +/// [ +/// SqlSegment { sql: "SELECT * FROM articles", args: [] }, +/// IfSegment { +/// condition: "let Some (name_filter) = filter", +/// then: [ +/// SqlSegment { sql: "WHERE name ILIKE {name_filter}" } +/// ] +/// } +/// ``` +/// +/// These segments are now transformed into a tree-structure. In essence, this would result in: +/// ```rust,ignore +/// IfContext { +/// condition: "let Some(name_filter) = filter", +/// then: NormalContext { sql: "SELECT * FROM articles WHERE name ILIKE ?", args: ["name_filter"] }, +/// or_else: NormalContext { sql: "SELECT * FROM articles", args: [] }, +/// } +/// ``` +/// +/// Finally, the resulting code is generated: +/// ```rust,ignore +/// enum ConditionalMap { .. } +/// if let Some(name_filter) = filter { +/// ConditionalMap::_1(sqlx_macros::expand_query!( +/// record = Article, +/// source = "SELECT * FROM articles WHERE name ILIKE ?", +/// args = [name_filter] +/// )) +/// } else { +/// ConditionalMap::_2(sqlx_macros::expand_query!( +/// record = Article, +/// source = "SELECT * FROM articles", +/// args = [] +/// )) +/// } +/// ``` +use std::{fmt::Write, rc::Rc}; + +use proc_macro2::{Span, TokenStream}; +use quote::{format_ident, quote}; +use segment::*; +use syn::{ + parse::{Parse, ParseStream}, + punctuated::Punctuated, + Error, Expr, Pat, Path, Result, Token, +}; + +mod map; +mod segment; + +/// Entry point of the `query_as!` macro +pub fn query_as(input: TokenStream) -> Result { + let input = syn::parse2::(input)?; + let ctx = input.to_context()?; + let out = ctx.generate_output(); + Ok(out) +} + +/// Input to the `query_as!` macro. +struct Input { + query_as: Path, + segments: Vec, + // separately specified arguments + arguments: Vec, +} + +impl Input { + /// Convert the input into a context + fn to_context(&self) -> Result { + let mut ctx = Context::Default(NormalContext { + query_as: Rc::new(self.query_as.clone()), + sql: String::new(), + args: vec![], + }); + + for segment in &self.segments { + ctx.add_segment(segment); + } + + // add separately specified arguments to context + if !self.arguments.is_empty() { + if ctx.branches() > 1 { + let err = Error::new( + Span::call_site(), + "branches (`match` and `if`) can only be used with inline arguments", + ); + return Err(err); + } + match &mut ctx { + Context::Default(ctx) => ctx.args.extend(self.arguments.iter().cloned()), + // we know this can only be a default context since there is only one branch + _ => unreachable!(), + } + } + + Ok(ctx) + } +} + +impl Parse for Input { + fn parse(input: ParseStream) -> Result { + let query_as = input.parse::()?; + input.parse::()?; + let segments = QuerySegment::parse_until(input, Token![,])?; + let arguments = match input.parse::>()? { + None => vec![], + Some(..) => Punctuated::::parse_terminated(input)? + .into_iter() + .collect(), + }; + Ok(Self { + query_as, + segments, + arguments, + }) + } +} + +/// A context describes the current position within a conditional query. +#[derive(Clone, Debug)] +enum Context { + Default(NormalContext), + If(IfContext), + Match(MatchContext), +} + +trait IsContext { + /// Return the number of branches the current context, including its children, contains. + fn branches(&self) -> usize; + /// Generate a call to a sqlx query macro for this context. + fn to_query(&self, branches: usize, branch_counter: &mut usize) -> TokenStream; + /// Add a piece of an SQL query to this context. + fn add_sql(&mut self, sql: &SqlSegment); +} + +impl IsContext for Context { + fn branches(&self) -> usize { + self.as_dyn().branches() + } + + fn to_query(&self, branches: usize, branch_counter: &mut usize) -> TokenStream { + self.as_dyn().to_query(branches, branch_counter) + } + + fn add_sql(&mut self, sql: &SqlSegment) { + self.as_dyn_mut().add_sql(sql); + } +} + +impl Context { + fn generate_output(&self) -> TokenStream { + let branches = self.branches(); + + let result = { + let mut branch_counter = 0; + let output = self.to_query(branches, &mut branch_counter); + assert_eq!(branch_counter, branches); + output + }; + + match branches { + 1 => quote!( #result ), + _ => { + let map = map::generate_conditional_map(branches); + quote!( { #map #result } ) + } + } + } + + fn as_dyn(&self) -> &dyn IsContext { + match self { + Context::Default(c) => c as _, + Context::If(c) => c as _, + Context::Match(c) => c as _, + } + } + + fn as_dyn_mut(&mut self) -> &mut dyn IsContext { + match self { + Context::Default(c) => c as _, + Context::If(c) => c as _, + Context::Match(c) => c as _, + } + } + + fn add_segment(&mut self, s: &QuerySegment) { + match s { + QuerySegment::Sql(s) => self.add_sql(s), + QuerySegment::If(s) => self.add_if(s), + QuerySegment::Match(s) => self.add_match(s), + } + } + + fn add_if(&mut self, s: &IfSegment) { + let mut if_ctx = IfContext { + condition: s.condition.clone(), + then: Box::new(self.clone()), + or_else: Box::new(self.clone()), + }; + for then in &s.then { + if_ctx.then.add_segment(then); + } + for or_else in &s.or_else { + if_ctx.or_else.add_segment(or_else); + } + // replace the current context with the new IfContext + *self = Context::If(if_ctx); + } + + fn add_match(&mut self, s: &MatchSegment) { + let arms = s + .arms + .iter() + .map(|arm| { + let mut arm_ctx = MatchArmContext { + pattern: arm.pat.clone(), + inner: Box::new(self.clone()), + }; + for segment in &arm.body { + arm_ctx.inner.add_segment(segment); + } + arm_ctx + }) + .collect::>(); + + let match_ctx = MatchContext { + expr: s.expr.clone(), + arms, + }; + + // replace the current context with the new MatchContext + *self = Context::Match(match_ctx); + } +} + +/// A "normal" linear context without any branches. +#[derive(Clone, Debug)] +struct NormalContext { + query_as: Rc, + sql: String, + args: Vec, +} + +impl NormalContext { + fn add_parameter(&mut self, expr: Expr) { + self.args.push(expr.clone()); + if cfg!(feature = "postgres") { + write!(&mut self.sql, "${}", self.args.len()).unwrap(); + } else { + self.sql.push('?'); + } + } +} + +impl IsContext for NormalContext { + fn branches(&self) -> usize { + 1 + } + + fn to_query(&self, branches: usize, branch_counter: &mut usize) -> TokenStream { + let NormalContext { + query_as, + sql, + args, + } = self; + *branch_counter += 1; + + let query_call = quote!(sqlx_macros::expand_query!( + record = #query_as, + source = #sql, + args = [#(#args),*] + )); + match branches { + 1 => query_call, + _ => { + let variant = format_ident!("_{}", branch_counter); + quote!(ConditionalMap::#variant(#query_call)) + } + } + } + + fn add_sql(&mut self, sql: &SqlSegment) { + if !self.sql.is_empty() { + self.sql.push(' '); + } + + // push the new sql segment, replacing inline arguments (`"{some rust expression}"`) + // with the appropriate placeholder (`$n` or `?`) + let mut args = sql.args.iter(); + let mut arg = args.next(); + for (idx, c) in sql.sql.chars().enumerate() { + if let Some((start, expr, end)) = arg { + if idx < *start { + self.sql.push(c); + } + if idx == *end { + self.add_parameter(expr.clone()); + arg = args.next(); + } + } else { + self.sql.push(c); + } + } + } +} + +/// Context within an `if .. {..} else ..` clause. +#[derive(Clone, Debug)] +struct IfContext { + condition: Expr, + then: Box, + or_else: Box, +} + +impl IsContext for IfContext { + fn branches(&self) -> usize { + self.then.branches() + self.or_else.branches() + } + + fn to_query(&self, branches: usize, branch_counter: &mut usize) -> TokenStream { + let condition = &self.condition; + let then = self.then.to_query(branches, branch_counter); + let or_else = self.or_else.to_query(branches, branch_counter); + quote! { + if #condition { + #then + } else { + #or_else + } + } + } + + fn add_sql(&mut self, sql: &SqlSegment) { + self.then.add_sql(sql); + self.or_else.add_sql(sql); + } +} + +/// Context within `match .. { .. }` +#[derive(Clone, Debug)] +struct MatchContext { + expr: Expr, + arms: Vec, +} + +impl IsContext for MatchContext { + fn branches(&self) -> usize { + self.arms.iter().map(|arm| arm.branches()).sum() + } + + fn to_query(&self, branches: usize, branch_counter: &mut usize) -> TokenStream { + let expr = &self.expr; + let arms = self + .arms + .iter() + .map(|arm| arm.to_query(branches, branch_counter)); + quote! { + match #expr { #(#arms,)* } + } + } + + fn add_sql(&mut self, sql: &SqlSegment) { + for arm in &mut self.arms { + arm.add_sql(sql); + } + } +} + +/// Context within the arm (`Pat => ..`) of a `match` +#[derive(Clone, Debug)] +struct MatchArmContext { + pattern: Pat, + inner: Box, +} + +impl IsContext for MatchArmContext { + fn branches(&self) -> usize { + self.inner.branches() + } + + fn to_query(&self, branches: usize, branch_counter: &mut usize) -> TokenStream { + let pat = &self.pattern; + let inner = self.inner.to_query(branches, branch_counter); + quote! { + #pat => #inner + } + } + + fn add_sql(&mut self, sql: &SqlSegment) { + self.inner.add_sql(sql); + } +} + +#[cfg(test)] +mod tests { + use proc_macro2::{TokenStream, TokenTree}; + use quote::quote; + + use crate::query::conditional::Input; + + // credits: Yandros#4299 + fn assert_token_stream_eq(ts1: TokenStream, ts2: TokenStream) { + fn assert_tt_eq(tt1: TokenTree, tt2: TokenTree) { + use ::proc_macro2::TokenTree::*; + match (tt1, tt2) { + (Group(g1), Group(g2)) => assert_token_stream_eq(g1.stream(), g2.stream()), + (Ident(lhs), Ident(rhs)) => assert_eq!(lhs.to_string(), rhs.to_string()), + (Punct(lhs), Punct(rhs)) => assert_eq!(lhs.as_char(), rhs.as_char()), + (Literal(lhs), Literal(rhs)) => assert_eq!(lhs.to_string(), rhs.to_string()), + _ => panic!("Not equal!"), + } + } + + let mut ts1 = ts1.into_iter(); + let mut ts2 = ts2.into_iter(); + loop { + match (ts1.next(), ts2.next()) { + (Some(tt1), Some(tt2)) => assert_tt_eq(tt1, tt2), + (None, None) => return, + _ => panic!("Not equal!"), + } + } + } + + #[test] + fn simple() { + let input = quote! { + OptionalRecord, "select something from somewhere where something_else = {1}" + }; + let result = syn::parse2::(input).unwrap(); + let expected_query = if cfg!(feature = "postgres") { + "select something from somewhere where something_else = $1" + } else { + "select something from somewhere where something_else = ?" + }; + assert_token_stream_eq( + result.to_context().unwrap().generate_output(), + quote! { + sqlx_macros::expand_query!( + record = OptionalRecord, + source = #expected_query, + args = [1] + ) + }, + ); + } + + #[test] + fn single_if() { + let input = quote!( + Article, + "SELECT * FROM articles" + if let Some(name_filter) = filter { + "WHERE name ILIKE {name_filter}" + } + ); + let result = syn::parse2::(input).unwrap(); + let ctx = result.to_context().unwrap(); + let output = ctx.generate_output(); + } + + #[test] + fn raw_literal() { + let input = quote!( + Article, + r#"SELECT * FROM articles"# + if let Some(name_filter) = filter { + r#"WHERE "name" ILIKE {name_filter}"# + } + ); + let result = syn::parse2::(input).unwrap(); + let ctx = result.to_context().unwrap(); + let output = ctx.generate_output(); + } +} diff --git a/sqlx-macros/src/query/conditional/segment.rs b/sqlx-macros/src/query/conditional/segment.rs new file mode 100644 index 0000000000..18f892256d --- /dev/null +++ b/sqlx-macros/src/query/conditional/segment.rs @@ -0,0 +1,233 @@ +use proc_macro2::Span; +use syn::{ + braced, + parse::{Parse, ParseStream, Peek}, + Error, Expr, LitStr, Pat, Token, +}; + +/// A single "piece" of the input. +#[derive(Debug)] +pub enum QuerySegment { + /// A part of an SQL query, like `"SELECT *"` + Sql(SqlSegment), + /// An `if .. { .. }`, with optional `else ..` + If(IfSegment), + /// An exhaustive `match .. { .. }` + Match(MatchSegment), +} + +impl QuerySegment { + /// Parse segments up to the first occurrence of the given token, or until the input is empty. + pub fn parse_until(input: ParseStream, until: T) -> syn::Result> { + let mut segments = vec![]; + while !input.is_empty() && !input.peek(until) { + segments.push(QuerySegment::parse(input)?); + } + Ok(segments) + } + + /// Parse segments until the input is empty. + pub fn parse_all(input: ParseStream) -> syn::Result> { + let mut segments = vec![]; + while !input.is_empty() { + segments.push(QuerySegment::parse(input)?); + } + Ok(segments) + } +} + +#[derive(Debug)] +pub struct SqlSegment { + pub sql: String, + pub args: Vec<(usize, Expr, usize)>, +} + +impl SqlSegment { + const EXPECT: &'static str = "\"..\""; + + fn matches(input: &ParseStream) -> bool { + input.fork().parse::().is_ok() + } +} + +impl Parse for SqlSegment { + fn parse(input: ParseStream) -> syn::Result { + let sql = input.parse::()?.value(); + let args = parse_inline_args(&sql)?; + + Ok(Self { sql, args }) + } +} + +// parses inline arguments in the query, for example `".. WHERE user_id = {1}"`, returning them with +// the index of `{`, the parsed argument, and the index of the `}`. +fn parse_inline_args(sql: &str) -> syn::Result> { + let mut args = vec![]; + let mut curr_level = 0; + let mut curr_arg = None; + + for (idx, c) in sql.chars().enumerate() { + match c { + '{' => { + if curr_arg.is_none() { + curr_arg = Some((idx, String::new())); + } + curr_level += 1; + } + '}' => { + if curr_arg.is_none() { + let err = Error::new(Span::call_site(), "unexpected '}' in query string"); + return Err(err); + }; + if curr_level == 1 { + let (arg_start, arg_str) = std::mem::replace(&mut curr_arg, None).unwrap(); + let arg = syn::parse_str(&arg_str)?; + args.push((arg_start, arg, idx)); + } + curr_level -= 1; + } + c => { + if let Some((_, arg)) = &mut curr_arg { + arg.push(c); + } + } + } + } + + if curr_arg.is_some() { + let err = Error::new(Span::call_site(), "expected '}', but got end of string"); + return Err(err); + } + + Ok(args) +} + +#[derive(Debug)] +pub struct MatchSegment { + pub expr: Expr, + pub arms: Vec, +} + +#[derive(Debug)] +pub struct MatchSegmentArm { + pub pat: Pat, + pub body: Vec, +} + +impl MatchSegmentArm { + fn parse_all(input: ParseStream) -> syn::Result> { + let mut arms = vec![]; + while !input.is_empty() { + arms.push(Self::parse(input)?); + } + Ok(arms) + } +} + +impl Parse for MatchSegmentArm { + fn parse(input: ParseStream) -> syn::Result { + let pat = input.parse()?; + input.parse::]>()?; + let body = if input.peek(syn::token::Brace) { + let body; + braced!(body in input); + QuerySegment::parse_all(&body)? + } else { + QuerySegment::parse_until(input, Token![,])? + }; + input.parse::>()?; + Ok(Self { pat, body }) + } +} + +impl MatchSegment { + const EXPECT: &'static str = "match .. { .. }"; + + fn matches(input: ParseStream) -> bool { + input.peek(Token![match]) + } +} + +impl Parse for MatchSegment { + fn parse(input: ParseStream) -> syn::Result { + input.parse::()?; + let expr = input.call(Expr::parse_without_eager_brace)?; + let input = { + let content; + braced!(content in input); + content + }; + let arms = MatchSegmentArm::parse_all(&input)?; + + Ok(Self { expr, arms }) + } +} + +#[derive(Debug)] +pub struct IfSegment { + pub condition: Expr, + pub then: Vec, + pub or_else: Vec, +} + +impl IfSegment { + const EXPECT: &'static str = "if { .. }"; + + fn matches(input: ParseStream) -> bool { + input.peek(Token![if]) + } +} + +impl Parse for IfSegment { + fn parse(input: ParseStream) -> syn::Result { + input.parse::()?; + let condition = input.call(Expr::parse_without_eager_brace)?; + let then = { + let if_then; + braced!(if_then in input); + QuerySegment::parse_all(&if_then)? + }; + let or_else = if input.parse::>()?.is_some() { + if IfSegment::matches(input) { + let else_if = IfSegment::parse(input)?; + vec![QuerySegment::If(else_if)] + } else { + let or_else; + braced!(or_else in input); + QuerySegment::parse_all(&or_else)? + } + } else { + vec![] + }; + Ok(Self { + condition, + then, + or_else, + }) + } +} + +impl Parse for QuerySegment { + fn parse(input: ParseStream) -> syn::Result { + // parse optional '+' for backwards compatibility + if input.peek(Token![+]) { + input.parse::()?; + } + + if SqlSegment::matches(&input) { + Ok(QuerySegment::Sql(input.parse()?)) + } else if IfSegment::matches(&input) { + Ok(QuerySegment::If(input.parse()?)) + } else if MatchSegment::matches(input) { + Ok(QuerySegment::Match(input.parse()?)) + } else { + let error = format!( + "expected `{}`, `{}` or `{}`", + SqlSegment::EXPECT, + IfSegment::EXPECT, + MatchSegment::EXPECT, + ); + Err(Error::new(input.span(), error)) + } + } +} diff --git a/sqlx-macros/src/query/mod.rs b/sqlx-macros/src/query/mod.rs index aa862bd743..835cb162b4 100644 --- a/sqlx-macros/src/query/mod.rs +++ b/sqlx-macros/src/query/mod.rs @@ -20,10 +20,13 @@ use crate::query::input::RecordType; use either::Either; mod args; +mod conditional; mod data; mod input; mod output; +pub use conditional::query_as; + struct Metadata { #[allow(unused)] manifest_dir: PathBuf, diff --git a/src/lib.rs b/src/lib.rs index 9d00dfcfdc..c8337bcad5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -65,6 +65,10 @@ pub use sqlx_core::postgres::{self, PgConnection, PgExecutor, PgPool, Postgres}; #[cfg_attr(docsrs, doc(cfg(feature = "sqlite")))] pub use sqlx_core::sqlite::{self, Sqlite, SqliteConnection, SqliteExecutor, SqlitePool}; +// used by conditional query macro +#[doc(hidden)] +pub use sqlx_core::futures_core; + #[cfg(feature = "macros")] #[doc(hidden)] pub extern crate sqlx_macros; @@ -74,6 +78,9 @@ pub extern crate sqlx_macros; #[doc(hidden)] pub use sqlx_macros::{FromRow, Type}; +#[cfg(feature = "macros")] +pub use sqlx_macros::query_as; + #[cfg(feature = "macros")] mod macros; diff --git a/src/macros.rs b/src/macros.rs index 0ca85d61b4..fecf27d6bf 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -395,166 +395,6 @@ macro_rules! query_file_unchecked ( }) ); -/// A variant of [query!] which takes a path to an explicitly defined struct as the output type. -/// -/// This lets you return the struct from a function or add your own trait implementations. -/// -/// **No trait implementations are required**; the macro maps rows using a struct literal -/// where the names of columns in the query are expected to be the same as the fields of the struct -/// (but the order does not need to be the same). The types of the columns are based on the -/// query and not the corresponding fields of the struct, so this is type-safe as well. -/// -/// This enforces a few things: -/// * The query must output at least one column. -/// * The column names of the query must match the field names of the struct. -/// * The field types must be the Rust equivalent of their SQL counterparts; see the corresponding -/// module for your database for mappings: -/// * Postgres: [crate::postgres::types] -/// * MySQL: [crate::mysql::types] -/// * SQLite: [crate::sqlite::types] -/// * MSSQL: [crate::mssql::types] -/// * If a column may be `NULL`, the corresponding field's type must be wrapped in `Option<_>`. -/// * Neither the query nor the struct may have unused fields. -/// -/// The only modification to the `query!()` syntax is that the struct name is given before the SQL -/// string: -/// ```rust,ignore -/// # use sqlx::Connect; -/// # #[cfg(all(feature = "mysql", feature = "_rt-async-std"))] -/// # #[async_std::main] -/// # async fn main() -> sqlx::Result<()>{ -/// # let db_url = dotenv::var("DATABASE_URL").expect("DATABASE_URL must be set"); -/// # -/// # if !(db_url.starts_with("mysql") || db_url.starts_with("mariadb")) { return Ok(()) } -/// # let mut conn = sqlx::MySqlConnection::connect(db_url).await?; -/// #[derive(Debug)] -/// struct Account { -/// id: i32, -/// name: String -/// } -/// -/// // let mut conn = ; -/// let account = sqlx::query_as!( -/// Account, -/// "select * from (select (1) as id, 'Herp Derpinson' as name) accounts where id = ?", -/// 1i32 -/// ) -/// .fetch_one(&mut conn) -/// .await?; -/// -/// println!("{:?}", account); -/// println!("{}: {}", account.id, account.name); -/// -/// # Ok(()) -/// # } -/// # -/// # #[cfg(any(not(feature = "mysql"), not(feature = "_rt-async-std")))] -/// # fn main() {} -/// ``` -/// -/// **The method you want to call depends on how many rows you're expecting.** -/// -/// | Number of Rows | Method to Call* | Returns (`T` being the given struct) | Notes | -/// |----------------| ----------------------------|----------------------------------------|-------| -/// | Zero or One | `.fetch_optional(...).await`| `sqlx::Result>` | Extra rows are ignored. | -/// | Exactly One | `.fetch_one(...).await` | `sqlx::Result` | Errors if no rows were returned. Extra rows are ignored. Aggregate queries, use this. | -/// | At Least One | `.fetch(...)` | `impl Stream>` | Call `.try_next().await` to get each row result. | -/// | Multiple | `.fetch_all(...)` | `sqlx::Result>` | | -/// -/// \* All methods accept one of `&mut {connection type}`, `&mut Transaction` or `&Pool`. -/// (`.execute()` is omitted as this macro requires at least one column to be returned.) -/// -/// ### Column Type Override: Infer from Struct Field -/// In addition to the column type overrides supported by [query!], `query_as!()` supports an -/// additional override option: -/// -/// If you select a column `foo as "foo: _"` (Postgres/SQLite) or `` foo as `foo: _` `` (MySQL) -/// it causes that column to be inferred based on the type of the corresponding field in the given -/// record struct. Runtime type-checking is still done so an error will be emitted if the types -/// are not compatible. -/// -/// This allows you to override the inferred type of a column to instead use a custom-defined type: -/// -/// ```rust,ignore -/// #[derive(sqlx::Type)] -/// #[sqlx(transparent)] -/// struct MyInt4(i32); -/// -/// struct Record { -/// id: MyInt4, -/// } -/// -/// let my_int = MyInt4(1); -/// -/// // Postgres/SQLite -/// sqlx::query_as!(Record, r#"select 1 as "id: _""#) // MySQL: use "select 1 as `id: _`" instead -/// .fetch_one(&mut conn) -/// .await?; -/// -/// assert_eq!(record.id, MyInt4(1)); -/// ``` -/// -/// ### Troubleshooting: "error: mismatched types" -/// If you get a "mismatched types" error from an invocation of this macro and the error -/// isn't pointing specifically at a parameter. -/// -/// For example, code like this (using a Postgres database): -/// -/// ```rust,ignore -/// struct Account { -/// id: i32, -/// name: Option, -/// } -/// -/// let account = sqlx::query_as!( -/// Account, -/// r#"SELECT id, name from (VALUES (1, 'Herp Derpinson')) accounts(id, name)"#, -/// ) -/// .fetch_one(&mut conn) -/// .await?; -/// ``` -/// -/// Might produce an error like this: -/// ```text,ignore -/// error[E0308]: mismatched types -/// --> tests/postgres/macros.rs:126:19 -/// | -/// 126 | let account = sqlx::query_as!( -/// | ___________________^ -/// 127 | | Account, -/// 128 | | r#"SELECT id, name from (VALUES (1, 'Herp Derpinson')) accounts(id, name)"#, -/// 129 | | ) -/// | |_____^ expected `i32`, found enum `std::option::Option` -/// | -/// = note: expected type `i32` -/// found enum `std::option::Option` -/// ``` -/// -/// This means that you need to check that any field of the "expected" type (here, `i32`) matches -/// the Rust type mapping for its corresponding SQL column (see the `types` module of your database, -/// listed above, for mappings). The "found" type is the SQL->Rust mapping that the macro chose. -/// -/// In the above example, the returned column is inferred to be nullable because it's being -/// returned from a `VALUES` statement in Postgres, so the macro inferred the field to be nullable -/// and so used `Option` instead of `i32`. **In this specific case** we could use -/// `select id as "id!"` to override the inferred nullability because we know in practice -/// that column will never be `NULL` and it will fix the error. -/// -/// Nullability inference and type overrides are discussed in detail in the docs for [query!]. -/// -/// It unfortunately doesn't appear to be possible right now to make the error specifically mention -/// the field; this probably requires the `const-panic` feature (still unstable as of Rust 1.45). -#[macro_export] -#[cfg_attr(docsrs, doc(cfg(feature = "macros")))] -macro_rules! query_as ( - ($out_struct:path, $query:expr) => ( { - $crate::sqlx_macros::expand_query!(record = $out_struct, source = $query) - }); - ($out_struct:path, $query:expr, $($args:tt)*) => ( { - $crate::sqlx_macros::expand_query!(record = $out_struct, source = $query, args = [$($args)*]) - }) -); - /// Combines the syntaxes of [query_as!] and [query_file!]. /// /// Enforces requirements of both macros; see them for details. diff --git a/tests/postgres/conditional_query.rs b/tests/postgres/conditional_query.rs new file mode 100644 index 0000000000..ce4c415935 --- /dev/null +++ b/tests/postgres/conditional_query.rs @@ -0,0 +1,143 @@ +use sqlx_core::postgres::{PgConnection, Postgres}; +use sqlx_test::new; + +#[sqlx_macros::test] +async fn simple() -> anyhow::Result<()> { + let mut conn = new::().await?; + + struct Result { + result: Option, + } + + for value in [true, false] { + let result = sqlx::query_as!( + Result, + "SELECT" + if value { "42" } else { "12" } + r#"AS "result""# + ) + .fetch_one(&mut conn) + .await?; + + if value { + assert_eq!(result.result, Some(42)); + } else { + assert_eq!(result.result, Some(12)); + } + } + + Ok(()) +} + +#[sqlx_macros::test] +async fn dynamic_ordering() -> anyhow::Result<()> { + let mut conn = new::().await?; + + #[derive(Clone, Eq, PartialEq, Debug)] + struct Article { + id: i32, + title: String, + author: String, + } + + let expected = vec![ + Article { + id: 1, + title: "Article1".to_owned(), + author: "Peter".to_owned(), + }, + Article { + id: 2, + title: "Article2".to_owned(), + author: "John".to_owned(), + }, + ]; + for reverse_order in [true, false] { + let articles = sqlx::query_as!( + Article, + "SELECT" + r#"id AS "id!", title AS "title!", author AS "author!""# + "FROM (" + "VALUES (1, 'Article1', 'Peter'), (2, 'Article2', 'John')" + ") articles(id, title, author)" + "ORDER BY title" + if reverse_order { + "DESC" + } else { + "ASC" + } + ) + .fetch_all(&mut conn) + .await?; + + if reverse_order { + let mut expected = expected.clone(); + expected.reverse(); + assert_eq!(articles, expected); + } else { + assert_eq!(articles, expected); + } + } + + Ok(()) +} + +#[sqlx_macros::test] +async fn dynamic_filtering() -> anyhow::Result<()> { + let mut conn = new::().await?; + + #[derive(Clone, Eq, PartialEq, Debug)] + struct Article { + id: i32, + title: String, + author: String, + } + + enum Filter { + Id(i32), + Title(String), + Author(String), + } + + async fn query_articles( + con: &mut PgConnection, + filter: Option, + ) -> anyhow::Result> { + let articles = sqlx::query_as!( + Article, + "SELECT" + r#"id AS "id!", title AS "title!", author AS "author!""# + "FROM (" + "VALUES (1, 'Article1', 'Peter'), (2, 'Article2', 'John'), (3, 'Article3', 'James')" + ") articles(id, title, author)" + if let Some(filter) = filter { + "WHERE" + match filter { + Filter::Id(id) => "id = {id}", + Filter::Title(title) => "title ILIKE {title}", + Filter::Author(author) => "author ILIKE {author}" + } + } + ) + .fetch_all(con) + .await?; + Ok(articles) + } + + let result = query_articles(&mut conn, None).await?; + assert_eq!(result.len(), 3); + + let result = query_articles(&mut conn, Some(Filter::Id(1))).await?; + assert_eq!(result.len(), 1); + assert_eq!(result[0].id, 1); + + let result = query_articles(&mut conn, Some(Filter::Title("article2".to_owned()))).await?; + assert_eq!(result.len(), 1); + assert_eq!(result[0].id, 2); + + let result = query_articles(&mut conn, Some(Filter::Author("james".to_owned()))).await?; + assert_eq!(result.len(), 1); + assert_eq!(result[0].id, 3); + + Ok(()) +} diff --git a/tests/postgres/macros.rs b/tests/postgres/macros.rs index 51d1f89bb8..c7b6ac2105 100644 --- a/tests/postgres/macros.rs +++ b/tests/postgres/macros.rs @@ -137,6 +137,25 @@ async fn test_query_as() -> anyhow::Result<()> { Ok(()) } +#[sqlx_macros::test] +async fn test_query_as_inline_args() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let name: Option<&str> = None; + let account = sqlx::query_as!( + Account, + r#"SELECT id "id!", name from (VALUES (1, {name})) accounts(id, name)"#, + ) + .fetch_one(&mut conn) + .await?; + + assert_eq!(None, account.name); + + println!("{:?}", account); + + Ok(()) +} + #[derive(Debug)] struct RawAccount { r#type: i32,