diff --git a/params_test.go b/params_test.go index 8404a44..ddfa329 100644 --- a/params_test.go +++ b/params_test.go @@ -388,6 +388,11 @@ func TestSetParameter(t *testing.T) { msg.Set(field, value) return msg }(), + }, { + fields: "unknownField", + value: "hello", + want: &testv1.ParameterValues{}, + wantErr: "unknown field in field path \"unknownField\": element \"unknownField\" does not correspond to any field of type vanguard.test.v1.ParameterValues", }} for _, testCase := range testCases { testCase := testCase diff --git a/protocol_http_test.go b/protocol_http_test.go index b7c0d32..803bcf6 100644 --- a/protocol_http_test.go +++ b/protocol_http_test.go @@ -233,6 +233,13 @@ func TestHTTPEncodePathValues(t *testing.T) { tmpl: "/v2/**", wantPath: "/v2/**", wantQuery: url.Values{}, + }, { + input: &testv1.ParameterValues{StringValue: "books/1"}, + reqFieldPath: "unknownQueryParam", + tmpl: "/v1/{string_value=books/*}:get", + wantPath: "/v1/books/1:get", + wantQuery: url.Values{}, + wantErr: "unknown field in field path \"unknownQueryParam\": element \"unknownQueryParam\" does not correspond to any field of type vanguard.test.v1.ParameterValues", }} for _, testCase := range testCases { testCase := testCase diff --git a/protocol_rest.go b/protocol_rest.go index 9512976..110ec8c 100644 --- a/protocol_rest.go +++ b/protocol_rest.go @@ -16,6 +16,7 @@ package vanguard import ( "bytes" + "errors" "fmt" "io" "net/http" @@ -160,12 +161,17 @@ func (r restClientProtocol) prepareUnmarshalledRequest(op *operation, src []byte return err } } + // And finally from the query string: + discardUnknownQueryParams := op.methodConf.serviceOptions.restUnmarshalOptions.DiscardUnknownQueryParams for fieldPath, values := range op.queryValues() { fields, err := resolvePathToFieldDescriptors( msg.Descriptor(), fieldPath, true, ) if err != nil { + if discardUnknownQueryParams && errors.Is(err, errUnknownField) { + continue + } return err } for _, value := range values { diff --git a/router.go b/router.go index 60ab258..2ea07d9 100644 --- a/router.go +++ b/router.go @@ -24,6 +24,10 @@ import ( "google.golang.org/protobuf/reflect/protoreflect" ) +var ( + errUnknownField = errors.New("unknown field") +) + // routeTrie is a prefix trie of valid REST URI paths to route targets. // It supports evaluation of variables as the path is matched, for // interpolating parts of the URI path into an RPC request field. The @@ -375,8 +379,8 @@ func resolvePathToFieldDescriptors( if field == nil { field = fields.ByName(protoreflect.Name(part)) if field == nil { - return nil, fmt.Errorf("in field path %q: element %q does not correspond to any field of type %s", - path, part, msg.FullName()) + return nil, fmt.Errorf("%w in field path %q: element %q does not correspond to any field of type %s", + errUnknownField, path, part, msg.FullName()) } } result[i] = field diff --git a/vanguard.go b/vanguard.go index f9cf3f2..e69cd99 100644 --- a/vanguard.go +++ b/vanguard.go @@ -393,6 +393,20 @@ func WithMaxGetURLBytes(limit uint32) ServiceOption { }) } +// WithRESTUnmarshalOptions returns a service option that sets the unmarshal options for use with the REST protocol. +func WithRESTUnmarshalOptions(options RESTUnmarshalOptions) ServiceOption { + return serviceOptionFunc(func(opts *serviceOptions) { + opts.restUnmarshalOptions = options + }) +} + +// RESTUnmarshalOptions contains options for unmarshalling REST requests. +type RESTUnmarshalOptions struct { + // If DiscardUnknownQueryParams is true, any query parameters in a request that do not correspond to a field in the + // request message will be ignored. If false, such query parameters will cause an error. Defaults to false. + DiscardUnknownQueryParams bool +} + type transcoderOptions struct { defaultServiceOptions []ServiceOption rules []*annotations.HttpRule @@ -420,6 +434,7 @@ type serviceOptions struct { preferredCodec string maxMsgBufferBytes uint32 maxGetURLBytes uint32 + restUnmarshalOptions RESTUnmarshalOptions } type methodConfig struct { diff --git a/vanguard_restxrpc_test.go b/vanguard_restxrpc_test.go index 90fbbe0..7b97a4d 100644 --- a/vanguard_restxrpc_test.go +++ b/vanguard_restxrpc_test.go @@ -94,6 +94,9 @@ func TestMux_RESTxRPC(t *testing.T) { } else { opts = append(opts, WithNoTargetCompression()) } + + opts = append(opts, WithRESTUnmarshalOptions(RESTUnmarshalOptions{DiscardUnknownQueryParams: true})) + svcHandler := protocolAssertMiddleware(protocol, codec, compression, handler) services := make([]*Service, len(serviceNames)) @@ -524,6 +527,44 @@ func TestMux_RESTxRPC(t *testing.T) { "Content-Type": []string{"text/plain"}, }, }, + }, { + name: "DiscardUnknownQueryParams", + input: input{ + method: http.MethodGet, + path: "/message.txt:download?unknownParam=1", + }, + stream: testStream{ + method: testv1connect.ContentServiceDownloadProcedure, + msgs: []testMsg{ + {in: &testMsgIn{ + msg: &testv1.DownloadRequest{ + Filename: "message.txt", + }, + }}, + {out: &testMsgOut{ + msg: &testv1.DownloadResponse{ + File: &httpbody.HttpBody{ + ContentType: "text/plain", + Data: []byte("hello"), + }, + }, + }}, + {out: &testMsgOut{ + msg: &testv1.DownloadResponse{ + File: &httpbody.HttpBody{ + Data: []byte(" world"), + }, + }, + }}, + }, + }, + output: output{ + code: http.StatusOK, + rawBody: `hello world`, + meta: http.Header{ + "Content-Type": []string{"text/plain"}, + }, + }, }} type testOpt struct {