Skip to content

Commit

Permalink
fix(pgdialect): postgres syntax errors for slices of pointers and jso…
Browse files Browse the repository at this point in the history
…n arrays #877
  • Loading branch information
rfarrjr committed Jan 27, 2025
1 parent dbae5e6 commit 1422b77
Show file tree
Hide file tree
Showing 7 changed files with 219 additions and 12 deletions.
2 changes: 2 additions & 0 deletions dialect/pgdialect/array.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ func (d *Dialect) arrayElemAppender(typ reflect.Type) schema.AppenderFunc {
if typ.Elem().Kind() == reflect.Uint8 {
return appendBytesElemValue
}
case reflect.Ptr:
return schema.PtrAppender(d.arrayElemAppender(typ.Elem()))
}
return schema.Appender(d, typ)
}
Expand Down
45 changes: 33 additions & 12 deletions dialect/pgdialect/array_parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,18 @@ type arrayParser struct {

elem []byte
err error

isJson bool
}

func newArrayParser(b []byte) *arrayParser {
p := new(arrayParser)

if len(b) < 2 || b[0] != '{' || b[len(b)-1] != '}' {
if len(b) < 2 || (b[0] != '{' && b[0] != '[') || (b[len(b)-1] != '}' && b[len(b)-1] != ']') {
p.err = fmt.Errorf("pgdialect: can't parse array: %q", b)
return p
}
p.isJson = b[0] == '['

p.p.Reset(b[1 : len(b)-1])
return p
Expand Down Expand Up @@ -51,7 +54,7 @@ func (p *arrayParser) readNext() error {
}

switch ch {
case '}':
case '}', ']':
return io.EOF
case '"':
b, err := p.p.ReadSubstring(ch)
Expand All @@ -78,16 +81,34 @@ func (p *arrayParser) readNext() error {
p.elem = rng
return nil
default:
lit := p.p.ReadLiteral(ch)
if bytes.Equal(lit, []byte("NULL")) {
lit = nil
if ch == '{' && p.isJson {
json, err := p.p.ReadJSON()
if err != nil {
return err
}

for {
if p.p.Peek() == ',' || p.p.Peek() == ' ' {
p.p.Advance()
} else {
break
}
}

p.elem = json
return nil
} else {
lit := p.p.ReadLiteral(ch)
if bytes.Equal(lit, []byte("NULL")) {
lit = nil
}

if p.p.Peek() == ',' {
p.p.Advance()
}

p.elem = lit
return nil
}

if p.p.Peek() == ',' {
p.p.Advance()
}

p.elem = lit
return nil
}
}
4 changes: 4 additions & 0 deletions dialect/pgdialect/array_parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ func TestArrayParser(t *testing.T) {
{`{"1","2"}`, []string{"1", "2"}},
{`{"{1}","{2}"}`, []string{"{1}", "{2}"}},
{`{[1,2),[3,4)}`, []string{"[1,2)", "[3,4)"}},

{`[]`, []string{}},
{`[{"'\"[]"}]`, []string{`{"'\"[]"}`}},
{`[{"id": 1}, {"id":2, "name":"bob"}]`, []string{"{\"id\": 1}", "{\"id\":2, \"name\":\"bob\"}"}},
}

for i, test := range tests {
Expand Down
54 changes: 54 additions & 0 deletions dialect/pgdialect/array_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package pgdialect

import (
"testing"

"github.com/uptrace/bun/schema"
)

func ptr[T any](v T) *T {
return &v
}

func TestArrayAppend(t *testing.T) {
tcases := []struct {
input interface{}
out string
}{
{
input: []byte{1, 2},
out: `'{1,2}'`,
},
{
input: []*byte{ptr(byte(1)), ptr(byte(2))},
out: `'{1,2}'`,
},
{
input: []int{1, 2},
out: `'{1,2}'`,
},
{
input: []*int{ptr(1), ptr(2)},
out: `'{1,2}'`,
},
{
input: []string{"foo", "bar"},
out: `'{"foo","bar"}'`,
},
{
input: []*string{ptr("foo"), ptr("bar")},
out: `'{"foo","bar"}'`,
},
}

for _, tcase := range tcases {
out, err := Array(tcase.input).AppendQuery(schema.NewFormatter(New()), []byte{})
if err != nil {
t.Fatal(err)
}

if string(out) != tcase.out {
t.Errorf("expected output to be %s, was %s", tcase.out, string(out))
}
}
}
36 changes: 36 additions & 0 deletions dialect/pgdialect/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,39 @@ func (p *pgparser) ReadRange(ch byte) ([]byte, error) {

return p.buf, nil
}

func (p *pgparser) ReadJSON() ([]byte, error) {
p.Unread()

c, err := p.ReadByte()
if err != nil {
return nil, err
}

p.buf = p.buf[:0]

depth := 0
for {
switch c {
case '{':
depth++
case '}':
depth--
}

p.buf = append(p.buf, c)

if depth == 0 {
break
}

next, err := p.ReadByte()
if err != nil {
return nil, err
}

c = next
}

return p.buf, nil
}
4 changes: 4 additions & 0 deletions dialect/pgdialect/sqltype.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ func fieldSQLType(field *schema.Field) string {
}

func sqlType(typ reflect.Type) string {
if typ.Kind() == reflect.Ptr {
typ = typ.Elem()
}

switch typ {
case nullStringType: // typ.Kind() == reflect.Struct, test for exact match
return sqltype.VarChar
Expand Down
86 changes: 86 additions & 0 deletions internal/dbtest/pg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/uptrace/bun"
"github.com/uptrace/bun/dialect/pgdialect"
"github.com/uptrace/bun/driver/pgdriver"
"github.com/uptrace/bun/schema"
)

func TestPostgresArray(t *testing.T) {
Expand All @@ -25,16 +26,20 @@ func TestPostgresArray(t *testing.T) {
Array1 []string `bun:",array"`
Array2 *[]string `bun:",array"`
Array3 *[]string `bun:",array"`
Array4 []*string `bun:",array"`
}

db := pg(t)
t.Cleanup(func() { db.Close() })
mustResetModel(t, ctx, db, (*Model)(nil))

str1 := "hello"
str2 := "world"
model1 := &Model{
ID: 123,
Array1: []string{"one", "two", "three"},
Array2: &[]string{"hello", "world"},
Array4: []*string{&str1, &str2},
}
_, err := db.NewInsert().Model(model1).Exec(ctx)
require.NoError(t, err)
Expand All @@ -56,6 +61,12 @@ func TestPostgresArray(t *testing.T) {
Scan(ctx, pgdialect.Array(&strs))
require.NoError(t, err)
require.Nil(t, strs)

err = db.NewSelect().Model((*Model)(nil)).
Column("array4").
Scan(ctx, pgdialect.Array(&strs))
require.NoError(t, err)
require.Equal(t, []string{"hello", "world"}, strs)
}

func TestPostgresArrayQuote(t *testing.T) {
Expand Down Expand Up @@ -877,3 +888,78 @@ func TestPostgresMultiRange(t *testing.T) {
err = db.NewSelect().Model(out).Scan(ctx)
require.NoError(t, err)
}

type UserID struct {
ID string
}

func (u UserID) AppendQuery(fmter schema.Formatter, b []byte) ([]byte, error) {
v := []byte(`"` + u.ID + `"`)
return append(b, v...), nil
}

var _ schema.QueryAppender = (*UserID)(nil)

func (r *UserID) Scan(anySrc any) (err error) {
src, ok := anySrc.([]byte)
if !ok {
return fmt.Errorf("pgdialect: Range can't scan %T", anySrc)
}

r.ID = string(src)
return nil
}

var _ sql.Scanner = (*UserID)(nil)

func TestPostgresJSONB(t *testing.T) {
type Item struct {
Name string `json:"name"`
}
type Model struct {
ID int64 `bun:",pk,autoincrement"`
Item Item `bun:",type:jsonb"`
ItemPtr *Item `bun:",type:jsonb"`
Items []Item `bun:",type:jsonb"`
ItemsP []*Item `bun:",type:jsonb"`
TextItemA []UserID `bun:"type:text[]"`
}

db := pg(t)
t.Cleanup(func() { db.Close() })
mustResetModel(t, ctx, db, (*Model)(nil))

item1 := Item{Name: "one"}
item2 := Item{Name: "two"}
uid1 := UserID{ID: "1"}
uid2 := UserID{ID: "2"}
model1 := &Model{
ID: 123,
Item: item1,
ItemPtr: &item2,
Items: []Item{item1, item2},
ItemsP: []*Item{&item1, &item2},
TextItemA: []UserID{uid1, uid2},
}
_, err := db.NewInsert().Model(model1).Exec(ctx)
require.NoError(t, err)

model2 := new(Model)
err = db.NewSelect().Model(model2).Scan(ctx)
require.NoError(t, err)
require.Equal(t, model1, model2)

var items []Item
err = db.NewSelect().Model((*Model)(nil)).
Column("items").
Scan(ctx, pgdialect.Array(&items))
require.NoError(t, err)
require.Equal(t, []Item{item1, item2}, items)

err = db.NewSelect().Model((*Model)(nil)).
Column("itemsp").
Scan(ctx, pgdialect.Array(&items))
require.NoError(t, err)
require.Equal(t, []Item{item1, item2}, items)

}

0 comments on commit 1422b77

Please sign in to comment.