Skip to content

Commit acc3801

Browse files
committed
fix: make begin,commit,rollback cancel-safe in sqlite (launchbadge#2054)
1 parent 26f60d9 commit acc3801

File tree

1 file changed

+155
-10
lines changed

1 file changed

+155
-10
lines changed

sqlx-core/src/sqlite/connection/worker.rs

+155-10
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,15 @@ pub(crate) struct ConnectionWorker {
3232
pub(crate) handle_raw: ConnectionHandleRaw,
3333
/// Mutex for locking access to the database.
3434
pub(crate) shared: Arc<WorkerSharedState>,
35+
36+
// Mirror of `shared.conn.transaction_depth` to help provide cancel-safety:
37+
//
38+
// - If `transaction_depth == shared.conn.transaction_depth` then no cancellation occurred
39+
// - If `transaction_depth == shared.conn.transaction_depth - 1` then a `begin()` was cancelled
40+
// - If `transaction_depth == shared.conn.transaction_depth + 1` then a `commit()` or
41+
// `rollback()` was cancelled
42+
// - No other cases are possible (would indicate a logic bug)
43+
transaction_depth: usize,
3544
}
3645

3746
pub(crate) struct WorkerSharedState {
@@ -52,15 +61,19 @@ enum Command {
5261
query: Box<str>,
5362
arguments: Option<SqliteArguments<'static>>,
5463
persistent: bool,
64+
transaction_depth: usize,
5565
tx: flume::Sender<Result<Either<SqliteQueryResult, SqliteRow>, Error>>,
5666
},
5767
Begin {
68+
transaction_depth: usize,
5869
tx: oneshot::Sender<Result<(), Error>>,
5970
},
6071
Commit {
72+
transaction_depth: usize,
6173
tx: oneshot::Sender<Result<(), Error>>,
6274
},
6375
Rollback {
76+
transaction_depth: usize,
6477
tx: Option<oneshot::Sender<Result<(), Error>>>,
6578
},
6679
CreateCollation {
@@ -110,6 +123,7 @@ impl ConnectionWorker {
110123
command_tx,
111124
handle_raw: conn.handle.to_raw(),
112125
shared: Arc::clone(&shared),
126+
transaction_depth: 0,
113127
}))
114128
.is_err()
115129
{
@@ -135,8 +149,15 @@ impl ConnectionWorker {
135149
query,
136150
arguments,
137151
persistent,
152+
transaction_depth,
138153
tx,
139154
} => {
155+
if let Err(error) = handle_cancelled_begin(&mut conn, transaction_depth)
156+
{
157+
tx.send(Err(error)).ok();
158+
continue;
159+
}
160+
140161
let iter = match execute::iter(&mut conn, &query, arguments, persistent)
141162
{
142163
Ok(iter) => iter,
@@ -154,7 +175,16 @@ impl ConnectionWorker {
154175

155176
update_cached_statements_size(&conn, &shared.cached_statements_size);
156177
}
157-
Command::Begin { tx } => {
178+
Command::Begin {
179+
transaction_depth,
180+
tx,
181+
} => {
182+
if let Err(error) = handle_cancelled_begin(&mut conn, transaction_depth)
183+
{
184+
tx.send(Err(error)).ok();
185+
continue;
186+
}
187+
158188
let depth = conn.transaction_depth;
159189
let res =
160190
conn.handle
@@ -165,9 +195,17 @@ impl ConnectionWorker {
165195

166196
tx.send(res).ok();
167197
}
168-
Command::Commit { tx } => {
169-
let depth = conn.transaction_depth;
198+
Command::Commit {
199+
transaction_depth,
200+
tx,
201+
} => {
202+
if let Err(error) = handle_cancelled_begin(&mut conn, transaction_depth)
203+
{
204+
tx.send(Err(error)).ok();
205+
continue;
206+
}
170207

208+
let depth = conn.transaction_depth;
171209
let res = if depth > 0 {
172210
conn.handle
173211
.exec(commit_ansi_transaction_sql(depth))
@@ -180,9 +218,26 @@ impl ConnectionWorker {
180218

181219
tx.send(res).ok();
182220
}
183-
Command::Rollback { tx } => {
184-
let depth = conn.transaction_depth;
221+
Command::Rollback {
222+
transaction_depth,
223+
tx,
224+
} => {
225+
match handle_cancelled_begin_or_commit_or_rollback(
226+
&mut conn,
227+
transaction_depth,
228+
) {
229+
Ok(true) => (),
230+
Ok(false) => continue,
231+
Err(error) => {
232+
if let Some(tx) = tx {
233+
tx.send(Err(error)).ok();
234+
}
185235

236+
continue;
237+
}
238+
}
239+
240+
let depth = conn.transaction_depth;
186241
let res = if depth > 0 {
187242
conn.handle
188243
.exec(rollback_ansi_transaction_sql(depth))
@@ -259,6 +314,7 @@ impl ConnectionWorker {
259314
query: query.into(),
260315
arguments: args.map(SqliteArguments::into_static),
261316
persistent,
317+
transaction_depth: self.transaction_depth,
262318
tx,
263319
})
264320
.await
@@ -268,21 +324,55 @@ impl ConnectionWorker {
268324
}
269325

270326
pub(crate) async fn begin(&mut self) -> Result<(), Error> {
271-
self.oneshot_cmd(|tx| Command::Begin { tx }).await?
327+
let transaction_depth = self.transaction_depth;
328+
329+
self.oneshot_cmd(|tx| Command::Begin {
330+
transaction_depth,
331+
tx,
332+
})
333+
.await??;
334+
335+
self.transaction_depth += 1;
336+
337+
Ok(())
272338
}
273339

274340
pub(crate) async fn commit(&mut self) -> Result<(), Error> {
275-
self.oneshot_cmd(|tx| Command::Commit { tx }).await?
341+
let transaction_depth = self.transaction_depth;
342+
343+
self.oneshot_cmd(|tx| Command::Commit {
344+
transaction_depth,
345+
tx,
346+
})
347+
.await??;
348+
349+
self.transaction_depth -= 1;
350+
351+
Ok(())
276352
}
277353

278354
pub(crate) async fn rollback(&mut self) -> Result<(), Error> {
279-
self.oneshot_cmd(|tx| Command::Rollback { tx: Some(tx) })
280-
.await?
355+
let transaction_depth = self.transaction_depth;
356+
357+
self.oneshot_cmd(|tx| Command::Rollback {
358+
transaction_depth,
359+
tx: Some(tx),
360+
})
361+
.await??;
362+
363+
self.transaction_depth -= 1;
364+
365+
Ok(())
281366
}
282367

283368
pub(crate) fn start_rollback(&mut self) -> Result<(), Error> {
369+
self.transaction_depth -= 1;
370+
284371
self.command_tx
285-
.send(Command::Rollback { tx: None })
372+
.send(Command::Rollback {
373+
transaction_depth: self.transaction_depth,
374+
tx: None,
375+
})
286376
.map_err(|_| Error::WorkerCrashed)
287377
}
288378

@@ -387,3 +477,58 @@ fn prepare(conn: &mut ConnectionState, query: &str) -> Result<SqliteStatement<'s
387477
fn update_cached_statements_size(conn: &ConnectionState, size: &AtomicUsize) {
388478
size.store(conn.statements.len(), Ordering::Release);
389479
}
480+
481+
// If a `begin()` is cancelled before completion it might happen that the `Begin` command is still
482+
// sent to the worker thread but no `Transaction` is created and so there is no way to commit it or
483+
// roll it back. This function detects such case and handles it by automatically rolling the
484+
// transaction back.
485+
//
486+
// Use only when handling an `Execute`, `Begin` or `Commit` command.
487+
fn handle_cancelled_begin(
488+
conn: &mut ConnectionState,
489+
expected_transaction_depth: usize,
490+
) -> Result<(), Error> {
491+
if expected_transaction_depth != conn.transaction_depth {
492+
if expected_transaction_depth == conn.transaction_depth - 1 {
493+
let depth = conn.transaction_depth;
494+
conn.handle.exec(rollback_ansi_transaction_sql(depth))?;
495+
conn.transaction_depth -= 1;
496+
} else {
497+
// This would indicate cancelled `commit` or `rollback`, but that can only happen when
498+
// handling a `Rollback` command because `commit()` / `rollback()` take the
499+
// transaction by value and so when they are cancelled the transaction is immediately
500+
// dropped which sends a `Rollback`.
501+
unreachable!()
502+
}
503+
}
504+
505+
Ok(())
506+
}
507+
508+
// Same as `handle_cancelled_begin` but additionally handles cancelled `commit()` and `rollback()`
509+
// as well. If `commit()` / `rollback()` is cancelled, it might happen that the corresponding
510+
// `Commit` / `Rollback` command is still sent to the worker thread but the transaction's `open`
511+
// flag is not set to `false` which causes another `Rollback` to be sent when the transaction
512+
// is dropped. This function detects that case and indicates to ignore the superfluous `Rollback`.
513+
//
514+
// Use only when handling a `Rollback` command.
515+
fn handle_cancelled_begin_or_commit_or_rollback(
516+
conn: &mut ConnectionState,
517+
expected_transaction_depth: usize,
518+
) -> Result<bool, Error> {
519+
if expected_transaction_depth != conn.transaction_depth {
520+
if expected_transaction_depth == conn.transaction_depth - 1 {
521+
let depth = conn.transaction_depth;
522+
conn.handle.exec(rollback_ansi_transaction_sql(depth))?;
523+
conn.transaction_depth -= 1;
524+
525+
Ok(true)
526+
} else if expected_transaction_depth == conn.transaction_depth + 1 {
527+
Ok(false)
528+
} else {
529+
unreachable!()
530+
}
531+
} else {
532+
Ok(true)
533+
}
534+
}

0 commit comments

Comments
 (0)