diff --git a/bigquery/integration_test.go b/bigquery/integration_test.go index 82e362645ea6..3d23a18236c1 100644 --- a/bigquery/integration_test.go +++ b/bigquery/integration_test.go @@ -1166,23 +1166,27 @@ func TestIntegration_DML(t *testing.T) { func runDML(ctx context.Context, sql string) error { // Retry insert; sometimes it fails with INTERNAL. - return internal.Retry(ctx, gax.Backoff{}, func() (bool, error) { - // Use DML to insert. - q := client.Query(sql) - job, err := q.Run(ctx) + return internal.Retry(ctx, gax.Backoff{}, func() (stop bool, err error) { + ri, err := client.Query(sql).Read(ctx) if err != nil { if e, ok := err.(*googleapi.Error); ok && e.Code < 500 { return true, err // fail on 4xx } return false, err } - if err := wait(ctx, job); err != nil { - if e, ok := err.(*googleapi.Error); ok && e.Code < 500 { - return true, err // fail on 4xx - } - return false, err + // It is OK to try to iterate over DML results. The first call to Next + // will return iterator.Done. + err = ri.Next(nil) + if err == nil { + return true, errors.New("want iterator.Done on the first call, got nil") + } + if err == iterator.Done { + return true, nil } - return true, nil + if e, ok := err.(*googleapi.Error); ok && e.Code < 500 { + return true, err // fail on 4xx + } + return false, err }) } @@ -1891,6 +1895,7 @@ func TestIntegration_Model(t *testing.T) { VALUES (1, 0), (2, 1), (3, 0), (4, 1)`, tableName) wantNumRows := 4 + if err := runDML(ctx, sql); err != nil { t.Fatal(err) } diff --git a/bigquery/iterator.go b/bigquery/iterator.go index f8894773f736..1633f16255bc 100644 --- a/bigquery/iterator.go +++ b/bigquery/iterator.go @@ -23,16 +23,20 @@ import ( "google.golang.org/api/iterator" ) +// Construct a RowIterator. +// If pf is nil, there are no rows in the result set. func newRowIterator(ctx context.Context, t *Table, pf pageFetcher) *RowIterator { it := &RowIterator{ ctx: ctx, table: t, pf: pf, } - it.pageInfo, it.nextFunc = iterator.NewPageInfo( - it.fetch, - func() int { return len(it.rows) }, - func() interface{} { r := it.rows; it.rows = nil; return r }) + if pf != nil { + it.pageInfo, it.nextFunc = iterator.NewPageInfo( + it.fetch, + func() int { return len(it.rows) }, + func() interface{} { r := it.rows; it.rows = nil; return r }) + } return it } @@ -99,6 +103,9 @@ type RowIterator struct { // NullDateTime. You can also use a *[]Value or *map[string]Value to read from a // table with NULLs. func (it *RowIterator) Next(dst interface{}) error { + if it.pf == nil { // There are no rows in the result set. + return iterator.Done + } var vl ValueLoader switch dst := dst.(type) { case ValueLoader: diff --git a/bigquery/job.go b/bigquery/job.go index 132c7e98a025..62f816d4e527 100644 --- a/bigquery/job.go +++ b/bigquery/job.go @@ -226,7 +226,7 @@ func (j *Job) Wait(ctx context.Context) (js *JobStatus, err error) { if j.isQuery() { // We can avoid polling for query jobs. - if _, err := j.waitForQuery(ctx, j.projectID); err != nil { + if _, _, err := j.waitForQuery(ctx, j.projectID); err != nil { return nil, err } // Note: extra RPC even if you just want to wait for the query to finish. @@ -262,7 +262,7 @@ func (j *Job) Read(ctx context.Context) (ri *RowIterator, err error) { return j.read(ctx, j.waitForQuery, fetchPage) } -func (j *Job) read(ctx context.Context, waitForQuery func(context.Context, string) (Schema, error), pf pageFetcher) (*RowIterator, error) { +func (j *Job) read(ctx context.Context, waitForQuery func(context.Context, string) (Schema, uint64, error), pf pageFetcher) (*RowIterator, error) { if !j.isQuery() { return nil, errors.New("bigquery: cannot read from a non-query job") } @@ -272,7 +272,7 @@ func (j *Job) read(ctx context.Context, waitForQuery func(context.Context, strin if destTable != nil && projectID != destTable.ProjectId { return nil, fmt.Errorf("bigquery: job project ID is %q, but destination table's is %q", projectID, destTable.ProjectId) } - schema, err := waitForQuery(ctx, projectID) + schema, totalRows, err := waitForQuery(ctx, projectID) if err != nil { return nil, err } @@ -280,13 +280,18 @@ func (j *Job) read(ctx context.Context, waitForQuery func(context.Context, strin return nil, errors.New("bigquery: query job missing destination table") } dt := bqToTable(destTable, j.c) + if totalRows == 0 { + pf = nil + } it := newRowIterator(ctx, dt, pf) it.Schema = schema + it.TotalRows = totalRows return it, nil } -// waitForQuery waits for the query job to complete and returns its schema. -func (j *Job) waitForQuery(ctx context.Context, projectID string) (Schema, error) { +// waitForQuery waits for the query job to complete and returns its schema. It also +// returns the total number of rows in the result set. +func (j *Job) waitForQuery(ctx context.Context, projectID string) (Schema, uint64, error) { // Use GetQueryResults only to wait for completion, not to read results. call := j.c.bqs.Jobs.GetQueryResults(projectID, j.jobID).Location(j.location).Context(ctx).MaxResults(0) setClientHeader(call.Header()) @@ -307,9 +312,9 @@ func (j *Job) waitForQuery(ctx context.Context, projectID string) (Schema, error return true, nil }) if err != nil { - return nil, err + return nil, 0, err } - return bqToSchema(res.Schema), nil + return bqToSchema(res.Schema), res.TotalRows, nil } // JobStatistics contains statistics about a job. diff --git a/bigquery/read_test.go b/bigquery/read_test.go index fb0f6e2ea4d0..9499684b47eb 100644 --- a/bigquery/read_test.go +++ b/bigquery/read_test.go @@ -56,8 +56,8 @@ func (s *pageFetcherReadStub) fetchPage(ctx context.Context, t *Table, schema Sc return result, nil } -func waitForQueryStub(context.Context, string) (Schema, error) { - return nil, nil +func waitForQueryStub(context.Context, string) (Schema, uint64, error) { + return nil, 1, nil } func TestRead(t *testing.T) {