@@ -32,6 +32,15 @@ pub(crate) struct ConnectionWorker {
32
32
pub ( crate ) handle_raw : ConnectionHandleRaw ,
33
33
/// Mutex for locking access to the database.
34
34
pub ( crate ) shared : Arc < WorkerSharedState > ,
35
+
36
+ // Mirror of `shared.conn.transaction_depth` to help provide cancel-safety:
37
+ //
38
+ // - If `transaction_depth == shared.conn.transaction_depth` then no cancellation occurred
39
+ // - If `transaction_depth == shared.conn.transaction_depth - 1` then a `begin()` was cancelled
40
+ // - If `transaction_depth == shared.conn.transaction_depth + 1` then a `commit()` or
41
+ // `rollback()` was cancelled
42
+ // - No other cases are possible (would indicate a logic bug)
43
+ transaction_depth : usize ,
35
44
}
36
45
37
46
pub ( crate ) struct WorkerSharedState {
@@ -52,15 +61,19 @@ enum Command {
52
61
query : Box < str > ,
53
62
arguments : Option < SqliteArguments < ' static > > ,
54
63
persistent : bool ,
64
+ transaction_depth : usize ,
55
65
tx : flume:: Sender < Result < Either < SqliteQueryResult , SqliteRow > , Error > > ,
56
66
} ,
57
67
Begin {
68
+ transaction_depth : usize ,
58
69
tx : oneshot:: Sender < Result < ( ) , Error > > ,
59
70
} ,
60
71
Commit {
72
+ transaction_depth : usize ,
61
73
tx : oneshot:: Sender < Result < ( ) , Error > > ,
62
74
} ,
63
75
Rollback {
76
+ transaction_depth : usize ,
64
77
tx : Option < oneshot:: Sender < Result < ( ) , Error > > > ,
65
78
} ,
66
79
CreateCollation {
@@ -110,6 +123,7 @@ impl ConnectionWorker {
110
123
command_tx,
111
124
handle_raw : conn. handle . to_raw ( ) ,
112
125
shared : Arc :: clone ( & shared) ,
126
+ transaction_depth : 0 ,
113
127
} ) )
114
128
. is_err ( )
115
129
{
@@ -135,8 +149,15 @@ impl ConnectionWorker {
135
149
query,
136
150
arguments,
137
151
persistent,
152
+ transaction_depth,
138
153
tx,
139
154
} => {
155
+ if let Err ( error) = handle_cancelled_begin ( & mut conn, transaction_depth)
156
+ {
157
+ tx. send ( Err ( error) ) . ok ( ) ;
158
+ continue ;
159
+ }
160
+
140
161
let iter = match execute:: iter ( & mut conn, & query, arguments, persistent)
141
162
{
142
163
Ok ( iter) => iter,
@@ -154,7 +175,16 @@ impl ConnectionWorker {
154
175
155
176
update_cached_statements_size ( & conn, & shared. cached_statements_size ) ;
156
177
}
157
- Command :: Begin { tx } => {
178
+ Command :: Begin {
179
+ transaction_depth,
180
+ tx,
181
+ } => {
182
+ if let Err ( error) = handle_cancelled_begin ( & mut conn, transaction_depth)
183
+ {
184
+ tx. send ( Err ( error) ) . ok ( ) ;
185
+ continue ;
186
+ }
187
+
158
188
let depth = conn. transaction_depth ;
159
189
let res =
160
190
conn. handle
@@ -165,9 +195,17 @@ impl ConnectionWorker {
165
195
166
196
tx. send ( res) . ok ( ) ;
167
197
}
168
- Command :: Commit { tx } => {
169
- let depth = conn. transaction_depth ;
198
+ Command :: Commit {
199
+ transaction_depth,
200
+ tx,
201
+ } => {
202
+ if let Err ( error) = handle_cancelled_begin ( & mut conn, transaction_depth)
203
+ {
204
+ tx. send ( Err ( error) ) . ok ( ) ;
205
+ continue ;
206
+ }
170
207
208
+ let depth = conn. transaction_depth ;
171
209
let res = if depth > 0 {
172
210
conn. handle
173
211
. exec ( commit_ansi_transaction_sql ( depth) )
@@ -180,9 +218,26 @@ impl ConnectionWorker {
180
218
181
219
tx. send ( res) . ok ( ) ;
182
220
}
183
- Command :: Rollback { tx } => {
184
- let depth = conn. transaction_depth ;
221
+ Command :: Rollback {
222
+ transaction_depth,
223
+ tx,
224
+ } => {
225
+ match handle_cancelled_begin_or_commit_or_rollback (
226
+ & mut conn,
227
+ transaction_depth,
228
+ ) {
229
+ Ok ( true ) => ( ) ,
230
+ Ok ( false ) => continue ,
231
+ Err ( error) => {
232
+ if let Some ( tx) = tx {
233
+ tx. send ( Err ( error) ) . ok ( ) ;
234
+ }
185
235
236
+ continue ;
237
+ }
238
+ }
239
+
240
+ let depth = conn. transaction_depth ;
186
241
let res = if depth > 0 {
187
242
conn. handle
188
243
. exec ( rollback_ansi_transaction_sql ( depth) )
@@ -259,6 +314,7 @@ impl ConnectionWorker {
259
314
query : query. into ( ) ,
260
315
arguments : args. map ( SqliteArguments :: into_static) ,
261
316
persistent,
317
+ transaction_depth : self . transaction_depth ,
262
318
tx,
263
319
} )
264
320
. await
@@ -268,21 +324,55 @@ impl ConnectionWorker {
268
324
}
269
325
270
326
pub ( crate ) async fn begin ( & mut self ) -> Result < ( ) , Error > {
271
- self . oneshot_cmd ( |tx| Command :: Begin { tx } ) . await ?
327
+ let transaction_depth = self . transaction_depth ;
328
+
329
+ self . oneshot_cmd ( |tx| Command :: Begin {
330
+ transaction_depth,
331
+ tx,
332
+ } )
333
+ . await ??;
334
+
335
+ self . transaction_depth += 1 ;
336
+
337
+ Ok ( ( ) )
272
338
}
273
339
274
340
pub ( crate ) async fn commit ( & mut self ) -> Result < ( ) , Error > {
275
- self . oneshot_cmd ( |tx| Command :: Commit { tx } ) . await ?
341
+ let transaction_depth = self . transaction_depth ;
342
+
343
+ self . oneshot_cmd ( |tx| Command :: Commit {
344
+ transaction_depth,
345
+ tx,
346
+ } )
347
+ . await ??;
348
+
349
+ self . transaction_depth -= 1 ;
350
+
351
+ Ok ( ( ) )
276
352
}
277
353
278
354
pub ( crate ) async fn rollback ( & mut self ) -> Result < ( ) , Error > {
279
- self . oneshot_cmd ( |tx| Command :: Rollback { tx : Some ( tx) } )
280
- . await ?
355
+ let transaction_depth = self . transaction_depth ;
356
+
357
+ self . oneshot_cmd ( |tx| Command :: Rollback {
358
+ transaction_depth,
359
+ tx : Some ( tx) ,
360
+ } )
361
+ . await ??;
362
+
363
+ self . transaction_depth -= 1 ;
364
+
365
+ Ok ( ( ) )
281
366
}
282
367
283
368
pub ( crate ) fn start_rollback ( & mut self ) -> Result < ( ) , Error > {
369
+ self . transaction_depth -= 1 ;
370
+
284
371
self . command_tx
285
- . send ( Command :: Rollback { tx : None } )
372
+ . send ( Command :: Rollback {
373
+ transaction_depth : self . transaction_depth ,
374
+ tx : None ,
375
+ } )
286
376
. map_err ( |_| Error :: WorkerCrashed )
287
377
}
288
378
@@ -387,3 +477,58 @@ fn prepare(conn: &mut ConnectionState, query: &str) -> Result<SqliteStatement<'s
387
477
fn update_cached_statements_size ( conn : & ConnectionState , size : & AtomicUsize ) {
388
478
size. store ( conn. statements . len ( ) , Ordering :: Release ) ;
389
479
}
480
+
481
+ // If a `begin()` is cancelled before completion it might happen that the `Begin` command is still
482
+ // sent to the worker thread but no `Transaction` is created and so there is no way to commit it or
483
+ // roll it back. This function detects such case and handles it by automatically rolling the
484
+ // transaction back.
485
+ //
486
+ // Use only when handling an `Execute`, `Begin` or `Commit` command.
487
+ fn handle_cancelled_begin (
488
+ conn : & mut ConnectionState ,
489
+ expected_transaction_depth : usize ,
490
+ ) -> Result < ( ) , Error > {
491
+ if expected_transaction_depth != conn. transaction_depth {
492
+ if expected_transaction_depth == conn. transaction_depth - 1 {
493
+ let depth = conn. transaction_depth ;
494
+ conn. handle . exec ( rollback_ansi_transaction_sql ( depth) ) ?;
495
+ conn. transaction_depth -= 1 ;
496
+ } else {
497
+ // This would indicate cancelled `commit` or `rollback`, but that can only happen when
498
+ // handling a `Rollback` command because `commit()` / `rollback()` take the
499
+ // transaction by value and so when they are cancelled the transaction is immediately
500
+ // dropped which sends a `Rollback`.
501
+ unreachable ! ( )
502
+ }
503
+ }
504
+
505
+ Ok ( ( ) )
506
+ }
507
+
508
+ // Same as `handle_cancelled_begin` but additionally handles cancelled `commit()` and `rollback()`
509
+ // as well. If `commit()` / `rollback()` is cancelled, it might happen that the corresponding
510
+ // `Commit` / `Rollback` command is still sent to the worker thread but the transaction's `open`
511
+ // flag is not set to `false` which causes another `Rollback` to be sent when the transaction
512
+ // is dropped. This function detects that case and indicates to ignore the superfluous `Rollback`.
513
+ //
514
+ // Use only when handling a `Rollback` command.
515
+ fn handle_cancelled_begin_or_commit_or_rollback (
516
+ conn : & mut ConnectionState ,
517
+ expected_transaction_depth : usize ,
518
+ ) -> Result < bool , Error > {
519
+ if expected_transaction_depth != conn. transaction_depth {
520
+ if expected_transaction_depth == conn. transaction_depth - 1 {
521
+ let depth = conn. transaction_depth ;
522
+ conn. handle . exec ( rollback_ansi_transaction_sql ( depth) ) ?;
523
+ conn. transaction_depth -= 1 ;
524
+
525
+ Ok ( true )
526
+ } else if expected_transaction_depth == conn. transaction_depth + 1 {
527
+ Ok ( false )
528
+ } else {
529
+ unreachable ! ( )
530
+ }
531
+ } else {
532
+ Ok ( true )
533
+ }
534
+ }
0 commit comments