Skip to content

Commit

Permalink
Fix crashes due to mismatching responses sent to the channel when eve…
Browse files Browse the repository at this point in the history
…nt observer factories fail. (#395)

* Fix crashes due to mismatching responses sent to the channel when event observer factories fail.

* Tweak `newFailedFuture`.

* PR fixes.

* Minor comment improvements.

* PR fixes.
  • Loading branch information
MrMage authored Mar 8, 2019
1 parent d4a6366 commit 772b78e
Show file tree
Hide file tree
Showing 8 changed files with 195 additions and 8 deletions.
6 changes: 6 additions & 0 deletions Sources/SwiftGRPCNIO/CallHandlers/BaseCallHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ public class BaseCallHandler<RequestMessage: Message, ResponseMessage: Message>:
fatalError("needs to be overridden")
}

/// Needs to be implemented by this class so that subclasses can override it.
///
/// Otherwise, the subclass's implementation will simply never be called (probably because the protocol's default
/// implementation in an extension is being used instead).
public func handlerAdded(ctx: ChannelHandlerContext) { }

/// Called when the client has half-closed the stream, indicating that they won't send any further data.
///
/// Overridden by subclasses if the "end-of-stream" event is relevant.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ import NIOHTTP1
/// Handles bidirectional streaming calls. Forwards incoming messages and end-of-stream events to the observer block.
///
/// - The observer block is implemented by the framework user and calls `context.sendResponse` as needed.
/// - To close the call and send the status, fulfill `context.statusPromise`.
/// If the framework user wants to return a call error (e.g. in case of authentication failure),
/// they can fail the observer block future.
/// - To close the call and send the status, complete `context.statusPromise`.
public class BidirectionalStreamingCallHandler<RequestMessage: Message, ResponseMessage: Message>: BaseCallHandler<RequestMessage, ResponseMessage> {
public typealias EventObserver = (StreamEvent<RequestMessage>) -> Void
private var eventObserver: EventLoopFuture<EventObserver>?
Expand All @@ -21,14 +23,23 @@ public class BidirectionalStreamingCallHandler<RequestMessage: Message, Response
self.context = context
let eventObserver = eventObserverFactory(context)
self.eventObserver = eventObserver
// Terminate the call if no observer is provided.
eventObserver.cascadeFailure(promise: context.statusPromise)
context.statusPromise.futureResult.whenComplete {
// When done, reset references to avoid retain cycles.
self.eventObserver = nil
self.context = nil
}
}

public override func handlerAdded(ctx: ChannelHandlerContext) {
guard let eventObserver = eventObserver,
let context = context else { return }
// Terminate the call if the future providing an observer fails.
// This is being done _after_ we have been added as a handler to ensure that the `GRPCServerCodec` required to
// translate our outgoing `GRPCServerResponsePart<ResponseMessage>` message is already present on the channel.
// Otherwise, our `OutboundOut` type would not match the `OutboundIn` type of the next handler on the channel.
eventObserver.cascadeFailure(promise: context.statusPromise)
}


