From 62939b3afb8ca622d6c24f73e8c5fc984334c9b6 Mon Sep 17 00:00:00 2001 From: Travis Harmon Date: Fri, 25 Mar 2022 12:36:07 -0500 Subject: [PATCH] Remove the requirement that selects must specify limit --- select_query.go | 34 +++++++++++----- select_query_test.go | 94 +++++++++++++++++++++++++++++--------------- 2 files changed, 86 insertions(+), 42 deletions(-) diff --git a/select_query.go b/select_query.go index 557f2c1..7b3b7a1 100644 --- a/select_query.go +++ b/select_query.go @@ -147,9 +147,6 @@ func (sq *SelectQuery) Get(out interface{}) error { if el.Kind() != reflect.Struct { break } - if sq.limit == 0 { - return errors.New("limit must be set and not zero when selecting multiple rows") - } var err error sq.model, err = sq.db.getModelOf(el) if err != nil { @@ -157,9 +154,6 @@ func (sq *SelectQuery) Get(out interface{}) error { } return sq.toMany(t, out) case reflect.Struct: - if sq.limit == 0 { - return errors.New("limit must be set and not zero when selecting multiple rows") - } var err error sq.model, err = sq.db.getModelOf(el) if err != nil { @@ -215,7 +209,12 @@ func (sq *SelectQuery) toMany(sliceType reflect.Type, outs interface{}) error { return err } defer rows.Close() - newOuts := reflect.MakeSlice(sliceType, int(sq.limit), int(sq.limit)) + var newOuts reflect.Value + if sq.limit == 0 { + newOuts = reflect.MakeSlice(sliceType, 8, 8) + } else { + newOuts = reflect.MakeSlice(sliceType, int(sq.limit), int(sq.limit)) + } i := 0 columns, _ := rows.Columns() fieldCount := len(columns) @@ -228,6 +227,10 @@ func (sq *SelectQuery) toMany(sliceType reflect.Type, outs interface{}) error { } dests := make([]interface{}, fieldCount) for rows.Next() { + if newOuts.Len() == i { + newOuts.SetCap(newOuts.Len() * 2) + newOuts.SetLen(newOuts.Len() * 2) + } newOut := newOuts.Index(i) newOut.Set(reflect.New(sq.model.typ)) for j := 0; j < fieldCount; j++ { @@ -253,7 +256,12 @@ func (sq *SelectQuery) toManyValues(sliceType reflect.Type, outs interface{}) er return err } defer rows.Close() - newOuts := reflect.MakeSlice(sliceType, int(sq.limit), int(sq.limit)) + var newOuts reflect.Value + if sq.limit == 0 { + newOuts = reflect.MakeSlice(sliceType, 8, 8) + } else { + newOuts = reflect.MakeSlice(sliceType, int(sq.limit), int(sq.limit)) + } i := 0 columns, _ := rows.Columns() fieldCount := len(columns) @@ -267,6 +275,10 @@ func (sq *SelectQuery) toManyValues(sliceType reflect.Type, outs interface{}) er dests := make([]interface{}, fieldCount) newOut := newOuts.Index(0) for rows.Next() { + if newOuts.Len() == i { + newOuts.SetCap(newOuts.Len() * 2) + newOuts.SetLen(newOuts.Len() * 2) + } newOut = newOuts.Index(i) for j := 0; j < fieldCount; j++ { dests[j] = newOut.Field(fieldIndecies[j]).Addr().Interface() @@ -324,8 +336,10 @@ func (sq *SelectQuery) String() string { q.WriteString(sq.order) } if sq.many { - q.WriteString(" limit ") - q.WriteString(strconv.FormatInt(sq.limit, 10)) + if sq.limit > 0 { + q.WriteString(" limit ") + q.WriteString(strconv.FormatInt(sq.limit, 10)) + } } else { q.WriteString(" limit 1") } diff --git a/select_query_test.go b/select_query_test.go index 4e5debd..9d578b5 100644 --- a/select_query_test.go +++ b/select_query_test.go @@ -47,9 +47,9 @@ func TestSelectQueryMany(t *testing.T) { for _, c := range control { rows.AddRow(c.ID, c.Name) } - mock.ExpectQuery(`^select \* from t limit 10$`).WillReturnRows(rows) + mock.ExpectQuery(`^select \* from t$`).WillReturnRows(rows) var test []*T - check(t, db.Select("*").Limit(10).Get(&test)) + check(t, db.Select("*").Get(&test)) check(t, mock.ExpectationsWereMet()) for i := 0; i < len(control); i++ { equals(t, control[i], test[i]) @@ -57,6 +57,66 @@ func TestSelectQueryMany(t *testing.T) { } func TestSelectQueryManyValues(t *testing.T) { + db, mock, err := getMockDB() + check(t, err) + type T struct { + ID int `idx:"primary"` + Name string + } + control := []T{ + { + ID: 5, + Name: "foo", + }, + { + ID: 6, + Name: "bar", + }, + } + rows := sqlmock.NewRows([]string{"id", "name"}) + for _, c := range control { + rows.AddRow(c.ID, c.Name) + } + mock.ExpectQuery(`^select \* from t$`).WillReturnRows(rows) + var test []T + check(t, db.Select("*").Get(&test)) + check(t, mock.ExpectationsWereMet()) + for i := 0; i < len(control); i++ { + equals(t, control[i], test[i]) + } +} + +func TestSelectQueryManyLimit(t *testing.T) { + db, mock, err := getMockDB() + check(t, err) + type T struct { + ID int `idx:"primary"` + Name string + } + control := []*T{ + { + ID: 5, + Name: "foo", + }, + { + ID: 6, + Name: "bar", + }, + } + rows := sqlmock.NewRows([]string{"id", "name"}) + for _, c := range control { + rows.AddRow(c.ID, c.Name) + } + mock.ExpectQuery(`^select \* from t$`).WillReturnRows(rows) + var test []*T + check(t, db.Select("*").Get(&test)) + check(t, mock.ExpectationsWereMet()) + for i := 0; i < len(control); i++ { + equals(t, control[i], test[i]) + } +} + +func TestSelectQueryManyValuesLimit(t *testing.T) { db, mock, err := getMockDB() check(t, err) type T struct { @@ -227,36 +287,6 @@ func TestSelectQueryErrNotStructOrSlice(t *testing.T) { } } -func TestSelectQueryErrLimitZeroManyValues(t *testing.T) { - db, _, err := getMockDB() - check(t, err) - type T struct { - ID int `idx:"primary"` - Name string - } - var test []T - if err := db.Select("*").Get(&test); err == nil { - t.Fatalf("expected err to be non nil") - } else { - contains(t, err.Error(), "limit") - } -} - -func TestSelectQueryErrLimitZeroMany(t *testing.T) { - db, _, err := getMockDB() - check(t, err) - type T struct { - ID int `idx:"primary"` - Name string - } - var test []*T - if err := db.Select("*").Get(&test); err == nil { - t.Fatalf("expected err to be non nil") - } else { - contains(t, err.Error(), "limit") - } -} - func TestSelectQueryCustomColumn(t *testing.T) { db, mock, err := getMockDB() check(t, err)