diff --git a/codec.go b/codec.go index 568c56ba..5d3f19b3 100644 --- a/codec.go +++ b/codec.go @@ -22,8 +22,9 @@ import ( ) const ( - codecNameProto = "proto" - codecNameJSON = "json" + codecNameProto = "proto" + codecNameJSON = "json" + codecNameJSONCharsetUTF8 = codecNameJSON + "; charset=utf-8" ) // Codec marshals structs (typically generated from a schema) to and from bytes. @@ -70,11 +71,13 @@ func (c *protoBinaryCodec) Unmarshal(data []byte, message any) error { return proto.Unmarshal(data, protoMessage) } -type protoJSONCodec struct{} +type protoJSONCodec struct { + name string +} var _ Codec = (*protoJSONCodec)(nil) -func (c *protoJSONCodec) Name() string { return codecNameJSON } +func (c *protoJSONCodec) Name() string { return c.name } func (c *protoJSONCodec) Marshal(message any) ([]byte, error) { protoMessage, ok := message.(proto.Message) diff --git a/error_writer.go b/error_writer.go index efca9da1..1c62760f 100644 --- a/error_writer.go +++ b/error_writer.go @@ -83,7 +83,7 @@ func NewErrorWriter(opts ...HandlerOption) *ErrorWriter { // IsSupported checks whether a request is using one of the ErrorWriter's // supported RPC protocols. func (w *ErrorWriter) IsSupported(request *http.Request) bool { - ctype := request.Header.Get(headerContentType) + ctype := canonicalizeContentType(request.Header.Get(headerContentType)) _, ok := w.allContentTypes[ctype] return ok } @@ -94,7 +94,7 @@ func (w *ErrorWriter) IsSupported(request *http.Request) bool { // // Write does not read or close the request body. func (w *ErrorWriter) Write(response http.ResponseWriter, request *http.Request, err error) error { - ctype := request.Header.Get(headerContentType) + ctype := canonicalizeContentType(request.Header.Get(headerContentType)) if _, ok := w.unaryConnectContentTypes[ctype]; ok { // Unary errors are always JSON. response.Header().Set(headerContentType, connectUnaryContentTypeJSON) diff --git a/handler.go b/handler.go index 14950a22..86eda1b8 100644 --- a/handler.go +++ b/handler.go @@ -172,7 +172,7 @@ func (h *Handler) ServeHTTP(responseWriter http.ResponseWriter, request *http.Re } // Find our implementation of the RPC protocol in use. - contentType := request.Header.Get("Content-Type") + contentType := canonicalizeContentType(request.Header.Get("Content-Type")) var protocolHandler protocolHandler for _, handler := range h.protocolHandlers { if _, ok := handler.ContentTypes()[contentType]; ok { @@ -187,6 +187,7 @@ func (h *Handler) ServeHTTP(responseWriter http.ResponseWriter, request *http.Re } // Establish a stream and serve the RPC. + request.Header.Set("Content-Type", contentType) // prefer canonicalized value ctx, cancel, timeoutErr := protocolHandler.SetTimeout(request) if timeoutErr != nil { ctx = request.Context() @@ -235,7 +236,7 @@ func newHandlerConfig(procedure string, options []HandlerOption) *handlerConfig BufferPool: newBufferPool(), } withProtoBinaryCodec().applyToHandler(&config) - withProtoJSONCodec().applyToHandler(&config) + withProtoJSONCodecs().applyToHandler(&config) withGzip().applyToHandler(&config) for _, opt := range options { opt.applyToHandler(&config) diff --git a/handler_ext_test.go b/handler_ext_test.go index e36da1a8..7135842c 100644 --- a/handler_ext_test.go +++ b/handler_ext_test.go @@ -74,15 +74,50 @@ func TestHandler_ServeHTTP(t *testing.T) { assert.Equal(t, resp.Header.Get("Accept-Post"), strings.Join([]string{ "application/grpc", "application/grpc+json", + "application/grpc+json; charset=utf-8", "application/grpc+proto", "application/grpc-web", "application/grpc-web+json", + "application/grpc-web+json; charset=utf-8", "application/grpc-web+proto", "application/json", + "application/json; charset=utf-8", "application/proto", }, ", ")) }) + t.Run("charset_in_content_type_header", func(t *testing.T) { + t.Parallel() + req, err := http.NewRequestWithContext( + context.Background(), + http.MethodPost, + server.URL+pingProcedure, + strings.NewReader("{}"), + ) + assert.Nil(t, err) + req.Header.Set("Content-Type", "application/json;Charset=utf-8") + resp, err := client.Do(req) + assert.Nil(t, err) + defer resp.Body.Close() + assert.Equal(t, resp.StatusCode, http.StatusOK) + }) + + t.Run("unsupported_charset", func(t *testing.T) { + t.Parallel() + req, err := http.NewRequestWithContext( + context.Background(), + http.MethodPost, + server.URL+pingProcedure, + strings.NewReader("{}"), + ) + assert.Nil(t, err) + req.Header.Set("Content-Type", "application/json; charset=shift-jis") + resp, err := client.Do(req) + assert.Nil(t, err) + defer resp.Body.Close() + assert.Equal(t, resp.StatusCode, http.StatusUnsupportedMediaType) + }) + t.Run("unsupported_content_encoding", func(t *testing.T) { t.Parallel() req, err := http.NewRequestWithContext( diff --git a/option.go b/option.go index 9697245c..f86d75fe 100644 --- a/option.go +++ b/option.go @@ -78,7 +78,7 @@ func WithGRPCWeb() ClientOption { // lowerCamelCase, zero values are omitted, missing required fields are errors, // enums are emitted as strings, etc. func WithProtoJSON() ClientOption { - return WithCodec(&protoJSONCodec{}) + return WithCodec(&protoJSONCodec{codecNameJSON}) } // WithSendCompression configures the client to use the specified algorithm to @@ -452,6 +452,9 @@ func withProtoBinaryCodec() Option { return WithCodec(&protoBinaryCodec{}) } -func withProtoJSONCodec() HandlerOption { - return WithCodec(&protoJSONCodec{}) +func withProtoJSONCodecs() HandlerOption { + return WithHandlerOptions( + WithCodec(&protoJSONCodec{codecNameJSON}), + WithCodec(&protoJSONCodec{codecNameJSONCharsetUTF8}), + ) } diff --git a/protocol.go b/protocol.go index 9f0cb8bc..e4b9f416 100644 --- a/protocol.go +++ b/protocol.go @@ -19,6 +19,7 @@ import ( "errors" "fmt" "io" + "mime" "net/http" "net/url" "sort" @@ -290,3 +291,11 @@ func flushResponseWriter(w http.ResponseWriter) { f.Flush() } } + +func canonicalizeContentType(ct string) string { + base, params, err := mime.ParseMediaType(ct) + if err != nil { + return ct + } + return mime.FormatMediaType(base, params) +}