diff --git a/internal/gengapic/generator.go b/internal/gengapic/generator.go index b39960df7..c8ababa87 100644 --- a/internal/gengapic/generator.go +++ b/internal/gengapic/generator.go @@ -53,6 +53,11 @@ var enableNewAuthLibraryBlocklist = map[string]bool{ "spanner.googleapis.com": true, } +// keyed by proto package name, e.g. "google.cloud.foo.v1". +var enableWrapperTypesForPageSize = map[string]bool{ + "google.cloud.bigquery.v2": true, +} + type generator struct { pt printer.P diff --git a/internal/gengapic/genrest.go b/internal/gengapic/genrest.go index 75b8e798b..a40bc57e4 100644 --- a/internal/gengapic/genrest.go +++ b/internal/gengapic/genrest.go @@ -791,20 +791,12 @@ func (g *generator) pagingRESTCall(servName string, m *descriptorpb.MethodDescri verb := strings.ToUpper(info.verb) - max := "math.MaxInt32" g.imports[pbinfo.ImportSpec{Path: "math"}] = true - psTyp := pbinfo.GoTypeForPrim[pageSize.GetType()] - ps := fmt.Sprintf("%s(pageSize)", psTyp) - if isOptional(inType, pageSize.GetName()) { - max = fmt.Sprintf("proto.%s(%s)", upperFirst(psTyp), max) - ps = fmt.Sprintf("proto.%s(%s)", upperFirst(psTyp), ps) - } tok := "pageToken" if isOptional(inType, "page_token") { tok = fmt.Sprintf("proto.String(%s)", tok) } - pageSizeFieldName := snakeToCamel(pageSize.GetName()) p("func (c *%s) %s(ctx context.Context, req *%s.%s, opts ...gax.CallOption) *%s {", lowcaseServName, m.GetName(), inSpec.Name, inType.GetName(), pt.iterTypeName) p("it := &%s{}", pt.iterTypeName) @@ -819,7 +811,7 @@ func (g *generator) pagingRESTCall(servName string, m *descriptorpb.MethodDescri p("unm := protojson.UnmarshalOptions{AllowPartial: true, DiscardUnknown: true}") p("it.InternalFetch = func(pageSize int, pageToken string) ([]%s, string, error) {", pt.elemTypeName) - g.internalFetchSetup(outType, outSpec, tok, pageSizeFieldName, max, ps) + g.internalFetchSetup(outType, outSpec, pageSize, tok) if info.body != "" { p(" jsonReq, err := m.Marshal(req)") @@ -874,7 +866,7 @@ func (g *generator) pagingRESTCall(servName string, m *descriptorpb.MethodDescri p(" return %s, resp.GetNextPageToken(), nil", elems) p("}") p("") - g.makeFetchAndIterUpdate(pageSizeFieldName) + g.makeFetchAndIterUpdate(pageSize) p("}") g.imports[pbinfo.ImportSpec{Path: "google.golang.org/api/iterator"}] = true diff --git a/internal/gengapic/paging.go b/internal/gengapic/paging.go index df5f3d858..7ad12918d 100644 --- a/internal/gengapic/paging.go +++ b/internal/gengapic/paging.go @@ -31,6 +31,34 @@ type iterType struct { elemImports []pbinfo.ImportSpec } +// isPageSizeField evaluates whether a particular field is a page size field, and whether this +// field will require a dependency on wrapper types in the generator. +// +// https://google.aip.dev/158 guidance is to use `page_size`, but older APIs like compute +// and bigquery use `max_results`. Similarly, `int32` is the expected scalar type, but +// there's more variance here in implementations, so int32 and uint32 are allowed. +// +// If wrapper support is allowed, the page size detection will include the +// usage of equivalent wrapper types as well (Int32Value, UInt32Value). This is legacy behavior +// due to older APIs that were built prior to proto3 presence being (re)introduced. +func isPageSizeField(f *descriptorpb.FieldDescriptorProto, wrappersAllowed bool) (isCandidate, requiresWrapper bool) { + if f.GetName() == "page_size" || f.GetName() == "max_results" { + // Scalar types. + if f.GetType() == descriptorpb.FieldDescriptorProto_TYPE_INT32 || f.GetType() == descriptorpb.FieldDescriptorProto_TYPE_UINT32 { + return true, false + } + // Wrapper types. + if wrappersAllowed { + if f.GetType() == descriptorpb.FieldDescriptorProto_TYPE_MESSAGE { + if f.GetTypeName() == ".google.protobuf.Int32Value" || f.GetTypeName() == ".google.protobuf.UInt32Value" { + return true, true + } + } + } + } + return false, false +} + // iterTypeOf deduces iterType from a field to be iterated over. // elemField should be the "resource" of a paginating RPC. // TODO(dovs): augment with paged map iterators @@ -121,22 +149,43 @@ func (g *generator) iterTypeOf(elemField *descriptorpb.FieldDescriptorProto) (*i // Makes particular allowance for diregapic idioms: maps can be paginated over, // and either 'page_size' XOR 'max_results' are allowable fields in the request. func (g *generator) getPagingFields(m *descriptorpb.MethodDescriptorProto) (repeatedField, pageSizeField *descriptorpb.FieldDescriptorProto, e error) { - // TODO: remove this once the next version of the Talent API is published. - // - // This is a workaround to disable auto-pagination for specifc RPCs in - // Talent v4beta1. The API team will make their API non-conforming in the - // next version. - // - // This should not be done for any other API. - if g.descInfo.ParentFile[m].GetPackage() == "google.cloud.talent.v4beta1" && - (m.GetName() == "SearchProfiles" || m.GetName() == "SearchJobs") { - return nil, nil, nil + // TODO: Remove this skip logic once annotation-based pagination config supercedes heuristic-based mechanisms. + // FR is tracked internally as b/337021569. + var paginationOverrides = []struct { + pkgName string + disallowedMethods []string // methods explicitly denied from pagination + }{ + { + pkgName: "google.cloud.talent.v4beta1", + disallowedMethods: []string{"SearchProfiles", "SearchJobs"}, + }, + { + pkgName: "google.cloud.bigquery.v2", + disallowedMethods: []string{"GetQueryResults"}, + }, } + for _, cfg := range paginationOverrides { + if g.descInfo.ParentFile[m].GetPackage() == cfg.pkgName { + for _, skipMethod := range cfg.disallowedMethods { + if m.GetName() == skipMethod { + return nil, nil, nil + } + } + } + } if m.GetClientStreaming() || m.GetServerStreaming() { return nil, nil, nil } + var wrapperTypesAllowed bool + for p, ok := range enableWrapperTypesForPageSize { + if g.descInfo.ParentFile[m].GetPackage() == p && ok { + wrapperTypesAllowed = true + break + } + } + inType := g.descInfo.Type[m.GetInputType()] if inType == nil { return nil, nil, fmt.Errorf("expected %q to be message type, found %T", m.GetInputType(), inType) @@ -157,12 +206,16 @@ func (g *generator) getPagingFields(m *descriptorpb.MethodDescriptorProto) (repe hasPageToken := false for _, f := range inMsg.GetField() { - isInt32 := f.GetType() == descriptorpb.FieldDescriptorProto_TYPE_INT32 || f.GetType() == descriptorpb.FieldDescriptorProto_TYPE_UINT32 - if (f.GetName() == "page_size" || f.GetName() == "max_results") && isInt32 { - if pageSizeField != nil { - return nil, nil, fmt.Errorf("found both page_size and max_results fields in message %q", m.GetInputType()) + candidate, needsWrapper := isPageSizeField(f, wrapperTypesAllowed) + if candidate { + if pageSizeField == nil { + pageSizeField = f + if needsWrapper { + g.imports[pbinfo.ImportSpec{Path: "google.golang.org/protobuf/types/known/wrapperspb"}] = true + } + } else { + return nil, nil, fmt.Errorf("found multiple page size fields in message %q: %q and %q", m.GetInputType(), pageSizeField.GetName(), f.GetName()) } - pageSizeField = f continue } @@ -222,7 +275,7 @@ func (g *generator) maybeSortMapPage(elemField *descriptorpb.FieldDescriptorProt return elems } -func (g *generator) makeFetchAndIterUpdate(pageSizeFieldName string) { +func (g *generator) makeFetchAndIterUpdate(pageSize *descriptorpb.FieldDescriptorProto) { p := g.printf p("fetch := func(pageSize int, pageToken string) (string, error) {") @@ -235,13 +288,29 @@ func (g *generator) makeFetchAndIterUpdate(pageSizeFieldName string) { p("}") p("") p("it.pageInfo, it.nextFunc = iterator.NewPageInfo(fetch, it.bufLen, it.takeBuf)") - p("it.pageInfo.MaxSize = int(req.Get%s())", pageSizeFieldName) + internalPageInfoMax(p, pageSize) p("it.pageInfo.Token = req.GetPageToken()") p("") p("return it") } -func (g *generator) internalFetchSetup(outType *descriptorpb.DescriptorProto, outSpec pbinfo.ImportSpec, tok, pageSizeFieldName, max, ps string) { +// internalPageInfo handles the logic for setting MaxSize in PageInfo. +// This method is called from makeFetchAndIterUpdate() and deals with the +// various types allowed for the page_size field. +func internalPageInfoMax(p func(s string, a ...interface{}), pageSize *descriptorpb.FieldDescriptorProto) { + cName := snakeToCamel(pageSize.GetName()) + switch pageSize.GetType() { + case descriptorpb.FieldDescriptorProto_TYPE_INT32, descriptorpb.FieldDescriptorProto_TYPE_UINT32: + p("it.pageInfo.MaxSize = int(req.Get%s())", cName) + case descriptorpb.FieldDescriptorProto_TYPE_MESSAGE: + // Both wrapper types use a castable GetValue() field. + p("if psVal := req.Get%s(); psVal != nil {", cName) + p(" it.pageInfo.MaxSize = int(psVal.GetValue())") + p("}") + } +} + +func (g *generator) internalFetchSetup(outType *descriptorpb.DescriptorProto, outSpec pbinfo.ImportSpec, pageSize *descriptorpb.FieldDescriptorProto, tok string) { p := g.printf p(" resp := &%s.%s{}", outSpec.Name, outType.GetName()) @@ -249,12 +318,46 @@ func (g *generator) internalFetchSetup(outType *descriptorpb.DescriptorProto, ou p(" req.PageToken = %s", tok) p(" }") p(" if pageSize > math.MaxInt32 {") - p(" req.%s = %s", pageSizeFieldName, max) + internalPageSizeSetter(p, pageSize, "math.MaxInt32") p(" } else if pageSize != 0 {") - p(" req.%s = %s", pageSizeFieldName, ps) + internalPageSizeSetter(p, pageSize, "pageSize") p(" }") } +// internalPageSizeSetter is a helper for injecting the value setting expression. +// The incoming setVal is based on an incoming set int32-based value variable, +// typically either labelled as 'pageSize' or 'math.MaxInt32'. +func internalPageSizeSetter(p func(s string, a ...interface{}), pageSize *descriptorpb.FieldDescriptorProto, setVal string) { + cName := snakeToCamel(pageSize.GetName()) + switch pageSize.GetType() { + case descriptorpb.FieldDescriptorProto_TYPE_INT32: + if pageSize.GetProto3Optional() { + p("req.%s = proto.Int32(%s)", cName, setVal) + } else { + if setVal != "math.MaxInt32" { + setVal = fmt.Sprintf("int32(%s)", setVal) + } + p("req.%s = %s", cName, setVal) + } + case descriptorpb.FieldDescriptorProto_TYPE_UINT32: + if pageSize.GetProto3Optional() { + p("req.%s = proto.UInt32(%s)", cName, setVal) + } else { + p("req.%s = uint32(%s)", cName, setVal) + } + case descriptorpb.FieldDescriptorProto_TYPE_MESSAGE: + switch pageSize.GetTypeName() { + case ".google.protobuf.Int32Value": + if setVal != "math.MaxInt32" { + setVal = fmt.Sprintf("int32(%s)", setVal) + } + p("req.%s = &wrapperspb.Int32Value{Value: %s}", cName, setVal) + case ".google.protobuf.UInt32Value": + p("req.%s = &wrapperspb.UInt32Value{Value: uint32(%s)}", cName, setVal) + } + } +} + func (g *generator) pagingCall(servName string, m *descriptorpb.MethodDescriptorProto, elemField, pageSize *descriptorpb.FieldDescriptorProto, pt *iterType) error { inType := g.descInfo.Type[m.GetInputType()].(*descriptorpb.DescriptorProto) outType := g.descInfo.Type[m.GetOutputType()].(*descriptorpb.DescriptorProto) @@ -272,13 +375,6 @@ func (g *generator) pagingCall(servName string, m *descriptorpb.MethodDescriptor return err } - max := "math.MaxInt32" - ps := "int32(pageSize)" - if isOptional(inType, "page_size") { - max = fmt.Sprintf("proto.Int32(%s)", max) - ps = fmt.Sprintf("proto.Int32(%s)", ps) - } - tok := "pageToken" if isOptional(inType, "page_token") { tok = fmt.Sprintf("proto.String(%s)", tok) @@ -290,11 +386,10 @@ func (g *generator) pagingCall(servName string, m *descriptorpb.MethodDescriptor g.insertRequestHeaders(m, grpc) g.appendCallOpts(m) - pageSizeFieldName := snakeToCamel(pageSize.GetName()) p("it := &%s{}", pt.iterTypeName) p("req = proto.Clone(req).(*%s.%s)", inSpec.Name, inType.GetName()) p("it.InternalFetch = func(pageSize int, pageToken string) ([]%s, string, error) {", pt.elemTypeName) - g.internalFetchSetup(outType, outSpec, tok, pageSizeFieldName, max, ps) + g.internalFetchSetup(outType, outSpec, pageSize, tok) p(" err := gax.Invoke(ctx, func(ctx context.Context, settings gax.CallSettings) error {") p(" var err error") p(" resp, err = %s", g.grpcStubCall(m)) @@ -308,7 +403,7 @@ func (g *generator) pagingCall(servName string, m *descriptorpb.MethodDescriptor elems := g.maybeSortMapPage(elemField, pt) p(" return %s, resp.GetNextPageToken(), nil", elems) p("}") - g.makeFetchAndIterUpdate(pageSizeFieldName) + g.makeFetchAndIterUpdate(pageSize) p("}") p("") diff --git a/internal/gengapic/paging_test.go b/internal/gengapic/paging_test.go index 32a21125b..e9f0d9554 100644 --- a/internal/gengapic/paging_test.go +++ b/internal/gengapic/paging_test.go @@ -125,6 +125,19 @@ func TestPagingField(t *testing.T) { // * Response has a string next_page_token field // * Response has one and only one repeated or map field + // This test manipulates the enableWrapperTypesForPageSize allowlist, so ensure we're + // not tainting state after test completes. + origAllowList := make(map[string]bool) + for k, v := range enableWrapperTypesForPageSize { + origAllowList[k] = v + } + + defer func() { + enableWrapperTypesForPageSize = origAllowList + }() + // Clear the allowlist for this test. + enableWrapperTypesForPageSize = make(map[string]bool) + // Messages validPageSize := &descriptorpb.DescriptorProto{ Name: proto.String("ValidPageSizeRequest"), @@ -156,6 +169,38 @@ func TestPagingField(t *testing.T) { }, }, } + restrictedWrapperMaxResults := &descriptorpb.DescriptorProto{ + Name: proto.String("WrapperInt32Request"), + Field: []*descriptorpb.FieldDescriptorProto{ + { + Name: proto.String("max_results"), + Number: proto.Int32(int32(1)), + Type: typep(descriptorpb.FieldDescriptorProto_TYPE_MESSAGE), + TypeName: proto.String(".google.protobuf.Int32Value"), + }, + { + Name: proto.String("page_token"), + Number: proto.Int32(int32(2)), + Type: typep(descriptorpb.FieldDescriptorProto_TYPE_STRING), + }, + }, + } + restrictedWrapperMaxResultsUint32 := &descriptorpb.DescriptorProto{ + Name: proto.String("WrapperUInt32Request"), + Field: []*descriptorpb.FieldDescriptorProto{ + { + Name: proto.String("max_results"), + Number: proto.Int32(int32(1)), + Type: typep(descriptorpb.FieldDescriptorProto_TYPE_MESSAGE), + TypeName: proto.String(".google.protobuf.UInt32Value"), + }, + { + Name: proto.String("page_token"), + Number: proto.Int32(int32(2)), + Type: typep(descriptorpb.FieldDescriptorProto_TYPE_STRING), + }, + }, + } randomMessage := &descriptorpb.DescriptorProto{Name: proto.String("RandomMessage")} validRepeated := &descriptorpb.DescriptorProto{ Name: proto.String("ValidRepeatedResponse"), @@ -330,6 +375,16 @@ func TestPagingField(t *testing.T) { InputType: proto.String(".paging.ValidPageSizeRequest"), OutputType: proto.String(".paging.ValidMapResponse"), } + restrictedPageWrapperMthd := &descriptorpb.MethodDescriptorProto{ + Name: proto.String("RestrictedPageSize"), + InputType: proto.String(".paging.WrapperInt32Request"), + OutputType: proto.String(".paging.ValidRepeatedResponse"), + } + restrictedPageWrapperUint32Mthd := &descriptorpb.MethodDescriptorProto{ + Name: proto.String("RestrictedPageSize"), + InputType: proto.String(".paging.WrapperUInt32Request"), + OutputType: proto.String(".paging.ValidRepeatedResponse"), + } clientStreamingMthd := &descriptorpb.MethodDescriptorProto{ Name: proto.String("ClientStreaming"), InputType: proto.String(".paging.ValidPageSizeRequest"), @@ -379,6 +434,8 @@ func TestPagingField(t *testing.T) { tooManyRepeated, validMap, validMaxResults, + restrictedWrapperMaxResults, + restrictedWrapperMaxResultsUint32, validPageSize, validRepeated, }, @@ -397,11 +454,14 @@ func TestPagingField(t *testing.T) { validPageSizeMapMthd, validPageSizeMthd, validPageSizeMultipleMthd, + restrictedPageWrapperMthd, + restrictedPageWrapperUint32Mthd, }, }, }, } + // First, test using the default "paging" package. req := pluginpb.CodeGeneratorRequest{ Parameter: proto.String("go-gapic-package=path;mypackage,transport=rest"), ProtoFile: []*descriptorpb.FileDescriptorProto{file}, @@ -423,6 +483,8 @@ func TestPagingField(t *testing.T) { {mthd: tooManyRepeatedMthd}, {mthd: noNextPageTokenMthd}, {mthd: noRepeatedFieldMthd}, + {mthd: restrictedPageWrapperMthd}, // not pageable without an allowed package + {mthd: restrictedPageWrapperUint32Mthd}, // not pageable without an allowed package {mthd: validMaxResultsRepeatedMthd, sizeField: validMaxResults.GetField()[0], iterField: validRepeated.GetField()[1]}, {mthd: validPageSizeMapMthd, sizeField: validPageSize.GetField()[0], iterField: validMap.GetField()[1]}, {mthd: validPageSizeMthd, sizeField: validPageSize.GetField()[0], iterField: validRepeated.GetField()[1]}, @@ -438,4 +500,31 @@ func TestPagingField(t *testing.T) { t.Errorf("test %s iter field: got %s, want %s, err %v", tst.mthd.GetName(), actualIter, tst.iterField, err) } } + + // Re-test, adding the "paging" package to the allowlist. + enableWrapperTypesForPageSize["paging"] = true + g, err = newGenerator(&req) + if err != nil { + t.Fatal(err) + } + g.opts = &options{transports: []transport{rest}} + for _, tst := range []struct { + mthd *descriptorpb.MethodDescriptorProto + sizeField *descriptorpb.FieldDescriptorProto // A nil field means this is not a paged method + iterField *descriptorpb.FieldDescriptorProto // A nil field means this is not a paged method + }{ + {mthd: noRepeatedFieldMthd}, // still invalid + {mthd: restrictedPageWrapperMthd, sizeField: restrictedWrapperMaxResults.GetField()[0], iterField: validRepeated.GetField()[1]}, + {mthd: restrictedPageWrapperUint32Mthd, sizeField: restrictedWrapperMaxResultsUint32.GetField()[0], iterField: validRepeated.GetField()[1]}, + {mthd: validMaxResultsRepeatedMthd, sizeField: validMaxResults.GetField()[0], iterField: validRepeated.GetField()[1]}, // valid regardless of allowlist + } { + actualIter, actualSize, err := g.getPagingFields(tst.mthd) + if actualSize != tst.sizeField { + t.Errorf("test w/wrapper %s page size field: got %s, want %s, err %v", tst.mthd.GetName(), actualSize, tst.sizeField, err) + } + + if actualIter != tst.iterField { + t.Errorf("test w/wrapper %s iter field: got %s, want %s, err %v", tst.mthd.GetName(), actualIter, tst.iterField, err) + } + } } diff --git a/internal/gengapic/testdata/method_GetManyThingsOptional.want b/internal/gengapic/testdata/method_GetManyThingsOptional.want index ee68a6499..e913744aa 100644 --- a/internal/gengapic/testdata/method_GetManyThingsOptional.want +++ b/internal/gengapic/testdata/method_GetManyThingsOptional.want @@ -14,7 +14,7 @@ func (c *fooGRPCClient) GetManyThingsOptional(ctx context.Context, req *mypackag if pageSize > math.MaxInt32 { req.PageSize = proto.Int32(math.MaxInt32) } else if pageSize != 0 { - req.PageSize = proto.Int32(int32(pageSize)) + req.PageSize = proto.Int32(pageSize) } err := gax.Invoke(ctx, func(ctx context.Context, settings gax.CallSettings) error { var err error diff --git a/test.sh b/test.sh index c9a720b78..43ba9cd29 100755 --- a/test.sh +++ b/test.sh @@ -57,6 +57,10 @@ generate --go_gapic_opt 'go-gapic-package=cloud.google.com/go/retail/apiv2;retai echo "Generating Apigee Connect v1 - Dual Transport, partial REGAPIC" generate --go_gapic_opt 'go-gapic-package=cloud.google.com/go/apigeeconnect/apiv1;apigeeconnect,transport=grpc+rest' $GOOGLEAPIS/google/cloud/apigeeconnect/v1/*.proto +echo "Generating BigQuery v2 - REGAPIC, atypical list RPCs" +generate --go_gapic_opt 'go-gapic-package=cloud.google.com/go/bigquery/apiv2;bigquery,transport=rest' $GOOGLEAPIS/google/cloud/bigquery/v2/*.proto + + echo "Generation complete" echo "Running gofmt to check for syntax errors"