Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for gRPC-Web text encoding #150

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ const (
// streaming endpoints. However, bidirectional streams are only supported
// when combined with HTTP/2.
ProtocolGRPCWeb
// ProtocolGRPCWebText is a variant of ProtocolGRPCWeb that uses base64
// text encoding for the request and response bodies.
//
// This protocol is not supported on the server side.
Copy link
Member

@jhump jhump Dec 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we even need/want an exported constant for it?

It would likely be trivial to implement as a separate, simpler middleware handler. Then we have the option to potentially move it to connect-go in the future, to enable grpc-web-text there. The Transcoder could always delegate to that middleware first since it should basically be free/pass-through for non-grpc-web-text requests. Then we don't need a protocol constant for it at all since the rest of the logic in here doesn't need to know about it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would also be relatively easy to test that separate middleware by copying it over into the connect conformance referenceserver and then see if the JS grpc-web-client starts passing the server stream tests that it currently opts out of. We'd have to enable grpc-web-text in that client, but that's as simple as changing the mode to "grpcwebtext" and re-generating.

(Admittedly, that's not a way to unit test this repo's implementation of it. But the tests you've added here could still be used of course.)

ProtocolGRPCWebText
// ProtocolREST indicates the REST+JSON protocol. This protocol often
// requires non-trivial transformations between HTTP requests and responses
// and Protobuf request and response messages.
Expand Down Expand Up @@ -96,6 +101,8 @@ func (p Protocol) serverHandler(op *operation) serverProtocolHandler {
return grpcServerProtocol{}
case ProtocolGRPCWeb:
return grpcWebServerProtocol{}
case ProtocolGRPCWebText:
return nil // gRPC-Web Text is not supported on the server.
case ProtocolREST:
return restServerProtocol{}
default:
Expand Down
208 changes: 206 additions & 2 deletions protocol_grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package vanguard

import (
"bytes"
"encoding/base64"
"encoding/binary"
"errors"
"fmt"
Expand Down Expand Up @@ -151,7 +152,7 @@ func (g grpcServerProtocol) String() string {
return g.protocol().String()
}

// grpcClientProtocol implements the gRPC protocol for
// grpcWebClientProtocol implements the gRPC-Web protocol for
// processing RPCs received from the client.
type grpcWebClientProtocol struct{}

Expand Down Expand Up @@ -212,7 +213,7 @@ func (g grpcWebClientProtocol) String() string {
return g.protocol().String()
}

// grpcServerProtocol implements the gRPC-Web protocol for
// grpcWebServerProtocol implements the gRPC-Web protocol for
// sending RPCs to the server handler.
type grpcWebServerProtocol struct{}

Expand Down Expand Up @@ -278,6 +279,73 @@ func (g grpcWebServerProtocol) String() string {
return g.protocol().String()
}

// grpcWebTextClientProtocol implements the gRPC-Web protocol for
// processing RPCs received from the client.
type grpcWebTextClientProtocol struct{}

var _ clientProtocolHandler = grpcWebTextClientProtocol{}
var _ clientBodyPreparer = grpcWebTextClientProtocol{}
var _ envelopedProtocolHandler = grpcWebTextClientProtocol{}

func (g grpcWebTextClientProtocol) protocol() Protocol {
return ProtocolGRPCWebText
}

func (g grpcWebTextClientProtocol) acceptsStreamType(_ *operation, _ connect.StreamType) bool {
return true
}

func (g grpcWebTextClientProtocol) requestNeedsPrep(op *operation) bool {
// Hijack the request and response body to handle base64 encoding/decoding.
op.request.Body = struct {
io.Reader
io.Closer
}{
Reader: newGRPCWebTextReader(op.request.Body),
Closer: op.request.Body,
}
op.writer = newGRPCWebTextResponseWriter(op.writer)
return false
}

func (g grpcWebTextClientProtocol) prepareUnmarshalledRequest(_ *operation, _ []byte, _ proto.Message) error {
// requestNeedsPrep always returns false.
return errors.New("gRPC-Web text prepareUnmarshalledRequest not implemented")
}

func (g grpcWebTextClientProtocol) responseNeedsPrep(_ *operation) bool {
return false // Setup in requestNeedsPrep.
}

func (g grpcWebTextClientProtocol) prepareMarshalledResponse(_ *operation, _ []byte, _ proto.Message, _ http.Header) ([]byte, error) {
// responseNeedsPrep always returns false.
return nil, errors.New("gRPC-Web text prepareMarshalledResponse not implemented")
}

func (g grpcWebTextClientProtocol) extractProtocolRequestHeaders(_ *operation, headers http.Header) (requestMeta, error) {
return grpcExtractRequestMeta("application/grpc-web-text", "application/grpc-web-text+", headers)
}

func (g grpcWebTextClientProtocol) addProtocolResponseHeaders(meta responseMeta, headers http.Header) int {
return grpcAddResponseMeta("application/grpc-web-text+", meta, headers)
}

func (g grpcWebTextClientProtocol) encodeEnd(op *operation, end *responseEnd, writer io.Writer, wasInHeaders bool) http.Header {
return grpcWebClientProtocol{}.encodeEnd(op, end, writer, wasInHeaders)
}

func (g grpcWebTextClientProtocol) decodeEnvelope(bytes envelopeBytes) (envelope, error) {
return grpcServerProtocol{}.decodeEnvelope(bytes)
}

func (g grpcWebTextClientProtocol) encodeEnvelope(env envelope) envelopeBytes {
return grpcWebClientProtocol{}.encodeEnvelope(env)
}

func (g grpcWebTextClientProtocol) String() string {
return g.protocol().String()
}

func grpcExtractRequestMeta(contentTypeShort, contentTypePrefix string, headers http.Header) (requestMeta, error) {
var reqMeta requestMeta
if err := grpcExtractTimeoutFromHeaders(headers, &reqMeta); err != nil {
Expand Down Expand Up @@ -641,3 +709,139 @@ func grpcTimeoutUnitLookup(unit byte) time.Duration {
return 0
}
}

// grpcWebTextResponseWriter wraps an http.ResponseWriter and base64-encodes
// the response body for grpc-web-text.
type grpcWebTextResponseWriter struct {
http.ResponseWriter

encoder io.WriteCloser
}

// newGRPCWebTextResponseWriter creates a new grpcWebTextResponseWriter.
func newGRPCWebTextResponseWriter(w http.ResponseWriter) *grpcWebTextResponseWriter {
return &grpcWebTextResponseWriter{
ResponseWriter: w,
}
}

func (w *grpcWebTextResponseWriter) Write(p []byte) (int, error) {
if w.encoder == nil {
w.encoder = base64.NewEncoder(base64.StdEncoding, w.ResponseWriter)
}
return w.encoder.Write(p)
}

func (w *grpcWebTextResponseWriter) Flush() {
// Close the base64 encoder to flush any remaining data. This may be
// called multiple times as needed, padding is output on Close.
// See https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-WEB.md
if w.encoder != nil {
_ = w.encoder.Close()
w.encoder = nil
}
// Some clients may expect a newline after each message. This does not
// affect the base64 encoding.
_, _ = w.ResponseWriter.Write([]byte{'\n'})
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
}

// Unwrap returns the underlying http.ResponseWriter.
func (w grpcWebTextResponseWriter) Unwrap() http.ResponseWriter {
return w.ResponseWriter
}

// grpcWebTextReader wraps an io.Reader and base64-decodes the response body
// for grpc-web-text.
type grpcWebTextReader struct {
delegate io.Reader
start, end int
inputBuffer [512]byte
outputBuffer [384]byte
output []byte
}

// newGRPCWebTextReader creates a new grpcWebTextReader.
func newGRPCWebTextReader(r io.Reader) *grpcWebTextReader {
return &grpcWebTextReader{
delegate: r,
}
}

// Read reads base64-encoded data from the underlying reader and decodes it.
// The reader handles padding characters within the stream. It will ensure that
// padding is always at the end of a chunk of data when processing the chunk.
func (r *grpcWebTextReader) Read(dst []byte) (int, error) {
if len(dst) == 0 {
return 0, nil
}
if len(r.output) > 0 {
size := copy(dst, r.output)
r.output = r.output[size:]
return size, nil
}
// Read from the stream in 4-byte tokens.
for r.end-r.start < 4 {
size, err := r.readWithoutNewlines(r.inputBuffer[r.end:])
if size == 0 {
if err == nil {
err = io.ErrNoProgress
} else if errors.Is(err, io.EOF) && r.end > r.start {
// Non 4-byte chunk at the end of the stream.
err = io.ErrUnexpectedEOF
}
return 0, err
}
r.end += size
}
// Decode the next chunk of data.
length := ((r.end - r.start) / 4) * 4
dstLength := base64.StdEncoding.EncodedLen(len(r.outputBuffer))
chunkLength := min(dstLength, length)
input := r.inputBuffer[r.start : r.start+chunkLength]
// If we have padding, we split the stream at the padding and decode the
// chunk up to the padding.
if index := bytes.IndexRune(input, base64.StdPadding); index != -1 {
chunkLength = ((index + 4) / 4) * 4
input = input[:chunkLength]
}
output := r.outputBuffer[:]
size, err := base64.StdEncoding.Decode(output, input)
if err != nil {
return 0, err
}
r.start += chunkLength
if r.start == r.end {
r.start, r.end = 0, 0
}
r.output = output[:size]
size = copy(dst, r.output)
r.output = r.output[size:]
return size, err
}

// readWithoutNewlines reads from the underlying reader, skipping over any
// newline characters in the buffer. This follows the behavior of
// base64.NewDecoder.
func (r *grpcWebTextReader) readWithoutNewlines(dst []byte) (n int, err error) {
n, err = r.delegate.Read(dst)
for n > 0 {
offset := 0
for i, b := range dst[:n] {
if b != '\r' && b != '\n' {
if i != offset {
dst[offset] = b
}
offset++
}
}
if offset > 0 {
return offset, err
}
// Previous buffer entirely whitespace, read again.
n, err = r.delegate.Read(dst)
}
return n, err
}
75 changes: 75 additions & 0 deletions protocol_grpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ package vanguard
import (
"errors"
"fmt"
"io"
"math"
"net/http/httptest"
"strings"
"testing"
"testing/quick"
"time"
Expand Down Expand Up @@ -167,3 +169,76 @@ func compareErrors(t *testing.T, got, want *connect.Error) {
}
}
}

func TestGRPCWebTextResponseWriter(t *testing.T) {
t.Parallel()

rec := httptest.NewRecorder()
writer := newGRPCWebTextResponseWriter(rec)
writer.Header().Set("Content-Type", "application/grpc-web-text+proto")
_, err := writer.Write([]byte("Hello, 世界"))
require.NoError(t, err)
writer.Flush()
_, err = writer.Write([]byte("Hello, 世界"))
require.NoError(t, err)
writer.Flush()

assert.Equal(t, "SGVsbG8sIOS4lueVjA==\nSGVsbG8sIOS4lueVjA==\n", rec.Body.String())
assert.Equal(t, "application/grpc-web-text+proto", rec.Header().Get("Content-Type"))

out, err := io.ReadAll(newGRPCWebTextReader(strings.NewReader(rec.Body.String())))
require.NoError(t, err)
assert.Equal(t, "Hello, 世界Hello, 世界", string(out))
}

func TestGRPCWebTextReader(t *testing.T) {
t.Parallel()
for _, test := range []struct {
name, input, output string
}{
{"hello", "SGVsbG8sIOS4lueVjA==", "Hello, 世界"},
{"hello_duplicate", "SGVsbG8sIOS4lueVjA==SGVsbG8sIOS4lueVjA==", "Hello, 世界Hello, 世界"},
{"some_data", "c29tZSBkYXRhIHdpdGggACBhbmQg77u/", "some data with \x00 and \ufeff"},
{"ab", "QQ==Qg==", "AB"},
{"a_b", "Q\nQ=\r=Qg=\r=", "AB"},
{
"foobar",
"Zg==" + "Zm8=" + "Zm9v" + "Zm9vYg==" + "Zm9vYmE=" + "Zm9vYmFy",
"f" + "fo" + "foo" + "foob" + "fooba" + "foobar",
},
{
"RFC3548",
"FPucA9l+" + "FPucA9k=" + "FPucAw==",
"\x14\xfb\x9c\x03\xd9\x7e" + "\x14\xfb\x9c\x03\xd9" + "\x14\xfb\x9c\x03",
},
{
"wikipedia",
"c3VyZS4=" + "c3VyZQ==" + "c3Vy" + "c3U=" + "bGVhc3VyZS4=" + "ZWFzdXJlLg==" + "YXN1cmUu" + "c3VyZS4=",
"sure." + "sure" + "sur" + "su" + "leasure." + "easure." + "asure." + "sure.",
},
} {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
decoder := newGRPCWebTextReader(strings.NewReader(test.input))
b, err := io.ReadAll(decoder)
require.NoError(t, err)
output := string(b)
assert.Equal(t, test.output, output)
})
}
t.Run("partial_reads", func(t *testing.T) {
var buf [5]byte
decoder := newGRPCWebTextReader(strings.NewReader("SGVsbG8sIOS4lueVjA==SGVsbG8sIOS4lueVjA=="))
total := 0
for {
n, err := decoder.Read(buf[:])
if errors.Is(err, io.EOF) {
break
}
require.NoError(t, err)
total += n
}
assert.Equal(t, len("Hello, 世界Hello, 世界"), total)
})
}
3 changes: 3 additions & 0 deletions protocol_rest.go
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,9 @@ func restHTTPBodyRequest(op *operation) bool {
}

func restHTTPBodyResponse(op *operation) bool {
if op.restTarget == nil {
return false
}
return restIsHTTPBody(op.methodConf.descriptor.Output(), op.restTarget.responseBodyFields)
}

Expand Down
6 changes: 4 additions & 2 deletions transcoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,8 @@ func classifyRequest(req *http.Request) (clientProtocolHandler, url.Values) {
return grpcClientProtocol{}, nil
case contentType == "application/grpc-web" || strings.HasPrefix(contentType, "application/grpc-web+"):
return grpcWebClientProtocol{}, nil
case contentType == "application/grpc-web-text" || strings.HasPrefix(contentType, "application/grpc-web-text+"):
return grpcWebTextClientProtocol{}, nil
case strings.HasPrefix(contentType, "application/"):
connectVersion := req.Header["Connect-Protocol-Version"]
if len(connectVersion) == 1 && connectVersion[0] == "1" {
Expand Down Expand Up @@ -2177,10 +2179,10 @@ func asFlusher(respWriter http.ResponseWriter) http.Flusher {
// we can't use that since it isn't available prior to Go 1.21.
for {
switch typedWriter := respWriter.(type) {
case http.Flusher:
return typedWriter
case errorFlusher:
return flusherNoError{f: typedWriter}
case http.Flusher:
return typedWriter
case interface{ Unwrap() http.ResponseWriter }:
respWriter = typedWriter.Unwrap()
default:
Expand Down
Loading