Skip to content

Commit

Permalink
QueryRow Err & add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wroge committed Dec 9, 2024
1 parent 987acce commit 8f10d13
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 7 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# sqlt: A Go Template-Based SQL Builder and ORM
# A Go Template-Based SQL Builder and ORM

[![go.dev reference](https://img.shields.io/badge/go.dev-reference-007d9c?logo=go&logoColor=white)](https://pkg.go.dev/github.com/wroge/sqlt)
[![GitHub tag (latest SemVer)](https://img.shields.io/github/tag/wroge/sqlt.svg?style=social)](https://github.com/wroge/sqlt/tags)
Expand All @@ -12,7 +12,7 @@ import "github.com/wroge/sqlt"

## Type-Safety without a Build Step

- Define SQL statements at the global level using functions like `New`, `Parse`, `ParseFiles`, `ParseFS`, `ParseGlob`, `Funcs` and `Lookup`.
- Define SQL statements at the global level using options like `New`, `Parse`, `ParseFiles`, `ParseFS`, `ParseGlob`, `Funcs` and `Lookup`.
- **Templates are validated via [jba/templatecheck](https://github.com/jba/templatecheck) during application startup**.
- Execute statements using methods such as `Exec`, `Query` or `QueryRow`.
- Execute query statements using `First`, `One` or `All`.
Expand Down
13 changes: 8 additions & 5 deletions sqlt.go
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,12 @@ func (s *Statement[Param]) QueryRow(ctx context.Context, db DB, param Param) (ro
s.Put(err, runner)
}()

return runner.QueryRow(db, param)
row, err = runner.QueryRow(db, param)
if err != nil {
return row, err
}

return row, row.Err()
}

// Query takes a runner and queries rows.
Expand Down Expand Up @@ -866,11 +871,9 @@ var ident = "___sqlt___"
// copied from here: https://github.com/mhilton/sqltemplate/blob/main/escape.go
func escape(text *template.Template) {
for _, tpl := range text.Templates() {
if tpl.Tree.Root == nil {
continue
if tpl.Tree.Root != nil {
escapeNode(tpl.Tree, tpl.Tree.Root)
}

escapeNode(tpl.Tree, tpl.Tree.Root)
}
}

Expand Down
113 changes: 113 additions & 0 deletions sqlt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ import (
"errors"
"fmt"
"io/fs"
"strings"
"testing"
"text/template"
"time"

"github.com/DATA-DOG/go-sqlmock"
"github.com/spf13/afero"
Expand Down Expand Up @@ -198,6 +200,49 @@ func TestErrNoRows(t *testing.T) {
}
}

func TestErrTooManyRows(t *testing.T) {
db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
if err != nil {
t.Fatal(err)
}

mock.ExpectQuery("SELECT id FROM books WHERE title = ?").WithArgs("TEST").WillReturnRows(
sqlmock.NewRows([]string{"id"}).
AddRow(1).
AddRow(2),
)

type Param struct {
Title string
}

type Book struct {
ID int64
}

config := sqlt.Config{
TemplateOptions: []sqlt.TemplateOption{
sqlt.Funcs(template.FuncMap{
"ScanRawJSON": sqlt.ScanJSON[json.RawMessage],
}),
},
}

stmt := sqlt.QueryStmt[Param, Book](
config,
sqlt.Parse(`
SELECT
{{ ScanInt64 Dest.ID "id" }}
FROM books WHERE title = {{ .Title }}
`),
)

_, err = stmt.One(context.Background(), db, Param{Title: "TEST"})
if !errors.Is(err, sqlt.ErrTooManyRows) {
t.Fatal(err)
}
}

func TestOneErrorInTx(t *testing.T) {
db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
if err != nil {
Expand Down Expand Up @@ -794,6 +839,9 @@ func TestExec(t *testing.T) {

stmt := sqlt.Stmt[Book](
config,
sqlt.Start(func(runner *sqlt.Runner) {
runner.Context, _ = context.WithTimeout(runner.Context, time.Second)

Check failure on line 843 in sqlt_test.go

View workflow job for this annotation

GitHub Actions / check

lostcancel: the cancel function returned by context.WithTimeout should be called, not discarded, to avoid a context leak (govet)
}),
sqlt.Parse(`
INSERT INTO books (id, title, json) VALUES ({{ .ID }}, {{ .Title }}, {{ .JSON }});
`),
Expand Down Expand Up @@ -1136,3 +1184,68 @@ func TestStmtPanic(t *testing.T) {
t.Fatal(err)
}
}

func TestInTxPanic(t *testing.T) {
db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
if err != nil {
t.Fatal(err)
}

mock.ExpectBegin()
mock.ExpectRollback()

err = sqlt.InTx(context.Background(), nil, db, func(db sqlt.DB) error {
panic("unexpected panic")
})

if err == nil || !strings.Contains(err.Error(), "unexpected panic") {
t.Fatalf("expected panic error, got %v", err)
}
}

func TestQueryRunnerQueryRowError(t *testing.T) {
db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
if err != nil {
t.Fatal(err)
}

mock.ExpectQuery("SELECT id FROM table WHERE id = ?").WillReturnError(errors.New("query error"))

stmt := sqlt.Stmt[string](
sqlt.Parse(`SELECT id FROM table WHERE id = {{ . }}`),
)

_, err = stmt.QueryRow(context.Background(), db, "123")
if err == nil || !strings.Contains(err.Error(), "query error") {
t.Fatal(err)
}
}

func TestSQLWriteWhitespace(t *testing.T) {
sql := &sqlt.SQL{}
_, err := sql.Write([]byte(" "))
if err != nil {
t.Fatal(err)
}
if sql.String() != "" {
t.Fatal("expected empty SQL string")
}
}

func TestStmtExecMissingParams(t *testing.T) {
db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
if err != nil {
t.Fatal(err)
}

mock.ExpectExec("INSERT INTO table (column) VALUES (?)").WithArgs().WillReturnError(errors.New("missing parameters"))

stmt := sqlt.Stmt[string](
sqlt.Parse(`INSERT INTO table (column) VALUES ({{ . }})`),
)

_, err = stmt.Exec(context.Background(), db, "")
if err == nil || !strings.Contains(err.Error(), "missing parameters") {
t.Fatal(err)
}
}

0 comments on commit 8f10d13

Please sign in to comment.