Skip to content

Commit 8334760

Browse files
committed
fix: make begin,commit,rollback cancel-safe in sqlite (#2054)
1 parent 2d65c5d commit 8334760

File tree

1 file changed

+107
-10
lines changed

1 file changed

+107
-10
lines changed

sqlx-core/src/sqlite/connection/worker.rs

+107-10
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,13 @@ enum Command {
5555
tx: flume::Sender<Result<Either<SqliteQueryResult, SqliteRow>, Error>>,
5656
},
5757
Begin {
58-
tx: oneshot::Sender<Result<(), Error>>,
58+
tx: rendezvous_oneshot::Sender<Result<(), Error>>,
5959
},
6060
Commit {
61-
tx: oneshot::Sender<Result<(), Error>>,
61+
tx: rendezvous_oneshot::Sender<Result<(), Error>>,
6262
},
6363
Rollback {
64-
tx: Option<oneshot::Sender<Result<(), Error>>>,
64+
tx: Option<rendezvous_oneshot::Sender<Result<(), Error>>>,
6565
},
6666
CreateCollation {
6767
create_collation:
@@ -116,6 +116,11 @@ impl ConnectionWorker {
116116
return;
117117
}
118118

119+
// If COMMIT or ROLLBACK is processed but not acknowledged, there would be another
120+
// ROLLBACK sent when the `Transaction` drops. We need to ignore it otherwise we
121+
// would rollback an already completed transaction.
122+
let mut ignore_next_start_rollback = false;
123+
119124
for cmd in command_rx {
120125
match cmd {
121126
Command::Prepare { query, tx } => {
@@ -162,8 +167,27 @@ impl ConnectionWorker {
162167
.map(|_| {
163168
conn.transaction_depth += 1;
164169
});
165-
166-
tx.send(res).ok();
170+
let res_ok = res.is_ok();
171+
172+
if tx.blocking_send(res).is_err() && res_ok {
173+
// The BEGIN was processed but not acknowledged. This means no
174+
// `Transaction` was created and so there is no way to commit /
175+
// rollback this transaction. We need to roll it back
176+
// immediately otherwise it would remain started forever.
177+
if let Err(e) = conn
178+
.handle
179+
.exec(rollback_ansi_transaction_sql(depth + 1))
180+
.map(|_| {
181+
conn.transaction_depth -= 1;
182+
})
183+
{
184+
// The rollback failed. To prevent leaving the connection
185+
// in an inconsistent state we shutdown this worker which
186+
// causes any subsequent operation on the connection to fail.
187+
log::error!("failed to rollback cancelled transaction: {}", e);
188+
break;
189+
}
190+
}
167191
}
168192
Command::Commit { tx } => {
169193
let depth = conn.transaction_depth;
@@ -177,10 +201,21 @@ impl ConnectionWorker {
177201
} else {
178202
Ok(())
179203
};
204+
let res_ok = res.is_ok();
180205

181-
tx.send(res).ok();
206+
if tx.blocking_send(res).is_err() && res_ok {
207+
// The COMMIT was processed but not acknowledged. This means that
208+
// the `Transaction` doesn't know it was committed and will try to
209+
// rollback on drop. We need to ignore that rollback.
210+
ignore_next_start_rollback = true;
211+
}
182212
}
183213
Command::Rollback { tx } => {
214+
if ignore_next_start_rollback && tx.is_none() {
215+
ignore_next_start_rollback = false;
216+
continue;
217+
}
218+
184219
let depth = conn.transaction_depth;
185220

186221
let res = if depth > 0 {
@@ -193,8 +228,16 @@ impl ConnectionWorker {
193228
Ok(())
194229
};
195230

231+
let res_ok = res.is_ok();
232+
196233
if let Some(tx) = tx {
197-
tx.send(res).ok();
234+
if tx.blocking_send(res).is_err() && res_ok {
235+
// The ROLLBACK was processed but not acknowledged. This means
236+
// that the `Transaction` doesn't know it was rolled back and
237+
// will try to rollback again on drop. We need to ignore that
238+
// rollback.
239+
ignore_next_start_rollback = true;
240+
}
198241
}
199242
}
200243
Command::CreateCollation { create_collation } => {
@@ -268,15 +311,17 @@ impl ConnectionWorker {
268311
}
269312

270313
pub(crate) async fn begin(&mut self) -> Result<(), Error> {
271-
self.oneshot_cmd(|tx| Command::Begin { tx }).await?
314+
self.oneshot_cmd_with_ack(|tx| Command::Begin { tx })
315+
.await?
272316
}
273317

274318
pub(crate) async fn commit(&mut self) -> Result<(), Error> {
275-
self.oneshot_cmd(|tx| Command::Commit { tx }).await?
319+
self.oneshot_cmd_with_ack(|tx| Command::Commit { tx })
320+
.await?
276321
}
277322

278323
pub(crate) async fn rollback(&mut self) -> Result<(), Error> {
279-
self.oneshot_cmd(|tx| Command::Rollback { tx: Some(tx) })
324+
self.oneshot_cmd_with_ack(|tx| Command::Rollback { tx: Some(tx) })
280325
.await?
281326
}
282327

@@ -304,6 +349,20 @@ impl ConnectionWorker {
304349
rx.await.map_err(|_| Error::WorkerCrashed)
305350
}
306351

352+
async fn oneshot_cmd_with_ack<F, T>(&mut self, command: F) -> Result<T, Error>
353+
where
354+
F: FnOnce(rendezvous_oneshot::Sender<T>) -> Command,
355+
{
356+
let (tx, rx) = rendezvous_oneshot::channel();
357+
358+
self.command_tx
359+
.send_async(command(tx))
360+
.await
361+
.map_err(|_| Error::WorkerCrashed)?;
362+
363+
rx.recv().await.map_err(|_| Error::WorkerCrashed)
364+
}
365+
307366
pub fn create_collation(
308367
&mut self,
309368
name: &str,
@@ -387,3 +446,41 @@ fn prepare(conn: &mut ConnectionState, query: &str) -> Result<SqliteStatement<'s
387446
fn update_cached_statements_size(conn: &ConnectionState, size: &AtomicUsize) {
388447
size.store(conn.statements.len(), Ordering::Release);
389448
}
449+
450+
// A oneshot channel where send completes only after the receiver receives the value.
451+
mod rendezvous_oneshot {
452+
use super::oneshot::{self, Canceled};
453+
454+
pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
455+
let (inner_tx, inner_rx) = oneshot::channel();
456+
(Sender { inner: inner_tx }, Receiver { inner: inner_rx })
457+
}
458+
459+
pub struct Sender<T> {
460+
inner: oneshot::Sender<(T, oneshot::Sender<()>)>,
461+
}
462+
463+
impl<T> Sender<T> {
464+
pub async fn send(self, value: T) -> Result<(), Canceled> {
465+
let (ack_tx, ack_rx) = oneshot::channel();
466+
self.inner.send((value, ack_tx)).map_err(|_| Canceled)?;
467+
ack_rx.await
468+
}
469+
470+
pub fn blocking_send(self, value: T) -> Result<(), Canceled> {
471+
futures_executor::block_on(self.send(value))
472+
}
473+
}
474+
475+
pub struct Receiver<T> {
476+
inner: oneshot::Receiver<(T, oneshot::Sender<()>)>,
477+
}
478+
479+
impl<T> Receiver<T> {
480+
pub async fn recv(self) -> Result<T, Canceled> {
481+
let (value, ack_tx) = self.inner.await?;
482+
ack_tx.send(()).map_err(|_| Canceled)?;
483+
Ok(value)
484+
}
485+
}
486+
}

0 commit comments

Comments
 (0)