Skip to content

Commit 40be06d

Browse files
authored
sqlite: fix handling of "early" context cancellations (#103)
The optional cancellation hook added in #85 and #86 has a deficiency: If the query successfully completes before its context ends, the cleanup context would unconditionally trigger an interrupt on the database connection. That interrupt could race with a subsequent query on that connection, and cause a spurious cancellation. To fix this, separate the cleanup context from the input context, and only effect an interrupt if the _input_ context terminates before cleanup occurs. Only if cleanup is definitively prior to the query finishing will we effect an explicit interrupt. Also: - Add a test that demonstrates the original problem. - Update Go version to 1.21
1 parent 38d2414 commit 40be06d

File tree

3 files changed

+72
-10
lines changed

3 files changed

+72
-10
lines changed

go.mod

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
module github.com/tailscale/sqlite
22

3-
go 1.20
3+
go 1.21

sqlite.go

+37-8
Original file line numberDiff line numberDiff line change
@@ -490,14 +490,25 @@ func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive
490490
if err := s.bindAll(args); err != nil {
491491
return nil, s.reserr("Stmt.Exec(Bind)", err)
492492
}
493+
var done chan struct{}
493494
if ctx.Value(queryCancelKey{}) != nil {
494-
var cancel context.CancelFunc
495-
ctx, cancel = context.WithCancel(ctx)
496-
defer cancel()
495+
done = make(chan struct{})
496+
pctx, pcancel := context.WithCancel(ctx)
497+
defer pcancel() // to make the AfterFunc fire and close(done)
497498

498499
db := s.stmt.DBHandle()
499-
go func() { <-ctx.Done(); db.Interrupt() }()
500+
stop := context.AfterFunc(pctx, func() {
501+
defer close(done)
502+
if ctx.Err() != nil {
503+
db.Interrupt()
504+
}
505+
})
506+
// In the event we get an error from the query's initial execution of
507+
// sqlite3_step below and exit early, dissociate the cancellation since
508+
// we don't want it to fire and potentially stop a later execution.
509+
defer stop()
500510
}
511+
501512
row, lastInsertRowID, changes, duration, err := s.stmt.StepResult()
502513
s.bound = false // StepResult resets the query
503514
err = s.reserr("Stmt.Exec", err)
@@ -508,6 +519,9 @@ func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive
508519
return nil, err
509520
}
510521
_ = row // TODO: return error if exec on query which returns rows?
522+
if done != nil {
523+
<-done
524+
}
511525
return getStmtResult(lastInsertRowID, changes), nil
512526
}
513527

@@ -549,12 +563,23 @@ func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driv
549563
return nil, err
550564
}
551565
cancel := func() {}
566+
var done chan struct{}
552567
if ctx.Value(queryCancelKey{}) != nil {
553-
ctx, cancel = context.WithCancel(ctx)
568+
done = make(chan struct{})
569+
pctx, pcancel := context.WithCancel(ctx)
570+
cancel = pcancel
554571
db := s.stmt.DBHandle()
555-
go func() { <-ctx.Done(); db.Interrupt() }()
572+
context.AfterFunc(pctx, func() {
573+
defer close(done)
574+
if ctx.Err() != nil {
575+
db.Interrupt()
576+
}
577+
})
578+
// In this case we do not have an early exit, so we don't need to
579+
// dissociate the cancellation handler: If the caller gets an error, it
580+
// will explicitly trigger the cancellation and wait in (*rows).Close.
556581
}
557-
return &rows{stmt: s, cancel: cancel}, nil
582+
return &rows{stmt: s, cancel: cancel, done: done}, nil
558583
}
559584

560585
func (s *stmt) resetAndClear() error {
@@ -732,6 +757,7 @@ type rows struct {
732757
stmt *stmt
733758
closed bool
734759
cancel context.CancelFunc // call when query ends
760+
done chan struct{} // either nil, or closed when cancellation is done
735761

736762
// colType is the column types for Step to fill on each row. We only use 23
737763
// as it packs well with the closed bool byte above (24 bytes total, same as
@@ -779,7 +805,10 @@ func (r *rows) Close() error {
779805
return ErrClosed
780806
}
781807
r.closed = true
782-
defer r.cancel()
808+
r.cancel()
809+
if r.done != nil {
810+
<-r.done
811+
}
783812
if err := r.stmt.resetAndClear(); err != nil {
784813
return r.stmt.reserr("Rows.Close(Reset)", err)
785814
}

sqlite_test.go

+34-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"os"
1414
"reflect"
1515
"runtime"
16+
"strconv"
1617
"strings"
1718
"sync"
1819
"testing"
@@ -418,7 +419,8 @@ func TestWithQueryCancel(t *testing.T) {
418419

419420
rows, err := db.QueryContext(WithQueryCancel(ctx), testQuery)
420421
if err != nil {
421-
t.Fatalf("QueryContext: unexpected error: %v", err)
422+
t.Errorf("QueryContext: unexpected error: %v", err)
423+
return
422424
}
423425
for rows.Next() {
424426
t.Error("Next result available before timeout")
@@ -438,6 +440,37 @@ func TestWithQueryCancel(t *testing.T) {
438440
}
439441
}
440442

443+
func TestWithQueryCancel_OK(t *testing.T) {
444+
db := openTestDB(t)
445+
446+
for i := 0; i < 100; i++ {
447+
t.Run(strconv.Itoa(i+1), func(t *testing.T) {
448+
// Set a timeout that is much longer than the expected runtime of the query.
449+
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
450+
defer cancel()
451+
452+
rows, err := db.QueryContext(WithQueryCancel(ctx), `select 1`)
453+
if err != nil {
454+
t.Fatalf("QueryContext: unexpected error: %v", err)
455+
}
456+
for rows.Next() {
457+
var z int
458+
if err := rows.Scan(&z); err != nil {
459+
t.Fatalf("Scan: %v", err)
460+
} else if z != 1 {
461+
t.Errorf("Scan: got %d, want 1", z)
462+
}
463+
}
464+
if err := rows.Err(); err != nil {
465+
t.Errorf("Err reported %v", err)
466+
}
467+
if err := rows.Close(); err != nil {
468+
t.Errorf("Close reported %v", err)
469+
}
470+
})
471+
}
472+
}
473+
441474
func TestErrors(t *testing.T) {
442475
db := openTestDB(t)
443476
exec(t, db, "CREATE TABLE t (c)")

0 commit comments

Comments
 (0)