Skip to content

Commit c2e9728

Browse files
committed
pool: fix panic when using callbacks
add regression test added missing typedef `MssqlPoolOptions`
1 parent 681cfb7 commit c2e9728

File tree

3 files changed

+156
-5
lines changed

3 files changed

+156
-5
lines changed

sqlx-core/src/mssql/mod.rs

+3
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ pub use value::{MssqlValue, MssqlValueRef};
3434
/// An alias for [`Pool`][crate::pool::Pool], specialized for MSSQL.
3535
pub type MssqlPool = crate::pool::Pool<Mssql>;
3636

37+
/// An alias for [`PoolOptions`][crate::pool::PoolOptions], specialized for MSSQL.
38+
pub type MssqlPoolOptions = crate::pool::PoolOptions<Mssql>;
39+
3740
/// An alias for [`Executor<'_, Database = Mssql>`][Executor].
3841
pub trait MssqlExecutor<'c>: Executor<'c, Database = Mssql> {}
3942
impl<'c, T: Executor<'c, Database = Mssql>> MssqlExecutor<'c> for T {}

tests/any/pool.rs

+18-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use sqlx::any::AnyPoolOptions;
1+
use sqlx::any::{AnyConnectOptions, AnyKind, AnyPoolOptions};
22
use sqlx::Executor;
33
use std::sync::atomic::AtomicI32;
44
use std::sync::{
@@ -69,15 +69,29 @@ async fn pool_should_be_returned_failed_transactions() -> anyhow::Result<()> {
6969

7070
#[sqlx_macros::test]
7171
async fn test_pool_callbacks() -> anyhow::Result<()> {
72-
sqlx_test::setup_if_needed();
73-
7472
#[derive(sqlx::FromRow, Debug, PartialEq, Eq)]
7573
struct ConnStats {
7674
id: i32,
7775
before_acquire_calls: i32,
7876
after_release_calls: i32,
7977
}
8078

79+
sqlx_test::setup_if_needed();
80+
81+
let conn_options: AnyConnectOptions = std::env::var("DATABASE_URL")?.parse()?;
82+
83+
#[cfg(feature = "mssql")]
84+
if conn_options.kind() == AnyKind::Mssql {
85+
// MSSQL doesn't support `CREATE TEMPORARY TABLE`,
86+
// because why follow conventions when you can subvert them?
87+
// Instead, you prepend `#` to the table name for a session-local temporary table
88+
// which you also have to do when referencing it.
89+
90+
// Since that affects basically every query here,
91+
// it's just easier to have a separate MSSQL-specific test case.
92+
return Ok(());
93+
}
94+
8195
let current_id = AtomicI32::new(0);
8296

8397
let pool = AnyPoolOptions::new()
@@ -158,7 +172,7 @@ async fn test_pool_callbacks() -> anyhow::Result<()> {
158172
})
159173
})
160174
// Don't establish a connection yet.
161-
.connect_lazy(&dotenv::var("DATABASE_URL")?)?;
175+
.connect_lazy_with(conn_options)?;
162176

