Skip to content

Commit

Permalink
recover from panics always
Browse files Browse the repository at this point in the history
  • Loading branch information
wroge committed Dec 8, 2024
1 parent c868584 commit 2775598
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 18 deletions.
6 changes: 0 additions & 6 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,19 +1,13 @@
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
github.com/google/safehtml v0.0.2 h1:ZOt2VXg4x24bW0m2jtzAOkhoXV0iM8vNKc0paByCZqM=
github.com/google/safehtml v0.0.2/go.mod h1:L4KWwDsUJdECRAEpZoBn3O64bQaywRscowZjJAzjHnU=
github.com/google/safehtml v0.1.0 h1:EwLKo8qawTKfsi0orxcQAZzu07cICaBeFMegAU9eaT8=
github.com/google/safehtml v0.1.0/go.mod h1:L4KWwDsUJdECRAEpZoBn3O64bQaywRscowZjJAzjHnU=
github.com/jba/templatecheck v0.7.0 h1:wjTb/VhGgSFeim5zjWVePBdaMo28X74bGLSABZV+zIA=
github.com/jba/templatecheck v0.7.0/go.mod h1:n1Etw+Rrw1mDDD8dDRsEKTwMZsJ98EkktgNJC6wLUGo=
github.com/jba/templatecheck v0.7.1 h1:yOEIFazBEwzdTPYHZF3Pm81NF1ksxx1+vJncSEwvjKc=
github.com/jba/templatecheck v0.7.1/go.mod h1:n1Etw+Rrw1mDDD8dDRsEKTwMZsJ98EkktgNJC6wLUGo=
github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE=
github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc=
golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
58 changes: 46 additions & 12 deletions sqlt.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,8 @@ func InTx(ctx context.Context, opts *sql.TxOptions, db *sql.DB, do func(db DB) e
}

