Skip to content

Commit 9ab0f5a

Browse files
committed
fix: make begin,commit,rollback cancel-safe in sqlite (launchbadge#2054)
1 parent 26df7ba commit 9ab0f5a

File tree

2 files changed

+124
-82
lines changed

2 files changed

+124
-82
lines changed

Cargo.lock

+12-67
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

sqlx-core/src/sqlite/connection/worker.rs

+112-15
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,14 @@ enum Command {
5959
tx: flume::Sender<Result<Either<SqliteQueryResult, SqliteRow>, Error>>,
6060
},
6161
Begin {
62-
tx: oneshot::Sender<Result<(), Error>>,
62+
tx: rendezvous_oneshot::Sender<Result<(), Error>>,
6363
options: SqliteTransactionOptions,
6464
},
6565
Commit {
66-
tx: oneshot::Sender<Result<(), Error>>,
66+
tx: rendezvous_oneshot::Sender<Result<(), Error>>,
6767
},
6868
Rollback {
69-
tx: Option<oneshot::Sender<Result<(), Error>>>,
69+
tx: Option<rendezvous_oneshot::Sender<Result<(), Error>>>,
7070
},
7171
CreateCollation {
7272
create_collation:
@@ -121,6 +121,11 @@ impl ConnectionWorker {
121121
return;
122122
}
123123

124+
// If COMMIT or ROLLBACK is processed but not acknowledged, there would be another
125+
// ROLLBACK sent when the `Transaction` drops. We need to ignore it otherwise we
126+
// would rollback an already completed transaction.
127+
let mut ignore_next_start_rollback = false;
128+
124129
for cmd in command_rx {
125130
match cmd {
126131
Command::Prepare { query, tx } => {
@@ -169,8 +174,25 @@ impl ConnectionWorker {
169174
let res = conn.handle.exec(stmt).map(|_| {
170175
conn.transaction_depth += 1;
171176
});
172-
173-
tx.send(res).ok();
177+
let res_ok = res.is_ok();
178+
179+
if tx.blocking_send(res).is_err() && res_ok {
180+
// The BEGIN was processed but not acknowledged. This means no
181+
// `Transaction` was created and so there is no way to commit /
182+
// rollback this transaction. We need to roll it back
183+
// immediately otherwise it would remain started forever.
184+
if let Err(e) =
185+
conn.handle.exec(rollback_sql(depth + 1)).map(|_| {
186+
conn.transaction_depth -= 1;
187+
})
188+
{
189+
// The rollback failed. To prevent leaving the connection
190+
// in an inconsistent state we shutdown this worker which
191+
// causes any subsequent operation on the connection to fail.
192+
log::error!("failed to rollback cancelled transaction: {}", e);
193+
break;
194+
}
195+
}
174196
}
175197
Command::Commit { tx } => {
176198
let depth = conn.transaction_depth;
@@ -186,26 +208,40 @@ impl ConnectionWorker {
186208
} else {
187209
Ok(())
188210
};
211+
let res_ok = res.is_ok();
189212

190-
tx.send(res).ok();
213+
if tx.blocking_send(res).is_err() && res_ok {
214+
// The COMMIT was processed but not acknowledged. This means that
215+
// the `Transaction` doesn't know it was committed and will try to
216+
// rollback on drop. We need to ignore that rollback.
217+
ignore_next_start_rollback = true;
218+
}
191219
}
192220
Command::Rollback { tx } => {
221+
if ignore_next_start_rollback && tx.is_none() {
222+
ignore_next_start_rollback = false;
223+
continue;
224+
}
225+
193226
let depth = conn.transaction_depth;
194227
let res = if depth > 0 {
195-
let stmt = if depth == 1 {
196-
ROLLBACK_ANSI_TRANSACTION.to_string()
197-
} else {
198-
rollback_savepoint_sql(depth)
199-
};
228+
let stmt = rollback_sql(depth);
200229
conn.handle.exec(stmt).map(|_| {
201230
conn.transaction_depth -= 1;
202231
})
203232
} else {
204233
Ok(())
205234
};
235+
let res_ok = res.is_ok();
206236

207237
if let Some(tx) = tx {
208-
tx.send(res).ok();
238+
if tx.blocking_send(res).is_err() && res_ok {
239+
// The ROLLBACK was processed but not acknowledged. This means
240+
// that the `Transaction` doesn't know it was rolled back and
241+
// will try to rollback again on drop. We need to ignore that
242+
// rollback.
243+
ignore_next_start_rollback = true;
244+
}
209245
}
210246
}
211247
Command::CreateCollation { create_collation } => {
@@ -279,16 +315,17 @@ impl ConnectionWorker {
279315
}
280316

281317
pub(crate) async fn begin(&mut self, options: SqliteTransactionOptions) -> Result<(), Error> {
282-
self.oneshot_cmd(|tx| Command::Begin { tx, options })
318+
self.oneshot_cmd_with_ack(|tx| Command::Begin { tx, options })
283319
.await?
284320
}
285321

286322
pub(crate) async fn commit(&mut self) -> Result<(), Error> {
287-
self.oneshot_cmd(|tx| Command::Commit { tx }).await?
323+
self.oneshot_cmd_with_ack(|tx| Command::Commit { tx })
324+
.await?
288325
}
289326

290327
pub(crate) async fn rollback(&mut self) -> Result<(), Error> {
291-
self.oneshot_cmd(|tx| Command::Rollback { tx: Some(tx) })
328+
self.oneshot_cmd_with_ack(|tx| Command::Rollback { tx: Some(tx) })
292329
.await?
293330
}
294331

@@ -316,6 +353,20 @@ impl ConnectionWorker {
316353
rx.await.map_err(|_| Error::WorkerCrashed)
317354
}
318355

356+
async fn oneshot_cmd_with_ack<F, T>(&mut self, command: F) -> Result<T, Error>
357+
where
358+
F: FnOnce(rendezvous_oneshot::Sender<T>) -> Command,
359+
{
360+
let (tx, rx) = rendezvous_oneshot::channel();
361+
362+
self.command_tx
363+
.send_async(command(tx))
364+
.await
365+
.map_err(|_| Error::WorkerCrashed)?;
366+
367+
rx.recv().await.map_err(|_| Error::WorkerCrashed)
368+
}
369+
319370
pub fn create_collation(
320371
&mut self,
321372
name: &str,
@@ -399,3 +450,49 @@ fn prepare(conn: &mut ConnectionState, query: &str) -> Result<SqliteStatement<'s
399450
fn update_cached_statements_size(conn: &ConnectionState, size: &AtomicUsize) {
400451
size.store(conn.statements.len(), Ordering::Release);
401452
}
453+
454+
fn rollback_sql(depth: usize) -> String {
455+
if depth == 1 {
456+
ROLLBACK_ANSI_TRANSACTION.to_string()
457+
} else {
458+
rollback_savepoint_sql(depth)
459+
}
460+
}
461+
462+
// A oneshot channel where send completes only after the receiver receives the value.
463+
mod rendezvous_oneshot {
464+
use super::oneshot::{self, Canceled};
465+
466+
pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
467+
let (inner_tx, inner_rx) = oneshot::channel();
468+
(Sender { inner: inner_tx }, Receiver { inner: inner_rx })
469+
}
470+
471+
pub struct Sender<T> {
472+
inner: oneshot::Sender<(T, oneshot::Sender<()>)>,
473+
}
474+
475+
impl<T> Sender<T> {
476+
pub async fn send(self, value: T) -> Result<(), Canceled> {
477+
let (ack_tx, ack_rx) = oneshot::channel();
478+
self.inner.send((value, ack_tx)).map_err(|_| Canceled)?;
479+
ack_rx.await
480+
}
481+
482+
pub fn blocking_send(self, value: T) -> Result<(), Canceled> {
483+
futures_executor::block_on(self.send(value))
484+
}
485+
}
486+
487+
pub struct Receiver<T> {
488+
inner: oneshot::Receiver<(T, oneshot::Sender<()>)>,
489+
}
490+
491+
impl<T> Receiver<T> {
492+
pub async fn recv(self) -> Result<T, Canceled> {
493+
let (value, ack_tx) = self.inner.await?;
494+
ack_tx.send(()).map_err(|_| Canceled)?;
495+
Ok(value)
496+
}
497+
}
498+
}

0 commit comments

Comments
 (0)