From b6ff498362910b9e6d58acadc13935cc97451896 Mon Sep 17 00:00:00 2001 From: Akshay Shah Date: Wed, 17 May 2023 10:48:41 -0400 Subject: [PATCH] Expose HTTP method in unary handlers Fixes #502 to expose the HTTP method in handlers. --- client_ext_test.go | 2 +- error_not_modified_example_test.go | 16 +++++++--------- protocol.go | 7 +++++++ 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/client_ext_test.go b/client_ext_test.go index d0c7ccc5..7ce97e5c 100644 --- a/client_ext_test.go +++ b/client_ext_test.go @@ -198,7 +198,7 @@ type notModifiedPingServer struct { func (s *notModifiedPingServer) Ping( _ context.Context, req *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { - if len(req.Peer().Query) > 0 && req.Header().Get("If-None-Match") == s.etag { + if req.HTTPMethod() == http.MethodGet && req.Header().Get("If-None-Match") == s.etag { return nil, connect.NewNotModifiedError(http.Header{"Etag": []string{s.etag}}) } resp := connect.NewResponse(&pingv1.PingResponse{}) diff --git a/error_not_modified_example_test.go b/error_not_modified_example_test.go index c6f6e0d9..d8f1e9f7 100644 --- a/error_not_modified_example_test.go +++ b/error_not_modified_example_test.go @@ -43,16 +43,14 @@ func (*ExampleCachingPingServer) Ping( }) // Our hashing logic is simple: we use the number in the PingResponse. hash := fmt.Sprint(resp.Msg.Number) - // If the request was an HTTP GET (which always has URL query parameters), - // we'll need to check if the client already has the response cached. - if len(req.Peer().Query) > 0 { - if req.Header().Get("If-None-Match") == hash { - return nil, connect.NewNotModifiedError(http.Header{ - "Etag": []string{hash}, - }) - } - resp.Header().Set("Etag", hash) + // If the request was an HTTP GET, we'll need to check if the client already + // has the response cached. + if req.HTTPMethod() == http.MethodGet && req.Header().Get("If-None-Match") == hash { + return nil, connect.NewNotModifiedError(http.Header{ + "Etag": []string{hash}, + }) } + resp.Header().Set("Etag", hash) return resp, nil } diff --git a/protocol.go b/protocol.go index c698f706..b5954290 100644 --- a/protocol.go +++ b/protocol.go @@ -181,6 +181,13 @@ func (hc *errorTranslatingHandlerConnCloser) Close(err error) error { return hc.fromWire(closeErr) } +func (hc *errorTranslatingHandlerConnCloser) getHTTPMethod() string { + if methoder, ok := hc.handlerConnCloser.(interface{ getHTTPMethod() string }); ok { + return methoder.getHTTPMethod() + } + return http.MethodPost +} + // errorTranslatingClientConn wraps a StreamingClientConn to make sure that we always // return coded errors from clients. //