Skip to content

Commit

Permalink
Add method-based negotation.
Browse files Browse the repository at this point in the history
  • Loading branch information
jchadwick-buf committed Feb 21, 2023
1 parent caee276 commit 2c378b8
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 17 deletions.
45 changes: 30 additions & 15 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ type Handler struct {
spec Spec
implementation StreamingHandlerFunc
protocolHandlers []protocolHandler
allowMethod string // Allow header
acceptPost string // Accept-Post header
}

Expand Down Expand Up @@ -86,6 +87,7 @@ func NewUnaryHandler[Req, Res any](
spec: config.newSpec(StreamTypeUnary),
implementation: implementation,
protocolHandlers: protocolHandlers,
allowMethod: sortedAllowMethodValue(protocolHandlers),
acceptPost: sortedAcceptPostValue(protocolHandlers),
}
}
Expand Down Expand Up @@ -182,26 +184,38 @@ func (h *Handler) ServeHTTP(responseWriter http.ResponseWriter, request *http.Re
return
}

// The gRPC-HTTP2, gRPC-Web, and Connect protocols are all POST-only.
if request.Method != http.MethodPost {
responseWriter.Header().Set("Allow", http.MethodPost)
responseWriter.WriteHeader(http.StatusMethodNotAllowed)
return
var protocolHandlers []protocolHandler
for _, handler := range h.protocolHandlers {
if _, ok := handler.Methods()[request.Method]; ok {
protocolHandlers = append(protocolHandlers, handler)
}
}

// Find our implementation of the RPC protocol in use.
contentType := canonicalizeContentType(getHeaderCanonical(request.Header, headerContentType))

// Find our implementation of the RPC protocol in use.
var protocolHandler protocolHandler
for _, handler := range h.protocolHandlers {
if _, ok := handler.ContentTypes()[contentType]; ok {
protocolHandler = handler
break
}
}
if protocolHandler == nil {
responseWriter.Header().Set("Accept-Post", h.acceptPost)
responseWriter.WriteHeader(http.StatusUnsupportedMediaType)
switch len(protocolHandlers) {
case 0:
responseWriter.Header().Set("Allow", h.allowMethod)
responseWriter.WriteHeader(http.StatusMethodNotAllowed)
return

case 1:
protocolHandler = protocolHandlers[0]

default:
for _, handler := range protocolHandlers {
if _, ok := handler.ContentTypes()[contentType]; ok {
protocolHandler = handler
break
}
}
if protocolHandler == nil {
responseWriter.Header().Set("Accept-Post", h.acceptPost)
responseWriter.WriteHeader(http.StatusUnsupportedMediaType)
return
}
}

// Establish a stream and serve the RPC.
Expand Down Expand Up @@ -316,6 +330,7 @@ func newStreamHandler(
spec: config.newSpec(streamType),
implementation: implementation,
protocolHandlers: protocolHandlers,
allowMethod: sortedAllowMethodValue(protocolHandlers),
acceptPost: sortedAcceptPostValue(protocolHandlers),
}
}
18 changes: 18 additions & 0 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ type protocolHandlerParams struct {
// Handler is the server side of a protocol. HTTP handlers typically support
// multiple protocols, codecs, and compressors.
type protocolHandler interface {
// Methods is the list of HTTP methods the protocol can handle.
Methods() map[string]struct{}

// ContentTypes is the set of HTTP Content-Types that the protocol can
// handle.
ContentTypes() map[string]struct{}
Expand Down Expand Up @@ -223,6 +226,21 @@ func sortedAcceptPostValue(handlers []protocolHandler) string {
return strings.Join(accept, ", ")
}

func sortedAllowMethodValue(handlers []protocolHandler) string {
methods := make(map[string]struct{})
for _, handler := range handlers {
for method := range handler.Methods() {
methods[method] = struct{}{}
}
}
allow := make([]string, 0, len(methods))
for ct := range methods {
allow = append(allow, ct)
}
sort.Strings(allow)
return strings.Join(allow, ", ")
}

func isCommaOrSpace(c rune) bool {
return c == ',' || c == ' '
}
Expand Down
12 changes: 11 additions & 1 deletion protocol_connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ type protocolConnect struct{}

// NewHandler implements protocol, so it must return an interface.
func (*protocolConnect) NewHandler(params *protocolHandlerParams) protocolHandler {
methods := make(map[string]struct{})
methods[http.MethodPost] = struct{}{}

contentTypes := make(map[string]struct{})
for _, name := range params.Codecs.Names() {
if params.Spec.StreamType == StreamTypeUnary {
Expand All @@ -66,8 +69,10 @@ func (*protocolConnect) NewHandler(params *protocolHandlerParams) protocolHandle
}
contentTypes[canonicalizeContentType(connectStreamingContentTypePrefix+name)] = struct{}{}
}

return &connectHandler{
protocolHandlerParams: *params,
methods: methods,
accept: contentTypes,
}
}
Expand All @@ -87,7 +92,12 @@ func (*protocolConnect) NewClient(params *protocolClientParams) (protocolClient,
type connectHandler struct {
protocolHandlerParams

accept map[string]struct{}
methods map[string]struct{}
accept map[string]struct{}
}

func (h *connectHandler) Methods() map[string]struct{} {
return h.methods
}

func (h *connectHandler) ContentTypes() map[string]struct{} {
Expand Down
9 changes: 8 additions & 1 deletion protocol_grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@ var (
{time.Minute, 'M'},
{time.Hour, 'H'},
}
grpcTimeoutUnitLookup = make(map[byte]time.Duration)
grpcTimeoutUnitLookup = make(map[byte]time.Duration)
grpcAllowedMethods = map[string]struct{}{
http.MethodPost: {},
}
errTrailersWithoutGRPCStatus = fmt.Errorf("gRPC protocol error: no %s trailer", grpcHeaderStatus)

// defaultGrpcUserAgent follows
Expand Down Expand Up @@ -132,6 +135,10 @@ type grpcHandler struct {
accept map[string]struct{}
}

func (g *grpcHandler) Methods() map[string]struct{} {
return grpcAllowedMethods
}

func (g *grpcHandler) ContentTypes() map[string]struct{} {
return g.accept
}
Expand Down

0 comments on commit 2c378b8

Please sign in to comment.