Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support wrapper types for autopagination #1541

Merged
merged 14 commits into from
Aug 15, 2024
Merged
5 changes: 5 additions & 0 deletions internal/gengapic/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ var enableNewAuthLibraryBlocklist = map[string]bool{
"spanner.googleapis.com": true,
}

// keyed by proto package name, e.g. "google.cloud.foo.v1".
var enableWrapperTypesForPageSize = map[string]bool{
"google.cloud.bigquery.v2": true,
}

type generator struct {
pt printer.P

Expand Down
12 changes: 2 additions & 10 deletions internal/gengapic/genrest.go
Original file line number Diff line number Diff line change
Expand Up @@ -791,20 +791,12 @@ func (g *generator) pagingRESTCall(servName string, m *descriptorpb.MethodDescri

verb := strings.ToUpper(info.verb)

max := "math.MaxInt32"
g.imports[pbinfo.ImportSpec{Path: "math"}] = true
psTyp := pbinfo.GoTypeForPrim[pageSize.GetType()]
ps := fmt.Sprintf("%s(pageSize)", psTyp)
if isOptional(inType, pageSize.GetName()) {
max = fmt.Sprintf("proto.%s(%s)", upperFirst(psTyp), max)
ps = fmt.Sprintf("proto.%s(%s)", upperFirst(psTyp), ps)
}
tok := "pageToken"
if isOptional(inType, "page_token") {
tok = fmt.Sprintf("proto.String(%s)", tok)
}

pageSizeFieldName := snakeToCamel(pageSize.GetName())
p("func (c *%s) %s(ctx context.Context, req *%s.%s, opts ...gax.CallOption) *%s {",
lowcaseServName, m.GetName(), inSpec.Name, inType.GetName(), pt.iterTypeName)
p("it := &%s{}", pt.iterTypeName)
Expand All @@ -819,7 +811,7 @@ func (g *generator) pagingRESTCall(servName string, m *descriptorpb.MethodDescri

p("unm := protojson.UnmarshalOptions{AllowPartial: true, DiscardUnknown: true}")
p("it.InternalFetch = func(pageSize int, pageToken string) ([]%s, string, error) {", pt.elemTypeName)
g.internalFetchSetup(outType, outSpec, tok, pageSizeFieldName, max, ps)
g.internalFetchSetup(outType, outSpec, pageSize, tok)

if info.body != "" {
p(" jsonReq, err := m.Marshal(req)")
Expand Down Expand Up @@ -874,7 +866,7 @@ func (g *generator) pagingRESTCall(servName string, m *descriptorpb.MethodDescri
p(" return %s, resp.GetNextPageToken(), nil", elems)
p("}")
p("")
g.makeFetchAndIterUpdate(pageSizeFieldName)
g.makeFetchAndIterUpdate(pageSize)
p("}")

g.imports[pbinfo.ImportSpec{Path: "google.golang.org/api/iterator"}] = true
Expand Down
155 changes: 125 additions & 30 deletions internal/gengapic/paging.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,34 @@ type iterType struct {
elemImports []pbinfo.ImportSpec
}

// isPageSizeField evaluates whether a particular field is a page size field, and whether this
// field will require a dependency on wrapper types in the generator.
//
// https://google.aip.dev/158 guidance is to use `page_size`, but older APIs like compute
// and bigquery use `max_results`. Similarly, `int32` is the expected scalar type, but
// there's more variance here in implementations, so int32 and uint32 are allowed.
//
// If wrapper support is allowed, the page size detection will include the
// usage of equivalent wrapper types as well (Int32Value, UInt32Value). This is legacy behavior
// due to older APIs that were built prior to proto3 presence being (re)introduced.
func isPageSizeField(f *descriptorpb.FieldDescriptorProto, wrappersAllowed bool) (isCandidate, requiresWrapper bool) {
if f.GetName() == "page_size" || f.GetName() == "max_results" {
// Scalar types.
if f.GetType() == descriptorpb.FieldDescriptorProto_TYPE_INT32 || f.GetType() == descriptorpb.FieldDescriptorProto_TYPE_UINT32 {
return true, false
}
// Wrapper types.
if wrappersAllowed {
if f.GetType() == descriptorpb.FieldDescriptorProto_TYPE_MESSAGE {
if f.GetTypeName() == ".google.protobuf.Int32Value" || f.GetTypeName() == ".google.protobuf.UInt32Value" {
return true, true
}
}
}
}
return false, false
}

// iterTypeOf deduces iterType from a field to be iterated over.
// elemField should be the "resource" of a paginating RPC.
// TODO(dovs): augment with paged map iterators
Expand Down Expand Up @@ -121,22 +149,43 @@ func (g *generator) iterTypeOf(elemField *descriptorpb.FieldDescriptorProto) (*i
// Makes particular allowance for diregapic idioms: maps can be paginated over,
// and either 'page_size' XOR 'max_results' are allowable fields in the request.
func (g *generator) getPagingFields(m *descriptorpb.MethodDescriptorProto) (repeatedField, pageSizeField *descriptorpb.FieldDescriptorProto, e error) {
// TODO: remove this once the next version of the Talent API is published.
//
// This is a workaround to disable auto-pagination for specifc RPCs in
// Talent v4beta1. The API team will make their API non-conforming in the
// next version.
//
// This should not be done for any other API.
if g.descInfo.ParentFile[m].GetPackage() == "google.cloud.talent.v4beta1" &&
(m.GetName() == "SearchProfiles" || m.GetName() == "SearchJobs") {
return nil, nil, nil
// TODO: Remove this skip logic once annotation-based pagination config supercedes heuristic-based mechanisms.
// FR is tracked internally as b/337021569.
var paginationOverrides = []struct {
pkgName string
disallowedMethods []string // methods explicitly denied from pagination
}{
{
pkgName: "google.cloud.talent.v4beta1",
disallowedMethods: []string{"SearchProfiles", "SearchJobs"},
},
{
pkgName: "google.cloud.bigquery.v2",
disallowedMethods: []string{"GetQueryResults"},
},
}

for _, cfg := range paginationOverrides {
if g.descInfo.ParentFile[m].GetPackage() == cfg.pkgName {
for _, skipMethod := range cfg.disallowedMethods {
if m.GetName() == skipMethod {
return nil, nil, nil
}
}
}
}
if m.GetClientStreaming() || m.GetServerStreaming() {
return nil, nil, nil
}

var wrapperTypesAllowed bool
for p, ok := range enableWrapperTypesForPageSize {
if g.descInfo.ParentFile[m].GetPackage() == p && ok {
wrapperTypesAllowed = true
break
}
}

inType := g.descInfo.Type[m.GetInputType()]
if inType == nil {
return nil, nil, fmt.Errorf("expected %q to be message type, found %T", m.GetInputType(), inType)
Expand All @@ -157,12 +206,16 @@ func (g *generator) getPagingFields(m *descriptorpb.MethodDescriptorProto) (repe

hasPageToken := false
for _, f := range inMsg.GetField() {
isInt32 := f.GetType() == descriptorpb.FieldDescriptorProto_TYPE_INT32 || f.GetType() == descriptorpb.FieldDescriptorProto_TYPE_UINT32
if (f.GetName() == "page_size" || f.GetName() == "max_results") && isInt32 {
if pageSizeField != nil {
return nil, nil, fmt.Errorf("found both page_size and max_results fields in message %q", m.GetInputType())
candidate, needsWrapper := isPageSizeField(f, wrapperTypesAllowed)
if candidate {
if pageSizeField == nil {
pageSizeField = f
if needsWrapper {
g.imports[pbinfo.ImportSpec{Path: "google.golang.org/protobuf/types/known/wrapperspb"}] = true
}
} else {
return nil, nil, fmt.Errorf("found multiple page size fields in message %q: %q and %q", m.GetInputType(), pageSizeField.GetName(), f.GetName())
}
pageSizeField = f
continue
}

Expand Down Expand Up @@ -222,7 +275,7 @@ func (g *generator) maybeSortMapPage(elemField *descriptorpb.FieldDescriptorProt
return elems
}

func (g *generator) makeFetchAndIterUpdate(pageSizeFieldName string) {
func (g *generator) makeFetchAndIterUpdate(pageSize *descriptorpb.FieldDescriptorProto) {
p := g.printf

p("fetch := func(pageSize int, pageToken string) (string, error) {")
Expand All @@ -235,26 +288,76 @@ func (g *generator) makeFetchAndIterUpdate(pageSizeFieldName string) {
p("}")
p("")
p("it.pageInfo, it.nextFunc = iterator.NewPageInfo(fetch, it.bufLen, it.takeBuf)")
p("it.pageInfo.MaxSize = int(req.Get%s())", pageSizeFieldName)
internalPageInfoMax(p, pageSize)
p("it.pageInfo.Token = req.GetPageToken()")
p("")
p("return it")
}

func (g *generator) internalFetchSetup(outType *descriptorpb.DescriptorProto, outSpec pbinfo.ImportSpec, tok, pageSizeFieldName, max, ps string) {
// internalPageInfo handles the logic for setting MaxSize in PageInfo.
// This method is called from makeFetchAndIterUpdate() and deals with the
// various types allowed for the page_size field.
func internalPageInfoMax(p func(s string, a ...interface{}), pageSize *descriptorpb.FieldDescriptorProto) {
cName := snakeToCamel(pageSize.GetName())
switch pageSize.GetType() {
case descriptorpb.FieldDescriptorProto_TYPE_INT32, descriptorpb.FieldDescriptorProto_TYPE_UINT32:
p("it.pageInfo.MaxSize = int(req.Get%s())", cName)
case descriptorpb.FieldDescriptorProto_TYPE_MESSAGE:
// Both wrapper types use a castable GetValue() field.
p("if psVal := req.Get%s(); psVal != nil {", cName)
p(" it.pageInfo.MaxSize = int(psVal.GetValue())")
p("}")
}
}

func (g *generator) internalFetchSetup(outType *descriptorpb.DescriptorProto, outSpec pbinfo.ImportSpec, pageSize *descriptorpb.FieldDescriptorProto, tok string) {
p := g.printf

p(" resp := &%s.%s{}", outSpec.Name, outType.GetName())
p(` if pageToken != "" {`)
p(" req.PageToken = %s", tok)
p(" }")
p(" if pageSize > math.MaxInt32 {")
p(" req.%s = %s", pageSizeFieldName, max)
internalPageSizeSetter(p, pageSize, "math.MaxInt32")
p(" } else if pageSize != 0 {")
p(" req.%s = %s", pageSizeFieldName, ps)
internalPageSizeSetter(p, pageSize, "pageSize")
p(" }")
}

// internalPageSizeSetter is a helper for injecting the value setting expression.
// The incoming setVal is based on an incoming set int32-based value variable,
// typically either labelled as 'pageSize' or 'math.MaxInt32'.
func internalPageSizeSetter(p func(s string, a ...interface{}), pageSize *descriptorpb.FieldDescriptorProto, setVal string) {
cName := snakeToCamel(pageSize.GetName())
switch pageSize.GetType() {
case descriptorpb.FieldDescriptorProto_TYPE_INT32:
if pageSize.GetProto3Optional() {
p("req.%s = proto.Int32(%s)", cName, setVal)
} else {
if setVal != "math.MaxInt32" {
setVal = fmt.Sprintf("int32(%s)", setVal)
}
p("req.%s = %s", cName, setVal)
}
case descriptorpb.FieldDescriptorProto_TYPE_UINT32:
if pageSize.GetProto3Optional() {
p("req.%s = proto.UInt32(%s)", cName, setVal)
} else {
p("req.%s = uint32(%s)", cName, setVal)
}
case descriptorpb.FieldDescriptorProto_TYPE_MESSAGE:
switch pageSize.GetTypeName() {
case ".google.protobuf.Int32Value":
if setVal != "math.MaxInt32" {
setVal = fmt.Sprintf("int32(%s)", setVal)
}
p("req.%s = &wrapperspb.Int32Value{Value: %s}", cName, setVal)
case ".google.protobuf.UInt32Value":
p("req.%s = &wrapperspb.UInt32Value{Value: uint32(%s)}", cName, setVal)
}
}
}

func (g *generator) pagingCall(servName string, m *descriptorpb.MethodDescriptorProto, elemField, pageSize *descriptorpb.FieldDescriptorProto, pt *iterType) error {
inType := g.descInfo.Type[m.GetInputType()].(*descriptorpb.DescriptorProto)
outType := g.descInfo.Type[m.GetOutputType()].(*descriptorpb.DescriptorProto)
Expand All @@ -272,13 +375,6 @@ func (g *generator) pagingCall(servName string, m *descriptorpb.MethodDescriptor
return err
}

max := "math.MaxInt32"
ps := "int32(pageSize)"
if isOptional(inType, "page_size") {
max = fmt.Sprintf("proto.Int32(%s)", max)
ps = fmt.Sprintf("proto.Int32(%s)", ps)
}

tok := "pageToken"
if isOptional(inType, "page_token") {
tok = fmt.Sprintf("proto.String(%s)", tok)
Expand All @@ -290,11 +386,10 @@ func (g *generator) pagingCall(servName string, m *descriptorpb.MethodDescriptor

g.insertRequestHeaders(m, grpc)
g.appendCallOpts(m)
pageSizeFieldName := snakeToCamel(pageSize.GetName())
p("it := &%s{}", pt.iterTypeName)
p("req = proto.Clone(req).(*%s.%s)", inSpec.Name, inType.GetName())
p("it.InternalFetch = func(pageSize int, pageToken string) ([]%s, string, error) {", pt.elemTypeName)
g.internalFetchSetup(outType, outSpec, tok, pageSizeFieldName, max, ps)
g.internalFetchSetup(outType, outSpec, pageSize, tok)
p(" err := gax.Invoke(ctx, func(ctx context.Context, settings gax.CallSettings) error {")
p(" var err error")
p(" resp, err = %s", g.grpcStubCall(m))
Expand All @@ -308,7 +403,7 @@ func (g *generator) pagingCall(servName string, m *descriptorpb.MethodDescriptor
elems := g.maybeSortMapPage(elemField, pt)
p(" return %s, resp.GetNextPageToken(), nil", elems)
p("}")
g.makeFetchAndIterUpdate(pageSizeFieldName)
g.makeFetchAndIterUpdate(pageSize)
p("}")
p("")

Expand Down
Loading