@@ -32,6 +32,16 @@ pub(crate) struct ConnectionWorker {
32
32
pub ( crate ) handle_raw : ConnectionHandleRaw ,
33
33
/// Mutex for locking access to the database.
34
34
pub ( crate ) shared : Arc < WorkerSharedState > ,
35
+
36
+ // Tracks `shared.conn.transaction_depth` to help provide cancel-safety. Updated only when
37
+ // `begin()` / `commit()` / `rollback()` successfully complete.
38
+ //
39
+ // - If `transaction_depth == shared.conn.transaction_depth` then no cancellation occurred
40
+ // - If `transaction_depth == shared.conn.transaction_depth - 1` then a `begin()` was cancelled
41
+ // - If `transaction_depth == shared.conn.transaction_depth + 1` then a `commit()` or
42
+ // `rollback()` was cancelled
43
+ // - No other cases are possible (would indicate a logic bug)
44
+ transaction_depth : usize ,
35
45
}
36
46
37
47
pub ( crate ) struct WorkerSharedState {
@@ -52,15 +62,19 @@ enum Command {
52
62
query : Box < str > ,
53
63
arguments : Option < SqliteArguments < ' static > > ,
54
64
persistent : bool ,
65
+ transaction_depth : usize ,
55
66
tx : flume:: Sender < Result < Either < SqliteQueryResult , SqliteRow > , Error > > ,
56
67
} ,
57
68
Begin {
69
+ transaction_depth : usize ,
58
70
tx : oneshot:: Sender < Result < ( ) , Error > > ,
59
71
} ,
60
72
Commit {
73
+ transaction_depth : usize ,
61
74
tx : oneshot:: Sender < Result < ( ) , Error > > ,
62
75
} ,
63
76
Rollback {
77
+ transaction_depth : usize ,
64
78
tx : Option < oneshot:: Sender < Result < ( ) , Error > > > ,
65
79
} ,
66
80
CreateCollation {
@@ -110,6 +124,7 @@ impl ConnectionWorker {
110
124
command_tx,
111
125
handle_raw : conn. handle . to_raw ( ) ,
112
126
shared : Arc :: clone ( & shared) ,
127
+ transaction_depth : 0 ,
113
128
} ) )
114
129
. is_err ( )
115
130
{
@@ -135,8 +150,15 @@ impl ConnectionWorker {
135
150
query,
136
151
arguments,
137
152
persistent,
153
+ transaction_depth,
138
154
tx,
139
155
} => {
156
+ if let Err ( error) = handle_cancelled_begin ( & mut conn, transaction_depth)
157
+ {
158
+ tx. send ( Err ( error) ) . ok ( ) ;
159
+ continue ;
160
+ }
161
+
140
162
let iter = match execute:: iter ( & mut conn, & query, arguments, persistent)
141
163
{
142
164
Ok ( iter) => iter,
@@ -154,7 +176,16 @@ impl ConnectionWorker {
154
176
155
177
update_cached_statements_size ( & conn, & shared. cached_statements_size ) ;
156
178
}
157
- Command :: Begin { tx } => {
179
+ Command :: Begin {
180
+ transaction_depth,
181
+ tx,
182
+ } => {
183
+ if let Err ( error) = handle_cancelled_begin ( & mut conn, transaction_depth)
184
+ {
185
+ tx. send ( Err ( error) ) . ok ( ) ;
186
+ continue ;
187
+ }
188
+
158
189
let depth = conn. transaction_depth ;
159
190
let res =
160
191
conn. handle
@@ -165,9 +196,17 @@ impl ConnectionWorker {
165
196
166
197
tx. send ( res) . ok ( ) ;
167
198
}
168
- Command :: Commit { tx } => {
169
- let depth = conn. transaction_depth ;
199
+ Command :: Commit {
200
+ transaction_depth,
201
+ tx,
202
+ } => {
203
+ if let Err ( error) = handle_cancelled_begin ( & mut conn, transaction_depth)
204
+ {
205
+ tx. send ( Err ( error) ) . ok ( ) ;
206
+ continue ;
207
+ }
170
208
209
+ let depth = conn. transaction_depth ;
171
210
let res = if depth > 0 {
172
211
conn. handle
173
212
. exec ( commit_ansi_transaction_sql ( depth) )
@@ -180,9 +219,26 @@ impl ConnectionWorker {
180
219
181
220
tx. send ( res) . ok ( ) ;
182
221
}
183
- Command :: Rollback { tx } => {
184
- let depth = conn. transaction_depth ;
222
+ Command :: Rollback {
223
+ transaction_depth,
224
+ tx,
225
+ } => {
226
+ match handle_cancelled_begin_or_commit_or_rollback (
227
+ & mut conn,
228
+ transaction_depth,
229
+ ) {
230
+ Ok ( true ) => ( ) ,
231
+ Ok ( false ) => continue ,
232
+ Err ( error) => {
233
+ if let Some ( tx) = tx {
234
+ tx. send ( Err ( error) ) . ok ( ) ;
235
+ }
185
236
237
+ continue ;
238
+ }
239
+ }
240
+
241
+ let depth = conn. transaction_depth ;
186
242
let res = if depth > 0 {
187
243
conn. handle
188
244
. exec ( rollback_ansi_transaction_sql ( depth) )
@@ -259,6 +315,7 @@ impl ConnectionWorker {
259
315
query : query. into ( ) ,
260
316
arguments : args. map ( SqliteArguments :: into_static) ,
261
317
persistent,
318
+ transaction_depth : self . transaction_depth ,
262
319
tx,
263
320
} )
264
321
. await
@@ -268,21 +325,56 @@ impl ConnectionWorker {
268
325
}
269
326
270
327
pub ( crate ) async fn begin ( & mut self ) -> Result < ( ) , Error > {
271
- self . oneshot_cmd ( |tx| Command :: Begin { tx } ) . await ?
328
+ let transaction_depth = self . transaction_depth ;
329
+
330
+ self . oneshot_cmd ( |tx| Command :: Begin {
331
+ transaction_depth,
332
+ tx,
333
+ } )
334
+ . await ??;
335
+
336
+ self . transaction_depth += 1 ;
337
+
338
+ Ok ( ( ) )
272
339
}
273
340
274
341
pub ( crate ) async fn commit ( & mut self ) -> Result < ( ) , Error > {
275
- self . oneshot_cmd ( |tx| Command :: Commit { tx } ) . await ?
342
+ let transaction_depth = self . transaction_depth ;
343
+
344
+ self . oneshot_cmd ( |tx| Command :: Commit {
345
+ transaction_depth,
346
+ tx,
347
+ } )
348
+ . await ??;
349
+
350
+ self . transaction_depth -= 1 ;
351
+
352
+ Ok ( ( ) )
276
353
}
277
354
278
355
pub ( crate ) async fn rollback ( & mut self ) -> Result < ( ) , Error > {
279
- self . oneshot_cmd ( |tx| Command :: Rollback { tx : Some ( tx) } )
280
- . await ?
356
+ let transaction_depth = self . transaction_depth ;
357
+
358
+ self . oneshot_cmd ( |tx| Command :: Rollback {
359
+ transaction_depth,
360
+ tx : Some ( tx) ,
361
+ } )
362
+ . await ??;
363
+
364
+ self . transaction_depth -= 1 ;
365
+
366
+ Ok ( ( ) )
281
367
}
282
368
283
369
pub ( crate ) fn start_rollback ( & mut self ) -> Result < ( ) , Error > {
370
+ let transaction_depth = self . transaction_depth ;
371
+ self . transaction_depth -= 1 ;
372
+
284
373
self . command_tx
285
- . send ( Command :: Rollback { tx : None } )
374
+ . send ( Command :: Rollback {
375
+ transaction_depth,
376
+ tx : None ,
377
+ } )
286
378
. map_err ( |_| Error :: WorkerCrashed )
287
379
}
288
380
@@ -387,3 +479,58 @@ fn prepare(conn: &mut ConnectionState, query: &str) -> Result<SqliteStatement<'s
387
479
fn update_cached_statements_size ( conn : & ConnectionState , size : & AtomicUsize ) {
388
480
size. store ( conn. statements . len ( ) , Ordering :: Release ) ;
389
481
}
482
+
483
+ // If a `begin()` is cancelled before completion it might happen that the `Begin` command is still
484
+ // sent to the worker thread but no `Transaction` is created and so there is no way to commit it or
485
+ // roll it back. This function detects such case and handles it by automatically rolling the
486
+ // transaction back.
487
+ //
488
+ // Use only when handling an `Execute`, `Begin` or `Commit` command.
489
+ fn handle_cancelled_begin (
490
+ conn : & mut ConnectionState ,
491
+ expected_transaction_depth : usize ,
492
+ ) -> Result < ( ) , Error > {
493
+ if expected_transaction_depth != conn. transaction_depth {
494
+ if expected_transaction_depth == conn. transaction_depth - 1 {
495
+ let depth = conn. transaction_depth ;
496
+ conn. handle . exec ( rollback_ansi_transaction_sql ( depth) ) ?;
497
+ conn. transaction_depth -= 1 ;
498
+ } else {
499
+ // This would indicate cancelled `commit` or `rollback`, but that can only happen when
500
+ // handling a `Rollback` command because `commit()` / `rollback()` take the
501
+ // transaction by value and so when they are cancelled the transaction is immediately
502
+ // dropped which sends a `Rollback`.
503
+ unreachable ! ( )
504
+ }
505
+ }
506
+
507
+ Ok ( ( ) )
508
+ }
509
+
510
+ // Same as `handle_cancelled_begin` but additionally handles cancelled `commit()` and `rollback()`
511
+ // as well. If `commit()` / `rollback()` is cancelled, it might happen that the corresponding
512
+ // `Commit` / `Rollback` command is still sent to the worker thread but the transaction's `open`
513
+ // flag is not set to `false` which causes another `Rollback` to be sent when the transaction
514
+ // is dropped. This function detects that case and indicates to ignore the superfluous `Rollback`.
515
+ //
516
+ // Use only when handling a `Rollback` command.
517
+ fn handle_cancelled_begin_or_commit_or_rollback (
518
+ conn : & mut ConnectionState ,
519
+ expected_transaction_depth : usize ,
520
+ ) -> Result < bool , Error > {
521
+ if expected_transaction_depth != conn. transaction_depth {
522
+ if expected_transaction_depth + 1 == conn. transaction_depth {
523
+ let depth = conn. transaction_depth ;
524
+ conn. handle . exec ( rollback_ansi_transaction_sql ( depth) ) ?;
525
+ conn. transaction_depth -= 1 ;
526
+
527
+ Ok ( true )
528
+ } else if expected_transaction_depth == conn. transaction_depth + 1 {
529
+ Ok ( false )
530
+ } else {
531
+ unreachable ! ( )
532
+ }
533
+ } else {
534
+ Ok ( true )
535
+ }
536
+ }
0 commit comments