diff --git a/README.md b/README.md index 618e5f8..a065482 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# sqlt: A Go Template-Based SQL Builder and ORM +# A Go Template-Based SQL Builder and ORM [![go.dev reference](https://img.shields.io/badge/go.dev-reference-007d9c?logo=go&logoColor=white)](https://pkg.go.dev/github.com/wroge/sqlt) [![GitHub tag (latest SemVer)](https://img.shields.io/github/tag/wroge/sqlt.svg?style=social)](https://github.com/wroge/sqlt/tags) @@ -12,7 +12,7 @@ import "github.com/wroge/sqlt" ## Type-Safety without a Build Step -- Define SQL statements at the global level using functions like `New`, `Parse`, `ParseFiles`, `ParseFS`, `ParseGlob`, `Funcs` and `Lookup`. +- Define SQL statements at the global level using options like `New`, `Parse`, `ParseFiles`, `ParseFS`, `ParseGlob`, `Funcs` and `Lookup`. - **Templates are validated via [jba/templatecheck](https://github.com/jba/templatecheck) during application startup**. - Execute statements using methods such as `Exec`, `Query` or `QueryRow`. - Execute query statements using `First`, `One` or `All`. diff --git a/sqlt.go b/sqlt.go index e0a4790..15f0d01 100644 --- a/sqlt.go +++ b/sqlt.go @@ -505,7 +505,12 @@ func (s *Statement[Param]) QueryRow(ctx context.Context, db DB, param Param) (ro s.Put(err, runner) }() - return runner.QueryRow(db, param) + row, err = runner.QueryRow(db, param) + if err != nil { + return row, err + } + + return row, row.Err() } // Query takes a runner and queries rows. @@ -866,11 +871,9 @@ var ident = "___sqlt___" // copied from here: https://github.com/mhilton/sqltemplate/blob/main/escape.go func escape(text *template.Template) { for _, tpl := range text.Templates() { - if tpl.Tree.Root == nil { - continue + if tpl.Tree.Root != nil { + escapeNode(tpl.Tree, tpl.Tree.Root) } - - escapeNode(tpl.Tree, tpl.Tree.Root) } } diff --git a/sqlt_test.go b/sqlt_test.go index 0c2b11a..627d206 100644 --- a/sqlt_test.go +++ b/sqlt_test.go @@ -7,8 +7,10 @@ import ( "errors" "fmt" "io/fs" + "strings" "testing" "text/template" + "time" "github.com/DATA-DOG/go-sqlmock" "github.com/spf13/afero" @@ -198,6 +200,49 @@ func TestErrNoRows(t *testing.T) { } } +func TestErrTooManyRows(t *testing.T) { + db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) + if err != nil { + t.Fatal(err) + } + + mock.ExpectQuery("SELECT id FROM books WHERE title = ?").WithArgs("TEST").WillReturnRows( + sqlmock.NewRows([]string{"id"}). + AddRow(1). + AddRow(2), + ) + + type Param struct { + Title string + } + + type Book struct { + ID int64 + } + + config := sqlt.Config{ + TemplateOptions: []sqlt.TemplateOption{ + sqlt.Funcs(template.FuncMap{ + "ScanRawJSON": sqlt.ScanJSON[json.RawMessage], + }), + }, + } + + stmt := sqlt.QueryStmt[Param, Book]( + config, + sqlt.Parse(` + SELECT + {{ ScanInt64 Dest.ID "id" }} + FROM books WHERE title = {{ .Title }} + `), + ) + + _, err = stmt.One(context.Background(), db, Param{Title: "TEST"}) + if !errors.Is(err, sqlt.ErrTooManyRows) { + t.Fatal(err) + } +} + func TestOneErrorInTx(t *testing.T) { db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) if err != nil { @@ -794,6 +839,9 @@ func TestExec(t *testing.T) { stmt := sqlt.Stmt[Book]( config, + sqlt.Start(func(runner *sqlt.Runner) { + runner.Context, _ = context.WithTimeout(runner.Context, time.Second) + }), sqlt.Parse(` INSERT INTO books (id, title, json) VALUES ({{ .ID }}, {{ .Title }}, {{ .JSON }}); `), @@ -1136,3 +1184,68 @@ func TestStmtPanic(t *testing.T) { t.Fatal(err) } } + +func TestInTxPanic(t *testing.T) { + db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) + if err != nil { + t.Fatal(err) + } + + mock.ExpectBegin() + mock.ExpectRollback() + + err = sqlt.InTx(context.Background(), nil, db, func(db sqlt.DB) error { + panic("unexpected panic") + }) + + if err == nil || !strings.Contains(err.Error(), "unexpected panic") { + t.Fatalf("expected panic error, got %v", err) + } +} + +func TestQueryRunnerQueryRowError(t *testing.T) { + db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) + if err != nil { + t.Fatal(err) + } + + mock.ExpectQuery("SELECT id FROM table WHERE id = ?").WillReturnError(errors.New("query error")) + + stmt := sqlt.Stmt[string]( + sqlt.Parse(`SELECT id FROM table WHERE id = {{ . }}`), + ) + + _, err = stmt.QueryRow(context.Background(), db, "123") + if err == nil || !strings.Contains(err.Error(), "query error") { + t.Fatal(err) + } +} + +func TestSQLWriteWhitespace(t *testing.T) { + sql := &sqlt.SQL{} + _, err := sql.Write([]byte(" ")) + if err != nil { + t.Fatal(err) + } + if sql.String() != "" { + t.Fatal("expected empty SQL string") + } +} + +func TestStmtExecMissingParams(t *testing.T) { + db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)) + if err != nil { + t.Fatal(err) + } + + mock.ExpectExec("INSERT INTO table (column) VALUES (?)").WithArgs().WillReturnError(errors.New("missing parameters")) + + stmt := sqlt.Stmt[string]( + sqlt.Parse(`INSERT INTO table (column) VALUES ({{ . }})`), + ) + + _, err = stmt.Exec(context.Background(), db, "") + if err == nil || !strings.Contains(err.Error(), "missing parameters") { + t.Fatal(err) + } +}