Skip to content

Commit

Permalink
Buffer in the server pipeline configuration (grpc#1564)
Browse files Browse the repository at this point in the history
Motivation:

When not using TLS, the server pipeline configurator inspects the first
bytes on a connection to determine whether HTTP1 or HTTP2 is being used
and closes the connection if it is determined that neither are. It does
this by only parsing the first packet, which may not have enough bytes
to make a correct determination.

Modifications:

- Buffer bytes in the configurator.
- Parse the buffered bytes and only close if enough bytes have been
  received.

Result:

Better version determination.
  • Loading branch information
glbrntt authored and pinlin168 committed Aug 24, 2023
1 parent c9f0d1a commit 9b54ca3
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 41 deletions.
140 changes: 110 additions & 30 deletions Sources/GRPC/GRPCServerPipelineConfigurator.swift
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ final class GRPCServerPipelineConfigurator: ChannelInboundHandler, RemovableChan
/// The server configuration.
private let configuration: Server.Configuration

/// Reads which we're holding on to before the pipeline is configured.
private var bufferedReads = CircularBuffer<NIOAny>()
/// A buffer containing the buffered bytes.
private var buffer: ByteBuffer?

/// The current state.
private var state: State
Expand Down Expand Up @@ -212,13 +212,17 @@ final class GRPCServerPipelineConfigurator: ChannelInboundHandler, RemovableChan
buffer: ByteBuffer,
context: ChannelHandlerContext
) {
if HTTPVersionParser.prefixedWithHTTP2ConnectionPreface(buffer) {
switch HTTPVersionParser.determineHTTPVersion(buffer) {
case .http2:
self.configureHTTP2(context: context)
} else if HTTPVersionParser.prefixedWithHTTP1RequestLine(buffer) {
case .http1:
self.configureHTTP1(context: context)
} else {
case .unknown:
// Neither H2 nor H1 or the length limit has been exceeded.
self.configuration.logger.error("Unable to determine http version, closing")
context.close(mode: .all, promise: nil)
case .notEnoughBytes:
() // Try again with more bytes.
}
}

Expand Down Expand Up @@ -268,13 +272,9 @@ final class GRPCServerPipelineConfigurator: ChannelInboundHandler, RemovableChan

/// Try to parse the buffered data to determine whether or not HTTP/2 or HTTP/1 should be used.
private func tryParsingBufferedData(context: ChannelHandlerContext) {
guard let first = self.bufferedReads.first else {
// No data buffered yet. We'll try when we read.
return
if let buffer = self.buffer {
self.determineHTTPVersionAndConfigurePipeline(buffer: buffer, context: context)
}

let buffer = self.unwrapInboundIn(first)
self.determineHTTPVersionAndConfigurePipeline(buffer: buffer, context: context)
}

// MARK: - Channel Handler
Expand Down Expand Up @@ -312,7 +312,8 @@ final class GRPCServerPipelineConfigurator: ChannelInboundHandler, RemovableChan
}

internal func channelRead(context: ChannelHandlerContext, data: NIOAny) {
self.bufferedReads.append(data)
var buffer = self.unwrapInboundIn(data)
self.buffer.setOrWriteBuffer(&buffer)

switch self.state {
case .notConfigured(alpn: .notExpected),
Expand All @@ -335,8 +336,9 @@ final class GRPCServerPipelineConfigurator: ChannelInboundHandler, RemovableChan
removalToken: ChannelHandlerContext.RemovalToken
) {
// Forward any buffered reads.
while let read = self.bufferedReads.popFirst() {
context.fireChannelRead(read)
if let buffer = self.buffer {
self.buffer = nil
context.fireChannelRead(self.wrapInboundOut(buffer))
}
context.leavePipeline(removalToken: removalToken)
}
Expand Down Expand Up @@ -375,16 +377,64 @@ struct HTTPVersionParser {

/// Determines whether the bytes in the `ByteBuffer` are prefixed with the HTTP/2 client
/// connection preface.
static func prefixedWithHTTP2ConnectionPreface(_ buffer: ByteBuffer) -> Bool {
static func prefixedWithHTTP2ConnectionPreface(_ buffer: ByteBuffer) -> SubParseResult {
let view = buffer.readableBytesView

guard view.count >= HTTPVersionParser.http2ClientMagic.count else {
// Not enough bytes.
return false
return .notEnoughBytes
}

let slice = view[view.startIndex ..< view.startIndex.advanced(by: self.http2ClientMagic.count)]
return slice.elementsEqual(HTTPVersionParser.http2ClientMagic)
return slice.elementsEqual(HTTPVersionParser.http2ClientMagic) ? .accepted : .rejected
}

enum ParseResult: Hashable {
case http1
case http2
case unknown
case notEnoughBytes
}

enum SubParseResult: Hashable {
case accepted
case rejected
case notEnoughBytes
}

private static let maxLengthToCheck = 1024

static func determineHTTPVersion(_ buffer: ByteBuffer) -> ParseResult {
switch Self.prefixedWithHTTP2ConnectionPreface(buffer) {
case .accepted:
return .http2

case .notEnoughBytes:
switch Self.prefixedWithHTTP1RequestLine(buffer) {
case .accepted:
// Not enough bytes to check H2, but enough to confirm H1.
return .http1
case .notEnoughBytes:
// Not enough bytes to check H2 or H1.
return .notEnoughBytes
case .rejected:
// Not enough bytes to check H2 and definitely not H1.
return .notEnoughBytes
}

case .rejected:
switch Self.prefixedWithHTTP1RequestLine(buffer) {
case .accepted:
// Not H2, but H1 is confirmed.
return .http1
case .notEnoughBytes:
// Not H2, but not enough bytes to reject H1 yet.
return .notEnoughBytes
case .rejected:
// Not H2 or H1.
return .unknown
}
}
}

private static let http1_1 = [
Expand All @@ -399,29 +449,59 @@ struct HTTPVersionParser {
]

/// Determines whether the bytes in the `ByteBuffer` are prefixed with an HTTP/1.1 request line.
static func prefixedWithHTTP1RequestLine(_ buffer: ByteBuffer) -> Bool {
static func prefixedWithHTTP1RequestLine(_ buffer: ByteBuffer) -> SubParseResult {
var readableBytesView = buffer.readableBytesView

// We don't need to validate the request line, only determine whether we think it's an HTTP1
// request line. Another handler will parse it properly.

// From RFC 2616 § 5.1:
// Request-Line = Method SP Request-URI SP HTTP-Version CRLF

// Read off the Method and Request-URI (and spaces).
guard readableBytesView.trimPrefix(to: UInt8(ascii: " ")) != nil,
readableBytesView.trimPrefix(to: UInt8(ascii: " ")) != nil else {
return false
// Get through the first space.
guard readableBytesView.dropPrefix(through: UInt8(ascii: " ")) != nil else {
let tooLong = buffer.readableBytes > Self.maxLengthToCheck
return tooLong ? .rejected : .notEnoughBytes
}

// Get through the second space.
guard readableBytesView.dropPrefix(through: UInt8(ascii: " ")) != nil else {
let tooLong = buffer.readableBytes > Self.maxLengthToCheck
return tooLong ? .rejected : .notEnoughBytes
}

// +2 for \r\n
guard readableBytesView.count >= (Self.http1_1.count + 2) else {
return .notEnoughBytes
}

// Read off the HTTP-Version and CR.
guard let versionView = readableBytesView.trimPrefix(to: UInt8(ascii: "\r")) else {
return false
guard let version = readableBytesView.dropPrefix(through: UInt8(ascii: "\r")),
readableBytesView.first == UInt8(ascii: "\n") else {
// If we didn't drop the prefix OR we did and the next byte wasn't '\n', then we had enough
// bytes but the '\r\n' wasn't present: reject this as being HTTP1.
return .rejected
}

return version.elementsEqual(Self.http1_1) ? .accepted : .rejected
}
}

extension Collection where Self == Self.SubSequence, Self.Element: Equatable {
/// Drops the prefix off the collection up to and including the first `separator`
/// only if that separator appears in the collection.
///
/// Returns the prefix up to but not including the separator if it was found, nil otherwise.
mutating func dropPrefix(through separator: Element) -> SubSequence? {
if self.isEmpty {
return nil
}

// Check that the LF followed the CR.
guard readableBytesView.first == UInt8(ascii: "\n") else {
return false
guard let separatorIndex = self.firstIndex(of: separator) else {
return nil
}

// Now check the HTTP version.
return versionView.elementsEqual(HTTPVersionParser.http1_1)
let prefix = self[..<separatorIndex]
self = self[self.index(after: separatorIndex)...]
return prefix
}
}
51 changes: 51 additions & 0 deletions Tests/GRPCTests/GRPCServerPipelineConfiguratorTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,24 @@ class GRPCServerPipelineConfiguratorTests: GRPCTestCase {
self.assertHTTP2Handler(isPresent: true)
}

func testHTTP2SetupViaBytesDripFed() {
self.setUp(tls: false)
var bytes = ByteBuffer(staticString: "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n")
var head = bytes.readSlice(length: bytes.readableBytes - 1)!
let tail = bytes.readSlice(length: 1)!

while let slice = head.readSlice(length: 1) {
assertThat(try self.channel.writeInbound(slice), .doesNotThrow())
self.assertConfigurator(isPresent: true)
self.assertHTTP2Handler(isPresent: false)
}

// Final byte.
assertThat(try self.channel.writeInbound(tail), .doesNotThrow())
self.assertConfigurator(isPresent: false)
self.assertHTTP2Handler(isPresent: true)
}

func testHTTP1Dot1SetupViaBytes() {
self.setUp(tls: false)
let bytes = ByteBuffer(staticString: "GET http://www.foo.bar HTTP/1.1\r\n")
Expand All @@ -143,6 +161,39 @@ class GRPCServerPipelineConfiguratorTests: GRPCTestCase {
self.assertGRPCWebToHTTP2Handler(isPresent: true)
}

func testHTTP1Dot1SetupViaBytesDripFed() {
self.setUp(tls: false)
var bytes = ByteBuffer(staticString: "GET http://www.foo.bar HTTP/1.1\r\n")
var head = bytes.readSlice(length: bytes.readableBytes - 1)!
let tail = bytes.readSlice(length: 1)!

while let slice = head.readSlice(length: 1) {
assertThat(try self.channel.writeInbound(slice), .doesNotThrow())
self.assertConfigurator(isPresent: true)
self.assertGRPCWebToHTTP2Handler(isPresent: false)
}

// Final byte.
assertThat(try self.channel.writeInbound(tail), .doesNotThrow())
self.assertConfigurator(isPresent: false)
self.assertGRPCWebToHTTP2Handler(isPresent: true)
}

func testUnexpectedInputClosesEventuallyWhenDripFed() {
self.setUp(tls: false)
var bytes = ByteBuffer(repeating: UInt8(ascii: "a"), count: 2048)

while let slice = bytes.readSlice(length: 1) {
assertThat(try self.channel.writeInbound(slice), .doesNotThrow())
self.assertConfigurator(isPresent: true)
self.assertHTTP2Handler(isPresent: false)
self.assertGRPCWebToHTTP2Handler(isPresent: false)
}

self.channel.embeddedEventLoop.run()
assertThat(try self.channel.closeFuture.wait(), .doesNotThrow())
}

func testReadsAreUnbufferedAfterConfiguration() throws {
self.setUp(tls: false)

Expand Down
28 changes: 17 additions & 11 deletions Tests/GRPCTests/HTTPVersionParserTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -22,58 +22,64 @@ class HTTPVersionParserTests: GRPCTestCase {

func testHTTP2ExactlyTheRightBytes() {
let buffer = ByteBuffer(string: self.preface)
XCTAssertTrue(HTTPVersionParser.prefixedWithHTTP2ConnectionPreface(buffer))
XCTAssertEqual(HTTPVersionParser.prefixedWithHTTP2ConnectionPreface(buffer), .accepted)
}

func testHTTP2TheRightBytesAndMore() {
var buffer = ByteBuffer(string: self.preface)
buffer.writeRepeatingByte(42, count: 1024)
XCTAssertTrue(HTTPVersionParser.prefixedWithHTTP2ConnectionPreface(buffer))
XCTAssertEqual(HTTPVersionParser.prefixedWithHTTP2ConnectionPreface(buffer), .accepted)
}

func testHTTP2NoBytes() {
let empty = ByteBuffer()
XCTAssertFalse(HTTPVersionParser.prefixedWithHTTP2ConnectionPreface(empty))
XCTAssertEqual(HTTPVersionParser.prefixedWithHTTP2ConnectionPreface(empty), .notEnoughBytes)
}

func testHTTP2NotEnoughBytes() {
var buffer = ByteBuffer(string: self.preface)
buffer.moveWriterIndex(to: buffer.writerIndex - 1)
XCTAssertFalse(HTTPVersionParser.prefixedWithHTTP2ConnectionPreface(buffer))
XCTAssertEqual(HTTPVersionParser.prefixedWithHTTP2ConnectionPreface(buffer), .notEnoughBytes)
}

func testHTTP2EnoughOfTheWrongBytes() {
let buffer = ByteBuffer(string: String(self.preface.reversed()))
XCTAssertFalse(HTTPVersionParser.prefixedWithHTTP2ConnectionPreface(buffer))
XCTAssertEqual(HTTPVersionParser.prefixedWithHTTP2ConnectionPreface(buffer), .rejected)
}

func testHTTP1RequestLine() {
let buffer = ByteBuffer(staticString: "GET https://grpc.io/index.html HTTP/1.1\r\n")
XCTAssertTrue(HTTPVersionParser.prefixedWithHTTP1RequestLine(buffer))
XCTAssertEqual(HTTPVersionParser.prefixedWithHTTP1RequestLine(buffer), .accepted)
}

func testHTTP1RequestLineAndMore() {
let buffer = ByteBuffer(staticString: "GET https://grpc.io/index.html HTTP/1.1\r\nMore")
XCTAssertTrue(HTTPVersionParser.prefixedWithHTTP1RequestLine(buffer))
XCTAssertEqual(HTTPVersionParser.prefixedWithHTTP1RequestLine(buffer), .accepted)
}

func testHTTP1RequestLineWithoutCRLF() {
let buffer = ByteBuffer(staticString: "GET https://grpc.io/index.html HTTP/1.1")
XCTAssertFalse(HTTPVersionParser.prefixedWithHTTP1RequestLine(buffer))
XCTAssertEqual(HTTPVersionParser.prefixedWithHTTP1RequestLine(buffer), .notEnoughBytes)
}

func testHTTP1NoBytes() {
let empty = ByteBuffer()
XCTAssertFalse(HTTPVersionParser.prefixedWithHTTP1RequestLine(empty))
XCTAssertEqual(HTTPVersionParser.prefixedWithHTTP1RequestLine(empty), .notEnoughBytes)
}

func testHTTP1IncompleteRequestLine() {
let buffer = ByteBuffer(staticString: "GET https://grpc.io/index.html")
XCTAssertFalse(HTTPVersionParser.prefixedWithHTTP1RequestLine(buffer))
XCTAssertEqual(HTTPVersionParser.prefixedWithHTTP1RequestLine(buffer), .notEnoughBytes)
}

func testHTTP1MalformedVersion() {
let buffer = ByteBuffer(staticString: "GET https://grpc.io/index.html ptth/1.1\r\n")
XCTAssertFalse(HTTPVersionParser.prefixedWithHTTP1RequestLine(buffer))
XCTAssertEqual(HTTPVersionParser.prefixedWithHTTP1RequestLine(buffer), .rejected)
}

func testTooManyIncorrectBytes() {
let buffer = ByteBuffer(repeating: UInt8(ascii: "\r"), count: 2048)
XCTAssertEqual(HTTPVersionParser.prefixedWithHTTP2ConnectionPreface(buffer), .rejected)
XCTAssertEqual(HTTPVersionParser.prefixedWithHTTP1RequestLine(buffer), .rejected)
}
}

0 comments on commit 9b54ca3

Please sign in to comment.