public override func processMessage(_ message: RequestMessage) {
eventObserver?.whenSuccess { observer in
Expand Down
15 changes: 13 additions & 2 deletions Sources/SwiftGRPCNIO/CallHandlers/ClientStreamingCallHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ import NIOHTTP1
/// Handles client-streaming calls. Forwards incoming messages and end-of-stream events to the observer block.
///
/// - The observer block is implemented by the framework user and fulfills `context.responsePromise` when done.
/// If the framework user wants to return a call error (e.g. in case of authentication failure),
/// they can fail the observer block future.
/// - To close the call and send the response, complete `context.responsePromise`.
public class ClientStreamingCallHandler<RequestMessage: Message, ResponseMessage: Message>: BaseCallHandler<RequestMessage, ResponseMessage> {
public typealias EventObserver = (StreamEvent<RequestMessage>) -> Void
private var eventObserver: EventLoopFuture<EventObserver>?
Expand All @@ -20,15 +23,23 @@ public class ClientStreamingCallHandler<RequestMessage: Message, ResponseMessage
self.context = context
let eventObserver = eventObserverFactory(context)
self.eventObserver = eventObserver
// Terminate the call if no observer is provided.
eventObserver.cascadeFailure(promise: context.responsePromise)
context.responsePromise.futureResult.whenComplete {
// When done, reset references to avoid retain cycles.
self.eventObserver = nil
self.context = nil
}
}

public override func handlerAdded(ctx: ChannelHandlerContext) {
guard let eventObserver = eventObserver,
let context = context else { return }
// Terminate the call if the future providing an observer fails.
// This is being done _after_ we have been added as a handler to ensure that the `GRPCServerCodec` required to
// translate our outgoing `GRPCServerResponsePart<ResponseMessage>` message is already present on the channel.
// Otherwise, our `OutboundOut` type would not match the `OutboundIn` type of the next handler on the channel.
eventObserver.cascadeFailure(promise: context.responsePromise)
}

public override func processMessage(_ message: RequestMessage) {
eventObserver?.whenSuccess { observer in
observer(.message(message))
Expand Down
2 changes: 1 addition & 1 deletion Sources/SwiftGRPCNIO/CallHandlers/UnaryCallHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import NIOHTTP1
///
/// - The observer block is implemented by the framework user and returns a future containing the call result.
/// - To return a response to the client, the framework user should complete that future
/// (similar to e.g. serving regular HTTP requests in frameworks such as Vapor).
/// (similar to e.g. serving regular HTTP requests in frameworks such as Vapor).
public class UnaryCallHandler<RequestMessage: Message, ResponseMessage: Message>: BaseCallHandler<RequestMessage, ResponseMessage> {
public typealias EventObserver = (RequestMessage) -> EventLoopFuture<ResponseMessage>
private var eventObserver: EventObserver?
Expand Down
5 changes: 4 additions & 1 deletion Tests/LinuxMain.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,14 @@ XCTMain([
testCase(MetadataTests.allTests),
testCase(ServerCancellingTests.allTests),
testCase(ServerTestExample.allTests),
testCase(ServerThrowingTests.allTests),
testCase(SwiftGRPCTests.ServerThrowingTests.allTests),
testCase(ServerTimeoutTests.allTests),

// SwiftGRPCNIO
testCase(NIOServerTests.allTests),
testCase(SwiftGRPCNIOTests.ServerThrowingTests.allTests),
testCase(SwiftGRPCNIOTests.ServerDelayedThrowingTests.allTests),
testCase(SwiftGRPCNIOTests.ClientThrowingWhenServerReturningErrorTests.allTests),
testCase(NIOClientCancellingTests.allTests),
testCase(NIOClientTimeoutTests.allTests),
testCase(NIOServerWebTests.allTests),
Expand Down
4 changes: 3 additions & 1 deletion Tests/SwiftGRPCNIOTests/NIOBasicEchoTestCase.swift
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,15 @@ class NIOBasicEchoTestCase: XCTestCase {

var clientEventLoopGroup: EventLoopGroup!
var client: Echo_EchoService_NIOClient!

func makeEchoProvider() -> Echo_EchoProvider_NIO { return EchoProviderNIO() }

override func setUp() {
super.setUp()

self.serverEventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
self.server = try! GRPCServer.start(
hostname: "localhost", port: 5050, eventLoopGroup: self.serverEventLoopGroup, serviceProviders: [EchoProviderNIO()])
hostname: "localhost", port: 5050, eventLoopGroup: self.serverEventLoopGroup, serviceProviders: [makeEchoProvider()])
.wait()

self.clientEventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
Expand Down
2 changes: 2 additions & 0 deletions Tests/SwiftGRPCNIOTests/NIOServerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class NIOServerTests: NIOBasicEchoTestCase {
return [
("testUnary", testUnary),
("testUnaryLotsOfRequests", testUnaryLotsOfRequests),
("testUnaryWithLargeData", testUnaryWithLargeData),
("testUnaryEmptyRequest", testUnaryEmptyRequest),
("testClientStreaming", testClientStreaming),
("testClientStreamingLotsOfMessages", testClientStreamingLotsOfMessages),
("testServerStreaming", testServerStreaming),
Expand Down
152 changes: 152 additions & 0 deletions Tests/SwiftGRPCNIOTests/ServerThrowingTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
/*
* Copyright 2018, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import Dispatch
import Foundation
import NIO
import NIOHTTP1
import NIOHTTP2
@testable import SwiftGRPCNIO
import XCTest

private let expectedError = GRPCStatus(code: .internalError, message: "expected error")

// Motivation for two different providers: Throwing immediately causes the event observer future (in the
// client-streaming and bidi-streaming cases) to throw immediately, _before_ the corresponding handler has even added
// to the channel. We want to test that case as well as the one where we throw only _after_ the handler has been added
// to the channel.
private class ImmediateThrowingEchoProviderNIO: Echo_EchoProvider_NIO {
func get(request: Echo_EchoRequest, context: StatusOnlyCallContext) -> EventLoopFuture<Echo_EchoResponse> {
return context.eventLoop.newFailedFuture(error: expectedError)
}

func expand(request: Echo_EchoRequest, context: StreamingResponseCallContext<Echo_EchoResponse>) -> EventLoopFuture<GRPCStatus> {
return context.eventLoop.newFailedFuture(error: expectedError)
}

func collect(context: UnaryResponseCallContext<Echo_EchoResponse>) -> EventLoopFuture<(StreamEvent<Echo_EchoRequest>) -> Void> {
return context.eventLoop.newFailedFuture(error: expectedError)
}

func update(context: StreamingResponseCallContext<Echo_EchoResponse>) -> EventLoopFuture<(StreamEvent<Echo_EchoRequest>) -> Void> {
return context.eventLoop.newFailedFuture(error: expectedError)
}
}

private extension EventLoop {
func newFailedFuture<T>(error: Error, delay: TimeInterval) -> EventLoopFuture<T> {
return self.scheduleTask(in: .nanoseconds(TimeAmount.Value(delay * 1000 * 1000 * 1000))) { () }.futureResult
.thenThrowing { _ -> T in throw error }
}
}

/// See `ImmediateThrowingEchoProviderNIO`.
private class DelayedThrowingEchoProviderNIO: Echo_EchoProvider_NIO {
func get(request: Echo_EchoRequest, context: StatusOnlyCallContext) -> EventLoopFuture<Echo_EchoResponse> {
return context.eventLoop.newFailedFuture(error: expectedError, delay: 0.01)
}

func expand(request: Echo_EchoRequest, context: StreamingResponseCallContext<Echo_EchoResponse>) -> EventLoopFuture<GRPCStatus> {
return context.eventLoop.newFailedFuture(error: expectedError, delay: 0.01)
}

func collect(context: UnaryResponseCallContext<Echo_EchoResponse>) -> EventLoopFuture<(StreamEvent<Echo_EchoRequest>) -> Void> {
return context.eventLoop.newFailedFuture(error: expectedError, delay: 0.01)
}

func update(context: StreamingResponseCallContext<Echo_EchoResponse>) -> EventLoopFuture<(StreamEvent<Echo_EchoRequest>) -> Void> {
return context.eventLoop.newFailedFuture(error: expectedError, delay: 0.01)
}
}

/// Ensures that fulfilling the status promise (where possible) with an error yields the same result as failing the future.
private class ErrorReturningEchoProviderNIO: ImmediateThrowingEchoProviderNIO {
// There's no status promise to fulfill for unary calls (only the response promise), so that case is omitted.

override func expand(request: Echo_EchoRequest, context: StreamingResponseCallContext<Echo_EchoResponse>) -> EventLoopFuture<GRPCStatus> {
return context.eventLoop.newSucceededFuture(result: expectedError)
}

override func collect(context: UnaryResponseCallContext<Echo_EchoResponse>) -> EventLoopFuture<(StreamEvent<Echo_EchoRequest>) -> Void> {
return context.eventLoop.newSucceededFuture(result: { _ in
context.responseStatus = expectedError
context.responsePromise.succeed(result: Echo_EchoResponse())
})
}

override func update(context: StreamingResponseCallContext<Echo_EchoResponse>) -> EventLoopFuture<(StreamEvent<Echo_EchoRequest>) -> Void> {
return context.eventLoop.newSucceededFuture(result: { _ in
context.statusPromise.succeed(result: expectedError)
})
}
}

class ServerThrowingTests: NIOBasicEchoTestCase {
override func makeEchoProvider() -> Echo_EchoProvider_NIO { return ImmediateThrowingEchoProviderNIO() }

static var allTests: [(String, (ServerThrowingTests) -> () throws -> Void)] {
return [
("testUnary", testUnary),
("testClientStreaming", testClientStreaming),
("testServerStreaming", testServerStreaming),
("testBidirectionalStreaming", testBidirectionalStreaming),
]
}
}

class ServerDelayedThrowingTests: ServerThrowingTests {
override func makeEchoProvider() -> Echo_EchoProvider_NIO { return DelayedThrowingEchoProviderNIO() }
}

class ClientThrowingWhenServerReturningErrorTests: ServerThrowingTests {
override func makeEchoProvider() -> Echo_EchoProvider_NIO { return ErrorReturningEchoProviderNIO() }
}

extension ServerThrowingTests {
func testUnary() throws {
let call = client.get(Echo_EchoRequest(text: "foo"))
XCTAssertEqual(expectedError, try call.status.wait())
XCTAssertThrowsError(try call.response.wait()) {
XCTAssertEqual(expectedError, $0 as? GRPCStatus)
}
}

func testClientStreaming() {
let call = client.collect()
XCTAssertNoThrow(try call.sendEnd().wait())
XCTAssertEqual(expectedError, try call.status.wait())

if type(of: makeEchoProvider()) != ErrorReturningEchoProviderNIO.self {
// With `ErrorReturningEchoProviderNIO` we actually _return_ a response, which means that the `response` future
// will _not_ fail, so in that case this test doesn't apply.
XCTAssertThrowsError(try call.response.wait()) {
XCTAssertEqual(expectedError, $0 as? GRPCStatus)
}
}
}

func testServerStreaming() {
let call = client.expand(Echo_EchoRequest(text: "foo")) { XCTFail("no message expected, got \($0)") }
// Nothing to throw here, but the `status` should be the expected error.
XCTAssertEqual(expectedError, try call.status.wait())
}

func testBidirectionalStreaming() {
let call = client.update() { XCTFail("no message expected, got \($0)") }
XCTAssertNoThrow(try call.sendEnd().wait())
// Nothing to throw here, but the `status` should be the expected error.
XCTAssertEqual(expectedError, try call.status.wait())
}
}

0 comments on commit 772b78e

Please sign in to comment.