163177
// Expected pattern of (id, before_acquire_calls, after_release_calls)
164178
let pattern = [

tests/mssql/mssql.rs

+135-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
use futures::TryStreamExt;
2-
use sqlx::mssql::Mssql;
2+
use sqlx::mssql::{Mssql, MssqlPoolOptions};
33
use sqlx::{Column, Connection, Executor, MssqlConnection, Row, Statement, TypeInfo};
44
use sqlx_core::mssql::MssqlRow;
55
use sqlx_test::new;
6+
use std::sync::atomic::{AtomicI32, Ordering};
7+
use std::time::Duration;
68

79
#[sqlx_macros::test]
810
async fn it_connects() -> anyhow::Result<()> {
@@ -325,3 +327,135 @@ async fn it_can_prepare_then_execute() -> anyhow::Result<()> {
325327

326328
Ok(())
327329
}
330+
331+
// MSSQL-specific copy of the test case in `tests/any/pool.rs`
332+
// because MSSQL has its own bespoke syntax for temporary tables.
333+
#[sqlx_macros::test]
334+
async fn test_pool_callbacks() -> anyhow::Result<()> {
335+
#[derive(sqlx::FromRow, Debug, PartialEq, Eq)]
336+
struct ConnStats {
337+
id: i32,
338+
before_acquire_calls: i32,
339+
after_release_calls: i32,
340+
}
341+
342+
sqlx_test::setup_if_needed();
343+
344+
let current_id = AtomicI32::new(0);
345+
346+
let pool = MssqlPoolOptions::new()
347+
.max_connections(1)
348+
.acquire_timeout(Duration::from_secs(5))
349+
.after_connect(move |conn, meta| {
350+
assert_eq!(meta.age, Duration::ZERO);
351+
assert_eq!(meta.idle_for, Duration::ZERO);
352+
353+
let id = current_id.fetch_add(1, Ordering::AcqRel);
354+
355+
Box::pin(async move {
356+
let statement = format!(
357+
// language=MSSQL
358+
r#"
359+
CREATE TABLE #conn_stats(
360+
id int primary key,
361+
before_acquire_calls int default 0,
362+
after_release_calls int default 0
363+
);
364+
INSERT INTO #conn_stats(id) VALUES ({});
365+
"#,
366+
// Until we have generalized bind parameters
367+
id
368+
);
369+
370+
conn.execute(&statement[..]).await?;
371+
Ok(())
372+
})
373+
})
374+
.before_acquire(|conn, meta| {
375+
// `age` and `idle_for` should both be nonzero
376+
assert_ne!(meta.age, Duration::ZERO);
377+
assert_ne!(meta.idle_for, Duration::ZERO);
378+
379+
Box::pin(async move {
380+
// MSSQL doesn't support UPDATE ... RETURNING either
381+
sqlx::query(
382+
r#"
383+
UPDATE #conn_stats
384+
SET before_acquire_calls = before_acquire_calls + 1
385+
"#,
386+
)
387+
.execute(&mut *conn)
388+
.await?;
389+
390+
let stats: ConnStats = sqlx::query_as("SELECT * FROM #conn_stats")
391+
.fetch_one(conn)
392+
.await?;
393+
394+
// For even IDs, cap by the number of before_acquire calls.
395+
// Ignore the check for odd IDs.
396+
Ok((stats.id & 1) == 1 || stats.before_acquire_calls < 3)
397+
})
398+
})
399+
.after_release(|conn, meta| {
400+
// `age` should be nonzero but `idle_for` should be zero.
401+
assert_ne!(meta.age, Duration::ZERO);
402+
assert_eq!(meta.idle_for, Duration::ZERO);
403+
404+
Box::pin(async move {
405+
sqlx::query(
406+
r#"
407+
UPDATE #conn_stats
408+
SET after_release_calls = after_release_calls + 1
409+
"#,
410+
)
411+
.execute(&mut *conn)
412+
.await?;
413+
414+
let stats: ConnStats = sqlx::query_as("SELECT * FROM #conn_stats")
415+
.fetch_one(conn)
416+
.await?;
417+
418+
// For odd IDs, cap by the number of before_release calls.
419+
// Ignore the check for even IDs.
420+
Ok((stats.id & 1) == 0 || stats.after_release_calls < 4)
421+
})
422+
})
423+
// Don't establish a connection yet.
424+
.connect_lazy(&std::env::var("DATABASE_URL")?)?;
425+
426+
// Expected pattern of (id, before_acquire_calls, after_release_calls)
427+
let pattern = [
428+
// The connection pool starts empty.
429+
(0, 0, 0),
430+
(0, 1, 1),
431+
(0, 2, 2),
432+
(1, 0, 0),
433+
(1, 1, 1),
434+
(1, 2, 2),
435+
// We should expect one more `acquire` because the ID is odd
436+
(1, 3, 3),
437+
(2, 0, 0),
438+
(2, 1, 1),
439+
(2, 2, 2),
440+
(3, 0, 0),
441+
];
442+
443+
for (id, before_acquire_calls, after_release_calls) in pattern {
444+
let conn_stats: ConnStats = sqlx::query_as("SELECT * FROM #conn_stats")
445+
.fetch_one(&pool)
446+
.await?;
447+
448+
assert_eq!(
449+
conn_stats,
450+
ConnStats {
451+
id,
452+
before_acquire_calls,
453+
after_release_calls
454+
}
455+
);
456+
}
457+
458+
pool.close().await;
459+
460+
Ok(())
461+
}

0 commit comments

Comments
 (0)