Skip to content

Commit 4155986

Browse files
committed
fix: make begin,commit,rollback cancel-safe in sqlite (#2054)
1 parent 26f60d9 commit 4155986

File tree

1 file changed

+157
-10
lines changed

1 file changed

+157
-10
lines changed

sqlx-core/src/sqlite/connection/worker.rs

+157-10
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,16 @@ pub(crate) struct ConnectionWorker {
3232
pub(crate) handle_raw: ConnectionHandleRaw,
3333
/// Mutex for locking access to the database.
3434
pub(crate) shared: Arc<WorkerSharedState>,
35+
36+
// Tracks `shared.conn.transaction_depth` to help provide cancel-safety. Updated only when
37+
// `begin()` / `commit()` / `rollback()` successfully complete.
38+
//
39+
// - If `transaction_depth == shared.conn.transaction_depth` then no cancellation occurred
40+
// - If `transaction_depth == shared.conn.transaction_depth - 1` then a `begin()` was cancelled
41+
// - If `transaction_depth == shared.conn.transaction_depth + 1` then a `commit()` or
42+
// `rollback()` was cancelled
43+
// - No other cases are possible (would indicate a logic bug)
44+
transaction_depth: usize,
3545
}
3646

3747
pub(crate) struct WorkerSharedState {
@@ -52,15 +62,19 @@ enum Command {
5262
query: Box<str>,
5363
arguments: Option<SqliteArguments<'static>>,
5464
persistent: bool,
65+
transaction_depth: usize,
5566
tx: flume::Sender<Result<Either<SqliteQueryResult, SqliteRow>, Error>>,
5667
},
5768
Begin {
69+
transaction_depth: usize,
5870
tx: oneshot::Sender<Result<(), Error>>,
5971
},
6072
Commit {
73+
transaction_depth: usize,
6174
tx: oneshot::Sender<Result<(), Error>>,
6275
},
6376
Rollback {
77+
transaction_depth: usize,
6478
tx: Option<oneshot::Sender<Result<(), Error>>>,
6579
},
6680
CreateCollation {
@@ -110,6 +124,7 @@ impl ConnectionWorker {
110124
command_tx,
111125
handle_raw: conn.handle.to_raw(),
112126
shared: Arc::clone(&shared),
127+
transaction_depth: 0,
113128
}))
114129
.is_err()
115130
{
@@ -135,8 +150,15 @@ impl ConnectionWorker {
135150
query,
136151
arguments,
137152
persistent,
153+
transaction_depth,
138154
tx,
139155
} => {
156+
if let Err(error) = handle_cancelled_begin(&mut conn, transaction_depth)
157+
{
158+
tx.send(Err(error)).ok();
159+
continue;
160+
}
161+
140162
let iter = match execute::iter(&mut conn, &query, arguments, persistent)
141163
{
142164
Ok(iter) => iter,
@@ -154,7 +176,16 @@ impl ConnectionWorker {
154176

155177
update_cached_statements_size(&conn, &shared.cached_statements_size);
156178
}
157-
Command::Begin { tx } => {
179+
Command::Begin {
180+
transaction_depth,
181+
tx,
182+
} => {
183+
if let Err(error) = handle_cancelled_begin(&mut conn, transaction_depth)
184+
{
185+
tx.send(Err(error)).ok();
186+
continue;
187+
}
188+
158189
let depth = conn.transaction_depth;
159190
let res =
160191
conn.handle
@@ -165,9 +196,17 @@ impl ConnectionWorker {
165196

166197
tx.send(res).ok();
167198
}
168-
Command::Commit { tx } => {
169-
let depth = conn.transaction_depth;
199+
Command::Commit {
200+
transaction_depth,
201+
tx,
202+
} => {
203+
if let Err(error) = handle_cancelled_begin(&mut conn, transaction_depth)
204+
{
205+
tx.send(Err(error)).ok();
206+
continue;
207+
}
170208

209+
let depth = conn.transaction_depth;
171210
let res = if depth > 0 {
172211
conn.handle
173212
.exec(commit_ansi_transaction_sql(depth))
@@ -180,9 +219,26 @@ impl ConnectionWorker {
180219

181220
tx.send(res).ok();
182221
}
183-
Command::Rollback { tx } => {
184-
let depth = conn.transaction_depth;
222+
Command::Rollback {
223+
transaction_depth,
224+
tx,
225+
} => {
226+
match handle_cancelled_begin_or_commit_or_rollback(
227+
&mut conn,
228+
transaction_depth,
229+
) {
230+
Ok(true) => (),
231+
Ok(false) => continue,
232+
Err(error) => {
233+
if let Some(tx) = tx {
234+
tx.send(Err(error)).ok();
235+
}
185236

237+
continue;
238+
}
239+
}
240+
241+
let depth = conn.transaction_depth;
186242
let res = if depth > 0 {
187243
conn.handle
188244
.exec(rollback_ansi_transaction_sql(depth))
@@ -259,6 +315,7 @@ impl ConnectionWorker {
259315
query: query.into(),
260316
arguments: args.map(SqliteArguments::into_static),
261317
persistent,
318+
transaction_depth: self.transaction_depth,
262319
tx,
263320
})
264321
.await
@@ -268,21 +325,56 @@ impl ConnectionWorker {
268325
}
269326

270327
pub(crate) async fn begin(&mut self) -> Result<(), Error> {
271-
self.oneshot_cmd(|tx| Command::Begin { tx }).await?
328+
let transaction_depth = self.transaction_depth;
329+
330+
self.oneshot_cmd(|tx| Command::Begin {
331+
transaction_depth,
332+
tx,
333+
})
334+
.await??;
335+
336+
self.transaction_depth += 1;
337+
338+
Ok(())
272339
}
273340

274341
pub(crate) async fn commit(&mut self) -> Result<(), Error> {
275-
self.oneshot_cmd(|tx| Command::Commit { tx }).await?
342+
let transaction_depth = self.transaction_depth;
343+
344+
self.oneshot_cmd(|tx| Command::Commit {
345+
transaction_depth,
346+
tx,
347+
})
348+
.await??;
349+
350+
self.transaction_depth -= 1;
351+
352+
Ok(())
276353
}
277354

278355
pub(crate) async fn rollback(&mut self) -> Result<(), Error> {
279-
self.oneshot_cmd(|tx| Command::Rollback { tx: Some(tx) })
280-
.await?
356+
let transaction_depth = self.transaction_depth;
357+
358+
self.oneshot_cmd(|tx| Command::Rollback {
359+
transaction_depth,
360+
tx: Some(tx),
361+
})
362+
.await??;
363+
364+
self.transaction_depth -= 1;
365+
366+
Ok(())
281367
}
282368

283369
pub(crate) fn start_rollback(&mut self) -> Result<(), Error> {
370+
let transaction_depth = self.transaction_depth;
371+
self.transaction_depth -= 1;
372+
284373
self.command_tx
285-
.send(Command::Rollback { tx: None })
374+
.send(Command::Rollback {
375+
transaction_depth,
376+
tx: None,
377+
})
286378
.map_err(|_| Error::WorkerCrashed)
287379
}
288380

@@ -387,3 +479,58 @@ fn prepare(conn: &mut ConnectionState, query: &str) -> Result<SqliteStatement<'s
387479
fn update_cached_statements_size(conn: &ConnectionState, size: &AtomicUsize) {
388480
size.store(conn.statements.len(), Ordering::Release);
389481
}
482+
483+
// If a `begin()` is cancelled before completion it might happen that the `Begin` command is still
484+
// sent to the worker thread but no `Transaction` is created and so there is no way to commit it or
485+
// roll it back. This function detects such case and handles it by automatically rolling the
486+
// transaction back.
487+
//
488+
// Use only when handling an `Execute`, `Begin` or `Commit` command.
489+
fn handle_cancelled_begin(
490+
conn: &mut ConnectionState,
491+
expected_transaction_depth: usize,
492+
) -> Result<(), Error> {
493+
if expected_transaction_depth != conn.transaction_depth {
494+
if expected_transaction_depth == conn.transaction_depth - 1 {
495+
let depth = conn.transaction_depth;
496+
conn.handle.exec(rollback_ansi_transaction_sql(depth))?;
497+
conn.transaction_depth -= 1;
498+
} else {
499+
// This would indicate cancelled `commit` or `rollback`, but that can only happen when
500+
// handling a `Rollback` command because `commit()` / `rollback()` take the
501+
// transaction by value and so when they are cancelled the transaction is immediately
502+
// dropped which sends a `Rollback`.
503+
unreachable!()
504+
}
505+
}
506+
507+
Ok(())
508+
}
509+
510+
// Same as `handle_cancelled_begin` but additionally handles cancelled `commit()` and `rollback()`
511+
// as well. If `commit()` / `rollback()` is cancelled, it might happen that the corresponding
512+
// `Commit` / `Rollback` command is still sent to the worker thread but the transaction's `open`
513+
// flag is not set to `false` which causes another `Rollback` to be sent when the transaction
514+
// is dropped. This function detects that case and indicates to ignore the superfluous `Rollback`.
515+
//
516+
// Use only when handling a `Rollback` command.
517+
fn handle_cancelled_begin_or_commit_or_rollback(
518+
conn: &mut ConnectionState,
519+
expected_transaction_depth: usize,
520+
) -> Result<bool, Error> {
521+
if expected_transaction_depth != conn.transaction_depth {
522+
if expected_transaction_depth + 1 == conn.transaction_depth {
523+
let depth = conn.transaction_depth;
524+
conn.handle.exec(rollback_ansi_transaction_sql(depth))?;
525+
conn.transaction_depth -= 1;
526+
527+
Ok(true)
528+
} else if expected_transaction_depth == conn.transaction_depth + 1 {
529+
Ok(false)
530+
} else {
531+
unreachable!()
532+
}
533+
} else {
534+
Ok(true)
535+
}
536+
}

0 commit comments

Comments
 (0)