diff --git a/src/conn/mod.rs b/src/conn/mod.rs index 04c1e5f7..b29cdf41 100644 --- a/src/conn/mod.rs +++ b/src/conn/mod.rs @@ -1144,6 +1144,58 @@ mod test { Ok(()) } + #[tokio::test] + async fn should_return_found_rows_if_flag_is_set() -> super::Result<()> { + let opts = get_opts().client_found_rows(true); + let mut conn = Conn::new(opts).await.unwrap(); + + "CREATE TEMPORARY TABLE mysql.found_rows (id INT PRIMARY KEY AUTO_INCREMENT, val INT)" + .ignore(&mut conn) + .await?; + + "INSERT INTO mysql.found_rows (val) VALUES (1)" + .ignore(&mut conn) + .await?; + + // Inserted one row, affected should be one. + assert_eq!(conn.affected_rows(), 1); + + "UPDATE mysql.found_rows SET val = 1 WHERE val = 1" + .ignore(&mut conn) + .await?; + + // The query doesn't affect any rows, but due to us wanting FOUND rows, + // this has to return one. + assert_eq!(conn.affected_rows(), 1); + + Ok(()) + } + + #[tokio::test] + async fn should_not_return_found_rows_if_flag_is_not_set() -> super::Result<()> { + let mut conn = Conn::new(get_opts()).await.unwrap(); + + "CREATE TEMPORARY TABLE mysql.found_rows (id INT PRIMARY KEY AUTO_INCREMENT, val INT)" + .ignore(&mut conn) + .await?; + + "INSERT INTO mysql.found_rows (val) VALUES (1)" + .ignore(&mut conn) + .await?; + + // Inserted one row, affected should be one. + assert_eq!(conn.affected_rows(), 1); + + "UPDATE mysql.found_rows SET val = 1 WHERE val = 1" + .ignore(&mut conn) + .await?; + + // The query doesn't affect any rows. + assert_eq!(conn.affected_rows(), 0); + + Ok(()) + } + async fn read_binlog_streams_and_close_their_connections( pool: Option<&Pool>, binlog_server_ids: (u32, u32, u32), diff --git a/src/opts/mod.rs b/src/opts/mod.rs index a74a6fc0..560b1cb9 100644 --- a/src/opts/mod.rs +++ b/src/opts/mod.rs @@ -404,6 +404,12 @@ pub(crate) struct MysqlOpts { /// /// Available via `secure_auth` connection url parameter. secure_auth: bool, + + /// Enables `CLIENT_FOUND_ROWS` capability (defaults to `false`). + /// + /// Changes the behavior of the affected count returned for writes (UPDATE/INSERT etc). + /// It makes MySQL return the FOUND rows instead of the AFFECTED rows. + client_found_rows: bool, } /// Mysql connection options. @@ -721,6 +727,26 @@ impl Opts { self.inner.mysql_opts.secure_auth } + /// Returns `true` if `CLIENT_FOUND_ROWS` capability is enabled (defaults to `false`). + /// + /// `CLIENT_FOUND_ROWS` changes the behavior of the affected count returned for writes + /// (UPDATE/INSERT etc). It makes MySQL return the FOUND rows instead of the AFFECTED rows. + /// + /// # Connection URL + /// + /// Use `client_found_rows` URL parameter to set this value. E.g. + /// + /// ``` + /// # use mysql_async::*; + /// # fn main() -> Result<()> { + /// let opts = Opts::from_url("mysql://localhost/db?client_found_rows=true")?; + /// assert!(opts.client_found_rows()); + /// # Ok(()) } + /// ``` + pub fn client_found_rows(&self) -> bool { + self.inner.mysql_opts.client_found_rows + } + pub(crate) fn get_capabilities(&self) -> CapabilityFlags { let mut out = CapabilityFlags::CLIENT_PROTOCOL_41 | CapabilityFlags::CLIENT_SECURE_CONNECTION @@ -742,6 +768,9 @@ impl Opts { if self.inner.mysql_opts.compression.is_some() { out |= CapabilityFlags::CLIENT_COMPRESS; } + if self.client_found_rows() { + out |= CapabilityFlags::CLIENT_FOUND_ROWS; + } out } @@ -767,6 +796,7 @@ impl Default for MysqlOpts { max_allowed_packet: None, wait_timeout: None, secure_auth: true, + client_found_rows: false, } } } @@ -1017,6 +1047,12 @@ impl OptsBuilder { self.opts.secure_auth = secure_auth; self } + + /// Enables or disables `CLIENT_FOUND_ROWS` capability. See [`Opts::client_found_rows`]. + pub fn client_found_rows(mut self, client_found_rows: bool) -> Self { + self.opts.client_found_rows = client_found_rows; + self + } } impl From for Opts { @@ -1245,6 +1281,18 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { }); } } + } else if key == "client_found_rows" { + match bool::from_str(&*value) { + Ok(client_found_rows) => { + opts.client_found_rows = client_found_rows; + } + _ => { + return Err(UrlError::InvalidParamValue { + param: "client_found_rows".into(), + value, + }); + } + } } else if key == "socket" { opts.socket = Some(value) } else if key == "compression" {