defer func() {
if p := recover(); p != nil {
if err = tx.Rollback(); err != nil {
panic(fmt.Errorf("%w: %v", err, p))
} else {
panic(p)
}
} else if err != nil {
err = errors.Join(err, tx.Rollback())
if r := recover(); r != nil || err != nil {
err = errors.Join(err, toErr(r), tx.Rollback())
} else {
err = tx.Commit()
}
Expand All @@ -53,6 +47,18 @@ func InTx(ctx context.Context, opts *sql.TxOptions, db *sql.DB, do func(db DB) e
return do(tx)
}

func toErr(r any) error {
if r == nil {
return nil
}

if perr, ok := r.(error); ok {
return perr
}

return fmt.Errorf("%v", r)
}

// Options are used to configure the statements.
type Option interface {
Configure(config *Config)
Expand Down Expand Up @@ -470,6 +476,7 @@ func (s *Statement[Param]) Get(ctx context.Context) *Runner {
}

// Put a Runner into the pool and execute the end option.
// This function should be called within a defer block to recover from panics.
func (s *Statement[Param]) Put(err error, runner *Runner) {
if s.end != nil {
s.end(err, runner)
Expand All @@ -485,6 +492,10 @@ func (s *Statement[Param]) Exec(ctx context.Context, db DB, param Param) (result
runner := s.Get(ctx)

defer func() {
if r := recover(); r != nil {
err = errors.Join(err, toErr(r))
}

s.Put(err, runner)
}()

Expand All @@ -496,6 +507,10 @@ func (s *Statement[Param]) QueryRow(ctx context.Context, db DB, param Param) (ro
runner := s.Get(ctx)

defer func() {
if r := recover(); r != nil {
err = errors.Join(err, toErr(r))
}

s.Put(err, runner)
}()

Expand All @@ -507,6 +522,10 @@ func (s *Statement[Param]) Query(ctx context.Context, db DB, param Param) (rows
runner := s.Get(ctx)

defer func() {
if r := recover(); r != nil {
err = errors.Join(err, toErr(r))
}

s.Put(err, runner)
}()

Expand Down Expand Up @@ -534,6 +553,8 @@ func (qr *QueryRunner[Dest]) Reset() {
func QueryStmt[Param, Dest any](opts ...Option) *QueryStatement[Param, Dest] {
_, file, line, _ := runtime.Caller(1)

location := fmt.Sprintf("%s:%d", file, line)

config := &Config{
Placeholder: "?",
}
Expand Down Expand Up @@ -564,12 +585,12 @@ func QueryStmt[Param, Dest any](opts ...Option) *QueryStatement[Param, Dest] {
for _, to := range config.TemplateOptions {
tpl, err = to(tpl)
if err != nil {
panic(fmt.Errorf("location: [%s:%d]: %w", file, line, err))
panic(fmt.Errorf("location: [%s]: %w", location, err))
}
}

if err = templatecheck.CheckText(tpl, *new(Param)); err != nil {
panic(fmt.Errorf("location: [%s:%d]: %w", file, line, err))
panic(fmt.Errorf("location: [%s]: %w", location, err))
}

escape(tpl)
Expand All @@ -584,14 +605,14 @@ func QueryStmt[Param, Dest any](opts ...Option) *QueryStatement[Param, Dest] {
New: func() any {
t, err := tpl.Clone()
if err != nil {
panic(fmt.Errorf("location: [%s:%d]: %w", file, line, err))
panic(fmt.Errorf("clone: location: [%s]: %w", location, err))
}

runner := &QueryRunner[Dest]{
Runner: &Runner{
Template: t,
SQL: &SQL{},
Location: fmt.Sprintf("%s:%d", file, line),
Location: location,
},
Dest: new(Dest),
}
Expand Down Expand Up @@ -656,6 +677,7 @@ func (qs *QueryStatement[Param, Dest]) Get(ctx context.Context) *QueryRunner[Des
}

// Put a QueryRunner into the pool and execute the end option.
// This function should be called within a defer block to recover from panics.
func (qs *QueryStatement[Param, Dest]) Put(err error, runner *QueryRunner[Dest]) {
if qs.end != nil {
qs.end(err, runner.Runner)
Expand All @@ -671,6 +693,10 @@ func (qs *QueryStatement[Param, Dest]) All(ctx context.Context, db DB, param Par
runner := qs.Get(ctx)

defer func() {
if r := recover(); r != nil {
err = errors.Join(err, toErr(r))
}

qs.Put(err, runner)
}()

Expand Down Expand Up @@ -718,6 +744,10 @@ func (qs *QueryStatement[Param, Dest]) One(ctx context.Context, db DB, param Par
runner := qs.Get(ctx)

defer func() {
if r := recover(); r != nil {
err = errors.Join(err, toErr(r))
}

qs.Put(err, runner)
}()

Expand Down Expand Up @@ -767,6 +797,10 @@ func (qs *QueryStatement[Param, Dest]) First(ctx context.Context, db DB, param P
runner := qs.Get(ctx)

defer func() {
if r := recover(); r != nil {
err = errors.Join(err, toErr(r))
}

qs.Put(err, runner)
}()

Expand Down
77 changes: 77 additions & 0 deletions sqlt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"database/sql"
"encoding/json"
"errors"
"fmt"
"io/fs"
"testing"
"text/template"
Expand Down Expand Up @@ -158,8 +159,17 @@ func TestOneErrorInTx(t *testing.T) {
Title string
}

type Key struct{}

stmt := sqlt.QueryStmt[Param, Book](
sqlt.Start(func(runner *sqlt.Runner) {
runner.Context = context.WithValue(runner.Context, Key{}, "VALUE")
}),
sqlt.End(func(err error, runner *sqlt.Runner) {
if runner.Context.Value(Key{}) != any("VALUE") {
t.Fail()
}

if err == nil || err.Error() != "ERROR" {
t.Fail()
}
Expand Down Expand Up @@ -979,3 +989,70 @@ func TestMapError(t *testing.T) {
t.Fatal(err)
}
}

func TestQueryStmtPanic(t *testing.T) {
db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
if err != nil {
t.Fatal(err)
}

mock.ExpectQuery("id").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2))
mock.ExpectQuery("id").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2))
mock.ExpectQuery("id").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2))

stmt := sqlt.QueryStmt[string, int64](
sqlt.Funcs(template.FuncMap{
"MyScanner": func(value *int64, str string) sqlt.Scanner {
return sqlt.Scanner{
SQL: str,
Value: value,
Map: func() error {
panic(errors.New("ERROR"))
},
}
},
}),
sqlt.Parse(`{{ MyScanner Dest "id" }}`),
)

_, err = stmt.All(context.Background(), db, "TEST")
if err == nil || err.Error() != "ERROR" {
t.Fatal(err)
}

_, err = stmt.One(context.Background(), db, "TEST")
if err == nil || err.Error() != "ERROR" {
t.Fatal(err)
}

_, err = stmt.First(context.Background(), db, "TEST")
if err == nil || err.Error() != "ERROR" {
t.Fatal(err)
}
}

func TestStmtPanic(t *testing.T) {
stmt := sqlt.Stmt[string](
sqlt.Funcs(template.FuncMap{
"Test": func() sqlt.Raw {
panic(fmt.Errorf("ERROR"))
},
}),
sqlt.Parse(`{{ Test }}`),
)

_, err := stmt.Exec(context.Background(), nil, "TEST")
if err == nil || err.Error() != `template: :1:3: executing "" at <Test>: error calling Test: ERROR` {
t.Fatal(err)
}

_, err = stmt.Query(context.Background(), nil, "TEST")
if err == nil || err.Error() != `template: :1:3: executing "" at <Test>: error calling Test: ERROR` {
t.Fatal(err)
}

_, err = stmt.QueryRow(context.Background(), nil, "TEST")
if err == nil || err.Error() != `template: :1:3: executing "" at <Test>: error calling Test: ERROR` {
t.Fatal(err)
}
}

0 comments on commit 2775598

Please sign in to comment.