Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

extend the HTTP/3 API for WebTransport support #3362

Merged
merged 8 commits into from
Apr 16, 2022
46 changes: 40 additions & 6 deletions http3/body.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,28 @@
package http3

import (
"context"
"fmt"
"io"

"github.com/lucas-clemente/quic-go"
)

type StreamCreator interface {
OpenStream() (quic.Stream, error)
OpenStreamSync(context.Context) (quic.Stream, error)
OpenUniStream() (quic.SendStream, error)
OpenUniStreamSync(context.Context) (quic.SendStream, error)
}

var _ StreamCreator = quic.Connection(nil)

// A Hijacker allows hijacking of the stream creating part of a quic.Session from a http.Response.Body.
// It is used by WebTransport to create WebTransport streams after a session has been established.
type Hijacker interface {
StreamCreator() StreamCreator
}

// The body of a http.Request or http.Response.
type body struct {
str quic.Stream
Expand All @@ -24,21 +40,35 @@ type body struct {

var _ io.ReadCloser = &body{}

type hijackableBody struct {
body
conn quic.Connection // only needed to implement Hijacker
}

var _ Hijacker = &hijackableBody{}

func newRequestBody(str quic.Stream, onFrameError func()) *body {
return &body{
str: str,
onFrameError: onFrameError,
}
}

func newResponseBody(str quic.Stream, done chan<- struct{}, onFrameError func()) *body {
return &body{
str: str,
onFrameError: onFrameError,
reqDone: done,
func newResponseBody(str quic.Stream, conn quic.Connection, done chan<- struct{}, onFrameError func()) *hijackableBody {
return &hijackableBody{
body: body{
str: str,
onFrameError: onFrameError,
reqDone: done,
},
conn: conn,
}
}

func (r *hijackableBody) StreamCreator() StreamCreator {
return r.conn
}

func (r *body) Read(b []byte) (int, error) {
n, err := r.readImpl(b)
if err != nil {
Expand All @@ -51,7 +81,7 @@ func (r *body) readImpl(b []byte) (int, error) {
if r.bytesRemainingInFrame == 0 {
parseLoop:
for {
frame, err := parseNextFrame(r.str)
frame, err := parseNextFrame(r.str, nil)
if err != nil {
return 0, err
}
Expand Down Expand Up @@ -90,6 +120,10 @@ func (r *body) requestDone() {
r.reqDoneClosed = true
}

func (r *body) StreamID() quic.StreamID {
return r.str.StreamID()
}

func (r *body) Close() error {
r.requestDone()
// If the EOF was read, CancelRead() is a no-op.
Expand Down
4 changes: 2 additions & 2 deletions http3/body_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func (t bodyType) String() string {

var _ = Describe("Body", func() {
var (
rb *body
rb io.ReadCloser
str *mockquic.MockStream
buf *bytes.Buffer
reqDone chan struct{}
Expand Down Expand Up @@ -68,7 +68,7 @@ var _ = Describe("Body", func() {
rb = newRequestBody(str, errorCb)
case bodyTypeResponse:
reqDone = make(chan struct{})
rb = newResponseBody(str, reqDone, errorCb)
rb = newResponseBody(str, nil, reqDone, errorCb)
}
})

Expand Down
41 changes: 36 additions & 5 deletions http3/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ type roundTripperOpts struct {
DisableCompression bool
EnableDatagram bool
MaxHeaderBytes int64
AdditionalSettings map[uint64]uint64
StreamHijacker func(FrameType, quic.Connection, quic.Stream) (hijacked bool, err error)
}

// client is a HTTP3 client doing requests
Expand Down Expand Up @@ -74,7 +76,9 @@ func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, con
if len(conf.Versions) != 1 {
return nil, errors.New("can only use a single QUIC version for dialing a HTTP/3 connection")
}
conf.MaxIncomingStreams = -1 // don't allow any bidirectional streams
if conf.MaxIncomingStreams == 0 {
conf.MaxIncomingStreams = -1 // don't allow any bidirectional streams
}
conf.EnableDatagrams = opts.EnableDatagram
logger := utils.DefaultLogger.WithPrefix("h3 client")

Expand Down Expand Up @@ -117,6 +121,9 @@ func (c *client) dial(ctx context.Context) error {
}
}()

if c.opts.StreamHijacker != nil {
go c.handleBidirectionalStreams()
}
go c.handleUnidirectionalStreams()
return nil
}
Expand All @@ -130,11 +137,35 @@ func (c *client) setupConn() error {
buf := &bytes.Buffer{}
quicvarint.Write(buf, streamTypeControlStream)
// send the SETTINGS frame
(&settingsFrame{Datagram: c.opts.EnableDatagram}).Write(buf)
(&settingsFrame{Datagram: c.opts.EnableDatagram, Other: c.opts.AdditionalSettings}).Write(buf)
_, err = str.Write(buf.Bytes())
return err
}

func (c *client) handleBidirectionalStreams() {
for {
str, err := c.conn.AcceptStream(context.Background())
if err != nil {
c.logger.Debugf("accepting bidirectional stream failed: %s", err)
return
}
go func(str quic.Stream) {
for {
_, err := parseNextFrame(str, func(ft FrameType) (processed bool, err error) {
return c.opts.StreamHijacker(ft, c.conn, str)
})
if err == errHijacked {
return
}
if err != nil {
c.logger.Debugf("error handling stream: %s", err)
}
c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "received HTTP/3 frame on bidirectional stream")
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to return here to avoid an infinite loop?

}(str)
}
}

func (c *client) handleUnidirectionalStreams() {
for {
str, err := c.conn.AcceptUniStream(context.Background())
Expand Down Expand Up @@ -164,7 +195,7 @@ func (c *client) handleUnidirectionalStreams() {
str.CancelRead(quic.StreamErrorCode(errorStreamCreationError))
return
}
f, err := parseNextFrame(str)
f, err := parseNextFrame(str, nil)
if err != nil {
c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameError), "")
return
Expand Down Expand Up @@ -275,7 +306,7 @@ func (c *client) doRequest(
return nil, newStreamError(errorInternalError, err)
}

frame, err := parseNextFrame(str)
frame, err := parseNextFrame(str, nil)
if err != nil {
return nil, newStreamError(errorFrameError, err)
}
Expand Down Expand Up @@ -316,7 +347,7 @@ func (c *client) doRequest(
res.Header.Add(hf.Name, hf.Value)
}
}
respBody := newResponseBody(str, reqDone, func() {
respBody := newResponseBody(str, c.conn, reqDone, func() {
c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "")
})

Expand Down
10 changes: 5 additions & 5 deletions http3/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ var _ = Describe("Client", func() {
})
}

It("resets streams other than the control stream and the QPACK streams", func() {
It("resets streams Other than the control stream and the QPACK streams", func() {
buf := &bytes.Buffer{}
quicvarint.Write(buf, 1337)
str := mockquic.NewMockStream(mockCtrl)
Expand Down Expand Up @@ -410,7 +410,7 @@ var _ = Describe("Client", func() {
fields := make(map[string]string)
decoder := qpack.NewDecoder(nil)

frame, err := parseNextFrame(str)
frame, err := parseNextFrame(str, nil)
ExpectWithOffset(1, err).ToNot(HaveOccurred())
ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{}))
headersFrame := frame.(*headersFrame)
Expand All @@ -429,7 +429,7 @@ var _ = Describe("Client", func() {
buf := &bytes.Buffer{}
rstr := mockquic.NewMockStream(mockCtrl)
rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
rw := newResponseWriter(rstr, utils.DefaultLogger)
rw := newResponseWriter(rstr, nil, utils.DefaultLogger)
rw.WriteHeader(status)
rw.Flush()
return buf.Bytes()
Expand Down Expand Up @@ -717,7 +717,7 @@ var _ = Describe("Client", func() {
buf := &bytes.Buffer{}
rstr := mockquic.NewMockStream(mockCtrl)
rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
rw := newResponseWriter(rstr, utils.DefaultLogger)
rw := newResponseWriter(rstr, nil, utils.DefaultLogger)
rw.Header().Set("Content-Encoding", "gzip")
gz := gzip.NewWriter(rw)
gz.Write([]byte("gzipped response"))
Expand All @@ -743,7 +743,7 @@ var _ = Describe("Client", func() {
buf := &bytes.Buffer{}
rstr := mockquic.NewMockStream(mockCtrl)
rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes()
rw := newResponseWriter(rstr, utils.DefaultLogger)
rw := newResponseWriter(rstr, nil, utils.DefaultLogger)
rw.Write([]byte("not gzipped"))
rw.Flush()
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
Expand Down
49 changes: 32 additions & 17 deletions http3/frames.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package http3

import (
"bytes"
"errors"
"fmt"
"io"
"io/ioutil"
Expand All @@ -10,15 +11,34 @@ import (
"github.com/lucas-clemente/quic-go/quicvarint"
)

// FrameType is the frame type of a HTTP/3 frame
type FrameType uint64

type unknownFrameHandlerFunc func(FrameType) (processed bool, err error)

type frame interface{}

func parseNextFrame(r io.Reader) (frame, error) {
var errHijacked = errors.New("hijacked")

func parseNextFrame(r io.Reader, unknownFrameHandler unknownFrameHandlerFunc) (frame, error) {
qr := quicvarint.NewReader(r)
for {
t, err := quicvarint.Read(qr)
if err != nil {
return nil, err
}
// Call the unknownFrameHandler for frames not defined in the HTTP/3 spec
if t > 0xd && unknownFrameHandler != nil {
hijacked, err := unknownFrameHandler(FrameType(t))
if err != nil {
return nil, err
}
// If the unknownFrameHandler didn't process the frame, it is our responsibility to skip it.
if hijacked {
return nil, errHijacked
}
continue
}
l, err := quicvarint.Read(qr)
if err != nil {
return nil, err
Expand All @@ -32,18 +52,13 @@ func parseNextFrame(r io.Reader) (frame, error) {
case 0x4:
return parseSettingsFrame(r, l)
case 0x3: // CANCEL_PUSH
fallthrough
case 0x5: // PUSH_PROMISE
fallthrough
case 0x7: // GOAWAY
fallthrough
case 0xd: // MAX_PUSH_ID
fallthrough
default:
// skip over unknown frames
if _, err := io.CopyN(ioutil.Discard, qr, int64(l)); err != nil {
return nil, err
}
}
// skip over unknown frames
if _, err := io.CopyN(ioutil.Discard, qr, int64(l)); err != nil {
return nil, err
}
}
}
Expand All @@ -70,7 +85,7 @@ const settingDatagram = 0xffd277

type settingsFrame struct {
Datagram bool
other map[uint64]uint64 // all settings that we don't explicitly recognize
Other map[uint64]uint64 // all settings that we don't explicitly recognize
}

func parseSettingsFrame(r io.Reader, l uint64) (*settingsFrame, error) {
Expand Down Expand Up @@ -108,13 +123,13 @@ func parseSettingsFrame(r io.Reader, l uint64) (*settingsFrame, error) {
}
frame.Datagram = val == 1
default:
if _, ok := frame.other[id]; ok {
if _, ok := frame.Other[id]; ok {
return nil, fmt.Errorf("duplicate setting: %d", id)
}
if frame.other == nil {
frame.other = make(map[uint64]uint64)
if frame.Other == nil {
frame.Other = make(map[uint64]uint64)
}
frame.other[id] = val
frame.Other[id] = val
}
}
return frame, nil
Expand All @@ -123,7 +138,7 @@ func parseSettingsFrame(r io.Reader, l uint64) (*settingsFrame, error) {
func (f *settingsFrame) Write(b *bytes.Buffer) {
quicvarint.Write(b, 0x4)
var l protocol.ByteCount
for id, val := range f.other {
for id, val := range f.Other {
l += quicvarint.Len(id) + quicvarint.Len(val)
}
if f.Datagram {
Expand All @@ -134,7 +149,7 @@ func (f *settingsFrame) Write(b *bytes.Buffer) {
quicvarint.Write(b, settingDatagram)
quicvarint.Write(b, 1)
}
for id, val := range f.other {
for id, val := range f.Other {
quicvarint.Write(b, id)
quicvarint.Write(b, val)
}
Expand Down
Loading