Skip to content

Commit 29dcd44

Browse files
fix(mysql): Close prepared statement if persistence is disabled (#2905)
* close prepared statement if persistence or statement cache are disabled * add tests
1 parent 31e541a commit 29dcd44

File tree

2 files changed

+127
-29
lines changed

2 files changed

+127
-29
lines changed

sqlx-mysql/src/connection/executor.rs

+61-29
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,10 @@ use futures_util::{pin_mut, TryStreamExt};
2525
use std::{borrow::Cow, sync::Arc};
2626

2727
impl MySqlConnection {
28-
async fn get_or_prepare<'c>(
28+
async fn prepare_statement<'c>(
2929
&mut self,
3030
sql: &str,
31-
persistent: bool,
3231
) -> Result<(u32, MySqlStatementMetadata), Error> {
33-
if let Some(statement) = self.cache_statement.get_mut(sql) {
34-
// <MySqlStatementMetadata> is internally reference-counted
35-
return Ok((*statement).clone());
36-
}
37-
3832
// https://dev.mysql.com/doc/internals/en/com-stmt-prepare.html
3933
// https://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html#packet-COM_STMT_PREPARE_OK
4034

@@ -72,11 +66,23 @@ impl MySqlConnection {
7266
column_names: Arc::new(column_names),
7367
};
7468

75-
if persistent && self.cache_statement.is_enabled() {
76-
// in case of the cache being full, close the least recently used statement
77-
if let Some((id, _)) = self.cache_statement.insert(sql, (id, metadata.clone())) {
78-
self.stream.send_packet(StmtClose { statement: id }).await?;
79-
}
69+
Ok((id, metadata))
70+
}
71+
72+
async fn get_or_prepare_statement<'c>(
73+
&mut self,
74+
sql: &str,
75+
) -> Result<(u32, MySqlStatementMetadata), Error> {
76+
if let Some(statement) = self.cache_statement.get_mut(sql) {
77+
// <MySqlStatementMetadata> is internally reference-counted
78+
return Ok((*statement).clone());
79+
}
80+
81+
let (id, metadata) = self.prepare_statement(sql).await?;
82+
83+
// in case of the cache being full, close the least recently used statement
84+
if let Some((id, _)) = self.cache_statement.insert(sql, (id, metadata.clone())) {
85+
self.stream.send_packet(StmtClose { statement: id }).await?;
8086
}
8187

8288
Ok((id, metadata))
@@ -102,21 +108,37 @@ impl MySqlConnection {
102108
let mut columns = Arc::new(Vec::new());
103109

104110
let (mut column_names, format, mut needs_metadata) = if let Some(arguments) = arguments {
105-
let (id, metadata) = self.get_or_prepare(
106-
sql,
107-
persistent,
108-
)
109-
.await?;
110-
111-
// https://dev.mysql.com/doc/internals/en/com-stmt-execute.html
112-
self.stream
113-
.send_packet(StatementExecute {
114-
statement: id,
115-
arguments: &arguments,
116-
})
117-
.await?;
118-
119-
(metadata.column_names, MySqlValueFormat::Binary, false)
111+
if persistent && self.cache_statement.is_enabled() {
112+
let (id, metadata) = self
113+
.get_or_prepare_statement(sql)
114+
.await?;
115+
116+
// https://dev.mysql.com/doc/internals/en/com-stmt-execute.html
117+
self.stream
118+
.send_packet(StatementExecute {
119+
statement: id,
120+
arguments: &arguments,
121+
})
122+
.await?;
123+
124+
(metadata.column_names, MySqlValueFormat::Binary, false)
125+
} else {
126+
let (id, metadata) = self
127+
.prepare_statement(sql)
128+
.await?;
129+
130+
// https://dev.mysql.com/doc/internals/en/com-stmt-execute.html
131+
self.stream
132+
.send_packet(StatementExecute {
133+
statement: id,
134+
arguments: &arguments,
135+
})
136+
.await?;
137+
138+
self.stream.send_packet(StmtClose { statement: id }).await?;
139+
140+
(metadata.column_names, MySqlValueFormat::Binary, false)
141+
}
120142
} else {
121143
// https://dev.mysql.com/doc/internals/en/com-query.html
122144
self.stream.send_packet(Query(sql)).await?;
@@ -269,7 +291,15 @@ impl<'c> Executor<'c> for &'c mut MySqlConnection {
269291
Box::pin(async move {
270292
self.stream.wait_until_ready().await?;
271293

272-
let (_, metadata) = self.get_or_prepare(sql, true).await?;
294+
let metadata = if self.cache_statement.is_enabled() {
295+
self.get_or_prepare_statement(sql).await?.1
296+
} else {
297+
let (id, metadata) = self.prepare_statement(sql).await?;
298+
299+
self.stream.send_packet(StmtClose { statement: id }).await?;
300+
301+
metadata
302+
};
273303

274304
Ok(MySqlStatement {
275305
sql: Cow::Borrowed(sql),
@@ -287,7 +317,9 @@ impl<'c> Executor<'c> for &'c mut MySqlConnection {
287317
Box::pin(async move {
288318
self.stream.wait_until_ready().await?;
289319

290-
let (_, metadata) = self.get_or_prepare(sql, false).await?;
320+
let (id, metadata) = self.prepare_statement(sql).await?;
321+
322+
self.stream.send_packet(StmtClose { statement: id }).await?;
291323

292324
let columns = (&*metadata.columns).clone();
293325

tests/mysql/mysql.rs

+66
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,57 @@ async fn it_caches_statements() -> anyhow::Result<()> {
237237
Ok(())
238238
}
239239

240+
#[sqlx_macros::test]
241+
async fn it_closes_statements_with_persistent_disabled() -> anyhow::Result<()> {
242+
let mut conn = new::<MySql>().await?;
243+
244+
let old_statement_count = select_statement_count(&mut conn).await.unwrap_or_default();
245+
246+
for i in 0..2 {
247+
let row = sqlx::query("SELECT ? AS val")
248+
.bind(i)
249+
.persistent(false)
250+
.fetch_one(&mut conn)
251+
.await?;
252+
253+
let val: i32 = row.get("val");
254+
255+
assert_eq!(i, val);
256+
}
257+
258+
let new_statement_count = select_statement_count(&mut conn).await.unwrap_or_default();
259+
260+
assert_eq!(old_statement_count, new_statement_count);
261+
262+
Ok(())
263+
}
264+
265+
#[sqlx_macros::test]
266+
async fn it_closes_statements_with_cache_disabled() -> anyhow::Result<()> {
267+
setup_if_needed();
268+
269+
let mut url = url::Url::parse(&env::var("DATABASE_URL")?)?;
270+
url.query_pairs_mut()
271+
.append_pair("statement-cache-capacity", "0");
272+
273+
let mut conn = MySqlConnection::connect(url.as_ref()).await?;
274+
275+
let old_statement_count = select_statement_count(&mut conn).await.unwrap_or_default();
276+
277+
for index in 1..=10_i32 {
278+
let _ = sqlx::query("SELECT ?")
279+
.bind(index)
280+
.execute(&mut conn)
281+
.await?;
282+
}
283+
284+
let new_statement_count = select_statement_count(&mut conn).await.unwrap_or_default();
285+
286+
assert_eq!(old_statement_count, new_statement_count);
287+
288+
Ok(())
289+
}
290+
240291
#[sqlx_macros::test]
241292
async fn it_can_bind_null_and_non_null_issue_540() -> anyhow::Result<()> {
242293
let mut conn = new::<MySql>().await?;
@@ -510,3 +561,18 @@ async fn test_shrink_buffers() -> anyhow::Result<()> {
510561

511562
Ok(())
512563
}
564+
565+
async fn select_statement_count(conn: &mut MySqlConnection) -> Result<i64, sqlx::Error> {
566+
// Fails if performance schema does not exist
567+
sqlx::query_scalar(
568+
r#"
569+
SELECT COUNT(*)
570+
FROM performance_schema.threads AS t
571+
INNER JOIN performance_schema.prepared_statements_instances AS psi
572+
ON psi.OWNER_THREAD_ID = t.THREAD_ID
573+
WHERE t.processlist_id = CONNECTION_ID()
574+
"#,
575+
)
576+
.fetch_one(conn)
577+
.await
578+
}

0 commit comments

Comments
 (0)