Skip to content

Commit af1946c

Browse files
authored
sqlite: fix cancellation cleanup for ExecContext (#104)
I did not test the changes in #103 sufficiently, and missed a corner case in ExecContext: We need to make sure the context cleanup happens prior to the synchronization, not in a defer. Add a test to actually trigger this case, and fix the cleanup so it happens before return as intended. Add some missing documentation too.
1 parent 40be06d commit af1946c

File tree

2 files changed

+104
-61
lines changed

2 files changed

+104
-61
lines changed

sqlite.go

+14-12
Original file line numberDiff line numberDiff line change
@@ -490,23 +490,25 @@ func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive
490490
if err := s.bindAll(args); err != nil {
491491
return nil, s.reserr("Stmt.Exec(Bind)", err)
492492
}
493-
var done chan struct{}
493+
494494
if ctx.Value(queryCancelKey{}) != nil {
495-
done = make(chan struct{})
495+
done := make(chan struct{})
496496
pctx, pcancel := context.WithCancel(ctx)
497-
defer pcancel() // to make the AfterFunc fire and close(done)
498-
499497
db := s.stmt.DBHandle()
500-
stop := context.AfterFunc(pctx, func() {
498+
context.AfterFunc(pctx, func() {
501499
defer close(done)
500+
501+
// Note: We respond to cancellation on the primary context (ctx) not
502+
// the cleanup context (pctx).
502503
if ctx.Err() != nil {
503504
db.Interrupt()
504505
}
505506
})
506-
// In the event we get an error from the query's initial execution of
507-
// sqlite3_step below and exit early, dissociate the cancellation since
508-
// we don't want it to fire and potentially stop a later execution.
509-
defer stop()
507+
508+
// We must wait prior to returning to ensure a cancellation context (if
509+
// present) has completed, so that a cancellation cannot outlast this
510+
// request and fire during a later execution.
511+
defer func() { pcancel(); <-done }()
510512
}
511513

512514
row, lastInsertRowID, changes, duration, err := s.stmt.StepResult()
@@ -519,9 +521,6 @@ func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive
519521
return nil, err
520522
}
521523
_ = row // TODO: return error if exec on query which returns rows?
522-
if done != nil {
523-
<-done
524-
}
525524
return getStmtResult(lastInsertRowID, changes), nil
526525
}
527526

@@ -571,6 +570,9 @@ func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driv
571570
db := s.stmt.DBHandle()
572571
context.AfterFunc(pctx, func() {
573572
defer close(done)
573+
574+
// Note: We respond to cancellation on the primary context (ctx) not
575+
// the cleanup context (pctx).
574576
if ctx.Err() != nil {
575577
db.Interrupt()
576578
}

sqlite_test.go

+90-49
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,7 @@ func TestWithPersist(t *testing.T) {
400400
}
401401
}
402402

403-
func TestWithQueryCancel(t *testing.T) {
403+
func TestWithQueryCancel_Timeout(t *testing.T) {
404404
// This test query runs forever until interrupted.
405405
const testQuery = `WITH RECURSIVE inf(n) AS (
406406
SELECT 1
@@ -410,65 +410,106 @@ func TestWithQueryCancel(t *testing.T) {
410410

411411
db := openTestDB(t)
412412

413-
done := make(chan struct{})
414-
go func() {
415-
defer close(done)
413+
t.Run("QueryContext", func(t *testing.T) {
414+
done := make(chan struct{})
415+
go func() {
416+
defer close(done)
416417

417-
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
418-
defer cancel()
418+
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
419+
defer cancel()
419420

420-
rows, err := db.QueryContext(WithQueryCancel(ctx), testQuery)
421-
if err != nil {
422-
t.Errorf("QueryContext: unexpected error: %v", err)
423-
return
424-
}
425-
for rows.Next() {
426-
t.Error("Next result available before timeout")
427-
}
428-
if err := rows.Err(); err == nil {
429-
t.Error("Rows did not report an error")
430-
} else if !strings.Contains(err.Error(), "SQLITE_INTERRUPT") {
431-
t.Errorf("Rows err=%v, want SQLITE_INTERRUPT", err)
421+
rows, err := db.QueryContext(WithQueryCancel(ctx), testQuery)
422+
if err != nil {
423+
t.Errorf("QueryContext: unexpected error: %v", err)
424+
return
425+
}
426+
for rows.Next() {
427+
t.Error("Next result available before timeout")
428+
}
429+
if err := rows.Err(); err == nil {
430+
t.Error("Rows did not report an error")
431+
} else if !strings.Contains(err.Error(), "SQLITE_INTERRUPT") {
432+
t.Errorf("Rows err=%v, want SQLITE_INTERRUPT", err)
433+
}
434+
}()
435+
436+
select {
437+
case <-done:
438+
// OK
439+
case <-time.After(30 * time.Second):
440+
t.Fatal("Timeout waiting for query to end")
432441
}
433-
}()
442+
})
443+
t.Run("ExecContext", func(t *testing.T) {
444+
done := make(chan struct{})
445+
go func() {
446+
defer close(done)
434447

435-
select {
436-
case <-done:
437-
// OK
438-
case <-time.After(30 * time.Second):
439-
t.Fatal("Timeout waiting for query to end")
440-
}
448+
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
449+
defer cancel()
450+
451+
res, err := db.ExecContext(WithQueryCancel(ctx), testQuery)
452+
if err == nil {
453+
t.Errorf("ExecContext: got %v, want error", res)
454+
} else if !strings.Contains(err.Error(), "SQLITE_INTERRUPT") {
455+
t.Errorf("ExecContext err=%v, want SQLITE_INTERRUPT", err)
456+
}
457+
}()
458+
459+
select {
460+
case <-done:
461+
// OK
462+
case <-time.After(30 * time.Second):
463+
t.Fatal("Timeout waiting for query to end")
464+
}
465+
})
441466
}
442467

443468
func TestWithQueryCancel_OK(t *testing.T) {
444469
db := openTestDB(t)
445470

446-
for i := 0; i < 100; i++ {
447-
t.Run(strconv.Itoa(i+1), func(t *testing.T) {
448-
// Set a timeout that is much longer than the expected runtime of the query.
449-
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
450-
defer cancel()
471+
t.Run("QueryContext", func(t *testing.T) {
472+
for i := 0; i < 100; i++ {
473+
t.Run(strconv.Itoa(i+1), func(t *testing.T) {
474+
// Set a timeout that is much longer than the expected runtime of the query.
475+
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
476+
defer cancel()
451477

452-
rows, err := db.QueryContext(WithQueryCancel(ctx), `select 1`)
453-
if err != nil {
454-
t.Fatalf("QueryContext: unexpected error: %v", err)
455-
}
456-
for rows.Next() {
457-
var z int
458-
if err := rows.Scan(&z); err != nil {
459-
t.Fatalf("Scan: %v", err)
460-
} else if z != 1 {
461-
t.Errorf("Scan: got %d, want 1", z)
478+
rows, err := db.QueryContext(WithQueryCancel(ctx), `select 1`)
479+
if err != nil {
480+
t.Fatalf("QueryContext: unexpected error: %v", err)
462481
}
463-
}
464-
if err := rows.Err(); err != nil {
465-
t.Errorf("Err reported %v", err)
466-
}
467-
if err := rows.Close(); err != nil {
468-
t.Errorf("Close reported %v", err)
469-
}
470-
})
471-
}
482+
for rows.Next() {
483+
var z int
484+
if err := rows.Scan(&z); err != nil {
485+
t.Fatalf("Scan: %v", err)
486+
} else if z != 1 {
487+
t.Errorf("Scan: got %d, want 1", z)
488+
}
489+
}
490+
if err := rows.Err(); err != nil {
491+
t.Errorf("Err reported %v", err)
492+
}
493+
if err := rows.Close(); err != nil {
494+
t.Errorf("Close reported %v", err)
495+
}
496+
})
497+
}
498+
})
499+
t.Run("ExecContext", func(t *testing.T) {
500+
for i := 0; i < 100; i++ {
501+
t.Run(strconv.Itoa(i+1), func(t *testing.T) {
502+
// Set a timeout that is much longer than the expected runtime of the query.
503+
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
504+
defer cancel()
505+
506+
_, err := db.ExecContext(WithQueryCancel(ctx), `select 1`)
507+
if err != nil && !strings.Contains(err.Error(), "SQLITE_INTERRUPT") {
508+
t.Errorf("ExecContext: unexpected error: %v", err)
509+
}
510+
})
511+
}
512+
})
472513
}
473514

474515
func TestErrors(t *testing.T) {

0 commit comments

Comments
 (0)