diff --git a/go.sum b/go.sum index 9d0ecd8..b2c218e 100644 --- a/go.sum +++ b/go.sum @@ -1,19 +1,13 @@ github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= -github.com/google/safehtml v0.0.2 h1:ZOt2VXg4x24bW0m2jtzAOkhoXV0iM8vNKc0paByCZqM= -github.com/google/safehtml v0.0.2/go.mod h1:L4KWwDsUJdECRAEpZoBn3O64bQaywRscowZjJAzjHnU= github.com/google/safehtml v0.1.0 h1:EwLKo8qawTKfsi0orxcQAZzu07cICaBeFMegAU9eaT8= github.com/google/safehtml v0.1.0/go.mod h1:L4KWwDsUJdECRAEpZoBn3O64bQaywRscowZjJAzjHnU= -github.com/jba/templatecheck v0.7.0 h1:wjTb/VhGgSFeim5zjWVePBdaMo28X74bGLSABZV+zIA= -github.com/jba/templatecheck v0.7.0/go.mod h1:n1Etw+Rrw1mDDD8dDRsEKTwMZsJ98EkktgNJC6wLUGo= github.com/jba/templatecheck v0.7.1 h1:yOEIFazBEwzdTPYHZF3Pm81NF1ksxx1+vJncSEwvjKc= github.com/jba/templatecheck v0.7.1/go.mod h1:n1Etw+Rrw1mDDD8dDRsEKTwMZsJ98EkktgNJC6wLUGo= github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE= github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= -golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/sqlt.go b/sqlt.go index 0d4759d..c346f4d 100644 --- a/sqlt.go +++ b/sqlt.go @@ -37,14 +37,8 @@ func InTx(ctx context.Context, opts *sql.TxOptions, db *sql.DB, do func(db DB) e } defer func() { - if p := recover(); p != nil { - if err = tx.Rollback(); err != nil { - panic(fmt.Errorf("%w: %v", err, p)) - } else { - panic(p) - } - } else if err != nil { - err = errors.Join(err, tx.Rollback()) + if r := recover(); r != nil || err != nil { + err = errors.Join(err, toErr(r), tx.Rollback()) } else { err = tx.Commit() } @@ -53,6 +47,18 @@ func InTx(ctx context.Context, opts *sql.TxOptions, db *sql.DB, do func(db DB) e return do(tx) } +func toErr(r any) error { + if r == nil { + return nil + } + + if perr, ok := r.(error); ok { + return perr + } + + return fmt.Errorf("%v", r) +} + // Options are used to configure the statements. type Option interface { Configure(config *Config) @@ -470,6 +476,7 @@ func (s *Statement[Param]) Get(ctx context.Context) *Runner { } // Put a Runner into the pool and execute the end option. +// This function should be called within a defer block to recover from panics. func (s *Statement[Param]) Put(err error, runner *Runner) { if s.end != nil { s.end(err, runner) @@ -485,6 +492,10 @@ func (s *Statement[Param]) Exec(ctx context.Context, db DB, param Param) (result runner := s.Get(ctx) defer func() { + if r := recover(); r != nil { + err = errors.Join(err, toErr(r)) + } + s.Put(err, runner) }() @@ -496,6 +507,10 @@ func (s *Statement[Param]) QueryRow(ctx context.Context, db DB, param Param) (ro runner := s.Get(ctx) defer func() { + if r := recover(); r != nil { + err = errors.Join(err, toErr(r)) + } + s.Put(err, runner) }() @@ -507,6 +522,10 @@ func (s *Statement[Param]) Query(ctx context.Context, db DB, param Param) (rows runner := s.Get(ctx) defer func() { + if r := recover(); r != nil { + err = errors.Join(err, toErr(r)) + } + s.Put(err, runner) }() @@ -534,6 +553,8 @@ func (qr *QueryRunner[Dest]) Reset() { func QueryStmt[Param, Dest any](opts ...Option) *QueryStatement[Param, Dest] { _, file, line, _ := runtime.Caller(1) + location := fmt.Sprintf("%s:%d", file, line) + config := &Config{ Placeholder: "?", } @@ -564,12 +585,12 @@ func QueryStmt[Param, Dest any](opts ...Option) *QueryStatement[Param, Dest] { for _, to := range config.TemplateOptions { tpl, err = to(tpl) if err != nil { - panic(fmt.Errorf("location: [%s:%d]: %w", file, line, err)) + panic(fmt.Errorf("location: [%s]: %w", location, err)) } } if err = templatecheck.CheckText(tpl, *new(Param)); err != nil { - panic(fmt.Errorf("location: [%s:%d]: %w", file, line, err)) + panic(fmt.Errorf("location: [%s]: %w", location, err)) } escape(tpl) @@ -584,14 +605,14 @@ func QueryStmt[Param, Dest any](opts ...Option) *QueryStatement[Param, Dest] { New: func() any { t, err := tpl.Clone() if err != nil { - panic(fmt.Errorf("location: [%s:%d]: %w", file, line, err)) + panic(fmt.Errorf("clone: location: [%s]: %w", location, err)) } runner := &QueryRunner[Dest]{ Runner: &Runner{ Template: t, SQL: &SQL{}, - Location: fmt.Sprintf("%s:%d", file, line), + Location: location, }, Dest: new(Dest), } @@ -656,6 +677,7 @@ func (qs *QueryStatement[Param, Dest]) Get(ctx context.Context) *QueryRunner[Des } // Put a QueryRunner into the pool and execute the end option. +// This function should be called within a defer block to recover from panics. func (qs *QueryStatement[Param, Dest]) Put(err error, runner *QueryRunner[Dest]) { if qs.end != nil { qs.end(err, runner.Runner) @@ -671,6 +693,10 @@ func (qs *QueryStatement[Param, Dest]) All(ctx context.Context, db DB, param Par runner := qs.Get(ctx) defer func() { + if r := recover(); r != nil { + err = errors.Join(err, toErr(r)) + } + qs.Put(err, runner) }() @@ -718,6 +744,10 @@ func (qs *QueryStatement[Param, Dest]) One(ctx context.Context, db DB, param Par runner := qs.Get(ctx) defer func() { + if r := recover(); r != nil { + err = errors.Join(err, toErr(r)) + } + qs.Put(err, runner) }() @@ -767,6 +797,10 @@ func (qs *QueryStatement[Param, Dest]) First(ctx context.Context, db DB, param P runner := qs.Get(ctx) defer func() { + if r := recover(); r != nil { + err = errors.Join(err, toErr(r)) + } + qs.Put(err, runner) }() diff --git a/sqlt_test.go b/sqlt_test.go index 12c9b0e..d2d5302 100644 --- a/sqlt_test.go +++ b/sqlt_test.go @@ -5,6 +5,7 @@ import ( "database/sql" "encoding/json" "errors" + "fmt" "io/fs" "testing" "text/template" @@ -158,8 +159,17 @@ func TestOneErrorInTx(t *testing.T) { Title string } + type Key struct{} + stmt := sqlt.QueryStmt[Param, Book]( + sqlt.Start(func(runner *sqlt.Runner) { + runner.Context = context.WithValue(runner.Context, Key{}, "VALUE") + }), sqlt.End(func(err error, runner *sqlt.Runner) { + if runner.Context.Value(Key{}) != any("VALUE") { + t.Fail() + } + if err == nil || err.Error() != "ERROR" { t.Fail() } @@ -979,3 +989,70 @@ func TestMapError(t *testing.T) { t.Fatal(err) } } + +func TestQueryStmtPanic(t *testing.T) { + db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) + if err != nil { + t.Fatal(err) + } + + mock.ExpectQuery("id").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2)) + mock.ExpectQuery("id").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2)) + mock.ExpectQuery("id").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2)) + + stmt := sqlt.QueryStmt[string, int64]( + sqlt.Funcs(template.FuncMap{ + "MyScanner": func(value *int64, str string) sqlt.Scanner { + return sqlt.Scanner{ + SQL: str, + Value: value, + Map: func() error { + panic(errors.New("ERROR")) + }, + } + }, + }), + sqlt.Parse(`{{ MyScanner Dest "id" }}`), + ) + + _, err = stmt.All(context.Background(), db, "TEST") + if err == nil || err.Error() != "ERROR" { + t.Fatal(err) + } + + _, err = stmt.One(context.Background(), db, "TEST") + if err == nil || err.Error() != "ERROR" { + t.Fatal(err) + } + + _, err = stmt.First(context.Background(), db, "TEST") + if err == nil || err.Error() != "ERROR" { + t.Fatal(err) + } +} + +func TestStmtPanic(t *testing.T) { + stmt := sqlt.Stmt[string]( + sqlt.Funcs(template.FuncMap{ + "Test": func() sqlt.Raw { + panic(fmt.Errorf("ERROR")) + }, + }), + sqlt.Parse(`{{ Test }}`), + ) + + _, err := stmt.Exec(context.Background(), nil, "TEST") + if err == nil || err.Error() != `template: :1:3: executing "" at : error calling Test: ERROR` { + t.Fatal(err) + } + + _, err = stmt.Query(context.Background(), nil, "TEST") + if err == nil || err.Error() != `template: :1:3: executing "" at : error calling Test: ERROR` { + t.Fatal(err) + } + + _, err = stmt.QueryRow(context.Background(), nil, "TEST") + if err == nil || err.Error() != `template: :1:3: executing "" at : error calling Test: ERROR` { + t.Fatal(err) + } +}